summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSimon Charette <charette.s@gmail.com>2025-02-07 16:34:17 -0500
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2025-02-15 15:46:59 +0100
commitdf2c4952df6d93c575fb8a3c853dc9d4c2449f36 (patch)
tree16bcf023ce3fb107a97a3276f743eb1e0046a7fd
parent6fcd0440aaa7601aa258d1c956eecfaedf72fbf4 (diff)
Fixed #36173 -- Stabilized identity of Concat with an explicit output_field.
When Expression.__init__() overrides make use of *args, **kwargs captures their argument values are respectively bound as a tuple and dict instances. These composite values might themselves contain values that require special identity treatments such as Concat(output_field) as it's a Field instance. Refs #30628 which introduced bound Field differentiation but lacked argument captures handling. Thanks erchenstein for the report.
-rw-r--r--django/db/models/expressions.py23
-rw-r--r--tests/db_functions/text/test_concat.py14
-rw-r--r--tests/expressions/tests.py23
3 files changed, 53 insertions, 7 deletions
diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py
index 57ceadcec4..444e2fab7b 100644
--- a/django/db/models/expressions.py
+++ b/django/db/models/expressions.py
@@ -523,6 +523,18 @@ class Expression(BaseExpression, Combinable):
def _constructor_signature(cls):
return inspect.signature(cls.__init__)
+ @classmethod
+ def _identity(cls, value):
+ if isinstance(value, tuple):
+ return tuple(map(cls._identity, value))
+ if isinstance(value, dict):
+ return tuple((key, cls._identity(val)) for key, val in value.items())
+ if isinstance(value, fields.Field):
+ if value.name and value.model:
+ return value.model._meta.label, value.name
+ return type(value)
+ return make_hashable(value)
+
@cached_property
def identity(self):
args, kwargs = self._constructor_args
@@ -532,13 +544,10 @@ class Expression(BaseExpression, Combinable):
next(arguments)
identity = [self.__class__]
for arg, value in arguments:
- if isinstance(value, fields.Field):
- if value.name and value.model:
- value = (value.model._meta.label, value.name)
- else:
- value = type(value)
- else:
- value = make_hashable(value)
+ # If __init__() makes use of *args or **kwargs captures `value`
+ # will respectively be a tuple or a dict that must have its
+ # constituents unpacked (mainly if contain Field instances).
+ value = self._identity(value)
identity.append((arg, value))
return tuple(identity)
diff --git a/tests/db_functions/text/test_concat.py b/tests/db_functions/text/test_concat.py
index 6e4cb91d3a..ffcd19fad6 100644
--- a/tests/db_functions/text/test_concat.py
+++ b/tests/db_functions/text/test_concat.py
@@ -107,3 +107,17 @@ class ConcatTests(TestCase):
ctx.captured_queries[0]["sql"].count("::text"),
1 if connection.vendor == "postgresql" else 0,
)
+
+ def test_equal(self):
+ self.assertEqual(
+ Concat("foo", "bar", output_field=TextField()),
+ Concat("foo", "bar", output_field=TextField()),
+ )
+ self.assertNotEqual(
+ Concat("foo", "bar", output_field=TextField()),
+ Concat("foo", "bar", output_field=CharField()),
+ )
+ self.assertNotEqual(
+ Concat("foo", "bar", output_field=TextField()),
+ Concat("bar", "foo", output_field=TextField()),
+ )
diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py
index cfa33b6f45..89601de85b 100644
--- a/tests/expressions/tests.py
+++ b/tests/expressions/tests.py
@@ -1433,6 +1433,29 @@ class SimpleExpressionTests(SimpleTestCase):
Expression(TestModel._meta.get_field("other_field")),
)
+ class InitCaptureExpression(Expression):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # The identity of expressions that obscure their __init__() signature
+ # with *args and **kwargs cannot be determined when bound with
+ # different combinations or *args and **kwargs.
+ self.assertNotEqual(
+ InitCaptureExpression(IntegerField()),
+ InitCaptureExpression(output_field=IntegerField()),
+ )
+
+ # However, they should be considered equal when their bindings are
+ # equal.
+ self.assertEqual(
+ InitCaptureExpression(IntegerField()),
+ InitCaptureExpression(IntegerField()),
+ )
+ self.assertEqual(
+ InitCaptureExpression(output_field=IntegerField()),
+ InitCaptureExpression(output_field=IntegerField()),
+ )
+
def test_hash(self):
self.assertEqual(hash(Expression()), hash(Expression()))
self.assertEqual(