diff options
| author | Russell Keith-Magee <russell@keith-magee.com> | 2009-01-15 11:06:34 +0000 |
|---|---|---|
| committer | Russell Keith-Magee <russell@keith-magee.com> | 2009-01-15 11:06:34 +0000 |
| commit | cc4e4d9aee0b3ebfb45bee01aec79edc9e144c78 (patch) | |
| tree | 2cdba846a105d406ecceff2c02e071c50502d487 /django/db/models | |
| parent | 50a293a0c31e7325ebd520338f9c8881f951d8a7 (diff) | |
Fixed #3566 -- Added support for aggregation to the ORM. See the documentation for details on usage.
Many thanks to:
* Nicolas Lara, who worked on this feature during the 2008 Google Summer of Code.
* Alex Gaynor for his help debugging and fixing a number of issues.
* Justin Bronn for his help integrating with contrib.gis.
* Karen Tracey for her help with cross-platform testing.
* Ian Kelly for his help testing and fixing Oracle support.
* Malcolm Tredinnick for his invaluable review notes.
git-svn-id: http://code.djangoproject.com/svn/django/trunk@9742 bcc190cf-cafb-0310-a4f2-bffc1f526a37
Diffstat (limited to 'django/db/models')
| -rw-r--r-- | django/db/models/__init__.py | 1 | ||||
| -rw-r--r-- | django/db/models/aggregates.py | 66 | ||||
| -rw-r--r-- | django/db/models/manager.py | 6 | ||||
| -rw-r--r-- | django/db/models/query.py | 89 | ||||
| -rw-r--r-- | django/db/models/query_utils.py | 1 | ||||
| -rw-r--r-- | django/db/models/sql/aggregates.py | 130 | ||||
| -rw-r--r-- | django/db/models/sql/datastructures.py | 53 | ||||
| -rw-r--r-- | django/db/models/sql/query.py | 307 | ||||
| -rw-r--r-- | django/db/models/sql/subqueries.py | 30 |
9 files changed, 544 insertions, 139 deletions
diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py index 5413133306..0802f8695e 100644 --- a/django/db/models/__init__.py +++ b/django/db/models/__init__.py @@ -5,6 +5,7 @@ from django.db.models.loading import get_apps, get_app, get_models, get_model, r from django.db.models.query import Q from django.db.models.manager import Manager from django.db.models.base import Model +from django.db.models.aggregates import * from django.db.models.fields import * from django.db.models.fields.subclassing import SubfieldBase from django.db.models.fields.files import FileField, ImageField diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py new file mode 100644 index 0000000000..1f676253b5 --- /dev/null +++ b/django/db/models/aggregates.py @@ -0,0 +1,66 @@ +""" +Classes to represent the definitions of aggregate functions. +""" + +class Aggregate(object): + """ + Default Aggregate definition. + """ + def __init__(self, lookup, **extra): + """Instantiate a new aggregate. + + * lookup is the field on which the aggregate operates. + * extra is a dictionary of additional data to provide for the + aggregate definition + + Also utilizes the class variables: + * name, the identifier for this aggregate function. + """ + self.lookup = lookup + self.extra = extra + + def _default_alias(self): + return '%s__%s' % (self.lookup, self.name.lower()) + default_alias = property(_default_alias) + + def add_to_query(self, query, alias, col, source, is_summary): + """Add the aggregate to the nominated query. + + This method is used to convert the generic Aggregate definition into a + backend-specific definition. + + * query is the backend-specific query instance to which the aggregate + is to be added. + * col is a column reference describing the subject field + of the aggregate. It can be an alias, or a tuple describing + a table and column name. + * source is the underlying field or aggregate definition for + the column reference. If the aggregate is not an ordinal or + computed type, this reference is used to determine the coerced + output type of the aggregate. + * is_summary is a boolean that is set True if the aggregate is a + summary value rather than an annotation. + """ + aggregate = getattr(query.aggregates_module, self.name) + query.aggregate_select[alias] = aggregate(col, source=source, is_summary=is_summary, **self.extra) + +class Avg(Aggregate): + name = 'Avg' + +class Count(Aggregate): + name = 'Count' + +class Max(Aggregate): + name = 'Max' + +class Min(Aggregate): + name = 'Min' + +class StdDev(Aggregate): + name = 'StdDev' + +class Sum(Aggregate): + name = 'Sum' + +class Variance(Aggregate): + name = 'Variance' diff --git a/django/db/models/manager.py b/django/db/models/manager.py index 683a8f8d10..4e8c6e94fb 100644 --- a/django/db/models/manager.py +++ b/django/db/models/manager.py @@ -101,6 +101,12 @@ class Manager(object): def filter(self, *args, **kwargs): return self.get_query_set().filter(*args, **kwargs) + def aggregate(self, *args, **kwargs): + return self.get_query_set().aggregate(*args, **kwargs) + + def annotate(self, *args, **kwargs): + return self.get_query_set().annotate(*args, **kwargs) + def complex_filter(self, *args, **kwargs): return self.get_query_set().complex_filter(*args, **kwargs) diff --git a/django/db/models/query.py b/django/db/models/query.py index e2dcb1fa65..b5c9d2f25d 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -4,6 +4,7 @@ except NameError: from sets import Set as set # Python 2.3 fallback from django.db import connection, transaction, IntegrityError +from django.db.models.aggregates import Aggregate from django.db.models.fields import DateField from django.db.models.query_utils import Q, select_related_descend from django.db.models import signals, sql @@ -270,18 +271,47 @@ class QuerySet(object): else: requested = None max_depth = self.query.max_depth + extra_select = self.query.extra_select.keys() + aggregate_select = self.query.aggregate_select.keys() + index_start = len(extra_select) + aggregate_start = index_start + len(self.model._meta.fields) + for row in self.query.results_iter(): if fill_cache: - obj, _ = get_cached_row(self.model, row, index_start, - max_depth, requested=requested) + obj, aggregate_start = get_cached_row(self.model, row, + index_start, max_depth, requested=requested) else: - obj = self.model(*row[index_start:]) + # omit aggregates in object creation + obj = self.model(*row[index_start:aggregate_start]) + for i, k in enumerate(extra_select): setattr(obj, k, row[i]) + + # Add the aggregates to the model + for i, aggregate in enumerate(aggregate_select): + setattr(obj, aggregate, row[i+aggregate_start]) + yield obj + def aggregate(self, *args, **kwargs): + """ + Returns a dictionary containing the calculations (aggregation) + over the current queryset + + If args is present the expression is passed as a kwarg ussing + the Aggregate object's default alias. + """ + for arg in args: + kwargs[arg.default_alias] = arg + + for (alias, aggregate_expr) in kwargs.items(): + self.query.add_aggregate(aggregate_expr, self.model, alias, + is_summary=True) + + return self.query.get_aggregation() + def count(self): """ Performs a SELECT COUNT() and returns the number of records as an @@ -553,6 +583,25 @@ class QuerySet(object): """ self.query.select_related = other.query.select_related + def annotate(self, *args, **kwargs): + """ + Return a query set in which the returned objects have been annotated + with data aggregated from related fields. + """ + for arg in args: + kwargs[arg.default_alias] = arg + + obj = self._clone() + + obj._setup_aggregate_query() + + # Add the aggregates to the query + for (alias, aggregate_expr) in kwargs.items(): + obj.query.add_aggregate(aggregate_expr, self.model, alias, + is_summary=False) + + return obj + def order_by(self, *field_names): """ Returns a new QuerySet instance with the ordering changed. @@ -641,6 +690,16 @@ class QuerySet(object): """ pass + def _setup_aggregate_query(self): + """ + Prepare the query for computing a result that contains aggregate annotations. + """ + opts = self.model._meta + if not self.query.group_by: + field_names = [f.attname for f in opts.fields] + self.query.add_fields(field_names, False) + self.query.set_group_by() + def as_sql(self): """ Returns the internal query's SQL and parameters (as a tuple). @@ -669,6 +728,8 @@ class ValuesQuerySet(QuerySet): len(self.field_names) != len(self.model._meta.fields)): self.query.trim_extra_select(self.extra_names) names = self.query.extra_select.keys() + self.field_names + names.extend(self.query.aggregate_select.keys()) + for row in self.query.results_iter(): yield dict(zip(names, row)) @@ -682,20 +743,25 @@ class ValuesQuerySet(QuerySet): """ self.query.clear_select_fields() self.extra_names = [] + self.aggregate_names = [] + if self._fields: - if not self.query.extra_select: + if not self.query.extra_select and not self.query.aggregate_select: field_names = list(self._fields) else: field_names = [] for f in self._fields: if self.query.extra_select.has_key(f): self.extra_names.append(f) + elif self.query.aggregate_select.has_key(f): + self.aggregate_names.append(f) else: field_names.append(f) else: # Default to all fields. field_names = [f.attname for f in self.model._meta.fields] + self.query.select = [] self.query.add_fields(field_names, False) self.query.default_cols = False self.field_names = field_names @@ -711,6 +777,7 @@ class ValuesQuerySet(QuerySet): c._fields = self._fields[:] c.field_names = self.field_names c.extra_names = self.extra_names + c.aggregate_names = self.aggregate_names if setup and hasattr(c, '_setup_query'): c._setup_query() return c @@ -718,10 +785,18 @@ class ValuesQuerySet(QuerySet): def _merge_sanity_check(self, other): super(ValuesQuerySet, self)._merge_sanity_check(other) if (set(self.extra_names) != set(other.extra_names) or - set(self.field_names) != set(other.field_names)): + set(self.field_names) != set(other.field_names) or + self.aggregate_names != other.aggregate_names): raise TypeError("Merging '%s' classes must involve the same values in each case." % self.__class__.__name__) + def _setup_aggregate_query(self): + """ + Prepare the query for computing a result that contains aggregate annotations. + """ + self.query.set_group_by() + + super(ValuesQuerySet, self)._setup_aggregate_query() class ValuesListQuerySet(ValuesQuerySet): def iterator(self): @@ -729,14 +804,14 @@ class ValuesListQuerySet(ValuesQuerySet): if self.flat and len(self._fields) == 1: for row in self.query.results_iter(): yield row[0] - elif not self.query.extra_select: + elif not self.query.extra_select and not self.query.aggregate_select: for row in self.query.results_iter(): yield tuple(row) else: # When extra(select=...) is involved, the extra cols come are # always at the start of the row, so we need to reorder the fields # to match the order in self._fields. - names = self.query.extra_select.keys() + self.field_names + names = self.query.extra_select.keys() + self.field_names + self.query.aggregate_select.keys() for row in self.query.results_iter(): data = dict(zip(names, row)) yield tuple([data[f] for f in self._fields]) diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py index 8dbb1ec667..9463283f25 100644 --- a/django/db/models/query_utils.py +++ b/django/db/models/query_utils.py @@ -64,4 +64,3 @@ def select_related_descend(field, restricted, requested): if not restricted and field.null: return False return True - diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py new file mode 100644 index 0000000000..6fdaf188c4 --- /dev/null +++ b/django/db/models/sql/aggregates.py @@ -0,0 +1,130 @@ +""" +Classes to represent the default SQL aggregate functions +""" + +class AggregateField(object): + """An internal field mockup used to identify aggregates in the + data-conversion parts of the database backend. + """ + def __init__(self, internal_type): + self.internal_type = internal_type + def get_internal_type(self): + return self.internal_type + +ordinal_aggregate_field = AggregateField('IntegerField') +computed_aggregate_field = AggregateField('FloatField') + +class Aggregate(object): + """ + Default SQL Aggregate. + """ + is_ordinal = False + is_computed = False + sql_template = '%(function)s(%(field)s)' + + def __init__(self, col, source=None, is_summary=False, **extra): + """Instantiate an SQL aggregate + + * col is a column reference describing the subject field + of the aggregate. It can be an alias, or a tuple describing + a table and column name. + * source is the underlying field or aggregate definition for + the column reference. If the aggregate is not an ordinal or + computed type, this reference is used to determine the coerced + output type of the aggregate. + * extra is a dictionary of additional data to provide for the + aggregate definition + + Also utilizes the class variables: + * sql_function, the name of the SQL function that implements the + aggregate. + * sql_template, a template string that is used to render the + aggregate into SQL. + * is_ordinal, a boolean indicating if the output of this aggregate + is an integer (e.g., a count) + * is_computed, a boolean indicating if this output of this aggregate + is a computed float (e.g., an average), regardless of the input + type. + + """ + self.col = col + self.source = source + self.is_summary = is_summary + self.extra = extra + + # Follow the chain of aggregate sources back until you find an + # actual field, or an aggregate that forces a particular output + # type. This type of this field will be used to coerce values + # retrieved from the database. + tmp = self + + while tmp and isinstance(tmp, Aggregate): + if getattr(tmp, 'is_ordinal', False): + tmp = ordinal_aggregate_field + elif getattr(tmp, 'is_computed', False): + tmp = computed_aggregate_field + else: + tmp = tmp.source + + self.field = tmp + + def relabel_aliases(self, change_map): + if isinstance(self.col, (list, tuple)): + self.col = (change_map.get(self.col[0], self.col[0]), self.col[1]) + + def as_sql(self, quote_func=None): + "Return the aggregate, rendered as SQL." + if not quote_func: + quote_func = lambda x: x + + if hasattr(self.col, 'as_sql'): + field_name = self.col.as_sql(quote_func) + elif isinstance(self.col, (list, tuple)): + field_name = '.'.join([quote_func(c) for c in self.col]) + else: + field_name = self.col + + params = { + 'function': self.sql_function, + 'field': field_name + } + params.update(self.extra) + + return self.sql_template % params + + +class Avg(Aggregate): + is_computed = True + sql_function = 'AVG' + +class Count(Aggregate): + is_ordinal = True + sql_function = 'COUNT' + sql_template = '%(function)s(%(distinct)s%(field)s)' + + def __init__(self, col, distinct=False, **extra): + super(Count, self).__init__(col, distinct=distinct and 'DISTINCT ' or '', **extra) + +class Max(Aggregate): + sql_function = 'MAX' + +class Min(Aggregate): + sql_function = 'MIN' + +class StdDev(Aggregate): + is_computed = True + + def __init__(self, col, sample=False, **extra): + super(StdDev, self).__init__(col, **extra) + self.sql_function = sample and 'STDDEV_SAMP' or 'STDDEV_POP' + +class Sum(Aggregate): + sql_function = 'SUM' + +class Variance(Aggregate): + is_computed = True + + def __init__(self, col, sample=False, **extra): + super(Variance, self).__init__(col, **extra) + self.sql_function = sample and 'VAR_SAMP' or 'VAR_POP' + diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py index 913d8fde25..4d53999c79 100644 --- a/django/db/models/sql/datastructures.py +++ b/django/db/models/sql/datastructures.py @@ -25,59 +25,6 @@ class RawValue(object): def __init__(self, value): self.value = value -class Aggregate(object): - """ - Base class for all aggregate-related classes (min, max, avg, count, sum). - """ - def relabel_aliases(self, change_map): - """ - Relabel the column alias, if necessary. Must be implemented by - subclasses. - """ - raise NotImplementedError - - def as_sql(self, quote_func=None): - """ - Returns the SQL string fragment for this object. - - The quote_func function is used to quote the column components. If - None, it defaults to doing nothing. - - Must be implemented by subclasses. - """ - raise NotImplementedError - -class Count(Aggregate): - """ - Perform a count on the given column. - """ - def __init__(self, col='*', distinct=False): - """ - Set the column to count on (defaults to '*') and set whether the count - should be distinct or not. - """ - self.col = col - self.distinct = distinct - - def relabel_aliases(self, change_map): - c = self.col - if isinstance(c, (list, tuple)): - self.col = (change_map.get(c[0], c[0]), c[1]) - - def as_sql(self, quote_func=None): - if not quote_func: - quote_func = lambda x: x - if isinstance(self.col, (list, tuple)): - col = ('%s.%s' % tuple([quote_func(c) for c in self.col])) - elif hasattr(self.col, 'as_sql'): - col = self.col.as_sql(quote_func) - else: - col = self.col - if self.distinct: - return 'COUNT(DISTINCT %s)' % col - else: - return 'COUNT(%s)' % col - class Date(object): """ Add a date selection column. diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index b12912461f..156617f807 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -12,12 +12,13 @@ from copy import deepcopy from django.utils.tree import Node from django.utils.datastructures import SortedDict from django.utils.encoding import force_unicode +from django.db.backends.util import truncate_name from django.db import connection from django.db.models import signals from django.db.models.fields import FieldDoesNotExist from django.db.models.query_utils import select_related_descend +from django.db.models.sql import aggregates as base_aggregates_module from django.db.models.sql.where import WhereNode, Constraint, EverythingNode, AND, OR -from django.db.models.sql.datastructures import Count from django.core.exceptions import FieldError from datastructures import EmptyResultSet, Empty, MultiJoin from constants import * @@ -40,6 +41,7 @@ class BaseQuery(object): alias_prefix = 'T' query_terms = QUERY_TERMS + aggregates_module = base_aggregates_module def __init__(self, model, connection, where=WhereNode): self.model = model @@ -73,6 +75,9 @@ class BaseQuery(object): self.select_related = False self.related_select_cols = [] + # SQL aggregate-related attributes + self.aggregate_select = SortedDict() # Maps alias -> SQL aggregate function + # Arbitrary maximum limit for select_related. Prevents infinite # recursion. Can be changed by the depth parameter to select_related(). self.max_depth = 5 @@ -178,6 +183,7 @@ class BaseQuery(object): obj.distinct = self.distinct obj.select_related = self.select_related obj.related_select_cols = [] + obj.aggregate_select = self.aggregate_select.copy() obj.max_depth = self.max_depth obj.extra_select = self.extra_select.copy() obj.extra_tables = self.extra_tables @@ -194,6 +200,35 @@ class BaseQuery(object): obj._setup_query() return obj + def convert_values(self, value, field): + """Convert the database-returned value into a type that is consistent + across database backends. + + By default, this defers to the underlying backend operations, but + it can be overridden by Query classes for specific backends. + """ + return self.connection.ops.convert_values(value, field) + + def resolve_aggregate(self, value, aggregate): + """Resolve the value of aggregates returned by the database to + consistent (and reasonable) types. + + This is required because of the predisposition of certain backends + to return Decimal and long types when they are not needed. + """ + if value is None: + # Return None as-is + return value + elif aggregate.is_ordinal: + # Any ordinal aggregate (e.g., count) returns an int + return int(value) + elif aggregate.is_computed: + # Any computed aggregate (e.g., avg) returns a float + return float(value) + else: + # Return value depends on the type of the field being processed. + return self.convert_values(value, aggregate.field) + def results_iter(self): """ Returns an iterator over the results from executing this query. @@ -212,29 +247,78 @@ class BaseQuery(object): else: fields = self.model._meta.fields row = self.resolve_columns(row, fields) + + if self.aggregate_select: + aggregate_start = len(self.extra_select.keys()) + len(self.select) + row = tuple(row[:aggregate_start]) + tuple([ + self.resolve_aggregate(value, aggregate) + for (alias, aggregate), value + in zip(self.aggregate_select.items(), row[aggregate_start:]) + ]) + yield row + def get_aggregation(self): + """ + Returns the dictionary with the values of the existing aggregations. + """ + if not self.aggregate_select: + return {} + + # If there is a group by clause, aggregating does not add useful + # information but retrieves only the first row. Aggregate + # over the subquery instead. + if self.group_by: + from subqueries import AggregateQuery + query = AggregateQuery(self.model, self.connection) + + obj = self.clone() + + # Remove any aggregates marked for reduction from the subquery + # and move them to the outer AggregateQuery. + for alias, aggregate in self.aggregate_select.items(): + if aggregate.is_summary: + query.aggregate_select[alias] = aggregate + del obj.aggregate_select[alias] + + query.add_subquery(obj) + else: + query = self + self.select = [] + self.default_cols = False + self.extra_select = {} + + query.clear_ordering(True) + query.clear_limits() + query.select_related = False + query.related_select_cols = [] + query.related_select_fields = [] + + return dict([ + (alias, self.resolve_aggregate(val, aggregate)) + for (alias, aggregate), val + in zip(query.aggregate_select.items(), query.execute_sql(SINGLE)) + ]) + def get_count(self): """ Performs a COUNT() query using the current filter constraints. """ - from subqueries import CountQuery obj = self.clone() - obj.clear_ordering(True) - obj.clear_limits() - obj.select_related = False - obj.related_select_cols = [] - obj.related_select_fields = [] - if len(obj.select) > 1: - obj = self.clone(CountQuery, _query=obj, where=self.where_class(), - distinct=False) - obj.select = [] - obj.extra_select = SortedDict() + if len(self.select) > 1: + # If a select clause exists, then the query has already started to + # specify the columns that are to be returned. + # In this case, we need to use a subquery to evaluate the count. + from subqueries import AggregateQuery + subquery = obj + subquery.clear_ordering(True) + subquery.clear_limits() + + obj = AggregateQuery(obj.model, obj.connection) + obj.add_subquery(subquery) + obj.add_count_column() - data = obj.execute_sql(SINGLE) - if not data: - return 0 - number = data[0] + number = obj.get_aggregation()[None] # Apply offset and limit constraints manually, since using LIMIT/OFFSET # in SQL (in variants that provide them) doesn't change the COUNT @@ -450,25 +534,41 @@ class BaseQuery(object): for col in self.select: if isinstance(col, (list, tuple)): r = '%s.%s' % (qn(col[0]), qn(col[1])) - if with_aliases and col[1] in col_aliases: - c_alias = 'Col%d' % len(col_aliases) - result.append('%s AS %s' % (r, c_alias)) - aliases.add(c_alias) - col_aliases.add(c_alias) + if with_aliases: + if col[1] in col_aliases: + c_alias = 'Col%d' % len(col_aliases) + result.append('%s AS %s' % (r, c_alias)) + aliases.add(c_alias) + col_aliases.add(c_alias) + else: + result.append('%s AS %s' % (r, col[1])) + aliases.add(r) + col_aliases.add(col[1]) else: result.append(r) aliases.add(r) col_aliases.add(col[1]) else: result.append(col.as_sql(quote_func=qn)) + if hasattr(col, 'alias'): aliases.add(col.alias) col_aliases.add(col.alias) + elif self.default_cols: cols, new_aliases = self.get_default_columns(with_aliases, col_aliases) result.extend(cols) aliases.update(new_aliases) + + result.extend([ + '%s%s' % ( + aggregate.as_sql(quote_func=qn), + alias is not None and ' AS %s' % qn(alias) or '' + ) + for alias, aggregate in self.aggregate_select.items() + ]) + for table, col in self.related_select_cols: r = '%s.%s' % (qn(table), qn(col)) if with_aliases and col in col_aliases: @@ -538,7 +638,7 @@ class BaseQuery(object): Returns a list of strings that are joined together to go after the "FROM" part of the query, as well as a list any extra parameters that need to be included. Sub-classes, can override this to create a - from-clause via a "select", for example (e.g. CountQuery). + from-clause via a "select". This should only be called after any SQL construction methods that might change the tables we need. This means the select columns and @@ -635,10 +735,13 @@ class BaseQuery(object): order = asc result.append('%s %s' % (field, order)) continue + col, order = get_order_dir(field, asc) + if col in self.aggregate_select: + result.append('%s %s' % (col, order)) + continue if '.' in field: # This came in through an extra(order_by=...) addition. Pass it # on verbatim. - col, order = get_order_dir(field, asc) table, col = col.split('.', 1) if (table, col) not in processed_pairs: elt = '%s.%s' % (qn(table), col) @@ -657,7 +760,6 @@ class BaseQuery(object): ordering_aliases.append(elt) result.append('%s %s' % (elt, order)) else: - col, order = get_order_dir(field, asc) elt = qn2(col) if distinct and col not in select_aliases: ordering_aliases.append(elt) @@ -1068,6 +1170,48 @@ class BaseQuery(object): self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1, used, next, restricted, new_nullable, dupe_set, avoid) + def add_aggregate(self, aggregate, model, alias, is_summary): + """ + Adds a single aggregate expression to the Query + """ + opts = model._meta + field_list = aggregate.lookup.split(LOOKUP_SEP) + if (len(field_list) == 1 and + aggregate.lookup in self.aggregate_select.keys()): + # Aggregate is over an annotation + field_name = field_list[0] + col = field_name + source = self.aggregate_select[field_name] + elif (len(field_list) > 1 or + field_list[0] not in [i.name for i in opts.fields]): + field, source, opts, join_list, last, _ = self.setup_joins( + field_list, opts, self.get_initial_alias(), False) + + # Process the join chain to see if it can be trimmed + _, _, col, _, join_list = self.trim_joins(source, join_list, last, False) + + # If the aggregate references a model or field that requires a join, + # those joins must be LEFT OUTER - empty join rows must be returned + # in order for zeros to be returned for those aggregates. + for column_alias in join_list: + self.promote_alias(column_alias, unconditional=True) + + col = (join_list[-1], col) + else: + # Aggregate references a normal field + field_name = field_list[0] + source = opts.get_field(field_name) + if not (self.group_by and is_summary): + # Only use a column alias if this is a + # standalone aggregate, or an annotation + col = (opts.db_table, source.column) + else: + col = field_name + + # Add the aggregate to the query + alias = truncate_name(alias, self.connection.ops.max_name_length()) + aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary) + def add_filter(self, filter_expr, connector=AND, negate=False, trim=False, can_reuse=None, process_extras=True): """ @@ -1119,6 +1263,11 @@ class BaseQuery(object): elif callable(value): value = value() + for alias, aggregate in self.aggregate_select.items(): + if alias == parts[0]: + self.having.add((aggregate, lookup_type, value), AND) + return + opts = self.get_meta() alias = self.get_initial_alias() allow_many = trim or not negate @@ -1131,38 +1280,9 @@ class BaseQuery(object): self.split_exclude(filter_expr, LOOKUP_SEP.join(parts[:e.level]), can_reuse) return - final = len(join_list) - penultimate = last.pop() - if penultimate == final: - penultimate = last.pop() - if trim and len(join_list) > 1: - extra = join_list[penultimate:] - join_list = join_list[:penultimate] - final = penultimate - penultimate = last.pop() - col = self.alias_map[extra[0]][LHS_JOIN_COL] - for alias in extra: - self.unref_alias(alias) - else: - col = target.column - alias = join_list[-1] - while final > 1: - # An optimization: if the final join is against the same column as - # we are comparing against, we can go back one step in the join - # chain and compare against the lhs of the join instead (and then - # repeat the optimization). The result, potentially, involves less - # table joins. - join = self.alias_map[alias] - if col != join[RHS_JOIN_COL]: - break - self.unref_alias(alias) - alias = join[LHS_ALIAS] - col = join[LHS_JOIN_COL] - join_list = join_list[:-1] - final -= 1 - if final == penultimate: - penultimate = last.pop() + # Process the join chain to see if it can be trimmed + final, penultimate, col, alias, join_list = self.trim_joins(target, join_list, last, trim) if (lookup_type == 'isnull' and value is True and not negate and final > 1): @@ -1313,7 +1433,7 @@ class BaseQuery(object): field, model, direct, m2m = opts.get_field_by_name(f.name) break else: - names = opts.get_all_field_names() + names = opts.get_all_field_names() + self.aggregate_select.keys() raise FieldError("Cannot resolve keyword %r into field. " "Choices are: %s" % (name, ", ".join(names))) @@ -1462,6 +1582,43 @@ class BaseQuery(object): return field, target, opts, joins, last, extra_filters + def trim_joins(self, target, join_list, last, trim): + """An optimization: if the final join is against the same column as + we are comparing against, we can go back one step in a join + chain and compare against the LHS of the join instead (and then + repeat the optimization). The result, potentially, involves less + table joins. + + Returns a tuple + """ + final = len(join_list) + penultimate = last.pop() + if penultimate == final: + penultimate = last.pop() + if trim and len(join_list) > 1: + extra = join_list[penultimate:] + join_list = join_list[:penultimate] + final = penultimate + penultimate = last.pop() + col = self.alias_map[extra[0]][LHS_JOIN_COL] + for alias in extra: + self.unref_alias(alias) + else: + col = target.column + alias = join_list[-1] + while final > 1: + join = self.alias_map[alias] + if col != join[RHS_JOIN_COL]: + break + self.unref_alias(alias) + alias = join[LHS_ALIAS] + col = join[LHS_JOIN_COL] + join_list = join_list[:-1] + final -= 1 + if final == penultimate: + penultimate = last.pop() + return final, penultimate, col, alias, join_list + def update_dupe_avoidance(self, opts, col, alias): """ For a column that is one of multiple pointing to the same table, update @@ -1554,6 +1711,7 @@ class BaseQuery(object): """ alias = self.get_initial_alias() opts = self.get_meta() + try: for name in field_names: field, target, u2, joins, u3, u4 = self.setup_joins( @@ -1574,7 +1732,7 @@ class BaseQuery(object): except MultiJoin: raise FieldError("Invalid field name: '%s'" % name) except FieldError: - names = opts.get_all_field_names() + self.extra_select.keys() + names = opts.get_all_field_names() + self.extra_select.keys() + self.aggregate_select.keys() names.sort() raise FieldError("Cannot resolve keyword %r into field. " "Choices are: %s" % (name, ", ".join(names))) @@ -1609,38 +1767,52 @@ class BaseQuery(object): if force_empty: self.default_ordering = False + def set_group_by(self): + """ + Expands the GROUP BY clause required by the query. + + This will usually be the set of all non-aggregate fields in the + return data. If the database backend supports grouping by the + primary key, and the query would be equivalent, the optimization + will be made automatically. + """ + if self.connection.features.allows_group_by_pk: + if len(self.select) == len(self.model._meta.fields): + self.group_by.append('.'.join([self.model._meta.db_table, + self.model._meta.pk.column])) + return + + for sel in self.select: + self.group_by.append(sel) + def add_count_column(self): """ Converts the query to do count(...) or count(distinct(pk)) in order to get its size. """ - # TODO: When group_by support is added, this needs to be adjusted so - # that it doesn't totally overwrite the select list. if not self.distinct: if not self.select: - select = Count() + count = self.aggregates_module.Count('*', is_summary=True) else: assert len(self.select) == 1, \ "Cannot add count col with multiple cols in 'select': %r" % self.select - select = Count(self.select[0]) + count = self.aggregates_module.Count(self.select[0]) else: opts = self.model._meta if not self.select: - select = Count((self.join((None, opts.db_table, None, None)), - opts.pk.column), True) + count = self.aggregates_module.Count((self.join((None, opts.db_table, None, None)), opts.pk.column), + is_summary=True, distinct=True) else: # Because of SQL portability issues, multi-column, distinct # counts need a sub-query -- see get_count() for details. assert len(self.select) == 1, \ "Cannot add count col with multiple cols in 'select'." - select = Count(self.select[0], True) + count = self.aggregates_module.Count(self.select[0], distinct=True) # Distinct handling is done in Count(), so don't do it at this # level. self.distinct = False - self.select = [select] - self.select_fields = [None] - self.extra_select = {} + self.aggregate_select = {None: count} def add_select_related(self, fields): """ @@ -1758,7 +1930,6 @@ class BaseQuery(object): return empty_iter() else: return - cursor = self.connection.cursor() cursor.execute(sql, params) diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index 524b0894c5..0a59b403c8 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -9,7 +9,7 @@ from django.db.models.sql.query import Query from django.db.models.sql.where import AND, Constraint __all__ = ['DeleteQuery', 'UpdateQuery', 'InsertQuery', 'DateQuery', - 'CountQuery'] + 'AggregateQuery'] class DeleteQuery(Query): """ @@ -400,15 +400,25 @@ class DateQuery(Query): self.distinct = True self.order_by = order == 'ASC' and [1] or [-1] -class CountQuery(Query): +class AggregateQuery(Query): """ - A CountQuery knows how to take a normal query which would select over - multiple distinct columns and turn it into SQL that can be used on a - variety of backends (it requires a select in the FROM clause). + An AggregateQuery takes another query as a parameter to the FROM + clause and only selects the elements in the provided list. """ - def get_from_clause(self): - result, params = self._query.as_sql() - return ['(%s) A1' % result], params + def add_subquery(self, query): + self.subquery, self.sub_params = query.as_sql(with_col_aliases=True) - def get_ordering(self): - return () + def as_sql(self, quote_func=None): + """ + Creates the SQL for this query. Returns the SQL string and list of + parameters. + """ + sql = ('SELECT %s FROM (%s) subquery' % ( + ', '.join([ + aggregate.as_sql() + for aggregate in self.aggregate_select.values() + ]), + self.subquery) + ) + params = self.sub_params + return (sql, params) |
