summaryrefslogtreecommitdiff
path: root/django/db/models/fields/json.py
diff options
context:
space:
mode:
authordjango-bot <ops@djangoproject.com>2022-02-03 20:24:19 +0100
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2022-02-07 20:37:05 +0100
commit9c19aff7c7561e3a82978a272ecdaad40dda5c00 (patch)
treef0506b668a013d0063e5fba3dbf4863b466713ba /django/db/models/fields/json.py
parentf68fa8b45dfac545cfc4111d4e52804c86db68d3 (diff)
Refs #33476 -- Reformatted code with Black.
Diffstat (limited to 'django/db/models/fields/json.py')
-rw-r--r--django/db/models/fields/json.py249
1 files changed, 140 insertions, 109 deletions
diff --git a/django/db/models/fields/json.py b/django/db/models/fields/json.py
index efb4e2f6ed..fdca700c9d 100644
--- a/django/db/models/fields/json.py
+++ b/django/db/models/fields/json.py
@@ -10,32 +10,36 @@ from django.utils.translation import gettext_lazy as _
from . import Field
from .mixins import CheckFieldDefaultMixin
-__all__ = ['JSONField']
+__all__ = ["JSONField"]
class JSONField(CheckFieldDefaultMixin, Field):
empty_strings_allowed = False
- description = _('A JSON object')
+ description = _("A JSON object")
default_error_messages = {
- 'invalid': _('Value must be valid JSON.'),
+ "invalid": _("Value must be valid JSON."),
}
- _default_hint = ('dict', '{}')
+ _default_hint = ("dict", "{}")
def __init__(
- self, verbose_name=None, name=None, encoder=None, decoder=None,
+ self,
+ verbose_name=None,
+ name=None,
+ encoder=None,
+ decoder=None,
**kwargs,
):
if encoder and not callable(encoder):
- raise ValueError('The encoder parameter must be a callable object.')
+ raise ValueError("The encoder parameter must be a callable object.")
if decoder and not callable(decoder):
- raise ValueError('The decoder parameter must be a callable object.')
+ raise ValueError("The decoder parameter must be a callable object.")
self.encoder = encoder
self.decoder = decoder
super().__init__(verbose_name, name, **kwargs)
def check(self, **kwargs):
errors = super().check(**kwargs)
- databases = kwargs.get('databases') or []
+ databases = kwargs.get("databases") or []
errors.extend(self._check_supported(databases))
return errors
@@ -46,20 +50,19 @@ class JSONField(CheckFieldDefaultMixin, Field):
continue
connection = connections[db]
if (
- self.model._meta.required_db_vendor and
- self.model._meta.required_db_vendor != connection.vendor
+ self.model._meta.required_db_vendor
+ and self.model._meta.required_db_vendor != connection.vendor
):
continue
if not (
- 'supports_json_field' in self.model._meta.required_db_features or
- connection.features.supports_json_field
+ "supports_json_field" in self.model._meta.required_db_features
+ or connection.features.supports_json_field
):
errors.append(
checks.Error(
- '%s does not support JSONFields.'
- % connection.display_name,
+ "%s does not support JSONFields." % connection.display_name,
obj=self.model,
- id='fields.E180',
+ id="fields.E180",
)
)
return errors
@@ -67,9 +70,9 @@ class JSONField(CheckFieldDefaultMixin, Field):
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if self.encoder is not None:
- kwargs['encoder'] = self.encoder
+ kwargs["encoder"] = self.encoder
if self.decoder is not None:
- kwargs['decoder'] = self.decoder
+ kwargs["decoder"] = self.decoder
return name, path, args, kwargs
def from_db_value(self, value, expression, connection):
@@ -85,7 +88,7 @@ class JSONField(CheckFieldDefaultMixin, Field):
return value
def get_internal_type(self):
- return 'JSONField'
+ return "JSONField"
def get_prep_value(self, value):
if value is None:
@@ -104,64 +107,66 @@ class JSONField(CheckFieldDefaultMixin, Field):
json.dumps(value, cls=self.encoder)
except TypeError:
raise exceptions.ValidationError(
- self.error_messages['invalid'],
- code='invalid',
- params={'value': value},
+ self.error_messages["invalid"],
+ code="invalid",
+ params={"value": value},
)
def value_to_string(self, obj):
return self.value_from_object(obj)
def formfield(self, **kwargs):
- return super().formfield(**{
- 'form_class': forms.JSONField,
- 'encoder': self.encoder,
- 'decoder': self.decoder,
- **kwargs,
- })
+ return super().formfield(
+ **{
+ "form_class": forms.JSONField,
+ "encoder": self.encoder,
+ "decoder": self.decoder,
+ **kwargs,
+ }
+ )
def compile_json_path(key_transforms, include_root=True):
- path = ['$'] if include_root else []
+ path = ["$"] if include_root else []
for key_transform in key_transforms:
try:
num = int(key_transform)
except ValueError: # non-integer
- path.append('.')
+ path.append(".")
path.append(json.dumps(key_transform))
else:
- path.append('[%s]' % num)
- return ''.join(path)
+ path.append("[%s]" % num)
+ return "".join(path)
class DataContains(PostgresOperatorLookup):
- lookup_name = 'contains'
- postgres_operator = '@>'
+ lookup_name = "contains"
+ postgres_operator = "@>"
def as_sql(self, compiler, connection):
if not connection.features.supports_json_field_contains:
raise NotSupportedError(
- 'contains lookup is not supported on this database backend.'
+ "contains lookup is not supported on this database backend."
)
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = tuple(lhs_params) + tuple(rhs_params)
- return 'JSON_CONTAINS(%s, %s)' % (lhs, rhs), params
+ return "JSON_CONTAINS(%s, %s)" % (lhs, rhs), params
class ContainedBy(PostgresOperatorLookup):
- lookup_name = 'contained_by'
- postgres_operator = '<@'
+ lookup_name = "contained_by"
+ postgres_operator = "<@"
def as_sql(self, compiler, connection):
if not connection.features.supports_json_field_contains:
raise NotSupportedError(
- 'contained_by lookup is not supported on this database backend.'
+ "contained_by lookup is not supported on this database backend."
)
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = tuple(rhs_params) + tuple(lhs_params)
- return 'JSON_CONTAINS(%s, %s)' % (rhs, lhs), params
+ return "JSON_CONTAINS(%s, %s)" % (rhs, lhs), params
class HasKeyLookup(PostgresOperatorLookup):
@@ -170,11 +175,13 @@ class HasKeyLookup(PostgresOperatorLookup):
def as_sql(self, compiler, connection, template=None):
# Process JSON path from the left-hand side.
if isinstance(self.lhs, KeyTransform):
- lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(compiler, connection)
+ lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(
+ compiler, connection
+ )
lhs_json_path = compile_json_path(lhs_key_transforms)
else:
lhs, lhs_params = self.process_lhs(compiler, connection)
- lhs_json_path = '$'
+ lhs_json_path = "$"
sql = template % lhs
# Process JSON path from the right-hand side.
rhs = self.rhs
@@ -186,20 +193,27 @@ class HasKeyLookup(PostgresOperatorLookup):
*_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
else:
rhs_key_transforms = [key]
- rhs_params.append('%s%s' % (
- lhs_json_path,
- compile_json_path(rhs_key_transforms, include_root=False),
- ))
+ rhs_params.append(
+ "%s%s"
+ % (
+ lhs_json_path,
+ compile_json_path(rhs_key_transforms, include_root=False),
+ )
+ )
# Add condition for each key.
if self.logical_operator:
- sql = '(%s)' % self.logical_operator.join([sql] * len(rhs_params))
+ sql = "(%s)" % self.logical_operator.join([sql] * len(rhs_params))
return sql, tuple(lhs_params) + tuple(rhs_params)
def as_mysql(self, compiler, connection):
- return self.as_sql(compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)")
+ return self.as_sql(
+ compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)"
+ )
def as_oracle(self, compiler, connection):
- sql, params = self.as_sql(compiler, connection, template="JSON_EXISTS(%s, '%%s')")
+ sql, params = self.as_sql(
+ compiler, connection, template="JSON_EXISTS(%s, '%%s')"
+ )
# Add paths directly into SQL because path expressions cannot be passed
# as bind variables on Oracle.
return sql % tuple(params), []
@@ -213,28 +227,30 @@ class HasKeyLookup(PostgresOperatorLookup):
return super().as_postgresql(compiler, connection)
def as_sqlite(self, compiler, connection):
- return self.as_sql(compiler, connection, template='JSON_TYPE(%s, %%s) IS NOT NULL')
+ return self.as_sql(
+ compiler, connection, template="JSON_TYPE(%s, %%s) IS NOT NULL"
+ )
class HasKey(HasKeyLookup):
- lookup_name = 'has_key'
- postgres_operator = '?'
+ lookup_name = "has_key"
+ postgres_operator = "?"
prepare_rhs = False
class HasKeys(HasKeyLookup):
- lookup_name = 'has_keys'
- postgres_operator = '?&'
- logical_operator = ' AND '
+ lookup_name = "has_keys"
+ postgres_operator = "?&"
+ logical_operator = " AND "
def get_prep_lookup(self):
return [str(item) for item in self.rhs]
class HasAnyKeys(HasKeys):
- lookup_name = 'has_any_keys'
- postgres_operator = '?|'
- logical_operator = ' OR '
+ lookup_name = "has_any_keys"
+ postgres_operator = "?|"
+ logical_operator = " OR "
class CaseInsensitiveMixin:
@@ -244,16 +260,17 @@ class CaseInsensitiveMixin:
Because utf8mb4_bin is a binary collation, comparison of JSON values is
case-sensitive.
"""
+
def process_lhs(self, compiler, connection):
lhs, lhs_params = super().process_lhs(compiler, connection)
- if connection.vendor == 'mysql':
- return 'LOWER(%s)' % lhs, lhs_params
+ if connection.vendor == "mysql":
+ return "LOWER(%s)" % lhs, lhs_params
return lhs, lhs_params
def process_rhs(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
- if connection.vendor == 'mysql':
- return 'LOWER(%s)' % rhs, rhs_params
+ if connection.vendor == "mysql":
+ return "LOWER(%s)" % rhs, rhs_params
return rhs, rhs_params
@@ -263,9 +280,9 @@ class JSONExact(lookups.Exact):
def process_rhs(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
# Treat None lookup values as null.
- if rhs == '%s' and rhs_params == [None]:
- rhs_params = ['null']
- if connection.vendor == 'mysql':
+ if rhs == "%s" and rhs_params == [None]:
+ rhs_params = ["null"]
+ if connection.vendor == "mysql":
func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
rhs = rhs % tuple(func)
return rhs, rhs_params
@@ -285,8 +302,8 @@ JSONField.register_lookup(JSONIContains)
class KeyTransform(Transform):
- postgres_operator = '->'
- postgres_nested_operator = '#>'
+ postgres_operator = "->"
+ postgres_nested_operator = "#>"
def __init__(self, key_name, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -299,41 +316,41 @@ class KeyTransform(Transform):
key_transforms.insert(0, previous.key_name)
previous = previous.lhs
lhs, params = compiler.compile(previous)
- if connection.vendor == 'oracle':
+ if connection.vendor == "oracle":
# Escape string-formatting.
- key_transforms = [key.replace('%', '%%') for key in key_transforms]
+ key_transforms = [key.replace("%", "%%") for key in key_transforms]
return lhs, params, key_transforms
def as_mysql(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
json_path = compile_json_path(key_transforms)
- return 'JSON_EXTRACT(%s, %%s)' % lhs, tuple(params) + (json_path,)
+ return "JSON_EXTRACT(%s, %%s)" % lhs, tuple(params) + (json_path,)
def as_oracle(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
json_path = compile_json_path(key_transforms)
return (
- "COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))" %
- ((lhs, json_path) * 2)
+ "COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))"
+ % ((lhs, json_path) * 2)
), tuple(params) * 2
def as_postgresql(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
if len(key_transforms) > 1:
- sql = '(%s %s %%s)' % (lhs, self.postgres_nested_operator)
+ sql = "(%s %s %%s)" % (lhs, self.postgres_nested_operator)
return sql, tuple(params) + (key_transforms,)
try:
lookup = int(self.key_name)
except ValueError:
lookup = self.key_name
- return '(%s %s %%s)' % (lhs, self.postgres_operator), tuple(params) + (lookup,)
+ return "(%s %s %%s)" % (lhs, self.postgres_operator), tuple(params) + (lookup,)
def as_sqlite(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
json_path = compile_json_path(key_transforms)
- datatype_values = ','.join([
- repr(datatype) for datatype in connection.ops.jsonfield_datatype_values
- ])
+ datatype_values = ",".join(
+ [repr(datatype) for datatype in connection.ops.jsonfield_datatype_values]
+ )
return (
"(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) "
"THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)"
@@ -341,8 +358,8 @@ class KeyTransform(Transform):
class KeyTextTransform(KeyTransform):
- postgres_operator = '->>'
- postgres_nested_operator = '#>>'
+ postgres_operator = "->>"
+ postgres_nested_operator = "#>>"
class KeyTransformTextLookupMixin:
@@ -352,14 +369,16 @@ class KeyTransformTextLookupMixin:
key values to text and performing the lookup on the resulting
representation.
"""
+
def __init__(self, key_transform, *args, **kwargs):
if not isinstance(key_transform, KeyTransform):
raise TypeError(
- 'Transform should be an instance of KeyTransform in order to '
- 'use this lookup.'
+ "Transform should be an instance of KeyTransform in order to "
+ "use this lookup."
)
key_text_transform = KeyTextTransform(
- key_transform.key_name, *key_transform.source_expressions,
+ key_transform.key_name,
+ *key_transform.source_expressions,
**key_transform.extra,
)
super().__init__(key_text_transform, *args, **kwargs)
@@ -376,12 +395,12 @@ class KeyTransformIsNull(lookups.IsNull):
return sql, params
# Column doesn't have a key or IS NULL.
lhs, lhs_params, _ = self.lhs.preprocess_lhs(compiler, connection)
- return '(NOT %s OR %s IS NULL)' % (sql, lhs), tuple(params) + tuple(lhs_params)
+ return "(NOT %s OR %s IS NULL)" % (sql, lhs), tuple(params) + tuple(lhs_params)
def as_sqlite(self, compiler, connection):
- template = 'JSON_TYPE(%s, %%s) IS NULL'
+ template = "JSON_TYPE(%s, %%s) IS NULL"
if not self.rhs:
- template = 'JSON_TYPE(%s, %%s) IS NOT NULL'
+ template = "JSON_TYPE(%s, %%s) IS NOT NULL"
return HasKey(self.lhs.lhs, self.lhs.key_name).as_sql(
compiler,
connection,
@@ -392,26 +411,29 @@ class KeyTransformIsNull(lookups.IsNull):
class KeyTransformIn(lookups.In):
def resolve_expression_parameter(self, compiler, connection, sql, param):
sql, params = super().resolve_expression_parameter(
- compiler, connection, sql, param,
+ compiler,
+ connection,
+ sql,
+ param,
)
if (
- not hasattr(param, 'as_sql') and
- not connection.features.has_native_json_field
+ not hasattr(param, "as_sql")
+ and not connection.features.has_native_json_field
):
- if connection.vendor == 'oracle':
+ if connection.vendor == "oracle":
value = json.loads(param)
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
if isinstance(value, (list, dict)):
- sql = sql % 'JSON_QUERY'
+ sql = sql % "JSON_QUERY"
else:
- sql = sql % 'JSON_VALUE'
- elif connection.vendor == 'mysql' or (
- connection.vendor == 'sqlite' and
- params[0] not in connection.ops.jsonfield_datatype_values
+ sql = sql % "JSON_VALUE"
+ elif connection.vendor == "mysql" or (
+ connection.vendor == "sqlite"
+ and params[0] not in connection.ops.jsonfield_datatype_values
):
sql = "JSON_EXTRACT(%s, '$')"
- if connection.vendor == 'mysql' and connection.mysql_is_mariadb:
- sql = 'JSON_UNQUOTE(%s)' % sql
+ if connection.vendor == "mysql" and connection.mysql_is_mariadb:
+ sql = "JSON_UNQUOTE(%s)" % sql
return sql, params
@@ -420,21 +442,21 @@ class KeyTransformExact(JSONExact):
if isinstance(self.rhs, KeyTransform):
return super(lookups.Exact, self).process_rhs(compiler, connection)
rhs, rhs_params = super().process_rhs(compiler, connection)
- if connection.vendor == 'oracle':
+ if connection.vendor == "oracle":
func = []
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
for value in rhs_params:
value = json.loads(value)
if isinstance(value, (list, dict)):
- func.append(sql % 'JSON_QUERY')
+ func.append(sql % "JSON_QUERY")
else:
- func.append(sql % 'JSON_VALUE')
+ func.append(sql % "JSON_VALUE")
rhs = rhs % tuple(func)
- elif connection.vendor == 'sqlite':
+ elif connection.vendor == "sqlite":
func = []
for value in rhs_params:
if value in connection.ops.jsonfield_datatype_values:
- func.append('%s')
+ func.append("%s")
else:
func.append("JSON_EXTRACT(%s, '$')")
rhs = rhs % tuple(func)
@@ -442,24 +464,28 @@ class KeyTransformExact(JSONExact):
def as_oracle(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
- if rhs_params == ['null']:
+ if rhs_params == ["null"]:
# Field has key and it's NULL.
has_key_expr = HasKey(self.lhs.lhs, self.lhs.key_name)
has_key_sql, has_key_params = has_key_expr.as_oracle(compiler, connection)
- is_null_expr = self.lhs.get_lookup('isnull')(self.lhs, True)
+ is_null_expr = self.lhs.get_lookup("isnull")(self.lhs, True)
is_null_sql, is_null_params = is_null_expr.as_sql(compiler, connection)
return (
- '%s AND %s' % (has_key_sql, is_null_sql),
+ "%s AND %s" % (has_key_sql, is_null_sql),
tuple(has_key_params) + tuple(is_null_params),
)
return super().as_sql(compiler, connection)
-class KeyTransformIExact(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact):
+class KeyTransformIExact(
+ CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact
+):
pass
-class KeyTransformIContains(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains):
+class KeyTransformIContains(
+ CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains
+):
pass
@@ -467,7 +493,9 @@ class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
pass
-class KeyTransformIStartsWith(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith):
+class KeyTransformIStartsWith(
+ CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith
+):
pass
@@ -475,7 +503,9 @@ class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
pass
-class KeyTransformIEndsWith(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith):
+class KeyTransformIEndsWith(
+ CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith
+):
pass
@@ -483,7 +513,9 @@ class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
pass
-class KeyTransformIRegex(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex):
+class KeyTransformIRegex(
+ CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex
+):
pass
@@ -530,7 +562,6 @@ KeyTransform.register_lookup(KeyTransformGte)
class KeyTransformFactory:
-
def __init__(self, key_name):
self.key_name = key_name