summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSimon Charette <charette.s@gmail.com>2019-02-15 22:02:33 -0500
committerTim Graham <timograham@gmail.com>2019-02-21 10:20:47 -0500
commitde7f6b51b21747e19e90d9e3e04e0cdbf84e8a75 (patch)
tree3d7ae393082708fbcafced7b4ae833035652876a
parent28712d8acfffa9cdabb88cb610bae14913fa185d (diff)
Refs #19544 -- Added a fast path for through additions if supported.
The single query insertion path is taken if the backend supports inserts that ignore conflicts and m2m_changed signals don't have to be sent.
-rw-r--r--django/db/models/fields/related_descriptors.py74
-rw-r--r--tests/many_to_many/tests.py17
2 files changed, 71 insertions, 20 deletions
diff --git a/django/db/models/fields/related_descriptors.py b/django/db/models/fields/related_descriptors.py
index 52cb91d3a8..2b426c37a5 100644
--- a/django/db/models/fields/related_descriptors.py
+++ b/django/db/models/fields/related_descriptors.py
@@ -1051,10 +1051,9 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
return obj, created
update_or_create.alters_data = True
- def _get_missing_target_ids(self, source_field_name, target_field_name, db, objs):
+ def _get_target_ids(self, target_field_name, objs):
"""
- Return the subset of ids of `objs` that aren't already assigned to
- this relationship.
+ Return the set of ids of `objs` that the target field references.
"""
from django.db.models import Model
target_ids = set()
@@ -1081,6 +1080,13 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
)
else:
target_ids.add(obj)
+ return target_ids
+
+ def _get_missing_target_ids(self, source_field_name, target_field_name, db, target_ids):
+ """
+ Return the subset of ids of `objs` that aren't already assigned to
+ this relationship.
+ """
vals = self.through._default_manager.using(db).values_list(
target_field_name, flat=True
).filter(**{
@@ -1089,6 +1095,35 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
})
return target_ids.difference(vals)
+ def _get_add_plan(self, db, source_field_name):
+ """
+ Return a boolean triple of the way the add should be performed.
+
+ The first element is whether or not bulk_create(ignore_conflicts)
+ can be used, the second whether or not signals must be sent, and
+ the third element is whether or not the immediate bulk insertion
+ with conflicts ignored can be performed.
+ """
+ # Conflicts can be ignored when the intermediary model is
+ # auto-created as the only possible collision is on the
+ # (source_id, target_id) tuple. The same assertion doesn't hold for
+ # user-defined intermediary models as they could have other fields
+ # causing conflicts which must be surfaced.
+ can_ignore_conflicts = (
+ connections[db].features.supports_ignore_conflicts and
+ self.through._meta.auto_created is not False
+ )
+ # Don't send the signal when inserting duplicate data row
+ # for symmetrical reverse entries.
+ must_send_signals = (self.reverse or source_field_name == self.source_field_name) and (
+ signals.m2m_changed.has_listeners(self.through)
+ )
+ # Fast addition through bulk insertion can only be performed
+ # if no m2m_changed listeners are connected for self.through
+ # as they require the added set of ids to be provided via
+ # pk_set.
+ return can_ignore_conflicts, must_send_signals, (can_ignore_conflicts and not must_send_signals)
+
def _add_items(self, source_field_name, target_field_name, *objs, through_defaults=None):
# source_field_name: the PK fieldname in join table for the source object
# target_field_name: the PK fieldname in join table for the target object
@@ -1097,37 +1132,40 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
# If there aren't any objects, there is nothing to do.
if objs:
+ target_ids = self._get_target_ids(target_field_name, objs)
db = router.db_for_write(self.through, instance=self.instance)
- missing_target_ids = self._get_missing_target_ids(source_field_name, target_field_name, db, objs)
+ can_ignore_conflicts, must_send_signals, can_fast_add = self._get_add_plan(db, source_field_name)
+ if can_fast_add:
+ self.through._default_manager.using(db).bulk_create([
+ self.through(**{
+ '%s_id' % source_field_name: self.related_val[0],
+ '%s_id' % target_field_name: target_id,
+ })
+ for target_id in target_ids
+ ], ignore_conflicts=True)
+ return
+ missing_target_ids = self._get_missing_target_ids(
+ source_field_name, target_field_name, db, target_ids
+ )
with transaction.atomic(using=db, savepoint=False):
- if self.reverse or source_field_name == self.source_field_name:
- # Don't send the signal when we are inserting the
- # duplicate data row for symmetrical reverse entries.
+ if must_send_signals:
signals.m2m_changed.send(
sender=self.through, action='pre_add',
instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=missing_target_ids, using=db,
)
- # Add the ones that aren't there already. Conflicts can be
- # ignored when the intermediary model is auto-created as
- # the only possible collision is on the (sid_id, tid_id)
- # tuple. The same assertion doesn't hold for user-defined
- # intermediary models as they could have other fields
- # causing conflicts which must be surfaced.
- ignore_conflicts = self.through._meta.auto_created is not False
+ # Add the ones that aren't there already.
self.through._default_manager.using(db).bulk_create([
self.through(**through_defaults, **{
'%s_id' % source_field_name: self.related_val[0],
'%s_id' % target_field_name: target_id,
})
for target_id in missing_target_ids
- ], ignore_conflicts=ignore_conflicts)
+ ], ignore_conflicts=can_ignore_conflicts)
- if self.reverse or source_field_name == self.source_field_name:
- # Don't send the signal when we are inserting the
- # duplicate data row for symmetrical reverse entries.
+ if must_send_signals:
signals.m2m_changed.send(
sender=self.through, action='post_add',
instance=self.instance, reverse=self.reverse,
diff --git a/tests/many_to_many/tests.py b/tests/many_to_many/tests.py
index adde2ac563..098cd29e46 100644
--- a/tests/many_to_many/tests.py
+++ b/tests/many_to_many/tests.py
@@ -118,13 +118,26 @@ class ManyToManyTests(TestCase):
)
@skipUnlessDBFeature('supports_ignore_conflicts')
- def test_add_ignore_conflicts(self):
+ def test_fast_add_ignore_conflicts(self):
+ """
+ A single query is necessary to add auto-created through instances if
+ the database backend supports bulk_create(ignore_conflicts) and no
+ m2m_changed signals receivers are connected.
+ """
+ with self.assertNumQueries(1):
+ self.a1.publications.add(self.p1, self.p2)
+
+ @skipUnlessDBFeature('supports_ignore_conflicts')
+ def test_slow_add_ignore_conflicts(self):
manager_cls = self.a1.publications.__class__
# Simulate a race condition between the missing ids retrieval and
# the bulk insertion attempt.
missing_target_ids = {self.p1.id}
+ # Disable fast-add to test the case where the slow add path is taken.
+ add_plan = (True, False, False)
with mock.patch.object(manager_cls, '_get_missing_target_ids', return_value=missing_target_ids) as mocked:
- self.a1.publications.add(self.p1)
+ with mock.patch.object(manager_cls, '_get_add_plan', return_value=add_plan):
+ self.a1.publications.add(self.p1)
mocked.assert_called_once()
def test_related_sets(self):