summaryrefslogtreecommitdiff
path: root/django/db/models/query.py
diff options
context:
space:
mode:
Diffstat (limited to 'django/db/models/query.py')
-rw-r--r--django/db/models/query.py89
1 files changed, 82 insertions, 7 deletions
diff --git a/django/db/models/query.py b/django/db/models/query.py
index e2dcb1fa65..b5c9d2f25d 100644
--- a/django/db/models/query.py
+++ b/django/db/models/query.py
@@ -4,6 +4,7 @@ except NameError:
from sets import Set as set # Python 2.3 fallback
from django.db import connection, transaction, IntegrityError
+from django.db.models.aggregates import Aggregate
from django.db.models.fields import DateField
from django.db.models.query_utils import Q, select_related_descend
from django.db.models import signals, sql
@@ -270,18 +271,47 @@ class QuerySet(object):
else:
requested = None
max_depth = self.query.max_depth
+
extra_select = self.query.extra_select.keys()
+ aggregate_select = self.query.aggregate_select.keys()
+
index_start = len(extra_select)
+ aggregate_start = index_start + len(self.model._meta.fields)
+
for row in self.query.results_iter():
if fill_cache:
- obj, _ = get_cached_row(self.model, row, index_start,
- max_depth, requested=requested)
+ obj, aggregate_start = get_cached_row(self.model, row,
+ index_start, max_depth, requested=requested)
else:
- obj = self.model(*row[index_start:])
+ # omit aggregates in object creation
+ obj = self.model(*row[index_start:aggregate_start])
+
for i, k in enumerate(extra_select):
setattr(obj, k, row[i])
+
+ # Add the aggregates to the model
+ for i, aggregate in enumerate(aggregate_select):
+ setattr(obj, aggregate, row[i+aggregate_start])
+
yield obj
+ def aggregate(self, *args, **kwargs):
+ """
+ Returns a dictionary containing the calculations (aggregation)
+ over the current queryset
+
+ If args is present the expression is passed as a kwarg ussing
+ the Aggregate object's default alias.
+ """
+ for arg in args:
+ kwargs[arg.default_alias] = arg
+
+ for (alias, aggregate_expr) in kwargs.items():
+ self.query.add_aggregate(aggregate_expr, self.model, alias,
+ is_summary=True)
+
+ return self.query.get_aggregation()
+
def count(self):
"""
Performs a SELECT COUNT() and returns the number of records as an
@@ -553,6 +583,25 @@ class QuerySet(object):
"""
self.query.select_related = other.query.select_related
+ def annotate(self, *args, **kwargs):
+ """
+ Return a query set in which the returned objects have been annotated
+ with data aggregated from related fields.
+ """
+ for arg in args:
+ kwargs[arg.default_alias] = arg
+
+ obj = self._clone()
+
+ obj._setup_aggregate_query()
+
+ # Add the aggregates to the query
+ for (alias, aggregate_expr) in kwargs.items():
+ obj.query.add_aggregate(aggregate_expr, self.model, alias,
+ is_summary=False)
+
+ return obj
+
def order_by(self, *field_names):
"""
Returns a new QuerySet instance with the ordering changed.
@@ -641,6 +690,16 @@ class QuerySet(object):
"""
pass
+ def _setup_aggregate_query(self):
+ """
+ Prepare the query for computing a result that contains aggregate annotations.
+ """
+ opts = self.model._meta
+ if not self.query.group_by:
+ field_names = [f.attname for f in opts.fields]
+ self.query.add_fields(field_names, False)
+ self.query.set_group_by()
+
def as_sql(self):
"""
Returns the internal query's SQL and parameters (as a tuple).
@@ -669,6 +728,8 @@ class ValuesQuerySet(QuerySet):
len(self.field_names) != len(self.model._meta.fields)):
self.query.trim_extra_select(self.extra_names)
names = self.query.extra_select.keys() + self.field_names
+ names.extend(self.query.aggregate_select.keys())
+
for row in self.query.results_iter():
yield dict(zip(names, row))
@@ -682,20 +743,25 @@ class ValuesQuerySet(QuerySet):
"""
self.query.clear_select_fields()
self.extra_names = []
+ self.aggregate_names = []
+
if self._fields:
- if not self.query.extra_select:
+ if not self.query.extra_select and not self.query.aggregate_select:
field_names = list(self._fields)
else:
field_names = []
for f in self._fields:
if self.query.extra_select.has_key(f):
self.extra_names.append(f)
+ elif self.query.aggregate_select.has_key(f):
+ self.aggregate_names.append(f)
else:
field_names.append(f)
else:
# Default to all fields.
field_names = [f.attname for f in self.model._meta.fields]
+ self.query.select = []
self.query.add_fields(field_names, False)
self.query.default_cols = False
self.field_names = field_names
@@ -711,6 +777,7 @@ class ValuesQuerySet(QuerySet):
c._fields = self._fields[:]
c.field_names = self.field_names
c.extra_names = self.extra_names
+ c.aggregate_names = self.aggregate_names
if setup and hasattr(c, '_setup_query'):
c._setup_query()
return c
@@ -718,10 +785,18 @@ class ValuesQuerySet(QuerySet):
def _merge_sanity_check(self, other):
super(ValuesQuerySet, self)._merge_sanity_check(other)
if (set(self.extra_names) != set(other.extra_names) or
- set(self.field_names) != set(other.field_names)):
+ set(self.field_names) != set(other.field_names) or
+ self.aggregate_names != other.aggregate_names):
raise TypeError("Merging '%s' classes must involve the same values in each case."
% self.__class__.__name__)
+ def _setup_aggregate_query(self):
+ """
+ Prepare the query for computing a result that contains aggregate annotations.
+ """
+ self.query.set_group_by()
+
+ super(ValuesQuerySet, self)._setup_aggregate_query()
class ValuesListQuerySet(ValuesQuerySet):
def iterator(self):
@@ -729,14 +804,14 @@ class ValuesListQuerySet(ValuesQuerySet):
if self.flat and len(self._fields) == 1:
for row in self.query.results_iter():
yield row[0]
- elif not self.query.extra_select:
+ elif not self.query.extra_select and not self.query.aggregate_select:
for row in self.query.results_iter():
yield tuple(row)
else:
# When extra(select=...) is involved, the extra cols come are
# always at the start of the row, so we need to reorder the fields
# to match the order in self._fields.
- names = self.query.extra_select.keys() + self.field_names
+ names = self.query.extra_select.keys() + self.field_names + self.query.aggregate_select.keys()
for row in self.query.results_iter():
data = dict(zip(names, row))
yield tuple([data[f] for f in self._fields])