diff options
| author | Mads Jensen <mje@inducks.org> | 2017-09-18 15:42:29 +0200 |
|---|---|---|
| committer | Tim Graham <timograham@gmail.com> | 2017-09-18 09:42:29 -0400 |
| commit | d549b8805053d4b064bf492ba90e90db5d7e2a6b (patch) | |
| tree | 2beee237ae541804ba18367d81e82840745d6e47 /django/db/models | |
| parent | da1ba03f1dfb303df9bfb5c76d36216e45d05edc (diff) | |
Fixed #26608 -- Added support for window expressions (OVER clause).
Thanks Josh Smeaton, Mariusz Felisiak, Sergey Fedoseev, Simon Charettes,
Adam Chainz/Johnson and Tim Graham for comments and reviews and Jamie
Cockburn for initial patch.
Diffstat (limited to 'django/db/models')
| -rw-r--r-- | django/db/models/__init__.py | 9 | ||||
| -rw-r--r-- | django/db/models/aggregates.py | 1 | ||||
| -rw-r--r-- | django/db/models/expressions.py | 198 | ||||
| -rw-r--r-- | django/db/models/functions/__init__.py | 7 | ||||
| -rw-r--r-- | django/db/models/functions/window.py | 118 | ||||
| -rw-r--r-- | django/db/models/lookups.py | 4 | ||||
| -rw-r--r-- | django/db/models/sql/compiler.py | 4 | ||||
| -rw-r--r-- | django/db/models/sql/query.py | 9 | ||||
| -rw-r--r-- | django/db/models/sql/where.py | 10 |
9 files changed, 355 insertions, 5 deletions
diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py index 06806fcdc4..d29addd1f7 100644 --- a/django/db/models/__init__.py +++ b/django/db/models/__init__.py @@ -6,8 +6,8 @@ from django.db.models.deletion import ( CASCADE, DO_NOTHING, PROTECT, SET, SET_DEFAULT, SET_NULL, ProtectedError, ) from django.db.models.expressions import ( - Case, Exists, Expression, ExpressionWrapper, F, Func, OuterRef, Subquery, - Value, When, + Case, Exists, Expression, ExpressionList, ExpressionWrapper, F, Func, + OuterRef, RowRange, Subquery, Value, ValueRange, When, Window, WindowFrame, ) from django.db.models.fields import * # NOQA from django.db.models.fields import __all__ as fields_all @@ -64,8 +64,9 @@ __all__ += [ 'ObjectDoesNotExist', 'signals', 'CASCADE', 'DO_NOTHING', 'PROTECT', 'SET', 'SET_DEFAULT', 'SET_NULL', 'ProtectedError', - 'Case', 'Exists', 'Expression', 'ExpressionWrapper', 'F', 'Func', - 'OuterRef', 'Subquery', 'Value', 'When', + 'Case', 'Exists', 'Expression', 'ExpressionList', 'ExpressionWrapper', 'F', + 'Func', 'OuterRef', 'RowRange', 'Subquery', 'Value', 'ValueRange', 'When', + 'Window', 'WindowFrame', 'FileField', 'ImageField', 'OrderWrt', 'Lookup', 'Transform', 'Manager', 'Prefetch', 'Q', 'QuerySet', 'prefetch_related_objects', 'DEFERRED', 'Model', 'ForeignKey', 'ForeignObject', 'OneToOneField', 'ManyToManyField', diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index c91200ef7f..4ed763cfe1 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -15,6 +15,7 @@ class Aggregate(Func): contains_aggregate = True name = None filter_template = '%s FILTER (WHERE %%(filter)s)' + window_compatible = True def __init__(self, *args, filter=None, **kwargs): self.filter = filter diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 1937ca16c7..49ca801924 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -3,6 +3,7 @@ import datetime from decimal import Decimal from django.core.exceptions import EmptyResultSet, FieldError +from django.db import connection from django.db.models import fields from django.db.models.query_utils import Q from django.utils.deconstruct import deconstructible @@ -140,6 +141,10 @@ class BaseExpression: # aggregate specific fields is_summary = False _output_field_resolved_to_none = False + # Can the expression be used in a WHERE clause? + filterable = True + # Can the expression can be used as a source expression in Window? + window_compatible = False def __init__(self, output_field=None): if output_field is not None: @@ -207,6 +212,13 @@ class BaseExpression: return False @cached_property + def contains_over_clause(self): + for expr in self.get_source_expressions(): + if expr and expr.contains_over_clause: + return True + return False + + @cached_property def contains_column_references(self): for expr in self.get_source_expressions(): if expr and expr.contains_column_references: @@ -232,6 +244,7 @@ class BaseExpression: c.is_summary = summarize c.set_source_expressions([ expr.resolve_expression(query, allow_joins, reuse, summarize) + if expr else None for expr in c.get_source_expressions() ]) return c @@ -482,6 +495,9 @@ class TemporalSubtraction(CombinedExpression): @deconstructible class F(Combinable): """An object capable of resolving references to existing query objects.""" + # Can the expression be used in a WHERE clause? + filterable = True + def __init__(self, name): """ Arguments: @@ -767,6 +783,23 @@ class Ref(Expression): return [self] +class ExpressionList(Func): + """ + An expression containing multiple expressions. Can be used to provide a + list of expressions as an argument to another expression, like an + ordering clause. + """ + template = '%(expressions)s' + + def __init__(self, *expressions, **extra): + if len(expressions) == 0: + raise ValueError('%s requires at least one expression.' % self.__class__.__name__) + super().__init__(*expressions, **extra) + + def __str__(self): + return self.arg_joiner.join(str(arg) for arg in self.source_expressions) + + class ExpressionWrapper(Expression): """ An expression that can wrap another expression so that it can provide @@ -1118,3 +1151,168 @@ class OrderBy(BaseExpression): def desc(self): self.descending = True + + +class Window(Expression): + template = '%(expression)s OVER (%(window)s)' + # Although the main expression may either be an aggregate or an + # expression with an aggregate function, the GROUP BY that will + # be introduced in the query as a result is not desired. + contains_aggregate = False + contains_over_clause = True + filterable = False + + def __init__(self, expression, partition_by=None, order_by=None, frame=None, output_field=None): + self.partition_by = partition_by + self.order_by = order_by + self.frame = frame + + if not getattr(expression, 'window_compatible', False): + raise ValueError( + "Expression '%s' isn't compatible with OVER clauses." % + expression.__class__.__name__ + ) + + if self.partition_by is not None: + if not isinstance(self.partition_by, (tuple, list)): + self.partition_by = (self.partition_by,) + self.partition_by = ExpressionList(*self.partition_by) + + if self.order_by is not None: + if isinstance(self.order_by, (list, tuple)): + self.order_by = ExpressionList(*self.order_by) + elif not isinstance(self.order_by, BaseExpression): + raise ValueError( + 'order_by must be either an Expression or a sequence of ' + 'expressions.' + ) + super().__init__(output_field=output_field) + self.source_expression = self._parse_expressions(expression)[0] + + def _resolve_output_field(self): + return self.source_expression.output_field + + def get_source_expressions(self): + return [self.source_expression, self.partition_by, self.order_by, self.frame] + + def set_source_expressions(self, exprs): + self.source_expression, self.partition_by, self.order_by, self.frame = exprs + + def as_sql(self, compiler, connection, function=None, template=None): + connection.ops.check_expression_support(self) + expr_sql, params = compiler.compile(self.source_expression) + window_sql, window_params = [], [] + + if self.partition_by is not None: + sql_expr, sql_params = self.partition_by.as_sql( + compiler=compiler, connection=connection, + template='PARTITION BY %(expressions)s', + ) + window_sql.extend(sql_expr) + window_params.extend(sql_params) + + if self.order_by is not None: + window_sql.append(' ORDER BY ') + order_sql, order_params = compiler.compile(self.order_by) + window_sql.extend(''.join(order_sql)) + window_params.extend(order_params) + + if self.frame: + frame_sql, frame_params = compiler.compile(self.frame) + window_sql.extend(' ' + frame_sql) + window_params.extend(frame_params) + + params.extend(window_params) + template = template or self.template + + return template % { + 'expression': expr_sql, + 'window': ''.join(window_sql).strip() + }, params + + def __str__(self): + return '{} OVER ({}{}{})'.format( + str(self.source_expression), + 'PARTITION BY ' + str(self.partition_by) if self.partition_by else '', + 'ORDER BY ' + str(self.order_by) if self.order_by else '', + str(self.frame or ''), + ) + + def __repr__(self): + return '<%s: %s>' % (self.__class__.__name__, self) + + def get_group_by_cols(self): + return [] + + +class WindowFrame(Expression): + """ + Model the frame clause in window expressions. There are two types of frame + clauses which are subclasses, however, all processing and validation (by no + means intended to be complete) is done here. Thus, providing an end for a + frame is optional (the default is UNBOUNDED FOLLOWING, which is the last + row in the frame). + """ + template = '%(frame_type)s BETWEEN %(start)s AND %(end)s' + + def __init__(self, start=None, end=None): + self.start = start + self.end = end + + def set_source_expressions(self, exprs): + self.start, self.end = exprs + + def get_source_expressions(self): + return [Value(self.start), Value(self.end)] + + def as_sql(self, compiler, connection): + connection.ops.check_expression_support(self) + start, end = self.window_frame_start_end(connection, self.start.value, self.end.value) + return self.template % { + 'frame_type': self.frame_type, + 'start': start, + 'end': end, + }, [] + + def __repr__(self): + return '<%s: %s>' % (self.__class__.__name__, self) + + def get_group_by_cols(self): + return [] + + def __str__(self): + if self.start is not None and self.start < 0: + start = '%d %s' % (abs(self.start), connection.ops.PRECEDING) + elif self.start is not None and self.start == 0: + start = connection.ops.CURRENT_ROW + else: + start = connection.ops.UNBOUNDED_PRECEDING + + if self.end is not None and self.end > 0: + end = '%d %s' % (self.end, connection.ops.FOLLOWING) + elif self.end is not None and self.end == 0: + end = connection.ops.CURRENT_ROW + else: + end = connection.ops.UNBOUNDED_FOLLOWING + return self.template % { + 'frame_type': self.frame_type, + 'start': start, + 'end': end, + } + + def window_frame_start_end(self, connection, start, end): + raise NotImplementedError('Subclasses must implement window_frame_start_end().') + + +class RowRange(WindowFrame): + frame_type = 'ROWS' + + def window_frame_start_end(self, connection, start, end): + return connection.ops.window_frame_rows_start_end(start, end) + + +class ValueRange(WindowFrame): + frame_type = 'RANGE' + + def window_frame_start_end(self, connection, start, end): + return connection.ops.window_frame_range_start_end(start, end) diff --git a/django/db/models/functions/__init__.py b/django/db/models/functions/__init__.py index f2e59f38ff..aab74b232a 100644 --- a/django/db/models/functions/__init__.py +++ b/django/db/models/functions/__init__.py @@ -8,6 +8,10 @@ from .datetime import ( Trunc, TruncDate, TruncDay, TruncHour, TruncMinute, TruncMonth, TruncQuarter, TruncSecond, TruncTime, TruncYear, ) +from .window import ( + CumeDist, DenseRank, FirstValue, Lag, LastValue, Lead, NthValue, Ntile, + PercentRank, Rank, RowNumber, +) __all__ = [ # base @@ -18,4 +22,7 @@ __all__ = [ 'ExtractQuarter', 'ExtractSecond', 'ExtractWeek', 'ExtractWeekDay', 'ExtractYear', 'Trunc', 'TruncDate', 'TruncDay', 'TruncHour', 'TruncMinute', 'TruncMonth', 'TruncQuarter', 'TruncSecond', 'TruncTime', 'TruncYear', + # window + 'CumeDist', 'DenseRank', 'FirstValue', 'Lag', 'LastValue', 'Lead', + 'NthValue', 'Ntile', 'PercentRank', 'Rank', 'RowNumber', ] diff --git a/django/db/models/functions/window.py b/django/db/models/functions/window.py new file mode 100644 index 0000000000..3719dfca88 --- /dev/null +++ b/django/db/models/functions/window.py @@ -0,0 +1,118 @@ +from django.db.models import FloatField, Func, IntegerField + +__all__ = [ + 'CumeDist', 'DenseRank', 'FirstValue', 'Lag', 'LastValue', 'Lead', + 'NthValue', 'Ntile', 'PercentRank', 'Rank', 'RowNumber', +] + + +class CumeDist(Func): + function = 'CUME_DIST' + name = 'CumeDist' + output_field = FloatField() + window_compatible = True + + +class DenseRank(Func): + function = 'DENSE_RANK' + name = 'DenseRank' + output_field = IntegerField() + window_compatible = True + + +class FirstValue(Func): + arity = 1 + function = 'FIRST_VALUE' + name = 'FirstValue' + window_compatible = True + + +class LagLeadFunction(Func): + window_compatible = True + + def __init__(self, expression, offset=1, default=None, **extra): + if expression is None: + raise ValueError( + '%s requires a non-null source expression.' % + self.__class__.__name__ + ) + if offset is None or offset <= 0: + raise ValueError( + '%s requires a positive integer for the offset.' % + self.__class__.__name__ + ) + args = (expression, offset) + if default is not None: + args += (default,) + super().__init__(*args, **extra) + + def _resolve_output_field(self): + sources = self.get_source_expressions() + return sources[0].output_field + + +class Lag(LagLeadFunction): + function = 'LAG' + name = 'Lag' + + +class LastValue(Func): + arity = 1 + function = 'LAST_VALUE' + name = 'LastValue' + window_compatible = True + + +class Lead(LagLeadFunction): + function = 'LEAD' + name = 'Lead' + + +class NthValue(Func): + function = 'NTH_VALUE' + name = 'NthValue' + window_compatible = True + + def __init__(self, expression, nth=1, **extra): + if expression is None: + raise ValueError('%s requires a non-null source expression.' % self.__class__.__name__) + if nth is None or nth <= 0: + raise ValueError('%s requires a positive integer as for nth.' % self.__class__.__name__) + super().__init__(expression, nth, **extra) + + def _resolve_output_field(self): + sources = self.get_source_expressions() + return sources[0].output_field + + +class Ntile(Func): + function = 'NTILE' + name = 'Ntile' + output_field = IntegerField() + window_compatible = True + + def __init__(self, num_buckets=1, **extra): + if num_buckets <= 0: + raise ValueError('num_buckets must be greater than 0.') + super().__init__(num_buckets, **extra) + + +class PercentRank(Func): + function = 'PERCENT_RANK' + name = 'PercentRank' + output_field = FloatField() + window_compatible = True + + +class Rank(Func): + function = 'RANK' + name = 'Rank' + output_field = IntegerField() + window_compatible = True + + +class RowNumber(Func): + function = 'ROW_NUMBER' + name = 'RowNumber' + output_field = IntegerField() + window_compatible = True diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index d82c29af66..f79f435515 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -115,6 +115,10 @@ class Lookup: def contains_aggregate(self): return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False) + @cached_property + def contains_over_clause(self): + return self.lhs.contains_over_clause or getattr(self.rhs, 'contains_over_clause', False) + @property def is_summary(self): return self.lhs.is_summary or getattr(self.rhs, 'is_summary', False) diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 01c303eb7e..11ff51f60f 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -1107,6 +1107,8 @@ class SQLInsertCompiler(SQLCompiler): ) if value.contains_aggregate: raise FieldError("Aggregate functions are not allowed in this query") + if value.contains_over_clause: + raise FieldError('Window expressions are not allowed in this query.') else: value = field.get_db_prep_save(value, connection=self.connection) return value @@ -1262,6 +1264,8 @@ class SQLUpdateCompiler(SQLCompiler): val = val.resolve_expression(self.query, allow_joins=False, for_save=True) if val.contains_aggregate: raise FieldError("Aggregate functions are not allowed in this query") + if val.contains_over_clause: + raise FieldError('Window expressions are not allowed in this query.') elif hasattr(val, 'prepare_database_save'): if field.remote_field: val = field.get_db_prep_save( diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 017edea873..4cd22c7b8a 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -13,7 +13,7 @@ from string import ascii_uppercase from django.core.exceptions import ( EmptyResultSet, FieldDoesNotExist, FieldError, ) -from django.db import DEFAULT_DB_ALIAS, connections +from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections from django.db.models.aggregates import Count from django.db.models.constants import LOOKUP_SEP from django.db.models.expressions import Col, Ref @@ -1125,6 +1125,13 @@ class Query: if not arg: raise FieldError("Cannot parse keyword query %r" % arg) lookups, parts, reffed_expression = self.solve_lookup_type(arg) + + if not getattr(reffed_expression, 'filterable', True): + raise NotSupportedError( + reffed_expression.__class__.__name__ + ' is disallowed in ' + 'the filter clause.' + ) + if not allow_joins and len(parts) > 1: raise FieldError("Joined field references are not permitted in this query") diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index ed24b08bd0..0ca95f7018 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -167,6 +167,16 @@ class WhereNode(tree.Node): def contains_aggregate(self): return self._contains_aggregate(self) + @classmethod + def _contains_over_clause(cls, obj): + if isinstance(obj, tree.Node): + return any(cls._contains_over_clause(c) for c in obj.children) + return obj.contains_over_clause + + @cached_property + def contains_over_clause(self): + return self._contains_over_clause(self) + @property def is_summary(self): return any(child.is_summary for child in self.children) |
