summaryrefslogtreecommitdiff
path: root/django/db/models/sql/where.py
diff options
context:
space:
mode:
Diffstat (limited to 'django/db/models/sql/where.py')
-rw-r--r--django/db/models/sql/where.py95
1 files changed, 75 insertions, 20 deletions
diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py
index 42a4b054a5..e2af46a309 100644
--- a/django/db/models/sql/where.py
+++ b/django/db/models/sql/where.py
@@ -35,48 +35,81 @@ class WhereNode(tree.Node):
resolved = False
conditional = True
- def split_having(self, negated=False):
+ def split_having_qualify(self, negated=False, must_group_by=False):
"""
- Return two possibly None nodes: one for those parts of self that
- should be included in the WHERE clause and one for those parts of
- self that must be included in the HAVING clause.
+ Return three possibly None nodes: one for those parts of self that
+ should be included in the WHERE clause, one for those parts of self
+ that must be included in the HAVING clause, and one for those parts
+ that refer to window functions.
"""
- if not self.contains_aggregate:
- return self, None
+ if not self.contains_aggregate and not self.contains_over_clause:
+ return self, None, None
in_negated = negated ^ self.negated
- # If the effective connector is OR or XOR and this node contains an
- # aggregate, then we need to push the whole branch to HAVING clause.
- may_need_split = (
+ # Whether or not children must be connected in the same filtering
+ # clause (WHERE > HAVING > QUALIFY) to maintain logical semantic.
+ must_remain_connected = (
(in_negated and self.connector == AND)
or (not in_negated and self.connector == OR)
or self.connector == XOR
)
- if may_need_split and self.contains_aggregate:
- return None, self
+ if (
+ must_remain_connected
+ and self.contains_aggregate
+ and not self.contains_over_clause
+ ):
+ # It's must cheaper to short-circuit and stash everything in the
+ # HAVING clause than split children if possible.
+ return None, self, None
where_parts = []
having_parts = []
+ qualify_parts = []
for c in self.children:
- if hasattr(c, "split_having"):
- where_part, having_part = c.split_having(in_negated)
+ if hasattr(c, "split_having_qualify"):
+ where_part, having_part, qualify_part = c.split_having_qualify(
+ in_negated, must_group_by
+ )
if where_part is not None:
where_parts.append(where_part)
if having_part is not None:
having_parts.append(having_part)
+ if qualify_part is not None:
+ qualify_parts.append(qualify_part)
+ elif c.contains_over_clause:
+ qualify_parts.append(c)
elif c.contains_aggregate:
having_parts.append(c)
else:
where_parts.append(c)
+ if must_remain_connected and qualify_parts:
+ # Disjunctive heterogeneous predicates can be pushed down to
+ # qualify as long as no conditional aggregation is involved.
+ if not where_parts or (where_parts and not must_group_by):
+ return None, None, self
+ elif where_parts:
+ # In theory this should only be enforced when dealing with
+ # where_parts containing predicates against multi-valued
+ # relationships that could affect aggregation results but this
+ # is complex to infer properly.
+ raise NotImplementedError(
+ "Heterogeneous disjunctive predicates against window functions are "
+ "not implemented when performing conditional aggregation."
+ )
+ where_node = (
+ self.create(where_parts, self.connector, self.negated)
+ if where_parts
+ else None
+ )
having_node = (
self.create(having_parts, self.connector, self.negated)
if having_parts
else None
)
- where_node = (
- self.create(where_parts, self.connector, self.negated)
- if where_parts
+ qualify_node = (
+ self.create(qualify_parts, self.connector, self.negated)
+ if qualify_parts
else None
)
- return where_node, having_node
+ return where_node, having_node, qualify_node
def as_sql(self, compiler, connection):
"""
@@ -183,6 +216,14 @@ class WhereNode(tree.Node):
clone.relabel_aliases(change_map)
return clone
+ def replace_expressions(self, replacements):
+ if replacement := replacements.get(self):
+ return replacement
+ clone = self.create(connector=self.connector, negated=self.negated)
+ for child in self.children:
+ clone.children.append(child.replace_expressions(replacements))
+ return clone
+
@classmethod
def _contains_aggregate(cls, obj):
if isinstance(obj, tree.Node):
@@ -231,6 +272,10 @@ class WhereNode(tree.Node):
return BooleanField()
+ @property
+ def _output_field_or_none(self):
+ return self.output_field
+
def select_format(self, compiler, sql, params):
# Wrap filters with a CASE WHEN expression if a database backend
# (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
@@ -245,19 +290,28 @@ class WhereNode(tree.Node):
def get_lookup(self, lookup):
return self.output_field.get_lookup(lookup)
+ def leaves(self):
+ for child in self.children:
+ if isinstance(child, WhereNode):
+ yield from child.leaves()
+ else:
+ yield child
+
class NothingNode:
"""A node that matches nothing."""
contains_aggregate = False
+ contains_over_clause = False
def as_sql(self, compiler=None, connection=None):
raise EmptyResultSet
class ExtraWhere:
- # The contents are a black box - assume no aggregates are used.
+ # The contents are a black box - assume no aggregates or windows are used.
contains_aggregate = False
+ contains_over_clause = False
def __init__(self, sqls, params):
self.sqls = sqls
@@ -269,9 +323,10 @@ class ExtraWhere:
class SubqueryConstraint:
- # Even if aggregates would be used in a subquery, the outer query isn't
- # interested about those.
+ # Even if aggregates or windows would be used in a subquery,
+ # the outer query isn't interested about those.
contains_aggregate = False
+ contains_over_clause = False
def __init__(self, alias, columns, targets, query_object):
self.alias = alias