diff options
Diffstat (limited to 'django/db/models/sql/compiler.py')
| -rw-r--r-- | django/db/models/sql/compiler.py | 69 |
1 files changed, 51 insertions, 18 deletions
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index e85e73de50..2f4c5bed58 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -964,6 +964,21 @@ class SQLCompiler: Return a quoted list of arguments for the SELECT FOR UPDATE OF part of the query. """ + def _get_parent_klass_info(klass_info): + return ( + { + 'model': parent_model, + 'field': parent_link, + 'reverse': False, + 'select_fields': [ + select_index + for select_index in klass_info['select_fields'] + if self.select[select_index][0].target.model == parent_model + ], + } + for parent_model, parent_link in klass_info['model']._meta.parents.items() + ) + def _get_field_choices(): """Yield all allowed field paths in breadth-first search order.""" queue = collections.deque([(None, self.klass_info)]) @@ -980,33 +995,51 @@ class SQLCompiler: yield LOOKUP_SEP.join(path) queue.extend( (path, klass_info) + for klass_info in _get_parent_klass_info(klass_info) + ) + queue.extend( + (path, klass_info) for klass_info in klass_info.get('related_klass_infos', []) ) result = [] invalid_names = [] for name in self.query.select_for_update_of: - parts = [] if name == 'self' else name.split(LOOKUP_SEP) klass_info = self.klass_info - for part in parts: - for related_klass_info in klass_info.get('related_klass_infos', []): - field = related_klass_info['field'] - if related_klass_info['reverse']: - field = field.remote_field - if field.name == part: - klass_info = related_klass_info + if name == 'self': + # Find the first selected column from a base model. If it + # doesn't exist, don't lock a base model. + for select_index in klass_info['select_fields']: + if self.select[select_index][0].target.model == klass_info['model']: + col = self.select[select_index][0] break else: - klass_info = None - break - if klass_info is None: - invalid_names.append(name) - continue - select_index = klass_info['select_fields'][0] - col = self.select[select_index][0] - if self.connection.features.select_for_update_of_column: - result.append(self.compile(col)[0]) + col = None else: - result.append(self.quote_name_unless_alias(col.alias)) + for part in name.split(LOOKUP_SEP): + klass_infos = ( + *klass_info.get('related_klass_infos', []), + *_get_parent_klass_info(klass_info), + ) + for related_klass_info in klass_infos: + field = related_klass_info['field'] + if related_klass_info['reverse']: + field = field.remote_field + if field.name == part: + klass_info = related_klass_info + break + else: + klass_info = None + break + if klass_info is None: + invalid_names.append(name) + continue + select_index = klass_info['select_fields'][0] + col = self.select[select_index][0] + if col is not None: + if self.connection.features.select_for_update_of_column: + result.append(self.compile(col)[0]) + else: + result.append(self.quote_name_unless_alias(col.alias)) if invalid_names: raise FieldError( 'Invalid field name(s) given in select_for_update(of=(...)): %s. ' |
