summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorThomas Chaumeny <thomas.chaumeny.ext@gitguardian.com>2023-07-07 13:08:17 +0200
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2023-07-10 13:17:28 +0200
commit89c7454dbdae3e0df6d96aa6132205d05e4a9b3d (patch)
treee98327998df0f061c6fe2a0a0002b6b60b310260
parentb7a17b0ea0a2061bae752a3a2292007d41825814 (diff)
Fixed #34698 -- Made QuerySet.bulk_create() retrieve primary keys when updating conflicts.
-rw-r--r--django/db/models/query.py7
-rw-r--r--docs/ref/models/querysets.txt10
-rw-r--r--docs/releases/5.0.txt4
-rw-r--r--tests/bulk_create/tests.py37
4 files changed, 47 insertions, 11 deletions
diff --git a/django/db/models/query.py b/django/db/models/query.py
index 5ac2407ea3..395ba6e404 100644
--- a/django/db/models/query.py
+++ b/django/db/models/query.py
@@ -1837,12 +1837,17 @@ class QuerySet(AltersData):
inserted_rows = []
bulk_return = connection.features.can_return_rows_from_bulk_insert
for item in [objs[i : i + batch_size] for i in range(0, len(objs), batch_size)]:
- if bulk_return and on_conflict is None:
+ if bulk_return and (
+ on_conflict is None or on_conflict == OnConflict.UPDATE
+ ):
inserted_rows.extend(
self._insert(
item,
fields=fields,
using=self.db,
+ on_conflict=on_conflict,
+ update_fields=update_fields,
+ unique_fields=unique_fields,
returning_fields=self.model._meta.db_returning_fields,
)
)
diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt
index a754953264..fd6bb39ff8 100644
--- a/docs/ref/models/querysets.txt
+++ b/docs/ref/models/querysets.txt
@@ -2411,9 +2411,13 @@ On databases that support it (all except Oracle and SQLite < 3.24), setting the
SQLite, in addition to ``update_fields``, a list of ``unique_fields`` that may
be in conflict must be provided.
-Enabling the ``ignore_conflicts`` or ``update_conflicts`` parameter disable
-setting the primary key on each model instance (if the database normally
-support it).
+Enabling the ``ignore_conflicts`` parameter disables setting the primary key on
+each model instance (if the database normally supports it).
+
+.. versionchanged:: 5.0
+
+ In older versions, enabling the ``update_conflicts`` parameter prevented
+ setting the primary key on each model instance.
.. warning::
diff --git a/docs/releases/5.0.txt b/docs/releases/5.0.txt
index 0ad0835c29..e4c1eac1d9 100644
--- a/docs/releases/5.0.txt
+++ b/docs/releases/5.0.txt
@@ -357,6 +357,10 @@ Models
:meth:`.Model.save` now allows specifying a tuple of parent classes that must
be forced to be inserted.
+* :meth:`.QuerySet.bulk_create` and :meth:`.QuerySet.abulk_create` methods now
+ set the primary key on each model instance when the ``update_conflicts``
+ parameter is enabled (if the database supports it).
+
Pagination
~~~~~~~~~~
diff --git a/tests/bulk_create/tests.py b/tests/bulk_create/tests.py
index aee0cd9996..7b86a2def5 100644
--- a/tests/bulk_create/tests.py
+++ b/tests/bulk_create/tests.py
@@ -582,12 +582,16 @@ class BulkCreateTests(TestCase):
TwoFields(f1=1, f2=1, name="c"),
TwoFields(f1=2, f2=2, name="d"),
]
- TwoFields.objects.bulk_create(
+ results = TwoFields.objects.bulk_create(
conflicting_objects,
update_conflicts=True,
unique_fields=unique_fields,
update_fields=["name"],
)
+ self.assertEqual(len(results), len(conflicting_objects))
+ if connection.features.can_return_rows_from_bulk_insert:
+ for instance in results:
+ self.assertIsNotNone(instance.pk)
self.assertEqual(TwoFields.objects.count(), 2)
self.assertCountEqual(
TwoFields.objects.values("f1", "f2", "name"),
@@ -619,7 +623,6 @@ class BulkCreateTests(TestCase):
TwoFields(f1=2, f2=2, name="b"),
]
)
- self.assertEqual(TwoFields.objects.count(), 2)
obj1 = TwoFields.objects.get(f1=1)
obj2 = TwoFields.objects.get(f1=2)
@@ -627,12 +630,16 @@ class BulkCreateTests(TestCase):
TwoFields(pk=obj1.pk, f1=3, f2=3, name="c"),
TwoFields(pk=obj2.pk, f1=4, f2=4, name="d"),
]
- TwoFields.objects.bulk_create(
+ results = TwoFields.objects.bulk_create(
conflicting_objects,
update_conflicts=True,
unique_fields=["pk"],
update_fields=["name"],
)
+ self.assertEqual(len(results), len(conflicting_objects))
+ if connection.features.can_return_rows_from_bulk_insert:
+ for instance in results:
+ self.assertIsNotNone(instance.pk)
self.assertEqual(TwoFields.objects.count(), 2)
self.assertCountEqual(
TwoFields.objects.values("f1", "f2", "name"),
@@ -680,12 +687,16 @@ class BulkCreateTests(TestCase):
description=("Japan is an island country in East Asia."),
),
]
- Country.objects.bulk_create(
+ results = Country.objects.bulk_create(
new_data,
update_conflicts=True,
update_fields=["description"],
unique_fields=unique_fields,
)
+ self.assertEqual(len(results), len(new_data))
+ if connection.features.can_return_rows_from_bulk_insert:
+ for instance in results:
+ self.assertIsNotNone(instance.pk)
self.assertEqual(Country.objects.count(), 6)
self.assertCountEqual(
Country.objects.values("iso_two_letter", "description"),
@@ -743,12 +754,16 @@ class BulkCreateTests(TestCase):
UpsertConflict(number=2, rank=2, name="Olivia"),
UpsertConflict(number=3, rank=1, name="Hannah"),
]
- UpsertConflict.objects.bulk_create(
+ results = UpsertConflict.objects.bulk_create(
conflicting_objects,
update_conflicts=True,
update_fields=["name", "rank"],
unique_fields=unique_fields,
)
+ self.assertEqual(len(results), len(conflicting_objects))
+ if connection.features.can_return_rows_from_bulk_insert:
+ for instance in results:
+ self.assertIsNotNone(instance.pk)
self.assertEqual(UpsertConflict.objects.count(), 3)
self.assertCountEqual(
UpsertConflict.objects.values("number", "rank", "name"),
@@ -759,12 +774,16 @@ class BulkCreateTests(TestCase):
],
)
- UpsertConflict.objects.bulk_create(
+ results = UpsertConflict.objects.bulk_create(
conflicting_objects + [UpsertConflict(number=4, rank=4, name="Mark")],
update_conflicts=True,
update_fields=["name", "rank"],
unique_fields=unique_fields,
)
+ self.assertEqual(len(results), 4)
+ if connection.features.can_return_rows_from_bulk_insert:
+ for instance in results:
+ self.assertIsNotNone(instance.pk)
self.assertEqual(UpsertConflict.objects.count(), 4)
self.assertCountEqual(
UpsertConflict.objects.values("number", "rank", "name"),
@@ -803,12 +822,16 @@ class BulkCreateTests(TestCase):
FieldsWithDbColumns(rank=1, name="c"),
FieldsWithDbColumns(rank=2, name="d"),
]
- FieldsWithDbColumns.objects.bulk_create(
+ results = FieldsWithDbColumns.objects.bulk_create(
conflicting_objects,
update_conflicts=True,
unique_fields=["rank"],
update_fields=["name"],
)
+ self.assertEqual(len(results), len(conflicting_objects))
+ if connection.features.can_return_rows_from_bulk_insert:
+ for instance in results:
+ self.assertIsNotNone(instance.pk)
self.assertEqual(FieldsWithDbColumns.objects.count(), 2)
self.assertCountEqual(
FieldsWithDbColumns.objects.values("rank", "name"),