summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorSimon Charette <charette.s@gmail.com>2025-03-19 01:11:34 -0400
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2025-09-14 00:27:49 +0200
commit55a0073b3beb9de8f7c1f7c44a7d0bc10126c841 (patch)
tree616a0bf54b0d9e3d09a2d033980f07bbb2a83e0d /tests
parentc48904a225e2e8f02274257247d5b7d29c5fe183 (diff)
Refs #27222 -- Refreshed GeneratedFields values on save() initiated update.
This required implementing UPDATE RETURNING machinery that heavily borrows from the INSERT one.
Diffstat (limited to 'tests')
-rw-r--r--tests/model_fields/test_generatedfield.py110
1 files changed, 81 insertions, 29 deletions
diff --git a/tests/model_fields/test_generatedfield.py b/tests/model_fields/test_generatedfield.py
index b6a933451d..f0ac6eecb5 100644
--- a/tests/model_fields/test_generatedfield.py
+++ b/tests/model_fields/test_generatedfield.py
@@ -173,11 +173,6 @@ class BaseGeneratedFieldTests(SimpleTestCase):
class GeneratedFieldTestMixin:
- def _refresh_if_needed(self, m):
- if not connection.features.can_return_columns_from_insert:
- m.refresh_from_db()
- return m
-
def test_unsaved_error(self):
m = self.base_model(a=1, b=2)
msg = "Cannot retrieve deferred field 'field' from an unsaved model."
@@ -189,8 +184,11 @@ class GeneratedFieldTestMixin:
# full_clean() ignores GeneratedFields.
m.full_clean()
m.save()
- m = self._refresh_if_needed(m)
- self.assertEqual(m.field, 3)
+ expected_num_queries = (
+ 0 if connection.features.can_return_columns_from_insert else 1
+ )
+ with self.assertNumQueries(expected_num_queries):
+ self.assertEqual(m.field, 3)
@skipUnlessDBFeature("supports_table_check_constraints")
def test_full_clean_with_check_constraint(self):
@@ -199,8 +197,11 @@ class GeneratedFieldTestMixin:
m = self.check_constraint_model(a=2)
m.full_clean()
m.save()
- m = self._refresh_if_needed(m)
- self.assertEqual(m.a_squared, 4)
+ expected_num_queries = (
+ 0 if connection.features.can_return_columns_from_insert else 1
+ )
+ with self.assertNumQueries(expected_num_queries):
+ self.assertEqual(m.a_squared, 4)
m = self.check_constraint_model(a=-1)
with self.assertRaises(ValidationError) as cm:
@@ -217,8 +218,11 @@ class GeneratedFieldTestMixin:
m = self.unique_constraint_model(a=2)
m.full_clean()
m.save()
- m = self._refresh_if_needed(m)
- self.assertEqual(m.a_squared, 4)
+ expected_num_queries = (
+ 0 if connection.features.can_return_columns_from_insert else 1
+ )
+ with self.assertNumQueries(expected_num_queries):
+ self.assertEqual(m.a_squared, 4)
m = self.unique_constraint_model(a=2)
with self.assertRaises(ValidationError) as cm:
@@ -230,8 +234,11 @@ class GeneratedFieldTestMixin:
def test_create(self):
m = self.base_model.objects.create(a=1, b=2)
- m = self._refresh_if_needed(m)
- self.assertEqual(m.field, 3)
+ expected_num_queries = (
+ 0 if connection.features.can_return_columns_from_insert else 1
+ )
+ with self.assertNumQueries(expected_num_queries):
+ self.assertEqual(m.field, 3)
def test_non_nullable_create(self):
with self.assertRaises(IntegrityError):
@@ -241,26 +248,52 @@ class GeneratedFieldTestMixin:
# Insert.
m = self.base_model(a=2, b=4)
m.save()
- m = self._refresh_if_needed(m)
- self.assertEqual(m.field, 6)
+ expected_num_queries = (
+ 0 if connection.features.can_return_columns_from_insert else 1
+ )
+ with self.assertNumQueries(expected_num_queries):
+ self.assertEqual(m.field, 6)
# Update.
m.a = 4
m.save()
- m.refresh_from_db()
- self.assertEqual(m.field, 8)
+ expected_num_queries = (
+ 0 if connection.features.can_return_rows_from_update else 1
+ )
+ with self.assertNumQueries(expected_num_queries):
+ self.assertEqual(m.field, 8)
+ # Update non-dependent field.
+ self.base_model.objects.filter(pk=m.pk).update(a=6)
+ m.save(update_fields=["fk"])
+ with self.assertNumQueries(0):
+ self.assertEqual(m.field, 8)
+ # Update dependent field without persisting local changes.
+ m.save(update_fields=["b"])
+ with self.assertNumQueries(expected_num_queries):
+ self.assertEqual(m.field, 10)
+ # Update dependent field while persisting local changes.
+ m.a = 8
+ m.save(update_fields=["a"])
+ with self.assertNumQueries(expected_num_queries):
+ self.assertEqual(m.field, 12)
def test_save_model_with_pk(self):
m = self.base_model(pk=1, a=1, b=2)
m.save()
- m = self._refresh_if_needed(m)
- self.assertEqual(m.field, 3)
+ expected_num_queries = (
+ 0 if connection.features.can_return_columns_from_insert else 1
+ )
+ with self.assertNumQueries(expected_num_queries):
+ self.assertEqual(m.field, 3)
def test_save_model_with_foreign_key(self):
fk_object = Foo.objects.create(a="abc", d=Decimal("12.34"))
m = self.base_model(a=1, b=2, fk=fk_object)
m.save()
- m = self._refresh_if_needed(m)
- self.assertEqual(m.field, 3)
+ expected_num_queries = (
+ 0 if connection.features.can_return_columns_from_insert else 1
+ )
+ with self.assertNumQueries(expected_num_queries):
+ self.assertEqual(m.field, 3)
def test_generated_fields_can_be_deferred(self):
fk_object = Foo.objects.create(a="abc", d=Decimal("12.34"))
@@ -330,17 +363,23 @@ class GeneratedFieldTestMixin:
def test_model_with_params(self):
m = self.params_model.objects.create()
- m = self._refresh_if_needed(m)
- self.assertEqual(m.field, "Constant")
+ expected_num_queries = (
+ 0 if connection.features.can_return_columns_from_insert else 1
+ )
+ with self.assertNumQueries(expected_num_queries):
+ self.assertEqual(m.field, "Constant")
def test_nullable(self):
m1 = self.nullable_model.objects.create()
- m1 = self._refresh_if_needed(m1)
none_val = "" if connection.features.interprets_empty_strings_as_nulls else None
- self.assertEqual(m1.lower_name, none_val)
+ expected_num_queries = (
+ 0 if connection.features.can_return_columns_from_insert else 1
+ )
+ with self.assertNumQueries(expected_num_queries):
+ self.assertEqual(m1.lower_name, none_val)
m2 = self.nullable_model.objects.create(name="NaMe")
- m2 = self._refresh_if_needed(m2)
- self.assertEqual(m2.lower_name, "name")
+ with self.assertNumQueries(expected_num_queries):
+ self.assertEqual(m2.lower_name, "name")
@skipUnlessDBFeature("supports_stored_generated_columns")
@@ -354,8 +393,21 @@ class StoredGeneratedFieldTests(GeneratedFieldTestMixin, TestCase):
def test_create_field_with_db_converters(self):
obj = GeneratedModelFieldWithConverters.objects.create(field=uuid.uuid4())
- obj = self._refresh_if_needed(obj)
- self.assertEqual(obj.field, obj.field_copy)
+ expected_num_queries = (
+ 0 if connection.features.can_return_columns_from_insert else 1
+ )
+ with self.assertNumQueries(expected_num_queries):
+ self.assertEqual(obj.field, obj.field_copy)
+
+ def test_save_field_with_db_converters(self):
+ obj = GeneratedModelFieldWithConverters.objects.create(field=uuid.uuid4())
+ obj.field = uuid.uuid4()
+ expected_num_queries = (
+ 0 if connection.features.can_return_rows_from_update else 1
+ )
+ obj.save(update_fields={"field"})
+ with self.assertNumQueries(expected_num_queries):
+ self.assertEqual(obj.field, obj.field_copy)
def test_create_with_non_auto_pk(self):
obj = GeneratedModelNonAutoPk.objects.create(id=1, a=2)