summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--django/db/backends/base/features.py3
-rw-r--r--django/db/backends/base/operations.py25
-rw-r--r--django/db/backends/mysql/operations.py14
-rw-r--r--django/db/backends/oracle/operations.py24
-rw-r--r--django/db/backends/postgresql/features.py1
-rw-r--r--django/db/models/lookups.py15
-rw-r--r--tests/lookup/tests.py66
7 files changed, 84 insertions, 64 deletions
diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py
index 6770c177c1..c9e78b5746 100644
--- a/django/db/backends/base/features.py
+++ b/django/db/backends/base/features.py
@@ -78,6 +78,9 @@ class BaseDatabaseFeatures:
# Does the backend ignore unnecessary ORDER BY clauses in subqueries?
ignores_unnecessary_order_by_in_subqueries = True
+ # Is there a true datatype for boolean?
+ has_native_boolean_field = False
+
# Is there a true datatype for uuid?
has_native_uuid_field = False
diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py
index 29ef7d93d1..db45d6922e 100644
--- a/django/db/backends/base/operations.py
+++ b/django/db/backends/base/operations.py
@@ -8,8 +8,10 @@ import sqlparse
from django.conf import settings
from django.db import NotSupportedError, models, transaction
-from django.db.models.expressions import Col
+from django.db.models import Exists, ExpressionWrapper, Lookup
+from django.db.models.expressions import Col, RawSQL
from django.db.models.fields.composite import CompositePrimaryKey
+from django.db.models.sql.where import WhereNode
from django.utils import timezone
from django.utils.duration import duration_microseconds
from django.utils.encoding import force_str
@@ -716,10 +718,25 @@ class BaseDatabaseOperations:
def conditional_expression_supported_in_where_clause(self, expression):
"""
- Return True, if the conditional expression is supported in the WHERE
- clause.
+ Return True, if the conditional expression is directly supported in the
+ WHERE clause.
"""
- return True
+ # If the backend supports native boolean field it can accept any
+ # direct conditional expression usage.
+ if self.connection.features.has_native_boolean_field:
+ return True
+ # Most backends support direct EXISTS and lookups usage.
+ if isinstance(expression, (Exists, Lookup, WhereNode)):
+ return True
+ # Nested expression wrappers should be unwrapped.
+ if isinstance(expression, ExpressionWrapper) and expression.conditional:
+ return self.conditional_expression_supported_in_where_clause(
+ expression.expression
+ )
+ # Trust that direct usage of RawSQL can be used by itself.
+ if isinstance(expression, RawSQL) and expression.conditional:
+ return True
+ return False
def combine_expression(self, connector, sub_expressions):
"""
diff --git a/django/db/backends/mysql/operations.py b/django/db/backends/mysql/operations.py
index 61fc9da3f4..7dee707820 100644
--- a/django/db/backends/mysql/operations.py
+++ b/django/db/backends/mysql/operations.py
@@ -3,7 +3,6 @@ import uuid
from django.conf import settings
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.backends.utils import split_tzname_delta
-from django.db.models import Exists, ExpressionWrapper, Lookup
from django.db.models.constants import OnConflict
from django.utils import timezone
from django.utils.encoding import force_str
@@ -393,19 +392,6 @@ class DatabaseOperations(BaseDatabaseOperations):
lookup = "JSON_UNQUOTE(%s)"
return lookup
- def conditional_expression_supported_in_where_clause(self, expression):
- # MySQL ignores indexes with boolean fields unless they're compared
- # directly to a boolean value.
- if isinstance(expression, (Exists, Lookup)):
- return True
- if isinstance(expression, ExpressionWrapper) and expression.conditional:
- return self.conditional_expression_supported_in_where_clause(
- expression.expression
- )
- if getattr(expression, "conditional", False):
- return False
- return super().conditional_expression_supported_in_where_clause(expression)
-
def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
if on_conflict == OnConflict.UPDATE:
conflict_suffix_sql = "ON DUPLICATE KEY UPDATE %(fields)s"
diff --git a/django/db/backends/oracle/operations.py b/django/db/backends/oracle/operations.py
index 802f27a3c6..d946e37fac 100644
--- a/django/db/backends/oracle/operations.py
+++ b/django/db/backends/oracle/operations.py
@@ -6,14 +6,7 @@ from django.conf import settings
from django.db import NotSupportedError
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.backends.utils import split_tzname_delta, strip_quotes, truncate_name
-from django.db.models import (
- AutoField,
- Exists,
- ExpressionWrapper,
- Lookup,
-)
-from django.db.models.expressions import RawSQL
-from django.db.models.sql.where import WhereNode
+from django.db.models import AutoField
from django.utils import timezone
from django.utils.encoding import force_bytes, force_str
from django.utils.functional import cached_property
@@ -705,21 +698,6 @@ END;
)
return super().subtract_temporals(internal_type, lhs, rhs)
- def conditional_expression_supported_in_where_clause(self, expression):
- """
- Oracle supports only EXISTS(...) or filters in the WHERE clause, others
- must be compared with True.
- """
- if isinstance(expression, (Exists, Lookup, WhereNode)):
- return True
- if isinstance(expression, ExpressionWrapper) and expression.conditional:
- return self.conditional_expression_supported_in_where_clause(
- expression.expression
- )
- if isinstance(expression, RawSQL) and expression.conditional:
- return True
- return False
-
def format_json_path_numeric_index(self, num):
if num < 0:
return "[last-%s]" % abs(num + 1) # Indexing is zero-based.
diff --git a/django/db/backends/postgresql/features.py b/django/db/backends/postgresql/features.py
index d3fae82a10..b4a2575475 100644
--- a/django/db/backends/postgresql/features.py
+++ b/django/db/backends/postgresql/features.py
@@ -13,6 +13,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
can_return_rows_from_bulk_insert = True
can_return_rows_from_update = True
has_real_datatype = True
+ has_native_boolean_field = True
has_native_uuid_field = True
has_native_duration_field = True
has_native_json_field = True
diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py
index 4c08999fb6..eef7bc93a5 100644
--- a/django/db/models/lookups.py
+++ b/django/db/models/lookups.py
@@ -151,11 +151,16 @@ class Lookup(Expression):
# expression unless they're wrapped in a CASE WHEN.
wrapped = False
exprs = []
- for expr in (self.lhs, self.rhs):
- if connection.ops.conditional_expression_supported_in_where_clause(expr):
- expr = Case(When(expr, then=True), default=False)
- wrapped = True
- exprs.append(expr)
+ if getattr(self.lhs, "conditional", False) and getattr(
+ self.rhs, "conditional", False
+ ):
+ for expr in (self.lhs, self.rhs):
+ if connection.ops.conditional_expression_supported_in_where_clause(
+ expr
+ ):
+ expr = Case(When(expr, then=True), default=False)
+ wrapped = True
+ exprs.append(expr)
lookup = type(self)(*exprs) if wrapped else self
return lookup.as_sql(compiler, connection)
diff --git a/tests/lookup/tests.py b/tests/lookup/tests.py
index b154541e78..9314fa05b0 100644
--- a/tests/lookup/tests.py
+++ b/tests/lookup/tests.py
@@ -2,7 +2,7 @@ import collections.abc
from datetime import datetime
from math import ceil
from operator import attrgetter
-from unittest import mock, skipUnless
+from unittest import mock
from django.core.exceptions import FieldError
from django.db import connection, models
@@ -19,6 +19,7 @@ from django.db.models import (
Value,
When,
)
+from django.db.models.expressions import RawSQL
from django.db.models.functions import Abs, Cast, Length, Substr
from django.db.models.lookups import (
Exact,
@@ -29,7 +30,7 @@ from django.db.models.lookups import (
LessThan,
LessThanOrEqual,
)
-from django.test import TestCase, skipUnlessDBFeature
+from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
from django.test.utils import ignore_warnings, isolate_apps, register_lookup
from django.utils.deprecation import RemovedInDjango70Warning
@@ -1591,10 +1592,10 @@ class LookupTests(TestCase):
with self.assertRaisesMessage(ValueError, msg):
list(Article.objects.filter(author=Author.objects.all()[1:]))
- @skipUnless(connection.vendor == "mysql", "MySQL-specific workaround.")
+ @skipIfDBFeature("has_native_boolean_field")
def test_exact_booleanfield(self):
- # MySQL ignores indexes with boolean fields unless they're compared
- # directly to a boolean value.
+ # Most databases without a native boolean type ignore indexes on them
+ # unless they're compared directly to a literal value.
product = Product.objects.create(name="Paper", qty_target=5000)
Stock.objects.create(product=product, short=False, qty_available=5100)
stock_1 = Stock.objects.create(product=product, short=True, qty_available=180)
@@ -1605,31 +1606,60 @@ class LookupTests(TestCase):
str(qs.query),
)
- @skipUnless(connection.vendor == "mysql", "MySQL-specific workaround.")
+ @skipIfDBFeature("has_native_boolean_field")
def test_exact_booleanfield_annotation(self):
- # MySQL ignores indexes with boolean fields unless they're compared
- # directly to a boolean value.
- qs = Author.objects.annotate(
- case=Case(
- When(alias="a1", then=True),
- default=False,
+ # Most databases without a native boolean type ignore indexes on them
+ # unless they're compared directly to a literal value.
+ product = Product.objects.create(name="Paper", qty_target=5000)
+ Stock.objects.create(product=product, short=False, qty_available=5100)
+ stock_1 = Stock.objects.create(product=product, short=True, qty_available=180)
+ qs = Stock.objects.annotate(
+ short_annotation=F("short"),
+ ).filter(short_annotation=True)
+ self.assertSequenceEqual(qs, [stock_1])
+ self.assertIn(" = True", str(qs.query))
+ # ExpressionWrapper should be unwrapped.
+ qs = Stock.objects.annotate(
+ short_wrapper=ExpressionWrapper(
+ F("short"),
output_field=BooleanField(),
)
- ).filter(case=True)
- self.assertSequenceEqual(qs, [self.au1])
+ ).filter(short_wrapper=True)
+ self.assertSequenceEqual(qs, [stock_1])
self.assertIn(" = True", str(qs.query))
-
+ # Q which resolve to WhereNode should not be compared to a boolean
+ # value as it's compatible by definition.
qs = Author.objects.annotate(
- wrapped=ExpressionWrapper(Q(alias="a1"), output_field=BooleanField()),
- ).filter(wrapped=True)
+ node=Q(alias="a1"),
+ ).filter(node=True)
self.assertSequenceEqual(qs, [self.au1])
- self.assertIn(" = True", str(qs.query))
+ self.assertNotIn(" = True", str(qs.query))
# EXISTS(...) shouldn't be compared to a boolean value.
qs = Author.objects.annotate(
exists=Exists(Author.objects.filter(alias="a1", pk=OuterRef("pk"))),
).filter(exists=True)
self.assertSequenceEqual(qs, [self.au1])
self.assertNotIn(" = True", str(qs.query))
+ # CASE shouldn't be compared to a boolean value.
+ qs = Author.objects.annotate(
+ case=Case(
+ When(alias="a1", then=True),
+ default=False,
+ output_field=BooleanField(),
+ )
+ ).filter(case=True)
+ self.assertSequenceEqual(qs, [self.au1])
+ self.assertEqual(str(qs.query).count(" = True"), 1)
+ # Conditional usage of RawSQL usage should not be compared to a boolean
+ # value.
+ queryset = Author.objects.all()
+ compiler = queryset.query.get_compiler(connection=connection)
+ sql, params = compiler.compile(Q(alias="a1").resolve_expression(queryset.query))
+ qs = Author.objects.alias(
+ raw=RawSQL(sql, params, BooleanField()),
+ ).filter(raw=True)
+ self.assertSequenceEqual(qs, [self.au1])
+ self.assertNotIn(" = True", str(qs.query))
def test_custom_field_none_rhs(self):
"""