diff options
Diffstat (limited to 'django/db/models/sql/query.py')
| -rw-r--r-- | django/db/models/sql/query.py | 307 |
1 files changed, 239 insertions, 68 deletions
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index b12912461f..156617f807 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -12,12 +12,13 @@ from copy import deepcopy from django.utils.tree import Node from django.utils.datastructures import SortedDict from django.utils.encoding import force_unicode +from django.db.backends.util import truncate_name from django.db import connection from django.db.models import signals from django.db.models.fields import FieldDoesNotExist from django.db.models.query_utils import select_related_descend +from django.db.models.sql import aggregates as base_aggregates_module from django.db.models.sql.where import WhereNode, Constraint, EverythingNode, AND, OR -from django.db.models.sql.datastructures import Count from django.core.exceptions import FieldError from datastructures import EmptyResultSet, Empty, MultiJoin from constants import * @@ -40,6 +41,7 @@ class BaseQuery(object): alias_prefix = 'T' query_terms = QUERY_TERMS + aggregates_module = base_aggregates_module def __init__(self, model, connection, where=WhereNode): self.model = model @@ -73,6 +75,9 @@ class BaseQuery(object): self.select_related = False self.related_select_cols = [] + # SQL aggregate-related attributes + self.aggregate_select = SortedDict() # Maps alias -> SQL aggregate function + # Arbitrary maximum limit for select_related. Prevents infinite # recursion. Can be changed by the depth parameter to select_related(). self.max_depth = 5 @@ -178,6 +183,7 @@ class BaseQuery(object): obj.distinct = self.distinct obj.select_related = self.select_related obj.related_select_cols = [] + obj.aggregate_select = self.aggregate_select.copy() obj.max_depth = self.max_depth obj.extra_select = self.extra_select.copy() obj.extra_tables = self.extra_tables @@ -194,6 +200,35 @@ class BaseQuery(object): obj._setup_query() return obj + def convert_values(self, value, field): + """Convert the database-returned value into a type that is consistent + across database backends. + + By default, this defers to the underlying backend operations, but + it can be overridden by Query classes for specific backends. + """ + return self.connection.ops.convert_values(value, field) + + def resolve_aggregate(self, value, aggregate): + """Resolve the value of aggregates returned by the database to + consistent (and reasonable) types. + + This is required because of the predisposition of certain backends + to return Decimal and long types when they are not needed. + """ + if value is None: + # Return None as-is + return value + elif aggregate.is_ordinal: + # Any ordinal aggregate (e.g., count) returns an int + return int(value) + elif aggregate.is_computed: + # Any computed aggregate (e.g., avg) returns a float + return float(value) + else: + # Return value depends on the type of the field being processed. + return self.convert_values(value, aggregate.field) + def results_iter(self): """ Returns an iterator over the results from executing this query. @@ -212,29 +247,78 @@ class BaseQuery(object): else: fields = self.model._meta.fields row = self.resolve_columns(row, fields) + + if self.aggregate_select: + aggregate_start = len(self.extra_select.keys()) + len(self.select) + row = tuple(row[:aggregate_start]) + tuple([ + self.resolve_aggregate(value, aggregate) + for (alias, aggregate), value + in zip(self.aggregate_select.items(), row[aggregate_start:]) + ]) + yield row + def get_aggregation(self): + """ + Returns the dictionary with the values of the existing aggregations. + """ + if not self.aggregate_select: + return {} + + # If there is a group by clause, aggregating does not add useful + # information but retrieves only the first row. Aggregate + # over the subquery instead. + if self.group_by: + from subqueries import AggregateQuery + query = AggregateQuery(self.model, self.connection) + + obj = self.clone() + + # Remove any aggregates marked for reduction from the subquery + # and move them to the outer AggregateQuery. + for alias, aggregate in self.aggregate_select.items(): + if aggregate.is_summary: + query.aggregate_select[alias] = aggregate + del obj.aggregate_select[alias] + + query.add_subquery(obj) + else: + query = self + self.select = [] + self.default_cols = False + self.extra_select = {} + + query.clear_ordering(True) + query.clear_limits() + query.select_related = False + query.related_select_cols = [] + query.related_select_fields = [] + + return dict([ + (alias, self.resolve_aggregate(val, aggregate)) + for (alias, aggregate), val + in zip(query.aggregate_select.items(), query.execute_sql(SINGLE)) + ]) + def get_count(self): """ Performs a COUNT() query using the current filter constraints. """ - from subqueries import CountQuery obj = self.clone() - obj.clear_ordering(True) - obj.clear_limits() - obj.select_related = False - obj.related_select_cols = [] - obj.related_select_fields = [] - if len(obj.select) > 1: - obj = self.clone(CountQuery, _query=obj, where=self.where_class(), - distinct=False) - obj.select = [] - obj.extra_select = SortedDict() + if len(self.select) > 1: + # If a select clause exists, then the query has already started to + # specify the columns that are to be returned. + # In this case, we need to use a subquery to evaluate the count. + from subqueries import AggregateQuery + subquery = obj + subquery.clear_ordering(True) + subquery.clear_limits() + + obj = AggregateQuery(obj.model, obj.connection) + obj.add_subquery(subquery) + obj.add_count_column() - data = obj.execute_sql(SINGLE) - if not data: - return 0 - number = data[0] + number = obj.get_aggregation()[None] # Apply offset and limit constraints manually, since using LIMIT/OFFSET # in SQL (in variants that provide them) doesn't change the COUNT @@ -450,25 +534,41 @@ class BaseQuery(object): for col in self.select: if isinstance(col, (list, tuple)): r = '%s.%s' % (qn(col[0]), qn(col[1])) - if with_aliases and col[1] in col_aliases: - c_alias = 'Col%d' % len(col_aliases) - result.append('%s AS %s' % (r, c_alias)) - aliases.add(c_alias) - col_aliases.add(c_alias) + if with_aliases: + if col[1] in col_aliases: + c_alias = 'Col%d' % len(col_aliases) + result.append('%s AS %s' % (r, c_alias)) + aliases.add(c_alias) + col_aliases.add(c_alias) + else: + result.append('%s AS %s' % (r, col[1])) + aliases.add(r) + col_aliases.add(col[1]) else: result.append(r) aliases.add(r) col_aliases.add(col[1]) else: result.append(col.as_sql(quote_func=qn)) + if hasattr(col, 'alias'): aliases.add(col.alias) col_aliases.add(col.alias) + elif self.default_cols: cols, new_aliases = self.get_default_columns(with_aliases, col_aliases) result.extend(cols) aliases.update(new_aliases) + + result.extend([ + '%s%s' % ( + aggregate.as_sql(quote_func=qn), + alias is not None and ' AS %s' % qn(alias) or '' + ) + for alias, aggregate in self.aggregate_select.items() + ]) + for table, col in self.related_select_cols: r = '%s.%s' % (qn(table), qn(col)) if with_aliases and col in col_aliases: @@ -538,7 +638,7 @@ class BaseQuery(object): Returns a list of strings that are joined together to go after the "FROM" part of the query, as well as a list any extra parameters that need to be included. Sub-classes, can override this to create a - from-clause via a "select", for example (e.g. CountQuery). + from-clause via a "select". This should only be called after any SQL construction methods that might change the tables we need. This means the select columns and @@ -635,10 +735,13 @@ class BaseQuery(object): order = asc result.append('%s %s' % (field, order)) continue + col, order = get_order_dir(field, asc) + if col in self.aggregate_select: + result.append('%s %s' % (col, order)) + continue if '.' in field: # This came in through an extra(order_by=...) addition. Pass it # on verbatim. - col, order = get_order_dir(field, asc) table, col = col.split('.', 1) if (table, col) not in processed_pairs: elt = '%s.%s' % (qn(table), col) @@ -657,7 +760,6 @@ class BaseQuery(object): ordering_aliases.append(elt) result.append('%s %s' % (elt, order)) else: - col, order = get_order_dir(field, asc) elt = qn2(col) if distinct and col not in select_aliases: ordering_aliases.append(elt) @@ -1068,6 +1170,48 @@ class BaseQuery(object): self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1, used, next, restricted, new_nullable, dupe_set, avoid) + def add_aggregate(self, aggregate, model, alias, is_summary): + """ + Adds a single aggregate expression to the Query + """ + opts = model._meta + field_list = aggregate.lookup.split(LOOKUP_SEP) + if (len(field_list) == 1 and + aggregate.lookup in self.aggregate_select.keys()): + # Aggregate is over an annotation + field_name = field_list[0] + col = field_name + source = self.aggregate_select[field_name] + elif (len(field_list) > 1 or + field_list[0] not in [i.name for i in opts.fields]): + field, source, opts, join_list, last, _ = self.setup_joins( + field_list, opts, self.get_initial_alias(), False) + + # Process the join chain to see if it can be trimmed + _, _, col, _, join_list = self.trim_joins(source, join_list, last, False) + + # If the aggregate references a model or field that requires a join, + # those joins must be LEFT OUTER - empty join rows must be returned + # in order for zeros to be returned for those aggregates. + for column_alias in join_list: + self.promote_alias(column_alias, unconditional=True) + + col = (join_list[-1], col) + else: + # Aggregate references a normal field + field_name = field_list[0] + source = opts.get_field(field_name) + if not (self.group_by and is_summary): + # Only use a column alias if this is a + # standalone aggregate, or an annotation + col = (opts.db_table, source.column) + else: + col = field_name + + # Add the aggregate to the query + alias = truncate_name(alias, self.connection.ops.max_name_length()) + aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary) + def add_filter(self, filter_expr, connector=AND, negate=False, trim=False, can_reuse=None, process_extras=True): """ @@ -1119,6 +1263,11 @@ class BaseQuery(object): elif callable(value): value = value() + for alias, aggregate in self.aggregate_select.items(): + if alias == parts[0]: + self.having.add((aggregate, lookup_type, value), AND) + return + opts = self.get_meta() alias = self.get_initial_alias() allow_many = trim or not negate @@ -1131,38 +1280,9 @@ class BaseQuery(object): self.split_exclude(filter_expr, LOOKUP_SEP.join(parts[:e.level]), can_reuse) return - final = len(join_list) - penultimate = last.pop() - if penultimate == final: - penultimate = last.pop() - if trim and len(join_list) > 1: - extra = join_list[penultimate:] - join_list = join_list[:penultimate] - final = penultimate - penultimate = last.pop() - col = self.alias_map[extra[0]][LHS_JOIN_COL] - for alias in extra: - self.unref_alias(alias) - else: - col = target.column - alias = join_list[-1] - while final > 1: - # An optimization: if the final join is against the same column as - # we are comparing against, we can go back one step in the join - # chain and compare against the lhs of the join instead (and then - # repeat the optimization). The result, potentially, involves less - # table joins. - join = self.alias_map[alias] - if col != join[RHS_JOIN_COL]: - break - self.unref_alias(alias) - alias = join[LHS_ALIAS] - col = join[LHS_JOIN_COL] - join_list = join_list[:-1] - final -= 1 - if final == penultimate: - penultimate = last.pop() + # Process the join chain to see if it can be trimmed + final, penultimate, col, alias, join_list = self.trim_joins(target, join_list, last, trim) if (lookup_type == 'isnull' and value is True and not negate and final > 1): @@ -1313,7 +1433,7 @@ class BaseQuery(object): field, model, direct, m2m = opts.get_field_by_name(f.name) break else: - names = opts.get_all_field_names() + names = opts.get_all_field_names() + self.aggregate_select.keys() raise FieldError("Cannot resolve keyword %r into field. " "Choices are: %s" % (name, ", ".join(names))) @@ -1462,6 +1582,43 @@ class BaseQuery(object): return field, target, opts, joins, last, extra_filters + def trim_joins(self, target, join_list, last, trim): + """An optimization: if the final join is against the same column as + we are comparing against, we can go back one step in a join + chain and compare against the LHS of the join instead (and then + repeat the optimization). The result, potentially, involves less + table joins. + + Returns a tuple + """ + final = len(join_list) + penultimate = last.pop() + if penultimate == final: + penultimate = last.pop() + if trim and len(join_list) > 1: + extra = join_list[penultimate:] + join_list = join_list[:penultimate] + final = penultimate + penultimate = last.pop() + col = self.alias_map[extra[0]][LHS_JOIN_COL] + for alias in extra: + self.unref_alias(alias) + else: + col = target.column + alias = join_list[-1] + while final > 1: + join = self.alias_map[alias] + if col != join[RHS_JOIN_COL]: + break + self.unref_alias(alias) + alias = join[LHS_ALIAS] + col = join[LHS_JOIN_COL] + join_list = join_list[:-1] + final -= 1 + if final == penultimate: + penultimate = last.pop() + return final, penultimate, col, alias, join_list + def update_dupe_avoidance(self, opts, col, alias): """ For a column that is one of multiple pointing to the same table, update @@ -1554,6 +1711,7 @@ class BaseQuery(object): """ alias = self.get_initial_alias() opts = self.get_meta() + try: for name in field_names: field, target, u2, joins, u3, u4 = self.setup_joins( @@ -1574,7 +1732,7 @@ class BaseQuery(object): except MultiJoin: raise FieldError("Invalid field name: '%s'" % name) except FieldError: - names = opts.get_all_field_names() + self.extra_select.keys() + names = opts.get_all_field_names() + self.extra_select.keys() + self.aggregate_select.keys() names.sort() raise FieldError("Cannot resolve keyword %r into field. " "Choices are: %s" % (name, ", ".join(names))) @@ -1609,38 +1767,52 @@ class BaseQuery(object): if force_empty: self.default_ordering = False + def set_group_by(self): + """ + Expands the GROUP BY clause required by the query. + + This will usually be the set of all non-aggregate fields in the + return data. If the database backend supports grouping by the + primary key, and the query would be equivalent, the optimization + will be made automatically. + """ + if self.connection.features.allows_group_by_pk: + if len(self.select) == len(self.model._meta.fields): + self.group_by.append('.'.join([self.model._meta.db_table, + self.model._meta.pk.column])) + return + + for sel in self.select: + self.group_by.append(sel) + def add_count_column(self): """ Converts the query to do count(...) or count(distinct(pk)) in order to get its size. """ - # TODO: When group_by support is added, this needs to be adjusted so - # that it doesn't totally overwrite the select list. if not self.distinct: if not self.select: - select = Count() + count = self.aggregates_module.Count('*', is_summary=True) else: assert len(self.select) == 1, \ "Cannot add count col with multiple cols in 'select': %r" % self.select - select = Count(self.select[0]) + count = self.aggregates_module.Count(self.select[0]) else: opts = self.model._meta if not self.select: - select = Count((self.join((None, opts.db_table, None, None)), - opts.pk.column), True) + count = self.aggregates_module.Count((self.join((None, opts.db_table, None, None)), opts.pk.column), + is_summary=True, distinct=True) else: # Because of SQL portability issues, multi-column, distinct # counts need a sub-query -- see get_count() for details. assert len(self.select) == 1, \ "Cannot add count col with multiple cols in 'select'." - select = Count(self.select[0], True) + count = self.aggregates_module.Count(self.select[0], distinct=True) # Distinct handling is done in Count(), so don't do it at this # level. self.distinct = False - self.select = [select] - self.select_fields = [None] - self.extra_select = {} + self.aggregate_select = {None: count} def add_select_related(self, fields): """ @@ -1758,7 +1930,6 @@ class BaseQuery(object): return empty_iter() else: return - cursor = self.connection.cursor() cursor.execute(sql, params) |
