summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMariusz Felisiak <felisiak.mariusz@gmail.com>2021-02-26 07:52:16 +0100
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2021-03-23 08:28:47 +0100
commit71ec102b01fcc85acae3819426a4e02ef423b0fa (patch)
tree4bea91872f6ecf85f4deb2512e7d11f01d9803b0
parentc4df8b86c7fac52d95eda3440edc397fc13c3e56 (diff)
Fixed #32483 -- Fixed QuerySet.values()/values_list() on JSONField key transforms with booleans on SQLite.
Thanks Matthew Cornell for the report.
-rw-r--r--django/db/backends/sqlite3/operations.py3
-rw-r--r--django/db/models/fields/json.py38
-rw-r--r--docs/ref/models/querysets.txt12
-rw-r--r--tests/model_fields/test_jsonfield.py10
4 files changed, 30 insertions, 33 deletions
diff --git a/django/db/backends/sqlite3/operations.py b/django/db/backends/sqlite3/operations.py
index faf96a1b97..c578979777 100644
--- a/django/db/backends/sqlite3/operations.py
+++ b/django/db/backends/sqlite3/operations.py
@@ -21,6 +21,9 @@ class DatabaseOperations(BaseDatabaseOperations):
'DateTimeField': 'TEXT',
}
explain_prefix = 'EXPLAIN QUERY PLAN'
+ # List of datatypes to that cannot be extracted with JSON_EXTRACT() on
+ # SQLite. Use JSON_TYPE() instead.
+ jsonfield_datatype_values = frozenset(['null', 'false', 'true'])
def bulk_batch_size(self, fields, objs):
"""
diff --git a/django/db/models/fields/json.py b/django/db/models/fields/json.py
index bd12bba7ac..efb4e2f6ed 100644
--- a/django/db/models/fields/json.py
+++ b/django/db/models/fields/json.py
@@ -260,15 +260,6 @@ class CaseInsensitiveMixin:
class JSONExact(lookups.Exact):
can_use_none_as_rhs = True
- def process_lhs(self, compiler, connection):
- lhs, lhs_params = super().process_lhs(compiler, connection)
- if connection.vendor == 'sqlite':
- rhs, rhs_params = super().process_rhs(compiler, connection)
- if rhs == '%s' and rhs_params == [None]:
- # Use JSON_TYPE instead of JSON_EXTRACT for NULLs.
- lhs = "JSON_TYPE(%s, '$')" % lhs
- return lhs, lhs_params
-
def process_rhs(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
# Treat None lookup values as null.
@@ -340,7 +331,13 @@ class KeyTransform(Transform):
def as_sqlite(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
json_path = compile_json_path(key_transforms)
- return 'JSON_EXTRACT(%s, %%s)' % lhs, tuple(params) + (json_path,)
+ datatype_values = ','.join([
+ repr(datatype) for datatype in connection.ops.jsonfield_datatype_values
+ ])
+ return (
+ "(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) "
+ "THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)"
+ ) % (lhs, datatype_values, lhs, lhs), (tuple(params) + (json_path,)) * 3
class KeyTextTransform(KeyTransform):
@@ -408,7 +405,10 @@ class KeyTransformIn(lookups.In):
sql = sql % 'JSON_QUERY'
else:
sql = sql % 'JSON_VALUE'
- elif connection.vendor in {'sqlite', 'mysql'}:
+ elif connection.vendor == 'mysql' or (
+ connection.vendor == 'sqlite' and
+ params[0] not in connection.ops.jsonfield_datatype_values
+ ):
sql = "JSON_EXTRACT(%s, '$')"
if connection.vendor == 'mysql' and connection.mysql_is_mariadb:
sql = 'JSON_UNQUOTE(%s)' % sql
@@ -416,15 +416,6 @@ class KeyTransformIn(lookups.In):
class KeyTransformExact(JSONExact):
- def process_lhs(self, compiler, connection):
- lhs, lhs_params = super().process_lhs(compiler, connection)
- if connection.vendor == 'sqlite':
- rhs, rhs_params = super().process_rhs(compiler, connection)
- if rhs == '%s' and rhs_params == ['null']:
- lhs, *_ = self.lhs.preprocess_lhs(compiler, connection)
- lhs = 'JSON_TYPE(%s, %%s)' % lhs
- return lhs, lhs_params
-
def process_rhs(self, compiler, connection):
if isinstance(self.rhs, KeyTransform):
return super(lookups.Exact, self).process_rhs(compiler, connection)
@@ -440,7 +431,12 @@ class KeyTransformExact(JSONExact):
func.append(sql % 'JSON_VALUE')
rhs = rhs % tuple(func)
elif connection.vendor == 'sqlite':
- func = ["JSON_EXTRACT(%s, '$')" if value != 'null' else '%s' for value in rhs_params]
+ func = []
+ for value in rhs_params:
+ if value in connection.ops.jsonfield_datatype_values:
+ func.append('%s')
+ else:
+ func.append("JSON_EXTRACT(%s, '$')")
rhs = rhs % tuple(func)
return rhs, rhs_params
diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt
index 68f964faf5..add56ca4fd 100644
--- a/docs/ref/models/querysets.txt
+++ b/docs/ref/models/querysets.txt
@@ -695,12 +695,6 @@ You can also refer to fields on related models with reverse relations through
pronounced if you include multiple such fields in your ``values()`` query,
in which case all possible combinations will be returned.
-.. admonition:: Boolean values for ``JSONField`` on SQLite
-
- Due to the way the ``JSON_EXTRACT`` SQL function is implemented on SQLite,
- ``values()`` will return ``1`` and ``0`` instead of ``True`` and ``False``
- for :class:`~django.db.models.JSONField` key transforms.
-
``values_list()``
~~~~~~~~~~~~~~~~~
@@ -771,12 +765,6 @@ not having any author::
>>> Entry.objects.values_list('authors')
<QuerySet [('Noam Chomsky',), ('George Orwell',), (None,)]>
-.. admonition:: Boolean values for ``JSONField`` on SQLite
-
- Due to the way the ``JSON_EXTRACT`` SQL function is implemented on SQLite,
- ``values_list()`` will return ``1`` and ``0`` instead of ``True`` and
- ``False`` for :class:`~django.db.models.JSONField` key transforms.
-
``dates()``
~~~~~~~~~~~
diff --git a/tests/model_fields/test_jsonfield.py b/tests/model_fields/test_jsonfield.py
index c903074e70..f7721aa6e8 100644
--- a/tests/model_fields/test_jsonfield.py
+++ b/tests/model_fields/test_jsonfield.py
@@ -808,6 +808,16 @@ class TestQuerying(TestCase):
with self.subTest(lookup=lookup):
self.assertEqual(qs.values_list(lookup, flat=True).get(), expected)
+ def test_key_values_boolean(self):
+ qs = NullableJSONModel.objects.filter(value__h=True, value__i=False)
+ tests = [
+ ('value__h', True),
+ ('value__i', False),
+ ]
+ for lookup, expected in tests:
+ with self.subTest(lookup=lookup):
+ self.assertIs(qs.values_list(lookup, flat=True).get(), expected)
+
@skipUnlessDBFeature('supports_json_field_contains')
def test_key_contains(self):
self.assertIs(NullableJSONModel.objects.filter(value__foo__contains='ar').exists(), False)