From 89c7454dbdae3e0df6d96aa6132205d05e4a9b3d Mon Sep 17 00:00:00 2001 From: Thomas Chaumeny Date: Fri, 7 Jul 2023 13:08:17 +0200 Subject: Fixed #34698 -- Made QuerySet.bulk_create() retrieve primary keys when updating conflicts. --- tests/bulk_create/tests.py | 37 ++++++++++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 7 deletions(-) (limited to 'tests/bulk_create') diff --git a/tests/bulk_create/tests.py b/tests/bulk_create/tests.py index aee0cd9996..7b86a2def5 100644 --- a/tests/bulk_create/tests.py +++ b/tests/bulk_create/tests.py @@ -582,12 +582,16 @@ class BulkCreateTests(TestCase): TwoFields(f1=1, f2=1, name="c"), TwoFields(f1=2, f2=2, name="d"), ] - TwoFields.objects.bulk_create( + results = TwoFields.objects.bulk_create( conflicting_objects, update_conflicts=True, unique_fields=unique_fields, update_fields=["name"], ) + self.assertEqual(len(results), len(conflicting_objects)) + if connection.features.can_return_rows_from_bulk_insert: + for instance in results: + self.assertIsNotNone(instance.pk) self.assertEqual(TwoFields.objects.count(), 2) self.assertCountEqual( TwoFields.objects.values("f1", "f2", "name"), @@ -619,7 +623,6 @@ class BulkCreateTests(TestCase): TwoFields(f1=2, f2=2, name="b"), ] ) - self.assertEqual(TwoFields.objects.count(), 2) obj1 = TwoFields.objects.get(f1=1) obj2 = TwoFields.objects.get(f1=2) @@ -627,12 +630,16 @@ class BulkCreateTests(TestCase): TwoFields(pk=obj1.pk, f1=3, f2=3, name="c"), TwoFields(pk=obj2.pk, f1=4, f2=4, name="d"), ] - TwoFields.objects.bulk_create( + results = TwoFields.objects.bulk_create( conflicting_objects, update_conflicts=True, unique_fields=["pk"], update_fields=["name"], ) + self.assertEqual(len(results), len(conflicting_objects)) + if connection.features.can_return_rows_from_bulk_insert: + for instance in results: + self.assertIsNotNone(instance.pk) self.assertEqual(TwoFields.objects.count(), 2) self.assertCountEqual( TwoFields.objects.values("f1", "f2", "name"), @@ -680,12 +687,16 @@ class BulkCreateTests(TestCase): description=("Japan is an island country in East Asia."), ), ] - Country.objects.bulk_create( + results = Country.objects.bulk_create( new_data, update_conflicts=True, update_fields=["description"], unique_fields=unique_fields, ) + self.assertEqual(len(results), len(new_data)) + if connection.features.can_return_rows_from_bulk_insert: + for instance in results: + self.assertIsNotNone(instance.pk) self.assertEqual(Country.objects.count(), 6) self.assertCountEqual( Country.objects.values("iso_two_letter", "description"), @@ -743,12 +754,16 @@ class BulkCreateTests(TestCase): UpsertConflict(number=2, rank=2, name="Olivia"), UpsertConflict(number=3, rank=1, name="Hannah"), ] - UpsertConflict.objects.bulk_create( + results = UpsertConflict.objects.bulk_create( conflicting_objects, update_conflicts=True, update_fields=["name", "rank"], unique_fields=unique_fields, ) + self.assertEqual(len(results), len(conflicting_objects)) + if connection.features.can_return_rows_from_bulk_insert: + for instance in results: + self.assertIsNotNone(instance.pk) self.assertEqual(UpsertConflict.objects.count(), 3) self.assertCountEqual( UpsertConflict.objects.values("number", "rank", "name"), @@ -759,12 +774,16 @@ class BulkCreateTests(TestCase): ], ) - UpsertConflict.objects.bulk_create( + results = UpsertConflict.objects.bulk_create( conflicting_objects + [UpsertConflict(number=4, rank=4, name="Mark")], update_conflicts=True, update_fields=["name", "rank"], unique_fields=unique_fields, ) + self.assertEqual(len(results), 4) + if connection.features.can_return_rows_from_bulk_insert: + for instance in results: + self.assertIsNotNone(instance.pk) self.assertEqual(UpsertConflict.objects.count(), 4) self.assertCountEqual( UpsertConflict.objects.values("number", "rank", "name"), @@ -803,12 +822,16 @@ class BulkCreateTests(TestCase): FieldsWithDbColumns(rank=1, name="c"), FieldsWithDbColumns(rank=2, name="d"), ] - FieldsWithDbColumns.objects.bulk_create( + results = FieldsWithDbColumns.objects.bulk_create( conflicting_objects, update_conflicts=True, unique_fields=["rank"], update_fields=["name"], ) + self.assertEqual(len(results), len(conflicting_objects)) + if connection.features.can_return_rows_from_bulk_insert: + for instance in results: + self.assertIsNotNone(instance.pk) self.assertEqual(FieldsWithDbColumns.objects.count(), 2) self.assertCountEqual( FieldsWithDbColumns.objects.values("rank", "name"), -- cgit v1.3