summaryrefslogtreecommitdiff
path: root/django/db/models/sql/compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'django/db/models/sql/compiler.py')
-rw-r--r--django/db/models/sql/compiler.py69
1 files changed, 51 insertions, 18 deletions
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index f0daffe5c5..9e709d0f6e 100644
--- a/django/db/models/sql/compiler.py
+++ b/django/db/models/sql/compiler.py
@@ -947,6 +947,21 @@ class SQLCompiler:
Return a quoted list of arguments for the SELECT FOR UPDATE OF part of
the query.
"""
+ def _get_parent_klass_info(klass_info):
+ return (
+ {
+ 'model': parent_model,
+ 'field': parent_link,
+ 'reverse': False,
+ 'select_fields': [
+ select_index
+ for select_index in klass_info['select_fields']
+ if self.select[select_index][0].target.model == parent_model
+ ],
+ }
+ for parent_model, parent_link in klass_info['model']._meta.parents.items()
+ )
+
def _get_field_choices():
"""Yield all allowed field paths in breadth-first search order."""
queue = collections.deque([(None, self.klass_info)])
@@ -963,33 +978,51 @@ class SQLCompiler:
yield LOOKUP_SEP.join(path)
queue.extend(
(path, klass_info)
+ for klass_info in _get_parent_klass_info(klass_info)
+ )
+ queue.extend(
+ (path, klass_info)
for klass_info in klass_info.get('related_klass_infos', [])
)
result = []
invalid_names = []
for name in self.query.select_for_update_of:
- parts = [] if name == 'self' else name.split(LOOKUP_SEP)
klass_info = self.klass_info
- for part in parts:
- for related_klass_info in klass_info.get('related_klass_infos', []):
- field = related_klass_info['field']
- if related_klass_info['reverse']:
- field = field.remote_field
- if field.name == part:
- klass_info = related_klass_info
+ if name == 'self':
+ # Find the first selected column from a base model. If it
+ # doesn't exist, don't lock a base model.
+ for select_index in klass_info['select_fields']:
+ if self.select[select_index][0].target.model == klass_info['model']:
+ col = self.select[select_index][0]
break
else:
- klass_info = None
- break
- if klass_info is None:
- invalid_names.append(name)
- continue
- select_index = klass_info['select_fields'][0]
- col = self.select[select_index][0]
- if self.connection.features.select_for_update_of_column:
- result.append(self.compile(col)[0])
+ col = None
else:
- result.append(self.quote_name_unless_alias(col.alias))
+ for part in name.split(LOOKUP_SEP):
+ klass_infos = (
+ *klass_info.get('related_klass_infos', []),
+ *_get_parent_klass_info(klass_info),
+ )
+ for related_klass_info in klass_infos:
+ field = related_klass_info['field']
+ if related_klass_info['reverse']:
+ field = field.remote_field
+ if field.name == part:
+ klass_info = related_klass_info
+ break
+ else:
+ klass_info = None
+ break
+ if klass_info is None:
+ invalid_names.append(name)
+ continue
+ select_index = klass_info['select_fields'][0]
+ col = self.select[select_index][0]
+ if col is not None:
+ if self.connection.features.select_for_update_of_column:
+ result.append(self.compile(col)[0])
+ else:
+ result.append(self.quote_name_unless_alias(col.alias))
if invalid_names:
raise FieldError(
'Invalid field name(s) given in select_for_update(of=(...)): %s. '