summaryrefslogtreecommitdiff
path: root/django/db/models/sql
diff options
context:
space:
mode:
authorRussell Keith-Magee <russell@keith-magee.com>2009-02-23 14:47:59 +0000
committerRussell Keith-Magee <russell@keith-magee.com>2009-02-23 14:47:59 +0000
commit542709d0d1796326dd1edacf32fc1198cfad2869 (patch)
tree40578c8972862606d7ddb0dad9c7e0e163b160e0 /django/db/models/sql
parent4bd24474c02a6f3c70e8111ac262fabf2fc5f454 (diff)
Fixed #10182 -- Corrected realiasing and the process of evaluating values() for queries with aggregate clauses. This means that aggregate queries can now be used as subqueries (such as in an __in clause). Thanks to omat for the report.
This involves a slight change to the interaction of annotate() and values() clauses that specify a list of columns. See the docs for details. git-svn-id: http://code.djangoproject.com/svn/django/trunk@9888 bcc190cf-cafb-0310-a4f2-bffc1f526a37
Diffstat (limited to 'django/db/models/sql')
-rw-r--r--django/db/models/sql/query.py70
-rw-r--r--django/db/models/sql/where.py12
2 files changed, 63 insertions, 19 deletions
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
index 629afa29e7..fbc5467b3c 100644
--- a/django/db/models/sql/query.py
+++ b/django/db/models/sql/query.py
@@ -77,7 +77,9 @@ class BaseQuery(object):
self.related_select_cols = []
# SQL aggregate-related attributes
- self.aggregate_select = SortedDict() # Maps alias -> SQL aggregate function
+ self.aggregates = SortedDict() # Maps alias -> SQL aggregate function
+ self.aggregate_select_mask = None
+ self._aggregate_select_cache = None
# Arbitrary maximum limit for select_related. Prevents infinite
# recursion. Can be changed by the depth parameter to select_related().
@@ -187,7 +189,15 @@ class BaseQuery(object):
obj.distinct = self.distinct
obj.select_related = self.select_related
obj.related_select_cols = []
- obj.aggregate_select = self.aggregate_select.copy()
+ obj.aggregates = self.aggregates.copy()
+ if self.aggregate_select_mask is None:
+ obj.aggregate_select_mask = None
+ else:
+ obj.aggregate_select_mask = self.aggregate_select_mask[:]
+ if self._aggregate_select_cache is None:
+ obj._aggregate_select_cache = None
+ else:
+ obj._aggregate_select_cache = self._aggregate_select_cache.copy()
obj.max_depth = self.max_depth
obj.extra_select = self.extra_select.copy()
obj.extra_tables = self.extra_tables
@@ -940,14 +950,17 @@ class BaseQuery(object):
"""
assert set(change_map.keys()).intersection(set(change_map.values())) == set()
- # 1. Update references in "select" and "where".
+ # 1. Update references in "select" (normal columns plus aliases),
+ # "group by", "where" and "having".
self.where.relabel_aliases(change_map)
- for pos, col in enumerate(self.select):
- if isinstance(col, (list, tuple)):
- old_alias = col[0]
- self.select[pos] = (change_map.get(old_alias, old_alias), col[1])
- else:
- col.relabel_aliases(change_map)
+ self.having.relabel_aliases(change_map)
+ for columns in (self.select, self.aggregates.values(), self.group_by or []):
+ for pos, col in enumerate(columns):
+ if isinstance(col, (list, tuple)):
+ old_alias = col[0]
+ columns[pos] = (change_map.get(old_alias, old_alias), col[1])
+ else:
+ col.relabel_aliases(change_map)
# 2. Rename the alias in the internal table/alias datastructures.
for old_alias, new_alias in change_map.iteritems():
@@ -1205,11 +1218,11 @@ class BaseQuery(object):
opts = model._meta
field_list = aggregate.lookup.split(LOOKUP_SEP)
if (len(field_list) == 1 and
- aggregate.lookup in self.aggregate_select.keys()):
+ aggregate.lookup in self.aggregates.keys()):
# Aggregate is over an annotation
field_name = field_list[0]
col = field_name
- source = self.aggregate_select[field_name]
+ source = self.aggregates[field_name]
elif (len(field_list) > 1 or
field_list[0] not in [i.name for i in opts.fields]):
field, source, opts, join_list, last, _ = self.setup_joins(
@@ -1299,7 +1312,7 @@ class BaseQuery(object):
value = SQLEvaluator(value, self)
having_clause = value.contains_aggregate
- for alias, aggregate in self.aggregate_select.items():
+ for alias, aggregate in self.aggregates.items():
if alias == parts[0]:
entry = self.where_class()
entry.add((aggregate, lookup_type, value), AND)
@@ -1824,8 +1837,8 @@ class BaseQuery(object):
self.group_by = []
if self.connection.features.allows_group_by_pk:
if len(self.select) == len(self.model._meta.fields):
- self.group_by.append('.'.join([self.model._meta.db_table,
- self.model._meta.pk.column]))
+ self.group_by.append((self.model._meta.db_table,
+ self.model._meta.pk.column))
return
for sel in self.select:
@@ -1858,7 +1871,11 @@ class BaseQuery(object):
# Distinct handling is done in Count(), so don't do it at this
# level.
self.distinct = False
- self.aggregate_select = {None: count}
+
+ # Set only aggregate to be the count column.
+ # Clear out the select cache to reflect the new unmasked aggregates.
+ self.aggregates = {None: count}
+ self.set_aggregate_mask(None)
def add_select_related(self, fields):
"""
@@ -1920,6 +1937,29 @@ class BaseQuery(object):
for key in set(self.extra_select).difference(set(names)):
del self.extra_select[key]
+ def set_aggregate_mask(self, names):
+ "Set the mask of aggregates that will actually be returned by the SELECT"
+ self.aggregate_select_mask = names
+ self._aggregate_select_cache = None
+
+ def _aggregate_select(self):
+ """The SortedDict of aggregate columns that are not masked, and should
+ be used in the SELECT clause.
+
+ This result is cached for optimization purposes.
+ """
+ if self._aggregate_select_cache is not None:
+ return self._aggregate_select_cache
+ elif self.aggregate_select_mask is not None:
+ self._aggregate_select_cache = SortedDict([
+ (k,v) for k,v in self.aggregates.items()
+ if k in self.aggregate_select_mask
+ ])
+ return self._aggregate_select_cache
+ else:
+ return self.aggregates
+ aggregate_select = property(_aggregate_select)
+
def set_start(self, start):
"""
Sets the table from which to start joining. The start position is
diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py
index 1d4df127fe..43ac42489a 100644
--- a/django/db/models/sql/where.py
+++ b/django/db/models/sql/where.py
@@ -213,10 +213,14 @@ class WhereNode(tree.Node):
elif isinstance(child, tree.Node):
self.relabel_aliases(change_map, child)
else:
- elt = list(child[0])
- if elt[0] in change_map:
- elt[0] = change_map[elt[0]]
- node.children[pos] = (tuple(elt),) + child[1:]
+ if isinstance(child[0], (list, tuple)):
+ elt = list(child[0])
+ if elt[0] in change_map:
+ elt[0] = change_map[elt[0]]
+ node.children[pos] = (tuple(elt),) + child[1:]
+ else:
+ child[0].relabel_aliases(change_map)
+
# Check if the query value also requires relabelling
if hasattr(child[3], 'relabel_aliases'):
child[3].relabel_aliases(change_map)