summaryrefslogtreecommitdiff
path: root/django/db/models/sql/compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'django/db/models/sql/compiler.py')
-rw-r--r--django/db/models/sql/compiler.py67
1 files changed, 64 insertions, 3 deletions
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index b4b27a5b56..c705d33af8 100644
--- a/django/db/models/sql/compiler.py
+++ b/django/db/models/sql/compiler.py
@@ -1,3 +1,4 @@
+import collections
import re
from itertools import chain
@@ -472,14 +473,21 @@ class SQLCompiler:
)
nowait = self.query.select_for_update_nowait
skip_locked = self.query.select_for_update_skip_locked
- # If it's a NOWAIT/SKIP LOCKED query but the backend
- # doesn't support it, raise a DatabaseError to prevent a
+ of = self.query.select_for_update_of
+ # If it's a NOWAIT/SKIP LOCKED/OF query but the backend
+ # doesn't support it, raise NotSupportedError to prevent a
# possible deadlock.
if nowait and not self.connection.features.has_select_for_update_nowait:
raise NotSupportedError('NOWAIT is not supported on this database backend.')
elif skip_locked and not self.connection.features.has_select_for_update_skip_locked:
raise NotSupportedError('SKIP LOCKED is not supported on this database backend.')
- for_update_part = self.connection.ops.for_update_sql(nowait=nowait, skip_locked=skip_locked)
+ elif of and not self.connection.features.has_select_for_update_of:
+ raise NotSupportedError('FOR UPDATE OF is not supported on this database backend.')
+ for_update_part = self.connection.ops.for_update_sql(
+ nowait=nowait,
+ skip_locked=skip_locked,
+ of=self.get_select_for_update_of_arguments(),
+ )
if for_update_part and self.connection.features.for_update_after_from:
result.append(for_update_part)
@@ -832,6 +840,59 @@ class SQLCompiler:
)
return related_klass_infos
+ def get_select_for_update_of_arguments(self):
+ """
+ Return a quoted list of arguments for the SELECT FOR UPDATE OF part of
+ the query.
+ """
+ def _get_field_choices():
+ """Yield all allowed field paths in breadth-first search order."""
+ queue = collections.deque([(None, self.klass_info)])
+ while queue:
+ parent_path, klass_info = queue.popleft()
+ if parent_path is None:
+ path = []
+ yield 'self'
+ else:
+ path = parent_path + [klass_info['field'].name]
+ yield LOOKUP_SEP.join(path)
+ queue.extend(
+ (path, klass_info)
+ for klass_info in klass_info.get('related_klass_infos', [])
+ )
+ result = []
+ invalid_names = []
+ for name in self.query.select_for_update_of:
+ parts = [] if name == 'self' else name.split(LOOKUP_SEP)
+ klass_info = self.klass_info
+ for part in parts:
+ for related_klass_info in klass_info.get('related_klass_infos', []):
+ if related_klass_info['field'].name == part:
+ klass_info = related_klass_info
+ break
+ else:
+ klass_info = None
+ break
+ if klass_info is None:
+ invalid_names.append(name)
+ continue
+ select_index = klass_info['select_fields'][0]
+ col = self.select[select_index][0]
+ if self.connection.features.select_for_update_of_column:
+ result.append(self.compile(col)[0])
+ else:
+ result.append(self.quote_name_unless_alias(col.alias))
+ if invalid_names:
+ raise FieldError(
+ 'Invalid field name(s) given in select_for_update(of=(...)): %s. '
+ 'Only relational fields followed in the query are allowed. '
+ 'Choices are: %s.' % (
+ ', '.join(invalid_names),
+ ', '.join(_get_field_choices()),
+ )
+ )
+ return result
+
def deferred_to_columns(self):
"""
Convert the self.deferred_loading data structure to mapping of table