summaryrefslogtreecommitdiff
path: root/django/db/models/sql
diff options
context:
space:
mode:
authorJosh Smeaton <josh.smeaton@gmail.com>2013-12-26 00:13:18 +1100
committerMarc Tamlyn <marc.tamlyn@gmail.com>2014-11-15 14:00:43 +0000
commitf59fd15c4928caf3dfcbd50f6ab47be409a43b01 (patch)
treefe4a04d98359e1ffcbfe991303eb97d9a8e16afc /django/db/models/sql
parent39e3ef88c237e3f4cedc89cd36494a6d3f490812 (diff)
Fixed #14030 -- Allowed annotations to accept all expressions
Diffstat (limited to 'django/db/models/sql')
-rw-r--r--django/db/models/sql/aggregates.py8
-rw-r--r--django/db/models/sql/compiler.py94
-rw-r--r--django/db/models/sql/datastructures.py66
-rw-r--r--django/db/models/sql/expressions.py119
-rw-r--r--django/db/models/sql/query.py384
-rw-r--r--django/db/models/sql/subqueries.py4
-rw-r--r--django/db/models/sql/where.py9
7 files changed, 271 insertions, 413 deletions
diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py
index 8274d43621..6ebf5fb966 100644
--- a/django/db/models/sql/aggregates.py
+++ b/django/db/models/sql/aggregates.py
@@ -2,15 +2,23 @@
Classes to represent the default SQL aggregate functions
"""
import copy
+import warnings
from django.db.models.fields import IntegerField, FloatField
from django.db.models.lookups import RegisterLookupMixin
+from django.utils.deprecation import RemovedInDjango20Warning
from django.utils.functional import cached_property
__all__ = ['Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance']
+warnings.warn(
+ "django.db.models.sql.aggregates is deprecated. Use "
+ "django.db.models.aggregates instead.",
+ RemovedInDjango20Warning, stacklevel=2)
+
+
class Aggregate(RegisterLookupMixin):
"""
Default SQL Aggregate.
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index 33fe343b5b..5f425a7543 100644
--- a/django/db/models/sql/compiler.py
+++ b/django/db/models/sql/compiler.py
@@ -4,12 +4,10 @@ from django.conf import settings
from django.core.exceptions import FieldError
from django.db.backends.utils import truncate_name
from django.db.models.constants import LOOKUP_SEP
-from django.db.models.expressions import ExpressionNode
from django.db.models.query_utils import select_related_descend, QueryWrapper
from django.db.models.sql.constants import (CURSOR, SINGLE, MULTI, NO_RESULTS,
ORDER_DIR, GET_ITERATOR_CHUNK_SIZE, SelectInfo)
from django.db.models.sql.datastructures import EmptyResultSet
-from django.db.models.sql.expressions import SQLEvaluator
from django.db.models.sql.query import get_order_dir, Query
from django.db.transaction import TransactionManagementError
from django.db.utils import DatabaseError
@@ -248,8 +246,8 @@ class SQLCompiler(object):
aliases.update(new_aliases)
max_name_length = self.connection.ops.max_name_length()
- for alias, aggregate in self.query.aggregate_select.items():
- agg_sql, agg_params = self.compile(aggregate)
+ for alias, annotation in self.query.annotation_select.items():
+ agg_sql, agg_params = self.compile(annotation)
if alias is None:
result.append(agg_sql)
else:
@@ -409,7 +407,7 @@ class SQLCompiler(object):
group_by.append((str(field), []))
continue
col, order = get_order_dir(field, asc)
- if col in self.query.aggregate_select:
+ if col in self.query.annotation_select:
result.append('%s %s' % (qn(col), order))
continue
if '.' in field:
@@ -718,25 +716,17 @@ class SQLCompiler(object):
"""
fields = None
converters = None
- has_aggregate_select = bool(self.query.aggregate_select)
+ has_annotation_select = bool(self.query.annotation_select)
for rows in self.execute_sql(MULTI):
for row in rows:
- if has_aggregate_select:
- loaded_fields = (
- self.query.get_loaded_field_names().get(self.query.model, set()) or
- self.query.select
- )
- aggregate_start = len(self.query.extra_select) + len(loaded_fields)
- aggregate_end = aggregate_start + len(self.query.aggregate_select)
if fields is None:
# We only set this up here because
# related_select_cols isn't populated until
# execute_sql() has been called.
- # We also include types of fields of related models that
- # will be included via select_related() for the benefit
- # of MySQL/MySQLdb when boolean fields are involved
- # (#15040).
+ # If the field was deferred, exclude it from being passed
+ # into `get_converters` because it wasn't selected.
+ only_load = self.deferred_to_columns()
# This code duplicates the logic for the order of fields
# found in get_columns(). It would be nice to clean this up.
@@ -746,30 +736,45 @@ class SQLCompiler(object):
fields = self.query.get_meta().concrete_fields
else:
fields = []
- fields = fields + [f.field for f in self.query.related_select_cols]
- # If the field was deferred, exclude it from being passed
- # into `get_converters` because it wasn't selected.
- only_load = self.deferred_to_columns()
if only_load:
- fields = [f for f in fields if f.model._meta.db_table not in only_load or
- f.column in only_load[f.model._meta.db_table]]
- if has_aggregate_select:
- # pad None in to fields for aggregates
- fields = fields[:aggregate_start] + [
- None for x in range(0, aggregate_end - aggregate_start)
- ] + fields[aggregate_start:]
+ # strip deferred fields
+ fields = [
+ f for f in fields if
+ f.model._meta.db_table not in only_load or
+ f.column in only_load[f.model._meta.db_table]
+ ]
+
+ # annotations come before the related cols
+ if has_annotation_select:
+ # extra is always at the start of the field list
+ prepended_cols = len(self.query.extra_select)
+ annotation_start = len(fields) + prepended_cols
+ fields = fields + [
+ anno.output_field for alias, anno in self.query.annotation_select.items()]
+ annotation_end = len(fields) + prepended_cols
+
+ # add related fields
+ fields = fields + [
+ # strip deferred
+ f.field for f in self.query.related_select_cols if
+ f.field.model._meta.db_table not in only_load or
+ f.field.column in only_load[f.field.model._meta.db_table]
+ ]
+
converters = self.get_converters(fields)
+ if has_annotation_select:
+ for (alias, annotation), position in zip(
+ self.query.annotation_select.items(),
+ range(annotation_start, annotation_end + 1)):
+ if position in converters:
+ # annotation conversions always run first
+ converters[position][1].insert(0, annotation.convert_value)
+ else:
+ converters[position] = ([], [annotation.convert_value], annotation.output_field)
+
if converters:
row = self.apply_converters(row, converters)
-
- if has_aggregate_select:
- row = tuple(row[:aggregate_start]) + tuple(
- self.query.resolve_aggregate(value, aggregate, self.connection)
- for (alias, aggregate), value
- in zip(self.query.aggregate_select.items(), row[aggregate_start:aggregate_end])
- ) + tuple(row[aggregate_end:])
-
yield row
def has_results(self):
@@ -878,7 +883,7 @@ class SQLInsertCompiler(SQLCompiler):
elif hasattr(field, 'get_placeholder'):
# Some fields (e.g. geo fields) need special munging before
# they can be inserted.
- return field.get_placeholder(val, self.connection)
+ return field.get_placeholder(val, self, self.connection)
else:
# Return the common case for the placeholder
return '%s'
@@ -985,8 +990,10 @@ class SQLUpdateCompiler(SQLCompiler):
result.append('SET')
values, update_params = [], []
for field, model, val in self.query.values:
- if hasattr(val, 'prepare_database_save'):
- if field.rel or isinstance(val, ExpressionNode):
+ if hasattr(val, 'resolve_expression'):
+ val = val.resolve_expression(self.query, allow_joins=False)
+ elif hasattr(val, 'prepare_database_save'):
+ if field.rel:
val = val.prepare_database_save(field)
else:
raise TypeError("Database is trying to update a relational field "
@@ -998,12 +1005,9 @@ class SQLUpdateCompiler(SQLCompiler):
# Getting the placeholder for the field.
if hasattr(field, 'get_placeholder'):
- placeholder = field.get_placeholder(val, self.connection)
+ placeholder = field.get_placeholder(val, self, self.connection)
else:
placeholder = '%s'
-
- if hasattr(val, 'evaluate'):
- val = SQLEvaluator(val, self.query, allow_joins=False)
name = field.column
if hasattr(val, 'as_sql'):
sql, params = self.compile(val)
@@ -1103,8 +1107,8 @@ class SQLAggregateCompiler(SQLCompiler):
qn = self
sql, params = [], []
- for aggregate in self.query.aggregate_select.values():
- agg_sql, agg_params = self.compile(aggregate)
+ for annotation in self.query.annotation_select.values():
+ agg_sql, agg_params = self.compile(annotation)
sql.append(agg_sql)
params.extend(agg_params)
sql = ', '.join(sql)
diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py
index f9c9c259de..321451ac42 100644
--- a/django/db/models/sql/datastructures.py
+++ b/django/db/models/sql/datastructures.py
@@ -4,33 +4,6 @@ the SQL domain.
"""
-class Col(object):
- def __init__(self, alias, target, source):
- self.alias, self.target, self.source = alias, target, source
-
- def as_sql(self, qn, connection):
- return "%s.%s" % (qn(self.alias), qn(self.target.column)), []
-
- @property
- def output_field(self):
- return self.source
-
- def relabeled_clone(self, relabels):
- return self.__class__(relabels.get(self.alias, self.alias), self.target, self.source)
-
- def get_group_by_cols(self):
- return [(self.alias, self.target.column)]
-
- def get_lookup(self, name):
- return self.output_field.get_lookup(name)
-
- def get_transform(self, name):
- return self.output_field.get_transform(name)
-
- def prepare(self):
- return self
-
-
class EmptyResultSet(Exception):
pass
@@ -49,42 +22,3 @@ class MultiJoin(Exception):
class Empty(object):
pass
-
-
-class Date(object):
- """
- Add a date selection column.
- """
- def __init__(self, col, lookup_type):
- self.col = col
- self.lookup_type = lookup_type
-
- def relabeled_clone(self, change_map):
- return self.__class__((change_map.get(self.col[0], self.col[0]), self.col[1]))
-
- def as_sql(self, qn, connection):
- if isinstance(self.col, (list, tuple)):
- col = '%s.%s' % tuple(qn(c) for c in self.col)
- else:
- col = self.col
- return connection.ops.date_trunc_sql(self.lookup_type, col), []
-
-
-class DateTime(object):
- """
- Add a datetime selection column.
- """
- def __init__(self, col, lookup_type, tzname):
- self.col = col
- self.lookup_type = lookup_type
- self.tzname = tzname
-
- def relabeled_clone(self, change_map):
- return self.__class__((change_map.get(self.col[0], self.col[0]), self.col[1]))
-
- def as_sql(self, qn, connection):
- if isinstance(self.col, (list, tuple)):
- col = '%s.%s' % tuple(qn(c) for c in self.col)
- else:
- col = self.col
- return connection.ops.datetime_trunc_sql(self.lookup_type, col, self.tzname)
diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py
deleted file mode 100644
index e15cc2642c..0000000000
--- a/django/db/models/sql/expressions.py
+++ /dev/null
@@ -1,119 +0,0 @@
-import copy
-
-from django.core.exceptions import FieldError
-from django.db.models.constants import LOOKUP_SEP
-from django.db.models.fields import FieldDoesNotExist
-
-
-class SQLEvaluator(object):
- def __init__(self, expression, query, allow_joins=True, reuse=None):
- self.expression = expression
- self.opts = query.get_meta()
- self.reuse = reuse
- self.cols = []
- self.expression.prepare(self, query, allow_joins)
-
- def relabeled_clone(self, change_map):
- clone = copy.copy(self)
- clone.cols = []
- for node, col in self.cols:
- if hasattr(col, 'relabeled_clone'):
- clone.cols.append((node, col.relabeled_clone(change_map)))
- else:
- clone.cols.append((node,
- (change_map.get(col[0], col[0]), col[1])))
- return clone
-
- def get_group_by_cols(self):
- cols = []
- for node, col in self.cols:
- if hasattr(node, 'get_group_by_cols'):
- cols.extend(node.get_group_by_cols())
- elif isinstance(col, tuple):
- cols.append(col)
- return cols
-
- def prepare(self):
- return self
-
- def as_sql(self, qn, connection):
- return self.expression.evaluate(self, qn, connection)
-
- #####################################################
- # Visitor methods for initial expression preparation #
- #####################################################
-
- def prepare_node(self, node, query, allow_joins):
- for child in node.children:
- if hasattr(child, 'prepare'):
- child.prepare(self, query, allow_joins)
-
- def prepare_leaf(self, node, query, allow_joins):
- if not allow_joins and LOOKUP_SEP in node.name:
- raise FieldError("Joined field references are not permitted in this query")
-
- field_list = node.name.split(LOOKUP_SEP)
- if node.name in query.aggregates:
- self.cols.append((node, query.aggregate_select[node.name]))
- else:
- try:
- _, sources, _, join_list, path = query.setup_joins(
- field_list, query.get_meta(), query.get_initial_alias(),
- can_reuse=self.reuse)
- self._used_joins = join_list
- targets, _, join_list = query.trim_joins(sources, join_list, path)
- if self.reuse is not None:
- self.reuse.update(join_list)
- for t in targets:
- self.cols.append((node, (join_list[-1], t.column)))
- except FieldDoesNotExist:
- raise FieldError("Cannot resolve keyword %r into field. "
- "Choices are: %s" % (self.name,
- [f.name for f in self.opts.fields]))
-
- ##################################################
- # Visitor methods for final expression evaluation #
- ##################################################
-
- def evaluate_node(self, node, qn, connection):
- expressions = []
- expression_params = []
- for child in node.children:
- if hasattr(child, 'evaluate'):
- sql, params = child.evaluate(self, qn, connection)
- else:
- sql, params = '%s', (child,)
-
- if len(getattr(child, 'children', [])) > 1:
- format = '(%s)'
- else:
- format = '%s'
-
- if sql:
- expressions.append(format % sql)
- expression_params.extend(params)
-
- return connection.ops.combine_expression(node.connector, expressions), expression_params
-
- def evaluate_leaf(self, node, qn, connection):
- col = None
- for n, c in self.cols:
- if n is node:
- col = c
- break
- if col is None:
- raise ValueError("Given node not found")
- if hasattr(col, 'as_sql'):
- return col.as_sql(qn, connection)
- else:
- return '%s.%s' % (qn(col[0]), qn(col[1])), []
-
- def evaluate_date_modifier_node(self, node, qn, connection):
- timedelta = node.children.pop()
- sql, params = self.evaluate_node(node, qn, connection)
- node.children.append(timedelta)
-
- if (timedelta.days == timedelta.seconds == timedelta.microseconds == 0):
- return sql, params
-
- return connection.ops.date_interval_sql(sql, node.connector, timedelta), params
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
index 856bc51f4f..a17cd62f29 100644
--- a/django/db/models/sql/query.py
+++ b/django/db/models/sql/query.py
@@ -14,20 +14,18 @@ import warnings
from django.core.exceptions import FieldError
from django.db import connections, DEFAULT_DB_ALIAS
from django.db.models.constants import LOOKUP_SEP
-from django.db.models.aggregates import refs_aggregate
-from django.db.models.expressions import ExpressionNode
+from django.db.models.expressions import Col, Ref
from django.db.models.fields import FieldDoesNotExist
-from django.db.models.query_utils import Q
+from django.db.models.query_utils import Q, refs_aggregate
from django.db.models.related import PathInfo
-from django.db.models.sql import aggregates as base_aggregates_module
+from django.db.models.aggregates import Count
from django.db.models.sql.constants import (QUERY_TERMS, ORDER_DIR, SINGLE,
ORDER_PATTERN, JoinInfo, SelectInfo)
-from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin, Col
-from django.db.models.sql.expressions import SQLEvaluator
+from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin
from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode,
ExtraWhere, AND, OR, EmptyWhere)
from django.utils import six
-from django.utils.deprecation import RemovedInDjango19Warning
+from django.utils.deprecation import RemovedInDjango19Warning, RemovedInDjango20Warning
from django.utils.encoding import force_text
from django.utils.tree import Node
@@ -49,7 +47,7 @@ class RawQuery(object):
# the compiler can be used to process results.
self.low_mark, self.high_mark = 0, None # Used for offset/limit
self.extra_select = {}
- self.aggregate_select = {}
+ self.annotation_select = {}
def clone(self, using):
return RawQuery(self.sql, using, params=self.params)
@@ -97,7 +95,6 @@ class Query(object):
alias_prefix = 'T'
subq_aliases = frozenset([alias_prefix])
query_terms = QUERY_TERMS
- aggregates_module = base_aggregates_module
compiler = 'SQLCompiler'
@@ -140,13 +137,13 @@ class Query(object):
self.select_for_update_nowait = False
self.select_related = False
- # SQL aggregate-related attributes
- # The _aggregates will be an OrderedDict when used. Due to the cost
+ # SQL annotation-related attributes
+ # The _annotations will be an OrderedDict when used. Due to the cost
# of creating OrderedDict this attribute is created lazily (in
- # self.aggregates property).
- self._aggregates = None # Maps alias -> SQL aggregate function
- self.aggregate_select_mask = None
- self._aggregate_select_cache = None
+ # self.annotations property).
+ self._annotations = None # Maps alias -> Annotation Expression
+ self.annotation_select_mask = None
+ self._annotation_select_cache = None
# Arbitrary maximum limit for select_related. Prevents infinite
# recursion. Can be changed by the depth parameter to select_related().
@@ -155,7 +152,7 @@ class Query(object):
# These are for extensions. The contents are more or less appended
# verbatim to the appropriate clause.
# The _extra attribute is an OrderedDict, lazily created similarly to
- # .aggregates
+ # .annotations
self._extra = None # Maps col_alias -> (col_sql, params).
self.extra_select_mask = None
self._extra_select_cache = None
@@ -175,10 +172,17 @@ class Query(object):
return self._extra
@property
+ def annotations(self):
+ if self._annotations is None:
+ self._annotations = OrderedDict()
+ return self._annotations
+
+ @property
def aggregates(self):
- if self._aggregates is None:
- self._aggregates = OrderedDict()
- return self._aggregates
+ warnings.warn(
+ "The aggregates property is deprecated. Use annotations instead.",
+ RemovedInDjango20Warning, stacklevel=2)
+ return self.annotations
def __str__(self):
"""
@@ -203,7 +207,7 @@ class Query(object):
memo[id(self)] = result
return result
- def prepare(self):
+ def _prepare(self):
return self
def get_compiler(self, using=None, connection=None):
@@ -213,8 +217,8 @@ class Query(object):
connection = connections[using]
# Check that the compiler will be able to execute the query
- for alias, aggregate in self.aggregate_select.items():
- connection.ops.check_aggregate_support(aggregate)
+ for alias, annotation in self.annotation_select.items():
+ connection.ops.check_aggregate_support(annotation)
return connection.ops.compiler(self.compiler)(self, connection, using)
@@ -260,17 +264,17 @@ class Query(object):
obj.select_for_update_nowait = self.select_for_update_nowait
obj.select_related = self.select_related
obj.related_select_cols = []
- obj._aggregates = self._aggregates.copy() if self._aggregates is not None else None
- if self.aggregate_select_mask is None:
- obj.aggregate_select_mask = None
+ obj._annotations = self._annotations.copy() if self._annotations is not None else None
+ if self.annotation_select_mask is None:
+ obj.annotation_select_mask = None
else:
- obj.aggregate_select_mask = self.aggregate_select_mask.copy()
- # _aggregate_select_cache cannot be copied, as doing so breaks the
- # (necessary) state in which both aggregates and
- # _aggregate_select_cache point to the same underlying objects.
+ obj.annotation_select_mask = self.annotation_select_mask.copy()
+ # _annotation_select_cache cannot be copied, as doing so breaks the
+ # (necessary) state in which both annotations and
+ # _annotation_select_cache point to the same underlying objects.
# It will get re-populated in the cloned queryset the next time it's
# used.
- obj._aggregate_select_cache = None
+ obj._annotation_select_cache = None
obj.max_depth = self.max_depth
obj._extra = self._extra.copy() if self._extra is not None else None
if self.extra_select_mask is None:
@@ -299,94 +303,84 @@ class Query(object):
obj._setup_query()
return obj
- def resolve_aggregate(self, value, aggregate, connection):
- """Resolve the value of aggregates returned by the database to
- consistent (and reasonable) types.
-
- This is required because of the predisposition of certain backends
- to return Decimal and long types when they are not needed.
- """
- if value is None:
- if aggregate.is_ordinal:
- return 0
- # Return None as-is
- return value
- elif aggregate.is_ordinal:
- # Any ordinal aggregate (e.g., count) returns an int
- return int(value)
- elif aggregate.is_computed:
- # Any computed aggregate (e.g., avg) returns a float
- return float(value)
- else:
- # Return value depends on the type of the field being processed.
- backend_converters = connection.ops.get_db_converters(aggregate.field.get_internal_type())
- field_converters = aggregate.field.get_db_converters(connection)
- for converter in backend_converters:
- value = converter(value, aggregate.field)
- for converter in field_converters:
- value = converter(value, connection)
- return value
-
def get_aggregation(self, using, force_subq=False):
"""
Returns the dictionary with the values of the existing aggregations.
"""
- if not self.aggregate_select:
+ if not self.annotation_select:
return {}
+ # annotations must be forced into subquery
+ has_annotation = any(
+ annotation for alias, annotation
+ in self.annotation_select.items()
+ if not annotation.contains_aggregate)
+
# If there is a group by clause, aggregating does not add useful
# information but retrieves only the first row. Aggregate
# over the subquery instead.
- if self.group_by is not None or force_subq:
+ if self.group_by is not None or force_subq or has_annotation:
from django.db.models.sql.subqueries import AggregateQuery
- query = AggregateQuery(self.model)
- obj = self.clone()
+ outer_query = AggregateQuery(self.model)
+ inner_query = self.clone()
if not force_subq:
# In forced subq case the ordering and limits will likely
# affect the results.
- obj.clear_ordering(True)
- obj.clear_limits()
- obj.select_for_update = False
- obj.select_related = False
- obj.related_select_cols = []
+ inner_query.clear_ordering(True)
+ inner_query.clear_limits()
+ inner_query.select_for_update = False
+ inner_query.select_related = False
+ inner_query.related_select_cols = []
- relabels = dict((t, 'subquery') for t in self.tables)
+ relabels = dict((t, 'subquery') for t in inner_query.tables)
+ relabels[None] = 'subquery'
# Remove any aggregates marked for reduction from the subquery
# and move them to the outer AggregateQuery.
- for alias, aggregate in self.aggregate_select.items():
- if aggregate.is_summary:
- query.aggregates[alias] = aggregate.relabeled_clone(relabels)
- del obj.aggregate_select[alias]
-
+ for alias, annotation in inner_query.annotation_select.items():
+ if annotation.is_summary:
+ # The annotation is already referring the subquery alias, so we
+ # just need to move the annotation to the outer query.
+ outer_query.annotations[alias] = annotation.relabeled_clone(relabels)
+ del inner_query.annotation_select[alias]
try:
- query.add_subquery(obj, using)
+ outer_query.add_subquery(inner_query, using)
except EmptyResultSet:
return dict(
(alias, None)
- for alias in query.aggregate_select
+ for alias in outer_query.annotation_select
)
else:
- query = self
+ outer_query = self
self.select = []
self.default_cols = False
self._extra = {}
self.remove_inherited_models()
- query.clear_ordering(True)
- query.clear_limits()
- query.select_for_update = False
- query.select_related = False
- query.related_select_cols = []
-
- result = query.get_compiler(using).execute_sql(SINGLE)
+ outer_query.clear_ordering(True)
+ outer_query.clear_limits()
+ outer_query.select_for_update = False
+ outer_query.select_related = False
+ outer_query.related_select_cols = []
+ compiler = outer_query.get_compiler(using)
+ result = compiler.execute_sql(SINGLE)
if result is None:
- result = [None for q in query.aggregate_select.items()]
+ result = [None for q in outer_query.annotation_select.items()]
+
+ fields = [annotation.output_field
+ for alias, annotation in outer_query.annotation_select.items()]
+ converters = compiler.get_converters(fields)
+ for position, (alias, annotation) in enumerate(outer_query.annotation_select.items()):
+ if position in converters:
+ converters[position][1].insert(0, annotation.convert_value)
+ else:
+ converters[position] = ([], [annotation.convert_value], annotation.output_field)
+ result = compiler.apply_converters(result, converters)
return dict(
- (alias, self.resolve_aggregate(val, aggregate, connection=connections[using]))
- for (alias, aggregate), val
- in zip(query.aggregate_select.items(), result)
+ (alias, val)
+ for (alias, annotation), val
+ in zip(outer_query.annotation_select.items(), result)
)
def get_count(self, using):
@@ -394,7 +388,7 @@ class Query(object):
Performs a COUNT() query using the current filter constraints.
"""
obj = self.clone()
- if len(self.select) > 1 or self.aggregate_select or (self.distinct and self.distinct_fields):
+ if len(self.select) > 1 or self.annotation_select or (self.distinct and self.distinct_fields):
# If a select clause exists, then the query has already started to
# specify the columns that are to be returned.
# In this case, we need to use a subquery to evaluate the count.
@@ -769,9 +763,9 @@ class Query(object):
self.group_by = [relabel_column(col) for col in self.group_by]
self.select = [SelectInfo(relabel_column(s.col), s.field)
for s in self.select]
- if self._aggregates:
- self._aggregates = OrderedDict(
- (key, relabel_column(col)) for key, col in self._aggregates.items())
+ if self._annotations:
+ self._annotations = OrderedDict(
+ (key, relabel_column(col)) for key, col in self._annotations.items())
# 2. Rename the alias in the internal table/alias datastructures.
for ident, aliases in self.join_map.items():
@@ -974,52 +968,18 @@ class Query(object):
self.included_inherited_models = {}
def add_aggregate(self, aggregate, model, alias, is_summary):
+ warnings.warn(
+ "add_aggregate() is deprecated. Use add_annotation() instead.",
+ RemovedInDjango20Warning, stacklevel=2)
+ self.add_annotation(aggregate, model, alias, is_summary)
+
+ def add_annotation(self, annotation, model, alias, is_summary):
"""
- Adds a single aggregate expression to the Query
+ Adds a single annotation expression to the Query
"""
- opts = model._meta
- field_list = aggregate.lookup.split(LOOKUP_SEP)
- if len(field_list) == 1 and self._aggregates and aggregate.lookup in self.aggregates:
- # Aggregate is over an annotation
- field_name = field_list[0]
- col = field_name
- source = self.aggregates[field_name]
- if not is_summary:
- raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (
- aggregate.name, field_name, field_name))
- elif ((len(field_list) > 1) or
- (field_list[0] not in [i.name for i in opts.fields]) or
- self.group_by is None or
- not is_summary):
- # If:
- # - the field descriptor has more than one part (foo__bar), or
- # - the field descriptor is referencing an m2m/m2o field, or
- # - this is a reference to a model field (possibly inherited), or
- # - this is an annotation over a model field
- # then we need to explore the joins that are required.
-
- # Join promotion note - we must not remove any rows here, so use
- # outer join if there isn't any existing join.
- _, sources, opts, join_list, path = self.setup_joins(
- field_list, opts, self.get_initial_alias())
-
- # Process the join chain to see if it can be trimmed
- targets, _, join_list = self.trim_joins(sources, join_list, path)
-
- col = targets[0].column
- source = sources[0]
- col = (join_list[-1], col)
- else:
- # The simplest cases. No joins required -
- # just reference the provided column alias.
- field_name = field_list[0]
- source = opts.get_field(field_name)
- col = field_name
- # We want to have the alias in SELECT clause even if mask is set.
- self.append_aggregate_mask([alias])
-
- # Add the aggregate to the query
- aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary)
+ annotation = annotation.resolve_expression(self, summarize=is_summary)
+ self.append_annotation_mask([alias])
+ self.annotations[alias] = annotation
def prepare_lookup_value(self, value, lookups, can_reuse):
# Default lookup if none given is exact.
@@ -1037,9 +997,8 @@ class Query(object):
"Passing callable arguments to queryset is deprecated.",
RemovedInDjango19Warning, stacklevel=2)
value = value()
- elif isinstance(value, ExpressionNode):
- # If value is a query expression, evaluate it
- value = SQLEvaluator(value, self, reuse=can_reuse)
+ elif hasattr(value, 'resolve_expression'):
+ value = value.resolve_expression(self, reuse=can_reuse)
if hasattr(value, 'query') and hasattr(value.query, 'bump_prefix'):
value = value._clone()
value.query.bump_prefix(self)
@@ -1061,8 +1020,8 @@ class Query(object):
Solve the lookup type from the lookup (eg: 'foobar__id__icontains')
"""
lookup_splitted = lookup.split(LOOKUP_SEP)
- if self._aggregates:
- aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.aggregates)
+ if self._annotations:
+ aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.annotations)
if aggregate:
return aggregate_lookups, (), aggregate
_, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta())
@@ -1232,7 +1191,11 @@ class Query(object):
lookup_type = lookups[-1]
else:
assert(len(targets) == 1)
- col = Col(alias, targets[0], field)
+ if hasattr(targets[0], 'as_sql'):
+ # handle Expressions as annotations
+ col = targets[0]
+ else:
+ col = Col(alias, targets[0], field)
condition = self.build_lookup(lookups, col, value)
if not condition:
# Backwards compat for custom lookups
@@ -1278,12 +1241,12 @@ class Query(object):
Returns whether or not all elements of this q_object need to be put
together in the HAVING clause.
"""
- if not self._aggregates:
+ if not self._annotations:
return False
if not isinstance(obj, Node):
- return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.aggregates)[0]
- or (hasattr(obj[1], 'contains_aggregate')
- and obj[1].contains_aggregate(self.aggregates)))
+ return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.annotations)[0]
+ or (hasattr(obj[1], 'refs_aggregate')
+ and obj[1].refs_aggregate(self.annotations)[0]))
return any(self.need_having(c) for c in obj.children)
def split_having_parts(self, q_object, negated=False):
@@ -1390,13 +1353,21 @@ class Query(object):
if name == 'pk':
name = opts.pk.name
try:
- field, model, direct, m2m = opts.get_field_by_name(name)
+ field, model, _, _ = opts.get_field_by_name(name)
except FieldDoesNotExist:
+ # is it an annotation?
+ if self._annotations and name in self._annotations:
+ field, model = self._annotations[name], None
+ if not field.contains_aggregate:
+ # Local non-relational field.
+ final_field = field
+ targets = (field,)
+ break
# We didn't find the current field, so move position back
# one step.
pos -= 1
if pos == -1 or fail_on_missing:
- available = opts.get_all_field_names() + list(self.aggregate_select)
+ available = opts.get_all_field_names() + list(self.annotation_select)
raise FieldError("Cannot resolve keyword %r into field. "
"Choices are: %s" % (name, ", ".join(available)))
break
@@ -1445,6 +1416,11 @@ class Query(object):
break
return path, final_field, targets, names[pos + 1:]
+ def raise_field_error(self, opts, name):
+ available = opts.get_all_field_names() + list(self.annotation_select)
+ raise FieldError("Cannot resolve keyword %r into field. "
+ "Choices are: %s" % (name, ", ".join(available)))
+
def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True):
"""
Compute the necessary table joins for the passage through the fields
@@ -1519,6 +1495,29 @@ class Query(object):
self.unref_alias(joins.pop())
return targets, joins[-1], joins
+ def resolve_ref(self, name, allow_joins, reuse, summarize):
+ if not allow_joins and LOOKUP_SEP in name:
+ raise FieldError("Joined field references are not permitted in this query")
+ if name in self.annotations:
+ if summarize:
+ return Ref(name, self.annotation_select[name])
+ else:
+ return self.annotation_select[name]
+ else:
+ field_list = name.split(LOOKUP_SEP)
+ field, sources, opts, join_list, path = self.setup_joins(
+ field_list, self.get_meta(),
+ self.get_initial_alias(), reuse)
+ targets, _, join_list = self.trim_joins(sources, join_list, path)
+ if len(targets) > 1:
+ raise FieldError("Referencing multicolumn fields with F() objects "
+ "isn't supported")
+ if reuse is not None:
+ reuse.update(join_list)
+ col = Col(join_list[-1], targets[0], sources[0])
+ col._used_joins = join_list
+ return col
+
def split_exclude(self, filter_expr, prefix, can_reuse, names_with_path):
"""
When doing an exclude against any kind of N-to-many relation, we need
@@ -1633,7 +1632,7 @@ class Query(object):
self.default_cols = False
self.select_related = False
self.set_extra_mask(())
- self.set_aggregate_mask(())
+ self.set_annotation_mask(())
def clear_select_fields(self):
"""
@@ -1676,7 +1675,7 @@ class Query(object):
raise
else:
names = sorted(opts.get_all_field_names() + list(self.extra)
- + list(self.aggregate_select))
+ + list(self.annotation_select))
raise FieldError("Cannot resolve keyword %r into field. "
"Choices are: %s" % (name, ", ".join(names)))
self.remove_inherited_models()
@@ -1725,39 +1724,55 @@ class Query(object):
for col, _ in self.select:
self.group_by.append(col)
+ if self._annotations:
+ for alias, annotation in six.iteritems(self.annotations):
+ for col in annotation.get_group_by_cols():
+ self.group_by.append(col)
+
def add_count_column(self):
"""
Converts the query to do count(...) or count(distinct(pk)) in order to
get its size.
"""
+ summarize = False
if not self.distinct:
if not self.select:
- count = self.aggregates_module.Count('*', is_summary=True)
+ count = Count('*')
+ summarize = True
else:
assert len(self.select) == 1, \
"Cannot add count col with multiple cols in 'select': %r" % self.select
- count = self.aggregates_module.Count(self.select[0].col)
+ col = self.select[0].col
+ if isinstance(col, (tuple, list)):
+ count = Count(col[1])
+ else:
+ count = Count(col)
+
else:
opts = self.get_meta()
if not self.select:
- count = self.aggregates_module.Count(
- (self.join((None, opts.db_table, None)), opts.pk.column),
- is_summary=True, distinct=True)
+ lookup = self.join((None, opts.db_table, None)), opts.pk.column
+ count = Count(lookup[1], distinct=True)
+ summarize = True
else:
# Because of SQL portability issues, multi-column, distinct
# counts need a sub-query -- see get_count() for details.
assert len(self.select) == 1, \
"Cannot add count col with multiple cols in 'select'."
-
- count = self.aggregates_module.Count(self.select[0].col, distinct=True)
+ col = self.select[0].col
+ if isinstance(col, (tuple, list)):
+ count = Count(col[1], distinct=True)
+ else:
+ count = Count(col, distinct=True)
# Distinct handling is done in Count(), so don't do it at this
# level.
self.distinct = False
# Set only aggregate to be the count column.
- # Clear out the select cache to reflect the new unmasked aggregates.
- self._aggregates = {None: count}
- self.set_aggregate_mask(None)
+ # Clear out the select cache to reflect the new unmasked annotations.
+ count = count.resolve_expression(self, summarize=summarize)
+ self._annotations = {None: count}
+ self.set_annotation_mask(None)
self.group_by = None
def add_select_related(self, fields):
@@ -1886,16 +1901,28 @@ class Query(object):
target[model] = set(f.name for f in fields)
def set_aggregate_mask(self, names):
- "Set the mask of aggregates that will actually be returned by the SELECT"
+ warnings.warn(
+ "set_aggregate_mask() is deprecated. Use set_annotation_mask() instead.",
+ RemovedInDjango20Warning, stacklevel=2)
+ self.set_annotation_mask(names)
+
+ def set_annotation_mask(self, names):
+ "Set the mask of annotations that will actually be returned by the SELECT"
if names is None:
- self.aggregate_select_mask = None
+ self.annotation_select_mask = None
else:
- self.aggregate_select_mask = set(names)
- self._aggregate_select_cache = None
+ self.annotation_select_mask = set(names)
+ self._annotation_select_cache = None
def append_aggregate_mask(self, names):
- if self.aggregate_select_mask is not None:
- self.set_aggregate_mask(set(names).union(self.aggregate_select_mask))
+ warnings.warn(
+ "append_aggregate_mask() is deprecated. Use append_annotation_mask() instead.",
+ RemovedInDjango20Warning, stacklevel=2)
+ self.append_annotation_mask(names)
+
+ def append_annotation_mask(self, names):
+ if self.annotation_select_mask is not None:
+ self.set_annotation_mask(set(names).union(self.annotation_select_mask))
def set_extra_mask(self, names):
"""
@@ -1910,24 +1937,31 @@ class Query(object):
self._extra_select_cache = None
@property
- def aggregate_select(self):
+ def annotation_select(self):
"""The OrderedDict of aggregate columns that are not masked, and should
be used in the SELECT clause.
This result is cached for optimization purposes.
"""
- if self._aggregate_select_cache is not None:
- return self._aggregate_select_cache
- elif not self._aggregates:
+ if self._annotation_select_cache is not None:
+ return self._annotation_select_cache
+ elif not self._annotations:
return {}
- elif self.aggregate_select_mask is not None:
- self._aggregate_select_cache = OrderedDict(
- (k, v) for k, v in self.aggregates.items()
- if k in self.aggregate_select_mask
+ elif self.annotation_select_mask is not None:
+ self._annotation_select_cache = OrderedDict(
+ (k, v) for k, v in self.annotations.items()
+ if k in self.annotation_select_mask
)
- return self._aggregate_select_cache
+ return self._annotation_select_cache
else:
- return self.aggregates
+ return self.annotations
+
+ @property
+ def aggregate_select(self):
+ warnings.warn(
+ "aggregate_select() is deprecated. Use annotation_select() instead.",
+ RemovedInDjango20Warning, stacklevel=2)
+ return self.annotation_select
@property
def extra_select(self):
diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py
index 2f0de5b80c..6f3f7358d3 100644
--- a/django/db/models/sql/subqueries.py
+++ b/django/db/models/sql/subqueries.py
@@ -7,9 +7,9 @@ from django.core.exceptions import FieldError
from django.db import connections
from django.db.models.query_utils import Q
from django.db.models.constants import LOOKUP_SEP
+from django.db.models.expressions import Date, DateTime, Col
from django.db.models.fields import DateField, DateTimeField, FieldDoesNotExist
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, NO_RESULTS, SelectInfo
-from django.db.models.sql.datastructures import Date, DateTime
from django.db.models.sql.query import Query
from django.utils import six
from django.utils import timezone
@@ -229,7 +229,7 @@ class DateQuery(Query):
))
self._check_field(field) # overridden in DateTimeQuery
alias = joins[-1]
- select = self._get_select((alias, field.column), lookup_type)
+ select = self._get_select(Col(alias, field), lookup_type)
self.clear_select_clause()
self.select = [SelectInfo(select, None)]
self.distinct = True
diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py
index f65e593a3a..13815cb68c 100644
--- a/django/db/models/sql/where.py
+++ b/django/db/models/sql/where.py
@@ -10,7 +10,6 @@ import warnings
from django.conf import settings
from django.db.models.fields import DateTimeField, Field
from django.db.models.sql.datastructures import EmptyResultSet, Empty
-from django.db.models.sql.aggregates import Aggregate
from django.utils.deprecation import RemovedInDjango19Warning
from django.utils.six.moves import xrange
from django.utils import timezone
@@ -78,7 +77,7 @@ class WhereNode(tree.Node):
else:
value_annotation = bool(value)
- if hasattr(obj, "prepare"):
+ if hasattr(obj, 'prepare'):
value = obj.prepare(lookup_type, value)
return (obj, lookup_type, value_annotation, value)
@@ -187,11 +186,9 @@ class WhereNode(tree.Node):
lvalue, params = lvalue.process(lookup_type, params_or_value, connection)
except EmptyShortCircuit:
raise EmptyResultSet
- elif isinstance(lvalue, Aggregate):
- params = lvalue.field.get_db_prep_lookup(lookup_type, params_or_value, connection)
else:
- raise TypeError("'make_atom' expects a Constraint or an Aggregate "
- "as the first item of its 'child' argument.")
+ raise TypeError("'make_atom' expects a Constraint as the first "
+ "item of its 'child' argument.")
if isinstance(lvalue, tuple):
# A direct database column lookup.