diff options
Diffstat (limited to 'django/db/models/sql/query.py')
| -rw-r--r-- | django/db/models/sql/query.py | 57 |
1 files changed, 49 insertions, 8 deletions
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index c8b557103f..d39514a0a5 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -6,6 +6,7 @@ themselves do not have to (and could be backed by things other than SQL databases). The abstraction barrier only works one way: this module has to know all about the internals of models in order to get the information it needs. """ +import functools from collections import Counter, OrderedDict, namedtuple from collections.abc import Iterator, Mapping from itertools import chain, count, product @@ -18,6 +19,7 @@ from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections from django.db.models.aggregates import Count from django.db.models.constants import LOOKUP_SEP from django.db.models.expressions import Col, Ref +from django.db.models.fields import Field from django.db.models.fields.related_lookups import MultiColSource from django.db.models.lookups import Lookup from django.db.models.query_utils import ( @@ -56,7 +58,7 @@ def get_children_from_q(q): JoinInfo = namedtuple( 'JoinInfo', - ('final_field', 'targets', 'opts', 'joins', 'path') + ('final_field', 'targets', 'opts', 'joins', 'path', 'transform_function') ) @@ -1429,8 +1431,11 @@ class Query: generate a MultiJoin exception. Return the final field involved in the joins, the target field (used - for any 'where' constraint), the final 'opts' value, the joins and the - field path travelled to generate the joins. + for any 'where' constraint), the final 'opts' value, the joins, the + field path traveled to generate the joins, and a transform function + that takes a field and alias and is equivalent to `field.get_col(alias)` + in the simple case but wraps field transforms if they were included in + names. The target field is the field containing the concrete value. Final field can be something different, for example foreign key pointing to @@ -1439,10 +1444,46 @@ class Query: key field for example). """ joins = [alias] - # First, generate the path for the names - path, final_field, targets, rest = self.names_to_path( - names, opts, allow_many, fail_on_missing=True) + # The transform can't be applied yet, as joins must be trimmed later. + # To avoid making every caller of this method look up transforms + # directly, compute transforms here and and create a partial that + # converts fields to the appropriate wrapped version. + def final_transformer(field, alias): + return field.get_col(alias) + + # Try resolving all the names as fields first. If there's an error, + # treat trailing names as lookups until a field can be resolved. + last_field_exception = None + for pivot in range(len(names), 0, -1): + try: + path, final_field, targets, rest = self.names_to_path( + names[:pivot], opts, allow_many, fail_on_missing=True, + ) + except FieldError as exc: + if pivot == 1: + # The first item cannot be a lookup, so it's safe + # to raise the field error here. + raise + else: + last_field_exception = exc + else: + # The transforms are the remaining items that couldn't be + # resolved into fields. + transforms = names[pivot:] + break + for name in transforms: + def transform(field, alias, *, name, previous): + try: + wrapped = previous(field, alias) + return self.try_transform(wrapped, name) + except FieldError: + # FieldError is raised if the transform doesn't exist. + if isinstance(final_field, Field) and last_field_exception: + raise last_field_exception + else: + raise + final_transformer = functools.partial(transform, name=name, previous=final_transformer) # Then, add the path to the query's joins. Note that we can't trim # joins at this stage - we will need the information about join type # of the trimmed joins. @@ -1470,7 +1511,7 @@ class Query: joins.append(alias) if filtered_relation: filtered_relation.path = joins[:] - return JoinInfo(final_field, targets, opts, joins, path) + return JoinInfo(final_field, targets, opts, joins, path, final_transformer) def trim_joins(self, targets, joins, path): """ @@ -1683,7 +1724,7 @@ class Query: join_info.path, ) for target in targets: - cols.append(target.get_col(final_alias)) + cols.append(join_info.transform_function(target, final_alias)) if cols: self.set_select(cols) except MultiJoin: |
