diff options
Diffstat (limited to 'django/db/models/query.py')
| -rw-r--r-- | django/db/models/query.py | 89 |
1 files changed, 82 insertions, 7 deletions
diff --git a/django/db/models/query.py b/django/db/models/query.py index e2dcb1fa65..b5c9d2f25d 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -4,6 +4,7 @@ except NameError: from sets import Set as set # Python 2.3 fallback from django.db import connection, transaction, IntegrityError +from django.db.models.aggregates import Aggregate from django.db.models.fields import DateField from django.db.models.query_utils import Q, select_related_descend from django.db.models import signals, sql @@ -270,18 +271,47 @@ class QuerySet(object): else: requested = None max_depth = self.query.max_depth + extra_select = self.query.extra_select.keys() + aggregate_select = self.query.aggregate_select.keys() + index_start = len(extra_select) + aggregate_start = index_start + len(self.model._meta.fields) + for row in self.query.results_iter(): if fill_cache: - obj, _ = get_cached_row(self.model, row, index_start, - max_depth, requested=requested) + obj, aggregate_start = get_cached_row(self.model, row, + index_start, max_depth, requested=requested) else: - obj = self.model(*row[index_start:]) + # omit aggregates in object creation + obj = self.model(*row[index_start:aggregate_start]) + for i, k in enumerate(extra_select): setattr(obj, k, row[i]) + + # Add the aggregates to the model + for i, aggregate in enumerate(aggregate_select): + setattr(obj, aggregate, row[i+aggregate_start]) + yield obj + def aggregate(self, *args, **kwargs): + """ + Returns a dictionary containing the calculations (aggregation) + over the current queryset + + If args is present the expression is passed as a kwarg ussing + the Aggregate object's default alias. + """ + for arg in args: + kwargs[arg.default_alias] = arg + + for (alias, aggregate_expr) in kwargs.items(): + self.query.add_aggregate(aggregate_expr, self.model, alias, + is_summary=True) + + return self.query.get_aggregation() + def count(self): """ Performs a SELECT COUNT() and returns the number of records as an @@ -553,6 +583,25 @@ class QuerySet(object): """ self.query.select_related = other.query.select_related + def annotate(self, *args, **kwargs): + """ + Return a query set in which the returned objects have been annotated + with data aggregated from related fields. + """ + for arg in args: + kwargs[arg.default_alias] = arg + + obj = self._clone() + + obj._setup_aggregate_query() + + # Add the aggregates to the query + for (alias, aggregate_expr) in kwargs.items(): + obj.query.add_aggregate(aggregate_expr, self.model, alias, + is_summary=False) + + return obj + def order_by(self, *field_names): """ Returns a new QuerySet instance with the ordering changed. @@ -641,6 +690,16 @@ class QuerySet(object): """ pass + def _setup_aggregate_query(self): + """ + Prepare the query for computing a result that contains aggregate annotations. + """ + opts = self.model._meta + if not self.query.group_by: + field_names = [f.attname for f in opts.fields] + self.query.add_fields(field_names, False) + self.query.set_group_by() + def as_sql(self): """ Returns the internal query's SQL and parameters (as a tuple). @@ -669,6 +728,8 @@ class ValuesQuerySet(QuerySet): len(self.field_names) != len(self.model._meta.fields)): self.query.trim_extra_select(self.extra_names) names = self.query.extra_select.keys() + self.field_names + names.extend(self.query.aggregate_select.keys()) + for row in self.query.results_iter(): yield dict(zip(names, row)) @@ -682,20 +743,25 @@ class ValuesQuerySet(QuerySet): """ self.query.clear_select_fields() self.extra_names = [] + self.aggregate_names = [] + if self._fields: - if not self.query.extra_select: + if not self.query.extra_select and not self.query.aggregate_select: field_names = list(self._fields) else: field_names = [] for f in self._fields: if self.query.extra_select.has_key(f): self.extra_names.append(f) + elif self.query.aggregate_select.has_key(f): + self.aggregate_names.append(f) else: field_names.append(f) else: # Default to all fields. field_names = [f.attname for f in self.model._meta.fields] + self.query.select = [] self.query.add_fields(field_names, False) self.query.default_cols = False self.field_names = field_names @@ -711,6 +777,7 @@ class ValuesQuerySet(QuerySet): c._fields = self._fields[:] c.field_names = self.field_names c.extra_names = self.extra_names + c.aggregate_names = self.aggregate_names if setup and hasattr(c, '_setup_query'): c._setup_query() return c @@ -718,10 +785,18 @@ class ValuesQuerySet(QuerySet): def _merge_sanity_check(self, other): super(ValuesQuerySet, self)._merge_sanity_check(other) if (set(self.extra_names) != set(other.extra_names) or - set(self.field_names) != set(other.field_names)): + set(self.field_names) != set(other.field_names) or + self.aggregate_names != other.aggregate_names): raise TypeError("Merging '%s' classes must involve the same values in each case." % self.__class__.__name__) + def _setup_aggregate_query(self): + """ + Prepare the query for computing a result that contains aggregate annotations. + """ + self.query.set_group_by() + + super(ValuesQuerySet, self)._setup_aggregate_query() class ValuesListQuerySet(ValuesQuerySet): def iterator(self): @@ -729,14 +804,14 @@ class ValuesListQuerySet(ValuesQuerySet): if self.flat and len(self._fields) == 1: for row in self.query.results_iter(): yield row[0] - elif not self.query.extra_select: + elif not self.query.extra_select and not self.query.aggregate_select: for row in self.query.results_iter(): yield tuple(row) else: # When extra(select=...) is involved, the extra cols come are # always at the start of the row, so we need to reorder the fields # to match the order in self._fields. - names = self.query.extra_select.keys() + self.field_names + names = self.query.extra_select.keys() + self.field_names + self.query.aggregate_select.keys() for row in self.query.results_iter(): data = dict(zip(names, row)) yield tuple([data[f] for f in self._fields]) |
