summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorSimon Charette <charette.s@gmail.com>2025-03-19 01:39:19 -0400
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2025-09-14 00:27:50 +0200
commit94680437a45a71c70ca8bd2e68b72aa1e2eff337 (patch)
tree0246e0b5a8dde81f0a327e1c7c7e3accd9257870 /tests
parent55a0073b3beb9de8f7c1f7c44a7d0bc10126c841 (diff)
Fixed #27222 -- Refreshed model field values assigned expressions on save().
Removed the can_return_columns_from_insert skip gates on existing field_defaults tests to confirm the expected number of queries are performed and that returning field overrides are respected.
Diffstat (limited to 'tests')
-rw-r--r--tests/basic/tests.py22
-rw-r--r--tests/expressions/tests.py14
-rw-r--r--tests/field_defaults/tests.py70
-rw-r--r--tests/update_only_fields/tests.py16
4 files changed, 84 insertions, 38 deletions
diff --git a/tests/basic/tests.py b/tests/basic/tests.py
index f8ec2715f6..38e7278210 100644
--- a/tests/basic/tests.py
+++ b/tests/basic/tests.py
@@ -1,5 +1,6 @@
import inspect
import threading
+import time
from datetime import datetime, timedelta
from unittest import mock
@@ -12,6 +13,7 @@ from django.db import (
models,
transaction,
)
+from django.db.models.functions import Now
from django.db.models.manager import BaseManager
from django.db.models.query import MAX_GET_RESULTS, EmptyQuerySet
from django.test import (
@@ -558,6 +560,26 @@ class ModelTest(TestCase):
with self.subTest(case=case):
self.assertIs(case._is_pk_set(), True)
+ def test_save_expressions(self):
+ article = Article(pub_date=Now())
+ article.save()
+ expected_num_queries = (
+ 0 if connection.features.can_return_columns_from_insert else 1
+ )
+ with self.assertNumQueries(expected_num_queries):
+ article_pub_date = article.pub_date
+ self.assertIsInstance(article_pub_date, datetime)
+ # Sleep slightly to ensure a different database level NOW().
+ time.sleep(0.1)
+ article.pub_date = Now()
+ article.save()
+ expected_num_queries = (
+ 0 if connection.features.can_return_rows_from_update else 1
+ )
+ with self.assertNumQueries(expected_num_queries):
+ self.assertIsInstance(article.pub_date, datetime)
+ self.assertGreater(article.pub_date, article_pub_date)
+
class ModelLookupTest(TestCase):
@classmethod
diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py
index 27d88be621..6f18321aa7 100644
--- a/tests/expressions/tests.py
+++ b/tests/expressions/tests.py
@@ -420,8 +420,11 @@ class BasicExpressionsTests(TestCase):
# F expressions can be used to update attributes on single objects
self.gmbh.num_employees = F("num_employees") + 4
self.gmbh.save()
- self.gmbh.refresh_from_db()
- self.assertEqual(self.gmbh.num_employees, 36)
+ expected_num_queries = (
+ 0 if connection.features.can_return_rows_from_update else 1
+ )
+ with self.assertNumQueries(expected_num_queries):
+ self.assertEqual(self.gmbh.num_employees, 36)
def test_new_object_save(self):
# We should be able to use Funcs when inserting new data
@@ -1644,8 +1647,11 @@ class ExpressionsNumericTests(TestCase):
n = Number.objects.create(integer=1, decimal_value=Decimal("0.5"))
n.decimal_value = F("decimal_value") - Decimal("0.4")
n.save()
- n.refresh_from_db()
- self.assertEqual(n.decimal_value, Decimal("0.1"))
+ expected_num_queries = (
+ 0 if connection.features.can_return_rows_from_update else 1
+ )
+ with self.assertNumQueries(expected_num_queries):
+ self.assertEqual(n.decimal_value, Decimal("0.1"))
class ExpressionOperatorTests(TestCase):
diff --git a/tests/field_defaults/tests.py b/tests/field_defaults/tests.py
index e914adfc51..7f85d946f6 100644
--- a/tests/field_defaults/tests.py
+++ b/tests/field_defaults/tests.py
@@ -15,13 +15,7 @@ from django.db.models.expressions import (
)
from django.db.models.functions import Collate
from django.db.models.lookups import GreaterThan
-from django.test import (
- SimpleTestCase,
- TestCase,
- override_settings,
- skipIfDBFeature,
- skipUnlessDBFeature,
-)
+from django.test import SimpleTestCase, TestCase, override_settings, skipUnlessDBFeature
from django.utils import timezone
from .models import (
@@ -44,47 +38,56 @@ class DefaultTests(TestCase):
self.assertEqual(a.headline, "Default headline")
self.assertLess((now - a.pub_date).seconds, 5)
- @skipUnlessDBFeature(
- "can_return_columns_from_insert", "supports_expression_defaults"
- )
+ @skipUnlessDBFeature("supports_expression_defaults")
def test_field_db_defaults_returning(self):
a = DBArticle()
a.save()
self.assertIsInstance(a.id, int)
- self.assertEqual(a.headline, "Default headline")
- self.assertIsInstance(a.pub_date, datetime)
- self.assertEqual(a.cost, Decimal("3.33"))
+ expected_num_queries = (
+ 0 if connection.features.can_return_columns_from_insert else 3
+ )
+ with self.assertNumQueries(expected_num_queries):
+ self.assertEqual(a.headline, "Default headline")
+ self.assertIsInstance(a.pub_date, datetime)
+ self.assertEqual(a.cost, Decimal("3.33"))
- @skipIfDBFeature("can_return_columns_from_insert")
@skipUnlessDBFeature("supports_expression_defaults")
def test_field_db_defaults_refresh(self):
a = DBArticle()
a.save()
- a.refresh_from_db()
+ expected_num_queries = (
+ 0 if connection.features.can_return_columns_from_insert else 3
+ )
self.assertIsInstance(a.id, int)
- self.assertEqual(a.headline, "Default headline")
- self.assertIsInstance(a.pub_date, datetime)
- self.assertEqual(a.cost, Decimal("3.33"))
+ with self.assertNumQueries(expected_num_queries):
+ self.assertEqual(a.headline, "Default headline")
+ self.assertIsInstance(a.pub_date, datetime)
+ self.assertEqual(a.cost, Decimal("3.33"))
def test_null_db_default(self):
obj1 = DBDefaults.objects.create()
- if not connection.features.can_return_columns_from_insert:
- obj1.refresh_from_db()
- self.assertEqual(obj1.null, 1.1)
+ expected_num_queries = (
+ 0 if connection.features.can_return_columns_from_insert else 1
+ )
+ with self.assertNumQueries(expected_num_queries):
+ self.assertEqual(obj1.null, 1.1)
obj2 = DBDefaults.objects.create(null=None)
- self.assertIsNone(obj2.null)
+ with self.assertNumQueries(0):
+ self.assertIsNone(obj2.null)
@skipUnlessDBFeature("supports_expression_defaults")
@override_settings(USE_TZ=True)
def test_db_default_function(self):
m = DBDefaultsFunction.objects.create()
- if not connection.features.can_return_columns_from_insert:
- m.refresh_from_db()
- self.assertAlmostEqual(m.number, pi)
- self.assertEqual(m.year, timezone.now().year)
- self.assertAlmostEqual(m.added, pi + 4.5)
- self.assertEqual(m.multiple_subfunctions, 4.5)
+ expected_num_queries = (
+ 0 if connection.features.can_return_columns_from_insert else 4
+ )
+ with self.assertNumQueries(expected_num_queries):
+ self.assertAlmostEqual(m.number, pi)
+ self.assertEqual(m.year, timezone.now().year)
+ self.assertAlmostEqual(m.added, pi + 4.5)
+ self.assertEqual(m.multiple_subfunctions, 4.5)
@skipUnlessDBFeature("insert_test_table_with_defaults")
def test_both_default(self):
@@ -125,14 +128,15 @@ class DefaultTests(TestCase):
child2 = DBDefaultsFK.objects.create(language_code=parent2)
self.assertEqual(child2.language_code, parent2)
- @skipUnlessDBFeature(
- "can_return_columns_from_insert", "supports_expression_defaults"
- )
+ @skipUnlessDBFeature("supports_expression_defaults")
def test_case_when_db_default_returning(self):
m = DBDefaultsFunction.objects.create()
- self.assertEqual(m.case_when, 3)
+ expected_num_queries = (
+ 0 if connection.features.can_return_columns_from_insert else 1
+ )
+ with self.assertNumQueries(expected_num_queries):
+ self.assertEqual(m.case_when, 3)
- @skipIfDBFeature("can_return_columns_from_insert")
@skipUnlessDBFeature("supports_expression_defaults")
def test_case_when_db_default_no_returning(self):
m = DBDefaultsFunction.objects.create()
diff --git a/tests/update_only_fields/tests.py b/tests/update_only_fields/tests.py
index 9595c767eb..1c7ef88832 100644
--- a/tests/update_only_fields/tests.py
+++ b/tests/update_only_fields/tests.py
@@ -1,5 +1,6 @@
from django.core.exceptions import ObjectNotUpdated
-from django.db import DatabaseError, transaction
+from django.db import DatabaseError, connection, transaction
+from django.db.models import F
from django.db.models.signals import post_save, pre_save
from django.test import TestCase
@@ -308,3 +309,16 @@ class UpdateOnlyFieldsTests(TestCase):
transaction.atomic(),
):
obj.save(update_fields=["name"])
+
+ def test_update_fields_expression(self):
+ obj = Person.objects.create(name="Valerie", gender="F", pid=42)
+ updated_pid = F("pid") + 1
+ obj.pid = updated_pid
+ obj.save(update_fields={"gender"})
+ self.assertIs(obj.pid, updated_pid)
+ obj.save(update_fields={"pid"})
+ expected_num_queries = (
+ 0 if connection.features.can_return_rows_from_update else 1
+ )
+ with self.assertNumQueries(expected_num_queries):
+ self.assertEqual(obj.pid, 43)