summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRan Benita <ran234@gmail.com>2017-06-29 23:00:15 +0300
committerTim Graham <timograham@gmail.com>2017-06-29 16:00:15 -0400
commitb9f7dce84b7ab5e198129030eae6c1a4aec83d24 (patch)
tree8f350d29029e977c48107db898a6994c38cbfba4
parent2d18c60fbb1efcc980adfe875dadb02c749da509 (diff)
Fixed #28010 -- Added FOR UPDATE OF support to QuerySet.select_for_update().
-rw-r--r--django/db/backends/base/features.py4
-rw-r--r--django/db/backends/base/operations.py13
-rw-r--r--django/db/backends/oracle/features.py2
-rw-r--r--django/db/backends/postgresql/features.py1
-rw-r--r--django/db/models/query.py3
-rw-r--r--django/db/models/sql/compiler.py67
-rw-r--r--django/db/models/sql/query.py2
-rw-r--r--docs/ref/databases.txt6
-rw-r--r--docs/ref/models/querysets.txt23
-rw-r--r--docs/releases/2.0.txt11
-rw-r--r--tests/select_for_update/models.py11
-rw-r--r--tests/select_for_update/tests.py86
12 files changed, 206 insertions, 23 deletions
diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py
index 4b38a2d6b1..41a456bfed 100644
--- a/django/db/backends/base/features.py
+++ b/django/db/backends/base/features.py
@@ -36,6 +36,10 @@ class BaseDatabaseFeatures:
has_select_for_update = False
has_select_for_update_nowait = False
has_select_for_update_skip_locked = False
+ has_select_for_update_of = False
+ # Does the database's SELECT FOR UPDATE OF syntax require a column rather
+ # than a table?
+ select_for_update_of_column = False
supports_select_related = True
diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py
index c3df3465c3..cf6b5f9166 100644
--- a/django/db/backends/base/operations.py
+++ b/django/db/backends/base/operations.py
@@ -177,16 +177,15 @@ class BaseDatabaseOperations:
"""
return []
- def for_update_sql(self, nowait=False, skip_locked=False):
+ def for_update_sql(self, nowait=False, skip_locked=False, of=()):
"""
Return the FOR UPDATE SQL clause to lock rows for an update operation.
"""
- if nowait:
- return 'FOR UPDATE NOWAIT'
- elif skip_locked:
- return 'FOR UPDATE SKIP LOCKED'
- else:
- return 'FOR UPDATE'
+ return 'FOR UPDATE%s%s%s' % (
+ ' OF %s' % ', '.join(of) if of else '',
+ ' NOWAIT' if nowait else '',
+ ' SKIP LOCKED' if skip_locked else '',
+ )
def last_executed_query(self, cursor, sql, params):
"""
diff --git a/django/db/backends/oracle/features.py b/django/db/backends/oracle/features.py
index fe6e30dc46..90584ff14f 100644
--- a/django/db/backends/oracle/features.py
+++ b/django/db/backends/oracle/features.py
@@ -9,6 +9,8 @@ class DatabaseFeatures(BaseDatabaseFeatures):
has_select_for_update = True
has_select_for_update_nowait = True
has_select_for_update_skip_locked = True
+ has_select_for_update_of = True
+ select_for_update_of_column = True
can_return_id_from_insert = True
allow_sliced_subqueries = False
can_introspect_autofield = True
diff --git a/django/db/backends/postgresql/features.py b/django/db/backends/postgresql/features.py
index 3f6cc7894d..0f291a6586 100644
--- a/django/db/backends/postgresql/features.py
+++ b/django/db/backends/postgresql/features.py
@@ -13,6 +13,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
can_defer_constraint_checks = True
has_select_for_update = True
has_select_for_update_nowait = True
+ has_select_for_update_of = True
has_bulk_insert = True
uses_savepoints = True
can_release_savepoints = True
diff --git a/django/db/models/query.py b/django/db/models/query.py
index 38f69f22d1..e5e1c1b9f4 100644
--- a/django/db/models/query.py
+++ b/django/db/models/query.py
@@ -839,7 +839,7 @@ class QuerySet:
return self
return self._combinator_query('difference', *other_qs)
- def select_for_update(self, nowait=False, skip_locked=False):
+ def select_for_update(self, nowait=False, skip_locked=False, of=()):
"""
Return a new QuerySet instance that will select objects with a
FOR UPDATE lock.
@@ -851,6 +851,7 @@ class QuerySet:
obj.query.select_for_update = True
obj.query.select_for_update_nowait = nowait
obj.query.select_for_update_skip_locked = skip_locked
+ obj.query.select_for_update_of = of
return obj
def select_related(self, *fields):
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index b4b27a5b56..c705d33af8 100644
--- a/django/db/models/sql/compiler.py
+++ b/django/db/models/sql/compiler.py
@@ -1,3 +1,4 @@
+import collections
import re
from itertools import chain
@@ -472,14 +473,21 @@ class SQLCompiler:
)
nowait = self.query.select_for_update_nowait
skip_locked = self.query.select_for_update_skip_locked
- # If it's a NOWAIT/SKIP LOCKED query but the backend
- # doesn't support it, raise a DatabaseError to prevent a
+ of = self.query.select_for_update_of
+ # If it's a NOWAIT/SKIP LOCKED/OF query but the backend
+ # doesn't support it, raise NotSupportedError to prevent a
# possible deadlock.
if nowait and not self.connection.features.has_select_for_update_nowait:
raise NotSupportedError('NOWAIT is not supported on this database backend.')
elif skip_locked and not self.connection.features.has_select_for_update_skip_locked:
raise NotSupportedError('SKIP LOCKED is not supported on this database backend.')
- for_update_part = self.connection.ops.for_update_sql(nowait=nowait, skip_locked=skip_locked)
+ elif of and not self.connection.features.has_select_for_update_of:
+ raise NotSupportedError('FOR UPDATE OF is not supported on this database backend.')
+ for_update_part = self.connection.ops.for_update_sql(
+ nowait=nowait,
+ skip_locked=skip_locked,
+ of=self.get_select_for_update_of_arguments(),
+ )
if for_update_part and self.connection.features.for_update_after_from:
result.append(for_update_part)
@@ -832,6 +840,59 @@ class SQLCompiler:
)
return related_klass_infos
+ def get_select_for_update_of_arguments(self):
+ """
+ Return a quoted list of arguments for the SELECT FOR UPDATE OF part of
+ the query.
+ """
+ def _get_field_choices():
+ """Yield all allowed field paths in breadth-first search order."""
+ queue = collections.deque([(None, self.klass_info)])
+ while queue:
+ parent_path, klass_info = queue.popleft()
+ if parent_path is None:
+ path = []
+ yield 'self'
+ else:
+ path = parent_path + [klass_info['field'].name]
+ yield LOOKUP_SEP.join(path)
+ queue.extend(
+ (path, klass_info)
+ for klass_info in klass_info.get('related_klass_infos', [])
+ )
+ result = []
+ invalid_names = []
+ for name in self.query.select_for_update_of:
+ parts = [] if name == 'self' else name.split(LOOKUP_SEP)
+ klass_info = self.klass_info
+ for part in parts:
+ for related_klass_info in klass_info.get('related_klass_infos', []):
+ if related_klass_info['field'].name == part:
+ klass_info = related_klass_info
+ break
+ else:
+ klass_info = None
+ break
+ if klass_info is None:
+ invalid_names.append(name)
+ continue
+ select_index = klass_info['select_fields'][0]
+ col = self.select[select_index][0]
+ if self.connection.features.select_for_update_of_column:
+ result.append(self.compile(col)[0])
+ else:
+ result.append(self.quote_name_unless_alias(col.alias))
+ if invalid_names:
+ raise FieldError(
+ 'Invalid field name(s) given in select_for_update(of=(...)): %s. '
+ 'Only relational fields followed in the query are allowed. '
+ 'Choices are: %s.' % (
+ ', '.join(invalid_names),
+ ', '.join(_get_field_choices()),
+ )
+ )
+ return result
+
def deferred_to_columns(self):
"""
Convert the self.deferred_loading data structure to mapping of table
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
index b4a87938f7..70fd648c52 100644
--- a/django/db/models/sql/query.py
+++ b/django/db/models/sql/query.py
@@ -161,6 +161,7 @@ class Query:
self.select_for_update = False
self.select_for_update_nowait = False
self.select_for_update_skip_locked = False
+ self.select_for_update_of = ()
self.select_related = False
# Arbitrary limit for select_related to prevents infinite recursion.
@@ -288,6 +289,7 @@ class Query:
obj.select_for_update = self.select_for_update
obj.select_for_update_nowait = self.select_for_update_nowait
obj.select_for_update_skip_locked = self.select_for_update_skip_locked
+ obj.select_for_update_of = self.select_for_update_of
obj.select_related = self.select_related
obj.values_select = self.values_select
obj._annotations = self._annotations.copy() if self._annotations is not None else None
diff --git a/docs/ref/databases.txt b/docs/ref/databases.txt
index 45b1772514..69921f437b 100644
--- a/docs/ref/databases.txt
+++ b/docs/ref/databases.txt
@@ -629,9 +629,9 @@ both MySQL and Django will attempt to convert the values from UTC to local time.
Row locking with ``QuerySet.select_for_update()``
-------------------------------------------------
-MySQL does not support the ``NOWAIT`` and ``SKIP LOCKED`` options to the
-``SELECT ... FOR UPDATE`` statement. If ``select_for_update()`` is used with
-``nowait=True`` or ``skip_locked=True``, then a
+MySQL does not support the ``NOWAIT``, ``SKIP LOCKED``, and ``OF`` options to
+the ``SELECT ... FOR UPDATE`` statement. If ``select_for_update()`` is used
+with ``nowait=True``, ``skip_locked=True``, or ``of`` then a
:exc:`~django.db.NotSupportedError` is raised.
Automatic typecasting can cause unexpected results
diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt
index a9006a14a9..d8d063a7a5 100644
--- a/docs/ref/models/querysets.txt
+++ b/docs/ref/models/querysets.txt
@@ -1611,7 +1611,7 @@ For example::
``select_for_update()``
~~~~~~~~~~~~~~~~~~~~~~~
-.. method:: select_for_update(nowait=False, skip_locked=False)
+.. method:: select_for_update(nowait=False, skip_locked=False, of=())
Returns a queryset that will lock rows until the end of the transaction,
generating a ``SELECT ... FOR UPDATE`` SQL statement on supported databases.
@@ -1635,14 +1635,21 @@ queryset is evaluated. You can also ignore locked rows by using
``select_for_update()`` with both options enabled will result in a
:exc:`ValueError`.
+By default, ``select_for_update()`` locks all rows that are selected by the
+query. For example, rows of related objects specified in :meth:`select_related`
+are locked in addition to rows of the queryset's model. If this isn't desired,
+specify the related objects you want to lock in ``select_for_update(of=(...))``
+using the same fields syntax as :meth:`select_related`. Use the value ``'self'``
+to refer to the queryset's model.
+
Currently, the ``postgresql``, ``oracle``, and ``mysql`` database
backends support ``select_for_update()``. However, MySQL doesn't support the
-``nowait`` and ``skip_locked`` arguments.
+``nowait``, ``skip_locked``, and ``of`` arguments.
-Passing ``nowait=True`` or ``skip_locked=True`` to ``select_for_update()``
-using database backends that do not support these options, such as MySQL,
-raises a :exc:`~django.db.NotSupportedError`. This prevents code from
-unexpectedly blocking.
+Passing ``nowait=True``, ``skip_locked=True``, or ``of`` to
+``select_for_update()`` using database backends that do not support these
+options, such as MySQL, raises a :exc:`~django.db.NotSupportedError`. This
+prevents code from unexpectedly blocking.
Evaluating a queryset with ``select_for_update()`` in autocommit mode on
backends which support ``SELECT ... FOR UPDATE`` is a
@@ -1670,6 +1677,10 @@ raised if ``select_for_update()`` is used in autocommit mode.
The ``skip_locked`` argument was added.
+.. versionchanged:: 2.0
+
+ The ``of`` argument was added.
+
``raw()``
~~~~~~~~~
diff --git a/docs/releases/2.0.txt b/docs/releases/2.0.txt
index 078cbbdf2c..8dfc43c24d 100644
--- a/docs/releases/2.0.txt
+++ b/docs/releases/2.0.txt
@@ -252,6 +252,12 @@ Models
:class:`~django.db.models.functions.datetime.Extract` now works with
:class:`~django.db.models.DurationField`.
+* Added the ``of`` argument to :meth:`.QuerySet.select_for_update()`, supported
+ on PostgreSQL and Oracle, to lock only rows from specific tables rather than
+ all selected tables. It may be helpful particularly when
+ :meth:`~.QuerySet.select_for_update()` is used in conjunction with
+ :meth:`~.QuerySet.select_related()`.
+
Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~
@@ -331,6 +337,11 @@ backends.
* The first argument of ``SchemaEditor._create_index_name()`` is now
``table_name`` rather than ``model``.
+* To enable ``FOR UPDATE OF`` support, set
+ ``DatabaseFeatures.has_select_for_update_of = True``. If the database
+ requires that the arguments to ``OF`` be columns rather than tables, set
+ ``DatabaseFeatures.select_for_update_of_column = True``.
+
Dropped support for Oracle 11.2
-------------------------------
diff --git a/tests/select_for_update/models.py b/tests/select_for_update/models.py
index 48ad58faa9..b04ed31b00 100644
--- a/tests/select_for_update/models.py
+++ b/tests/select_for_update/models.py
@@ -1,5 +1,16 @@
from django.db import models
+class Country(models.Model):
+ name = models.CharField(max_length=30)
+
+
+class City(models.Model):
+ name = models.CharField(max_length=30)
+ country = models.ForeignKey(Country, models.CASCADE)
+
+
class Person(models.Model):
name = models.CharField(max_length=30)
+ born = models.ForeignKey(City, models.CASCADE, related_name='+')
+ died = models.ForeignKey(City, models.CASCADE, related_name='+')
diff --git a/tests/select_for_update/tests.py b/tests/select_for_update/tests.py
index 0c581f0f37..7228af6e8e 100644
--- a/tests/select_for_update/tests.py
+++ b/tests/select_for_update/tests.py
@@ -4,6 +4,7 @@ from unittest import mock
from multiple_database.routers import TestRouter
+from django.core.exceptions import FieldError
from django.db import (
DatabaseError, NotSupportedError, connection, connections, router,
transaction,
@@ -14,7 +15,7 @@ from django.test import (
)
from django.test.utils import CaptureQueriesContext
-from .models import Person
+from .models import City, Country, Person
class SelectForUpdateTests(TransactionTestCase):
@@ -24,7 +25,11 @@ class SelectForUpdateTests(TransactionTestCase):
def setUp(self):
# This is executed in autocommit mode so that code in
# run_select_for_update can see this data.
- self.person = Person.objects.create(name='Reinhardt')
+ self.country1 = Country.objects.create(name='Belgium')
+ self.country2 = Country.objects.create(name='France')
+ self.city1 = City.objects.create(name='Liberchies', country=self.country1)
+ self.city2 = City.objects.create(name='Samois-sur-Seine', country=self.country2)
+ self.person = Person.objects.create(name='Reinhardt', born=self.city1, died=self.city2)
# We need another database connection in transaction to test that one
# connection issuing a SELECT ... FOR UPDATE will block.
@@ -90,6 +95,29 @@ class SelectForUpdateTests(TransactionTestCase):
list(Person.objects.all().select_for_update(skip_locked=True))
self.assertTrue(self.has_for_update_sql(ctx.captured_queries, skip_locked=True))
+ @skipUnlessDBFeature('has_select_for_update_of')
+ def test_for_update_sql_generated_of(self):
+ """
+ The backend's FOR UPDATE OF variant appears in the generated SQL when
+ select_for_update() is invoked.
+ """
+ with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
+ list(Person.objects.select_related(
+ 'born__country',
+ ).select_for_update(
+ of=('born__country',),
+ ).select_for_update(
+ of=('self', 'born__country')
+ ))
+ features = connections['default'].features
+ if features.select_for_update_of_column:
+ expected = ['"select_for_update_person"."id"', '"select_for_update_country"."id"']
+ else:
+ expected = ['"select_for_update_person"', '"select_for_update_country"']
+ if features.uppercases_column_names:
+ expected = [value.upper() for value in expected]
+ self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
+
@skipUnlessDBFeature('has_select_for_update_nowait')
def test_nowait_raises_error_on_block(self):
"""
@@ -152,6 +180,58 @@ class SelectForUpdateTests(TransactionTestCase):
with transaction.atomic():
Person.objects.select_for_update(skip_locked=True).get()
+ @skipIfDBFeature('has_select_for_update_of')
+ @skipUnlessDBFeature('has_select_for_update')
+ def test_unsupported_of_raises_error(self):
+ """
+ NotSupportedError is raised if a SELECT...FOR UPDATE OF... is run on
+ a database backend that supports FOR UPDATE but not OF.
+ """
+ msg = 'FOR UPDATE OF is not supported on this database backend.'
+ with self.assertRaisesMessage(NotSupportedError, msg):
+ with transaction.atomic():
+ Person.objects.select_for_update(of=('self',)).get()
+
+ @skipUnlessDBFeature('has_select_for_update', 'has_select_for_update_of')
+ def test_unrelated_of_argument_raises_error(self):
+ """
+ FieldError is raised if a non-relation field is specified in of=(...).
+ """
+ msg = (
+ 'Invalid field name(s) given in select_for_update(of=(...)): %s. '
+ 'Only relational fields followed in the query are allowed. '
+ 'Choices are: self, born, born__country.'
+ )
+ invalid_of = [
+ ('nonexistent',),
+ ('name',),
+ ('born__nonexistent',),
+ ('born__name',),
+ ('born__nonexistent', 'born__name'),
+ ]
+ for of in invalid_of:
+ with self.subTest(of=of):
+ with self.assertRaisesMessage(FieldError, msg % ', '.join(of)):
+ with transaction.atomic():
+ Person.objects.select_related('born__country').select_for_update(of=of).get()
+
+ @skipUnlessDBFeature('has_select_for_update', 'has_select_for_update_of')
+ def test_related_but_unselected_of_argument_raises_error(self):
+ """
+ FieldError is raised if a relation field that is not followed in the
+ query is specified in of=(...).
+ """
+ msg = (
+ 'Invalid field name(s) given in select_for_update(of=(...)): %s. '
+ 'Only relational fields followed in the query are allowed. '
+ 'Choices are: self, born.'
+ )
+ for name in ['born__country', 'died', 'died__country']:
+ with self.subTest(name=name):
+ with self.assertRaisesMessage(FieldError, msg % name):
+ with transaction.atomic():
+ Person.objects.select_related('born').select_for_update(of=(name,)).get()
+
@skipUnlessDBFeature('has_select_for_update')
def test_for_update_after_from(self):
features_class = connections['default'].features.__class__
@@ -182,7 +262,7 @@ class SelectForUpdateTests(TransactionTestCase):
@skipUnlessDBFeature('supports_select_for_update_with_limit')
def test_select_for_update_with_limit(self):
- other = Person.objects.create(name='Grappeli')
+ other = Person.objects.create(name='Grappeli', born=self.city1, died=self.city2)
with transaction.atomic():
qs = list(Person.objects.all().order_by('pk').select_for_update()[1:2])
self.assertEqual(qs[0], other)