summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--django/db/models/sql/compiler.py72
-rw-r--r--tests/queries/test_qs_combinators.py6
2 files changed, 44 insertions, 34 deletions
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index 98feb42716..262d722dc1 100644
--- a/django/db/models/sql/compiler.py
+++ b/django/db/models/sql/compiler.py
@@ -583,50 +583,28 @@ class SQLCompiler:
raise DatabaseError(
"ORDER BY not allowed in subqueries of compound statements."
)
- elif self.query.is_sliced and combinator == "union":
- for compiler in compilers:
- # A sliced union cannot have its parts elided as some of them
- # might be sliced as well and in the event where only a single
- # part produces a non-empty resultset it might be impossible to
- # generate valid SQL.
- compiler.elide_empty = False
- parts = ()
- selected = self.query.selected
+ parts = []
+ empty_compiler = None
for compiler in compilers:
try:
- # If the columns list is limited, then all combined queries
- # must have the same columns list. Set the selects defined on
- # the query on all combined queries, if not already set.
- if selected is not None and compiler.query.selected is None:
- compiler.query = compiler.query.clone()
- compiler.query.set_values(selected)
- part_sql, part_args = compiler.as_sql(with_col_aliases=True)
- if compiler.query.combinator:
- # Wrap in a subquery if wrapping in parentheses isn't
- # supported.
- if not features.supports_parentheses_in_compound:
- part_sql = "SELECT * FROM ({})".format(part_sql)
- # Add parentheses when combining with compound query if not
- # already added for all compound queries.
- elif (
- self.query.subquery
- or not features.supports_slicing_ordering_in_compound
- ):
- part_sql = "({})".format(part_sql)
- elif (
- self.query.subquery
- and features.supports_slicing_ordering_in_compound
- ):
- part_sql = "({})".format(part_sql)
- parts += ((part_sql, part_args),)
+ parts.append(self._get_combinator_part_sql(compiler))
except EmptyResultSet:
# Omit the empty queryset with UNION and with DIFFERENCE if the
# first queryset is nonempty.
if combinator == "union" or (combinator == "difference" and parts):
+ empty_compiler = compiler
continue
raise
if not parts:
raise EmptyResultSet
+ elif len(parts) == 1 and combinator == "union" and self.query.is_sliced:
+ # A sliced union cannot be composed of a single component because
+ # in the event the later is also sliced it might result in invalid
+ # SQL due to the usage of multiple LIMIT clauses. Prevent that from
+ # happening by always including an empty resultset query to force
+ # the creation of an union.
+ empty_compiler.elide_empty = False
+ parts.append(self._get_combinator_part_sql(empty_compiler))
combinator_sql = self.connection.ops.set_operators[combinator]
if all and combinator == "union":
combinator_sql += " ALL"
@@ -642,6 +620,32 @@ class SQLCompiler:
params.extend(part)
return result, params
+ def _get_combinator_part_sql(self, compiler):
+ features = self.connection.features
+ # If the columns list is limited, then all combined queries
+ # must have the same columns list. Set the selects defined on
+ # the query on all combined queries, if not already set.
+ selected = self.query.selected
+ if selected is not None and compiler.query.selected is None:
+ compiler.query = compiler.query.clone()
+ compiler.query.set_values(selected)
+ part_sql, part_args = compiler.as_sql(with_col_aliases=True)
+ if compiler.query.combinator:
+ # Wrap in a subquery if wrapping in parentheses isn't
+ # supported.
+ if not features.supports_parentheses_in_compound:
+ part_sql = "SELECT * FROM ({})".format(part_sql)
+ # Add parentheses when combining with compound query if not
+ # already added for all compound queries.
+ elif (
+ self.query.subquery
+ or not features.supports_slicing_ordering_in_compound
+ ):
+ part_sql = "({})".format(part_sql)
+ elif self.query.subquery and features.supports_slicing_ordering_in_compound:
+ part_sql = "({})".format(part_sql)
+ return part_sql, part_args
+
def get_qualify_sql(self):
where_parts = []
if self.where:
diff --git a/tests/queries/test_qs_combinators.py b/tests/queries/test_qs_combinators.py
index eac1533803..ad1017c8af 100644
--- a/tests/queries/test_qs_combinators.py
+++ b/tests/queries/test_qs_combinators.py
@@ -76,6 +76,12 @@ class QuerySetSetOperationTests(TestCase):
qs3 = qs1.union(qs2)
self.assertNumbersEqual(qs3[:1], [0])
+ def test_union_all_none_slice(self):
+ qs = Number.objects.filter(id__in=[])
+ with self.assertNumQueries(0):
+ self.assertSequenceEqual(qs.union(qs), [])
+ self.assertSequenceEqual(qs.union(qs)[0:0], [])
+
def test_union_empty_filter_slice(self):
qs1 = Number.objects.filter(num__lte=0)
qs2 = Number.objects.filter(pk__in=[])