summaryrefslogtreecommitdiff
path: root/django/db
diff options
context:
space:
mode:
authorMatthew Wilkes <git@matthewwilkes.name>2017-06-18 16:53:40 +0100
committerTim Graham <timograham@gmail.com>2018-02-10 19:08:55 -0500
commit2162f0983de0dfe2178531638ce7ea56f54dd4e7 (patch)
treebb1e859159200fa7ebeeaa02ec3908e1cf5d2655 /django/db
parentbf26f66029bca94b007a2452679ac004598364a6 (diff)
Fixed #24747 -- Allowed transforms in QuerySet.order_by() and distinct(*fields).
Diffstat (limited to 'django/db')
-rw-r--r--django/db/backends/base/operations.py4
-rw-r--r--django/db/backends/postgresql/operations.py7
-rw-r--r--django/db/models/sql/compiler.py33
-rw-r--r--django/db/models/sql/query.py57
4 files changed, 74 insertions, 27 deletions
diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py
index f6d8925278..465ac70b7b 100644
--- a/django/db/backends/base/operations.py
+++ b/django/db/backends/base/operations.py
@@ -158,7 +158,7 @@ class BaseDatabaseOperations:
"""
return ''
- def distinct_sql(self, fields):
+ def distinct_sql(self, fields, params):
"""
Return an SQL DISTINCT clause which removes duplicate rows from the
result set. If any fields are given, only check the given fields for
@@ -167,7 +167,7 @@ class BaseDatabaseOperations:
if fields:
raise NotSupportedError('DISTINCT ON fields is not supported by this database backend')
else:
- return 'DISTINCT'
+ return ['DISTINCT'], []
def fetch_returned_insert_id(self, cursor):
"""
diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py
index 3b71cd4f2c..6f48cfa228 100644
--- a/django/db/backends/postgresql/operations.py
+++ b/django/db/backends/postgresql/operations.py
@@ -207,11 +207,12 @@ class DatabaseOperations(BaseDatabaseOperations):
"""
return 63
- def distinct_sql(self, fields):
+ def distinct_sql(self, fields, params):
if fields:
- return 'DISTINCT ON (%s)' % ', '.join(fields)
+ params = [param for param_list in params for param in param_list]
+ return (['DISTINCT ON (%s)' % ', '.join(fields)], params)
else:
- return 'DISTINCT'
+ return ['DISTINCT'], []
def last_executed_query(self, cursor, sql, params):
# http://initd.org/psycopg/docs/cursor.html#cursor.query
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index 1f78a8b5b2..1fdbd156b6 100644
--- a/django/db/models/sql/compiler.py
+++ b/django/db/models/sql/compiler.py
@@ -451,7 +451,7 @@ class SQLCompiler:
raise NotSupportedError('{} is not supported on this database backend.'.format(combinator))
result, params = self.get_combinator_sql(combinator, self.query.combinator_all)
else:
- distinct_fields = self.get_distinct()
+ distinct_fields, distinct_params = self.get_distinct()
# This must come after 'select', 'ordering', and 'distinct'
# (see docstring of get_from_clause() for details).
from_, f_params = self.get_from_clause()
@@ -461,7 +461,12 @@ class SQLCompiler:
params = []
if self.query.distinct:
- result.append(self.connection.ops.distinct_sql(distinct_fields))
+ distinct_result, distinct_params = self.connection.ops.distinct_sql(
+ distinct_fields,
+ distinct_params,
+ )
+ result += distinct_result
+ params += distinct_params
out_cols = []
col_idx = 1
@@ -621,21 +626,22 @@ class SQLCompiler:
This method can alter the tables in the query, and thus it must be
called before get_from_clause().
"""
- qn = self.quote_name_unless_alias
- qn2 = self.connection.ops.quote_name
result = []
+ params = []
opts = self.query.get_meta()
for name in self.query.distinct_fields:
parts = name.split(LOOKUP_SEP)
- _, targets, alias, joins, path, _ = self._setup_joins(parts, opts, None)
+ _, targets, alias, joins, path, _, transform_function = self._setup_joins(parts, opts, None)
targets, alias, _ = self.query.trim_joins(targets, joins, path)
for target in targets:
if name in self.query.annotation_select:
result.append(name)
else:
- result.append("%s.%s" % (qn(alias), qn2(target.column)))
- return result
+ r, p = self.compile(transform_function(target, alias))
+ result.append(r)
+ params.append(p)
+ return result, params
def find_ordering_name(self, name, opts, alias=None, default_order='ASC',
already_seen=None):
@@ -647,7 +653,7 @@ class SQLCompiler:
name, order = get_order_dir(name, default_order)
descending = order == 'DESC'
pieces = name.split(LOOKUP_SEP)
- field, targets, alias, joins, path, opts = self._setup_joins(pieces, opts, alias)
+ field, targets, alias, joins, path, opts, transform_function = self._setup_joins(pieces, opts, alias)
# If we get to this point and the field is a relation to another model,
# append the default ordering for that model unless the attribute name
@@ -666,7 +672,7 @@ class SQLCompiler:
order, already_seen))
return results
targets, alias, _ = self.query.trim_joins(targets, joins, path)
- return [(OrderBy(t.get_col(alias), descending=descending), False) for t in targets]
+ return [(OrderBy(transform_function(t, alias), descending=descending), False) for t in targets]
def _setup_joins(self, pieces, opts, alias):
"""
@@ -677,10 +683,9 @@ class SQLCompiler:
match. Executing SQL where this is not true is an error.
"""
alias = alias or self.query.get_initial_alias()
- field, targets, opts, joins, path = self.query.setup_joins(
- pieces, opts, alias)
+ field, targets, opts, joins, path, transform_function = self.query.setup_joins(pieces, opts, alias)
alias = joins[-1]
- return field, targets, alias, joins, path, opts
+ return field, targets, alias, joins, path, opts, transform_function
def get_from_clause(self):
"""
@@ -786,7 +791,7 @@ class SQLCompiler:
}
related_klass_infos.append(klass_info)
select_fields = []
- _, _, _, joins, _ = self.query.setup_joins(
+ _, _, _, joins, _, _ = self.query.setup_joins(
[f.name], opts, root_alias)
alias = joins[-1]
columns = self.get_default_columns(start_alias=alias, opts=f.remote_field.model._meta)
@@ -843,7 +848,7 @@ class SQLCompiler:
break
if name in self.query._filtered_relations:
fields_found.add(name)
- f, _, join_opts, joins, _ = self.query.setup_joins([name], opts, root_alias)
+ f, _, join_opts, joins, _, _ = self.query.setup_joins([name], opts, root_alias)
model = join_opts.model
alias = joins[-1]
from_parent = issubclass(model, opts.model) and model is not opts.model
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
index c8b557103f..d39514a0a5 100644
--- a/django/db/models/sql/query.py
+++ b/django/db/models/sql/query.py
@@ -6,6 +6,7 @@ themselves do not have to (and could be backed by things other than SQL
databases). The abstraction barrier only works one way: this module has to know
all about the internals of models in order to get the information it needs.
"""
+import functools
from collections import Counter, OrderedDict, namedtuple
from collections.abc import Iterator, Mapping
from itertools import chain, count, product
@@ -18,6 +19,7 @@ from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections
from django.db.models.aggregates import Count
from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import Col, Ref
+from django.db.models.fields import Field
from django.db.models.fields.related_lookups import MultiColSource
from django.db.models.lookups import Lookup
from django.db.models.query_utils import (
@@ -56,7 +58,7 @@ def get_children_from_q(q):
JoinInfo = namedtuple(
'JoinInfo',
- ('final_field', 'targets', 'opts', 'joins', 'path')
+ ('final_field', 'targets', 'opts', 'joins', 'path', 'transform_function')
)
@@ -1429,8 +1431,11 @@ class Query:
generate a MultiJoin exception.
Return the final field involved in the joins, the target field (used
- for any 'where' constraint), the final 'opts' value, the joins and the
- field path travelled to generate the joins.
+ for any 'where' constraint), the final 'opts' value, the joins, the
+ field path traveled to generate the joins, and a transform function
+ that takes a field and alias and is equivalent to `field.get_col(alias)`
+ in the simple case but wraps field transforms if they were included in
+ names.
The target field is the field containing the concrete value. Final
field can be something different, for example foreign key pointing to
@@ -1439,10 +1444,46 @@ class Query:
key field for example).
"""
joins = [alias]
- # First, generate the path for the names
- path, final_field, targets, rest = self.names_to_path(
- names, opts, allow_many, fail_on_missing=True)
+ # The transform can't be applied yet, as joins must be trimmed later.
+ # To avoid making every caller of this method look up transforms
+ # directly, compute transforms here and and create a partial that
+ # converts fields to the appropriate wrapped version.
+ def final_transformer(field, alias):
+ return field.get_col(alias)
+
+ # Try resolving all the names as fields first. If there's an error,
+ # treat trailing names as lookups until a field can be resolved.
+ last_field_exception = None
+ for pivot in range(len(names), 0, -1):
+ try:
+ path, final_field, targets, rest = self.names_to_path(
+ names[:pivot], opts, allow_many, fail_on_missing=True,
+ )
+ except FieldError as exc:
+ if pivot == 1:
+ # The first item cannot be a lookup, so it's safe
+ # to raise the field error here.
+ raise
+ else:
+ last_field_exception = exc
+ else:
+ # The transforms are the remaining items that couldn't be
+ # resolved into fields.
+ transforms = names[pivot:]
+ break
+ for name in transforms:
+ def transform(field, alias, *, name, previous):
+ try:
+ wrapped = previous(field, alias)
+ return self.try_transform(wrapped, name)
+ except FieldError:
+ # FieldError is raised if the transform doesn't exist.
+ if isinstance(final_field, Field) and last_field_exception:
+ raise last_field_exception
+ else:
+ raise
+ final_transformer = functools.partial(transform, name=name, previous=final_transformer)
# Then, add the path to the query's joins. Note that we can't trim
# joins at this stage - we will need the information about join type
# of the trimmed joins.
@@ -1470,7 +1511,7 @@ class Query:
joins.append(alias)
if filtered_relation:
filtered_relation.path = joins[:]
- return JoinInfo(final_field, targets, opts, joins, path)
+ return JoinInfo(final_field, targets, opts, joins, path, final_transformer)
def trim_joins(self, targets, joins, path):
"""
@@ -1683,7 +1724,7 @@ class Query:
join_info.path,
)
for target in targets:
- cols.append(target.get_col(final_alias))
+ cols.append(join_info.transform_function(target, final_alias))
if cols:
self.set_select(cols)
except MultiJoin: