diff options
| author | sage <laymonage@gmail.com> | 2019-06-09 07:56:37 +0700 |
|---|---|---|
| committer | Mariusz Felisiak <felisiak.mariusz@gmail.com> | 2020-05-08 07:23:31 +0200 |
| commit | 6789ded0a6ab797f0dcdfa6ad5d1cfa46e23abcd (patch) | |
| tree | 1de598fc92480c64835b60b6ddbb461c3cd2e864 /django | |
| parent | f97f71f59249f1fbeebe84d4fc858d70fc456f7d (diff) | |
Fixed #12990, Refs #27694 -- Added JSONField model field.
Thanks to Adam Johnson, Carlton Gibson, Mariusz Felisiak, and Raphael
Michel for mentoring this Google Summer of Code 2019 project and
everyone else who helped with the patch.
Special thanks to Mads Jensen, Nick Pope, and Simon Charette for
extensive reviews.
Co-authored-by: Mariusz Felisiak <felisiak.mariusz@gmail.com>
Diffstat (limited to 'django')
27 files changed, 803 insertions, 263 deletions
diff --git a/django/contrib/postgres/aggregates/general.py b/django/contrib/postgres/aggregates/general.py index 31dd52773b..12cba62701 100644 --- a/django/contrib/postgres/aggregates/general.py +++ b/django/contrib/postgres/aggregates/general.py @@ -1,5 +1,5 @@ -from django.contrib.postgres.fields import ArrayField, JSONField -from django.db.models import Aggregate, Value +from django.contrib.postgres.fields import ArrayField +from django.db.models import Aggregate, JSONField, Value from .mixins import OrderableAggMixin diff --git a/django/contrib/postgres/apps.py b/django/contrib/postgres/apps.py index 97475de6f7..25cfa1a814 100644 --- a/django/contrib/postgres/apps.py +++ b/django/contrib/postgres/apps.py @@ -47,7 +47,6 @@ class PostgresConfig(AppConfig): for conn in connections.all(): if conn.vendor == 'postgresql': conn.introspection.data_types_reverse.update({ - 3802: 'django.contrib.postgres.fields.JSONField', 3904: 'django.contrib.postgres.fields.IntegerRangeField', 3906: 'django.contrib.postgres.fields.DecimalRangeField', 3910: 'django.contrib.postgres.fields.DateTimeRangeField', diff --git a/django/contrib/postgres/fields/jsonb.py b/django/contrib/postgres/fields/jsonb.py index c402dd19d8..7f76b29a13 100644 --- a/django/contrib/postgres/fields/jsonb.py +++ b/django/contrib/postgres/fields/jsonb.py @@ -1,185 +1,43 @@ -import json +import warnings -from psycopg2.extras import Json - -from django.contrib.postgres import forms, lookups -from django.core import exceptions -from django.db.models import ( - Field, TextField, Transform, lookups as builtin_lookups, +from django.db.models import JSONField as BuiltinJSONField +from django.db.models.fields.json import ( + KeyTextTransform as BuiltinKeyTextTransform, + KeyTransform as BuiltinKeyTransform, ) -from django.db.models.fields.mixins import CheckFieldDefaultMixin -from django.utils.translation import gettext_lazy as _ +from django.utils.deprecation import RemovedInDjango40Warning __all__ = ['JSONField'] -class JsonAdapter(Json): - """ - Customized psycopg2.extras.Json to allow for a custom encoder. - """ - def __init__(self, adapted, dumps=None, encoder=None): - self.encoder = encoder - super().__init__(adapted, dumps=dumps) - - def dumps(self, obj): - options = {'cls': self.encoder} if self.encoder else {} - return json.dumps(obj, **options) - - -class JSONField(CheckFieldDefaultMixin, Field): - empty_strings_allowed = False - description = _('A JSON object') - default_error_messages = { - 'invalid': _("Value must be valid JSON."), +class JSONField(BuiltinJSONField): + system_check_deprecated_details = { + 'msg': ( + 'django.contrib.postgres.fields.JSONField is deprecated. Support ' + 'for it (except in historical migrations) will be removed in ' + 'Django 4.0.' + ), + 'hint': 'Use django.db.models.JSONField instead.', + 'id': 'fields.W904', } - _default_hint = ('dict', '{}') - - def __init__(self, verbose_name=None, name=None, encoder=None, **kwargs): - if encoder and not callable(encoder): - raise ValueError("The encoder parameter must be a callable object.") - self.encoder = encoder - super().__init__(verbose_name, name, **kwargs) - - def db_type(self, connection): - return 'jsonb' - - def deconstruct(self): - name, path, args, kwargs = super().deconstruct() - if self.encoder is not None: - kwargs['encoder'] = self.encoder - return name, path, args, kwargs - - def get_transform(self, name): - transform = super().get_transform(name) - if transform: - return transform - return KeyTransformFactory(name) - - def get_prep_value(self, value): - if value is not None: - return JsonAdapter(value, encoder=self.encoder) - return value - - def validate(self, value, model_instance): - super().validate(value, model_instance) - options = {'cls': self.encoder} if self.encoder else {} - try: - json.dumps(value, **options) - except TypeError: - raise exceptions.ValidationError( - self.error_messages['invalid'], - code='invalid', - params={'value': value}, - ) - - def value_to_string(self, obj): - return self.value_from_object(obj) - def formfield(self, **kwargs): - return super().formfield(**{ - 'form_class': forms.JSONField, - **kwargs, - }) - -JSONField.register_lookup(lookups.DataContains) -JSONField.register_lookup(lookups.ContainedBy) -JSONField.register_lookup(lookups.HasKey) -JSONField.register_lookup(lookups.HasKeys) -JSONField.register_lookup(lookups.HasAnyKeys) -JSONField.register_lookup(lookups.JSONExact) - - -class KeyTransform(Transform): - operator = '->' - nested_operator = '#>' - - def __init__(self, key_name, *args, **kwargs): +class KeyTransform(BuiltinKeyTransform): + def __init__(self, *args, **kwargs): + warnings.warn( + 'django.contrib.postgres.fields.jsonb.KeyTransform is deprecated ' + 'in favor of django.db.models.fields.json.KeyTransform.', + RemovedInDjango40Warning, stacklevel=2, + ) super().__init__(*args, **kwargs) - self.key_name = key_name - def as_sql(self, compiler, connection): - key_transforms = [self.key_name] - previous = self.lhs - while isinstance(previous, KeyTransform): - key_transforms.insert(0, previous.key_name) - previous = previous.lhs - lhs, params = compiler.compile(previous) - if len(key_transforms) > 1: - return '(%s %s %%s)' % (lhs, self.nested_operator), params + [key_transforms] - try: - lookup = int(self.key_name) - except ValueError: - lookup = self.key_name - return '(%s %s %%s)' % (lhs, self.operator), tuple(params) + (lookup,) - -class KeyTextTransform(KeyTransform): - operator = '->>' - nested_operator = '#>>' - output_field = TextField() - - -class KeyTransformTextLookupMixin: - """ - Mixin for combining with a lookup expecting a text lhs from a JSONField - key lookup. Make use of the ->> operator instead of casting key values to - text and performing the lookup on the resulting representation. - """ - def __init__(self, key_transform, *args, **kwargs): - assert isinstance(key_transform, KeyTransform) - key_text_transform = KeyTextTransform( - key_transform.key_name, *key_transform.source_expressions, **key_transform.extra +class KeyTextTransform(BuiltinKeyTextTransform): + def __init__(self, *args, **kwargs): + warnings.warn( + 'django.contrib.postgres.fields.jsonb.KeyTextTransform is ' + 'deprecated in favor of ' + 'django.db.models.fields.json.KeyTextTransform.', + RemovedInDjango40Warning, stacklevel=2, ) - super().__init__(key_text_transform, *args, **kwargs) - - -class KeyTransformIExact(KeyTransformTextLookupMixin, builtin_lookups.IExact): - pass - - -class KeyTransformIContains(KeyTransformTextLookupMixin, builtin_lookups.IContains): - pass - - -class KeyTransformStartsWith(KeyTransformTextLookupMixin, builtin_lookups.StartsWith): - pass - - -class KeyTransformIStartsWith(KeyTransformTextLookupMixin, builtin_lookups.IStartsWith): - pass - - -class KeyTransformEndsWith(KeyTransformTextLookupMixin, builtin_lookups.EndsWith): - pass - - -class KeyTransformIEndsWith(KeyTransformTextLookupMixin, builtin_lookups.IEndsWith): - pass - - -class KeyTransformRegex(KeyTransformTextLookupMixin, builtin_lookups.Regex): - pass - - -class KeyTransformIRegex(KeyTransformTextLookupMixin, builtin_lookups.IRegex): - pass - - -KeyTransform.register_lookup(KeyTransformIExact) -KeyTransform.register_lookup(KeyTransformIContains) -KeyTransform.register_lookup(KeyTransformStartsWith) -KeyTransform.register_lookup(KeyTransformIStartsWith) -KeyTransform.register_lookup(KeyTransformEndsWith) -KeyTransform.register_lookup(KeyTransformIEndsWith) -KeyTransform.register_lookup(KeyTransformRegex) -KeyTransform.register_lookup(KeyTransformIRegex) - - -class KeyTransformFactory: - - def __init__(self, key_name): - self.key_name = key_name - - def __call__(self, *args, **kwargs): - return KeyTransform(self.key_name, *args, **kwargs) + super().__init__(*args, **kwargs) diff --git a/django/contrib/postgres/forms/jsonb.py b/django/contrib/postgres/forms/jsonb.py index 196d2b9096..ebc85efa6f 100644 --- a/django/contrib/postgres/forms/jsonb.py +++ b/django/contrib/postgres/forms/jsonb.py @@ -1,63 +1,16 @@ -import json +import warnings -from django import forms -from django.core.exceptions import ValidationError -from django.utils.translation import gettext_lazy as _ +from django.forms import JSONField as BuiltinJSONField +from django.utils.deprecation import RemovedInDjango40Warning __all__ = ['JSONField'] -class InvalidJSONInput(str): - pass - - -class JSONString(str): - pass - - -class JSONField(forms.CharField): - default_error_messages = { - 'invalid': _('ā%(value)sā value must be valid JSON.'), - } - widget = forms.Textarea - - def to_python(self, value): - if self.disabled: - return value - if value in self.empty_values: - return None - elif isinstance(value, (list, dict, int, float, JSONString)): - return value - try: - converted = json.loads(value) - except json.JSONDecodeError: - raise ValidationError( - self.error_messages['invalid'], - code='invalid', - params={'value': value}, - ) - if isinstance(converted, str): - return JSONString(converted) - else: - return converted - - def bound_data(self, data, initial): - if self.disabled: - return initial - try: - return json.loads(data) - except json.JSONDecodeError: - return InvalidJSONInput(data) - - def prepare_value(self, value): - if isinstance(value, InvalidJSONInput): - return value - return json.dumps(value) - - def has_changed(self, initial, data): - if super().has_changed(initial, data): - return True - # For purposes of seeing whether something has changed, True isn't the - # same as 1 and the order of keys doesn't matter. - data = self.to_python(data) - return json.dumps(initial, sort_keys=True) != json.dumps(data, sort_keys=True) +class JSONField(BuiltinJSONField): + def __init__(self, *args, **kwargs): + warnings.warn( + 'django.contrib.postgres.forms.JSONField is deprecated in favor ' + 'of django.forms.JSONField.', + RemovedInDjango40Warning, stacklevel=2, + ) + super().__init__(*args, **kwargs) diff --git a/django/contrib/postgres/lookups.py b/django/contrib/postgres/lookups.py index 360e0c6a31..28d8590e1d 100644 --- a/django/contrib/postgres/lookups.py +++ b/django/contrib/postgres/lookups.py @@ -1,5 +1,5 @@ from django.db.models import Transform -from django.db.models.lookups import Exact, PostgresOperatorLookup +from django.db.models.lookups import PostgresOperatorLookup from .search import SearchVector, SearchVectorExact, SearchVectorField @@ -58,12 +58,3 @@ class SearchLookup(SearchVectorExact): class TrigramSimilar(PostgresOperatorLookup): lookup_name = 'trigram_similar' postgres_operator = '%%' - - -class JSONExact(Exact): - can_use_none_as_rhs = True - - def process_rhs(self, compiler, connection): - result = super().process_rhs(compiler, connection) - # Treat None lookup values as null. - return ("'null'", []) if result == ('%s', [None]) else result diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py index a8f55f966c..33eeff171d 100644 --- a/django/db/backends/base/features.py +++ b/django/db/backends/base/features.py @@ -300,6 +300,15 @@ class BaseDatabaseFeatures: # Does the backend support boolean expressions in the SELECT clause? supports_boolean_expr_in_select_clause = True + # Does the backend support JSONField? + supports_json_field = True + # Can the backend introspect a JSONField? + can_introspect_json_field = True + # Does the backend support primitives in JSONField? + supports_primitives_in_json_field = True + # Is there a true datatype for JSON? + has_native_json_field = False + def __init__(self, connection): self.connection = connection diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py index 1dbcee4637..6d0f5c68b3 100644 --- a/django/db/backends/base/operations.py +++ b/django/db/backends/base/operations.py @@ -159,6 +159,13 @@ class BaseDatabaseOperations: """ return self.date_extract_sql(lookup_type, field_name) + def json_cast_text_sql(self, field_name): + """Return the SQL to cast a JSON value to text value.""" + raise NotImplementedError( + 'subclasses of BaseDatabaseOperations may require a ' + 'json_cast_text_sql() method' + ) + def deferrable_sql(self): """ Return the SQL to make a constraint "initially deferred" during a diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py index 44560ccdaf..8792f3c7c5 100644 --- a/django/db/backends/mysql/base.py +++ b/django/db/backends/mysql/base.py @@ -118,6 +118,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): 'BigIntegerField': 'bigint', 'IPAddressField': 'char(15)', 'GenericIPAddressField': 'char(39)', + 'JSONField': 'json', 'NullBooleanField': 'bool', 'OneToOneField': 'integer', 'PositiveBigIntegerField': 'bigint UNSIGNED', @@ -341,11 +342,16 @@ class DatabaseWrapper(BaseDatabaseWrapper): @cached_property def data_type_check_constraints(self): if self.features.supports_column_check_constraints: - return { + check_constraints = { 'PositiveBigIntegerField': '`%(column)s` >= 0', 'PositiveIntegerField': '`%(column)s` >= 0', 'PositiveSmallIntegerField': '`%(column)s` >= 0', } + if self.mysql_is_mariadb and self.mysql_version < (10, 4, 3): + # MariaDB < 10.4.3 doesn't automatically use the JSON_VALID as + # a check constraint. + check_constraints['JSONField'] = 'JSON_VALID(`%(column)s`)' + return check_constraints return {} @cached_property diff --git a/django/db/backends/mysql/features.py b/django/db/backends/mysql/features.py index 8a2a64c5e4..faa84f7d7c 100644 --- a/django/db/backends/mysql/features.py +++ b/django/db/backends/mysql/features.py @@ -160,3 +160,15 @@ class DatabaseFeatures(BaseDatabaseFeatures): def supports_default_in_lead_lag(self): # To be added in https://jira.mariadb.org/browse/MDEV-12981. return not self.connection.mysql_is_mariadb + + @cached_property + def supports_json_field(self): + if self.connection.mysql_is_mariadb: + return self.connection.mysql_version >= (10, 2, 7) + return self.connection.mysql_version >= (5, 7, 8) + + @cached_property + def can_introspect_json_field(self): + if self.connection.mysql_is_mariadb: + return self.supports_json_field and self.can_introspect_check_constraints + return self.supports_json_field diff --git a/django/db/backends/mysql/introspection.py b/django/db/backends/mysql/introspection.py index 50160ba590..1a104c7810 100644 --- a/django/db/backends/mysql/introspection.py +++ b/django/db/backends/mysql/introspection.py @@ -9,7 +9,7 @@ from django.db.backends.base.introspection import ( from django.db.models import Index from django.utils.datastructures import OrderedSet -FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('extra', 'is_unsigned')) +FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('extra', 'is_unsigned', 'has_json_constraint')) InfoLine = namedtuple('InfoLine', 'col_name data_type max_len num_prec num_scale extra column_default is_unsigned') @@ -24,6 +24,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): FIELD_TYPE.DOUBLE: 'FloatField', FIELD_TYPE.FLOAT: 'FloatField', FIELD_TYPE.INT24: 'IntegerField', + FIELD_TYPE.JSON: 'JSONField', FIELD_TYPE.LONG: 'IntegerField', FIELD_TYPE.LONGLONG: 'BigIntegerField', FIELD_TYPE.SHORT: 'SmallIntegerField', @@ -53,6 +54,10 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): return 'PositiveIntegerField' elif field_type == 'SmallIntegerField': return 'PositiveSmallIntegerField' + # JSON data type is an alias for LONGTEXT in MariaDB, use check + # constraints clauses to introspect JSONField. + if description.has_json_constraint: + return 'JSONField' return field_type def get_table_list(self, cursor): @@ -66,6 +71,19 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): Return a description of the table with the DB-API cursor.description interface." """ + json_constraints = {} + if self.connection.mysql_is_mariadb and self.connection.features.can_introspect_json_field: + # JSON data type is an alias for LONGTEXT in MariaDB, select + # JSON_VALID() constraints to introspect JSONField. + cursor.execute(""" + SELECT c.constraint_name AS column_name + FROM information_schema.check_constraints AS c + WHERE + c.table_name = %s AND + LOWER(c.check_clause) = 'json_valid(`' + LOWER(c.constraint_name) + '`)' AND + c.constraint_schema = DATABASE() + """, [table_name]) + json_constraints = {row[0] for row in cursor.fetchall()} # information_schema database gives more accurate results for some figures: # - varchar length returned by cursor.description is an internal length, # not visible length (#5725) @@ -100,6 +118,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): info.column_default, info.extra, info.is_unsigned, + line[0] in json_constraints, )) return fields diff --git a/django/db/backends/mysql/operations.py b/django/db/backends/mysql/operations.py index d01e3bef6b..bc04739f0d 100644 --- a/django/db/backends/mysql/operations.py +++ b/django/db/backends/mysql/operations.py @@ -368,3 +368,13 @@ class DatabaseOperations(BaseDatabaseOperations): def insert_statement(self, ignore_conflicts=False): return 'INSERT IGNORE INTO' if ignore_conflicts else super().insert_statement(ignore_conflicts) + + def lookup_cast(self, lookup_type, internal_type=None): + lookup = '%s' + if internal_type == 'JSONField': + if self.connection.mysql_is_mariadb or lookup_type in ( + 'iexact', 'contains', 'icontains', 'startswith', 'istartswith', + 'endswith', 'iendswith', 'regex', 'iregex', + ): + lookup = 'JSON_UNQUOTE(%s)' + return lookup diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py index e9ec2bac51..e104530228 100644 --- a/django/db/backends/oracle/base.py +++ b/django/db/backends/oracle/base.py @@ -123,6 +123,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): 'FilePathField': 'NVARCHAR2(%(max_length)s)', 'FloatField': 'DOUBLE PRECISION', 'IntegerField': 'NUMBER(11)', + 'JSONField': 'NCLOB', 'BigIntegerField': 'NUMBER(19)', 'IPAddressField': 'VARCHAR2(15)', 'GenericIPAddressField': 'VARCHAR2(39)', @@ -141,6 +142,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): } data_type_check_constraints = { 'BooleanField': '%(qn_column)s IN (0,1)', + 'JSONField': '%(qn_column)s IS JSON', 'NullBooleanField': '%(qn_column)s IN (0,1)', 'PositiveBigIntegerField': '%(qn_column)s >= 0', 'PositiveIntegerField': '%(qn_column)s >= 0', diff --git a/django/db/backends/oracle/features.py b/django/db/backends/oracle/features.py index 3782874512..bae09559ce 100644 --- a/django/db/backends/oracle/features.py +++ b/django/db/backends/oracle/features.py @@ -60,3 +60,4 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_slicing_ordering_in_compound = True allows_multiple_constraints_on_same_fields = False supports_boolean_expr_in_select_clause = False + supports_primitives_in_json_field = False diff --git a/django/db/backends/oracle/introspection.py b/django/db/backends/oracle/introspection.py index 2322ae0b5d..3fab497b2a 100644 --- a/django/db/backends/oracle/introspection.py +++ b/django/db/backends/oracle/introspection.py @@ -7,7 +7,7 @@ from django.db.backends.base.introspection import ( BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo, ) -FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('is_autofield',)) +FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('is_autofield', 'is_json')) class DatabaseIntrospection(BaseDatabaseIntrospection): @@ -45,6 +45,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): return 'IntegerField' elif scale == -127: return 'FloatField' + elif data_type == cx_Oracle.NCLOB and description.is_json: + return 'JSONField' return super().get_field_type(data_type, description) @@ -83,12 +85,23 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): CASE WHEN identity_column = 'YES' THEN 1 ELSE 0 - END as is_autofield + END as is_autofield, + CASE + WHEN EXISTS ( + SELECT 1 + FROM user_json_columns + WHERE + user_json_columns.table_name = user_tab_cols.table_name AND + user_json_columns.column_name = user_tab_cols.column_name + ) + THEN 1 + ELSE 0 + END as is_json FROM user_tab_cols WHERE table_name = UPPER(%s)""", [table_name]) field_map = { - column: (internal_size, default if default != 'NULL' else None, is_autofield) - for column, default, internal_size, is_autofield in cursor.fetchall() + column: (internal_size, default if default != 'NULL' else None, is_autofield, is_json) + for column, default, internal_size, is_autofield, is_json in cursor.fetchall() } self.cache_bust_counter += 1 cursor.execute("SELECT * FROM {} WHERE ROWNUM < 2 AND {} > 0".format( @@ -97,11 +110,11 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): description = [] for desc in cursor.description: name = desc[0] - internal_size, default, is_autofield = field_map[name] + internal_size, default, is_autofield, is_json = field_map[name] name = name % {} # cx_Oracle, for some reason, doubles percent signs. description.append(FieldInfo( self.identifier_converter(name), *desc[1:3], internal_size, desc[4] or 0, - desc[5] or 0, *desc[6:], default, is_autofield, + desc[5] or 0, *desc[6:], default, is_autofield, is_json, )) return description diff --git a/django/db/backends/oracle/operations.py b/django/db/backends/oracle/operations.py index 6f4121425f..9dc28c84cd 100644 --- a/django/db/backends/oracle/operations.py +++ b/django/db/backends/oracle/operations.py @@ -176,7 +176,7 @@ END; def get_db_converters(self, expression): converters = super().get_db_converters(expression) internal_type = expression.output_field.get_internal_type() - if internal_type == 'TextField': + if internal_type in ['JSONField', 'TextField']: converters.append(self.convert_textfield_value) elif internal_type == 'BinaryField': converters.append(self.convert_binaryfield_value) @@ -269,7 +269,7 @@ END; return tuple(columns) def field_cast_sql(self, db_type, internal_type): - if db_type and db_type.endswith('LOB'): + if db_type and db_type.endswith('LOB') and internal_type != 'JSONField': return "DBMS_LOB.SUBSTR(%s)" else: return "%s" @@ -307,6 +307,8 @@ END; def lookup_cast(self, lookup_type, internal_type=None): if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'): return "UPPER(%s)" + if internal_type == 'JSONField' and lookup_type == 'exact': + return 'DBMS_LOB.SUBSTR(%s)' return "%s" def max_in_list_size(self): diff --git a/django/db/backends/postgresql/base.py b/django/db/backends/postgresql/base.py index 192316d7fb..ed911a91da 100644 --- a/django/db/backends/postgresql/base.py +++ b/django/db/backends/postgresql/base.py @@ -86,6 +86,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): 'BigIntegerField': 'bigint', 'IPAddressField': 'inet', 'GenericIPAddressField': 'inet', + 'JSONField': 'jsonb', 'NullBooleanField': 'boolean', 'OneToOneField': 'integer', 'PositiveBigIntegerField': 'bigint', diff --git a/django/db/backends/postgresql/features.py b/django/db/backends/postgresql/features.py index 3b4199fa78..00a8009cf2 100644 --- a/django/db/backends/postgresql/features.py +++ b/django/db/backends/postgresql/features.py @@ -12,6 +12,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): has_real_datatype = True has_native_uuid_field = True has_native_duration_field = True + has_native_json_field = True can_defer_constraint_checks = True has_select_for_update = True has_select_for_update_nowait = True diff --git a/django/db/backends/postgresql/introspection.py b/django/db/backends/postgresql/introspection.py index beec8619cc..dee305cc06 100644 --- a/django/db/backends/postgresql/introspection.py +++ b/django/db/backends/postgresql/introspection.py @@ -26,6 +26,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): 1266: 'TimeField', 1700: 'DecimalField', 2950: 'UUIDField', + 3802: 'JSONField', } ignored_tables = [] diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py index 70880d4179..c67062a4a7 100644 --- a/django/db/backends/postgresql/operations.py +++ b/django/db/backends/postgresql/operations.py @@ -74,6 +74,9 @@ class DatabaseOperations(BaseDatabaseOperations): def time_trunc_sql(self, lookup_type, field_name): return "DATE_TRUNC('%s', %s)::time" % (lookup_type, field_name) + def json_cast_text_sql(self, field_name): + return '(%s)::text' % field_name + def deferrable_sql(self): return " DEFERRABLE INITIALLY DEFERRED" diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index 26968475bf..31e8a55a43 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -5,6 +5,7 @@ import datetime import decimal import functools import hashlib +import json import math import operator import re @@ -101,6 +102,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): 'BigIntegerField': 'bigint', 'IPAddressField': 'char(15)', 'GenericIPAddressField': 'char(39)', + 'JSONField': 'text', 'NullBooleanField': 'bool', 'OneToOneField': 'integer', 'PositiveBigIntegerField': 'bigint unsigned', @@ -115,6 +117,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): } data_type_check_constraints = { 'PositiveBigIntegerField': '"%(column)s" >= 0', + 'JSONField': '(JSON_VALID("%(column)s") OR "%(column)s" IS NULL)', 'PositiveIntegerField': '"%(column)s" >= 0', 'PositiveSmallIntegerField': '"%(column)s" >= 0', } @@ -233,6 +236,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): create_deterministic_function('DEGREES', 1, none_guard(math.degrees)) create_deterministic_function('EXP', 1, none_guard(math.exp)) create_deterministic_function('FLOOR', 1, none_guard(math.floor)) + create_deterministic_function('JSON_CONTAINS', 2, _sqlite_json_contains) create_deterministic_function('LN', 1, none_guard(math.log)) create_deterministic_function('LOG', 2, none_guard(lambda x, y: math.log(y, x))) create_deterministic_function('LPAD', 3, _sqlite_lpad) @@ -598,3 +602,11 @@ def _sqlite_lpad(text, length, fill_text): @none_guard def _sqlite_rpad(text, length, fill_text): return (text + fill_text * length)[:length] + + +@none_guard +def _sqlite_json_contains(haystack, needle): + target, candidate = json.loads(haystack), json.loads(needle) + if isinstance(target, dict) and isinstance(candidate, dict): + return target.items() >= candidate.items() + return target == candidate diff --git a/django/db/backends/sqlite3/features.py b/django/db/backends/sqlite3/features.py index 817b1067e3..1b6f99a58c 100644 --- a/django/db/backends/sqlite3/features.py +++ b/django/db/backends/sqlite3/features.py @@ -1,4 +1,9 @@ +import operator + +from django.db import transaction from django.db.backends.base.features import BaseDatabaseFeatures +from django.db.utils import OperationalError +from django.utils.functional import cached_property from .base import Database @@ -45,3 +50,14 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_aggregate_filter_clause = Database.sqlite_version_info >= (3, 30, 1) supports_order_by_nulls_modifier = Database.sqlite_version_info >= (3, 30, 0) order_by_nulls_first = True + + @cached_property + def supports_json_field(self): + try: + with self.connection.cursor() as cursor, transaction.atomic(): + cursor.execute('SELECT JSON(\'{"a": "b"}\')') + except OperationalError: + return False + return True + + can_introspect_json_field = property(operator.attrgetter('supports_json_field')) diff --git a/django/db/backends/sqlite3/introspection.py b/django/db/backends/sqlite3/introspection.py index a203c454df..992e925e10 100644 --- a/django/db/backends/sqlite3/introspection.py +++ b/django/db/backends/sqlite3/introspection.py @@ -9,7 +9,7 @@ from django.db.backends.base.introspection import ( from django.db.models import Index from django.utils.regex_helper import _lazy_re_compile -FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('pk',)) +FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('pk', 'has_json_constraint')) field_size_re = _lazy_re_compile(r'^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$') @@ -63,6 +63,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): # No support for BigAutoField or SmallAutoField as SQLite treats # all integer primary keys as signed 64-bit integers. return 'AutoField' + if description.has_json_constraint: + return 'JSONField' return field_type def get_table_list(self, cursor): @@ -81,12 +83,28 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): interface. """ cursor.execute('PRAGMA table_info(%s)' % self.connection.ops.quote_name(table_name)) + table_info = cursor.fetchall() + json_columns = set() + if self.connection.features.can_introspect_json_field: + for line in table_info: + column = line[1] + json_constraint_sql = '%%json_valid("%s")%%' % column + has_json_constraint = cursor.execute(""" + SELECT sql + FROM sqlite_master + WHERE + type = 'table' AND + name = %s AND + sql LIKE %s + """, [table_name, json_constraint_sql]).fetchone() + if has_json_constraint: + json_columns.add(column) return [ FieldInfo( name, data_type, None, get_field_size(data_type), None, None, - not notnull, default, pk == 1, + not notnull, default, pk == 1, name in json_columns ) - for cid, name, data_type, notnull, default, pk in cursor.fetchall() + for cid, name, data_type, notnull, default, pk in table_info ] def get_sequences(self, cursor, table_name, table_fields=()): diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py index 7af6e60c51..a583af2aff 100644 --- a/django/db/models/__init__.py +++ b/django/db/models/__init__.py @@ -18,6 +18,7 @@ from django.db.models.expressions import ( from django.db.models.fields import * # NOQA from django.db.models.fields import __all__ as fields_all from django.db.models.fields.files import FileField, ImageField +from django.db.models.fields.json import JSONField from django.db.models.fields.proxy import OrderWrt from django.db.models.indexes import * # NOQA from django.db.models.indexes import __all__ as indexes_all @@ -43,9 +44,9 @@ __all__ += [ 'Func', 'OrderBy', 'OuterRef', 'RowRange', 'Subquery', 'Value', 'ValueRange', 'When', 'Window', 'WindowFrame', - 'FileField', 'ImageField', 'OrderWrt', 'Lookup', 'Transform', 'Manager', - 'Prefetch', 'Q', 'QuerySet', 'prefetch_related_objects', 'DEFERRED', 'Model', - 'FilteredRelation', + 'FileField', 'ImageField', 'JSONField', 'OrderWrt', 'Lookup', 'Transform', + 'Manager', 'Prefetch', 'Q', 'QuerySet', 'prefetch_related_objects', + 'DEFERRED', 'Model', 'FilteredRelation', 'ForeignKey', 'ForeignObject', 'OneToOneField', 'ManyToManyField', 'ForeignObjectRel', 'ManyToOneRel', 'ManyToManyRel', 'OneToOneRel', ] diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index 08c2a18d94..0fd69059ee 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -496,6 +496,8 @@ class Field(RegisterLookupMixin): path = path.replace("django.db.models.fields.related", "django.db.models") elif path.startswith("django.db.models.fields.files"): path = path.replace("django.db.models.fields.files", "django.db.models") + elif path.startswith('django.db.models.fields.json'): + path = path.replace('django.db.models.fields.json', 'django.db.models') elif path.startswith("django.db.models.fields.proxy"): path = path.replace("django.db.models.fields.proxy", "django.db.models") elif path.startswith("django.db.models.fields"): diff --git a/django/db/models/fields/json.py b/django/db/models/fields/json.py new file mode 100644 index 0000000000..edc5441799 --- /dev/null +++ b/django/db/models/fields/json.py @@ -0,0 +1,525 @@ +import json + +from django import forms +from django.core import checks, exceptions +from django.db import NotSupportedError, connections, router +from django.db.models import lookups +from django.db.models.lookups import PostgresOperatorLookup, Transform +from django.utils.translation import gettext_lazy as _ + +from . import Field +from .mixins import CheckFieldDefaultMixin + +__all__ = ['JSONField'] + + +class JSONField(CheckFieldDefaultMixin, Field): + empty_strings_allowed = False + description = _('A JSON object') + default_error_messages = { + 'invalid': _('Value must be valid JSON.'), + } + _default_hint = ('dict', '{}') + + def __init__( + self, verbose_name=None, name=None, encoder=None, decoder=None, + **kwargs, + ): + if encoder and not callable(encoder): + raise ValueError('The encoder parameter must be a callable object.') + if decoder and not callable(decoder): + raise ValueError('The decoder parameter must be a callable object.') + self.encoder = encoder + self.decoder = decoder + super().__init__(verbose_name, name, **kwargs) + + def check(self, **kwargs): + errors = super().check(**kwargs) + databases = kwargs.get('databases') or [] + errors.extend(self._check_supported(databases)) + return errors + + def _check_supported(self, databases): + errors = [] + for db in databases: + if not router.allow_migrate_model(db, self.model): + continue + connection = connections[db] + if not ( + 'supports_json_field' in self.model._meta.required_db_features or + connection.features.supports_json_field + ): + errors.append( + checks.Error( + '%s does not support JSONFields.' + % connection.display_name, + obj=self.model, + id='fields.E180', + ) + ) + return errors + + def deconstruct(self): + name, path, args, kwargs = super().deconstruct() + if self.encoder is not None: + kwargs['encoder'] = self.encoder + if self.decoder is not None: + kwargs['decoder'] = self.decoder + return name, path, args, kwargs + + def from_db_value(self, value, expression, connection): + if value is None: + return value + if connection.features.has_native_json_field and self.decoder is None: + return value + try: + return json.loads(value, cls=self.decoder) + except json.JSONDecodeError: + return value + + def get_internal_type(self): + return 'JSONField' + + def get_prep_value(self, value): + if value is None: + return value + return json.dumps(value, cls=self.encoder) + + def get_transform(self, name): + transform = super().get_transform(name) + if transform: + return transform + return KeyTransformFactory(name) + + def select_format(self, compiler, sql, params): + if ( + compiler.connection.features.has_native_json_field and + self.decoder is not None + ): + return compiler.connection.ops.json_cast_text_sql(sql), params + return super().select_format(compiler, sql, params) + + def validate(self, value, model_instance): + super().validate(value, model_instance) + try: + json.dumps(value, cls=self.encoder) + except TypeError: + raise exceptions.ValidationError( + self.error_messages['invalid'], + code='invalid', + params={'value': value}, + ) + + def value_to_string(self, obj): + return self.value_from_object(obj) + + def formfield(self, **kwargs): + return super().formfield(**{ + 'form_class': forms.JSONField, + 'encoder': self.encoder, + 'decoder': self.decoder, + **kwargs, + }) + + +def compile_json_path(key_transforms, include_root=True): + path = ['$'] if include_root else [] + for key_transform in key_transforms: + try: + num = int(key_transform) + except ValueError: # non-integer + path.append('.') + path.append(json.dumps(key_transform)) + else: + path.append('[%s]' % num) + return ''.join(path) + + +class DataContains(PostgresOperatorLookup): + lookup_name = 'contains' + postgres_operator = '@>' + + def as_sql(self, compiler, connection): + lhs, lhs_params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.process_rhs(compiler, connection) + params = tuple(lhs_params) + tuple(rhs_params) + return 'JSON_CONTAINS(%s, %s)' % (lhs, rhs), params + + def as_oracle(self, compiler, connection): + if isinstance(self.rhs, KeyTransform): + return HasKey(self.lhs, self.rhs).as_oracle(compiler, connection) + lhs, lhs_params = self.process_lhs(compiler, connection) + params = tuple(lhs_params) + sql = ( + "JSON_QUERY(%s, '$%s' WITH WRAPPER) = " + "JSON_QUERY('%s', '$.value' WITH WRAPPER)" + ) + rhs = json.loads(self.rhs) + if isinstance(rhs, dict): + if not rhs: + return "DBMS_LOB.SUBSTR(%s) LIKE '{%%%%}'" % lhs, params + return ' AND '.join([ + sql % ( + lhs, '.%s' % json.dumps(key), json.dumps({'value': value}), + ) for key, value in rhs.items() + ]), params + return sql % (lhs, '', json.dumps({'value': rhs})), params + + +class ContainedBy(PostgresOperatorLookup): + lookup_name = 'contained_by' + postgres_operator = '<@' + + def as_sql(self, compiler, connection): + lhs, lhs_params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.process_rhs(compiler, connection) + params = tuple(rhs_params) + tuple(lhs_params) + return 'JSON_CONTAINS(%s, %s)' % (rhs, lhs), params + + def as_oracle(self, compiler, connection): + raise NotSupportedError('contained_by lookup is not supported on Oracle.') + + +class HasKeyLookup(PostgresOperatorLookup): + logical_operator = None + + def as_sql(self, compiler, connection, template=None): + # Process JSON path from the left-hand side. + if isinstance(self.lhs, KeyTransform): + lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(compiler, connection) + lhs_json_path = compile_json_path(lhs_key_transforms) + else: + lhs, lhs_params = self.process_lhs(compiler, connection) + lhs_json_path = '$' + sql = template % lhs + # Process JSON path from the right-hand side. + rhs = self.rhs + rhs_params = [] + if not isinstance(rhs, (list, tuple)): + rhs = [rhs] + for key in rhs: + if isinstance(key, KeyTransform): + *_, rhs_key_transforms = key.preprocess_lhs(compiler, connection) + else: + rhs_key_transforms = [key] + rhs_params.append('%s%s' % ( + lhs_json_path, + compile_json_path(rhs_key_transforms, include_root=False), + )) + # Add condition for each key. + if self.logical_operator: + sql = '(%s)' % self.logical_operator.join([sql] * len(rhs_params)) + return sql, tuple(lhs_params) + tuple(rhs_params) + + def as_mysql(self, compiler, connection): + return self.as_sql(compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)") + + def as_oracle(self, compiler, connection): + sql, params = self.as_sql(compiler, connection, template="JSON_EXISTS(%s, '%%s')") + # Add paths directly into SQL because path expressions cannot be passed + # as bind variables on Oracle. + return sql % tuple(params), [] + + def as_postgresql(self, compiler, connection): + if isinstance(self.rhs, KeyTransform): + *_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection) + for key in rhs_key_transforms[:-1]: + self.lhs = KeyTransform(key, self.lhs) + self.rhs = rhs_key_transforms[-1] + return super().as_postgresql(compiler, connection) + + def as_sqlite(self, compiler, connection): + return self.as_sql(compiler, connection, template='JSON_TYPE(%s, %%s) IS NOT NULL') + + +class HasKey(HasKeyLookup): + lookup_name = 'has_key' + postgres_operator = '?' + prepare_rhs = False + + +class HasKeys(HasKeyLookup): + lookup_name = 'has_keys' + postgres_operator = '?&' + logical_operator = ' AND ' + + def get_prep_lookup(self): + return [str(item) for item in self.rhs] + + +class HasAnyKeys(HasKeys): + lookup_name = 'has_any_keys' + postgres_operator = '?|' + logical_operator = ' OR ' + + +class JSONExact(lookups.Exact): + can_use_none_as_rhs = True + + def process_lhs(self, compiler, connection): + lhs, lhs_params = super().process_lhs(compiler, connection) + if connection.vendor == 'sqlite': + rhs, rhs_params = super().process_rhs(compiler, connection) + if rhs == '%s' and rhs_params == [None]: + # Use JSON_TYPE instead of JSON_EXTRACT for NULLs. + lhs = "JSON_TYPE(%s, '$')" % lhs + return lhs, lhs_params + + def process_rhs(self, compiler, connection): + rhs, rhs_params = super().process_rhs(compiler, connection) + # Treat None lookup values as null. + if rhs == '%s' and rhs_params == [None]: + rhs_params = ['null'] + if connection.vendor == 'mysql': + func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params) + rhs = rhs % tuple(func) + return rhs, rhs_params + + +JSONField.register_lookup(DataContains) +JSONField.register_lookup(ContainedBy) +JSONField.register_lookup(HasKey) +JSONField.register_lookup(HasKeys) +JSONField.register_lookup(HasAnyKeys) +JSONField.register_lookup(JSONExact) + + +class KeyTransform(Transform): + postgres_operator = '->' + postgres_nested_operator = '#>' + + def __init__(self, key_name, *args, **kwargs): + super().__init__(*args, **kwargs) + self.key_name = str(key_name) + + def preprocess_lhs(self, compiler, connection, lhs_only=False): + if not lhs_only: + key_transforms = [self.key_name] + previous = self.lhs + while isinstance(previous, KeyTransform): + if not lhs_only: + key_transforms.insert(0, previous.key_name) + previous = previous.lhs + lhs, params = compiler.compile(previous) + if connection.vendor == 'oracle': + # Escape string-formatting. + key_transforms = [key.replace('%', '%%') for key in key_transforms] + return (lhs, params, key_transforms) if not lhs_only else (lhs, params) + + def as_mysql(self, compiler, connection): + lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) + json_path = compile_json_path(key_transforms) + return 'JSON_EXTRACT(%s, %%s)' % lhs, tuple(params) + (json_path,) + + def as_oracle(self, compiler, connection): + lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) + json_path = compile_json_path(key_transforms) + return ( + "COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))" % + ((lhs, json_path) * 2) + ), tuple(params) * 2 + + def as_postgresql(self, compiler, connection): + lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) + if len(key_transforms) > 1: + return '(%s %s %%s)' % (lhs, self.postgres_nested_operator), params + [key_transforms] + try: + lookup = int(self.key_name) + except ValueError: + lookup = self.key_name + return '(%s %s %%s)' % (lhs, self.postgres_operator), tuple(params) + (lookup,) + + def as_sqlite(self, compiler, connection): + lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) + json_path = compile_json_path(key_transforms) + return 'JSON_EXTRACT(%s, %%s)' % lhs, tuple(params) + (json_path,) + + +class KeyTextTransform(KeyTransform): + postgres_operator = '->>' + postgres_nested_operator = '#>>' + + +class KeyTransformTextLookupMixin: + """ + Mixin for combining with a lookup expecting a text lhs from a JSONField + key lookup. On PostgreSQL, make use of the ->> operator instead of casting + key values to text and performing the lookup on the resulting + representation. + """ + def __init__(self, key_transform, *args, **kwargs): + if not isinstance(key_transform, KeyTransform): + raise TypeError( + 'Transform should be an instance of KeyTransform in order to ' + 'use this lookup.' + ) + key_text_transform = KeyTextTransform( + key_transform.key_name, *key_transform.source_expressions, + **key_transform.extra, + ) + super().__init__(key_text_transform, *args, **kwargs) + + +class CaseInsensitiveMixin: + """ + Mixin to allow case-insensitive comparison of JSON values on MySQL. + MySQL handles strings used in JSON context using the utf8mb4_bin collation. + Because utf8mb4_bin is a binary collation, comparison of JSON values is + case-sensitive. + """ + def process_lhs(self, compiler, connection): + lhs, lhs_params = super().process_lhs(compiler, connection) + if connection.vendor == 'mysql': + return 'LOWER(%s)' % lhs, lhs_params + return lhs, lhs_params + + def process_rhs(self, compiler, connection): + rhs, rhs_params = super().process_rhs(compiler, connection) + if connection.vendor == 'mysql': + return 'LOWER(%s)' % rhs, rhs_params + return rhs, rhs_params + + +class KeyTransformIsNull(lookups.IsNull): + # key__isnull=False is the same as has_key='key' + def as_oracle(self, compiler, connection): + if not self.rhs: + return HasKey(self.lhs.lhs, self.lhs.key_name).as_oracle(compiler, connection) + return super().as_sql(compiler, connection) + + def as_sqlite(self, compiler, connection): + if not self.rhs: + return HasKey(self.lhs.lhs, self.lhs.key_name).as_sqlite(compiler, connection) + return super().as_sql(compiler, connection) + + +class KeyTransformExact(JSONExact): + def process_lhs(self, compiler, connection): + lhs, lhs_params = super().process_lhs(compiler, connection) + if connection.vendor == 'sqlite': + rhs, rhs_params = super().process_rhs(compiler, connection) + if rhs == '%s' and rhs_params == ['null']: + lhs, _ = self.lhs.preprocess_lhs(compiler, connection, lhs_only=True) + lhs = 'JSON_TYPE(%s, %%s)' % lhs + return lhs, lhs_params + + def process_rhs(self, compiler, connection): + if isinstance(self.rhs, KeyTransform): + return super(lookups.Exact, self).process_rhs(compiler, connection) + rhs, rhs_params = super().process_rhs(compiler, connection) + if connection.vendor == 'oracle': + func = [] + for value in rhs_params: + value = json.loads(value) + function = 'JSON_QUERY' if isinstance(value, (list, dict)) else 'JSON_VALUE' + func.append("%s('%s', '$.value')" % ( + function, + json.dumps({'value': value}), + )) + rhs = rhs % tuple(func) + rhs_params = [] + elif connection.vendor == 'sqlite': + func = ["JSON_EXTRACT(%s, '$')" if value != 'null' else '%s' for value in rhs_params] + rhs = rhs % tuple(func) + return rhs, rhs_params + + def as_oracle(self, compiler, connection): + rhs, rhs_params = super().process_rhs(compiler, connection) + if rhs_params == ['null']: + # Field has key and it's NULL. + has_key_expr = HasKey(self.lhs.lhs, self.lhs.key_name) + has_key_sql, has_key_params = has_key_expr.as_oracle(compiler, connection) + is_null_expr = self.lhs.get_lookup('isnull')(self.lhs, True) + is_null_sql, is_null_params = is_null_expr.as_sql(compiler, connection) + return ( + '%s AND %s' % (has_key_sql, is_null_sql), + tuple(has_key_params) + tuple(is_null_params), + ) + return super().as_sql(compiler, connection) + + +class KeyTransformIExact(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact): + pass + + +class KeyTransformIContains(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains): + pass + + +class KeyTransformContains(KeyTransformTextLookupMixin, lookups.Contains): + pass + + +class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith): + pass + + +class KeyTransformIStartsWith(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith): + pass + + +class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith): + pass + + +class KeyTransformIEndsWith(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith): + pass + + +class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex): + pass + + +class KeyTransformIRegex(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex): + pass + + +class KeyTransformNumericLookupMixin: + def process_rhs(self, compiler, connection): + rhs, rhs_params = super().process_rhs(compiler, connection) + if not connection.features.has_native_json_field: + rhs_params = [json.loads(value) for value in rhs_params] + return rhs, rhs_params + + +class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan): + pass + + +class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual): + pass + + +class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan): + pass + + +class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual): + pass + + +KeyTransform.register_lookup(KeyTransformExact) +KeyTransform.register_lookup(KeyTransformIExact) +KeyTransform.register_lookup(KeyTransformIsNull) +KeyTransform.register_lookup(KeyTransformContains) +KeyTransform.register_lookup(KeyTransformIContains) +KeyTransform.register_lookup(KeyTransformStartsWith) +KeyTransform.register_lookup(KeyTransformIStartsWith) +KeyTransform.register_lookup(KeyTransformEndsWith) +KeyTransform.register_lookup(KeyTransformIEndsWith) +KeyTransform.register_lookup(KeyTransformRegex) +KeyTransform.register_lookup(KeyTransformIRegex) + +KeyTransform.register_lookup(KeyTransformLt) +KeyTransform.register_lookup(KeyTransformLte) +KeyTransform.register_lookup(KeyTransformGt) +KeyTransform.register_lookup(KeyTransformGte) + + +class KeyTransformFactory: + + def __init__(self, key_name): + self.key_name = key_name + + def __call__(self, *args, **kwargs): + return KeyTransform(self.key_name, *args, **kwargs) diff --git a/django/db/models/functions/comparison.py b/django/db/models/functions/comparison.py index 24c3c4b4b8..6dc235bffb 100644 --- a/django/db/models/functions/comparison.py +++ b/django/db/models/functions/comparison.py @@ -29,8 +29,14 @@ class Cast(Func): return self.as_sql(compiler, connection, **extra_context) def as_mysql(self, compiler, connection, **extra_context): + template = None + output_type = self.output_field.get_internal_type() # MySQL doesn't support explicit cast to float. - template = '(%(expressions)s + 0.0)' if self.output_field.get_internal_type() == 'FloatField' else None + if output_type == 'FloatField': + template = '(%(expressions)s + 0.0)' + # MariaDB doesn't support explicit cast to JSON. + elif output_type == 'JSONField' and connection.mysql_is_mariadb: + template = "JSON_EXTRACT(%(expressions)s, '$')" return self.as_sql(compiler, connection, template=template, **extra_context) def as_postgresql(self, compiler, connection, **extra_context): @@ -39,6 +45,13 @@ class Cast(Func): # expression. return self.as_sql(compiler, connection, template='(%(expressions)s)::%(db_type)s', **extra_context) + def as_oracle(self, compiler, connection, **extra_context): + if self.output_field.get_internal_type() == 'JSONField': + # Oracle doesn't support explicit cast to JSON. + template = "JSON_QUERY(%(expressions)s, '$')" + return super().as_sql(compiler, connection, template=template, **extra_context) + return self.as_sql(compiler, connection, **extra_context) + class Coalesce(Func): """Return, from left to right, the first non-null expression.""" diff --git a/django/forms/fields.py b/django/forms/fields.py index c5374c7e9d..36dad72704 100644 --- a/django/forms/fields.py +++ b/django/forms/fields.py @@ -4,6 +4,7 @@ Field classes. import copy import datetime +import json import math import operator import os @@ -21,8 +22,8 @@ from django.forms.widgets import ( FILE_INPUT_CONTRADICTION, CheckboxInput, ClearableFileInput, DateInput, DateTimeInput, EmailInput, FileInput, HiddenInput, MultipleHiddenInput, NullBooleanSelect, NumberInput, Select, SelectMultiple, - SplitDateTimeWidget, SplitHiddenDateTimeWidget, TextInput, TimeInput, - URLInput, + SplitDateTimeWidget, SplitHiddenDateTimeWidget, Textarea, TextInput, + TimeInput, URLInput, ) from django.utils import formats from django.utils.dateparse import parse_datetime, parse_duration @@ -38,7 +39,8 @@ __all__ = ( 'BooleanField', 'NullBooleanField', 'ChoiceField', 'MultipleChoiceField', 'ComboField', 'MultiValueField', 'FloatField', 'DecimalField', 'SplitDateTimeField', 'GenericIPAddressField', 'FilePathField', - 'SlugField', 'TypedChoiceField', 'TypedMultipleChoiceField', 'UUIDField', + 'JSONField', 'SlugField', 'TypedChoiceField', 'TypedMultipleChoiceField', + 'UUIDField', ) @@ -1211,3 +1213,66 @@ class UUIDField(CharField): except ValueError: raise ValidationError(self.error_messages['invalid'], code='invalid') return value + + +class InvalidJSONInput(str): + pass + + +class JSONString(str): + pass + + +class JSONField(CharField): + default_error_messages = { + 'invalid': _('Enter a valid JSON.'), + } + widget = Textarea + + def __init__(self, encoder=None, decoder=None, **kwargs): + self.encoder = encoder + self.decoder = decoder + super().__init__(**kwargs) + + def to_python(self, value): + if self.disabled: + return value + if value in self.empty_values: + return None + elif isinstance(value, (list, dict, int, float, JSONString)): + return value + try: + converted = json.loads(value, cls=self.decoder) + except json.JSONDecodeError: + raise ValidationError( + self.error_messages['invalid'], + code='invalid', + params={'value': value}, + ) + if isinstance(converted, str): + return JSONString(converted) + else: + return converted + + def bound_data(self, data, initial): + if self.disabled: + return initial + try: + return json.loads(data, cls=self.decoder) + except json.JSONDecodeError: + return InvalidJSONInput(data) + + def prepare_value(self, value): + if isinstance(value, InvalidJSONInput): + return value + return json.dumps(value, cls=self.encoder) + + def has_changed(self, initial, data): + if super().has_changed(initial, data): + return True + # For purposes of seeing whether something has changed, True isn't the + # same as 1 and the order of keys doesn't matter. + return ( + json.dumps(initial, sort_keys=True, cls=self.encoder) != + json.dumps(self.to_python(data), sort_keys=True, cls=self.encoder) + ) |
