diff options
Diffstat (limited to 'tests/postgres_tests')
26 files changed, 3412 insertions, 2359 deletions
diff --git a/tests/postgres_tests/__init__.py b/tests/postgres_tests/__init__.py index 2b84fc25db..6f02531ed0 100644 --- a/tests/postgres_tests/__init__.py +++ b/tests/postgres_tests/__init__.py @@ -6,18 +6,18 @@ from django.db import connection from django.test import SimpleTestCase, TestCase, modify_settings -@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific tests") +@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests") class PostgreSQLSimpleTestCase(SimpleTestCase): pass -@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific tests") +@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests") class PostgreSQLTestCase(TestCase): pass -@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific tests") +@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests") # To locate the widget's template. -@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}) +@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"}) class PostgreSQLWidgetTestCase(WidgetTest, PostgreSQLSimpleTestCase): pass diff --git a/tests/postgres_tests/array_default_migrations/0001_initial.py b/tests/postgres_tests/array_default_migrations/0001_initial.py index eb523218ef..10eaef2aab 100644 --- a/tests/postgres_tests/array_default_migrations/0001_initial.py +++ b/tests/postgres_tests/array_default_migrations/0001_initial.py @@ -4,18 +4,29 @@ from django.db import migrations, models class Migration(migrations.Migration): - dependencies = [ - ] + dependencies = [] operations = [ migrations.CreateModel( - name='IntegerArrayDefaultModel', + name="IntegerArrayDefaultModel", fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('field', django.contrib.postgres.fields.ArrayField(models.IntegerField(), size=None)), + ( + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ( + "field", + django.contrib.postgres.fields.ArrayField( + models.IntegerField(), size=None + ), + ), ], - options={ - }, + options={}, bases=(models.Model,), ), ] diff --git a/tests/postgres_tests/array_default_migrations/0002_integerarraymodel_field_2.py b/tests/postgres_tests/array_default_migrations/0002_integerarraymodel_field_2.py index 679c9bb0d3..b15b575e54 100644 --- a/tests/postgres_tests/array_default_migrations/0002_integerarraymodel_field_2.py +++ b/tests/postgres_tests/array_default_migrations/0002_integerarraymodel_field_2.py @@ -5,14 +5,16 @@ from django.db import migrations, models class Migration(migrations.Migration): dependencies = [ - ('postgres_tests', '0001_initial'), + ("postgres_tests", "0001_initial"), ] operations = [ migrations.AddField( - model_name='integerarraydefaultmodel', - name='field_2', - field=django.contrib.postgres.fields.ArrayField(models.IntegerField(), default=[], size=None), + model_name="integerarraydefaultmodel", + name="field_2", + field=django.contrib.postgres.fields.ArrayField( + models.IntegerField(), default=[], size=None + ), preserve_default=False, ), ] diff --git a/tests/postgres_tests/array_index_migrations/0001_initial.py b/tests/postgres_tests/array_index_migrations/0001_initial.py index 505e53e4e8..5c74be326a 100644 --- a/tests/postgres_tests/array_index_migrations/0001_initial.py +++ b/tests/postgres_tests/array_index_migrations/0001_initial.py @@ -4,22 +4,36 @@ from django.db import migrations, models class Migration(migrations.Migration): - dependencies = [ - ] + dependencies = [] operations = [ migrations.CreateModel( - name='CharTextArrayIndexModel', + name="CharTextArrayIndexModel", fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('char', django.contrib.postgres.fields.ArrayField( - models.CharField(max_length=10), db_index=True, size=100) - ), - ('char2', models.CharField(max_length=11, db_index=True)), - ('text', django.contrib.postgres.fields.ArrayField(models.TextField(), db_index=True)), + ( + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ( + "char", + django.contrib.postgres.fields.ArrayField( + models.CharField(max_length=10), db_index=True, size=100 + ), + ), + ("char2", models.CharField(max_length=11, db_index=True)), + ( + "text", + django.contrib.postgres.fields.ArrayField( + models.TextField(), db_index=True + ), + ), ], - options={ - }, + options={}, bases=(models.Model,), ), ] diff --git a/tests/postgres_tests/fields.py b/tests/postgres_tests/fields.py index 1c0cb05d46..2d62e26a92 100644 --- a/tests/postgres_tests/fields.py +++ b/tests/postgres_tests/fields.py @@ -8,31 +8,41 @@ from django.db import models try: from django.contrib.postgres.fields import ( - ArrayField, BigIntegerRangeField, CICharField, CIEmailField, - CITextField, DateRangeField, DateTimeRangeField, DecimalRangeField, - HStoreField, IntegerRangeField, + ArrayField, + BigIntegerRangeField, + CICharField, + CIEmailField, + CITextField, + DateRangeField, + DateTimeRangeField, + DecimalRangeField, + HStoreField, + IntegerRangeField, ) from django.contrib.postgres.search import SearchVector, SearchVectorField except ImportError: + class DummyArrayField(models.Field): def __init__(self, base_field, size=None, **kwargs): super().__init__(**kwargs) def deconstruct(self): name, path, args, kwargs = super().deconstruct() - kwargs.update({ - 'base_field': '', - 'size': 1, - }) + kwargs.update( + { + "base_field": "", + "size": 1, + } + ) return name, path, args, kwargs class DummyContinuousRangeField(models.Field): - def __init__(self, *args, default_bounds='[)', **kwargs): + def __init__(self, *args, default_bounds="[)", **kwargs): super().__init__(**kwargs) def deconstruct(self): name, path, args, kwargs = super().deconstruct() - kwargs['default_bounds'] = '[)' + kwargs["default_bounds"] = "[)" return name, path, args, kwargs ArrayField = DummyArrayField diff --git a/tests/postgres_tests/integration_settings.py b/tests/postgres_tests/integration_settings.py index c4ec0d1157..7e2ea9c8d0 100644 --- a/tests/postgres_tests/integration_settings.py +++ b/tests/postgres_tests/integration_settings.py @@ -1,5 +1,5 @@ -SECRET_KEY = 'abcdefg' +SECRET_KEY = "abcdefg" INSTALLED_APPS = [ - 'django.contrib.postgres', + "django.contrib.postgres", ] diff --git a/tests/postgres_tests/migrations/0001_setup_extensions.py b/tests/postgres_tests/migrations/0001_setup_extensions.py index bd5da83d15..090abf9649 100644 --- a/tests/postgres_tests/migrations/0001_setup_extensions.py +++ b/tests/postgres_tests/migrations/0001_setup_extensions.py @@ -4,8 +4,14 @@ from django.db import connection, migrations try: from django.contrib.postgres.operations import ( - BloomExtension, BtreeGinExtension, BtreeGistExtension, CITextExtension, - CreateExtension, CryptoExtension, HStoreExtension, TrigramExtension, + BloomExtension, + BtreeGinExtension, + BtreeGistExtension, + CITextExtension, + CreateExtension, + CryptoExtension, + HStoreExtension, + TrigramExtension, UnaccentExtension, ) except ImportError: @@ -20,8 +26,7 @@ except ImportError: needs_crypto_extension = False else: needs_crypto_extension = ( - connection.vendor == 'postgresql' and - not connection.features.is_postgresql_13 + connection.vendor == "postgresql" and not connection.features.is_postgresql_13 ) @@ -34,7 +39,7 @@ class Migration(migrations.Migration): CITextExtension(), # Ensure CreateExtension quotes extension names by creating one with a # dash in its name. - CreateExtension('uuid-ossp'), + CreateExtension("uuid-ossp"), # CryptoExtension is required for RandomUUID() on PostgreSQL < 13. CryptoExtension() if needs_crypto_extension else mock.Mock(), HStoreExtension(), diff --git a/tests/postgres_tests/migrations/0002_create_test_models.py b/tests/postgres_tests/migrations/0002_create_test_models.py index 7ede593dae..a78563f80d 100644 --- a/tests/postgres_tests/migrations/0002_create_test_models.py +++ b/tests/postgres_tests/migrations/0002_create_test_models.py @@ -1,9 +1,18 @@ from django.db import migrations, models from ..fields import ( - ArrayField, BigIntegerRangeField, CICharField, CIEmailField, CITextField, - DateRangeField, DateTimeRangeField, DecimalRangeField, EnumField, - HStoreField, IntegerRangeField, SearchVectorField, + ArrayField, + BigIntegerRangeField, + CICharField, + CIEmailField, + CITextField, + DateRangeField, + DateTimeRangeField, + DecimalRangeField, + EnumField, + HStoreField, + IntegerRangeField, + SearchVectorField, ) from ..models import TagField @@ -11,305 +20,538 @@ from ..models import TagField class Migration(migrations.Migration): dependencies = [ - ('postgres_tests', '0001_setup_extensions'), + ("postgres_tests", "0001_setup_extensions"), ] operations = [ migrations.CreateModel( - name='CharArrayModel', + name="CharArrayModel", fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('field', ArrayField(models.CharField(max_length=10), size=None)), + ( + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ("field", ArrayField(models.CharField(max_length=10), size=None)), ], options={ - 'required_db_vendor': 'postgresql', + "required_db_vendor": "postgresql", }, bases=(models.Model,), ), migrations.CreateModel( - name='DateTimeArrayModel', + name="DateTimeArrayModel", fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('datetimes', ArrayField(models.DateTimeField(), size=None)), - ('dates', ArrayField(models.DateField(), size=None)), - ('times', ArrayField(models.TimeField(), size=None)), + ( + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ("datetimes", ArrayField(models.DateTimeField(), size=None)), + ("dates", ArrayField(models.DateField(), size=None)), + ("times", ArrayField(models.TimeField(), size=None)), ], options={ - 'required_db_vendor': 'postgresql', + "required_db_vendor": "postgresql", }, bases=(models.Model,), ), migrations.CreateModel( - name='HStoreModel', + name="HStoreModel", fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('field', HStoreField(blank=True, null=True)), - ('array_field', ArrayField(HStoreField(), null=True)), + ( + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ("field", HStoreField(blank=True, null=True)), + ("array_field", ArrayField(HStoreField(), null=True)), ], options={ - 'required_db_vendor': 'postgresql', + "required_db_vendor": "postgresql", }, bases=(models.Model,), ), migrations.CreateModel( - name='OtherTypesArrayModel', + name="OtherTypesArrayModel", fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('ips', ArrayField(models.GenericIPAddressField(), size=None, default=list)), - ('uuids', ArrayField(models.UUIDField(), size=None, default=list)), - ('decimals', ArrayField(models.DecimalField(max_digits=5, decimal_places=2), size=None, default=list)), - ('tags', ArrayField(TagField(), blank=True, null=True, size=None)), - ('json', ArrayField(models.JSONField(default={}), default=[])), - ('int_ranges', ArrayField(IntegerRangeField(), null=True, blank=True)), - ('bigint_ranges', ArrayField(BigIntegerRangeField(), null=True, blank=True)), + ( + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ( + "ips", + ArrayField(models.GenericIPAddressField(), size=None, default=list), + ), + ("uuids", ArrayField(models.UUIDField(), size=None, default=list)), + ( + "decimals", + ArrayField( + models.DecimalField(max_digits=5, decimal_places=2), + size=None, + default=list, + ), + ), + ("tags", ArrayField(TagField(), blank=True, null=True, size=None)), + ("json", ArrayField(models.JSONField(default={}), default=[])), + ("int_ranges", ArrayField(IntegerRangeField(), null=True, blank=True)), + ( + "bigint_ranges", + ArrayField(BigIntegerRangeField(), null=True, blank=True), + ), ], options={ - 'required_db_vendor': 'postgresql', + "required_db_vendor": "postgresql", }, bases=(models.Model,), ), migrations.CreateModel( - name='IntegerArrayModel', + name="IntegerArrayModel", fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('field', ArrayField(models.IntegerField(), size=None)), + ( + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ("field", ArrayField(models.IntegerField(), size=None)), ], options={ - 'required_db_vendor': 'postgresql', + "required_db_vendor": "postgresql", }, bases=(models.Model,), ), migrations.CreateModel( - name='NestedIntegerArrayModel', + name="NestedIntegerArrayModel", fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('field', ArrayField(ArrayField(models.IntegerField(), size=None), size=None)), + ( + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ( + "field", + ArrayField(ArrayField(models.IntegerField(), size=None), size=None), + ), ], options={ - 'required_db_vendor': 'postgresql', + "required_db_vendor": "postgresql", }, bases=(models.Model,), ), migrations.CreateModel( - name='NullableIntegerArrayModel', + name="NullableIntegerArrayModel", fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('field', ArrayField(models.IntegerField(), size=None, null=True, blank=True)), ( - 'field_nested', - ArrayField(ArrayField(models.IntegerField(), size=None, null=True), size=None, null=True), + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), ), - ('order', models.IntegerField(null=True)), + ( + "field", + ArrayField(models.IntegerField(), size=None, null=True, blank=True), + ), + ( + "field_nested", + ArrayField( + ArrayField(models.IntegerField(), size=None, null=True), + size=None, + null=True, + ), + ), + ("order", models.IntegerField(null=True)), ], options={ - 'required_db_vendor': 'postgresql', + "required_db_vendor": "postgresql", }, bases=(models.Model,), ), migrations.CreateModel( - name='CharFieldModel', + name="CharFieldModel", fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('field', models.CharField(max_length=64)), + ( + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ("field", models.CharField(max_length=64)), ], options=None, bases=None, ), migrations.CreateModel( - name='TextFieldModel', + name="TextFieldModel", fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('field', models.TextField()), + ( + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ("field", models.TextField()), ], options=None, bases=None, ), migrations.CreateModel( - name='SmallAutoFieldModel', + name="SmallAutoFieldModel", fields=[ - ('id', models.SmallAutoField(verbose_name='ID', serialize=False, primary_key=True)), + ( + "id", + models.SmallAutoField( + verbose_name="ID", serialize=False, primary_key=True + ), + ), ], options=None, ), migrations.CreateModel( - name='BigAutoFieldModel', + name="BigAutoFieldModel", fields=[ - ('id', models.BigAutoField(verbose_name='ID', serialize=False, primary_key=True)), + ( + "id", + models.BigAutoField( + verbose_name="ID", serialize=False, primary_key=True + ), + ), ], options=None, ), migrations.CreateModel( - name='Scene', + name="Scene", fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('scene', models.TextField()), - ('setting', models.CharField(max_length=255)), + ( + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ("scene", models.TextField()), + ("setting", models.CharField(max_length=255)), ], options=None, bases=None, ), migrations.CreateModel( - name='Character', + name="Character", fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('name', models.CharField(max_length=255)), + ( + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ("name", models.CharField(max_length=255)), ], options=None, bases=None, ), migrations.CreateModel( - name='CITestModel', + name="CITestModel", fields=[ - ('name', CICharField(primary_key=True, max_length=255)), - ('email', CIEmailField()), - ('description', CITextField()), - ('array_field', ArrayField(CITextField(), null=True)), + ("name", CICharField(primary_key=True, max_length=255)), + ("email", CIEmailField()), + ("description", CITextField()), + ("array_field", ArrayField(CITextField(), null=True)), ], options={ - 'required_db_vendor': 'postgresql', + "required_db_vendor": "postgresql", }, bases=None, ), migrations.CreateModel( - name='Line', + name="Line", fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('scene', models.ForeignKey('postgres_tests.Scene', on_delete=models.SET_NULL)), - ('character', models.ForeignKey('postgres_tests.Character', on_delete=models.SET_NULL)), - ('dialogue', models.TextField(blank=True, null=True)), - ('dialogue_search_vector', SearchVectorField(blank=True, null=True)), - ('dialogue_config', models.CharField(max_length=100, blank=True, null=True)), + ( + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ( + "scene", + models.ForeignKey( + "postgres_tests.Scene", on_delete=models.SET_NULL + ), + ), + ( + "character", + models.ForeignKey( + "postgres_tests.Character", on_delete=models.SET_NULL + ), + ), + ("dialogue", models.TextField(blank=True, null=True)), + ("dialogue_search_vector", SearchVectorField(blank=True, null=True)), + ( + "dialogue_config", + models.CharField(max_length=100, blank=True, null=True), + ), ], options={ - 'required_db_vendor': 'postgresql', + "required_db_vendor": "postgresql", }, bases=None, ), migrations.CreateModel( - name='LineSavedSearch', + name="LineSavedSearch", fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('line', models.ForeignKey('postgres_tests.Line', on_delete=models.CASCADE)), - ('query', models.CharField(max_length=100)), + ( + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ( + "line", + models.ForeignKey("postgres_tests.Line", on_delete=models.CASCADE), + ), + ("query", models.CharField(max_length=100)), ], options={ - 'required_db_vendor': 'postgresql', + "required_db_vendor": "postgresql", }, ), migrations.CreateModel( - name='AggregateTestModel', + name="AggregateTestModel", fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('boolean_field', models.BooleanField(null=True)), - ('char_field', models.CharField(max_length=30, blank=True)), - ('text_field', models.TextField(blank=True)), - ('integer_field', models.IntegerField(null=True)), - ('json_field', models.JSONField(null=True)), + ( + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ("boolean_field", models.BooleanField(null=True)), + ("char_field", models.CharField(max_length=30, blank=True)), + ("text_field", models.TextField(blank=True)), + ("integer_field", models.IntegerField(null=True)), + ("json_field", models.JSONField(null=True)), ], options={ - 'required_db_vendor': 'postgresql', + "required_db_vendor": "postgresql", }, ), migrations.CreateModel( - name='StatTestModel', + name="StatTestModel", fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('int1', models.IntegerField()), - ('int2', models.IntegerField()), - ('related_field', models.ForeignKey( - 'postgres_tests.AggregateTestModel', - models.SET_NULL, - null=True, - )), + ( + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ("int1", models.IntegerField()), + ("int2", models.IntegerField()), + ( + "related_field", + models.ForeignKey( + "postgres_tests.AggregateTestModel", + models.SET_NULL, + null=True, + ), + ), ], options={ - 'required_db_vendor': 'postgresql', + "required_db_vendor": "postgresql", }, ), migrations.CreateModel( - name='NowTestModel', + name="NowTestModel", fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('when', models.DateTimeField(null=True, default=None)), - ] + ( + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ("when", models.DateTimeField(null=True, default=None)), + ], ), migrations.CreateModel( - name='UUIDTestModel', + name="UUIDTestModel", fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('uuid', models.UUIDField(default=None, null=True)), - ] + ( + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ("uuid", models.UUIDField(default=None, null=True)), + ], ), migrations.CreateModel( - name='RangesModel', + name="RangesModel", fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('ints', IntegerRangeField(null=True, blank=True)), - ('bigints', BigIntegerRangeField(null=True, blank=True)), - ('decimals', DecimalRangeField(null=True, blank=True)), - ('timestamps', DateTimeRangeField(null=True, blank=True)), - ('timestamps_inner', DateTimeRangeField(null=True, blank=True)), - ('timestamps_closed_bounds', DateTimeRangeField(null=True, blank=True, default_bounds='[]')), - ('dates', DateRangeField(null=True, blank=True)), - ('dates_inner', DateRangeField(null=True, blank=True)), + ( + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ("ints", IntegerRangeField(null=True, blank=True)), + ("bigints", BigIntegerRangeField(null=True, blank=True)), + ("decimals", DecimalRangeField(null=True, blank=True)), + ("timestamps", DateTimeRangeField(null=True, blank=True)), + ("timestamps_inner", DateTimeRangeField(null=True, blank=True)), + ( + "timestamps_closed_bounds", + DateTimeRangeField(null=True, blank=True, default_bounds="[]"), + ), + ("dates", DateRangeField(null=True, blank=True)), + ("dates_inner", DateRangeField(null=True, blank=True)), ], - options={ - 'required_db_vendor': 'postgresql' - }, - bases=(models.Model,) + options={"required_db_vendor": "postgresql"}, + bases=(models.Model,), ), migrations.CreateModel( - name='RangeLookupsModel', + name="RangeLookupsModel", fields=[ - ('parent', models.ForeignKey( - 'postgres_tests.RangesModel', - models.SET_NULL, - blank=True, null=True, - )), - ('integer', models.IntegerField(blank=True, null=True)), - ('big_integer', models.BigIntegerField(blank=True, null=True)), - ('float', models.FloatField(blank=True, null=True)), - ('timestamp', models.DateTimeField(blank=True, null=True)), - ('date', models.DateField(blank=True, null=True)), - ('small_integer', models.SmallIntegerField(blank=True, null=True)), - ('decimal_field', models.DecimalField(max_digits=5, decimal_places=2, blank=True, null=True)), + ( + "parent", + models.ForeignKey( + "postgres_tests.RangesModel", + models.SET_NULL, + blank=True, + null=True, + ), + ), + ("integer", models.IntegerField(blank=True, null=True)), + ("big_integer", models.BigIntegerField(blank=True, null=True)), + ("float", models.FloatField(blank=True, null=True)), + ("timestamp", models.DateTimeField(blank=True, null=True)), + ("date", models.DateField(blank=True, null=True)), + ("small_integer", models.SmallIntegerField(blank=True, null=True)), + ( + "decimal_field", + models.DecimalField( + max_digits=5, decimal_places=2, blank=True, null=True + ), + ), ], options={ - 'required_db_vendor': 'postgresql', + "required_db_vendor": "postgresql", }, bases=(models.Model,), ), migrations.CreateModel( - name='ArrayEnumModel', + name="ArrayEnumModel", fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('array_of_enums', ArrayField(EnumField(max_length=20), null=True, blank=True)), + ( + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ( + "array_of_enums", + ArrayField(EnumField(max_length=20), null=True, blank=True), + ), ], options={ - 'required_db_vendor': 'postgresql', + "required_db_vendor": "postgresql", }, bases=(models.Model,), ), migrations.CreateModel( - name='Room', + name="Room", fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('number', models.IntegerField(unique=True)), + ( + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ("number", models.IntegerField(unique=True)), ], ), migrations.CreateModel( - name='HotelReservation', + name="HotelReservation", fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('room', models.ForeignKey('postgres_tests.Room', models.CASCADE)), - ('datespan', DateRangeField()), - ('start', models.DateTimeField()), - ('end', models.DateTimeField()), - ('cancelled', models.BooleanField(default=False)), - ('requirements', models.JSONField(blank=True, null=True)), + ( + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ("room", models.ForeignKey("postgres_tests.Room", models.CASCADE)), + ("datespan", DateRangeField()), + ("start", models.DateTimeField()), + ("end", models.DateTimeField()), + ("cancelled", models.BooleanField(default=False)), + ("requirements", models.JSONField(blank=True, null=True)), ], options={ - 'required_db_vendor': 'postgresql', + "required_db_vendor": "postgresql", }, ), ] diff --git a/tests/postgres_tests/models.py b/tests/postgres_tests/models.py index 444b039840..8f27838ad5 100644 --- a/tests/postgres_tests/models.py +++ b/tests/postgres_tests/models.py @@ -1,9 +1,18 @@ from django.db import models from .fields import ( - ArrayField, BigIntegerRangeField, CICharField, CIEmailField, CITextField, - DateRangeField, DateTimeRangeField, DecimalRangeField, EnumField, - HStoreField, IntegerRangeField, SearchVectorField, + ArrayField, + BigIntegerRangeField, + CICharField, + CIEmailField, + CITextField, + DateRangeField, + DateTimeRangeField, + DecimalRangeField, + EnumField, + HStoreField, + IntegerRangeField, + SearchVectorField, ) @@ -16,7 +25,6 @@ class Tag: class TagField(models.SmallIntegerField): - def from_db_value(self, value, expression, connection): if value is None: return value @@ -36,7 +44,7 @@ class TagField(models.SmallIntegerField): class PostgreSQLModel(models.Model): class Meta: abstract = True - required_db_vendor = 'postgresql' + required_db_vendor = "postgresql" class IntegerArrayModel(PostgreSQLModel): @@ -66,7 +74,9 @@ class NestedIntegerArrayModel(PostgreSQLModel): class OtherTypesArrayModel(PostgreSQLModel): ips = ArrayField(models.GenericIPAddressField(), default=list) uuids = ArrayField(models.UUIDField(), default=list) - decimals = ArrayField(models.DecimalField(max_digits=5, decimal_places=2), default=list) + decimals = ArrayField( + models.DecimalField(max_digits=5, decimal_places=2), default=list + ) tags = ArrayField(TagField(), blank=True, null=True) json = ArrayField(models.JSONField(default=dict), default=list) int_ranges = ArrayField(IntegerRangeField(), blank=True, null=True) @@ -117,15 +127,15 @@ class CITestModel(PostgreSQLModel): class Line(PostgreSQLModel): - scene = models.ForeignKey('Scene', models.CASCADE) - character = models.ForeignKey('Character', models.CASCADE) + scene = models.ForeignKey("Scene", models.CASCADE) + character = models.ForeignKey("Character", models.CASCADE) dialogue = models.TextField(blank=True, null=True) dialogue_search_vector = SearchVectorField(blank=True, null=True) dialogue_config = models.CharField(max_length=100, blank=True, null=True) class LineSavedSearch(PostgreSQLModel): - line = models.ForeignKey('Line', models.CASCADE) + line = models.ForeignKey("Line", models.CASCADE) query = models.CharField(max_length=100) @@ -136,7 +146,9 @@ class RangesModel(PostgreSQLModel): timestamps = DateTimeRangeField(blank=True, null=True) timestamps_inner = DateTimeRangeField(blank=True, null=True) timestamps_closed_bounds = DateTimeRangeField( - blank=True, null=True, default_bounds='[]', + blank=True, + null=True, + default_bounds="[]", ) dates = DateRangeField(blank=True, null=True) dates_inner = DateRangeField(blank=True, null=True) @@ -150,7 +162,9 @@ class RangeLookupsModel(PostgreSQLModel): timestamp = models.DateTimeField(blank=True, null=True) date = models.DateField(blank=True, null=True) small_integer = models.SmallIntegerField(blank=True, null=True) - decimal_field = models.DecimalField(max_digits=5, decimal_places=2, blank=True, null=True) + decimal_field = models.DecimalField( + max_digits=5, decimal_places=2, blank=True, null=True + ) class ArrayFieldSubclass(ArrayField): @@ -162,6 +176,7 @@ class AggregateTestModel(PostgreSQLModel): """ To test postgres-specific general aggregation functions """ + char_field = models.CharField(max_length=30, blank=True) text_field = models.TextField(blank=True) integer_field = models.IntegerField(null=True) @@ -173,6 +188,7 @@ class StatTestModel(PostgreSQLModel): """ To test postgres-specific aggregation functions for statistics """ + int1 = models.IntegerField() int2 = models.IntegerField() related_field = models.ForeignKey(AggregateTestModel, models.SET_NULL, null=True) @@ -191,7 +207,7 @@ class Room(models.Model): class HotelReservation(PostgreSQLModel): - room = models.ForeignKey('Room', on_delete=models.CASCADE) + room = models.ForeignKey("Room", on_delete=models.CASCADE) datespan = DateRangeField() start = models.DateTimeField() end = models.DateTimeField() diff --git a/tests/postgres_tests/test_aggregates.py b/tests/postgres_tests/test_aggregates.py index 399c4fa8a9..c3df490fcf 100644 --- a/tests/postgres_tests/test_aggregates.py +++ b/tests/postgres_tests/test_aggregates.py @@ -1,6 +1,13 @@ from django.db import connection from django.db.models import ( - CharField, F, Func, IntegerField, OuterRef, Q, Subquery, Value, + CharField, + F, + Func, + IntegerField, + OuterRef, + Q, + Subquery, + Value, ) from django.db.models.fields.json import KeyTextTransform, KeyTransform from django.db.models.functions import Cast, Concat, Substr @@ -14,9 +21,26 @@ from .models import AggregateTestModel, HotelReservation, Room, StatTestModel try: from django.contrib.postgres.aggregates import ( - ArrayAgg, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, Corr, CovarPop, - JSONBAgg, RegrAvgX, RegrAvgY, RegrCount, RegrIntercept, RegrR2, - RegrSlope, RegrSXX, RegrSXY, RegrSYY, StatAggregate, StringAgg, + ArrayAgg, + BitAnd, + BitOr, + BitXor, + BoolAnd, + BoolOr, + Corr, + CovarPop, + JSONBAgg, + RegrAvgX, + RegrAvgY, + RegrCount, + RegrIntercept, + RegrR2, + RegrSlope, + RegrSXX, + RegrSXY, + RegrSYY, + StatAggregate, + StringAgg, ) from django.contrib.postgres.fields import ArrayField except ImportError: @@ -26,52 +50,54 @@ except ImportError: class TestGeneralAggregate(PostgreSQLTestCase): @classmethod def setUpTestData(cls): - cls.aggs = AggregateTestModel.objects.bulk_create([ - AggregateTestModel( - boolean_field=True, - char_field='Foo1', - text_field='Text1', - integer_field=0, - ), - AggregateTestModel( - boolean_field=False, - char_field='Foo2', - text_field='Text2', - integer_field=1, - json_field={'lang': 'pl'}, - ), - AggregateTestModel( - boolean_field=False, - char_field='Foo4', - text_field='Text4', - integer_field=2, - json_field={'lang': 'en'}, - ), - AggregateTestModel( - boolean_field=True, - char_field='Foo3', - text_field='Text3', - integer_field=0, - json_field={'breed': 'collie'}, - ), - ]) + cls.aggs = AggregateTestModel.objects.bulk_create( + [ + AggregateTestModel( + boolean_field=True, + char_field="Foo1", + text_field="Text1", + integer_field=0, + ), + AggregateTestModel( + boolean_field=False, + char_field="Foo2", + text_field="Text2", + integer_field=1, + json_field={"lang": "pl"}, + ), + AggregateTestModel( + boolean_field=False, + char_field="Foo4", + text_field="Text4", + integer_field=2, + json_field={"lang": "en"}, + ), + AggregateTestModel( + boolean_field=True, + char_field="Foo3", + text_field="Text3", + integer_field=0, + json_field={"breed": "collie"}, + ), + ] + ) @ignore_warnings(category=RemovedInDjango50Warning) def test_empty_result_set(self): AggregateTestModel.objects.all().delete() tests = [ - (ArrayAgg('char_field'), []), - (ArrayAgg('integer_field'), []), - (ArrayAgg('boolean_field'), []), - (BitAnd('integer_field'), None), - (BitOr('integer_field'), None), - (BoolAnd('boolean_field'), None), - (BoolOr('boolean_field'), None), - (JSONBAgg('integer_field'), []), - (StringAgg('char_field', delimiter=';'), ''), + (ArrayAgg("char_field"), []), + (ArrayAgg("integer_field"), []), + (ArrayAgg("boolean_field"), []), + (BitAnd("integer_field"), None), + (BitOr("integer_field"), None), + (BoolAnd("boolean_field"), None), + (BoolOr("boolean_field"), None), + (JSONBAgg("integer_field"), []), + (StringAgg("char_field", delimiter=";"), ""), ] if connection.features.has_bit_xor: - tests.append((BitXor('integer_field'), None)) + tests.append((BitXor("integer_field"), None)) for aggregation, expected_result in tests: with self.subTest(aggregation=aggregation): # Empty result with non-execution optimization. @@ -79,29 +105,32 @@ class TestGeneralAggregate(PostgreSQLTestCase): values = AggregateTestModel.objects.none().aggregate( aggregation=aggregation, ) - self.assertEqual(values, {'aggregation': expected_result}) + self.assertEqual(values, {"aggregation": expected_result}) # Empty result when query must be executed. with self.assertNumQueries(1): values = AggregateTestModel.objects.aggregate( aggregation=aggregation, ) - self.assertEqual(values, {'aggregation': expected_result}) + self.assertEqual(values, {"aggregation": expected_result}) def test_default_argument(self): AggregateTestModel.objects.all().delete() tests = [ - (ArrayAgg('char_field', default=['<empty>']), ['<empty>']), - (ArrayAgg('integer_field', default=[0]), [0]), - (ArrayAgg('boolean_field', default=[False]), [False]), - (BitAnd('integer_field', default=0), 0), - (BitOr('integer_field', default=0), 0), - (BoolAnd('boolean_field', default=False), False), - (BoolOr('boolean_field', default=False), False), - (JSONBAgg('integer_field', default=Value('["<empty>"]')), ['<empty>']), - (StringAgg('char_field', delimiter=';', default=Value('<empty>')), '<empty>'), + (ArrayAgg("char_field", default=["<empty>"]), ["<empty>"]), + (ArrayAgg("integer_field", default=[0]), [0]), + (ArrayAgg("boolean_field", default=[False]), [False]), + (BitAnd("integer_field", default=0), 0), + (BitOr("integer_field", default=0), 0), + (BoolAnd("boolean_field", default=False), False), + (BoolOr("boolean_field", default=False), False), + (JSONBAgg("integer_field", default=Value('["<empty>"]')), ["<empty>"]), + ( + StringAgg("char_field", delimiter=";", default=Value("<empty>")), + "<empty>", + ), ] if connection.features.has_bit_xor: - tests.append((BitXor('integer_field', default=0), 0)) + tests.append((BitXor("integer_field", default=0), 0)) for aggregation, expected_result in tests: with self.subTest(aggregation=aggregation): # Empty result with non-execution optimization. @@ -109,135 +138,159 @@ class TestGeneralAggregate(PostgreSQLTestCase): values = AggregateTestModel.objects.none().aggregate( aggregation=aggregation, ) - self.assertEqual(values, {'aggregation': expected_result}) + self.assertEqual(values, {"aggregation": expected_result}) # Empty result when query must be executed. with self.assertNumQueries(1): values = AggregateTestModel.objects.aggregate( aggregation=aggregation, ) - self.assertEqual(values, {'aggregation': expected_result}) + self.assertEqual(values, {"aggregation": expected_result}) def test_convert_value_deprecation(self): AggregateTestModel.objects.all().delete() queryset = AggregateTestModel.objects.all() - with self.assertWarnsMessage(RemovedInDjango50Warning, ArrayAgg.deprecation_msg): - queryset.aggregate(aggregation=ArrayAgg('boolean_field')) + with self.assertWarnsMessage( + RemovedInDjango50Warning, ArrayAgg.deprecation_msg + ): + queryset.aggregate(aggregation=ArrayAgg("boolean_field")) - with self.assertWarnsMessage(RemovedInDjango50Warning, JSONBAgg.deprecation_msg): - queryset.aggregate(aggregation=JSONBAgg('integer_field')) + with self.assertWarnsMessage( + RemovedInDjango50Warning, JSONBAgg.deprecation_msg + ): + queryset.aggregate(aggregation=JSONBAgg("integer_field")) - with self.assertWarnsMessage(RemovedInDjango50Warning, StringAgg.deprecation_msg): - queryset.aggregate(aggregation=StringAgg('char_field', delimiter=';')) + with self.assertWarnsMessage( + RemovedInDjango50Warning, StringAgg.deprecation_msg + ): + queryset.aggregate(aggregation=StringAgg("char_field", delimiter=";")) # No warnings raised if default argument provided. self.assertEqual( - queryset.aggregate(aggregation=ArrayAgg('boolean_field', default=None)), - {'aggregation': None}, + queryset.aggregate(aggregation=ArrayAgg("boolean_field", default=None)), + {"aggregation": None}, ) self.assertEqual( - queryset.aggregate(aggregation=JSONBAgg('integer_field', default=None)), - {'aggregation': None}, + queryset.aggregate(aggregation=JSONBAgg("integer_field", default=None)), + {"aggregation": None}, ) self.assertEqual( queryset.aggregate( - aggregation=StringAgg('char_field', delimiter=';', default=None), + aggregation=StringAgg("char_field", delimiter=";", default=None), ), - {'aggregation': None}, + {"aggregation": None}, ) self.assertEqual( - queryset.aggregate(aggregation=ArrayAgg('boolean_field', default=Value([]))), - {'aggregation': []}, + queryset.aggregate( + aggregation=ArrayAgg("boolean_field", default=Value([])) + ), + {"aggregation": []}, ) self.assertEqual( - queryset.aggregate(aggregation=JSONBAgg('integer_field', default=Value('[]'))), - {'aggregation': []}, + queryset.aggregate( + aggregation=JSONBAgg("integer_field", default=Value("[]")) + ), + {"aggregation": []}, ) self.assertEqual( queryset.aggregate( - aggregation=StringAgg('char_field', delimiter=';', default=Value('')), + aggregation=StringAgg("char_field", delimiter=";", default=Value("")), ), - {'aggregation': ''}, + {"aggregation": ""}, ) def test_array_agg_charfield(self): - values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('char_field')) - self.assertEqual(values, {'arrayagg': ['Foo1', 'Foo2', 'Foo4', 'Foo3']}) + values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg("char_field")) + self.assertEqual(values, {"arrayagg": ["Foo1", "Foo2", "Foo4", "Foo3"]}) def test_array_agg_charfield_ordering(self): ordering_test_cases = ( - (F('char_field').desc(), ['Foo4', 'Foo3', 'Foo2', 'Foo1']), - (F('char_field').asc(), ['Foo1', 'Foo2', 'Foo3', 'Foo4']), - (F('char_field'), ['Foo1', 'Foo2', 'Foo3', 'Foo4']), - ([F('boolean_field'), F('char_field').desc()], ['Foo4', 'Foo2', 'Foo3', 'Foo1']), - ((F('boolean_field'), F('char_field').desc()), ['Foo4', 'Foo2', 'Foo3', 'Foo1']), - ('char_field', ['Foo1', 'Foo2', 'Foo3', 'Foo4']), - ('-char_field', ['Foo4', 'Foo3', 'Foo2', 'Foo1']), - (Concat('char_field', Value('@')), ['Foo1', 'Foo2', 'Foo3', 'Foo4']), - (Concat('char_field', Value('@')).desc(), ['Foo4', 'Foo3', 'Foo2', 'Foo1']), + (F("char_field").desc(), ["Foo4", "Foo3", "Foo2", "Foo1"]), + (F("char_field").asc(), ["Foo1", "Foo2", "Foo3", "Foo4"]), + (F("char_field"), ["Foo1", "Foo2", "Foo3", "Foo4"]), + ( + [F("boolean_field"), F("char_field").desc()], + ["Foo4", "Foo2", "Foo3", "Foo1"], + ), ( - (Substr('char_field', 1, 1), F('integer_field'), Substr('char_field', 4, 1).desc()), - ['Foo3', 'Foo1', 'Foo2', 'Foo4'], + (F("boolean_field"), F("char_field").desc()), + ["Foo4", "Foo2", "Foo3", "Foo1"], + ), + ("char_field", ["Foo1", "Foo2", "Foo3", "Foo4"]), + ("-char_field", ["Foo4", "Foo3", "Foo2", "Foo1"]), + (Concat("char_field", Value("@")), ["Foo1", "Foo2", "Foo3", "Foo4"]), + (Concat("char_field", Value("@")).desc(), ["Foo4", "Foo3", "Foo2", "Foo1"]), + ( + ( + Substr("char_field", 1, 1), + F("integer_field"), + Substr("char_field", 4, 1).desc(), + ), + ["Foo3", "Foo1", "Foo2", "Foo4"], ), ) for ordering, expected_output in ordering_test_cases: with self.subTest(ordering=ordering, expected_output=expected_output): values = AggregateTestModel.objects.aggregate( - arrayagg=ArrayAgg('char_field', ordering=ordering) + arrayagg=ArrayAgg("char_field", ordering=ordering) ) - self.assertEqual(values, {'arrayagg': expected_output}) + self.assertEqual(values, {"arrayagg": expected_output}) def test_array_agg_integerfield(self): - values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('integer_field')) - self.assertEqual(values, {'arrayagg': [0, 1, 2, 0]}) + values = AggregateTestModel.objects.aggregate( + arrayagg=ArrayAgg("integer_field") + ) + self.assertEqual(values, {"arrayagg": [0, 1, 2, 0]}) def test_array_agg_integerfield_ordering(self): values = AggregateTestModel.objects.aggregate( - arrayagg=ArrayAgg('integer_field', ordering=F('integer_field').desc()) + arrayagg=ArrayAgg("integer_field", ordering=F("integer_field").desc()) ) - self.assertEqual(values, {'arrayagg': [2, 1, 0, 0]}) + self.assertEqual(values, {"arrayagg": [2, 1, 0, 0]}) def test_array_agg_booleanfield(self): - values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('boolean_field')) - self.assertEqual(values, {'arrayagg': [True, False, False, True]}) + values = AggregateTestModel.objects.aggregate( + arrayagg=ArrayAgg("boolean_field") + ) + self.assertEqual(values, {"arrayagg": [True, False, False, True]}) def test_array_agg_booleanfield_ordering(self): ordering_test_cases = ( - (F('boolean_field').asc(), [False, False, True, True]), - (F('boolean_field').desc(), [True, True, False, False]), - (F('boolean_field'), [False, False, True, True]), + (F("boolean_field").asc(), [False, False, True, True]), + (F("boolean_field").desc(), [True, True, False, False]), + (F("boolean_field"), [False, False, True, True]), ) for ordering, expected_output in ordering_test_cases: with self.subTest(ordering=ordering, expected_output=expected_output): values = AggregateTestModel.objects.aggregate( - arrayagg=ArrayAgg('boolean_field', ordering=ordering) + arrayagg=ArrayAgg("boolean_field", ordering=ordering) ) - self.assertEqual(values, {'arrayagg': expected_output}) + self.assertEqual(values, {"arrayagg": expected_output}) def test_array_agg_jsonfield(self): values = AggregateTestModel.objects.aggregate( arrayagg=ArrayAgg( - KeyTransform('lang', 'json_field'), + KeyTransform("lang", "json_field"), filter=Q(json_field__lang__isnull=False), ), ) - self.assertEqual(values, {'arrayagg': ['pl', 'en']}) + self.assertEqual(values, {"arrayagg": ["pl", "en"]}) def test_array_agg_jsonfield_ordering(self): values = AggregateTestModel.objects.aggregate( arrayagg=ArrayAgg( - KeyTransform('lang', 'json_field'), + KeyTransform("lang", "json_field"), filter=Q(json_field__lang__isnull=False), - ordering=KeyTransform('lang', 'json_field'), + ordering=KeyTransform("lang", "json_field"), ), ) - self.assertEqual(values, {'arrayagg': ['en', 'pl']}) + self.assertEqual(values, {"arrayagg": ["en", "pl"]}) def test_array_agg_filter(self): values = AggregateTestModel.objects.aggregate( - arrayagg=ArrayAgg('integer_field', filter=Q(integer_field__gt=0)), + arrayagg=ArrayAgg("integer_field", filter=Q(integer_field__gt=0)), ) - self.assertEqual(values, {'arrayagg': [1, 2]}) + self.assertEqual(values, {"arrayagg": [1, 2]}) def test_array_agg_lookups(self): aggr1 = AggregateTestModel.objects.create() @@ -246,194 +299,207 @@ class TestGeneralAggregate(PostgreSQLTestCase): StatTestModel.objects.create(related_field=aggr1, int1=2, int2=0) StatTestModel.objects.create(related_field=aggr2, int1=3, int2=0) StatTestModel.objects.create(related_field=aggr2, int1=4, int2=0) - qs = StatTestModel.objects.values('related_field').annotate( - array=ArrayAgg('int1') - ).filter(array__overlap=[2]).values_list('array', flat=True) + qs = ( + StatTestModel.objects.values("related_field") + .annotate(array=ArrayAgg("int1")) + .filter(array__overlap=[2]) + .values_list("array", flat=True) + ) self.assertCountEqual(qs.get(), [1, 2]) def test_bit_and_general(self): - values = AggregateTestModel.objects.filter( - integer_field__in=[0, 1]).aggregate(bitand=BitAnd('integer_field')) - self.assertEqual(values, {'bitand': 0}) + values = AggregateTestModel.objects.filter(integer_field__in=[0, 1]).aggregate( + bitand=BitAnd("integer_field") + ) + self.assertEqual(values, {"bitand": 0}) def test_bit_and_on_only_true_values(self): - values = AggregateTestModel.objects.filter( - integer_field=1).aggregate(bitand=BitAnd('integer_field')) - self.assertEqual(values, {'bitand': 1}) + values = AggregateTestModel.objects.filter(integer_field=1).aggregate( + bitand=BitAnd("integer_field") + ) + self.assertEqual(values, {"bitand": 1}) def test_bit_and_on_only_false_values(self): - values = AggregateTestModel.objects.filter( - integer_field=0).aggregate(bitand=BitAnd('integer_field')) - self.assertEqual(values, {'bitand': 0}) + values = AggregateTestModel.objects.filter(integer_field=0).aggregate( + bitand=BitAnd("integer_field") + ) + self.assertEqual(values, {"bitand": 0}) def test_bit_or_general(self): - values = AggregateTestModel.objects.filter( - integer_field__in=[0, 1]).aggregate(bitor=BitOr('integer_field')) - self.assertEqual(values, {'bitor': 1}) + values = AggregateTestModel.objects.filter(integer_field__in=[0, 1]).aggregate( + bitor=BitOr("integer_field") + ) + self.assertEqual(values, {"bitor": 1}) def test_bit_or_on_only_true_values(self): - values = AggregateTestModel.objects.filter( - integer_field=1).aggregate(bitor=BitOr('integer_field')) - self.assertEqual(values, {'bitor': 1}) + values = AggregateTestModel.objects.filter(integer_field=1).aggregate( + bitor=BitOr("integer_field") + ) + self.assertEqual(values, {"bitor": 1}) def test_bit_or_on_only_false_values(self): - values = AggregateTestModel.objects.filter( - integer_field=0).aggregate(bitor=BitOr('integer_field')) - self.assertEqual(values, {'bitor': 0}) + values = AggregateTestModel.objects.filter(integer_field=0).aggregate( + bitor=BitOr("integer_field") + ) + self.assertEqual(values, {"bitor": 0}) - @skipUnlessDBFeature('has_bit_xor') + @skipUnlessDBFeature("has_bit_xor") def test_bit_xor_general(self): AggregateTestModel.objects.create(integer_field=3) values = AggregateTestModel.objects.filter( integer_field__in=[1, 3], - ).aggregate(bitxor=BitXor('integer_field')) - self.assertEqual(values, {'bitxor': 2}) + ).aggregate(bitxor=BitXor("integer_field")) + self.assertEqual(values, {"bitxor": 2}) - @skipUnlessDBFeature('has_bit_xor') + @skipUnlessDBFeature("has_bit_xor") def test_bit_xor_on_only_true_values(self): values = AggregateTestModel.objects.filter( integer_field=1, - ).aggregate(bitxor=BitXor('integer_field')) - self.assertEqual(values, {'bitxor': 1}) + ).aggregate(bitxor=BitXor("integer_field")) + self.assertEqual(values, {"bitxor": 1}) - @skipUnlessDBFeature('has_bit_xor') + @skipUnlessDBFeature("has_bit_xor") def test_bit_xor_on_only_false_values(self): values = AggregateTestModel.objects.filter( integer_field=0, - ).aggregate(bitxor=BitXor('integer_field')) - self.assertEqual(values, {'bitxor': 0}) + ).aggregate(bitxor=BitXor("integer_field")) + self.assertEqual(values, {"bitxor": 0}) def test_bool_and_general(self): - values = AggregateTestModel.objects.aggregate(booland=BoolAnd('boolean_field')) - self.assertEqual(values, {'booland': False}) + values = AggregateTestModel.objects.aggregate(booland=BoolAnd("boolean_field")) + self.assertEqual(values, {"booland": False}) def test_bool_and_q_object(self): values = AggregateTestModel.objects.aggregate( booland=BoolAnd(Q(integer_field__gt=2)), ) - self.assertEqual(values, {'booland': False}) + self.assertEqual(values, {"booland": False}) def test_bool_or_general(self): - values = AggregateTestModel.objects.aggregate(boolor=BoolOr('boolean_field')) - self.assertEqual(values, {'boolor': True}) + values = AggregateTestModel.objects.aggregate(boolor=BoolOr("boolean_field")) + self.assertEqual(values, {"boolor": True}) def test_bool_or_q_object(self): values = AggregateTestModel.objects.aggregate( boolor=BoolOr(Q(integer_field__gt=2)), ) - self.assertEqual(values, {'boolor': False}) + self.assertEqual(values, {"boolor": False}) def test_string_agg_requires_delimiter(self): with self.assertRaises(TypeError): - AggregateTestModel.objects.aggregate(stringagg=StringAgg('char_field')) + AggregateTestModel.objects.aggregate(stringagg=StringAgg("char_field")) def test_string_agg_delimiter_escaping(self): - values = AggregateTestModel.objects.aggregate(stringagg=StringAgg('char_field', delimiter="'")) - self.assertEqual(values, {'stringagg': "Foo1'Foo2'Foo4'Foo3"}) + values = AggregateTestModel.objects.aggregate( + stringagg=StringAgg("char_field", delimiter="'") + ) + self.assertEqual(values, {"stringagg": "Foo1'Foo2'Foo4'Foo3"}) def test_string_agg_charfield(self): - values = AggregateTestModel.objects.aggregate(stringagg=StringAgg('char_field', delimiter=';')) - self.assertEqual(values, {'stringagg': 'Foo1;Foo2;Foo4;Foo3'}) + values = AggregateTestModel.objects.aggregate( + stringagg=StringAgg("char_field", delimiter=";") + ) + self.assertEqual(values, {"stringagg": "Foo1;Foo2;Foo4;Foo3"}) def test_string_agg_default_output_field(self): values = AggregateTestModel.objects.aggregate( - stringagg=StringAgg('text_field', delimiter=';'), + stringagg=StringAgg("text_field", delimiter=";"), ) - self.assertEqual(values, {'stringagg': 'Text1;Text2;Text4;Text3'}) + self.assertEqual(values, {"stringagg": "Text1;Text2;Text4;Text3"}) def test_string_agg_charfield_ordering(self): ordering_test_cases = ( - (F('char_field').desc(), 'Foo4;Foo3;Foo2;Foo1'), - (F('char_field').asc(), 'Foo1;Foo2;Foo3;Foo4'), - (F('char_field'), 'Foo1;Foo2;Foo3;Foo4'), - ('char_field', 'Foo1;Foo2;Foo3;Foo4'), - ('-char_field', 'Foo4;Foo3;Foo2;Foo1'), - (Concat('char_field', Value('@')), 'Foo1;Foo2;Foo3;Foo4'), - (Concat('char_field', Value('@')).desc(), 'Foo4;Foo3;Foo2;Foo1'), + (F("char_field").desc(), "Foo4;Foo3;Foo2;Foo1"), + (F("char_field").asc(), "Foo1;Foo2;Foo3;Foo4"), + (F("char_field"), "Foo1;Foo2;Foo3;Foo4"), + ("char_field", "Foo1;Foo2;Foo3;Foo4"), + ("-char_field", "Foo4;Foo3;Foo2;Foo1"), + (Concat("char_field", Value("@")), "Foo1;Foo2;Foo3;Foo4"), + (Concat("char_field", Value("@")).desc(), "Foo4;Foo3;Foo2;Foo1"), ) for ordering, expected_output in ordering_test_cases: with self.subTest(ordering=ordering, expected_output=expected_output): values = AggregateTestModel.objects.aggregate( - stringagg=StringAgg('char_field', delimiter=';', ordering=ordering) + stringagg=StringAgg("char_field", delimiter=";", ordering=ordering) ) - self.assertEqual(values, {'stringagg': expected_output}) + self.assertEqual(values, {"stringagg": expected_output}) def test_string_agg_jsonfield_ordering(self): values = AggregateTestModel.objects.aggregate( stringagg=StringAgg( - KeyTextTransform('lang', 'json_field'), - delimiter=';', - ordering=KeyTextTransform('lang', 'json_field'), + KeyTextTransform("lang", "json_field"), + delimiter=";", + ordering=KeyTextTransform("lang", "json_field"), output_field=CharField(), ), ) - self.assertEqual(values, {'stringagg': 'en;pl'}) + self.assertEqual(values, {"stringagg": "en;pl"}) def test_string_agg_filter(self): values = AggregateTestModel.objects.aggregate( stringagg=StringAgg( - 'char_field', - delimiter=';', - filter=Q(char_field__endswith='3') | Q(char_field__endswith='1'), + "char_field", + delimiter=";", + filter=Q(char_field__endswith="3") | Q(char_field__endswith="1"), ) ) - self.assertEqual(values, {'stringagg': 'Foo1;Foo3'}) + self.assertEqual(values, {"stringagg": "Foo1;Foo3"}) def test_orderable_agg_alternative_fields(self): values = AggregateTestModel.objects.aggregate( - arrayagg=ArrayAgg('integer_field', ordering=F('char_field').asc()) + arrayagg=ArrayAgg("integer_field", ordering=F("char_field").asc()) ) - self.assertEqual(values, {'arrayagg': [0, 1, 0, 2]}) + self.assertEqual(values, {"arrayagg": [0, 1, 0, 2]}) def test_jsonb_agg(self): - values = AggregateTestModel.objects.aggregate(jsonbagg=JSONBAgg('char_field')) - self.assertEqual(values, {'jsonbagg': ['Foo1', 'Foo2', 'Foo4', 'Foo3']}) + values = AggregateTestModel.objects.aggregate(jsonbagg=JSONBAgg("char_field")) + self.assertEqual(values, {"jsonbagg": ["Foo1", "Foo2", "Foo4", "Foo3"]}) def test_jsonb_agg_charfield_ordering(self): ordering_test_cases = ( - (F('char_field').desc(), ['Foo4', 'Foo3', 'Foo2', 'Foo1']), - (F('char_field').asc(), ['Foo1', 'Foo2', 'Foo3', 'Foo4']), - (F('char_field'), ['Foo1', 'Foo2', 'Foo3', 'Foo4']), - ('char_field', ['Foo1', 'Foo2', 'Foo3', 'Foo4']), - ('-char_field', ['Foo4', 'Foo3', 'Foo2', 'Foo1']), - (Concat('char_field', Value('@')), ['Foo1', 'Foo2', 'Foo3', 'Foo4']), - (Concat('char_field', Value('@')).desc(), ['Foo4', 'Foo3', 'Foo2', 'Foo1']), + (F("char_field").desc(), ["Foo4", "Foo3", "Foo2", "Foo1"]), + (F("char_field").asc(), ["Foo1", "Foo2", "Foo3", "Foo4"]), + (F("char_field"), ["Foo1", "Foo2", "Foo3", "Foo4"]), + ("char_field", ["Foo1", "Foo2", "Foo3", "Foo4"]), + ("-char_field", ["Foo4", "Foo3", "Foo2", "Foo1"]), + (Concat("char_field", Value("@")), ["Foo1", "Foo2", "Foo3", "Foo4"]), + (Concat("char_field", Value("@")).desc(), ["Foo4", "Foo3", "Foo2", "Foo1"]), ) for ordering, expected_output in ordering_test_cases: with self.subTest(ordering=ordering, expected_output=expected_output): values = AggregateTestModel.objects.aggregate( - jsonbagg=JSONBAgg('char_field', ordering=ordering), + jsonbagg=JSONBAgg("char_field", ordering=ordering), ) - self.assertEqual(values, {'jsonbagg': expected_output}) + self.assertEqual(values, {"jsonbagg": expected_output}) def test_jsonb_agg_integerfield_ordering(self): values = AggregateTestModel.objects.aggregate( - jsonbagg=JSONBAgg('integer_field', ordering=F('integer_field').desc()), + jsonbagg=JSONBAgg("integer_field", ordering=F("integer_field").desc()), ) - self.assertEqual(values, {'jsonbagg': [2, 1, 0, 0]}) + self.assertEqual(values, {"jsonbagg": [2, 1, 0, 0]}) def test_jsonb_agg_booleanfield_ordering(self): ordering_test_cases = ( - (F('boolean_field').asc(), [False, False, True, True]), - (F('boolean_field').desc(), [True, True, False, False]), - (F('boolean_field'), [False, False, True, True]), + (F("boolean_field").asc(), [False, False, True, True]), + (F("boolean_field").desc(), [True, True, False, False]), + (F("boolean_field"), [False, False, True, True]), ) for ordering, expected_output in ordering_test_cases: with self.subTest(ordering=ordering, expected_output=expected_output): values = AggregateTestModel.objects.aggregate( - jsonbagg=JSONBAgg('boolean_field', ordering=ordering), + jsonbagg=JSONBAgg("boolean_field", ordering=ordering), ) - self.assertEqual(values, {'jsonbagg': expected_output}) + self.assertEqual(values, {"jsonbagg": expected_output}) def test_jsonb_agg_jsonfield_ordering(self): values = AggregateTestModel.objects.aggregate( jsonbagg=JSONBAgg( - KeyTransform('lang', 'json_field'), + KeyTransform("lang", "json_field"), filter=Q(json_field__lang__isnull=False), - ordering=KeyTransform('lang', 'json_field'), + ordering=KeyTransform("lang", "json_field"), ), ) - self.assertEqual(values, {'jsonbagg': ['en', 'pl']}) + self.assertEqual(values, {"jsonbagg": ["en", "pl"]}) def test_jsonb_agg_key_index_transforms(self): room101 = Room.objects.create(number=101) @@ -448,160 +514,206 @@ class TestGeneralAggregate(PostgreSQLTestCase): start=datetimes[0], end=datetimes[1], room=room102, - requirements={'double_bed': True, 'parking': True}, + requirements={"double_bed": True, "parking": True}, ) HotelReservation.objects.create( datespan=(datetimes[1].date(), datetimes[2].date()), start=datetimes[1], end=datetimes[2], room=room102, - requirements={'double_bed': False, 'sea_view': True, 'parking': False}, + requirements={"double_bed": False, "sea_view": True, "parking": False}, ) HotelReservation.objects.create( datespan=(datetimes[0].date(), datetimes[2].date()), start=datetimes[0], end=datetimes[2], room=room101, - requirements={'sea_view': False}, + requirements={"sea_view": False}, ) - values = Room.objects.annotate( - requirements=JSONBAgg( - 'hotelreservation__requirements', - ordering='-hotelreservation__start', + values = ( + Room.objects.annotate( + requirements=JSONBAgg( + "hotelreservation__requirements", + ordering="-hotelreservation__start", + ) ) - ).filter(requirements__0__sea_view=True).values('number', 'requirements') - self.assertSequenceEqual(values, [ - {'number': 102, 'requirements': [ - {'double_bed': False, 'sea_view': True, 'parking': False}, - {'double_bed': True, 'parking': True}, - ]}, - ]) + .filter(requirements__0__sea_view=True) + .values("number", "requirements") + ) + self.assertSequenceEqual( + values, + [ + { + "number": 102, + "requirements": [ + {"double_bed": False, "sea_view": True, "parking": False}, + {"double_bed": True, "parking": True}, + ], + }, + ], + ) def test_string_agg_array_agg_ordering_in_subquery(self): stats = [] - for i, agg in enumerate(AggregateTestModel.objects.order_by('char_field')): + for i, agg in enumerate(AggregateTestModel.objects.order_by("char_field")): stats.append(StatTestModel(related_field=agg, int1=i, int2=i + 1)) stats.append(StatTestModel(related_field=agg, int1=i + 1, int2=i)) StatTestModel.objects.bulk_create(stats) for aggregate, expected_result in ( ( - ArrayAgg('stattestmodel__int1', ordering='-stattestmodel__int2'), - [('Foo1', [0, 1]), ('Foo2', [1, 2]), ('Foo3', [2, 3]), ('Foo4', [3, 4])], + ArrayAgg("stattestmodel__int1", ordering="-stattestmodel__int2"), + [ + ("Foo1", [0, 1]), + ("Foo2", [1, 2]), + ("Foo3", [2, 3]), + ("Foo4", [3, 4]), + ], ), ( StringAgg( - Cast('stattestmodel__int1', CharField()), - delimiter=';', - ordering='-stattestmodel__int2', + Cast("stattestmodel__int1", CharField()), + delimiter=";", + ordering="-stattestmodel__int2", ), - [('Foo1', '0;1'), ('Foo2', '1;2'), ('Foo3', '2;3'), ('Foo4', '3;4')], + [("Foo1", "0;1"), ("Foo2", "1;2"), ("Foo3", "2;3"), ("Foo4", "3;4")], ), ): with self.subTest(aggregate=aggregate.__class__.__name__): - subquery = AggregateTestModel.objects.filter( - pk=OuterRef('pk'), - ).annotate(agg=aggregate).values('agg') - values = AggregateTestModel.objects.annotate( - agg=Subquery(subquery), - ).order_by('char_field').values_list('char_field', 'agg') + subquery = ( + AggregateTestModel.objects.filter( + pk=OuterRef("pk"), + ) + .annotate(agg=aggregate) + .values("agg") + ) + values = ( + AggregateTestModel.objects.annotate( + agg=Subquery(subquery), + ) + .order_by("char_field") + .values_list("char_field", "agg") + ) self.assertEqual(list(values), expected_result) def test_string_agg_array_agg_filter_in_subquery(self): - StatTestModel.objects.bulk_create([ - StatTestModel(related_field=self.aggs[0], int1=0, int2=5), - StatTestModel(related_field=self.aggs[0], int1=1, int2=4), - StatTestModel(related_field=self.aggs[0], int1=2, int2=3), - ]) + StatTestModel.objects.bulk_create( + [ + StatTestModel(related_field=self.aggs[0], int1=0, int2=5), + StatTestModel(related_field=self.aggs[0], int1=1, int2=4), + StatTestModel(related_field=self.aggs[0], int1=2, int2=3), + ] + ) for aggregate, expected_result in ( ( - ArrayAgg('stattestmodel__int1', filter=Q(stattestmodel__int2__gt=3)), - [('Foo1', [0, 1]), ('Foo2', None)], + ArrayAgg("stattestmodel__int1", filter=Q(stattestmodel__int2__gt=3)), + [("Foo1", [0, 1]), ("Foo2", None)], ), ( StringAgg( - Cast('stattestmodel__int2', CharField()), - delimiter=';', + Cast("stattestmodel__int2", CharField()), + delimiter=";", filter=Q(stattestmodel__int1__lt=2), ), - [('Foo1', '5;4'), ('Foo2', None)], + [("Foo1", "5;4"), ("Foo2", None)], ), ): with self.subTest(aggregate=aggregate.__class__.__name__): - subquery = AggregateTestModel.objects.filter( - pk=OuterRef('pk'), - ).annotate(agg=aggregate).values('agg') - values = AggregateTestModel.objects.annotate( - agg=Subquery(subquery), - ).filter( - char_field__in=['Foo1', 'Foo2'], - ).order_by('char_field').values_list('char_field', 'agg') + subquery = ( + AggregateTestModel.objects.filter( + pk=OuterRef("pk"), + ) + .annotate(agg=aggregate) + .values("agg") + ) + values = ( + AggregateTestModel.objects.annotate( + agg=Subquery(subquery), + ) + .filter( + char_field__in=["Foo1", "Foo2"], + ) + .order_by("char_field") + .values_list("char_field", "agg") + ) self.assertEqual(list(values), expected_result) def test_string_agg_filter_in_subquery_with_exclude(self): - subquery = AggregateTestModel.objects.annotate( - stringagg=StringAgg( - 'char_field', - delimiter=';', - filter=Q(char_field__endswith='1'), + subquery = ( + AggregateTestModel.objects.annotate( + stringagg=StringAgg( + "char_field", + delimiter=";", + filter=Q(char_field__endswith="1"), + ) ) - ).exclude(stringagg='').values('id') + .exclude(stringagg="") + .values("id") + ) self.assertSequenceEqual( AggregateTestModel.objects.filter(id__in=Subquery(subquery)), [self.aggs[0]], ) def test_ordering_isnt_cleared_for_array_subquery(self): - inner_qs = AggregateTestModel.objects.order_by('-integer_field') + inner_qs = AggregateTestModel.objects.order_by("-integer_field") qs = AggregateTestModel.objects.annotate( integers=Func( - Subquery(inner_qs.values('integer_field')), - function='ARRAY', + Subquery(inner_qs.values("integer_field")), + function="ARRAY", output_field=ArrayField(base_field=IntegerField()), ), ) self.assertSequenceEqual( qs.first().integers, - inner_qs.values_list('integer_field', flat=True), + inner_qs.values_list("integer_field", flat=True), ) class TestAggregateDistinct(PostgreSQLTestCase): @classmethod def setUpTestData(cls): - AggregateTestModel.objects.create(char_field='Foo') - AggregateTestModel.objects.create(char_field='Foo') - AggregateTestModel.objects.create(char_field='Bar') + AggregateTestModel.objects.create(char_field="Foo") + AggregateTestModel.objects.create(char_field="Foo") + AggregateTestModel.objects.create(char_field="Bar") def test_string_agg_distinct_false(self): - values = AggregateTestModel.objects.aggregate(stringagg=StringAgg('char_field', delimiter=' ', distinct=False)) - self.assertEqual(values['stringagg'].count('Foo'), 2) - self.assertEqual(values['stringagg'].count('Bar'), 1) + values = AggregateTestModel.objects.aggregate( + stringagg=StringAgg("char_field", delimiter=" ", distinct=False) + ) + self.assertEqual(values["stringagg"].count("Foo"), 2) + self.assertEqual(values["stringagg"].count("Bar"), 1) def test_string_agg_distinct_true(self): - values = AggregateTestModel.objects.aggregate(stringagg=StringAgg('char_field', delimiter=' ', distinct=True)) - self.assertEqual(values['stringagg'].count('Foo'), 1) - self.assertEqual(values['stringagg'].count('Bar'), 1) + values = AggregateTestModel.objects.aggregate( + stringagg=StringAgg("char_field", delimiter=" ", distinct=True) + ) + self.assertEqual(values["stringagg"].count("Foo"), 1) + self.assertEqual(values["stringagg"].count("Bar"), 1) def test_array_agg_distinct_false(self): - values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('char_field', distinct=False)) - self.assertEqual(sorted(values['arrayagg']), ['Bar', 'Foo', 'Foo']) + values = AggregateTestModel.objects.aggregate( + arrayagg=ArrayAgg("char_field", distinct=False) + ) + self.assertEqual(sorted(values["arrayagg"]), ["Bar", "Foo", "Foo"]) def test_array_agg_distinct_true(self): - values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('char_field', distinct=True)) - self.assertEqual(sorted(values['arrayagg']), ['Bar', 'Foo']) + values = AggregateTestModel.objects.aggregate( + arrayagg=ArrayAgg("char_field", distinct=True) + ) + self.assertEqual(sorted(values["arrayagg"]), ["Bar", "Foo"]) def test_jsonb_agg_distinct_false(self): values = AggregateTestModel.objects.aggregate( - jsonbagg=JSONBAgg('char_field', distinct=False), + jsonbagg=JSONBAgg("char_field", distinct=False), ) - self.assertEqual(sorted(values['jsonbagg']), ['Bar', 'Foo', 'Foo']) + self.assertEqual(sorted(values["jsonbagg"]), ["Bar", "Foo", "Foo"]) def test_jsonb_agg_distinct_true(self): values = AggregateTestModel.objects.aggregate( - jsonbagg=JSONBAgg('char_field', distinct=True), + jsonbagg=JSONBAgg("char_field", distinct=True), ) - self.assertEqual(sorted(values['jsonbagg']), ['Bar', 'Foo']) + self.assertEqual(sorted(values["jsonbagg"]), ["Bar", "Foo"]) class TestStatisticsAggregate(PostgreSQLTestCase): @@ -626,37 +738,38 @@ class TestStatisticsAggregate(PostgreSQLTestCase): # Tests for base class (StatAggregate) def test_missing_arguments_raises_exception(self): - with self.assertRaisesMessage(ValueError, 'Both y and x must be provided.'): + with self.assertRaisesMessage(ValueError, "Both y and x must be provided."): StatAggregate(x=None, y=None) def test_correct_source_expressions(self): - func = StatAggregate(x='test', y=13) + func = StatAggregate(x="test", y=13) self.assertIsInstance(func.source_expressions[0], Value) self.assertIsInstance(func.source_expressions[1], F) def test_alias_is_required(self): class SomeFunc(StatAggregate): - function = 'TEST' - with self.assertRaisesMessage(TypeError, 'Complex aggregates require an alias'): - StatTestModel.objects.aggregate(SomeFunc(y='int2', x='int1')) + function = "TEST" + + with self.assertRaisesMessage(TypeError, "Complex aggregates require an alias"): + StatTestModel.objects.aggregate(SomeFunc(y="int2", x="int1")) # Test aggregates def test_empty_result_set(self): StatTestModel.objects.all().delete() tests = [ - (Corr(y='int2', x='int1'), None), - (CovarPop(y='int2', x='int1'), None), - (CovarPop(y='int2', x='int1', sample=True), None), - (RegrAvgX(y='int2', x='int1'), None), - (RegrAvgY(y='int2', x='int1'), None), - (RegrCount(y='int2', x='int1'), 0), - (RegrIntercept(y='int2', x='int1'), None), - (RegrR2(y='int2', x='int1'), None), - (RegrSlope(y='int2', x='int1'), None), - (RegrSXX(y='int2', x='int1'), None), - (RegrSXY(y='int2', x='int1'), None), - (RegrSYY(y='int2', x='int1'), None), + (Corr(y="int2", x="int1"), None), + (CovarPop(y="int2", x="int1"), None), + (CovarPop(y="int2", x="int1", sample=True), None), + (RegrAvgX(y="int2", x="int1"), None), + (RegrAvgY(y="int2", x="int1"), None), + (RegrCount(y="int2", x="int1"), 0), + (RegrIntercept(y="int2", x="int1"), None), + (RegrR2(y="int2", x="int1"), None), + (RegrSlope(y="int2", x="int1"), None), + (RegrSXX(y="int2", x="int1"), None), + (RegrSXY(y="int2", x="int1"), None), + (RegrSYY(y="int2", x="int1"), None), ] for aggregation, expected_result in tests: with self.subTest(aggregation=aggregation): @@ -665,29 +778,29 @@ class TestStatisticsAggregate(PostgreSQLTestCase): values = StatTestModel.objects.none().aggregate( aggregation=aggregation, ) - self.assertEqual(values, {'aggregation': expected_result}) + self.assertEqual(values, {"aggregation": expected_result}) # Empty result when query must be executed. with self.assertNumQueries(1): values = StatTestModel.objects.aggregate( aggregation=aggregation, ) - self.assertEqual(values, {'aggregation': expected_result}) + self.assertEqual(values, {"aggregation": expected_result}) def test_default_argument(self): StatTestModel.objects.all().delete() tests = [ - (Corr(y='int2', x='int1', default=0), 0), - (CovarPop(y='int2', x='int1', default=0), 0), - (CovarPop(y='int2', x='int1', sample=True, default=0), 0), - (RegrAvgX(y='int2', x='int1', default=0), 0), - (RegrAvgY(y='int2', x='int1', default=0), 0), + (Corr(y="int2", x="int1", default=0), 0), + (CovarPop(y="int2", x="int1", default=0), 0), + (CovarPop(y="int2", x="int1", sample=True, default=0), 0), + (RegrAvgX(y="int2", x="int1", default=0), 0), + (RegrAvgY(y="int2", x="int1", default=0), 0), # RegrCount() doesn't support the default argument. - (RegrIntercept(y='int2', x='int1', default=0), 0), - (RegrR2(y='int2', x='int1', default=0), 0), - (RegrSlope(y='int2', x='int1', default=0), 0), - (RegrSXX(y='int2', x='int1', default=0), 0), - (RegrSXY(y='int2', x='int1', default=0), 0), - (RegrSYY(y='int2', x='int1', default=0), 0), + (RegrIntercept(y="int2", x="int1", default=0), 0), + (RegrR2(y="int2", x="int1", default=0), 0), + (RegrSlope(y="int2", x="int1", default=0), 0), + (RegrSXX(y="int2", x="int1", default=0), 0), + (RegrSXY(y="int2", x="int1", default=0), 0), + (RegrSYY(y="int2", x="int1", default=0), 0), ] for aggregation, expected_result in tests: with self.subTest(aggregation=aggregation): @@ -696,71 +809,81 @@ class TestStatisticsAggregate(PostgreSQLTestCase): values = StatTestModel.objects.none().aggregate( aggregation=aggregation, ) - self.assertEqual(values, {'aggregation': expected_result}) + self.assertEqual(values, {"aggregation": expected_result}) # Empty result when query must be executed. with self.assertNumQueries(1): values = StatTestModel.objects.aggregate( aggregation=aggregation, ) - self.assertEqual(values, {'aggregation': expected_result}) + self.assertEqual(values, {"aggregation": expected_result}) def test_corr_general(self): - values = StatTestModel.objects.aggregate(corr=Corr(y='int2', x='int1')) - self.assertEqual(values, {'corr': -1.0}) + values = StatTestModel.objects.aggregate(corr=Corr(y="int2", x="int1")) + self.assertEqual(values, {"corr": -1.0}) def test_covar_pop_general(self): - values = StatTestModel.objects.aggregate(covarpop=CovarPop(y='int2', x='int1')) - self.assertEqual(values, {'covarpop': Approximate(-0.66, places=1)}) + values = StatTestModel.objects.aggregate(covarpop=CovarPop(y="int2", x="int1")) + self.assertEqual(values, {"covarpop": Approximate(-0.66, places=1)}) def test_covar_pop_sample(self): - values = StatTestModel.objects.aggregate(covarpop=CovarPop(y='int2', x='int1', sample=True)) - self.assertEqual(values, {'covarpop': -1.0}) + values = StatTestModel.objects.aggregate( + covarpop=CovarPop(y="int2", x="int1", sample=True) + ) + self.assertEqual(values, {"covarpop": -1.0}) def test_regr_avgx_general(self): - values = StatTestModel.objects.aggregate(regravgx=RegrAvgX(y='int2', x='int1')) - self.assertEqual(values, {'regravgx': 2.0}) + values = StatTestModel.objects.aggregate(regravgx=RegrAvgX(y="int2", x="int1")) + self.assertEqual(values, {"regravgx": 2.0}) def test_regr_avgy_general(self): - values = StatTestModel.objects.aggregate(regravgy=RegrAvgY(y='int2', x='int1')) - self.assertEqual(values, {'regravgy': 2.0}) + values = StatTestModel.objects.aggregate(regravgy=RegrAvgY(y="int2", x="int1")) + self.assertEqual(values, {"regravgy": 2.0}) def test_regr_count_general(self): - values = StatTestModel.objects.aggregate(regrcount=RegrCount(y='int2', x='int1')) - self.assertEqual(values, {'regrcount': 3}) + values = StatTestModel.objects.aggregate( + regrcount=RegrCount(y="int2", x="int1") + ) + self.assertEqual(values, {"regrcount": 3}) def test_regr_count_default(self): - msg = 'RegrCount does not allow default.' + msg = "RegrCount does not allow default." with self.assertRaisesMessage(TypeError, msg): - RegrCount(y='int2', x='int1', default=0) + RegrCount(y="int2", x="int1", default=0) def test_regr_intercept_general(self): - values = StatTestModel.objects.aggregate(regrintercept=RegrIntercept(y='int2', x='int1')) - self.assertEqual(values, {'regrintercept': 4}) + values = StatTestModel.objects.aggregate( + regrintercept=RegrIntercept(y="int2", x="int1") + ) + self.assertEqual(values, {"regrintercept": 4}) def test_regr_r2_general(self): - values = StatTestModel.objects.aggregate(regrr2=RegrR2(y='int2', x='int1')) - self.assertEqual(values, {'regrr2': 1}) + values = StatTestModel.objects.aggregate(regrr2=RegrR2(y="int2", x="int1")) + self.assertEqual(values, {"regrr2": 1}) def test_regr_slope_general(self): - values = StatTestModel.objects.aggregate(regrslope=RegrSlope(y='int2', x='int1')) - self.assertEqual(values, {'regrslope': -1}) + values = StatTestModel.objects.aggregate( + regrslope=RegrSlope(y="int2", x="int1") + ) + self.assertEqual(values, {"regrslope": -1}) def test_regr_sxx_general(self): - values = StatTestModel.objects.aggregate(regrsxx=RegrSXX(y='int2', x='int1')) - self.assertEqual(values, {'regrsxx': 2.0}) + values = StatTestModel.objects.aggregate(regrsxx=RegrSXX(y="int2", x="int1")) + self.assertEqual(values, {"regrsxx": 2.0}) def test_regr_sxy_general(self): - values = StatTestModel.objects.aggregate(regrsxy=RegrSXY(y='int2', x='int1')) - self.assertEqual(values, {'regrsxy': -2.0}) + values = StatTestModel.objects.aggregate(regrsxy=RegrSXY(y="int2", x="int1")) + self.assertEqual(values, {"regrsxy": -2.0}) def test_regr_syy_general(self): - values = StatTestModel.objects.aggregate(regrsyy=RegrSYY(y='int2', x='int1')) - self.assertEqual(values, {'regrsyy': 2.0}) + values = StatTestModel.objects.aggregate(regrsyy=RegrSYY(y="int2", x="int1")) + self.assertEqual(values, {"regrsyy": 2.0}) def test_regr_avgx_with_related_obj_and_number_as_argument(self): """ This is more complex test to check if JOIN on field and number as argument works as expected. """ - values = StatTestModel.objects.aggregate(complex_regravgx=RegrAvgX(y=5, x='related_field__integer_field')) - self.assertEqual(values, {'complex_regravgx': 1.0}) + values = StatTestModel.objects.aggregate( + complex_regravgx=RegrAvgX(y=5, x="related_field__integer_field") + ) + self.assertEqual(values, {"complex_regravgx": 1.0}) diff --git a/tests/postgres_tests/test_apps.py b/tests/postgres_tests/test_apps.py index 94001822c2..340e555609 100644 --- a/tests/postgres_tests/test_apps.py +++ b/tests/postgres_tests/test_apps.py @@ -7,12 +7,12 @@ from django.test.utils import modify_settings from . import PostgreSQLTestCase try: - from psycopg2.extras import ( - DateRange, DateTimeRange, DateTimeTZRange, NumericRange, - ) + from psycopg2.extras import DateRange, DateTimeRange, DateTimeTZRange, NumericRange from django.contrib.postgres.fields import ( - DateRangeField, DateTimeRangeField, DecimalRangeField, + DateRangeField, + DateTimeRangeField, + DecimalRangeField, IntegerRangeField, ) except ImportError: @@ -22,17 +22,24 @@ except ImportError: class PostgresConfigTests(PostgreSQLTestCase): def test_register_type_handlers_connection(self): from django.contrib.postgres.signals import register_type_handlers - self.assertNotIn(register_type_handlers, connection_created._live_receivers(None)) - with modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}): - self.assertIn(register_type_handlers, connection_created._live_receivers(None)) - self.assertNotIn(register_type_handlers, connection_created._live_receivers(None)) + + self.assertNotIn( + register_type_handlers, connection_created._live_receivers(None) + ) + with modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"}): + self.assertIn( + register_type_handlers, connection_created._live_receivers(None) + ) + self.assertNotIn( + register_type_handlers, connection_created._live_receivers(None) + ) def test_register_serializer_for_migrations(self): tests = ( (DateRange(empty=True), DateRangeField), (DateTimeRange(empty=True), DateRangeField), - (DateTimeTZRange(None, None, '[]'), DateTimeRangeField), - (NumericRange(Decimal('1.0'), Decimal('5.0'), '()'), DecimalRangeField), + (DateTimeTZRange(None, None, "[]"), DateTimeRangeField), + (NumericRange(Decimal("1.0"), Decimal("5.0"), "()"), DecimalRangeField), (NumericRange(1, 10), IntegerRangeField), ) @@ -40,25 +47,31 @@ class PostgresConfigTests(PostgreSQLTestCase): for default, test_field in tests: with self.subTest(default=default): field = test_field(default=default) - with self.assertRaisesMessage(ValueError, 'Cannot serialize: %s' % default.__class__.__name__): + with self.assertRaisesMessage( + ValueError, "Cannot serialize: %s" % default.__class__.__name__ + ): MigrationWriter.serialize(field) assertNotSerializable() - with self.modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}): + with self.modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"}): for default, test_field in tests: with self.subTest(default=default): field = test_field(default=default) serialized_field, imports = MigrationWriter.serialize(field) - self.assertEqual(imports, { - 'import django.contrib.postgres.fields.ranges', - 'import psycopg2.extras', - }) + self.assertEqual( + imports, + { + "import django.contrib.postgres.fields.ranges", + "import psycopg2.extras", + }, + ) self.assertIn( - '%s.%s(default=psycopg2.extras.%r)' % ( + "%s.%s(default=psycopg2.extras.%r)" + % ( field.__module__, field.__class__.__name__, default, ), - serialized_field + serialized_field, ) assertNotSerializable() diff --git a/tests/postgres_tests/test_array.py b/tests/postgres_tests/test_array.py index ba04c15e24..512972b8e6 100644 --- a/tests/postgres_tests/test_array.py +++ b/tests/postgres_tests/test_array.py @@ -15,13 +15,18 @@ from django.test import TransactionTestCase, modify_settings, override_settings from django.test.utils import isolate_apps from django.utils import timezone -from . import ( - PostgreSQLSimpleTestCase, PostgreSQLTestCase, PostgreSQLWidgetTestCase, -) +from . import PostgreSQLSimpleTestCase, PostgreSQLTestCase, PostgreSQLWidgetTestCase from .models import ( - ArrayEnumModel, ArrayFieldSubclass, CharArrayModel, DateTimeArrayModel, - IntegerArrayModel, NestedIntegerArrayModel, NullableIntegerArrayModel, - OtherTypesArrayModel, PostgreSQLModel, Tag, + ArrayEnumModel, + ArrayFieldSubclass, + CharArrayModel, + DateTimeArrayModel, + IntegerArrayModel, + NestedIntegerArrayModel, + NullableIntegerArrayModel, + OtherTypesArrayModel, + PostgreSQLModel, + Tag, ) try: @@ -30,33 +35,33 @@ try: from django.contrib.postgres.aggregates import ArrayAgg from django.contrib.postgres.expressions import ArraySubquery from django.contrib.postgres.fields import ArrayField - from django.contrib.postgres.fields.array import ( - IndexTransform, SliceTransform, - ) + from django.contrib.postgres.fields.array import IndexTransform, SliceTransform from django.contrib.postgres.forms import ( - SimpleArrayField, SplitArrayField, SplitArrayWidget, + SimpleArrayField, + SplitArrayField, + SplitArrayWidget, ) except ImportError: pass -@isolate_apps('postgres_tests') +@isolate_apps("postgres_tests") class BasicTests(PostgreSQLSimpleTestCase): def test_get_field_display(self): class MyModel(PostgreSQLModel): field = ArrayField( models.CharField(max_length=16), choices=[ - ['Media', [(['vinyl', 'cd'], 'Audio')]], - (('mp3', 'mp4'), 'Digital'), + ["Media", [(["vinyl", "cd"], "Audio")]], + (("mp3", "mp4"), "Digital"), ], ) tests = ( - (['vinyl', 'cd'], 'Audio'), - (('mp3', 'mp4'), 'Digital'), - (('a', 'b'), "('a', 'b')"), - (['c', 'd'], "['c', 'd']"), + (["vinyl", "cd"], "Audio"), + (("mp3", "mp4"), "Digital"), + (("a", "b"), "('a', 'b')"), + (["c", "d"], "['c', 'd']"), ) for value, display in tests: with self.subTest(value=value, display=display): @@ -69,17 +74,18 @@ class BasicTests(PostgreSQLSimpleTestCase): ArrayField(models.CharField(max_length=16)), choices=[ [ - 'Media', - [([['vinyl', 'cd'], ('x',)], 'Audio')], + "Media", + [([["vinyl", "cd"], ("x",)], "Audio")], ], - ((['mp3'], ('mp4',)), 'Digital'), + ((["mp3"], ("mp4",)), "Digital"), ], ) + tests = ( - ([['vinyl', 'cd'], ('x',)], 'Audio'), - ((['mp3'], ('mp4',)), 'Digital'), - ((('a', 'b'), ('c',)), "(('a', 'b'), ('c',))"), - ([['a', 'b'], ['c']], "[['a', 'b'], ['c']]"), + ([["vinyl", "cd"], ("x",)], "Audio"), + ((["mp3"], ("mp4",)), "Digital"), + ((("a", "b"), ("c",)), "(('a', 'b'), ('c',))"), + ([["a", "b"], ["c"]], "[['a', 'b'], ['c']]"), ) for value, display in tests: with self.subTest(value=value, display=display): @@ -88,7 +94,6 @@ class BasicTests(PostgreSQLSimpleTestCase): class TestSaveLoad(PostgreSQLTestCase): - def test_integer(self): instance = IntegerArrayModel(field=[1, 2, 3]) instance.save() @@ -96,7 +101,7 @@ class TestSaveLoad(PostgreSQLTestCase): self.assertEqual(instance.field, loaded.field) def test_char(self): - instance = CharArrayModel(field=['hello', 'goodbye']) + instance = CharArrayModel(field=["hello", "goodbye"]) instance.save() loaded = CharArrayModel.objects.get() self.assertEqual(instance.field, loaded.field) @@ -121,7 +126,7 @@ class TestSaveLoad(PostgreSQLTestCase): def test_integers_passed_as_strings(self): # This checks that get_prep_value is deferred properly - instance = IntegerArrayModel(field=['1']) + instance = IntegerArrayModel(field=["1"]) instance.save() loaded = IntegerArrayModel.objects.get() self.assertEqual(loaded.field, [1]) @@ -151,16 +156,16 @@ class TestSaveLoad(PostgreSQLTestCase): def test_other_array_types(self): instance = OtherTypesArrayModel( - ips=['192.168.0.1', '::1'], + ips=["192.168.0.1", "::1"], uuids=[uuid.uuid4()], decimals=[decimal.Decimal(1.25), 1.75], tags=[Tag(1), Tag(2), Tag(3)], - json=[{'a': 1}, {'b': 2}], + json=[{"a": 1}, {"b": 2}], int_ranges=[NumericRange(10, 20), NumericRange(30, 40)], bigint_ranges=[ NumericRange(7000000000, 10000000000), NumericRange(50000000000, 70000000000), - ] + ], ) instance.save() loaded = OtherTypesArrayModel.objects.get() @@ -174,7 +179,7 @@ class TestSaveLoad(PostgreSQLTestCase): def test_null_from_db_value_handling(self): instance = OtherTypesArrayModel.objects.create( - ips=['192.168.0.1', '::1'], + ips=["192.168.0.1", "::1"], uuids=[uuid.uuid4()], decimals=[decimal.Decimal(1.25), 1.75], tags=None, @@ -187,7 +192,7 @@ class TestSaveLoad(PostgreSQLTestCase): def test_model_set_on_base_field(self): instance = IntegerArrayModel() - field = instance._meta.get_field('field') + field = instance._meta.get_field("field") self.assertEqual(field.model, IntegerArrayModel) self.assertEqual(field.base_field.model, IntegerArrayModel) @@ -199,29 +204,35 @@ class TestSaveLoad(PostgreSQLTestCase): class TestQuerying(PostgreSQLTestCase): - @classmethod def setUpTestData(cls): - cls.objs = NullableIntegerArrayModel.objects.bulk_create([ - NullableIntegerArrayModel(order=1, field=[1]), - NullableIntegerArrayModel(order=2, field=[2]), - NullableIntegerArrayModel(order=3, field=[2, 3]), - NullableIntegerArrayModel(order=4, field=[20, 30, 40]), - NullableIntegerArrayModel(order=5, field=None), - ]) + cls.objs = NullableIntegerArrayModel.objects.bulk_create( + [ + NullableIntegerArrayModel(order=1, field=[1]), + NullableIntegerArrayModel(order=2, field=[2]), + NullableIntegerArrayModel(order=3, field=[2, 3]), + NullableIntegerArrayModel(order=4, field=[20, 30, 40]), + NullableIntegerArrayModel(order=5, field=None), + ] + ) def test_empty_list(self): NullableIntegerArrayModel.objects.create(field=[]) - obj = NullableIntegerArrayModel.objects.annotate( - empty_array=models.Value([], output_field=ArrayField(models.IntegerField())), - ).filter(field=models.F('empty_array')).get() + obj = ( + NullableIntegerArrayModel.objects.annotate( + empty_array=models.Value( + [], output_field=ArrayField(models.IntegerField()) + ), + ) + .filter(field=models.F("empty_array")) + .get() + ) self.assertEqual(obj.field, []) self.assertEqual(obj.empty_array, []) def test_exact(self): self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__exact=[1]), - self.objs[:1] + NullableIntegerArrayModel.objects.filter(field__exact=[1]), self.objs[:1] ) def test_exact_with_expression(self): @@ -231,50 +242,47 @@ class TestQuerying(PostgreSQLTestCase): ) def test_exact_charfield(self): - instance = CharArrayModel.objects.create(field=['text']) + instance = CharArrayModel.objects.create(field=["text"]) self.assertSequenceEqual( - CharArrayModel.objects.filter(field=['text']), - [instance] + CharArrayModel.objects.filter(field=["text"]), [instance] ) def test_exact_nested(self): instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]]) self.assertSequenceEqual( - NestedIntegerArrayModel.objects.filter(field=[[1, 2], [3, 4]]), - [instance] + NestedIntegerArrayModel.objects.filter(field=[[1, 2], [3, 4]]), [instance] ) def test_isnull(self): self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__isnull=True), - self.objs[-1:] + NullableIntegerArrayModel.objects.filter(field__isnull=True), self.objs[-1:] ) def test_gt(self): self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__gt=[0]), - self.objs[:4] + NullableIntegerArrayModel.objects.filter(field__gt=[0]), self.objs[:4] ) def test_lt(self): self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__lt=[2]), - self.objs[:1] + NullableIntegerArrayModel.objects.filter(field__lt=[2]), self.objs[:1] ) def test_in(self): self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__in=[[1], [2]]), - self.objs[:2] + self.objs[:2], ) def test_in_subquery(self): IntegerArrayModel.objects.create(field=[2, 3]) self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter( - field__in=IntegerArrayModel.objects.all().values_list('field', flat=True) + field__in=IntegerArrayModel.objects.all().values_list( + "field", flat=True + ) ), - self.objs[2:3] + self.objs[2:3], ) @unittest.expectedFailure @@ -284,42 +292,44 @@ class TestQuerying(PostgreSQLTestCase): # psycopg2 mogrify method that generates the ARRAY() syntax is # expecting literals, not column references (#27095). self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__in=[[models.F('id')]]), - self.objs[:2] + NullableIntegerArrayModel.objects.filter(field__in=[[models.F("id")]]), + self.objs[:2], ) def test_in_as_F_object(self): self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__in=[models.F('field')]), - self.objs[:4] + NullableIntegerArrayModel.objects.filter(field__in=[models.F("field")]), + self.objs[:4], ) def test_contained_by(self): self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__contained_by=[1, 2]), - self.objs[:2] + self.objs[:2], ) def test_contained_by_including_F_object(self): self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__contained_by=[models.F('order'), 2]), + NullableIntegerArrayModel.objects.filter( + field__contained_by=[models.F("order"), 2] + ), self.objs[:3], ) def test_contains(self): self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__contains=[2]), - self.objs[1:3] + self.objs[1:3], ) def test_contains_subquery(self): IntegerArrayModel.objects.create(field=[2, 3]) - inner_qs = IntegerArrayModel.objects.values_list('field', flat=True) + inner_qs = IntegerArrayModel.objects.values_list("field", flat=True) self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__contains=inner_qs[:1]), self.objs[2:3], ) - inner_qs = IntegerArrayModel.objects.filter(field__contains=OuterRef('field')) + inner_qs = IntegerArrayModel.objects.filter(field__contains=OuterRef("field")) self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(Exists(inner_qs)), self.objs[1:3], @@ -335,93 +345,92 @@ class TestQuerying(PostgreSQLTestCase): def test_icontains(self): # Using the __icontains lookup with ArrayField is inefficient. - instance = CharArrayModel.objects.create(field=['FoO']) + instance = CharArrayModel.objects.create(field=["FoO"]) self.assertSequenceEqual( - CharArrayModel.objects.filter(field__icontains='foo'), - [instance] + CharArrayModel.objects.filter(field__icontains="foo"), [instance] ) def test_contains_charfield(self): # Regression for #22907 self.assertSequenceEqual( - CharArrayModel.objects.filter(field__contains=['text']), - [] + CharArrayModel.objects.filter(field__contains=["text"]), [] ) def test_contained_by_charfield(self): self.assertSequenceEqual( - CharArrayModel.objects.filter(field__contained_by=['text']), - [] + CharArrayModel.objects.filter(field__contained_by=["text"]), [] ) def test_overlap_charfield(self): self.assertSequenceEqual( - CharArrayModel.objects.filter(field__overlap=['text']), - [] + CharArrayModel.objects.filter(field__overlap=["text"]), [] ) def test_overlap_charfield_including_expression(self): - obj_1 = CharArrayModel.objects.create(field=['TEXT', 'lower text']) - obj_2 = CharArrayModel.objects.create(field=['lower text', 'TEXT']) - CharArrayModel.objects.create(field=['lower text', 'text']) + obj_1 = CharArrayModel.objects.create(field=["TEXT", "lower text"]) + obj_2 = CharArrayModel.objects.create(field=["lower text", "TEXT"]) + CharArrayModel.objects.create(field=["lower text", "text"]) self.assertSequenceEqual( - CharArrayModel.objects.filter(field__overlap=[ - Upper(Value('text')), - 'other', - ]), + CharArrayModel.objects.filter( + field__overlap=[ + Upper(Value("text")), + "other", + ] + ), [obj_1, obj_2], ) def test_lookups_autofield_array(self): - qs = NullableIntegerArrayModel.objects.filter( - field__0__isnull=False, - ).values('field__0').annotate( - arrayagg=ArrayAgg('id'), - ).order_by('field__0') + qs = ( + NullableIntegerArrayModel.objects.filter( + field__0__isnull=False, + ) + .values("field__0") + .annotate( + arrayagg=ArrayAgg("id"), + ) + .order_by("field__0") + ) tests = ( - ('contained_by', [self.objs[1].pk, self.objs[2].pk, 0], [2]), - ('contains', [self.objs[2].pk], [2]), - ('exact', [self.objs[3].pk], [20]), - ('overlap', [self.objs[1].pk, self.objs[3].pk], [2, 20]), + ("contained_by", [self.objs[1].pk, self.objs[2].pk, 0], [2]), + ("contains", [self.objs[2].pk], [2]), + ("exact", [self.objs[3].pk], [20]), + ("overlap", [self.objs[1].pk, self.objs[3].pk], [2, 20]), ) for lookup, value, expected in tests: with self.subTest(lookup=lookup): self.assertSequenceEqual( qs.filter( - **{'arrayagg__' + lookup: value}, - ).values_list('field__0', flat=True), + **{"arrayagg__" + lookup: value}, + ).values_list("field__0", flat=True), expected, ) def test_index(self): self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__0=2), - self.objs[1:3] + NullableIntegerArrayModel.objects.filter(field__0=2), self.objs[1:3] ) def test_index_chained(self): self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__0__lt=3), - self.objs[0:3] + NullableIntegerArrayModel.objects.filter(field__0__lt=3), self.objs[0:3] ) def test_index_nested(self): instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]]) self.assertSequenceEqual( - NestedIntegerArrayModel.objects.filter(field__0__0=1), - [instance] + NestedIntegerArrayModel.objects.filter(field__0__0=1), [instance] ) @unittest.expectedFailure def test_index_used_on_nested_data(self): instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]]) self.assertSequenceEqual( - NestedIntegerArrayModel.objects.filter(field__0=[1, 2]), - [instance] + NestedIntegerArrayModel.objects.filter(field__0=[1, 2]), [instance] ) def test_index_transform_expression(self): - expr = RawSQL("string_to_array(%s, ';')", ['1;2']) + expr = RawSQL("string_to_array(%s, ';')", ["1;2"]) self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter( field__0=Cast( @@ -433,40 +442,36 @@ class TestQuerying(PostgreSQLTestCase): ) def test_index_annotation(self): - qs = NullableIntegerArrayModel.objects.annotate(second=models.F('field__1')) + qs = NullableIntegerArrayModel.objects.annotate(second=models.F("field__1")) self.assertCountEqual( - qs.values_list('second', flat=True), + qs.values_list("second", flat=True), [None, None, None, 3, 30], ) def test_overlap(self): self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__overlap=[1, 2]), - self.objs[0:3] + self.objs[0:3], ) def test_len(self): self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__len__lte=2), - self.objs[0:3] + NullableIntegerArrayModel.objects.filter(field__len__lte=2), self.objs[0:3] ) def test_len_empty_array(self): obj = NullableIntegerArrayModel.objects.create(field=[]) self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__len=0), - [obj] + NullableIntegerArrayModel.objects.filter(field__len=0), [obj] ) def test_slice(self): self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__0_1=[2]), - self.objs[1:3] + NullableIntegerArrayModel.objects.filter(field__0_1=[2]), self.objs[1:3] ) self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__0_2=[2, 3]), - self.objs[2:3] + NullableIntegerArrayModel.objects.filter(field__0_2=[2, 3]), self.objs[2:3] ) def test_order_by_slice(self): @@ -477,35 +482,42 @@ class TestQuerying(PostgreSQLTestCase): NullableIntegerArrayModel.objects.create(field=[4, 2]), ) self.assertSequenceEqual( - NullableIntegerArrayModel.objects.order_by('field__1'), + NullableIntegerArrayModel.objects.order_by("field__1"), [ - more_objs[2], more_objs[1], more_objs[3], self.objs[2], - self.objs[3], more_objs[0], self.objs[4], self.objs[1], + more_objs[2], + more_objs[1], + more_objs[3], + self.objs[2], + self.objs[3], + more_objs[0], + self.objs[4], + self.objs[1], self.objs[0], - ] + ], ) @unittest.expectedFailure def test_slice_nested(self): instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]]) self.assertSequenceEqual( - NestedIntegerArrayModel.objects.filter(field__0__0_1=[1]), - [instance] + NestedIntegerArrayModel.objects.filter(field__0__0_1=[1]), [instance] ) def test_slice_transform_expression(self): - expr = RawSQL("string_to_array(%s, ';')", ['9;2;3']) + expr = RawSQL("string_to_array(%s, ';')", ["9;2;3"]) self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__0_2=SliceTransform(2, 3, expr)), + NullableIntegerArrayModel.objects.filter( + field__0_2=SliceTransform(2, 3, expr) + ), self.objs[2:3], ) def test_slice_annotation(self): qs = NullableIntegerArrayModel.objects.annotate( - first_two=models.F('field__0_2'), + first_two=models.F("field__0_2"), ) self.assertCountEqual( - qs.values_list('first_two', flat=True), + qs.values_list("first_two", flat=True), [None, [1], [2], [2, 3], [20, 30]], ) @@ -514,17 +526,17 @@ class TestQuerying(PostgreSQLTestCase): NullableIntegerArrayModel.objects.filter( id__in=NullableIntegerArrayModel.objects.filter(field__len=3) ), - [self.objs[3]] + [self.objs[3]], ) def test_enum_lookup(self): class TestEnum(enum.Enum): - VALUE_1 = 'value_1' + VALUE_1 = "value_1" instance = ArrayEnumModel.objects.create(array_of_enums=[TestEnum.VALUE_1]) self.assertSequenceEqual( ArrayEnumModel.objects.filter(array_of_enums__contains=[TestEnum.VALUE_1]), - [instance] + [instance], ) def test_unsupported_lookup(self): @@ -541,18 +553,24 @@ class TestQuerying(PostgreSQLTestCase): self.assertEqual( NullableIntegerArrayModel.objects.annotate( array_length=models.Func( - value, 1, function='ARRAY_LENGTH', output_field=models.IntegerField(), + value, + 1, + function="ARRAY_LENGTH", + output_field=models.IntegerField(), ), - ).values('array_length').annotate( - count=models.Count('pk'), - ).get()['array_length'], + ) + .values("array_length") + .annotate( + count=models.Count("pk"), + ) + .get()["array_length"], 1, ) def test_filter_by_array_subquery(self): inner_qs = NullableIntegerArrayModel.objects.filter( - field__len=models.OuterRef('field__len'), - ).values('field') + field__len=models.OuterRef("field__len"), + ).values("field") self.assertSequenceEqual( NullableIntegerArrayModel.objects.alias( same_sized_fields=ArraySubquery(inner_qs), @@ -562,56 +580,63 @@ class TestQuerying(PostgreSQLTestCase): def test_annotated_array_subquery(self): inner_qs = NullableIntegerArrayModel.objects.exclude( - pk=models.OuterRef('pk') - ).values('order') + pk=models.OuterRef("pk") + ).values("order") self.assertSequenceEqual( NullableIntegerArrayModel.objects.annotate( sibling_ids=ArraySubquery(inner_qs), - ).get(order=1).sibling_ids, + ) + .get(order=1) + .sibling_ids, [2, 3, 4, 5], ) def test_group_by_with_annotated_array_subquery(self): inner_qs = NullableIntegerArrayModel.objects.exclude( - pk=models.OuterRef('pk') - ).values('order') + pk=models.OuterRef("pk") + ).values("order") self.assertSequenceEqual( NullableIntegerArrayModel.objects.annotate( sibling_ids=ArraySubquery(inner_qs), - sibling_count=models.Max('sibling_ids__len'), - ).values_list('sibling_count', flat=True), + sibling_count=models.Max("sibling_ids__len"), + ).values_list("sibling_count", flat=True), [len(self.objs) - 1] * len(self.objs), ) def test_annotated_ordered_array_subquery(self): - inner_qs = NullableIntegerArrayModel.objects.order_by('-order').values('order') + inner_qs = NullableIntegerArrayModel.objects.order_by("-order").values("order") self.assertSequenceEqual( NullableIntegerArrayModel.objects.annotate( ids=ArraySubquery(inner_qs), - ).first().ids, + ) + .first() + .ids, [5, 4, 3, 2, 1], ) def test_annotated_array_subquery_with_json_objects(self): inner_qs = NullableIntegerArrayModel.objects.exclude( - pk=models.OuterRef('pk') - ).values(json=JSONObject(order='order', field='field')) - siblings_json = NullableIntegerArrayModel.objects.annotate( - siblings_json=ArraySubquery(inner_qs), - ).values_list('siblings_json', flat=True).get(order=1) + pk=models.OuterRef("pk") + ).values(json=JSONObject(order="order", field="field")) + siblings_json = ( + NullableIntegerArrayModel.objects.annotate( + siblings_json=ArraySubquery(inner_qs), + ) + .values_list("siblings_json", flat=True) + .get(order=1) + ) self.assertSequenceEqual( siblings_json, [ - {'field': [2], 'order': 2}, - {'field': [2, 3], 'order': 3}, - {'field': [20, 30, 40], 'order': 4}, - {'field': None, 'order': 5}, + {"field": [2], "order": 2}, + {"field": [2, 3], "order": 3}, + {"field": [20, 30, 40], "order": 4}, + {"field": None, "order": 5}, ], ) class TestDateTimeExactQuerying(PostgreSQLTestCase): - @classmethod def setUpTestData(cls): now = timezone.now() @@ -619,33 +644,31 @@ class TestDateTimeExactQuerying(PostgreSQLTestCase): cls.dates = [now.date()] cls.times = [now.time()] cls.objs = [ - DateTimeArrayModel.objects.create(datetimes=cls.datetimes, dates=cls.dates, times=cls.times), + DateTimeArrayModel.objects.create( + datetimes=cls.datetimes, dates=cls.dates, times=cls.times + ), ] def test_exact_datetimes(self): self.assertSequenceEqual( - DateTimeArrayModel.objects.filter(datetimes=self.datetimes), - self.objs + DateTimeArrayModel.objects.filter(datetimes=self.datetimes), self.objs ) def test_exact_dates(self): self.assertSequenceEqual( - DateTimeArrayModel.objects.filter(dates=self.dates), - self.objs + DateTimeArrayModel.objects.filter(dates=self.dates), self.objs ) def test_exact_times(self): self.assertSequenceEqual( - DateTimeArrayModel.objects.filter(times=self.times), - self.objs + DateTimeArrayModel.objects.filter(times=self.times), self.objs ) class TestOtherTypesExactQuerying(PostgreSQLTestCase): - @classmethod def setUpTestData(cls): - cls.ips = ['192.168.0.1', '::1'] + cls.ips = ["192.168.0.1", "::1"] cls.uuids = [uuid.uuid4()] cls.decimals = [decimal.Decimal(1.25), 1.75] cls.tags = [Tag(1), Tag(2), Tag(3)] @@ -660,32 +683,27 @@ class TestOtherTypesExactQuerying(PostgreSQLTestCase): def test_exact_ip_addresses(self): self.assertSequenceEqual( - OtherTypesArrayModel.objects.filter(ips=self.ips), - self.objs + OtherTypesArrayModel.objects.filter(ips=self.ips), self.objs ) def test_exact_uuids(self): self.assertSequenceEqual( - OtherTypesArrayModel.objects.filter(uuids=self.uuids), - self.objs + OtherTypesArrayModel.objects.filter(uuids=self.uuids), self.objs ) def test_exact_decimals(self): self.assertSequenceEqual( - OtherTypesArrayModel.objects.filter(decimals=self.decimals), - self.objs + OtherTypesArrayModel.objects.filter(decimals=self.decimals), self.objs ) def test_exact_tags(self): self.assertSequenceEqual( - OtherTypesArrayModel.objects.filter(tags=self.tags), - self.objs + OtherTypesArrayModel.objects.filter(tags=self.tags), self.objs ) -@isolate_apps('postgres_tests') +@isolate_apps("postgres_tests") class TestChecks(PostgreSQLSimpleTestCase): - def test_field_checks(self): class MyModel(PostgreSQLModel): field = ArrayField(models.CharField()) @@ -694,35 +712,40 @@ class TestChecks(PostgreSQLSimpleTestCase): errors = model.check() self.assertEqual(len(errors), 1) # The inner CharField is missing a max_length. - self.assertEqual(errors[0].id, 'postgres.E001') - self.assertIn('max_length', errors[0].msg) + self.assertEqual(errors[0].id, "postgres.E001") + self.assertIn("max_length", errors[0].msg) def test_invalid_base_fields(self): class MyModel(PostgreSQLModel): - field = ArrayField(models.ManyToManyField('postgres_tests.IntegerArrayModel')) + field = ArrayField( + models.ManyToManyField("postgres_tests.IntegerArrayModel") + ) model = MyModel() errors = model.check() self.assertEqual(len(errors), 1) - self.assertEqual(errors[0].id, 'postgres.E002') + self.assertEqual(errors[0].id, "postgres.E002") def test_invalid_default(self): class MyModel(PostgreSQLModel): field = ArrayField(models.IntegerField(), default=[]) model = MyModel() - self.assertEqual(model.check(), [ - checks.Warning( - msg=( - "ArrayField default should be a callable instead of an " - "instance so that it's not shared between all field " - "instances." - ), - hint='Use a callable instead, e.g., use `list` instead of `[]`.', - obj=MyModel._meta.get_field('field'), - id='fields.E010', - ) - ]) + self.assertEqual( + model.check(), + [ + checks.Warning( + msg=( + "ArrayField default should be a callable instead of an " + "instance so that it's not shared between all field " + "instances." + ), + hint="Use a callable instead, e.g., use `list` instead of `[]`.", + obj=MyModel._meta.get_field("field"), + id="fields.E010", + ) + ], + ) def test_valid_default(self): class MyModel(PostgreSQLModel): @@ -742,6 +765,7 @@ class TestChecks(PostgreSQLSimpleTestCase): """ Nested ArrayFields are permitted. """ + class MyModel(PostgreSQLModel): field = ArrayField(ArrayField(models.CharField())) @@ -749,8 +773,8 @@ class TestChecks(PostgreSQLSimpleTestCase): errors = model.check() self.assertEqual(len(errors), 1) # The inner CharField is missing a max_length. - self.assertEqual(errors[0].id, 'postgres.E001') - self.assertIn('max_length', errors[0].msg) + self.assertEqual(errors[0].id, "postgres.E001") + self.assertIn("max_length", errors[0].msg) def test_choices_tuple_list(self): class MyModel(PostgreSQLModel): @@ -758,19 +782,20 @@ class TestChecks(PostgreSQLSimpleTestCase): models.CharField(max_length=16), choices=[ [ - 'Media', - [(['vinyl', 'cd'], 'Audio'), (('vhs', 'dvd'), 'Video')], + "Media", + [(["vinyl", "cd"], "Audio"), (("vhs", "dvd"), "Video")], ], - (['mp3', 'mp4'], 'Digital'), + (["mp3", "mp4"], "Digital"), ], ) - self.assertEqual(MyModel._meta.get_field('field').check(), []) + + self.assertEqual(MyModel._meta.get_field("field").check(), []) -@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific tests") +@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests") class TestMigrations(TransactionTestCase): - available_apps = ['postgres_tests'] + available_apps = ["postgres_tests"] def test_deconstruct(self): field = ArrayField(models.IntegerField()) @@ -794,84 +819,89 @@ class TestMigrations(TransactionTestCase): def test_subclass_deconstruct(self): field = ArrayField(models.IntegerField()) name, path, args, kwargs = field.deconstruct() - self.assertEqual(path, 'django.contrib.postgres.fields.ArrayField') + self.assertEqual(path, "django.contrib.postgres.fields.ArrayField") field = ArrayFieldSubclass() name, path, args, kwargs = field.deconstruct() - self.assertEqual(path, 'postgres_tests.models.ArrayFieldSubclass') + self.assertEqual(path, "postgres_tests.models.ArrayFieldSubclass") - @override_settings(MIGRATION_MODULES={ - "postgres_tests": "postgres_tests.array_default_migrations", - }) + @override_settings( + MIGRATION_MODULES={ + "postgres_tests": "postgres_tests.array_default_migrations", + } + ) def test_adding_field_with_default(self): # See #22962 - table_name = 'postgres_tests_integerarraydefaultmodel' + table_name = "postgres_tests_integerarraydefaultmodel" with connection.cursor() as cursor: self.assertNotIn(table_name, connection.introspection.table_names(cursor)) - call_command('migrate', 'postgres_tests', verbosity=0) + call_command("migrate", "postgres_tests", verbosity=0) with connection.cursor() as cursor: self.assertIn(table_name, connection.introspection.table_names(cursor)) - call_command('migrate', 'postgres_tests', 'zero', verbosity=0) + call_command("migrate", "postgres_tests", "zero", verbosity=0) with connection.cursor() as cursor: self.assertNotIn(table_name, connection.introspection.table_names(cursor)) - @override_settings(MIGRATION_MODULES={ - "postgres_tests": "postgres_tests.array_index_migrations", - }) + @override_settings( + MIGRATION_MODULES={ + "postgres_tests": "postgres_tests.array_index_migrations", + } + ) def test_adding_arrayfield_with_index(self): """ ArrayField shouldn't have varchar_patterns_ops or text_patterns_ops indexes. """ - table_name = 'postgres_tests_chartextarrayindexmodel' - call_command('migrate', 'postgres_tests', verbosity=0) + table_name = "postgres_tests_chartextarrayindexmodel" + call_command("migrate", "postgres_tests", verbosity=0) with connection.cursor() as cursor: like_constraint_columns_list = [ - v['columns'] - for k, v in list(connection.introspection.get_constraints(cursor, table_name).items()) - if k.endswith('_like') + v["columns"] + for k, v in list( + connection.introspection.get_constraints(cursor, table_name).items() + ) + if k.endswith("_like") ] # Only the CharField should have a LIKE index. - self.assertEqual(like_constraint_columns_list, [['char2']]) + self.assertEqual(like_constraint_columns_list, [["char2"]]) # All fields should have regular indexes. with connection.cursor() as cursor: indexes = [ - c['columns'][0] - for c in connection.introspection.get_constraints(cursor, table_name).values() - if c['index'] and len(c['columns']) == 1 + c["columns"][0] + for c in connection.introspection.get_constraints( + cursor, table_name + ).values() + if c["index"] and len(c["columns"]) == 1 ] - self.assertIn('char', indexes) - self.assertIn('char2', indexes) - self.assertIn('text', indexes) - call_command('migrate', 'postgres_tests', 'zero', verbosity=0) + self.assertIn("char", indexes) + self.assertIn("char2", indexes) + self.assertIn("text", indexes) + call_command("migrate", "postgres_tests", "zero", verbosity=0) with connection.cursor() as cursor: self.assertNotIn(table_name, connection.introspection.table_names(cursor)) class TestSerialization(PostgreSQLSimpleTestCase): - test_data = ( - '[{"fields": {"field": "[\\"1\\", \\"2\\", null]"}, "model": "postgres_tests.integerarraymodel", "pk": null}]' - ) + test_data = '[{"fields": {"field": "[\\"1\\", \\"2\\", null]"}, "model": "postgres_tests.integerarraymodel", "pk": null}]' def test_dumping(self): instance = IntegerArrayModel(field=[1, 2, None]) - data = serializers.serialize('json', [instance]) + data = serializers.serialize("json", [instance]) self.assertEqual(json.loads(data), json.loads(self.test_data)) def test_loading(self): - instance = list(serializers.deserialize('json', self.test_data))[0].object + instance = list(serializers.deserialize("json", self.test_data))[0].object self.assertEqual(instance.field, [1, 2, None]) class TestValidation(PostgreSQLSimpleTestCase): - def test_unbounded(self): field = ArrayField(models.IntegerField()) with self.assertRaises(exceptions.ValidationError) as cm: field.clean([1, None], None) - self.assertEqual(cm.exception.code, 'item_invalid') + self.assertEqual(cm.exception.code, "item_invalid") self.assertEqual( cm.exception.message % cm.exception.params, - 'Item 2 in the array did not validate: This field cannot be null.' + "Item 2 in the array did not validate: This field cannot be null.", ) def test_blank_true(self): @@ -884,31 +914,41 @@ class TestValidation(PostgreSQLSimpleTestCase): field.clean([1, 2, 3], None) with self.assertRaises(exceptions.ValidationError) as cm: field.clean([1, 2, 3, 4], None) - self.assertEqual(cm.exception.messages[0], 'List contains 4 items, it should contain no more than 3.') + self.assertEqual( + cm.exception.messages[0], + "List contains 4 items, it should contain no more than 3.", + ) def test_nested_array_mismatch(self): field = ArrayField(ArrayField(models.IntegerField())) field.clean([[1, 2], [3, 4]], None) with self.assertRaises(exceptions.ValidationError) as cm: field.clean([[1, 2], [3, 4, 5]], None) - self.assertEqual(cm.exception.code, 'nested_array_mismatch') - self.assertEqual(cm.exception.messages[0], 'Nested arrays must have the same length.') + self.assertEqual(cm.exception.code, "nested_array_mismatch") + self.assertEqual( + cm.exception.messages[0], "Nested arrays must have the same length." + ) def test_with_base_field_error_params(self): field = ArrayField(models.CharField(max_length=2)) with self.assertRaises(exceptions.ValidationError) as cm: - field.clean(['abc'], None) + field.clean(["abc"], None) self.assertEqual(len(cm.exception.error_list), 1) exception = cm.exception.error_list[0] self.assertEqual( exception.message, - 'Item 1 in the array did not validate: Ensure this value has at most 2 characters (it has 3).' + "Item 1 in the array did not validate: Ensure this value has at most 2 characters (it has 3).", + ) + self.assertEqual(exception.code, "item_invalid") + self.assertEqual( + exception.params, + {"nth": 1, "value": "abc", "limit_value": 2, "show_value": 3}, ) - self.assertEqual(exception.code, 'item_invalid') - self.assertEqual(exception.params, {'nth': 1, 'value': 'abc', 'limit_value': 2, 'show_value': 3}) def test_with_validators(self): - field = ArrayField(models.IntegerField(validators=[validators.MinValueValidator(1)])) + field = ArrayField( + models.IntegerField(validators=[validators.MinValueValidator(1)]) + ) field.clean([1, 2], None) with self.assertRaises(exceptions.ValidationError) as cm: field.clean([0], None) @@ -916,90 +956,112 @@ class TestValidation(PostgreSQLSimpleTestCase): exception = cm.exception.error_list[0] self.assertEqual( exception.message, - 'Item 1 in the array did not validate: Ensure this value is greater than or equal to 1.' + "Item 1 in the array did not validate: Ensure this value is greater than or equal to 1.", + ) + self.assertEqual(exception.code, "item_invalid") + self.assertEqual( + exception.params, {"nth": 1, "value": 0, "limit_value": 1, "show_value": 0} ) - self.assertEqual(exception.code, 'item_invalid') - self.assertEqual(exception.params, {'nth': 1, 'value': 0, 'limit_value': 1, 'show_value': 0}) class TestSimpleFormField(PostgreSQLSimpleTestCase): - def test_valid(self): field = SimpleArrayField(forms.CharField()) - value = field.clean('a,b,c') - self.assertEqual(value, ['a', 'b', 'c']) + value = field.clean("a,b,c") + self.assertEqual(value, ["a", "b", "c"]) def test_to_python_fail(self): field = SimpleArrayField(forms.IntegerField()) with self.assertRaises(exceptions.ValidationError) as cm: - field.clean('a,b,9') - self.assertEqual(cm.exception.messages[0], 'Item 1 in the array did not validate: Enter a whole number.') + field.clean("a,b,9") + self.assertEqual( + cm.exception.messages[0], + "Item 1 in the array did not validate: Enter a whole number.", + ) def test_validate_fail(self): field = SimpleArrayField(forms.CharField(required=True)) with self.assertRaises(exceptions.ValidationError) as cm: - field.clean('a,b,') - self.assertEqual(cm.exception.messages[0], 'Item 3 in the array did not validate: This field is required.') + field.clean("a,b,") + self.assertEqual( + cm.exception.messages[0], + "Item 3 in the array did not validate: This field is required.", + ) def test_validate_fail_base_field_error_params(self): field = SimpleArrayField(forms.CharField(max_length=2)) with self.assertRaises(exceptions.ValidationError) as cm: - field.clean('abc,c,defg') + field.clean("abc,c,defg") errors = cm.exception.error_list self.assertEqual(len(errors), 2) first_error = errors[0] self.assertEqual( first_error.message, - 'Item 1 in the array did not validate: Ensure this value has at most 2 characters (it has 3).' + "Item 1 in the array did not validate: Ensure this value has at most 2 characters (it has 3).", + ) + self.assertEqual(first_error.code, "item_invalid") + self.assertEqual( + first_error.params, + {"nth": 1, "value": "abc", "limit_value": 2, "show_value": 3}, ) - self.assertEqual(first_error.code, 'item_invalid') - self.assertEqual(first_error.params, {'nth': 1, 'value': 'abc', 'limit_value': 2, 'show_value': 3}) second_error = errors[1] self.assertEqual( second_error.message, - 'Item 3 in the array did not validate: Ensure this value has at most 2 characters (it has 4).' + "Item 3 in the array did not validate: Ensure this value has at most 2 characters (it has 4).", + ) + self.assertEqual(second_error.code, "item_invalid") + self.assertEqual( + second_error.params, + {"nth": 3, "value": "defg", "limit_value": 2, "show_value": 4}, ) - self.assertEqual(second_error.code, 'item_invalid') - self.assertEqual(second_error.params, {'nth': 3, 'value': 'defg', 'limit_value': 2, 'show_value': 4}) def test_validators_fail(self): - field = SimpleArrayField(forms.RegexField('[a-e]{2}')) + field = SimpleArrayField(forms.RegexField("[a-e]{2}")) with self.assertRaises(exceptions.ValidationError) as cm: - field.clean('a,bc,de') - self.assertEqual(cm.exception.messages[0], 'Item 1 in the array did not validate: Enter a valid value.') + field.clean("a,bc,de") + self.assertEqual( + cm.exception.messages[0], + "Item 1 in the array did not validate: Enter a valid value.", + ) def test_delimiter(self): - field = SimpleArrayField(forms.CharField(), delimiter='|') - value = field.clean('a|b|c') - self.assertEqual(value, ['a', 'b', 'c']) + field = SimpleArrayField(forms.CharField(), delimiter="|") + value = field.clean("a|b|c") + self.assertEqual(value, ["a", "b", "c"]) def test_delimiter_with_nesting(self): - field = SimpleArrayField(SimpleArrayField(forms.CharField()), delimiter='|') - value = field.clean('a,b|c,d') - self.assertEqual(value, [['a', 'b'], ['c', 'd']]) + field = SimpleArrayField(SimpleArrayField(forms.CharField()), delimiter="|") + value = field.clean("a,b|c,d") + self.assertEqual(value, [["a", "b"], ["c", "d"]]) def test_prepare_value(self): field = SimpleArrayField(forms.CharField()) - value = field.prepare_value(['a', 'b', 'c']) - self.assertEqual(value, 'a,b,c') + value = field.prepare_value(["a", "b", "c"]) + self.assertEqual(value, "a,b,c") def test_max_length(self): field = SimpleArrayField(forms.CharField(), max_length=2) with self.assertRaises(exceptions.ValidationError) as cm: - field.clean('a,b,c') - self.assertEqual(cm.exception.messages[0], 'List contains 3 items, it should contain no more than 2.') + field.clean("a,b,c") + self.assertEqual( + cm.exception.messages[0], + "List contains 3 items, it should contain no more than 2.", + ) def test_min_length(self): field = SimpleArrayField(forms.CharField(), min_length=4) with self.assertRaises(exceptions.ValidationError) as cm: - field.clean('a,b,c') - self.assertEqual(cm.exception.messages[0], 'List contains 3 items, it should contain no fewer than 4.') + field.clean("a,b,c") + self.assertEqual( + cm.exception.messages[0], + "List contains 3 items, it should contain no fewer than 4.", + ) def test_required(self): field = SimpleArrayField(forms.CharField(), required=True) with self.assertRaises(exceptions.ValidationError) as cm: - field.clean('') - self.assertEqual(cm.exception.messages[0], 'This field is required.') + field.clean("") + self.assertEqual(cm.exception.messages[0], "This field is required.") def test_model_field_formfield(self): model_field = ArrayField(models.CharField(max_length=27)) @@ -1015,59 +1077,66 @@ class TestSimpleFormField(PostgreSQLSimpleTestCase): self.assertEqual(form_field.max_length, 4) def test_model_field_choices(self): - model_field = ArrayField(models.IntegerField(choices=((1, 'A'), (2, 'B')))) + model_field = ArrayField(models.IntegerField(choices=((1, "A"), (2, "B")))) form_field = model_field.formfield() - self.assertEqual(form_field.clean('1,2'), [1, 2]) + self.assertEqual(form_field.clean("1,2"), [1, 2]) def test_already_converted_value(self): field = SimpleArrayField(forms.CharField()) - vals = ['a', 'b', 'c'] + vals = ["a", "b", "c"] self.assertEqual(field.clean(vals), vals) def test_has_changed(self): field = SimpleArrayField(forms.IntegerField()) self.assertIs(field.has_changed([1, 2], [1, 2]), False) - self.assertIs(field.has_changed([1, 2], '1,2'), False) - self.assertIs(field.has_changed([1, 2], '1,2,3'), True) - self.assertIs(field.has_changed([1, 2], 'a,b'), True) + self.assertIs(field.has_changed([1, 2], "1,2"), False) + self.assertIs(field.has_changed([1, 2], "1,2,3"), True) + self.assertIs(field.has_changed([1, 2], "a,b"), True) def test_has_changed_empty(self): field = SimpleArrayField(forms.CharField()) self.assertIs(field.has_changed(None, None), False) - self.assertIs(field.has_changed(None, ''), False) + self.assertIs(field.has_changed(None, ""), False) self.assertIs(field.has_changed(None, []), False) self.assertIs(field.has_changed([], None), False) - self.assertIs(field.has_changed([], ''), False) + self.assertIs(field.has_changed([], ""), False) class TestSplitFormField(PostgreSQLSimpleTestCase): - def test_valid(self): class SplitForm(forms.Form): array = SplitArrayField(forms.CharField(), size=3) - data = {'array_0': 'a', 'array_1': 'b', 'array_2': 'c'} + data = {"array_0": "a", "array_1": "b", "array_2": "c"} form = SplitForm(data) self.assertTrue(form.is_valid()) - self.assertEqual(form.cleaned_data, {'array': ['a', 'b', 'c']}) + self.assertEqual(form.cleaned_data, {"array": ["a", "b", "c"]}) def test_required(self): class SplitForm(forms.Form): array = SplitArrayField(forms.CharField(), required=True, size=3) - data = {'array_0': '', 'array_1': '', 'array_2': ''} + data = {"array_0": "", "array_1": "", "array_2": ""} form = SplitForm(data) self.assertFalse(form.is_valid()) - self.assertEqual(form.errors, {'array': ['This field is required.']}) + self.assertEqual(form.errors, {"array": ["This field is required."]}) def test_remove_trailing_nulls(self): class SplitForm(forms.Form): - array = SplitArrayField(forms.CharField(required=False), size=5, remove_trailing_nulls=True) + array = SplitArrayField( + forms.CharField(required=False), size=5, remove_trailing_nulls=True + ) - data = {'array_0': 'a', 'array_1': '', 'array_2': 'b', 'array_3': '', 'array_4': ''} + data = { + "array_0": "a", + "array_1": "", + "array_2": "b", + "array_3": "", + "array_4": "", + } form = SplitForm(data) self.assertTrue(form.is_valid(), form.errors) - self.assertEqual(form.cleaned_data, {'array': ['a', '', 'b']}) + self.assertEqual(form.cleaned_data, {"array": ["a", "", "b"]}) def test_remove_trailing_nulls_not_required(self): class SplitForm(forms.Form): @@ -1078,32 +1147,41 @@ class TestSplitFormField(PostgreSQLSimpleTestCase): required=False, ) - data = {'array_0': '', 'array_1': ''} + data = {"array_0": "", "array_1": ""} form = SplitForm(data) self.assertTrue(form.is_valid()) - self.assertEqual(form.cleaned_data, {'array': []}) + self.assertEqual(form.cleaned_data, {"array": []}) def test_required_field(self): class SplitForm(forms.Form): array = SplitArrayField(forms.CharField(), size=3) - data = {'array_0': 'a', 'array_1': 'b', 'array_2': ''} + data = {"array_0": "a", "array_1": "b", "array_2": ""} form = SplitForm(data) self.assertFalse(form.is_valid()) - self.assertEqual(form.errors, {'array': ['Item 3 in the array did not validate: This field is required.']}) + self.assertEqual( + form.errors, + { + "array": [ + "Item 3 in the array did not validate: This field is required." + ] + }, + ) def test_invalid_integer(self): - msg = 'Item 2 in the array did not validate: Ensure this value is less than or equal to 100.' + msg = "Item 2 in the array did not validate: Ensure this value is less than or equal to 100." with self.assertRaisesMessage(exceptions.ValidationError, msg): SplitArrayField(forms.IntegerField(max_value=100), size=2).clean([0, 101]) # To locate the widget's template. - @modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}) + @modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"}) def test_rendering(self): class SplitForm(forms.Form): array = SplitArrayField(forms.CharField(), size=3) - self.assertHTMLEqual(str(SplitForm()), ''' + self.assertHTMLEqual( + str(SplitForm()), + """ <tr> <th><label for="id_array_0">Array:</label></th> <td> @@ -1112,16 +1190,20 @@ class TestSplitFormField(PostgreSQLSimpleTestCase): <input id="id_array_2" name="array_2" type="text" required> </td> </tr> - ''') + """, + ) def test_invalid_char_length(self): field = SplitArrayField(forms.CharField(max_length=2), size=3) with self.assertRaises(exceptions.ValidationError) as cm: - field.clean(['abc', 'c', 'defg']) - self.assertEqual(cm.exception.messages, [ - 'Item 1 in the array did not validate: Ensure this value has at most 2 characters (it has 3).', - 'Item 3 in the array did not validate: Ensure this value has at most 2 characters (it has 4).', - ]) + field.clean(["abc", "c", "defg"]) + self.assertEqual( + cm.exception.messages, + [ + "Item 1 in the array did not validate: Ensure this value has at most 2 characters (it has 3).", + "Item 3 in the array did not validate: Ensure this value has at most 2 characters (it has 4).", + ], + ) def test_splitarraywidget_value_omitted_from_data(self): class Form(forms.ModelForm): @@ -1129,9 +1211,9 @@ class TestSplitFormField(PostgreSQLSimpleTestCase): class Meta: model = IntegerArrayModel - fields = ('field',) + fields = ("field",) - form = Form({'field_0': '1', 'field_1': '2'}) + form = Form({"field_0": "1", "field_1": "2"}) self.assertEqual(form.errors, {}) obj = form.save(commit=False) self.assertEqual(obj.field, [1, 2]) @@ -1142,15 +1224,15 @@ class TestSplitFormField(PostgreSQLSimpleTestCase): class Meta: model = IntegerArrayModel - fields = ('field',) + fields = ("field",) tests = [ - ({}, {'field_0': '', 'field_1': ''}, True), - ({'field': None}, {'field_0': '', 'field_1': ''}, True), - ({'field': [1]}, {'field_0': '', 'field_1': ''}, True), - ({'field': [1]}, {'field_0': '1', 'field_1': '0'}, True), - ({'field': [1, 2]}, {'field_0': '1', 'field_1': '2'}, False), - ({'field': [1, 2]}, {'field_0': 'a', 'field_1': 'b'}, True), + ({}, {"field_0": "", "field_1": ""}, True), + ({"field": None}, {"field_0": "", "field_1": ""}, True), + ({"field": [1]}, {"field_0": "", "field_1": ""}, True), + ({"field": [1]}, {"field_0": "1", "field_1": "0"}, True), + ({"field": [1, 2]}, {"field_0": "1", "field_1": "2"}, False), + ({"field": [1, 2]}, {"field_0": "a", "field_1": "b"}, True), ] for initial, data, expected_result in tests: with self.subTest(initial=initial, data=data): @@ -1160,17 +1242,19 @@ class TestSplitFormField(PostgreSQLSimpleTestCase): def test_splitarrayfield_remove_trailing_nulls_has_changed(self): class Form(forms.ModelForm): - field = SplitArrayField(forms.IntegerField(), required=False, size=2, remove_trailing_nulls=True) + field = SplitArrayField( + forms.IntegerField(), required=False, size=2, remove_trailing_nulls=True + ) class Meta: model = IntegerArrayModel - fields = ('field',) + fields = ("field",) tests = [ - ({}, {'field_0': '', 'field_1': ''}, False), - ({'field': None}, {'field_0': '', 'field_1': ''}, False), - ({'field': []}, {'field_0': '', 'field_1': ''}, False), - ({'field': [1]}, {'field_0': '1', 'field_1': ''}, False), + ({}, {"field_0": "", "field_1": ""}, False), + ({"field": None}, {"field_0": "", "field_1": ""}, False), + ({"field": []}, {"field_0": "", "field_1": ""}, False), + ({"field": [1]}, {"field_0": "1", "field_1": ""}, False), ] for initial, data, expected_result in tests: with self.subTest(initial=initial, data=data): @@ -1180,77 +1264,91 @@ class TestSplitFormField(PostgreSQLSimpleTestCase): class TestSplitFormWidget(PostgreSQLWidgetTestCase): - def test_get_context(self): self.assertEqual( - SplitArrayWidget(forms.TextInput(), size=2).get_context('name', ['val1', 'val2']), + SplitArrayWidget(forms.TextInput(), size=2).get_context( + "name", ["val1", "val2"] + ), { - 'widget': { - 'name': 'name', - 'is_hidden': False, - 'required': False, - 'value': "['val1', 'val2']", - 'attrs': {}, - 'template_name': 'postgres/widgets/split_array.html', - 'subwidgets': [ + "widget": { + "name": "name", + "is_hidden": False, + "required": False, + "value": "['val1', 'val2']", + "attrs": {}, + "template_name": "postgres/widgets/split_array.html", + "subwidgets": [ { - 'name': 'name_0', - 'is_hidden': False, - 'required': False, - 'value': 'val1', - 'attrs': {}, - 'template_name': 'django/forms/widgets/text.html', - 'type': 'text', + "name": "name_0", + "is_hidden": False, + "required": False, + "value": "val1", + "attrs": {}, + "template_name": "django/forms/widgets/text.html", + "type": "text", }, { - 'name': 'name_1', - 'is_hidden': False, - 'required': False, - 'value': 'val2', - 'attrs': {}, - 'template_name': 'django/forms/widgets/text.html', - 'type': 'text', + "name": "name_1", + "is_hidden": False, + "required": False, + "value": "val2", + "attrs": {}, + "template_name": "django/forms/widgets/text.html", + "type": "text", }, - ] + ], } - } + }, ) def test_checkbox_get_context_attrs(self): context = SplitArrayWidget( forms.CheckboxInput(), size=2, - ).get_context('name', [True, False]) - self.assertEqual(context['widget']['value'], '[True, False]') + ).get_context("name", [True, False]) + self.assertEqual(context["widget"]["value"], "[True, False]") self.assertEqual( - [subwidget['attrs'] for subwidget in context['widget']['subwidgets']], - [{'checked': True}, {}] + [subwidget["attrs"] for subwidget in context["widget"]["subwidgets"]], + [{"checked": True}, {}], ) def test_render(self): self.check_html( - SplitArrayWidget(forms.TextInput(), size=2), 'array', None, + SplitArrayWidget(forms.TextInput(), size=2), + "array", + None, """ <input name="array_0" type="text"> <input name="array_1" type="text"> - """ + """, ) def test_render_attrs(self): self.check_html( SplitArrayWidget(forms.TextInput(), size=2), - 'array', ['val1', 'val2'], attrs={'id': 'foo'}, + "array", + ["val1", "val2"], + attrs={"id": "foo"}, html=( """ <input id="foo_0" name="array_0" type="text" value="val1"> <input id="foo_1" name="array_1" type="text" value="val2"> """ - ) + ), ) def test_value_omitted_from_data(self): widget = SplitArrayWidget(forms.TextInput(), size=2) - self.assertIs(widget.value_omitted_from_data({}, {}, 'field'), True) - self.assertIs(widget.value_omitted_from_data({'field_0': 'value'}, {}, 'field'), False) - self.assertIs(widget.value_omitted_from_data({'field_1': 'value'}, {}, 'field'), False) - self.assertIs(widget.value_omitted_from_data({'field_0': 'value', 'field_1': 'value'}, {}, 'field'), False) + self.assertIs(widget.value_omitted_from_data({}, {}, "field"), True) + self.assertIs( + widget.value_omitted_from_data({"field_0": "value"}, {}, "field"), False + ) + self.assertIs( + widget.value_omitted_from_data({"field_1": "value"}, {}, "field"), False + ) + self.assertIs( + widget.value_omitted_from_data( + {"field_0": "value", "field_1": "value"}, {}, "field" + ), + False, + ) diff --git a/tests/postgres_tests/test_bulk_update.py b/tests/postgres_tests/test_bulk_update.py index da5aee0f70..f0b473efa7 100644 --- a/tests/postgres_tests/test_bulk_update.py +++ b/tests/postgres_tests/test_bulk_update.py @@ -2,8 +2,12 @@ from datetime import date from . import PostgreSQLTestCase from .models import ( - HStoreModel, IntegerArrayModel, NestedIntegerArrayModel, - NullableIntegerArrayModel, OtherTypesArrayModel, RangesModel, + HStoreModel, + IntegerArrayModel, + NestedIntegerArrayModel, + NullableIntegerArrayModel, + OtherTypesArrayModel, + RangesModel, ) try: @@ -15,19 +19,28 @@ except ImportError: class BulkSaveTests(PostgreSQLTestCase): def test_bulk_update(self): test_data = [ - (IntegerArrayModel, 'field', [], [1, 2, 3]), - (NullableIntegerArrayModel, 'field', [1, 2, 3], None), - (NestedIntegerArrayModel, 'field', [], [[1, 2, 3]]), - (HStoreModel, 'field', {}, {1: 2}), - (RangesModel, 'ints', None, NumericRange(lower=1, upper=10)), - (RangesModel, 'dates', None, DateRange(lower=date.today(), upper=date.today())), - (OtherTypesArrayModel, 'ips', [], ['1.2.3.4']), - (OtherTypesArrayModel, 'json', [], [{'a': 'b'}]) + (IntegerArrayModel, "field", [], [1, 2, 3]), + (NullableIntegerArrayModel, "field", [1, 2, 3], None), + (NestedIntegerArrayModel, "field", [], [[1, 2, 3]]), + (HStoreModel, "field", {}, {1: 2}), + (RangesModel, "ints", None, NumericRange(lower=1, upper=10)), + ( + RangesModel, + "dates", + None, + DateRange(lower=date.today(), upper=date.today()), + ), + (OtherTypesArrayModel, "ips", [], ["1.2.3.4"]), + (OtherTypesArrayModel, "json", [], [{"a": "b"}]), ] for Model, field, initial, new in test_data: with self.subTest(model=Model, field=field): - instances = Model.objects.bulk_create(Model(**{field: initial}) for _ in range(20)) + instances = Model.objects.bulk_create( + Model(**{field: initial}) for _ in range(20) + ) for instance in instances: setattr(instance, field, new) Model.objects.bulk_update(instances, [field]) - self.assertSequenceEqual(Model.objects.filter(**{field: new}), instances) + self.assertSequenceEqual( + Model.objects.filter(**{field: new}), instances + ) diff --git a/tests/postgres_tests/test_citext.py b/tests/postgres_tests/test_citext.py index 350f34ba33..f1c13184b6 100644 --- a/tests/postgres_tests/test_citext.py +++ b/tests/postgres_tests/test_citext.py @@ -10,26 +10,35 @@ from . import PostgreSQLTestCase from .models import CITestModel -@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}) +@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"}) class CITextTestCase(PostgreSQLTestCase): - case_sensitive_lookups = ('contains', 'startswith', 'endswith', 'regex') + case_sensitive_lookups = ("contains", "startswith", "endswith", "regex") @classmethod def setUpTestData(cls): cls.john = CITestModel.objects.create( - name='JoHn', - email='joHn@johN.com', - description='Average Joe named JoHn', - array_field=['JoE', 'jOhn'], + name="JoHn", + email="joHn@johN.com", + description="Average Joe named JoHn", + array_field=["JoE", "jOhn"], ) def test_equal_lowercase(self): """ citext removes the need for iexact as the index is case-insensitive. """ - self.assertEqual(CITestModel.objects.filter(name=self.john.name.lower()).count(), 1) - self.assertEqual(CITestModel.objects.filter(email=self.john.email.lower()).count(), 1) - self.assertEqual(CITestModel.objects.filter(description=self.john.description.lower()).count(), 1) + self.assertEqual( + CITestModel.objects.filter(name=self.john.name.lower()).count(), 1 + ) + self.assertEqual( + CITestModel.objects.filter(email=self.john.email.lower()).count(), 1 + ) + self.assertEqual( + CITestModel.objects.filter( + description=self.john.description.lower() + ).count(), + 1, + ) def test_fail_citext_primary_key(self): """ @@ -37,27 +46,39 @@ class CITextTestCase(PostgreSQLTestCase): clashes with an existing value isn't allowed. """ with self.assertRaises(IntegrityError): - CITestModel.objects.create(name='John') + CITestModel.objects.create(name="John") def test_array_field(self): instance = CITestModel.objects.get() self.assertEqual(instance.array_field, self.john.array_field) - self.assertTrue(CITestModel.objects.filter(array_field__contains=['joe']).exists()) + self.assertTrue( + CITestModel.objects.filter(array_field__contains=["joe"]).exists() + ) def test_lookups_name_char(self): for lookup in self.case_sensitive_lookups: with self.subTest(lookup=lookup): - query = {'name__{}'.format(lookup): 'john'} - self.assertSequenceEqual(CITestModel.objects.filter(**query), [self.john]) + query = {"name__{}".format(lookup): "john"} + self.assertSequenceEqual( + CITestModel.objects.filter(**query), [self.john] + ) def test_lookups_description_text(self): - for lookup, string in zip(self.case_sensitive_lookups, ('average', 'average joe', 'john', 'Joe.named')): + for lookup, string in zip( + self.case_sensitive_lookups, ("average", "average joe", "john", "Joe.named") + ): with self.subTest(lookup=lookup, string=string): - query = {'description__{}'.format(lookup): string} - self.assertSequenceEqual(CITestModel.objects.filter(**query), [self.john]) + query = {"description__{}".format(lookup): string} + self.assertSequenceEqual( + CITestModel.objects.filter(**query), [self.john] + ) def test_lookups_email(self): - for lookup, string in zip(self.case_sensitive_lookups, ('john', 'john', 'john.com', 'john.com')): + for lookup, string in zip( + self.case_sensitive_lookups, ("john", "john", "john.com", "john.com") + ): with self.subTest(lookup=lookup, string=string): - query = {'email__{}'.format(lookup): string} - self.assertSequenceEqual(CITestModel.objects.filter(**query), [self.john]) + query = {"email__{}".format(lookup): string} + self.assertSequenceEqual( + CITestModel.objects.filter(**query), [self.john] + ) diff --git a/tests/postgres_tests/test_constraints.py b/tests/postgres_tests/test_constraints.py index 22506ed62d..14b45f9b7f 100644 --- a/tests/postgres_tests/test_constraints.py +++ b/tests/postgres_tests/test_constraints.py @@ -2,11 +2,15 @@ import datetime from unittest import mock from django.contrib.postgres.indexes import OpClass -from django.db import ( - IntegrityError, NotSupportedError, connection, transaction, -) +from django.db import IntegrityError, NotSupportedError, connection, transaction from django.db.models import ( - CheckConstraint, Deferrable, F, Func, IntegerField, Q, UniqueConstraint, + CheckConstraint, + Deferrable, + F, + Func, + IntegerField, + Q, + UniqueConstraint, ) from django.db.models.fields.json import KeyTextTransform from django.db.models.functions import Cast, Left, Lower @@ -15,29 +19,29 @@ from django.utils import timezone from django.utils.deprecation import RemovedInDjango50Warning from . import PostgreSQLTestCase -from .models import ( - HotelReservation, IntegerArrayModel, RangesModel, Room, Scene, -) +from .models import HotelReservation, IntegerArrayModel, RangesModel, Room, Scene try: from psycopg2.extras import DateRange, NumericRange from django.contrib.postgres.constraints import ExclusionConstraint from django.contrib.postgres.fields import ( - DateTimeRangeField, RangeBoundary, RangeOperators, + DateTimeRangeField, + RangeBoundary, + RangeOperators, ) except ImportError: pass -@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}) +@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"}) class SchemaTests(PostgreSQLTestCase): - get_opclass_query = ''' + get_opclass_query = """ SELECT opcname, c.relname FROM pg_opclass AS oc JOIN pg_index as i on oc.oid = ANY(i.indclass) JOIN pg_class as c on c.oid = i.indexrelid WHERE c.relname = %s - ''' + """ def get_constraints(self, table): """Get the constraints on the table using a new cursor.""" @@ -45,8 +49,10 @@ class SchemaTests(PostgreSQLTestCase): return connection.introspection.get_constraints(cursor, table) def test_check_constraint_range_value(self): - constraint_name = 'ints_between' - self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + constraint_name = "ints_between" + self.assertNotIn( + constraint_name, self.get_constraints(RangesModel._meta.db_table) + ) constraint = CheckConstraint( check=Q(ints__contained_by=NumericRange(10, 30)), name=constraint_name, @@ -59,10 +65,12 @@ class SchemaTests(PostgreSQLTestCase): RangesModel.objects.create(ints=(10, 30)) def test_check_constraint_daterange_contains(self): - constraint_name = 'dates_contains' - self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + constraint_name = "dates_contains" + self.assertNotIn( + constraint_name, self.get_constraints(RangesModel._meta.db_table) + ) constraint = CheckConstraint( - check=Q(dates__contains=F('dates_inner')), + check=Q(dates__contains=F("dates_inner")), name=constraint_name, ) with connection.schema_editor() as editor: @@ -81,10 +89,12 @@ class SchemaTests(PostgreSQLTestCase): ) def test_check_constraint_datetimerange_contains(self): - constraint_name = 'timestamps_contains' - self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + constraint_name = "timestamps_contains" + self.assertNotIn( + constraint_name, self.get_constraints(RangesModel._meta.db_table) + ) constraint = CheckConstraint( - check=Q(timestamps__contains=F('timestamps_inner')), + check=Q(timestamps__contains=F("timestamps_inner")), name=constraint_name, ) with connection.schema_editor() as editor: @@ -104,9 +114,9 @@ class SchemaTests(PostgreSQLTestCase): def test_opclass(self): constraint = UniqueConstraint( - name='test_opclass', - fields=['scene'], - opclasses=['varchar_pattern_ops'], + name="test_opclass", + fields=["scene"], + opclasses=["varchar_pattern_ops"], ) with connection.schema_editor() as editor: editor.add_constraint(Scene, constraint) @@ -115,7 +125,7 @@ class SchemaTests(PostgreSQLTestCase): cursor.execute(self.get_opclass_query, [constraint.name]) self.assertEqual( cursor.fetchall(), - [('varchar_pattern_ops', constraint.name)], + [("varchar_pattern_ops", constraint.name)], ) # Drop the constraint. with connection.schema_editor() as editor: @@ -124,25 +134,25 @@ class SchemaTests(PostgreSQLTestCase): def test_opclass_multiple_columns(self): constraint = UniqueConstraint( - name='test_opclass_multiple', - fields=['scene', 'setting'], - opclasses=['varchar_pattern_ops', 'text_pattern_ops'], + name="test_opclass_multiple", + fields=["scene", "setting"], + opclasses=["varchar_pattern_ops", "text_pattern_ops"], ) with connection.schema_editor() as editor: editor.add_constraint(Scene, constraint) with editor.connection.cursor() as cursor: cursor.execute(self.get_opclass_query, [constraint.name]) expected_opclasses = ( - ('varchar_pattern_ops', constraint.name), - ('text_pattern_ops', constraint.name), + ("varchar_pattern_ops", constraint.name), + ("text_pattern_ops", constraint.name), ) self.assertCountEqual(cursor.fetchall(), expected_opclasses) def test_opclass_partial(self): constraint = UniqueConstraint( - name='test_opclass_partial', - fields=['scene'], - opclasses=['varchar_pattern_ops'], + name="test_opclass_partial", + fields=["scene"], + opclasses=["varchar_pattern_ops"], condition=Q(setting__contains="Sir Bedemir's Castle"), ) with connection.schema_editor() as editor: @@ -151,16 +161,16 @@ class SchemaTests(PostgreSQLTestCase): cursor.execute(self.get_opclass_query, [constraint.name]) self.assertCountEqual( cursor.fetchall(), - [('varchar_pattern_ops', constraint.name)], + [("varchar_pattern_ops", constraint.name)], ) - @skipUnlessDBFeature('supports_covering_indexes') + @skipUnlessDBFeature("supports_covering_indexes") def test_opclass_include(self): constraint = UniqueConstraint( - name='test_opclass_include', - fields=['scene'], - opclasses=['varchar_pattern_ops'], - include=['setting'], + name="test_opclass_include", + fields=["scene"], + opclasses=["varchar_pattern_ops"], + include=["setting"], ) with connection.schema_editor() as editor: editor.add_constraint(Scene, constraint) @@ -168,38 +178,38 @@ class SchemaTests(PostgreSQLTestCase): cursor.execute(self.get_opclass_query, [constraint.name]) self.assertCountEqual( cursor.fetchall(), - [('varchar_pattern_ops', constraint.name)], + [("varchar_pattern_ops", constraint.name)], ) - @skipUnlessDBFeature('supports_expression_indexes') + @skipUnlessDBFeature("supports_expression_indexes") def test_opclass_func(self): constraint = UniqueConstraint( - OpClass(Lower('scene'), name='text_pattern_ops'), - name='test_opclass_func', + OpClass(Lower("scene"), name="text_pattern_ops"), + name="test_opclass_func", ) with connection.schema_editor() as editor: editor.add_constraint(Scene, constraint) constraints = self.get_constraints(Scene._meta.db_table) - self.assertIs(constraints[constraint.name]['unique'], True) + self.assertIs(constraints[constraint.name]["unique"], True) self.assertIn(constraint.name, constraints) with editor.connection.cursor() as cursor: cursor.execute(self.get_opclass_query, [constraint.name]) self.assertEqual( cursor.fetchall(), - [('text_pattern_ops', constraint.name)], + [("text_pattern_ops", constraint.name)], ) - Scene.objects.create(scene='Scene 10', setting='The dark forest of Ewing') + Scene.objects.create(scene="Scene 10", setting="The dark forest of Ewing") with self.assertRaises(IntegrityError), transaction.atomic(): - Scene.objects.create(scene='ScEnE 10', setting="Sir Bedemir's Castle") - Scene.objects.create(scene='Scene 5', setting="Sir Bedemir's Castle") + Scene.objects.create(scene="ScEnE 10", setting="Sir Bedemir's Castle") + Scene.objects.create(scene="Scene 5", setting="Sir Bedemir's Castle") # Drop the constraint. with connection.schema_editor() as editor: editor.remove_constraint(Scene, constraint) self.assertNotIn(constraint.name, self.get_constraints(Scene._meta.db_table)) - Scene.objects.create(scene='ScEnE 10', setting="Sir Bedemir's Castle") + Scene.objects.create(scene="ScEnE 10", setting="Sir Bedemir's Castle") -@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}) +@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"}) class ExclusionConstraintTests(PostgreSQLTestCase): def get_constraints(self, table): """Get the constraints on the table using a new cursor.""" @@ -207,102 +217,104 @@ class ExclusionConstraintTests(PostgreSQLTestCase): return connection.introspection.get_constraints(cursor, table) def test_invalid_condition(self): - msg = 'ExclusionConstraint.condition must be a Q instance.' + msg = "ExclusionConstraint.condition must be a Q instance." with self.assertRaisesMessage(ValueError, msg): ExclusionConstraint( - index_type='GIST', - name='exclude_invalid_condition', - expressions=[(F('datespan'), RangeOperators.OVERLAPS)], - condition=F('invalid'), + index_type="GIST", + name="exclude_invalid_condition", + expressions=[(F("datespan"), RangeOperators.OVERLAPS)], + condition=F("invalid"), ) def test_invalid_index_type(self): - msg = 'Exclusion constraints only support GiST or SP-GiST indexes.' + msg = "Exclusion constraints only support GiST or SP-GiST indexes." with self.assertRaisesMessage(ValueError, msg): ExclusionConstraint( - index_type='gin', - name='exclude_invalid_index_type', - expressions=[(F('datespan'), RangeOperators.OVERLAPS)], + index_type="gin", + name="exclude_invalid_index_type", + expressions=[(F("datespan"), RangeOperators.OVERLAPS)], ) def test_invalid_expressions(self): - msg = 'The expressions must be a list of 2-tuples.' - for expressions in (['foo'], [('foo')], [('foo_1', 'foo_2', 'foo_3')]): + msg = "The expressions must be a list of 2-tuples." + for expressions in (["foo"], [("foo")], [("foo_1", "foo_2", "foo_3")]): with self.subTest(expressions), self.assertRaisesMessage(ValueError, msg): ExclusionConstraint( - index_type='GIST', - name='exclude_invalid_expressions', + index_type="GIST", + name="exclude_invalid_expressions", expressions=expressions, ) def test_empty_expressions(self): - msg = 'At least one expression is required to define an exclusion constraint.' + msg = "At least one expression is required to define an exclusion constraint." for empty_expressions in (None, []): - with self.subTest(empty_expressions), self.assertRaisesMessage(ValueError, msg): + with self.subTest(empty_expressions), self.assertRaisesMessage( + ValueError, msg + ): ExclusionConstraint( - index_type='GIST', - name='exclude_empty_expressions', + index_type="GIST", + name="exclude_empty_expressions", expressions=empty_expressions, ) def test_invalid_deferrable(self): - msg = 'ExclusionConstraint.deferrable must be a Deferrable instance.' + msg = "ExclusionConstraint.deferrable must be a Deferrable instance." with self.assertRaisesMessage(ValueError, msg): ExclusionConstraint( - name='exclude_invalid_deferrable', - expressions=[(F('datespan'), RangeOperators.OVERLAPS)], - deferrable='invalid', + name="exclude_invalid_deferrable", + expressions=[(F("datespan"), RangeOperators.OVERLAPS)], + deferrable="invalid", ) def test_deferrable_with_condition(self): - msg = 'ExclusionConstraint with conditions cannot be deferred.' + msg = "ExclusionConstraint with conditions cannot be deferred." with self.assertRaisesMessage(ValueError, msg): ExclusionConstraint( - name='exclude_invalid_condition', - expressions=[(F('datespan'), RangeOperators.OVERLAPS)], + name="exclude_invalid_condition", + expressions=[(F("datespan"), RangeOperators.OVERLAPS)], condition=Q(cancelled=False), deferrable=Deferrable.DEFERRED, ) def test_invalid_include_type(self): - msg = 'ExclusionConstraint.include must be a list or tuple.' + msg = "ExclusionConstraint.include must be a list or tuple." with self.assertRaisesMessage(ValueError, msg): ExclusionConstraint( - name='exclude_invalid_include', - expressions=[(F('datespan'), RangeOperators.OVERLAPS)], - include='invalid', + name="exclude_invalid_include", + expressions=[(F("datespan"), RangeOperators.OVERLAPS)], + include="invalid", ) @ignore_warnings(category=RemovedInDjango50Warning) def test_invalid_opclasses_type(self): - msg = 'ExclusionConstraint.opclasses must be a list or tuple.' + msg = "ExclusionConstraint.opclasses must be a list or tuple." with self.assertRaisesMessage(ValueError, msg): ExclusionConstraint( - name='exclude_invalid_opclasses', - expressions=[(F('datespan'), RangeOperators.OVERLAPS)], - opclasses='invalid', + name="exclude_invalid_opclasses", + expressions=[(F("datespan"), RangeOperators.OVERLAPS)], + opclasses="invalid", ) @ignore_warnings(category=RemovedInDjango50Warning) def test_opclasses_and_expressions_same_length(self): msg = ( - 'ExclusionConstraint.expressions and ' - 'ExclusionConstraint.opclasses must have the same number of ' - 'elements.' + "ExclusionConstraint.expressions and " + "ExclusionConstraint.opclasses must have the same number of " + "elements." ) with self.assertRaisesMessage(ValueError, msg): ExclusionConstraint( - name='exclude_invalid_expressions_opclasses_length', - expressions=[(F('datespan'), RangeOperators.OVERLAPS)], - opclasses=['foo', 'bar'], + name="exclude_invalid_expressions_opclasses_length", + expressions=[(F("datespan"), RangeOperators.OVERLAPS)], + opclasses=["foo", "bar"], ) def test_repr(self): constraint = ExclusionConstraint( - name='exclude_overlapping', + name="exclude_overlapping", expressions=[ - (F('datespan'), RangeOperators.OVERLAPS), - (F('room'), RangeOperators.EQUAL), + (F("datespan"), RangeOperators.OVERLAPS), + (F("room"), RangeOperators.EQUAL), ], ) self.assertEqual( @@ -311,10 +323,10 @@ class ExclusionConstraintTests(PostgreSQLTestCase): "(F(datespan), '&&'), (F(room), '=')] name='exclude_overlapping'>", ) constraint = ExclusionConstraint( - name='exclude_overlapping', - expressions=[(F('datespan'), RangeOperators.ADJACENT_TO)], + name="exclude_overlapping", + expressions=[(F("datespan"), RangeOperators.ADJACENT_TO)], condition=Q(cancelled=False), - index_type='SPGiST', + index_type="SPGiST", ) self.assertEqual( repr(constraint), @@ -323,8 +335,8 @@ class ExclusionConstraintTests(PostgreSQLTestCase): "condition=(AND: ('cancelled', False))>", ) constraint = ExclusionConstraint( - name='exclude_overlapping', - expressions=[(F('datespan'), RangeOperators.ADJACENT_TO)], + name="exclude_overlapping", + expressions=[(F("datespan"), RangeOperators.ADJACENT_TO)], deferrable=Deferrable.IMMEDIATE, ) self.assertEqual( @@ -334,9 +346,9 @@ class ExclusionConstraintTests(PostgreSQLTestCase): "deferrable=Deferrable.IMMEDIATE>", ) constraint = ExclusionConstraint( - name='exclude_overlapping', - expressions=[(F('datespan'), RangeOperators.ADJACENT_TO)], - include=['cancelled', 'room'], + name="exclude_overlapping", + expressions=[(F("datespan"), RangeOperators.ADJACENT_TO)], + include=["cancelled", "room"], ) self.assertEqual( repr(constraint), @@ -345,9 +357,9 @@ class ExclusionConstraintTests(PostgreSQLTestCase): "include=('cancelled', 'room')>", ) constraint = ExclusionConstraint( - name='exclude_overlapping', + name="exclude_overlapping", expressions=[ - (OpClass('datespan', name='range_ops'), RangeOperators.ADJACENT_TO), + (OpClass("datespan", name="range_ops"), RangeOperators.ADJACENT_TO), ], ) self.assertEqual( @@ -359,75 +371,75 @@ class ExclusionConstraintTests(PostgreSQLTestCase): def test_eq(self): constraint_1 = ExclusionConstraint( - name='exclude_overlapping', + name="exclude_overlapping", expressions=[ - (F('datespan'), RangeOperators.OVERLAPS), - (F('room'), RangeOperators.EQUAL), + (F("datespan"), RangeOperators.OVERLAPS), + (F("room"), RangeOperators.EQUAL), ], condition=Q(cancelled=False), ) constraint_2 = ExclusionConstraint( - name='exclude_overlapping', + name="exclude_overlapping", expressions=[ - ('datespan', RangeOperators.OVERLAPS), - ('room', RangeOperators.EQUAL), + ("datespan", RangeOperators.OVERLAPS), + ("room", RangeOperators.EQUAL), ], ) constraint_3 = ExclusionConstraint( - name='exclude_overlapping', - expressions=[('datespan', RangeOperators.OVERLAPS)], + name="exclude_overlapping", + expressions=[("datespan", RangeOperators.OVERLAPS)], condition=Q(cancelled=False), ) constraint_4 = ExclusionConstraint( - name='exclude_overlapping', + name="exclude_overlapping", expressions=[ - ('datespan', RangeOperators.OVERLAPS), - ('room', RangeOperators.EQUAL), + ("datespan", RangeOperators.OVERLAPS), + ("room", RangeOperators.EQUAL), ], deferrable=Deferrable.DEFERRED, ) constraint_5 = ExclusionConstraint( - name='exclude_overlapping', + name="exclude_overlapping", expressions=[ - ('datespan', RangeOperators.OVERLAPS), - ('room', RangeOperators.EQUAL), + ("datespan", RangeOperators.OVERLAPS), + ("room", RangeOperators.EQUAL), ], deferrable=Deferrable.IMMEDIATE, ) constraint_6 = ExclusionConstraint( - name='exclude_overlapping', + name="exclude_overlapping", expressions=[ - ('datespan', RangeOperators.OVERLAPS), - ('room', RangeOperators.EQUAL), + ("datespan", RangeOperators.OVERLAPS), + ("room", RangeOperators.EQUAL), ], deferrable=Deferrable.IMMEDIATE, - include=['cancelled'], + include=["cancelled"], ) constraint_7 = ExclusionConstraint( - name='exclude_overlapping', + name="exclude_overlapping", expressions=[ - ('datespan', RangeOperators.OVERLAPS), - ('room', RangeOperators.EQUAL), + ("datespan", RangeOperators.OVERLAPS), + ("room", RangeOperators.EQUAL), ], - include=['cancelled'], + include=["cancelled"], ) with ignore_warnings(category=RemovedInDjango50Warning): constraint_8 = ExclusionConstraint( - name='exclude_overlapping', + name="exclude_overlapping", expressions=[ - ('datespan', RangeOperators.OVERLAPS), - ('room', RangeOperators.EQUAL), + ("datespan", RangeOperators.OVERLAPS), + ("room", RangeOperators.EQUAL), ], - include=['cancelled'], - opclasses=['range_ops', 'range_ops'] + include=["cancelled"], + opclasses=["range_ops", "range_ops"], ) constraint_9 = ExclusionConstraint( - name='exclude_overlapping', + name="exclude_overlapping", expressions=[ - ('datespan', RangeOperators.OVERLAPS), - ('room', RangeOperators.EQUAL), + ("datespan", RangeOperators.OVERLAPS), + ("room", RangeOperators.EQUAL), ], - opclasses=['range_ops', 'range_ops'] + opclasses=["range_ops", "range_ops"], ) self.assertNotEqual(constraint_2, constraint_9) self.assertNotEqual(constraint_7, constraint_8) @@ -445,99 +457,151 @@ class ExclusionConstraintTests(PostgreSQLTestCase): def test_deconstruct(self): constraint = ExclusionConstraint( - name='exclude_overlapping', - expressions=[('datespan', RangeOperators.OVERLAPS), ('room', RangeOperators.EQUAL)], + name="exclude_overlapping", + expressions=[ + ("datespan", RangeOperators.OVERLAPS), + ("room", RangeOperators.EQUAL), + ], ) path, args, kwargs = constraint.deconstruct() - self.assertEqual(path, 'django.contrib.postgres.constraints.ExclusionConstraint') + self.assertEqual( + path, "django.contrib.postgres.constraints.ExclusionConstraint" + ) self.assertEqual(args, ()) - self.assertEqual(kwargs, { - 'name': 'exclude_overlapping', - 'expressions': [('datespan', RangeOperators.OVERLAPS), ('room', RangeOperators.EQUAL)], - }) + self.assertEqual( + kwargs, + { + "name": "exclude_overlapping", + "expressions": [ + ("datespan", RangeOperators.OVERLAPS), + ("room", RangeOperators.EQUAL), + ], + }, + ) def test_deconstruct_index_type(self): constraint = ExclusionConstraint( - name='exclude_overlapping', - index_type='SPGIST', - expressions=[('datespan', RangeOperators.OVERLAPS), ('room', RangeOperators.EQUAL)], + name="exclude_overlapping", + index_type="SPGIST", + expressions=[ + ("datespan", RangeOperators.OVERLAPS), + ("room", RangeOperators.EQUAL), + ], ) path, args, kwargs = constraint.deconstruct() - self.assertEqual(path, 'django.contrib.postgres.constraints.ExclusionConstraint') + self.assertEqual( + path, "django.contrib.postgres.constraints.ExclusionConstraint" + ) self.assertEqual(args, ()) - self.assertEqual(kwargs, { - 'name': 'exclude_overlapping', - 'index_type': 'SPGIST', - 'expressions': [('datespan', RangeOperators.OVERLAPS), ('room', RangeOperators.EQUAL)], - }) + self.assertEqual( + kwargs, + { + "name": "exclude_overlapping", + "index_type": "SPGIST", + "expressions": [ + ("datespan", RangeOperators.OVERLAPS), + ("room", RangeOperators.EQUAL), + ], + }, + ) def test_deconstruct_condition(self): constraint = ExclusionConstraint( - name='exclude_overlapping', - expressions=[('datespan', RangeOperators.OVERLAPS), ('room', RangeOperators.EQUAL)], + name="exclude_overlapping", + expressions=[ + ("datespan", RangeOperators.OVERLAPS), + ("room", RangeOperators.EQUAL), + ], condition=Q(cancelled=False), ) path, args, kwargs = constraint.deconstruct() - self.assertEqual(path, 'django.contrib.postgres.constraints.ExclusionConstraint') + self.assertEqual( + path, "django.contrib.postgres.constraints.ExclusionConstraint" + ) self.assertEqual(args, ()) - self.assertEqual(kwargs, { - 'name': 'exclude_overlapping', - 'expressions': [('datespan', RangeOperators.OVERLAPS), ('room', RangeOperators.EQUAL)], - 'condition': Q(cancelled=False), - }) + self.assertEqual( + kwargs, + { + "name": "exclude_overlapping", + "expressions": [ + ("datespan", RangeOperators.OVERLAPS), + ("room", RangeOperators.EQUAL), + ], + "condition": Q(cancelled=False), + }, + ) def test_deconstruct_deferrable(self): constraint = ExclusionConstraint( - name='exclude_overlapping', - expressions=[('datespan', RangeOperators.OVERLAPS)], + name="exclude_overlapping", + expressions=[("datespan", RangeOperators.OVERLAPS)], deferrable=Deferrable.DEFERRED, ) path, args, kwargs = constraint.deconstruct() - self.assertEqual(path, 'django.contrib.postgres.constraints.ExclusionConstraint') + self.assertEqual( + path, "django.contrib.postgres.constraints.ExclusionConstraint" + ) self.assertEqual(args, ()) - self.assertEqual(kwargs, { - 'name': 'exclude_overlapping', - 'expressions': [('datespan', RangeOperators.OVERLAPS)], - 'deferrable': Deferrable.DEFERRED, - }) + self.assertEqual( + kwargs, + { + "name": "exclude_overlapping", + "expressions": [("datespan", RangeOperators.OVERLAPS)], + "deferrable": Deferrable.DEFERRED, + }, + ) def test_deconstruct_include(self): constraint = ExclusionConstraint( - name='exclude_overlapping', - expressions=[('datespan', RangeOperators.OVERLAPS)], - include=['cancelled', 'room'], + name="exclude_overlapping", + expressions=[("datespan", RangeOperators.OVERLAPS)], + include=["cancelled", "room"], ) path, args, kwargs = constraint.deconstruct() - self.assertEqual(path, 'django.contrib.postgres.constraints.ExclusionConstraint') + self.assertEqual( + path, "django.contrib.postgres.constraints.ExclusionConstraint" + ) self.assertEqual(args, ()) - self.assertEqual(kwargs, { - 'name': 'exclude_overlapping', - 'expressions': [('datespan', RangeOperators.OVERLAPS)], - 'include': ('cancelled', 'room'), - }) + self.assertEqual( + kwargs, + { + "name": "exclude_overlapping", + "expressions": [("datespan", RangeOperators.OVERLAPS)], + "include": ("cancelled", "room"), + }, + ) @ignore_warnings(category=RemovedInDjango50Warning) def test_deconstruct_opclasses(self): constraint = ExclusionConstraint( - name='exclude_overlapping', - expressions=[('datespan', RangeOperators.OVERLAPS)], - opclasses=['range_ops'], + name="exclude_overlapping", + expressions=[("datespan", RangeOperators.OVERLAPS)], + opclasses=["range_ops"], ) path, args, kwargs = constraint.deconstruct() - self.assertEqual(path, 'django.contrib.postgres.constraints.ExclusionConstraint') + self.assertEqual( + path, "django.contrib.postgres.constraints.ExclusionConstraint" + ) self.assertEqual(args, ()) - self.assertEqual(kwargs, { - 'name': 'exclude_overlapping', - 'expressions': [('datespan', RangeOperators.OVERLAPS)], - 'opclasses': ['range_ops'], - }) + self.assertEqual( + kwargs, + { + "name": "exclude_overlapping", + "expressions": [("datespan", RangeOperators.OVERLAPS)], + "opclasses": ["range_ops"], + }, + ) def _test_range_overlaps(self, constraint): # Create exclusion constraint. - self.assertNotIn(constraint.name, self.get_constraints(HotelReservation._meta.db_table)) + self.assertNotIn( + constraint.name, self.get_constraints(HotelReservation._meta.db_table) + ) with connection.schema_editor() as editor: editor.add_constraint(HotelReservation, constraint) - self.assertIn(constraint.name, self.get_constraints(HotelReservation._meta.db_table)) + self.assertIn( + constraint.name, self.get_constraints(HotelReservation._meta.db_table) + ) # Add initial reservations. room101 = Room.objects.create(number=101) room102 = Room.objects.create(number=102) @@ -570,61 +634,63 @@ class ExclusionConstraintTests(PostgreSQLTestCase): ) reservation.save() # Valid range. - HotelReservation.objects.bulk_create([ - # Other room. - HotelReservation( - datespan=(datetimes[1].date(), datetimes[2].date()), - start=datetimes[1], - end=datetimes[2], - room=room101, - ), - # Cancelled reservation. - HotelReservation( - datespan=(datetimes[1].date(), datetimes[1].date()), - start=datetimes[1], - end=datetimes[2], - room=room102, - cancelled=True, - ), - # Other adjacent dates. - HotelReservation( - datespan=(datetimes[3].date(), datetimes[4].date()), - start=datetimes[3], - end=datetimes[4], - room=room102, - ), - ]) + HotelReservation.objects.bulk_create( + [ + # Other room. + HotelReservation( + datespan=(datetimes[1].date(), datetimes[2].date()), + start=datetimes[1], + end=datetimes[2], + room=room101, + ), + # Cancelled reservation. + HotelReservation( + datespan=(datetimes[1].date(), datetimes[1].date()), + start=datetimes[1], + end=datetimes[2], + room=room102, + cancelled=True, + ), + # Other adjacent dates. + HotelReservation( + datespan=(datetimes[3].date(), datetimes[4].date()), + start=datetimes[3], + end=datetimes[4], + room=room102, + ), + ] + ) @ignore_warnings(category=RemovedInDjango50Warning) def test_range_overlaps_custom_opclasses(self): class TsTzRange(Func): - function = 'TSTZRANGE' + function = "TSTZRANGE" output_field = DateTimeRangeField() constraint = ExclusionConstraint( - name='exclude_overlapping_reservations_custom', + name="exclude_overlapping_reservations_custom", expressions=[ - (TsTzRange('start', 'end', RangeBoundary()), RangeOperators.OVERLAPS), - ('room', RangeOperators.EQUAL) + (TsTzRange("start", "end", RangeBoundary()), RangeOperators.OVERLAPS), + ("room", RangeOperators.EQUAL), ], condition=Q(cancelled=False), - opclasses=['range_ops', 'gist_int4_ops'], + opclasses=["range_ops", "gist_int4_ops"], ) self._test_range_overlaps(constraint) def test_range_overlaps_custom(self): class TsTzRange(Func): - function = 'TSTZRANGE' + function = "TSTZRANGE" output_field = DateTimeRangeField() constraint = ExclusionConstraint( - name='exclude_overlapping_reservations_custom_opclass', + name="exclude_overlapping_reservations_custom_opclass", expressions=[ ( - OpClass(TsTzRange('start', 'end', RangeBoundary()), 'range_ops'), + OpClass(TsTzRange("start", "end", RangeBoundary()), "range_ops"), RangeOperators.OVERLAPS, ), - (OpClass('room', 'gist_int4_ops'), RangeOperators.EQUAL), + (OpClass("room", "gist_int4_ops"), RangeOperators.EQUAL), ], condition=Q(cancelled=False), ) @@ -632,21 +698,23 @@ class ExclusionConstraintTests(PostgreSQLTestCase): def test_range_overlaps(self): constraint = ExclusionConstraint( - name='exclude_overlapping_reservations', + name="exclude_overlapping_reservations", expressions=[ - (F('datespan'), RangeOperators.OVERLAPS), - ('room', RangeOperators.EQUAL) + (F("datespan"), RangeOperators.OVERLAPS), + ("room", RangeOperators.EQUAL), ], condition=Q(cancelled=False), ) self._test_range_overlaps(constraint) def test_range_adjacent(self): - constraint_name = 'ints_adjacent' - self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + constraint_name = "ints_adjacent" + self.assertNotIn( + constraint_name, self.get_constraints(RangesModel._meta.db_table) + ) constraint = ExclusionConstraint( name=constraint_name, - expressions=[('ints', RangeOperators.ADJACENT_TO)], + expressions=[("ints", RangeOperators.ADJACENT_TO)], ) with connection.schema_editor() as editor: editor.add_constraint(RangesModel, constraint) @@ -659,26 +727,28 @@ class ExclusionConstraintTests(PostgreSQLTestCase): # Drop the constraint. with connection.schema_editor() as editor: editor.remove_constraint(RangesModel, constraint) - self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + self.assertNotIn( + constraint_name, self.get_constraints(RangesModel._meta.db_table) + ) def test_expressions_with_params(self): - constraint_name = 'scene_left_equal' + constraint_name = "scene_left_equal" self.assertNotIn(constraint_name, self.get_constraints(Scene._meta.db_table)) constraint = ExclusionConstraint( name=constraint_name, - expressions=[(Left('scene', 4), RangeOperators.EQUAL)], + expressions=[(Left("scene", 4), RangeOperators.EQUAL)], ) with connection.schema_editor() as editor: editor.add_constraint(Scene, constraint) self.assertIn(constraint_name, self.get_constraints(Scene._meta.db_table)) def test_expressions_with_key_transform(self): - constraint_name = 'exclude_overlapping_reservations_smoking' + constraint_name = "exclude_overlapping_reservations_smoking" constraint = ExclusionConstraint( name=constraint_name, expressions=[ - (F('datespan'), RangeOperators.OVERLAPS), - (KeyTextTransform('smoking', 'requirements'), RangeOperators.EQUAL), + (F("datespan"), RangeOperators.OVERLAPS), + (KeyTextTransform("smoking", "requirements"), RangeOperators.EQUAL), ], ) with connection.schema_editor() as editor: @@ -689,10 +759,10 @@ class ExclusionConstraintTests(PostgreSQLTestCase): ) def test_index_transform(self): - constraint_name = 'first_index_equal' + constraint_name = "first_index_equal" constraint = ExclusionConstraint( name=constraint_name, - expressions=[('field__0', RangeOperators.EQUAL)], + expressions=[("field__0", RangeOperators.EQUAL)], ) with connection.schema_editor() as editor: editor.add_constraint(IntegerArrayModel, constraint) @@ -702,11 +772,13 @@ class ExclusionConstraintTests(PostgreSQLTestCase): ) def test_range_adjacent_initially_deferred(self): - constraint_name = 'ints_adjacent_deferred' - self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + constraint_name = "ints_adjacent_deferred" + self.assertNotIn( + constraint_name, self.get_constraints(RangesModel._meta.db_table) + ) constraint = ExclusionConstraint( name=constraint_name, - expressions=[('ints', RangeOperators.ADJACENT_TO)], + expressions=[("ints", RangeOperators.ADJACENT_TO)], deferrable=Deferrable.DEFERRED, ) with connection.schema_editor() as editor: @@ -718,21 +790,23 @@ class ExclusionConstraintTests(PostgreSQLTestCase): with self.assertRaises(IntegrityError): with transaction.atomic(), connection.cursor() as cursor: quoted_name = connection.ops.quote_name(constraint_name) - cursor.execute('SET CONSTRAINTS %s IMMEDIATE' % quoted_name) + cursor.execute("SET CONSTRAINTS %s IMMEDIATE" % quoted_name) # Remove adjacent range before the end of transaction. adjacent_range.delete() RangesModel.objects.create(ints=(10, 19)) RangesModel.objects.create(ints=(51, 60)) - @skipUnlessDBFeature('supports_covering_gist_indexes') + @skipUnlessDBFeature("supports_covering_gist_indexes") def test_range_adjacent_gist_include(self): - constraint_name = 'ints_adjacent_gist_include' - self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + constraint_name = "ints_adjacent_gist_include" + self.assertNotIn( + constraint_name, self.get_constraints(RangesModel._meta.db_table) + ) constraint = ExclusionConstraint( name=constraint_name, - expressions=[('ints', RangeOperators.ADJACENT_TO)], - index_type='gist', - include=['decimals', 'ints'], + expressions=[("ints", RangeOperators.ADJACENT_TO)], + index_type="gist", + include=["decimals", "ints"], ) with connection.schema_editor() as editor: editor.add_constraint(RangesModel, constraint) @@ -743,15 +817,17 @@ class ExclusionConstraintTests(PostgreSQLTestCase): RangesModel.objects.create(ints=(10, 19)) RangesModel.objects.create(ints=(51, 60)) - @skipUnlessDBFeature('supports_covering_spgist_indexes') + @skipUnlessDBFeature("supports_covering_spgist_indexes") def test_range_adjacent_spgist_include(self): - constraint_name = 'ints_adjacent_spgist_include' - self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + constraint_name = "ints_adjacent_spgist_include" + self.assertNotIn( + constraint_name, self.get_constraints(RangesModel._meta.db_table) + ) constraint = ExclusionConstraint( name=constraint_name, - expressions=[('ints', RangeOperators.ADJACENT_TO)], - index_type='spgist', - include=['decimals', 'ints'], + expressions=[("ints", RangeOperators.ADJACENT_TO)], + index_type="spgist", + include=["decimals", "ints"], ) with connection.schema_editor() as editor: editor.add_constraint(RangesModel, constraint) @@ -762,60 +838,68 @@ class ExclusionConstraintTests(PostgreSQLTestCase): RangesModel.objects.create(ints=(10, 19)) RangesModel.objects.create(ints=(51, 60)) - @skipUnlessDBFeature('supports_covering_gist_indexes') + @skipUnlessDBFeature("supports_covering_gist_indexes") def test_range_adjacent_gist_include_condition(self): - constraint_name = 'ints_adjacent_gist_include_condition' - self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + constraint_name = "ints_adjacent_gist_include_condition" + self.assertNotIn( + constraint_name, self.get_constraints(RangesModel._meta.db_table) + ) constraint = ExclusionConstraint( name=constraint_name, - expressions=[('ints', RangeOperators.ADJACENT_TO)], - index_type='gist', - include=['decimals'], + expressions=[("ints", RangeOperators.ADJACENT_TO)], + index_type="gist", + include=["decimals"], condition=Q(id__gte=100), ) with connection.schema_editor() as editor: editor.add_constraint(RangesModel, constraint) self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) - @skipUnlessDBFeature('supports_covering_spgist_indexes') + @skipUnlessDBFeature("supports_covering_spgist_indexes") def test_range_adjacent_spgist_include_condition(self): - constraint_name = 'ints_adjacent_spgist_include_condition' - self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + constraint_name = "ints_adjacent_spgist_include_condition" + self.assertNotIn( + constraint_name, self.get_constraints(RangesModel._meta.db_table) + ) constraint = ExclusionConstraint( name=constraint_name, - expressions=[('ints', RangeOperators.ADJACENT_TO)], - index_type='spgist', - include=['decimals'], + expressions=[("ints", RangeOperators.ADJACENT_TO)], + index_type="spgist", + include=["decimals"], condition=Q(id__gte=100), ) with connection.schema_editor() as editor: editor.add_constraint(RangesModel, constraint) self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) - @skipUnlessDBFeature('supports_covering_gist_indexes') + @skipUnlessDBFeature("supports_covering_gist_indexes") def test_range_adjacent_gist_include_deferrable(self): - constraint_name = 'ints_adjacent_gist_include_deferrable' - self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + constraint_name = "ints_adjacent_gist_include_deferrable" + self.assertNotIn( + constraint_name, self.get_constraints(RangesModel._meta.db_table) + ) constraint = ExclusionConstraint( name=constraint_name, - expressions=[('ints', RangeOperators.ADJACENT_TO)], - index_type='gist', - include=['decimals'], + expressions=[("ints", RangeOperators.ADJACENT_TO)], + index_type="gist", + include=["decimals"], deferrable=Deferrable.DEFERRED, ) with connection.schema_editor() as editor: editor.add_constraint(RangesModel, constraint) self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) - @skipUnlessDBFeature('supports_covering_spgist_indexes') + @skipUnlessDBFeature("supports_covering_spgist_indexes") def test_range_adjacent_spgist_include_deferrable(self): - constraint_name = 'ints_adjacent_spgist_include_deferrable' - self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + constraint_name = "ints_adjacent_spgist_include_deferrable" + self.assertNotIn( + constraint_name, self.get_constraints(RangesModel._meta.db_table) + ) constraint = ExclusionConstraint( name=constraint_name, - expressions=[('ints', RangeOperators.ADJACENT_TO)], - index_type='spgist', - include=['decimals'], + expressions=[("ints", RangeOperators.ADJACENT_TO)], + index_type="spgist", + include=["decimals"], deferrable=Deferrable.DEFERRED, ) with connection.schema_editor() as editor: @@ -823,53 +907,55 @@ class ExclusionConstraintTests(PostgreSQLTestCase): self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) def test_gist_include_not_supported(self): - constraint_name = 'ints_adjacent_gist_include_not_supported' + constraint_name = "ints_adjacent_gist_include_not_supported" constraint = ExclusionConstraint( name=constraint_name, - expressions=[('ints', RangeOperators.ADJACENT_TO)], - index_type='gist', - include=['id'], + expressions=[("ints", RangeOperators.ADJACENT_TO)], + index_type="gist", + include=["id"], ) msg = ( - 'Covering exclusion constraints using a GiST index require ' - 'PostgreSQL 12+.' + "Covering exclusion constraints using a GiST index require " + "PostgreSQL 12+." ) with connection.schema_editor() as editor: with mock.patch( - 'django.db.backends.postgresql.features.DatabaseFeatures.supports_covering_gist_indexes', + "django.db.backends.postgresql.features.DatabaseFeatures.supports_covering_gist_indexes", False, ): with self.assertRaisesMessage(NotSupportedError, msg): editor.add_constraint(RangesModel, constraint) def test_spgist_include_not_supported(self): - constraint_name = 'ints_adjacent_spgist_include_not_supported' + constraint_name = "ints_adjacent_spgist_include_not_supported" constraint = ExclusionConstraint( name=constraint_name, - expressions=[('ints', RangeOperators.ADJACENT_TO)], - index_type='spgist', - include=['id'], + expressions=[("ints", RangeOperators.ADJACENT_TO)], + index_type="spgist", + include=["id"], ) msg = ( - 'Covering exclusion constraints using an SP-GiST index require ' - 'PostgreSQL 14+.' + "Covering exclusion constraints using an SP-GiST index require " + "PostgreSQL 14+." ) with connection.schema_editor() as editor: with mock.patch( - 'django.db.backends.postgresql.features.DatabaseFeatures.' - 'supports_covering_spgist_indexes', + "django.db.backends.postgresql.features.DatabaseFeatures." + "supports_covering_spgist_indexes", False, ): with self.assertRaisesMessage(NotSupportedError, msg): editor.add_constraint(RangesModel, constraint) def test_range_adjacent_opclass(self): - constraint_name = 'ints_adjacent_opclass' - self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + constraint_name = "ints_adjacent_opclass" + self.assertNotIn( + constraint_name, self.get_constraints(RangesModel._meta.db_table) + ) constraint = ExclusionConstraint( name=constraint_name, expressions=[ - (OpClass('ints', name='range_ops'), RangeOperators.ADJACENT_TO), + (OpClass("ints", name="range_ops"), RangeOperators.ADJACENT_TO), ], ) with connection.schema_editor() as editor: @@ -880,7 +966,7 @@ class ExclusionConstraintTests(PostgreSQLTestCase): cursor.execute(SchemaTests.get_opclass_query, [constraint_name]) self.assertEqual( cursor.fetchall(), - [('range_ops', constraint_name)], + [("range_ops", constraint_name)], ) RangesModel.objects.create(ints=(20, 50)) with self.assertRaises(IntegrityError), transaction.atomic(): @@ -890,15 +976,19 @@ class ExclusionConstraintTests(PostgreSQLTestCase): # Drop the constraint. with connection.schema_editor() as editor: editor.remove_constraint(RangesModel, constraint) - self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + self.assertNotIn( + constraint_name, self.get_constraints(RangesModel._meta.db_table) + ) def test_range_adjacent_opclass_condition(self): - constraint_name = 'ints_adjacent_opclass_condition' - self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + constraint_name = "ints_adjacent_opclass_condition" + self.assertNotIn( + constraint_name, self.get_constraints(RangesModel._meta.db_table) + ) constraint = ExclusionConstraint( name=constraint_name, expressions=[ - (OpClass('ints', name='range_ops'), RangeOperators.ADJACENT_TO), + (OpClass("ints", name="range_ops"), RangeOperators.ADJACENT_TO), ], condition=Q(id__gte=100), ) @@ -907,12 +997,14 @@ class ExclusionConstraintTests(PostgreSQLTestCase): self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) def test_range_adjacent_opclass_deferrable(self): - constraint_name = 'ints_adjacent_opclass_deferrable' - self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + constraint_name = "ints_adjacent_opclass_deferrable" + self.assertNotIn( + constraint_name, self.get_constraints(RangesModel._meta.db_table) + ) constraint = ExclusionConstraint( name=constraint_name, expressions=[ - (OpClass('ints', name='range_ops'), RangeOperators.ADJACENT_TO), + (OpClass("ints", name="range_ops"), RangeOperators.ADJACENT_TO), ], deferrable=Deferrable.DEFERRED, ) @@ -920,51 +1012,55 @@ class ExclusionConstraintTests(PostgreSQLTestCase): editor.add_constraint(RangesModel, constraint) self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) - @skipUnlessDBFeature('supports_covering_gist_indexes') + @skipUnlessDBFeature("supports_covering_gist_indexes") def test_range_adjacent_gist_opclass_include(self): - constraint_name = 'ints_adjacent_gist_opclass_include' - self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + constraint_name = "ints_adjacent_gist_opclass_include" + self.assertNotIn( + constraint_name, self.get_constraints(RangesModel._meta.db_table) + ) constraint = ExclusionConstraint( name=constraint_name, expressions=[ - (OpClass('ints', name='range_ops'), RangeOperators.ADJACENT_TO), + (OpClass("ints", name="range_ops"), RangeOperators.ADJACENT_TO), ], - index_type='gist', - include=['decimals'], + index_type="gist", + include=["decimals"], ) with connection.schema_editor() as editor: editor.add_constraint(RangesModel, constraint) self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) - @skipUnlessDBFeature('supports_covering_spgist_indexes') + @skipUnlessDBFeature("supports_covering_spgist_indexes") def test_range_adjacent_spgist_opclass_include(self): - constraint_name = 'ints_adjacent_spgist_opclass_include' - self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + constraint_name = "ints_adjacent_spgist_opclass_include" + self.assertNotIn( + constraint_name, self.get_constraints(RangesModel._meta.db_table) + ) constraint = ExclusionConstraint( name=constraint_name, expressions=[ - (OpClass('ints', name='range_ops'), RangeOperators.ADJACENT_TO), + (OpClass("ints", name="range_ops"), RangeOperators.ADJACENT_TO), ], - index_type='spgist', - include=['decimals'], + index_type="spgist", + include=["decimals"], ) with connection.schema_editor() as editor: editor.add_constraint(RangesModel, constraint) self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) def test_range_equal_cast(self): - constraint_name = 'exclusion_equal_room_cast' + constraint_name = "exclusion_equal_room_cast" self.assertNotIn(constraint_name, self.get_constraints(Room._meta.db_table)) constraint = ExclusionConstraint( name=constraint_name, - expressions=[(Cast('number', IntegerField()), RangeOperators.EQUAL)], + expressions=[(Cast("number", IntegerField()), RangeOperators.EQUAL)], ) with connection.schema_editor() as editor: editor.add_constraint(Room, constraint) self.assertIn(constraint_name, self.get_constraints(Room._meta.db_table)) -@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}) +@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"}) class ExclusionConstraintOpclassesDepracationTests(PostgreSQLTestCase): def get_constraints(self, table): """Get the constraints on the table using a new cursor.""" @@ -973,23 +1069,23 @@ class ExclusionConstraintOpclassesDepracationTests(PostgreSQLTestCase): def test_warning(self): msg = ( - 'The opclasses argument is deprecated in favor of using ' - 'django.contrib.postgres.indexes.OpClass in ' - 'ExclusionConstraint.expressions.' + "The opclasses argument is deprecated in favor of using " + "django.contrib.postgres.indexes.OpClass in " + "ExclusionConstraint.expressions." ) with self.assertWarnsMessage(RemovedInDjango50Warning, msg): ExclusionConstraint( - name='exclude_overlapping', - expressions=[(F('datespan'), RangeOperators.ADJACENT_TO)], - opclasses=['range_ops'], + name="exclude_overlapping", + expressions=[(F("datespan"), RangeOperators.ADJACENT_TO)], + opclasses=["range_ops"], ) @ignore_warnings(category=RemovedInDjango50Warning) def test_repr(self): constraint = ExclusionConstraint( - name='exclude_overlapping', - expressions=[(F('datespan'), RangeOperators.ADJACENT_TO)], - opclasses=['range_ops'], + name="exclude_overlapping", + expressions=[(F("datespan"), RangeOperators.ADJACENT_TO)], + opclasses=["range_ops"], ) self.assertEqual( repr(constraint), @@ -1000,12 +1096,14 @@ class ExclusionConstraintOpclassesDepracationTests(PostgreSQLTestCase): @ignore_warnings(category=RemovedInDjango50Warning) def test_range_adjacent_opclasses(self): - constraint_name = 'ints_adjacent_opclasses' - self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + constraint_name = "ints_adjacent_opclasses" + self.assertNotIn( + constraint_name, self.get_constraints(RangesModel._meta.db_table) + ) constraint = ExclusionConstraint( name=constraint_name, - expressions=[('ints', RangeOperators.ADJACENT_TO)], - opclasses=['range_ops'], + expressions=[("ints", RangeOperators.ADJACENT_TO)], + opclasses=["range_ops"], ) with connection.schema_editor() as editor: editor.add_constraint(RangesModel, constraint) @@ -1015,7 +1113,7 @@ class ExclusionConstraintOpclassesDepracationTests(PostgreSQLTestCase): cursor.execute(SchemaTests.get_opclass_query, [constraint.name]) self.assertEqual( cursor.fetchall(), - [('range_ops', constraint.name)], + [("range_ops", constraint.name)], ) RangesModel.objects.create(ints=(20, 50)) with self.assertRaises(IntegrityError), transaction.atomic(): @@ -1025,16 +1123,20 @@ class ExclusionConstraintOpclassesDepracationTests(PostgreSQLTestCase): # Drop the constraint. with connection.schema_editor() as editor: editor.remove_constraint(RangesModel, constraint) - self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + self.assertNotIn( + constraint_name, self.get_constraints(RangesModel._meta.db_table) + ) @ignore_warnings(category=RemovedInDjango50Warning) def test_range_adjacent_opclasses_condition(self): - constraint_name = 'ints_adjacent_opclasses_condition' - self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + constraint_name = "ints_adjacent_opclasses_condition" + self.assertNotIn( + constraint_name, self.get_constraints(RangesModel._meta.db_table) + ) constraint = ExclusionConstraint( name=constraint_name, - expressions=[('ints', RangeOperators.ADJACENT_TO)], - opclasses=['range_ops'], + expressions=[("ints", RangeOperators.ADJACENT_TO)], + opclasses=["range_ops"], condition=Q(id__gte=100), ) with connection.schema_editor() as editor: @@ -1043,12 +1145,14 @@ class ExclusionConstraintOpclassesDepracationTests(PostgreSQLTestCase): @ignore_warnings(category=RemovedInDjango50Warning) def test_range_adjacent_opclasses_deferrable(self): - constraint_name = 'ints_adjacent_opclasses_deferrable' - self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + constraint_name = "ints_adjacent_opclasses_deferrable" + self.assertNotIn( + constraint_name, self.get_constraints(RangesModel._meta.db_table) + ) constraint = ExclusionConstraint( name=constraint_name, - expressions=[('ints', RangeOperators.ADJACENT_TO)], - opclasses=['range_ops'], + expressions=[("ints", RangeOperators.ADJACENT_TO)], + opclasses=["range_ops"], deferrable=Deferrable.DEFERRED, ) with connection.schema_editor() as editor: @@ -1056,32 +1160,36 @@ class ExclusionConstraintOpclassesDepracationTests(PostgreSQLTestCase): self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) @ignore_warnings(category=RemovedInDjango50Warning) - @skipUnlessDBFeature('supports_covering_gist_indexes') + @skipUnlessDBFeature("supports_covering_gist_indexes") def test_range_adjacent_gist_opclasses_include(self): - constraint_name = 'ints_adjacent_gist_opclasses_include' - self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + constraint_name = "ints_adjacent_gist_opclasses_include" + self.assertNotIn( + constraint_name, self.get_constraints(RangesModel._meta.db_table) + ) constraint = ExclusionConstraint( name=constraint_name, - expressions=[('ints', RangeOperators.ADJACENT_TO)], - index_type='gist', - opclasses=['range_ops'], - include=['decimals'], + expressions=[("ints", RangeOperators.ADJACENT_TO)], + index_type="gist", + opclasses=["range_ops"], + include=["decimals"], ) with connection.schema_editor() as editor: editor.add_constraint(RangesModel, constraint) self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) @ignore_warnings(category=RemovedInDjango50Warning) - @skipUnlessDBFeature('supports_covering_spgist_indexes') + @skipUnlessDBFeature("supports_covering_spgist_indexes") def test_range_adjacent_spgist_opclasses_include(self): - constraint_name = 'ints_adjacent_spgist_opclasses_include' - self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + constraint_name = "ints_adjacent_spgist_opclasses_include" + self.assertNotIn( + constraint_name, self.get_constraints(RangesModel._meta.db_table) + ) constraint = ExclusionConstraint( name=constraint_name, - expressions=[('ints', RangeOperators.ADJACENT_TO)], - index_type='spgist', - opclasses=['range_ops'], - include=['decimals'], + expressions=[("ints", RangeOperators.ADJACENT_TO)], + index_type="spgist", + opclasses=["range_ops"], + include=["decimals"], ) with connection.schema_editor() as editor: editor.add_constraint(RangesModel, constraint) diff --git a/tests/postgres_tests/test_functions.py b/tests/postgres_tests/test_functions.py index 875a4b9520..e500580645 100644 --- a/tests/postgres_tests/test_functions.py +++ b/tests/postgres_tests/test_functions.py @@ -9,7 +9,6 @@ from .models import NowTestModel, UUIDTestModel class TestTransactionNow(PostgreSQLTestCase): - def test_transaction_now(self): """ The test case puts everything under a transaction, so two models @@ -30,7 +29,6 @@ class TestTransactionNow(PostgreSQLTestCase): class TestRandomUUID(PostgreSQLTestCase): - def test_random_uuid(self): m1 = UUIDTestModel.objects.create() m2 = UUIDTestModel.objects.create() diff --git a/tests/postgres_tests/test_hstore.py b/tests/postgres_tests/test_hstore.py index b6fdd2da60..2aaad637c6 100644 --- a/tests/postgres_tests/test_hstore.py +++ b/tests/postgres_tests/test_hstore.py @@ -21,7 +21,7 @@ except ImportError: class SimpleTests(PostgreSQLTestCase): def test_save_load_success(self): - value = {'a': 'b'} + value = {"a": "b"} instance = HStoreModel(field=value) instance.save() reloaded = HStoreModel.objects.get() @@ -34,15 +34,15 @@ class SimpleTests(PostgreSQLTestCase): self.assertIsNone(reloaded.field) def test_value_null(self): - value = {'a': None} + value = {"a": None} instance = HStoreModel(field=value) instance.save() reloaded = HStoreModel.objects.get() self.assertEqual(reloaded.field, value) def test_key_val_cast_to_string(self): - value = {'a': 1, 'b': 'B', 2: 'c', 'ï': 'ê'} - expected_value = {'a': '1', 'b': 'B', '2': 'c', 'ï': 'ê'} + value = {"a": 1, "b": "B", 2: "c", "ï": "ê"} + expected_value = {"a": "1", "b": "B", "2": "c", "ï": "ê"} instance = HStoreModel.objects.create(field=value) instance = HStoreModel.objects.get() @@ -51,17 +51,17 @@ class SimpleTests(PostgreSQLTestCase): instance = HStoreModel.objects.get(field__a=1) self.assertEqual(instance.field, expected_value) - instance = HStoreModel.objects.get(field__has_keys=[2, 'a', 'ï']) + instance = HStoreModel.objects.get(field__has_keys=[2, "a", "ï"]) self.assertEqual(instance.field, expected_value) def test_array_field(self): value = [ - {'a': 1, 'b': 'B', 2: 'c', 'ï': 'ê'}, - {'a': 1, 'b': 'B', 2: 'c', 'ï': 'ê'}, + {"a": 1, "b": "B", 2: "c", "ï": "ê"}, + {"a": 1, "b": "B", 2: "c", "ï": "ê"}, ] expected_value = [ - {'a': '1', 'b': 'B', '2': 'c', 'ï': 'ê'}, - {'a': '1', 'b': 'B', '2': 'c', 'ï': 'ê'}, + {"a": "1", "b": "B", "2": "c", "ï": "ê"}, + {"a": "1", "b": "B", "2": "c", "ï": "ê"}, ] instance = HStoreModel.objects.create(array_field=value) instance.refresh_from_db() @@ -69,231 +69,225 @@ class SimpleTests(PostgreSQLTestCase): class TestQuerying(PostgreSQLTestCase): - @classmethod def setUpTestData(cls): - cls.objs = HStoreModel.objects.bulk_create([ - HStoreModel(field={'a': 'b'}), - HStoreModel(field={'a': 'b', 'c': 'd'}), - HStoreModel(field={'c': 'd'}), - HStoreModel(field={}), - HStoreModel(field=None), - HStoreModel(field={'cat': 'TigrOu', 'breed': 'birman'}), - HStoreModel(field={'cat': 'minou', 'breed': 'ragdoll'}), - HStoreModel(field={'cat': 'kitty', 'breed': 'Persian'}), - HStoreModel(field={'cat': 'Kit Kat', 'breed': 'persian'}), - ]) + cls.objs = HStoreModel.objects.bulk_create( + [ + HStoreModel(field={"a": "b"}), + HStoreModel(field={"a": "b", "c": "d"}), + HStoreModel(field={"c": "d"}), + HStoreModel(field={}), + HStoreModel(field=None), + HStoreModel(field={"cat": "TigrOu", "breed": "birman"}), + HStoreModel(field={"cat": "minou", "breed": "ragdoll"}), + HStoreModel(field={"cat": "kitty", "breed": "Persian"}), + HStoreModel(field={"cat": "Kit Kat", "breed": "persian"}), + ] + ) def test_exact(self): self.assertSequenceEqual( - HStoreModel.objects.filter(field__exact={'a': 'b'}), - self.objs[:1] + HStoreModel.objects.filter(field__exact={"a": "b"}), self.objs[:1] ) def test_contained_by(self): self.assertSequenceEqual( - HStoreModel.objects.filter(field__contained_by={'a': 'b', 'c': 'd'}), - self.objs[:4] + HStoreModel.objects.filter(field__contained_by={"a": "b", "c": "d"}), + self.objs[:4], ) def test_contains(self): self.assertSequenceEqual( - HStoreModel.objects.filter(field__contains={'a': 'b'}), - self.objs[:2] + HStoreModel.objects.filter(field__contains={"a": "b"}), self.objs[:2] ) def test_in_generator(self): def search(): - yield {'a': 'b'} + yield {"a": "b"} + self.assertSequenceEqual( - HStoreModel.objects.filter(field__in=search()), - self.objs[:1] + HStoreModel.objects.filter(field__in=search()), self.objs[:1] ) def test_has_key(self): self.assertSequenceEqual( - HStoreModel.objects.filter(field__has_key='c'), - self.objs[1:3] + HStoreModel.objects.filter(field__has_key="c"), self.objs[1:3] ) def test_has_keys(self): self.assertSequenceEqual( - HStoreModel.objects.filter(field__has_keys=['a', 'c']), - self.objs[1:2] + HStoreModel.objects.filter(field__has_keys=["a", "c"]), self.objs[1:2] ) def test_has_any_keys(self): self.assertSequenceEqual( - HStoreModel.objects.filter(field__has_any_keys=['a', 'c']), - self.objs[:3] + HStoreModel.objects.filter(field__has_any_keys=["a", "c"]), self.objs[:3] ) def test_key_transform(self): self.assertSequenceEqual( - HStoreModel.objects.filter(field__a='b'), - self.objs[:2] + HStoreModel.objects.filter(field__a="b"), self.objs[:2] ) def test_key_transform_raw_expression(self): - expr = RawSQL('%s::hstore', ['x => b, y => c']) + expr = RawSQL("%s::hstore", ["x => b, y => c"]) self.assertSequenceEqual( - HStoreModel.objects.filter(field__a=KeyTransform('x', expr)), - self.objs[:2] + HStoreModel.objects.filter(field__a=KeyTransform("x", expr)), self.objs[:2] ) def test_key_transform_annotation(self): - qs = HStoreModel.objects.annotate(a=F('field__a')) + qs = HStoreModel.objects.annotate(a=F("field__a")) self.assertCountEqual( - qs.values_list('a', flat=True), - ['b', 'b', None, None, None, None, None, None, None], + qs.values_list("a", flat=True), + ["b", "b", None, None, None, None, None, None, None], ) def test_keys(self): self.assertSequenceEqual( - HStoreModel.objects.filter(field__keys=['a']), - self.objs[:1] + HStoreModel.objects.filter(field__keys=["a"]), self.objs[:1] ) def test_values(self): self.assertSequenceEqual( - HStoreModel.objects.filter(field__values=['b']), - self.objs[:1] + HStoreModel.objects.filter(field__values=["b"]), self.objs[:1] ) def test_field_chaining_contains(self): self.assertSequenceEqual( - HStoreModel.objects.filter(field__a__contains='b'), - self.objs[:2] + HStoreModel.objects.filter(field__a__contains="b"), self.objs[:2] ) def test_field_chaining_icontains(self): self.assertSequenceEqual( - HStoreModel.objects.filter(field__cat__icontains='INo'), + HStoreModel.objects.filter(field__cat__icontains="INo"), [self.objs[6]], ) def test_field_chaining_startswith(self): self.assertSequenceEqual( - HStoreModel.objects.filter(field__cat__startswith='kit'), + HStoreModel.objects.filter(field__cat__startswith="kit"), [self.objs[7]], ) def test_field_chaining_istartswith(self): self.assertSequenceEqual( - HStoreModel.objects.filter(field__cat__istartswith='kit'), + HStoreModel.objects.filter(field__cat__istartswith="kit"), self.objs[7:], ) def test_field_chaining_endswith(self): self.assertSequenceEqual( - HStoreModel.objects.filter(field__cat__endswith='ou'), + HStoreModel.objects.filter(field__cat__endswith="ou"), [self.objs[6]], ) def test_field_chaining_iendswith(self): self.assertSequenceEqual( - HStoreModel.objects.filter(field__cat__iendswith='ou'), + HStoreModel.objects.filter(field__cat__iendswith="ou"), self.objs[5:7], ) def test_field_chaining_iexact(self): self.assertSequenceEqual( - HStoreModel.objects.filter(field__breed__iexact='persian'), + HStoreModel.objects.filter(field__breed__iexact="persian"), self.objs[7:], ) def test_field_chaining_regex(self): self.assertSequenceEqual( - HStoreModel.objects.filter(field__cat__regex=r'ou$'), + HStoreModel.objects.filter(field__cat__regex=r"ou$"), [self.objs[6]], ) def test_field_chaining_iregex(self): self.assertSequenceEqual( - HStoreModel.objects.filter(field__cat__iregex=r'oU$'), + HStoreModel.objects.filter(field__cat__iregex=r"oU$"), self.objs[5:7], ) def test_order_by_field(self): more_objs = ( - HStoreModel.objects.create(field={'g': '637'}), - HStoreModel.objects.create(field={'g': '002'}), - HStoreModel.objects.create(field={'g': '042'}), - HStoreModel.objects.create(field={'g': '981'}), + HStoreModel.objects.create(field={"g": "637"}), + HStoreModel.objects.create(field={"g": "002"}), + HStoreModel.objects.create(field={"g": "042"}), + HStoreModel.objects.create(field={"g": "981"}), ) self.assertSequenceEqual( - HStoreModel.objects.filter(field__has_key='g').order_by('field__g'), - [more_objs[1], more_objs[2], more_objs[0], more_objs[3]] + HStoreModel.objects.filter(field__has_key="g").order_by("field__g"), + [more_objs[1], more_objs[2], more_objs[0], more_objs[3]], ) def test_keys_contains(self): self.assertSequenceEqual( - HStoreModel.objects.filter(field__keys__contains=['a']), - self.objs[:2] + HStoreModel.objects.filter(field__keys__contains=["a"]), self.objs[:2] ) def test_values_overlap(self): self.assertSequenceEqual( - HStoreModel.objects.filter(field__values__overlap=['b', 'd']), - self.objs[:3] + HStoreModel.objects.filter(field__values__overlap=["b", "d"]), self.objs[:3] ) def test_key_isnull(self): - obj = HStoreModel.objects.create(field={'a': None}) + obj = HStoreModel.objects.create(field={"a": None}) self.assertSequenceEqual( HStoreModel.objects.filter(field__a__isnull=True), self.objs[2:9] + [obj], ) self.assertSequenceEqual( - HStoreModel.objects.filter(field__a__isnull=False), - self.objs[:2] + HStoreModel.objects.filter(field__a__isnull=False), self.objs[:2] ) def test_usage_in_subquery(self): self.assertSequenceEqual( - HStoreModel.objects.filter(id__in=HStoreModel.objects.filter(field__a='b')), - self.objs[:2] + HStoreModel.objects.filter(id__in=HStoreModel.objects.filter(field__a="b")), + self.objs[:2], ) def test_key_sql_injection(self): with CaptureQueriesContext(connection) as queries: self.assertFalse( - HStoreModel.objects.filter(**{ - "field__test' = 'a') OR 1 = 1 OR ('d": 'x', - }).exists() + HStoreModel.objects.filter( + **{ + "field__test' = 'a') OR 1 = 1 OR ('d": "x", + } + ).exists() ) self.assertIn( """."field" -> 'test'' = ''a'') OR 1 = 1 OR (''d') = 'x' """, - queries[0]['sql'], + queries[0]["sql"], ) def test_obj_subquery_lookup(self): qs = HStoreModel.objects.annotate( - value=Subquery(HStoreModel.objects.filter(pk=OuterRef('pk')).values('field')), - ).filter(value__a='b') + value=Subquery( + HStoreModel.objects.filter(pk=OuterRef("pk")).values("field") + ), + ).filter(value__a="b") self.assertSequenceEqual(qs, self.objs[:2]) -@isolate_apps('postgres_tests') +@isolate_apps("postgres_tests") class TestChecks(PostgreSQLSimpleTestCase): - def test_invalid_default(self): class MyModel(PostgreSQLModel): field = HStoreField(default={}) model = MyModel() - self.assertEqual(model.check(), [ - checks.Warning( - msg=( - "HStoreField default should be a callable instead of an " - "instance so that it's not shared between all field " - "instances." - ), - hint='Use a callable instead, e.g., use `dict` instead of `{}`.', - obj=MyModel._meta.get_field('field'), - id='fields.E010', - ) - ]) + self.assertEqual( + model.check(), + [ + checks.Warning( + msg=( + "HStoreField default should be a callable instead of an " + "instance so that it's not shared between all field " + "instances." + ), + hint="Use a callable instead, e.g., use `dict` instead of `{}`.", + obj=MyModel._meta.get_field("field"), + id="fields.E010", + ) + ], + ) def test_valid_default(self): class MyModel(PostgreSQLModel): @@ -303,83 +297,90 @@ class TestChecks(PostgreSQLSimpleTestCase): class TestSerialization(PostgreSQLSimpleTestCase): - test_data = json.dumps([{ - 'model': 'postgres_tests.hstoremodel', - 'pk': None, - 'fields': { - 'field': json.dumps({'a': 'b'}), - 'array_field': json.dumps([ - json.dumps({'a': 'b'}), - json.dumps({'b': 'a'}), - ]), - }, - }]) + test_data = json.dumps( + [ + { + "model": "postgres_tests.hstoremodel", + "pk": None, + "fields": { + "field": json.dumps({"a": "b"}), + "array_field": json.dumps( + [ + json.dumps({"a": "b"}), + json.dumps({"b": "a"}), + ] + ), + }, + } + ] + ) def test_dumping(self): - instance = HStoreModel(field={'a': 'b'}, array_field=[{'a': 'b'}, {'b': 'a'}]) - data = serializers.serialize('json', [instance]) + instance = HStoreModel(field={"a": "b"}, array_field=[{"a": "b"}, {"b": "a"}]) + data = serializers.serialize("json", [instance]) self.assertEqual(json.loads(data), json.loads(self.test_data)) def test_loading(self): - instance = list(serializers.deserialize('json', self.test_data))[0].object - self.assertEqual(instance.field, {'a': 'b'}) - self.assertEqual(instance.array_field, [{'a': 'b'}, {'b': 'a'}]) + instance = list(serializers.deserialize("json", self.test_data))[0].object + self.assertEqual(instance.field, {"a": "b"}) + self.assertEqual(instance.array_field, [{"a": "b"}, {"b": "a"}]) def test_roundtrip_with_null(self): - instance = HStoreModel(field={'a': 'b', 'c': None}) - data = serializers.serialize('json', [instance]) - new_instance = list(serializers.deserialize('json', data))[0].object + instance = HStoreModel(field={"a": "b", "c": None}) + data = serializers.serialize("json", [instance]) + new_instance = list(serializers.deserialize("json", data))[0].object self.assertEqual(instance.field, new_instance.field) class TestValidation(PostgreSQLSimpleTestCase): - def test_not_a_string(self): field = HStoreField() with self.assertRaises(exceptions.ValidationError) as cm: - field.clean({'a': 1}, None) - self.assertEqual(cm.exception.code, 'not_a_string') - self.assertEqual(cm.exception.message % cm.exception.params, 'The value of “a” is not a string or null.') + field.clean({"a": 1}, None) + self.assertEqual(cm.exception.code, "not_a_string") + self.assertEqual( + cm.exception.message % cm.exception.params, + "The value of “a” is not a string or null.", + ) def test_none_allowed_as_value(self): field = HStoreField() - self.assertEqual(field.clean({'a': None}, None), {'a': None}) + self.assertEqual(field.clean({"a": None}, None), {"a": None}) class TestFormField(PostgreSQLSimpleTestCase): - def test_valid(self): field = forms.HStoreField() value = field.clean('{"a": "b"}') - self.assertEqual(value, {'a': 'b'}) + self.assertEqual(value, {"a": "b"}) def test_invalid_json(self): field = forms.HStoreField() with self.assertRaises(exceptions.ValidationError) as cm: field.clean('{"a": "b"') - self.assertEqual(cm.exception.messages[0], 'Could not load JSON data.') - self.assertEqual(cm.exception.code, 'invalid_json') + self.assertEqual(cm.exception.messages[0], "Could not load JSON data.") + self.assertEqual(cm.exception.code, "invalid_json") def test_non_dict_json(self): field = forms.HStoreField() - msg = 'Input must be a JSON dictionary.' + msg = "Input must be a JSON dictionary." with self.assertRaisesMessage(exceptions.ValidationError, msg) as cm: field.clean('["a", "b", 1]') - self.assertEqual(cm.exception.code, 'invalid_format') + self.assertEqual(cm.exception.code, "invalid_format") def test_not_string_values(self): field = forms.HStoreField() value = field.clean('{"a": 1}') - self.assertEqual(value, {'a': '1'}) + self.assertEqual(value, {"a": "1"}) def test_none_value(self): field = forms.HStoreField() value = field.clean('{"a": null}') - self.assertEqual(value, {'a': None}) + self.assertEqual(value, {"a": None}) def test_empty(self): field = forms.HStoreField(required=False) - value = field.clean('') + value = field.clean("") self.assertEqual(value, {}) def test_model_field_formfield(self): @@ -390,69 +391,71 @@ class TestFormField(PostgreSQLSimpleTestCase): def test_field_has_changed(self): class HStoreFormTest(Form): f1 = forms.HStoreField() + form_w_hstore = HStoreFormTest() self.assertFalse(form_w_hstore.has_changed()) - form_w_hstore = HStoreFormTest({'f1': '{"a": 1}'}) + form_w_hstore = HStoreFormTest({"f1": '{"a": 1}'}) self.assertTrue(form_w_hstore.has_changed()) - form_w_hstore = HStoreFormTest({'f1': '{"a": 1}'}, initial={'f1': '{"a": 1}'}) + form_w_hstore = HStoreFormTest({"f1": '{"a": 1}'}, initial={"f1": '{"a": 1}'}) self.assertFalse(form_w_hstore.has_changed()) - form_w_hstore = HStoreFormTest({'f1': '{"a": 2}'}, initial={'f1': '{"a": 1}'}) + form_w_hstore = HStoreFormTest({"f1": '{"a": 2}'}, initial={"f1": '{"a": 1}'}) self.assertTrue(form_w_hstore.has_changed()) - form_w_hstore = HStoreFormTest({'f1': '{"a": 1}'}, initial={'f1': {"a": 1}}) + form_w_hstore = HStoreFormTest({"f1": '{"a": 1}'}, initial={"f1": {"a": 1}}) self.assertFalse(form_w_hstore.has_changed()) - form_w_hstore = HStoreFormTest({'f1': '{"a": 2}'}, initial={'f1': {"a": 1}}) + form_w_hstore = HStoreFormTest({"f1": '{"a": 2}'}, initial={"f1": {"a": 1}}) self.assertTrue(form_w_hstore.has_changed()) class TestValidator(PostgreSQLSimpleTestCase): - def test_simple_valid(self): - validator = KeysValidator(keys=['a', 'b']) - validator({'a': 'foo', 'b': 'bar', 'c': 'baz'}) + validator = KeysValidator(keys=["a", "b"]) + validator({"a": "foo", "b": "bar", "c": "baz"}) def test_missing_keys(self): - validator = KeysValidator(keys=['a', 'b']) + validator = KeysValidator(keys=["a", "b"]) with self.assertRaises(exceptions.ValidationError) as cm: - validator({'a': 'foo', 'c': 'baz'}) - self.assertEqual(cm.exception.messages[0], 'Some keys were missing: b') - self.assertEqual(cm.exception.code, 'missing_keys') + validator({"a": "foo", "c": "baz"}) + self.assertEqual(cm.exception.messages[0], "Some keys were missing: b") + self.assertEqual(cm.exception.code, "missing_keys") def test_strict_valid(self): - validator = KeysValidator(keys=['a', 'b'], strict=True) - validator({'a': 'foo', 'b': 'bar'}) + validator = KeysValidator(keys=["a", "b"], strict=True) + validator({"a": "foo", "b": "bar"}) def test_extra_keys(self): - validator = KeysValidator(keys=['a', 'b'], strict=True) + validator = KeysValidator(keys=["a", "b"], strict=True) with self.assertRaises(exceptions.ValidationError) as cm: - validator({'a': 'foo', 'b': 'bar', 'c': 'baz'}) - self.assertEqual(cm.exception.messages[0], 'Some unknown keys were provided: c') - self.assertEqual(cm.exception.code, 'extra_keys') + validator({"a": "foo", "b": "bar", "c": "baz"}) + self.assertEqual(cm.exception.messages[0], "Some unknown keys were provided: c") + self.assertEqual(cm.exception.code, "extra_keys") def test_custom_messages(self): messages = { - 'missing_keys': 'Foobar', + "missing_keys": "Foobar", } - validator = KeysValidator(keys=['a', 'b'], strict=True, messages=messages) + validator = KeysValidator(keys=["a", "b"], strict=True, messages=messages) with self.assertRaises(exceptions.ValidationError) as cm: - validator({'a': 'foo', 'c': 'baz'}) - self.assertEqual(cm.exception.messages[0], 'Foobar') - self.assertEqual(cm.exception.code, 'missing_keys') + validator({"a": "foo", "c": "baz"}) + self.assertEqual(cm.exception.messages[0], "Foobar") + self.assertEqual(cm.exception.code, "missing_keys") with self.assertRaises(exceptions.ValidationError) as cm: - validator({'a': 'foo', 'b': 'bar', 'c': 'baz'}) - self.assertEqual(cm.exception.messages[0], 'Some unknown keys were provided: c') - self.assertEqual(cm.exception.code, 'extra_keys') + validator({"a": "foo", "b": "bar", "c": "baz"}) + self.assertEqual(cm.exception.messages[0], "Some unknown keys were provided: c") + self.assertEqual(cm.exception.code, "extra_keys") def test_deconstruct(self): messages = { - 'missing_keys': 'Foobar', + "missing_keys": "Foobar", } - validator = KeysValidator(keys=['a', 'b'], strict=True, messages=messages) + validator = KeysValidator(keys=["a", "b"], strict=True, messages=messages) path, args, kwargs = validator.deconstruct() - self.assertEqual(path, 'django.contrib.postgres.validators.KeysValidator') + self.assertEqual(path, "django.contrib.postgres.validators.KeysValidator") self.assertEqual(args, ()) - self.assertEqual(kwargs, {'keys': ['a', 'b'], 'strict': True, 'messages': messages}) + self.assertEqual( + kwargs, {"keys": ["a", "b"], "strict": True, "messages": messages} + ) diff --git a/tests/postgres_tests/test_indexes.py b/tests/postgres_tests/test_indexes.py index 75d0640a08..f57e7a7c80 100644 --- a/tests/postgres_tests/test_indexes.py +++ b/tests/postgres_tests/test_indexes.py @@ -1,7 +1,13 @@ from unittest import mock from django.contrib.postgres.indexes import ( - BloomIndex, BrinIndex, BTreeIndex, GinIndex, GistIndex, HashIndex, OpClass, + BloomIndex, + BrinIndex, + BTreeIndex, + GinIndex, + GistIndex, + HashIndex, + OpClass, SpGistIndex, ) from django.db import NotSupportedError, connection @@ -16,191 +22,226 @@ from .models import CharFieldModel, IntegerArrayModel, Scene, TextFieldModel class IndexTestMixin: - def test_name_auto_generation(self): - index = self.index_class(fields=['field']) + index = self.index_class(fields=["field"]) index.set_name_with_model(CharFieldModel) - self.assertRegex(index.name, r'postgres_te_field_[0-9a-f]{6}_%s' % self.index_class.suffix) + self.assertRegex( + index.name, r"postgres_te_field_[0-9a-f]{6}_%s" % self.index_class.suffix + ) def test_deconstruction_no_customization(self): - index = self.index_class(fields=['title'], name='test_title_%s' % self.index_class.suffix) + index = self.index_class( + fields=["title"], name="test_title_%s" % self.index_class.suffix + ) path, args, kwargs = index.deconstruct() - self.assertEqual(path, 'django.contrib.postgres.indexes.%s' % self.index_class.__name__) + self.assertEqual( + path, "django.contrib.postgres.indexes.%s" % self.index_class.__name__ + ) self.assertEqual(args, ()) - self.assertEqual(kwargs, {'fields': ['title'], 'name': 'test_title_%s' % self.index_class.suffix}) + self.assertEqual( + kwargs, + {"fields": ["title"], "name": "test_title_%s" % self.index_class.suffix}, + ) def test_deconstruction_with_expressions_no_customization(self): - name = f'test_title_{self.index_class.suffix}' - index = self.index_class(Lower('title'), name=name) + name = f"test_title_{self.index_class.suffix}" + index = self.index_class(Lower("title"), name=name) path, args, kwargs = index.deconstruct() self.assertEqual( path, - f'django.contrib.postgres.indexes.{self.index_class.__name__}', + f"django.contrib.postgres.indexes.{self.index_class.__name__}", ) - self.assertEqual(args, (Lower('title'),)) - self.assertEqual(kwargs, {'name': name}) + self.assertEqual(args, (Lower("title"),)) + self.assertEqual(kwargs, {"name": name}) class BloomIndexTests(IndexTestMixin, PostgreSQLSimpleTestCase): index_class = BloomIndex def test_suffix(self): - self.assertEqual(BloomIndex.suffix, 'bloom') + self.assertEqual(BloomIndex.suffix, "bloom") def test_deconstruction(self): - index = BloomIndex(fields=['title'], name='test_bloom', length=80, columns=[4]) + index = BloomIndex(fields=["title"], name="test_bloom", length=80, columns=[4]) path, args, kwargs = index.deconstruct() - self.assertEqual(path, 'django.contrib.postgres.indexes.BloomIndex') + self.assertEqual(path, "django.contrib.postgres.indexes.BloomIndex") self.assertEqual(args, ()) - self.assertEqual(kwargs, { - 'fields': ['title'], - 'name': 'test_bloom', - 'length': 80, - 'columns': [4], - }) + self.assertEqual( + kwargs, + { + "fields": ["title"], + "name": "test_bloom", + "length": 80, + "columns": [4], + }, + ) def test_invalid_fields(self): - msg = 'Bloom indexes support a maximum of 32 fields.' + msg = "Bloom indexes support a maximum of 32 fields." with self.assertRaisesMessage(ValueError, msg): - BloomIndex(fields=['title'] * 33, name='test_bloom') + BloomIndex(fields=["title"] * 33, name="test_bloom") def test_invalid_columns(self): - msg = 'BloomIndex.columns must be a list or tuple.' + msg = "BloomIndex.columns must be a list or tuple." with self.assertRaisesMessage(ValueError, msg): - BloomIndex(fields=['title'], name='test_bloom', columns='x') - msg = 'BloomIndex.columns cannot have more values than fields.' + BloomIndex(fields=["title"], name="test_bloom", columns="x") + msg = "BloomIndex.columns cannot have more values than fields." with self.assertRaisesMessage(ValueError, msg): - BloomIndex(fields=['title'], name='test_bloom', columns=[4, 3]) + BloomIndex(fields=["title"], name="test_bloom", columns=[4, 3]) def test_invalid_columns_value(self): - msg = 'BloomIndex.columns must contain integers from 1 to 4095.' + msg = "BloomIndex.columns must contain integers from 1 to 4095." for length in (0, 4096): with self.subTest(length), self.assertRaisesMessage(ValueError, msg): - BloomIndex(fields=['title'], name='test_bloom', columns=[length]) + BloomIndex(fields=["title"], name="test_bloom", columns=[length]) def test_invalid_length(self): - msg = 'BloomIndex.length must be None or an integer from 1 to 4096.' + msg = "BloomIndex.length must be None or an integer from 1 to 4096." for length in (0, 4097): with self.subTest(length), self.assertRaisesMessage(ValueError, msg): - BloomIndex(fields=['title'], name='test_bloom', length=length) + BloomIndex(fields=["title"], name="test_bloom", length=length) class BrinIndexTests(IndexTestMixin, PostgreSQLSimpleTestCase): index_class = BrinIndex def test_suffix(self): - self.assertEqual(BrinIndex.suffix, 'brin') + self.assertEqual(BrinIndex.suffix, "brin") def test_deconstruction(self): - index = BrinIndex(fields=['title'], name='test_title_brin', autosummarize=True, pages_per_range=16) + index = BrinIndex( + fields=["title"], + name="test_title_brin", + autosummarize=True, + pages_per_range=16, + ) path, args, kwargs = index.deconstruct() - self.assertEqual(path, 'django.contrib.postgres.indexes.BrinIndex') + self.assertEqual(path, "django.contrib.postgres.indexes.BrinIndex") self.assertEqual(args, ()) - self.assertEqual(kwargs, { - 'fields': ['title'], - 'name': 'test_title_brin', - 'autosummarize': True, - 'pages_per_range': 16, - }) + self.assertEqual( + kwargs, + { + "fields": ["title"], + "name": "test_title_brin", + "autosummarize": True, + "pages_per_range": 16, + }, + ) def test_invalid_pages_per_range(self): - with self.assertRaisesMessage(ValueError, 'pages_per_range must be None or a positive integer'): - BrinIndex(fields=['title'], name='test_title_brin', pages_per_range=0) + with self.assertRaisesMessage( + ValueError, "pages_per_range must be None or a positive integer" + ): + BrinIndex(fields=["title"], name="test_title_brin", pages_per_range=0) class BTreeIndexTests(IndexTestMixin, PostgreSQLSimpleTestCase): index_class = BTreeIndex def test_suffix(self): - self.assertEqual(BTreeIndex.suffix, 'btree') + self.assertEqual(BTreeIndex.suffix, "btree") def test_deconstruction(self): - index = BTreeIndex(fields=['title'], name='test_title_btree', fillfactor=80) + index = BTreeIndex(fields=["title"], name="test_title_btree", fillfactor=80) path, args, kwargs = index.deconstruct() - self.assertEqual(path, 'django.contrib.postgres.indexes.BTreeIndex') + self.assertEqual(path, "django.contrib.postgres.indexes.BTreeIndex") self.assertEqual(args, ()) - self.assertEqual(kwargs, {'fields': ['title'], 'name': 'test_title_btree', 'fillfactor': 80}) + self.assertEqual( + kwargs, {"fields": ["title"], "name": "test_title_btree", "fillfactor": 80} + ) class GinIndexTests(IndexTestMixin, PostgreSQLSimpleTestCase): index_class = GinIndex def test_suffix(self): - self.assertEqual(GinIndex.suffix, 'gin') + self.assertEqual(GinIndex.suffix, "gin") def test_deconstruction(self): index = GinIndex( - fields=['title'], - name='test_title_gin', + fields=["title"], + name="test_title_gin", fastupdate=True, gin_pending_list_limit=128, ) path, args, kwargs = index.deconstruct() - self.assertEqual(path, 'django.contrib.postgres.indexes.GinIndex') + self.assertEqual(path, "django.contrib.postgres.indexes.GinIndex") self.assertEqual(args, ()) - self.assertEqual(kwargs, { - 'fields': ['title'], - 'name': 'test_title_gin', - 'fastupdate': True, - 'gin_pending_list_limit': 128, - }) + self.assertEqual( + kwargs, + { + "fields": ["title"], + "name": "test_title_gin", + "fastupdate": True, + "gin_pending_list_limit": 128, + }, + ) class GistIndexTests(IndexTestMixin, PostgreSQLSimpleTestCase): index_class = GistIndex def test_suffix(self): - self.assertEqual(GistIndex.suffix, 'gist') + self.assertEqual(GistIndex.suffix, "gist") def test_deconstruction(self): - index = GistIndex(fields=['title'], name='test_title_gist', buffering=False, fillfactor=80) + index = GistIndex( + fields=["title"], name="test_title_gist", buffering=False, fillfactor=80 + ) path, args, kwargs = index.deconstruct() - self.assertEqual(path, 'django.contrib.postgres.indexes.GistIndex') + self.assertEqual(path, "django.contrib.postgres.indexes.GistIndex") self.assertEqual(args, ()) - self.assertEqual(kwargs, { - 'fields': ['title'], - 'name': 'test_title_gist', - 'buffering': False, - 'fillfactor': 80, - }) + self.assertEqual( + kwargs, + { + "fields": ["title"], + "name": "test_title_gist", + "buffering": False, + "fillfactor": 80, + }, + ) class HashIndexTests(IndexTestMixin, PostgreSQLSimpleTestCase): index_class = HashIndex def test_suffix(self): - self.assertEqual(HashIndex.suffix, 'hash') + self.assertEqual(HashIndex.suffix, "hash") def test_deconstruction(self): - index = HashIndex(fields=['title'], name='test_title_hash', fillfactor=80) + index = HashIndex(fields=["title"], name="test_title_hash", fillfactor=80) path, args, kwargs = index.deconstruct() - self.assertEqual(path, 'django.contrib.postgres.indexes.HashIndex') + self.assertEqual(path, "django.contrib.postgres.indexes.HashIndex") self.assertEqual(args, ()) - self.assertEqual(kwargs, {'fields': ['title'], 'name': 'test_title_hash', 'fillfactor': 80}) + self.assertEqual( + kwargs, {"fields": ["title"], "name": "test_title_hash", "fillfactor": 80} + ) class SpGistIndexTests(IndexTestMixin, PostgreSQLSimpleTestCase): index_class = SpGistIndex def test_suffix(self): - self.assertEqual(SpGistIndex.suffix, 'spgist') + self.assertEqual(SpGistIndex.suffix, "spgist") def test_deconstruction(self): - index = SpGistIndex(fields=['title'], name='test_title_spgist', fillfactor=80) + index = SpGistIndex(fields=["title"], name="test_title_spgist", fillfactor=80) path, args, kwargs = index.deconstruct() - self.assertEqual(path, 'django.contrib.postgres.indexes.SpGistIndex') + self.assertEqual(path, "django.contrib.postgres.indexes.SpGistIndex") self.assertEqual(args, ()) - self.assertEqual(kwargs, {'fields': ['title'], 'name': 'test_title_spgist', 'fillfactor': 80}) + self.assertEqual( + kwargs, {"fields": ["title"], "name": "test_title_spgist", "fillfactor": 80} + ) -@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}) +@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"}) class SchemaTests(PostgreSQLTestCase): - get_opclass_query = ''' + get_opclass_query = """ SELECT opcname, c.relname FROM pg_opclass AS oc JOIN pg_index as i on oc.oid = ANY(i.indclass) JOIN pg_class as c on c.oid = i.indexrelid WHERE c.relname = %s - ''' + """ def get_constraints(self, table): """ @@ -211,229 +252,274 @@ class SchemaTests(PostgreSQLTestCase): def test_gin_index(self): # Ensure the table is there and doesn't have an index. - self.assertNotIn('field', self.get_constraints(IntegerArrayModel._meta.db_table)) + self.assertNotIn( + "field", self.get_constraints(IntegerArrayModel._meta.db_table) + ) # Add the index - index_name = 'integer_array_model_field_gin' - index = GinIndex(fields=['field'], name=index_name) + index_name = "integer_array_model_field_gin" + index = GinIndex(fields=["field"], name=index_name) with connection.schema_editor() as editor: editor.add_index(IntegerArrayModel, index) constraints = self.get_constraints(IntegerArrayModel._meta.db_table) # Check gin index was added - self.assertEqual(constraints[index_name]['type'], GinIndex.suffix) + self.assertEqual(constraints[index_name]["type"], GinIndex.suffix) # Drop the index with connection.schema_editor() as editor: editor.remove_index(IntegerArrayModel, index) - self.assertNotIn(index_name, self.get_constraints(IntegerArrayModel._meta.db_table)) + self.assertNotIn( + index_name, self.get_constraints(IntegerArrayModel._meta.db_table) + ) def test_gin_fastupdate(self): - index_name = 'integer_array_gin_fastupdate' - index = GinIndex(fields=['field'], name=index_name, fastupdate=False) + index_name = "integer_array_gin_fastupdate" + index = GinIndex(fields=["field"], name=index_name, fastupdate=False) with connection.schema_editor() as editor: editor.add_index(IntegerArrayModel, index) constraints = self.get_constraints(IntegerArrayModel._meta.db_table) - self.assertEqual(constraints[index_name]['type'], 'gin') - self.assertEqual(constraints[index_name]['options'], ['fastupdate=off']) + self.assertEqual(constraints[index_name]["type"], "gin") + self.assertEqual(constraints[index_name]["options"], ["fastupdate=off"]) with connection.schema_editor() as editor: editor.remove_index(IntegerArrayModel, index) - self.assertNotIn(index_name, self.get_constraints(IntegerArrayModel._meta.db_table)) + self.assertNotIn( + index_name, self.get_constraints(IntegerArrayModel._meta.db_table) + ) def test_partial_gin_index(self): with register_lookup(CharField, Length): - index_name = 'char_field_gin_partial_idx' - index = GinIndex(fields=['field'], name=index_name, condition=Q(field__length=40)) + index_name = "char_field_gin_partial_idx" + index = GinIndex( + fields=["field"], name=index_name, condition=Q(field__length=40) + ) with connection.schema_editor() as editor: editor.add_index(CharFieldModel, index) constraints = self.get_constraints(CharFieldModel._meta.db_table) - self.assertEqual(constraints[index_name]['type'], 'gin') + self.assertEqual(constraints[index_name]["type"], "gin") with connection.schema_editor() as editor: editor.remove_index(CharFieldModel, index) - self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table)) + self.assertNotIn( + index_name, self.get_constraints(CharFieldModel._meta.db_table) + ) def test_partial_gin_index_with_tablespace(self): with register_lookup(CharField, Length): - index_name = 'char_field_gin_partial_idx' + index_name = "char_field_gin_partial_idx" index = GinIndex( - fields=['field'], + fields=["field"], name=index_name, condition=Q(field__length=40), - db_tablespace='pg_default', + db_tablespace="pg_default", ) with connection.schema_editor() as editor: editor.add_index(CharFieldModel, index) - self.assertIn('TABLESPACE "pg_default" ', str(index.create_sql(CharFieldModel, editor))) + self.assertIn( + 'TABLESPACE "pg_default" ', + str(index.create_sql(CharFieldModel, editor)), + ) constraints = self.get_constraints(CharFieldModel._meta.db_table) - self.assertEqual(constraints[index_name]['type'], 'gin') + self.assertEqual(constraints[index_name]["type"], "gin") with connection.schema_editor() as editor: editor.remove_index(CharFieldModel, index) - self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table)) + self.assertNotIn( + index_name, self.get_constraints(CharFieldModel._meta.db_table) + ) def test_gin_parameters(self): - index_name = 'integer_array_gin_params' - index = GinIndex(fields=['field'], name=index_name, fastupdate=True, gin_pending_list_limit=64) + index_name = "integer_array_gin_params" + index = GinIndex( + fields=["field"], + name=index_name, + fastupdate=True, + gin_pending_list_limit=64, + ) with connection.schema_editor() as editor: editor.add_index(IntegerArrayModel, index) constraints = self.get_constraints(IntegerArrayModel._meta.db_table) - self.assertEqual(constraints[index_name]['type'], 'gin') - self.assertEqual(constraints[index_name]['options'], ['gin_pending_list_limit=64', 'fastupdate=on']) + self.assertEqual(constraints[index_name]["type"], "gin") + self.assertEqual( + constraints[index_name]["options"], + ["gin_pending_list_limit=64", "fastupdate=on"], + ) with connection.schema_editor() as editor: editor.remove_index(IntegerArrayModel, index) - self.assertNotIn(index_name, self.get_constraints(IntegerArrayModel._meta.db_table)) + self.assertNotIn( + index_name, self.get_constraints(IntegerArrayModel._meta.db_table) + ) def test_trigram_op_class_gin_index(self): - index_name = 'trigram_op_class_gin' - index = GinIndex(OpClass(F('scene'), name='gin_trgm_ops'), name=index_name) + index_name = "trigram_op_class_gin" + index = GinIndex(OpClass(F("scene"), name="gin_trgm_ops"), name=index_name) with connection.schema_editor() as editor: editor.add_index(Scene, index) with editor.connection.cursor() as cursor: cursor.execute(self.get_opclass_query, [index_name]) - self.assertCountEqual(cursor.fetchall(), [('gin_trgm_ops', index_name)]) + self.assertCountEqual(cursor.fetchall(), [("gin_trgm_ops", index_name)]) constraints = self.get_constraints(Scene._meta.db_table) self.assertIn(index_name, constraints) - self.assertIn(constraints[index_name]['type'], GinIndex.suffix) + self.assertIn(constraints[index_name]["type"], GinIndex.suffix) with connection.schema_editor() as editor: editor.remove_index(Scene, index) self.assertNotIn(index_name, self.get_constraints(Scene._meta.db_table)) def test_cast_search_vector_gin_index(self): - index_name = 'cast_search_vector_gin' - index = GinIndex(Cast('field', SearchVectorField()), name=index_name) + index_name = "cast_search_vector_gin" + index = GinIndex(Cast("field", SearchVectorField()), name=index_name) with connection.schema_editor() as editor: editor.add_index(TextFieldModel, index) sql = index.create_sql(TextFieldModel, editor) table = TextFieldModel._meta.db_table constraints = self.get_constraints(table) self.assertIn(index_name, constraints) - self.assertIn(constraints[index_name]['type'], GinIndex.suffix) - self.assertIs(sql.references_column(table, 'field'), True) - self.assertIn('::tsvector', str(sql)) + self.assertIn(constraints[index_name]["type"], GinIndex.suffix) + self.assertIs(sql.references_column(table, "field"), True) + self.assertIn("::tsvector", str(sql)) with connection.schema_editor() as editor: editor.remove_index(TextFieldModel, index) self.assertNotIn(index_name, self.get_constraints(table)) def test_bloom_index(self): - index_name = 'char_field_model_field_bloom' - index = BloomIndex(fields=['field'], name=index_name) + index_name = "char_field_model_field_bloom" + index = BloomIndex(fields=["field"], name=index_name) with connection.schema_editor() as editor: editor.add_index(CharFieldModel, index) constraints = self.get_constraints(CharFieldModel._meta.db_table) - self.assertEqual(constraints[index_name]['type'], BloomIndex.suffix) + self.assertEqual(constraints[index_name]["type"], BloomIndex.suffix) with connection.schema_editor() as editor: editor.remove_index(CharFieldModel, index) - self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table)) + self.assertNotIn( + index_name, self.get_constraints(CharFieldModel._meta.db_table) + ) def test_bloom_parameters(self): - index_name = 'char_field_model_field_bloom_params' - index = BloomIndex(fields=['field'], name=index_name, length=512, columns=[3]) + index_name = "char_field_model_field_bloom_params" + index = BloomIndex(fields=["field"], name=index_name, length=512, columns=[3]) with connection.schema_editor() as editor: editor.add_index(CharFieldModel, index) constraints = self.get_constraints(CharFieldModel._meta.db_table) - self.assertEqual(constraints[index_name]['type'], BloomIndex.suffix) - self.assertEqual(constraints[index_name]['options'], ['length=512', 'col1=3']) + self.assertEqual(constraints[index_name]["type"], BloomIndex.suffix) + self.assertEqual(constraints[index_name]["options"], ["length=512", "col1=3"]) with connection.schema_editor() as editor: editor.remove_index(CharFieldModel, index) - self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table)) + self.assertNotIn( + index_name, self.get_constraints(CharFieldModel._meta.db_table) + ) def test_brin_index(self): - index_name = 'char_field_model_field_brin' - index = BrinIndex(fields=['field'], name=index_name, pages_per_range=4) + index_name = "char_field_model_field_brin" + index = BrinIndex(fields=["field"], name=index_name, pages_per_range=4) with connection.schema_editor() as editor: editor.add_index(CharFieldModel, index) constraints = self.get_constraints(CharFieldModel._meta.db_table) - self.assertEqual(constraints[index_name]['type'], BrinIndex.suffix) - self.assertEqual(constraints[index_name]['options'], ['pages_per_range=4']) + self.assertEqual(constraints[index_name]["type"], BrinIndex.suffix) + self.assertEqual(constraints[index_name]["options"], ["pages_per_range=4"]) with connection.schema_editor() as editor: editor.remove_index(CharFieldModel, index) - self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table)) + self.assertNotIn( + index_name, self.get_constraints(CharFieldModel._meta.db_table) + ) def test_brin_parameters(self): - index_name = 'char_field_brin_params' - index = BrinIndex(fields=['field'], name=index_name, autosummarize=True) + index_name = "char_field_brin_params" + index = BrinIndex(fields=["field"], name=index_name, autosummarize=True) with connection.schema_editor() as editor: editor.add_index(CharFieldModel, index) constraints = self.get_constraints(CharFieldModel._meta.db_table) - self.assertEqual(constraints[index_name]['type'], BrinIndex.suffix) - self.assertEqual(constraints[index_name]['options'], ['autosummarize=on']) + self.assertEqual(constraints[index_name]["type"], BrinIndex.suffix) + self.assertEqual(constraints[index_name]["options"], ["autosummarize=on"]) with connection.schema_editor() as editor: editor.remove_index(CharFieldModel, index) - self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table)) + self.assertNotIn( + index_name, self.get_constraints(CharFieldModel._meta.db_table) + ) def test_btree_index(self): # Ensure the table is there and doesn't have an index. - self.assertNotIn('field', self.get_constraints(CharFieldModel._meta.db_table)) + self.assertNotIn("field", self.get_constraints(CharFieldModel._meta.db_table)) # Add the index. - index_name = 'char_field_model_field_btree' - index = BTreeIndex(fields=['field'], name=index_name) + index_name = "char_field_model_field_btree" + index = BTreeIndex(fields=["field"], name=index_name) with connection.schema_editor() as editor: editor.add_index(CharFieldModel, index) constraints = self.get_constraints(CharFieldModel._meta.db_table) # The index was added. - self.assertEqual(constraints[index_name]['type'], BTreeIndex.suffix) + self.assertEqual(constraints[index_name]["type"], BTreeIndex.suffix) # Drop the index. with connection.schema_editor() as editor: editor.remove_index(CharFieldModel, index) - self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table)) + self.assertNotIn( + index_name, self.get_constraints(CharFieldModel._meta.db_table) + ) def test_btree_parameters(self): - index_name = 'integer_array_btree_fillfactor' - index = BTreeIndex(fields=['field'], name=index_name, fillfactor=80) + index_name = "integer_array_btree_fillfactor" + index = BTreeIndex(fields=["field"], name=index_name, fillfactor=80) with connection.schema_editor() as editor: editor.add_index(CharFieldModel, index) constraints = self.get_constraints(CharFieldModel._meta.db_table) - self.assertEqual(constraints[index_name]['type'], BTreeIndex.suffix) - self.assertEqual(constraints[index_name]['options'], ['fillfactor=80']) + self.assertEqual(constraints[index_name]["type"], BTreeIndex.suffix) + self.assertEqual(constraints[index_name]["options"], ["fillfactor=80"]) with connection.schema_editor() as editor: editor.remove_index(CharFieldModel, index) - self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table)) + self.assertNotIn( + index_name, self.get_constraints(CharFieldModel._meta.db_table) + ) def test_gist_index(self): # Ensure the table is there and doesn't have an index. - self.assertNotIn('field', self.get_constraints(CharFieldModel._meta.db_table)) + self.assertNotIn("field", self.get_constraints(CharFieldModel._meta.db_table)) # Add the index. - index_name = 'char_field_model_field_gist' - index = GistIndex(fields=['field'], name=index_name) + index_name = "char_field_model_field_gist" + index = GistIndex(fields=["field"], name=index_name) with connection.schema_editor() as editor: editor.add_index(CharFieldModel, index) constraints = self.get_constraints(CharFieldModel._meta.db_table) # The index was added. - self.assertEqual(constraints[index_name]['type'], GistIndex.suffix) + self.assertEqual(constraints[index_name]["type"], GistIndex.suffix) # Drop the index. with connection.schema_editor() as editor: editor.remove_index(CharFieldModel, index) - self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table)) + self.assertNotIn( + index_name, self.get_constraints(CharFieldModel._meta.db_table) + ) def test_gist_parameters(self): - index_name = 'integer_array_gist_buffering' - index = GistIndex(fields=['field'], name=index_name, buffering=True, fillfactor=80) + index_name = "integer_array_gist_buffering" + index = GistIndex( + fields=["field"], name=index_name, buffering=True, fillfactor=80 + ) with connection.schema_editor() as editor: editor.add_index(CharFieldModel, index) constraints = self.get_constraints(CharFieldModel._meta.db_table) - self.assertEqual(constraints[index_name]['type'], GistIndex.suffix) - self.assertEqual(constraints[index_name]['options'], ['buffering=on', 'fillfactor=80']) + self.assertEqual(constraints[index_name]["type"], GistIndex.suffix) + self.assertEqual( + constraints[index_name]["options"], ["buffering=on", "fillfactor=80"] + ) with connection.schema_editor() as editor: editor.remove_index(CharFieldModel, index) - self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table)) + self.assertNotIn( + index_name, self.get_constraints(CharFieldModel._meta.db_table) + ) - @skipUnlessDBFeature('supports_covering_gist_indexes') + @skipUnlessDBFeature("supports_covering_gist_indexes") def test_gist_include(self): - index_name = 'scene_gist_include_setting' - index = GistIndex(name=index_name, fields=['scene'], include=['setting']) + index_name = "scene_gist_include_setting" + index = GistIndex(name=index_name, fields=["scene"], include=["setting"]) with connection.schema_editor() as editor: editor.add_index(Scene, index) constraints = self.get_constraints(Scene._meta.db_table) self.assertIn(index_name, constraints) - self.assertEqual(constraints[index_name]['type'], GistIndex.suffix) - self.assertEqual(constraints[index_name]['columns'], ['scene', 'setting']) + self.assertEqual(constraints[index_name]["type"], GistIndex.suffix) + self.assertEqual(constraints[index_name]["columns"], ["scene", "setting"]) with connection.schema_editor() as editor: editor.remove_index(Scene, index) self.assertNotIn(index_name, self.get_constraints(Scene._meta.db_table)) def test_gist_include_not_supported(self): - index_name = 'gist_include_exception' - index = GistIndex(fields=['scene'], name=index_name, include=['setting']) - msg = 'Covering GiST indexes require PostgreSQL 12+.' + index_name = "gist_include_exception" + index = GistIndex(fields=["scene"], name=index_name, include=["setting"]) + msg = "Covering GiST indexes require PostgreSQL 12+." with self.assertRaisesMessage(NotSupportedError, msg): with mock.patch( - 'django.db.backends.postgresql.features.DatabaseFeatures.supports_covering_gist_indexes', + "django.db.backends.postgresql.features.DatabaseFeatures.supports_covering_gist_indexes", False, ): with connection.schema_editor() as editor: @@ -441,11 +527,11 @@ class SchemaTests(PostgreSQLTestCase): self.assertNotIn(index_name, self.get_constraints(Scene._meta.db_table)) def test_tsvector_op_class_gist_index(self): - index_name = 'tsvector_op_class_gist' + index_name = "tsvector_op_class_gist" index = GistIndex( OpClass( - SearchVector('scene', 'setting', config='english'), - name='tsvector_ops', + SearchVector("scene", "setting", config="english"), + name="tsvector_ops", ), name=index_name, ) @@ -455,90 +541,98 @@ class SchemaTests(PostgreSQLTestCase): table = Scene._meta.db_table constraints = self.get_constraints(table) self.assertIn(index_name, constraints) - self.assertIn(constraints[index_name]['type'], GistIndex.suffix) - self.assertIs(sql.references_column(table, 'scene'), True) - self.assertIs(sql.references_column(table, 'setting'), True) + self.assertIn(constraints[index_name]["type"], GistIndex.suffix) + self.assertIs(sql.references_column(table, "scene"), True) + self.assertIs(sql.references_column(table, "setting"), True) with connection.schema_editor() as editor: editor.remove_index(Scene, index) self.assertNotIn(index_name, self.get_constraints(table)) def test_hash_index(self): # Ensure the table is there and doesn't have an index. - self.assertNotIn('field', self.get_constraints(CharFieldModel._meta.db_table)) + self.assertNotIn("field", self.get_constraints(CharFieldModel._meta.db_table)) # Add the index. - index_name = 'char_field_model_field_hash' - index = HashIndex(fields=['field'], name=index_name) + index_name = "char_field_model_field_hash" + index = HashIndex(fields=["field"], name=index_name) with connection.schema_editor() as editor: editor.add_index(CharFieldModel, index) constraints = self.get_constraints(CharFieldModel._meta.db_table) # The index was added. - self.assertEqual(constraints[index_name]['type'], HashIndex.suffix) + self.assertEqual(constraints[index_name]["type"], HashIndex.suffix) # Drop the index. with connection.schema_editor() as editor: editor.remove_index(CharFieldModel, index) - self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table)) + self.assertNotIn( + index_name, self.get_constraints(CharFieldModel._meta.db_table) + ) def test_hash_parameters(self): - index_name = 'integer_array_hash_fillfactor' - index = HashIndex(fields=['field'], name=index_name, fillfactor=80) + index_name = "integer_array_hash_fillfactor" + index = HashIndex(fields=["field"], name=index_name, fillfactor=80) with connection.schema_editor() as editor: editor.add_index(CharFieldModel, index) constraints = self.get_constraints(CharFieldModel._meta.db_table) - self.assertEqual(constraints[index_name]['type'], HashIndex.suffix) - self.assertEqual(constraints[index_name]['options'], ['fillfactor=80']) + self.assertEqual(constraints[index_name]["type"], HashIndex.suffix) + self.assertEqual(constraints[index_name]["options"], ["fillfactor=80"]) with connection.schema_editor() as editor: editor.remove_index(CharFieldModel, index) - self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table)) + self.assertNotIn( + index_name, self.get_constraints(CharFieldModel._meta.db_table) + ) def test_spgist_index(self): # Ensure the table is there and doesn't have an index. - self.assertNotIn('field', self.get_constraints(TextFieldModel._meta.db_table)) + self.assertNotIn("field", self.get_constraints(TextFieldModel._meta.db_table)) # Add the index. - index_name = 'text_field_model_field_spgist' - index = SpGistIndex(fields=['field'], name=index_name) + index_name = "text_field_model_field_spgist" + index = SpGistIndex(fields=["field"], name=index_name) with connection.schema_editor() as editor: editor.add_index(TextFieldModel, index) constraints = self.get_constraints(TextFieldModel._meta.db_table) # The index was added. - self.assertEqual(constraints[index_name]['type'], SpGistIndex.suffix) + self.assertEqual(constraints[index_name]["type"], SpGistIndex.suffix) # Drop the index. with connection.schema_editor() as editor: editor.remove_index(TextFieldModel, index) - self.assertNotIn(index_name, self.get_constraints(TextFieldModel._meta.db_table)) + self.assertNotIn( + index_name, self.get_constraints(TextFieldModel._meta.db_table) + ) def test_spgist_parameters(self): - index_name = 'text_field_model_spgist_fillfactor' - index = SpGistIndex(fields=['field'], name=index_name, fillfactor=80) + index_name = "text_field_model_spgist_fillfactor" + index = SpGistIndex(fields=["field"], name=index_name, fillfactor=80) with connection.schema_editor() as editor: editor.add_index(TextFieldModel, index) constraints = self.get_constraints(TextFieldModel._meta.db_table) - self.assertEqual(constraints[index_name]['type'], SpGistIndex.suffix) - self.assertEqual(constraints[index_name]['options'], ['fillfactor=80']) + self.assertEqual(constraints[index_name]["type"], SpGistIndex.suffix) + self.assertEqual(constraints[index_name]["options"], ["fillfactor=80"]) with connection.schema_editor() as editor: editor.remove_index(TextFieldModel, index) - self.assertNotIn(index_name, self.get_constraints(TextFieldModel._meta.db_table)) + self.assertNotIn( + index_name, self.get_constraints(TextFieldModel._meta.db_table) + ) - @skipUnlessDBFeature('supports_covering_spgist_indexes') + @skipUnlessDBFeature("supports_covering_spgist_indexes") def test_spgist_include(self): - index_name = 'scene_spgist_include_setting' - index = SpGistIndex(name=index_name, fields=['scene'], include=['setting']) + index_name = "scene_spgist_include_setting" + index = SpGistIndex(name=index_name, fields=["scene"], include=["setting"]) with connection.schema_editor() as editor: editor.add_index(Scene, index) constraints = self.get_constraints(Scene._meta.db_table) self.assertIn(index_name, constraints) - self.assertEqual(constraints[index_name]['type'], SpGistIndex.suffix) - self.assertEqual(constraints[index_name]['columns'], ['scene', 'setting']) + self.assertEqual(constraints[index_name]["type"], SpGistIndex.suffix) + self.assertEqual(constraints[index_name]["columns"], ["scene", "setting"]) with connection.schema_editor() as editor: editor.remove_index(Scene, index) self.assertNotIn(index_name, self.get_constraints(Scene._meta.db_table)) def test_spgist_include_not_supported(self): - index_name = 'spgist_include_exception' - index = SpGistIndex(fields=['scene'], name=index_name, include=['setting']) - msg = 'Covering SP-GiST indexes require PostgreSQL 14+.' + index_name = "spgist_include_exception" + index = SpGistIndex(fields=["scene"], name=index_name, include=["setting"]) + msg = "Covering SP-GiST indexes require PostgreSQL 14+." with self.assertRaisesMessage(NotSupportedError, msg): with mock.patch( - 'django.db.backends.postgresql.features.DatabaseFeatures.supports_covering_spgist_indexes', + "django.db.backends.postgresql.features.DatabaseFeatures.supports_covering_spgist_indexes", False, ): with connection.schema_editor() as editor: @@ -546,27 +640,25 @@ class SchemaTests(PostgreSQLTestCase): self.assertNotIn(index_name, self.get_constraints(Scene._meta.db_table)) def test_op_class(self): - index_name = 'test_op_class' + index_name = "test_op_class" index = Index( - OpClass(Lower('field'), name='text_pattern_ops'), + OpClass(Lower("field"), name="text_pattern_ops"), name=index_name, ) with connection.schema_editor() as editor: editor.add_index(TextFieldModel, index) with editor.connection.cursor() as cursor: cursor.execute(self.get_opclass_query, [index_name]) - self.assertCountEqual(cursor.fetchall(), [('text_pattern_ops', index_name)]) + self.assertCountEqual(cursor.fetchall(), [("text_pattern_ops", index_name)]) def test_op_class_descending_collation(self): - collation = connection.features.test_collations.get('non_default') + collation = connection.features.test_collations.get("non_default") if not collation: - self.skipTest( - 'This backend does not support case-insensitive collations.' - ) - index_name = 'test_op_class_descending_collation' + self.skipTest("This backend does not support case-insensitive collations.") + index_name = "test_op_class_descending_collation" index = Index( Collate( - OpClass(Lower('field'), name='text_pattern_ops').desc(nulls_last=True), + OpClass(Lower("field"), name="text_pattern_ops").desc(nulls_last=True), collation=collation, ), name=index_name, @@ -574,53 +666,53 @@ class SchemaTests(PostgreSQLTestCase): with connection.schema_editor() as editor: editor.add_index(TextFieldModel, index) self.assertIn( - 'COLLATE %s' % editor.quote_name(collation), + "COLLATE %s" % editor.quote_name(collation), str(index.create_sql(TextFieldModel, editor)), ) with editor.connection.cursor() as cursor: cursor.execute(self.get_opclass_query, [index_name]) - self.assertCountEqual(cursor.fetchall(), [('text_pattern_ops', index_name)]) + self.assertCountEqual(cursor.fetchall(), [("text_pattern_ops", index_name)]) table = TextFieldModel._meta.db_table constraints = self.get_constraints(table) self.assertIn(index_name, constraints) - self.assertEqual(constraints[index_name]['orders'], ['DESC']) + self.assertEqual(constraints[index_name]["orders"], ["DESC"]) with connection.schema_editor() as editor: editor.remove_index(TextFieldModel, index) self.assertNotIn(index_name, self.get_constraints(table)) def test_op_class_descending_partial(self): - index_name = 'test_op_class_descending_partial' + index_name = "test_op_class_descending_partial" index = Index( - OpClass(Lower('field'), name='text_pattern_ops').desc(), + OpClass(Lower("field"), name="text_pattern_ops").desc(), name=index_name, - condition=Q(field__contains='China'), + condition=Q(field__contains="China"), ) with connection.schema_editor() as editor: editor.add_index(TextFieldModel, index) with editor.connection.cursor() as cursor: cursor.execute(self.get_opclass_query, [index_name]) - self.assertCountEqual(cursor.fetchall(), [('text_pattern_ops', index_name)]) + self.assertCountEqual(cursor.fetchall(), [("text_pattern_ops", index_name)]) constraints = self.get_constraints(TextFieldModel._meta.db_table) self.assertIn(index_name, constraints) - self.assertEqual(constraints[index_name]['orders'], ['DESC']) + self.assertEqual(constraints[index_name]["orders"], ["DESC"]) def test_op_class_descending_partial_tablespace(self): - index_name = 'test_op_class_descending_partial_tablespace' + index_name = "test_op_class_descending_partial_tablespace" index = Index( - OpClass(Lower('field').desc(), name='text_pattern_ops'), + OpClass(Lower("field").desc(), name="text_pattern_ops"), name=index_name, - condition=Q(field__contains='China'), - db_tablespace='pg_default', + condition=Q(field__contains="China"), + db_tablespace="pg_default", ) with connection.schema_editor() as editor: editor.add_index(TextFieldModel, index) self.assertIn( 'TABLESPACE "pg_default" ', - str(index.create_sql(TextFieldModel, editor)) + str(index.create_sql(TextFieldModel, editor)), ) with editor.connection.cursor() as cursor: cursor.execute(self.get_opclass_query, [index_name]) - self.assertCountEqual(cursor.fetchall(), [('text_pattern_ops', index_name)]) + self.assertCountEqual(cursor.fetchall(), [("text_pattern_ops", index_name)]) constraints = self.get_constraints(TextFieldModel._meta.db_table) self.assertIn(index_name, constraints) - self.assertEqual(constraints[index_name]['orders'], ['DESC']) + self.assertEqual(constraints[index_name]["orders"], ["DESC"]) diff --git a/tests/postgres_tests/test_integration.py b/tests/postgres_tests/test_integration.py index db082a6495..e3c62a93e9 100644 --- a/tests/postgres_tests/test_integration.py +++ b/tests/postgres_tests/test_integration.py @@ -8,15 +8,22 @@ from . import PostgreSQLSimpleTestCase class PostgresIntegrationTests(PostgreSQLSimpleTestCase): def test_check(self): test_environ = os.environ.copy() - if 'DJANGO_SETTINGS_MODULE' in test_environ: - del test_environ['DJANGO_SETTINGS_MODULE'] - test_environ['PYTHONPATH'] = os.path.join(os.path.dirname(__file__), '../../') + if "DJANGO_SETTINGS_MODULE" in test_environ: + del test_environ["DJANGO_SETTINGS_MODULE"] + test_environ["PYTHONPATH"] = os.path.join(os.path.dirname(__file__), "../../") result = subprocess.run( - [sys.executable, '-m', 'django', 'check', '--settings', 'integration_settings'], + [ + sys.executable, + "-m", + "django", + "check", + "--settings", + "integration_settings", + ], stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, cwd=os.path.dirname(__file__), env=test_environ, - encoding='utf-8', + encoding="utf-8", ) self.assertEqual(result.returncode, 0, msg=result.stderr) diff --git a/tests/postgres_tests/test_introspection.py b/tests/postgres_tests/test_introspection.py index 50cb9b2828..670be46536 100644 --- a/tests/postgres_tests/test_introspection.py +++ b/tests/postgres_tests/test_introspection.py @@ -6,12 +6,12 @@ from django.test.utils import modify_settings from . import PostgreSQLTestCase -@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}) +@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"}) class InspectDBTests(PostgreSQLTestCase): def assertFieldsInModel(self, model, field_outputs): out = StringIO() call_command( - 'inspectdb', + "inspectdb", table_name_filter=lambda tn: tn.startswith(model), stdout=out, ) @@ -21,12 +21,12 @@ class InspectDBTests(PostgreSQLTestCase): def test_range_fields(self): self.assertFieldsInModel( - 'postgres_tests_rangesmodel', + "postgres_tests_rangesmodel", [ - 'ints = django.contrib.postgres.fields.IntegerRangeField(blank=True, null=True)', - 'bigints = django.contrib.postgres.fields.BigIntegerRangeField(blank=True, null=True)', - 'decimals = django.contrib.postgres.fields.DecimalRangeField(blank=True, null=True)', - 'timestamps = django.contrib.postgres.fields.DateTimeRangeField(blank=True, null=True)', - 'dates = django.contrib.postgres.fields.DateRangeField(blank=True, null=True)', + "ints = django.contrib.postgres.fields.IntegerRangeField(blank=True, null=True)", + "bigints = django.contrib.postgres.fields.BigIntegerRangeField(blank=True, null=True)", + "decimals = django.contrib.postgres.fields.DecimalRangeField(blank=True, null=True)", + "timestamps = django.contrib.postgres.fields.DateTimeRangeField(blank=True, null=True)", + "dates = django.contrib.postgres.fields.DateRangeField(blank=True, null=True)", ], ) diff --git a/tests/postgres_tests/test_operations.py b/tests/postgres_tests/test_operations.py index 790fef8332..a54f8811ba 100644 --- a/tests/postgres_tests/test_operations.py +++ b/tests/postgres_tests/test_operations.py @@ -3,9 +3,7 @@ from unittest import mock from migrations.test_base import OperationTestBase -from django.db import ( - IntegrityError, NotSupportedError, connection, transaction, -) +from django.db import IntegrityError, NotSupportedError, connection, transaction from django.db.migrations.state import ProjectState from django.db.models import CheckConstraint, Index, Q, UniqueConstraint from django.db.utils import ProgrammingError @@ -17,262 +15,315 @@ from . import PostgreSQLTestCase try: from django.contrib.postgres.indexes import BrinIndex, BTreeIndex from django.contrib.postgres.operations import ( - AddConstraintNotValid, AddIndexConcurrently, BloomExtension, - CreateCollation, CreateExtension, RemoveCollation, - RemoveIndexConcurrently, ValidateConstraint, + AddConstraintNotValid, + AddIndexConcurrently, + BloomExtension, + CreateCollation, + CreateExtension, + RemoveCollation, + RemoveIndexConcurrently, + ValidateConstraint, ) except ImportError: pass -@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.') -@modify_settings(INSTALLED_APPS={'append': 'migrations'}) +@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests.") +@modify_settings(INSTALLED_APPS={"append": "migrations"}) class AddIndexConcurrentlyTests(OperationTestBase): - app_label = 'test_add_concurrently' + app_label = "test_add_concurrently" def test_requires_atomic_false(self): project_state = self.set_up_test_model(self.app_label) new_state = project_state.clone() operation = AddIndexConcurrently( - 'Pony', - Index(fields=['pink'], name='pony_pink_idx'), + "Pony", + Index(fields=["pink"], name="pony_pink_idx"), ) msg = ( - 'The AddIndexConcurrently operation cannot be executed inside ' - 'a transaction (set atomic = False on the migration).' + "The AddIndexConcurrently operation cannot be executed inside " + "a transaction (set atomic = False on the migration)." ) with self.assertRaisesMessage(NotSupportedError, msg): with connection.schema_editor(atomic=True) as editor: - operation.database_forwards(self.app_label, editor, project_state, new_state) + operation.database_forwards( + self.app_label, editor, project_state, new_state + ) def test_add(self): project_state = self.set_up_test_model(self.app_label, index=False) - table_name = '%s_pony' % self.app_label - index = Index(fields=['pink'], name='pony_pink_idx') + table_name = "%s_pony" % self.app_label + index = Index(fields=["pink"], name="pony_pink_idx") new_state = project_state.clone() - operation = AddIndexConcurrently('Pony', index) + operation = AddIndexConcurrently("Pony", index) self.assertEqual( operation.describe(), - 'Concurrently create index pony_pink_idx on field(s) pink of model Pony', + "Concurrently create index pony_pink_idx on field(s) pink of model Pony", ) operation.state_forwards(self.app_label, new_state) - self.assertEqual(len(new_state.models[self.app_label, 'pony'].options['indexes']), 1) - self.assertIndexNotExists(table_name, ['pink']) + self.assertEqual( + len(new_state.models[self.app_label, "pony"].options["indexes"]), 1 + ) + self.assertIndexNotExists(table_name, ["pink"]) # Add index. with connection.schema_editor(atomic=False) as editor: - operation.database_forwards(self.app_label, editor, project_state, new_state) - self.assertIndexExists(table_name, ['pink']) + operation.database_forwards( + self.app_label, editor, project_state, new_state + ) + self.assertIndexExists(table_name, ["pink"]) # Reversal. with connection.schema_editor(atomic=False) as editor: - operation.database_backwards(self.app_label, editor, new_state, project_state) - self.assertIndexNotExists(table_name, ['pink']) + operation.database_backwards( + self.app_label, editor, new_state, project_state + ) + self.assertIndexNotExists(table_name, ["pink"]) # Deconstruction. name, args, kwargs = operation.deconstruct() - self.assertEqual(name, 'AddIndexConcurrently') + self.assertEqual(name, "AddIndexConcurrently") self.assertEqual(args, []) - self.assertEqual(kwargs, {'model_name': 'Pony', 'index': index}) + self.assertEqual(kwargs, {"model_name": "Pony", "index": index}) def test_add_other_index_type(self): project_state = self.set_up_test_model(self.app_label, index=False) - table_name = '%s_pony' % self.app_label + table_name = "%s_pony" % self.app_label new_state = project_state.clone() operation = AddIndexConcurrently( - 'Pony', - BrinIndex(fields=['pink'], name='pony_pink_brin_idx'), + "Pony", + BrinIndex(fields=["pink"], name="pony_pink_brin_idx"), ) - self.assertIndexNotExists(table_name, ['pink']) + self.assertIndexNotExists(table_name, ["pink"]) # Add index. with connection.schema_editor(atomic=False) as editor: - operation.database_forwards(self.app_label, editor, project_state, new_state) - self.assertIndexExists(table_name, ['pink'], index_type='brin') + operation.database_forwards( + self.app_label, editor, project_state, new_state + ) + self.assertIndexExists(table_name, ["pink"], index_type="brin") # Reversal. with connection.schema_editor(atomic=False) as editor: - operation.database_backwards(self.app_label, editor, new_state, project_state) - self.assertIndexNotExists(table_name, ['pink']) + operation.database_backwards( + self.app_label, editor, new_state, project_state + ) + self.assertIndexNotExists(table_name, ["pink"]) def test_add_with_options(self): project_state = self.set_up_test_model(self.app_label, index=False) - table_name = '%s_pony' % self.app_label + table_name = "%s_pony" % self.app_label new_state = project_state.clone() - index = BTreeIndex(fields=['pink'], name='pony_pink_btree_idx', fillfactor=70) - operation = AddIndexConcurrently('Pony', index) - self.assertIndexNotExists(table_name, ['pink']) + index = BTreeIndex(fields=["pink"], name="pony_pink_btree_idx", fillfactor=70) + operation = AddIndexConcurrently("Pony", index) + self.assertIndexNotExists(table_name, ["pink"]) # Add index. with connection.schema_editor(atomic=False) as editor: - operation.database_forwards(self.app_label, editor, project_state, new_state) - self.assertIndexExists(table_name, ['pink'], index_type='btree') + operation.database_forwards( + self.app_label, editor, project_state, new_state + ) + self.assertIndexExists(table_name, ["pink"], index_type="btree") # Reversal. with connection.schema_editor(atomic=False) as editor: - operation.database_backwards(self.app_label, editor, new_state, project_state) - self.assertIndexNotExists(table_name, ['pink']) + operation.database_backwards( + self.app_label, editor, new_state, project_state + ) + self.assertIndexNotExists(table_name, ["pink"]) -@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.') -@modify_settings(INSTALLED_APPS={'append': 'migrations'}) +@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests.") +@modify_settings(INSTALLED_APPS={"append": "migrations"}) class RemoveIndexConcurrentlyTests(OperationTestBase): - app_label = 'test_rm_concurrently' + app_label = "test_rm_concurrently" def test_requires_atomic_false(self): project_state = self.set_up_test_model(self.app_label, index=True) new_state = project_state.clone() - operation = RemoveIndexConcurrently('Pony', 'pony_pink_idx') + operation = RemoveIndexConcurrently("Pony", "pony_pink_idx") msg = ( - 'The RemoveIndexConcurrently operation cannot be executed inside ' - 'a transaction (set atomic = False on the migration).' + "The RemoveIndexConcurrently operation cannot be executed inside " + "a transaction (set atomic = False on the migration)." ) with self.assertRaisesMessage(NotSupportedError, msg): with connection.schema_editor(atomic=True) as editor: - operation.database_forwards(self.app_label, editor, project_state, new_state) + operation.database_forwards( + self.app_label, editor, project_state, new_state + ) def test_remove(self): project_state = self.set_up_test_model(self.app_label, index=True) - table_name = '%s_pony' % self.app_label + table_name = "%s_pony" % self.app_label self.assertTableExists(table_name) new_state = project_state.clone() - operation = RemoveIndexConcurrently('Pony', 'pony_pink_idx') + operation = RemoveIndexConcurrently("Pony", "pony_pink_idx") self.assertEqual( operation.describe(), - 'Concurrently remove index pony_pink_idx from Pony', + "Concurrently remove index pony_pink_idx from Pony", ) operation.state_forwards(self.app_label, new_state) - self.assertEqual(len(new_state.models[self.app_label, 'pony'].options['indexes']), 0) - self.assertIndexExists(table_name, ['pink']) + self.assertEqual( + len(new_state.models[self.app_label, "pony"].options["indexes"]), 0 + ) + self.assertIndexExists(table_name, ["pink"]) # Remove index. with connection.schema_editor(atomic=False) as editor: - operation.database_forwards(self.app_label, editor, project_state, new_state) - self.assertIndexNotExists(table_name, ['pink']) + operation.database_forwards( + self.app_label, editor, project_state, new_state + ) + self.assertIndexNotExists(table_name, ["pink"]) # Reversal. with connection.schema_editor(atomic=False) as editor: - operation.database_backwards(self.app_label, editor, new_state, project_state) - self.assertIndexExists(table_name, ['pink']) + operation.database_backwards( + self.app_label, editor, new_state, project_state + ) + self.assertIndexExists(table_name, ["pink"]) # Deconstruction. name, args, kwargs = operation.deconstruct() - self.assertEqual(name, 'RemoveIndexConcurrently') + self.assertEqual(name, "RemoveIndexConcurrently") self.assertEqual(args, []) - self.assertEqual(kwargs, {'model_name': 'Pony', 'name': 'pony_pink_idx'}) + self.assertEqual(kwargs, {"model_name": "Pony", "name": "pony_pink_idx"}) -class NoMigrationRouter(): +class NoMigrationRouter: def allow_migrate(self, db, app_label, **hints): return False -@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.') +@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests.") class CreateExtensionTests(PostgreSQLTestCase): - app_label = 'test_allow_create_extention' + app_label = "test_allow_create_extention" @override_settings(DATABASE_ROUTERS=[NoMigrationRouter()]) def test_no_allow_migrate(self): - operation = CreateExtension('tablefunc') + operation = CreateExtension("tablefunc") project_state = ProjectState() new_state = project_state.clone() # Don't create an extension. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: - operation.database_forwards(self.app_label, editor, project_state, new_state) + operation.database_forwards( + self.app_label, editor, project_state, new_state + ) self.assertEqual(len(captured_queries), 0) # Reversal. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: - operation.database_backwards(self.app_label, editor, new_state, project_state) + operation.database_backwards( + self.app_label, editor, new_state, project_state + ) self.assertEqual(len(captured_queries), 0) def test_allow_migrate(self): - operation = CreateExtension('tablefunc') - self.assertEqual(operation.migration_name_fragment, 'create_extension_tablefunc') + operation = CreateExtension("tablefunc") + self.assertEqual( + operation.migration_name_fragment, "create_extension_tablefunc" + ) project_state = ProjectState() new_state = project_state.clone() # Create an extension. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: - operation.database_forwards(self.app_label, editor, project_state, new_state) + operation.database_forwards( + self.app_label, editor, project_state, new_state + ) self.assertEqual(len(captured_queries), 4) - self.assertIn('CREATE EXTENSION IF NOT EXISTS', captured_queries[1]['sql']) + self.assertIn("CREATE EXTENSION IF NOT EXISTS", captured_queries[1]["sql"]) # Reversal. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: - operation.database_backwards(self.app_label, editor, new_state, project_state) + operation.database_backwards( + self.app_label, editor, new_state, project_state + ) self.assertEqual(len(captured_queries), 2) - self.assertIn('DROP EXTENSION IF EXISTS', captured_queries[1]['sql']) + self.assertIn("DROP EXTENSION IF EXISTS", captured_queries[1]["sql"]) def test_create_existing_extension(self): operation = BloomExtension() - self.assertEqual(operation.migration_name_fragment, 'create_extension_bloom') + self.assertEqual(operation.migration_name_fragment, "create_extension_bloom") project_state = ProjectState() new_state = project_state.clone() # Don't create an existing extension. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: - operation.database_forwards(self.app_label, editor, project_state, new_state) + operation.database_forwards( + self.app_label, editor, project_state, new_state + ) self.assertEqual(len(captured_queries), 3) - self.assertIn('SELECT', captured_queries[0]['sql']) + self.assertIn("SELECT", captured_queries[0]["sql"]) def test_drop_nonexistent_extension(self): - operation = CreateExtension('tablefunc') + operation = CreateExtension("tablefunc") project_state = ProjectState() new_state = project_state.clone() # Don't drop a nonexistent extension. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: - operation.database_backwards(self.app_label, editor, project_state, new_state) + operation.database_backwards( + self.app_label, editor, project_state, new_state + ) self.assertEqual(len(captured_queries), 1) - self.assertIn('SELECT', captured_queries[0]['sql']) + self.assertIn("SELECT", captured_queries[0]["sql"]) -@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.') +@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests.") class CreateCollationTests(PostgreSQLTestCase): - app_label = 'test_allow_create_collation' + app_label = "test_allow_create_collation" @override_settings(DATABASE_ROUTERS=[NoMigrationRouter()]) def test_no_allow_migrate(self): - operation = CreateCollation('C_test', locale='C') + operation = CreateCollation("C_test", locale="C") project_state = ProjectState() new_state = project_state.clone() # Don't create a collation. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: - operation.database_forwards(self.app_label, editor, project_state, new_state) + operation.database_forwards( + self.app_label, editor, project_state, new_state + ) self.assertEqual(len(captured_queries), 0) # Reversal. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: - operation.database_backwards(self.app_label, editor, new_state, project_state) + operation.database_backwards( + self.app_label, editor, new_state, project_state + ) self.assertEqual(len(captured_queries), 0) def test_create(self): - operation = CreateCollation('C_test', locale='C') - self.assertEqual(operation.migration_name_fragment, 'create_collation_c_test') - self.assertEqual(operation.describe(), 'Create collation C_test') + operation = CreateCollation("C_test", locale="C") + self.assertEqual(operation.migration_name_fragment, "create_collation_c_test") + self.assertEqual(operation.describe(), "Create collation C_test") project_state = ProjectState() new_state = project_state.clone() # Create a collation. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: - operation.database_forwards(self.app_label, editor, project_state, new_state) + operation.database_forwards( + self.app_label, editor, project_state, new_state + ) self.assertEqual(len(captured_queries), 1) - self.assertIn('CREATE COLLATION', captured_queries[0]['sql']) + self.assertIn("CREATE COLLATION", captured_queries[0]["sql"]) # Creating the same collation raises an exception. - with self.assertRaisesMessage(ProgrammingError, 'already exists'): + with self.assertRaisesMessage(ProgrammingError, "already exists"): with connection.schema_editor(atomic=True) as editor: - operation.database_forwards(self.app_label, editor, project_state, new_state) + operation.database_forwards( + self.app_label, editor, project_state, new_state + ) # Reversal. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: - operation.database_backwards(self.app_label, editor, new_state, project_state) + operation.database_backwards( + self.app_label, editor, new_state, project_state + ) self.assertEqual(len(captured_queries), 1) - self.assertIn('DROP COLLATION', captured_queries[0]['sql']) + self.assertIn("DROP COLLATION", captured_queries[0]["sql"]) # Deconstruction. name, args, kwargs = operation.deconstruct() - self.assertEqual(name, 'CreateCollation') + self.assertEqual(name, "CreateCollation") self.assertEqual(args, []) - self.assertEqual(kwargs, {'name': 'C_test', 'locale': 'C'}) + self.assertEqual(kwargs, {"name": "C_test", "locale": "C"}) - @skipUnlessDBFeature('supports_non_deterministic_collations') + @skipUnlessDBFeature("supports_non_deterministic_collations") def test_create_non_deterministic_collation(self): operation = CreateCollation( - 'case_insensitive_test', - 'und-u-ks-level2', - provider='icu', + "case_insensitive_test", + "und-u-ks-level2", + provider="icu", deterministic=False, ) project_state = ProjectState() @@ -280,216 +331,253 @@ class CreateCollationTests(PostgreSQLTestCase): # Create a collation. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: - operation.database_forwards(self.app_label, editor, project_state, new_state) + operation.database_forwards( + self.app_label, editor, project_state, new_state + ) self.assertEqual(len(captured_queries), 1) - self.assertIn('CREATE COLLATION', captured_queries[0]['sql']) + self.assertIn("CREATE COLLATION", captured_queries[0]["sql"]) # Reversal. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: - operation.database_backwards(self.app_label, editor, new_state, project_state) + operation.database_backwards( + self.app_label, editor, new_state, project_state + ) self.assertEqual(len(captured_queries), 1) - self.assertIn('DROP COLLATION', captured_queries[0]['sql']) + self.assertIn("DROP COLLATION", captured_queries[0]["sql"]) # Deconstruction. name, args, kwargs = operation.deconstruct() - self.assertEqual(name, 'CreateCollation') + self.assertEqual(name, "CreateCollation") self.assertEqual(args, []) - self.assertEqual(kwargs, { - 'name': 'case_insensitive_test', - 'locale': 'und-u-ks-level2', - 'provider': 'icu', - 'deterministic': False, - }) + self.assertEqual( + kwargs, + { + "name": "case_insensitive_test", + "locale": "und-u-ks-level2", + "provider": "icu", + "deterministic": False, + }, + ) def test_create_collation_alternate_provider(self): operation = CreateCollation( - 'german_phonebook_test', - provider='icu', - locale='de-u-co-phonebk', + "german_phonebook_test", + provider="icu", + locale="de-u-co-phonebk", ) project_state = ProjectState() new_state = project_state.clone() # Create an collation. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: - operation.database_forwards(self.app_label, editor, project_state, new_state) + operation.database_forwards( + self.app_label, editor, project_state, new_state + ) self.assertEqual(len(captured_queries), 1) - self.assertIn('CREATE COLLATION', captured_queries[0]['sql']) + self.assertIn("CREATE COLLATION", captured_queries[0]["sql"]) # Reversal. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: - operation.database_backwards(self.app_label, editor, new_state, project_state) + operation.database_backwards( + self.app_label, editor, new_state, project_state + ) self.assertEqual(len(captured_queries), 1) - self.assertIn('DROP COLLATION', captured_queries[0]['sql']) + self.assertIn("DROP COLLATION", captured_queries[0]["sql"]) def test_nondeterministic_collation_not_supported(self): operation = CreateCollation( - 'case_insensitive_test', - provider='icu', - locale='und-u-ks-level2', + "case_insensitive_test", + provider="icu", + locale="und-u-ks-level2", deterministic=False, ) project_state = ProjectState() new_state = project_state.clone() - msg = 'Non-deterministic collations require PostgreSQL 12+.' + msg = "Non-deterministic collations require PostgreSQL 12+." with connection.schema_editor(atomic=False) as editor: with mock.patch( - 'django.db.backends.postgresql.features.DatabaseFeatures.' - 'supports_non_deterministic_collations', + "django.db.backends.postgresql.features.DatabaseFeatures." + "supports_non_deterministic_collations", False, ): with self.assertRaisesMessage(NotSupportedError, msg): - operation.database_forwards(self.app_label, editor, project_state, new_state) + operation.database_forwards( + self.app_label, editor, project_state, new_state + ) -@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.') +@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests.") class RemoveCollationTests(PostgreSQLTestCase): - app_label = 'test_allow_remove_collation' + app_label = "test_allow_remove_collation" @override_settings(DATABASE_ROUTERS=[NoMigrationRouter()]) def test_no_allow_migrate(self): - operation = RemoveCollation('C_test', locale='C') + operation = RemoveCollation("C_test", locale="C") project_state = ProjectState() new_state = project_state.clone() # Don't create a collation. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: - operation.database_forwards(self.app_label, editor, project_state, new_state) + operation.database_forwards( + self.app_label, editor, project_state, new_state + ) self.assertEqual(len(captured_queries), 0) # Reversal. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: - operation.database_backwards(self.app_label, editor, new_state, project_state) + operation.database_backwards( + self.app_label, editor, new_state, project_state + ) self.assertEqual(len(captured_queries), 0) def test_remove(self): - operation = CreateCollation('C_test', locale='C') + operation = CreateCollation("C_test", locale="C") project_state = ProjectState() new_state = project_state.clone() with connection.schema_editor(atomic=False) as editor: - operation.database_forwards(self.app_label, editor, project_state, new_state) + operation.database_forwards( + self.app_label, editor, project_state, new_state + ) - operation = RemoveCollation('C_test', locale='C') - self.assertEqual(operation.migration_name_fragment, 'remove_collation_c_test') - self.assertEqual(operation.describe(), 'Remove collation C_test') + operation = RemoveCollation("C_test", locale="C") + self.assertEqual(operation.migration_name_fragment, "remove_collation_c_test") + self.assertEqual(operation.describe(), "Remove collation C_test") project_state = ProjectState() new_state = project_state.clone() # Remove a collation. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: - operation.database_forwards(self.app_label, editor, project_state, new_state) + operation.database_forwards( + self.app_label, editor, project_state, new_state + ) self.assertEqual(len(captured_queries), 1) - self.assertIn('DROP COLLATION', captured_queries[0]['sql']) + self.assertIn("DROP COLLATION", captured_queries[0]["sql"]) # Removing a nonexistent collation raises an exception. - with self.assertRaisesMessage(ProgrammingError, 'does not exist'): + with self.assertRaisesMessage(ProgrammingError, "does not exist"): with connection.schema_editor(atomic=True) as editor: - operation.database_forwards(self.app_label, editor, project_state, new_state) + operation.database_forwards( + self.app_label, editor, project_state, new_state + ) # Reversal. with CaptureQueriesContext(connection) as captured_queries: with connection.schema_editor(atomic=False) as editor: - operation.database_backwards(self.app_label, editor, new_state, project_state) + operation.database_backwards( + self.app_label, editor, new_state, project_state + ) self.assertEqual(len(captured_queries), 1) - self.assertIn('CREATE COLLATION', captured_queries[0]['sql']) + self.assertIn("CREATE COLLATION", captured_queries[0]["sql"]) # Deconstruction. name, args, kwargs = operation.deconstruct() - self.assertEqual(name, 'RemoveCollation') + self.assertEqual(name, "RemoveCollation") self.assertEqual(args, []) - self.assertEqual(kwargs, {'name': 'C_test', 'locale': 'C'}) + self.assertEqual(kwargs, {"name": "C_test", "locale": "C"}) -@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.') -@modify_settings(INSTALLED_APPS={'append': 'migrations'}) +@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests.") +@modify_settings(INSTALLED_APPS={"append": "migrations"}) class AddConstraintNotValidTests(OperationTestBase): - app_label = 'test_add_constraint_not_valid' + app_label = "test_add_constraint_not_valid" def test_non_check_constraint_not_supported(self): - constraint = UniqueConstraint(fields=['pink'], name='pony_pink_uniq') - msg = 'AddConstraintNotValid.constraint must be a check constraint.' + constraint = UniqueConstraint(fields=["pink"], name="pony_pink_uniq") + msg = "AddConstraintNotValid.constraint must be a check constraint." with self.assertRaisesMessage(TypeError, msg): - AddConstraintNotValid(model_name='pony', constraint=constraint) + AddConstraintNotValid(model_name="pony", constraint=constraint) def test_add(self): - table_name = f'{self.app_label}_pony' - constraint_name = 'pony_pink_gte_check' + table_name = f"{self.app_label}_pony" + constraint_name = "pony_pink_gte_check" constraint = CheckConstraint(check=Q(pink__gte=4), name=constraint_name) - operation = AddConstraintNotValid('Pony', constraint=constraint) + operation = AddConstraintNotValid("Pony", constraint=constraint) project_state, new_state = self.make_test_state(self.app_label, operation) self.assertEqual( operation.describe(), - f'Create not valid constraint {constraint_name} on model Pony', + f"Create not valid constraint {constraint_name} on model Pony", ) self.assertEqual( operation.migration_name_fragment, - f'pony_{constraint_name}_not_valid', + f"pony_{constraint_name}_not_valid", ) self.assertEqual( - len(new_state.models[self.app_label, 'pony'].options['constraints']), + len(new_state.models[self.app_label, "pony"].options["constraints"]), 1, ) self.assertConstraintNotExists(table_name, constraint_name) - Pony = new_state.apps.get_model(self.app_label, 'Pony') + Pony = new_state.apps.get_model(self.app_label, "Pony") self.assertEqual(len(Pony._meta.constraints), 1) Pony.objects.create(pink=2, weight=1.0) # Add constraint. with connection.schema_editor(atomic=True) as editor: - operation.database_forwards(self.app_label, editor, project_state, new_state) + operation.database_forwards( + self.app_label, editor, project_state, new_state + ) msg = f'check constraint "{constraint_name}"' with self.assertRaisesMessage(IntegrityError, msg), transaction.atomic(): Pony.objects.create(pink=3, weight=1.0) self.assertConstraintExists(table_name, constraint_name) # Reversal. with connection.schema_editor(atomic=True) as editor: - operation.database_backwards(self.app_label, editor, project_state, new_state) + operation.database_backwards( + self.app_label, editor, project_state, new_state + ) self.assertConstraintNotExists(table_name, constraint_name) Pony.objects.create(pink=3, weight=1.0) # Deconstruction. name, args, kwargs = operation.deconstruct() - self.assertEqual(name, 'AddConstraintNotValid') + self.assertEqual(name, "AddConstraintNotValid") self.assertEqual(args, []) - self.assertEqual(kwargs, {'model_name': 'Pony', 'constraint': constraint}) + self.assertEqual(kwargs, {"model_name": "Pony", "constraint": constraint}) -@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.') -@modify_settings(INSTALLED_APPS={'append': 'migrations'}) +@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests.") +@modify_settings(INSTALLED_APPS={"append": "migrations"}) class ValidateConstraintTests(OperationTestBase): - app_label = 'test_validate_constraint' + app_label = "test_validate_constraint" def test_validate(self): - constraint_name = 'pony_pink_gte_check' + constraint_name = "pony_pink_gte_check" constraint = CheckConstraint(check=Q(pink__gte=4), name=constraint_name) - operation = AddConstraintNotValid('Pony', constraint=constraint) + operation = AddConstraintNotValid("Pony", constraint=constraint) project_state, new_state = self.make_test_state(self.app_label, operation) - Pony = new_state.apps.get_model(self.app_label, 'Pony') + Pony = new_state.apps.get_model(self.app_label, "Pony") obj = Pony.objects.create(pink=2, weight=1.0) # Add constraint. with connection.schema_editor(atomic=True) as editor: - operation.database_forwards(self.app_label, editor, project_state, new_state) + operation.database_forwards( + self.app_label, editor, project_state, new_state + ) project_state = new_state new_state = new_state.clone() - operation = ValidateConstraint('Pony', name=constraint_name) + operation = ValidateConstraint("Pony", name=constraint_name) operation.state_forwards(self.app_label, new_state) self.assertEqual( operation.describe(), - f'Validate constraint {constraint_name} on model Pony', + f"Validate constraint {constraint_name} on model Pony", ) self.assertEqual( operation.migration_name_fragment, - f'pony_validate_{constraint_name}', + f"pony_validate_{constraint_name}", ) # Validate constraint. with connection.schema_editor(atomic=True) as editor: msg = f'check constraint "{constraint_name}"' with self.assertRaisesMessage(IntegrityError, msg): - operation.database_forwards(self.app_label, editor, project_state, new_state) + operation.database_forwards( + self.app_label, editor, project_state, new_state + ) obj.pink = 5 obj.save() with connection.schema_editor(atomic=True) as editor: - operation.database_forwards(self.app_label, editor, project_state, new_state) + operation.database_forwards( + self.app_label, editor, project_state, new_state + ) # Reversal is a noop. with connection.schema_editor() as editor: with self.assertNumQueries(0): - operation.database_backwards(self.app_label, editor, new_state, project_state) + operation.database_backwards( + self.app_label, editor, new_state, project_state + ) # Deconstruction. name, args, kwargs = operation.deconstruct() - self.assertEqual(name, 'ValidateConstraint') + self.assertEqual(name, "ValidateConstraint") self.assertEqual(args, []) - self.assertEqual(kwargs, {'model_name': 'Pony', 'name': constraint_name}) + self.assertEqual(kwargs, {"model_name": "Pony", "name": constraint_name}) diff --git a/tests/postgres_tests/test_ranges.py b/tests/postgres_tests/test_ranges.py index c2f9d443dd..7563b4ffff 100644 --- a/tests/postgres_tests/test_ranges.py +++ b/tests/postgres_tests/test_ranges.py @@ -12,38 +12,43 @@ from django.utils import timezone from . import PostgreSQLSimpleTestCase, PostgreSQLTestCase from .models import ( - BigAutoFieldModel, PostgreSQLModel, RangeLookupsModel, RangesModel, + BigAutoFieldModel, + PostgreSQLModel, + RangeLookupsModel, + RangesModel, SmallAutoFieldModel, ) try: from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange - from django.contrib.postgres import fields as pg_fields, forms as pg_forms + from django.contrib.postgres import fields as pg_fields + from django.contrib.postgres import forms as pg_forms from django.contrib.postgres.validators import ( - RangeMaxValueValidator, RangeMinValueValidator, + RangeMaxValueValidator, + RangeMinValueValidator, ) except ImportError: pass -@isolate_apps('postgres_tests') +@isolate_apps("postgres_tests") class BasicTests(PostgreSQLSimpleTestCase): def test_get_field_display(self): class Model(PostgreSQLModel): field = pg_fields.IntegerRangeField( choices=[ - ['1-50', [((1, 25), '1-25'), ([26, 50], '26-50')]], - ((51, 100), '51-100'), + ["1-50", [((1, 25), "1-25"), ([26, 50], "26-50")]], + ((51, 100), "51-100"), ], ) tests = ( - ((1, 25), '1-25'), - ([26, 50], '26-50'), - ((51, 100), '51-100'), - ((1, 2), '(1, 2)'), - ([1, 2], '[1, 2]'), + ((1, 25), "1-25"), + ([26, 50], "26-50"), + ((51, 100), "51-100"), + ((1, 2), "(1, 2)"), + ([1, 2], "[1, 2]"), ) for value, display in tests: with self.subTest(value=value, display=display): @@ -59,7 +64,7 @@ class BasicTests(PostgreSQLSimpleTestCase): for field_type in discrete_range_types: msg = f"Cannot use 'default_bounds' with {field_type.__name__}." with self.assertRaisesMessage(TypeError, msg): - field_type(choices=[((51, 100), '51-100')], default_bounds='[]') + field_type(choices=[((51, 100), "51-100")], default_bounds="[]") def test_continuous_range_fields_default_bounds(self): continuous_range_types = [ @@ -67,11 +72,11 @@ class BasicTests(PostgreSQLSimpleTestCase): pg_fields.DateTimeRangeField, ] for field_type in continuous_range_types: - field = field_type(choices=[((51, 100), '51-100')], default_bounds='[]') - self.assertEqual(field.default_bounds, '[]') + field = field_type(choices=[((51, 100), "51-100")], default_bounds="[]") + self.assertEqual(field.default_bounds, "[]") def test_invalid_default_bounds(self): - tests = [')]', ')[', '](', '])', '([', '[(', 'x', '', None] + tests = [")]", ")[", "](", "])", "([", "[(", "x", "", None] msg = "default_bounds must be one of '[)', '(]', '()', or '[]'." for invalid_bounds in tests: with self.assertRaisesMessage(ValueError, msg): @@ -81,13 +86,12 @@ class BasicTests(PostgreSQLSimpleTestCase): field = pg_fields.DecimalRangeField() *_, kwargs = field.deconstruct() self.assertEqual(kwargs, {}) - field = pg_fields.DecimalRangeField(default_bounds='[]') + field = pg_fields.DecimalRangeField(default_bounds="[]") *_, kwargs = field.deconstruct() - self.assertEqual(kwargs, {'default_bounds': '[]'}) + self.assertEqual(kwargs, {"default_bounds": "[]"}) class TestSaveLoad(PostgreSQLTestCase): - def test_all_fields(self): now = timezone.now() instance = RangesModel( @@ -124,15 +128,15 @@ class TestSaveLoad(PostgreSQLTestCase): loaded = RangesModel.objects.get() self.assertEqual( loaded.timestamps_closed_bounds, - DateTimeTZRange(range_[0], range_[1], '[]'), + DateTimeTZRange(range_[0], range_[1], "[]"), ) self.assertEqual( loaded.timestamps, - DateTimeTZRange(range_[0], range_[1], '[)'), + DateTimeTZRange(range_[0], range_[1], "[)"), ) def test_range_object_boundaries(self): - r = NumericRange(0, 10, '[]') + r = NumericRange(0, 10, "[]") instance = RangesModel(decimals=r) instance.save() loaded = RangesModel.objects.get() @@ -143,14 +147,14 @@ class TestSaveLoad(PostgreSQLTestCase): range_ = DateTimeTZRange( timezone.now(), timezone.now() + datetime.timedelta(hours=1), - bounds='()', + bounds="()", ) RangesModel.objects.create(timestamps_closed_bounds=range_) loaded = RangesModel.objects.get() self.assertEqual(loaded.timestamps_closed_bounds, range_) def test_unbounded(self): - r = NumericRange(None, None, '()') + r = NumericRange(None, None, "()") instance = RangesModel(decimals=r) instance.save() loaded = RangesModel.objects.get() @@ -171,13 +175,12 @@ class TestSaveLoad(PostgreSQLTestCase): def test_model_set_on_base_field(self): instance = RangesModel() - field = instance._meta.get_field('ints') + field = instance._meta.get_field("ints") self.assertEqual(field.model, RangesModel) self.assertEqual(field.base_field.model, RangesModel) class TestRangeContainsLookup(PostgreSQLTestCase): - @classmethod def setUpTestData(cls): cls.timestamps = [ @@ -189,8 +192,7 @@ class TestRangeContainsLookup(PostgreSQLTestCase): datetime.datetime(year=2016, month=2, day=2), ] cls.aware_timestamps = [ - timezone.make_aware(timestamp) - for timestamp in cls.timestamps + timezone.make_aware(timestamp) for timestamp in cls.timestamps ] cls.dates = [ datetime.date(year=2016, month=1, day=1), @@ -230,13 +232,13 @@ class TestRangeContainsLookup(PostgreSQLTestCase): (self.timestamps[1], self.timestamps[2]), (self.aware_timestamps[1], self.aware_timestamps[2]), Value(self.dates[0]), - Func(F('dates'), function='lower', output_field=DateTimeField()), - F('timestamps_inner'), + Func(F("dates"), function="lower", output_field=DateTimeField()), + F("timestamps_inner"), ) for filter_arg in filter_args: with self.subTest(filter_arg=filter_arg): self.assertCountEqual( - RangesModel.objects.filter(**{'timestamps__contains': filter_arg}), + RangesModel.objects.filter(**{"timestamps__contains": filter_arg}), [self.obj, self.aware_obj], ) @@ -245,28 +247,29 @@ class TestRangeContainsLookup(PostgreSQLTestCase): self.timestamps[1], (self.dates[1], self.dates[2]), Value(self.dates[0], output_field=DateField()), - Func(F('timestamps'), function='lower', output_field=DateField()), - F('dates_inner'), + Func(F("timestamps"), function="lower", output_field=DateField()), + F("dates_inner"), ) for filter_arg in filter_args: with self.subTest(filter_arg=filter_arg): self.assertCountEqual( - RangesModel.objects.filter(**{'dates__contains': filter_arg}), + RangesModel.objects.filter(**{"dates__contains": filter_arg}), [self.obj, self.aware_obj], ) class TestQuerying(PostgreSQLTestCase): - @classmethod def setUpTestData(cls): - cls.objs = RangesModel.objects.bulk_create([ - RangesModel(ints=NumericRange(0, 10)), - RangesModel(ints=NumericRange(5, 15)), - RangesModel(ints=NumericRange(None, 0)), - RangesModel(ints=NumericRange(empty=True)), - RangesModel(ints=None), - ]) + cls.objs = RangesModel.objects.bulk_create( + [ + RangesModel(ints=NumericRange(0, 10)), + RangesModel(ints=NumericRange(5, 15)), + RangesModel(ints=NumericRange(None, 0)), + RangesModel(ints=NumericRange(empty=True)), + RangesModel(ints=None), + ] + ) def test_exact(self): self.assertSequenceEqual( @@ -359,26 +362,28 @@ class TestQuerying(PostgreSQLTestCase): ) def test_bound_type(self): - decimals = RangesModel.objects.bulk_create([ - RangesModel(decimals=NumericRange(None, 10)), - RangesModel(decimals=NumericRange(10, None)), - RangesModel(decimals=NumericRange(5, 15)), - RangesModel(decimals=NumericRange(5, 15, '(]')), - ]) + decimals = RangesModel.objects.bulk_create( + [ + RangesModel(decimals=NumericRange(None, 10)), + RangesModel(decimals=NumericRange(10, None)), + RangesModel(decimals=NumericRange(5, 15)), + RangesModel(decimals=NumericRange(5, 15, "(]")), + ] + ) tests = [ - ('lower_inc', True, [decimals[1], decimals[2]]), - ('lower_inc', False, [decimals[0], decimals[3]]), - ('lower_inf', True, [decimals[0]]), - ('lower_inf', False, [decimals[1], decimals[2], decimals[3]]), - ('upper_inc', True, [decimals[3]]), - ('upper_inc', False, [decimals[0], decimals[1], decimals[2]]), - ('upper_inf', True, [decimals[1]]), - ('upper_inf', False, [decimals[0], decimals[2], decimals[3]]), + ("lower_inc", True, [decimals[1], decimals[2]]), + ("lower_inc", False, [decimals[0], decimals[3]]), + ("lower_inf", True, [decimals[0]]), + ("lower_inf", False, [decimals[1], decimals[2], decimals[3]]), + ("upper_inc", True, [decimals[3]]), + ("upper_inc", False, [decimals[0], decimals[1], decimals[2]]), + ("upper_inf", True, [decimals[1]]), + ("upper_inf", False, [decimals[0], decimals[2], decimals[3]]), ] for lookup, filter_arg, excepted_result in tests: with self.subTest(lookup=lookup, filter_arg=filter_arg): self.assertSequenceEqual( - RangesModel.objects.filter(**{'decimals__%s' % lookup: filter_arg}), + RangesModel.objects.filter(**{"decimals__%s" % lookup: filter_arg}), excepted_result, ) @@ -386,32 +391,38 @@ class TestQuerying(PostgreSQLTestCase): class TestQueryingWithRanges(PostgreSQLTestCase): def test_date_range(self): objs = [ - RangeLookupsModel.objects.create(date='2015-01-01'), - RangeLookupsModel.objects.create(date='2015-05-05'), + RangeLookupsModel.objects.create(date="2015-01-01"), + RangeLookupsModel.objects.create(date="2015-05-05"), ] self.assertSequenceEqual( - RangeLookupsModel.objects.filter(date__contained_by=DateRange('2015-01-01', '2015-05-04')), + RangeLookupsModel.objects.filter( + date__contained_by=DateRange("2015-01-01", "2015-05-04") + ), [objs[0]], ) def test_date_range_datetime_field(self): objs = [ - RangeLookupsModel.objects.create(timestamp='2015-01-01'), - RangeLookupsModel.objects.create(timestamp='2015-05-05'), + RangeLookupsModel.objects.create(timestamp="2015-01-01"), + RangeLookupsModel.objects.create(timestamp="2015-05-05"), ] self.assertSequenceEqual( - RangeLookupsModel.objects.filter(timestamp__date__contained_by=DateRange('2015-01-01', '2015-05-04')), + RangeLookupsModel.objects.filter( + timestamp__date__contained_by=DateRange("2015-01-01", "2015-05-04") + ), [objs[0]], ) def test_datetime_range(self): objs = [ - RangeLookupsModel.objects.create(timestamp='2015-01-01T09:00:00'), - RangeLookupsModel.objects.create(timestamp='2015-05-05T17:00:00'), + RangeLookupsModel.objects.create(timestamp="2015-01-01T09:00:00"), + RangeLookupsModel.objects.create(timestamp="2015-05-05T17:00:00"), ] self.assertSequenceEqual( RangeLookupsModel.objects.filter( - timestamp__contained_by=DateTimeTZRange('2015-01-01T09:00', '2015-05-04T23:55') + timestamp__contained_by=DateTimeTZRange( + "2015-01-01T09:00", "2015-05-04T23:55" + ) ), [objs[0]], ) @@ -423,7 +434,9 @@ class TestQueryingWithRanges(PostgreSQLTestCase): RangeLookupsModel.objects.create(small_integer=-1), ] self.assertSequenceEqual( - RangeLookupsModel.objects.filter(small_integer__contained_by=NumericRange(4, 6)), + RangeLookupsModel.objects.filter( + small_integer__contained_by=NumericRange(4, 6) + ), [objs[1]], ) @@ -435,7 +448,7 @@ class TestQueryingWithRanges(PostgreSQLTestCase): ] self.assertSequenceEqual( RangeLookupsModel.objects.filter(integer__contained_by=NumericRange(1, 98)), - [objs[0]] + [objs[0]], ) def test_biginteger_range(self): @@ -445,19 +458,23 @@ class TestQueryingWithRanges(PostgreSQLTestCase): RangeLookupsModel.objects.create(big_integer=-1), ] self.assertSequenceEqual( - RangeLookupsModel.objects.filter(big_integer__contained_by=NumericRange(1, 98)), - [objs[0]] + RangeLookupsModel.objects.filter( + big_integer__contained_by=NumericRange(1, 98) + ), + [objs[0]], ) def test_decimal_field_contained_by(self): objs = [ - RangeLookupsModel.objects.create(decimal_field=Decimal('1.33')), - RangeLookupsModel.objects.create(decimal_field=Decimal('2.88')), - RangeLookupsModel.objects.create(decimal_field=Decimal('99.17')), + RangeLookupsModel.objects.create(decimal_field=Decimal("1.33")), + RangeLookupsModel.objects.create(decimal_field=Decimal("2.88")), + RangeLookupsModel.objects.create(decimal_field=Decimal("99.17")), ] self.assertSequenceEqual( RangeLookupsModel.objects.filter( - decimal_field__contained_by=NumericRange(Decimal('1.89'), Decimal('7.91')), + decimal_field__contained_by=NumericRange( + Decimal("1.89"), Decimal("7.91") + ), ), [objs[1]], ) @@ -470,13 +487,13 @@ class TestQueryingWithRanges(PostgreSQLTestCase): ] self.assertSequenceEqual( RangeLookupsModel.objects.filter(float__contained_by=NumericRange(1, 98)), - [objs[0]] + [objs[0]], ) def test_small_auto_field_contained_by(self): - objs = SmallAutoFieldModel.objects.bulk_create([ - SmallAutoFieldModel() for i in range(1, 5) - ]) + objs = SmallAutoFieldModel.objects.bulk_create( + [SmallAutoFieldModel() for i in range(1, 5)] + ) self.assertSequenceEqual( SmallAutoFieldModel.objects.filter( id__contained_by=NumericRange(objs[1].pk, objs[3].pk), @@ -485,9 +502,9 @@ class TestQueryingWithRanges(PostgreSQLTestCase): ) def test_auto_field_contained_by(self): - objs = RangeLookupsModel.objects.bulk_create([ - RangeLookupsModel() for i in range(1, 5) - ]) + objs = RangeLookupsModel.objects.bulk_create( + [RangeLookupsModel() for i in range(1, 5)] + ) self.assertSequenceEqual( RangeLookupsModel.objects.filter( id__contained_by=NumericRange(objs[1].pk, objs[3].pk), @@ -496,9 +513,9 @@ class TestQueryingWithRanges(PostgreSQLTestCase): ) def test_big_auto_field_contained_by(self): - objs = BigAutoFieldModel.objects.bulk_create([ - BigAutoFieldModel() for i in range(1, 5) - ]) + objs = BigAutoFieldModel.objects.bulk_create( + [BigAutoFieldModel() for i in range(1, 5)] + ) self.assertSequenceEqual( BigAutoFieldModel.objects.filter( id__contained_by=NumericRange(objs[1].pk, objs[3].pk), @@ -513,8 +530,8 @@ class TestQueryingWithRanges(PostgreSQLTestCase): RangeLookupsModel.objects.create(float=99, parent=parent), ] self.assertSequenceEqual( - RangeLookupsModel.objects.filter(float__contained_by=F('parent__decimals')), - [objs[0]] + RangeLookupsModel.objects.filter(float__contained_by=F("parent__decimals")), + [objs[0]], ) def test_exclude(self): @@ -525,7 +542,7 @@ class TestQueryingWithRanges(PostgreSQLTestCase): ] self.assertSequenceEqual( RangeLookupsModel.objects.exclude(float__contained_by=NumericRange(0, 100)), - [objs[2]] + [objs[2]], ) @@ -550,44 +567,49 @@ class TestSerialization(PostgreSQLSimpleTestCase): def test_dumping(self): instance = RangesModel( - ints=NumericRange(0, 10), decimals=NumericRange(empty=True), + ints=NumericRange(0, 10), + decimals=NumericRange(empty=True), timestamps=DateTimeTZRange(self.lower_dt, self.upper_dt), timestamps_closed_bounds=DateTimeTZRange( - self.lower_dt, self.upper_dt, bounds='()', + self.lower_dt, + self.upper_dt, + bounds="()", ), dates=DateRange(self.lower_date, self.upper_date), ) - data = serializers.serialize('json', [instance]) + data = serializers.serialize("json", [instance]) dumped = json.loads(data) - for field in ('ints', 'dates', 'timestamps', 'timestamps_closed_bounds'): - dumped[0]['fields'][field] = json.loads(dumped[0]['fields'][field]) + for field in ("ints", "dates", "timestamps", "timestamps_closed_bounds"): + dumped[0]["fields"][field] = json.loads(dumped[0]["fields"][field]) check = json.loads(self.test_data) - for field in ('ints', 'dates', 'timestamps', 'timestamps_closed_bounds'): - check[0]['fields'][field] = json.loads(check[0]['fields'][field]) + for field in ("ints", "dates", "timestamps", "timestamps_closed_bounds"): + check[0]["fields"][field] = json.loads(check[0]["fields"][field]) self.assertEqual(dumped, check) def test_loading(self): - instance = list(serializers.deserialize('json', self.test_data))[0].object + instance = list(serializers.deserialize("json", self.test_data))[0].object self.assertEqual(instance.ints, NumericRange(0, 10)) self.assertEqual(instance.decimals, NumericRange(empty=True)) self.assertIsNone(instance.bigints) self.assertEqual(instance.dates, DateRange(self.lower_date, self.upper_date)) - self.assertEqual(instance.timestamps, DateTimeTZRange(self.lower_dt, self.upper_dt)) + self.assertEqual( + instance.timestamps, DateTimeTZRange(self.lower_dt, self.upper_dt) + ) self.assertEqual( instance.timestamps_closed_bounds, - DateTimeTZRange(self.lower_dt, self.upper_dt, bounds='()'), + DateTimeTZRange(self.lower_dt, self.upper_dt, bounds="()"), ) def test_serialize_range_with_null(self): instance = RangesModel(ints=NumericRange(None, 10)) - data = serializers.serialize('json', [instance]) - new_instance = list(serializers.deserialize('json', data))[0].object + data = serializers.serialize("json", [instance]) + new_instance = list(serializers.deserialize("json", data))[0].object self.assertEqual(new_instance.ints, NumericRange(None, 10)) instance = RangesModel(ints=NumericRange(10, None)) - data = serializers.serialize('json', [instance]) - new_instance = list(serializers.deserialize('json', data))[0].object + data = serializers.serialize("json", [instance]) + new_instance = list(serializers.deserialize("json", data))[0].object self.assertEqual(new_instance.ints, NumericRange(10, None)) @@ -596,60 +618,59 @@ class TestChecks(PostgreSQLSimpleTestCase): class Model(PostgreSQLModel): field = pg_fields.IntegerRangeField( choices=[ - ['1-50', [((1, 25), '1-25'), ([26, 50], '26-50')]], - ((51, 100), '51-100'), + ["1-50", [((1, 25), "1-25"), ([26, 50], "26-50")]], + ((51, 100), "51-100"), ], ) - self.assertEqual(Model._meta.get_field('field').check(), []) + self.assertEqual(Model._meta.get_field("field").check(), []) -class TestValidators(PostgreSQLSimpleTestCase): +class TestValidators(PostgreSQLSimpleTestCase): def test_max(self): validator = RangeMaxValueValidator(5) validator(NumericRange(0, 5)) - msg = 'Ensure that this range is completely less than or equal to 5.' + msg = "Ensure that this range is completely less than or equal to 5." with self.assertRaises(exceptions.ValidationError) as cm: validator(NumericRange(0, 10)) self.assertEqual(cm.exception.messages[0], msg) - self.assertEqual(cm.exception.code, 'max_value') + self.assertEqual(cm.exception.code, "max_value") with self.assertRaisesMessage(exceptions.ValidationError, msg): validator(NumericRange(0, None)) # an unbound range def test_min(self): validator = RangeMinValueValidator(5) validator(NumericRange(10, 15)) - msg = 'Ensure that this range is completely greater than or equal to 5.' + msg = "Ensure that this range is completely greater than or equal to 5." with self.assertRaises(exceptions.ValidationError) as cm: validator(NumericRange(0, 10)) self.assertEqual(cm.exception.messages[0], msg) - self.assertEqual(cm.exception.code, 'min_value') + self.assertEqual(cm.exception.code, "min_value") with self.assertRaisesMessage(exceptions.ValidationError, msg): validator(NumericRange(None, 10)) # an unbound range class TestFormField(PostgreSQLSimpleTestCase): - def test_valid_integer(self): field = pg_forms.IntegerRangeField() - value = field.clean(['1', '2']) + value = field.clean(["1", "2"]) self.assertEqual(value, NumericRange(1, 2)) def test_valid_decimal(self): field = pg_forms.DecimalRangeField() - value = field.clean(['1.12345', '2.001']) - self.assertEqual(value, NumericRange(Decimal('1.12345'), Decimal('2.001'))) + value = field.clean(["1.12345", "2.001"]) + self.assertEqual(value, NumericRange(Decimal("1.12345"), Decimal("2.001"))) def test_valid_timestamps(self): field = pg_forms.DateTimeRangeField() - value = field.clean(['01/01/2014 00:00:00', '02/02/2014 12:12:12']) + value = field.clean(["01/01/2014 00:00:00", "02/02/2014 12:12:12"]) lower = datetime.datetime(2014, 1, 1, 0, 0, 0) upper = datetime.datetime(2014, 2, 2, 12, 12, 12) self.assertEqual(value, DateTimeTZRange(lower, upper)) def test_valid_dates(self): field = pg_forms.DateRangeField() - value = field.clean(['01/01/2014', '02/02/2014']) + value = field.clean(["01/01/2014", "02/02/2014"]) lower = datetime.date(2014, 1, 1) upper = datetime.date(2014, 2, 2) self.assertEqual(value, DateRange(lower, upper)) @@ -662,7 +683,9 @@ class TestFormField(PostgreSQLSimpleTestCase): field = SplitDateTimeRangeField() form = SplitForm() - self.assertHTMLEqual(str(form), ''' + self.assertHTMLEqual( + str(form), + """ <tr> <th> <label>Field:</label> @@ -674,21 +697,24 @@ class TestFormField(PostgreSQLSimpleTestCase): <input id="id_field_1_1" name="field_1_1" type="text"> </td> </tr> - ''') - form = SplitForm({ - 'field_0_0': '01/01/2014', - 'field_0_1': '00:00:00', - 'field_1_0': '02/02/2014', - 'field_1_1': '12:12:12', - }) + """, + ) + form = SplitForm( + { + "field_0_0": "01/01/2014", + "field_0_1": "00:00:00", + "field_1_0": "02/02/2014", + "field_1_1": "12:12:12", + } + ) self.assertTrue(form.is_valid()) lower = datetime.datetime(2014, 1, 1, 0, 0, 0) upper = datetime.datetime(2014, 2, 2, 12, 12, 12) - self.assertEqual(form.cleaned_data['field'], DateTimeTZRange(lower, upper)) + self.assertEqual(form.cleaned_data["field"], DateTimeTZRange(lower, upper)) def test_none(self): field = pg_forms.IntegerRangeField(required=False) - value = field.clean(['', '']) + value = field.clean(["", ""]) self.assertIsNone(value) def test_datetime_form_as_table(self): @@ -707,12 +733,14 @@ class TestFormField(PostgreSQLSimpleTestCase): <input type="hidden" name="initial-datetime_field_0" id="initial-id_datetime_field_0"> <input type="hidden" name="initial-datetime_field_1" id="initial-id_datetime_field_1"> </td></tr> - """ + """, + ) + form = DateTimeRangeForm( + { + "datetime_field_0": "2010-01-01 11:13:00", + "datetime_field_1": "2020-12-12 16:59:00", + } ) - form = DateTimeRangeForm({ - 'datetime_field_0': '2010-01-01 11:13:00', - 'datetime_field_1': '2020-12-12 16:59:00', - }) self.assertHTMLEqual( form.as_table(), """ @@ -727,7 +755,7 @@ class TestFormField(PostgreSQLSimpleTestCase): id="initial-id_datetime_field_0"> <input type="hidden" name="initial-datetime_field_1" value="2020-12-12 16:59:00" id="initial-id_datetime_field_1"></td></tr> - """ + """, ) def test_datetime_form_initial_data(self): @@ -735,16 +763,18 @@ class TestFormField(PostgreSQLSimpleTestCase): datetime_field = pg_forms.DateTimeRangeField(show_hidden_initial=True) data = QueryDict(mutable=True) - data.update({ - 'datetime_field_0': '2010-01-01 11:13:00', - 'datetime_field_1': '', - 'initial-datetime_field_0': '2010-01-01 10:12:00', - 'initial-datetime_field_1': '', - }) + data.update( + { + "datetime_field_0": "2010-01-01 11:13:00", + "datetime_field_1": "", + "initial-datetime_field_0": "2010-01-01 10:12:00", + "initial-datetime_field_1": "", + } + ) form = DateTimeRangeForm(data=data) self.assertTrue(form.has_changed()) - data['initial-datetime_field_0'] = '2010-01-01 11:13:00' + data["initial-datetime_field_0"] = "2010-01-01 11:13:00" form = DateTimeRangeForm(data=data) self.assertFalse(form.has_changed()) @@ -752,7 +782,9 @@ class TestFormField(PostgreSQLSimpleTestCase): class RangeForm(forms.Form): ints = pg_forms.IntegerRangeField() - self.assertHTMLEqual(str(RangeForm()), ''' + self.assertHTMLEqual( + str(RangeForm()), + """ <tr> <th><label>Ints:</label></th> <td> @@ -760,195 +792,222 @@ class TestFormField(PostgreSQLSimpleTestCase): <input id="id_ints_1" name="ints_1" type="number"> </td> </tr> - ''') + """, + ) def test_integer_lower_bound_higher(self): field = pg_forms.IntegerRangeField() with self.assertRaises(exceptions.ValidationError) as cm: - field.clean(['10', '2']) - self.assertEqual(cm.exception.messages[0], 'The start of the range must not exceed the end of the range.') - self.assertEqual(cm.exception.code, 'bound_ordering') + field.clean(["10", "2"]) + self.assertEqual( + cm.exception.messages[0], + "The start of the range must not exceed the end of the range.", + ) + self.assertEqual(cm.exception.code, "bound_ordering") def test_integer_open(self): field = pg_forms.IntegerRangeField() - value = field.clean(['', '0']) + value = field.clean(["", "0"]) self.assertEqual(value, NumericRange(None, 0)) def test_integer_incorrect_data_type(self): field = pg_forms.IntegerRangeField() with self.assertRaises(exceptions.ValidationError) as cm: - field.clean('1') - self.assertEqual(cm.exception.messages[0], 'Enter two whole numbers.') - self.assertEqual(cm.exception.code, 'invalid') + field.clean("1") + self.assertEqual(cm.exception.messages[0], "Enter two whole numbers.") + self.assertEqual(cm.exception.code, "invalid") def test_integer_invalid_lower(self): field = pg_forms.IntegerRangeField() with self.assertRaises(exceptions.ValidationError) as cm: - field.clean(['a', '2']) - self.assertEqual(cm.exception.messages[0], 'Enter a whole number.') + field.clean(["a", "2"]) + self.assertEqual(cm.exception.messages[0], "Enter a whole number.") def test_integer_invalid_upper(self): field = pg_forms.IntegerRangeField() with self.assertRaises(exceptions.ValidationError) as cm: - field.clean(['1', 'b']) - self.assertEqual(cm.exception.messages[0], 'Enter a whole number.') + field.clean(["1", "b"]) + self.assertEqual(cm.exception.messages[0], "Enter a whole number.") def test_integer_required(self): field = pg_forms.IntegerRangeField(required=True) with self.assertRaises(exceptions.ValidationError) as cm: - field.clean(['', '']) - self.assertEqual(cm.exception.messages[0], 'This field is required.') - value = field.clean([1, '']) + field.clean(["", ""]) + self.assertEqual(cm.exception.messages[0], "This field is required.") + value = field.clean([1, ""]) self.assertEqual(value, NumericRange(1, None)) def test_decimal_lower_bound_higher(self): field = pg_forms.DecimalRangeField() with self.assertRaises(exceptions.ValidationError) as cm: - field.clean(['1.8', '1.6']) - self.assertEqual(cm.exception.messages[0], 'The start of the range must not exceed the end of the range.') - self.assertEqual(cm.exception.code, 'bound_ordering') + field.clean(["1.8", "1.6"]) + self.assertEqual( + cm.exception.messages[0], + "The start of the range must not exceed the end of the range.", + ) + self.assertEqual(cm.exception.code, "bound_ordering") def test_decimal_open(self): field = pg_forms.DecimalRangeField() - value = field.clean(['', '3.1415926']) - self.assertEqual(value, NumericRange(None, Decimal('3.1415926'))) + value = field.clean(["", "3.1415926"]) + self.assertEqual(value, NumericRange(None, Decimal("3.1415926"))) def test_decimal_incorrect_data_type(self): field = pg_forms.DecimalRangeField() with self.assertRaises(exceptions.ValidationError) as cm: - field.clean('1.6') - self.assertEqual(cm.exception.messages[0], 'Enter two numbers.') - self.assertEqual(cm.exception.code, 'invalid') + field.clean("1.6") + self.assertEqual(cm.exception.messages[0], "Enter two numbers.") + self.assertEqual(cm.exception.code, "invalid") def test_decimal_invalid_lower(self): field = pg_forms.DecimalRangeField() with self.assertRaises(exceptions.ValidationError) as cm: - field.clean(['a', '3.1415926']) - self.assertEqual(cm.exception.messages[0], 'Enter a number.') + field.clean(["a", "3.1415926"]) + self.assertEqual(cm.exception.messages[0], "Enter a number.") def test_decimal_invalid_upper(self): field = pg_forms.DecimalRangeField() with self.assertRaises(exceptions.ValidationError) as cm: - field.clean(['1.61803399', 'b']) - self.assertEqual(cm.exception.messages[0], 'Enter a number.') + field.clean(["1.61803399", "b"]) + self.assertEqual(cm.exception.messages[0], "Enter a number.") def test_decimal_required(self): field = pg_forms.DecimalRangeField(required=True) with self.assertRaises(exceptions.ValidationError) as cm: - field.clean(['', '']) - self.assertEqual(cm.exception.messages[0], 'This field is required.') - value = field.clean(['1.61803399', '']) - self.assertEqual(value, NumericRange(Decimal('1.61803399'), None)) + field.clean(["", ""]) + self.assertEqual(cm.exception.messages[0], "This field is required.") + value = field.clean(["1.61803399", ""]) + self.assertEqual(value, NumericRange(Decimal("1.61803399"), None)) def test_date_lower_bound_higher(self): field = pg_forms.DateRangeField() with self.assertRaises(exceptions.ValidationError) as cm: - field.clean(['2013-04-09', '1976-04-16']) - self.assertEqual(cm.exception.messages[0], 'The start of the range must not exceed the end of the range.') - self.assertEqual(cm.exception.code, 'bound_ordering') + field.clean(["2013-04-09", "1976-04-16"]) + self.assertEqual( + cm.exception.messages[0], + "The start of the range must not exceed the end of the range.", + ) + self.assertEqual(cm.exception.code, "bound_ordering") def test_date_open(self): field = pg_forms.DateRangeField() - value = field.clean(['', '2013-04-09']) + value = field.clean(["", "2013-04-09"]) self.assertEqual(value, DateRange(None, datetime.date(2013, 4, 9))) def test_date_incorrect_data_type(self): field = pg_forms.DateRangeField() with self.assertRaises(exceptions.ValidationError) as cm: - field.clean('1') - self.assertEqual(cm.exception.messages[0], 'Enter two valid dates.') - self.assertEqual(cm.exception.code, 'invalid') + field.clean("1") + self.assertEqual(cm.exception.messages[0], "Enter two valid dates.") + self.assertEqual(cm.exception.code, "invalid") def test_date_invalid_lower(self): field = pg_forms.DateRangeField() with self.assertRaises(exceptions.ValidationError) as cm: - field.clean(['a', '2013-04-09']) - self.assertEqual(cm.exception.messages[0], 'Enter a valid date.') + field.clean(["a", "2013-04-09"]) + self.assertEqual(cm.exception.messages[0], "Enter a valid date.") def test_date_invalid_upper(self): field = pg_forms.DateRangeField() with self.assertRaises(exceptions.ValidationError) as cm: - field.clean(['2013-04-09', 'b']) - self.assertEqual(cm.exception.messages[0], 'Enter a valid date.') + field.clean(["2013-04-09", "b"]) + self.assertEqual(cm.exception.messages[0], "Enter a valid date.") def test_date_required(self): field = pg_forms.DateRangeField(required=True) with self.assertRaises(exceptions.ValidationError) as cm: - field.clean(['', '']) - self.assertEqual(cm.exception.messages[0], 'This field is required.') - value = field.clean(['1976-04-16', '']) + field.clean(["", ""]) + self.assertEqual(cm.exception.messages[0], "This field is required.") + value = field.clean(["1976-04-16", ""]) self.assertEqual(value, DateRange(datetime.date(1976, 4, 16), None)) def test_date_has_changed_first(self): - self.assertTrue(pg_forms.DateRangeField().has_changed( - ['2010-01-01', '2020-12-12'], - ['2010-01-31', '2020-12-12'], - )) + self.assertTrue( + pg_forms.DateRangeField().has_changed( + ["2010-01-01", "2020-12-12"], + ["2010-01-31", "2020-12-12"], + ) + ) def test_date_has_changed_last(self): - self.assertTrue(pg_forms.DateRangeField().has_changed( - ['2010-01-01', '2020-12-12'], - ['2010-01-01', '2020-12-31'], - )) + self.assertTrue( + pg_forms.DateRangeField().has_changed( + ["2010-01-01", "2020-12-12"], + ["2010-01-01", "2020-12-31"], + ) + ) def test_datetime_lower_bound_higher(self): field = pg_forms.DateTimeRangeField() with self.assertRaises(exceptions.ValidationError) as cm: - field.clean(['2006-10-25 14:59', '2006-10-25 14:58']) - self.assertEqual(cm.exception.messages[0], 'The start of the range must not exceed the end of the range.') - self.assertEqual(cm.exception.code, 'bound_ordering') + field.clean(["2006-10-25 14:59", "2006-10-25 14:58"]) + self.assertEqual( + cm.exception.messages[0], + "The start of the range must not exceed the end of the range.", + ) + self.assertEqual(cm.exception.code, "bound_ordering") def test_datetime_open(self): field = pg_forms.DateTimeRangeField() - value = field.clean(['', '2013-04-09 11:45']) - self.assertEqual(value, DateTimeTZRange(None, datetime.datetime(2013, 4, 9, 11, 45))) + value = field.clean(["", "2013-04-09 11:45"]) + self.assertEqual( + value, DateTimeTZRange(None, datetime.datetime(2013, 4, 9, 11, 45)) + ) def test_datetime_incorrect_data_type(self): field = pg_forms.DateTimeRangeField() with self.assertRaises(exceptions.ValidationError) as cm: - field.clean('2013-04-09 11:45') - self.assertEqual(cm.exception.messages[0], 'Enter two valid date/times.') - self.assertEqual(cm.exception.code, 'invalid') + field.clean("2013-04-09 11:45") + self.assertEqual(cm.exception.messages[0], "Enter two valid date/times.") + self.assertEqual(cm.exception.code, "invalid") def test_datetime_invalid_lower(self): field = pg_forms.DateTimeRangeField() with self.assertRaises(exceptions.ValidationError) as cm: - field.clean(['45', '2013-04-09 11:45']) - self.assertEqual(cm.exception.messages[0], 'Enter a valid date/time.') + field.clean(["45", "2013-04-09 11:45"]) + self.assertEqual(cm.exception.messages[0], "Enter a valid date/time.") def test_datetime_invalid_upper(self): field = pg_forms.DateTimeRangeField() with self.assertRaises(exceptions.ValidationError) as cm: - field.clean(['2013-04-09 11:45', 'sweet pickles']) - self.assertEqual(cm.exception.messages[0], 'Enter a valid date/time.') + field.clean(["2013-04-09 11:45", "sweet pickles"]) + self.assertEqual(cm.exception.messages[0], "Enter a valid date/time.") def test_datetime_required(self): field = pg_forms.DateTimeRangeField(required=True) with self.assertRaises(exceptions.ValidationError) as cm: - field.clean(['', '']) - self.assertEqual(cm.exception.messages[0], 'This field is required.') - value = field.clean(['2013-04-09 11:45', '']) - self.assertEqual(value, DateTimeTZRange(datetime.datetime(2013, 4, 9, 11, 45), None)) + field.clean(["", ""]) + self.assertEqual(cm.exception.messages[0], "This field is required.") + value = field.clean(["2013-04-09 11:45", ""]) + self.assertEqual( + value, DateTimeTZRange(datetime.datetime(2013, 4, 9, 11, 45), None) + ) - @override_settings(USE_TZ=True, TIME_ZONE='Africa/Johannesburg') + @override_settings(USE_TZ=True, TIME_ZONE="Africa/Johannesburg") def test_datetime_prepare_value(self): field = pg_forms.DateTimeRangeField() value = field.prepare_value( - DateTimeTZRange(datetime.datetime(2015, 5, 22, 16, 6, 33, tzinfo=timezone.utc), None) + DateTimeTZRange( + datetime.datetime(2015, 5, 22, 16, 6, 33, tzinfo=timezone.utc), None + ) ) self.assertEqual(value, [datetime.datetime(2015, 5, 22, 18, 6, 33), None]) def test_datetime_has_changed_first(self): - self.assertTrue(pg_forms.DateTimeRangeField().has_changed( - ['2010-01-01 00:00', '2020-12-12 00:00'], - ['2010-01-31 23:00', '2020-12-12 00:00'], - )) + self.assertTrue( + pg_forms.DateTimeRangeField().has_changed( + ["2010-01-01 00:00", "2020-12-12 00:00"], + ["2010-01-31 23:00", "2020-12-12 00:00"], + ) + ) def test_datetime_has_changed_last(self): - self.assertTrue(pg_forms.DateTimeRangeField().has_changed( - ['2010-01-01 00:00', '2020-12-12 00:00'], - ['2010-01-01 00:00', '2020-12-31 23:00'], - )) + self.assertTrue( + pg_forms.DateTimeRangeField().has_changed( + ["2010-01-01 00:00", "2020-12-12 00:00"], + ["2010-01-01 00:00", "2020-12-31 23:00"], + ) + ) def test_model_field_formfield_integer(self): model_field = pg_fields.IntegerRangeField() @@ -963,10 +1022,10 @@ class TestFormField(PostgreSQLSimpleTestCase): self.assertEqual(form_field.range_kwargs, {}) def test_model_field_formfield_float(self): - model_field = pg_fields.DecimalRangeField(default_bounds='()') + model_field = pg_fields.DecimalRangeField(default_bounds="()") form_field = model_field.formfield() self.assertIsInstance(form_field, pg_forms.DecimalRangeField) - self.assertEqual(form_field.range_kwargs, {'bounds': '()'}) + self.assertEqual(form_field.range_kwargs, {"bounds": "()"}) def test_model_field_formfield_date(self): model_field = pg_fields.DateRangeField() @@ -980,33 +1039,33 @@ class TestFormField(PostgreSQLSimpleTestCase): self.assertIsInstance(form_field, pg_forms.DateTimeRangeField) self.assertEqual( form_field.range_kwargs, - {'bounds': pg_fields.ranges.CANONICAL_RANGE_BOUNDS}, + {"bounds": pg_fields.ranges.CANONICAL_RANGE_BOUNDS}, ) def test_model_field_formfield_datetime_default_bounds(self): - model_field = pg_fields.DateTimeRangeField(default_bounds='[]') + model_field = pg_fields.DateTimeRangeField(default_bounds="[]") form_field = model_field.formfield() self.assertIsInstance(form_field, pg_forms.DateTimeRangeField) - self.assertEqual(form_field.range_kwargs, {'bounds': '[]'}) + self.assertEqual(form_field.range_kwargs, {"bounds": "[]"}) def test_model_field_with_default_bounds(self): - field = pg_forms.DateTimeRangeField(default_bounds='[]') - value = field.clean(['2014-01-01 00:00:00', '2014-02-03 12:13:14']) + field = pg_forms.DateTimeRangeField(default_bounds="[]") + value = field.clean(["2014-01-01 00:00:00", "2014-02-03 12:13:14"]) lower = datetime.datetime(2014, 1, 1, 0, 0, 0) upper = datetime.datetime(2014, 2, 3, 12, 13, 14) - self.assertEqual(value, DateTimeTZRange(lower, upper, '[]')) + self.assertEqual(value, DateTimeTZRange(lower, upper, "[]")) def test_has_changed(self): for field, value in ( - (pg_forms.DateRangeField(), ['2010-01-01', '2020-12-12']), - (pg_forms.DateTimeRangeField(), ['2010-01-01 11:13', '2020-12-12 14:52']), + (pg_forms.DateRangeField(), ["2010-01-01", "2020-12-12"]), + (pg_forms.DateTimeRangeField(), ["2010-01-01 11:13", "2020-12-12 14:52"]), (pg_forms.IntegerRangeField(), [1, 2]), - (pg_forms.DecimalRangeField(), ['1.12345', '2.001']), + (pg_forms.DecimalRangeField(), ["1.12345", "2.001"]), ): with self.subTest(field=field.__class__.__name__): self.assertTrue(field.has_changed(None, value)) - self.assertTrue(field.has_changed([value[0], ''], value)) - self.assertTrue(field.has_changed(['', value[1]], value)) + self.assertTrue(field.has_changed([value[0], ""], value)) + self.assertTrue(field.has_changed(["", value[1]], value)) self.assertFalse(field.has_changed(value, value)) @@ -1014,19 +1073,18 @@ class TestWidget(PostgreSQLSimpleTestCase): def test_range_widget(self): f = pg_forms.ranges.DateTimeRangeField() self.assertHTMLEqual( - f.widget.render('datetimerange', ''), - '<input type="text" name="datetimerange_0"><input type="text" name="datetimerange_1">' + f.widget.render("datetimerange", ""), + '<input type="text" name="datetimerange_0"><input type="text" name="datetimerange_1">', ) self.assertHTMLEqual( - f.widget.render('datetimerange', None), - '<input type="text" name="datetimerange_0"><input type="text" name="datetimerange_1">' + f.widget.render("datetimerange", None), + '<input type="text" name="datetimerange_0"><input type="text" name="datetimerange_1">', ) dt_range = DateTimeTZRange( - datetime.datetime(2006, 1, 10, 7, 30), - datetime.datetime(2006, 2, 12, 9, 50) + datetime.datetime(2006, 1, 10, 7, 30), datetime.datetime(2006, 2, 12, 9, 50) ) self.assertHTMLEqual( - f.widget.render('datetimerange', dt_range), + f.widget.render("datetimerange", dt_range), '<input type="text" name="datetimerange_0" value="2006-01-10 07:30:00">' - '<input type="text" name="datetimerange_1" value="2006-02-12 09:50:00">' + '<input type="text" name="datetimerange_1" value="2006-02-12 09:50:00">', ) diff --git a/tests/postgres_tests/test_search.py b/tests/postgres_tests/test_search.py index bf6f53ddb2..670b0aaaa9 100644 --- a/tests/postgres_tests/test_search.py +++ b/tests/postgres_tests/test_search.py @@ -14,108 +14,125 @@ from .models import Character, Line, LineSavedSearch, Scene try: from django.contrib.postgres.search import ( - SearchConfig, SearchHeadline, SearchQuery, SearchRank, SearchVector, + SearchConfig, + SearchHeadline, + SearchQuery, + SearchRank, + SearchVector, ) except ImportError: pass class GrailTestData: - @classmethod def setUpTestData(cls): - cls.robin = Scene.objects.create(scene='Scene 10', setting='The dark forest of Ewing') - cls.minstrel = Character.objects.create(name='Minstrel') + cls.robin = Scene.objects.create( + scene="Scene 10", setting="The dark forest of Ewing" + ) + cls.minstrel = Character.objects.create(name="Minstrel") verses = [ ( - 'Bravely bold Sir Robin, rode forth from Camelot. ' - 'He was not afraid to die, o Brave Sir Robin. ' - 'He was not at all afraid to be killed in nasty ways. ' - 'Brave, brave, brave, brave Sir Robin' + "Bravely bold Sir Robin, rode forth from Camelot. " + "He was not afraid to die, o Brave Sir Robin. " + "He was not at all afraid to be killed in nasty ways. " + "Brave, brave, brave, brave Sir Robin" ), ( - 'He was not in the least bit scared to be mashed into a pulp, ' - 'Or to have his eyes gouged out, and his elbows broken. ' - 'To have his kneecaps split, and his body burned away, ' - 'And his limbs all hacked and mangled, brave Sir Robin!' + "He was not in the least bit scared to be mashed into a pulp, " + "Or to have his eyes gouged out, and his elbows broken. " + "To have his kneecaps split, and his body burned away, " + "And his limbs all hacked and mangled, brave Sir Robin!" ), ( - 'His head smashed in and his heart cut out, ' - 'And his liver removed and his bowels unplugged, ' - 'And his nostrils ripped and his bottom burned off,' - 'And his --' + "His head smashed in and his heart cut out, " + "And his liver removed and his bowels unplugged, " + "And his nostrils ripped and his bottom burned off," + "And his --" ), ] - cls.verses = [Line.objects.create( - scene=cls.robin, - character=cls.minstrel, - dialogue=verse, - ) for verse in verses] + cls.verses = [ + Line.objects.create( + scene=cls.robin, + character=cls.minstrel, + dialogue=verse, + ) + for verse in verses + ] cls.verse0, cls.verse1, cls.verse2 = cls.verses - cls.witch_scene = Scene.objects.create(scene='Scene 5', setting="Sir Bedemir's Castle") - bedemir = Character.objects.create(name='Bedemir') - crowd = Character.objects.create(name='Crowd') - witch = Character.objects.create(name='Witch') - duck = Character.objects.create(name='Duck') + cls.witch_scene = Scene.objects.create( + scene="Scene 5", setting="Sir Bedemir's Castle" + ) + bedemir = Character.objects.create(name="Bedemir") + crowd = Character.objects.create(name="Crowd") + witch = Character.objects.create(name="Witch") + duck = Character.objects.create(name="Duck") cls.bedemir0 = Line.objects.create( scene=cls.witch_scene, character=bedemir, - dialogue='We shall use my larger scales!', - dialogue_config='english', + dialogue="We shall use my larger scales!", + dialogue_config="english", ) cls.bedemir1 = Line.objects.create( scene=cls.witch_scene, character=bedemir, - dialogue='Right, remove the supports!', - dialogue_config='english', + dialogue="Right, remove the supports!", + dialogue_config="english", + ) + cls.duck = Line.objects.create( + scene=cls.witch_scene, character=duck, dialogue=None + ) + cls.crowd = Line.objects.create( + scene=cls.witch_scene, character=crowd, dialogue="A witch! A witch!" + ) + cls.witch = Line.objects.create( + scene=cls.witch_scene, character=witch, dialogue="It's a fair cop." ) - cls.duck = Line.objects.create(scene=cls.witch_scene, character=duck, dialogue=None) - cls.crowd = Line.objects.create(scene=cls.witch_scene, character=crowd, dialogue='A witch! A witch!') - cls.witch = Line.objects.create(scene=cls.witch_scene, character=witch, dialogue="It's a fair cop.") - trojan_rabbit = Scene.objects.create(scene='Scene 8', setting="The castle of Our Master Ruiz' de lu la Ramper") - guards = Character.objects.create(name='French Guards') + trojan_rabbit = Scene.objects.create( + scene="Scene 8", setting="The castle of Our Master Ruiz' de lu la Ramper" + ) + guards = Character.objects.create(name="French Guards") cls.french = Line.objects.create( scene=trojan_rabbit, character=guards, - dialogue='Oh. Un beau cadeau. Oui oui.', - dialogue_config='french', + dialogue="Oh. Un beau cadeau. Oui oui.", + dialogue_config="french", ) -@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}) +@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"}) class SimpleSearchTest(GrailTestData, PostgreSQLTestCase): - def test_simple(self): - searched = Line.objects.filter(dialogue__search='elbows') + searched = Line.objects.filter(dialogue__search="elbows") self.assertSequenceEqual(searched, [self.verse1]) def test_non_exact_match(self): - searched = Line.objects.filter(dialogue__search='hearts') + searched = Line.objects.filter(dialogue__search="hearts") self.assertSequenceEqual(searched, [self.verse2]) def test_search_two_terms(self): - searched = Line.objects.filter(dialogue__search='heart bowel') + searched = Line.objects.filter(dialogue__search="heart bowel") self.assertSequenceEqual(searched, [self.verse2]) def test_search_two_terms_with_partial_match(self): - searched = Line.objects.filter(dialogue__search='Robin killed') + searched = Line.objects.filter(dialogue__search="Robin killed") self.assertSequenceEqual(searched, [self.verse0]) def test_search_query_config(self): searched = Line.objects.filter( - dialogue__search=SearchQuery('nostrils', config='simple'), + dialogue__search=SearchQuery("nostrils", config="simple"), ) self.assertSequenceEqual(searched, [self.verse2]) def test_search_with_F_expression(self): # Non-matching query. - LineSavedSearch.objects.create(line=self.verse1, query='hearts') + LineSavedSearch.objects.create(line=self.verse1, query="hearts") # Matching query. - match = LineSavedSearch.objects.create(line=self.verse1, query='elbows') - for query_expression in [F('query'), SearchQuery(F('query'))]: + match = LineSavedSearch.objects.create(line=self.verse1, query="elbows") + for query_expression in [F("query"), SearchQuery(F("query"))]: with self.subTest(query_expression): searched = LineSavedSearch.objects.filter( line__dialogue__search=query_expression, @@ -123,254 +140,296 @@ class SimpleSearchTest(GrailTestData, PostgreSQLTestCase): self.assertSequenceEqual(searched, [match]) -@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}) +@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"}) class SearchVectorFieldTest(GrailTestData, PostgreSQLTestCase): def test_existing_vector(self): - Line.objects.update(dialogue_search_vector=SearchVector('dialogue')) - searched = Line.objects.filter(dialogue_search_vector=SearchQuery('Robin killed')) + Line.objects.update(dialogue_search_vector=SearchVector("dialogue")) + searched = Line.objects.filter( + dialogue_search_vector=SearchQuery("Robin killed") + ) self.assertSequenceEqual(searched, [self.verse0]) def test_existing_vector_config_explicit(self): - Line.objects.update(dialogue_search_vector=SearchVector('dialogue')) - searched = Line.objects.filter(dialogue_search_vector=SearchQuery('cadeaux', config='french')) + Line.objects.update(dialogue_search_vector=SearchVector("dialogue")) + searched = Line.objects.filter( + dialogue_search_vector=SearchQuery("cadeaux", config="french") + ) self.assertSequenceEqual(searched, [self.french]) def test_single_coalesce_expression(self): - searched = Line.objects.annotate(search=SearchVector('dialogue')).filter(search='cadeaux') - self.assertNotIn('COALESCE(COALESCE', str(searched.query)) + searched = Line.objects.annotate(search=SearchVector("dialogue")).filter( + search="cadeaux" + ) + self.assertNotIn("COALESCE(COALESCE", str(searched.query)) class SearchConfigTests(PostgreSQLSimpleTestCase): def test_from_parameter(self): self.assertIsNone(SearchConfig.from_parameter(None)) - self.assertEqual(SearchConfig.from_parameter('foo'), SearchConfig('foo')) - self.assertEqual(SearchConfig.from_parameter(SearchConfig('bar')), SearchConfig('bar')) + self.assertEqual(SearchConfig.from_parameter("foo"), SearchConfig("foo")) + self.assertEqual( + SearchConfig.from_parameter(SearchConfig("bar")), SearchConfig("bar") + ) class MultipleFieldsTest(GrailTestData, PostgreSQLTestCase): - def test_simple_on_dialogue(self): searched = Line.objects.annotate( - search=SearchVector('scene__setting', 'dialogue'), - ).filter(search='elbows') + search=SearchVector("scene__setting", "dialogue"), + ).filter(search="elbows") self.assertSequenceEqual(searched, [self.verse1]) def test_simple_on_scene(self): searched = Line.objects.annotate( - search=SearchVector('scene__setting', 'dialogue'), - ).filter(search='Forest') + search=SearchVector("scene__setting", "dialogue"), + ).filter(search="Forest") self.assertCountEqual(searched, self.verses) def test_non_exact_match(self): searched = Line.objects.annotate( - search=SearchVector('scene__setting', 'dialogue'), - ).filter(search='heart') + search=SearchVector("scene__setting", "dialogue"), + ).filter(search="heart") self.assertSequenceEqual(searched, [self.verse2]) def test_search_two_terms(self): searched = Line.objects.annotate( - search=SearchVector('scene__setting', 'dialogue'), - ).filter(search='heart forest') + search=SearchVector("scene__setting", "dialogue"), + ).filter(search="heart forest") self.assertSequenceEqual(searched, [self.verse2]) def test_terms_adjacent(self): searched = Line.objects.annotate( - search=SearchVector('character__name', 'dialogue'), - ).filter(search='minstrel') + search=SearchVector("character__name", "dialogue"), + ).filter(search="minstrel") self.assertCountEqual(searched, self.verses) searched = Line.objects.annotate( - search=SearchVector('scene__setting', 'dialogue'), - ).filter(search='minstrelbravely') + search=SearchVector("scene__setting", "dialogue"), + ).filter(search="minstrelbravely") self.assertSequenceEqual(searched, []) def test_search_with_null(self): searched = Line.objects.annotate( - search=SearchVector('scene__setting', 'dialogue'), - ).filter(search='bedemir') - self.assertCountEqual(searched, [self.bedemir0, self.bedemir1, self.crowd, self.witch, self.duck]) + search=SearchVector("scene__setting", "dialogue"), + ).filter(search="bedemir") + self.assertCountEqual( + searched, [self.bedemir0, self.bedemir1, self.crowd, self.witch, self.duck] + ) def test_search_with_non_text(self): searched = Line.objects.annotate( - search=SearchVector('id'), + search=SearchVector("id"), ).filter(search=str(self.crowd.id)) self.assertSequenceEqual(searched, [self.crowd]) def test_phrase_search(self): - line_qs = Line.objects.annotate(search=SearchVector('dialogue')) - searched = line_qs.filter(search=SearchQuery('burned body his away', search_type='phrase')) + line_qs = Line.objects.annotate(search=SearchVector("dialogue")) + searched = line_qs.filter( + search=SearchQuery("burned body his away", search_type="phrase") + ) self.assertSequenceEqual(searched, []) - searched = line_qs.filter(search=SearchQuery('his body burned away', search_type='phrase')) + searched = line_qs.filter( + search=SearchQuery("his body burned away", search_type="phrase") + ) self.assertSequenceEqual(searched, [self.verse1]) def test_phrase_search_with_config(self): line_qs = Line.objects.annotate( - search=SearchVector('scene__setting', 'dialogue', config='french'), + search=SearchVector("scene__setting", "dialogue", config="french"), ) searched = line_qs.filter( - search=SearchQuery('cadeau beau un', search_type='phrase', config='french'), + search=SearchQuery("cadeau beau un", search_type="phrase", config="french"), ) self.assertSequenceEqual(searched, []) searched = line_qs.filter( - search=SearchQuery('un beau cadeau', search_type='phrase', config='french'), + search=SearchQuery("un beau cadeau", search_type="phrase", config="french"), ) self.assertSequenceEqual(searched, [self.french]) def test_raw_search(self): - line_qs = Line.objects.annotate(search=SearchVector('dialogue')) - searched = line_qs.filter(search=SearchQuery('Robin', search_type='raw')) + line_qs = Line.objects.annotate(search=SearchVector("dialogue")) + searched = line_qs.filter(search=SearchQuery("Robin", search_type="raw")) self.assertCountEqual(searched, [self.verse0, self.verse1]) - searched = line_qs.filter(search=SearchQuery("Robin & !'Camelot'", search_type='raw')) + searched = line_qs.filter( + search=SearchQuery("Robin & !'Camelot'", search_type="raw") + ) self.assertSequenceEqual(searched, [self.verse1]) def test_raw_search_with_config(self): - line_qs = Line.objects.annotate(search=SearchVector('dialogue', config='french')) + line_qs = Line.objects.annotate( + search=SearchVector("dialogue", config="french") + ) searched = line_qs.filter( - search=SearchQuery("'cadeaux' & 'beaux'", search_type='raw', config='french'), + search=SearchQuery( + "'cadeaux' & 'beaux'", search_type="raw", config="french" + ), ) self.assertSequenceEqual(searched, [self.french]) - @skipUnlessDBFeature('has_websearch_to_tsquery') + @skipUnlessDBFeature("has_websearch_to_tsquery") def test_web_search(self): - line_qs = Line.objects.annotate(search=SearchVector('dialogue')) + line_qs = Line.objects.annotate(search=SearchVector("dialogue")) searched = line_qs.filter( search=SearchQuery( '"burned body" "split kneecaps"', - search_type='websearch', + search_type="websearch", ), ) self.assertSequenceEqual(searched, []) searched = line_qs.filter( search=SearchQuery( '"body burned" "kneecaps split" -"nostrils"', - search_type='websearch', + search_type="websearch", ), ) self.assertSequenceEqual(searched, [self.verse1]) searched = line_qs.filter( search=SearchQuery( '"Sir Robin" ("kneecaps" OR "Camelot")', - search_type='websearch', + search_type="websearch", ), ) self.assertSequenceEqual(searched, [self.verse0, self.verse1]) - @skipUnlessDBFeature('has_websearch_to_tsquery') + @skipUnlessDBFeature("has_websearch_to_tsquery") def test_web_search_with_config(self): line_qs = Line.objects.annotate( - search=SearchVector('scene__setting', 'dialogue', config='french'), + search=SearchVector("scene__setting", "dialogue", config="french"), ) searched = line_qs.filter( - search=SearchQuery('cadeau -beau', search_type='websearch', config='french'), + search=SearchQuery( + "cadeau -beau", search_type="websearch", config="french" + ), ) self.assertSequenceEqual(searched, []) searched = line_qs.filter( - search=SearchQuery('beau cadeau', search_type='websearch', config='french'), + search=SearchQuery("beau cadeau", search_type="websearch", config="french"), ) self.assertSequenceEqual(searched, [self.french]) def test_bad_search_type(self): - with self.assertRaisesMessage(ValueError, "Unknown search_type argument 'foo'."): - SearchQuery('kneecaps', search_type='foo') + with self.assertRaisesMessage( + ValueError, "Unknown search_type argument 'foo'." + ): + SearchQuery("kneecaps", search_type="foo") def test_config_query_explicit(self): searched = Line.objects.annotate( - search=SearchVector('scene__setting', 'dialogue', config='french'), - ).filter(search=SearchQuery('cadeaux', config='french')) + search=SearchVector("scene__setting", "dialogue", config="french"), + ).filter(search=SearchQuery("cadeaux", config="french")) self.assertSequenceEqual(searched, [self.french]) def test_config_query_implicit(self): searched = Line.objects.annotate( - search=SearchVector('scene__setting', 'dialogue', config='french'), - ).filter(search='cadeaux') + search=SearchVector("scene__setting", "dialogue", config="french"), + ).filter(search="cadeaux") self.assertSequenceEqual(searched, [self.french]) def test_config_from_field_explicit(self): searched = Line.objects.annotate( - search=SearchVector('scene__setting', 'dialogue', config=F('dialogue_config')), - ).filter(search=SearchQuery('cadeaux', config=F('dialogue_config'))) + search=SearchVector( + "scene__setting", "dialogue", config=F("dialogue_config") + ), + ).filter(search=SearchQuery("cadeaux", config=F("dialogue_config"))) self.assertSequenceEqual(searched, [self.french]) def test_config_from_field_implicit(self): searched = Line.objects.annotate( - search=SearchVector('scene__setting', 'dialogue', config=F('dialogue_config')), - ).filter(search='cadeaux') + search=SearchVector( + "scene__setting", "dialogue", config=F("dialogue_config") + ), + ).filter(search="cadeaux") self.assertSequenceEqual(searched, [self.french]) -@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}) +@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"}) class TestCombinations(GrailTestData, PostgreSQLTestCase): - def test_vector_add(self): searched = Line.objects.annotate( - search=SearchVector('scene__setting') + SearchVector('character__name'), - ).filter(search='bedemir') - self.assertCountEqual(searched, [self.bedemir0, self.bedemir1, self.crowd, self.witch, self.duck]) + search=SearchVector("scene__setting") + SearchVector("character__name"), + ).filter(search="bedemir") + self.assertCountEqual( + searched, [self.bedemir0, self.bedemir1, self.crowd, self.witch, self.duck] + ) def test_vector_add_multi(self): searched = Line.objects.annotate( search=( - SearchVector('scene__setting') + - SearchVector('character__name') + - SearchVector('dialogue') + SearchVector("scene__setting") + + SearchVector("character__name") + + SearchVector("dialogue") ), - ).filter(search='bedemir') - self.assertCountEqual(searched, [self.bedemir0, self.bedemir1, self.crowd, self.witch, self.duck]) + ).filter(search="bedemir") + self.assertCountEqual( + searched, [self.bedemir0, self.bedemir1, self.crowd, self.witch, self.duck] + ) def test_vector_combined_mismatch(self): msg = ( - 'SearchVector can only be combined with other SearchVector ' - 'instances, got NoneType.' + "SearchVector can only be combined with other SearchVector " + "instances, got NoneType." ) with self.assertRaisesMessage(TypeError, msg): - Line.objects.filter(dialogue__search=None + SearchVector('character__name')) + Line.objects.filter(dialogue__search=None + SearchVector("character__name")) def test_combine_different_vector_configs(self): searched = Line.objects.annotate( search=( - SearchVector('dialogue', config='english') + - SearchVector('dialogue', config='french') + SearchVector("dialogue", config="english") + + SearchVector("dialogue", config="french") ), ).filter( - search=SearchQuery('cadeaux', config='french') | SearchQuery('nostrils') + search=SearchQuery("cadeaux", config="french") | SearchQuery("nostrils") ) self.assertCountEqual(searched, [self.french, self.verse2]) def test_query_and(self): searched = Line.objects.annotate( - search=SearchVector('scene__setting', 'dialogue'), - ).filter(search=SearchQuery('bedemir') & SearchQuery('scales')) + search=SearchVector("scene__setting", "dialogue"), + ).filter(search=SearchQuery("bedemir") & SearchQuery("scales")) self.assertSequenceEqual(searched, [self.bedemir0]) def test_query_multiple_and(self): searched = Line.objects.annotate( - search=SearchVector('scene__setting', 'dialogue'), - ).filter(search=SearchQuery('bedemir') & SearchQuery('scales') & SearchQuery('nostrils')) + search=SearchVector("scene__setting", "dialogue"), + ).filter( + search=SearchQuery("bedemir") + & SearchQuery("scales") + & SearchQuery("nostrils") + ) self.assertSequenceEqual(searched, []) searched = Line.objects.annotate( - search=SearchVector('scene__setting', 'dialogue'), - ).filter(search=SearchQuery('shall') & SearchQuery('use') & SearchQuery('larger')) + search=SearchVector("scene__setting", "dialogue"), + ).filter( + search=SearchQuery("shall") & SearchQuery("use") & SearchQuery("larger") + ) self.assertSequenceEqual(searched, [self.bedemir0]) def test_query_or(self): - searched = Line.objects.filter(dialogue__search=SearchQuery('kneecaps') | SearchQuery('nostrils')) + searched = Line.objects.filter( + dialogue__search=SearchQuery("kneecaps") | SearchQuery("nostrils") + ) self.assertCountEqual(searched, [self.verse1, self.verse2]) def test_query_multiple_or(self): searched = Line.objects.filter( - dialogue__search=SearchQuery('kneecaps') | SearchQuery('nostrils') | SearchQuery('Sir Robin') + dialogue__search=SearchQuery("kneecaps") + | SearchQuery("nostrils") + | SearchQuery("Sir Robin") ) self.assertCountEqual(searched, [self.verse1, self.verse2, self.verse0]) def test_query_invert(self): - searched = Line.objects.filter(character=self.minstrel, dialogue__search=~SearchQuery('kneecaps')) + searched = Line.objects.filter( + character=self.minstrel, dialogue__search=~SearchQuery("kneecaps") + ) self.assertCountEqual(searched, [self.verse0, self.verse2]) def test_combine_different_configs(self): searched = Line.objects.filter( dialogue__search=( - SearchQuery('cadeau', config='french') | - SearchQuery('nostrils', config='english') + SearchQuery("cadeau", config="french") + | SearchQuery("nostrils", config="english") ) ) self.assertCountEqual(searched, [self.french, self.verse2]) @@ -378,8 +437,8 @@ class TestCombinations(GrailTestData, PostgreSQLTestCase): def test_combined_configs(self): searched = Line.objects.filter( dialogue__search=( - SearchQuery('nostrils', config='simple') & - SearchQuery('bowels', config='simple') + SearchQuery("nostrils", config="simple") + & SearchQuery("bowels", config="simple") ), ) self.assertSequenceEqual(searched, [self.verse2]) @@ -387,63 +446,96 @@ class TestCombinations(GrailTestData, PostgreSQLTestCase): def test_combine_raw_phrase(self): searched = Line.objects.filter( dialogue__search=( - SearchQuery('burn:*', search_type='raw', config='simple') | - SearchQuery('rode forth from Camelot', search_type='phrase') + SearchQuery("burn:*", search_type="raw", config="simple") + | SearchQuery("rode forth from Camelot", search_type="phrase") ) ) self.assertCountEqual(searched, [self.verse0, self.verse1, self.verse2]) def test_query_combined_mismatch(self): msg = ( - 'SearchQuery can only be combined with other SearchQuery ' - 'instances, got NoneType.' + "SearchQuery can only be combined with other SearchQuery " + "instances, got NoneType." ) with self.assertRaisesMessage(TypeError, msg): - Line.objects.filter(dialogue__search=None | SearchQuery('kneecaps')) + Line.objects.filter(dialogue__search=None | SearchQuery("kneecaps")) with self.assertRaisesMessage(TypeError, msg): - Line.objects.filter(dialogue__search=None & SearchQuery('kneecaps')) + Line.objects.filter(dialogue__search=None & SearchQuery("kneecaps")) -@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}) +@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"}) class TestRankingAndWeights(GrailTestData, PostgreSQLTestCase): - def test_ranking(self): - searched = Line.objects.filter(character=self.minstrel).annotate( - rank=SearchRank(SearchVector('dialogue'), SearchQuery('brave sir robin')), - ).order_by('rank') + searched = ( + Line.objects.filter(character=self.minstrel) + .annotate( + rank=SearchRank( + SearchVector("dialogue"), SearchQuery("brave sir robin") + ), + ) + .order_by("rank") + ) self.assertSequenceEqual(searched, [self.verse2, self.verse1, self.verse0]) def test_rank_passing_untyped_args(self): - searched = Line.objects.filter(character=self.minstrel).annotate( - rank=SearchRank('dialogue', 'brave sir robin'), - ).order_by('rank') + searched = ( + Line.objects.filter(character=self.minstrel) + .annotate( + rank=SearchRank("dialogue", "brave sir robin"), + ) + .order_by("rank") + ) self.assertSequenceEqual(searched, [self.verse2, self.verse1, self.verse0]) def test_weights_in_vector(self): - vector = SearchVector('dialogue', weight='A') + SearchVector('character__name', weight='D') - searched = Line.objects.filter(scene=self.witch_scene).annotate( - rank=SearchRank(vector, SearchQuery('witch')), - ).order_by('-rank')[:2] + vector = SearchVector("dialogue", weight="A") + SearchVector( + "character__name", weight="D" + ) + searched = ( + Line.objects.filter(scene=self.witch_scene) + .annotate( + rank=SearchRank(vector, SearchQuery("witch")), + ) + .order_by("-rank")[:2] + ) self.assertSequenceEqual(searched, [self.crowd, self.witch]) - vector = SearchVector('dialogue', weight='D') + SearchVector('character__name', weight='A') - searched = Line.objects.filter(scene=self.witch_scene).annotate( - rank=SearchRank(vector, SearchQuery('witch')), - ).order_by('-rank')[:2] + vector = SearchVector("dialogue", weight="D") + SearchVector( + "character__name", weight="A" + ) + searched = ( + Line.objects.filter(scene=self.witch_scene) + .annotate( + rank=SearchRank(vector, SearchQuery("witch")), + ) + .order_by("-rank")[:2] + ) self.assertSequenceEqual(searched, [self.witch, self.crowd]) def test_ranked_custom_weights(self): - vector = SearchVector('dialogue', weight='D') + SearchVector('character__name', weight='A') - searched = Line.objects.filter(scene=self.witch_scene).annotate( - rank=SearchRank(vector, SearchQuery('witch'), weights=[1, 0, 0, 0.5]), - ).order_by('-rank')[:2] + vector = SearchVector("dialogue", weight="D") + SearchVector( + "character__name", weight="A" + ) + searched = ( + Line.objects.filter(scene=self.witch_scene) + .annotate( + rank=SearchRank(vector, SearchQuery("witch"), weights=[1, 0, 0, 0.5]), + ) + .order_by("-rank")[:2] + ) self.assertSequenceEqual(searched, [self.crowd, self.witch]) def test_ranking_chaining(self): - searched = Line.objects.filter(character=self.minstrel).annotate( - rank=SearchRank(SearchVector('dialogue'), SearchQuery('brave sir robin')), - ).filter(rank__gt=0.3) + searched = ( + Line.objects.filter(character=self.minstrel) + .annotate( + rank=SearchRank( + SearchVector("dialogue"), SearchQuery("brave sir robin") + ), + ) + .filter(rank__gt=0.3) + ) self.assertSequenceEqual(searched, [self.verse0]) def test_cover_density_ranking(self): @@ -451,17 +543,21 @@ class TestRankingAndWeights(GrailTestData, PostgreSQLTestCase): scene=self.robin, character=self.minstrel, dialogue=( - 'Bravely taking to his feet, he beat a very brave retreat. ' - 'A brave retreat brave Sir Robin.' + "Bravely taking to his feet, he beat a very brave retreat. " + "A brave retreat brave Sir Robin." + ), + ) + searched = ( + Line.objects.filter(character=self.minstrel) + .annotate( + rank=SearchRank( + SearchVector("dialogue"), + SearchQuery("brave robin"), + cover_density=True, + ), ) + .order_by("rank", "-pk") ) - searched = Line.objects.filter(character=self.minstrel).annotate( - rank=SearchRank( - SearchVector('dialogue'), - SearchQuery('brave robin'), - cover_density=True, - ), - ).order_by('rank', '-pk') self.assertSequenceEqual( searched, [self.verse2, not_dense_verse, self.verse1, self.verse0], @@ -471,16 +567,20 @@ class TestRankingAndWeights(GrailTestData, PostgreSQLTestCase): short_verse = Line.objects.create( scene=self.robin, character=self.minstrel, - dialogue='A brave retreat brave Sir Robin.', + dialogue="A brave retreat brave Sir Robin.", + ) + searched = ( + Line.objects.filter(character=self.minstrel) + .annotate( + rank=SearchRank( + SearchVector("dialogue"), + SearchQuery("brave sir robin"), + # Divide the rank by the document length. + normalization=2, + ), + ) + .order_by("rank") ) - searched = Line.objects.filter(character=self.minstrel).annotate( - rank=SearchRank( - SearchVector('dialogue'), - SearchQuery('brave sir robin'), - # Divide the rank by the document length. - normalization=2, - ), - ).order_by('rank') self.assertSequenceEqual( searched, [self.verse2, self.verse1, self.verse0, short_verse], @@ -490,17 +590,21 @@ class TestRankingAndWeights(GrailTestData, PostgreSQLTestCase): short_verse = Line.objects.create( scene=self.robin, character=self.minstrel, - dialogue='A brave retreat brave Sir Robin.', + dialogue="A brave retreat brave Sir Robin.", + ) + searched = ( + Line.objects.filter(character=self.minstrel) + .annotate( + rank=SearchRank( + SearchVector("dialogue"), + SearchQuery("brave sir robin"), + # Divide the rank by the document length and by the number of + # unique words in document. + normalization=Value(2).bitor(Value(8)), + ), + ) + .order_by("rank") ) - searched = Line.objects.filter(character=self.minstrel).annotate( - rank=SearchRank( - SearchVector('dialogue'), - SearchQuery('brave sir robin'), - # Divide the rank by the document length and by the number of - # unique words in document. - normalization=Value(2).bitor(Value(8)), - ), - ).order_by('rank') self.assertSequenceEqual( searched, [self.verse2, self.verse1, self.verse0, short_verse], @@ -513,13 +617,16 @@ class SearchVectorIndexTests(PostgreSQLTestCase): # This test should be moved to test_indexes and use a functional # index instead once support lands (see #26167). query = Line.objects.all().query - resolved = SearchVector('id', 'dialogue', config='english').resolve_expression(query) + resolved = SearchVector("id", "dialogue", config="english").resolve_expression( + query + ) compiler = query.get_compiler(connection.alias) sql, params = resolved.as_sql(compiler, connection) # Indexed function must be IMMUTABLE. with connection.cursor() as cursor: cursor.execute( - 'CREATE INDEX search_vector_index ON %s USING GIN (%s)' % (Line._meta.db_table, sql), + "CREATE INDEX search_vector_index ON %s USING GIN (%s)" + % (Line._meta.db_table, sql), params, ) @@ -527,24 +634,26 @@ class SearchVectorIndexTests(PostgreSQLTestCase): class SearchQueryTests(PostgreSQLSimpleTestCase): def test_str(self): tests = ( - (~SearchQuery('a'), "~SearchQuery(Value('a'))"), + (~SearchQuery("a"), "~SearchQuery(Value('a'))"), ( - (SearchQuery('a') | SearchQuery('b')) & (SearchQuery('c') | SearchQuery('d')), + (SearchQuery("a") | SearchQuery("b")) + & (SearchQuery("c") | SearchQuery("d")), "((SearchQuery(Value('a')) || SearchQuery(Value('b'))) && " "(SearchQuery(Value('c')) || SearchQuery(Value('d'))))", ), ( - SearchQuery('a') & (SearchQuery('b') | SearchQuery('c')), + SearchQuery("a") & (SearchQuery("b") | SearchQuery("c")), "(SearchQuery(Value('a')) && (SearchQuery(Value('b')) || " "SearchQuery(Value('c'))))", ), ( - (SearchQuery('a') | SearchQuery('b')) & SearchQuery('c'), + (SearchQuery("a") | SearchQuery("b")) & SearchQuery("c"), "((SearchQuery(Value('a')) || SearchQuery(Value('b'))) && " - "SearchQuery(Value('c')))" + "SearchQuery(Value('c')))", ), ( - SearchQuery('a') & (SearchQuery('b') & (SearchQuery('c') | SearchQuery('d'))), + SearchQuery("a") + & (SearchQuery("b") & (SearchQuery("c") | SearchQuery("d"))), "(SearchQuery(Value('a')) && (SearchQuery(Value('b')) && " "(SearchQuery(Value('c')) || SearchQuery(Value('d')))))", ), @@ -554,109 +663,112 @@ class SearchQueryTests(PostgreSQLSimpleTestCase): self.assertEqual(str(query), expected_str) -@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}) +@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"}) class SearchHeadlineTests(GrailTestData, PostgreSQLTestCase): def test_headline(self): searched = Line.objects.annotate( headline=SearchHeadline( - F('dialogue'), - SearchQuery('brave sir robin'), - config=SearchConfig('english'), + F("dialogue"), + SearchQuery("brave sir robin"), + config=SearchConfig("english"), ), ).get(pk=self.verse0.pk) self.assertEqual( searched.headline, - '<b>Robin</b>. He was not at all afraid to be killed in nasty ' - 'ways. <b>Brave</b>, <b>brave</b>, <b>brave</b>, <b>brave</b> ' - '<b>Sir</b> <b>Robin</b>', + "<b>Robin</b>. He was not at all afraid to be killed in nasty " + "ways. <b>Brave</b>, <b>brave</b>, <b>brave</b>, <b>brave</b> " + "<b>Sir</b> <b>Robin</b>", ) def test_headline_untyped_args(self): searched = Line.objects.annotate( - headline=SearchHeadline('dialogue', 'killed', config='english'), + headline=SearchHeadline("dialogue", "killed", config="english"), ).get(pk=self.verse0.pk) self.assertEqual( searched.headline, - 'Robin. He was not at all afraid to be <b>killed</b> in nasty ' - 'ways. Brave, brave, brave, brave Sir Robin', + "Robin. He was not at all afraid to be <b>killed</b> in nasty " + "ways. Brave, brave, brave, brave Sir Robin", ) def test_headline_with_config(self): searched = Line.objects.annotate( headline=SearchHeadline( - 'dialogue', - SearchQuery('cadeaux', config='french'), - config='french', + "dialogue", + SearchQuery("cadeaux", config="french"), + config="french", ), ).get(pk=self.french.pk) self.assertEqual( searched.headline, - 'Oh. Un beau <b>cadeau</b>. Oui oui.', + "Oh. Un beau <b>cadeau</b>. Oui oui.", ) def test_headline_with_config_from_field(self): searched = Line.objects.annotate( headline=SearchHeadline( - 'dialogue', - SearchQuery('cadeaux', config=F('dialogue_config')), - config=F('dialogue_config'), + "dialogue", + SearchQuery("cadeaux", config=F("dialogue_config")), + config=F("dialogue_config"), ), ).get(pk=self.french.pk) self.assertEqual( searched.headline, - 'Oh. Un beau <b>cadeau</b>. Oui oui.', + "Oh. Un beau <b>cadeau</b>. Oui oui.", ) def test_headline_separator_options(self): searched = Line.objects.annotate( headline=SearchHeadline( - 'dialogue', - 'brave sir robin', - start_sel='<span>', - stop_sel='</span>', + "dialogue", + "brave sir robin", + start_sel="<span>", + stop_sel="</span>", ), ).get(pk=self.verse0.pk) self.assertEqual( searched.headline, - '<span>Robin</span>. He was not at all afraid to be killed in ' - 'nasty ways. <span>Brave</span>, <span>brave</span>, <span>brave' - '</span>, <span>brave</span> <span>Sir</span> <span>Robin</span>', + "<span>Robin</span>. He was not at all afraid to be killed in " + "nasty ways. <span>Brave</span>, <span>brave</span>, <span>brave" + "</span>, <span>brave</span> <span>Sir</span> <span>Robin</span>", ) def test_headline_highlight_all_option(self): searched = Line.objects.annotate( headline=SearchHeadline( - 'dialogue', - SearchQuery('brave sir robin', config='english'), + "dialogue", + SearchQuery("brave sir robin", config="english"), highlight_all=True, ), ).get(pk=self.verse0.pk) self.assertIn( - '<b>Bravely</b> bold <b>Sir</b> <b>Robin</b>, rode forth from ' - 'Camelot. He was not afraid to die, o ', + "<b>Bravely</b> bold <b>Sir</b> <b>Robin</b>, rode forth from " + "Camelot. He was not afraid to die, o ", searched.headline, ) def test_headline_short_word_option(self): searched = Line.objects.annotate( headline=SearchHeadline( - 'dialogue', - SearchQuery('Camelot', config='english'), + "dialogue", + SearchQuery("Camelot", config="english"), short_word=5, min_words=8, ), ).get(pk=self.verse0.pk) - self.assertEqual(searched.headline, ( - '<b>Camelot</b>. He was not afraid to die, o Brave Sir Robin. He ' - 'was not at all afraid' - )) + self.assertEqual( + searched.headline, + ( + "<b>Camelot</b>. He was not afraid to die, o Brave Sir Robin. He " + "was not at all afraid" + ), + ) def test_headline_fragments_words_options(self): searched = Line.objects.annotate( headline=SearchHeadline( - 'dialogue', - SearchQuery('brave sir robin', config='english'), - fragment_delimiter='...<br>', + "dialogue", + SearchQuery("brave sir robin", config="english"), + fragment_delimiter="...<br>", max_fragments=4, max_words=3, min_words=1, @@ -664,8 +776,8 @@ class SearchHeadlineTests(GrailTestData, PostgreSQLTestCase): ).get(pk=self.verse0.pk) self.assertEqual( searched.headline, - '<b>Sir</b> <b>Robin</b>, rode...<br>' - '<b>Brave</b> <b>Sir</b> <b>Robin</b>...<br>' - '<b>Brave</b>, <b>brave</b>, <b>brave</b>...<br>' - '<b>brave</b> <b>Sir</b> <b>Robin</b>', + "<b>Sir</b> <b>Robin</b>, rode...<br>" + "<b>Brave</b> <b>Sir</b> <b>Robin</b>...<br>" + "<b>Brave</b>, <b>brave</b>, <b>brave</b>...<br>" + "<b>brave</b> <b>Sir</b> <b>Robin</b>", ) diff --git a/tests/postgres_tests/test_signals.py b/tests/postgres_tests/test_signals.py index f1569c361c..764524d8e6 100644 --- a/tests/postgres_tests/test_signals.py +++ b/tests/postgres_tests/test_signals.py @@ -4,14 +4,15 @@ from . import PostgreSQLTestCase try: from django.contrib.postgres.signals import ( - get_citext_oids, get_hstore_oids, register_type_handlers, + get_citext_oids, + get_hstore_oids, + register_type_handlers, ) except ImportError: pass # pyscogp2 isn't installed. class OIDTests(PostgreSQLTestCase): - def assertOIDs(self, oids): self.assertIsInstance(oids, tuple) self.assertGreater(len(oids), 0) diff --git a/tests/postgres_tests/test_trigram.py b/tests/postgres_tests/test_trigram.py index 079a32a19b..a0502e0b9b 100644 --- a/tests/postgres_tests/test_trigram.py +++ b/tests/postgres_tests/test_trigram.py @@ -5,65 +5,74 @@ from .models import CharFieldModel, TextFieldModel try: from django.contrib.postgres.search import ( - TrigramDistance, TrigramSimilarity, TrigramWordDistance, + TrigramDistance, + TrigramSimilarity, + TrigramWordDistance, TrigramWordSimilarity, ) except ImportError: pass -@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}) +@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"}) class TrigramTest(PostgreSQLTestCase): Model = CharFieldModel @classmethod def setUpTestData(cls): - cls.Model.objects.bulk_create([ - cls.Model(field='Matthew'), - cls.Model(field='Cat sat on mat.'), - cls.Model(field='Dog sat on rug.'), - ]) + cls.Model.objects.bulk_create( + [ + cls.Model(field="Matthew"), + cls.Model(field="Cat sat on mat."), + cls.Model(field="Dog sat on rug."), + ] + ) def test_trigram_search(self): self.assertQuerysetEqual( - self.Model.objects.filter(field__trigram_similar='Mathew'), - ['Matthew'], + self.Model.objects.filter(field__trigram_similar="Mathew"), + ["Matthew"], transform=lambda instance: instance.field, ) def test_trigram_word_search(self): obj = self.Model.objects.create( - field='Gumby rides on the path of Middlesbrough', + field="Gumby rides on the path of Middlesbrough", ) self.assertSequenceEqual( - self.Model.objects.filter(field__trigram_word_similar='Middlesborough'), + self.Model.objects.filter(field__trigram_word_similar="Middlesborough"), [obj], ) def test_trigram_similarity(self): - search = 'Bat sat on cat.' + search = "Bat sat on cat." # Round result of similarity because PostgreSQL 12+ uses greater # precision. self.assertQuerysetEqual( self.Model.objects.filter( field__trigram_similar=search, - ).annotate(similarity=TrigramSimilarity('field', search)).order_by('-similarity'), - [('Cat sat on mat.', 0.625), ('Dog sat on rug.', 0.333333)], + ) + .annotate(similarity=TrigramSimilarity("field", search)) + .order_by("-similarity"), + [("Cat sat on mat.", 0.625), ("Dog sat on rug.", 0.333333)], transform=lambda instance: (instance.field, round(instance.similarity, 6)), ordered=True, ) def test_trigram_word_similarity(self): - search = 'mat' + search = "mat" self.assertSequenceEqual( self.Model.objects.filter( field__trigram_word_similar=search, - ).annotate( - word_similarity=TrigramWordSimilarity(search, 'field'), - ).values('field', 'word_similarity').order_by('-word_similarity'), + ) + .annotate( + word_similarity=TrigramWordSimilarity(search, "field"), + ) + .values("field", "word_similarity") + .order_by("-word_similarity"), [ - {'field': 'Cat sat on mat.', 'word_similarity': 1.0}, - {'field': 'Matthew', 'word_similarity': 0.75}, + {"field": "Cat sat on mat.", "word_similarity": 1.0}, + {"field": "Matthew", "word_similarity": 0.75}, ], ) @@ -72,9 +81,11 @@ class TrigramTest(PostgreSQLTestCase): # precision. self.assertQuerysetEqual( self.Model.objects.annotate( - distance=TrigramDistance('field', 'Bat sat on cat.'), - ).filter(distance__lte=0.7).order_by('distance'), - [('Cat sat on mat.', 0.375), ('Dog sat on rug.', 0.666667)], + distance=TrigramDistance("field", "Bat sat on cat."), + ) + .filter(distance__lte=0.7) + .order_by("distance"), + [("Cat sat on mat.", 0.375), ("Dog sat on rug.", 0.666667)], transform=lambda instance: (instance.field, round(instance.distance, 6)), ordered=True, ) @@ -82,13 +93,16 @@ class TrigramTest(PostgreSQLTestCase): def test_trigram_word_similarity_alternate(self): self.assertSequenceEqual( self.Model.objects.annotate( - word_distance=TrigramWordDistance('mat', 'field'), - ).filter( + word_distance=TrigramWordDistance("mat", "field"), + ) + .filter( word_distance__lte=0.7, - ).values('field', 'word_distance').order_by('word_distance'), + ) + .values("field", "word_distance") + .order_by("word_distance"), [ - {'field': 'Cat sat on mat.', 'word_distance': 0}, - {'field': 'Matthew', 'word_distance': 0.25}, + {"field": "Cat sat on mat.", "word_distance": 0}, + {"field": "Matthew", "word_distance": 0.25}, ], ) @@ -97,4 +111,5 @@ class TrigramTextFieldTest(TrigramTest): """ TextField has the same behavior as CharField regarding trigram lookups. """ + Model = TextFieldModel diff --git a/tests/postgres_tests/test_unaccent.py b/tests/postgres_tests/test_unaccent.py index 6d52f1d7dd..4188d90794 100644 --- a/tests/postgres_tests/test_unaccent.py +++ b/tests/postgres_tests/test_unaccent.py @@ -5,25 +5,27 @@ from . import PostgreSQLTestCase from .models import CharFieldModel, TextFieldModel -@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}) +@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"}) class UnaccentTest(PostgreSQLTestCase): Model = CharFieldModel @classmethod def setUpTestData(cls): - cls.Model.objects.bulk_create([ - cls.Model(field="àéÖ"), - cls.Model(field="aeO"), - cls.Model(field="aeo"), - ]) + cls.Model.objects.bulk_create( + [ + cls.Model(field="àéÖ"), + cls.Model(field="aeO"), + cls.Model(field="aeo"), + ] + ) def test_unaccent(self): self.assertQuerysetEqual( self.Model.objects.filter(field__unaccent="aeO"), ["àéÖ", "aeO"], transform=lambda instance: instance.field, - ordered=False + ordered=False, ) def test_unaccent_chained(self): @@ -35,39 +37,39 @@ class UnaccentTest(PostgreSQLTestCase): self.Model.objects.filter(field__unaccent__iexact="aeO"), ["àéÖ", "aeO", "aeo"], transform=lambda instance: instance.field, - ordered=False + ordered=False, ) self.assertQuerysetEqual( self.Model.objects.filter(field__unaccent__endswith="éÖ"), ["àéÖ", "aeO"], transform=lambda instance: instance.field, - ordered=False + ordered=False, ) def test_unaccent_with_conforming_strings_off(self): """SQL is valid when standard_conforming_strings is off.""" with connection.cursor() as cursor: - cursor.execute('SHOW standard_conforming_strings') - disable_conforming_strings = cursor.fetchall()[0][0] == 'on' + cursor.execute("SHOW standard_conforming_strings") + disable_conforming_strings = cursor.fetchall()[0][0] == "on" if disable_conforming_strings: - cursor.execute('SET standard_conforming_strings TO off') + cursor.execute("SET standard_conforming_strings TO off") try: self.assertQuerysetEqual( - self.Model.objects.filter(field__unaccent__endswith='éÖ'), - ['àéÖ', 'aeO'], + self.Model.objects.filter(field__unaccent__endswith="éÖ"), + ["àéÖ", "aeO"], transform=lambda instance: instance.field, ordered=False, ) finally: if disable_conforming_strings: - cursor.execute('SET standard_conforming_strings TO on') + cursor.execute("SET standard_conforming_strings TO on") def test_unaccent_accentuated_needle(self): self.assertQuerysetEqual( self.Model.objects.filter(field__unaccent="aéÖ"), ["àéÖ", "aeO"], transform=lambda instance: instance.field, - ordered=False + ordered=False, ) @@ -76,4 +78,5 @@ class UnaccentTextFieldTest(UnaccentTest): TextField should have the exact same behavior as CharField regarding unaccent lookups. """ + Model = TextFieldModel |
