diff options
| author | Simon Charette <charette.s@gmail.com> | 2022-11-06 11:19:33 -0500 |
|---|---|---|
| committer | Mariusz Felisiak <felisiak.mariusz@gmail.com> | 2022-11-07 20:23:53 +0100 |
| commit | 76e37513e22f4d9a01c7f15eee36fe44388e6670 (patch) | |
| tree | 575decec7547c3c128857b0444f342452865a0f9 /django/db/models | |
| parent | 4b702c832cd550fe682ef37a69e93866815b9123 (diff) | |
Refs #33374 -- Adjusted full match condition handling.
Adjusting WhereNode.as_sql() to raise an exception when encoutering a
full match just like with empty matches ensures that all case are
explicitly handled.
Diffstat (limited to 'django/db/models')
| -rw-r--r-- | django/db/models/aggregates.py | 9 | ||||
| -rw-r--r-- | django/db/models/expressions.py | 17 | ||||
| -rw-r--r-- | django/db/models/fields/__init__.py | 9 | ||||
| -rw-r--r-- | django/db/models/sql/compiler.py | 37 | ||||
| -rw-r--r-- | django/db/models/sql/datastructures.py | 8 | ||||
| -rw-r--r-- | django/db/models/sql/where.py | 25 |
6 files changed, 58 insertions, 47 deletions
diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index ab38a33bf0..7878fb6fb2 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -1,7 +1,7 @@ """ Classes to represent the definitions of aggregate functions. """ -from django.core.exceptions import FieldError +from django.core.exceptions import FieldError, FullResultSet from django.db.models.expressions import Case, Func, Star, When from django.db.models.fields import IntegerField from django.db.models.functions.comparison import Coalesce @@ -104,8 +104,11 @@ class Aggregate(Func): extra_context["distinct"] = "DISTINCT " if self.distinct else "" if self.filter: if connection.features.supports_aggregate_filter_clause: - filter_sql, filter_params = self.filter.as_sql(compiler, connection) - if filter_sql: + try: + filter_sql, filter_params = self.filter.as_sql(compiler, connection) + except FullResultSet: + pass + else: template = self.filter_template % extra_context.get( "template", self.template ) diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 8b04e1f11b..86a3a92f07 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -7,7 +7,7 @@ from collections import defaultdict from decimal import Decimal from uuid import UUID -from django.core.exceptions import EmptyResultSet, FieldError +from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet from django.db import DatabaseError, NotSupportedError, connection from django.db.models import fields from django.db.models.constants import LOOKUP_SEP @@ -955,6 +955,8 @@ class Func(SQLiteNumericMixin, Expression): if empty_result_set_value is NotImplemented: raise arg_sql, arg_params = compiler.compile(Value(empty_result_set_value)) + except FullResultSet: + arg_sql, arg_params = compiler.compile(Value(True)) sql_parts.append(arg_sql) params.extend(arg_params) data = {**self.extra, **extra_context} @@ -1367,14 +1369,6 @@ class When(Expression): template_params = extra_context sql_params = [] condition_sql, condition_params = compiler.compile(self.condition) - # Filters that match everything are handled as empty strings in the - # WHERE clause, but in a CASE WHEN expression they must use a predicate - # that's always True. - if condition_sql == "": - if connection.features.supports_boolean_expr_in_select_clause: - condition_sql, condition_params = compiler.compile(Value(True)) - else: - condition_sql, condition_params = "1=1", () template_params["condition"] = condition_sql result_sql, result_params = compiler.compile(self.result) template_params["result"] = result_sql @@ -1461,14 +1455,17 @@ class Case(SQLiteNumericMixin, Expression): template_params = {**self.extra, **extra_context} case_parts = [] sql_params = [] + default_sql, default_params = compiler.compile(self.default) for case in self.cases: try: case_sql, case_params = compiler.compile(case) except EmptyResultSet: continue + except FullResultSet: + default_sql, default_params = compiler.compile(case.result) + break case_parts.append(case_sql) sql_params.extend(case_params) - default_sql, default_params = compiler.compile(self.default) if not case_parts: return default_sql, default_params case_joiner = case_joiner or self.case_joiner diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index 5069a491e8..2a98396cad 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -1103,15 +1103,6 @@ class BooleanField(Field): defaults = {"form_class": form_class, "required": False} return super().formfield(**{**defaults, **kwargs}) - def select_format(self, compiler, sql, params): - sql, params = super().select_format(compiler, sql, params) - # Filters that match everything are handled as empty strings in the - # WHERE clause, but in SELECT or GROUP BY list they must use a - # predicate that's always True. - if sql == "": - sql = "1" - return sql, params - class CharField(Field): description = _("String (up to %(max_length)s)") diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 97c7ba2013..170bde1d42 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -4,7 +4,7 @@ import re from functools import partial from itertools import chain -from django.core.exceptions import EmptyResultSet, FieldError +from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet from django.db import DatabaseError, NotSupportedError from django.db.models.constants import LOOKUP_SEP from django.db.models.expressions import F, OrderBy, RawSQL, Ref, Value @@ -169,7 +169,7 @@ class SQLCompiler: expr = Ref(alias, expr) try: sql, params = self.compile(expr) - except EmptyResultSet: + except (EmptyResultSet, FullResultSet): continue sql, params = expr.select_format(self, sql, params) params_hash = make_hashable(params) @@ -287,6 +287,8 @@ class SQLCompiler: sql, params = "0", () else: sql, params = self.compile(Value(empty_result_set_value)) + except FullResultSet: + sql, params = self.compile(Value(True)) else: sql, params = col.select_format(self, sql, params) if alias is None and with_col_aliases: @@ -721,9 +723,16 @@ class SQLCompiler: raise # Use a predicate that's always False. where, w_params = "0 = 1", [] - having, h_params = ( - self.compile(self.having) if self.having is not None else ("", []) - ) + except FullResultSet: + where, w_params = "", [] + try: + having, h_params = ( + self.compile(self.having) + if self.having is not None + else ("", []) + ) + except FullResultSet: + having, h_params = "", [] result = ["SELECT"] params = [] @@ -1817,11 +1826,12 @@ class SQLDeleteCompiler(SQLCompiler): ) def _as_sql(self, query): - result = ["DELETE FROM %s" % self.quote_name_unless_alias(query.base_table)] - where, params = self.compile(query.where) - if where: - result.append("WHERE %s" % where) - return " ".join(result), tuple(params) + delete = "DELETE FROM %s" % self.quote_name_unless_alias(query.base_table) + try: + where, params = self.compile(query.where) + except FullResultSet: + return delete, () + return f"{delete} WHERE {where}", tuple(params) def as_sql(self): """ @@ -1906,8 +1916,11 @@ class SQLUpdateCompiler(SQLCompiler): "UPDATE %s SET" % qn(table), ", ".join(values), ] - where, params = self.compile(self.query.where) - if where: + try: + where, params = self.compile(self.query.where) + except FullResultSet: + params = [] + else: result.append("WHERE %s" % where) return " ".join(result), tuple(update_params + params) diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py index 1edf040e82..069eb1a301 100644 --- a/django/db/models/sql/datastructures.py +++ b/django/db/models/sql/datastructures.py @@ -2,6 +2,7 @@ Useful auxiliary data structures for query construction. Not useful outside the SQL domain. """ +from django.core.exceptions import FullResultSet from django.db.models.sql.constants import INNER, LOUTER @@ -100,8 +101,11 @@ class Join: join_conditions.append("(%s)" % extra_sql) params.extend(extra_params) if self.filtered_relation: - extra_sql, extra_params = compiler.compile(self.filtered_relation) - if extra_sql: + try: + extra_sql, extra_params = compiler.compile(self.filtered_relation) + except FullResultSet: + pass + else: join_conditions.append("(%s)" % extra_sql) params.extend(extra_params) if not join_conditions: diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 63fdf58d9d..1928ba91b8 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -4,7 +4,7 @@ Code to manage the creation and SQL rendering of 'where' constraints. import operator from functools import reduce -from django.core.exceptions import EmptyResultSet +from django.core.exceptions import EmptyResultSet, FullResultSet from django.db.models.expressions import Case, When from django.db.models.lookups import Exact from django.utils import tree @@ -145,6 +145,8 @@ class WhereNode(tree.Node): sql, params = compiler.compile(child) except EmptyResultSet: empty_needed -= 1 + except FullResultSet: + full_needed -= 1 else: if sql: result.append(sql) @@ -158,24 +160,25 @@ class WhereNode(tree.Node): # counts. if empty_needed == 0: if self.negated: - return "", [] + raise FullResultSet else: raise EmptyResultSet if full_needed == 0: if self.negated: raise EmptyResultSet else: - return "", [] + raise FullResultSet conn = " %s " % self.connector sql_string = conn.join(result) - if sql_string: - if self.negated: - # Some backends (Oracle at least) need parentheses - # around the inner SQL in the negated case, even if the - # inner SQL contains just a single expression. - sql_string = "NOT (%s)" % sql_string - elif len(result) > 1 or self.resolved: - sql_string = "(%s)" % sql_string + if not sql_string: + raise FullResultSet + if self.negated: + # Some backends (Oracle at least) need parentheses around the inner + # SQL in the negated case, even if the inner SQL contains just a + # single expression. + sql_string = "NOT (%s)" % sql_string + elif len(result) > 1 or self.resolved: + sql_string = "(%s)" % sql_string return sql_string, result_params def get_group_by_cols(self): |
