diff options
| author | Clifford Gama <cliffygamy@gmail.com> | 2025-10-30 20:59:15 +0200 |
|---|---|---|
| committer | Jacob Walls <jacobtylerwalls@gmail.com> | 2025-12-10 17:45:51 -0500 |
| commit | 66fed37ecb78daf0a50e95151a752b5760293514 (patch) | |
| tree | 5a83bb79028cdc23c390e4a1db38c684dd4ebdea | |
| parent | bbabbac936caf6db129427e6e65f03b6d0a68f62 (diff) | |
Fixed #36689 -- Fixed top-level JSONField __in lookup failures on MySQL and Oracle.
Added a JSONIn lookup to handle correct serialization and extraction
for JSONField top-level __in queries on backends without native JSON
support. KeyTransformIn now subclasses JSONIn.
Co-authored-by: Jacob Walls <jacobtylerwalls@gmail.com>
Thanks Jacob Walls for the report and review.
| -rw-r--r-- | django/db/models/fields/json.py | 175 | ||||
| -rw-r--r-- | tests/model_fields/test_jsonfield.py | 20 |
2 files changed, 135 insertions, 60 deletions
diff --git a/django/db/models/fields/json.py b/django/db/models/fields/json.py index 8ca9cc6811..c4817317f0 100644 --- a/django/db/models/fields/json.py +++ b/django/db/models/fields/json.py @@ -375,6 +375,114 @@ class JSONIContains(CaseInsensitiveMixin, lookups.IContains): pass +class ProcessJSONLHSMixin: + def _get_json_path(self, connection, key_transforms): + if key_transforms is None: + return "$" + return connection.ops.compile_json_path(key_transforms) + + def _process_as_oracle(self, sql, params, connection, key_transforms=None): + json_path = self._get_json_path(connection, key_transforms) + if connection.features.supports_primitives_in_json_field: + template = ( + "COALESCE(" + "JSON_VALUE(%s, q'\uffff%s\uffff')," + "JSON_QUERY(%s, q'\uffff%s\uffff' DISALLOW SCALARS)" + ")" + ) + else: + template = ( + "COALESCE(" + "JSON_QUERY(%s, q'\uffff%s\uffff')," + "JSON_VALUE(%s, q'\uffff%s\uffff')" + ")" + ) + # Add paths directly into SQL because path expressions cannot be passed + # as bind variables on Oracle. Use a custom delimiter to prevent the + # JSON path from escaping the SQL literal. Each key in the JSON path is + # passed through json.dumps() with ensure_ascii=True (the default), + # which converts the delimiter into the escaped \uffff format. This + # ensures that the delimiter is not present in the JSON path. + sql = template % ((sql, json_path) * 2) + return sql, params * 2 + + def _process_as_sqlite(self, sql, params, connection, key_transforms=None): + json_path = self._get_json_path(connection, key_transforms) + datatype_values = ",".join( + [repr(value) for value in connection.ops.jsonfield_datatype_values] + ) + return ( + "(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) " + "THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)" + ) % (sql, datatype_values, sql, sql), (*params, json_path) * 3 + + def _process_as_mysql(self, sql, params, connection, key_transforms=None): + json_path = self._get_json_path(connection, key_transforms) + return "JSON_EXTRACT(%s, %%s)" % sql, (*params, json_path) + + +class JSONIn(ProcessJSONLHSMixin, lookups.In): + def resolve_expression_parameter(self, compiler, connection, sql, param): + sql, params = super().resolve_expression_parameter( + compiler, + connection, + sql, + param, + ) + if ( + not hasattr(param, "as_sql") + and not connection.features.has_native_json_field + ): + if connection.vendor == "oracle": + value = json.loads(param) + sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')" + if isinstance(value, (list, dict)): + sql %= "JSON_QUERY" + else: + sql %= "JSON_VALUE" + elif connection.vendor == "mysql" or ( + connection.vendor == "sqlite" + and params[0] not in connection.ops.jsonfield_datatype_values + ): + sql = "JSON_EXTRACT(%s, '$')" + if connection.vendor == "mysql" and connection.mysql_is_mariadb: + sql = "JSON_UNQUOTE(%s)" % sql + return sql, params + + def process_lhs(self, compiler, connection): + sql, params = super().process_lhs(compiler, connection) + if isinstance(self.lhs, KeyTransform): + return sql, params + if connection.vendor == "mysql": + return self._process_as_mysql(sql, params, connection) + elif connection.vendor == "oracle": + return self._process_as_oracle(sql, params, connection) + elif connection.vendor == "sqlite": + return self._process_as_sqlite(sql, params, connection) + return sql, params + + def as_oracle(self, compiler, connection): + if ( + connection.features.supports_primitives_in_json_field + and isinstance(self.rhs, expressions.ExpressionList) + and JSONNull() in self.rhs.get_source_expressions() + ): + # Break the lookup into multiple exact lookups combined with OR, as + # Oracle does not support directly extracting JSON scalar null as a + # value in the right-hand side of an IN clause. + exact_lookup = self.lhs.get_lookup("exact") + sql_parts = [] + all_params = () + for expr in self.rhs.get_source_expressions(): + lookup = exact_lookup(self.lhs, expr) + sql, params = lookup.as_oracle(compiler, connection) + sql_parts.append(f"({sql})") + all_params = (*all_params, *params) + sql = " OR ".join(sql_parts) + return sql, all_params + return self.as_sql(compiler, connection) + + JSONField.register_lookup(DataContains) JSONField.register_lookup(ContainedBy) JSONField.register_lookup(HasKey) @@ -382,9 +490,10 @@ JSONField.register_lookup(HasKeys) JSONField.register_lookup(HasAnyKeys) JSONField.register_lookup(JSONExact) JSONField.register_lookup(JSONIContains) +JSONField.register_lookup(JSONIn) -class KeyTransform(Transform): +class KeyTransform(ProcessJSONLHSMixin, Transform): postgres_operator = "->" postgres_nested_operator = "#>" @@ -406,33 +515,11 @@ class KeyTransform(Transform): def as_mysql(self, compiler, connection): lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) - json_path = connection.ops.compile_json_path(key_transforms) - return "JSON_EXTRACT(%s, %%s)" % lhs, (*params, json_path) + return self._process_as_mysql(lhs, params, connection, key_transforms) def as_oracle(self, compiler, connection): lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) - json_path = connection.ops.compile_json_path(key_transforms) - if connection.features.supports_primitives_in_json_field: - sql = ( - "COALESCE(" - "JSON_VALUE(%s, q'\uffff%s\uffff')," - "JSON_QUERY(%s, q'\uffff%s\uffff' DISALLOW SCALARS)" - ")" - ) - else: - sql = ( - "COALESCE(" - "JSON_QUERY(%s, q'\uffff%s\uffff')," - "JSON_VALUE(%s, q'\uffff%s\uffff')" - ")" - ) - # Add paths directly into SQL because path expressions cannot be passed - # as bind variables on Oracle. Use a custom delimiter to prevent the - # JSON path from escaping the SQL literal. Each key in the JSON path is - # passed through json.dumps() with ensure_ascii=True (the default), - # which converts the delimiter into the escaped \uffff format. This - # ensures that the delimiter is not present in the JSON path. - return sql % ((lhs, json_path) * 2), tuple(params) * 2 + return self._process_as_oracle(lhs, params, connection, key_transforms) def as_postgresql(self, compiler, connection): lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) @@ -447,14 +534,7 @@ class KeyTransform(Transform): def as_sqlite(self, compiler, connection): lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) - json_path = connection.ops.compile_json_path(key_transforms) - datatype_values = ",".join( - [repr(datatype) for datatype in connection.ops.jsonfield_datatype_values] - ) - return ( - "(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) " - "THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)" - ) % (lhs, datatype_values, lhs, lhs), (*params, json_path) * 3 + return self._process_as_sqlite(lhs, params, connection, key_transforms) class KeyTextTransform(KeyTransform): @@ -535,33 +615,8 @@ class KeyTransformIsNull(lookups.IsNull): ) -class KeyTransformIn(lookups.In): - def resolve_expression_parameter(self, compiler, connection, sql, param): - sql, params = super().resolve_expression_parameter( - compiler, - connection, - sql, - param, - ) - if ( - not hasattr(param, "as_sql") - and not connection.features.has_native_json_field - ): - if connection.vendor == "oracle": - value = json.loads(param) - sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')" - if isinstance(value, (list, dict)): - sql %= "JSON_QUERY" - else: - sql %= "JSON_VALUE" - elif connection.vendor == "mysql" or ( - connection.vendor == "sqlite" - and params[0] not in connection.ops.jsonfield_datatype_values - ): - sql = "JSON_EXTRACT(%s, '$')" - if connection.vendor == "mysql" and connection.mysql_is_mariadb: - sql = "JSON_UNQUOTE(%s)" % sql - return sql, params +class KeyTransformIn(JSONIn): + pass class KeyTransformExact(JSONExact): diff --git a/tests/model_fields/test_jsonfield.py b/tests/model_fields/test_jsonfield.py index cf8e1888e5..cd3cad734e 100644 --- a/tests/model_fields/test_jsonfield.py +++ b/tests/model_fields/test_jsonfield.py @@ -1010,6 +1010,19 @@ class TestQuerying(TestCase): NullableJSONModel.objects.filter(value__foo__iexact='"BaR"').exists(), False ) + def test_in(self): + tests = [ + ([[]], [self.objs[1]]), + ([{}], [self.objs[2]]), + ([{"a": "b", "c": 14}], [self.objs[3]]), + ([[1, [2]]], [self.objs[5]]), + ] + for lookup_value, expected in tests: + with self.subTest(value__in=lookup_value): + self.assertCountEqual( + NullableJSONModel.objects.filter(value__in=lookup_value), expected + ) + def test_key_in(self): tests = [ ("value__c__in", [14], self.objs[3:5]), @@ -1297,6 +1310,13 @@ class JSONNullTests(TestCase): NullableJSONModel.objects.filter(value__isnull=True), [sql_null] ) + def test_filter_in(self): + obj = NullableJSONModel.objects.create(value=JSONNull()) + self.assertSequenceEqual( + NullableJSONModel.objects.filter(value__in=[JSONNull()]), + [obj], + ) + def test_bulk_update(self): obj1 = NullableJSONModel.objects.create(value={"k": "1st"}) obj2 = NullableJSONModel.objects.create(value={"k": "2nd"}) |
