summaryrefslogtreecommitdiff
path: root/django/db/models/sql/query.py
diff options
context:
space:
mode:
Diffstat (limited to 'django/db/models/sql/query.py')
-rw-r--r--django/db/models/sql/query.py45
1 files changed, 38 insertions, 7 deletions
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
index 6a9348af66..a7839ccb4d 100644
--- a/django/db/models/sql/query.py
+++ b/django/db/models/sql/query.py
@@ -65,20 +65,52 @@ def get_field_names_from_opts(opts):
)
+def get_paths_from_expression(expr):
+ if isinstance(expr, F):
+ yield expr.name
+ elif hasattr(expr, "flatten"):
+ for child in expr.flatten():
+ if isinstance(child, F):
+ yield child.name
+ elif isinstance(child, Q):
+ yield from get_children_from_q(child)
+
+
def get_children_from_q(q):
for child in q.children:
if isinstance(child, Node):
yield from get_children_from_q(child)
- else:
- yield child
+ elif isinstance(child, tuple):
+ lhs, rhs = child
+ yield lhs
+ if hasattr(rhs, "resolve_expression"):
+ yield from get_paths_from_expression(rhs)
+ elif hasattr(child, "resolve_expression"):
+ yield from get_paths_from_expression(child)
def get_child_with_renamed_prefix(prefix, replacement, child):
if isinstance(child, Node):
return rename_prefix_from_q(prefix, replacement, child)
- lhs, rhs = child
- lhs = lhs.replace(prefix, replacement, 1)
- return lhs, rhs
+ if isinstance(child, tuple):
+ lhs, rhs = child
+ lhs = lhs.replace(prefix, replacement, 1)
+ if not isinstance(rhs, F) and hasattr(rhs, "resolve_expression"):
+ rhs = get_child_with_renamed_prefix(prefix, replacement, rhs)
+ return lhs, rhs
+
+ if isinstance(child, F):
+ child = child.copy()
+ child.name = child.name.replace(prefix, replacement, 1)
+ elif hasattr(child, "resolve_expression"):
+ child = child.copy()
+ child.set_source_expressions(
+ [
+ get_child_with_renamed_prefix(prefix, replacement, grand_child)
+ for grand_child in child.get_source_expressions()
+ ]
+ )
+ return child
def rename_prefix_from_q(prefix, replacement, q):
@@ -1618,7 +1650,6 @@ class Query(BaseExpression):
def add_filtered_relation(self, filtered_relation, alias):
filtered_relation.alias = alias
- lookups = dict(get_children_from_q(filtered_relation.condition))
relation_lookup_parts, relation_field_parts, _ = self.solve_lookup_type(
filtered_relation.relation_name
)
@@ -1627,7 +1658,7 @@ class Query(BaseExpression):
"FilteredRelation's relation_name cannot contain lookups "
"(got %r)." % filtered_relation.relation_name
)
- for lookup in chain(lookups):
+ for lookup in get_children_from_q(filtered_relation.condition):
lookup_parts, lookup_field_parts, _ = self.solve_lookup_type(lookup)
shift = 2 if not lookup_parts else 1
lookup_field_path = lookup_field_parts[:-shift]