summaryrefslogtreecommitdiff
path: root/django/db/models/sql
diff options
context:
space:
mode:
authorRussell Keith-Magee <russell@keith-magee.com>2009-04-30 15:40:09 +0000
committerRussell Keith-Magee <russell@keith-magee.com>2009-04-30 15:40:09 +0000
commit5e2d38465a661fc267676798d3aa4872e9da8265 (patch)
treef1ed2888c0d1a731b2c5a44ff449408b24d23b15 /django/db/models/sql
parent17958fa7a9a757c2b0abcdc3f931a010584234f1 (diff)
Fixed #10847 -- Modified handling of extra() to use a masking strategy, rather than last-minute trimming. Thanks to Tai Lee for the report, and Alex Gaynor for his work on the patch.
This enables querysets with an extra clause to be used in an __in filter; as a side effect, it also means that as_sql() now returns the correct result for any query with an extra clause. git-svn-id: http://code.djangoproject.com/svn/django/trunk@10648 bcc190cf-cafb-0310-a4f2-bffc1f526a37
Diffstat (limited to 'django/db/models/sql')
-rw-r--r--django/db/models/sql/query.py75
-rw-r--r--django/db/models/sql/subqueries.py4
2 files changed, 57 insertions, 22 deletions
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
index f4bf8b2b07..bafa1e93ea 100644
--- a/django/db/models/sql/query.py
+++ b/django/db/models/sql/query.py
@@ -88,7 +88,10 @@ class BaseQuery(object):
# These are for extensions. The contents are more or less appended
# verbatim to the appropriate clause.
- self.extra_select = SortedDict() # Maps col_alias -> (col_sql, params).
+ self.extra = SortedDict() # Maps col_alias -> (col_sql, params).
+ self.extra_select_mask = None
+ self._extra_select_cache = None
+
self.extra_tables = ()
self.extra_where = ()
self.extra_params = ()
@@ -214,13 +217,21 @@ class BaseQuery(object):
if self.aggregate_select_mask is None:
obj.aggregate_select_mask = None
else:
- obj.aggregate_select_mask = self.aggregate_select_mask[:]
+ obj.aggregate_select_mask = self.aggregate_select_mask.copy()
if self._aggregate_select_cache is None:
obj._aggregate_select_cache = None
else:
obj._aggregate_select_cache = self._aggregate_select_cache.copy()
obj.max_depth = self.max_depth
- obj.extra_select = self.extra_select.copy()
+ obj.extra = self.extra.copy()
+ if self.extra_select_mask is None:
+ obj.extra_select_mask = None
+ else:
+ obj.extra_select_mask = self.extra_select_mask.copy()
+ if self._extra_select_cache is None:
+ obj._extra_select_cache = None
+ else:
+ obj._extra_select_cache = self._extra_select_cache.copy()
obj.extra_tables = self.extra_tables
obj.extra_where = self.extra_where
obj.extra_params = self.extra_params
@@ -325,7 +336,7 @@ class BaseQuery(object):
query = self
self.select = []
self.default_cols = False
- self.extra_select = {}
+ self.extra = {}
self.remove_inherited_models()
query.clear_ordering(True)
@@ -540,13 +551,20 @@ class BaseQuery(object):
# It would be nice to be able to handle this, but the queries don't
# really make sense (or return consistent value sets). Not worth
# the extra complexity when you can write a real query instead.
- if self.extra_select and rhs.extra_select:
+ if self.extra and rhs.extra:
raise ValueError("When merging querysets using 'or', you "
"cannot have extra(select=...) on both sides.")
if self.extra_where and rhs.extra_where:
raise ValueError("When merging querysets using 'or', you "
"cannot have extra(where=...) on both sides.")
- self.extra_select.update(rhs.extra_select)
+ self.extra.update(rhs.extra)
+ extra_select_mask = set()
+ if self.extra_select_mask is not None:
+ extra_select_mask.update(self.extra_select_mask)
+ if rhs.extra_select_mask is not None:
+ extra_select_mask.update(rhs.extra_select_mask)
+ if extra_select_mask:
+ self.set_extra_mask(extra_select_mask)
self.extra_tables += rhs.extra_tables
self.extra_where += rhs.extra_where
self.extra_params += rhs.extra_params
@@ -2011,7 +2029,7 @@ class BaseQuery(object):
except MultiJoin:
raise FieldError("Invalid field name: '%s'" % name)
except FieldError:
- names = opts.get_all_field_names() + self.extra_select.keys() + self.aggregate_select.keys()
+ names = opts.get_all_field_names() + self.extra.keys() + self.aggregate_select.keys()
names.sort()
raise FieldError("Cannot resolve keyword %r into field. "
"Choices are: %s" % (name, ", ".join(names)))
@@ -2139,7 +2157,7 @@ class BaseQuery(object):
pos = entry.find("%s", pos + 2)
select_pairs[name] = (entry, entry_params)
# This is order preserving, since self.extra_select is a SortedDict.
- self.extra_select.update(select_pairs)
+ self.extra.update(select_pairs)
if where:
self.extra_where += tuple(where)
if params:
@@ -2213,22 +2231,26 @@ class BaseQuery(object):
"""
target[model] = set([f.name for f in fields])
- def trim_extra_select(self, names):
- """
- Removes any aliases in the extra_select dictionary that aren't in
- 'names'.
-
- This is needed if we are selecting certain values that don't incldue
- all of the extra_select names.
- """
- for key in set(self.extra_select).difference(set(names)):
- del self.extra_select[key]
-
def set_aggregate_mask(self, names):
"Set the mask of aggregates that will actually be returned by the SELECT"
- self.aggregate_select_mask = names
+ if names is None:
+ self.aggregate_select_mask = None
+ else:
+ self.aggregate_select_mask = set(names)
self._aggregate_select_cache = None
+ def set_extra_mask(self, names):
+ """
+ Set the mask of extra select items that will be returned by SELECT,
+ we don't actually remove them from the Query since they might be used
+ later
+ """
+ if names is None:
+ self.extra_select_mask = None
+ else:
+ self.extra_select_mask = set(names)
+ self._extra_select_cache = None
+
def _aggregate_select(self):
"""The SortedDict of aggregate columns that are not masked, and should
be used in the SELECT clause.
@@ -2247,6 +2269,19 @@ class BaseQuery(object):
return self.aggregates
aggregate_select = property(_aggregate_select)
+ def _extra_select(self):
+ if self._extra_select_cache is not None:
+ return self._extra_select_cache
+ elif self.extra_select_mask is not None:
+ self._extra_select_cache = SortedDict([
+ (k,v) for k,v in self.extra.items()
+ if k in self.extra_select_mask
+ ])
+ return self._extra_select_cache
+ else:
+ return self.extra
+ extra_select = property(_extra_select)
+
def set_start(self, start):
"""
Sets the table from which to start joining. The start position is
diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py
index 4c62457c57..0cd393756d 100644
--- a/django/db/models/sql/subqueries.py
+++ b/django/db/models/sql/subqueries.py
@@ -178,7 +178,7 @@ class UpdateQuery(Query):
# from other tables.
query = self.clone(klass=Query)
query.bump_prefix()
- query.extra_select = {}
+ query.extra = {}
query.select = []
query.add_fields([query.model._meta.pk.name])
must_pre_select = count > 1 and not self.connection.features.update_can_self_select
@@ -409,7 +409,7 @@ class DateQuery(Query):
self.select = [select]
self.select_fields = [None]
self.select_related = False # See #7097.
- self.extra_select = {}
+ self.extra = {}
self.distinct = True
self.order_by = order == 'ASC' and [1] or [-1]