summaryrefslogtreecommitdiff
path: root/django/db/models/sql/compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'django/db/models/sql/compiler.py')
-rw-r--r--django/db/models/sql/compiler.py37
1 files changed, 22 insertions, 15 deletions
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index be2d590d84..18365f1d75 100644
--- a/django/db/models/sql/compiler.py
+++ b/django/db/models/sql/compiler.py
@@ -972,19 +972,34 @@ class SQLCompiler:
the query.
"""
def _get_parent_klass_info(klass_info):
- return (
- {
+ for parent_model, parent_link in klass_info['model']._meta.parents.items():
+ parent_list = parent_model._meta.get_parent_list()
+ yield {
'model': parent_model,
'field': parent_link,
'reverse': False,
'select_fields': [
select_index
for select_index in klass_info['select_fields']
- if self.select[select_index][0].target.model == parent_model
+ # Selected columns from a model or its parents.
+ if (
+ self.select[select_index][0].target.model == parent_model or
+ self.select[select_index][0].target.model in parent_list
+ )
],
}
- for parent_model, parent_link in klass_info['model']._meta.parents.items()
- )
+
+ def _get_first_selected_col_from_model(klass_info):
+ """
+ Find the first selected column from a model. If it doesn't exist,
+ don't lock a model.
+
+ select_fields is filled recursively, so it also contains fields
+ from the parent models.
+ """
+ for select_index in klass_info['select_fields']:
+ if self.select[select_index][0].target.model == klass_info['model']:
+ return self.select[select_index][0]
def _get_field_choices():
"""Yield all allowed field paths in breadth-first search order."""
@@ -1013,14 +1028,7 @@ class SQLCompiler:
for name in self.query.select_for_update_of:
klass_info = self.klass_info
if name == 'self':
- # Find the first selected column from a base model. If it
- # doesn't exist, don't lock a base model.
- for select_index in klass_info['select_fields']:
- if self.select[select_index][0].target.model == klass_info['model']:
- col = self.select[select_index][0]
- break
- else:
- col = None
+ col = _get_first_selected_col_from_model(klass_info)
else:
for part in name.split(LOOKUP_SEP):
klass_infos = (
@@ -1040,8 +1048,7 @@ class SQLCompiler:
if klass_info is None:
invalid_names.append(name)
continue
- select_index = klass_info['select_fields'][0]
- col = self.select[select_index][0]
+ col = _get_first_selected_col_from_model(klass_info)
if col is not None:
if self.connection.features.select_for_update_of_column:
result.append(self.compile(col)[0])