summaryrefslogtreecommitdiff
path: root/django/db/models
diff options
context:
space:
mode:
authorMads Jensen <mje@inducks.org>2017-09-18 15:42:29 +0200
committerTim Graham <timograham@gmail.com>2017-09-18 09:42:29 -0400
commitd549b8805053d4b064bf492ba90e90db5d7e2a6b (patch)
tree2beee237ae541804ba18367d81e82840745d6e47 /django/db/models
parentda1ba03f1dfb303df9bfb5c76d36216e45d05edc (diff)
Fixed #26608 -- Added support for window expressions (OVER clause).
Thanks Josh Smeaton, Mariusz Felisiak, Sergey Fedoseev, Simon Charettes, Adam Chainz/Johnson and Tim Graham for comments and reviews and Jamie Cockburn for initial patch.
Diffstat (limited to 'django/db/models')
-rw-r--r--django/db/models/__init__.py9
-rw-r--r--django/db/models/aggregates.py1
-rw-r--r--django/db/models/expressions.py198
-rw-r--r--django/db/models/functions/__init__.py7
-rw-r--r--django/db/models/functions/window.py118
-rw-r--r--django/db/models/lookups.py4
-rw-r--r--django/db/models/sql/compiler.py4
-rw-r--r--django/db/models/sql/query.py9
-rw-r--r--django/db/models/sql/where.py10
9 files changed, 355 insertions, 5 deletions
diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py
index 06806fcdc4..d29addd1f7 100644
--- a/django/db/models/__init__.py
+++ b/django/db/models/__init__.py
@@ -6,8 +6,8 @@ from django.db.models.deletion import (
CASCADE, DO_NOTHING, PROTECT, SET, SET_DEFAULT, SET_NULL, ProtectedError,
)
from django.db.models.expressions import (
- Case, Exists, Expression, ExpressionWrapper, F, Func, OuterRef, Subquery,
- Value, When,
+ Case, Exists, Expression, ExpressionList, ExpressionWrapper, F, Func,
+ OuterRef, RowRange, Subquery, Value, ValueRange, When, Window, WindowFrame,
)
from django.db.models.fields import * # NOQA
from django.db.models.fields import __all__ as fields_all
@@ -64,8 +64,9 @@ __all__ += [
'ObjectDoesNotExist', 'signals',
'CASCADE', 'DO_NOTHING', 'PROTECT', 'SET', 'SET_DEFAULT', 'SET_NULL',
'ProtectedError',
- 'Case', 'Exists', 'Expression', 'ExpressionWrapper', 'F', 'Func',
- 'OuterRef', 'Subquery', 'Value', 'When',
+ 'Case', 'Exists', 'Expression', 'ExpressionList', 'ExpressionWrapper', 'F',
+ 'Func', 'OuterRef', 'RowRange', 'Subquery', 'Value', 'ValueRange', 'When',
+ 'Window', 'WindowFrame',
'FileField', 'ImageField', 'OrderWrt', 'Lookup', 'Transform', 'Manager',
'Prefetch', 'Q', 'QuerySet', 'prefetch_related_objects', 'DEFERRED', 'Model',
'ForeignKey', 'ForeignObject', 'OneToOneField', 'ManyToManyField',
diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py
index c91200ef7f..4ed763cfe1 100644
--- a/django/db/models/aggregates.py
+++ b/django/db/models/aggregates.py
@@ -15,6 +15,7 @@ class Aggregate(Func):
contains_aggregate = True
name = None
filter_template = '%s FILTER (WHERE %%(filter)s)'
+ window_compatible = True
def __init__(self, *args, filter=None, **kwargs):
self.filter = filter
diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py
index 1937ca16c7..49ca801924 100644
--- a/django/db/models/expressions.py
+++ b/django/db/models/expressions.py
@@ -3,6 +3,7 @@ import datetime
from decimal import Decimal
from django.core.exceptions import EmptyResultSet, FieldError
+from django.db import connection
from django.db.models import fields
from django.db.models.query_utils import Q
from django.utils.deconstruct import deconstructible
@@ -140,6 +141,10 @@ class BaseExpression:
# aggregate specific fields
is_summary = False
_output_field_resolved_to_none = False
+ # Can the expression be used in a WHERE clause?
+ filterable = True
+ # Can the expression can be used as a source expression in Window?
+ window_compatible = False
def __init__(self, output_field=None):
if output_field is not None:
@@ -207,6 +212,13 @@ class BaseExpression:
return False
@cached_property
+ def contains_over_clause(self):
+ for expr in self.get_source_expressions():
+ if expr and expr.contains_over_clause:
+ return True
+ return False
+
+ @cached_property
def contains_column_references(self):
for expr in self.get_source_expressions():
if expr and expr.contains_column_references:
@@ -232,6 +244,7 @@ class BaseExpression:
c.is_summary = summarize
c.set_source_expressions([
expr.resolve_expression(query, allow_joins, reuse, summarize)
+ if expr else None
for expr in c.get_source_expressions()
])
return c
@@ -482,6 +495,9 @@ class TemporalSubtraction(CombinedExpression):
@deconstructible
class F(Combinable):
"""An object capable of resolving references to existing query objects."""
+ # Can the expression be used in a WHERE clause?
+ filterable = True
+
def __init__(self, name):
"""
Arguments:
@@ -767,6 +783,23 @@ class Ref(Expression):
return [self]
+class ExpressionList(Func):
+ """
+ An expression containing multiple expressions. Can be used to provide a
+ list of expressions as an argument to another expression, like an
+ ordering clause.
+ """
+ template = '%(expressions)s'
+
+ def __init__(self, *expressions, **extra):
+ if len(expressions) == 0:
+ raise ValueError('%s requires at least one expression.' % self.__class__.__name__)
+ super().__init__(*expressions, **extra)
+
+ def __str__(self):
+ return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
+
+
class ExpressionWrapper(Expression):
"""
An expression that can wrap another expression so that it can provide
@@ -1118,3 +1151,168 @@ class OrderBy(BaseExpression):
def desc(self):
self.descending = True
+
+
+class Window(Expression):
+ template = '%(expression)s OVER (%(window)s)'
+ # Although the main expression may either be an aggregate or an
+ # expression with an aggregate function, the GROUP BY that will
+ # be introduced in the query as a result is not desired.
+ contains_aggregate = False
+ contains_over_clause = True
+ filterable = False
+
+ def __init__(self, expression, partition_by=None, order_by=None, frame=None, output_field=None):
+ self.partition_by = partition_by
+ self.order_by = order_by
+ self.frame = frame
+
+ if not getattr(expression, 'window_compatible', False):
+ raise ValueError(
+ "Expression '%s' isn't compatible with OVER clauses." %
+ expression.__class__.__name__
+ )
+
+ if self.partition_by is not None:
+ if not isinstance(self.partition_by, (tuple, list)):
+ self.partition_by = (self.partition_by,)
+ self.partition_by = ExpressionList(*self.partition_by)
+
+ if self.order_by is not None:
+ if isinstance(self.order_by, (list, tuple)):
+ self.order_by = ExpressionList(*self.order_by)
+ elif not isinstance(self.order_by, BaseExpression):
+ raise ValueError(
+ 'order_by must be either an Expression or a sequence of '
+ 'expressions.'
+ )
+ super().__init__(output_field=output_field)
+ self.source_expression = self._parse_expressions(expression)[0]
+
+ def _resolve_output_field(self):
+ return self.source_expression.output_field
+
+ def get_source_expressions(self):
+ return [self.source_expression, self.partition_by, self.order_by, self.frame]
+
+ def set_source_expressions(self, exprs):
+ self.source_expression, self.partition_by, self.order_by, self.frame = exprs
+
+ def as_sql(self, compiler, connection, function=None, template=None):
+ connection.ops.check_expression_support(self)
+ expr_sql, params = compiler.compile(self.source_expression)
+ window_sql, window_params = [], []
+
+ if self.partition_by is not None:
+ sql_expr, sql_params = self.partition_by.as_sql(
+ compiler=compiler, connection=connection,
+ template='PARTITION BY %(expressions)s',
+ )
+ window_sql.extend(sql_expr)
+ window_params.extend(sql_params)
+
+ if self.order_by is not None:
+ window_sql.append(' ORDER BY ')
+ order_sql, order_params = compiler.compile(self.order_by)
+ window_sql.extend(''.join(order_sql))
+ window_params.extend(order_params)
+
+ if self.frame:
+ frame_sql, frame_params = compiler.compile(self.frame)
+ window_sql.extend(' ' + frame_sql)
+ window_params.extend(frame_params)
+
+ params.extend(window_params)
+ template = template or self.template
+
+ return template % {
+ 'expression': expr_sql,
+ 'window': ''.join(window_sql).strip()
+ }, params
+
+ def __str__(self):
+ return '{} OVER ({}{}{})'.format(
+ str(self.source_expression),
+ 'PARTITION BY ' + str(self.partition_by) if self.partition_by else '',
+ 'ORDER BY ' + str(self.order_by) if self.order_by else '',
+ str(self.frame or ''),
+ )
+
+ def __repr__(self):
+ return '<%s: %s>' % (self.__class__.__name__, self)
+
+ def get_group_by_cols(self):
+ return []
+
+
+class WindowFrame(Expression):
+ """
+ Model the frame clause in window expressions. There are two types of frame
+ clauses which are subclasses, however, all processing and validation (by no
+ means intended to be complete) is done here. Thus, providing an end for a
+ frame is optional (the default is UNBOUNDED FOLLOWING, which is the last
+ row in the frame).
+ """
+ template = '%(frame_type)s BETWEEN %(start)s AND %(end)s'
+
+ def __init__(self, start=None, end=None):
+ self.start = start
+ self.end = end
+
+ def set_source_expressions(self, exprs):
+ self.start, self.end = exprs
+
+ def get_source_expressions(self):
+ return [Value(self.start), Value(self.end)]
+
+ def as_sql(self, compiler, connection):
+ connection.ops.check_expression_support(self)
+ start, end = self.window_frame_start_end(connection, self.start.value, self.end.value)
+ return self.template % {
+ 'frame_type': self.frame_type,
+ 'start': start,
+ 'end': end,
+ }, []
+
+ def __repr__(self):
+ return '<%s: %s>' % (self.__class__.__name__, self)
+
+ def get_group_by_cols(self):
+ return []
+
+ def __str__(self):
+ if self.start is not None and self.start < 0:
+ start = '%d %s' % (abs(self.start), connection.ops.PRECEDING)
+ elif self.start is not None and self.start == 0:
+ start = connection.ops.CURRENT_ROW
+ else:
+ start = connection.ops.UNBOUNDED_PRECEDING
+
+ if self.end is not None and self.end > 0:
+ end = '%d %s' % (self.end, connection.ops.FOLLOWING)
+ elif self.end is not None and self.end == 0:
+ end = connection.ops.CURRENT_ROW
+ else:
+ end = connection.ops.UNBOUNDED_FOLLOWING
+ return self.template % {
+ 'frame_type': self.frame_type,
+ 'start': start,
+ 'end': end,
+ }
+
+ def window_frame_start_end(self, connection, start, end):
+ raise NotImplementedError('Subclasses must implement window_frame_start_end().')
+
+
+class RowRange(WindowFrame):
+ frame_type = 'ROWS'
+
+ def window_frame_start_end(self, connection, start, end):
+ return connection.ops.window_frame_rows_start_end(start, end)
+
+
+class ValueRange(WindowFrame):
+ frame_type = 'RANGE'
+
+ def window_frame_start_end(self, connection, start, end):
+ return connection.ops.window_frame_range_start_end(start, end)
diff --git a/django/db/models/functions/__init__.py b/django/db/models/functions/__init__.py
index f2e59f38ff..aab74b232a 100644
--- a/django/db/models/functions/__init__.py
+++ b/django/db/models/functions/__init__.py
@@ -8,6 +8,10 @@ from .datetime import (
Trunc, TruncDate, TruncDay, TruncHour, TruncMinute, TruncMonth,
TruncQuarter, TruncSecond, TruncTime, TruncYear,
)
+from .window import (
+ CumeDist, DenseRank, FirstValue, Lag, LastValue, Lead, NthValue, Ntile,
+ PercentRank, Rank, RowNumber,
+)
__all__ = [
# base
@@ -18,4 +22,7 @@ __all__ = [
'ExtractQuarter', 'ExtractSecond', 'ExtractWeek', 'ExtractWeekDay',
'ExtractYear', 'Trunc', 'TruncDate', 'TruncDay', 'TruncHour', 'TruncMinute',
'TruncMonth', 'TruncQuarter', 'TruncSecond', 'TruncTime', 'TruncYear',
+ # window
+ 'CumeDist', 'DenseRank', 'FirstValue', 'Lag', 'LastValue', 'Lead',
+ 'NthValue', 'Ntile', 'PercentRank', 'Rank', 'RowNumber',
]
diff --git a/django/db/models/functions/window.py b/django/db/models/functions/window.py
new file mode 100644
index 0000000000..3719dfca88
--- /dev/null
+++ b/django/db/models/functions/window.py
@@ -0,0 +1,118 @@
+from django.db.models import FloatField, Func, IntegerField
+
+__all__ = [
+ 'CumeDist', 'DenseRank', 'FirstValue', 'Lag', 'LastValue', 'Lead',
+ 'NthValue', 'Ntile', 'PercentRank', 'Rank', 'RowNumber',
+]
+
+
+class CumeDist(Func):
+ function = 'CUME_DIST'
+ name = 'CumeDist'
+ output_field = FloatField()
+ window_compatible = True
+
+
+class DenseRank(Func):
+ function = 'DENSE_RANK'
+ name = 'DenseRank'
+ output_field = IntegerField()
+ window_compatible = True
+
+
+class FirstValue(Func):
+ arity = 1
+ function = 'FIRST_VALUE'
+ name = 'FirstValue'
+ window_compatible = True
+
+
+class LagLeadFunction(Func):
+ window_compatible = True
+
+ def __init__(self, expression, offset=1, default=None, **extra):
+ if expression is None:
+ raise ValueError(
+ '%s requires a non-null source expression.' %
+ self.__class__.__name__
+ )
+ if offset is None or offset <= 0:
+ raise ValueError(
+ '%s requires a positive integer for the offset.' %
+ self.__class__.__name__
+ )
+ args = (expression, offset)
+ if default is not None:
+ args += (default,)
+ super().__init__(*args, **extra)
+
+ def _resolve_output_field(self):
+ sources = self.get_source_expressions()
+ return sources[0].output_field
+
+
+class Lag(LagLeadFunction):
+ function = 'LAG'
+ name = 'Lag'
+
+
+class LastValue(Func):
+ arity = 1
+ function = 'LAST_VALUE'
+ name = 'LastValue'
+ window_compatible = True
+
+
+class Lead(LagLeadFunction):
+ function = 'LEAD'
+ name = 'Lead'
+
+
+class NthValue(Func):
+ function = 'NTH_VALUE'
+ name = 'NthValue'
+ window_compatible = True
+
+ def __init__(self, expression, nth=1, **extra):
+ if expression is None:
+ raise ValueError('%s requires a non-null source expression.' % self.__class__.__name__)
+ if nth is None or nth <= 0:
+ raise ValueError('%s requires a positive integer as for nth.' % self.__class__.__name__)
+ super().__init__(expression, nth, **extra)
+
+ def _resolve_output_field(self):
+ sources = self.get_source_expressions()
+ return sources[0].output_field
+
+
+class Ntile(Func):
+ function = 'NTILE'
+ name = 'Ntile'
+ output_field = IntegerField()
+ window_compatible = True
+
+ def __init__(self, num_buckets=1, **extra):
+ if num_buckets <= 0:
+ raise ValueError('num_buckets must be greater than 0.')
+ super().__init__(num_buckets, **extra)
+
+
+class PercentRank(Func):
+ function = 'PERCENT_RANK'
+ name = 'PercentRank'
+ output_field = FloatField()
+ window_compatible = True
+
+
+class Rank(Func):
+ function = 'RANK'
+ name = 'Rank'
+ output_field = IntegerField()
+ window_compatible = True
+
+
+class RowNumber(Func):
+ function = 'ROW_NUMBER'
+ name = 'RowNumber'
+ output_field = IntegerField()
+ window_compatible = True
diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py
index d82c29af66..f79f435515 100644
--- a/django/db/models/lookups.py
+++ b/django/db/models/lookups.py
@@ -115,6 +115,10 @@ class Lookup:
def contains_aggregate(self):
return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False)
+ @cached_property
+ def contains_over_clause(self):
+ return self.lhs.contains_over_clause or getattr(self.rhs, 'contains_over_clause', False)
+
@property
def is_summary(self):
return self.lhs.is_summary or getattr(self.rhs, 'is_summary', False)
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index 01c303eb7e..11ff51f60f 100644
--- a/django/db/models/sql/compiler.py
+++ b/django/db/models/sql/compiler.py
@@ -1107,6 +1107,8 @@ class SQLInsertCompiler(SQLCompiler):
)
if value.contains_aggregate:
raise FieldError("Aggregate functions are not allowed in this query")
+ if value.contains_over_clause:
+ raise FieldError('Window expressions are not allowed in this query.')
else:
value = field.get_db_prep_save(value, connection=self.connection)
return value
@@ -1262,6 +1264,8 @@ class SQLUpdateCompiler(SQLCompiler):
val = val.resolve_expression(self.query, allow_joins=False, for_save=True)
if val.contains_aggregate:
raise FieldError("Aggregate functions are not allowed in this query")
+ if val.contains_over_clause:
+ raise FieldError('Window expressions are not allowed in this query.')
elif hasattr(val, 'prepare_database_save'):
if field.remote_field:
val = field.get_db_prep_save(
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
index 017edea873..4cd22c7b8a 100644
--- a/django/db/models/sql/query.py
+++ b/django/db/models/sql/query.py
@@ -13,7 +13,7 @@ from string import ascii_uppercase
from django.core.exceptions import (
EmptyResultSet, FieldDoesNotExist, FieldError,
)
-from django.db import DEFAULT_DB_ALIAS, connections
+from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections
from django.db.models.aggregates import Count
from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import Col, Ref
@@ -1125,6 +1125,13 @@ class Query:
if not arg:
raise FieldError("Cannot parse keyword query %r" % arg)
lookups, parts, reffed_expression = self.solve_lookup_type(arg)
+
+ if not getattr(reffed_expression, 'filterable', True):
+ raise NotSupportedError(
+ reffed_expression.__class__.__name__ + ' is disallowed in '
+ 'the filter clause.'
+ )
+
if not allow_joins and len(parts) > 1:
raise FieldError("Joined field references are not permitted in this query")
diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py
index ed24b08bd0..0ca95f7018 100644
--- a/django/db/models/sql/where.py
+++ b/django/db/models/sql/where.py
@@ -167,6 +167,16 @@ class WhereNode(tree.Node):
def contains_aggregate(self):
return self._contains_aggregate(self)
+ @classmethod
+ def _contains_over_clause(cls, obj):
+ if isinstance(obj, tree.Node):
+ return any(cls._contains_over_clause(c) for c in obj.children)
+ return obj.contains_over_clause
+
+ @cached_property
+ def contains_over_clause(self):
+ return self._contains_over_clause(self)
+
@property
def is_summary(self):
return any(child.is_summary for child in self.children)