diff options
| author | Josh Smeaton <josh.smeaton@gmail.com> | 2013-12-26 00:13:18 +1100 |
|---|---|---|
| committer | Marc Tamlyn <marc.tamlyn@gmail.com> | 2014-11-15 14:00:43 +0000 |
| commit | f59fd15c4928caf3dfcbd50f6ab47be409a43b01 (patch) | |
| tree | fe4a04d98359e1ffcbfe991303eb97d9a8e16afc /django/db/models/sql | |
| parent | 39e3ef88c237e3f4cedc89cd36494a6d3f490812 (diff) | |
Fixed #14030 -- Allowed annotations to accept all expressions
Diffstat (limited to 'django/db/models/sql')
| -rw-r--r-- | django/db/models/sql/aggregates.py | 8 | ||||
| -rw-r--r-- | django/db/models/sql/compiler.py | 94 | ||||
| -rw-r--r-- | django/db/models/sql/datastructures.py | 66 | ||||
| -rw-r--r-- | django/db/models/sql/expressions.py | 119 | ||||
| -rw-r--r-- | django/db/models/sql/query.py | 384 | ||||
| -rw-r--r-- | django/db/models/sql/subqueries.py | 4 | ||||
| -rw-r--r-- | django/db/models/sql/where.py | 9 |
7 files changed, 271 insertions, 413 deletions
diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py index 8274d43621..6ebf5fb966 100644 --- a/django/db/models/sql/aggregates.py +++ b/django/db/models/sql/aggregates.py @@ -2,15 +2,23 @@ Classes to represent the default SQL aggregate functions """ import copy +import warnings from django.db.models.fields import IntegerField, FloatField from django.db.models.lookups import RegisterLookupMixin +from django.utils.deprecation import RemovedInDjango20Warning from django.utils.functional import cached_property __all__ = ['Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance'] +warnings.warn( + "django.db.models.sql.aggregates is deprecated. Use " + "django.db.models.aggregates instead.", + RemovedInDjango20Warning, stacklevel=2) + + class Aggregate(RegisterLookupMixin): """ Default SQL Aggregate. diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 33fe343b5b..5f425a7543 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -4,12 +4,10 @@ from django.conf import settings from django.core.exceptions import FieldError from django.db.backends.utils import truncate_name from django.db.models.constants import LOOKUP_SEP -from django.db.models.expressions import ExpressionNode from django.db.models.query_utils import select_related_descend, QueryWrapper from django.db.models.sql.constants import (CURSOR, SINGLE, MULTI, NO_RESULTS, ORDER_DIR, GET_ITERATOR_CHUNK_SIZE, SelectInfo) from django.db.models.sql.datastructures import EmptyResultSet -from django.db.models.sql.expressions import SQLEvaluator from django.db.models.sql.query import get_order_dir, Query from django.db.transaction import TransactionManagementError from django.db.utils import DatabaseError @@ -248,8 +246,8 @@ class SQLCompiler(object): aliases.update(new_aliases) max_name_length = self.connection.ops.max_name_length() - for alias, aggregate in self.query.aggregate_select.items(): - agg_sql, agg_params = self.compile(aggregate) + for alias, annotation in self.query.annotation_select.items(): + agg_sql, agg_params = self.compile(annotation) if alias is None: result.append(agg_sql) else: @@ -409,7 +407,7 @@ class SQLCompiler(object): group_by.append((str(field), [])) continue col, order = get_order_dir(field, asc) - if col in self.query.aggregate_select: + if col in self.query.annotation_select: result.append('%s %s' % (qn(col), order)) continue if '.' in field: @@ -718,25 +716,17 @@ class SQLCompiler(object): """ fields = None converters = None - has_aggregate_select = bool(self.query.aggregate_select) + has_annotation_select = bool(self.query.annotation_select) for rows in self.execute_sql(MULTI): for row in rows: - if has_aggregate_select: - loaded_fields = ( - self.query.get_loaded_field_names().get(self.query.model, set()) or - self.query.select - ) - aggregate_start = len(self.query.extra_select) + len(loaded_fields) - aggregate_end = aggregate_start + len(self.query.aggregate_select) if fields is None: # We only set this up here because # related_select_cols isn't populated until # execute_sql() has been called. - # We also include types of fields of related models that - # will be included via select_related() for the benefit - # of MySQL/MySQLdb when boolean fields are involved - # (#15040). + # If the field was deferred, exclude it from being passed + # into `get_converters` because it wasn't selected. + only_load = self.deferred_to_columns() # This code duplicates the logic for the order of fields # found in get_columns(). It would be nice to clean this up. @@ -746,30 +736,45 @@ class SQLCompiler(object): fields = self.query.get_meta().concrete_fields else: fields = [] - fields = fields + [f.field for f in self.query.related_select_cols] - # If the field was deferred, exclude it from being passed - # into `get_converters` because it wasn't selected. - only_load = self.deferred_to_columns() if only_load: - fields = [f for f in fields if f.model._meta.db_table not in only_load or - f.column in only_load[f.model._meta.db_table]] - if has_aggregate_select: - # pad None in to fields for aggregates - fields = fields[:aggregate_start] + [ - None for x in range(0, aggregate_end - aggregate_start) - ] + fields[aggregate_start:] + # strip deferred fields + fields = [ + f for f in fields if + f.model._meta.db_table not in only_load or + f.column in only_load[f.model._meta.db_table] + ] + + # annotations come before the related cols + if has_annotation_select: + # extra is always at the start of the field list + prepended_cols = len(self.query.extra_select) + annotation_start = len(fields) + prepended_cols + fields = fields + [ + anno.output_field for alias, anno in self.query.annotation_select.items()] + annotation_end = len(fields) + prepended_cols + + # add related fields + fields = fields + [ + # strip deferred + f.field for f in self.query.related_select_cols if + f.field.model._meta.db_table not in only_load or + f.field.column in only_load[f.field.model._meta.db_table] + ] + converters = self.get_converters(fields) + if has_annotation_select: + for (alias, annotation), position in zip( + self.query.annotation_select.items(), + range(annotation_start, annotation_end + 1)): + if position in converters: + # annotation conversions always run first + converters[position][1].insert(0, annotation.convert_value) + else: + converters[position] = ([], [annotation.convert_value], annotation.output_field) + if converters: row = self.apply_converters(row, converters) - - if has_aggregate_select: - row = tuple(row[:aggregate_start]) + tuple( - self.query.resolve_aggregate(value, aggregate, self.connection) - for (alias, aggregate), value - in zip(self.query.aggregate_select.items(), row[aggregate_start:aggregate_end]) - ) + tuple(row[aggregate_end:]) - yield row def has_results(self): @@ -878,7 +883,7 @@ class SQLInsertCompiler(SQLCompiler): elif hasattr(field, 'get_placeholder'): # Some fields (e.g. geo fields) need special munging before # they can be inserted. - return field.get_placeholder(val, self.connection) + return field.get_placeholder(val, self, self.connection) else: # Return the common case for the placeholder return '%s' @@ -985,8 +990,10 @@ class SQLUpdateCompiler(SQLCompiler): result.append('SET') values, update_params = [], [] for field, model, val in self.query.values: - if hasattr(val, 'prepare_database_save'): - if field.rel or isinstance(val, ExpressionNode): + if hasattr(val, 'resolve_expression'): + val = val.resolve_expression(self.query, allow_joins=False) + elif hasattr(val, 'prepare_database_save'): + if field.rel: val = val.prepare_database_save(field) else: raise TypeError("Database is trying to update a relational field " @@ -998,12 +1005,9 @@ class SQLUpdateCompiler(SQLCompiler): # Getting the placeholder for the field. if hasattr(field, 'get_placeholder'): - placeholder = field.get_placeholder(val, self.connection) + placeholder = field.get_placeholder(val, self, self.connection) else: placeholder = '%s' - - if hasattr(val, 'evaluate'): - val = SQLEvaluator(val, self.query, allow_joins=False) name = field.column if hasattr(val, 'as_sql'): sql, params = self.compile(val) @@ -1103,8 +1107,8 @@ class SQLAggregateCompiler(SQLCompiler): qn = self sql, params = [], [] - for aggregate in self.query.aggregate_select.values(): - agg_sql, agg_params = self.compile(aggregate) + for annotation in self.query.annotation_select.values(): + agg_sql, agg_params = self.compile(annotation) sql.append(agg_sql) params.extend(agg_params) sql = ', '.join(sql) diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py index f9c9c259de..321451ac42 100644 --- a/django/db/models/sql/datastructures.py +++ b/django/db/models/sql/datastructures.py @@ -4,33 +4,6 @@ the SQL domain. """ -class Col(object): - def __init__(self, alias, target, source): - self.alias, self.target, self.source = alias, target, source - - def as_sql(self, qn, connection): - return "%s.%s" % (qn(self.alias), qn(self.target.column)), [] - - @property - def output_field(self): - return self.source - - def relabeled_clone(self, relabels): - return self.__class__(relabels.get(self.alias, self.alias), self.target, self.source) - - def get_group_by_cols(self): - return [(self.alias, self.target.column)] - - def get_lookup(self, name): - return self.output_field.get_lookup(name) - - def get_transform(self, name): - return self.output_field.get_transform(name) - - def prepare(self): - return self - - class EmptyResultSet(Exception): pass @@ -49,42 +22,3 @@ class MultiJoin(Exception): class Empty(object): pass - - -class Date(object): - """ - Add a date selection column. - """ - def __init__(self, col, lookup_type): - self.col = col - self.lookup_type = lookup_type - - def relabeled_clone(self, change_map): - return self.__class__((change_map.get(self.col[0], self.col[0]), self.col[1])) - - def as_sql(self, qn, connection): - if isinstance(self.col, (list, tuple)): - col = '%s.%s' % tuple(qn(c) for c in self.col) - else: - col = self.col - return connection.ops.date_trunc_sql(self.lookup_type, col), [] - - -class DateTime(object): - """ - Add a datetime selection column. - """ - def __init__(self, col, lookup_type, tzname): - self.col = col - self.lookup_type = lookup_type - self.tzname = tzname - - def relabeled_clone(self, change_map): - return self.__class__((change_map.get(self.col[0], self.col[0]), self.col[1])) - - def as_sql(self, qn, connection): - if isinstance(self.col, (list, tuple)): - col = '%s.%s' % tuple(qn(c) for c in self.col) - else: - col = self.col - return connection.ops.datetime_trunc_sql(self.lookup_type, col, self.tzname) diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py deleted file mode 100644 index e15cc2642c..0000000000 --- a/django/db/models/sql/expressions.py +++ /dev/null @@ -1,119 +0,0 @@ -import copy - -from django.core.exceptions import FieldError -from django.db.models.constants import LOOKUP_SEP -from django.db.models.fields import FieldDoesNotExist - - -class SQLEvaluator(object): - def __init__(self, expression, query, allow_joins=True, reuse=None): - self.expression = expression - self.opts = query.get_meta() - self.reuse = reuse - self.cols = [] - self.expression.prepare(self, query, allow_joins) - - def relabeled_clone(self, change_map): - clone = copy.copy(self) - clone.cols = [] - for node, col in self.cols: - if hasattr(col, 'relabeled_clone'): - clone.cols.append((node, col.relabeled_clone(change_map))) - else: - clone.cols.append((node, - (change_map.get(col[0], col[0]), col[1]))) - return clone - - def get_group_by_cols(self): - cols = [] - for node, col in self.cols: - if hasattr(node, 'get_group_by_cols'): - cols.extend(node.get_group_by_cols()) - elif isinstance(col, tuple): - cols.append(col) - return cols - - def prepare(self): - return self - - def as_sql(self, qn, connection): - return self.expression.evaluate(self, qn, connection) - - ##################################################### - # Visitor methods for initial expression preparation # - ##################################################### - - def prepare_node(self, node, query, allow_joins): - for child in node.children: - if hasattr(child, 'prepare'): - child.prepare(self, query, allow_joins) - - def prepare_leaf(self, node, query, allow_joins): - if not allow_joins and LOOKUP_SEP in node.name: - raise FieldError("Joined field references are not permitted in this query") - - field_list = node.name.split(LOOKUP_SEP) - if node.name in query.aggregates: - self.cols.append((node, query.aggregate_select[node.name])) - else: - try: - _, sources, _, join_list, path = query.setup_joins( - field_list, query.get_meta(), query.get_initial_alias(), - can_reuse=self.reuse) - self._used_joins = join_list - targets, _, join_list = query.trim_joins(sources, join_list, path) - if self.reuse is not None: - self.reuse.update(join_list) - for t in targets: - self.cols.append((node, (join_list[-1], t.column))) - except FieldDoesNotExist: - raise FieldError("Cannot resolve keyword %r into field. " - "Choices are: %s" % (self.name, - [f.name for f in self.opts.fields])) - - ################################################## - # Visitor methods for final expression evaluation # - ################################################## - - def evaluate_node(self, node, qn, connection): - expressions = [] - expression_params = [] - for child in node.children: - if hasattr(child, 'evaluate'): - sql, params = child.evaluate(self, qn, connection) - else: - sql, params = '%s', (child,) - - if len(getattr(child, 'children', [])) > 1: - format = '(%s)' - else: - format = '%s' - - if sql: - expressions.append(format % sql) - expression_params.extend(params) - - return connection.ops.combine_expression(node.connector, expressions), expression_params - - def evaluate_leaf(self, node, qn, connection): - col = None - for n, c in self.cols: - if n is node: - col = c - break - if col is None: - raise ValueError("Given node not found") - if hasattr(col, 'as_sql'): - return col.as_sql(qn, connection) - else: - return '%s.%s' % (qn(col[0]), qn(col[1])), [] - - def evaluate_date_modifier_node(self, node, qn, connection): - timedelta = node.children.pop() - sql, params = self.evaluate_node(node, qn, connection) - node.children.append(timedelta) - - if (timedelta.days == timedelta.seconds == timedelta.microseconds == 0): - return sql, params - - return connection.ops.date_interval_sql(sql, node.connector, timedelta), params diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 856bc51f4f..a17cd62f29 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -14,20 +14,18 @@ import warnings from django.core.exceptions import FieldError from django.db import connections, DEFAULT_DB_ALIAS from django.db.models.constants import LOOKUP_SEP -from django.db.models.aggregates import refs_aggregate -from django.db.models.expressions import ExpressionNode +from django.db.models.expressions import Col, Ref from django.db.models.fields import FieldDoesNotExist -from django.db.models.query_utils import Q +from django.db.models.query_utils import Q, refs_aggregate from django.db.models.related import PathInfo -from django.db.models.sql import aggregates as base_aggregates_module +from django.db.models.aggregates import Count from django.db.models.sql.constants import (QUERY_TERMS, ORDER_DIR, SINGLE, ORDER_PATTERN, JoinInfo, SelectInfo) -from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin, Col -from django.db.models.sql.expressions import SQLEvaluator +from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode, ExtraWhere, AND, OR, EmptyWhere) from django.utils import six -from django.utils.deprecation import RemovedInDjango19Warning +from django.utils.deprecation import RemovedInDjango19Warning, RemovedInDjango20Warning from django.utils.encoding import force_text from django.utils.tree import Node @@ -49,7 +47,7 @@ class RawQuery(object): # the compiler can be used to process results. self.low_mark, self.high_mark = 0, None # Used for offset/limit self.extra_select = {} - self.aggregate_select = {} + self.annotation_select = {} def clone(self, using): return RawQuery(self.sql, using, params=self.params) @@ -97,7 +95,6 @@ class Query(object): alias_prefix = 'T' subq_aliases = frozenset([alias_prefix]) query_terms = QUERY_TERMS - aggregates_module = base_aggregates_module compiler = 'SQLCompiler' @@ -140,13 +137,13 @@ class Query(object): self.select_for_update_nowait = False self.select_related = False - # SQL aggregate-related attributes - # The _aggregates will be an OrderedDict when used. Due to the cost + # SQL annotation-related attributes + # The _annotations will be an OrderedDict when used. Due to the cost # of creating OrderedDict this attribute is created lazily (in - # self.aggregates property). - self._aggregates = None # Maps alias -> SQL aggregate function - self.aggregate_select_mask = None - self._aggregate_select_cache = None + # self.annotations property). + self._annotations = None # Maps alias -> Annotation Expression + self.annotation_select_mask = None + self._annotation_select_cache = None # Arbitrary maximum limit for select_related. Prevents infinite # recursion. Can be changed by the depth parameter to select_related(). @@ -155,7 +152,7 @@ class Query(object): # These are for extensions. The contents are more or less appended # verbatim to the appropriate clause. # The _extra attribute is an OrderedDict, lazily created similarly to - # .aggregates + # .annotations self._extra = None # Maps col_alias -> (col_sql, params). self.extra_select_mask = None self._extra_select_cache = None @@ -175,10 +172,17 @@ class Query(object): return self._extra @property + def annotations(self): + if self._annotations is None: + self._annotations = OrderedDict() + return self._annotations + + @property def aggregates(self): - if self._aggregates is None: - self._aggregates = OrderedDict() - return self._aggregates + warnings.warn( + "The aggregates property is deprecated. Use annotations instead.", + RemovedInDjango20Warning, stacklevel=2) + return self.annotations def __str__(self): """ @@ -203,7 +207,7 @@ class Query(object): memo[id(self)] = result return result - def prepare(self): + def _prepare(self): return self def get_compiler(self, using=None, connection=None): @@ -213,8 +217,8 @@ class Query(object): connection = connections[using] # Check that the compiler will be able to execute the query - for alias, aggregate in self.aggregate_select.items(): - connection.ops.check_aggregate_support(aggregate) + for alias, annotation in self.annotation_select.items(): + connection.ops.check_aggregate_support(annotation) return connection.ops.compiler(self.compiler)(self, connection, using) @@ -260,17 +264,17 @@ class Query(object): obj.select_for_update_nowait = self.select_for_update_nowait obj.select_related = self.select_related obj.related_select_cols = [] - obj._aggregates = self._aggregates.copy() if self._aggregates is not None else None - if self.aggregate_select_mask is None: - obj.aggregate_select_mask = None + obj._annotations = self._annotations.copy() if self._annotations is not None else None + if self.annotation_select_mask is None: + obj.annotation_select_mask = None else: - obj.aggregate_select_mask = self.aggregate_select_mask.copy() - # _aggregate_select_cache cannot be copied, as doing so breaks the - # (necessary) state in which both aggregates and - # _aggregate_select_cache point to the same underlying objects. + obj.annotation_select_mask = self.annotation_select_mask.copy() + # _annotation_select_cache cannot be copied, as doing so breaks the + # (necessary) state in which both annotations and + # _annotation_select_cache point to the same underlying objects. # It will get re-populated in the cloned queryset the next time it's # used. - obj._aggregate_select_cache = None + obj._annotation_select_cache = None obj.max_depth = self.max_depth obj._extra = self._extra.copy() if self._extra is not None else None if self.extra_select_mask is None: @@ -299,94 +303,84 @@ class Query(object): obj._setup_query() return obj - def resolve_aggregate(self, value, aggregate, connection): - """Resolve the value of aggregates returned by the database to - consistent (and reasonable) types. - - This is required because of the predisposition of certain backends - to return Decimal and long types when they are not needed. - """ - if value is None: - if aggregate.is_ordinal: - return 0 - # Return None as-is - return value - elif aggregate.is_ordinal: - # Any ordinal aggregate (e.g., count) returns an int - return int(value) - elif aggregate.is_computed: - # Any computed aggregate (e.g., avg) returns a float - return float(value) - else: - # Return value depends on the type of the field being processed. - backend_converters = connection.ops.get_db_converters(aggregate.field.get_internal_type()) - field_converters = aggregate.field.get_db_converters(connection) - for converter in backend_converters: - value = converter(value, aggregate.field) - for converter in field_converters: - value = converter(value, connection) - return value - def get_aggregation(self, using, force_subq=False): """ Returns the dictionary with the values of the existing aggregations. """ - if not self.aggregate_select: + if not self.annotation_select: return {} + # annotations must be forced into subquery + has_annotation = any( + annotation for alias, annotation + in self.annotation_select.items() + if not annotation.contains_aggregate) + # If there is a group by clause, aggregating does not add useful # information but retrieves only the first row. Aggregate # over the subquery instead. - if self.group_by is not None or force_subq: + if self.group_by is not None or force_subq or has_annotation: from django.db.models.sql.subqueries import AggregateQuery - query = AggregateQuery(self.model) - obj = self.clone() + outer_query = AggregateQuery(self.model) + inner_query = self.clone() if not force_subq: # In forced subq case the ordering and limits will likely # affect the results. - obj.clear_ordering(True) - obj.clear_limits() - obj.select_for_update = False - obj.select_related = False - obj.related_select_cols = [] + inner_query.clear_ordering(True) + inner_query.clear_limits() + inner_query.select_for_update = False + inner_query.select_related = False + inner_query.related_select_cols = [] - relabels = dict((t, 'subquery') for t in self.tables) + relabels = dict((t, 'subquery') for t in inner_query.tables) + relabels[None] = 'subquery' # Remove any aggregates marked for reduction from the subquery # and move them to the outer AggregateQuery. - for alias, aggregate in self.aggregate_select.items(): - if aggregate.is_summary: - query.aggregates[alias] = aggregate.relabeled_clone(relabels) - del obj.aggregate_select[alias] - + for alias, annotation in inner_query.annotation_select.items(): + if annotation.is_summary: + # The annotation is already referring the subquery alias, so we + # just need to move the annotation to the outer query. + outer_query.annotations[alias] = annotation.relabeled_clone(relabels) + del inner_query.annotation_select[alias] try: - query.add_subquery(obj, using) + outer_query.add_subquery(inner_query, using) except EmptyResultSet: return dict( (alias, None) - for alias in query.aggregate_select + for alias in outer_query.annotation_select ) else: - query = self + outer_query = self self.select = [] self.default_cols = False self._extra = {} self.remove_inherited_models() - query.clear_ordering(True) - query.clear_limits() - query.select_for_update = False - query.select_related = False - query.related_select_cols = [] - - result = query.get_compiler(using).execute_sql(SINGLE) + outer_query.clear_ordering(True) + outer_query.clear_limits() + outer_query.select_for_update = False + outer_query.select_related = False + outer_query.related_select_cols = [] + compiler = outer_query.get_compiler(using) + result = compiler.execute_sql(SINGLE) if result is None: - result = [None for q in query.aggregate_select.items()] + result = [None for q in outer_query.annotation_select.items()] + + fields = [annotation.output_field + for alias, annotation in outer_query.annotation_select.items()] + converters = compiler.get_converters(fields) + for position, (alias, annotation) in enumerate(outer_query.annotation_select.items()): + if position in converters: + converters[position][1].insert(0, annotation.convert_value) + else: + converters[position] = ([], [annotation.convert_value], annotation.output_field) + result = compiler.apply_converters(result, converters) return dict( - (alias, self.resolve_aggregate(val, aggregate, connection=connections[using])) - for (alias, aggregate), val - in zip(query.aggregate_select.items(), result) + (alias, val) + for (alias, annotation), val + in zip(outer_query.annotation_select.items(), result) ) def get_count(self, using): @@ -394,7 +388,7 @@ class Query(object): Performs a COUNT() query using the current filter constraints. """ obj = self.clone() - if len(self.select) > 1 or self.aggregate_select or (self.distinct and self.distinct_fields): + if len(self.select) > 1 or self.annotation_select or (self.distinct and self.distinct_fields): # If a select clause exists, then the query has already started to # specify the columns that are to be returned. # In this case, we need to use a subquery to evaluate the count. @@ -769,9 +763,9 @@ class Query(object): self.group_by = [relabel_column(col) for col in self.group_by] self.select = [SelectInfo(relabel_column(s.col), s.field) for s in self.select] - if self._aggregates: - self._aggregates = OrderedDict( - (key, relabel_column(col)) for key, col in self._aggregates.items()) + if self._annotations: + self._annotations = OrderedDict( + (key, relabel_column(col)) for key, col in self._annotations.items()) # 2. Rename the alias in the internal table/alias datastructures. for ident, aliases in self.join_map.items(): @@ -974,52 +968,18 @@ class Query(object): self.included_inherited_models = {} def add_aggregate(self, aggregate, model, alias, is_summary): + warnings.warn( + "add_aggregate() is deprecated. Use add_annotation() instead.", + RemovedInDjango20Warning, stacklevel=2) + self.add_annotation(aggregate, model, alias, is_summary) + + def add_annotation(self, annotation, model, alias, is_summary): """ - Adds a single aggregate expression to the Query + Adds a single annotation expression to the Query """ - opts = model._meta - field_list = aggregate.lookup.split(LOOKUP_SEP) - if len(field_list) == 1 and self._aggregates and aggregate.lookup in self.aggregates: - # Aggregate is over an annotation - field_name = field_list[0] - col = field_name - source = self.aggregates[field_name] - if not is_summary: - raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % ( - aggregate.name, field_name, field_name)) - elif ((len(field_list) > 1) or - (field_list[0] not in [i.name for i in opts.fields]) or - self.group_by is None or - not is_summary): - # If: - # - the field descriptor has more than one part (foo__bar), or - # - the field descriptor is referencing an m2m/m2o field, or - # - this is a reference to a model field (possibly inherited), or - # - this is an annotation over a model field - # then we need to explore the joins that are required. - - # Join promotion note - we must not remove any rows here, so use - # outer join if there isn't any existing join. - _, sources, opts, join_list, path = self.setup_joins( - field_list, opts, self.get_initial_alias()) - - # Process the join chain to see if it can be trimmed - targets, _, join_list = self.trim_joins(sources, join_list, path) - - col = targets[0].column - source = sources[0] - col = (join_list[-1], col) - else: - # The simplest cases. No joins required - - # just reference the provided column alias. - field_name = field_list[0] - source = opts.get_field(field_name) - col = field_name - # We want to have the alias in SELECT clause even if mask is set. - self.append_aggregate_mask([alias]) - - # Add the aggregate to the query - aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary) + annotation = annotation.resolve_expression(self, summarize=is_summary) + self.append_annotation_mask([alias]) + self.annotations[alias] = annotation def prepare_lookup_value(self, value, lookups, can_reuse): # Default lookup if none given is exact. @@ -1037,9 +997,8 @@ class Query(object): "Passing callable arguments to queryset is deprecated.", RemovedInDjango19Warning, stacklevel=2) value = value() - elif isinstance(value, ExpressionNode): - # If value is a query expression, evaluate it - value = SQLEvaluator(value, self, reuse=can_reuse) + elif hasattr(value, 'resolve_expression'): + value = value.resolve_expression(self, reuse=can_reuse) if hasattr(value, 'query') and hasattr(value.query, 'bump_prefix'): value = value._clone() value.query.bump_prefix(self) @@ -1061,8 +1020,8 @@ class Query(object): Solve the lookup type from the lookup (eg: 'foobar__id__icontains') """ lookup_splitted = lookup.split(LOOKUP_SEP) - if self._aggregates: - aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.aggregates) + if self._annotations: + aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.annotations) if aggregate: return aggregate_lookups, (), aggregate _, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta()) @@ -1232,7 +1191,11 @@ class Query(object): lookup_type = lookups[-1] else: assert(len(targets) == 1) - col = Col(alias, targets[0], field) + if hasattr(targets[0], 'as_sql'): + # handle Expressions as annotations + col = targets[0] + else: + col = Col(alias, targets[0], field) condition = self.build_lookup(lookups, col, value) if not condition: # Backwards compat for custom lookups @@ -1278,12 +1241,12 @@ class Query(object): Returns whether or not all elements of this q_object need to be put together in the HAVING clause. """ - if not self._aggregates: + if not self._annotations: return False if not isinstance(obj, Node): - return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.aggregates)[0] - or (hasattr(obj[1], 'contains_aggregate') - and obj[1].contains_aggregate(self.aggregates))) + return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.annotations)[0] + or (hasattr(obj[1], 'refs_aggregate') + and obj[1].refs_aggregate(self.annotations)[0])) return any(self.need_having(c) for c in obj.children) def split_having_parts(self, q_object, negated=False): @@ -1390,13 +1353,21 @@ class Query(object): if name == 'pk': name = opts.pk.name try: - field, model, direct, m2m = opts.get_field_by_name(name) + field, model, _, _ = opts.get_field_by_name(name) except FieldDoesNotExist: + # is it an annotation? + if self._annotations and name in self._annotations: + field, model = self._annotations[name], None + if not field.contains_aggregate: + # Local non-relational field. + final_field = field + targets = (field,) + break # We didn't find the current field, so move position back # one step. pos -= 1 if pos == -1 or fail_on_missing: - available = opts.get_all_field_names() + list(self.aggregate_select) + available = opts.get_all_field_names() + list(self.annotation_select) raise FieldError("Cannot resolve keyword %r into field. " "Choices are: %s" % (name, ", ".join(available))) break @@ -1445,6 +1416,11 @@ class Query(object): break return path, final_field, targets, names[pos + 1:] + def raise_field_error(self, opts, name): + available = opts.get_all_field_names() + list(self.annotation_select) + raise FieldError("Cannot resolve keyword %r into field. " + "Choices are: %s" % (name, ", ".join(available))) + def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True): """ Compute the necessary table joins for the passage through the fields @@ -1519,6 +1495,29 @@ class Query(object): self.unref_alias(joins.pop()) return targets, joins[-1], joins + def resolve_ref(self, name, allow_joins, reuse, summarize): + if not allow_joins and LOOKUP_SEP in name: + raise FieldError("Joined field references are not permitted in this query") + if name in self.annotations: + if summarize: + return Ref(name, self.annotation_select[name]) + else: + return self.annotation_select[name] + else: + field_list = name.split(LOOKUP_SEP) + field, sources, opts, join_list, path = self.setup_joins( + field_list, self.get_meta(), + self.get_initial_alias(), reuse) + targets, _, join_list = self.trim_joins(sources, join_list, path) + if len(targets) > 1: + raise FieldError("Referencing multicolumn fields with F() objects " + "isn't supported") + if reuse is not None: + reuse.update(join_list) + col = Col(join_list[-1], targets[0], sources[0]) + col._used_joins = join_list + return col + def split_exclude(self, filter_expr, prefix, can_reuse, names_with_path): """ When doing an exclude against any kind of N-to-many relation, we need @@ -1633,7 +1632,7 @@ class Query(object): self.default_cols = False self.select_related = False self.set_extra_mask(()) - self.set_aggregate_mask(()) + self.set_annotation_mask(()) def clear_select_fields(self): """ @@ -1676,7 +1675,7 @@ class Query(object): raise else: names = sorted(opts.get_all_field_names() + list(self.extra) - + list(self.aggregate_select)) + + list(self.annotation_select)) raise FieldError("Cannot resolve keyword %r into field. " "Choices are: %s" % (name, ", ".join(names))) self.remove_inherited_models() @@ -1725,39 +1724,55 @@ class Query(object): for col, _ in self.select: self.group_by.append(col) + if self._annotations: + for alias, annotation in six.iteritems(self.annotations): + for col in annotation.get_group_by_cols(): + self.group_by.append(col) + def add_count_column(self): """ Converts the query to do count(...) or count(distinct(pk)) in order to get its size. """ + summarize = False if not self.distinct: if not self.select: - count = self.aggregates_module.Count('*', is_summary=True) + count = Count('*') + summarize = True else: assert len(self.select) == 1, \ "Cannot add count col with multiple cols in 'select': %r" % self.select - count = self.aggregates_module.Count(self.select[0].col) + col = self.select[0].col + if isinstance(col, (tuple, list)): + count = Count(col[1]) + else: + count = Count(col) + else: opts = self.get_meta() if not self.select: - count = self.aggregates_module.Count( - (self.join((None, opts.db_table, None)), opts.pk.column), - is_summary=True, distinct=True) + lookup = self.join((None, opts.db_table, None)), opts.pk.column + count = Count(lookup[1], distinct=True) + summarize = True else: # Because of SQL portability issues, multi-column, distinct # counts need a sub-query -- see get_count() for details. assert len(self.select) == 1, \ "Cannot add count col with multiple cols in 'select'." - - count = self.aggregates_module.Count(self.select[0].col, distinct=True) + col = self.select[0].col + if isinstance(col, (tuple, list)): + count = Count(col[1], distinct=True) + else: + count = Count(col, distinct=True) # Distinct handling is done in Count(), so don't do it at this # level. self.distinct = False # Set only aggregate to be the count column. - # Clear out the select cache to reflect the new unmasked aggregates. - self._aggregates = {None: count} - self.set_aggregate_mask(None) + # Clear out the select cache to reflect the new unmasked annotations. + count = count.resolve_expression(self, summarize=summarize) + self._annotations = {None: count} + self.set_annotation_mask(None) self.group_by = None def add_select_related(self, fields): @@ -1886,16 +1901,28 @@ class Query(object): target[model] = set(f.name for f in fields) def set_aggregate_mask(self, names): - "Set the mask of aggregates that will actually be returned by the SELECT" + warnings.warn( + "set_aggregate_mask() is deprecated. Use set_annotation_mask() instead.", + RemovedInDjango20Warning, stacklevel=2) + self.set_annotation_mask(names) + + def set_annotation_mask(self, names): + "Set the mask of annotations that will actually be returned by the SELECT" if names is None: - self.aggregate_select_mask = None + self.annotation_select_mask = None else: - self.aggregate_select_mask = set(names) - self._aggregate_select_cache = None + self.annotation_select_mask = set(names) + self._annotation_select_cache = None def append_aggregate_mask(self, names): - if self.aggregate_select_mask is not None: - self.set_aggregate_mask(set(names).union(self.aggregate_select_mask)) + warnings.warn( + "append_aggregate_mask() is deprecated. Use append_annotation_mask() instead.", + RemovedInDjango20Warning, stacklevel=2) + self.append_annotation_mask(names) + + def append_annotation_mask(self, names): + if self.annotation_select_mask is not None: + self.set_annotation_mask(set(names).union(self.annotation_select_mask)) def set_extra_mask(self, names): """ @@ -1910,24 +1937,31 @@ class Query(object): self._extra_select_cache = None @property - def aggregate_select(self): + def annotation_select(self): """The OrderedDict of aggregate columns that are not masked, and should be used in the SELECT clause. This result is cached for optimization purposes. """ - if self._aggregate_select_cache is not None: - return self._aggregate_select_cache - elif not self._aggregates: + if self._annotation_select_cache is not None: + return self._annotation_select_cache + elif not self._annotations: return {} - elif self.aggregate_select_mask is not None: - self._aggregate_select_cache = OrderedDict( - (k, v) for k, v in self.aggregates.items() - if k in self.aggregate_select_mask + elif self.annotation_select_mask is not None: + self._annotation_select_cache = OrderedDict( + (k, v) for k, v in self.annotations.items() + if k in self.annotation_select_mask ) - return self._aggregate_select_cache + return self._annotation_select_cache else: - return self.aggregates + return self.annotations + + @property + def aggregate_select(self): + warnings.warn( + "aggregate_select() is deprecated. Use annotation_select() instead.", + RemovedInDjango20Warning, stacklevel=2) + return self.annotation_select @property def extra_select(self): diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index 2f0de5b80c..6f3f7358d3 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -7,9 +7,9 @@ from django.core.exceptions import FieldError from django.db import connections from django.db.models.query_utils import Q from django.db.models.constants import LOOKUP_SEP +from django.db.models.expressions import Date, DateTime, Col from django.db.models.fields import DateField, DateTimeField, FieldDoesNotExist from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, NO_RESULTS, SelectInfo -from django.db.models.sql.datastructures import Date, DateTime from django.db.models.sql.query import Query from django.utils import six from django.utils import timezone @@ -229,7 +229,7 @@ class DateQuery(Query): )) self._check_field(field) # overridden in DateTimeQuery alias = joins[-1] - select = self._get_select((alias, field.column), lookup_type) + select = self._get_select(Col(alias, field), lookup_type) self.clear_select_clause() self.select = [SelectInfo(select, None)] self.distinct = True diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index f65e593a3a..13815cb68c 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -10,7 +10,6 @@ import warnings from django.conf import settings from django.db.models.fields import DateTimeField, Field from django.db.models.sql.datastructures import EmptyResultSet, Empty -from django.db.models.sql.aggregates import Aggregate from django.utils.deprecation import RemovedInDjango19Warning from django.utils.six.moves import xrange from django.utils import timezone @@ -78,7 +77,7 @@ class WhereNode(tree.Node): else: value_annotation = bool(value) - if hasattr(obj, "prepare"): + if hasattr(obj, 'prepare'): value = obj.prepare(lookup_type, value) return (obj, lookup_type, value_annotation, value) @@ -187,11 +186,9 @@ class WhereNode(tree.Node): lvalue, params = lvalue.process(lookup_type, params_or_value, connection) except EmptyShortCircuit: raise EmptyResultSet - elif isinstance(lvalue, Aggregate): - params = lvalue.field.get_db_prep_lookup(lookup_type, params_or_value, connection) else: - raise TypeError("'make_atom' expects a Constraint or an Aggregate " - "as the first item of its 'child' argument.") + raise TypeError("'make_atom' expects a Constraint as the first " + "item of its 'child' argument.") if isinstance(lvalue, tuple): # A direct database column lookup. |
