diff options
| author | Russell Keith-Magee <russell@keith-magee.com> | 2009-02-23 14:47:59 +0000 |
|---|---|---|
| committer | Russell Keith-Magee <russell@keith-magee.com> | 2009-02-23 14:47:59 +0000 |
| commit | 542709d0d1796326dd1edacf32fc1198cfad2869 (patch) | |
| tree | 40578c8972862606d7ddb0dad9c7e0e163b160e0 /django/db/models/sql | |
| parent | 4bd24474c02a6f3c70e8111ac262fabf2fc5f454 (diff) | |
Fixed #10182 -- Corrected realiasing and the process of evaluating values() for queries with aggregate clauses. This means that aggregate queries can now be used as subqueries (such as in an __in clause). Thanks to omat for the report.
This involves a slight change to the interaction of annotate() and values() clauses that specify a list of columns. See the docs for details.
git-svn-id: http://code.djangoproject.com/svn/django/trunk@9888 bcc190cf-cafb-0310-a4f2-bffc1f526a37
Diffstat (limited to 'django/db/models/sql')
| -rw-r--r-- | django/db/models/sql/query.py | 70 | ||||
| -rw-r--r-- | django/db/models/sql/where.py | 12 |
2 files changed, 63 insertions, 19 deletions
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 629afa29e7..fbc5467b3c 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -77,7 +77,9 @@ class BaseQuery(object): self.related_select_cols = [] # SQL aggregate-related attributes - self.aggregate_select = SortedDict() # Maps alias -> SQL aggregate function + self.aggregates = SortedDict() # Maps alias -> SQL aggregate function + self.aggregate_select_mask = None + self._aggregate_select_cache = None # Arbitrary maximum limit for select_related. Prevents infinite # recursion. Can be changed by the depth parameter to select_related(). @@ -187,7 +189,15 @@ class BaseQuery(object): obj.distinct = self.distinct obj.select_related = self.select_related obj.related_select_cols = [] - obj.aggregate_select = self.aggregate_select.copy() + obj.aggregates = self.aggregates.copy() + if self.aggregate_select_mask is None: + obj.aggregate_select_mask = None + else: + obj.aggregate_select_mask = self.aggregate_select_mask[:] + if self._aggregate_select_cache is None: + obj._aggregate_select_cache = None + else: + obj._aggregate_select_cache = self._aggregate_select_cache.copy() obj.max_depth = self.max_depth obj.extra_select = self.extra_select.copy() obj.extra_tables = self.extra_tables @@ -940,14 +950,17 @@ class BaseQuery(object): """ assert set(change_map.keys()).intersection(set(change_map.values())) == set() - # 1. Update references in "select" and "where". + # 1. Update references in "select" (normal columns plus aliases), + # "group by", "where" and "having". self.where.relabel_aliases(change_map) - for pos, col in enumerate(self.select): - if isinstance(col, (list, tuple)): - old_alias = col[0] - self.select[pos] = (change_map.get(old_alias, old_alias), col[1]) - else: - col.relabel_aliases(change_map) + self.having.relabel_aliases(change_map) + for columns in (self.select, self.aggregates.values(), self.group_by or []): + for pos, col in enumerate(columns): + if isinstance(col, (list, tuple)): + old_alias = col[0] + columns[pos] = (change_map.get(old_alias, old_alias), col[1]) + else: + col.relabel_aliases(change_map) # 2. Rename the alias in the internal table/alias datastructures. for old_alias, new_alias in change_map.iteritems(): @@ -1205,11 +1218,11 @@ class BaseQuery(object): opts = model._meta field_list = aggregate.lookup.split(LOOKUP_SEP) if (len(field_list) == 1 and - aggregate.lookup in self.aggregate_select.keys()): + aggregate.lookup in self.aggregates.keys()): # Aggregate is over an annotation field_name = field_list[0] col = field_name - source = self.aggregate_select[field_name] + source = self.aggregates[field_name] elif (len(field_list) > 1 or field_list[0] not in [i.name for i in opts.fields]): field, source, opts, join_list, last, _ = self.setup_joins( @@ -1299,7 +1312,7 @@ class BaseQuery(object): value = SQLEvaluator(value, self) having_clause = value.contains_aggregate - for alias, aggregate in self.aggregate_select.items(): + for alias, aggregate in self.aggregates.items(): if alias == parts[0]: entry = self.where_class() entry.add((aggregate, lookup_type, value), AND) @@ -1824,8 +1837,8 @@ class BaseQuery(object): self.group_by = [] if self.connection.features.allows_group_by_pk: if len(self.select) == len(self.model._meta.fields): - self.group_by.append('.'.join([self.model._meta.db_table, - self.model._meta.pk.column])) + self.group_by.append((self.model._meta.db_table, + self.model._meta.pk.column)) return for sel in self.select: @@ -1858,7 +1871,11 @@ class BaseQuery(object): # Distinct handling is done in Count(), so don't do it at this # level. self.distinct = False - self.aggregate_select = {None: count} + + # Set only aggregate to be the count column. + # Clear out the select cache to reflect the new unmasked aggregates. + self.aggregates = {None: count} + self.set_aggregate_mask(None) def add_select_related(self, fields): """ @@ -1920,6 +1937,29 @@ class BaseQuery(object): for key in set(self.extra_select).difference(set(names)): del self.extra_select[key] + def set_aggregate_mask(self, names): + "Set the mask of aggregates that will actually be returned by the SELECT" + self.aggregate_select_mask = names + self._aggregate_select_cache = None + + def _aggregate_select(self): + """The SortedDict of aggregate columns that are not masked, and should + be used in the SELECT clause. + + This result is cached for optimization purposes. + """ + if self._aggregate_select_cache is not None: + return self._aggregate_select_cache + elif self.aggregate_select_mask is not None: + self._aggregate_select_cache = SortedDict([ + (k,v) for k,v in self.aggregates.items() + if k in self.aggregate_select_mask + ]) + return self._aggregate_select_cache + else: + return self.aggregates + aggregate_select = property(_aggregate_select) + def set_start(self, start): """ Sets the table from which to start joining. The start position is diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 1d4df127fe..43ac42489a 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -213,10 +213,14 @@ class WhereNode(tree.Node): elif isinstance(child, tree.Node): self.relabel_aliases(change_map, child) else: - elt = list(child[0]) - if elt[0] in change_map: - elt[0] = change_map[elt[0]] - node.children[pos] = (tuple(elt),) + child[1:] + if isinstance(child[0], (list, tuple)): + elt = list(child[0]) + if elt[0] in change_map: + elt[0] = change_map[elt[0]] + node.children[pos] = (tuple(elt),) + child[1:] + else: + child[0].relabel_aliases(change_map) + # Check if the query value also requires relabelling if hasattr(child[3], 'relabel_aliases'): child[3].relabel_aliases(change_map) |
