diff options
Diffstat (limited to 'django/db/models/sql/query.py')
| -rw-r--r-- | django/db/models/sql/query.py | 47 |
1 files changed, 33 insertions, 14 deletions
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 07c3fdbd34..ce97ebe1d1 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -26,6 +26,7 @@ from django.db.models.expressions import ( Exists, F, OuterRef, + RawSQL, Ref, ResolvedOuterRef, Value, @@ -265,6 +266,7 @@ class Query(BaseExpression): # Holds the selects defined by a call to values() or values_list() # excluding annotation_select and extra_select. values_select = () + selected = None # SQL annotation-related attributes. annotation_select_mask = None @@ -584,6 +586,7 @@ class Query(BaseExpression): else: outer_query = self self.select = () + self.selected = None self.default_cols = False self.extra = {} if self.annotations: @@ -1194,13 +1197,10 @@ class Query(BaseExpression): if select: self.append_annotation_mask([alias]) else: - annotation_mask = ( - value - for value in dict.fromkeys(self.annotation_select) - if value != alias - ) - self.set_annotation_mask(annotation_mask) + self.set_annotation_mask(set(self.annotation_select).difference({alias})) self.annotations[alias] = annotation + if self.selected: + self.selected[alias] = alias def resolve_expression(self, query, *args, **kwargs): clone = self.clone() @@ -2153,6 +2153,7 @@ class Query(BaseExpression): self.select_related = False self.set_extra_mask(()) self.set_annotation_mask(()) + self.selected = None def clear_select_fields(self): """ @@ -2162,10 +2163,12 @@ class Query(BaseExpression): """ self.select = () self.values_select = () + self.selected = None def add_select_col(self, col, name): self.select += (col,) self.values_select += (name,) + self.selected[name] = len(self.select) - 1 def set_select(self, cols): self.default_cols = False @@ -2416,12 +2419,23 @@ class Query(BaseExpression): if names is None: self.annotation_select_mask = None else: - self.annotation_select_mask = list(dict.fromkeys(names)) + self.annotation_select_mask = set(names) + if self.selected: + # Prune the masked annotations. + self.selected = { + key: value + for key, value in self.selected.items() + if not isinstance(value, str) + or value in self.annotation_select_mask + } + # Append the unmasked annotations. + for name in names: + self.selected[name] = name self._annotation_select_cache = None def append_annotation_mask(self, names): if self.annotation_select_mask is not None: - self.set_annotation_mask((*self.annotation_select_mask, *names)) + self.set_annotation_mask(self.annotation_select_mask.union(names)) def set_extra_mask(self, names): """ @@ -2440,6 +2454,7 @@ class Query(BaseExpression): self.clear_select_fields() self.has_select_fields = True + selected = {} if fields: field_names = [] extra_names = [] @@ -2448,13 +2463,16 @@ class Query(BaseExpression): # Shortcut - if there are no extra or annotations, then # the values() clause must be just field names. field_names = list(fields) + selected = dict(zip(fields, range(len(fields)))) else: self.default_cols = False for f in fields: - if f in self.extra_select: + if extra := self.extra_select.get(f): extra_names.append(f) + selected[f] = RawSQL(*extra) elif f in self.annotation_select: annotation_names.append(f) + selected[f] = f elif f in self.annotations: raise FieldError( f"Cannot select the '{f}' alias. Use annotate() to " @@ -2466,13 +2484,13 @@ class Query(BaseExpression): # `f` is not resolvable. if self.annotation_select: self.names_to_path(f.split(LOOKUP_SEP), self.model._meta) + selected[f] = len(field_names) field_names.append(f) self.set_extra_mask(extra_names) self.set_annotation_mask(annotation_names) - selected = frozenset(field_names + extra_names + annotation_names) else: field_names = [f.attname for f in self.model._meta.concrete_fields] - selected = frozenset(field_names) + selected = dict.fromkeys(field_names, None) # Selected annotations must be known before setting the GROUP BY # clause. if self.group_by is True: @@ -2495,6 +2513,7 @@ class Query(BaseExpression): self.values_select = tuple(field_names) self.add_fields(field_names, True) + self.selected = selected if fields else None @property def annotation_select(self): @@ -2508,9 +2527,9 @@ class Query(BaseExpression): return {} elif self.annotation_select_mask is not None: self._annotation_select_cache = { - k: self.annotations[k] - for k in self.annotation_select_mask - if k in self.annotations + k: v + for k, v in self.annotations.items() + if k in self.annotation_select_mask } return self._annotation_select_cache else: |
