summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMariusz Felisiak <felisiak.mariusz@gmail.com>2019-12-02 07:57:19 +0100
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2019-12-02 07:58:44 +0100
commit6cf3b6f5cf0cc3b11e86e511ec5201999913286f (patch)
tree7c4c08f951923bb03efffff6c8639b530073d9b7
parent9a17ae50c61a3a0ea6c552ce4e3eab27f796d094 (diff)
[2.2.x] Fixed #30953 -- Made select_for_update() lock queryset's model when using "self" with multi-table inheritance.
Thanks Abhijeet Viswa for the report and initial patch. Backport of 0107e3d1058f653f66032f7fd3a0bd61e96bf782 from master
-rw-r--r--django/db/models/sql/compiler.py69
-rw-r--r--docs/ref/models/querysets.txt8
-rw-r--r--docs/releases/2.2.8.txt6
-rw-r--r--tests/select_for_update/models.py9
-rw-r--r--tests/select_for_update/tests.py62
5 files changed, 135 insertions, 19 deletions
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index f0daffe5c5..9e709d0f6e 100644
--- a/django/db/models/sql/compiler.py
+++ b/django/db/models/sql/compiler.py
@@ -947,6 +947,21 @@ class SQLCompiler:
Return a quoted list of arguments for the SELECT FOR UPDATE OF part of
the query.
"""
+ def _get_parent_klass_info(klass_info):
+ return (
+ {
+ 'model': parent_model,
+ 'field': parent_link,
+ 'reverse': False,
+ 'select_fields': [
+ select_index
+ for select_index in klass_info['select_fields']
+ if self.select[select_index][0].target.model == parent_model
+ ],
+ }
+ for parent_model, parent_link in klass_info['model']._meta.parents.items()
+ )
+
def _get_field_choices():
"""Yield all allowed field paths in breadth-first search order."""
queue = collections.deque([(None, self.klass_info)])
@@ -963,33 +978,51 @@ class SQLCompiler:
yield LOOKUP_SEP.join(path)
queue.extend(
(path, klass_info)
+ for klass_info in _get_parent_klass_info(klass_info)
+ )
+ queue.extend(
+ (path, klass_info)
for klass_info in klass_info.get('related_klass_infos', [])
)
result = []
invalid_names = []
for name in self.query.select_for_update_of:
- parts = [] if name == 'self' else name.split(LOOKUP_SEP)
klass_info = self.klass_info
- for part in parts:
- for related_klass_info in klass_info.get('related_klass_infos', []):
- field = related_klass_info['field']
- if related_klass_info['reverse']:
- field = field.remote_field
- if field.name == part:
- klass_info = related_klass_info
+ if name == 'self':
+ # Find the first selected column from a base model. If it
+ # doesn't exist, don't lock a base model.
+ for select_index in klass_info['select_fields']:
+ if self.select[select_index][0].target.model == klass_info['model']:
+ col = self.select[select_index][0]
break
else:
- klass_info = None
- break
- if klass_info is None:
- invalid_names.append(name)
- continue
- select_index = klass_info['select_fields'][0]
- col = self.select[select_index][0]
- if self.connection.features.select_for_update_of_column:
- result.append(self.compile(col)[0])
+ col = None
else:
- result.append(self.quote_name_unless_alias(col.alias))
+ for part in name.split(LOOKUP_SEP):
+ klass_infos = (
+ *klass_info.get('related_klass_infos', []),
+ *_get_parent_klass_info(klass_info),
+ )
+ for related_klass_info in klass_infos:
+ field = related_klass_info['field']
+ if related_klass_info['reverse']:
+ field = field.remote_field
+ if field.name == part:
+ klass_info = related_klass_info
+ break
+ else:
+ klass_info = None
+ break
+ if klass_info is None:
+ invalid_names.append(name)
+ continue
+ select_index = klass_info['select_fields'][0]
+ col = self.select[select_index][0]
+ if col is not None:
+ if self.connection.features.select_for_update_of_column:
+ result.append(self.compile(col)[0])
+ else:
+ result.append(self.quote_name_unless_alias(col.alias))
if invalid_names:
raise FieldError(
'Invalid field name(s) given in select_for_update(of=(...)): %s. '
diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt
index 3ccd5958fd..f23d4bc598 100644
--- a/docs/ref/models/querysets.txt
+++ b/docs/ref/models/querysets.txt
@@ -1696,6 +1696,14 @@ specify the related objects you want to lock in ``select_for_update(of=(...))``
using the same fields syntax as :meth:`select_related`. Use the value ``'self'``
to refer to the queryset's model.
+.. admonition:: Lock parents models in ``select_for_update(of=(...))``
+
+ If you want to lock parents models when using :ref:`multi-table inheritance
+ <multi-table-inheritance>`, you must specify parent link fields (by default
+ ``<parent_model_name>_ptr``) in the ``of`` argument. For example::
+
+ Restaurant.objects.select_for_update(of=('self', 'place_ptr'))
+
You can't use ``select_for_update()`` on nullable relations::
>>> Person.objects.select_related('hometown').select_for_update()
diff --git a/docs/releases/2.2.8.txt b/docs/releases/2.2.8.txt
index 4d8f9869c5..3c5eb5c754 100644
--- a/docs/releases/2.2.8.txt
+++ b/docs/releases/2.2.8.txt
@@ -17,3 +17,9 @@ Bugfixes
* Fixed a regression in Django 2.2.1 that caused a crash when migrating
permissions for proxy models with a multiple database setup if the
``default`` entry was empty (:ticket:`31021`).
+
+* Fixed a data loss possibility in the
+ :meth:`~django.db.models.query.QuerySet.select_for_update()`. When using
+ ``'self'`` in the ``of`` argument with :ref:`multi-table inheritance
+ <multi-table-inheritance>`, a parent model was locked instead of the
+ queryset's model (:ticket:`30953`).
diff --git a/tests/select_for_update/models.py b/tests/select_for_update/models.py
index b8154af3df..c84f9ad6b2 100644
--- a/tests/select_for_update/models.py
+++ b/tests/select_for_update/models.py
@@ -5,11 +5,20 @@ class Country(models.Model):
name = models.CharField(max_length=30)
+class EUCountry(Country):
+ join_date = models.DateField()
+
+
class City(models.Model):
name = models.CharField(max_length=30)
country = models.ForeignKey(Country, models.CASCADE)
+class EUCity(models.Model):
+ name = models.CharField(max_length=30)
+ country = models.ForeignKey(EUCountry, models.CASCADE)
+
+
class Person(models.Model):
name = models.CharField(max_length=30)
born = models.ForeignKey(City, models.CASCADE, related_name='+')
diff --git a/tests/select_for_update/tests.py b/tests/select_for_update/tests.py
index f359dc2650..1f1b20e47f 100644
--- a/tests/select_for_update/tests.py
+++ b/tests/select_for_update/tests.py
@@ -15,7 +15,7 @@ from django.test import (
)
from django.test.utils import CaptureQueriesContext
-from .models import City, Country, Person, PersonProfile
+from .models import City, Country, EUCity, EUCountry, Person, PersonProfile
class SelectForUpdateTests(TransactionTestCase):
@@ -120,6 +120,47 @@ class SelectForUpdateTests(TransactionTestCase):
self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
@skipUnlessDBFeature('has_select_for_update_of')
+ def test_for_update_sql_model_inheritance_generated_of(self):
+ with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
+ list(EUCountry.objects.select_for_update(of=('self',)))
+ if connection.features.select_for_update_of_column:
+ expected = ['select_for_update_eucountry"."country_ptr_id']
+ else:
+ expected = ['select_for_update_eucountry']
+ expected = [connection.ops.quote_name(value) for value in expected]
+ self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
+
+ @skipUnlessDBFeature('has_select_for_update_of')
+ def test_for_update_sql_model_inheritance_ptr_generated_of(self):
+ with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
+ list(EUCountry.objects.select_for_update(of=('self', 'country_ptr',)))
+ if connection.features.select_for_update_of_column:
+ expected = [
+ 'select_for_update_eucountry"."country_ptr_id',
+ 'select_for_update_country"."id',
+ ]
+ else:
+ expected = ['select_for_update_eucountry', 'select_for_update_country']
+ expected = [connection.ops.quote_name(value) for value in expected]
+ self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
+
+ @skipUnlessDBFeature('has_select_for_update_of')
+ def test_for_update_sql_model_inheritance_nested_ptr_generated_of(self):
+ with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
+ list(EUCity.objects.select_related('country').select_for_update(
+ of=('self', 'country__country_ptr',),
+ ))
+ if connection.features.select_for_update_of_column:
+ expected = [
+ 'select_for_update_eucity"."id',
+ 'select_for_update_country"."id',
+ ]
+ else:
+ expected = ['select_for_update_eucity', 'select_for_update_country']
+ expected = [connection.ops.quote_name(value) for value in expected]
+ self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
+
+ @skipUnlessDBFeature('has_select_for_update_of')
def test_for_update_of_followed_by_values(self):
with transaction.atomic():
values = list(Person.objects.select_for_update(of=('self',)).values('pk'))
@@ -258,6 +299,25 @@ class SelectForUpdateTests(TransactionTestCase):
).exclude(profile=None).select_for_update(of=(name,)).get()
@skipUnlessDBFeature('has_select_for_update', 'has_select_for_update_of')
+ def test_model_inheritance_of_argument_raises_error_ptr_in_choices(self):
+ msg = (
+ 'Invalid field name(s) given in select_for_update(of=(...)): '
+ 'name. Only relational fields followed in the query are allowed. '
+ 'Choices are: self, %s.'
+ )
+ with self.assertRaisesMessage(
+ FieldError,
+ msg % 'country, country__country_ptr',
+ ):
+ with transaction.atomic():
+ EUCity.objects.select_related(
+ 'country',
+ ).select_for_update(of=('name',)).get()
+ with self.assertRaisesMessage(FieldError, msg % 'country_ptr'):
+ with transaction.atomic():
+ EUCountry.objects.select_for_update(of=('name',)).get()
+
+ @skipUnlessDBFeature('has_select_for_update', 'has_select_for_update_of')
def test_reverse_one_to_one_of_arguments(self):
"""
Reverse OneToOneFields may be included in of=(...) as long as NULLs