diff options
| author | Mariusz Felisiak <felisiak.mariusz@gmail.com> | 2019-12-02 07:57:19 +0100 |
|---|---|---|
| committer | Mariusz Felisiak <felisiak.mariusz@gmail.com> | 2019-12-02 07:58:04 +0100 |
| commit | f4ed6800bd5de576816861931699ddd8377d338d (patch) | |
| tree | 96564db66230939682c37c7e511925e257dd0249 /django | |
| parent | ca9144a4a810dd509e468699c13325d8a1f5dcb1 (diff) | |
[3.0.x] Fixed #30953 -- Made select_for_update() lock queryset's model when using "self" with multi-table inheritance.
Thanks Abhijeet Viswa for the report and initial patch.
Backport of 0107e3d1058f653f66032f7fd3a0bd61e96bf782 from master
Diffstat (limited to 'django')
| -rw-r--r-- | django/db/models/sql/compiler.py | 69 |
1 files changed, 51 insertions, 18 deletions
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index e85e73de50..2f4c5bed58 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -964,6 +964,21 @@ class SQLCompiler: Return a quoted list of arguments for the SELECT FOR UPDATE OF part of the query. """ + def _get_parent_klass_info(klass_info): + return ( + { + 'model': parent_model, + 'field': parent_link, + 'reverse': False, + 'select_fields': [ + select_index + for select_index in klass_info['select_fields'] + if self.select[select_index][0].target.model == parent_model + ], + } + for parent_model, parent_link in klass_info['model']._meta.parents.items() + ) + def _get_field_choices(): """Yield all allowed field paths in breadth-first search order.""" queue = collections.deque([(None, self.klass_info)]) @@ -980,33 +995,51 @@ class SQLCompiler: yield LOOKUP_SEP.join(path) queue.extend( (path, klass_info) + for klass_info in _get_parent_klass_info(klass_info) + ) + queue.extend( + (path, klass_info) for klass_info in klass_info.get('related_klass_infos', []) ) result = [] invalid_names = [] for name in self.query.select_for_update_of: - parts = [] if name == 'self' else name.split(LOOKUP_SEP) klass_info = self.klass_info - for part in parts: - for related_klass_info in klass_info.get('related_klass_infos', []): - field = related_klass_info['field'] - if related_klass_info['reverse']: - field = field.remote_field - if field.name == part: - klass_info = related_klass_info + if name == 'self': + # Find the first selected column from a base model. If it + # doesn't exist, don't lock a base model. + for select_index in klass_info['select_fields']: + if self.select[select_index][0].target.model == klass_info['model']: + col = self.select[select_index][0] break else: - klass_info = None - break - if klass_info is None: - invalid_names.append(name) - continue - select_index = klass_info['select_fields'][0] - col = self.select[select_index][0] - if self.connection.features.select_for_update_of_column: - result.append(self.compile(col)[0]) + col = None else: - result.append(self.quote_name_unless_alias(col.alias)) + for part in name.split(LOOKUP_SEP): + klass_infos = ( + *klass_info.get('related_klass_infos', []), + *_get_parent_klass_info(klass_info), + ) + for related_klass_info in klass_infos: + field = related_klass_info['field'] + if related_klass_info['reverse']: + field = field.remote_field + if field.name == part: + klass_info = related_klass_info + break + else: + klass_info = None + break + if klass_info is None: + invalid_names.append(name) + continue + select_index = klass_info['select_fields'][0] + col = self.select[select_index][0] + if col is not None: + if self.connection.features.select_for_update_of_column: + result.append(self.compile(col)[0]) + else: + result.append(self.quote_name_unless_alias(col.alias)) if invalid_names: raise FieldError( 'Invalid field name(s) given in select_for_update(of=(...)): %s. ' |
