summaryrefslogtreecommitdiff
path: root/django
diff options
context:
space:
mode:
authorSimon Charette <charette.s@gmail.com>2022-11-06 11:19:33 -0500
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2022-11-07 20:23:53 +0100
commit76e37513e22f4d9a01c7f15eee36fe44388e6670 (patch)
tree575decec7547c3c128857b0444f342452865a0f9 /django
parent4b702c832cd550fe682ef37a69e93866815b9123 (diff)
Refs #33374 -- Adjusted full match condition handling.
Adjusting WhereNode.as_sql() to raise an exception when encoutering a full match just like with empty matches ensures that all case are explicitly handled.
Diffstat (limited to 'django')
-rw-r--r--django/core/exceptions.py6
-rw-r--r--django/db/backends/mysql/compiler.py14
-rw-r--r--django/db/models/aggregates.py9
-rw-r--r--django/db/models/expressions.py17
-rw-r--r--django/db/models/fields/__init__.py9
-rw-r--r--django/db/models/sql/compiler.py37
-rw-r--r--django/db/models/sql/datastructures.py8
-rw-r--r--django/db/models/sql/where.py25
8 files changed, 73 insertions, 52 deletions
diff --git a/django/core/exceptions.py b/django/core/exceptions.py
index 7be4e16bc5..646644f3e0 100644
--- a/django/core/exceptions.py
+++ b/django/core/exceptions.py
@@ -233,6 +233,12 @@ class EmptyResultSet(Exception):
pass
+class FullResultSet(Exception):
+ """A database query predicate is matches everything."""
+
+ pass
+
+
class SynchronousOnlyOperation(Exception):
"""The user tried to call a sync-only function from an async context."""
diff --git a/django/db/backends/mysql/compiler.py b/django/db/backends/mysql/compiler.py
index bd2715fb43..2ec6bea2f1 100644
--- a/django/db/backends/mysql/compiler.py
+++ b/django/db/backends/mysql/compiler.py
@@ -1,4 +1,4 @@
-from django.core.exceptions import FieldError
+from django.core.exceptions import FieldError, FullResultSet
from django.db.models.expressions import Col
from django.db.models.sql import compiler
@@ -40,12 +40,16 @@ class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
"DELETE %s FROM"
% self.quote_name_unless_alias(self.query.get_initial_alias())
]
- from_sql, from_params = self.get_from_clause()
+ from_sql, params = self.get_from_clause()
result.extend(from_sql)
- where_sql, where_params = self.compile(where)
- if where_sql:
+ try:
+ where_sql, where_params = self.compile(where)
+ except FullResultSet:
+ pass
+ else:
result.append("WHERE %s" % where_sql)
- return " ".join(result), tuple(from_params) + tuple(where_params)
+ params.extend(where_params)
+ return " ".join(result), tuple(params)
class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py
index ab38a33bf0..7878fb6fb2 100644
--- a/django/db/models/aggregates.py
+++ b/django/db/models/aggregates.py
@@ -1,7 +1,7 @@
"""
Classes to represent the definitions of aggregate functions.
"""
-from django.core.exceptions import FieldError
+from django.core.exceptions import FieldError, FullResultSet
from django.db.models.expressions import Case, Func, Star, When
from django.db.models.fields import IntegerField
from django.db.models.functions.comparison import Coalesce
@@ -104,8 +104,11 @@ class Aggregate(Func):
extra_context["distinct"] = "DISTINCT " if self.distinct else ""
if self.filter:
if connection.features.supports_aggregate_filter_clause:
- filter_sql, filter_params = self.filter.as_sql(compiler, connection)
- if filter_sql:
+ try:
+ filter_sql, filter_params = self.filter.as_sql(compiler, connection)
+ except FullResultSet:
+ pass
+ else:
template = self.filter_template % extra_context.get(
"template", self.template
)
diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py
index 8b04e1f11b..86a3a92f07 100644
--- a/django/db/models/expressions.py
+++ b/django/db/models/expressions.py
@@ -7,7 +7,7 @@ from collections import defaultdict
from decimal import Decimal
from uuid import UUID
-from django.core.exceptions import EmptyResultSet, FieldError
+from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
from django.db import DatabaseError, NotSupportedError, connection
from django.db.models import fields
from django.db.models.constants import LOOKUP_SEP
@@ -955,6 +955,8 @@ class Func(SQLiteNumericMixin, Expression):
if empty_result_set_value is NotImplemented:
raise
arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
+ except FullResultSet:
+ arg_sql, arg_params = compiler.compile(Value(True))
sql_parts.append(arg_sql)
params.extend(arg_params)
data = {**self.extra, **extra_context}
@@ -1367,14 +1369,6 @@ class When(Expression):
template_params = extra_context
sql_params = []
condition_sql, condition_params = compiler.compile(self.condition)
- # Filters that match everything are handled as empty strings in the
- # WHERE clause, but in a CASE WHEN expression they must use a predicate
- # that's always True.
- if condition_sql == "":
- if connection.features.supports_boolean_expr_in_select_clause:
- condition_sql, condition_params = compiler.compile(Value(True))
- else:
- condition_sql, condition_params = "1=1", ()
template_params["condition"] = condition_sql
result_sql, result_params = compiler.compile(self.result)
template_params["result"] = result_sql
@@ -1461,14 +1455,17 @@ class Case(SQLiteNumericMixin, Expression):
template_params = {**self.extra, **extra_context}
case_parts = []
sql_params = []
+ default_sql, default_params = compiler.compile(self.default)
for case in self.cases:
try:
case_sql, case_params = compiler.compile(case)
except EmptyResultSet:
continue
+ except FullResultSet:
+ default_sql, default_params = compiler.compile(case.result)
+ break
case_parts.append(case_sql)
sql_params.extend(case_params)
- default_sql, default_params = compiler.compile(self.default)
if not case_parts:
return default_sql, default_params
case_joiner = case_joiner or self.case_joiner
diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py
index 5069a491e8..2a98396cad 100644
--- a/django/db/models/fields/__init__.py
+++ b/django/db/models/fields/__init__.py
@@ -1103,15 +1103,6 @@ class BooleanField(Field):
defaults = {"form_class": form_class, "required": False}
return super().formfield(**{**defaults, **kwargs})
- def select_format(self, compiler, sql, params):
- sql, params = super().select_format(compiler, sql, params)
- # Filters that match everything are handled as empty strings in the
- # WHERE clause, but in SELECT or GROUP BY list they must use a
- # predicate that's always True.
- if sql == "":
- sql = "1"
- return sql, params
-
class CharField(Field):
description = _("String (up to %(max_length)s)")
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index 97c7ba2013..170bde1d42 100644
--- a/django/db/models/sql/compiler.py
+++ b/django/db/models/sql/compiler.py
@@ -4,7 +4,7 @@ import re
from functools import partial
from itertools import chain
-from django.core.exceptions import EmptyResultSet, FieldError
+from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
from django.db import DatabaseError, NotSupportedError
from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import F, OrderBy, RawSQL, Ref, Value
@@ -169,7 +169,7 @@ class SQLCompiler:
expr = Ref(alias, expr)
try:
sql, params = self.compile(expr)
- except EmptyResultSet:
+ except (EmptyResultSet, FullResultSet):
continue
sql, params = expr.select_format(self, sql, params)
params_hash = make_hashable(params)
@@ -287,6 +287,8 @@ class SQLCompiler:
sql, params = "0", ()
else:
sql, params = self.compile(Value(empty_result_set_value))
+ except FullResultSet:
+ sql, params = self.compile(Value(True))
else:
sql, params = col.select_format(self, sql, params)
if alias is None and with_col_aliases:
@@ -721,9 +723,16 @@ class SQLCompiler:
raise
# Use a predicate that's always False.
where, w_params = "0 = 1", []
- having, h_params = (
- self.compile(self.having) if self.having is not None else ("", [])
- )
+ except FullResultSet:
+ where, w_params = "", []
+ try:
+ having, h_params = (
+ self.compile(self.having)
+ if self.having is not None
+ else ("", [])
+ )
+ except FullResultSet:
+ having, h_params = "", []
result = ["SELECT"]
params = []
@@ -1817,11 +1826,12 @@ class SQLDeleteCompiler(SQLCompiler):
)
def _as_sql(self, query):
- result = ["DELETE FROM %s" % self.quote_name_unless_alias(query.base_table)]
- where, params = self.compile(query.where)
- if where:
- result.append("WHERE %s" % where)
- return " ".join(result), tuple(params)
+ delete = "DELETE FROM %s" % self.quote_name_unless_alias(query.base_table)
+ try:
+ where, params = self.compile(query.where)
+ except FullResultSet:
+ return delete, ()
+ return f"{delete} WHERE {where}", tuple(params)
def as_sql(self):
"""
@@ -1906,8 +1916,11 @@ class SQLUpdateCompiler(SQLCompiler):
"UPDATE %s SET" % qn(table),
", ".join(values),
]
- where, params = self.compile(self.query.where)
- if where:
+ try:
+ where, params = self.compile(self.query.where)
+ except FullResultSet:
+ params = []
+ else:
result.append("WHERE %s" % where)
return " ".join(result), tuple(update_params + params)
diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py
index 1edf040e82..069eb1a301 100644
--- a/django/db/models/sql/datastructures.py
+++ b/django/db/models/sql/datastructures.py
@@ -2,6 +2,7 @@
Useful auxiliary data structures for query construction. Not useful outside
the SQL domain.
"""
+from django.core.exceptions import FullResultSet
from django.db.models.sql.constants import INNER, LOUTER
@@ -100,8 +101,11 @@ class Join:
join_conditions.append("(%s)" % extra_sql)
params.extend(extra_params)
if self.filtered_relation:
- extra_sql, extra_params = compiler.compile(self.filtered_relation)
- if extra_sql:
+ try:
+ extra_sql, extra_params = compiler.compile(self.filtered_relation)
+ except FullResultSet:
+ pass
+ else:
join_conditions.append("(%s)" % extra_sql)
params.extend(extra_params)
if not join_conditions:
diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py
index 63fdf58d9d..1928ba91b8 100644
--- a/django/db/models/sql/where.py
+++ b/django/db/models/sql/where.py
@@ -4,7 +4,7 @@ Code to manage the creation and SQL rendering of 'where' constraints.
import operator
from functools import reduce
-from django.core.exceptions import EmptyResultSet
+from django.core.exceptions import EmptyResultSet, FullResultSet
from django.db.models.expressions import Case, When
from django.db.models.lookups import Exact
from django.utils import tree
@@ -145,6 +145,8 @@ class WhereNode(tree.Node):
sql, params = compiler.compile(child)
except EmptyResultSet:
empty_needed -= 1
+ except FullResultSet:
+ full_needed -= 1
else:
if sql:
result.append(sql)
@@ -158,24 +160,25 @@ class WhereNode(tree.Node):
# counts.
if empty_needed == 0:
if self.negated:
- return "", []
+ raise FullResultSet
else:
raise EmptyResultSet
if full_needed == 0:
if self.negated:
raise EmptyResultSet
else:
- return "", []
+ raise FullResultSet
conn = " %s " % self.connector
sql_string = conn.join(result)
- if sql_string:
- if self.negated:
- # Some backends (Oracle at least) need parentheses
- # around the inner SQL in the negated case, even if the
- # inner SQL contains just a single expression.
- sql_string = "NOT (%s)" % sql_string
- elif len(result) > 1 or self.resolved:
- sql_string = "(%s)" % sql_string
+ if not sql_string:
+ raise FullResultSet
+ if self.negated:
+ # Some backends (Oracle at least) need parentheses around the inner
+ # SQL in the negated case, even if the inner SQL contains just a
+ # single expression.
+ sql_string = "NOT (%s)" % sql_string
+ elif len(result) > 1 or self.resolved:
+ sql_string = "(%s)" % sql_string
return sql_string, result_params
def get_group_by_cols(self):