summaryrefslogtreecommitdiff
path: root/django/db/models/sql
diff options
context:
space:
mode:
authorRussell Keith-Magee <russell@keith-magee.com>2009-01-15 11:06:34 +0000
committerRussell Keith-Magee <russell@keith-magee.com>2009-01-15 11:06:34 +0000
commitcc4e4d9aee0b3ebfb45bee01aec79edc9e144c78 (patch)
tree2cdba846a105d406ecceff2c02e071c50502d487 /django/db/models/sql
parent50a293a0c31e7325ebd520338f9c8881f951d8a7 (diff)
Fixed #3566 -- Added support for aggregation to the ORM. See the documentation for details on usage.
Many thanks to: * Nicolas Lara, who worked on this feature during the 2008 Google Summer of Code. * Alex Gaynor for his help debugging and fixing a number of issues. * Justin Bronn for his help integrating with contrib.gis. * Karen Tracey for her help with cross-platform testing. * Ian Kelly for his help testing and fixing Oracle support. * Malcolm Tredinnick for his invaluable review notes. git-svn-id: http://code.djangoproject.com/svn/django/trunk@9742 bcc190cf-cafb-0310-a4f2-bffc1f526a37
Diffstat (limited to 'django/db/models/sql')
-rw-r--r--django/db/models/sql/aggregates.py130
-rw-r--r--django/db/models/sql/datastructures.py53
-rw-r--r--django/db/models/sql/query.py307
-rw-r--r--django/db/models/sql/subqueries.py30
4 files changed, 389 insertions, 131 deletions
diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py
new file mode 100644
index 0000000000..6fdaf188c4
--- /dev/null
+++ b/django/db/models/sql/aggregates.py
@@ -0,0 +1,130 @@
+"""
+Classes to represent the default SQL aggregate functions
+"""
+
+class AggregateField(object):
+ """An internal field mockup used to identify aggregates in the
+ data-conversion parts of the database backend.
+ """
+ def __init__(self, internal_type):
+ self.internal_type = internal_type
+ def get_internal_type(self):
+ return self.internal_type
+
+ordinal_aggregate_field = AggregateField('IntegerField')
+computed_aggregate_field = AggregateField('FloatField')
+
+class Aggregate(object):
+ """
+ Default SQL Aggregate.
+ """
+ is_ordinal = False
+ is_computed = False
+ sql_template = '%(function)s(%(field)s)'
+
+ def __init__(self, col, source=None, is_summary=False, **extra):
+ """Instantiate an SQL aggregate
+
+ * col is a column reference describing the subject field
+ of the aggregate. It can be an alias, or a tuple describing
+ a table and column name.
+ * source is the underlying field or aggregate definition for
+ the column reference. If the aggregate is not an ordinal or
+ computed type, this reference is used to determine the coerced
+ output type of the aggregate.
+ * extra is a dictionary of additional data to provide for the
+ aggregate definition
+
+ Also utilizes the class variables:
+ * sql_function, the name of the SQL function that implements the
+ aggregate.
+ * sql_template, a template string that is used to render the
+ aggregate into SQL.
+ * is_ordinal, a boolean indicating if the output of this aggregate
+ is an integer (e.g., a count)
+ * is_computed, a boolean indicating if this output of this aggregate
+ is a computed float (e.g., an average), regardless of the input
+ type.
+
+ """
+ self.col = col
+ self.source = source
+ self.is_summary = is_summary
+ self.extra = extra
+
+ # Follow the chain of aggregate sources back until you find an
+ # actual field, or an aggregate that forces a particular output
+ # type. This type of this field will be used to coerce values
+ # retrieved from the database.
+ tmp = self
+
+ while tmp and isinstance(tmp, Aggregate):
+ if getattr(tmp, 'is_ordinal', False):
+ tmp = ordinal_aggregate_field
+ elif getattr(tmp, 'is_computed', False):
+ tmp = computed_aggregate_field
+ else:
+ tmp = tmp.source
+
+ self.field = tmp
+
+ def relabel_aliases(self, change_map):
+ if isinstance(self.col, (list, tuple)):
+ self.col = (change_map.get(self.col[0], self.col[0]), self.col[1])
+
+ def as_sql(self, quote_func=None):
+ "Return the aggregate, rendered as SQL."
+ if not quote_func:
+ quote_func = lambda x: x
+
+ if hasattr(self.col, 'as_sql'):
+ field_name = self.col.as_sql(quote_func)
+ elif isinstance(self.col, (list, tuple)):
+ field_name = '.'.join([quote_func(c) for c in self.col])
+ else:
+ field_name = self.col
+
+ params = {
+ 'function': self.sql_function,
+ 'field': field_name
+ }
+ params.update(self.extra)
+
+ return self.sql_template % params
+
+
+class Avg(Aggregate):
+ is_computed = True
+ sql_function = 'AVG'
+
+class Count(Aggregate):
+ is_ordinal = True
+ sql_function = 'COUNT'
+ sql_template = '%(function)s(%(distinct)s%(field)s)'
+
+ def __init__(self, col, distinct=False, **extra):
+ super(Count, self).__init__(col, distinct=distinct and 'DISTINCT ' or '', **extra)
+
+class Max(Aggregate):
+ sql_function = 'MAX'
+
+class Min(Aggregate):
+ sql_function = 'MIN'
+
+class StdDev(Aggregate):
+ is_computed = True
+
+ def __init__(self, col, sample=False, **extra):
+ super(StdDev, self).__init__(col, **extra)
+ self.sql_function = sample and 'STDDEV_SAMP' or 'STDDEV_POP'
+
+class Sum(Aggregate):
+ sql_function = 'SUM'
+
+class Variance(Aggregate):
+ is_computed = True
+
+ def __init__(self, col, sample=False, **extra):
+ super(Variance, self).__init__(col, **extra)
+ self.sql_function = sample and 'VAR_SAMP' or 'VAR_POP'
+
diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py
index 913d8fde25..4d53999c79 100644
--- a/django/db/models/sql/datastructures.py
+++ b/django/db/models/sql/datastructures.py
@@ -25,59 +25,6 @@ class RawValue(object):
def __init__(self, value):
self.value = value
-class Aggregate(object):
- """
- Base class for all aggregate-related classes (min, max, avg, count, sum).
- """
- def relabel_aliases(self, change_map):
- """
- Relabel the column alias, if necessary. Must be implemented by
- subclasses.
- """
- raise NotImplementedError
-
- def as_sql(self, quote_func=None):
- """
- Returns the SQL string fragment for this object.
-
- The quote_func function is used to quote the column components. If
- None, it defaults to doing nothing.
-
- Must be implemented by subclasses.
- """
- raise NotImplementedError
-
-class Count(Aggregate):
- """
- Perform a count on the given column.
- """
- def __init__(self, col='*', distinct=False):
- """
- Set the column to count on (defaults to '*') and set whether the count
- should be distinct or not.
- """
- self.col = col
- self.distinct = distinct
-
- def relabel_aliases(self, change_map):
- c = self.col
- if isinstance(c, (list, tuple)):
- self.col = (change_map.get(c[0], c[0]), c[1])
-
- def as_sql(self, quote_func=None):
- if not quote_func:
- quote_func = lambda x: x
- if isinstance(self.col, (list, tuple)):
- col = ('%s.%s' % tuple([quote_func(c) for c in self.col]))
- elif hasattr(self.col, 'as_sql'):
- col = self.col.as_sql(quote_func)
- else:
- col = self.col
- if self.distinct:
- return 'COUNT(DISTINCT %s)' % col
- else:
- return 'COUNT(%s)' % col
-
class Date(object):
"""
Add a date selection column.
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
index b12912461f..156617f807 100644
--- a/django/db/models/sql/query.py
+++ b/django/db/models/sql/query.py
@@ -12,12 +12,13 @@ from copy import deepcopy
from django.utils.tree import Node
from django.utils.datastructures import SortedDict
from django.utils.encoding import force_unicode
+from django.db.backends.util import truncate_name
from django.db import connection
from django.db.models import signals
from django.db.models.fields import FieldDoesNotExist
from django.db.models.query_utils import select_related_descend
+from django.db.models.sql import aggregates as base_aggregates_module
from django.db.models.sql.where import WhereNode, Constraint, EverythingNode, AND, OR
-from django.db.models.sql.datastructures import Count
from django.core.exceptions import FieldError
from datastructures import EmptyResultSet, Empty, MultiJoin
from constants import *
@@ -40,6 +41,7 @@ class BaseQuery(object):
alias_prefix = 'T'
query_terms = QUERY_TERMS
+ aggregates_module = base_aggregates_module
def __init__(self, model, connection, where=WhereNode):
self.model = model
@@ -73,6 +75,9 @@ class BaseQuery(object):
self.select_related = False
self.related_select_cols = []
+ # SQL aggregate-related attributes
+ self.aggregate_select = SortedDict() # Maps alias -> SQL aggregate function
+
# Arbitrary maximum limit for select_related. Prevents infinite
# recursion. Can be changed by the depth parameter to select_related().
self.max_depth = 5
@@ -178,6 +183,7 @@ class BaseQuery(object):
obj.distinct = self.distinct
obj.select_related = self.select_related
obj.related_select_cols = []
+ obj.aggregate_select = self.aggregate_select.copy()
obj.max_depth = self.max_depth
obj.extra_select = self.extra_select.copy()
obj.extra_tables = self.extra_tables
@@ -194,6 +200,35 @@ class BaseQuery(object):
obj._setup_query()
return obj
+ def convert_values(self, value, field):
+ """Convert the database-returned value into a type that is consistent
+ across database backends.
+
+ By default, this defers to the underlying backend operations, but
+ it can be overridden by Query classes for specific backends.
+ """
+ return self.connection.ops.convert_values(value, field)
+
+ def resolve_aggregate(self, value, aggregate):
+ """Resolve the value of aggregates returned by the database to
+ consistent (and reasonable) types.
+
+ This is required because of the predisposition of certain backends
+ to return Decimal and long types when they are not needed.
+ """
+ if value is None:
+ # Return None as-is
+ return value
+ elif aggregate.is_ordinal:
+ # Any ordinal aggregate (e.g., count) returns an int
+ return int(value)
+ elif aggregate.is_computed:
+ # Any computed aggregate (e.g., avg) returns a float
+ return float(value)
+ else:
+ # Return value depends on the type of the field being processed.
+ return self.convert_values(value, aggregate.field)
+
def results_iter(self):
"""
Returns an iterator over the results from executing this query.
@@ -212,29 +247,78 @@ class BaseQuery(object):
else:
fields = self.model._meta.fields
row = self.resolve_columns(row, fields)
+
+ if self.aggregate_select:
+ aggregate_start = len(self.extra_select.keys()) + len(self.select)
+ row = tuple(row[:aggregate_start]) + tuple([
+ self.resolve_aggregate(value, aggregate)
+ for (alias, aggregate), value
+ in zip(self.aggregate_select.items(), row[aggregate_start:])
+ ])
+
yield row
+ def get_aggregation(self):
+ """
+ Returns the dictionary with the values of the existing aggregations.
+ """
+ if not self.aggregate_select:
+ return {}
+
+ # If there is a group by clause, aggregating does not add useful
+ # information but retrieves only the first row. Aggregate
+ # over the subquery instead.
+ if self.group_by:
+ from subqueries import AggregateQuery
+ query = AggregateQuery(self.model, self.connection)
+
+ obj = self.clone()
+
+ # Remove any aggregates marked for reduction from the subquery
+ # and move them to the outer AggregateQuery.
+ for alias, aggregate in self.aggregate_select.items():
+ if aggregate.is_summary:
+ query.aggregate_select[alias] = aggregate
+ del obj.aggregate_select[alias]
+
+ query.add_subquery(obj)
+ else:
+ query = self
+ self.select = []
+ self.default_cols = False
+ self.extra_select = {}
+
+ query.clear_ordering(True)
+ query.clear_limits()
+ query.select_related = False
+ query.related_select_cols = []
+ query.related_select_fields = []
+
+ return dict([
+ (alias, self.resolve_aggregate(val, aggregate))
+ for (alias, aggregate), val
+ in zip(query.aggregate_select.items(), query.execute_sql(SINGLE))
+ ])
+
def get_count(self):
"""
Performs a COUNT() query using the current filter constraints.
"""
- from subqueries import CountQuery
obj = self.clone()
- obj.clear_ordering(True)
- obj.clear_limits()
- obj.select_related = False
- obj.related_select_cols = []
- obj.related_select_fields = []
- if len(obj.select) > 1:
- obj = self.clone(CountQuery, _query=obj, where=self.where_class(),
- distinct=False)
- obj.select = []
- obj.extra_select = SortedDict()
+ if len(self.select) > 1:
+ # If a select clause exists, then the query has already started to
+ # specify the columns that are to be returned.
+ # In this case, we need to use a subquery to evaluate the count.
+ from subqueries import AggregateQuery
+ subquery = obj
+ subquery.clear_ordering(True)
+ subquery.clear_limits()
+
+ obj = AggregateQuery(obj.model, obj.connection)
+ obj.add_subquery(subquery)
+
obj.add_count_column()
- data = obj.execute_sql(SINGLE)
- if not data:
- return 0
- number = data[0]
+ number = obj.get_aggregation()[None]
# Apply offset and limit constraints manually, since using LIMIT/OFFSET
# in SQL (in variants that provide them) doesn't change the COUNT
@@ -450,25 +534,41 @@ class BaseQuery(object):
for col in self.select:
if isinstance(col, (list, tuple)):
r = '%s.%s' % (qn(col[0]), qn(col[1]))
- if with_aliases and col[1] in col_aliases:
- c_alias = 'Col%d' % len(col_aliases)
- result.append('%s AS %s' % (r, c_alias))
- aliases.add(c_alias)
- col_aliases.add(c_alias)
+ if with_aliases:
+ if col[1] in col_aliases:
+ c_alias = 'Col%d' % len(col_aliases)
+ result.append('%s AS %s' % (r, c_alias))
+ aliases.add(c_alias)
+ col_aliases.add(c_alias)
+ else:
+ result.append('%s AS %s' % (r, col[1]))
+ aliases.add(r)
+ col_aliases.add(col[1])
else:
result.append(r)
aliases.add(r)
col_aliases.add(col[1])
else:
result.append(col.as_sql(quote_func=qn))
+
if hasattr(col, 'alias'):
aliases.add(col.alias)
col_aliases.add(col.alias)
+
elif self.default_cols:
cols, new_aliases = self.get_default_columns(with_aliases,
col_aliases)
result.extend(cols)
aliases.update(new_aliases)
+
+ result.extend([
+ '%s%s' % (
+ aggregate.as_sql(quote_func=qn),
+ alias is not None and ' AS %s' % qn(alias) or ''
+ )
+ for alias, aggregate in self.aggregate_select.items()
+ ])
+
for table, col in self.related_select_cols:
r = '%s.%s' % (qn(table), qn(col))
if with_aliases and col in col_aliases:
@@ -538,7 +638,7 @@ class BaseQuery(object):
Returns a list of strings that are joined together to go after the
"FROM" part of the query, as well as a list any extra parameters that
need to be included. Sub-classes, can override this to create a
- from-clause via a "select", for example (e.g. CountQuery).
+ from-clause via a "select".
This should only be called after any SQL construction methods that
might change the tables we need. This means the select columns and
@@ -635,10 +735,13 @@ class BaseQuery(object):
order = asc
result.append('%s %s' % (field, order))
continue
+ col, order = get_order_dir(field, asc)
+ if col in self.aggregate_select:
+ result.append('%s %s' % (col, order))
+ continue
if '.' in field:
# This came in through an extra(order_by=...) addition. Pass it
# on verbatim.
- col, order = get_order_dir(field, asc)
table, col = col.split('.', 1)
if (table, col) not in processed_pairs:
elt = '%s.%s' % (qn(table), col)
@@ -657,7 +760,6 @@ class BaseQuery(object):
ordering_aliases.append(elt)
result.append('%s %s' % (elt, order))
else:
- col, order = get_order_dir(field, asc)
elt = qn2(col)
if distinct and col not in select_aliases:
ordering_aliases.append(elt)
@@ -1068,6 +1170,48 @@ class BaseQuery(object):
self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1,
used, next, restricted, new_nullable, dupe_set, avoid)
+ def add_aggregate(self, aggregate, model, alias, is_summary):
+ """
+ Adds a single aggregate expression to the Query
+ """
+ opts = model._meta
+ field_list = aggregate.lookup.split(LOOKUP_SEP)
+ if (len(field_list) == 1 and
+ aggregate.lookup in self.aggregate_select.keys()):
+ # Aggregate is over an annotation
+ field_name = field_list[0]
+ col = field_name
+ source = self.aggregate_select[field_name]
+ elif (len(field_list) > 1 or
+ field_list[0] not in [i.name for i in opts.fields]):
+ field, source, opts, join_list, last, _ = self.setup_joins(
+ field_list, opts, self.get_initial_alias(), False)
+
+ # Process the join chain to see if it can be trimmed
+ _, _, col, _, join_list = self.trim_joins(source, join_list, last, False)
+
+ # If the aggregate references a model or field that requires a join,
+ # those joins must be LEFT OUTER - empty join rows must be returned
+ # in order for zeros to be returned for those aggregates.
+ for column_alias in join_list:
+ self.promote_alias(column_alias, unconditional=True)
+
+ col = (join_list[-1], col)
+ else:
+ # Aggregate references a normal field
+ field_name = field_list[0]
+ source = opts.get_field(field_name)
+ if not (self.group_by and is_summary):
+ # Only use a column alias if this is a
+ # standalone aggregate, or an annotation
+ col = (opts.db_table, source.column)
+ else:
+ col = field_name
+
+ # Add the aggregate to the query
+ alias = truncate_name(alias, self.connection.ops.max_name_length())
+ aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary)
+
def add_filter(self, filter_expr, connector=AND, negate=False, trim=False,
can_reuse=None, process_extras=True):
"""
@@ -1119,6 +1263,11 @@ class BaseQuery(object):
elif callable(value):
value = value()
+ for alias, aggregate in self.aggregate_select.items():
+ if alias == parts[0]:
+ self.having.add((aggregate, lookup_type, value), AND)
+ return
+
opts = self.get_meta()
alias = self.get_initial_alias()
allow_many = trim or not negate
@@ -1131,38 +1280,9 @@ class BaseQuery(object):
self.split_exclude(filter_expr, LOOKUP_SEP.join(parts[:e.level]),
can_reuse)
return
- final = len(join_list)
- penultimate = last.pop()
- if penultimate == final:
- penultimate = last.pop()
- if trim and len(join_list) > 1:
- extra = join_list[penultimate:]
- join_list = join_list[:penultimate]
- final = penultimate
- penultimate = last.pop()
- col = self.alias_map[extra[0]][LHS_JOIN_COL]
- for alias in extra:
- self.unref_alias(alias)
- else:
- col = target.column
- alias = join_list[-1]
- while final > 1:
- # An optimization: if the final join is against the same column as
- # we are comparing against, we can go back one step in the join
- # chain and compare against the lhs of the join instead (and then
- # repeat the optimization). The result, potentially, involves less
- # table joins.
- join = self.alias_map[alias]
- if col != join[RHS_JOIN_COL]:
- break
- self.unref_alias(alias)
- alias = join[LHS_ALIAS]
- col = join[LHS_JOIN_COL]
- join_list = join_list[:-1]
- final -= 1
- if final == penultimate:
- penultimate = last.pop()
+ # Process the join chain to see if it can be trimmed
+ final, penultimate, col, alias, join_list = self.trim_joins(target, join_list, last, trim)
if (lookup_type == 'isnull' and value is True and not negate and
final > 1):
@@ -1313,7 +1433,7 @@ class BaseQuery(object):
field, model, direct, m2m = opts.get_field_by_name(f.name)
break
else:
- names = opts.get_all_field_names()
+ names = opts.get_all_field_names() + self.aggregate_select.keys()
raise FieldError("Cannot resolve keyword %r into field. "
"Choices are: %s" % (name, ", ".join(names)))
@@ -1462,6 +1582,43 @@ class BaseQuery(object):
return field, target, opts, joins, last, extra_filters
+ def trim_joins(self, target, join_list, last, trim):
+ """An optimization: if the final join is against the same column as
+ we are comparing against, we can go back one step in a join
+ chain and compare against the LHS of the join instead (and then
+ repeat the optimization). The result, potentially, involves less
+ table joins.
+
+ Returns a tuple
+ """
+ final = len(join_list)
+ penultimate = last.pop()
+ if penultimate == final:
+ penultimate = last.pop()
+ if trim and len(join_list) > 1:
+ extra = join_list[penultimate:]
+ join_list = join_list[:penultimate]
+ final = penultimate
+ penultimate = last.pop()
+ col = self.alias_map[extra[0]][LHS_JOIN_COL]
+ for alias in extra:
+ self.unref_alias(alias)
+ else:
+ col = target.column
+ alias = join_list[-1]
+ while final > 1:
+ join = self.alias_map[alias]
+ if col != join[RHS_JOIN_COL]:
+ break
+ self.unref_alias(alias)
+ alias = join[LHS_ALIAS]
+ col = join[LHS_JOIN_COL]
+ join_list = join_list[:-1]
+ final -= 1
+ if final == penultimate:
+ penultimate = last.pop()
+ return final, penultimate, col, alias, join_list
+
def update_dupe_avoidance(self, opts, col, alias):
"""
For a column that is one of multiple pointing to the same table, update
@@ -1554,6 +1711,7 @@ class BaseQuery(object):
"""
alias = self.get_initial_alias()
opts = self.get_meta()
+
try:
for name in field_names:
field, target, u2, joins, u3, u4 = self.setup_joins(
@@ -1574,7 +1732,7 @@ class BaseQuery(object):
except MultiJoin:
raise FieldError("Invalid field name: '%s'" % name)
except FieldError:
- names = opts.get_all_field_names() + self.extra_select.keys()
+ names = opts.get_all_field_names() + self.extra_select.keys() + self.aggregate_select.keys()
names.sort()
raise FieldError("Cannot resolve keyword %r into field. "
"Choices are: %s" % (name, ", ".join(names)))
@@ -1609,38 +1767,52 @@ class BaseQuery(object):
if force_empty:
self.default_ordering = False
+ def set_group_by(self):
+ """
+ Expands the GROUP BY clause required by the query.
+
+ This will usually be the set of all non-aggregate fields in the
+ return data. If the database backend supports grouping by the
+ primary key, and the query would be equivalent, the optimization
+ will be made automatically.
+ """
+ if self.connection.features.allows_group_by_pk:
+ if len(self.select) == len(self.model._meta.fields):
+ self.group_by.append('.'.join([self.model._meta.db_table,
+ self.model._meta.pk.column]))
+ return
+
+ for sel in self.select:
+ self.group_by.append(sel)
+
def add_count_column(self):
"""
Converts the query to do count(...) or count(distinct(pk)) in order to
get its size.
"""
- # TODO: When group_by support is added, this needs to be adjusted so
- # that it doesn't totally overwrite the select list.
if not self.distinct:
if not self.select:
- select = Count()
+ count = self.aggregates_module.Count('*', is_summary=True)
else:
assert len(self.select) == 1, \
"Cannot add count col with multiple cols in 'select': %r" % self.select
- select = Count(self.select[0])
+ count = self.aggregates_module.Count(self.select[0])
else:
opts = self.model._meta
if not self.select:
- select = Count((self.join((None, opts.db_table, None, None)),
- opts.pk.column), True)
+ count = self.aggregates_module.Count((self.join((None, opts.db_table, None, None)), opts.pk.column),
+ is_summary=True, distinct=True)
else:
# Because of SQL portability issues, multi-column, distinct
# counts need a sub-query -- see get_count() for details.
assert len(self.select) == 1, \
"Cannot add count col with multiple cols in 'select'."
- select = Count(self.select[0], True)
+ count = self.aggregates_module.Count(self.select[0], distinct=True)
# Distinct handling is done in Count(), so don't do it at this
# level.
self.distinct = False
- self.select = [select]
- self.select_fields = [None]
- self.extra_select = {}
+ self.aggregate_select = {None: count}
def add_select_related(self, fields):
"""
@@ -1758,7 +1930,6 @@ class BaseQuery(object):
return empty_iter()
else:
return
-
cursor = self.connection.cursor()
cursor.execute(sql, params)
diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py
index 524b0894c5..0a59b403c8 100644
--- a/django/db/models/sql/subqueries.py
+++ b/django/db/models/sql/subqueries.py
@@ -9,7 +9,7 @@ from django.db.models.sql.query import Query
from django.db.models.sql.where import AND, Constraint
__all__ = ['DeleteQuery', 'UpdateQuery', 'InsertQuery', 'DateQuery',
- 'CountQuery']
+ 'AggregateQuery']
class DeleteQuery(Query):
"""
@@ -400,15 +400,25 @@ class DateQuery(Query):
self.distinct = True
self.order_by = order == 'ASC' and [1] or [-1]
-class CountQuery(Query):
+class AggregateQuery(Query):
"""
- A CountQuery knows how to take a normal query which would select over
- multiple distinct columns and turn it into SQL that can be used on a
- variety of backends (it requires a select in the FROM clause).
+ An AggregateQuery takes another query as a parameter to the FROM
+ clause and only selects the elements in the provided list.
"""
- def get_from_clause(self):
- result, params = self._query.as_sql()
- return ['(%s) A1' % result], params
+ def add_subquery(self, query):
+ self.subquery, self.sub_params = query.as_sql(with_col_aliases=True)
- def get_ordering(self):
- return ()
+ def as_sql(self, quote_func=None):
+ """
+ Creates the SQL for this query. Returns the SQL string and list of
+ parameters.
+ """
+ sql = ('SELECT %s FROM (%s) subquery' % (
+ ', '.join([
+ aggregate.as_sql()
+ for aggregate in self.aggregate_select.values()
+ ]),
+ self.subquery)
+ )
+ params = self.sub_params
+ return (sql, params)