diff options
| author | Simon Charette <charette.s@gmail.com> | 2025-03-19 01:11:34 -0400 |
|---|---|---|
| committer | Mariusz Felisiak <felisiak.mariusz@gmail.com> | 2025-09-14 00:27:49 +0200 |
| commit | 55a0073b3beb9de8f7c1f7c44a7d0bc10126c841 (patch) | |
| tree | 616a0bf54b0d9e3d09a2d033980f07bbb2a83e0d /tests | |
| parent | c48904a225e2e8f02274257247d5b7d29c5fe183 (diff) | |
Refs #27222 -- Refreshed GeneratedFields values on save() initiated update.
This required implementing UPDATE RETURNING machinery that heavily
borrows from the INSERT one.
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/model_fields/test_generatedfield.py | 110 |
1 files changed, 81 insertions, 29 deletions
diff --git a/tests/model_fields/test_generatedfield.py b/tests/model_fields/test_generatedfield.py index b6a933451d..f0ac6eecb5 100644 --- a/tests/model_fields/test_generatedfield.py +++ b/tests/model_fields/test_generatedfield.py @@ -173,11 +173,6 @@ class BaseGeneratedFieldTests(SimpleTestCase): class GeneratedFieldTestMixin: - def _refresh_if_needed(self, m): - if not connection.features.can_return_columns_from_insert: - m.refresh_from_db() - return m - def test_unsaved_error(self): m = self.base_model(a=1, b=2) msg = "Cannot retrieve deferred field 'field' from an unsaved model." @@ -189,8 +184,11 @@ class GeneratedFieldTestMixin: # full_clean() ignores GeneratedFields. m.full_clean() m.save() - m = self._refresh_if_needed(m) - self.assertEqual(m.field, 3) + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(m.field, 3) @skipUnlessDBFeature("supports_table_check_constraints") def test_full_clean_with_check_constraint(self): @@ -199,8 +197,11 @@ class GeneratedFieldTestMixin: m = self.check_constraint_model(a=2) m.full_clean() m.save() - m = self._refresh_if_needed(m) - self.assertEqual(m.a_squared, 4) + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(m.a_squared, 4) m = self.check_constraint_model(a=-1) with self.assertRaises(ValidationError) as cm: @@ -217,8 +218,11 @@ class GeneratedFieldTestMixin: m = self.unique_constraint_model(a=2) m.full_clean() m.save() - m = self._refresh_if_needed(m) - self.assertEqual(m.a_squared, 4) + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(m.a_squared, 4) m = self.unique_constraint_model(a=2) with self.assertRaises(ValidationError) as cm: @@ -230,8 +234,11 @@ class GeneratedFieldTestMixin: def test_create(self): m = self.base_model.objects.create(a=1, b=2) - m = self._refresh_if_needed(m) - self.assertEqual(m.field, 3) + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(m.field, 3) def test_non_nullable_create(self): with self.assertRaises(IntegrityError): @@ -241,26 +248,52 @@ class GeneratedFieldTestMixin: # Insert. m = self.base_model(a=2, b=4) m.save() - m = self._refresh_if_needed(m) - self.assertEqual(m.field, 6) + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(m.field, 6) # Update. m.a = 4 m.save() - m.refresh_from_db() - self.assertEqual(m.field, 8) + expected_num_queries = ( + 0 if connection.features.can_return_rows_from_update else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(m.field, 8) + # Update non-dependent field. + self.base_model.objects.filter(pk=m.pk).update(a=6) + m.save(update_fields=["fk"]) + with self.assertNumQueries(0): + self.assertEqual(m.field, 8) + # Update dependent field without persisting local changes. + m.save(update_fields=["b"]) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(m.field, 10) + # Update dependent field while persisting local changes. + m.a = 8 + m.save(update_fields=["a"]) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(m.field, 12) def test_save_model_with_pk(self): m = self.base_model(pk=1, a=1, b=2) m.save() - m = self._refresh_if_needed(m) - self.assertEqual(m.field, 3) + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(m.field, 3) def test_save_model_with_foreign_key(self): fk_object = Foo.objects.create(a="abc", d=Decimal("12.34")) m = self.base_model(a=1, b=2, fk=fk_object) m.save() - m = self._refresh_if_needed(m) - self.assertEqual(m.field, 3) + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(m.field, 3) def test_generated_fields_can_be_deferred(self): fk_object = Foo.objects.create(a="abc", d=Decimal("12.34")) @@ -330,17 +363,23 @@ class GeneratedFieldTestMixin: def test_model_with_params(self): m = self.params_model.objects.create() - m = self._refresh_if_needed(m) - self.assertEqual(m.field, "Constant") + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(m.field, "Constant") def test_nullable(self): m1 = self.nullable_model.objects.create() - m1 = self._refresh_if_needed(m1) none_val = "" if connection.features.interprets_empty_strings_as_nulls else None - self.assertEqual(m1.lower_name, none_val) + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(m1.lower_name, none_val) m2 = self.nullable_model.objects.create(name="NaMe") - m2 = self._refresh_if_needed(m2) - self.assertEqual(m2.lower_name, "name") + with self.assertNumQueries(expected_num_queries): + self.assertEqual(m2.lower_name, "name") @skipUnlessDBFeature("supports_stored_generated_columns") @@ -354,8 +393,21 @@ class StoredGeneratedFieldTests(GeneratedFieldTestMixin, TestCase): def test_create_field_with_db_converters(self): obj = GeneratedModelFieldWithConverters.objects.create(field=uuid.uuid4()) - obj = self._refresh_if_needed(obj) - self.assertEqual(obj.field, obj.field_copy) + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(obj.field, obj.field_copy) + + def test_save_field_with_db_converters(self): + obj = GeneratedModelFieldWithConverters.objects.create(field=uuid.uuid4()) + obj.field = uuid.uuid4() + expected_num_queries = ( + 0 if connection.features.can_return_rows_from_update else 1 + ) + obj.save(update_fields={"field"}) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(obj.field, obj.field_copy) def test_create_with_non_auto_pk(self): obj = GeneratedModelNonAutoPk.objects.create(id=1, a=2) |
