diff options
Diffstat (limited to 'django/db/models/fields/json.py')
| -rw-r--r-- | django/db/models/fields/json.py | 249 |
1 files changed, 140 insertions, 109 deletions
diff --git a/django/db/models/fields/json.py b/django/db/models/fields/json.py index efb4e2f6ed..fdca700c9d 100644 --- a/django/db/models/fields/json.py +++ b/django/db/models/fields/json.py @@ -10,32 +10,36 @@ from django.utils.translation import gettext_lazy as _ from . import Field from .mixins import CheckFieldDefaultMixin -__all__ = ['JSONField'] +__all__ = ["JSONField"] class JSONField(CheckFieldDefaultMixin, Field): empty_strings_allowed = False - description = _('A JSON object') + description = _("A JSON object") default_error_messages = { - 'invalid': _('Value must be valid JSON.'), + "invalid": _("Value must be valid JSON."), } - _default_hint = ('dict', '{}') + _default_hint = ("dict", "{}") def __init__( - self, verbose_name=None, name=None, encoder=None, decoder=None, + self, + verbose_name=None, + name=None, + encoder=None, + decoder=None, **kwargs, ): if encoder and not callable(encoder): - raise ValueError('The encoder parameter must be a callable object.') + raise ValueError("The encoder parameter must be a callable object.") if decoder and not callable(decoder): - raise ValueError('The decoder parameter must be a callable object.') + raise ValueError("The decoder parameter must be a callable object.") self.encoder = encoder self.decoder = decoder super().__init__(verbose_name, name, **kwargs) def check(self, **kwargs): errors = super().check(**kwargs) - databases = kwargs.get('databases') or [] + databases = kwargs.get("databases") or [] errors.extend(self._check_supported(databases)) return errors @@ -46,20 +50,19 @@ class JSONField(CheckFieldDefaultMixin, Field): continue connection = connections[db] if ( - self.model._meta.required_db_vendor and - self.model._meta.required_db_vendor != connection.vendor + self.model._meta.required_db_vendor + and self.model._meta.required_db_vendor != connection.vendor ): continue if not ( - 'supports_json_field' in self.model._meta.required_db_features or - connection.features.supports_json_field + "supports_json_field" in self.model._meta.required_db_features + or connection.features.supports_json_field ): errors.append( checks.Error( - '%s does not support JSONFields.' - % connection.display_name, + "%s does not support JSONFields." % connection.display_name, obj=self.model, - id='fields.E180', + id="fields.E180", ) ) return errors @@ -67,9 +70,9 @@ class JSONField(CheckFieldDefaultMixin, Field): def deconstruct(self): name, path, args, kwargs = super().deconstruct() if self.encoder is not None: - kwargs['encoder'] = self.encoder + kwargs["encoder"] = self.encoder if self.decoder is not None: - kwargs['decoder'] = self.decoder + kwargs["decoder"] = self.decoder return name, path, args, kwargs def from_db_value(self, value, expression, connection): @@ -85,7 +88,7 @@ class JSONField(CheckFieldDefaultMixin, Field): return value def get_internal_type(self): - return 'JSONField' + return "JSONField" def get_prep_value(self, value): if value is None: @@ -104,64 +107,66 @@ class JSONField(CheckFieldDefaultMixin, Field): json.dumps(value, cls=self.encoder) except TypeError: raise exceptions.ValidationError( - self.error_messages['invalid'], - code='invalid', - params={'value': value}, + self.error_messages["invalid"], + code="invalid", + params={"value": value}, ) def value_to_string(self, obj): return self.value_from_object(obj) def formfield(self, **kwargs): - return super().formfield(**{ - 'form_class': forms.JSONField, - 'encoder': self.encoder, - 'decoder': self.decoder, - **kwargs, - }) + return super().formfield( + **{ + "form_class": forms.JSONField, + "encoder": self.encoder, + "decoder": self.decoder, + **kwargs, + } + ) def compile_json_path(key_transforms, include_root=True): - path = ['$'] if include_root else [] + path = ["$"] if include_root else [] for key_transform in key_transforms: try: num = int(key_transform) except ValueError: # non-integer - path.append('.') + path.append(".") path.append(json.dumps(key_transform)) else: - path.append('[%s]' % num) - return ''.join(path) + path.append("[%s]" % num) + return "".join(path) class DataContains(PostgresOperatorLookup): - lookup_name = 'contains' - postgres_operator = '@>' + lookup_name = "contains" + postgres_operator = "@>" def as_sql(self, compiler, connection): if not connection.features.supports_json_field_contains: raise NotSupportedError( - 'contains lookup is not supported on this database backend.' + "contains lookup is not supported on this database backend." ) lhs, lhs_params = self.process_lhs(compiler, connection) rhs, rhs_params = self.process_rhs(compiler, connection) params = tuple(lhs_params) + tuple(rhs_params) - return 'JSON_CONTAINS(%s, %s)' % (lhs, rhs), params + return "JSON_CONTAINS(%s, %s)" % (lhs, rhs), params class ContainedBy(PostgresOperatorLookup): - lookup_name = 'contained_by' - postgres_operator = '<@' + lookup_name = "contained_by" + postgres_operator = "<@" def as_sql(self, compiler, connection): if not connection.features.supports_json_field_contains: raise NotSupportedError( - 'contained_by lookup is not supported on this database backend.' + "contained_by lookup is not supported on this database backend." ) lhs, lhs_params = self.process_lhs(compiler, connection) rhs, rhs_params = self.process_rhs(compiler, connection) params = tuple(rhs_params) + tuple(lhs_params) - return 'JSON_CONTAINS(%s, %s)' % (rhs, lhs), params + return "JSON_CONTAINS(%s, %s)" % (rhs, lhs), params class HasKeyLookup(PostgresOperatorLookup): @@ -170,11 +175,13 @@ class HasKeyLookup(PostgresOperatorLookup): def as_sql(self, compiler, connection, template=None): # Process JSON path from the left-hand side. if isinstance(self.lhs, KeyTransform): - lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(compiler, connection) + lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs( + compiler, connection + ) lhs_json_path = compile_json_path(lhs_key_transforms) else: lhs, lhs_params = self.process_lhs(compiler, connection) - lhs_json_path = '$' + lhs_json_path = "$" sql = template % lhs # Process JSON path from the right-hand side. rhs = self.rhs @@ -186,20 +193,27 @@ class HasKeyLookup(PostgresOperatorLookup): *_, rhs_key_transforms = key.preprocess_lhs(compiler, connection) else: rhs_key_transforms = [key] - rhs_params.append('%s%s' % ( - lhs_json_path, - compile_json_path(rhs_key_transforms, include_root=False), - )) + rhs_params.append( + "%s%s" + % ( + lhs_json_path, + compile_json_path(rhs_key_transforms, include_root=False), + ) + ) # Add condition for each key. if self.logical_operator: - sql = '(%s)' % self.logical_operator.join([sql] * len(rhs_params)) + sql = "(%s)" % self.logical_operator.join([sql] * len(rhs_params)) return sql, tuple(lhs_params) + tuple(rhs_params) def as_mysql(self, compiler, connection): - return self.as_sql(compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)") + return self.as_sql( + compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)" + ) def as_oracle(self, compiler, connection): - sql, params = self.as_sql(compiler, connection, template="JSON_EXISTS(%s, '%%s')") + sql, params = self.as_sql( + compiler, connection, template="JSON_EXISTS(%s, '%%s')" + ) # Add paths directly into SQL because path expressions cannot be passed # as bind variables on Oracle. return sql % tuple(params), [] @@ -213,28 +227,30 @@ class HasKeyLookup(PostgresOperatorLookup): return super().as_postgresql(compiler, connection) def as_sqlite(self, compiler, connection): - return self.as_sql(compiler, connection, template='JSON_TYPE(%s, %%s) IS NOT NULL') + return self.as_sql( + compiler, connection, template="JSON_TYPE(%s, %%s) IS NOT NULL" + ) class HasKey(HasKeyLookup): - lookup_name = 'has_key' - postgres_operator = '?' + lookup_name = "has_key" + postgres_operator = "?" prepare_rhs = False class HasKeys(HasKeyLookup): - lookup_name = 'has_keys' - postgres_operator = '?&' - logical_operator = ' AND ' + lookup_name = "has_keys" + postgres_operator = "?&" + logical_operator = " AND " def get_prep_lookup(self): return [str(item) for item in self.rhs] class HasAnyKeys(HasKeys): - lookup_name = 'has_any_keys' - postgres_operator = '?|' - logical_operator = ' OR ' + lookup_name = "has_any_keys" + postgres_operator = "?|" + logical_operator = " OR " class CaseInsensitiveMixin: @@ -244,16 +260,17 @@ class CaseInsensitiveMixin: Because utf8mb4_bin is a binary collation, comparison of JSON values is case-sensitive. """ + def process_lhs(self, compiler, connection): lhs, lhs_params = super().process_lhs(compiler, connection) - if connection.vendor == 'mysql': - return 'LOWER(%s)' % lhs, lhs_params + if connection.vendor == "mysql": + return "LOWER(%s)" % lhs, lhs_params return lhs, lhs_params def process_rhs(self, compiler, connection): rhs, rhs_params = super().process_rhs(compiler, connection) - if connection.vendor == 'mysql': - return 'LOWER(%s)' % rhs, rhs_params + if connection.vendor == "mysql": + return "LOWER(%s)" % rhs, rhs_params return rhs, rhs_params @@ -263,9 +280,9 @@ class JSONExact(lookups.Exact): def process_rhs(self, compiler, connection): rhs, rhs_params = super().process_rhs(compiler, connection) # Treat None lookup values as null. - if rhs == '%s' and rhs_params == [None]: - rhs_params = ['null'] - if connection.vendor == 'mysql': + if rhs == "%s" and rhs_params == [None]: + rhs_params = ["null"] + if connection.vendor == "mysql": func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params) rhs = rhs % tuple(func) return rhs, rhs_params @@ -285,8 +302,8 @@ JSONField.register_lookup(JSONIContains) class KeyTransform(Transform): - postgres_operator = '->' - postgres_nested_operator = '#>' + postgres_operator = "->" + postgres_nested_operator = "#>" def __init__(self, key_name, *args, **kwargs): super().__init__(*args, **kwargs) @@ -299,41 +316,41 @@ class KeyTransform(Transform): key_transforms.insert(0, previous.key_name) previous = previous.lhs lhs, params = compiler.compile(previous) - if connection.vendor == 'oracle': + if connection.vendor == "oracle": # Escape string-formatting. - key_transforms = [key.replace('%', '%%') for key in key_transforms] + key_transforms = [key.replace("%", "%%") for key in key_transforms] return lhs, params, key_transforms def as_mysql(self, compiler, connection): lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) json_path = compile_json_path(key_transforms) - return 'JSON_EXTRACT(%s, %%s)' % lhs, tuple(params) + (json_path,) + return "JSON_EXTRACT(%s, %%s)" % lhs, tuple(params) + (json_path,) def as_oracle(self, compiler, connection): lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) json_path = compile_json_path(key_transforms) return ( - "COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))" % - ((lhs, json_path) * 2) + "COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))" + % ((lhs, json_path) * 2) ), tuple(params) * 2 def as_postgresql(self, compiler, connection): lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) if len(key_transforms) > 1: - sql = '(%s %s %%s)' % (lhs, self.postgres_nested_operator) + sql = "(%s %s %%s)" % (lhs, self.postgres_nested_operator) return sql, tuple(params) + (key_transforms,) try: lookup = int(self.key_name) except ValueError: lookup = self.key_name - return '(%s %s %%s)' % (lhs, self.postgres_operator), tuple(params) + (lookup,) + return "(%s %s %%s)" % (lhs, self.postgres_operator), tuple(params) + (lookup,) def as_sqlite(self, compiler, connection): lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) json_path = compile_json_path(key_transforms) - datatype_values = ','.join([ - repr(datatype) for datatype in connection.ops.jsonfield_datatype_values - ]) + datatype_values = ",".join( + [repr(datatype) for datatype in connection.ops.jsonfield_datatype_values] + ) return ( "(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) " "THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)" @@ -341,8 +358,8 @@ class KeyTransform(Transform): class KeyTextTransform(KeyTransform): - postgres_operator = '->>' - postgres_nested_operator = '#>>' + postgres_operator = "->>" + postgres_nested_operator = "#>>" class KeyTransformTextLookupMixin: @@ -352,14 +369,16 @@ class KeyTransformTextLookupMixin: key values to text and performing the lookup on the resulting representation. """ + def __init__(self, key_transform, *args, **kwargs): if not isinstance(key_transform, KeyTransform): raise TypeError( - 'Transform should be an instance of KeyTransform in order to ' - 'use this lookup.' + "Transform should be an instance of KeyTransform in order to " + "use this lookup." ) key_text_transform = KeyTextTransform( - key_transform.key_name, *key_transform.source_expressions, + key_transform.key_name, + *key_transform.source_expressions, **key_transform.extra, ) super().__init__(key_text_transform, *args, **kwargs) @@ -376,12 +395,12 @@ class KeyTransformIsNull(lookups.IsNull): return sql, params # Column doesn't have a key or IS NULL. lhs, lhs_params, _ = self.lhs.preprocess_lhs(compiler, connection) - return '(NOT %s OR %s IS NULL)' % (sql, lhs), tuple(params) + tuple(lhs_params) + return "(NOT %s OR %s IS NULL)" % (sql, lhs), tuple(params) + tuple(lhs_params) def as_sqlite(self, compiler, connection): - template = 'JSON_TYPE(%s, %%s) IS NULL' + template = "JSON_TYPE(%s, %%s) IS NULL" if not self.rhs: - template = 'JSON_TYPE(%s, %%s) IS NOT NULL' + template = "JSON_TYPE(%s, %%s) IS NOT NULL" return HasKey(self.lhs.lhs, self.lhs.key_name).as_sql( compiler, connection, @@ -392,26 +411,29 @@ class KeyTransformIsNull(lookups.IsNull): class KeyTransformIn(lookups.In): def resolve_expression_parameter(self, compiler, connection, sql, param): sql, params = super().resolve_expression_parameter( - compiler, connection, sql, param, + compiler, + connection, + sql, + param, ) if ( - not hasattr(param, 'as_sql') and - not connection.features.has_native_json_field + not hasattr(param, "as_sql") + and not connection.features.has_native_json_field ): - if connection.vendor == 'oracle': + if connection.vendor == "oracle": value = json.loads(param) sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')" if isinstance(value, (list, dict)): - sql = sql % 'JSON_QUERY' + sql = sql % "JSON_QUERY" else: - sql = sql % 'JSON_VALUE' - elif connection.vendor == 'mysql' or ( - connection.vendor == 'sqlite' and - params[0] not in connection.ops.jsonfield_datatype_values + sql = sql % "JSON_VALUE" + elif connection.vendor == "mysql" or ( + connection.vendor == "sqlite" + and params[0] not in connection.ops.jsonfield_datatype_values ): sql = "JSON_EXTRACT(%s, '$')" - if connection.vendor == 'mysql' and connection.mysql_is_mariadb: - sql = 'JSON_UNQUOTE(%s)' % sql + if connection.vendor == "mysql" and connection.mysql_is_mariadb: + sql = "JSON_UNQUOTE(%s)" % sql return sql, params @@ -420,21 +442,21 @@ class KeyTransformExact(JSONExact): if isinstance(self.rhs, KeyTransform): return super(lookups.Exact, self).process_rhs(compiler, connection) rhs, rhs_params = super().process_rhs(compiler, connection) - if connection.vendor == 'oracle': + if connection.vendor == "oracle": func = [] sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')" for value in rhs_params: value = json.loads(value) if isinstance(value, (list, dict)): - func.append(sql % 'JSON_QUERY') + func.append(sql % "JSON_QUERY") else: - func.append(sql % 'JSON_VALUE') + func.append(sql % "JSON_VALUE") rhs = rhs % tuple(func) - elif connection.vendor == 'sqlite': + elif connection.vendor == "sqlite": func = [] for value in rhs_params: if value in connection.ops.jsonfield_datatype_values: - func.append('%s') + func.append("%s") else: func.append("JSON_EXTRACT(%s, '$')") rhs = rhs % tuple(func) @@ -442,24 +464,28 @@ class KeyTransformExact(JSONExact): def as_oracle(self, compiler, connection): rhs, rhs_params = super().process_rhs(compiler, connection) - if rhs_params == ['null']: + if rhs_params == ["null"]: # Field has key and it's NULL. has_key_expr = HasKey(self.lhs.lhs, self.lhs.key_name) has_key_sql, has_key_params = has_key_expr.as_oracle(compiler, connection) - is_null_expr = self.lhs.get_lookup('isnull')(self.lhs, True) + is_null_expr = self.lhs.get_lookup("isnull")(self.lhs, True) is_null_sql, is_null_params = is_null_expr.as_sql(compiler, connection) return ( - '%s AND %s' % (has_key_sql, is_null_sql), + "%s AND %s" % (has_key_sql, is_null_sql), tuple(has_key_params) + tuple(is_null_params), ) return super().as_sql(compiler, connection) -class KeyTransformIExact(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact): +class KeyTransformIExact( + CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact +): pass -class KeyTransformIContains(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains): +class KeyTransformIContains( + CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains +): pass @@ -467,7 +493,9 @@ class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith): pass -class KeyTransformIStartsWith(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith): +class KeyTransformIStartsWith( + CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith +): pass @@ -475,7 +503,9 @@ class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith): pass -class KeyTransformIEndsWith(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith): +class KeyTransformIEndsWith( + CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith +): pass @@ -483,7 +513,9 @@ class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex): pass -class KeyTransformIRegex(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex): +class KeyTransformIRegex( + CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex +): pass @@ -530,7 +562,6 @@ KeyTransform.register_lookup(KeyTransformGte) class KeyTransformFactory: - def __init__(self, key_name): self.key_name = key_name |
