diff options
| author | Matthew Wilkes <git@matthewwilkes.name> | 2017-06-18 16:53:40 +0100 |
|---|---|---|
| committer | Tim Graham <timograham@gmail.com> | 2018-02-10 19:08:55 -0500 |
| commit | 2162f0983de0dfe2178531638ce7ea56f54dd4e7 (patch) | |
| tree | bb1e859159200fa7ebeeaa02ec3908e1cf5d2655 /django/db | |
| parent | bf26f66029bca94b007a2452679ac004598364a6 (diff) | |
Fixed #24747 -- Allowed transforms in QuerySet.order_by() and distinct(*fields).
Diffstat (limited to 'django/db')
| -rw-r--r-- | django/db/backends/base/operations.py | 4 | ||||
| -rw-r--r-- | django/db/backends/postgresql/operations.py | 7 | ||||
| -rw-r--r-- | django/db/models/sql/compiler.py | 33 | ||||
| -rw-r--r-- | django/db/models/sql/query.py | 57 |
4 files changed, 74 insertions, 27 deletions
diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py index f6d8925278..465ac70b7b 100644 --- a/django/db/backends/base/operations.py +++ b/django/db/backends/base/operations.py @@ -158,7 +158,7 @@ class BaseDatabaseOperations: """ return '' - def distinct_sql(self, fields): + def distinct_sql(self, fields, params): """ Return an SQL DISTINCT clause which removes duplicate rows from the result set. If any fields are given, only check the given fields for @@ -167,7 +167,7 @@ class BaseDatabaseOperations: if fields: raise NotSupportedError('DISTINCT ON fields is not supported by this database backend') else: - return 'DISTINCT' + return ['DISTINCT'], [] def fetch_returned_insert_id(self, cursor): """ diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py index 3b71cd4f2c..6f48cfa228 100644 --- a/django/db/backends/postgresql/operations.py +++ b/django/db/backends/postgresql/operations.py @@ -207,11 +207,12 @@ class DatabaseOperations(BaseDatabaseOperations): """ return 63 - def distinct_sql(self, fields): + def distinct_sql(self, fields, params): if fields: - return 'DISTINCT ON (%s)' % ', '.join(fields) + params = [param for param_list in params for param in param_list] + return (['DISTINCT ON (%s)' % ', '.join(fields)], params) else: - return 'DISTINCT' + return ['DISTINCT'], [] def last_executed_query(self, cursor, sql, params): # http://initd.org/psycopg/docs/cursor.html#cursor.query diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 1f78a8b5b2..1fdbd156b6 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -451,7 +451,7 @@ class SQLCompiler: raise NotSupportedError('{} is not supported on this database backend.'.format(combinator)) result, params = self.get_combinator_sql(combinator, self.query.combinator_all) else: - distinct_fields = self.get_distinct() + distinct_fields, distinct_params = self.get_distinct() # This must come after 'select', 'ordering', and 'distinct' # (see docstring of get_from_clause() for details). from_, f_params = self.get_from_clause() @@ -461,7 +461,12 @@ class SQLCompiler: params = [] if self.query.distinct: - result.append(self.connection.ops.distinct_sql(distinct_fields)) + distinct_result, distinct_params = self.connection.ops.distinct_sql( + distinct_fields, + distinct_params, + ) + result += distinct_result + params += distinct_params out_cols = [] col_idx = 1 @@ -621,21 +626,22 @@ class SQLCompiler: This method can alter the tables in the query, and thus it must be called before get_from_clause(). """ - qn = self.quote_name_unless_alias - qn2 = self.connection.ops.quote_name result = [] + params = [] opts = self.query.get_meta() for name in self.query.distinct_fields: parts = name.split(LOOKUP_SEP) - _, targets, alias, joins, path, _ = self._setup_joins(parts, opts, None) + _, targets, alias, joins, path, _, transform_function = self._setup_joins(parts, opts, None) targets, alias, _ = self.query.trim_joins(targets, joins, path) for target in targets: if name in self.query.annotation_select: result.append(name) else: - result.append("%s.%s" % (qn(alias), qn2(target.column))) - return result + r, p = self.compile(transform_function(target, alias)) + result.append(r) + params.append(p) + return result, params def find_ordering_name(self, name, opts, alias=None, default_order='ASC', already_seen=None): @@ -647,7 +653,7 @@ class SQLCompiler: name, order = get_order_dir(name, default_order) descending = order == 'DESC' pieces = name.split(LOOKUP_SEP) - field, targets, alias, joins, path, opts = self._setup_joins(pieces, opts, alias) + field, targets, alias, joins, path, opts, transform_function = self._setup_joins(pieces, opts, alias) # If we get to this point and the field is a relation to another model, # append the default ordering for that model unless the attribute name @@ -666,7 +672,7 @@ class SQLCompiler: order, already_seen)) return results targets, alias, _ = self.query.trim_joins(targets, joins, path) - return [(OrderBy(t.get_col(alias), descending=descending), False) for t in targets] + return [(OrderBy(transform_function(t, alias), descending=descending), False) for t in targets] def _setup_joins(self, pieces, opts, alias): """ @@ -677,10 +683,9 @@ class SQLCompiler: match. Executing SQL where this is not true is an error. """ alias = alias or self.query.get_initial_alias() - field, targets, opts, joins, path = self.query.setup_joins( - pieces, opts, alias) + field, targets, opts, joins, path, transform_function = self.query.setup_joins(pieces, opts, alias) alias = joins[-1] - return field, targets, alias, joins, path, opts + return field, targets, alias, joins, path, opts, transform_function def get_from_clause(self): """ @@ -786,7 +791,7 @@ class SQLCompiler: } related_klass_infos.append(klass_info) select_fields = [] - _, _, _, joins, _ = self.query.setup_joins( + _, _, _, joins, _, _ = self.query.setup_joins( [f.name], opts, root_alias) alias = joins[-1] columns = self.get_default_columns(start_alias=alias, opts=f.remote_field.model._meta) @@ -843,7 +848,7 @@ class SQLCompiler: break if name in self.query._filtered_relations: fields_found.add(name) - f, _, join_opts, joins, _ = self.query.setup_joins([name], opts, root_alias) + f, _, join_opts, joins, _, _ = self.query.setup_joins([name], opts, root_alias) model = join_opts.model alias = joins[-1] from_parent = issubclass(model, opts.model) and model is not opts.model diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index c8b557103f..d39514a0a5 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -6,6 +6,7 @@ themselves do not have to (and could be backed by things other than SQL databases). The abstraction barrier only works one way: this module has to know all about the internals of models in order to get the information it needs. """ +import functools from collections import Counter, OrderedDict, namedtuple from collections.abc import Iterator, Mapping from itertools import chain, count, product @@ -18,6 +19,7 @@ from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections from django.db.models.aggregates import Count from django.db.models.constants import LOOKUP_SEP from django.db.models.expressions import Col, Ref +from django.db.models.fields import Field from django.db.models.fields.related_lookups import MultiColSource from django.db.models.lookups import Lookup from django.db.models.query_utils import ( @@ -56,7 +58,7 @@ def get_children_from_q(q): JoinInfo = namedtuple( 'JoinInfo', - ('final_field', 'targets', 'opts', 'joins', 'path') + ('final_field', 'targets', 'opts', 'joins', 'path', 'transform_function') ) @@ -1429,8 +1431,11 @@ class Query: generate a MultiJoin exception. Return the final field involved in the joins, the target field (used - for any 'where' constraint), the final 'opts' value, the joins and the - field path travelled to generate the joins. + for any 'where' constraint), the final 'opts' value, the joins, the + field path traveled to generate the joins, and a transform function + that takes a field and alias and is equivalent to `field.get_col(alias)` + in the simple case but wraps field transforms if they were included in + names. The target field is the field containing the concrete value. Final field can be something different, for example foreign key pointing to @@ -1439,10 +1444,46 @@ class Query: key field for example). """ joins = [alias] - # First, generate the path for the names - path, final_field, targets, rest = self.names_to_path( - names, opts, allow_many, fail_on_missing=True) + # The transform can't be applied yet, as joins must be trimmed later. + # To avoid making every caller of this method look up transforms + # directly, compute transforms here and and create a partial that + # converts fields to the appropriate wrapped version. + def final_transformer(field, alias): + return field.get_col(alias) + + # Try resolving all the names as fields first. If there's an error, + # treat trailing names as lookups until a field can be resolved. + last_field_exception = None + for pivot in range(len(names), 0, -1): + try: + path, final_field, targets, rest = self.names_to_path( + names[:pivot], opts, allow_many, fail_on_missing=True, + ) + except FieldError as exc: + if pivot == 1: + # The first item cannot be a lookup, so it's safe + # to raise the field error here. + raise + else: + last_field_exception = exc + else: + # The transforms are the remaining items that couldn't be + # resolved into fields. + transforms = names[pivot:] + break + for name in transforms: + def transform(field, alias, *, name, previous): + try: + wrapped = previous(field, alias) + return self.try_transform(wrapped, name) + except FieldError: + # FieldError is raised if the transform doesn't exist. + if isinstance(final_field, Field) and last_field_exception: + raise last_field_exception + else: + raise + final_transformer = functools.partial(transform, name=name, previous=final_transformer) # Then, add the path to the query's joins. Note that we can't trim # joins at this stage - we will need the information about join type # of the trimmed joins. @@ -1470,7 +1511,7 @@ class Query: joins.append(alias) if filtered_relation: filtered_relation.path = joins[:] - return JoinInfo(final_field, targets, opts, joins, path) + return JoinInfo(final_field, targets, opts, joins, path, final_transformer) def trim_joins(self, targets, joins, path): """ @@ -1683,7 +1724,7 @@ class Query: join_info.path, ) for target in targets: - cols.append(target.get_col(final_alias)) + cols.append(join_info.transform_function(target, final_alias)) if cols: self.set_select(cols) except MultiJoin: |
