diff options
| author | Russell Keith-Magee <russell@keith-magee.com> | 2009-04-30 15:40:09 +0000 |
|---|---|---|
| committer | Russell Keith-Magee <russell@keith-magee.com> | 2009-04-30 15:40:09 +0000 |
| commit | 5e2d38465a661fc267676798d3aa4872e9da8265 (patch) | |
| tree | f1ed2888c0d1a731b2c5a44ff449408b24d23b15 /django/db/models/sql/query.py | |
| parent | 17958fa7a9a757c2b0abcdc3f931a010584234f1 (diff) | |
Fixed #10847 -- Modified handling of extra() to use a masking strategy, rather than last-minute trimming. Thanks to Tai Lee for the report, and Alex Gaynor for his work on the patch.
This enables querysets with an extra clause to be used in an __in filter; as a side effect, it also means that as_sql() now returns the correct result for any query with an extra clause.
git-svn-id: http://code.djangoproject.com/svn/django/trunk@10648 bcc190cf-cafb-0310-a4f2-bffc1f526a37
Diffstat (limited to 'django/db/models/sql/query.py')
| -rw-r--r-- | django/db/models/sql/query.py | 75 |
1 files changed, 55 insertions, 20 deletions
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index f4bf8b2b07..bafa1e93ea 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -88,7 +88,10 @@ class BaseQuery(object): # These are for extensions. The contents are more or less appended # verbatim to the appropriate clause. - self.extra_select = SortedDict() # Maps col_alias -> (col_sql, params). + self.extra = SortedDict() # Maps col_alias -> (col_sql, params). + self.extra_select_mask = None + self._extra_select_cache = None + self.extra_tables = () self.extra_where = () self.extra_params = () @@ -214,13 +217,21 @@ class BaseQuery(object): if self.aggregate_select_mask is None: obj.aggregate_select_mask = None else: - obj.aggregate_select_mask = self.aggregate_select_mask[:] + obj.aggregate_select_mask = self.aggregate_select_mask.copy() if self._aggregate_select_cache is None: obj._aggregate_select_cache = None else: obj._aggregate_select_cache = self._aggregate_select_cache.copy() obj.max_depth = self.max_depth - obj.extra_select = self.extra_select.copy() + obj.extra = self.extra.copy() + if self.extra_select_mask is None: + obj.extra_select_mask = None + else: + obj.extra_select_mask = self.extra_select_mask.copy() + if self._extra_select_cache is None: + obj._extra_select_cache = None + else: + obj._extra_select_cache = self._extra_select_cache.copy() obj.extra_tables = self.extra_tables obj.extra_where = self.extra_where obj.extra_params = self.extra_params @@ -325,7 +336,7 @@ class BaseQuery(object): query = self self.select = [] self.default_cols = False - self.extra_select = {} + self.extra = {} self.remove_inherited_models() query.clear_ordering(True) @@ -540,13 +551,20 @@ class BaseQuery(object): # It would be nice to be able to handle this, but the queries don't # really make sense (or return consistent value sets). Not worth # the extra complexity when you can write a real query instead. - if self.extra_select and rhs.extra_select: + if self.extra and rhs.extra: raise ValueError("When merging querysets using 'or', you " "cannot have extra(select=...) on both sides.") if self.extra_where and rhs.extra_where: raise ValueError("When merging querysets using 'or', you " "cannot have extra(where=...) on both sides.") - self.extra_select.update(rhs.extra_select) + self.extra.update(rhs.extra) + extra_select_mask = set() + if self.extra_select_mask is not None: + extra_select_mask.update(self.extra_select_mask) + if rhs.extra_select_mask is not None: + extra_select_mask.update(rhs.extra_select_mask) + if extra_select_mask: + self.set_extra_mask(extra_select_mask) self.extra_tables += rhs.extra_tables self.extra_where += rhs.extra_where self.extra_params += rhs.extra_params @@ -2011,7 +2029,7 @@ class BaseQuery(object): except MultiJoin: raise FieldError("Invalid field name: '%s'" % name) except FieldError: - names = opts.get_all_field_names() + self.extra_select.keys() + self.aggregate_select.keys() + names = opts.get_all_field_names() + self.extra.keys() + self.aggregate_select.keys() names.sort() raise FieldError("Cannot resolve keyword %r into field. " "Choices are: %s" % (name, ", ".join(names))) @@ -2139,7 +2157,7 @@ class BaseQuery(object): pos = entry.find("%s", pos + 2) select_pairs[name] = (entry, entry_params) # This is order preserving, since self.extra_select is a SortedDict. - self.extra_select.update(select_pairs) + self.extra.update(select_pairs) if where: self.extra_where += tuple(where) if params: @@ -2213,22 +2231,26 @@ class BaseQuery(object): """ target[model] = set([f.name for f in fields]) - def trim_extra_select(self, names): - """ - Removes any aliases in the extra_select dictionary that aren't in - 'names'. - - This is needed if we are selecting certain values that don't incldue - all of the extra_select names. - """ - for key in set(self.extra_select).difference(set(names)): - del self.extra_select[key] - def set_aggregate_mask(self, names): "Set the mask of aggregates that will actually be returned by the SELECT" - self.aggregate_select_mask = names + if names is None: + self.aggregate_select_mask = None + else: + self.aggregate_select_mask = set(names) self._aggregate_select_cache = None + def set_extra_mask(self, names): + """ + Set the mask of extra select items that will be returned by SELECT, + we don't actually remove them from the Query since they might be used + later + """ + if names is None: + self.extra_select_mask = None + else: + self.extra_select_mask = set(names) + self._extra_select_cache = None + def _aggregate_select(self): """The SortedDict of aggregate columns that are not masked, and should be used in the SELECT clause. @@ -2247,6 +2269,19 @@ class BaseQuery(object): return self.aggregates aggregate_select = property(_aggregate_select) + def _extra_select(self): + if self._extra_select_cache is not None: + return self._extra_select_cache + elif self.extra_select_mask is not None: + self._extra_select_cache = SortedDict([ + (k,v) for k,v in self.extra.items() + if k in self.extra_select_mask + ]) + return self._extra_select_cache + else: + return self.extra + extra_select = property(_extra_select) + def set_start(self, start): """ Sets the table from which to start joining. The start position is |
