diff options
Diffstat (limited to 'django/db/models/sql/where.py')
| -rw-r--r-- | django/db/models/sql/where.py | 95 |
1 files changed, 75 insertions, 20 deletions
diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 42a4b054a5..e2af46a309 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -35,48 +35,81 @@ class WhereNode(tree.Node): resolved = False conditional = True - def split_having(self, negated=False): + def split_having_qualify(self, negated=False, must_group_by=False): """ - Return two possibly None nodes: one for those parts of self that - should be included in the WHERE clause and one for those parts of - self that must be included in the HAVING clause. + Return three possibly None nodes: one for those parts of self that + should be included in the WHERE clause, one for those parts of self + that must be included in the HAVING clause, and one for those parts + that refer to window functions. """ - if not self.contains_aggregate: - return self, None + if not self.contains_aggregate and not self.contains_over_clause: + return self, None, None in_negated = negated ^ self.negated - # If the effective connector is OR or XOR and this node contains an - # aggregate, then we need to push the whole branch to HAVING clause. - may_need_split = ( + # Whether or not children must be connected in the same filtering + # clause (WHERE > HAVING > QUALIFY) to maintain logical semantic. + must_remain_connected = ( (in_negated and self.connector == AND) or (not in_negated and self.connector == OR) or self.connector == XOR ) - if may_need_split and self.contains_aggregate: - return None, self + if ( + must_remain_connected + and self.contains_aggregate + and not self.contains_over_clause + ): + # It's must cheaper to short-circuit and stash everything in the + # HAVING clause than split children if possible. + return None, self, None where_parts = [] having_parts = [] + qualify_parts = [] for c in self.children: - if hasattr(c, "split_having"): - where_part, having_part = c.split_having(in_negated) + if hasattr(c, "split_having_qualify"): + where_part, having_part, qualify_part = c.split_having_qualify( + in_negated, must_group_by + ) if where_part is not None: where_parts.append(where_part) if having_part is not None: having_parts.append(having_part) + if qualify_part is not None: + qualify_parts.append(qualify_part) + elif c.contains_over_clause: + qualify_parts.append(c) elif c.contains_aggregate: having_parts.append(c) else: where_parts.append(c) + if must_remain_connected and qualify_parts: + # Disjunctive heterogeneous predicates can be pushed down to + # qualify as long as no conditional aggregation is involved. + if not where_parts or (where_parts and not must_group_by): + return None, None, self + elif where_parts: + # In theory this should only be enforced when dealing with + # where_parts containing predicates against multi-valued + # relationships that could affect aggregation results but this + # is complex to infer properly. + raise NotImplementedError( + "Heterogeneous disjunctive predicates against window functions are " + "not implemented when performing conditional aggregation." + ) + where_node = ( + self.create(where_parts, self.connector, self.negated) + if where_parts + else None + ) having_node = ( self.create(having_parts, self.connector, self.negated) if having_parts else None ) - where_node = ( - self.create(where_parts, self.connector, self.negated) - if where_parts + qualify_node = ( + self.create(qualify_parts, self.connector, self.negated) + if qualify_parts else None ) - return where_node, having_node + return where_node, having_node, qualify_node def as_sql(self, compiler, connection): """ @@ -183,6 +216,14 @@ class WhereNode(tree.Node): clone.relabel_aliases(change_map) return clone + def replace_expressions(self, replacements): + if replacement := replacements.get(self): + return replacement + clone = self.create(connector=self.connector, negated=self.negated) + for child in self.children: + clone.children.append(child.replace_expressions(replacements)) + return clone + @classmethod def _contains_aggregate(cls, obj): if isinstance(obj, tree.Node): @@ -231,6 +272,10 @@ class WhereNode(tree.Node): return BooleanField() + @property + def _output_field_or_none(self): + return self.output_field + def select_format(self, compiler, sql, params): # Wrap filters with a CASE WHEN expression if a database backend # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP @@ -245,19 +290,28 @@ class WhereNode(tree.Node): def get_lookup(self, lookup): return self.output_field.get_lookup(lookup) + def leaves(self): + for child in self.children: + if isinstance(child, WhereNode): + yield from child.leaves() + else: + yield child + class NothingNode: """A node that matches nothing.""" contains_aggregate = False + contains_over_clause = False def as_sql(self, compiler=None, connection=None): raise EmptyResultSet class ExtraWhere: - # The contents are a black box - assume no aggregates are used. + # The contents are a black box - assume no aggregates or windows are used. contains_aggregate = False + contains_over_clause = False def __init__(self, sqls, params): self.sqls = sqls @@ -269,9 +323,10 @@ class ExtraWhere: class SubqueryConstraint: - # Even if aggregates would be used in a subquery, the outer query isn't - # interested about those. + # Even if aggregates or windows would be used in a subquery, + # the outer query isn't interested about those. contains_aggregate = False + contains_over_clause = False def __init__(self, alias, columns, targets, query_object): self.alias = alias |
