diff options
Diffstat (limited to 'django/db/models/sql/compiler.py')
| -rw-r--r-- | django/db/models/sql/compiler.py | 37 |
1 files changed, 22 insertions, 15 deletions
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 9e709d0f6e..e5c726676a 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -948,19 +948,34 @@ class SQLCompiler: the query. """ def _get_parent_klass_info(klass_info): - return ( - { + for parent_model, parent_link in klass_info['model']._meta.parents.items(): + parent_list = parent_model._meta.get_parent_list() + yield { 'model': parent_model, 'field': parent_link, 'reverse': False, 'select_fields': [ select_index for select_index in klass_info['select_fields'] - if self.select[select_index][0].target.model == parent_model + # Selected columns from a model or its parents. + if ( + self.select[select_index][0].target.model == parent_model or + self.select[select_index][0].target.model in parent_list + ) ], } - for parent_model, parent_link in klass_info['model']._meta.parents.items() - ) + + def _get_first_selected_col_from_model(klass_info): + """ + Find the first selected column from a model. If it doesn't exist, + don't lock a model. + + select_fields is filled recursively, so it also contains fields + from the parent models. + """ + for select_index in klass_info['select_fields']: + if self.select[select_index][0].target.model == klass_info['model']: + return self.select[select_index][0] def _get_field_choices(): """Yield all allowed field paths in breadth-first search order.""" @@ -989,14 +1004,7 @@ class SQLCompiler: for name in self.query.select_for_update_of: klass_info = self.klass_info if name == 'self': - # Find the first selected column from a base model. If it - # doesn't exist, don't lock a base model. - for select_index in klass_info['select_fields']: - if self.select[select_index][0].target.model == klass_info['model']: - col = self.select[select_index][0] - break - else: - col = None + col = _get_first_selected_col_from_model(klass_info) else: for part in name.split(LOOKUP_SEP): klass_infos = ( @@ -1016,8 +1024,7 @@ class SQLCompiler: if klass_info is None: invalid_names.append(name) continue - select_index = klass_info['select_fields'][0] - col = self.select[select_index][0] + col = _get_first_selected_col_from_model(klass_info) if col is not None: if self.connection.features.select_for_update_of_column: result.append(self.compile(col)[0]) |
