diff options
| author | Nicolas Delaby <ticosax@free.fr> | 2017-09-22 17:53:17 +0200 |
|---|---|---|
| committer | Tim Graham <timograham@gmail.com> | 2017-09-22 11:53:17 -0400 |
| commit | 01d440fa1e6b5c62acfa8b3fde43dfa1505f93c6 (patch) | |
| tree | 21b1f96ecd0fca636746595bce50eb46abdde880 /django/db/models/sql/query.py | |
| parent | 3f9d85d95cab228fd881ea952c707022e9e3bdf3 (diff) | |
Fixed #27332 -- Added FilteredRelation API for conditional join (ON clause) support.
Thanks Anssi Kääriäinen for contributing to the patch.
Diffstat (limited to 'django/db/models/sql/query.py')
| -rw-r--r-- | django/db/models/sql/query.py | 137 |
1 files changed, 116 insertions, 21 deletions
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index dfa369513b..a962aabdf1 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -45,6 +45,14 @@ def get_field_names_from_opts(opts): )) +def get_children_from_q(q): + for child in q.children: + if isinstance(child, Node): + yield from get_children_from_q(child) + else: + yield child + + JoinInfo = namedtuple( 'JoinInfo', ('final_field', 'targets', 'opts', 'joins', 'path') @@ -210,6 +218,8 @@ class Query: # load. self.deferred_loading = (frozenset(), True) + self._filtered_relations = {} + @property def extra(self): if self._extra is None: @@ -311,6 +321,7 @@ class Query: if 'subq_aliases' in self.__dict__: obj.subq_aliases = self.subq_aliases.copy() obj.used_aliases = self.used_aliases.copy() + obj._filtered_relations = self._filtered_relations.copy() # Clear the cached_property try: del obj.base_table @@ -624,6 +635,8 @@ class Query: opts = orig_opts for name in parts[:-1]: old_model = cur_model + if name in self._filtered_relations: + name = self._filtered_relations[name].relation_name source = opts.get_field(name) if is_reverse_o2o(source): cur_model = source.related_model @@ -684,7 +697,7 @@ class Query: for model, values in seen.items(): callback(target, model, values) - def table_alias(self, table_name, create=False): + def table_alias(self, table_name, create=False, filtered_relation=None): """ Return a table alias for the given table_name and whether this is a new alias or not. @@ -704,8 +717,8 @@ class Query: alias_list.append(alias) else: # The first occurrence of a table uses the table name directly. - alias = table_name - self.table_map[alias] = [alias] + alias = filtered_relation.alias if filtered_relation is not None else table_name + self.table_map[table_name] = [alias] self.alias_refcount[alias] = 1 return alias, True @@ -881,7 +894,7 @@ class Query: """ return len([1 for count in self.alias_refcount.values() if count]) - def join(self, join, reuse=None): + def join(self, join, reuse=None, reuse_with_filtered_relation=False): """ Return an alias for the 'join', either reusing an existing alias for that join or creating a new one. 'join' is either a @@ -890,18 +903,29 @@ class Query: The 'reuse' parameter can be either None which means all joins are reusable, or it can be a set containing the aliases that can be reused. + The 'reuse_with_filtered_relation' parameter is used when computing + FilteredRelation instances. + A join is always created as LOUTER if the lhs alias is LOUTER to make sure chains like t1 LOUTER t2 INNER t3 aren't generated. All new joins are created as LOUTER if the join is nullable. """ - reuse = [a for a, j in self.alias_map.items() - if (reuse is None or a in reuse) and j == join] - if reuse: - self.ref_alias(reuse[0]) - return reuse[0] + if reuse_with_filtered_relation and reuse: + reuse_aliases = [ + a for a, j in self.alias_map.items() + if a in reuse and j.equals(join, with_filtered_relation=False) + ] + else: + reuse_aliases = [ + a for a, j in self.alias_map.items() + if (reuse is None or a in reuse) and j == join + ] + if reuse_aliases: + self.ref_alias(reuse_aliases[0]) + return reuse_aliases[0] # No reuse is possible, so we need a new alias. - alias, _ = self.table_alias(join.table_name, create=True) + alias, _ = self.table_alias(join.table_name, create=True, filtered_relation=join.filtered_relation) if join.join_type: if self.alias_map[join.parent_alias].join_type == LOUTER or join.nullable: join_type = LOUTER @@ -1090,7 +1114,8 @@ class Query: (name, lhs.output_field.__class__.__name__)) def build_filter(self, filter_expr, branch_negated=False, current_negated=False, - can_reuse=None, allow_joins=True, split_subq=True): + can_reuse=None, allow_joins=True, split_subq=True, + reuse_with_filtered_relation=False): """ Build a WhereNode for a single filter clause but don't add it to this Query. Query.add_q() will then add this filter to the where @@ -1112,6 +1137,9 @@ class Query: The 'can_reuse' is a set of reusable joins for multijoins. + If 'reuse_with_filtered_relation' is True, then only joins in can_reuse + will be reused. + The method will create a filter clause that can be added to the current query. However, if the filter isn't added to the query then the caller is responsible for unreffing the joins used. @@ -1147,7 +1175,10 @@ class Query: allow_many = not branch_negated or not split_subq try: - join_info = self.setup_joins(parts, opts, alias, can_reuse=can_reuse, allow_many=allow_many) + join_info = self.setup_joins( + parts, opts, alias, can_reuse=can_reuse, allow_many=allow_many, + reuse_with_filtered_relation=reuse_with_filtered_relation, + ) # Prevent iterator from being consumed by check_related_objects() if isinstance(value, Iterator): @@ -1250,6 +1281,41 @@ class Query: needed_inner = joinpromoter.update_join_types(self) return target_clause, needed_inner + def build_filtered_relation_q(self, q_object, reuse, branch_negated=False, current_negated=False): + """Add a FilteredRelation object to the current filter.""" + connector = q_object.connector + current_negated ^= q_object.negated + branch_negated = branch_negated or q_object.negated + target_clause = self.where_class(connector=connector, negated=q_object.negated) + for child in q_object.children: + if isinstance(child, Node): + child_clause = self.build_filtered_relation_q( + child, reuse=reuse, branch_negated=branch_negated, + current_negated=current_negated, + ) + else: + child_clause, _ = self.build_filter( + child, can_reuse=reuse, branch_negated=branch_negated, + current_negated=current_negated, + allow_joins=True, split_subq=False, + reuse_with_filtered_relation=True, + ) + target_clause.add(child_clause, connector) + return target_clause + + def add_filtered_relation(self, filtered_relation, alias): + filtered_relation.alias = alias + lookups = dict(get_children_from_q(filtered_relation.condition)) + for lookup in chain((filtered_relation.relation_name,), lookups): + lookup_parts, field_parts, _ = self.solve_lookup_type(lookup) + shift = 2 if not lookup_parts else 1 + if len(field_parts) > (shift + len(lookup_parts)): + raise ValueError( + "FilteredRelation's condition doesn't support nested " + "relations (got %r)." % lookup + ) + self._filtered_relations[filtered_relation.alias] = filtered_relation + def names_to_path(self, names, opts, allow_many=True, fail_on_missing=False): """ Walk the list of names and turns them into PathInfo tuples. A single @@ -1272,12 +1338,15 @@ class Query: name = opts.pk.name field = None + filtered_relation = None try: field = opts.get_field(name) except FieldDoesNotExist: if name in self.annotation_select: field = self.annotation_select[name].output_field - + elif name in self._filtered_relations and pos == 0: + filtered_relation = self._filtered_relations[name] + field = opts.get_field(filtered_relation.relation_name) if field is not None: # Fields that contain one-to-many relations with a generic # model (like a GenericForeignKey) cannot generate reverse @@ -1301,7 +1370,10 @@ class Query: pos -= 1 if pos == -1 or fail_on_missing: field_names = list(get_field_names_from_opts(opts)) - available = sorted(field_names + list(self.annotation_select)) + available = sorted( + field_names + list(self.annotation_select) + + list(self._filtered_relations) + ) raise FieldError("Cannot resolve keyword '%s' into field. " "Choices are: %s" % (name, ", ".join(available))) break @@ -1315,7 +1387,7 @@ class Query: cur_names_with_path[1].extend(path_to_parent) opts = path_to_parent[-1].to_opts if hasattr(field, 'get_path_info'): - pathinfos = field.get_path_info() + pathinfos = field.get_path_info(filtered_relation) if not allow_many: for inner_pos, p in enumerate(pathinfos): if p.m2m: @@ -1340,7 +1412,8 @@ class Query: break return path, final_field, targets, names[pos + 1:] - def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True): + def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True, + reuse_with_filtered_relation=False): """ Compute the necessary table joins for the passage through the fields given in 'names'. 'opts' is the Options class for the current model @@ -1352,6 +1425,9 @@ class Query: that can be reused. Note that non-reverse foreign keys are always reusable when using setup_joins(). + The 'reuse_with_filtered_relation' can be used to force 'can_reuse' + parameter and force the relation on the given connections. + If 'allow_many' is False, then any reverse foreign key seen will generate a MultiJoin exception. @@ -1374,15 +1450,29 @@ class Query: # joins at this stage - we will need the information about join type # of the trimmed joins. for join in path: + if join.filtered_relation: + filtered_relation = join.filtered_relation.clone() + table_alias = filtered_relation.alias + else: + filtered_relation = None + table_alias = None opts = join.to_opts if join.direct: nullable = self.is_nullable(join.join_field) else: nullable = True - connection = Join(opts.db_table, alias, None, INNER, join.join_field, nullable) - reuse = can_reuse if join.m2m else None - alias = self.join(connection, reuse=reuse) + connection = Join( + opts.db_table, alias, table_alias, INNER, join.join_field, + nullable, filtered_relation=filtered_relation, + ) + reuse = can_reuse if join.m2m or reuse_with_filtered_relation else None + alias = self.join( + connection, reuse=reuse, + reuse_with_filtered_relation=reuse_with_filtered_relation, + ) joins.append(alias) + if filtered_relation: + filtered_relation.path = joins[:] return JoinInfo(final_field, targets, opts, joins, path) def trim_joins(self, targets, joins, path): @@ -1402,6 +1492,8 @@ class Query: for pos, info in enumerate(reversed(path)): if len(joins) == 1 or not info.direct: break + if info.filtered_relation: + break join_targets = {t.column for t in info.join_field.foreign_related_fields} cur_targets = {t.column for t in targets} if not cur_targets.issubset(join_targets): @@ -1425,7 +1517,7 @@ class Query: return self.annotation_select[name] else: field_list = name.split(LOOKUP_SEP) - join_info = self.setup_joins(field_list, self.get_meta(), self.get_initial_alias(), reuse) + join_info = self.setup_joins(field_list, self.get_meta(), self.get_initial_alias(), can_reuse=reuse) targets, _, join_list = self.trim_joins(join_info.targets, join_info.joins, join_info.path) if len(targets) > 1: raise FieldError("Referencing multicolumn fields with F() objects " @@ -1602,7 +1694,10 @@ class Query: # from the model on which the lookup failed. raise else: - names = sorted(list(get_field_names_from_opts(opts)) + list(self.extra) + list(self.annotation_select)) + names = sorted( + list(get_field_names_from_opts(opts)) + list(self.extra) + + list(self.annotation_select) + list(self._filtered_relations) + ) raise FieldError("Cannot resolve keyword %r into field. " "Choices are: %s" % (name, ", ".join(names))) |
