summaryrefslogtreecommitdiff
path: root/django/db/models/sql/query.py
diff options
context:
space:
mode:
Diffstat (limited to 'django/db/models/sql/query.py')
-rw-r--r--django/db/models/sql/query.py105
1 files changed, 56 insertions, 49 deletions
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
index 0a4152587d..154b6bd204 100644
--- a/django/db/models/sql/query.py
+++ b/django/db/models/sql/query.py
@@ -1422,7 +1422,9 @@ class Query(object):
query.clear_ordering(True)
# Try to have as simple as possible subquery -> trim leading joins from
# the subquery.
- trimmed_joins = query.trim_start(names_with_path)
+ trimmed_prefix, contains_louter = query.trim_start(names_with_path)
+ query.remove_inherited_models()
+
# Add extra check to make sure the selected field will not be null
# since we are adding a IN <subquery> clause. This prevents the
# database from tripping over IN (...,NULL,...) selects and returning
@@ -1431,38 +1433,20 @@ class Query(object):
alias, col = query.select[0].col
query.where.add((Constraint(alias, col, query.select[0].field), 'isnull', False), AND)
- # Still make sure that the trimmed parts in the inner query and
- # trimmed prefix are in sync. So, use the trimmed_joins to make sure
- # as many path elements are in the prefix as there were trimmed joins.
- # In addition, convert the path elements back to names so that
- # add_filter() can handle them.
- trimmed_prefix = []
- paths_in_prefix = trimmed_joins
- for name, path in names_with_path:
- if paths_in_prefix - len(path) < 0:
- break
- trimmed_prefix.append(name)
- paths_in_prefix -= len(path)
- join_field = path[paths_in_prefix].join_field
- # TODO: This should be made properly multicolumn
- # join aware. It is likely better to not use build_filter
- # at all, instead construct joins up to the correct point,
- # then construct the needed equality constraint manually,
- # or maybe using SubqueryConstraint would work, too.
- # The foreign_related_fields attribute is right here, we
- # don't ever split joins for direct case.
- trimmed_prefix.append(
- join_field.field.foreign_related_fields[0].name)
- trimmed_prefix = LOOKUP_SEP.join(trimmed_prefix)
condition = self.build_filter(
('%s__in' % trimmed_prefix, query),
current_negated=True, branch_negated=True, can_reuse=can_reuse)
- # Intentionally leave the other alias as blank, if the condition
- # refers it, things will break here.
- extra_restriction = join_field.get_extra_restriction(
- self.where_class, None, [t for t in query.tables if query.alias_refcount[t]][0])
- if extra_restriction:
- query.where.add(extra_restriction, 'AND')
+ if contains_louter:
+ or_null_condition = self.build_filter(
+ ('%s__isnull' % trimmed_prefix, True),
+ current_negated=True, branch_negated=True, can_reuse=can_reuse)
+ condition.add(or_null_condition, OR)
+ # Note that the end result will be:
+ # (outercol NOT IN innerq AND outercol IS NOT NULL) OR outercol IS NULL.
+ # This might look crazy but due to how IN works, this seems to be
+ # correct. If the IS NOT NULL check is removed then outercol NOT
+ # IN will return UNKNOWN. If the IS NULL check is removed, then if
+ # outercol IS NULL we will not match the row.
return condition
def set_empty(self):
@@ -1821,35 +1805,58 @@ class Query(object):
def trim_start(self, names_with_path):
"""
Trims joins from the start of the join path. The candidates for trim
- are the PathInfos in names_with_path structure. Outer joins are not
- eligible for removal. Also sets the select column so the start
- matches the join.
+ are the PathInfos in names_with_path structure that are m2m joins.
+
+ Also sets the select column so the start matches the join.
+
+ This method is meant to be used for generating the subquery joins &
+ cols in split_exclude().
- This method is mostly useful for generating the subquery joins & col
- in "WHERE somecol IN (subquery)". This construct is needed by
- split_exclude().
+ Returns a lookup usable for doing outerq.filter(lookup=self). Returns
+ also if the joins in the prefix contain a LEFT OUTER join.
_"""
all_paths = []
for _, paths in names_with_path:
all_paths.extend(paths)
- direct_join = True
+ contains_louter = False
for pos, path in enumerate(all_paths):
+ if path.m2m:
+ break
if self.alias_map[self.tables[pos + 1]].join_type == self.LOUTER:
- direct_join = False
- pos -= 1
+ contains_louter = True
+ self.unref_alias(self.tables[pos])
+ # The path.join_field is a Rel, lets get the other side's field
+ join_field = path.join_field.field
+ # Build the filter prefix.
+ trimmed_prefix = []
+ paths_in_prefix = pos
+ for name, path in names_with_path:
+ if paths_in_prefix - len(path) < 0:
break
+ trimmed_prefix.append(name)
+ paths_in_prefix -= len(path)
+ trimmed_prefix.append(
+ join_field.foreign_related_fields[0].name)
+ trimmed_prefix = LOOKUP_SEP.join(trimmed_prefix)
+ # Lets still see if we can trim the first join from the inner query
+ # (that is, self). We can't do this for LEFT JOINs because we would
+ # miss those rows that have nothing on the outer side.
+ if self.alias_map[self.tables[pos + 1]].join_type != self.LOUTER:
+ select_fields = [r[0] for r in join_field.related_fields]
+ select_alias = self.tables[pos + 1]
self.unref_alias(self.tables[pos])
- if path.direct:
- direct_join = not direct_join
- join_side = 0 if direct_join else 1
- select_alias = self.tables[pos + 1]
- join_field = path.join_field
- if hasattr(join_field, 'field'):
- join_field = join_field.field
- select_fields = [r[join_side] for r in join_field.related_fields]
+ extra_restriction = join_field.get_extra_restriction(
+ self.where_class, None, self.tables[pos + 1])
+ if extra_restriction:
+ self.where.add(extra_restriction, AND)
+ else:
+ # TODO: It might be possible to trim more joins from the start of the
+ # inner query if it happens to have a longer join chain containing the
+ # values in select_fields. Lets punt this one for now.
+ select_fields = [r[1] for r in join_field.related_fields]
+ select_alias = self.tables[pos]
self.select = [SelectInfo((select_alias, f.column), f) for f in select_fields]
- self.remove_inherited_models()
- return pos
+ return trimmed_prefix, contains_louter
def is_nullable(self, field):
"""