summaryrefslogtreecommitdiff
path: root/django/db/models/sql
diff options
context:
space:
mode:
authordjango-bot <ops@djangoproject.com>2022-02-03 20:24:19 +0100
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2022-02-07 20:37:05 +0100
commit9c19aff7c7561e3a82978a272ecdaad40dda5c00 (patch)
treef0506b668a013d0063e5fba3dbf4863b466713ba /django/db/models/sql
parentf68fa8b45dfac545cfc4111d4e52804c86db68d3 (diff)
Refs #33476 -- Reformatted code with Black.
Diffstat (limited to 'django/db/models/sql')
-rw-r--r--django/db/models/sql/__init__.py2
-rw-r--r--django/db/models/sql/compiler.py770
-rw-r--r--django/db/models/sql/constants.py16
-rw-r--r--django/db/models/sql/datastructures.py74
-rw-r--r--django/db/models/sql/query.py616
-rw-r--r--django/db/models/sql/subqueries.py42
-rw-r--r--django/db/models/sql/where.py57
7 files changed, 997 insertions, 580 deletions
diff --git a/django/db/models/sql/__init__.py b/django/db/models/sql/__init__.py
index 5fa52f6a1f..2956e047b1 100644
--- a/django/db/models/sql/__init__.py
+++ b/django/db/models/sql/__init__.py
@@ -3,4 +3,4 @@ from django.db.models.sql.query import Query
from django.db.models.sql.subqueries import * # NOQA
from django.db.models.sql.where import AND, OR
-__all__ = ['Query', 'AND', 'OR']
+__all__ = ["Query", "AND", "OR"]
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index d405a203ee..13a7ec7263 100644
--- a/django/db/models/sql/compiler.py
+++ b/django/db/models/sql/compiler.py
@@ -11,7 +11,12 @@ from django.db.models.expressions import F, OrderBy, RawSQL, Ref, Value
from django.db.models.functions import Cast, Random
from django.db.models.query_utils import select_related_descend
from django.db.models.sql.constants import (
- CURSOR, GET_ITERATOR_CHUNK_SIZE, MULTI, NO_RESULTS, ORDER_DIR, SINGLE,
+ CURSOR,
+ GET_ITERATOR_CHUNK_SIZE,
+ MULTI,
+ NO_RESULTS,
+ ORDER_DIR,
+ SINGLE,
)
from django.db.models.sql.query import Query, get_order_dir
from django.db.transaction import TransactionManagementError
@@ -23,7 +28,7 @@ from django.utils.regex_helper import _lazy_re_compile
class SQLCompiler:
# Multiline ordering SQL clause may appear from RawSQL.
ordering_parts = _lazy_re_compile(
- r'^(.*)\s(?:ASC|DESC).*',
+ r"^(.*)\s(?:ASC|DESC).*",
re.MULTILINE | re.DOTALL,
)
@@ -34,7 +39,7 @@ class SQLCompiler:
# Some queries, e.g. coalesced aggregation, need to be executed even if
# they would return an empty result set.
self.elide_empty = elide_empty
- self.quote_cache = {'*': '*'}
+ self.quote_cache = {"*": "*"}
# The select, klass_info, and annotations are needed by QuerySet.iterator()
# these are set as a side-effect of executing the query. Note that we calculate
# separately a list of extra select columns needed for grammatical correctness
@@ -46,9 +51,9 @@ class SQLCompiler:
def __repr__(self):
return (
- f'<{self.__class__.__qualname__} '
- f'model={self.query.model.__qualname__} '
- f'connection={self.connection!r} using={self.using!r}>'
+ f"<{self.__class__.__qualname__} "
+ f"model={self.query.model.__qualname__} "
+ f"connection={self.connection!r} using={self.using!r}>"
)
def setup_query(self):
@@ -118,16 +123,14 @@ class SQLCompiler:
# when we have public API way of forcing the GROUP BY clause.
# Converts string references to expressions.
for expr in self.query.group_by:
- if not hasattr(expr, 'as_sql'):
+ if not hasattr(expr, "as_sql"):
expressions.append(self.query.resolve_ref(expr))
else:
expressions.append(expr)
# Note that even if the group_by is set, it is only the minimal
# set to group by. So, we need to add cols in select, order_by, and
# having into the select in any case.
- ref_sources = {
- expr.source for expr in expressions if isinstance(expr, Ref)
- }
+ ref_sources = {expr.source for expr in expressions if isinstance(expr, Ref)}
for expr, _, _ in select:
# Skip members of the select clause that are already included
# by reference.
@@ -169,8 +172,10 @@ class SQLCompiler:
for expr in expressions:
# Is this a reference to query's base table primary key? If the
# expression isn't a Col-like, then skip the expression.
- if (getattr(expr, 'target', None) == self.query.model._meta.pk and
- getattr(expr, 'alias', None) == self.query.base_table):
+ if (
+ getattr(expr, "target", None) == self.query.model._meta.pk
+ and getattr(expr, "alias", None) == self.query.base_table
+ ):
pk = expr
break
# If the main model's primary key is in the query, group by that
@@ -178,13 +183,17 @@ class SQLCompiler:
# that don't have a primary key included in the grouped columns.
if pk:
pk_aliases = {
- expr.alias for expr in expressions
- if hasattr(expr, 'target') and expr.target.primary_key
+ expr.alias
+ for expr in expressions
+ if hasattr(expr, "target") and expr.target.primary_key
}
expressions = [pk] + [
- expr for expr in expressions
- if expr in having or (
- getattr(expr, 'alias', None) is not None and expr.alias not in pk_aliases
+ expr
+ for expr in expressions
+ if expr in having
+ or (
+ getattr(expr, "alias", None) is not None
+ and expr.alias not in pk_aliases
)
]
elif self.connection.features.allows_group_by_selected_pks:
@@ -195,16 +204,21 @@ class SQLCompiler:
# Unmanaged models are excluded because they could be representing
# database views on which the optimization might not be allowed.
pks = {
- expr for expr in expressions
+ expr
+ for expr in expressions
if (
- hasattr(expr, 'target') and
- expr.target.primary_key and
- self.connection.features.allows_group_by_selected_pks_on_model(expr.target.model)
+ hasattr(expr, "target")
+ and expr.target.primary_key
+ and self.connection.features.allows_group_by_selected_pks_on_model(
+ expr.target.model
+ )
)
}
aliases = {expr.alias for expr in pks}
expressions = [
- expr for expr in expressions if expr in pks or getattr(expr, 'alias', None) not in aliases
+ expr
+ for expr in expressions
+ if expr in pks or getattr(expr, "alias", None) not in aliases
]
return expressions
@@ -248,8 +262,8 @@ class SQLCompiler:
select.append((col, None))
select_idx += 1
klass_info = {
- 'model': self.query.model,
- 'select_fields': select_list,
+ "model": self.query.model,
+ "select_fields": select_list,
}
for alias, annotation in self.query.annotation_select.items():
annotations[alias] = select_idx
@@ -258,14 +272,16 @@ class SQLCompiler:
if self.query.select_related:
related_klass_infos = self.get_related_selections(select)
- klass_info['related_klass_infos'] = related_klass_infos
+ klass_info["related_klass_infos"] = related_klass_infos
def get_select_from_parent(klass_info):
- for ki in klass_info['related_klass_infos']:
- if ki['from_parent']:
- ki['select_fields'] = (klass_info['select_fields'] +
- ki['select_fields'])
+ for ki in klass_info["related_klass_infos"]:
+ if ki["from_parent"]:
+ ki["select_fields"] = (
+ klass_info["select_fields"] + ki["select_fields"]
+ )
get_select_from_parent(ki)
+
get_select_from_parent(klass_info)
ret = []
@@ -273,10 +289,12 @@ class SQLCompiler:
try:
sql, params = self.compile(col)
except EmptyResultSet:
- empty_result_set_value = getattr(col, 'empty_result_set_value', NotImplemented)
+ empty_result_set_value = getattr(
+ col, "empty_result_set_value", NotImplemented
+ )
if empty_result_set_value is NotImplemented:
# Select a predicate that's always False.
- sql, params = '0', ()
+ sql, params = "0", ()
else:
sql, params = self.compile(Value(empty_result_set_value))
else:
@@ -297,12 +315,12 @@ class SQLCompiler:
else:
ordering = []
if self.query.standard_ordering:
- default_order, _ = ORDER_DIR['ASC']
+ default_order, _ = ORDER_DIR["ASC"]
else:
- default_order, _ = ORDER_DIR['DESC']
+ default_order, _ = ORDER_DIR["DESC"]
for field in ordering:
- if hasattr(field, 'resolve_expression'):
+ if hasattr(field, "resolve_expression"):
if isinstance(field, Value):
# output_field must be resolved for constants.
field = Cast(field, field.output_field)
@@ -313,12 +331,12 @@ class SQLCompiler:
field.reverse_ordering()
yield field, False
continue
- if field == '?': # random
+ if field == "?": # random
yield OrderBy(Random()), False
continue
col, order = get_order_dir(field, default_order)
- descending = order == 'DESC'
+ descending = order == "DESC"
if col in self.query.annotation_select:
# Reference to expression in SELECT clause
@@ -345,13 +363,15 @@ class SQLCompiler:
yield OrderBy(expr, descending=descending), False
continue
- if '.' in field:
+ if "." in field:
# This came in through an extra(order_by=...) addition. Pass it
# on verbatim.
- table, col = col.split('.', 1)
+ table, col = col.split(".", 1)
yield (
OrderBy(
- RawSQL('%s.%s' % (self.quote_name_unless_alias(table), col), []),
+ RawSQL(
+ "%s.%s" % (self.quote_name_unless_alias(table), col), []
+ ),
descending=descending,
),
False,
@@ -361,7 +381,10 @@ class SQLCompiler:
if self.query.extra and col in self.query.extra:
if col in self.query.extra_select:
yield (
- OrderBy(Ref(col, RawSQL(*self.query.extra[col])), descending=descending),
+ OrderBy(
+ Ref(col, RawSQL(*self.query.extra[col])),
+ descending=descending,
+ ),
True,
)
else:
@@ -378,7 +401,9 @@ class SQLCompiler:
# 'col' is of the form 'field' or 'field1__field2' or
# '-field1__field2__field', etc.
yield from self.find_ordering_name(
- field, self.query.get_meta(), default_order=default_order,
+ field,
+ self.query.get_meta(),
+ default_order=default_order,
)
def get_order_by(self):
@@ -409,19 +434,21 @@ class SQLCompiler:
):
continue
if src == sel_expr:
- resolved.set_source_expressions([RawSQL('%d' % (idx + 1), ())])
+ resolved.set_source_expressions([RawSQL("%d" % (idx + 1), ())])
break
else:
if col_alias:
- raise DatabaseError('ORDER BY term does not match any column in the result set.')
+ raise DatabaseError(
+ "ORDER BY term does not match any column in the result set."
+ )
# Add column used in ORDER BY clause to the selected
# columns and to each combined query.
order_by_idx = len(self.query.select) + 1
- col_name = f'__orderbycol{order_by_idx}'
+ col_name = f"__orderbycol{order_by_idx}"
for q in self.query.combined_queries:
q.add_annotation(expr_src, col_name)
self.query.add_select_col(resolved, col_name)
- resolved.set_source_expressions([RawSQL(f'{order_by_idx}', ())])
+ resolved.set_source_expressions([RawSQL(f"{order_by_idx}", ())])
sql, params = self.compile(resolved)
# Don't add the same column twice, but the order direction is
# not taken into account so we strip it. When this entire method
@@ -453,9 +480,14 @@ class SQLCompiler:
"""
if name in self.quote_cache:
return self.quote_cache[name]
- if ((name in self.query.alias_map and name not in self.query.table_map) or
- name in self.query.extra_select or (
- self.query.external_aliases.get(name) and name not in self.query.table_map)):
+ if (
+ (name in self.query.alias_map and name not in self.query.table_map)
+ or name in self.query.extra_select
+ or (
+ self.query.external_aliases.get(name)
+ and name not in self.query.table_map
+ )
+ ):
self.quote_cache[name] = name
return name
r = self.connection.ops.quote_name(name)
@@ -463,7 +495,7 @@ class SQLCompiler:
return r
def compile(self, node):
- vendor_impl = getattr(node, 'as_' + self.connection.vendor, None)
+ vendor_impl = getattr(node, "as_" + self.connection.vendor, None)
if vendor_impl:
sql, params = vendor_impl(self, self.connection)
else:
@@ -474,14 +506,19 @@ class SQLCompiler:
features = self.connection.features
compilers = [
query.get_compiler(self.using, self.connection, self.elide_empty)
- for query in self.query.combined_queries if not query.is_empty()
+ for query in self.query.combined_queries
+ if not query.is_empty()
]
if not features.supports_slicing_ordering_in_compound:
for query, compiler in zip(self.query.combined_queries, compilers):
if query.low_mark or query.high_mark:
- raise DatabaseError('LIMIT/OFFSET not allowed in subqueries of compound statements.')
+ raise DatabaseError(
+ "LIMIT/OFFSET not allowed in subqueries of compound statements."
+ )
if compiler.get_order_by():
- raise DatabaseError('ORDER BY not allowed in subqueries of compound statements.')
+ raise DatabaseError(
+ "ORDER BY not allowed in subqueries of compound statements."
+ )
parts = ()
for compiler in compilers:
try:
@@ -490,41 +527,45 @@ class SQLCompiler:
# the query on all combined queries, if not already set.
if not compiler.query.values_select and self.query.values_select:
compiler.query = compiler.query.clone()
- compiler.query.set_values((
- *self.query.extra_select,
- *self.query.values_select,
- *self.query.annotation_select,
- ))
+ compiler.query.set_values(
+ (
+ *self.query.extra_select,
+ *self.query.values_select,
+ *self.query.annotation_select,
+ )
+ )
part_sql, part_args = compiler.as_sql()
if compiler.query.combinator:
# Wrap in a subquery if wrapping in parentheses isn't
# supported.
if not features.supports_parentheses_in_compound:
- part_sql = 'SELECT * FROM ({})'.format(part_sql)
+ part_sql = "SELECT * FROM ({})".format(part_sql)
# Add parentheses when combining with compound query if not
# already added for all compound queries.
elif (
- self.query.subquery or
- not features.supports_slicing_ordering_in_compound
+ self.query.subquery
+ or not features.supports_slicing_ordering_in_compound
):
- part_sql = '({})'.format(part_sql)
+ part_sql = "({})".format(part_sql)
parts += ((part_sql, part_args),)
except EmptyResultSet:
# Omit the empty queryset with UNION and with DIFFERENCE if the
# first queryset is nonempty.
- if combinator == 'union' or (combinator == 'difference' and parts):
+ if combinator == "union" or (combinator == "difference" and parts):
continue
raise
if not parts:
raise EmptyResultSet
combinator_sql = self.connection.ops.set_operators[combinator]
- if all and combinator == 'union':
- combinator_sql += ' ALL'
- braces = '{}'
+ if all and combinator == "union":
+ combinator_sql += " ALL"
+ braces = "{}"
if not self.query.subquery and features.supports_slicing_ordering_in_compound:
- braces = '({})'
- sql_parts, args_parts = zip(*((braces.format(sql), args) for sql, args in parts))
- result = [' {} '.format(combinator_sql).join(sql_parts)]
+ braces = "({})"
+ sql_parts, args_parts = zip(
+ *((braces.format(sql), args) for sql, args in parts)
+ )
+ result = [" {} ".format(combinator_sql).join(sql_parts)]
params = []
for part in args_parts:
params.extend(part)
@@ -543,27 +584,39 @@ class SQLCompiler:
extra_select, order_by, group_by = self.pre_sql_setup()
for_update_part = None
# Is a LIMIT/OFFSET clause needed?
- with_limit_offset = with_limits and (self.query.high_mark is not None or self.query.low_mark)
+ with_limit_offset = with_limits and (
+ self.query.high_mark is not None or self.query.low_mark
+ )
combinator = self.query.combinator
features = self.connection.features
if combinator:
- if not getattr(features, 'supports_select_{}'.format(combinator)):
- raise NotSupportedError('{} is not supported on this database backend.'.format(combinator))
- result, params = self.get_combinator_sql(combinator, self.query.combinator_all)
+ if not getattr(features, "supports_select_{}".format(combinator)):
+ raise NotSupportedError(
+ "{} is not supported on this database backend.".format(
+ combinator
+ )
+ )
+ result, params = self.get_combinator_sql(
+ combinator, self.query.combinator_all
+ )
else:
distinct_fields, distinct_params = self.get_distinct()
# This must come after 'select', 'ordering', and 'distinct'
# (see docstring of get_from_clause() for details).
from_, f_params = self.get_from_clause()
try:
- where, w_params = self.compile(self.where) if self.where is not None else ('', [])
+ where, w_params = (
+ self.compile(self.where) if self.where is not None else ("", [])
+ )
except EmptyResultSet:
if self.elide_empty:
raise
# Use a predicate that's always False.
- where, w_params = '0 = 1', []
- having, h_params = self.compile(self.having) if self.having is not None else ("", [])
- result = ['SELECT']
+ where, w_params = "0 = 1", []
+ having, h_params = (
+ self.compile(self.having) if self.having is not None else ("", [])
+ )
+ result = ["SELECT"]
params = []
if self.query.distinct:
@@ -578,27 +631,38 @@ class SQLCompiler:
col_idx = 1
for _, (s_sql, s_params), alias in self.select + extra_select:
if alias:
- s_sql = '%s AS %s' % (s_sql, self.connection.ops.quote_name(alias))
+ s_sql = "%s AS %s" % (
+ s_sql,
+ self.connection.ops.quote_name(alias),
+ )
elif with_col_aliases:
- s_sql = '%s AS %s' % (
+ s_sql = "%s AS %s" % (
s_sql,
- self.connection.ops.quote_name('col%d' % col_idx),
+ self.connection.ops.quote_name("col%d" % col_idx),
)
col_idx += 1
params.extend(s_params)
out_cols.append(s_sql)
- result += [', '.join(out_cols), 'FROM', *from_]
+ result += [", ".join(out_cols), "FROM", *from_]
params.extend(f_params)
- if self.query.select_for_update and self.connection.features.has_select_for_update:
+ if (
+ self.query.select_for_update
+ and self.connection.features.has_select_for_update
+ ):
if self.connection.get_autocommit():
- raise TransactionManagementError('select_for_update cannot be used outside of a transaction.')
+ raise TransactionManagementError(
+ "select_for_update cannot be used outside of a transaction."
+ )
- if with_limit_offset and not self.connection.features.supports_select_for_update_with_limit:
+ if (
+ with_limit_offset
+ and not self.connection.features.supports_select_for_update_with_limit
+ ):
raise NotSupportedError(
- 'LIMIT/OFFSET is not supported with '
- 'select_for_update on this database backend.'
+ "LIMIT/OFFSET is not supported with "
+ "select_for_update on this database backend."
)
nowait = self.query.select_for_update_nowait
skip_locked = self.query.select_for_update_skip_locked
@@ -607,16 +671,31 @@ class SQLCompiler:
# If it's a NOWAIT/SKIP LOCKED/OF/NO KEY query but the
# backend doesn't support it, raise NotSupportedError to
# prevent a possible deadlock.
- if nowait and not self.connection.features.has_select_for_update_nowait:
- raise NotSupportedError('NOWAIT is not supported on this database backend.')
- elif skip_locked and not self.connection.features.has_select_for_update_skip_locked:
- raise NotSupportedError('SKIP LOCKED is not supported on this database backend.')
+ if (
+ nowait
+ and not self.connection.features.has_select_for_update_nowait
+ ):
+ raise NotSupportedError(
+ "NOWAIT is not supported on this database backend."
+ )
+ elif (
+ skip_locked
+ and not self.connection.features.has_select_for_update_skip_locked
+ ):
+ raise NotSupportedError(
+ "SKIP LOCKED is not supported on this database backend."
+ )
elif of and not self.connection.features.has_select_for_update_of:
- raise NotSupportedError('FOR UPDATE OF is not supported on this database backend.')
- elif no_key and not self.connection.features.has_select_for_no_key_update:
raise NotSupportedError(
- 'FOR NO KEY UPDATE is not supported on this '
- 'database backend.'
+ "FOR UPDATE OF is not supported on this database backend."
+ )
+ elif (
+ no_key
+ and not self.connection.features.has_select_for_no_key_update
+ ):
+ raise NotSupportedError(
+ "FOR NO KEY UPDATE is not supported on this "
+ "database backend."
)
for_update_part = self.connection.ops.for_update_sql(
nowait=nowait,
@@ -629,7 +708,7 @@ class SQLCompiler:
result.append(for_update_part)
if where:
- result.append('WHERE %s' % where)
+ result.append("WHERE %s" % where)
params.extend(w_params)
grouping = []
@@ -638,30 +717,39 @@ class SQLCompiler:
params.extend(g_params)
if grouping:
if distinct_fields:
- raise NotImplementedError('annotate() + distinct(fields) is not implemented.')
+ raise NotImplementedError(
+ "annotate() + distinct(fields) is not implemented."
+ )
order_by = order_by or self.connection.ops.force_no_ordering()
- result.append('GROUP BY %s' % ', '.join(grouping))
+ result.append("GROUP BY %s" % ", ".join(grouping))
if self._meta_ordering:
order_by = None
if having:
- result.append('HAVING %s' % having)
+ result.append("HAVING %s" % having)
params.extend(h_params)
if self.query.explain_info:
- result.insert(0, self.connection.ops.explain_query_prefix(
- self.query.explain_info.format,
- **self.query.explain_info.options
- ))
+ result.insert(
+ 0,
+ self.connection.ops.explain_query_prefix(
+ self.query.explain_info.format,
+ **self.query.explain_info.options,
+ ),
+ )
if order_by:
ordering = []
for _, (o_sql, o_params, _) in order_by:
ordering.append(o_sql)
params.extend(o_params)
- result.append('ORDER BY %s' % ', '.join(ordering))
+ result.append("ORDER BY %s" % ", ".join(ordering))
if with_limit_offset:
- result.append(self.connection.ops.limit_offset_sql(self.query.low_mark, self.query.high_mark))
+ result.append(
+ self.connection.ops.limit_offset_sql(
+ self.query.low_mark, self.query.high_mark
+ )
+ )
if for_update_part and not self.connection.features.for_update_after_from:
result.append(for_update_part)
@@ -677,23 +765,30 @@ class SQLCompiler:
sub_params = []
for index, (select, _, alias) in enumerate(self.select, start=1):
if not alias and with_col_aliases:
- alias = 'col%d' % index
+ alias = "col%d" % index
if alias:
- sub_selects.append("%s.%s" % (
- self.connection.ops.quote_name('subquery'),
- self.connection.ops.quote_name(alias),
- ))
+ sub_selects.append(
+ "%s.%s"
+ % (
+ self.connection.ops.quote_name("subquery"),
+ self.connection.ops.quote_name(alias),
+ )
+ )
else:
- select_clone = select.relabeled_clone({select.alias: 'subquery'})
- subselect, subparams = select_clone.as_sql(self, self.connection)
+ select_clone = select.relabeled_clone(
+ {select.alias: "subquery"}
+ )
+ subselect, subparams = select_clone.as_sql(
+ self, self.connection
+ )
sub_selects.append(subselect)
sub_params.extend(subparams)
- return 'SELECT %s FROM (%s) subquery' % (
- ', '.join(sub_selects),
- ' '.join(result),
+ return "SELECT %s FROM (%s) subquery" % (
+ ", ".join(sub_selects),
+ " ".join(result),
), tuple(sub_params + params)
- return ' '.join(result), tuple(params)
+ return " ".join(result), tuple(params)
finally:
# Finally do cleanup - get rid of the joins we created above.
self.query.reset_refcounts(refcounts_before)
@@ -726,8 +821,13 @@ class SQLCompiler:
# will assign None if the field belongs to this model.
if model == opts.model:
model = None
- if from_parent and model is not None and issubclass(
- from_parent._meta.concrete_model, model._meta.concrete_model):
+ if (
+ from_parent
+ and model is not None
+ and issubclass(
+ from_parent._meta.concrete_model, model._meta.concrete_model
+ )
+ ):
# Avoid loading data for already loaded parents.
# We end up here in the case select_related() resolution
# proceeds from parent model to child model. In that case the
@@ -736,8 +836,7 @@ class SQLCompiler:
continue
if field.model in only_load and field.attname not in only_load[field.model]:
continue
- alias = self.query.join_parent_model(opts, model, start_alias,
- seen_models)
+ alias = self.query.join_parent_model(opts, model, start_alias, seen_models)
column = field.get_col(alias)
result.append(column)
return result
@@ -755,7 +854,9 @@ class SQLCompiler:
for name in self.query.distinct_fields:
parts = name.split(LOOKUP_SEP)
- _, targets, alias, joins, path, _, transform_function = self._setup_joins(parts, opts, None)
+ _, targets, alias, joins, path, _, transform_function = self._setup_joins(
+ parts, opts, None
+ )
targets, alias, _ = self.query.trim_joins(targets, joins, path)
for target in targets:
if name in self.query.annotation_select:
@@ -766,46 +867,63 @@ class SQLCompiler:
params.append(p)
return result, params
- def find_ordering_name(self, name, opts, alias=None, default_order='ASC',
- already_seen=None):
+ def find_ordering_name(
+ self, name, opts, alias=None, default_order="ASC", already_seen=None
+ ):
"""
Return the table alias (the name might be ambiguous, the alias will
not be) and column name for ordering by the given 'name' parameter.
The 'name' is of the form 'field1__field2__...__fieldN'.
"""
name, order = get_order_dir(name, default_order)
- descending = order == 'DESC'
+ descending = order == "DESC"
pieces = name.split(LOOKUP_SEP)
- field, targets, alias, joins, path, opts, transform_function = self._setup_joins(pieces, opts, alias)
+ (
+ field,
+ targets,
+ alias,
+ joins,
+ path,
+ opts,
+ transform_function,
+ ) = self._setup_joins(pieces, opts, alias)
# If we get to this point and the field is a relation to another model,
# append the default ordering for that model unless it is the pk
# shortcut or the attribute name of the field that is specified.
if (
- field.is_relation and
- opts.ordering and
- getattr(field, 'attname', None) != pieces[-1] and
- name != 'pk'
+ field.is_relation
+ and opts.ordering
+ and getattr(field, "attname", None) != pieces[-1]
+ and name != "pk"
):
# Firstly, avoid infinite loops.
already_seen = already_seen or set()
- join_tuple = tuple(getattr(self.query.alias_map[j], 'join_cols', None) for j in joins)
+ join_tuple = tuple(
+ getattr(self.query.alias_map[j], "join_cols", None) for j in joins
+ )
if join_tuple in already_seen:
- raise FieldError('Infinite loop caused by ordering.')
+ raise FieldError("Infinite loop caused by ordering.")
already_seen.add(join_tuple)
results = []
for item in opts.ordering:
- if hasattr(item, 'resolve_expression') and not isinstance(item, OrderBy):
+ if hasattr(item, "resolve_expression") and not isinstance(
+ item, OrderBy
+ ):
item = item.desc() if descending else item.asc()
if isinstance(item, OrderBy):
results.append((item, False))
continue
- results.extend(self.find_ordering_name(item, opts, alias,
- order, already_seen))
+ results.extend(
+ self.find_ordering_name(item, opts, alias, order, already_seen)
+ )
return results
targets, alias, _ = self.query.trim_joins(targets, joins, path)
- return [(OrderBy(transform_function(t, alias), descending=descending), False) for t in targets]
+ return [
+ (OrderBy(transform_function(t, alias), descending=descending), False)
+ for t in targets
+ ]
def _setup_joins(self, pieces, opts, alias):
"""
@@ -816,7 +934,9 @@ class SQLCompiler:
match. Executing SQL where this is not true is an error.
"""
alias = alias or self.query.get_initial_alias()
- field, targets, opts, joins, path, transform_function = self.query.setup_joins(pieces, opts, alias)
+ field, targets, opts, joins, path, transform_function = self.query.setup_joins(
+ pieces, opts, alias
+ )
alias = joins[-1]
return field, targets, alias, joins, path, opts, transform_function
@@ -850,25 +970,39 @@ class SQLCompiler:
# Only add the alias if it's not already present (the table_alias()
# call increments the refcount, so an alias refcount of one means
# this is the only reference).
- if alias not in self.query.alias_map or self.query.alias_refcount[alias] == 1:
- result.append(', %s' % self.quote_name_unless_alias(alias))
+ if (
+ alias not in self.query.alias_map
+ or self.query.alias_refcount[alias] == 1
+ ):
+ result.append(", %s" % self.quote_name_unless_alias(alias))
return result, params
- def get_related_selections(self, select, opts=None, root_alias=None, cur_depth=1,
- requested=None, restricted=None):
+ def get_related_selections(
+ self,
+ select,
+ opts=None,
+ root_alias=None,
+ cur_depth=1,
+ requested=None,
+ restricted=None,
+ ):
"""
Fill in the information needed for a select_related query. The current
depth is measured as the number of connections away from the root model
(for example, cur_depth=1 means we are looking at models with direct
connections to the root model).
"""
+
def _get_field_choices():
direct_choices = (f.name for f in opts.fields if f.is_relation)
reverse_choices = (
f.field.related_query_name()
- for f in opts.related_objects if f.field.unique
+ for f in opts.related_objects
+ if f.field.unique
+ )
+ return chain(
+ direct_choices, reverse_choices, self.query._filtered_relations
)
- return chain(direct_choices, reverse_choices, self.query._filtered_relations)
related_klass_infos = []
if not restricted and cur_depth > self.query.max_depth:
@@ -889,7 +1023,7 @@ class SQLCompiler:
requested = self.query.select_related
def get_related_klass_infos(klass_info, related_klass_infos):
- klass_info['related_klass_infos'] = related_klass_infos
+ klass_info["related_klass_infos"] = related_klass_infos
for f in opts.fields:
field_model = f.model._meta.concrete_model
@@ -903,37 +1037,48 @@ class SQLCompiler:
if next or f.name in requested:
raise FieldError(
"Non-relational field given in select_related: '%s'. "
- "Choices are: %s" % (
+ "Choices are: %s"
+ % (
f.name,
- ", ".join(_get_field_choices()) or '(none)',
+ ", ".join(_get_field_choices()) or "(none)",
)
)
else:
next = False
- if not select_related_descend(f, restricted, requested,
- only_load.get(field_model)):
+ if not select_related_descend(
+ f, restricted, requested, only_load.get(field_model)
+ ):
continue
klass_info = {
- 'model': f.remote_field.model,
- 'field': f,
- 'reverse': False,
- 'local_setter': f.set_cached_value,
- 'remote_setter': f.remote_field.set_cached_value if f.unique else lambda x, y: None,
- 'from_parent': False,
+ "model": f.remote_field.model,
+ "field": f,
+ "reverse": False,
+ "local_setter": f.set_cached_value,
+ "remote_setter": f.remote_field.set_cached_value
+ if f.unique
+ else lambda x, y: None,
+ "from_parent": False,
}
related_klass_infos.append(klass_info)
select_fields = []
- _, _, _, joins, _, _ = self.query.setup_joins(
- [f.name], opts, root_alias)
+ _, _, _, joins, _, _ = self.query.setup_joins([f.name], opts, root_alias)
alias = joins[-1]
- columns = self.get_default_columns(start_alias=alias, opts=f.remote_field.model._meta)
+ columns = self.get_default_columns(
+ start_alias=alias, opts=f.remote_field.model._meta
+ )
for col in columns:
select_fields.append(len(select))
select.append((col, None))
- klass_info['select_fields'] = select_fields
+ klass_info["select_fields"] = select_fields
next_klass_infos = self.get_related_selections(
- select, f.remote_field.model._meta, alias, cur_depth + 1, next, restricted)
+ select,
+ f.remote_field.model._meta,
+ alias,
+ cur_depth + 1,
+ next,
+ restricted,
+ )
get_related_klass_infos(klass_info, next_klass_infos)
if restricted:
@@ -943,36 +1088,40 @@ class SQLCompiler:
if o.field.unique and not o.many_to_many
]
for f, model in related_fields:
- if not select_related_descend(f, restricted, requested,
- only_load.get(model), reverse=True):
+ if not select_related_descend(
+ f, restricted, requested, only_load.get(model), reverse=True
+ ):
continue
related_field_name = f.related_query_name()
fields_found.add(related_field_name)
- join_info = self.query.setup_joins([related_field_name], opts, root_alias)
+ join_info = self.query.setup_joins(
+ [related_field_name], opts, root_alias
+ )
alias = join_info.joins[-1]
from_parent = issubclass(model, opts.model) and model is not opts.model
klass_info = {
- 'model': model,
- 'field': f,
- 'reverse': True,
- 'local_setter': f.remote_field.set_cached_value,
- 'remote_setter': f.set_cached_value,
- 'from_parent': from_parent,
+ "model": model,
+ "field": f,
+ "reverse": True,
+ "local_setter": f.remote_field.set_cached_value,
+ "remote_setter": f.set_cached_value,
+ "from_parent": from_parent,
}
related_klass_infos.append(klass_info)
select_fields = []
columns = self.get_default_columns(
- start_alias=alias, opts=model._meta, from_parent=opts.model)
+ start_alias=alias, opts=model._meta, from_parent=opts.model
+ )
for col in columns:
select_fields.append(len(select))
select.append((col, None))
- klass_info['select_fields'] = select_fields
+ klass_info["select_fields"] = select_fields
next = requested.get(f.related_query_name(), {})
next_klass_infos = self.get_related_selections(
- select, model._meta, alias, cur_depth + 1,
- next, restricted)
+ select, model._meta, alias, cur_depth + 1, next, restricted
+ )
get_related_klass_infos(klass_info, next_klass_infos)
def local_setter(obj, from_obj):
@@ -989,32 +1138,40 @@ class SQLCompiler:
break
if name in self.query._filtered_relations:
fields_found.add(name)
- f, _, join_opts, joins, _, _ = self.query.setup_joins([name], opts, root_alias)
+ f, _, join_opts, joins, _, _ = self.query.setup_joins(
+ [name], opts, root_alias
+ )
model = join_opts.model
alias = joins[-1]
- from_parent = issubclass(model, opts.model) and model is not opts.model
+ from_parent = (
+ issubclass(model, opts.model) and model is not opts.model
+ )
klass_info = {
- 'model': model,
- 'field': f,
- 'reverse': True,
- 'local_setter': local_setter,
- 'remote_setter': partial(remote_setter, name),
- 'from_parent': from_parent,
+ "model": model,
+ "field": f,
+ "reverse": True,
+ "local_setter": local_setter,
+ "remote_setter": partial(remote_setter, name),
+ "from_parent": from_parent,
}
related_klass_infos.append(klass_info)
select_fields = []
columns = self.get_default_columns(
- start_alias=alias, opts=model._meta,
+ start_alias=alias,
+ opts=model._meta,
from_parent=opts.model,
)
for col in columns:
select_fields.append(len(select))
select.append((col, None))
- klass_info['select_fields'] = select_fields
+ klass_info["select_fields"] = select_fields
next_requested = requested.get(name, {})
next_klass_infos = self.get_related_selections(
- select, opts=model._meta, root_alias=alias,
- cur_depth=cur_depth + 1, requested=next_requested,
+ select,
+ opts=model._meta,
+ root_alias=alias,
+ cur_depth=cur_depth + 1,
+ requested=next_requested,
restricted=restricted,
)
get_related_klass_infos(klass_info, next_klass_infos)
@@ -1022,10 +1179,11 @@ class SQLCompiler:
if fields_not_found:
invalid_fields = ("'%s'" % s for s in fields_not_found)
raise FieldError(
- 'Invalid field name(s) given in select_related: %s. '
- 'Choices are: %s' % (
- ', '.join(invalid_fields),
- ', '.join(_get_field_choices()) or '(none)',
+ "Invalid field name(s) given in select_related: %s. "
+ "Choices are: %s"
+ % (
+ ", ".join(invalid_fields),
+ ", ".join(_get_field_choices()) or "(none)",
)
)
return related_klass_infos
@@ -1035,21 +1193,22 @@ class SQLCompiler:
Return a quoted list of arguments for the SELECT FOR UPDATE OF part of
the query.
"""
+
def _get_parent_klass_info(klass_info):
- concrete_model = klass_info['model']._meta.concrete_model
+ concrete_model = klass_info["model"]._meta.concrete_model
for parent_model, parent_link in concrete_model._meta.parents.items():
parent_list = parent_model._meta.get_parent_list()
yield {
- 'model': parent_model,
- 'field': parent_link,
- 'reverse': False,
- 'select_fields': [
+ "model": parent_model,
+ "field": parent_link,
+ "reverse": False,
+ "select_fields": [
select_index
- for select_index in klass_info['select_fields']
+ for select_index in klass_info["select_fields"]
# Selected columns from a model or its parents.
if (
- self.select[select_index][0].target.model == parent_model or
- self.select[select_index][0].target.model in parent_list
+ self.select[select_index][0].target.model == parent_model
+ or self.select[select_index][0].target.model in parent_list
)
],
}
@@ -1062,8 +1221,8 @@ class SQLCompiler:
select_fields is filled recursively, so it also contains fields
from the parent models.
"""
- concrete_model = klass_info['model']._meta.concrete_model
- for select_index in klass_info['select_fields']:
+ concrete_model = klass_info["model"]._meta.concrete_model
+ for select_index in klass_info["select_fields"]:
if self.select[select_index][0].target.model == concrete_model:
return self.select[select_index][0]
@@ -1074,10 +1233,10 @@ class SQLCompiler:
parent_path, klass_info = queue.popleft()
if parent_path is None:
path = []
- yield 'self'
+ yield "self"
else:
- field = klass_info['field']
- if klass_info['reverse']:
+ field = klass_info["field"]
+ if klass_info["reverse"]:
field = field.remote_field
path = parent_path + [field.name]
yield LOOKUP_SEP.join(path)
@@ -1087,25 +1246,26 @@ class SQLCompiler:
)
queue.extend(
(path, klass_info)
- for klass_info in klass_info.get('related_klass_infos', [])
+ for klass_info in klass_info.get("related_klass_infos", [])
)
+
if not self.klass_info:
return []
result = []
invalid_names = []
for name in self.query.select_for_update_of:
klass_info = self.klass_info
- if name == 'self':
+ if name == "self":
col = _get_first_selected_col_from_model(klass_info)
else:
for part in name.split(LOOKUP_SEP):
klass_infos = (
- *klass_info.get('related_klass_infos', []),
+ *klass_info.get("related_klass_infos", []),
*_get_parent_klass_info(klass_info),
)
for related_klass_info in klass_infos:
- field = related_klass_info['field']
- if related_klass_info['reverse']:
+ field = related_klass_info["field"]
+ if related_klass_info["reverse"]:
field = field.remote_field
if field.name == part:
klass_info = related_klass_info
@@ -1124,11 +1284,12 @@ class SQLCompiler:
result.append(self.quote_name_unless_alias(col.alias))
if invalid_names:
raise FieldError(
- 'Invalid field name(s) given in select_for_update(of=(...)): %s. '
- 'Only relational fields followed in the query are allowed. '
- 'Choices are: %s.' % (
- ', '.join(invalid_names),
- ', '.join(_get_field_choices()),
+ "Invalid field name(s) given in select_for_update(of=(...)): %s. "
+ "Only relational fields followed in the query are allowed. "
+ "Choices are: %s."
+ % (
+ ", ".join(invalid_names),
+ ", ".join(_get_field_choices()),
)
)
return result
@@ -1164,12 +1325,19 @@ class SQLCompiler:
row[pos] = value
yield row
- def results_iter(self, results=None, tuple_expected=False, chunked_fetch=False,
- chunk_size=GET_ITERATOR_CHUNK_SIZE):
+ def results_iter(
+ self,
+ results=None,
+ tuple_expected=False,
+ chunked_fetch=False,
+ chunk_size=GET_ITERATOR_CHUNK_SIZE,
+ ):
"""Return an iterator over the results from executing this query."""
if results is None:
- results = self.execute_sql(MULTI, chunked_fetch=chunked_fetch, chunk_size=chunk_size)
- fields = [s[0] for s in self.select[0:self.col_count]]
+ results = self.execute_sql(
+ MULTI, chunked_fetch=chunked_fetch, chunk_size=chunk_size
+ )
+ fields = [s[0] for s in self.select[0 : self.col_count]]
converters = self.get_converters(fields)
rows = chain.from_iterable(results)
if converters:
@@ -1185,7 +1353,9 @@ class SQLCompiler:
"""
return bool(self.execute_sql(SINGLE))
- def execute_sql(self, result_type=MULTI, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE):
+ def execute_sql(
+ self, result_type=MULTI, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE
+ ):
"""
Run the query against the database and return the result(s). The
return value is a single data item if result_type is SINGLE, or an
@@ -1226,7 +1396,7 @@ class SQLCompiler:
try:
val = cursor.fetchone()
if val:
- return val[0:self.col_count]
+ return val[0 : self.col_count]
return val
finally:
# done with the cursor
@@ -1236,7 +1406,8 @@ class SQLCompiler:
return
result = cursor_iter(
- cursor, self.connection.features.empty_fetchmany_value,
+ cursor,
+ self.connection.features.empty_fetchmany_value,
self.col_count if self.has_extra_select else None,
chunk_size,
)
@@ -1254,21 +1425,22 @@ class SQLCompiler:
for index, select_col in enumerate(self.query.select):
lhs_sql, lhs_params = self.compile(select_col)
- rhs = '%s.%s' % (qn(alias), qn2(columns[index]))
- self.query.where.add(
- RawSQL('%s = %s' % (lhs_sql, rhs), lhs_params), 'AND')
+ rhs = "%s.%s" % (qn(alias), qn2(columns[index]))
+ self.query.where.add(RawSQL("%s = %s" % (lhs_sql, rhs), lhs_params), "AND")
sql, params = self.as_sql()
- return 'EXISTS (%s)' % sql, params
+ return "EXISTS (%s)" % sql, params
def explain_query(self):
result = list(self.execute_sql())
# Some backends return 1 item tuples with strings, and others return
# tuples with integers and strings. Flatten them out into strings.
- output_formatter = json.dumps if self.query.explain_info.format == 'json' else str
+ output_formatter = (
+ json.dumps if self.query.explain_info.format == "json" else str
+ )
for row in result[0]:
if not isinstance(row, str):
- yield ' '.join(output_formatter(c) for c in row)
+ yield " ".join(output_formatter(c) for c in row)
else:
yield row
@@ -1289,16 +1461,16 @@ class SQLInsertCompiler(SQLCompiler):
if field is None:
# A field value of None means the value is raw.
sql, params = val, []
- elif hasattr(val, 'as_sql'):
+ elif hasattr(val, "as_sql"):
# This is an expression, let's compile it.
sql, params = self.compile(val)
- elif hasattr(field, 'get_placeholder'):
+ elif hasattr(field, "get_placeholder"):
# Some fields (e.g. geo fields) need special munging before
# they can be inserted.
sql, params = field.get_placeholder(val, self, self.connection), [val]
else:
# Return the common case for the placeholder
- sql, params = '%s', [val]
+ sql, params = "%s", [val]
# The following hook is only used by Oracle Spatial, which sometimes
# needs to yield 'NULL' and [] as its placeholder and params instead
@@ -1314,24 +1486,26 @@ class SQLInsertCompiler(SQLCompiler):
Prepare a value to be used in a query by resolving it if it is an
expression and otherwise calling the field's get_db_prep_save().
"""
- if hasattr(value, 'resolve_expression'):
- value = value.resolve_expression(self.query, allow_joins=False, for_save=True)
+ if hasattr(value, "resolve_expression"):
+ value = value.resolve_expression(
+ self.query, allow_joins=False, for_save=True
+ )
# Don't allow values containing Col expressions. They refer to
# existing columns on a row, but in the case of insert the row
# doesn't exist yet.
if value.contains_column_references:
raise ValueError(
'Failed to insert expression "%s" on %s. F() expressions '
- 'can only be used to update, not to insert.' % (value, field)
+ "can only be used to update, not to insert." % (value, field)
)
if value.contains_aggregate:
raise FieldError(
- 'Aggregate functions are not allowed in this query '
- '(%s=%r).' % (field.name, value)
+ "Aggregate functions are not allowed in this query "
+ "(%s=%r)." % (field.name, value)
)
if value.contains_over_clause:
raise FieldError(
- 'Window expressions are not allowed in this query (%s=%r).'
+ "Window expressions are not allowed in this query (%s=%r)."
% (field.name, value)
)
else:
@@ -1390,25 +1564,32 @@ class SQLInsertCompiler(SQLCompiler):
insert_statement = self.connection.ops.insert_statement(
on_conflict=self.query.on_conflict,
)
- result = ['%s %s' % (insert_statement, qn(opts.db_table))]
+ result = ["%s %s" % (insert_statement, qn(opts.db_table))]
fields = self.query.fields or [opts.pk]
- result.append('(%s)' % ', '.join(qn(f.column) for f in fields))
+ result.append("(%s)" % ", ".join(qn(f.column) for f in fields))
if self.query.fields:
value_rows = [
- [self.prepare_value(field, self.pre_save_val(field, obj)) for field in fields]
+ [
+ self.prepare_value(field, self.pre_save_val(field, obj))
+ for field in fields
+ ]
for obj in self.query.objs
]
else:
# An empty object.
- value_rows = [[self.connection.ops.pk_default_value()] for _ in self.query.objs]
+ value_rows = [
+ [self.connection.ops.pk_default_value()] for _ in self.query.objs
+ ]
fields = [None]
# Currently the backends just accept values when generating bulk
# queries and generate their own placeholders. Doing that isn't
# necessary and it should be possible to use placeholders and
# expressions in bulk inserts too.
- can_bulk = (not self.returning_fields and self.connection.features.has_bulk_insert)
+ can_bulk = (
+ not self.returning_fields and self.connection.features.has_bulk_insert
+ )
placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows)
@@ -1418,9 +1599,14 @@ class SQLInsertCompiler(SQLCompiler):
self.query.update_fields,
self.query.unique_fields,
)
- if self.returning_fields and self.connection.features.can_return_columns_from_insert:
+ if (
+ self.returning_fields
+ and self.connection.features.can_return_columns_from_insert
+ ):
if self.connection.features.can_return_rows_from_bulk_insert:
- result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
+ result.append(
+ self.connection.ops.bulk_insert_sql(fields, placeholder_rows)
+ )
params = param_rows
else:
result.append("VALUES (%s)" % ", ".join(placeholder_rows[0]))
@@ -1429,7 +1615,9 @@ class SQLInsertCompiler(SQLCompiler):
result.append(on_conflict_suffix_sql)
# Skip empty r_sql to allow subclasses to customize behavior for
# 3rd party backends. Refs #19096.
- r_sql, self.returning_params = self.connection.ops.return_insert_columns(self.returning_fields)
+ r_sql, self.returning_params = self.connection.ops.return_insert_columns(
+ self.returning_fields
+ )
if r_sql:
result.append(r_sql)
params += [self.returning_params]
@@ -1450,8 +1638,9 @@ class SQLInsertCompiler(SQLCompiler):
def execute_sql(self, returning_fields=None):
assert not (
- returning_fields and len(self.query.objs) != 1 and
- not self.connection.features.can_return_rows_from_bulk_insert
+ returning_fields
+ and len(self.query.objs) != 1
+ and not self.connection.features.can_return_rows_from_bulk_insert
)
opts = self.query.get_meta()
self.returning_fields = returning_fields
@@ -1460,17 +1649,29 @@ class SQLInsertCompiler(SQLCompiler):
cursor.execute(sql, params)
if not self.returning_fields:
return []
- if self.connection.features.can_return_rows_from_bulk_insert and len(self.query.objs) > 1:
+ if (
+ self.connection.features.can_return_rows_from_bulk_insert
+ and len(self.query.objs) > 1
+ ):
rows = self.connection.ops.fetch_returned_insert_rows(cursor)
elif self.connection.features.can_return_columns_from_insert:
assert len(self.query.objs) == 1
- rows = [self.connection.ops.fetch_returned_insert_columns(
- cursor, self.returning_params,
- )]
+ rows = [
+ self.connection.ops.fetch_returned_insert_columns(
+ cursor,
+ self.returning_params,
+ )
+ ]
else:
- rows = [(self.connection.ops.last_insert_id(
- cursor, opts.db_table, opts.pk.column,
- ),)]
+ rows = [
+ (
+ self.connection.ops.last_insert_id(
+ cursor,
+ opts.db_table,
+ opts.pk.column,
+ ),
+ )
+ ]
cols = [field.get_col(opts.db_table) for field in self.returning_fields]
converters = self.get_converters(cols)
if converters:
@@ -1489,7 +1690,7 @@ class SQLDeleteCompiler(SQLCompiler):
def _expr_refs_base_model(cls, expr, base_model):
if isinstance(expr, Query):
return expr.model == base_model
- if not hasattr(expr, 'get_source_expressions'):
+ if not hasattr(expr, "get_source_expressions"):
return False
return any(
cls._expr_refs_base_model(source_expr, base_model)
@@ -1500,17 +1701,17 @@ class SQLDeleteCompiler(SQLCompiler):
def contains_self_reference_subquery(self):
return any(
self._expr_refs_base_model(expr, self.query.model)
- for expr in chain(self.query.annotations.values(), self.query.where.children)
+ for expr in chain(
+ self.query.annotations.values(), self.query.where.children
+ )
)
def _as_sql(self, query):
- result = [
- 'DELETE FROM %s' % self.quote_name_unless_alias(query.base_table)
- ]
+ result = ["DELETE FROM %s" % self.quote_name_unless_alias(query.base_table)]
where, params = self.compile(query.where)
if where:
- result.append('WHERE %s' % where)
- return ' '.join(result), tuple(params)
+ result.append("WHERE %s" % where)
+ return " ".join(result), tuple(params)
def as_sql(self):
"""
@@ -1523,16 +1724,14 @@ class SQLDeleteCompiler(SQLCompiler):
innerq.__class__ = Query
innerq.clear_select_clause()
pk = self.query.model._meta.pk
- innerq.select = [
- pk.get_col(self.query.get_initial_alias())
- ]
+ innerq.select = [pk.get_col(self.query.get_initial_alias())]
outerq = Query(self.query.model)
if not self.connection.features.update_can_self_select:
# Force the materialization of the inner query to allow reference
# to the target table on MySQL.
sql, params = innerq.get_compiler(connection=self.connection).as_sql()
- innerq = RawSQL('SELECT * FROM (%s) subquery' % sql, params)
- outerq.add_filter('pk__in', innerq)
+ innerq = RawSQL("SELECT * FROM (%s) subquery" % sql, params)
+ outerq.add_filter("pk__in", innerq)
return self._as_sql(outerq)
@@ -1544,23 +1743,25 @@ class SQLUpdateCompiler(SQLCompiler):
"""
self.pre_sql_setup()
if not self.query.values:
- return '', ()
+ return "", ()
qn = self.quote_name_unless_alias
values, update_params = [], []
for field, model, val in self.query.values:
- if hasattr(val, 'resolve_expression'):
- val = val.resolve_expression(self.query, allow_joins=False, for_save=True)
+ if hasattr(val, "resolve_expression"):
+ val = val.resolve_expression(
+ self.query, allow_joins=False, for_save=True
+ )
if val.contains_aggregate:
raise FieldError(
- 'Aggregate functions are not allowed in this query '
- '(%s=%r).' % (field.name, val)
+ "Aggregate functions are not allowed in this query "
+ "(%s=%r)." % (field.name, val)
)
if val.contains_over_clause:
raise FieldError(
- 'Window expressions are not allowed in this query '
- '(%s=%r).' % (field.name, val)
+ "Window expressions are not allowed in this query "
+ "(%s=%r)." % (field.name, val)
)
- elif hasattr(val, 'prepare_database_save'):
+ elif hasattr(val, "prepare_database_save"):
if field.remote_field:
val = field.get_db_prep_save(
val.prepare_database_save(field),
@@ -1576,29 +1777,29 @@ class SQLUpdateCompiler(SQLCompiler):
val = field.get_db_prep_save(val, connection=self.connection)
# Getting the placeholder for the field.
- if hasattr(field, 'get_placeholder'):
+ if hasattr(field, "get_placeholder"):
placeholder = field.get_placeholder(val, self, self.connection)
else:
- placeholder = '%s'
+ placeholder = "%s"
name = field.column
- if hasattr(val, 'as_sql'):
+ if hasattr(val, "as_sql"):
sql, params = self.compile(val)
- values.append('%s = %s' % (qn(name), placeholder % sql))
+ values.append("%s = %s" % (qn(name), placeholder % sql))
update_params.extend(params)
elif val is not None:
- values.append('%s = %s' % (qn(name), placeholder))
+ values.append("%s = %s" % (qn(name), placeholder))
update_params.append(val)
else:
- values.append('%s = NULL' % qn(name))
+ values.append("%s = NULL" % qn(name))
table = self.query.base_table
result = [
- 'UPDATE %s SET' % qn(table),
- ', '.join(values),
+ "UPDATE %s SET" % qn(table),
+ ", ".join(values),
]
where, params = self.compile(self.query.where)
if where:
- result.append('WHERE %s' % where)
- return ' '.join(result), tuple(update_params + params)
+ result.append("WHERE %s" % where)
+ return " ".join(result), tuple(update_params + params)
def execute_sql(self, result_type):
"""
@@ -1644,7 +1845,9 @@ class SQLUpdateCompiler(SQLCompiler):
query.add_fields([query.get_meta().pk.name])
super().pre_sql_setup()
- must_pre_select = count > 1 and not self.connection.features.update_can_self_select
+ must_pre_select = (
+ count > 1 and not self.connection.features.update_can_self_select
+ )
# Now we adjust the current query: reset the where clause and get rid
# of all the tables we don't need (since they're in the sub-select).
@@ -1656,11 +1859,11 @@ class SQLUpdateCompiler(SQLCompiler):
idents = []
for rows in query.get_compiler(self.using).execute_sql(MULTI):
idents.extend(r[0] for r in rows)
- self.query.add_filter('pk__in', idents)
+ self.query.add_filter("pk__in", idents)
self.query.related_ids = idents
else:
# The fast path. Filters and updates in one query.
- self.query.add_filter('pk__in', query)
+ self.query.add_filter("pk__in", query)
self.query.reset_refcounts(refcounts_before)
@@ -1677,13 +1880,14 @@ class SQLAggregateCompiler(SQLCompiler):
sql.append(ann_sql)
params.extend(ann_params)
self.col_count = len(self.query.annotation_select)
- sql = ', '.join(sql)
+ sql = ", ".join(sql)
params = tuple(params)
inner_query_sql, inner_query_params = self.query.inner_query.get_compiler(
- self.using, elide_empty=self.elide_empty,
+ self.using,
+ elide_empty=self.elide_empty,
).as_sql(with_col_aliases=True)
- sql = 'SELECT %s FROM (%s) subquery' % (sql, inner_query_sql)
+ sql = "SELECT %s FROM (%s) subquery" % (sql, inner_query_sql)
params = params + inner_query_params
return sql, params
diff --git a/django/db/models/sql/constants.py b/django/db/models/sql/constants.py
index a1db61b9ff..fdfb2ea891 100644
--- a/django/db/models/sql/constants.py
+++ b/django/db/models/sql/constants.py
@@ -9,16 +9,16 @@ GET_ITERATOR_CHUNK_SIZE = 100
# Namedtuples for sql.* internal use.
# How many results to expect from a cursor.execute call
-MULTI = 'multi'
-SINGLE = 'single'
-CURSOR = 'cursor'
-NO_RESULTS = 'no results'
+MULTI = "multi"
+SINGLE = "single"
+CURSOR = "cursor"
+NO_RESULTS = "no results"
ORDER_DIR = {
- 'ASC': ('ASC', 'DESC'),
- 'DESC': ('DESC', 'ASC'),
+ "ASC": ("ASC", "DESC"),
+ "DESC": ("DESC", "ASC"),
}
# SQL join types.
-INNER = 'INNER JOIN'
-LOUTER = 'LEFT OUTER JOIN'
+INNER = "INNER JOIN"
+LOUTER = "LEFT OUTER JOIN"
diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py
index e08b570350..f398074bf7 100644
--- a/django/db/models/sql/datastructures.py
+++ b/django/db/models/sql/datastructures.py
@@ -11,6 +11,7 @@ class MultiJoin(Exception):
multi-valued join was attempted (if the caller wants to treat that
exceptionally).
"""
+
def __init__(self, names_pos, path_with_names):
self.level = names_pos
# The path travelled, this includes the path to the multijoin.
@@ -38,8 +39,17 @@ class Join:
- as_sql()
- relabeled_clone()
"""
- def __init__(self, table_name, parent_alias, table_alias, join_type,
- join_field, nullable, filtered_relation=None):
+
+ def __init__(
+ self,
+ table_name,
+ parent_alias,
+ table_alias,
+ join_type,
+ join_field,
+ nullable,
+ filtered_relation=None,
+ ):
# Join table
self.table_name = table_name
self.parent_alias = parent_alias
@@ -69,35 +79,47 @@ class Join:
# Add a join condition for each pair of joining columns.
for lhs_col, rhs_col in self.join_cols:
- join_conditions.append('%s.%s = %s.%s' % (
- qn(self.parent_alias),
- qn2(lhs_col),
- qn(self.table_alias),
- qn2(rhs_col),
- ))
+ join_conditions.append(
+ "%s.%s = %s.%s"
+ % (
+ qn(self.parent_alias),
+ qn2(lhs_col),
+ qn(self.table_alias),
+ qn2(rhs_col),
+ )
+ )
# Add a single condition inside parentheses for whatever
# get_extra_restriction() returns.
- extra_cond = self.join_field.get_extra_restriction(self.table_alias, self.parent_alias)
+ extra_cond = self.join_field.get_extra_restriction(
+ self.table_alias, self.parent_alias
+ )
if extra_cond:
extra_sql, extra_params = compiler.compile(extra_cond)
- join_conditions.append('(%s)' % extra_sql)
+ join_conditions.append("(%s)" % extra_sql)
params.extend(extra_params)
if self.filtered_relation:
extra_sql, extra_params = compiler.compile(self.filtered_relation)
if extra_sql:
- join_conditions.append('(%s)' % extra_sql)
+ join_conditions.append("(%s)" % extra_sql)
params.extend(extra_params)
if not join_conditions:
# This might be a rel on the other end of an actual declared field.
- declared_field = getattr(self.join_field, 'field', self.join_field)
+ declared_field = getattr(self.join_field, "field", self.join_field)
raise ValueError(
"Join generated an empty ON clause. %s did not yield either "
"joining columns or extra restrictions." % declared_field.__class__
)
- on_clause_sql = ' AND '.join(join_conditions)
- alias_str = '' if self.table_alias == self.table_name else (' %s' % self.table_alias)
- sql = '%s %s%s ON (%s)' % (self.join_type, qn(self.table_name), alias_str, on_clause_sql)
+ on_clause_sql = " AND ".join(join_conditions)
+ alias_str = (
+ "" if self.table_alias == self.table_name else (" %s" % self.table_alias)
+ )
+ sql = "%s %s%s ON (%s)" % (
+ self.join_type,
+ qn(self.table_name),
+ alias_str,
+ on_clause_sql,
+ )
return sql, params
def relabeled_clone(self, change_map):
@@ -105,12 +127,19 @@ class Join:
new_table_alias = change_map.get(self.table_alias, self.table_alias)
if self.filtered_relation is not None:
filtered_relation = self.filtered_relation.clone()
- filtered_relation.path = [change_map.get(p, p) for p in self.filtered_relation.path]
+ filtered_relation.path = [
+ change_map.get(p, p) for p in self.filtered_relation.path
+ ]
else:
filtered_relation = None
return self.__class__(
- self.table_name, new_parent_alias, new_table_alias, self.join_type,
- self.join_field, self.nullable, filtered_relation=filtered_relation,
+ self.table_name,
+ new_parent_alias,
+ new_table_alias,
+ self.join_type,
+ self.join_field,
+ self.nullable,
+ filtered_relation=filtered_relation,
)
@property
@@ -153,6 +182,7 @@ class BaseTable:
SELECT * FROM "foo" WHERE somecond
could be generated by this class.
"""
+
join_type = None
parent_alias = None
filtered_relation = None
@@ -162,12 +192,16 @@ class BaseTable:
self.table_alias = alias
def as_sql(self, compiler, connection):
- alias_str = '' if self.table_alias == self.table_name else (' %s' % self.table_alias)
+ alias_str = (
+ "" if self.table_alias == self.table_name else (" %s" % self.table_alias)
+ )
base_sql = compiler.quote_name_unless_alias(self.table_name)
return base_sql + alias_str, []
def relabeled_clone(self, change_map):
- return self.__class__(self.table_name, change_map.get(self.table_alias, self.table_alias))
+ return self.__class__(
+ self.table_name, change_map.get(self.table_alias, self.table_alias)
+ )
@property
def identity(self):
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
index 1dc770ae3a..242b2a1f3f 100644
--- a/django/db/models/sql/query.py
+++ b/django/db/models/sql/query.py
@@ -20,32 +20,37 @@ from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections
from django.db.models.aggregates import Count
from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import (
- BaseExpression, Col, Exists, F, OuterRef, Ref, ResolvedOuterRef,
+ BaseExpression,
+ Col,
+ Exists,
+ F,
+ OuterRef,
+ Ref,
+ ResolvedOuterRef,
)
from django.db.models.fields import Field
from django.db.models.fields.related_lookups import MultiColSource
from django.db.models.lookups import Lookup
from django.db.models.query_utils import (
- Q, check_rel_lookup_compatibility, refs_expression,
+ Q,
+ check_rel_lookup_compatibility,
+ refs_expression,
)
from django.db.models.sql.constants import INNER, LOUTER, ORDER_DIR, SINGLE
-from django.db.models.sql.datastructures import (
- BaseTable, Empty, Join, MultiJoin,
-)
-from django.db.models.sql.where import (
- AND, OR, ExtraWhere, NothingNode, WhereNode,
-)
+from django.db.models.sql.datastructures import BaseTable, Empty, Join, MultiJoin
+from django.db.models.sql.where import AND, OR, ExtraWhere, NothingNode, WhereNode
from django.utils.functional import cached_property
from django.utils.tree import Node
-__all__ = ['Query', 'RawQuery']
+__all__ = ["Query", "RawQuery"]
def get_field_names_from_opts(opts):
- return set(chain.from_iterable(
- (f.name, f.attname) if f.concrete else (f.name,)
- for f in opts.get_fields()
- ))
+ return set(
+ chain.from_iterable(
+ (f.name, f.attname) if f.concrete else (f.name,) for f in opts.get_fields()
+ )
+ )
def get_children_from_q(q):
@@ -57,8 +62,8 @@ def get_children_from_q(q):
JoinInfo = namedtuple(
- 'JoinInfo',
- ('final_field', 'targets', 'opts', 'joins', 'path', 'transform_function')
+ "JoinInfo",
+ ("final_field", "targets", "opts", "joins", "path", "transform_function"),
)
@@ -87,8 +92,7 @@ class RawQuery:
if self.cursor is None:
self._execute_query()
converter = connections[self.using].introspection.identifier_converter
- return [converter(column_meta[0])
- for column_meta in self.cursor.description]
+ return [converter(column_meta[0]) for column_meta in self.cursor.description]
def __iter__(self):
# Always execute a new query for a new iterator.
@@ -136,17 +140,17 @@ class RawQuery:
self.cursor.execute(self.sql, params)
-ExplainInfo = namedtuple('ExplainInfo', ('format', 'options'))
+ExplainInfo = namedtuple("ExplainInfo", ("format", "options"))
class Query(BaseExpression):
"""A single SQL query."""
- alias_prefix = 'T'
+ alias_prefix = "T"
empty_result_set_value = None
subq_aliases = frozenset([alias_prefix])
- compiler = 'SQLCompiler'
+ compiler = "SQLCompiler"
base_table_class = BaseTable
join_class = Join
@@ -167,7 +171,7 @@ class Query(BaseExpression):
# aliases too.
# Map external tables to whether they are aliased.
self.external_aliases = {}
- self.table_map = {} # Maps table names to list of aliases.
+ self.table_map = {} # Maps table names to list of aliases.
self.default_cols = True
self.default_ordering = True
self.standard_ordering = True
@@ -240,13 +244,15 @@ class Query(BaseExpression):
def output_field(self):
if len(self.select) == 1:
select = self.select[0]
- return getattr(select, 'target', None) or select.field
+ return getattr(select, "target", None) or select.field
elif len(self.annotation_select) == 1:
return next(iter(self.annotation_select.values())).output_field
@property
def has_select_fields(self):
- return bool(self.select or self.annotation_select_mask or self.extra_select_mask)
+ return bool(
+ self.select or self.annotation_select_mask or self.extra_select_mask
+ )
@cached_property
def base_table(self):
@@ -282,7 +288,9 @@ class Query(BaseExpression):
raise ValueError("Need either using or connection")
if using:
connection = connections[using]
- return connection.ops.compiler(self.compiler)(self, connection, using, elide_empty)
+ return connection.ops.compiler(self.compiler)(
+ self, connection, using, elide_empty
+ )
def get_meta(self):
"""
@@ -311,9 +319,9 @@ class Query(BaseExpression):
if self.annotation_select_mask is not None:
obj.annotation_select_mask = self.annotation_select_mask.copy()
if self.combined_queries:
- obj.combined_queries = tuple([
- query.clone() for query in self.combined_queries
- ])
+ obj.combined_queries = tuple(
+ [query.clone() for query in self.combined_queries]
+ )
# _annotation_select_cache cannot be copied, as doing so breaks the
# (necessary) state in which both annotations and
# _annotation_select_cache point to the same underlying objects.
@@ -329,7 +337,7 @@ class Query(BaseExpression):
# Use deepcopy because select_related stores fields in nested
# dicts.
obj.select_related = copy.deepcopy(obj.select_related)
- if 'subq_aliases' in self.__dict__:
+ if "subq_aliases" in self.__dict__:
obj.subq_aliases = self.subq_aliases.copy()
obj.used_aliases = self.used_aliases.copy()
obj._filtered_relations = self._filtered_relations.copy()
@@ -351,7 +359,7 @@ class Query(BaseExpression):
if not obj.filter_is_sticky:
obj.used_aliases = set()
obj.filter_is_sticky = False
- if hasattr(obj, '_setup_query'):
+ if hasattr(obj, "_setup_query"):
obj._setup_query()
return obj
@@ -401,11 +409,13 @@ class Query(BaseExpression):
break
else:
# An expression that is not selected the subquery.
- if isinstance(expr, Col) or (expr.contains_aggregate and not expr.is_summary):
+ if isinstance(expr, Col) or (
+ expr.contains_aggregate and not expr.is_summary
+ ):
# Reference column or another aggregate. Select it
# under a non-conflicting alias.
col_cnt += 1
- col_alias = '__col%d' % col_cnt
+ col_alias = "__col%d" % col_cnt
self.annotations[col_alias] = expr
self.append_annotation_mask([col_alias])
new_expr = Ref(col_alias, expr)
@@ -424,8 +434,8 @@ class Query(BaseExpression):
if not self.annotation_select:
return {}
existing_annotations = [
- annotation for alias, annotation
- in self.annotations.items()
+ annotation
+ for alias, annotation in self.annotations.items()
if alias not in added_aggregate_names
]
# Decide if we need to use a subquery.
@@ -439,9 +449,15 @@ class Query(BaseExpression):
# those operations must be done in a subquery so that the query
# aggregates on the limit and/or distinct results instead of applying
# the distinct and limit after the aggregation.
- if (isinstance(self.group_by, tuple) or self.is_sliced or existing_annotations or
- self.distinct or self.combinator):
+ if (
+ isinstance(self.group_by, tuple)
+ or self.is_sliced
+ or existing_annotations
+ or self.distinct
+ or self.combinator
+ ):
from django.db.models.sql.subqueries import AggregateQuery
+
inner_query = self.clone()
inner_query.subquery = True
outer_query = AggregateQuery(self.model, inner_query)
@@ -459,15 +475,18 @@ class Query(BaseExpression):
# clearing the select clause can alter results if distinct is
# used.
has_existing_aggregate_annotations = any(
- annotation for annotation in existing_annotations
- if getattr(annotation, 'contains_aggregate', True)
+ annotation
+ for annotation in existing_annotations
+ if getattr(annotation, "contains_aggregate", True)
)
if inner_query.default_cols and has_existing_aggregate_annotations:
- inner_query.group_by = (self.model._meta.pk.get_col(inner_query.get_initial_alias()),)
+ inner_query.group_by = (
+ self.model._meta.pk.get_col(inner_query.get_initial_alias()),
+ )
inner_query.default_cols = False
- relabels = {t: 'subquery' for t in inner_query.alias_map}
- relabels[None] = 'subquery'
+ relabels = {t: "subquery" for t in inner_query.alias_map}
+ relabels[None] = "subquery"
# Remove any aggregates marked for reduction from the subquery
# and move them to the outer AggregateQuery.
col_cnt = 0
@@ -475,16 +494,24 @@ class Query(BaseExpression):
annotation_select_mask = inner_query.annotation_select_mask
if expression.is_summary:
expression, col_cnt = inner_query.rewrite_cols(expression, col_cnt)
- outer_query.annotations[alias] = expression.relabeled_clone(relabels)
+ outer_query.annotations[alias] = expression.relabeled_clone(
+ relabels
+ )
del inner_query.annotations[alias]
annotation_select_mask.remove(alias)
# Make sure the annotation_select wont use cached results.
inner_query.set_annotation_mask(inner_query.annotation_select_mask)
- if inner_query.select == () and not inner_query.default_cols and not inner_query.annotation_select_mask:
+ if (
+ inner_query.select == ()
+ and not inner_query.default_cols
+ and not inner_query.annotation_select_mask
+ ):
# In case of Model.objects[0:3].count(), there would be no
# field selected in the inner query, yet we must use a subquery.
# So, make sure at least one field is selected.
- inner_query.select = (self.model._meta.pk.get_col(inner_query.get_initial_alias()),)
+ inner_query.select = (
+ self.model._meta.pk.get_col(inner_query.get_initial_alias()),
+ )
else:
outer_query = self
self.select = ()
@@ -515,8 +542,8 @@ class Query(BaseExpression):
Perform a COUNT() query using the current filter constraints.
"""
obj = self.clone()
- obj.add_annotation(Count('*'), alias='__count', is_summary=True)
- return obj.get_aggregation(using, ['__count'])['__count']
+ obj.add_annotation(Count("*"), alias="__count", is_summary=True)
+ return obj.get_aggregation(using, ["__count"])["__count"]
def has_filters(self):
return self.where
@@ -525,13 +552,17 @@ class Query(BaseExpression):
q = self.clone()
if not q.distinct:
if q.group_by is True:
- q.add_fields((f.attname for f in self.model._meta.concrete_fields), False)
+ q.add_fields(
+ (f.attname for f in self.model._meta.concrete_fields), False
+ )
# Disable GROUP BY aliases to avoid orphaning references to the
# SELECT clause which is about to be cleared.
q.set_group_by(allow_aliases=False)
q.clear_select_clause()
- if q.combined_queries and q.combinator == 'union':
- limit_combined = connections[using].features.supports_slicing_ordering_in_compound
+ if q.combined_queries and q.combinator == "union":
+ limit_combined = connections[
+ using
+ ].features.supports_slicing_ordering_in_compound
q.combined_queries = tuple(
combined_query.exists(using, limit=limit_combined)
for combined_query in q.combined_queries
@@ -539,8 +570,8 @@ class Query(BaseExpression):
q.clear_ordering(force=True)
if limit:
q.set_limits(high=1)
- q.add_extra({'a': 1}, None, None, None, None, None)
- q.set_extra_mask(['a'])
+ q.add_extra({"a": 1}, None, None, None, None, None)
+ q.set_extra_mask(["a"])
return q
def has_results(self, using):
@@ -552,7 +583,7 @@ class Query(BaseExpression):
q = self.clone()
q.explain_info = ExplainInfo(format, options)
compiler = q.get_compiler(using=using)
- return '\n'.join(compiler.explain_query())
+ return "\n".join(compiler.explain_query())
def combine(self, rhs, connector):
"""
@@ -564,13 +595,13 @@ class Query(BaseExpression):
'rhs' query.
"""
if self.model != rhs.model:
- raise TypeError('Cannot combine queries on two different base models.')
+ raise TypeError("Cannot combine queries on two different base models.")
if self.is_sliced:
- raise TypeError('Cannot combine queries once a slice has been taken.')
+ raise TypeError("Cannot combine queries once a slice has been taken.")
if self.distinct != rhs.distinct:
- raise TypeError('Cannot combine a unique query with a non-unique query.')
+ raise TypeError("Cannot combine a unique query with a non-unique query.")
if self.distinct_fields != rhs.distinct_fields:
- raise TypeError('Cannot combine queries with different distinct fields.')
+ raise TypeError("Cannot combine queries with different distinct fields.")
# If lhs and rhs shares the same alias prefix, it is possible to have
# conflicting alias changes like T4 -> T5, T5 -> T6, which might end up
@@ -583,7 +614,7 @@ class Query(BaseExpression):
# Work out how to relabel the rhs aliases, if necessary.
change_map = {}
- conjunction = (connector == AND)
+ conjunction = connector == AND
# Determine which existing joins can be reused. When combining the
# query with AND we must recreate all joins for m2m filters. When
@@ -600,7 +631,8 @@ class Query(BaseExpression):
reuse = set() if conjunction else set(self.alias_map)
joinpromoter = JoinPromoter(connector, 2, False)
joinpromoter.add_votes(
- j for j in self.alias_map if self.alias_map[j].join_type == INNER)
+ j for j in self.alias_map if self.alias_map[j].join_type == INNER
+ )
rhs_votes = set()
# Now, add the joins from rhs query into the new query (skipping base
# table).
@@ -649,7 +681,9 @@ class Query(BaseExpression):
# really make sense (or return consistent value sets). Not worth
# the extra complexity when you can write a real query instead.
if self.extra and rhs.extra:
- raise ValueError("When merging querysets using 'or', you cannot have extra(select=...) on both sides.")
+ raise ValueError(
+ "When merging querysets using 'or', you cannot have extra(select=...) on both sides."
+ )
self.extra.update(rhs.extra)
extra_select_mask = set()
if self.extra_select_mask is not None:
@@ -767,11 +801,13 @@ class Query(BaseExpression):
# Create a new alias for this table.
if alias_list:
- alias = '%s%d' % (self.alias_prefix, len(self.alias_map) + 1)
+ alias = "%s%d" % (self.alias_prefix, len(self.alias_map) + 1)
alias_list.append(alias)
else:
# The first occurrence of a table uses the table name directly.
- alias = filtered_relation.alias if filtered_relation is not None else table_name
+ alias = (
+ filtered_relation.alias if filtered_relation is not None else table_name
+ )
self.table_map[table_name] = [alias]
self.alias_refcount[alias] = 1
return alias, True
@@ -806,16 +842,19 @@ class Query(BaseExpression):
# Only the first alias (skipped above) should have None join_type
assert self.alias_map[alias].join_type is not None
parent_alias = self.alias_map[alias].parent_alias
- parent_louter = parent_alias and self.alias_map[parent_alias].join_type == LOUTER
+ parent_louter = (
+ parent_alias and self.alias_map[parent_alias].join_type == LOUTER
+ )
already_louter = self.alias_map[alias].join_type == LOUTER
- if ((self.alias_map[alias].nullable or parent_louter) and
- not already_louter):
+ if (self.alias_map[alias].nullable or parent_louter) and not already_louter:
self.alias_map[alias] = self.alias_map[alias].promote()
# Join type of 'alias' changed, so re-examine all aliases that
# refer to this one.
aliases.extend(
- join for join in self.alias_map
- if self.alias_map[join].parent_alias == alias and join not in aliases
+ join
+ for join in self.alias_map
+ if self.alias_map[join].parent_alias == alias
+ and join not in aliases
)
def demote_joins(self, aliases):
@@ -861,10 +900,13 @@ class Query(BaseExpression):
# "group by" and "where".
self.where.relabel_aliases(change_map)
if isinstance(self.group_by, tuple):
- self.group_by = tuple([col.relabeled_clone(change_map) for col in self.group_by])
+ self.group_by = tuple(
+ [col.relabeled_clone(change_map) for col in self.group_by]
+ )
self.select = tuple([col.relabeled_clone(change_map) for col in self.select])
self.annotations = self.annotations and {
- key: col.relabeled_clone(change_map) for key, col in self.annotations.items()
+ key: col.relabeled_clone(change_map)
+ for key, col in self.annotations.items()
}
# 2. Rename the alias in the internal table/alias datastructures.
@@ -895,6 +937,7 @@ class Query(BaseExpression):
conflict. Even tables that previously had no alias will get an alias
after this call. To prevent changing aliases use the exclude parameter.
"""
+
def prefix_gen():
"""
Generate a sequence of characters in alphabetical order:
@@ -908,9 +951,9 @@ class Query(BaseExpression):
prefix = chr(ord(self.alias_prefix) + 1)
yield prefix
for n in count(1):
- seq = alphabet[alphabet.index(prefix):] if prefix else alphabet
+ seq = alphabet[alphabet.index(prefix) :] if prefix else alphabet
for s in product(seq, repeat=n):
- yield ''.join(s)
+ yield "".join(s)
prefix = None
if self.alias_prefix != other_query.alias_prefix:
@@ -928,17 +971,19 @@ class Query(BaseExpression):
break
if pos > local_recursion_limit:
raise RecursionError(
- 'Maximum recursion depth exceeded: too many subqueries.'
+ "Maximum recursion depth exceeded: too many subqueries."
)
self.subq_aliases = self.subq_aliases.union([self.alias_prefix])
other_query.subq_aliases = other_query.subq_aliases.union(self.subq_aliases)
if exclude is None:
exclude = {}
- self.change_aliases({
- alias: '%s%d' % (self.alias_prefix, pos)
- for pos, alias in enumerate(self.alias_map)
- if alias not in exclude
- })
+ self.change_aliases(
+ {
+ alias: "%s%d" % (self.alias_prefix, pos)
+ for pos, alias in enumerate(self.alias_map)
+ if alias not in exclude
+ }
+ )
def get_initial_alias(self):
"""
@@ -974,7 +1019,8 @@ class Query(BaseExpression):
joins are created as LOUTER if the join is nullable.
"""
reuse_aliases = [
- a for a, j in self.alias_map.items()
+ a
+ for a, j in self.alias_map.items()
if (reuse is None or a in reuse) and j.equals(join)
]
if reuse_aliases:
@@ -988,7 +1034,9 @@ class Query(BaseExpression):
return reuse_alias
# No reuse is possible, so we need a new alias.
- alias, _ = self.table_alias(join.table_name, create=True, filtered_relation=join.filtered_relation)
+ alias, _ = self.table_alias(
+ join.table_name, create=True, filtered_relation=join.filtered_relation
+ )
if join.join_type:
if self.alias_map[join.parent_alias].join_type == LOUTER or join.nullable:
join_type = LOUTER
@@ -1034,8 +1082,9 @@ class Query(BaseExpression):
def add_annotation(self, annotation, alias, is_summary=False, select=True):
"""Add a single annotation expression to the Query."""
- annotation = annotation.resolve_expression(self, allow_joins=True, reuse=None,
- summarize=is_summary)
+ annotation = annotation.resolve_expression(
+ self, allow_joins=True, reuse=None, summarize=is_summary
+ )
if select:
self.append_annotation_mask([alias])
else:
@@ -1050,27 +1099,32 @@ class Query(BaseExpression):
clone.where.resolve_expression(query, *args, **kwargs)
# Resolve combined queries.
if clone.combinator:
- clone.combined_queries = tuple([
- combined_query.resolve_expression(query, *args, **kwargs)
- for combined_query in clone.combined_queries
- ])
+ clone.combined_queries = tuple(
+ [
+ combined_query.resolve_expression(query, *args, **kwargs)
+ for combined_query in clone.combined_queries
+ ]
+ )
for key, value in clone.annotations.items():
resolved = value.resolve_expression(query, *args, **kwargs)
- if hasattr(resolved, 'external_aliases'):
+ if hasattr(resolved, "external_aliases"):
resolved.external_aliases.update(clone.external_aliases)
clone.annotations[key] = resolved
# Outer query's aliases are considered external.
for alias, table in query.alias_map.items():
clone.external_aliases[alias] = (
- (isinstance(table, Join) and table.join_field.related_model._meta.db_table != alias) or
- (isinstance(table, BaseTable) and table.table_name != table.table_alias)
+ isinstance(table, Join)
+ and table.join_field.related_model._meta.db_table != alias
+ ) or (
+ isinstance(table, BaseTable) and table.table_name != table.table_alias
)
return clone
def get_external_cols(self):
exprs = chain(self.annotations.values(), self.where.children)
return [
- col for col in self._gen_cols(exprs, include_external=True)
+ col
+ for col in self._gen_cols(exprs, include_external=True)
if col.alias in self.external_aliases
]
@@ -1086,19 +1140,21 @@ class Query(BaseExpression):
# Some backends (e.g. Oracle) raise an error when a subquery contains
# unnecessary ORDER BY clause.
if (
- self.subquery and
- not connection.features.ignores_unnecessary_order_by_in_subqueries
+ self.subquery
+ and not connection.features.ignores_unnecessary_order_by_in_subqueries
):
self.clear_ordering(force=False)
sql, params = self.get_compiler(connection=connection).as_sql()
if self.subquery:
- sql = '(%s)' % sql
+ sql = "(%s)" % sql
return sql, params
def resolve_lookup_value(self, value, can_reuse, allow_joins):
- if hasattr(value, 'resolve_expression'):
+ if hasattr(value, "resolve_expression"):
value = value.resolve_expression(
- self, reuse=can_reuse, allow_joins=allow_joins,
+ self,
+ reuse=can_reuse,
+ allow_joins=allow_joins,
)
elif isinstance(value, (list, tuple)):
# The items of the iterable may be expressions and therefore need
@@ -1108,7 +1164,7 @@ class Query(BaseExpression):
for sub_value in value
)
type_ = type(value)
- if hasattr(type_, '_make'): # namedtuple
+ if hasattr(type_, "_make"): # namedtuple
return type_(*values)
return type_(values)
return value
@@ -1119,15 +1175,17 @@ class Query(BaseExpression):
"""
lookup_splitted = lookup.split(LOOKUP_SEP)
if self.annotations:
- expression, expression_lookups = refs_expression(lookup_splitted, self.annotations)
+ expression, expression_lookups = refs_expression(
+ lookup_splitted, self.annotations
+ )
if expression:
return expression_lookups, (), expression
_, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta())
- field_parts = lookup_splitted[0:len(lookup_splitted) - len(lookup_parts)]
+ field_parts = lookup_splitted[0 : len(lookup_splitted) - len(lookup_parts)]
if len(lookup_parts) > 1 and not field_parts:
raise FieldError(
- 'Invalid lookup "%s" for model %s".' %
- (lookup, self.get_meta().model.__name__)
+ 'Invalid lookup "%s" for model %s".'
+ % (lookup, self.get_meta().model.__name__)
)
return lookup_parts, field_parts, False
@@ -1136,11 +1194,12 @@ class Query(BaseExpression):
Check whether the object passed while querying is of the correct type.
If not, raise a ValueError specifying the wrong object.
"""
- if hasattr(value, '_meta'):
+ if hasattr(value, "_meta"):
if not check_rel_lookup_compatibility(value._meta.model, opts, field):
raise ValueError(
- 'Cannot query "%s": Must be "%s" instance.' %
- (value, opts.object_name))
+ 'Cannot query "%s": Must be "%s" instance.'
+ % (value, opts.object_name)
+ )
def check_related_objects(self, field, value, opts):
"""Check the type of object passed to query relations."""
@@ -1150,29 +1209,31 @@ class Query(BaseExpression):
# opts would be Author's (from the author field) and value.model
# would be Author.objects.all() queryset's .model (Author also).
# The field is the related field on the lhs side.
- if (isinstance(value, Query) and not value.has_select_fields and
- not check_rel_lookup_compatibility(value.model, opts, field)):
+ if (
+ isinstance(value, Query)
+ and not value.has_select_fields
+ and not check_rel_lookup_compatibility(value.model, opts, field)
+ ):
raise ValueError(
- 'Cannot use QuerySet for "%s": Use a QuerySet for "%s".' %
- (value.model._meta.object_name, opts.object_name)
+ 'Cannot use QuerySet for "%s": Use a QuerySet for "%s".'
+ % (value.model._meta.object_name, opts.object_name)
)
- elif hasattr(value, '_meta'):
+ elif hasattr(value, "_meta"):
self.check_query_object_type(value, opts, field)
- elif hasattr(value, '__iter__'):
+ elif hasattr(value, "__iter__"):
for v in value:
self.check_query_object_type(v, opts, field)
def check_filterable(self, expression):
"""Raise an error if expression cannot be used in a WHERE clause."""
- if (
- hasattr(expression, 'resolve_expression') and
- not getattr(expression, 'filterable', True)
+ if hasattr(expression, "resolve_expression") and not getattr(
+ expression, "filterable", True
):
raise NotSupportedError(
- expression.__class__.__name__ + ' is disallowed in the filter '
- 'clause.'
+ expression.__class__.__name__ + " is disallowed in the filter "
+ "clause."
)
- if hasattr(expression, 'get_source_expressions'):
+ if hasattr(expression, "get_source_expressions"):
for expr in expression.get_source_expressions():
self.check_filterable(expr)
@@ -1186,7 +1247,7 @@ class Query(BaseExpression):
and get_transform().
"""
# __exact is the default lookup if one isn't given.
- *transforms, lookup_name = lookups or ['exact']
+ *transforms, lookup_name = lookups or ["exact"]
for name in transforms:
lhs = self.try_transform(lhs, name)
# First try get_lookup() so that the lookup takes precedence if the lhs
@@ -1194,11 +1255,13 @@ class Query(BaseExpression):
lookup_class = lhs.get_lookup(lookup_name)
if not lookup_class:
if lhs.field.is_relation:
- raise FieldError('Related Field got invalid lookup: {}'.format(lookup_name))
+ raise FieldError(
+ "Related Field got invalid lookup: {}".format(lookup_name)
+ )
# A lookup wasn't found. Try to interpret the name as a transform
# and do an Exact lookup against it.
lhs = self.try_transform(lhs, lookup_name)
- lookup_name = 'exact'
+ lookup_name = "exact"
lookup_class = lhs.get_lookup(lookup_name)
if not lookup_class:
return
@@ -1207,20 +1270,20 @@ class Query(BaseExpression):
# Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all
# uses of None as a query value unless the lookup supports it.
if lookup.rhs is None and not lookup.can_use_none_as_rhs:
- if lookup_name not in ('exact', 'iexact'):
+ if lookup_name not in ("exact", "iexact"):
raise ValueError("Cannot use None as a query value")
- return lhs.get_lookup('isnull')(lhs, True)
+ return lhs.get_lookup("isnull")(lhs, True)
# For Oracle '' is equivalent to null. The check must be done at this
# stage because join promotion can't be done in the compiler. Using
# DEFAULT_DB_ALIAS isn't nice but it's the best that can be done here.
# A similar thing is done in is_nullable(), too.
if (
- lookup_name == 'exact' and
- lookup.rhs == '' and
- connections[DEFAULT_DB_ALIAS].features.interprets_empty_strings_as_nulls
+ lookup_name == "exact"
+ and lookup.rhs == ""
+ and connections[DEFAULT_DB_ALIAS].features.interprets_empty_strings_as_nulls
):
- return lhs.get_lookup('isnull')(lhs, True)
+ return lhs.get_lookup("isnull")(lhs, True)
return lookup
@@ -1234,19 +1297,28 @@ class Query(BaseExpression):
return transform_class(lhs)
else:
output_field = lhs.output_field.__class__
- suggested_lookups = difflib.get_close_matches(name, output_field.get_lookups())
+ suggested_lookups = difflib.get_close_matches(
+ name, output_field.get_lookups()
+ )
if suggested_lookups:
- suggestion = ', perhaps you meant %s?' % ' or '.join(suggested_lookups)
+ suggestion = ", perhaps you meant %s?" % " or ".join(suggested_lookups)
else:
- suggestion = '.'
+ suggestion = "."
raise FieldError(
"Unsupported lookup '%s' for %s or join on the field not "
"permitted%s" % (name, output_field.__name__, suggestion)
)
- def build_filter(self, filter_expr, branch_negated=False, current_negated=False,
- can_reuse=None, allow_joins=True, split_subq=True,
- check_filterable=True):
+ def build_filter(
+ self,
+ filter_expr,
+ branch_negated=False,
+ current_negated=False,
+ can_reuse=None,
+ allow_joins=True,
+ split_subq=True,
+ check_filterable=True,
+ ):
"""
Build a WhereNode for a single filter clause but don't add it
to this Query. Query.add_q() will then add this filter to the where
@@ -1284,12 +1356,12 @@ class Query(BaseExpression):
split_subq=split_subq,
check_filterable=check_filterable,
)
- if hasattr(filter_expr, 'resolve_expression'):
- if not getattr(filter_expr, 'conditional', False):
- raise TypeError('Cannot filter against a non-conditional expression.')
+ if hasattr(filter_expr, "resolve_expression"):
+ if not getattr(filter_expr, "conditional", False):
+ raise TypeError("Cannot filter against a non-conditional expression.")
condition = filter_expr.resolve_expression(self, allow_joins=allow_joins)
if not isinstance(condition, Lookup):
- condition = self.build_lookup(['exact'], condition, True)
+ condition = self.build_lookup(["exact"], condition, True)
return WhereNode([condition], connector=AND), []
arg, value = filter_expr
if not arg:
@@ -1304,7 +1376,9 @@ class Query(BaseExpression):
pre_joins = self.alias_refcount.copy()
value = self.resolve_lookup_value(value, can_reuse, allow_joins)
- used_joins = {k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)}
+ used_joins = {
+ k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)
+ }
if check_filterable:
self.check_filterable(value)
@@ -1319,7 +1393,11 @@ class Query(BaseExpression):
try:
join_info = self.setup_joins(
- parts, opts, alias, can_reuse=can_reuse, allow_many=allow_many,
+ parts,
+ opts,
+ alias,
+ can_reuse=can_reuse,
+ allow_many=allow_many,
)
# Prevent iterator from being consumed by check_related_objects()
@@ -1336,7 +1414,9 @@ class Query(BaseExpression):
# Update used_joins before trimming since they are reused to determine
# which joins could be later promoted to INNER.
used_joins.update(join_info.joins)
- targets, alias, join_list = self.trim_joins(join_info.targets, join_info.joins, join_info.path)
+ targets, alias, join_list = self.trim_joins(
+ join_info.targets, join_info.joins, join_info.path
+ )
if can_reuse is not None:
can_reuse.update(join_list)
@@ -1344,11 +1424,15 @@ class Query(BaseExpression):
# No support for transforms for relational fields
num_lookups = len(lookups)
if num_lookups > 1:
- raise FieldError('Related Field got invalid lookup: {}'.format(lookups[0]))
+ raise FieldError(
+ "Related Field got invalid lookup: {}".format(lookups[0])
+ )
if len(targets) == 1:
col = self._get_col(targets[0], join_info.final_field, alias)
else:
- col = MultiColSource(alias, targets, join_info.targets, join_info.final_field)
+ col = MultiColSource(
+ alias, targets, join_info.targets, join_info.final_field
+ )
else:
col = self._get_col(targets[0], join_info.final_field, alias)
@@ -1356,10 +1440,16 @@ class Query(BaseExpression):
lookup_type = condition.lookup_name
clause = WhereNode([condition], connector=AND)
- require_outer = lookup_type == 'isnull' and condition.rhs is True and not current_negated
- if current_negated and (lookup_type != 'isnull' or condition.rhs is False) and condition.rhs is not None:
+ require_outer = (
+ lookup_type == "isnull" and condition.rhs is True and not current_negated
+ )
+ if (
+ current_negated
+ and (lookup_type != "isnull" or condition.rhs is False)
+ and condition.rhs is not None
+ ):
require_outer = True
- if lookup_type != 'isnull':
+ if lookup_type != "isnull":
# The condition added here will be SQL like this:
# NOT (col IS NOT NULL), where the first NOT is added in
# upper layers of code. The reason for addition is that if col
@@ -1370,16 +1460,16 @@ class Query(BaseExpression):
# <=>
# NOT (col IS NOT NULL AND col = someval).
if (
- self.is_nullable(targets[0]) or
- self.alias_map[join_list[-1]].join_type == LOUTER
+ self.is_nullable(targets[0])
+ or self.alias_map[join_list[-1]].join_type == LOUTER
):
- lookup_class = targets[0].get_lookup('isnull')
+ lookup_class = targets[0].get_lookup("isnull")
col = self._get_col(targets[0], join_info.targets[0], alias)
clause.add(lookup_class(col, False), AND)
# If someval is a nullable column, someval IS NOT NULL is
# added.
if isinstance(value, Col) and self.is_nullable(value.target):
- lookup_class = value.target.get_lookup('isnull')
+ lookup_class = value.target.get_lookup("isnull")
clause.add(lookup_class(value, False), AND)
return clause, used_joins if not require_outer else ()
@@ -1397,7 +1487,9 @@ class Query(BaseExpression):
# (Consider case where rel_a is LOUTER and rel_a__col=1 is added - if
# rel_a doesn't produce any rows, then the whole condition must fail.
# So, demotion is OK.
- existing_inner = {a for a in self.alias_map if self.alias_map[a].join_type == INNER}
+ existing_inner = {
+ a for a in self.alias_map if self.alias_map[a].join_type == INNER
+ }
clause, _ = self._add_q(q_object, self.used_aliases)
if clause:
self.where.add(clause, AND)
@@ -1409,20 +1501,33 @@ class Query(BaseExpression):
def clear_where(self):
self.where = WhereNode()
- def _add_q(self, q_object, used_aliases, branch_negated=False,
- current_negated=False, allow_joins=True, split_subq=True,
- check_filterable=True):
+ def _add_q(
+ self,
+ q_object,
+ used_aliases,
+ branch_negated=False,
+ current_negated=False,
+ allow_joins=True,
+ split_subq=True,
+ check_filterable=True,
+ ):
"""Add a Q-object to the current filter."""
connector = q_object.connector
current_negated = current_negated ^ q_object.negated
branch_negated = branch_negated or q_object.negated
target_clause = WhereNode(connector=connector, negated=q_object.negated)
- joinpromoter = JoinPromoter(q_object.connector, len(q_object.children), current_negated)
+ joinpromoter = JoinPromoter(
+ q_object.connector, len(q_object.children), current_negated
+ )
for child in q_object.children:
child_clause, needed_inner = self.build_filter(
- child, can_reuse=used_aliases, branch_negated=branch_negated,
- current_negated=current_negated, allow_joins=allow_joins,
- split_subq=split_subq, check_filterable=check_filterable,
+ child,
+ can_reuse=used_aliases,
+ branch_negated=branch_negated,
+ current_negated=current_negated,
+ allow_joins=allow_joins,
+ split_subq=split_subq,
+ check_filterable=check_filterable,
)
joinpromoter.add_votes(needed_inner)
if child_clause:
@@ -1430,7 +1535,9 @@ class Query(BaseExpression):
needed_inner = joinpromoter.update_join_types(self)
return target_clause, needed_inner
- def build_filtered_relation_q(self, q_object, reuse, branch_negated=False, current_negated=False):
+ def build_filtered_relation_q(
+ self, q_object, reuse, branch_negated=False, current_negated=False
+ ):
"""Add a FilteredRelation object to the current filter."""
connector = q_object.connector
current_negated ^= q_object.negated
@@ -1439,14 +1546,19 @@ class Query(BaseExpression):
for child in q_object.children:
if isinstance(child, Node):
child_clause = self.build_filtered_relation_q(
- child, reuse=reuse, branch_negated=branch_negated,
+ child,
+ reuse=reuse,
+ branch_negated=branch_negated,
current_negated=current_negated,
)
else:
child_clause, _ = self.build_filter(
- child, can_reuse=reuse, branch_negated=branch_negated,
+ child,
+ can_reuse=reuse,
+ branch_negated=branch_negated,
current_negated=current_negated,
- allow_joins=True, split_subq=False,
+ allow_joins=True,
+ split_subq=False,
)
target_clause.add(child_clause, connector)
return target_clause
@@ -1454,7 +1566,9 @@ class Query(BaseExpression):
def add_filtered_relation(self, filtered_relation, alias):
filtered_relation.alias = alias
lookups = dict(get_children_from_q(filtered_relation.condition))
- relation_lookup_parts, relation_field_parts, _ = self.solve_lookup_type(filtered_relation.relation_name)
+ relation_lookup_parts, relation_field_parts, _ = self.solve_lookup_type(
+ filtered_relation.relation_name
+ )
if relation_lookup_parts:
raise ValueError(
"FilteredRelation's relation_name cannot contain lookups "
@@ -1498,7 +1612,7 @@ class Query(BaseExpression):
path, names_with_path = [], []
for pos, name in enumerate(names):
cur_names_with_path = (name, [])
- if name == 'pk':
+ if name == "pk":
name = opts.pk.name
field = None
@@ -1513,7 +1627,10 @@ class Query(BaseExpression):
if LOOKUP_SEP in filtered_relation.relation_name:
parts = filtered_relation.relation_name.split(LOOKUP_SEP)
filtered_relation_path, field, _, _ = self.names_to_path(
- parts, opts, allow_many, fail_on_missing,
+ parts,
+ opts,
+ allow_many,
+ fail_on_missing,
)
path.extend(filtered_relation_path[:-1])
else:
@@ -1540,13 +1657,17 @@ class Query(BaseExpression):
# one step.
pos -= 1
if pos == -1 or fail_on_missing:
- available = sorted([
- *get_field_names_from_opts(opts),
- *self.annotation_select,
- *self._filtered_relations,
- ])
- raise FieldError("Cannot resolve keyword '%s' into field. "
- "Choices are: %s" % (name, ", ".join(available)))
+ available = sorted(
+ [
+ *get_field_names_from_opts(opts),
+ *self.annotation_select,
+ *self._filtered_relations,
+ ]
+ )
+ raise FieldError(
+ "Cannot resolve keyword '%s' into field. "
+ "Choices are: %s" % (name, ", ".join(available))
+ )
break
# Check if we need any joins for concrete inheritance cases (the
# field lives in parent, but we are currently in one of its
@@ -1557,7 +1678,7 @@ class Query(BaseExpression):
path.extend(path_to_parent)
cur_names_with_path[1].extend(path_to_parent)
opts = path_to_parent[-1].to_opts
- if hasattr(field, 'path_infos'):
+ if hasattr(field, "path_infos"):
if filtered_relation:
pathinfos = field.get_path_info(filtered_relation)
else:
@@ -1565,7 +1686,7 @@ class Query(BaseExpression):
if not allow_many:
for inner_pos, p in enumerate(pathinfos):
if p.m2m:
- cur_names_with_path[1].extend(pathinfos[0:inner_pos + 1])
+ cur_names_with_path[1].extend(pathinfos[0 : inner_pos + 1])
names_with_path.append(cur_names_with_path)
raise MultiJoin(pos + 1, names_with_path)
last = pathinfos[-1]
@@ -1582,9 +1703,10 @@ class Query(BaseExpression):
if fail_on_missing and pos + 1 != len(names):
raise FieldError(
"Cannot resolve keyword %r into field. Join on '%s'"
- " not permitted." % (names[pos + 1], name))
+ " not permitted." % (names[pos + 1], name)
+ )
break
- return path, final_field, targets, names[pos + 1:]
+ return path, final_field, targets, names[pos + 1 :]
def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True):
"""
@@ -1631,7 +1753,10 @@ class Query(BaseExpression):
for pivot in range(len(names), 0, -1):
try:
path, final_field, targets, rest = self.names_to_path(
- names[:pivot], opts, allow_many, fail_on_missing=True,
+ names[:pivot],
+ opts,
+ allow_many,
+ fail_on_missing=True,
)
except FieldError as exc:
if pivot == 1:
@@ -1646,6 +1771,7 @@ class Query(BaseExpression):
transforms = names[pivot:]
break
for name in transforms:
+
def transform(field, alias, *, name, previous):
try:
wrapped = previous(field, alias)
@@ -1656,7 +1782,10 @@ class Query(BaseExpression):
raise last_field_exception
else:
raise
- final_transformer = functools.partial(transform, name=name, previous=final_transformer)
+
+ final_transformer = functools.partial(
+ transform, name=name, previous=final_transformer
+ )
# Then, add the path to the query's joins. Note that we can't trim
# joins at this stage - we will need the information about join type
# of the trimmed joins.
@@ -1673,8 +1802,13 @@ class Query(BaseExpression):
else:
nullable = True
connection = self.join_class(
- opts.db_table, alias, table_alias, INNER, join.join_field,
- nullable, filtered_relation=filtered_relation,
+ opts.db_table,
+ alias,
+ table_alias,
+ INNER,
+ join.join_field,
+ nullable,
+ filtered_relation=filtered_relation,
)
reuse = can_reuse if join.m2m else None
alias = self.join(connection, reuse=reuse)
@@ -1706,7 +1840,11 @@ class Query(BaseExpression):
cur_targets = {t.column for t in targets}
if not cur_targets.issubset(join_targets):
break
- targets_dict = {r[1].column: r[0] for r in info.join_field.related_fields if r[1].column in cur_targets}
+ targets_dict = {
+ r[1].column: r[0]
+ for r in info.join_field.related_fields
+ if r[1].column in cur_targets
+ }
targets = tuple(targets_dict[t.column] for t in targets)
self.unref_alias(joins.pop())
return targets, joins[-1], joins
@@ -1716,9 +1854,11 @@ class Query(BaseExpression):
for expr in exprs:
if isinstance(expr, Col):
yield expr
- elif include_external and callable(getattr(expr, 'get_external_cols', None)):
+ elif include_external and callable(
+ getattr(expr, "get_external_cols", None)
+ ):
yield from expr.get_external_cols()
- elif hasattr(expr, 'get_source_expressions'):
+ elif hasattr(expr, "get_source_expressions"):
yield from cls._gen_cols(
expr.get_source_expressions(),
include_external=include_external,
@@ -1735,7 +1875,7 @@ class Query(BaseExpression):
for alias in self._gen_col_aliases([annotation]):
if isinstance(self.alias_map[alias], Join):
raise FieldError(
- 'Joined field references are not permitted in this query'
+ "Joined field references are not permitted in this query"
)
if summarize:
# Summarize currently means we are doing an aggregate() query
@@ -1757,10 +1897,16 @@ class Query(BaseExpression):
for transform in field_list[1:]:
annotation = self.try_transform(annotation, transform)
return annotation
- join_info = self.setup_joins(field_list, self.get_meta(), self.get_initial_alias(), can_reuse=reuse)
- targets, final_alias, join_list = self.trim_joins(join_info.targets, join_info.joins, join_info.path)
+ join_info = self.setup_joins(
+ field_list, self.get_meta(), self.get_initial_alias(), can_reuse=reuse
+ )
+ targets, final_alias, join_list = self.trim_joins(
+ join_info.targets, join_info.joins, join_info.path
+ )
if not allow_joins and len(join_list) > 1:
- raise FieldError('Joined field references are not permitted in this query')
+ raise FieldError(
+ "Joined field references are not permitted in this query"
+ )
if len(targets) > 1:
raise FieldError(
"Referencing multicolumn fields with F() objects isn't supported"
@@ -1813,23 +1959,25 @@ class Query(BaseExpression):
# Need to add a restriction so that outer query's filters are in effect for
# the subquery, too.
query.bump_prefix(self)
- lookup_class = select_field.get_lookup('exact')
+ lookup_class = select_field.get_lookup("exact")
# Note that the query.select[0].alias is different from alias
# due to bump_prefix above.
- lookup = lookup_class(pk.get_col(query.select[0].alias),
- pk.get_col(alias))
+ lookup = lookup_class(pk.get_col(query.select[0].alias), pk.get_col(alias))
query.where.add(lookup, AND)
query.external_aliases[alias] = True
- lookup_class = select_field.get_lookup('exact')
+ lookup_class = select_field.get_lookup("exact")
lookup = lookup_class(col, ResolvedOuterRef(trimmed_prefix))
query.where.add(lookup, AND)
condition, needed_inner = self.build_filter(Exists(query))
if contains_louter:
or_null_condition, _ = self.build_filter(
- ('%s__isnull' % trimmed_prefix, True),
- current_negated=True, branch_negated=True, can_reuse=can_reuse)
+ ("%s__isnull" % trimmed_prefix, True),
+ current_negated=True,
+ branch_negated=True,
+ can_reuse=can_reuse,
+ )
condition.add(or_null_condition, OR)
# Note that the end result will be:
# (outercol NOT IN innerq AND outercol IS NOT NULL) OR outercol IS NULL.
@@ -1907,8 +2055,8 @@ class Query(BaseExpression):
self.values_select = ()
def add_select_col(self, col, name):
- self.select += col,
- self.values_select += name,
+ self.select += (col,)
+ self.values_select += (name,)
def set_select(self, cols):
self.default_cols = False
@@ -1934,7 +2082,9 @@ class Query(BaseExpression):
for name in field_names:
# Join promotion note - we must not remove any rows here, so
# if there is no existing joins, use outer join.
- join_info = self.setup_joins(name.split(LOOKUP_SEP), opts, alias, allow_many=allow_m2m)
+ join_info = self.setup_joins(
+ name.split(LOOKUP_SEP), opts, alias, allow_many=allow_m2m
+ )
targets, final_alias, joins = self.trim_joins(
join_info.targets,
join_info.joins,
@@ -1957,12 +2107,18 @@ class Query(BaseExpression):
"it." % name
)
else:
- names = sorted([
- *get_field_names_from_opts(opts), *self.extra,
- *self.annotation_select, *self._filtered_relations
- ])
- raise FieldError("Cannot resolve keyword %r into field. "
- "Choices are: %s" % (name, ", ".join(names)))
+ names = sorted(
+ [
+ *get_field_names_from_opts(opts),
+ *self.extra,
+ *self.annotation_select,
+ *self._filtered_relations,
+ ]
+ )
+ raise FieldError(
+ "Cannot resolve keyword %r into field. "
+ "Choices are: %s" % (name, ", ".join(names))
+ )
def add_ordering(self, *ordering):
"""
@@ -1976,9 +2132,9 @@ class Query(BaseExpression):
errors = []
for item in ordering:
if isinstance(item, str):
- if item == '?':
+ if item == "?":
continue
- if item.startswith('-'):
+ if item.startswith("-"):
item = item[1:]
if item in self.annotations:
continue
@@ -1987,15 +2143,15 @@ class Query(BaseExpression):
# names_to_path() validates the lookup. A descriptive
# FieldError will be raise if it's not.
self.names_to_path(item.split(LOOKUP_SEP), self.model._meta)
- elif not hasattr(item, 'resolve_expression'):
+ elif not hasattr(item, "resolve_expression"):
errors.append(item)
- if getattr(item, 'contains_aggregate', False):
+ if getattr(item, "contains_aggregate", False):
raise FieldError(
- 'Using an aggregate in order_by() without also including '
- 'it in annotate() is not allowed: %s' % item
+ "Using an aggregate in order_by() without also including "
+ "it in annotate() is not allowed: %s" % item
)
if errors:
- raise FieldError('Invalid order_by arguments: %s' % errors)
+ raise FieldError("Invalid order_by arguments: %s" % errors)
if ordering:
self.order_by += ordering
else:
@@ -2008,7 +2164,9 @@ class Query(BaseExpression):
If 'clear_default' is True, there will be no ordering in the resulting
query (not even the model's default).
"""
- if not force and (self.is_sliced or self.distinct_fields or self.select_for_update):
+ if not force and (
+ self.is_sliced or self.distinct_fields or self.select_for_update
+ ):
return
self.order_by = ()
self.extra_order_by = ()
@@ -2031,10 +2189,9 @@ class Query(BaseExpression):
for join in list(self.alias_map.values())[1:]: # Skip base table.
model = join.join_field.related_model
if model not in seen_models:
- column_names.update({
- field.column
- for field in model._meta.local_concrete_fields
- })
+ column_names.update(
+ {field.column for field in model._meta.local_concrete_fields}
+ )
seen_models.add(model)
group_by = list(self.select)
@@ -2082,7 +2239,7 @@ class Query(BaseExpression):
entry_params = []
pos = entry.find("%s")
while pos != -1:
- if pos == 0 or entry[pos - 1] != '%':
+ if pos == 0 or entry[pos - 1] != "%":
entry_params.append(next(param_iter))
pos = entry.find("%s", pos + 2)
select_pairs[name] = (entry, entry_params)
@@ -2135,8 +2292,8 @@ class Query(BaseExpression):
"""
existing, defer = self.deferred_loading
field_names = set(field_names)
- if 'pk' in field_names:
- field_names.remove('pk')
+ if "pk" in field_names:
+ field_names.remove("pk")
field_names.add(self.get_meta().pk.name)
if defer:
@@ -2224,7 +2381,9 @@ class Query(BaseExpression):
# Selected annotations must be known before setting the GROUP BY
# clause.
if self.group_by is True:
- self.add_fields((f.attname for f in self.model._meta.concrete_fields), False)
+ self.add_fields(
+ (f.attname for f in self.model._meta.concrete_fields), False
+ )
# Disable GROUP BY aliases to avoid orphaning references to the
# SELECT clause which is about to be cleared.
self.set_group_by(allow_aliases=False)
@@ -2254,7 +2413,8 @@ class Query(BaseExpression):
return {}
elif self.annotation_select_mask is not None:
self._annotation_select_cache = {
- k: v for k, v in self.annotations.items()
+ k: v
+ for k, v in self.annotations.items()
if k in self.annotation_select_mask
}
return self._annotation_select_cache
@@ -2269,8 +2429,7 @@ class Query(BaseExpression):
return {}
elif self.extra_select_mask is not None:
self._extra_select_cache = {
- k: v for k, v in self.extra.items()
- if k in self.extra_select_mask
+ k: v for k, v in self.extra.items() if k in self.extra_select_mask
}
return self._extra_select_cache
else:
@@ -2297,8 +2456,7 @@ class Query(BaseExpression):
# the lookup part of the query. That is, avoid trimming
# joins generated for F() expressions.
lookup_tables = [
- t for t in self.alias_map
- if t in self._lookup_joins or t == self.base_table
+ t for t in self.alias_map if t in self._lookup_joins or t == self.base_table
]
for trimmed_paths, path in enumerate(all_paths):
if path.m2m:
@@ -2317,8 +2475,7 @@ class Query(BaseExpression):
break
trimmed_prefix.append(name)
paths_in_prefix -= len(path)
- trimmed_prefix.append(
- join_field.foreign_related_fields[0].name)
+ trimmed_prefix.append(join_field.foreign_related_fields[0].name)
trimmed_prefix = LOOKUP_SEP.join(trimmed_prefix)
# Lets still see if we can trim the first join from the inner query
# (that is, self). We can't do this for:
@@ -2331,7 +2488,9 @@ class Query(BaseExpression):
select_fields = [r[0] for r in join_field.related_fields]
select_alias = lookup_tables[trimmed_paths + 1]
self.unref_alias(lookup_tables[trimmed_paths])
- extra_restriction = join_field.get_extra_restriction(None, lookup_tables[trimmed_paths + 1])
+ extra_restriction = join_field.get_extra_restriction(
+ None, lookup_tables[trimmed_paths + 1]
+ )
if extra_restriction:
self.where.add(extra_restriction, AND)
else:
@@ -2367,12 +2526,12 @@ class Query(BaseExpression):
# is_nullable() is needed to the compiler stage, but that is not easy
# to do currently.
return field.null or (
- field.empty_strings_allowed and
- connections[DEFAULT_DB_ALIAS].features.interprets_empty_strings_as_nulls
+ field.empty_strings_allowed
+ and connections[DEFAULT_DB_ALIAS].features.interprets_empty_strings_as_nulls
)
-def get_order_dir(field, default='ASC'):
+def get_order_dir(field, default="ASC"):
"""
Return the field name and direction for an order specification. For
example, '-foo' is returned as ('foo', 'DESC').
@@ -2381,7 +2540,7 @@ def get_order_dir(field, default='ASC'):
prefix) should sort. The '-' prefix always sorts the opposite way.
"""
dirn = ORDER_DIR[default]
- if field[0] == '-':
+ if field[0] == "-":
return field[1:], dirn[1]
return field, dirn[0]
@@ -2428,8 +2587,8 @@ class JoinPromoter:
def __repr__(self):
return (
- f'{self.__class__.__qualname__}(connector={self.connector!r}, '
- f'num_children={self.num_children!r}, negated={self.negated!r})'
+ f"{self.__class__.__qualname__}(connector={self.connector!r}, "
+ f"num_children={self.num_children!r}, negated={self.negated!r})"
)
def add_votes(self, votes):
@@ -2461,7 +2620,7 @@ class JoinPromoter:
# to rel_a would remove a valid match from the query. So, we need
# to promote any existing INNER to LOUTER (it is possible this
# promotion in turn will be demoted later on).
- if self.effective_connector == 'OR' and votes < self.num_children:
+ if self.effective_connector == "OR" and votes < self.num_children:
to_promote.add(table)
# If connector is AND and there is a filter that can match only
# when there is a joinable row, then use INNER. For example, in
@@ -2473,8 +2632,9 @@ class JoinPromoter:
# (rel_a__col__icontains=Alex | rel_a__col__icontains=Russell)
# then if rel_a doesn't produce any rows, the whole condition
# can't match. Hence we can safely use INNER join.
- if self.effective_connector == 'AND' or (
- self.effective_connector == 'OR' and votes == self.num_children):
+ if self.effective_connector == "AND" or (
+ self.effective_connector == "OR" and votes == self.num_children
+ ):
to_demote.add(table)
# Finally, what happens in cases where we have:
# (rel_a__col=1|rel_b__col=2) & rel_a__col__gte=0
diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py
index f6a371a925..04063f73bc 100644
--- a/django/db/models/sql/subqueries.py
+++ b/django/db/models/sql/subqueries.py
@@ -3,18 +3,16 @@ Query subclasses which provide extra functionality beyond simple data retrieval.
"""
from django.core.exceptions import FieldError
-from django.db.models.sql.constants import (
- CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS,
-)
+from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS
from django.db.models.sql.query import Query
-__all__ = ['DeleteQuery', 'UpdateQuery', 'InsertQuery', 'AggregateQuery']
+__all__ = ["DeleteQuery", "UpdateQuery", "InsertQuery", "AggregateQuery"]
class DeleteQuery(Query):
"""A DELETE SQL query."""
- compiler = 'SQLDeleteCompiler'
+ compiler = "SQLDeleteCompiler"
def do_query(self, table, where, using):
self.alias_map = {table: self.alias_map[table]}
@@ -38,17 +36,19 @@ class DeleteQuery(Query):
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
self.clear_where()
self.add_filter(
- f'{field.attname}__in',
- pk_list[offset:offset + GET_ITERATOR_CHUNK_SIZE],
+ f"{field.attname}__in",
+ pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE],
+ )
+ num_deleted += self.do_query(
+ self.get_meta().db_table, self.where, using=using
)
- num_deleted += self.do_query(self.get_meta().db_table, self.where, using=using)
return num_deleted
class UpdateQuery(Query):
"""An UPDATE SQL query."""
- compiler = 'SQLUpdateCompiler'
+ compiler = "SQLUpdateCompiler"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -72,7 +72,9 @@ class UpdateQuery(Query):
self.add_update_values(values)
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
self.clear_where()
- self.add_filter('pk__in', pk_list[offset: offset + GET_ITERATOR_CHUNK_SIZE])
+ self.add_filter(
+ "pk__in", pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]
+ )
self.get_compiler(using).execute_sql(NO_RESULTS)
def add_update_values(self, values):
@@ -84,12 +86,14 @@ class UpdateQuery(Query):
values_seq = []
for name, val in values.items():
field = self.get_meta().get_field(name)
- direct = not (field.auto_created and not field.concrete) or not field.concrete
+ direct = (
+ not (field.auto_created and not field.concrete) or not field.concrete
+ )
model = field.model._meta.concrete_model
if not direct or (field.is_relation and field.many_to_many):
raise FieldError(
- 'Cannot update model field %r (only non-relations and '
- 'foreign keys permitted).' % field
+ "Cannot update model field %r (only non-relations and "
+ "foreign keys permitted)." % field
)
if model is not self.get_meta().concrete_model:
self.add_related_update(model, field, val)
@@ -104,7 +108,7 @@ class UpdateQuery(Query):
called add_update_targets() to hint at the extra information here.
"""
for field, model, val in values_seq:
- if hasattr(val, 'resolve_expression'):
+ if hasattr(val, "resolve_expression"):
# Resolve expressions here so that annotations are no longer needed
val = val.resolve_expression(self, allow_joins=False, for_save=True)
self.values.append((field, model, val))
@@ -130,15 +134,17 @@ class UpdateQuery(Query):
query = UpdateQuery(model)
query.values = values
if self.related_ids is not None:
- query.add_filter('pk__in', self.related_ids)
+ query.add_filter("pk__in", self.related_ids)
result.append(query)
return result
class InsertQuery(Query):
- compiler = 'SQLInsertCompiler'
+ compiler = "SQLInsertCompiler"
- def __init__(self, *args, on_conflict=None, update_fields=None, unique_fields=None, **kwargs):
+ def __init__(
+ self, *args, on_conflict=None, update_fields=None, unique_fields=None, **kwargs
+ ):
super().__init__(*args, **kwargs)
self.fields = []
self.objs = []
@@ -158,7 +164,7 @@ class AggregateQuery(Query):
elements in the provided list.
"""
- compiler = 'SQLAggregateCompiler'
+ compiler = "SQLAggregateCompiler"
def __init__(self, model, inner_query):
self.inner_query = inner_query
diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py
index 50ff13be75..532780fd98 100644
--- a/django/db/models/sql/where.py
+++ b/django/db/models/sql/where.py
@@ -7,8 +7,8 @@ from django.utils import tree
from django.utils.functional import cached_property
# Connection types
-AND = 'AND'
-OR = 'OR'
+AND = "AND"
+OR = "OR"
class WhereNode(tree.Node):
@@ -25,6 +25,7 @@ class WhereNode(tree.Node):
relabeled_clone() method or relabel_aliases() and clone() methods and
contains_aggregate attribute.
"""
+
default = AND
resolved = False
conditional = True
@@ -40,15 +41,15 @@ class WhereNode(tree.Node):
in_negated = negated ^ self.negated
# If the effective connector is OR and this node contains an aggregate,
# then we need to push the whole branch to HAVING clause.
- may_need_split = (
- (in_negated and self.connector == AND) or
- (not in_negated and self.connector == OR))
+ may_need_split = (in_negated and self.connector == AND) or (
+ not in_negated and self.connector == OR
+ )
if may_need_split and self.contains_aggregate:
return None, self
where_parts = []
having_parts = []
for c in self.children:
- if hasattr(c, 'split_having'):
+ if hasattr(c, "split_having"):
where_part, having_part = c.split_having(in_negated)
if where_part is not None:
where_parts.append(where_part)
@@ -58,8 +59,16 @@ class WhereNode(tree.Node):
having_parts.append(c)
else:
where_parts.append(c)
- having_node = self.__class__(having_parts, self.connector, self.negated) if having_parts else None
- where_node = self.__class__(where_parts, self.connector, self.negated) if where_parts else None
+ having_node = (
+ self.__class__(having_parts, self.connector, self.negated)
+ if having_parts
+ else None
+ )
+ where_node = (
+ self.__class__(where_parts, self.connector, self.negated)
+ if where_parts
+ else None
+ )
return where_node, having_node
def as_sql(self, compiler, connection):
@@ -94,24 +103,24 @@ class WhereNode(tree.Node):
# counts.
if empty_needed == 0:
if self.negated:
- return '', []
+ return "", []
else:
raise EmptyResultSet
if full_needed == 0:
if self.negated:
raise EmptyResultSet
else:
- return '', []
- conn = ' %s ' % self.connector
+ return "", []
+ conn = " %s " % self.connector
sql_string = conn.join(result)
if sql_string:
if self.negated:
# Some backends (Oracle at least) need parentheses
# around the inner SQL in the negated case, even if the
# inner SQL contains just a single expression.
- sql_string = 'NOT (%s)' % sql_string
+ sql_string = "NOT (%s)" % sql_string
elif len(result) > 1 or self.resolved:
- sql_string = '(%s)' % sql_string
+ sql_string = "(%s)" % sql_string
return sql_string, result_params
def get_group_by_cols(self, alias=None):
@@ -133,10 +142,10 @@ class WhereNode(tree.Node):
mapping old (current) alias values to the new values.
"""
for pos, child in enumerate(self.children):
- if hasattr(child, 'relabel_aliases'):
+ if hasattr(child, "relabel_aliases"):
# For example another WhereNode
child.relabel_aliases(change_map)
- elif hasattr(child, 'relabeled_clone'):
+ elif hasattr(child, "relabeled_clone"):
self.children[pos] = child.relabeled_clone(change_map)
def clone(self):
@@ -146,10 +155,12 @@ class WhereNode(tree.Node):
value) tuples, or objects supporting .clone().
"""
clone = self.__class__._new_instance(
- children=None, connector=self.connector, negated=self.negated,
+ children=None,
+ connector=self.connector,
+ negated=self.negated,
)
for child in self.children:
- if hasattr(child, 'clone'):
+ if hasattr(child, "clone"):
clone.children.append(child.clone())
else:
clone.children.append(child)
@@ -185,18 +196,18 @@ class WhereNode(tree.Node):
@staticmethod
def _resolve_leaf(expr, query, *args, **kwargs):
- if hasattr(expr, 'resolve_expression'):
+ if hasattr(expr, "resolve_expression"):
expr = expr.resolve_expression(query, *args, **kwargs)
return expr
@classmethod
def _resolve_node(cls, node, query, *args, **kwargs):
- if hasattr(node, 'children'):
+ if hasattr(node, "children"):
for child in node.children:
cls._resolve_node(child, query, *args, **kwargs)
- if hasattr(node, 'lhs'):
+ if hasattr(node, "lhs"):
node.lhs = cls._resolve_leaf(node.lhs, query, *args, **kwargs)
- if hasattr(node, 'rhs'):
+ if hasattr(node, "rhs"):
node.rhs = cls._resolve_leaf(node.rhs, query, *args, **kwargs)
def resolve_expression(self, *args, **kwargs):
@@ -208,6 +219,7 @@ class WhereNode(tree.Node):
@cached_property
def output_field(self):
from django.db.models import BooleanField
+
return BooleanField()
def select_format(self, compiler, sql, params):
@@ -215,7 +227,7 @@ class WhereNode(tree.Node):
# (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
# BY list.
if not compiler.connection.features.supports_boolean_expr_in_select_clause:
- sql = f'CASE WHEN {sql} THEN 1 ELSE 0 END'
+ sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
return sql, params
def get_db_converters(self, connection):
@@ -227,6 +239,7 @@ class WhereNode(tree.Node):
class NothingNode:
"""A node that matches nothing."""
+
contains_aggregate = False
def as_sql(self, compiler=None, connection=None):