summaryrefslogtreecommitdiff
path: root/tests/postgres_tests
diff options
context:
space:
mode:
authorTom Forbes <tom@tomforb.es>2018-09-18 21:14:44 +0100
committerTim Graham <timograham@gmail.com>2018-09-18 16:14:44 -0400
commit9cbdb44014c8027f1b4571bac701a247b0ce02a3 (patch)
treeb7cd20864b0d06f5e08b2c98a50cd5ef2a4cd9a0 /tests/postgres_tests
parent7b159df94235036a41ee93952ff83bbc95c1da3c (diff)
Fixed #23646 -- Added QuerySet.bulk_update() to efficiently update many models.
Diffstat (limited to 'tests/postgres_tests')
-rw-r--r--tests/postgres_tests/migrations/0002_create_test_models.py6
-rw-r--r--tests/postgres_tests/models.py6
-rw-r--r--tests/postgres_tests/test_bulk_update.py34
3 files changed, 40 insertions, 6 deletions
diff --git a/tests/postgres_tests/migrations/0002_create_test_models.py b/tests/postgres_tests/migrations/0002_create_test_models.py
index 9f4417b58d..0e7ba938ca 100644
--- a/tests/postgres_tests/migrations/0002_create_test_models.py
+++ b/tests/postgres_tests/migrations/0002_create_test_models.py
@@ -56,9 +56,9 @@ class Migration(migrations.Migration):
name='OtherTypesArrayModel',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
- ('ips', ArrayField(models.GenericIPAddressField(), size=None)),
- ('uuids', ArrayField(models.UUIDField(), size=None)),
- ('decimals', ArrayField(models.DecimalField(max_digits=5, decimal_places=2), size=None)),
+ ('ips', ArrayField(models.GenericIPAddressField(), size=None, default=list)),
+ ('uuids', ArrayField(models.UUIDField(), size=None, default=list)),
+ ('decimals', ArrayField(models.DecimalField(max_digits=5, decimal_places=2), size=None, default=list)),
('tags', ArrayField(TagField(), blank=True, null=True, size=None)),
('json', ArrayField(JSONField(default={}), default=[])),
('int_ranges', ArrayField(IntegerRangeField(), null=True, blank=True)),
diff --git a/tests/postgres_tests/models.py b/tests/postgres_tests/models.py
index cd1646a3e6..841f246c6a 100644
--- a/tests/postgres_tests/models.py
+++ b/tests/postgres_tests/models.py
@@ -63,9 +63,9 @@ class NestedIntegerArrayModel(PostgreSQLModel):
class OtherTypesArrayModel(PostgreSQLModel):
- ips = ArrayField(models.GenericIPAddressField())
- uuids = ArrayField(models.UUIDField())
- decimals = ArrayField(models.DecimalField(max_digits=5, decimal_places=2))
+ ips = ArrayField(models.GenericIPAddressField(), default=list)
+ uuids = ArrayField(models.UUIDField(), default=list)
+ decimals = ArrayField(models.DecimalField(max_digits=5, decimal_places=2), default=list)
tags = ArrayField(TagField(), blank=True, null=True)
json = ArrayField(JSONField(default=dict), default=list)
int_ranges = ArrayField(IntegerRangeField(), blank=True, null=True)
diff --git a/tests/postgres_tests/test_bulk_update.py b/tests/postgres_tests/test_bulk_update.py
new file mode 100644
index 0000000000..6dd7036a9b
--- /dev/null
+++ b/tests/postgres_tests/test_bulk_update.py
@@ -0,0 +1,34 @@
+from datetime import date
+
+from . import PostgreSQLTestCase
+from .models import (
+ HStoreModel, IntegerArrayModel, JSONModel, NestedIntegerArrayModel,
+ NullableIntegerArrayModel, OtherTypesArrayModel, RangesModel,
+)
+
+try:
+ from psycopg2.extras import NumericRange, DateRange
+except ImportError:
+ pass # psycopg2 isn't installed.
+
+
+class BulkSaveTests(PostgreSQLTestCase):
+ def test_bulk_update(self):
+ test_data = [
+ (IntegerArrayModel, 'field', [], [1, 2, 3]),
+ (NullableIntegerArrayModel, 'field', [1, 2, 3], None),
+ (JSONModel, 'field', {'a': 'b'}, {'c': 'd'}),
+ (NestedIntegerArrayModel, 'field', [], [[1, 2, 3]]),
+ (HStoreModel, 'field', {}, {1: 2}),
+ (RangesModel, 'ints', None, NumericRange(lower=1, upper=10)),
+ (RangesModel, 'dates', None, DateRange(lower=date.today(), upper=date.today())),
+ (OtherTypesArrayModel, 'ips', [], ['1.2.3.4']),
+ (OtherTypesArrayModel, 'json', [], [{'a': 'b'}])
+ ]
+ for Model, field, initial, new in test_data:
+ with self.subTest(model=Model, field=field):
+ instances = Model.objects.bulk_create(Model(**{field: initial}) for _ in range(20))
+ for instance in instances:
+ setattr(instance, field, new)
+ Model.objects.bulk_update(instances, [field])
+ self.assertSequenceEqual(Model.objects.filter(**{field: new}), instances)