summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClifford Gama <cliffygamy@gmail.com>2025-10-30 20:59:15 +0200
committerJacob Walls <jacobtylerwalls@gmail.com>2025-12-10 17:45:51 -0500
commit66fed37ecb78daf0a50e95151a752b5760293514 (patch)
tree5a83bb79028cdc23c390e4a1db38c684dd4ebdea
parentbbabbac936caf6db129427e6e65f03b6d0a68f62 (diff)
Fixed #36689 -- Fixed top-level JSONField __in lookup failures on MySQL and Oracle.
Added a JSONIn lookup to handle correct serialization and extraction for JSONField top-level __in queries on backends without native JSON support. KeyTransformIn now subclasses JSONIn. Co-authored-by: Jacob Walls <jacobtylerwalls@gmail.com> Thanks Jacob Walls for the report and review.
-rw-r--r--django/db/models/fields/json.py175
-rw-r--r--tests/model_fields/test_jsonfield.py20
2 files changed, 135 insertions, 60 deletions
diff --git a/django/db/models/fields/json.py b/django/db/models/fields/json.py
index 8ca9cc6811..c4817317f0 100644
--- a/django/db/models/fields/json.py
+++ b/django/db/models/fields/json.py
@@ -375,6 +375,114 @@ class JSONIContains(CaseInsensitiveMixin, lookups.IContains):
pass
+class ProcessJSONLHSMixin:
+ def _get_json_path(self, connection, key_transforms):
+ if key_transforms is None:
+ return "$"
+ return connection.ops.compile_json_path(key_transforms)
+
+ def _process_as_oracle(self, sql, params, connection, key_transforms=None):
+ json_path = self._get_json_path(connection, key_transforms)
+ if connection.features.supports_primitives_in_json_field:
+ template = (
+ "COALESCE("
+ "JSON_VALUE(%s, q'\uffff%s\uffff'),"
+ "JSON_QUERY(%s, q'\uffff%s\uffff' DISALLOW SCALARS)"
+ ")"
+ )
+ else:
+ template = (
+ "COALESCE("
+ "JSON_QUERY(%s, q'\uffff%s\uffff'),"
+ "JSON_VALUE(%s, q'\uffff%s\uffff')"
+ ")"
+ )
+ # Add paths directly into SQL because path expressions cannot be passed
+ # as bind variables on Oracle. Use a custom delimiter to prevent the
+ # JSON path from escaping the SQL literal. Each key in the JSON path is
+ # passed through json.dumps() with ensure_ascii=True (the default),
+ # which converts the delimiter into the escaped \uffff format. This
+ # ensures that the delimiter is not present in the JSON path.
+ sql = template % ((sql, json_path) * 2)
+ return sql, params * 2
+
+ def _process_as_sqlite(self, sql, params, connection, key_transforms=None):
+ json_path = self._get_json_path(connection, key_transforms)
+ datatype_values = ",".join(
+ [repr(value) for value in connection.ops.jsonfield_datatype_values]
+ )
+ return (
+ "(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) "
+ "THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)"
+ ) % (sql, datatype_values, sql, sql), (*params, json_path) * 3
+
+ def _process_as_mysql(self, sql, params, connection, key_transforms=None):
+ json_path = self._get_json_path(connection, key_transforms)
+ return "JSON_EXTRACT(%s, %%s)" % sql, (*params, json_path)
+
+
+class JSONIn(ProcessJSONLHSMixin, lookups.In):
+ def resolve_expression_parameter(self, compiler, connection, sql, param):
+ sql, params = super().resolve_expression_parameter(
+ compiler,
+ connection,
+ sql,
+ param,
+ )
+ if (
+ not hasattr(param, "as_sql")
+ and not connection.features.has_native_json_field
+ ):
+ if connection.vendor == "oracle":
+ value = json.loads(param)
+ sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
+ if isinstance(value, (list, dict)):
+ sql %= "JSON_QUERY"
+ else:
+ sql %= "JSON_VALUE"
+ elif connection.vendor == "mysql" or (
+ connection.vendor == "sqlite"
+ and params[0] not in connection.ops.jsonfield_datatype_values
+ ):
+ sql = "JSON_EXTRACT(%s, '$')"
+ if connection.vendor == "mysql" and connection.mysql_is_mariadb:
+ sql = "JSON_UNQUOTE(%s)" % sql
+ return sql, params
+
+ def process_lhs(self, compiler, connection):
+ sql, params = super().process_lhs(compiler, connection)
+ if isinstance(self.lhs, KeyTransform):
+ return sql, params
+ if connection.vendor == "mysql":
+ return self._process_as_mysql(sql, params, connection)
+ elif connection.vendor == "oracle":
+ return self._process_as_oracle(sql, params, connection)
+ elif connection.vendor == "sqlite":
+ return self._process_as_sqlite(sql, params, connection)
+ return sql, params
+
+ def as_oracle(self, compiler, connection):
+ if (
+ connection.features.supports_primitives_in_json_field
+ and isinstance(self.rhs, expressions.ExpressionList)
+ and JSONNull() in self.rhs.get_source_expressions()
+ ):
+ # Break the lookup into multiple exact lookups combined with OR, as
+ # Oracle does not support directly extracting JSON scalar null as a
+ # value in the right-hand side of an IN clause.
+ exact_lookup = self.lhs.get_lookup("exact")
+ sql_parts = []
+ all_params = ()
+ for expr in self.rhs.get_source_expressions():
+ lookup = exact_lookup(self.lhs, expr)
+ sql, params = lookup.as_oracle(compiler, connection)
+ sql_parts.append(f"({sql})")
+ all_params = (*all_params, *params)
+ sql = " OR ".join(sql_parts)
+ return sql, all_params
+ return self.as_sql(compiler, connection)
+
+
JSONField.register_lookup(DataContains)
JSONField.register_lookup(ContainedBy)
JSONField.register_lookup(HasKey)
@@ -382,9 +490,10 @@ JSONField.register_lookup(HasKeys)
JSONField.register_lookup(HasAnyKeys)
JSONField.register_lookup(JSONExact)
JSONField.register_lookup(JSONIContains)
+JSONField.register_lookup(JSONIn)
-class KeyTransform(Transform):
+class KeyTransform(ProcessJSONLHSMixin, Transform):
postgres_operator = "->"
postgres_nested_operator = "#>"
@@ -406,33 +515,11 @@ class KeyTransform(Transform):
def as_mysql(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
- json_path = connection.ops.compile_json_path(key_transforms)
- return "JSON_EXTRACT(%s, %%s)" % lhs, (*params, json_path)
+ return self._process_as_mysql(lhs, params, connection, key_transforms)
def as_oracle(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
- json_path = connection.ops.compile_json_path(key_transforms)
- if connection.features.supports_primitives_in_json_field:
- sql = (
- "COALESCE("
- "JSON_VALUE(%s, q'\uffff%s\uffff'),"
- "JSON_QUERY(%s, q'\uffff%s\uffff' DISALLOW SCALARS)"
- ")"
- )
- else:
- sql = (
- "COALESCE("
- "JSON_QUERY(%s, q'\uffff%s\uffff'),"
- "JSON_VALUE(%s, q'\uffff%s\uffff')"
- ")"
- )
- # Add paths directly into SQL because path expressions cannot be passed
- # as bind variables on Oracle. Use a custom delimiter to prevent the
- # JSON path from escaping the SQL literal. Each key in the JSON path is
- # passed through json.dumps() with ensure_ascii=True (the default),
- # which converts the delimiter into the escaped \uffff format. This
- # ensures that the delimiter is not present in the JSON path.
- return sql % ((lhs, json_path) * 2), tuple(params) * 2
+ return self._process_as_oracle(lhs, params, connection, key_transforms)
def as_postgresql(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
@@ -447,14 +534,7 @@ class KeyTransform(Transform):
def as_sqlite(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
- json_path = connection.ops.compile_json_path(key_transforms)
- datatype_values = ",".join(
- [repr(datatype) for datatype in connection.ops.jsonfield_datatype_values]
- )
- return (
- "(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) "
- "THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)"
- ) % (lhs, datatype_values, lhs, lhs), (*params, json_path) * 3
+ return self._process_as_sqlite(lhs, params, connection, key_transforms)
class KeyTextTransform(KeyTransform):
@@ -535,33 +615,8 @@ class KeyTransformIsNull(lookups.IsNull):
)
-class KeyTransformIn(lookups.In):
- def resolve_expression_parameter(self, compiler, connection, sql, param):
- sql, params = super().resolve_expression_parameter(
- compiler,
- connection,
- sql,
- param,
- )
- if (
- not hasattr(param, "as_sql")
- and not connection.features.has_native_json_field
- ):
- if connection.vendor == "oracle":
- value = json.loads(param)
- sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
- if isinstance(value, (list, dict)):
- sql %= "JSON_QUERY"
- else:
- sql %= "JSON_VALUE"
- elif connection.vendor == "mysql" or (
- connection.vendor == "sqlite"
- and params[0] not in connection.ops.jsonfield_datatype_values
- ):
- sql = "JSON_EXTRACT(%s, '$')"
- if connection.vendor == "mysql" and connection.mysql_is_mariadb:
- sql = "JSON_UNQUOTE(%s)" % sql
- return sql, params
+class KeyTransformIn(JSONIn):
+ pass
class KeyTransformExact(JSONExact):
diff --git a/tests/model_fields/test_jsonfield.py b/tests/model_fields/test_jsonfield.py
index cf8e1888e5..cd3cad734e 100644
--- a/tests/model_fields/test_jsonfield.py
+++ b/tests/model_fields/test_jsonfield.py
@@ -1010,6 +1010,19 @@ class TestQuerying(TestCase):
NullableJSONModel.objects.filter(value__foo__iexact='"BaR"').exists(), False
)
+ def test_in(self):
+ tests = [
+ ([[]], [self.objs[1]]),
+ ([{}], [self.objs[2]]),
+ ([{"a": "b", "c": 14}], [self.objs[3]]),
+ ([[1, [2]]], [self.objs[5]]),
+ ]
+ for lookup_value, expected in tests:
+ with self.subTest(value__in=lookup_value):
+ self.assertCountEqual(
+ NullableJSONModel.objects.filter(value__in=lookup_value), expected
+ )
+
def test_key_in(self):
tests = [
("value__c__in", [14], self.objs[3:5]),
@@ -1297,6 +1310,13 @@ class JSONNullTests(TestCase):
NullableJSONModel.objects.filter(value__isnull=True), [sql_null]
)
+ def test_filter_in(self):
+ obj = NullableJSONModel.objects.create(value=JSONNull())
+ self.assertSequenceEqual(
+ NullableJSONModel.objects.filter(value__in=[JSONNull()]),
+ [obj],
+ )
+
def test_bulk_update(self):
obj1 = NullableJSONModel.objects.create(value={"k": "1st"})
obj2 = NullableJSONModel.objects.create(value={"k": "2nd"})