diff options
Diffstat (limited to 'tests/queries/test_bulk_update.py')
| -rw-r--r-- | tests/queries/test_bulk_update.py | 210 |
1 files changed, 111 insertions, 99 deletions
diff --git a/tests/queries/test_bulk_update.py b/tests/queries/test_bulk_update.py index b63046f9d2..389d6c1c41 100644 --- a/tests/queries/test_bulk_update.py +++ b/tests/queries/test_bulk_update.py @@ -6,67 +6,74 @@ from django.db.models.functions import Lower from django.test import TestCase, skipUnlessDBFeature from .models import ( - Article, CustomDbColumn, CustomPk, Detail, Individual, JSONFieldNullable, - Member, Note, Number, Order, Paragraph, RelatedObject, SingleObject, - SpecialCategory, Tag, Valid, + Article, + CustomDbColumn, + CustomPk, + Detail, + Individual, + JSONFieldNullable, + Member, + Note, + Number, + Order, + Paragraph, + RelatedObject, + SingleObject, + SpecialCategory, + Tag, + Valid, ) class BulkUpdateNoteTests(TestCase): @classmethod def setUpTestData(cls): - cls.notes = [ - Note.objects.create(note=str(i), misc=str(i)) - for i in range(10) - ] + cls.notes = [Note.objects.create(note=str(i), misc=str(i)) for i in range(10)] def create_tags(self): - self.tags = [ - Tag.objects.create(name=str(i)) - for i in range(10) - ] + self.tags = [Tag.objects.create(name=str(i)) for i in range(10)] def test_simple(self): for note in self.notes: - note.note = 'test-%s' % note.id + note.note = "test-%s" % note.id with self.assertNumQueries(1): - Note.objects.bulk_update(self.notes, ['note']) + Note.objects.bulk_update(self.notes, ["note"]) self.assertCountEqual( - Note.objects.values_list('note', flat=True), - [cat.note for cat in self.notes] + Note.objects.values_list("note", flat=True), + [cat.note for cat in self.notes], ) def test_multiple_fields(self): for note in self.notes: - note.note = 'test-%s' % note.id - note.misc = 'misc-%s' % note.id + note.note = "test-%s" % note.id + note.misc = "misc-%s" % note.id with self.assertNumQueries(1): - Note.objects.bulk_update(self.notes, ['note', 'misc']) + Note.objects.bulk_update(self.notes, ["note", "misc"]) self.assertCountEqual( - Note.objects.values_list('note', flat=True), - [cat.note for cat in self.notes] + Note.objects.values_list("note", flat=True), + [cat.note for cat in self.notes], ) self.assertCountEqual( - Note.objects.values_list('misc', flat=True), - [cat.misc for cat in self.notes] + Note.objects.values_list("misc", flat=True), + [cat.misc for cat in self.notes], ) def test_batch_size(self): with self.assertNumQueries(len(self.notes)): - Note.objects.bulk_update(self.notes, fields=['note'], batch_size=1) + Note.objects.bulk_update(self.notes, fields=["note"], batch_size=1) def test_unsaved_models(self): - objs = self.notes + [Note(note='test', misc='test')] - msg = 'All bulk_update() objects must have a primary key set.' + objs = self.notes + [Note(note="test", misc="test")] + msg = "All bulk_update() objects must have a primary key set." with self.assertRaisesMessage(ValueError, msg): - Note.objects.bulk_update(objs, fields=['note']) + Note.objects.bulk_update(objs, fields=["note"]) def test_foreign_keys_do_not_lookup(self): self.create_tags() for note, tag in zip(self.notes, self.tags): note.tag = tag with self.assertNumQueries(1): - Note.objects.bulk_update(self.notes, ['tag']) + Note.objects.bulk_update(self.notes, ["tag"]) self.assertSequenceEqual(Note.objects.filter(tag__isnull=False), self.notes) def test_set_field_to_null(self): @@ -74,7 +81,7 @@ class BulkUpdateNoteTests(TestCase): Note.objects.update(tag=self.tags[0]) for note in self.notes: note.tag = None - Note.objects.bulk_update(self.notes, ['tag']) + Note.objects.bulk_update(self.notes, ["tag"]) self.assertCountEqual(Note.objects.filter(tag__isnull=True), self.notes) def test_set_mixed_fields_to_null(self): @@ -85,106 +92,106 @@ class BulkUpdateNoteTests(TestCase): note.tag = None for note in bottom: note.tag = self.tags[0] - Note.objects.bulk_update(self.notes, ['tag']) + Note.objects.bulk_update(self.notes, ["tag"]) self.assertCountEqual(Note.objects.filter(tag__isnull=True), top) self.assertCountEqual(Note.objects.filter(tag__isnull=False), bottom) def test_functions(self): - Note.objects.update(note='TEST') + Note.objects.update(note="TEST") for note in self.notes: - note.note = Lower('note') - Note.objects.bulk_update(self.notes, ['note']) - self.assertEqual(set(Note.objects.values_list('note', flat=True)), {'test'}) + note.note = Lower("note") + Note.objects.bulk_update(self.notes, ["note"]) + self.assertEqual(set(Note.objects.values_list("note", flat=True)), {"test"}) # Tests that use self.notes go here, otherwise put them in another class. class BulkUpdateTests(TestCase): def test_no_fields(self): - msg = 'Field names must be given to bulk_update().' + msg = "Field names must be given to bulk_update()." with self.assertRaisesMessage(ValueError, msg): Note.objects.bulk_update([], fields=[]) def test_invalid_batch_size(self): - msg = 'Batch size must be a positive integer.' + msg = "Batch size must be a positive integer." with self.assertRaisesMessage(ValueError, msg): - Note.objects.bulk_update([], fields=['note'], batch_size=-1) + Note.objects.bulk_update([], fields=["note"], batch_size=-1) def test_nonexistent_field(self): - with self.assertRaisesMessage(FieldDoesNotExist, "Note has no field named 'nonexistent'"): - Note.objects.bulk_update([], ['nonexistent']) + with self.assertRaisesMessage( + FieldDoesNotExist, "Note has no field named 'nonexistent'" + ): + Note.objects.bulk_update([], ["nonexistent"]) - pk_fields_error = 'bulk_update() cannot be used with primary key fields.' + pk_fields_error = "bulk_update() cannot be used with primary key fields." def test_update_primary_key(self): with self.assertRaisesMessage(ValueError, self.pk_fields_error): - Note.objects.bulk_update([], ['id']) + Note.objects.bulk_update([], ["id"]) def test_update_custom_primary_key(self): with self.assertRaisesMessage(ValueError, self.pk_fields_error): - CustomPk.objects.bulk_update([], ['name']) + CustomPk.objects.bulk_update([], ["name"]) def test_empty_objects(self): with self.assertNumQueries(0): - rows_updated = Note.objects.bulk_update([], ['note']) + rows_updated = Note.objects.bulk_update([], ["note"]) self.assertEqual(rows_updated, 0) def test_large_batch(self): - Note.objects.bulk_create([ - Note(note=str(i), misc=str(i)) - for i in range(0, 2000) - ]) + Note.objects.bulk_create( + [Note(note=str(i), misc=str(i)) for i in range(0, 2000)] + ) notes = list(Note.objects.all()) - rows_updated = Note.objects.bulk_update(notes, ['note']) + rows_updated = Note.objects.bulk_update(notes, ["note"]) self.assertEqual(rows_updated, 2000) def test_updated_rows_when_passing_duplicates(self): - note = Note.objects.create(note='test-note', misc='test') - rows_updated = Note.objects.bulk_update([note, note], ['note']) + note = Note.objects.create(note="test-note", misc="test") + rows_updated = Note.objects.bulk_update([note, note], ["note"]) self.assertEqual(rows_updated, 1) # Duplicates in different batches. - rows_updated = Note.objects.bulk_update([note, note], ['note'], batch_size=1) + rows_updated = Note.objects.bulk_update([note, note], ["note"], batch_size=1) self.assertEqual(rows_updated, 2) def test_only_concrete_fields_allowed(self): - obj = Valid.objects.create(valid='test') - detail = Detail.objects.create(data='test') - paragraph = Paragraph.objects.create(text='test') - Member.objects.create(name='test', details=detail) - msg = 'bulk_update() can only be used with concrete fields.' + obj = Valid.objects.create(valid="test") + detail = Detail.objects.create(data="test") + paragraph = Paragraph.objects.create(text="test") + Member.objects.create(name="test", details=detail) + msg = "bulk_update() can only be used with concrete fields." with self.assertRaisesMessage(ValueError, msg): - Detail.objects.bulk_update([detail], fields=['member']) + Detail.objects.bulk_update([detail], fields=["member"]) with self.assertRaisesMessage(ValueError, msg): - Paragraph.objects.bulk_update([paragraph], fields=['page']) + Paragraph.objects.bulk_update([paragraph], fields=["page"]) with self.assertRaisesMessage(ValueError, msg): - Valid.objects.bulk_update([obj], fields=['parent']) + Valid.objects.bulk_update([obj], fields=["parent"]) def test_custom_db_columns(self): model = CustomDbColumn.objects.create(custom_column=1) model.custom_column = 2 - CustomDbColumn.objects.bulk_update([model], fields=['custom_column']) + CustomDbColumn.objects.bulk_update([model], fields=["custom_column"]) model.refresh_from_db() self.assertEqual(model.custom_column, 2) def test_custom_pk(self): custom_pks = [ - CustomPk.objects.create(name='pk-%s' % i, extra='') - for i in range(10) + CustomPk.objects.create(name="pk-%s" % i, extra="") for i in range(10) ] for model in custom_pks: - model.extra = 'extra-%s' % model.pk - CustomPk.objects.bulk_update(custom_pks, ['extra']) + model.extra = "extra-%s" % model.pk + CustomPk.objects.bulk_update(custom_pks, ["extra"]) self.assertCountEqual( - CustomPk.objects.values_list('extra', flat=True), - [cat.extra for cat in custom_pks] + CustomPk.objects.values_list("extra", flat=True), + [cat.extra for cat in custom_pks], ) def test_falsey_pk_value(self): - order = Order.objects.create(pk=0, name='test') - order.name = 'updated' - Order.objects.bulk_update([order], ['name']) + order = Order.objects.create(pk=0, name="test") + order.name = "updated" + Order.objects.bulk_update([order], ["name"]) order.refresh_from_db() - self.assertEqual(order.name, 'updated') + self.assertEqual(order.name, "updated") def test_inherited_fields(self): special_categories = [ @@ -192,53 +199,56 @@ class BulkUpdateTests(TestCase): for i in range(10) ] for category in special_categories: - category.name = 'test-%s' % category.id - category.special_name = 'special-test-%s' % category.special_name - SpecialCategory.objects.bulk_update(special_categories, ['name', 'special_name']) + category.name = "test-%s" % category.id + category.special_name = "special-test-%s" % category.special_name + SpecialCategory.objects.bulk_update( + special_categories, ["name", "special_name"] + ) self.assertCountEqual( - SpecialCategory.objects.values_list('name', flat=True), - [cat.name for cat in special_categories] + SpecialCategory.objects.values_list("name", flat=True), + [cat.name for cat in special_categories], ) self.assertCountEqual( - SpecialCategory.objects.values_list('special_name', flat=True), - [cat.special_name for cat in special_categories] + SpecialCategory.objects.values_list("special_name", flat=True), + [cat.special_name for cat in special_categories], ) def test_field_references(self): numbers = [Number.objects.create(num=0) for _ in range(10)] for number in numbers: - number.num = F('num') + 1 - Number.objects.bulk_update(numbers, ['num']) + number.num = F("num") + 1 + Number.objects.bulk_update(numbers, ["num"]) self.assertCountEqual(Number.objects.filter(num=1), numbers) def test_f_expression(self): notes = [ - Note.objects.create(note='test_note', misc='test_misc') - for _ in range(10) + Note.objects.create(note="test_note", misc="test_misc") for _ in range(10) ] for note in notes: - note.misc = F('note') - Note.objects.bulk_update(notes, ['misc']) - self.assertCountEqual(Note.objects.filter(misc='test_note'), notes) + note.misc = F("note") + Note.objects.bulk_update(notes, ["misc"]) + self.assertCountEqual(Note.objects.filter(misc="test_note"), notes) def test_booleanfield(self): individuals = [Individual.objects.create(alive=False) for _ in range(10)] for individual in individuals: individual.alive = True - Individual.objects.bulk_update(individuals, ['alive']) + Individual.objects.bulk_update(individuals, ["alive"]) self.assertCountEqual(Individual.objects.filter(alive=True), individuals) def test_ipaddressfield(self): - for ip in ('2001::1', '1.2.3.4'): + for ip in ("2001::1", "1.2.3.4"): with self.subTest(ip=ip): models = [ - CustomDbColumn.objects.create(ip_address='0.0.0.0') + CustomDbColumn.objects.create(ip_address="0.0.0.0") for _ in range(10) ] for model in models: model.ip_address = ip - CustomDbColumn.objects.bulk_update(models, ['ip_address']) - self.assertCountEqual(CustomDbColumn.objects.filter(ip_address=ip), models) + CustomDbColumn.objects.bulk_update(models, ["ip_address"]) + self.assertCountEqual( + CustomDbColumn.objects.filter(ip_address=ip), models + ) def test_datetime_field(self): articles = [ @@ -248,26 +258,28 @@ class BulkUpdateTests(TestCase): point_in_time = datetime.datetime(1991, 10, 31) for article in articles: article.created = point_in_time - Article.objects.bulk_update(articles, ['created']) + Article.objects.bulk_update(articles, ["created"]) self.assertCountEqual(Article.objects.filter(created=point_in_time), articles) - @skipUnlessDBFeature('supports_json_field') + @skipUnlessDBFeature("supports_json_field") def test_json_field(self): - JSONFieldNullable.objects.bulk_create([ - JSONFieldNullable(json_field={'a': i}) for i in range(10) - ]) + JSONFieldNullable.objects.bulk_create( + [JSONFieldNullable(json_field={"a": i}) for i in range(10)] + ) objs = JSONFieldNullable.objects.all() for obj in objs: - obj.json_field = {'c': obj.json_field['a'] + 1} - JSONFieldNullable.objects.bulk_update(objs, ['json_field']) - self.assertCountEqual(JSONFieldNullable.objects.filter(json_field__has_key='c'), objs) + obj.json_field = {"c": obj.json_field["a"] + 1} + JSONFieldNullable.objects.bulk_update(objs, ["json_field"]) + self.assertCountEqual( + JSONFieldNullable.objects.filter(json_field__has_key="c"), objs + ) def test_nullable_fk_after_related_save(self): parent = RelatedObject.objects.create() child = SingleObject() parent.single = child parent.single.save() - RelatedObject.objects.bulk_update([parent], fields=['single']) + RelatedObject.objects.bulk_update([parent], fields=["single"]) self.assertEqual(parent.single_id, parent.single.pk) parent.refresh_from_db() self.assertEqual(parent.single, child) @@ -280,13 +292,13 @@ class BulkUpdateTests(TestCase): "related object 'single'." ) with self.assertRaisesMessage(ValueError, msg): - RelatedObject.objects.bulk_update([parent], fields=['single']) + RelatedObject.objects.bulk_update([parent], fields=["single"]) def test_unspecified_unsaved_parent(self): parent = RelatedObject.objects.create() parent.single = SingleObject() parent.f = 42 - RelatedObject.objects.bulk_update([parent], fields=['f']) + RelatedObject.objects.bulk_update([parent], fields=["f"]) parent.refresh_from_db() self.assertEqual(parent.f, 42) self.assertIsNone(parent.single) |
