diff options
| author | Simon Charette <charette.s@gmail.com> | 2025-03-19 01:39:19 -0400 |
|---|---|---|
| committer | Mariusz Felisiak <felisiak.mariusz@gmail.com> | 2025-09-14 00:27:50 +0200 |
| commit | 94680437a45a71c70ca8bd2e68b72aa1e2eff337 (patch) | |
| tree | 0246e0b5a8dde81f0a327e1c7c7e3accd9257870 /tests | |
| parent | 55a0073b3beb9de8f7c1f7c44a7d0bc10126c841 (diff) | |
Fixed #27222 -- Refreshed model field values assigned expressions on save().
Removed the can_return_columns_from_insert skip gates on existing
field_defaults tests to confirm the expected number of queries are
performed and that returning field overrides are respected.
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/basic/tests.py | 22 | ||||
| -rw-r--r-- | tests/expressions/tests.py | 14 | ||||
| -rw-r--r-- | tests/field_defaults/tests.py | 70 | ||||
| -rw-r--r-- | tests/update_only_fields/tests.py | 16 |
4 files changed, 84 insertions, 38 deletions
diff --git a/tests/basic/tests.py b/tests/basic/tests.py index f8ec2715f6..38e7278210 100644 --- a/tests/basic/tests.py +++ b/tests/basic/tests.py @@ -1,5 +1,6 @@ import inspect import threading +import time from datetime import datetime, timedelta from unittest import mock @@ -12,6 +13,7 @@ from django.db import ( models, transaction, ) +from django.db.models.functions import Now from django.db.models.manager import BaseManager from django.db.models.query import MAX_GET_RESULTS, EmptyQuerySet from django.test import ( @@ -558,6 +560,26 @@ class ModelTest(TestCase): with self.subTest(case=case): self.assertIs(case._is_pk_set(), True) + def test_save_expressions(self): + article = Article(pub_date=Now()) + article.save() + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 1 + ) + with self.assertNumQueries(expected_num_queries): + article_pub_date = article.pub_date + self.assertIsInstance(article_pub_date, datetime) + # Sleep slightly to ensure a different database level NOW(). + time.sleep(0.1) + article.pub_date = Now() + article.save() + expected_num_queries = ( + 0 if connection.features.can_return_rows_from_update else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertIsInstance(article.pub_date, datetime) + self.assertGreater(article.pub_date, article_pub_date) + class ModelLookupTest(TestCase): @classmethod diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index 27d88be621..6f18321aa7 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -420,8 +420,11 @@ class BasicExpressionsTests(TestCase): # F expressions can be used to update attributes on single objects self.gmbh.num_employees = F("num_employees") + 4 self.gmbh.save() - self.gmbh.refresh_from_db() - self.assertEqual(self.gmbh.num_employees, 36) + expected_num_queries = ( + 0 if connection.features.can_return_rows_from_update else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(self.gmbh.num_employees, 36) def test_new_object_save(self): # We should be able to use Funcs when inserting new data @@ -1644,8 +1647,11 @@ class ExpressionsNumericTests(TestCase): n = Number.objects.create(integer=1, decimal_value=Decimal("0.5")) n.decimal_value = F("decimal_value") - Decimal("0.4") n.save() - n.refresh_from_db() - self.assertEqual(n.decimal_value, Decimal("0.1")) + expected_num_queries = ( + 0 if connection.features.can_return_rows_from_update else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(n.decimal_value, Decimal("0.1")) class ExpressionOperatorTests(TestCase): diff --git a/tests/field_defaults/tests.py b/tests/field_defaults/tests.py index e914adfc51..7f85d946f6 100644 --- a/tests/field_defaults/tests.py +++ b/tests/field_defaults/tests.py @@ -15,13 +15,7 @@ from django.db.models.expressions import ( ) from django.db.models.functions import Collate from django.db.models.lookups import GreaterThan -from django.test import ( - SimpleTestCase, - TestCase, - override_settings, - skipIfDBFeature, - skipUnlessDBFeature, -) +from django.test import SimpleTestCase, TestCase, override_settings, skipUnlessDBFeature from django.utils import timezone from .models import ( @@ -44,47 +38,56 @@ class DefaultTests(TestCase): self.assertEqual(a.headline, "Default headline") self.assertLess((now - a.pub_date).seconds, 5) - @skipUnlessDBFeature( - "can_return_columns_from_insert", "supports_expression_defaults" - ) + @skipUnlessDBFeature("supports_expression_defaults") def test_field_db_defaults_returning(self): a = DBArticle() a.save() self.assertIsInstance(a.id, int) - self.assertEqual(a.headline, "Default headline") - self.assertIsInstance(a.pub_date, datetime) - self.assertEqual(a.cost, Decimal("3.33")) + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 3 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(a.headline, "Default headline") + self.assertIsInstance(a.pub_date, datetime) + self.assertEqual(a.cost, Decimal("3.33")) - @skipIfDBFeature("can_return_columns_from_insert") @skipUnlessDBFeature("supports_expression_defaults") def test_field_db_defaults_refresh(self): a = DBArticle() a.save() - a.refresh_from_db() + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 3 + ) self.assertIsInstance(a.id, int) - self.assertEqual(a.headline, "Default headline") - self.assertIsInstance(a.pub_date, datetime) - self.assertEqual(a.cost, Decimal("3.33")) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(a.headline, "Default headline") + self.assertIsInstance(a.pub_date, datetime) + self.assertEqual(a.cost, Decimal("3.33")) def test_null_db_default(self): obj1 = DBDefaults.objects.create() - if not connection.features.can_return_columns_from_insert: - obj1.refresh_from_db() - self.assertEqual(obj1.null, 1.1) + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(obj1.null, 1.1) obj2 = DBDefaults.objects.create(null=None) - self.assertIsNone(obj2.null) + with self.assertNumQueries(0): + self.assertIsNone(obj2.null) @skipUnlessDBFeature("supports_expression_defaults") @override_settings(USE_TZ=True) def test_db_default_function(self): m = DBDefaultsFunction.objects.create() - if not connection.features.can_return_columns_from_insert: - m.refresh_from_db() - self.assertAlmostEqual(m.number, pi) - self.assertEqual(m.year, timezone.now().year) - self.assertAlmostEqual(m.added, pi + 4.5) - self.assertEqual(m.multiple_subfunctions, 4.5) + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 4 + ) + with self.assertNumQueries(expected_num_queries): + self.assertAlmostEqual(m.number, pi) + self.assertEqual(m.year, timezone.now().year) + self.assertAlmostEqual(m.added, pi + 4.5) + self.assertEqual(m.multiple_subfunctions, 4.5) @skipUnlessDBFeature("insert_test_table_with_defaults") def test_both_default(self): @@ -125,14 +128,15 @@ class DefaultTests(TestCase): child2 = DBDefaultsFK.objects.create(language_code=parent2) self.assertEqual(child2.language_code, parent2) - @skipUnlessDBFeature( - "can_return_columns_from_insert", "supports_expression_defaults" - ) + @skipUnlessDBFeature("supports_expression_defaults") def test_case_when_db_default_returning(self): m = DBDefaultsFunction.objects.create() - self.assertEqual(m.case_when, 3) + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(m.case_when, 3) - @skipIfDBFeature("can_return_columns_from_insert") @skipUnlessDBFeature("supports_expression_defaults") def test_case_when_db_default_no_returning(self): m = DBDefaultsFunction.objects.create() diff --git a/tests/update_only_fields/tests.py b/tests/update_only_fields/tests.py index 9595c767eb..1c7ef88832 100644 --- a/tests/update_only_fields/tests.py +++ b/tests/update_only_fields/tests.py @@ -1,5 +1,6 @@ from django.core.exceptions import ObjectNotUpdated -from django.db import DatabaseError, transaction +from django.db import DatabaseError, connection, transaction +from django.db.models import F from django.db.models.signals import post_save, pre_save from django.test import TestCase @@ -308,3 +309,16 @@ class UpdateOnlyFieldsTests(TestCase): transaction.atomic(), ): obj.save(update_fields=["name"]) + + def test_update_fields_expression(self): + obj = Person.objects.create(name="Valerie", gender="F", pid=42) + updated_pid = F("pid") + 1 + obj.pid = updated_pid + obj.save(update_fields={"gender"}) + self.assertIs(obj.pid, updated_pid) + obj.save(update_fields={"pid"}) + expected_num_queries = ( + 0 if connection.features.can_return_rows_from_update else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(obj.pid, 43) |
