summaryrefslogtreecommitdiff
path: root/django/db/models/sql/query.py
diff options
context:
space:
mode:
Diffstat (limited to 'django/db/models/sql/query.py')
-rw-r--r--django/db/models/sql/query.py57
1 files changed, 49 insertions, 8 deletions
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
index c8b557103f..d39514a0a5 100644
--- a/django/db/models/sql/query.py
+++ b/django/db/models/sql/query.py
@@ -6,6 +6,7 @@ themselves do not have to (and could be backed by things other than SQL
databases). The abstraction barrier only works one way: this module has to know
all about the internals of models in order to get the information it needs.
"""
+import functools
from collections import Counter, OrderedDict, namedtuple
from collections.abc import Iterator, Mapping
from itertools import chain, count, product
@@ -18,6 +19,7 @@ from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections
from django.db.models.aggregates import Count
from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import Col, Ref
+from django.db.models.fields import Field
from django.db.models.fields.related_lookups import MultiColSource
from django.db.models.lookups import Lookup
from django.db.models.query_utils import (
@@ -56,7 +58,7 @@ def get_children_from_q(q):
JoinInfo = namedtuple(
'JoinInfo',
- ('final_field', 'targets', 'opts', 'joins', 'path')
+ ('final_field', 'targets', 'opts', 'joins', 'path', 'transform_function')
)
@@ -1429,8 +1431,11 @@ class Query:
generate a MultiJoin exception.
Return the final field involved in the joins, the target field (used
- for any 'where' constraint), the final 'opts' value, the joins and the
- field path travelled to generate the joins.
+ for any 'where' constraint), the final 'opts' value, the joins, the
+ field path traveled to generate the joins, and a transform function
+ that takes a field and alias and is equivalent to `field.get_col(alias)`
+ in the simple case but wraps field transforms if they were included in
+ names.
The target field is the field containing the concrete value. Final
field can be something different, for example foreign key pointing to
@@ -1439,10 +1444,46 @@ class Query:
key field for example).
"""
joins = [alias]
- # First, generate the path for the names
- path, final_field, targets, rest = self.names_to_path(
- names, opts, allow_many, fail_on_missing=True)
+ # The transform can't be applied yet, as joins must be trimmed later.
+ # To avoid making every caller of this method look up transforms
+ # directly, compute transforms here and and create a partial that
+ # converts fields to the appropriate wrapped version.
+ def final_transformer(field, alias):
+ return field.get_col(alias)
+
+ # Try resolving all the names as fields first. If there's an error,
+ # treat trailing names as lookups until a field can be resolved.
+ last_field_exception = None
+ for pivot in range(len(names), 0, -1):
+ try:
+ path, final_field, targets, rest = self.names_to_path(
+ names[:pivot], opts, allow_many, fail_on_missing=True,
+ )
+ except FieldError as exc:
+ if pivot == 1:
+ # The first item cannot be a lookup, so it's safe
+ # to raise the field error here.
+ raise
+ else:
+ last_field_exception = exc
+ else:
+ # The transforms are the remaining items that couldn't be
+ # resolved into fields.
+ transforms = names[pivot:]
+ break
+ for name in transforms:
+ def transform(field, alias, *, name, previous):
+ try:
+ wrapped = previous(field, alias)
+ return self.try_transform(wrapped, name)
+ except FieldError:
+ # FieldError is raised if the transform doesn't exist.
+ if isinstance(final_field, Field) and last_field_exception:
+ raise last_field_exception
+ else:
+ raise
+ final_transformer = functools.partial(transform, name=name, previous=final_transformer)
# Then, add the path to the query's joins. Note that we can't trim
# joins at this stage - we will need the information about join type
# of the trimmed joins.
@@ -1470,7 +1511,7 @@ class Query:
joins.append(alias)
if filtered_relation:
filtered_relation.path = joins[:]
- return JoinInfo(final_field, targets, opts, joins, path)
+ return JoinInfo(final_field, targets, opts, joins, path, final_transformer)
def trim_joins(self, targets, joins, path):
"""
@@ -1683,7 +1724,7 @@ class Query:
join_info.path,
)
for target in targets:
- cols.append(target.get_col(final_alias))
+ cols.append(join_info.transform_function(target, final_alias))
if cols:
self.set_select(cols)
except MultiJoin: