summaryrefslogtreecommitdiff
path: root/tests/postgres_tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests/postgres_tests')
-rw-r--r--tests/postgres_tests/__init__.py8
-rw-r--r--tests/postgres_tests/array_default_migrations/0001_initial.py25
-rw-r--r--tests/postgres_tests/array_default_migrations/0002_integerarraymodel_field_2.py10
-rw-r--r--tests/postgres_tests/array_index_migrations/0001_initial.py36
-rw-r--r--tests/postgres_tests/fields.py28
-rw-r--r--tests/postgres_tests/integration_settings.py4
-rw-r--r--tests/postgres_tests/migrations/0001_setup_extensions.py15
-rw-r--r--tests/postgres_tests/migrations/0002_create_test_models.py542
-rw-r--r--tests/postgres_tests/models.py40
-rw-r--r--tests/postgres_tests/test_aggregates.py769
-rw-r--r--tests/postgres_tests/test_apps.py49
-rw-r--r--tests/postgres_tests/test_array.py802
-rw-r--r--tests/postgres_tests/test_bulk_update.py37
-rw-r--r--tests/postgres_tests/test_citext.py59
-rw-r--r--tests/postgres_tests/test_constraints.py802
-rw-r--r--tests/postgres_tests/test_functions.py2
-rw-r--r--tests/postgres_tests/test_hstore.py315
-rw-r--r--tests/postgres_tests/test_indexes.py526
-rw-r--r--tests/postgres_tests/test_integration.py17
-rw-r--r--tests/postgres_tests/test_introspection.py16
-rw-r--r--tests/postgres_tests/test_operations.py434
-rw-r--r--tests/postgres_tests/test_ranges.py534
-rw-r--r--tests/postgres_tests/test_search.py590
-rw-r--r--tests/postgres_tests/test_signals.py5
-rw-r--r--tests/postgres_tests/test_trigram.py71
-rw-r--r--tests/postgres_tests/test_unaccent.py35
26 files changed, 3412 insertions, 2359 deletions
diff --git a/tests/postgres_tests/__init__.py b/tests/postgres_tests/__init__.py
index 2b84fc25db..6f02531ed0 100644
--- a/tests/postgres_tests/__init__.py
+++ b/tests/postgres_tests/__init__.py
@@ -6,18 +6,18 @@ from django.db import connection
from django.test import SimpleTestCase, TestCase, modify_settings
-@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific tests")
+@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests")
class PostgreSQLSimpleTestCase(SimpleTestCase):
pass
-@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific tests")
+@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests")
class PostgreSQLTestCase(TestCase):
pass
-@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific tests")
+@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests")
# To locate the widget's template.
-@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
+@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
class PostgreSQLWidgetTestCase(WidgetTest, PostgreSQLSimpleTestCase):
pass
diff --git a/tests/postgres_tests/array_default_migrations/0001_initial.py b/tests/postgres_tests/array_default_migrations/0001_initial.py
index eb523218ef..10eaef2aab 100644
--- a/tests/postgres_tests/array_default_migrations/0001_initial.py
+++ b/tests/postgres_tests/array_default_migrations/0001_initial.py
@@ -4,18 +4,29 @@ from django.db import migrations, models
class Migration(migrations.Migration):
- dependencies = [
- ]
+ dependencies = []
operations = [
migrations.CreateModel(
- name='IntegerArrayDefaultModel',
+ name="IntegerArrayDefaultModel",
fields=[
- ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
- ('field', django.contrib.postgres.fields.ArrayField(models.IntegerField(), size=None)),
+ (
+ "id",
+ models.AutoField(
+ verbose_name="ID",
+ serialize=False,
+ auto_created=True,
+ primary_key=True,
+ ),
+ ),
+ (
+ "field",
+ django.contrib.postgres.fields.ArrayField(
+ models.IntegerField(), size=None
+ ),
+ ),
],
- options={
- },
+ options={},
bases=(models.Model,),
),
]
diff --git a/tests/postgres_tests/array_default_migrations/0002_integerarraymodel_field_2.py b/tests/postgres_tests/array_default_migrations/0002_integerarraymodel_field_2.py
index 679c9bb0d3..b15b575e54 100644
--- a/tests/postgres_tests/array_default_migrations/0002_integerarraymodel_field_2.py
+++ b/tests/postgres_tests/array_default_migrations/0002_integerarraymodel_field_2.py
@@ -5,14 +5,16 @@ from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
- ('postgres_tests', '0001_initial'),
+ ("postgres_tests", "0001_initial"),
]
operations = [
migrations.AddField(
- model_name='integerarraydefaultmodel',
- name='field_2',
- field=django.contrib.postgres.fields.ArrayField(models.IntegerField(), default=[], size=None),
+ model_name="integerarraydefaultmodel",
+ name="field_2",
+ field=django.contrib.postgres.fields.ArrayField(
+ models.IntegerField(), default=[], size=None
+ ),
preserve_default=False,
),
]
diff --git a/tests/postgres_tests/array_index_migrations/0001_initial.py b/tests/postgres_tests/array_index_migrations/0001_initial.py
index 505e53e4e8..5c74be326a 100644
--- a/tests/postgres_tests/array_index_migrations/0001_initial.py
+++ b/tests/postgres_tests/array_index_migrations/0001_initial.py
@@ -4,22 +4,36 @@ from django.db import migrations, models
class Migration(migrations.Migration):
- dependencies = [
- ]
+ dependencies = []
operations = [
migrations.CreateModel(
- name='CharTextArrayIndexModel',
+ name="CharTextArrayIndexModel",
fields=[
- ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
- ('char', django.contrib.postgres.fields.ArrayField(
- models.CharField(max_length=10), db_index=True, size=100)
- ),
- ('char2', models.CharField(max_length=11, db_index=True)),
- ('text', django.contrib.postgres.fields.ArrayField(models.TextField(), db_index=True)),
+ (
+ "id",
+ models.AutoField(
+ verbose_name="ID",
+ serialize=False,
+ auto_created=True,
+ primary_key=True,
+ ),
+ ),
+ (
+ "char",
+ django.contrib.postgres.fields.ArrayField(
+ models.CharField(max_length=10), db_index=True, size=100
+ ),
+ ),
+ ("char2", models.CharField(max_length=11, db_index=True)),
+ (
+ "text",
+ django.contrib.postgres.fields.ArrayField(
+ models.TextField(), db_index=True
+ ),
+ ),
],
- options={
- },
+ options={},
bases=(models.Model,),
),
]
diff --git a/tests/postgres_tests/fields.py b/tests/postgres_tests/fields.py
index 1c0cb05d46..2d62e26a92 100644
--- a/tests/postgres_tests/fields.py
+++ b/tests/postgres_tests/fields.py
@@ -8,31 +8,41 @@ from django.db import models
try:
from django.contrib.postgres.fields import (
- ArrayField, BigIntegerRangeField, CICharField, CIEmailField,
- CITextField, DateRangeField, DateTimeRangeField, DecimalRangeField,
- HStoreField, IntegerRangeField,
+ ArrayField,
+ BigIntegerRangeField,
+ CICharField,
+ CIEmailField,
+ CITextField,
+ DateRangeField,
+ DateTimeRangeField,
+ DecimalRangeField,
+ HStoreField,
+ IntegerRangeField,
)
from django.contrib.postgres.search import SearchVector, SearchVectorField
except ImportError:
+
class DummyArrayField(models.Field):
def __init__(self, base_field, size=None, **kwargs):
super().__init__(**kwargs)
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
- kwargs.update({
- 'base_field': '',
- 'size': 1,
- })
+ kwargs.update(
+ {
+ "base_field": "",
+ "size": 1,
+ }
+ )
return name, path, args, kwargs
class DummyContinuousRangeField(models.Field):
- def __init__(self, *args, default_bounds='[)', **kwargs):
+ def __init__(self, *args, default_bounds="[)", **kwargs):
super().__init__(**kwargs)
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
- kwargs['default_bounds'] = '[)'
+ kwargs["default_bounds"] = "[)"
return name, path, args, kwargs
ArrayField = DummyArrayField
diff --git a/tests/postgres_tests/integration_settings.py b/tests/postgres_tests/integration_settings.py
index c4ec0d1157..7e2ea9c8d0 100644
--- a/tests/postgres_tests/integration_settings.py
+++ b/tests/postgres_tests/integration_settings.py
@@ -1,5 +1,5 @@
-SECRET_KEY = 'abcdefg'
+SECRET_KEY = "abcdefg"
INSTALLED_APPS = [
- 'django.contrib.postgres',
+ "django.contrib.postgres",
]
diff --git a/tests/postgres_tests/migrations/0001_setup_extensions.py b/tests/postgres_tests/migrations/0001_setup_extensions.py
index bd5da83d15..090abf9649 100644
--- a/tests/postgres_tests/migrations/0001_setup_extensions.py
+++ b/tests/postgres_tests/migrations/0001_setup_extensions.py
@@ -4,8 +4,14 @@ from django.db import connection, migrations
try:
from django.contrib.postgres.operations import (
- BloomExtension, BtreeGinExtension, BtreeGistExtension, CITextExtension,
- CreateExtension, CryptoExtension, HStoreExtension, TrigramExtension,
+ BloomExtension,
+ BtreeGinExtension,
+ BtreeGistExtension,
+ CITextExtension,
+ CreateExtension,
+ CryptoExtension,
+ HStoreExtension,
+ TrigramExtension,
UnaccentExtension,
)
except ImportError:
@@ -20,8 +26,7 @@ except ImportError:
needs_crypto_extension = False
else:
needs_crypto_extension = (
- connection.vendor == 'postgresql' and
- not connection.features.is_postgresql_13
+ connection.vendor == "postgresql" and not connection.features.is_postgresql_13
)
@@ -34,7 +39,7 @@ class Migration(migrations.Migration):
CITextExtension(),
# Ensure CreateExtension quotes extension names by creating one with a
# dash in its name.
- CreateExtension('uuid-ossp'),
+ CreateExtension("uuid-ossp"),
# CryptoExtension is required for RandomUUID() on PostgreSQL < 13.
CryptoExtension() if needs_crypto_extension else mock.Mock(),
HStoreExtension(),
diff --git a/tests/postgres_tests/migrations/0002_create_test_models.py b/tests/postgres_tests/migrations/0002_create_test_models.py
index 7ede593dae..a78563f80d 100644
--- a/tests/postgres_tests/migrations/0002_create_test_models.py
+++ b/tests/postgres_tests/migrations/0002_create_test_models.py
@@ -1,9 +1,18 @@
from django.db import migrations, models
from ..fields import (
- ArrayField, BigIntegerRangeField, CICharField, CIEmailField, CITextField,
- DateRangeField, DateTimeRangeField, DecimalRangeField, EnumField,
- HStoreField, IntegerRangeField, SearchVectorField,
+ ArrayField,
+ BigIntegerRangeField,
+ CICharField,
+ CIEmailField,
+ CITextField,
+ DateRangeField,
+ DateTimeRangeField,
+ DecimalRangeField,
+ EnumField,
+ HStoreField,
+ IntegerRangeField,
+ SearchVectorField,
)
from ..models import TagField
@@ -11,305 +20,538 @@ from ..models import TagField
class Migration(migrations.Migration):
dependencies = [
- ('postgres_tests', '0001_setup_extensions'),
+ ("postgres_tests", "0001_setup_extensions"),
]
operations = [
migrations.CreateModel(
- name='CharArrayModel',
+ name="CharArrayModel",
fields=[
- ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
- ('field', ArrayField(models.CharField(max_length=10), size=None)),
+ (
+ "id",
+ models.AutoField(
+ verbose_name="ID",
+ serialize=False,
+ auto_created=True,
+ primary_key=True,
+ ),
+ ),
+ ("field", ArrayField(models.CharField(max_length=10), size=None)),
],
options={
- 'required_db_vendor': 'postgresql',
+ "required_db_vendor": "postgresql",
},
bases=(models.Model,),
),
migrations.CreateModel(
- name='DateTimeArrayModel',
+ name="DateTimeArrayModel",
fields=[
- ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
- ('datetimes', ArrayField(models.DateTimeField(), size=None)),
- ('dates', ArrayField(models.DateField(), size=None)),
- ('times', ArrayField(models.TimeField(), size=None)),
+ (
+ "id",
+ models.AutoField(
+ verbose_name="ID",
+ serialize=False,
+ auto_created=True,
+ primary_key=True,
+ ),
+ ),
+ ("datetimes", ArrayField(models.DateTimeField(), size=None)),
+ ("dates", ArrayField(models.DateField(), size=None)),
+ ("times", ArrayField(models.TimeField(), size=None)),
],
options={
- 'required_db_vendor': 'postgresql',
+ "required_db_vendor": "postgresql",
},
bases=(models.Model,),
),
migrations.CreateModel(
- name='HStoreModel',
+ name="HStoreModel",
fields=[
- ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
- ('field', HStoreField(blank=True, null=True)),
- ('array_field', ArrayField(HStoreField(), null=True)),
+ (
+ "id",
+ models.AutoField(
+ verbose_name="ID",
+ serialize=False,
+ auto_created=True,
+ primary_key=True,
+ ),
+ ),
+ ("field", HStoreField(blank=True, null=True)),
+ ("array_field", ArrayField(HStoreField(), null=True)),
],
options={
- 'required_db_vendor': 'postgresql',
+ "required_db_vendor": "postgresql",
},
bases=(models.Model,),
),
migrations.CreateModel(
- name='OtherTypesArrayModel',
+ name="OtherTypesArrayModel",
fields=[
- ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
- ('ips', ArrayField(models.GenericIPAddressField(), size=None, default=list)),
- ('uuids', ArrayField(models.UUIDField(), size=None, default=list)),
- ('decimals', ArrayField(models.DecimalField(max_digits=5, decimal_places=2), size=None, default=list)),
- ('tags', ArrayField(TagField(), blank=True, null=True, size=None)),
- ('json', ArrayField(models.JSONField(default={}), default=[])),
- ('int_ranges', ArrayField(IntegerRangeField(), null=True, blank=True)),
- ('bigint_ranges', ArrayField(BigIntegerRangeField(), null=True, blank=True)),
+ (
+ "id",
+ models.AutoField(
+ verbose_name="ID",
+ serialize=False,
+ auto_created=True,
+ primary_key=True,
+ ),
+ ),
+ (
+ "ips",
+ ArrayField(models.GenericIPAddressField(), size=None, default=list),
+ ),
+ ("uuids", ArrayField(models.UUIDField(), size=None, default=list)),
+ (
+ "decimals",
+ ArrayField(
+ models.DecimalField(max_digits=5, decimal_places=2),
+ size=None,
+ default=list,
+ ),
+ ),
+ ("tags", ArrayField(TagField(), blank=True, null=True, size=None)),
+ ("json", ArrayField(models.JSONField(default={}), default=[])),
+ ("int_ranges", ArrayField(IntegerRangeField(), null=True, blank=True)),
+ (
+ "bigint_ranges",
+ ArrayField(BigIntegerRangeField(), null=True, blank=True),
+ ),
],
options={
- 'required_db_vendor': 'postgresql',
+ "required_db_vendor": "postgresql",
},
bases=(models.Model,),
),
migrations.CreateModel(
- name='IntegerArrayModel',
+ name="IntegerArrayModel",
fields=[
- ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
- ('field', ArrayField(models.IntegerField(), size=None)),
+ (
+ "id",
+ models.AutoField(
+ verbose_name="ID",
+ serialize=False,
+ auto_created=True,
+ primary_key=True,
+ ),
+ ),
+ ("field", ArrayField(models.IntegerField(), size=None)),
],
options={
- 'required_db_vendor': 'postgresql',
+ "required_db_vendor": "postgresql",
},
bases=(models.Model,),
),
migrations.CreateModel(
- name='NestedIntegerArrayModel',
+ name="NestedIntegerArrayModel",
fields=[
- ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
- ('field', ArrayField(ArrayField(models.IntegerField(), size=None), size=None)),
+ (
+ "id",
+ models.AutoField(
+ verbose_name="ID",
+ serialize=False,
+ auto_created=True,
+ primary_key=True,
+ ),
+ ),
+ (
+ "field",
+ ArrayField(ArrayField(models.IntegerField(), size=None), size=None),
+ ),
],
options={
- 'required_db_vendor': 'postgresql',
+ "required_db_vendor": "postgresql",
},
bases=(models.Model,),
),
migrations.CreateModel(
- name='NullableIntegerArrayModel',
+ name="NullableIntegerArrayModel",
fields=[
- ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
- ('field', ArrayField(models.IntegerField(), size=None, null=True, blank=True)),
(
- 'field_nested',
- ArrayField(ArrayField(models.IntegerField(), size=None, null=True), size=None, null=True),
+ "id",
+ models.AutoField(
+ verbose_name="ID",
+ serialize=False,
+ auto_created=True,
+ primary_key=True,
+ ),
),
- ('order', models.IntegerField(null=True)),
+ (
+ "field",
+ ArrayField(models.IntegerField(), size=None, null=True, blank=True),
+ ),
+ (
+ "field_nested",
+ ArrayField(
+ ArrayField(models.IntegerField(), size=None, null=True),
+ size=None,
+ null=True,
+ ),
+ ),
+ ("order", models.IntegerField(null=True)),
],
options={
- 'required_db_vendor': 'postgresql',
+ "required_db_vendor": "postgresql",
},
bases=(models.Model,),
),
migrations.CreateModel(
- name='CharFieldModel',
+ name="CharFieldModel",
fields=[
- ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
- ('field', models.CharField(max_length=64)),
+ (
+ "id",
+ models.AutoField(
+ verbose_name="ID",
+ serialize=False,
+ auto_created=True,
+ primary_key=True,
+ ),
+ ),
+ ("field", models.CharField(max_length=64)),
],
options=None,
bases=None,
),
migrations.CreateModel(
- name='TextFieldModel',
+ name="TextFieldModel",
fields=[
- ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
- ('field', models.TextField()),
+ (
+ "id",
+ models.AutoField(
+ verbose_name="ID",
+ serialize=False,
+ auto_created=True,
+ primary_key=True,
+ ),
+ ),
+ ("field", models.TextField()),
],
options=None,
bases=None,
),
migrations.CreateModel(
- name='SmallAutoFieldModel',
+ name="SmallAutoFieldModel",
fields=[
- ('id', models.SmallAutoField(verbose_name='ID', serialize=False, primary_key=True)),
+ (
+ "id",
+ models.SmallAutoField(
+ verbose_name="ID", serialize=False, primary_key=True
+ ),
+ ),
],
options=None,
),
migrations.CreateModel(
- name='BigAutoFieldModel',
+ name="BigAutoFieldModel",
fields=[
- ('id', models.BigAutoField(verbose_name='ID', serialize=False, primary_key=True)),
+ (
+ "id",
+ models.BigAutoField(
+ verbose_name="ID", serialize=False, primary_key=True
+ ),
+ ),
],
options=None,
),
migrations.CreateModel(
- name='Scene',
+ name="Scene",
fields=[
- ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
- ('scene', models.TextField()),
- ('setting', models.CharField(max_length=255)),
+ (
+ "id",
+ models.AutoField(
+ verbose_name="ID",
+ serialize=False,
+ auto_created=True,
+ primary_key=True,
+ ),
+ ),
+ ("scene", models.TextField()),
+ ("setting", models.CharField(max_length=255)),
],
options=None,
bases=None,
),
migrations.CreateModel(
- name='Character',
+ name="Character",
fields=[
- ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
- ('name', models.CharField(max_length=255)),
+ (
+ "id",
+ models.AutoField(
+ verbose_name="ID",
+ serialize=False,
+ auto_created=True,
+ primary_key=True,
+ ),
+ ),
+ ("name", models.CharField(max_length=255)),
],
options=None,
bases=None,
),
migrations.CreateModel(
- name='CITestModel',
+ name="CITestModel",
fields=[
- ('name', CICharField(primary_key=True, max_length=255)),
- ('email', CIEmailField()),
- ('description', CITextField()),
- ('array_field', ArrayField(CITextField(), null=True)),
+ ("name", CICharField(primary_key=True, max_length=255)),
+ ("email", CIEmailField()),
+ ("description", CITextField()),
+ ("array_field", ArrayField(CITextField(), null=True)),
],
options={
- 'required_db_vendor': 'postgresql',
+ "required_db_vendor": "postgresql",
},
bases=None,
),
migrations.CreateModel(
- name='Line',
+ name="Line",
fields=[
- ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
- ('scene', models.ForeignKey('postgres_tests.Scene', on_delete=models.SET_NULL)),
- ('character', models.ForeignKey('postgres_tests.Character', on_delete=models.SET_NULL)),
- ('dialogue', models.TextField(blank=True, null=True)),
- ('dialogue_search_vector', SearchVectorField(blank=True, null=True)),
- ('dialogue_config', models.CharField(max_length=100, blank=True, null=True)),
+ (
+ "id",
+ models.AutoField(
+ verbose_name="ID",
+ serialize=False,
+ auto_created=True,
+ primary_key=True,
+ ),
+ ),
+ (
+ "scene",
+ models.ForeignKey(
+ "postgres_tests.Scene", on_delete=models.SET_NULL
+ ),
+ ),
+ (
+ "character",
+ models.ForeignKey(
+ "postgres_tests.Character", on_delete=models.SET_NULL
+ ),
+ ),
+ ("dialogue", models.TextField(blank=True, null=True)),
+ ("dialogue_search_vector", SearchVectorField(blank=True, null=True)),
+ (
+ "dialogue_config",
+ models.CharField(max_length=100, blank=True, null=True),
+ ),
],
options={
- 'required_db_vendor': 'postgresql',
+ "required_db_vendor": "postgresql",
},
bases=None,
),
migrations.CreateModel(
- name='LineSavedSearch',
+ name="LineSavedSearch",
fields=[
- ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
- ('line', models.ForeignKey('postgres_tests.Line', on_delete=models.CASCADE)),
- ('query', models.CharField(max_length=100)),
+ (
+ "id",
+ models.AutoField(
+ verbose_name="ID",
+ serialize=False,
+ auto_created=True,
+ primary_key=True,
+ ),
+ ),
+ (
+ "line",
+ models.ForeignKey("postgres_tests.Line", on_delete=models.CASCADE),
+ ),
+ ("query", models.CharField(max_length=100)),
],
options={
- 'required_db_vendor': 'postgresql',
+ "required_db_vendor": "postgresql",
},
),
migrations.CreateModel(
- name='AggregateTestModel',
+ name="AggregateTestModel",
fields=[
- ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
- ('boolean_field', models.BooleanField(null=True)),
- ('char_field', models.CharField(max_length=30, blank=True)),
- ('text_field', models.TextField(blank=True)),
- ('integer_field', models.IntegerField(null=True)),
- ('json_field', models.JSONField(null=True)),
+ (
+ "id",
+ models.AutoField(
+ verbose_name="ID",
+ serialize=False,
+ auto_created=True,
+ primary_key=True,
+ ),
+ ),
+ ("boolean_field", models.BooleanField(null=True)),
+ ("char_field", models.CharField(max_length=30, blank=True)),
+ ("text_field", models.TextField(blank=True)),
+ ("integer_field", models.IntegerField(null=True)),
+ ("json_field", models.JSONField(null=True)),
],
options={
- 'required_db_vendor': 'postgresql',
+ "required_db_vendor": "postgresql",
},
),
migrations.CreateModel(
- name='StatTestModel',
+ name="StatTestModel",
fields=[
- ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
- ('int1', models.IntegerField()),
- ('int2', models.IntegerField()),
- ('related_field', models.ForeignKey(
- 'postgres_tests.AggregateTestModel',
- models.SET_NULL,
- null=True,
- )),
+ (
+ "id",
+ models.AutoField(
+ verbose_name="ID",
+ serialize=False,
+ auto_created=True,
+ primary_key=True,
+ ),
+ ),
+ ("int1", models.IntegerField()),
+ ("int2", models.IntegerField()),
+ (
+ "related_field",
+ models.ForeignKey(
+ "postgres_tests.AggregateTestModel",
+ models.SET_NULL,
+ null=True,
+ ),
+ ),
],
options={
- 'required_db_vendor': 'postgresql',
+ "required_db_vendor": "postgresql",
},
),
migrations.CreateModel(
- name='NowTestModel',
+ name="NowTestModel",
fields=[
- ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
- ('when', models.DateTimeField(null=True, default=None)),
- ]
+ (
+ "id",
+ models.AutoField(
+ verbose_name="ID",
+ serialize=False,
+ auto_created=True,
+ primary_key=True,
+ ),
+ ),
+ ("when", models.DateTimeField(null=True, default=None)),
+ ],
),
migrations.CreateModel(
- name='UUIDTestModel',
+ name="UUIDTestModel",
fields=[
- ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
- ('uuid', models.UUIDField(default=None, null=True)),
- ]
+ (
+ "id",
+ models.AutoField(
+ verbose_name="ID",
+ serialize=False,
+ auto_created=True,
+ primary_key=True,
+ ),
+ ),
+ ("uuid", models.UUIDField(default=None, null=True)),
+ ],
),
migrations.CreateModel(
- name='RangesModel',
+ name="RangesModel",
fields=[
- ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
- ('ints', IntegerRangeField(null=True, blank=True)),
- ('bigints', BigIntegerRangeField(null=True, blank=True)),
- ('decimals', DecimalRangeField(null=True, blank=True)),
- ('timestamps', DateTimeRangeField(null=True, blank=True)),
- ('timestamps_inner', DateTimeRangeField(null=True, blank=True)),
- ('timestamps_closed_bounds', DateTimeRangeField(null=True, blank=True, default_bounds='[]')),
- ('dates', DateRangeField(null=True, blank=True)),
- ('dates_inner', DateRangeField(null=True, blank=True)),
+ (
+ "id",
+ models.AutoField(
+ verbose_name="ID",
+ serialize=False,
+ auto_created=True,
+ primary_key=True,
+ ),
+ ),
+ ("ints", IntegerRangeField(null=True, blank=True)),
+ ("bigints", BigIntegerRangeField(null=True, blank=True)),
+ ("decimals", DecimalRangeField(null=True, blank=True)),
+ ("timestamps", DateTimeRangeField(null=True, blank=True)),
+ ("timestamps_inner", DateTimeRangeField(null=True, blank=True)),
+ (
+ "timestamps_closed_bounds",
+ DateTimeRangeField(null=True, blank=True, default_bounds="[]"),
+ ),
+ ("dates", DateRangeField(null=True, blank=True)),
+ ("dates_inner", DateRangeField(null=True, blank=True)),
],
- options={
- 'required_db_vendor': 'postgresql'
- },
- bases=(models.Model,)
+ options={"required_db_vendor": "postgresql"},
+ bases=(models.Model,),
),
migrations.CreateModel(
- name='RangeLookupsModel',
+ name="RangeLookupsModel",
fields=[
- ('parent', models.ForeignKey(
- 'postgres_tests.RangesModel',
- models.SET_NULL,
- blank=True, null=True,
- )),
- ('integer', models.IntegerField(blank=True, null=True)),
- ('big_integer', models.BigIntegerField(blank=True, null=True)),
- ('float', models.FloatField(blank=True, null=True)),
- ('timestamp', models.DateTimeField(blank=True, null=True)),
- ('date', models.DateField(blank=True, null=True)),
- ('small_integer', models.SmallIntegerField(blank=True, null=True)),
- ('decimal_field', models.DecimalField(max_digits=5, decimal_places=2, blank=True, null=True)),
+ (
+ "parent",
+ models.ForeignKey(
+ "postgres_tests.RangesModel",
+ models.SET_NULL,
+ blank=True,
+ null=True,
+ ),
+ ),
+ ("integer", models.IntegerField(blank=True, null=True)),
+ ("big_integer", models.BigIntegerField(blank=True, null=True)),
+ ("float", models.FloatField(blank=True, null=True)),
+ ("timestamp", models.DateTimeField(blank=True, null=True)),
+ ("date", models.DateField(blank=True, null=True)),
+ ("small_integer", models.SmallIntegerField(blank=True, null=True)),
+ (
+ "decimal_field",
+ models.DecimalField(
+ max_digits=5, decimal_places=2, blank=True, null=True
+ ),
+ ),
],
options={
- 'required_db_vendor': 'postgresql',
+ "required_db_vendor": "postgresql",
},
bases=(models.Model,),
),
migrations.CreateModel(
- name='ArrayEnumModel',
+ name="ArrayEnumModel",
fields=[
- ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
- ('array_of_enums', ArrayField(EnumField(max_length=20), null=True, blank=True)),
+ (
+ "id",
+ models.AutoField(
+ verbose_name="ID",
+ serialize=False,
+ auto_created=True,
+ primary_key=True,
+ ),
+ ),
+ (
+ "array_of_enums",
+ ArrayField(EnumField(max_length=20), null=True, blank=True),
+ ),
],
options={
- 'required_db_vendor': 'postgresql',
+ "required_db_vendor": "postgresql",
},
bases=(models.Model,),
),
migrations.CreateModel(
- name='Room',
+ name="Room",
fields=[
- ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
- ('number', models.IntegerField(unique=True)),
+ (
+ "id",
+ models.AutoField(
+ verbose_name="ID",
+ serialize=False,
+ auto_created=True,
+ primary_key=True,
+ ),
+ ),
+ ("number", models.IntegerField(unique=True)),
],
),
migrations.CreateModel(
- name='HotelReservation',
+ name="HotelReservation",
fields=[
- ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
- ('room', models.ForeignKey('postgres_tests.Room', models.CASCADE)),
- ('datespan', DateRangeField()),
- ('start', models.DateTimeField()),
- ('end', models.DateTimeField()),
- ('cancelled', models.BooleanField(default=False)),
- ('requirements', models.JSONField(blank=True, null=True)),
+ (
+ "id",
+ models.AutoField(
+ verbose_name="ID",
+ serialize=False,
+ auto_created=True,
+ primary_key=True,
+ ),
+ ),
+ ("room", models.ForeignKey("postgres_tests.Room", models.CASCADE)),
+ ("datespan", DateRangeField()),
+ ("start", models.DateTimeField()),
+ ("end", models.DateTimeField()),
+ ("cancelled", models.BooleanField(default=False)),
+ ("requirements", models.JSONField(blank=True, null=True)),
],
options={
- 'required_db_vendor': 'postgresql',
+ "required_db_vendor": "postgresql",
},
),
]
diff --git a/tests/postgres_tests/models.py b/tests/postgres_tests/models.py
index 444b039840..8f27838ad5 100644
--- a/tests/postgres_tests/models.py
+++ b/tests/postgres_tests/models.py
@@ -1,9 +1,18 @@
from django.db import models
from .fields import (
- ArrayField, BigIntegerRangeField, CICharField, CIEmailField, CITextField,
- DateRangeField, DateTimeRangeField, DecimalRangeField, EnumField,
- HStoreField, IntegerRangeField, SearchVectorField,
+ ArrayField,
+ BigIntegerRangeField,
+ CICharField,
+ CIEmailField,
+ CITextField,
+ DateRangeField,
+ DateTimeRangeField,
+ DecimalRangeField,
+ EnumField,
+ HStoreField,
+ IntegerRangeField,
+ SearchVectorField,
)
@@ -16,7 +25,6 @@ class Tag:
class TagField(models.SmallIntegerField):
-
def from_db_value(self, value, expression, connection):
if value is None:
return value
@@ -36,7 +44,7 @@ class TagField(models.SmallIntegerField):
class PostgreSQLModel(models.Model):
class Meta:
abstract = True
- required_db_vendor = 'postgresql'
+ required_db_vendor = "postgresql"
class IntegerArrayModel(PostgreSQLModel):
@@ -66,7 +74,9 @@ class NestedIntegerArrayModel(PostgreSQLModel):
class OtherTypesArrayModel(PostgreSQLModel):
ips = ArrayField(models.GenericIPAddressField(), default=list)
uuids = ArrayField(models.UUIDField(), default=list)
- decimals = ArrayField(models.DecimalField(max_digits=5, decimal_places=2), default=list)
+ decimals = ArrayField(
+ models.DecimalField(max_digits=5, decimal_places=2), default=list
+ )
tags = ArrayField(TagField(), blank=True, null=True)
json = ArrayField(models.JSONField(default=dict), default=list)
int_ranges = ArrayField(IntegerRangeField(), blank=True, null=True)
@@ -117,15 +127,15 @@ class CITestModel(PostgreSQLModel):
class Line(PostgreSQLModel):
- scene = models.ForeignKey('Scene', models.CASCADE)
- character = models.ForeignKey('Character', models.CASCADE)
+ scene = models.ForeignKey("Scene", models.CASCADE)
+ character = models.ForeignKey("Character", models.CASCADE)
dialogue = models.TextField(blank=True, null=True)
dialogue_search_vector = SearchVectorField(blank=True, null=True)
dialogue_config = models.CharField(max_length=100, blank=True, null=True)
class LineSavedSearch(PostgreSQLModel):
- line = models.ForeignKey('Line', models.CASCADE)
+ line = models.ForeignKey("Line", models.CASCADE)
query = models.CharField(max_length=100)
@@ -136,7 +146,9 @@ class RangesModel(PostgreSQLModel):
timestamps = DateTimeRangeField(blank=True, null=True)
timestamps_inner = DateTimeRangeField(blank=True, null=True)
timestamps_closed_bounds = DateTimeRangeField(
- blank=True, null=True, default_bounds='[]',
+ blank=True,
+ null=True,
+ default_bounds="[]",
)
dates = DateRangeField(blank=True, null=True)
dates_inner = DateRangeField(blank=True, null=True)
@@ -150,7 +162,9 @@ class RangeLookupsModel(PostgreSQLModel):
timestamp = models.DateTimeField(blank=True, null=True)
date = models.DateField(blank=True, null=True)
small_integer = models.SmallIntegerField(blank=True, null=True)
- decimal_field = models.DecimalField(max_digits=5, decimal_places=2, blank=True, null=True)
+ decimal_field = models.DecimalField(
+ max_digits=5, decimal_places=2, blank=True, null=True
+ )
class ArrayFieldSubclass(ArrayField):
@@ -162,6 +176,7 @@ class AggregateTestModel(PostgreSQLModel):
"""
To test postgres-specific general aggregation functions
"""
+
char_field = models.CharField(max_length=30, blank=True)
text_field = models.TextField(blank=True)
integer_field = models.IntegerField(null=True)
@@ -173,6 +188,7 @@ class StatTestModel(PostgreSQLModel):
"""
To test postgres-specific aggregation functions for statistics
"""
+
int1 = models.IntegerField()
int2 = models.IntegerField()
related_field = models.ForeignKey(AggregateTestModel, models.SET_NULL, null=True)
@@ -191,7 +207,7 @@ class Room(models.Model):
class HotelReservation(PostgreSQLModel):
- room = models.ForeignKey('Room', on_delete=models.CASCADE)
+ room = models.ForeignKey("Room", on_delete=models.CASCADE)
datespan = DateRangeField()
start = models.DateTimeField()
end = models.DateTimeField()
diff --git a/tests/postgres_tests/test_aggregates.py b/tests/postgres_tests/test_aggregates.py
index 399c4fa8a9..c3df490fcf 100644
--- a/tests/postgres_tests/test_aggregates.py
+++ b/tests/postgres_tests/test_aggregates.py
@@ -1,6 +1,13 @@
from django.db import connection
from django.db.models import (
- CharField, F, Func, IntegerField, OuterRef, Q, Subquery, Value,
+ CharField,
+ F,
+ Func,
+ IntegerField,
+ OuterRef,
+ Q,
+ Subquery,
+ Value,
)
from django.db.models.fields.json import KeyTextTransform, KeyTransform
from django.db.models.functions import Cast, Concat, Substr
@@ -14,9 +21,26 @@ from .models import AggregateTestModel, HotelReservation, Room, StatTestModel
try:
from django.contrib.postgres.aggregates import (
- ArrayAgg, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, Corr, CovarPop,
- JSONBAgg, RegrAvgX, RegrAvgY, RegrCount, RegrIntercept, RegrR2,
- RegrSlope, RegrSXX, RegrSXY, RegrSYY, StatAggregate, StringAgg,
+ ArrayAgg,
+ BitAnd,
+ BitOr,
+ BitXor,
+ BoolAnd,
+ BoolOr,
+ Corr,
+ CovarPop,
+ JSONBAgg,
+ RegrAvgX,
+ RegrAvgY,
+ RegrCount,
+ RegrIntercept,
+ RegrR2,
+ RegrSlope,
+ RegrSXX,
+ RegrSXY,
+ RegrSYY,
+ StatAggregate,
+ StringAgg,
)
from django.contrib.postgres.fields import ArrayField
except ImportError:
@@ -26,52 +50,54 @@ except ImportError:
class TestGeneralAggregate(PostgreSQLTestCase):
@classmethod
def setUpTestData(cls):
- cls.aggs = AggregateTestModel.objects.bulk_create([
- AggregateTestModel(
- boolean_field=True,
- char_field='Foo1',
- text_field='Text1',
- integer_field=0,
- ),
- AggregateTestModel(
- boolean_field=False,
- char_field='Foo2',
- text_field='Text2',
- integer_field=1,
- json_field={'lang': 'pl'},
- ),
- AggregateTestModel(
- boolean_field=False,
- char_field='Foo4',
- text_field='Text4',
- integer_field=2,
- json_field={'lang': 'en'},
- ),
- AggregateTestModel(
- boolean_field=True,
- char_field='Foo3',
- text_field='Text3',
- integer_field=0,
- json_field={'breed': 'collie'},
- ),
- ])
+ cls.aggs = AggregateTestModel.objects.bulk_create(
+ [
+ AggregateTestModel(
+ boolean_field=True,
+ char_field="Foo1",
+ text_field="Text1",
+ integer_field=0,
+ ),
+ AggregateTestModel(
+ boolean_field=False,
+ char_field="Foo2",
+ text_field="Text2",
+ integer_field=1,
+ json_field={"lang": "pl"},
+ ),
+ AggregateTestModel(
+ boolean_field=False,
+ char_field="Foo4",
+ text_field="Text4",
+ integer_field=2,
+ json_field={"lang": "en"},
+ ),
+ AggregateTestModel(
+ boolean_field=True,
+ char_field="Foo3",
+ text_field="Text3",
+ integer_field=0,
+ json_field={"breed": "collie"},
+ ),
+ ]
+ )
@ignore_warnings(category=RemovedInDjango50Warning)
def test_empty_result_set(self):
AggregateTestModel.objects.all().delete()
tests = [
- (ArrayAgg('char_field'), []),
- (ArrayAgg('integer_field'), []),
- (ArrayAgg('boolean_field'), []),
- (BitAnd('integer_field'), None),
- (BitOr('integer_field'), None),
- (BoolAnd('boolean_field'), None),
- (BoolOr('boolean_field'), None),
- (JSONBAgg('integer_field'), []),
- (StringAgg('char_field', delimiter=';'), ''),
+ (ArrayAgg("char_field"), []),
+ (ArrayAgg("integer_field"), []),
+ (ArrayAgg("boolean_field"), []),
+ (BitAnd("integer_field"), None),
+ (BitOr("integer_field"), None),
+ (BoolAnd("boolean_field"), None),
+ (BoolOr("boolean_field"), None),
+ (JSONBAgg("integer_field"), []),
+ (StringAgg("char_field", delimiter=";"), ""),
]
if connection.features.has_bit_xor:
- tests.append((BitXor('integer_field'), None))
+ tests.append((BitXor("integer_field"), None))
for aggregation, expected_result in tests:
with self.subTest(aggregation=aggregation):
# Empty result with non-execution optimization.
@@ -79,29 +105,32 @@ class TestGeneralAggregate(PostgreSQLTestCase):
values = AggregateTestModel.objects.none().aggregate(
aggregation=aggregation,
)
- self.assertEqual(values, {'aggregation': expected_result})
+ self.assertEqual(values, {"aggregation": expected_result})
# Empty result when query must be executed.
with self.assertNumQueries(1):
values = AggregateTestModel.objects.aggregate(
aggregation=aggregation,
)
- self.assertEqual(values, {'aggregation': expected_result})
+ self.assertEqual(values, {"aggregation": expected_result})
def test_default_argument(self):
AggregateTestModel.objects.all().delete()
tests = [
- (ArrayAgg('char_field', default=['<empty>']), ['<empty>']),
- (ArrayAgg('integer_field', default=[0]), [0]),
- (ArrayAgg('boolean_field', default=[False]), [False]),
- (BitAnd('integer_field', default=0), 0),
- (BitOr('integer_field', default=0), 0),
- (BoolAnd('boolean_field', default=False), False),
- (BoolOr('boolean_field', default=False), False),
- (JSONBAgg('integer_field', default=Value('["<empty>"]')), ['<empty>']),
- (StringAgg('char_field', delimiter=';', default=Value('<empty>')), '<empty>'),
+ (ArrayAgg("char_field", default=["<empty>"]), ["<empty>"]),
+ (ArrayAgg("integer_field", default=[0]), [0]),
+ (ArrayAgg("boolean_field", default=[False]), [False]),
+ (BitAnd("integer_field", default=0), 0),
+ (BitOr("integer_field", default=0), 0),
+ (BoolAnd("boolean_field", default=False), False),
+ (BoolOr("boolean_field", default=False), False),
+ (JSONBAgg("integer_field", default=Value('["<empty>"]')), ["<empty>"]),
+ (
+ StringAgg("char_field", delimiter=";", default=Value("<empty>")),
+ "<empty>",
+ ),
]
if connection.features.has_bit_xor:
- tests.append((BitXor('integer_field', default=0), 0))
+ tests.append((BitXor("integer_field", default=0), 0))
for aggregation, expected_result in tests:
with self.subTest(aggregation=aggregation):
# Empty result with non-execution optimization.
@@ -109,135 +138,159 @@ class TestGeneralAggregate(PostgreSQLTestCase):
values = AggregateTestModel.objects.none().aggregate(
aggregation=aggregation,
)
- self.assertEqual(values, {'aggregation': expected_result})
+ self.assertEqual(values, {"aggregation": expected_result})
# Empty result when query must be executed.
with self.assertNumQueries(1):
values = AggregateTestModel.objects.aggregate(
aggregation=aggregation,
)
- self.assertEqual(values, {'aggregation': expected_result})
+ self.assertEqual(values, {"aggregation": expected_result})
def test_convert_value_deprecation(self):
AggregateTestModel.objects.all().delete()
queryset = AggregateTestModel.objects.all()
- with self.assertWarnsMessage(RemovedInDjango50Warning, ArrayAgg.deprecation_msg):
- queryset.aggregate(aggregation=ArrayAgg('boolean_field'))
+ with self.assertWarnsMessage(
+ RemovedInDjango50Warning, ArrayAgg.deprecation_msg
+ ):
+ queryset.aggregate(aggregation=ArrayAgg("boolean_field"))
- with self.assertWarnsMessage(RemovedInDjango50Warning, JSONBAgg.deprecation_msg):
- queryset.aggregate(aggregation=JSONBAgg('integer_field'))
+ with self.assertWarnsMessage(
+ RemovedInDjango50Warning, JSONBAgg.deprecation_msg
+ ):
+ queryset.aggregate(aggregation=JSONBAgg("integer_field"))
- with self.assertWarnsMessage(RemovedInDjango50Warning, StringAgg.deprecation_msg):
- queryset.aggregate(aggregation=StringAgg('char_field', delimiter=';'))
+ with self.assertWarnsMessage(
+ RemovedInDjango50Warning, StringAgg.deprecation_msg
+ ):
+ queryset.aggregate(aggregation=StringAgg("char_field", delimiter=";"))
# No warnings raised if default argument provided.
self.assertEqual(
- queryset.aggregate(aggregation=ArrayAgg('boolean_field', default=None)),
- {'aggregation': None},
+ queryset.aggregate(aggregation=ArrayAgg("boolean_field", default=None)),
+ {"aggregation": None},
)
self.assertEqual(
- queryset.aggregate(aggregation=JSONBAgg('integer_field', default=None)),
- {'aggregation': None},
+ queryset.aggregate(aggregation=JSONBAgg("integer_field", default=None)),
+ {"aggregation": None},
)
self.assertEqual(
queryset.aggregate(
- aggregation=StringAgg('char_field', delimiter=';', default=None),
+ aggregation=StringAgg("char_field", delimiter=";", default=None),
),
- {'aggregation': None},
+ {"aggregation": None},
)
self.assertEqual(
- queryset.aggregate(aggregation=ArrayAgg('boolean_field', default=Value([]))),
- {'aggregation': []},
+ queryset.aggregate(
+ aggregation=ArrayAgg("boolean_field", default=Value([]))
+ ),
+ {"aggregation": []},
)
self.assertEqual(
- queryset.aggregate(aggregation=JSONBAgg('integer_field', default=Value('[]'))),
- {'aggregation': []},
+ queryset.aggregate(
+ aggregation=JSONBAgg("integer_field", default=Value("[]"))
+ ),
+ {"aggregation": []},
)
self.assertEqual(
queryset.aggregate(
- aggregation=StringAgg('char_field', delimiter=';', default=Value('')),
+ aggregation=StringAgg("char_field", delimiter=";", default=Value("")),
),
- {'aggregation': ''},
+ {"aggregation": ""},
)
def test_array_agg_charfield(self):
- values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('char_field'))
- self.assertEqual(values, {'arrayagg': ['Foo1', 'Foo2', 'Foo4', 'Foo3']})
+ values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg("char_field"))
+ self.assertEqual(values, {"arrayagg": ["Foo1", "Foo2", "Foo4", "Foo3"]})
def test_array_agg_charfield_ordering(self):
ordering_test_cases = (
- (F('char_field').desc(), ['Foo4', 'Foo3', 'Foo2', 'Foo1']),
- (F('char_field').asc(), ['Foo1', 'Foo2', 'Foo3', 'Foo4']),
- (F('char_field'), ['Foo1', 'Foo2', 'Foo3', 'Foo4']),
- ([F('boolean_field'), F('char_field').desc()], ['Foo4', 'Foo2', 'Foo3', 'Foo1']),
- ((F('boolean_field'), F('char_field').desc()), ['Foo4', 'Foo2', 'Foo3', 'Foo1']),
- ('char_field', ['Foo1', 'Foo2', 'Foo3', 'Foo4']),
- ('-char_field', ['Foo4', 'Foo3', 'Foo2', 'Foo1']),
- (Concat('char_field', Value('@')), ['Foo1', 'Foo2', 'Foo3', 'Foo4']),
- (Concat('char_field', Value('@')).desc(), ['Foo4', 'Foo3', 'Foo2', 'Foo1']),
+ (F("char_field").desc(), ["Foo4", "Foo3", "Foo2", "Foo1"]),
+ (F("char_field").asc(), ["Foo1", "Foo2", "Foo3", "Foo4"]),
+ (F("char_field"), ["Foo1", "Foo2", "Foo3", "Foo4"]),
+ (
+ [F("boolean_field"), F("char_field").desc()],
+ ["Foo4", "Foo2", "Foo3", "Foo1"],
+ ),
(
- (Substr('char_field', 1, 1), F('integer_field'), Substr('char_field', 4, 1).desc()),
- ['Foo3', 'Foo1', 'Foo2', 'Foo4'],
+ (F("boolean_field"), F("char_field").desc()),
+ ["Foo4", "Foo2", "Foo3", "Foo1"],
+ ),
+ ("char_field", ["Foo1", "Foo2", "Foo3", "Foo4"]),
+ ("-char_field", ["Foo4", "Foo3", "Foo2", "Foo1"]),
+ (Concat("char_field", Value("@")), ["Foo1", "Foo2", "Foo3", "Foo4"]),
+ (Concat("char_field", Value("@")).desc(), ["Foo4", "Foo3", "Foo2", "Foo1"]),
+ (
+ (
+ Substr("char_field", 1, 1),
+ F("integer_field"),
+ Substr("char_field", 4, 1).desc(),
+ ),
+ ["Foo3", "Foo1", "Foo2", "Foo4"],
),
)
for ordering, expected_output in ordering_test_cases:
with self.subTest(ordering=ordering, expected_output=expected_output):
values = AggregateTestModel.objects.aggregate(
- arrayagg=ArrayAgg('char_field', ordering=ordering)
+ arrayagg=ArrayAgg("char_field", ordering=ordering)
)
- self.assertEqual(values, {'arrayagg': expected_output})
+ self.assertEqual(values, {"arrayagg": expected_output})
def test_array_agg_integerfield(self):
- values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('integer_field'))
- self.assertEqual(values, {'arrayagg': [0, 1, 2, 0]})
+ values = AggregateTestModel.objects.aggregate(
+ arrayagg=ArrayAgg("integer_field")
+ )
+ self.assertEqual(values, {"arrayagg": [0, 1, 2, 0]})
def test_array_agg_integerfield_ordering(self):
values = AggregateTestModel.objects.aggregate(
- arrayagg=ArrayAgg('integer_field', ordering=F('integer_field').desc())
+ arrayagg=ArrayAgg("integer_field", ordering=F("integer_field").desc())
)
- self.assertEqual(values, {'arrayagg': [2, 1, 0, 0]})
+ self.assertEqual(values, {"arrayagg": [2, 1, 0, 0]})
def test_array_agg_booleanfield(self):
- values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('boolean_field'))
- self.assertEqual(values, {'arrayagg': [True, False, False, True]})
+ values = AggregateTestModel.objects.aggregate(
+ arrayagg=ArrayAgg("boolean_field")
+ )
+ self.assertEqual(values, {"arrayagg": [True, False, False, True]})
def test_array_agg_booleanfield_ordering(self):
ordering_test_cases = (
- (F('boolean_field').asc(), [False, False, True, True]),
- (F('boolean_field').desc(), [True, True, False, False]),
- (F('boolean_field'), [False, False, True, True]),
+ (F("boolean_field").asc(), [False, False, True, True]),
+ (F("boolean_field").desc(), [True, True, False, False]),
+ (F("boolean_field"), [False, False, True, True]),
)
for ordering, expected_output in ordering_test_cases:
with self.subTest(ordering=ordering, expected_output=expected_output):
values = AggregateTestModel.objects.aggregate(
- arrayagg=ArrayAgg('boolean_field', ordering=ordering)
+ arrayagg=ArrayAgg("boolean_field", ordering=ordering)
)
- self.assertEqual(values, {'arrayagg': expected_output})
+ self.assertEqual(values, {"arrayagg": expected_output})
def test_array_agg_jsonfield(self):
values = AggregateTestModel.objects.aggregate(
arrayagg=ArrayAgg(
- KeyTransform('lang', 'json_field'),
+ KeyTransform("lang", "json_field"),
filter=Q(json_field__lang__isnull=False),
),
)
- self.assertEqual(values, {'arrayagg': ['pl', 'en']})
+ self.assertEqual(values, {"arrayagg": ["pl", "en"]})
def test_array_agg_jsonfield_ordering(self):
values = AggregateTestModel.objects.aggregate(
arrayagg=ArrayAgg(
- KeyTransform('lang', 'json_field'),
+ KeyTransform("lang", "json_field"),
filter=Q(json_field__lang__isnull=False),
- ordering=KeyTransform('lang', 'json_field'),
+ ordering=KeyTransform("lang", "json_field"),
),
)
- self.assertEqual(values, {'arrayagg': ['en', 'pl']})
+ self.assertEqual(values, {"arrayagg": ["en", "pl"]})
def test_array_agg_filter(self):
values = AggregateTestModel.objects.aggregate(
- arrayagg=ArrayAgg('integer_field', filter=Q(integer_field__gt=0)),
+ arrayagg=ArrayAgg("integer_field", filter=Q(integer_field__gt=0)),
)
- self.assertEqual(values, {'arrayagg': [1, 2]})
+ self.assertEqual(values, {"arrayagg": [1, 2]})
def test_array_agg_lookups(self):
aggr1 = AggregateTestModel.objects.create()
@@ -246,194 +299,207 @@ class TestGeneralAggregate(PostgreSQLTestCase):
StatTestModel.objects.create(related_field=aggr1, int1=2, int2=0)
StatTestModel.objects.create(related_field=aggr2, int1=3, int2=0)
StatTestModel.objects.create(related_field=aggr2, int1=4, int2=0)
- qs = StatTestModel.objects.values('related_field').annotate(
- array=ArrayAgg('int1')
- ).filter(array__overlap=[2]).values_list('array', flat=True)
+ qs = (
+ StatTestModel.objects.values("related_field")
+ .annotate(array=ArrayAgg("int1"))
+ .filter(array__overlap=[2])
+ .values_list("array", flat=True)
+ )
self.assertCountEqual(qs.get(), [1, 2])
def test_bit_and_general(self):
- values = AggregateTestModel.objects.filter(
- integer_field__in=[0, 1]).aggregate(bitand=BitAnd('integer_field'))
- self.assertEqual(values, {'bitand': 0})
+ values = AggregateTestModel.objects.filter(integer_field__in=[0, 1]).aggregate(
+ bitand=BitAnd("integer_field")
+ )
+ self.assertEqual(values, {"bitand": 0})
def test_bit_and_on_only_true_values(self):
- values = AggregateTestModel.objects.filter(
- integer_field=1).aggregate(bitand=BitAnd('integer_field'))
- self.assertEqual(values, {'bitand': 1})
+ values = AggregateTestModel.objects.filter(integer_field=1).aggregate(
+ bitand=BitAnd("integer_field")
+ )
+ self.assertEqual(values, {"bitand": 1})
def test_bit_and_on_only_false_values(self):
- values = AggregateTestModel.objects.filter(
- integer_field=0).aggregate(bitand=BitAnd('integer_field'))
- self.assertEqual(values, {'bitand': 0})
+ values = AggregateTestModel.objects.filter(integer_field=0).aggregate(
+ bitand=BitAnd("integer_field")
+ )
+ self.assertEqual(values, {"bitand": 0})
def test_bit_or_general(self):
- values = AggregateTestModel.objects.filter(
- integer_field__in=[0, 1]).aggregate(bitor=BitOr('integer_field'))
- self.assertEqual(values, {'bitor': 1})
+ values = AggregateTestModel.objects.filter(integer_field__in=[0, 1]).aggregate(
+ bitor=BitOr("integer_field")
+ )
+ self.assertEqual(values, {"bitor": 1})
def test_bit_or_on_only_true_values(self):
- values = AggregateTestModel.objects.filter(
- integer_field=1).aggregate(bitor=BitOr('integer_field'))
- self.assertEqual(values, {'bitor': 1})
+ values = AggregateTestModel.objects.filter(integer_field=1).aggregate(
+ bitor=BitOr("integer_field")
+ )
+ self.assertEqual(values, {"bitor": 1})
def test_bit_or_on_only_false_values(self):
- values = AggregateTestModel.objects.filter(
- integer_field=0).aggregate(bitor=BitOr('integer_field'))
- self.assertEqual(values, {'bitor': 0})
+ values = AggregateTestModel.objects.filter(integer_field=0).aggregate(
+ bitor=BitOr("integer_field")
+ )
+ self.assertEqual(values, {"bitor": 0})
- @skipUnlessDBFeature('has_bit_xor')
+ @skipUnlessDBFeature("has_bit_xor")
def test_bit_xor_general(self):
AggregateTestModel.objects.create(integer_field=3)
values = AggregateTestModel.objects.filter(
integer_field__in=[1, 3],
- ).aggregate(bitxor=BitXor('integer_field'))
- self.assertEqual(values, {'bitxor': 2})
+ ).aggregate(bitxor=BitXor("integer_field"))
+ self.assertEqual(values, {"bitxor": 2})
- @skipUnlessDBFeature('has_bit_xor')
+ @skipUnlessDBFeature("has_bit_xor")
def test_bit_xor_on_only_true_values(self):
values = AggregateTestModel.objects.filter(
integer_field=1,
- ).aggregate(bitxor=BitXor('integer_field'))
- self.assertEqual(values, {'bitxor': 1})
+ ).aggregate(bitxor=BitXor("integer_field"))
+ self.assertEqual(values, {"bitxor": 1})
- @skipUnlessDBFeature('has_bit_xor')
+ @skipUnlessDBFeature("has_bit_xor")
def test_bit_xor_on_only_false_values(self):
values = AggregateTestModel.objects.filter(
integer_field=0,
- ).aggregate(bitxor=BitXor('integer_field'))
- self.assertEqual(values, {'bitxor': 0})
+ ).aggregate(bitxor=BitXor("integer_field"))
+ self.assertEqual(values, {"bitxor": 0})
def test_bool_and_general(self):
- values = AggregateTestModel.objects.aggregate(booland=BoolAnd('boolean_field'))
- self.assertEqual(values, {'booland': False})
+ values = AggregateTestModel.objects.aggregate(booland=BoolAnd("boolean_field"))
+ self.assertEqual(values, {"booland": False})
def test_bool_and_q_object(self):
values = AggregateTestModel.objects.aggregate(
booland=BoolAnd(Q(integer_field__gt=2)),
)
- self.assertEqual(values, {'booland': False})
+ self.assertEqual(values, {"booland": False})
def test_bool_or_general(self):
- values = AggregateTestModel.objects.aggregate(boolor=BoolOr('boolean_field'))
- self.assertEqual(values, {'boolor': True})
+ values = AggregateTestModel.objects.aggregate(boolor=BoolOr("boolean_field"))
+ self.assertEqual(values, {"boolor": True})
def test_bool_or_q_object(self):
values = AggregateTestModel.objects.aggregate(
boolor=BoolOr(Q(integer_field__gt=2)),
)
- self.assertEqual(values, {'boolor': False})
+ self.assertEqual(values, {"boolor": False})
def test_string_agg_requires_delimiter(self):
with self.assertRaises(TypeError):
- AggregateTestModel.objects.aggregate(stringagg=StringAgg('char_field'))
+ AggregateTestModel.objects.aggregate(stringagg=StringAgg("char_field"))
def test_string_agg_delimiter_escaping(self):
- values = AggregateTestModel.objects.aggregate(stringagg=StringAgg('char_field', delimiter="'"))
- self.assertEqual(values, {'stringagg': "Foo1'Foo2'Foo4'Foo3"})
+ values = AggregateTestModel.objects.aggregate(
+ stringagg=StringAgg("char_field", delimiter="'")
+ )
+ self.assertEqual(values, {"stringagg": "Foo1'Foo2'Foo4'Foo3"})
def test_string_agg_charfield(self):
- values = AggregateTestModel.objects.aggregate(stringagg=StringAgg('char_field', delimiter=';'))
- self.assertEqual(values, {'stringagg': 'Foo1;Foo2;Foo4;Foo3'})
+ values = AggregateTestModel.objects.aggregate(
+ stringagg=StringAgg("char_field", delimiter=";")
+ )
+ self.assertEqual(values, {"stringagg": "Foo1;Foo2;Foo4;Foo3"})
def test_string_agg_default_output_field(self):
values = AggregateTestModel.objects.aggregate(
- stringagg=StringAgg('text_field', delimiter=';'),
+ stringagg=StringAgg("text_field", delimiter=";"),
)
- self.assertEqual(values, {'stringagg': 'Text1;Text2;Text4;Text3'})
+ self.assertEqual(values, {"stringagg": "Text1;Text2;Text4;Text3"})
def test_string_agg_charfield_ordering(self):
ordering_test_cases = (
- (F('char_field').desc(), 'Foo4;Foo3;Foo2;Foo1'),
- (F('char_field').asc(), 'Foo1;Foo2;Foo3;Foo4'),
- (F('char_field'), 'Foo1;Foo2;Foo3;Foo4'),
- ('char_field', 'Foo1;Foo2;Foo3;Foo4'),
- ('-char_field', 'Foo4;Foo3;Foo2;Foo1'),
- (Concat('char_field', Value('@')), 'Foo1;Foo2;Foo3;Foo4'),
- (Concat('char_field', Value('@')).desc(), 'Foo4;Foo3;Foo2;Foo1'),
+ (F("char_field").desc(), "Foo4;Foo3;Foo2;Foo1"),
+ (F("char_field").asc(), "Foo1;Foo2;Foo3;Foo4"),
+ (F("char_field"), "Foo1;Foo2;Foo3;Foo4"),
+ ("char_field", "Foo1;Foo2;Foo3;Foo4"),
+ ("-char_field", "Foo4;Foo3;Foo2;Foo1"),
+ (Concat("char_field", Value("@")), "Foo1;Foo2;Foo3;Foo4"),
+ (Concat("char_field", Value("@")).desc(), "Foo4;Foo3;Foo2;Foo1"),
)
for ordering, expected_output in ordering_test_cases:
with self.subTest(ordering=ordering, expected_output=expected_output):
values = AggregateTestModel.objects.aggregate(
- stringagg=StringAgg('char_field', delimiter=';', ordering=ordering)
+ stringagg=StringAgg("char_field", delimiter=";", ordering=ordering)
)
- self.assertEqual(values, {'stringagg': expected_output})
+ self.assertEqual(values, {"stringagg": expected_output})
def test_string_agg_jsonfield_ordering(self):
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg(
- KeyTextTransform('lang', 'json_field'),
- delimiter=';',
- ordering=KeyTextTransform('lang', 'json_field'),
+ KeyTextTransform("lang", "json_field"),
+ delimiter=";",
+ ordering=KeyTextTransform("lang", "json_field"),
output_field=CharField(),
),
)
- self.assertEqual(values, {'stringagg': 'en;pl'})
+ self.assertEqual(values, {"stringagg": "en;pl"})
def test_string_agg_filter(self):
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg(
- 'char_field',
- delimiter=';',
- filter=Q(char_field__endswith='3') | Q(char_field__endswith='1'),
+ "char_field",
+ delimiter=";",
+ filter=Q(char_field__endswith="3") | Q(char_field__endswith="1"),
)
)
- self.assertEqual(values, {'stringagg': 'Foo1;Foo3'})
+ self.assertEqual(values, {"stringagg": "Foo1;Foo3"})
def test_orderable_agg_alternative_fields(self):
values = AggregateTestModel.objects.aggregate(
- arrayagg=ArrayAgg('integer_field', ordering=F('char_field').asc())
+ arrayagg=ArrayAgg("integer_field", ordering=F("char_field").asc())
)
- self.assertEqual(values, {'arrayagg': [0, 1, 0, 2]})
+ self.assertEqual(values, {"arrayagg": [0, 1, 0, 2]})
def test_jsonb_agg(self):
- values = AggregateTestModel.objects.aggregate(jsonbagg=JSONBAgg('char_field'))
- self.assertEqual(values, {'jsonbagg': ['Foo1', 'Foo2', 'Foo4', 'Foo3']})
+ values = AggregateTestModel.objects.aggregate(jsonbagg=JSONBAgg("char_field"))
+ self.assertEqual(values, {"jsonbagg": ["Foo1", "Foo2", "Foo4", "Foo3"]})
def test_jsonb_agg_charfield_ordering(self):
ordering_test_cases = (
- (F('char_field').desc(), ['Foo4', 'Foo3', 'Foo2', 'Foo1']),
- (F('char_field').asc(), ['Foo1', 'Foo2', 'Foo3', 'Foo4']),
- (F('char_field'), ['Foo1', 'Foo2', 'Foo3', 'Foo4']),
- ('char_field', ['Foo1', 'Foo2', 'Foo3', 'Foo4']),
- ('-char_field', ['Foo4', 'Foo3', 'Foo2', 'Foo1']),
- (Concat('char_field', Value('@')), ['Foo1', 'Foo2', 'Foo3', 'Foo4']),
- (Concat('char_field', Value('@')).desc(), ['Foo4', 'Foo3', 'Foo2', 'Foo1']),
+ (F("char_field").desc(), ["Foo4", "Foo3", "Foo2", "Foo1"]),
+ (F("char_field").asc(), ["Foo1", "Foo2", "Foo3", "Foo4"]),
+ (F("char_field"), ["Foo1", "Foo2", "Foo3", "Foo4"]),
+ ("char_field", ["Foo1", "Foo2", "Foo3", "Foo4"]),
+ ("-char_field", ["Foo4", "Foo3", "Foo2", "Foo1"]),
+ (Concat("char_field", Value("@")), ["Foo1", "Foo2", "Foo3", "Foo4"]),
+ (Concat("char_field", Value("@")).desc(), ["Foo4", "Foo3", "Foo2", "Foo1"]),
)
for ordering, expected_output in ordering_test_cases:
with self.subTest(ordering=ordering, expected_output=expected_output):
values = AggregateTestModel.objects.aggregate(
- jsonbagg=JSONBAgg('char_field', ordering=ordering),
+ jsonbagg=JSONBAgg("char_field", ordering=ordering),
)
- self.assertEqual(values, {'jsonbagg': expected_output})
+ self.assertEqual(values, {"jsonbagg": expected_output})
def test_jsonb_agg_integerfield_ordering(self):
values = AggregateTestModel.objects.aggregate(
- jsonbagg=JSONBAgg('integer_field', ordering=F('integer_field').desc()),
+ jsonbagg=JSONBAgg("integer_field", ordering=F("integer_field").desc()),
)
- self.assertEqual(values, {'jsonbagg': [2, 1, 0, 0]})
+ self.assertEqual(values, {"jsonbagg": [2, 1, 0, 0]})
def test_jsonb_agg_booleanfield_ordering(self):
ordering_test_cases = (
- (F('boolean_field').asc(), [False, False, True, True]),
- (F('boolean_field').desc(), [True, True, False, False]),
- (F('boolean_field'), [False, False, True, True]),
+ (F("boolean_field").asc(), [False, False, True, True]),
+ (F("boolean_field").desc(), [True, True, False, False]),
+ (F("boolean_field"), [False, False, True, True]),
)
for ordering, expected_output in ordering_test_cases:
with self.subTest(ordering=ordering, expected_output=expected_output):
values = AggregateTestModel.objects.aggregate(
- jsonbagg=JSONBAgg('boolean_field', ordering=ordering),
+ jsonbagg=JSONBAgg("boolean_field", ordering=ordering),
)
- self.assertEqual(values, {'jsonbagg': expected_output})
+ self.assertEqual(values, {"jsonbagg": expected_output})
def test_jsonb_agg_jsonfield_ordering(self):
values = AggregateTestModel.objects.aggregate(
jsonbagg=JSONBAgg(
- KeyTransform('lang', 'json_field'),
+ KeyTransform("lang", "json_field"),
filter=Q(json_field__lang__isnull=False),
- ordering=KeyTransform('lang', 'json_field'),
+ ordering=KeyTransform("lang", "json_field"),
),
)
- self.assertEqual(values, {'jsonbagg': ['en', 'pl']})
+ self.assertEqual(values, {"jsonbagg": ["en", "pl"]})
def test_jsonb_agg_key_index_transforms(self):
room101 = Room.objects.create(number=101)
@@ -448,160 +514,206 @@ class TestGeneralAggregate(PostgreSQLTestCase):
start=datetimes[0],
end=datetimes[1],
room=room102,
- requirements={'double_bed': True, 'parking': True},
+ requirements={"double_bed": True, "parking": True},
)
HotelReservation.objects.create(
datespan=(datetimes[1].date(), datetimes[2].date()),
start=datetimes[1],
end=datetimes[2],
room=room102,
- requirements={'double_bed': False, 'sea_view': True, 'parking': False},
+ requirements={"double_bed": False, "sea_view": True, "parking": False},
)
HotelReservation.objects.create(
datespan=(datetimes[0].date(), datetimes[2].date()),
start=datetimes[0],
end=datetimes[2],
room=room101,
- requirements={'sea_view': False},
+ requirements={"sea_view": False},
)
- values = Room.objects.annotate(
- requirements=JSONBAgg(
- 'hotelreservation__requirements',
- ordering='-hotelreservation__start',
+ values = (
+ Room.objects.annotate(
+ requirements=JSONBAgg(
+ "hotelreservation__requirements",
+ ordering="-hotelreservation__start",
+ )
)
- ).filter(requirements__0__sea_view=True).values('number', 'requirements')
- self.assertSequenceEqual(values, [
- {'number': 102, 'requirements': [
- {'double_bed': False, 'sea_view': True, 'parking': False},
- {'double_bed': True, 'parking': True},
- ]},
- ])
+ .filter(requirements__0__sea_view=True)
+ .values("number", "requirements")
+ )
+ self.assertSequenceEqual(
+ values,
+ [
+ {
+ "number": 102,
+ "requirements": [
+ {"double_bed": False, "sea_view": True, "parking": False},
+ {"double_bed": True, "parking": True},
+ ],
+ },
+ ],
+ )
def test_string_agg_array_agg_ordering_in_subquery(self):
stats = []
- for i, agg in enumerate(AggregateTestModel.objects.order_by('char_field')):
+ for i, agg in enumerate(AggregateTestModel.objects.order_by("char_field")):
stats.append(StatTestModel(related_field=agg, int1=i, int2=i + 1))
stats.append(StatTestModel(related_field=agg, int1=i + 1, int2=i))
StatTestModel.objects.bulk_create(stats)
for aggregate, expected_result in (
(
- ArrayAgg('stattestmodel__int1', ordering='-stattestmodel__int2'),
- [('Foo1', [0, 1]), ('Foo2', [1, 2]), ('Foo3', [2, 3]), ('Foo4', [3, 4])],
+ ArrayAgg("stattestmodel__int1", ordering="-stattestmodel__int2"),
+ [
+ ("Foo1", [0, 1]),
+ ("Foo2", [1, 2]),
+ ("Foo3", [2, 3]),
+ ("Foo4", [3, 4]),
+ ],
),
(
StringAgg(
- Cast('stattestmodel__int1', CharField()),
- delimiter=';',
- ordering='-stattestmodel__int2',
+ Cast("stattestmodel__int1", CharField()),
+ delimiter=";",
+ ordering="-stattestmodel__int2",
),
- [('Foo1', '0;1'), ('Foo2', '1;2'), ('Foo3', '2;3'), ('Foo4', '3;4')],
+ [("Foo1", "0;1"), ("Foo2", "1;2"), ("Foo3", "2;3"), ("Foo4", "3;4")],
),
):
with self.subTest(aggregate=aggregate.__class__.__name__):
- subquery = AggregateTestModel.objects.filter(
- pk=OuterRef('pk'),
- ).annotate(agg=aggregate).values('agg')
- values = AggregateTestModel.objects.annotate(
- agg=Subquery(subquery),
- ).order_by('char_field').values_list('char_field', 'agg')
+ subquery = (
+ AggregateTestModel.objects.filter(
+ pk=OuterRef("pk"),
+ )
+ .annotate(agg=aggregate)
+ .values("agg")
+ )
+ values = (
+ AggregateTestModel.objects.annotate(
+ agg=Subquery(subquery),
+ )
+ .order_by("char_field")
+ .values_list("char_field", "agg")
+ )
self.assertEqual(list(values), expected_result)
def test_string_agg_array_agg_filter_in_subquery(self):
- StatTestModel.objects.bulk_create([
- StatTestModel(related_field=self.aggs[0], int1=0, int2=5),
- StatTestModel(related_field=self.aggs[0], int1=1, int2=4),
- StatTestModel(related_field=self.aggs[0], int1=2, int2=3),
- ])
+ StatTestModel.objects.bulk_create(
+ [
+ StatTestModel(related_field=self.aggs[0], int1=0, int2=5),
+ StatTestModel(related_field=self.aggs[0], int1=1, int2=4),
+ StatTestModel(related_field=self.aggs[0], int1=2, int2=3),
+ ]
+ )
for aggregate, expected_result in (
(
- ArrayAgg('stattestmodel__int1', filter=Q(stattestmodel__int2__gt=3)),
- [('Foo1', [0, 1]), ('Foo2', None)],
+ ArrayAgg("stattestmodel__int1", filter=Q(stattestmodel__int2__gt=3)),
+ [("Foo1", [0, 1]), ("Foo2", None)],
),
(
StringAgg(
- Cast('stattestmodel__int2', CharField()),
- delimiter=';',
+ Cast("stattestmodel__int2", CharField()),
+ delimiter=";",
filter=Q(stattestmodel__int1__lt=2),
),
- [('Foo1', '5;4'), ('Foo2', None)],
+ [("Foo1", "5;4"), ("Foo2", None)],
),
):
with self.subTest(aggregate=aggregate.__class__.__name__):
- subquery = AggregateTestModel.objects.filter(
- pk=OuterRef('pk'),
- ).annotate(agg=aggregate).values('agg')
- values = AggregateTestModel.objects.annotate(
- agg=Subquery(subquery),
- ).filter(
- char_field__in=['Foo1', 'Foo2'],
- ).order_by('char_field').values_list('char_field', 'agg')
+ subquery = (
+ AggregateTestModel.objects.filter(
+ pk=OuterRef("pk"),
+ )
+ .annotate(agg=aggregate)
+ .values("agg")
+ )
+ values = (
+ AggregateTestModel.objects.annotate(
+ agg=Subquery(subquery),
+ )
+ .filter(
+ char_field__in=["Foo1", "Foo2"],
+ )
+ .order_by("char_field")
+ .values_list("char_field", "agg")
+ )
self.assertEqual(list(values), expected_result)
def test_string_agg_filter_in_subquery_with_exclude(self):
- subquery = AggregateTestModel.objects.annotate(
- stringagg=StringAgg(
- 'char_field',
- delimiter=';',
- filter=Q(char_field__endswith='1'),
+ subquery = (
+ AggregateTestModel.objects.annotate(
+ stringagg=StringAgg(
+ "char_field",
+ delimiter=";",
+ filter=Q(char_field__endswith="1"),
+ )
)
- ).exclude(stringagg='').values('id')
+ .exclude(stringagg="")
+ .values("id")
+ )
self.assertSequenceEqual(
AggregateTestModel.objects.filter(id__in=Subquery(subquery)),
[self.aggs[0]],
)
def test_ordering_isnt_cleared_for_array_subquery(self):
- inner_qs = AggregateTestModel.objects.order_by('-integer_field')
+ inner_qs = AggregateTestModel.objects.order_by("-integer_field")
qs = AggregateTestModel.objects.annotate(
integers=Func(
- Subquery(inner_qs.values('integer_field')),
- function='ARRAY',
+ Subquery(inner_qs.values("integer_field")),
+ function="ARRAY",
output_field=ArrayField(base_field=IntegerField()),
),
)
self.assertSequenceEqual(
qs.first().integers,
- inner_qs.values_list('integer_field', flat=True),
+ inner_qs.values_list("integer_field", flat=True),
)
class TestAggregateDistinct(PostgreSQLTestCase):
@classmethod
def setUpTestData(cls):
- AggregateTestModel.objects.create(char_field='Foo')
- AggregateTestModel.objects.create(char_field='Foo')
- AggregateTestModel.objects.create(char_field='Bar')
+ AggregateTestModel.objects.create(char_field="Foo")
+ AggregateTestModel.objects.create(char_field="Foo")
+ AggregateTestModel.objects.create(char_field="Bar")
def test_string_agg_distinct_false(self):
- values = AggregateTestModel.objects.aggregate(stringagg=StringAgg('char_field', delimiter=' ', distinct=False))
- self.assertEqual(values['stringagg'].count('Foo'), 2)
- self.assertEqual(values['stringagg'].count('Bar'), 1)
+ values = AggregateTestModel.objects.aggregate(
+ stringagg=StringAgg("char_field", delimiter=" ", distinct=False)
+ )
+ self.assertEqual(values["stringagg"].count("Foo"), 2)
+ self.assertEqual(values["stringagg"].count("Bar"), 1)
def test_string_agg_distinct_true(self):
- values = AggregateTestModel.objects.aggregate(stringagg=StringAgg('char_field', delimiter=' ', distinct=True))
- self.assertEqual(values['stringagg'].count('Foo'), 1)
- self.assertEqual(values['stringagg'].count('Bar'), 1)
+ values = AggregateTestModel.objects.aggregate(
+ stringagg=StringAgg("char_field", delimiter=" ", distinct=True)
+ )
+ self.assertEqual(values["stringagg"].count("Foo"), 1)
+ self.assertEqual(values["stringagg"].count("Bar"), 1)
def test_array_agg_distinct_false(self):
- values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('char_field', distinct=False))
- self.assertEqual(sorted(values['arrayagg']), ['Bar', 'Foo', 'Foo'])
+ values = AggregateTestModel.objects.aggregate(
+ arrayagg=ArrayAgg("char_field", distinct=False)
+ )
+ self.assertEqual(sorted(values["arrayagg"]), ["Bar", "Foo", "Foo"])
def test_array_agg_distinct_true(self):
- values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('char_field', distinct=True))
- self.assertEqual(sorted(values['arrayagg']), ['Bar', 'Foo'])
+ values = AggregateTestModel.objects.aggregate(
+ arrayagg=ArrayAgg("char_field", distinct=True)
+ )
+ self.assertEqual(sorted(values["arrayagg"]), ["Bar", "Foo"])
def test_jsonb_agg_distinct_false(self):
values = AggregateTestModel.objects.aggregate(
- jsonbagg=JSONBAgg('char_field', distinct=False),
+ jsonbagg=JSONBAgg("char_field", distinct=False),
)
- self.assertEqual(sorted(values['jsonbagg']), ['Bar', 'Foo', 'Foo'])
+ self.assertEqual(sorted(values["jsonbagg"]), ["Bar", "Foo", "Foo"])
def test_jsonb_agg_distinct_true(self):
values = AggregateTestModel.objects.aggregate(
- jsonbagg=JSONBAgg('char_field', distinct=True),
+ jsonbagg=JSONBAgg("char_field", distinct=True),
)
- self.assertEqual(sorted(values['jsonbagg']), ['Bar', 'Foo'])
+ self.assertEqual(sorted(values["jsonbagg"]), ["Bar", "Foo"])
class TestStatisticsAggregate(PostgreSQLTestCase):
@@ -626,37 +738,38 @@ class TestStatisticsAggregate(PostgreSQLTestCase):
# Tests for base class (StatAggregate)
def test_missing_arguments_raises_exception(self):
- with self.assertRaisesMessage(ValueError, 'Both y and x must be provided.'):
+ with self.assertRaisesMessage(ValueError, "Both y and x must be provided."):
StatAggregate(x=None, y=None)
def test_correct_source_expressions(self):
- func = StatAggregate(x='test', y=13)
+ func = StatAggregate(x="test", y=13)
self.assertIsInstance(func.source_expressions[0], Value)
self.assertIsInstance(func.source_expressions[1], F)
def test_alias_is_required(self):
class SomeFunc(StatAggregate):
- function = 'TEST'
- with self.assertRaisesMessage(TypeError, 'Complex aggregates require an alias'):
- StatTestModel.objects.aggregate(SomeFunc(y='int2', x='int1'))
+ function = "TEST"
+
+ with self.assertRaisesMessage(TypeError, "Complex aggregates require an alias"):
+ StatTestModel.objects.aggregate(SomeFunc(y="int2", x="int1"))
# Test aggregates
def test_empty_result_set(self):
StatTestModel.objects.all().delete()
tests = [
- (Corr(y='int2', x='int1'), None),
- (CovarPop(y='int2', x='int1'), None),
- (CovarPop(y='int2', x='int1', sample=True), None),
- (RegrAvgX(y='int2', x='int1'), None),
- (RegrAvgY(y='int2', x='int1'), None),
- (RegrCount(y='int2', x='int1'), 0),
- (RegrIntercept(y='int2', x='int1'), None),
- (RegrR2(y='int2', x='int1'), None),
- (RegrSlope(y='int2', x='int1'), None),
- (RegrSXX(y='int2', x='int1'), None),
- (RegrSXY(y='int2', x='int1'), None),
- (RegrSYY(y='int2', x='int1'), None),
+ (Corr(y="int2", x="int1"), None),
+ (CovarPop(y="int2", x="int1"), None),
+ (CovarPop(y="int2", x="int1", sample=True), None),
+ (RegrAvgX(y="int2", x="int1"), None),
+ (RegrAvgY(y="int2", x="int1"), None),
+ (RegrCount(y="int2", x="int1"), 0),
+ (RegrIntercept(y="int2", x="int1"), None),
+ (RegrR2(y="int2", x="int1"), None),
+ (RegrSlope(y="int2", x="int1"), None),
+ (RegrSXX(y="int2", x="int1"), None),
+ (RegrSXY(y="int2", x="int1"), None),
+ (RegrSYY(y="int2", x="int1"), None),
]
for aggregation, expected_result in tests:
with self.subTest(aggregation=aggregation):
@@ -665,29 +778,29 @@ class TestStatisticsAggregate(PostgreSQLTestCase):
values = StatTestModel.objects.none().aggregate(
aggregation=aggregation,
)
- self.assertEqual(values, {'aggregation': expected_result})
+ self.assertEqual(values, {"aggregation": expected_result})
# Empty result when query must be executed.
with self.assertNumQueries(1):
values = StatTestModel.objects.aggregate(
aggregation=aggregation,
)
- self.assertEqual(values, {'aggregation': expected_result})
+ self.assertEqual(values, {"aggregation": expected_result})
def test_default_argument(self):
StatTestModel.objects.all().delete()
tests = [
- (Corr(y='int2', x='int1', default=0), 0),
- (CovarPop(y='int2', x='int1', default=0), 0),
- (CovarPop(y='int2', x='int1', sample=True, default=0), 0),
- (RegrAvgX(y='int2', x='int1', default=0), 0),
- (RegrAvgY(y='int2', x='int1', default=0), 0),
+ (Corr(y="int2", x="int1", default=0), 0),
+ (CovarPop(y="int2", x="int1", default=0), 0),
+ (CovarPop(y="int2", x="int1", sample=True, default=0), 0),
+ (RegrAvgX(y="int2", x="int1", default=0), 0),
+ (RegrAvgY(y="int2", x="int1", default=0), 0),
# RegrCount() doesn't support the default argument.
- (RegrIntercept(y='int2', x='int1', default=0), 0),
- (RegrR2(y='int2', x='int1', default=0), 0),
- (RegrSlope(y='int2', x='int1', default=0), 0),
- (RegrSXX(y='int2', x='int1', default=0), 0),
- (RegrSXY(y='int2', x='int1', default=0), 0),
- (RegrSYY(y='int2', x='int1', default=0), 0),
+ (RegrIntercept(y="int2", x="int1", default=0), 0),
+ (RegrR2(y="int2", x="int1", default=0), 0),
+ (RegrSlope(y="int2", x="int1", default=0), 0),
+ (RegrSXX(y="int2", x="int1", default=0), 0),
+ (RegrSXY(y="int2", x="int1", default=0), 0),
+ (RegrSYY(y="int2", x="int1", default=0), 0),
]
for aggregation, expected_result in tests:
with self.subTest(aggregation=aggregation):
@@ -696,71 +809,81 @@ class TestStatisticsAggregate(PostgreSQLTestCase):
values = StatTestModel.objects.none().aggregate(
aggregation=aggregation,
)
- self.assertEqual(values, {'aggregation': expected_result})
+ self.assertEqual(values, {"aggregation": expected_result})
# Empty result when query must be executed.
with self.assertNumQueries(1):
values = StatTestModel.objects.aggregate(
aggregation=aggregation,
)
- self.assertEqual(values, {'aggregation': expected_result})
+ self.assertEqual(values, {"aggregation": expected_result})
def test_corr_general(self):
- values = StatTestModel.objects.aggregate(corr=Corr(y='int2', x='int1'))
- self.assertEqual(values, {'corr': -1.0})
+ values = StatTestModel.objects.aggregate(corr=Corr(y="int2", x="int1"))
+ self.assertEqual(values, {"corr": -1.0})
def test_covar_pop_general(self):
- values = StatTestModel.objects.aggregate(covarpop=CovarPop(y='int2', x='int1'))
- self.assertEqual(values, {'covarpop': Approximate(-0.66, places=1)})
+ values = StatTestModel.objects.aggregate(covarpop=CovarPop(y="int2", x="int1"))
+ self.assertEqual(values, {"covarpop": Approximate(-0.66, places=1)})
def test_covar_pop_sample(self):
- values = StatTestModel.objects.aggregate(covarpop=CovarPop(y='int2', x='int1', sample=True))
- self.assertEqual(values, {'covarpop': -1.0})
+ values = StatTestModel.objects.aggregate(
+ covarpop=CovarPop(y="int2", x="int1", sample=True)
+ )
+ self.assertEqual(values, {"covarpop": -1.0})
def test_regr_avgx_general(self):
- values = StatTestModel.objects.aggregate(regravgx=RegrAvgX(y='int2', x='int1'))
- self.assertEqual(values, {'regravgx': 2.0})
+ values = StatTestModel.objects.aggregate(regravgx=RegrAvgX(y="int2", x="int1"))
+ self.assertEqual(values, {"regravgx": 2.0})
def test_regr_avgy_general(self):
- values = StatTestModel.objects.aggregate(regravgy=RegrAvgY(y='int2', x='int1'))
- self.assertEqual(values, {'regravgy': 2.0})
+ values = StatTestModel.objects.aggregate(regravgy=RegrAvgY(y="int2", x="int1"))
+ self.assertEqual(values, {"regravgy": 2.0})
def test_regr_count_general(self):
- values = StatTestModel.objects.aggregate(regrcount=RegrCount(y='int2', x='int1'))
- self.assertEqual(values, {'regrcount': 3})
+ values = StatTestModel.objects.aggregate(
+ regrcount=RegrCount(y="int2", x="int1")
+ )
+ self.assertEqual(values, {"regrcount": 3})
def test_regr_count_default(self):
- msg = 'RegrCount does not allow default.'
+ msg = "RegrCount does not allow default."
with self.assertRaisesMessage(TypeError, msg):
- RegrCount(y='int2', x='int1', default=0)
+ RegrCount(y="int2", x="int1", default=0)
def test_regr_intercept_general(self):
- values = StatTestModel.objects.aggregate(regrintercept=RegrIntercept(y='int2', x='int1'))
- self.assertEqual(values, {'regrintercept': 4})
+ values = StatTestModel.objects.aggregate(
+ regrintercept=RegrIntercept(y="int2", x="int1")
+ )
+ self.assertEqual(values, {"regrintercept": 4})
def test_regr_r2_general(self):
- values = StatTestModel.objects.aggregate(regrr2=RegrR2(y='int2', x='int1'))
- self.assertEqual(values, {'regrr2': 1})
+ values = StatTestModel.objects.aggregate(regrr2=RegrR2(y="int2", x="int1"))
+ self.assertEqual(values, {"regrr2": 1})
def test_regr_slope_general(self):
- values = StatTestModel.objects.aggregate(regrslope=RegrSlope(y='int2', x='int1'))
- self.assertEqual(values, {'regrslope': -1})
+ values = StatTestModel.objects.aggregate(
+ regrslope=RegrSlope(y="int2", x="int1")
+ )
+ self.assertEqual(values, {"regrslope": -1})
def test_regr_sxx_general(self):
- values = StatTestModel.objects.aggregate(regrsxx=RegrSXX(y='int2', x='int1'))
- self.assertEqual(values, {'regrsxx': 2.0})
+ values = StatTestModel.objects.aggregate(regrsxx=RegrSXX(y="int2", x="int1"))
+ self.assertEqual(values, {"regrsxx": 2.0})
def test_regr_sxy_general(self):
- values = StatTestModel.objects.aggregate(regrsxy=RegrSXY(y='int2', x='int1'))
- self.assertEqual(values, {'regrsxy': -2.0})
+ values = StatTestModel.objects.aggregate(regrsxy=RegrSXY(y="int2", x="int1"))
+ self.assertEqual(values, {"regrsxy": -2.0})
def test_regr_syy_general(self):
- values = StatTestModel.objects.aggregate(regrsyy=RegrSYY(y='int2', x='int1'))
- self.assertEqual(values, {'regrsyy': 2.0})
+ values = StatTestModel.objects.aggregate(regrsyy=RegrSYY(y="int2", x="int1"))
+ self.assertEqual(values, {"regrsyy": 2.0})
def test_regr_avgx_with_related_obj_and_number_as_argument(self):
"""
This is more complex test to check if JOIN on field and
number as argument works as expected.
"""
- values = StatTestModel.objects.aggregate(complex_regravgx=RegrAvgX(y=5, x='related_field__integer_field'))
- self.assertEqual(values, {'complex_regravgx': 1.0})
+ values = StatTestModel.objects.aggregate(
+ complex_regravgx=RegrAvgX(y=5, x="related_field__integer_field")
+ )
+ self.assertEqual(values, {"complex_regravgx": 1.0})
diff --git a/tests/postgres_tests/test_apps.py b/tests/postgres_tests/test_apps.py
index 94001822c2..340e555609 100644
--- a/tests/postgres_tests/test_apps.py
+++ b/tests/postgres_tests/test_apps.py
@@ -7,12 +7,12 @@ from django.test.utils import modify_settings
from . import PostgreSQLTestCase
try:
- from psycopg2.extras import (
- DateRange, DateTimeRange, DateTimeTZRange, NumericRange,
- )
+ from psycopg2.extras import DateRange, DateTimeRange, DateTimeTZRange, NumericRange
from django.contrib.postgres.fields import (
- DateRangeField, DateTimeRangeField, DecimalRangeField,
+ DateRangeField,
+ DateTimeRangeField,
+ DecimalRangeField,
IntegerRangeField,
)
except ImportError:
@@ -22,17 +22,24 @@ except ImportError:
class PostgresConfigTests(PostgreSQLTestCase):
def test_register_type_handlers_connection(self):
from django.contrib.postgres.signals import register_type_handlers
- self.assertNotIn(register_type_handlers, connection_created._live_receivers(None))
- with modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}):
- self.assertIn(register_type_handlers, connection_created._live_receivers(None))
- self.assertNotIn(register_type_handlers, connection_created._live_receivers(None))
+
+ self.assertNotIn(
+ register_type_handlers, connection_created._live_receivers(None)
+ )
+ with modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"}):
+ self.assertIn(
+ register_type_handlers, connection_created._live_receivers(None)
+ )
+ self.assertNotIn(
+ register_type_handlers, connection_created._live_receivers(None)
+ )
def test_register_serializer_for_migrations(self):
tests = (
(DateRange(empty=True), DateRangeField),
(DateTimeRange(empty=True), DateRangeField),
- (DateTimeTZRange(None, None, '[]'), DateTimeRangeField),
- (NumericRange(Decimal('1.0'), Decimal('5.0'), '()'), DecimalRangeField),
+ (DateTimeTZRange(None, None, "[]"), DateTimeRangeField),
+ (NumericRange(Decimal("1.0"), Decimal("5.0"), "()"), DecimalRangeField),
(NumericRange(1, 10), IntegerRangeField),
)
@@ -40,25 +47,31 @@ class PostgresConfigTests(PostgreSQLTestCase):
for default, test_field in tests:
with self.subTest(default=default):
field = test_field(default=default)
- with self.assertRaisesMessage(ValueError, 'Cannot serialize: %s' % default.__class__.__name__):
+ with self.assertRaisesMessage(
+ ValueError, "Cannot serialize: %s" % default.__class__.__name__
+ ):
MigrationWriter.serialize(field)
assertNotSerializable()
- with self.modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}):
+ with self.modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"}):
for default, test_field in tests:
with self.subTest(default=default):
field = test_field(default=default)
serialized_field, imports = MigrationWriter.serialize(field)
- self.assertEqual(imports, {
- 'import django.contrib.postgres.fields.ranges',
- 'import psycopg2.extras',
- })
+ self.assertEqual(
+ imports,
+ {
+ "import django.contrib.postgres.fields.ranges",
+ "import psycopg2.extras",
+ },
+ )
self.assertIn(
- '%s.%s(default=psycopg2.extras.%r)' % (
+ "%s.%s(default=psycopg2.extras.%r)"
+ % (
field.__module__,
field.__class__.__name__,
default,
),
- serialized_field
+ serialized_field,
)
assertNotSerializable()
diff --git a/tests/postgres_tests/test_array.py b/tests/postgres_tests/test_array.py
index ba04c15e24..512972b8e6 100644
--- a/tests/postgres_tests/test_array.py
+++ b/tests/postgres_tests/test_array.py
@@ -15,13 +15,18 @@ from django.test import TransactionTestCase, modify_settings, override_settings
from django.test.utils import isolate_apps
from django.utils import timezone
-from . import (
- PostgreSQLSimpleTestCase, PostgreSQLTestCase, PostgreSQLWidgetTestCase,
-)
+from . import PostgreSQLSimpleTestCase, PostgreSQLTestCase, PostgreSQLWidgetTestCase
from .models import (
- ArrayEnumModel, ArrayFieldSubclass, CharArrayModel, DateTimeArrayModel,
- IntegerArrayModel, NestedIntegerArrayModel, NullableIntegerArrayModel,
- OtherTypesArrayModel, PostgreSQLModel, Tag,
+ ArrayEnumModel,
+ ArrayFieldSubclass,
+ CharArrayModel,
+ DateTimeArrayModel,
+ IntegerArrayModel,
+ NestedIntegerArrayModel,
+ NullableIntegerArrayModel,
+ OtherTypesArrayModel,
+ PostgreSQLModel,
+ Tag,
)
try:
@@ -30,33 +35,33 @@ try:
from django.contrib.postgres.aggregates import ArrayAgg
from django.contrib.postgres.expressions import ArraySubquery
from django.contrib.postgres.fields import ArrayField
- from django.contrib.postgres.fields.array import (
- IndexTransform, SliceTransform,
- )
+ from django.contrib.postgres.fields.array import IndexTransform, SliceTransform
from django.contrib.postgres.forms import (
- SimpleArrayField, SplitArrayField, SplitArrayWidget,
+ SimpleArrayField,
+ SplitArrayField,
+ SplitArrayWidget,
)
except ImportError:
pass
-@isolate_apps('postgres_tests')
+@isolate_apps("postgres_tests")
class BasicTests(PostgreSQLSimpleTestCase):
def test_get_field_display(self):
class MyModel(PostgreSQLModel):
field = ArrayField(
models.CharField(max_length=16),
choices=[
- ['Media', [(['vinyl', 'cd'], 'Audio')]],
- (('mp3', 'mp4'), 'Digital'),
+ ["Media", [(["vinyl", "cd"], "Audio")]],
+ (("mp3", "mp4"), "Digital"),
],
)
tests = (
- (['vinyl', 'cd'], 'Audio'),
- (('mp3', 'mp4'), 'Digital'),
- (('a', 'b'), "('a', 'b')"),
- (['c', 'd'], "['c', 'd']"),
+ (["vinyl", "cd"], "Audio"),
+ (("mp3", "mp4"), "Digital"),
+ (("a", "b"), "('a', 'b')"),
+ (["c", "d"], "['c', 'd']"),
)
for value, display in tests:
with self.subTest(value=value, display=display):
@@ -69,17 +74,18 @@ class BasicTests(PostgreSQLSimpleTestCase):
ArrayField(models.CharField(max_length=16)),
choices=[
[
- 'Media',
- [([['vinyl', 'cd'], ('x',)], 'Audio')],
+ "Media",
+ [([["vinyl", "cd"], ("x",)], "Audio")],
],
- ((['mp3'], ('mp4',)), 'Digital'),
+ ((["mp3"], ("mp4",)), "Digital"),
],
)
+
tests = (
- ([['vinyl', 'cd'], ('x',)], 'Audio'),
- ((['mp3'], ('mp4',)), 'Digital'),
- ((('a', 'b'), ('c',)), "(('a', 'b'), ('c',))"),
- ([['a', 'b'], ['c']], "[['a', 'b'], ['c']]"),
+ ([["vinyl", "cd"], ("x",)], "Audio"),
+ ((["mp3"], ("mp4",)), "Digital"),
+ ((("a", "b"), ("c",)), "(('a', 'b'), ('c',))"),
+ ([["a", "b"], ["c"]], "[['a', 'b'], ['c']]"),
)
for value, display in tests:
with self.subTest(value=value, display=display):
@@ -88,7 +94,6 @@ class BasicTests(PostgreSQLSimpleTestCase):
class TestSaveLoad(PostgreSQLTestCase):
-
def test_integer(self):
instance = IntegerArrayModel(field=[1, 2, 3])
instance.save()
@@ -96,7 +101,7 @@ class TestSaveLoad(PostgreSQLTestCase):
self.assertEqual(instance.field, loaded.field)
def test_char(self):
- instance = CharArrayModel(field=['hello', 'goodbye'])
+ instance = CharArrayModel(field=["hello", "goodbye"])
instance.save()
loaded = CharArrayModel.objects.get()
self.assertEqual(instance.field, loaded.field)
@@ -121,7 +126,7 @@ class TestSaveLoad(PostgreSQLTestCase):
def test_integers_passed_as_strings(self):
# This checks that get_prep_value is deferred properly
- instance = IntegerArrayModel(field=['1'])
+ instance = IntegerArrayModel(field=["1"])
instance.save()
loaded = IntegerArrayModel.objects.get()
self.assertEqual(loaded.field, [1])
@@ -151,16 +156,16 @@ class TestSaveLoad(PostgreSQLTestCase):
def test_other_array_types(self):
instance = OtherTypesArrayModel(
- ips=['192.168.0.1', '::1'],
+ ips=["192.168.0.1", "::1"],
uuids=[uuid.uuid4()],
decimals=[decimal.Decimal(1.25), 1.75],
tags=[Tag(1), Tag(2), Tag(3)],
- json=[{'a': 1}, {'b': 2}],
+ json=[{"a": 1}, {"b": 2}],
int_ranges=[NumericRange(10, 20), NumericRange(30, 40)],
bigint_ranges=[
NumericRange(7000000000, 10000000000),
NumericRange(50000000000, 70000000000),
- ]
+ ],
)
instance.save()
loaded = OtherTypesArrayModel.objects.get()
@@ -174,7 +179,7 @@ class TestSaveLoad(PostgreSQLTestCase):
def test_null_from_db_value_handling(self):
instance = OtherTypesArrayModel.objects.create(
- ips=['192.168.0.1', '::1'],
+ ips=["192.168.0.1", "::1"],
uuids=[uuid.uuid4()],
decimals=[decimal.Decimal(1.25), 1.75],
tags=None,
@@ -187,7 +192,7 @@ class TestSaveLoad(PostgreSQLTestCase):
def test_model_set_on_base_field(self):
instance = IntegerArrayModel()
- field = instance._meta.get_field('field')
+ field = instance._meta.get_field("field")
self.assertEqual(field.model, IntegerArrayModel)
self.assertEqual(field.base_field.model, IntegerArrayModel)
@@ -199,29 +204,35 @@ class TestSaveLoad(PostgreSQLTestCase):
class TestQuerying(PostgreSQLTestCase):
-
@classmethod
def setUpTestData(cls):
- cls.objs = NullableIntegerArrayModel.objects.bulk_create([
- NullableIntegerArrayModel(order=1, field=[1]),
- NullableIntegerArrayModel(order=2, field=[2]),
- NullableIntegerArrayModel(order=3, field=[2, 3]),
- NullableIntegerArrayModel(order=4, field=[20, 30, 40]),
- NullableIntegerArrayModel(order=5, field=None),
- ])
+ cls.objs = NullableIntegerArrayModel.objects.bulk_create(
+ [
+ NullableIntegerArrayModel(order=1, field=[1]),
+ NullableIntegerArrayModel(order=2, field=[2]),
+ NullableIntegerArrayModel(order=3, field=[2, 3]),
+ NullableIntegerArrayModel(order=4, field=[20, 30, 40]),
+ NullableIntegerArrayModel(order=5, field=None),
+ ]
+ )
def test_empty_list(self):
NullableIntegerArrayModel.objects.create(field=[])
- obj = NullableIntegerArrayModel.objects.annotate(
- empty_array=models.Value([], output_field=ArrayField(models.IntegerField())),
- ).filter(field=models.F('empty_array')).get()
+ obj = (
+ NullableIntegerArrayModel.objects.annotate(
+ empty_array=models.Value(
+ [], output_field=ArrayField(models.IntegerField())
+ ),
+ )
+ .filter(field=models.F("empty_array"))
+ .get()
+ )
self.assertEqual(obj.field, [])
self.assertEqual(obj.empty_array, [])
def test_exact(self):
self.assertSequenceEqual(
- NullableIntegerArrayModel.objects.filter(field__exact=[1]),
- self.objs[:1]
+ NullableIntegerArrayModel.objects.filter(field__exact=[1]), self.objs[:1]
)
def test_exact_with_expression(self):
@@ -231,50 +242,47 @@ class TestQuerying(PostgreSQLTestCase):
)
def test_exact_charfield(self):
- instance = CharArrayModel.objects.create(field=['text'])
+ instance = CharArrayModel.objects.create(field=["text"])
self.assertSequenceEqual(
- CharArrayModel.objects.filter(field=['text']),
- [instance]
+ CharArrayModel.objects.filter(field=["text"]), [instance]
)
def test_exact_nested(self):
instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])
self.assertSequenceEqual(
- NestedIntegerArrayModel.objects.filter(field=[[1, 2], [3, 4]]),
- [instance]
+ NestedIntegerArrayModel.objects.filter(field=[[1, 2], [3, 4]]), [instance]
)
def test_isnull(self):
self.assertSequenceEqual(
- NullableIntegerArrayModel.objects.filter(field__isnull=True),
- self.objs[-1:]
+ NullableIntegerArrayModel.objects.filter(field__isnull=True), self.objs[-1:]
)
def test_gt(self):
self.assertSequenceEqual(
- NullableIntegerArrayModel.objects.filter(field__gt=[0]),
- self.objs[:4]
+ NullableIntegerArrayModel.objects.filter(field__gt=[0]), self.objs[:4]
)
def test_lt(self):
self.assertSequenceEqual(
- NullableIntegerArrayModel.objects.filter(field__lt=[2]),
- self.objs[:1]
+ NullableIntegerArrayModel.objects.filter(field__lt=[2]), self.objs[:1]
)
def test_in(self):
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(field__in=[[1], [2]]),
- self.objs[:2]
+ self.objs[:2],
)
def test_in_subquery(self):
IntegerArrayModel.objects.create(field=[2, 3])
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(
- field__in=IntegerArrayModel.objects.all().values_list('field', flat=True)
+ field__in=IntegerArrayModel.objects.all().values_list(
+ "field", flat=True
+ )
),
- self.objs[2:3]
+ self.objs[2:3],
)
@unittest.expectedFailure
@@ -284,42 +292,44 @@ class TestQuerying(PostgreSQLTestCase):
# psycopg2 mogrify method that generates the ARRAY() syntax is
# expecting literals, not column references (#27095).
self.assertSequenceEqual(
- NullableIntegerArrayModel.objects.filter(field__in=[[models.F('id')]]),
- self.objs[:2]
+ NullableIntegerArrayModel.objects.filter(field__in=[[models.F("id")]]),
+ self.objs[:2],
)
def test_in_as_F_object(self):
self.assertSequenceEqual(
- NullableIntegerArrayModel.objects.filter(field__in=[models.F('field')]),
- self.objs[:4]
+ NullableIntegerArrayModel.objects.filter(field__in=[models.F("field")]),
+ self.objs[:4],
)
def test_contained_by(self):
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(field__contained_by=[1, 2]),
- self.objs[:2]
+ self.objs[:2],
)
def test_contained_by_including_F_object(self):
self.assertSequenceEqual(
- NullableIntegerArrayModel.objects.filter(field__contained_by=[models.F('order'), 2]),
+ NullableIntegerArrayModel.objects.filter(
+ field__contained_by=[models.F("order"), 2]
+ ),
self.objs[:3],
)
def test_contains(self):
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(field__contains=[2]),
- self.objs[1:3]
+ self.objs[1:3],
)
def test_contains_subquery(self):
IntegerArrayModel.objects.create(field=[2, 3])
- inner_qs = IntegerArrayModel.objects.values_list('field', flat=True)
+ inner_qs = IntegerArrayModel.objects.values_list("field", flat=True)
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(field__contains=inner_qs[:1]),
self.objs[2:3],
)
- inner_qs = IntegerArrayModel.objects.filter(field__contains=OuterRef('field'))
+ inner_qs = IntegerArrayModel.objects.filter(field__contains=OuterRef("field"))
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(Exists(inner_qs)),
self.objs[1:3],
@@ -335,93 +345,92 @@ class TestQuerying(PostgreSQLTestCase):
def test_icontains(self):
# Using the __icontains lookup with ArrayField is inefficient.
- instance = CharArrayModel.objects.create(field=['FoO'])
+ instance = CharArrayModel.objects.create(field=["FoO"])
self.assertSequenceEqual(
- CharArrayModel.objects.filter(field__icontains='foo'),
- [instance]
+ CharArrayModel.objects.filter(field__icontains="foo"), [instance]
)
def test_contains_charfield(self):
# Regression for #22907
self.assertSequenceEqual(
- CharArrayModel.objects.filter(field__contains=['text']),
- []
+ CharArrayModel.objects.filter(field__contains=["text"]), []
)
def test_contained_by_charfield(self):
self.assertSequenceEqual(
- CharArrayModel.objects.filter(field__contained_by=['text']),
- []
+ CharArrayModel.objects.filter(field__contained_by=["text"]), []
)
def test_overlap_charfield(self):
self.assertSequenceEqual(
- CharArrayModel.objects.filter(field__overlap=['text']),
- []
+ CharArrayModel.objects.filter(field__overlap=["text"]), []
)
def test_overlap_charfield_including_expression(self):
- obj_1 = CharArrayModel.objects.create(field=['TEXT', 'lower text'])
- obj_2 = CharArrayModel.objects.create(field=['lower text', 'TEXT'])
- CharArrayModel.objects.create(field=['lower text', 'text'])
+ obj_1 = CharArrayModel.objects.create(field=["TEXT", "lower text"])
+ obj_2 = CharArrayModel.objects.create(field=["lower text", "TEXT"])
+ CharArrayModel.objects.create(field=["lower text", "text"])
self.assertSequenceEqual(
- CharArrayModel.objects.filter(field__overlap=[
- Upper(Value('text')),
- 'other',
- ]),
+ CharArrayModel.objects.filter(
+ field__overlap=[
+ Upper(Value("text")),
+ "other",
+ ]
+ ),
[obj_1, obj_2],
)
def test_lookups_autofield_array(self):
- qs = NullableIntegerArrayModel.objects.filter(
- field__0__isnull=False,
- ).values('field__0').annotate(
- arrayagg=ArrayAgg('id'),
- ).order_by('field__0')
+ qs = (
+ NullableIntegerArrayModel.objects.filter(
+ field__0__isnull=False,
+ )
+ .values("field__0")
+ .annotate(
+ arrayagg=ArrayAgg("id"),
+ )
+ .order_by("field__0")
+ )
tests = (
- ('contained_by', [self.objs[1].pk, self.objs[2].pk, 0], [2]),
- ('contains', [self.objs[2].pk], [2]),
- ('exact', [self.objs[3].pk], [20]),
- ('overlap', [self.objs[1].pk, self.objs[3].pk], [2, 20]),
+ ("contained_by", [self.objs[1].pk, self.objs[2].pk, 0], [2]),
+ ("contains", [self.objs[2].pk], [2]),
+ ("exact", [self.objs[3].pk], [20]),
+ ("overlap", [self.objs[1].pk, self.objs[3].pk], [2, 20]),
)
for lookup, value, expected in tests:
with self.subTest(lookup=lookup):
self.assertSequenceEqual(
qs.filter(
- **{'arrayagg__' + lookup: value},
- ).values_list('field__0', flat=True),
+ **{"arrayagg__" + lookup: value},
+ ).values_list("field__0", flat=True),
expected,
)
def test_index(self):
self.assertSequenceEqual(
- NullableIntegerArrayModel.objects.filter(field__0=2),
- self.objs[1:3]
+ NullableIntegerArrayModel.objects.filter(field__0=2), self.objs[1:3]
)
def test_index_chained(self):
self.assertSequenceEqual(
- NullableIntegerArrayModel.objects.filter(field__0__lt=3),
- self.objs[0:3]
+ NullableIntegerArrayModel.objects.filter(field__0__lt=3), self.objs[0:3]
)
def test_index_nested(self):
instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])
self.assertSequenceEqual(
- NestedIntegerArrayModel.objects.filter(field__0__0=1),
- [instance]
+ NestedIntegerArrayModel.objects.filter(field__0__0=1), [instance]
)
@unittest.expectedFailure
def test_index_used_on_nested_data(self):
instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])
self.assertSequenceEqual(
- NestedIntegerArrayModel.objects.filter(field__0=[1, 2]),
- [instance]
+ NestedIntegerArrayModel.objects.filter(field__0=[1, 2]), [instance]
)
def test_index_transform_expression(self):
- expr = RawSQL("string_to_array(%s, ';')", ['1;2'])
+ expr = RawSQL("string_to_array(%s, ';')", ["1;2"])
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(
field__0=Cast(
@@ -433,40 +442,36 @@ class TestQuerying(PostgreSQLTestCase):
)
def test_index_annotation(self):
- qs = NullableIntegerArrayModel.objects.annotate(second=models.F('field__1'))
+ qs = NullableIntegerArrayModel.objects.annotate(second=models.F("field__1"))
self.assertCountEqual(
- qs.values_list('second', flat=True),
+ qs.values_list("second", flat=True),
[None, None, None, 3, 30],
)
def test_overlap(self):
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(field__overlap=[1, 2]),
- self.objs[0:3]
+ self.objs[0:3],
)
def test_len(self):
self.assertSequenceEqual(
- NullableIntegerArrayModel.objects.filter(field__len__lte=2),
- self.objs[0:3]
+ NullableIntegerArrayModel.objects.filter(field__len__lte=2), self.objs[0:3]
)
def test_len_empty_array(self):
obj = NullableIntegerArrayModel.objects.create(field=[])
self.assertSequenceEqual(
- NullableIntegerArrayModel.objects.filter(field__len=0),
- [obj]
+ NullableIntegerArrayModel.objects.filter(field__len=0), [obj]
)
def test_slice(self):
self.assertSequenceEqual(
- NullableIntegerArrayModel.objects.filter(field__0_1=[2]),
- self.objs[1:3]
+ NullableIntegerArrayModel.objects.filter(field__0_1=[2]), self.objs[1:3]
)
self.assertSequenceEqual(
- NullableIntegerArrayModel.objects.filter(field__0_2=[2, 3]),
- self.objs[2:3]
+ NullableIntegerArrayModel.objects.filter(field__0_2=[2, 3]), self.objs[2:3]
)
def test_order_by_slice(self):
@@ -477,35 +482,42 @@ class TestQuerying(PostgreSQLTestCase):
NullableIntegerArrayModel.objects.create(field=[4, 2]),
)
self.assertSequenceEqual(
- NullableIntegerArrayModel.objects.order_by('field__1'),
+ NullableIntegerArrayModel.objects.order_by("field__1"),
[
- more_objs[2], more_objs[1], more_objs[3], self.objs[2],
- self.objs[3], more_objs[0], self.objs[4], self.objs[1],
+ more_objs[2],
+ more_objs[1],
+ more_objs[3],
+ self.objs[2],
+ self.objs[3],
+ more_objs[0],
+ self.objs[4],
+ self.objs[1],
self.objs[0],
- ]
+ ],
)
@unittest.expectedFailure
def test_slice_nested(self):
instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])
self.assertSequenceEqual(
- NestedIntegerArrayModel.objects.filter(field__0__0_1=[1]),
- [instance]
+ NestedIntegerArrayModel.objects.filter(field__0__0_1=[1]), [instance]
)
def test_slice_transform_expression(self):
- expr = RawSQL("string_to_array(%s, ';')", ['9;2;3'])
+ expr = RawSQL("string_to_array(%s, ';')", ["9;2;3"])
self.assertSequenceEqual(
- NullableIntegerArrayModel.objects.filter(field__0_2=SliceTransform(2, 3, expr)),
+ NullableIntegerArrayModel.objects.filter(
+ field__0_2=SliceTransform(2, 3, expr)
+ ),
self.objs[2:3],
)
def test_slice_annotation(self):
qs = NullableIntegerArrayModel.objects.annotate(
- first_two=models.F('field__0_2'),
+ first_two=models.F("field__0_2"),
)
self.assertCountEqual(
- qs.values_list('first_two', flat=True),
+ qs.values_list("first_two", flat=True),
[None, [1], [2], [2, 3], [20, 30]],
)
@@ -514,17 +526,17 @@ class TestQuerying(PostgreSQLTestCase):
NullableIntegerArrayModel.objects.filter(
id__in=NullableIntegerArrayModel.objects.filter(field__len=3)
),
- [self.objs[3]]
+ [self.objs[3]],
)
def test_enum_lookup(self):
class TestEnum(enum.Enum):
- VALUE_1 = 'value_1'
+ VALUE_1 = "value_1"
instance = ArrayEnumModel.objects.create(array_of_enums=[TestEnum.VALUE_1])
self.assertSequenceEqual(
ArrayEnumModel.objects.filter(array_of_enums__contains=[TestEnum.VALUE_1]),
- [instance]
+ [instance],
)
def test_unsupported_lookup(self):
@@ -541,18 +553,24 @@ class TestQuerying(PostgreSQLTestCase):
self.assertEqual(
NullableIntegerArrayModel.objects.annotate(
array_length=models.Func(
- value, 1, function='ARRAY_LENGTH', output_field=models.IntegerField(),
+ value,
+ 1,
+ function="ARRAY_LENGTH",
+ output_field=models.IntegerField(),
),
- ).values('array_length').annotate(
- count=models.Count('pk'),
- ).get()['array_length'],
+ )
+ .values("array_length")
+ .annotate(
+ count=models.Count("pk"),
+ )
+ .get()["array_length"],
1,
)
def test_filter_by_array_subquery(self):
inner_qs = NullableIntegerArrayModel.objects.filter(
- field__len=models.OuterRef('field__len'),
- ).values('field')
+ field__len=models.OuterRef("field__len"),
+ ).values("field")
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.alias(
same_sized_fields=ArraySubquery(inner_qs),
@@ -562,56 +580,63 @@ class TestQuerying(PostgreSQLTestCase):
def test_annotated_array_subquery(self):
inner_qs = NullableIntegerArrayModel.objects.exclude(
- pk=models.OuterRef('pk')
- ).values('order')
+ pk=models.OuterRef("pk")
+ ).values("order")
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.annotate(
sibling_ids=ArraySubquery(inner_qs),
- ).get(order=1).sibling_ids,
+ )
+ .get(order=1)
+ .sibling_ids,
[2, 3, 4, 5],
)
def test_group_by_with_annotated_array_subquery(self):
inner_qs = NullableIntegerArrayModel.objects.exclude(
- pk=models.OuterRef('pk')
- ).values('order')
+ pk=models.OuterRef("pk")
+ ).values("order")
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.annotate(
sibling_ids=ArraySubquery(inner_qs),
- sibling_count=models.Max('sibling_ids__len'),
- ).values_list('sibling_count', flat=True),
+ sibling_count=models.Max("sibling_ids__len"),
+ ).values_list("sibling_count", flat=True),
[len(self.objs) - 1] * len(self.objs),
)
def test_annotated_ordered_array_subquery(self):
- inner_qs = NullableIntegerArrayModel.objects.order_by('-order').values('order')
+ inner_qs = NullableIntegerArrayModel.objects.order_by("-order").values("order")
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.annotate(
ids=ArraySubquery(inner_qs),
- ).first().ids,
+ )
+ .first()
+ .ids,
[5, 4, 3, 2, 1],
)
def test_annotated_array_subquery_with_json_objects(self):
inner_qs = NullableIntegerArrayModel.objects.exclude(
- pk=models.OuterRef('pk')
- ).values(json=JSONObject(order='order', field='field'))
- siblings_json = NullableIntegerArrayModel.objects.annotate(
- siblings_json=ArraySubquery(inner_qs),
- ).values_list('siblings_json', flat=True).get(order=1)
+ pk=models.OuterRef("pk")
+ ).values(json=JSONObject(order="order", field="field"))
+ siblings_json = (
+ NullableIntegerArrayModel.objects.annotate(
+ siblings_json=ArraySubquery(inner_qs),
+ )
+ .values_list("siblings_json", flat=True)
+ .get(order=1)
+ )
self.assertSequenceEqual(
siblings_json,
[
- {'field': [2], 'order': 2},
- {'field': [2, 3], 'order': 3},
- {'field': [20, 30, 40], 'order': 4},
- {'field': None, 'order': 5},
+ {"field": [2], "order": 2},
+ {"field": [2, 3], "order": 3},
+ {"field": [20, 30, 40], "order": 4},
+ {"field": None, "order": 5},
],
)
class TestDateTimeExactQuerying(PostgreSQLTestCase):
-
@classmethod
def setUpTestData(cls):
now = timezone.now()
@@ -619,33 +644,31 @@ class TestDateTimeExactQuerying(PostgreSQLTestCase):
cls.dates = [now.date()]
cls.times = [now.time()]
cls.objs = [
- DateTimeArrayModel.objects.create(datetimes=cls.datetimes, dates=cls.dates, times=cls.times),
+ DateTimeArrayModel.objects.create(
+ datetimes=cls.datetimes, dates=cls.dates, times=cls.times
+ ),
]
def test_exact_datetimes(self):
self.assertSequenceEqual(
- DateTimeArrayModel.objects.filter(datetimes=self.datetimes),
- self.objs
+ DateTimeArrayModel.objects.filter(datetimes=self.datetimes), self.objs
)
def test_exact_dates(self):
self.assertSequenceEqual(
- DateTimeArrayModel.objects.filter(dates=self.dates),
- self.objs
+ DateTimeArrayModel.objects.filter(dates=self.dates), self.objs
)
def test_exact_times(self):
self.assertSequenceEqual(
- DateTimeArrayModel.objects.filter(times=self.times),
- self.objs
+ DateTimeArrayModel.objects.filter(times=self.times), self.objs
)
class TestOtherTypesExactQuerying(PostgreSQLTestCase):
-
@classmethod
def setUpTestData(cls):
- cls.ips = ['192.168.0.1', '::1']
+ cls.ips = ["192.168.0.1", "::1"]
cls.uuids = [uuid.uuid4()]
cls.decimals = [decimal.Decimal(1.25), 1.75]
cls.tags = [Tag(1), Tag(2), Tag(3)]
@@ -660,32 +683,27 @@ class TestOtherTypesExactQuerying(PostgreSQLTestCase):
def test_exact_ip_addresses(self):
self.assertSequenceEqual(
- OtherTypesArrayModel.objects.filter(ips=self.ips),
- self.objs
+ OtherTypesArrayModel.objects.filter(ips=self.ips), self.objs
)
def test_exact_uuids(self):
self.assertSequenceEqual(
- OtherTypesArrayModel.objects.filter(uuids=self.uuids),
- self.objs
+ OtherTypesArrayModel.objects.filter(uuids=self.uuids), self.objs
)
def test_exact_decimals(self):
self.assertSequenceEqual(
- OtherTypesArrayModel.objects.filter(decimals=self.decimals),
- self.objs
+ OtherTypesArrayModel.objects.filter(decimals=self.decimals), self.objs
)
def test_exact_tags(self):
self.assertSequenceEqual(
- OtherTypesArrayModel.objects.filter(tags=self.tags),
- self.objs
+ OtherTypesArrayModel.objects.filter(tags=self.tags), self.objs
)
-@isolate_apps('postgres_tests')
+@isolate_apps("postgres_tests")
class TestChecks(PostgreSQLSimpleTestCase):
-
def test_field_checks(self):
class MyModel(PostgreSQLModel):
field = ArrayField(models.CharField())
@@ -694,35 +712,40 @@ class TestChecks(PostgreSQLSimpleTestCase):
errors = model.check()
self.assertEqual(len(errors), 1)
# The inner CharField is missing a max_length.
- self.assertEqual(errors[0].id, 'postgres.E001')
- self.assertIn('max_length', errors[0].msg)
+ self.assertEqual(errors[0].id, "postgres.E001")
+ self.assertIn("max_length", errors[0].msg)
def test_invalid_base_fields(self):
class MyModel(PostgreSQLModel):
- field = ArrayField(models.ManyToManyField('postgres_tests.IntegerArrayModel'))
+ field = ArrayField(
+ models.ManyToManyField("postgres_tests.IntegerArrayModel")
+ )
model = MyModel()
errors = model.check()
self.assertEqual(len(errors), 1)
- self.assertEqual(errors[0].id, 'postgres.E002')
+ self.assertEqual(errors[0].id, "postgres.E002")
def test_invalid_default(self):
class MyModel(PostgreSQLModel):
field = ArrayField(models.IntegerField(), default=[])
model = MyModel()
- self.assertEqual(model.check(), [
- checks.Warning(
- msg=(
- "ArrayField default should be a callable instead of an "
- "instance so that it's not shared between all field "
- "instances."
- ),
- hint='Use a callable instead, e.g., use `list` instead of `[]`.',
- obj=MyModel._meta.get_field('field'),
- id='fields.E010',
- )
- ])
+ self.assertEqual(
+ model.check(),
+ [
+ checks.Warning(
+ msg=(
+ "ArrayField default should be a callable instead of an "
+ "instance so that it's not shared between all field "
+ "instances."
+ ),
+ hint="Use a callable instead, e.g., use `list` instead of `[]`.",
+ obj=MyModel._meta.get_field("field"),
+ id="fields.E010",
+ )
+ ],
+ )
def test_valid_default(self):
class MyModel(PostgreSQLModel):
@@ -742,6 +765,7 @@ class TestChecks(PostgreSQLSimpleTestCase):
"""
Nested ArrayFields are permitted.
"""
+
class MyModel(PostgreSQLModel):
field = ArrayField(ArrayField(models.CharField()))
@@ -749,8 +773,8 @@ class TestChecks(PostgreSQLSimpleTestCase):
errors = model.check()
self.assertEqual(len(errors), 1)
# The inner CharField is missing a max_length.
- self.assertEqual(errors[0].id, 'postgres.E001')
- self.assertIn('max_length', errors[0].msg)
+ self.assertEqual(errors[0].id, "postgres.E001")
+ self.assertIn("max_length", errors[0].msg)
def test_choices_tuple_list(self):
class MyModel(PostgreSQLModel):
@@ -758,19 +782,20 @@ class TestChecks(PostgreSQLSimpleTestCase):
models.CharField(max_length=16),
choices=[
[
- 'Media',
- [(['vinyl', 'cd'], 'Audio'), (('vhs', 'dvd'), 'Video')],
+ "Media",
+ [(["vinyl", "cd"], "Audio"), (("vhs", "dvd"), "Video")],
],
- (['mp3', 'mp4'], 'Digital'),
+ (["mp3", "mp4"], "Digital"),
],
)
- self.assertEqual(MyModel._meta.get_field('field').check(), [])
+
+ self.assertEqual(MyModel._meta.get_field("field").check(), [])
-@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific tests")
+@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests")
class TestMigrations(TransactionTestCase):
- available_apps = ['postgres_tests']
+ available_apps = ["postgres_tests"]
def test_deconstruct(self):
field = ArrayField(models.IntegerField())
@@ -794,84 +819,89 @@ class TestMigrations(TransactionTestCase):
def test_subclass_deconstruct(self):
field = ArrayField(models.IntegerField())
name, path, args, kwargs = field.deconstruct()
- self.assertEqual(path, 'django.contrib.postgres.fields.ArrayField')
+ self.assertEqual(path, "django.contrib.postgres.fields.ArrayField")
field = ArrayFieldSubclass()
name, path, args, kwargs = field.deconstruct()
- self.assertEqual(path, 'postgres_tests.models.ArrayFieldSubclass')
+ self.assertEqual(path, "postgres_tests.models.ArrayFieldSubclass")
- @override_settings(MIGRATION_MODULES={
- "postgres_tests": "postgres_tests.array_default_migrations",
- })
+ @override_settings(
+ MIGRATION_MODULES={
+ "postgres_tests": "postgres_tests.array_default_migrations",
+ }
+ )
def test_adding_field_with_default(self):
# See #22962
- table_name = 'postgres_tests_integerarraydefaultmodel'
+ table_name = "postgres_tests_integerarraydefaultmodel"
with connection.cursor() as cursor:
self.assertNotIn(table_name, connection.introspection.table_names(cursor))
- call_command('migrate', 'postgres_tests', verbosity=0)
+ call_command("migrate", "postgres_tests", verbosity=0)
with connection.cursor() as cursor:
self.assertIn(table_name, connection.introspection.table_names(cursor))
- call_command('migrate', 'postgres_tests', 'zero', verbosity=0)
+ call_command("migrate", "postgres_tests", "zero", verbosity=0)
with connection.cursor() as cursor:
self.assertNotIn(table_name, connection.introspection.table_names(cursor))
- @override_settings(MIGRATION_MODULES={
- "postgres_tests": "postgres_tests.array_index_migrations",
- })
+ @override_settings(
+ MIGRATION_MODULES={
+ "postgres_tests": "postgres_tests.array_index_migrations",
+ }
+ )
def test_adding_arrayfield_with_index(self):
"""
ArrayField shouldn't have varchar_patterns_ops or text_patterns_ops indexes.
"""
- table_name = 'postgres_tests_chartextarrayindexmodel'
- call_command('migrate', 'postgres_tests', verbosity=0)
+ table_name = "postgres_tests_chartextarrayindexmodel"
+ call_command("migrate", "postgres_tests", verbosity=0)
with connection.cursor() as cursor:
like_constraint_columns_list = [
- v['columns']
- for k, v in list(connection.introspection.get_constraints(cursor, table_name).items())
- if k.endswith('_like')
+ v["columns"]
+ for k, v in list(
+ connection.introspection.get_constraints(cursor, table_name).items()
+ )
+ if k.endswith("_like")
]
# Only the CharField should have a LIKE index.
- self.assertEqual(like_constraint_columns_list, [['char2']])
+ self.assertEqual(like_constraint_columns_list, [["char2"]])
# All fields should have regular indexes.
with connection.cursor() as cursor:
indexes = [
- c['columns'][0]
- for c in connection.introspection.get_constraints(cursor, table_name).values()
- if c['index'] and len(c['columns']) == 1
+ c["columns"][0]
+ for c in connection.introspection.get_constraints(
+ cursor, table_name
+ ).values()
+ if c["index"] and len(c["columns"]) == 1
]
- self.assertIn('char', indexes)
- self.assertIn('char2', indexes)
- self.assertIn('text', indexes)
- call_command('migrate', 'postgres_tests', 'zero', verbosity=0)
+ self.assertIn("char", indexes)
+ self.assertIn("char2", indexes)
+ self.assertIn("text", indexes)
+ call_command("migrate", "postgres_tests", "zero", verbosity=0)
with connection.cursor() as cursor:
self.assertNotIn(table_name, connection.introspection.table_names(cursor))
class TestSerialization(PostgreSQLSimpleTestCase):
- test_data = (
- '[{"fields": {"field": "[\\"1\\", \\"2\\", null]"}, "model": "postgres_tests.integerarraymodel", "pk": null}]'
- )
+ test_data = '[{"fields": {"field": "[\\"1\\", \\"2\\", null]"}, "model": "postgres_tests.integerarraymodel", "pk": null}]'
def test_dumping(self):
instance = IntegerArrayModel(field=[1, 2, None])
- data = serializers.serialize('json', [instance])
+ data = serializers.serialize("json", [instance])
self.assertEqual(json.loads(data), json.loads(self.test_data))
def test_loading(self):
- instance = list(serializers.deserialize('json', self.test_data))[0].object
+ instance = list(serializers.deserialize("json", self.test_data))[0].object
self.assertEqual(instance.field, [1, 2, None])
class TestValidation(PostgreSQLSimpleTestCase):
-
def test_unbounded(self):
field = ArrayField(models.IntegerField())
with self.assertRaises(exceptions.ValidationError) as cm:
field.clean([1, None], None)
- self.assertEqual(cm.exception.code, 'item_invalid')
+ self.assertEqual(cm.exception.code, "item_invalid")
self.assertEqual(
cm.exception.message % cm.exception.params,
- 'Item 2 in the array did not validate: This field cannot be null.'
+ "Item 2 in the array did not validate: This field cannot be null.",
)
def test_blank_true(self):
@@ -884,31 +914,41 @@ class TestValidation(PostgreSQLSimpleTestCase):
field.clean([1, 2, 3], None)
with self.assertRaises(exceptions.ValidationError) as cm:
field.clean([1, 2, 3, 4], None)
- self.assertEqual(cm.exception.messages[0], 'List contains 4 items, it should contain no more than 3.')
+ self.assertEqual(
+ cm.exception.messages[0],
+ "List contains 4 items, it should contain no more than 3.",
+ )
def test_nested_array_mismatch(self):
field = ArrayField(ArrayField(models.IntegerField()))
field.clean([[1, 2], [3, 4]], None)
with self.assertRaises(exceptions.ValidationError) as cm:
field.clean([[1, 2], [3, 4, 5]], None)
- self.assertEqual(cm.exception.code, 'nested_array_mismatch')
- self.assertEqual(cm.exception.messages[0], 'Nested arrays must have the same length.')
+ self.assertEqual(cm.exception.code, "nested_array_mismatch")
+ self.assertEqual(
+ cm.exception.messages[0], "Nested arrays must have the same length."
+ )
def test_with_base_field_error_params(self):
field = ArrayField(models.CharField(max_length=2))
with self.assertRaises(exceptions.ValidationError) as cm:
- field.clean(['abc'], None)
+ field.clean(["abc"], None)
self.assertEqual(len(cm.exception.error_list), 1)
exception = cm.exception.error_list[0]
self.assertEqual(
exception.message,
- 'Item 1 in the array did not validate: Ensure this value has at most 2 characters (it has 3).'
+ "Item 1 in the array did not validate: Ensure this value has at most 2 characters (it has 3).",
+ )
+ self.assertEqual(exception.code, "item_invalid")
+ self.assertEqual(
+ exception.params,
+ {"nth": 1, "value": "abc", "limit_value": 2, "show_value": 3},
)
- self.assertEqual(exception.code, 'item_invalid')
- self.assertEqual(exception.params, {'nth': 1, 'value': 'abc', 'limit_value': 2, 'show_value': 3})
def test_with_validators(self):
- field = ArrayField(models.IntegerField(validators=[validators.MinValueValidator(1)]))
+ field = ArrayField(
+ models.IntegerField(validators=[validators.MinValueValidator(1)])
+ )
field.clean([1, 2], None)
with self.assertRaises(exceptions.ValidationError) as cm:
field.clean([0], None)
@@ -916,90 +956,112 @@ class TestValidation(PostgreSQLSimpleTestCase):
exception = cm.exception.error_list[0]
self.assertEqual(
exception.message,
- 'Item 1 in the array did not validate: Ensure this value is greater than or equal to 1.'
+ "Item 1 in the array did not validate: Ensure this value is greater than or equal to 1.",
+ )
+ self.assertEqual(exception.code, "item_invalid")
+ self.assertEqual(
+ exception.params, {"nth": 1, "value": 0, "limit_value": 1, "show_value": 0}
)
- self.assertEqual(exception.code, 'item_invalid')
- self.assertEqual(exception.params, {'nth': 1, 'value': 0, 'limit_value': 1, 'show_value': 0})
class TestSimpleFormField(PostgreSQLSimpleTestCase):
-
def test_valid(self):
field = SimpleArrayField(forms.CharField())
- value = field.clean('a,b,c')
- self.assertEqual(value, ['a', 'b', 'c'])
+ value = field.clean("a,b,c")
+ self.assertEqual(value, ["a", "b", "c"])
def test_to_python_fail(self):
field = SimpleArrayField(forms.IntegerField())
with self.assertRaises(exceptions.ValidationError) as cm:
- field.clean('a,b,9')
- self.assertEqual(cm.exception.messages[0], 'Item 1 in the array did not validate: Enter a whole number.')
+ field.clean("a,b,9")
+ self.assertEqual(
+ cm.exception.messages[0],
+ "Item 1 in the array did not validate: Enter a whole number.",
+ )
def test_validate_fail(self):
field = SimpleArrayField(forms.CharField(required=True))
with self.assertRaises(exceptions.ValidationError) as cm:
- field.clean('a,b,')
- self.assertEqual(cm.exception.messages[0], 'Item 3 in the array did not validate: This field is required.')
+ field.clean("a,b,")
+ self.assertEqual(
+ cm.exception.messages[0],
+ "Item 3 in the array did not validate: This field is required.",
+ )
def test_validate_fail_base_field_error_params(self):
field = SimpleArrayField(forms.CharField(max_length=2))
with self.assertRaises(exceptions.ValidationError) as cm:
- field.clean('abc,c,defg')
+ field.clean("abc,c,defg")
errors = cm.exception.error_list
self.assertEqual(len(errors), 2)
first_error = errors[0]
self.assertEqual(
first_error.message,
- 'Item 1 in the array did not validate: Ensure this value has at most 2 characters (it has 3).'
+ "Item 1 in the array did not validate: Ensure this value has at most 2 characters (it has 3).",
+ )
+ self.assertEqual(first_error.code, "item_invalid")
+ self.assertEqual(
+ first_error.params,
+ {"nth": 1, "value": "abc", "limit_value": 2, "show_value": 3},
)
- self.assertEqual(first_error.code, 'item_invalid')
- self.assertEqual(first_error.params, {'nth': 1, 'value': 'abc', 'limit_value': 2, 'show_value': 3})
second_error = errors[1]
self.assertEqual(
second_error.message,
- 'Item 3 in the array did not validate: Ensure this value has at most 2 characters (it has 4).'
+ "Item 3 in the array did not validate: Ensure this value has at most 2 characters (it has 4).",
+ )
+ self.assertEqual(second_error.code, "item_invalid")
+ self.assertEqual(
+ second_error.params,
+ {"nth": 3, "value": "defg", "limit_value": 2, "show_value": 4},
)
- self.assertEqual(second_error.code, 'item_invalid')
- self.assertEqual(second_error.params, {'nth': 3, 'value': 'defg', 'limit_value': 2, 'show_value': 4})
def test_validators_fail(self):
- field = SimpleArrayField(forms.RegexField('[a-e]{2}'))
+ field = SimpleArrayField(forms.RegexField("[a-e]{2}"))
with self.assertRaises(exceptions.ValidationError) as cm:
- field.clean('a,bc,de')
- self.assertEqual(cm.exception.messages[0], 'Item 1 in the array did not validate: Enter a valid value.')
+ field.clean("a,bc,de")
+ self.assertEqual(
+ cm.exception.messages[0],
+ "Item 1 in the array did not validate: Enter a valid value.",
+ )
def test_delimiter(self):
- field = SimpleArrayField(forms.CharField(), delimiter='|')
- value = field.clean('a|b|c')
- self.assertEqual(value, ['a', 'b', 'c'])
+ field = SimpleArrayField(forms.CharField(), delimiter="|")
+ value = field.clean("a|b|c")
+ self.assertEqual(value, ["a", "b", "c"])
def test_delimiter_with_nesting(self):
- field = SimpleArrayField(SimpleArrayField(forms.CharField()), delimiter='|')
- value = field.clean('a,b|c,d')
- self.assertEqual(value, [['a', 'b'], ['c', 'd']])
+ field = SimpleArrayField(SimpleArrayField(forms.CharField()), delimiter="|")
+ value = field.clean("a,b|c,d")
+ self.assertEqual(value, [["a", "b"], ["c", "d"]])
def test_prepare_value(self):
field = SimpleArrayField(forms.CharField())
- value = field.prepare_value(['a', 'b', 'c'])
- self.assertEqual(value, 'a,b,c')
+ value = field.prepare_value(["a", "b", "c"])
+ self.assertEqual(value, "a,b,c")
def test_max_length(self):
field = SimpleArrayField(forms.CharField(), max_length=2)
with self.assertRaises(exceptions.ValidationError) as cm:
- field.clean('a,b,c')
- self.assertEqual(cm.exception.messages[0], 'List contains 3 items, it should contain no more than 2.')
+ field.clean("a,b,c")
+ self.assertEqual(
+ cm.exception.messages[0],
+ "List contains 3 items, it should contain no more than 2.",
+ )
def test_min_length(self):
field = SimpleArrayField(forms.CharField(), min_length=4)
with self.assertRaises(exceptions.ValidationError) as cm:
- field.clean('a,b,c')
- self.assertEqual(cm.exception.messages[0], 'List contains 3 items, it should contain no fewer than 4.')
+ field.clean("a,b,c")
+ self.assertEqual(
+ cm.exception.messages[0],
+ "List contains 3 items, it should contain no fewer than 4.",
+ )
def test_required(self):
field = SimpleArrayField(forms.CharField(), required=True)
with self.assertRaises(exceptions.ValidationError) as cm:
- field.clean('')
- self.assertEqual(cm.exception.messages[0], 'This field is required.')
+ field.clean("")
+ self.assertEqual(cm.exception.messages[0], "This field is required.")
def test_model_field_formfield(self):
model_field = ArrayField(models.CharField(max_length=27))
@@ -1015,59 +1077,66 @@ class TestSimpleFormField(PostgreSQLSimpleTestCase):
self.assertEqual(form_field.max_length, 4)
def test_model_field_choices(self):
- model_field = ArrayField(models.IntegerField(choices=((1, 'A'), (2, 'B'))))
+ model_field = ArrayField(models.IntegerField(choices=((1, "A"), (2, "B"))))
form_field = model_field.formfield()
- self.assertEqual(form_field.clean('1,2'), [1, 2])
+ self.assertEqual(form_field.clean("1,2"), [1, 2])
def test_already_converted_value(self):
field = SimpleArrayField(forms.CharField())
- vals = ['a', 'b', 'c']
+ vals = ["a", "b", "c"]
self.assertEqual(field.clean(vals), vals)
def test_has_changed(self):
field = SimpleArrayField(forms.IntegerField())
self.assertIs(field.has_changed([1, 2], [1, 2]), False)
- self.assertIs(field.has_changed([1, 2], '1,2'), False)
- self.assertIs(field.has_changed([1, 2], '1,2,3'), True)
- self.assertIs(field.has_changed([1, 2], 'a,b'), True)
+ self.assertIs(field.has_changed([1, 2], "1,2"), False)
+ self.assertIs(field.has_changed([1, 2], "1,2,3"), True)
+ self.assertIs(field.has_changed([1, 2], "a,b"), True)
def test_has_changed_empty(self):
field = SimpleArrayField(forms.CharField())
self.assertIs(field.has_changed(None, None), False)
- self.assertIs(field.has_changed(None, ''), False)
+ self.assertIs(field.has_changed(None, ""), False)
self.assertIs(field.has_changed(None, []), False)
self.assertIs(field.has_changed([], None), False)
- self.assertIs(field.has_changed([], ''), False)
+ self.assertIs(field.has_changed([], ""), False)
class TestSplitFormField(PostgreSQLSimpleTestCase):
-
def test_valid(self):
class SplitForm(forms.Form):
array = SplitArrayField(forms.CharField(), size=3)
- data = {'array_0': 'a', 'array_1': 'b', 'array_2': 'c'}
+ data = {"array_0": "a", "array_1": "b", "array_2": "c"}
form = SplitForm(data)
self.assertTrue(form.is_valid())
- self.assertEqual(form.cleaned_data, {'array': ['a', 'b', 'c']})
+ self.assertEqual(form.cleaned_data, {"array": ["a", "b", "c"]})
def test_required(self):
class SplitForm(forms.Form):
array = SplitArrayField(forms.CharField(), required=True, size=3)
- data = {'array_0': '', 'array_1': '', 'array_2': ''}
+ data = {"array_0": "", "array_1": "", "array_2": ""}
form = SplitForm(data)
self.assertFalse(form.is_valid())
- self.assertEqual(form.errors, {'array': ['This field is required.']})
+ self.assertEqual(form.errors, {"array": ["This field is required."]})
def test_remove_trailing_nulls(self):
class SplitForm(forms.Form):
- array = SplitArrayField(forms.CharField(required=False), size=5, remove_trailing_nulls=True)
+ array = SplitArrayField(
+ forms.CharField(required=False), size=5, remove_trailing_nulls=True
+ )
- data = {'array_0': 'a', 'array_1': '', 'array_2': 'b', 'array_3': '', 'array_4': ''}
+ data = {
+ "array_0": "a",
+ "array_1": "",
+ "array_2": "b",
+ "array_3": "",
+ "array_4": "",
+ }
form = SplitForm(data)
self.assertTrue(form.is_valid(), form.errors)
- self.assertEqual(form.cleaned_data, {'array': ['a', '', 'b']})
+ self.assertEqual(form.cleaned_data, {"array": ["a", "", "b"]})
def test_remove_trailing_nulls_not_required(self):
class SplitForm(forms.Form):
@@ -1078,32 +1147,41 @@ class TestSplitFormField(PostgreSQLSimpleTestCase):
required=False,
)
- data = {'array_0': '', 'array_1': ''}
+ data = {"array_0": "", "array_1": ""}
form = SplitForm(data)
self.assertTrue(form.is_valid())
- self.assertEqual(form.cleaned_data, {'array': []})
+ self.assertEqual(form.cleaned_data, {"array": []})
def test_required_field(self):
class SplitForm(forms.Form):
array = SplitArrayField(forms.CharField(), size=3)
- data = {'array_0': 'a', 'array_1': 'b', 'array_2': ''}
+ data = {"array_0": "a", "array_1": "b", "array_2": ""}
form = SplitForm(data)
self.assertFalse(form.is_valid())
- self.assertEqual(form.errors, {'array': ['Item 3 in the array did not validate: This field is required.']})
+ self.assertEqual(
+ form.errors,
+ {
+ "array": [
+ "Item 3 in the array did not validate: This field is required."
+ ]
+ },
+ )
def test_invalid_integer(self):
- msg = 'Item 2 in the array did not validate: Ensure this value is less than or equal to 100.'
+ msg = "Item 2 in the array did not validate: Ensure this value is less than or equal to 100."
with self.assertRaisesMessage(exceptions.ValidationError, msg):
SplitArrayField(forms.IntegerField(max_value=100), size=2).clean([0, 101])
# To locate the widget's template.
- @modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
+ @modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
def test_rendering(self):
class SplitForm(forms.Form):
array = SplitArrayField(forms.CharField(), size=3)
- self.assertHTMLEqual(str(SplitForm()), '''
+ self.assertHTMLEqual(
+ str(SplitForm()),
+ """
<tr>
<th><label for="id_array_0">Array:</label></th>
<td>
@@ -1112,16 +1190,20 @@ class TestSplitFormField(PostgreSQLSimpleTestCase):
<input id="id_array_2" name="array_2" type="text" required>
</td>
</tr>
- ''')
+ """,
+ )
def test_invalid_char_length(self):
field = SplitArrayField(forms.CharField(max_length=2), size=3)
with self.assertRaises(exceptions.ValidationError) as cm:
- field.clean(['abc', 'c', 'defg'])
- self.assertEqual(cm.exception.messages, [
- 'Item 1 in the array did not validate: Ensure this value has at most 2 characters (it has 3).',
- 'Item 3 in the array did not validate: Ensure this value has at most 2 characters (it has 4).',
- ])
+ field.clean(["abc", "c", "defg"])
+ self.assertEqual(
+ cm.exception.messages,
+ [
+ "Item 1 in the array did not validate: Ensure this value has at most 2 characters (it has 3).",
+ "Item 3 in the array did not validate: Ensure this value has at most 2 characters (it has 4).",
+ ],
+ )
def test_splitarraywidget_value_omitted_from_data(self):
class Form(forms.ModelForm):
@@ -1129,9 +1211,9 @@ class TestSplitFormField(PostgreSQLSimpleTestCase):
class Meta:
model = IntegerArrayModel
- fields = ('field',)
+ fields = ("field",)
- form = Form({'field_0': '1', 'field_1': '2'})
+ form = Form({"field_0": "1", "field_1": "2"})
self.assertEqual(form.errors, {})
obj = form.save(commit=False)
self.assertEqual(obj.field, [1, 2])
@@ -1142,15 +1224,15 @@ class TestSplitFormField(PostgreSQLSimpleTestCase):
class Meta:
model = IntegerArrayModel
- fields = ('field',)
+ fields = ("field",)
tests = [
- ({}, {'field_0': '', 'field_1': ''}, True),
- ({'field': None}, {'field_0': '', 'field_1': ''}, True),
- ({'field': [1]}, {'field_0': '', 'field_1': ''}, True),
- ({'field': [1]}, {'field_0': '1', 'field_1': '0'}, True),
- ({'field': [1, 2]}, {'field_0': '1', 'field_1': '2'}, False),
- ({'field': [1, 2]}, {'field_0': 'a', 'field_1': 'b'}, True),
+ ({}, {"field_0": "", "field_1": ""}, True),
+ ({"field": None}, {"field_0": "", "field_1": ""}, True),
+ ({"field": [1]}, {"field_0": "", "field_1": ""}, True),
+ ({"field": [1]}, {"field_0": "1", "field_1": "0"}, True),
+ ({"field": [1, 2]}, {"field_0": "1", "field_1": "2"}, False),
+ ({"field": [1, 2]}, {"field_0": "a", "field_1": "b"}, True),
]
for initial, data, expected_result in tests:
with self.subTest(initial=initial, data=data):
@@ -1160,17 +1242,19 @@ class TestSplitFormField(PostgreSQLSimpleTestCase):
def test_splitarrayfield_remove_trailing_nulls_has_changed(self):
class Form(forms.ModelForm):
- field = SplitArrayField(forms.IntegerField(), required=False, size=2, remove_trailing_nulls=True)
+ field = SplitArrayField(
+ forms.IntegerField(), required=False, size=2, remove_trailing_nulls=True
+ )
class Meta:
model = IntegerArrayModel
- fields = ('field',)
+ fields = ("field",)
tests = [
- ({}, {'field_0': '', 'field_1': ''}, False),
- ({'field': None}, {'field_0': '', 'field_1': ''}, False),
- ({'field': []}, {'field_0': '', 'field_1': ''}, False),
- ({'field': [1]}, {'field_0': '1', 'field_1': ''}, False),
+ ({}, {"field_0": "", "field_1": ""}, False),
+ ({"field": None}, {"field_0": "", "field_1": ""}, False),
+ ({"field": []}, {"field_0": "", "field_1": ""}, False),
+ ({"field": [1]}, {"field_0": "1", "field_1": ""}, False),
]
for initial, data, expected_result in tests:
with self.subTest(initial=initial, data=data):
@@ -1180,77 +1264,91 @@ class TestSplitFormField(PostgreSQLSimpleTestCase):
class TestSplitFormWidget(PostgreSQLWidgetTestCase):
-
def test_get_context(self):
self.assertEqual(
- SplitArrayWidget(forms.TextInput(), size=2).get_context('name', ['val1', 'val2']),
+ SplitArrayWidget(forms.TextInput(), size=2).get_context(
+ "name", ["val1", "val2"]
+ ),
{
- 'widget': {
- 'name': 'name',
- 'is_hidden': False,
- 'required': False,
- 'value': "['val1', 'val2']",
- 'attrs': {},
- 'template_name': 'postgres/widgets/split_array.html',
- 'subwidgets': [
+ "widget": {
+ "name": "name",
+ "is_hidden": False,
+ "required": False,
+ "value": "['val1', 'val2']",
+ "attrs": {},
+ "template_name": "postgres/widgets/split_array.html",
+ "subwidgets": [
{
- 'name': 'name_0',
- 'is_hidden': False,
- 'required': False,
- 'value': 'val1',
- 'attrs': {},
- 'template_name': 'django/forms/widgets/text.html',
- 'type': 'text',
+ "name": "name_0",
+ "is_hidden": False,
+ "required": False,
+ "value": "val1",
+ "attrs": {},
+ "template_name": "django/forms/widgets/text.html",
+ "type": "text",
},
{
- 'name': 'name_1',
- 'is_hidden': False,
- 'required': False,
- 'value': 'val2',
- 'attrs': {},
- 'template_name': 'django/forms/widgets/text.html',
- 'type': 'text',
+ "name": "name_1",
+ "is_hidden": False,
+ "required": False,
+ "value": "val2",
+ "attrs": {},
+ "template_name": "django/forms/widgets/text.html",
+ "type": "text",
},
- ]
+ ],
}
- }
+ },
)
def test_checkbox_get_context_attrs(self):
context = SplitArrayWidget(
forms.CheckboxInput(),
size=2,
- ).get_context('name', [True, False])
- self.assertEqual(context['widget']['value'], '[True, False]')
+ ).get_context("name", [True, False])
+ self.assertEqual(context["widget"]["value"], "[True, False]")
self.assertEqual(
- [subwidget['attrs'] for subwidget in context['widget']['subwidgets']],
- [{'checked': True}, {}]
+ [subwidget["attrs"] for subwidget in context["widget"]["subwidgets"]],
+ [{"checked": True}, {}],
)
def test_render(self):
self.check_html(
- SplitArrayWidget(forms.TextInput(), size=2), 'array', None,
+ SplitArrayWidget(forms.TextInput(), size=2),
+ "array",
+ None,
"""
<input name="array_0" type="text">
<input name="array_1" type="text">
- """
+ """,
)
def test_render_attrs(self):
self.check_html(
SplitArrayWidget(forms.TextInput(), size=2),
- 'array', ['val1', 'val2'], attrs={'id': 'foo'},
+ "array",
+ ["val1", "val2"],
+ attrs={"id": "foo"},
html=(
"""
<input id="foo_0" name="array_0" type="text" value="val1">
<input id="foo_1" name="array_1" type="text" value="val2">
"""
- )
+ ),
)
def test_value_omitted_from_data(self):
widget = SplitArrayWidget(forms.TextInput(), size=2)
- self.assertIs(widget.value_omitted_from_data({}, {}, 'field'), True)
- self.assertIs(widget.value_omitted_from_data({'field_0': 'value'}, {}, 'field'), False)
- self.assertIs(widget.value_omitted_from_data({'field_1': 'value'}, {}, 'field'), False)
- self.assertIs(widget.value_omitted_from_data({'field_0': 'value', 'field_1': 'value'}, {}, 'field'), False)
+ self.assertIs(widget.value_omitted_from_data({}, {}, "field"), True)
+ self.assertIs(
+ widget.value_omitted_from_data({"field_0": "value"}, {}, "field"), False
+ )
+ self.assertIs(
+ widget.value_omitted_from_data({"field_1": "value"}, {}, "field"), False
+ )
+ self.assertIs(
+ widget.value_omitted_from_data(
+ {"field_0": "value", "field_1": "value"}, {}, "field"
+ ),
+ False,
+ )
diff --git a/tests/postgres_tests/test_bulk_update.py b/tests/postgres_tests/test_bulk_update.py
index da5aee0f70..f0b473efa7 100644
--- a/tests/postgres_tests/test_bulk_update.py
+++ b/tests/postgres_tests/test_bulk_update.py
@@ -2,8 +2,12 @@ from datetime import date
from . import PostgreSQLTestCase
from .models import (
- HStoreModel, IntegerArrayModel, NestedIntegerArrayModel,
- NullableIntegerArrayModel, OtherTypesArrayModel, RangesModel,
+ HStoreModel,
+ IntegerArrayModel,
+ NestedIntegerArrayModel,
+ NullableIntegerArrayModel,
+ OtherTypesArrayModel,
+ RangesModel,
)
try:
@@ -15,19 +19,28 @@ except ImportError:
class BulkSaveTests(PostgreSQLTestCase):
def test_bulk_update(self):
test_data = [
- (IntegerArrayModel, 'field', [], [1, 2, 3]),
- (NullableIntegerArrayModel, 'field', [1, 2, 3], None),
- (NestedIntegerArrayModel, 'field', [], [[1, 2, 3]]),
- (HStoreModel, 'field', {}, {1: 2}),
- (RangesModel, 'ints', None, NumericRange(lower=1, upper=10)),
- (RangesModel, 'dates', None, DateRange(lower=date.today(), upper=date.today())),
- (OtherTypesArrayModel, 'ips', [], ['1.2.3.4']),
- (OtherTypesArrayModel, 'json', [], [{'a': 'b'}])
+ (IntegerArrayModel, "field", [], [1, 2, 3]),
+ (NullableIntegerArrayModel, "field", [1, 2, 3], None),
+ (NestedIntegerArrayModel, "field", [], [[1, 2, 3]]),
+ (HStoreModel, "field", {}, {1: 2}),
+ (RangesModel, "ints", None, NumericRange(lower=1, upper=10)),
+ (
+ RangesModel,
+ "dates",
+ None,
+ DateRange(lower=date.today(), upper=date.today()),
+ ),
+ (OtherTypesArrayModel, "ips", [], ["1.2.3.4"]),
+ (OtherTypesArrayModel, "json", [], [{"a": "b"}]),
]
for Model, field, initial, new in test_data:
with self.subTest(model=Model, field=field):
- instances = Model.objects.bulk_create(Model(**{field: initial}) for _ in range(20))
+ instances = Model.objects.bulk_create(
+ Model(**{field: initial}) for _ in range(20)
+ )
for instance in instances:
setattr(instance, field, new)
Model.objects.bulk_update(instances, [field])
- self.assertSequenceEqual(Model.objects.filter(**{field: new}), instances)
+ self.assertSequenceEqual(
+ Model.objects.filter(**{field: new}), instances
+ )
diff --git a/tests/postgres_tests/test_citext.py b/tests/postgres_tests/test_citext.py
index 350f34ba33..f1c13184b6 100644
--- a/tests/postgres_tests/test_citext.py
+++ b/tests/postgres_tests/test_citext.py
@@ -10,26 +10,35 @@ from . import PostgreSQLTestCase
from .models import CITestModel
-@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
+@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
class CITextTestCase(PostgreSQLTestCase):
- case_sensitive_lookups = ('contains', 'startswith', 'endswith', 'regex')
+ case_sensitive_lookups = ("contains", "startswith", "endswith", "regex")
@classmethod
def setUpTestData(cls):
cls.john = CITestModel.objects.create(
- name='JoHn',
- email='joHn@johN.com',
- description='Average Joe named JoHn',
- array_field=['JoE', 'jOhn'],
+ name="JoHn",
+ email="joHn@johN.com",
+ description="Average Joe named JoHn",
+ array_field=["JoE", "jOhn"],
)
def test_equal_lowercase(self):
"""
citext removes the need for iexact as the index is case-insensitive.
"""
- self.assertEqual(CITestModel.objects.filter(name=self.john.name.lower()).count(), 1)
- self.assertEqual(CITestModel.objects.filter(email=self.john.email.lower()).count(), 1)
- self.assertEqual(CITestModel.objects.filter(description=self.john.description.lower()).count(), 1)
+ self.assertEqual(
+ CITestModel.objects.filter(name=self.john.name.lower()).count(), 1
+ )
+ self.assertEqual(
+ CITestModel.objects.filter(email=self.john.email.lower()).count(), 1
+ )
+ self.assertEqual(
+ CITestModel.objects.filter(
+ description=self.john.description.lower()
+ ).count(),
+ 1,
+ )
def test_fail_citext_primary_key(self):
"""
@@ -37,27 +46,39 @@ class CITextTestCase(PostgreSQLTestCase):
clashes with an existing value isn't allowed.
"""
with self.assertRaises(IntegrityError):
- CITestModel.objects.create(name='John')
+ CITestModel.objects.create(name="John")
def test_array_field(self):
instance = CITestModel.objects.get()
self.assertEqual(instance.array_field, self.john.array_field)
- self.assertTrue(CITestModel.objects.filter(array_field__contains=['joe']).exists())
+ self.assertTrue(
+ CITestModel.objects.filter(array_field__contains=["joe"]).exists()
+ )
def test_lookups_name_char(self):
for lookup in self.case_sensitive_lookups:
with self.subTest(lookup=lookup):
- query = {'name__{}'.format(lookup): 'john'}
- self.assertSequenceEqual(CITestModel.objects.filter(**query), [self.john])
+ query = {"name__{}".format(lookup): "john"}
+ self.assertSequenceEqual(
+ CITestModel.objects.filter(**query), [self.john]
+ )
def test_lookups_description_text(self):
- for lookup, string in zip(self.case_sensitive_lookups, ('average', 'average joe', 'john', 'Joe.named')):
+ for lookup, string in zip(
+ self.case_sensitive_lookups, ("average", "average joe", "john", "Joe.named")
+ ):
with self.subTest(lookup=lookup, string=string):
- query = {'description__{}'.format(lookup): string}
- self.assertSequenceEqual(CITestModel.objects.filter(**query), [self.john])
+ query = {"description__{}".format(lookup): string}
+ self.assertSequenceEqual(
+ CITestModel.objects.filter(**query), [self.john]
+ )
def test_lookups_email(self):
- for lookup, string in zip(self.case_sensitive_lookups, ('john', 'john', 'john.com', 'john.com')):
+ for lookup, string in zip(
+ self.case_sensitive_lookups, ("john", "john", "john.com", "john.com")
+ ):
with self.subTest(lookup=lookup, string=string):
- query = {'email__{}'.format(lookup): string}
- self.assertSequenceEqual(CITestModel.objects.filter(**query), [self.john])
+ query = {"email__{}".format(lookup): string}
+ self.assertSequenceEqual(
+ CITestModel.objects.filter(**query), [self.john]
+ )
diff --git a/tests/postgres_tests/test_constraints.py b/tests/postgres_tests/test_constraints.py
index 22506ed62d..14b45f9b7f 100644
--- a/tests/postgres_tests/test_constraints.py
+++ b/tests/postgres_tests/test_constraints.py
@@ -2,11 +2,15 @@ import datetime
from unittest import mock
from django.contrib.postgres.indexes import OpClass
-from django.db import (
- IntegrityError, NotSupportedError, connection, transaction,
-)
+from django.db import IntegrityError, NotSupportedError, connection, transaction
from django.db.models import (
- CheckConstraint, Deferrable, F, Func, IntegerField, Q, UniqueConstraint,
+ CheckConstraint,
+ Deferrable,
+ F,
+ Func,
+ IntegerField,
+ Q,
+ UniqueConstraint,
)
from django.db.models.fields.json import KeyTextTransform
from django.db.models.functions import Cast, Left, Lower
@@ -15,29 +19,29 @@ from django.utils import timezone
from django.utils.deprecation import RemovedInDjango50Warning
from . import PostgreSQLTestCase
-from .models import (
- HotelReservation, IntegerArrayModel, RangesModel, Room, Scene,
-)
+from .models import HotelReservation, IntegerArrayModel, RangesModel, Room, Scene
try:
from psycopg2.extras import DateRange, NumericRange
from django.contrib.postgres.constraints import ExclusionConstraint
from django.contrib.postgres.fields import (
- DateTimeRangeField, RangeBoundary, RangeOperators,
+ DateTimeRangeField,
+ RangeBoundary,
+ RangeOperators,
)
except ImportError:
pass
-@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
+@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
class SchemaTests(PostgreSQLTestCase):
- get_opclass_query = '''
+ get_opclass_query = """
SELECT opcname, c.relname FROM pg_opclass AS oc
JOIN pg_index as i on oc.oid = ANY(i.indclass)
JOIN pg_class as c on c.oid = i.indexrelid
WHERE c.relname = %s
- '''
+ """
def get_constraints(self, table):
"""Get the constraints on the table using a new cursor."""
@@ -45,8 +49,10 @@ class SchemaTests(PostgreSQLTestCase):
return connection.introspection.get_constraints(cursor, table)
def test_check_constraint_range_value(self):
- constraint_name = 'ints_between'
- self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+ constraint_name = "ints_between"
+ self.assertNotIn(
+ constraint_name, self.get_constraints(RangesModel._meta.db_table)
+ )
constraint = CheckConstraint(
check=Q(ints__contained_by=NumericRange(10, 30)),
name=constraint_name,
@@ -59,10 +65,12 @@ class SchemaTests(PostgreSQLTestCase):
RangesModel.objects.create(ints=(10, 30))
def test_check_constraint_daterange_contains(self):
- constraint_name = 'dates_contains'
- self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+ constraint_name = "dates_contains"
+ self.assertNotIn(
+ constraint_name, self.get_constraints(RangesModel._meta.db_table)
+ )
constraint = CheckConstraint(
- check=Q(dates__contains=F('dates_inner')),
+ check=Q(dates__contains=F("dates_inner")),
name=constraint_name,
)
with connection.schema_editor() as editor:
@@ -81,10 +89,12 @@ class SchemaTests(PostgreSQLTestCase):
)
def test_check_constraint_datetimerange_contains(self):
- constraint_name = 'timestamps_contains'
- self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+ constraint_name = "timestamps_contains"
+ self.assertNotIn(
+ constraint_name, self.get_constraints(RangesModel._meta.db_table)
+ )
constraint = CheckConstraint(
- check=Q(timestamps__contains=F('timestamps_inner')),
+ check=Q(timestamps__contains=F("timestamps_inner")),
name=constraint_name,
)
with connection.schema_editor() as editor:
@@ -104,9 +114,9 @@ class SchemaTests(PostgreSQLTestCase):
def test_opclass(self):
constraint = UniqueConstraint(
- name='test_opclass',
- fields=['scene'],
- opclasses=['varchar_pattern_ops'],
+ name="test_opclass",
+ fields=["scene"],
+ opclasses=["varchar_pattern_ops"],
)
with connection.schema_editor() as editor:
editor.add_constraint(Scene, constraint)
@@ -115,7 +125,7 @@ class SchemaTests(PostgreSQLTestCase):
cursor.execute(self.get_opclass_query, [constraint.name])
self.assertEqual(
cursor.fetchall(),
- [('varchar_pattern_ops', constraint.name)],
+ [("varchar_pattern_ops", constraint.name)],
)
# Drop the constraint.
with connection.schema_editor() as editor:
@@ -124,25 +134,25 @@ class SchemaTests(PostgreSQLTestCase):
def test_opclass_multiple_columns(self):
constraint = UniqueConstraint(
- name='test_opclass_multiple',
- fields=['scene', 'setting'],
- opclasses=['varchar_pattern_ops', 'text_pattern_ops'],
+ name="test_opclass_multiple",
+ fields=["scene", "setting"],
+ opclasses=["varchar_pattern_ops", "text_pattern_ops"],
)
with connection.schema_editor() as editor:
editor.add_constraint(Scene, constraint)
with editor.connection.cursor() as cursor:
cursor.execute(self.get_opclass_query, [constraint.name])
expected_opclasses = (
- ('varchar_pattern_ops', constraint.name),
- ('text_pattern_ops', constraint.name),
+ ("varchar_pattern_ops", constraint.name),
+ ("text_pattern_ops", constraint.name),
)
self.assertCountEqual(cursor.fetchall(), expected_opclasses)
def test_opclass_partial(self):
constraint = UniqueConstraint(
- name='test_opclass_partial',
- fields=['scene'],
- opclasses=['varchar_pattern_ops'],
+ name="test_opclass_partial",
+ fields=["scene"],
+ opclasses=["varchar_pattern_ops"],
condition=Q(setting__contains="Sir Bedemir's Castle"),
)
with connection.schema_editor() as editor:
@@ -151,16 +161,16 @@ class SchemaTests(PostgreSQLTestCase):
cursor.execute(self.get_opclass_query, [constraint.name])
self.assertCountEqual(
cursor.fetchall(),
- [('varchar_pattern_ops', constraint.name)],
+ [("varchar_pattern_ops", constraint.name)],
)
- @skipUnlessDBFeature('supports_covering_indexes')
+ @skipUnlessDBFeature("supports_covering_indexes")
def test_opclass_include(self):
constraint = UniqueConstraint(
- name='test_opclass_include',
- fields=['scene'],
- opclasses=['varchar_pattern_ops'],
- include=['setting'],
+ name="test_opclass_include",
+ fields=["scene"],
+ opclasses=["varchar_pattern_ops"],
+ include=["setting"],
)
with connection.schema_editor() as editor:
editor.add_constraint(Scene, constraint)
@@ -168,38 +178,38 @@ class SchemaTests(PostgreSQLTestCase):
cursor.execute(self.get_opclass_query, [constraint.name])
self.assertCountEqual(
cursor.fetchall(),
- [('varchar_pattern_ops', constraint.name)],
+ [("varchar_pattern_ops", constraint.name)],
)
- @skipUnlessDBFeature('supports_expression_indexes')
+ @skipUnlessDBFeature("supports_expression_indexes")
def test_opclass_func(self):
constraint = UniqueConstraint(
- OpClass(Lower('scene'), name='text_pattern_ops'),
- name='test_opclass_func',
+ OpClass(Lower("scene"), name="text_pattern_ops"),
+ name="test_opclass_func",
)
with connection.schema_editor() as editor:
editor.add_constraint(Scene, constraint)
constraints = self.get_constraints(Scene._meta.db_table)
- self.assertIs(constraints[constraint.name]['unique'], True)
+ self.assertIs(constraints[constraint.name]["unique"], True)
self.assertIn(constraint.name, constraints)
with editor.connection.cursor() as cursor:
cursor.execute(self.get_opclass_query, [constraint.name])
self.assertEqual(
cursor.fetchall(),
- [('text_pattern_ops', constraint.name)],
+ [("text_pattern_ops", constraint.name)],
)
- Scene.objects.create(scene='Scene 10', setting='The dark forest of Ewing')
+ Scene.objects.create(scene="Scene 10", setting="The dark forest of Ewing")
with self.assertRaises(IntegrityError), transaction.atomic():
- Scene.objects.create(scene='ScEnE 10', setting="Sir Bedemir's Castle")
- Scene.objects.create(scene='Scene 5', setting="Sir Bedemir's Castle")
+ Scene.objects.create(scene="ScEnE 10", setting="Sir Bedemir's Castle")
+ Scene.objects.create(scene="Scene 5", setting="Sir Bedemir's Castle")
# Drop the constraint.
with connection.schema_editor() as editor:
editor.remove_constraint(Scene, constraint)
self.assertNotIn(constraint.name, self.get_constraints(Scene._meta.db_table))
- Scene.objects.create(scene='ScEnE 10', setting="Sir Bedemir's Castle")
+ Scene.objects.create(scene="ScEnE 10", setting="Sir Bedemir's Castle")
-@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
+@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
class ExclusionConstraintTests(PostgreSQLTestCase):
def get_constraints(self, table):
"""Get the constraints on the table using a new cursor."""
@@ -207,102 +217,104 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
return connection.introspection.get_constraints(cursor, table)
def test_invalid_condition(self):
- msg = 'ExclusionConstraint.condition must be a Q instance.'
+ msg = "ExclusionConstraint.condition must be a Q instance."
with self.assertRaisesMessage(ValueError, msg):
ExclusionConstraint(
- index_type='GIST',
- name='exclude_invalid_condition',
- expressions=[(F('datespan'), RangeOperators.OVERLAPS)],
- condition=F('invalid'),
+ index_type="GIST",
+ name="exclude_invalid_condition",
+ expressions=[(F("datespan"), RangeOperators.OVERLAPS)],
+ condition=F("invalid"),
)
def test_invalid_index_type(self):
- msg = 'Exclusion constraints only support GiST or SP-GiST indexes.'
+ msg = "Exclusion constraints only support GiST or SP-GiST indexes."
with self.assertRaisesMessage(ValueError, msg):
ExclusionConstraint(
- index_type='gin',
- name='exclude_invalid_index_type',
- expressions=[(F('datespan'), RangeOperators.OVERLAPS)],
+ index_type="gin",
+ name="exclude_invalid_index_type",
+ expressions=[(F("datespan"), RangeOperators.OVERLAPS)],
)
def test_invalid_expressions(self):
- msg = 'The expressions must be a list of 2-tuples.'
- for expressions in (['foo'], [('foo')], [('foo_1', 'foo_2', 'foo_3')]):
+ msg = "The expressions must be a list of 2-tuples."
+ for expressions in (["foo"], [("foo")], [("foo_1", "foo_2", "foo_3")]):
with self.subTest(expressions), self.assertRaisesMessage(ValueError, msg):
ExclusionConstraint(
- index_type='GIST',
- name='exclude_invalid_expressions',
+ index_type="GIST",
+ name="exclude_invalid_expressions",
expressions=expressions,
)
def test_empty_expressions(self):
- msg = 'At least one expression is required to define an exclusion constraint.'
+ msg = "At least one expression is required to define an exclusion constraint."
for empty_expressions in (None, []):
- with self.subTest(empty_expressions), self.assertRaisesMessage(ValueError, msg):
+ with self.subTest(empty_expressions), self.assertRaisesMessage(
+ ValueError, msg
+ ):
ExclusionConstraint(
- index_type='GIST',
- name='exclude_empty_expressions',
+ index_type="GIST",
+ name="exclude_empty_expressions",
expressions=empty_expressions,
)
def test_invalid_deferrable(self):
- msg = 'ExclusionConstraint.deferrable must be a Deferrable instance.'
+ msg = "ExclusionConstraint.deferrable must be a Deferrable instance."
with self.assertRaisesMessage(ValueError, msg):
ExclusionConstraint(
- name='exclude_invalid_deferrable',
- expressions=[(F('datespan'), RangeOperators.OVERLAPS)],
- deferrable='invalid',
+ name="exclude_invalid_deferrable",
+ expressions=[(F("datespan"), RangeOperators.OVERLAPS)],
+ deferrable="invalid",
)
def test_deferrable_with_condition(self):
- msg = 'ExclusionConstraint with conditions cannot be deferred.'
+ msg = "ExclusionConstraint with conditions cannot be deferred."
with self.assertRaisesMessage(ValueError, msg):
ExclusionConstraint(
- name='exclude_invalid_condition',
- expressions=[(F('datespan'), RangeOperators.OVERLAPS)],
+ name="exclude_invalid_condition",
+ expressions=[(F("datespan"), RangeOperators.OVERLAPS)],
condition=Q(cancelled=False),
deferrable=Deferrable.DEFERRED,
)
def test_invalid_include_type(self):
- msg = 'ExclusionConstraint.include must be a list or tuple.'
+ msg = "ExclusionConstraint.include must be a list or tuple."
with self.assertRaisesMessage(ValueError, msg):
ExclusionConstraint(
- name='exclude_invalid_include',
- expressions=[(F('datespan'), RangeOperators.OVERLAPS)],
- include='invalid',
+ name="exclude_invalid_include",
+ expressions=[(F("datespan"), RangeOperators.OVERLAPS)],
+ include="invalid",
)
@ignore_warnings(category=RemovedInDjango50Warning)
def test_invalid_opclasses_type(self):
- msg = 'ExclusionConstraint.opclasses must be a list or tuple.'
+ msg = "ExclusionConstraint.opclasses must be a list or tuple."
with self.assertRaisesMessage(ValueError, msg):
ExclusionConstraint(
- name='exclude_invalid_opclasses',
- expressions=[(F('datespan'), RangeOperators.OVERLAPS)],
- opclasses='invalid',
+ name="exclude_invalid_opclasses",
+ expressions=[(F("datespan"), RangeOperators.OVERLAPS)],
+ opclasses="invalid",
)
@ignore_warnings(category=RemovedInDjango50Warning)
def test_opclasses_and_expressions_same_length(self):
msg = (
- 'ExclusionConstraint.expressions and '
- 'ExclusionConstraint.opclasses must have the same number of '
- 'elements.'
+ "ExclusionConstraint.expressions and "
+ "ExclusionConstraint.opclasses must have the same number of "
+ "elements."
)
with self.assertRaisesMessage(ValueError, msg):
ExclusionConstraint(
- name='exclude_invalid_expressions_opclasses_length',
- expressions=[(F('datespan'), RangeOperators.OVERLAPS)],
- opclasses=['foo', 'bar'],
+ name="exclude_invalid_expressions_opclasses_length",
+ expressions=[(F("datespan"), RangeOperators.OVERLAPS)],
+ opclasses=["foo", "bar"],
)
def test_repr(self):
constraint = ExclusionConstraint(
- name='exclude_overlapping',
+ name="exclude_overlapping",
expressions=[
- (F('datespan'), RangeOperators.OVERLAPS),
- (F('room'), RangeOperators.EQUAL),
+ (F("datespan"), RangeOperators.OVERLAPS),
+ (F("room"), RangeOperators.EQUAL),
],
)
self.assertEqual(
@@ -311,10 +323,10 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
"(F(datespan), '&&'), (F(room), '=')] name='exclude_overlapping'>",
)
constraint = ExclusionConstraint(
- name='exclude_overlapping',
- expressions=[(F('datespan'), RangeOperators.ADJACENT_TO)],
+ name="exclude_overlapping",
+ expressions=[(F("datespan"), RangeOperators.ADJACENT_TO)],
condition=Q(cancelled=False),
- index_type='SPGiST',
+ index_type="SPGiST",
)
self.assertEqual(
repr(constraint),
@@ -323,8 +335,8 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
"condition=(AND: ('cancelled', False))>",
)
constraint = ExclusionConstraint(
- name='exclude_overlapping',
- expressions=[(F('datespan'), RangeOperators.ADJACENT_TO)],
+ name="exclude_overlapping",
+ expressions=[(F("datespan"), RangeOperators.ADJACENT_TO)],
deferrable=Deferrable.IMMEDIATE,
)
self.assertEqual(
@@ -334,9 +346,9 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
"deferrable=Deferrable.IMMEDIATE>",
)
constraint = ExclusionConstraint(
- name='exclude_overlapping',
- expressions=[(F('datespan'), RangeOperators.ADJACENT_TO)],
- include=['cancelled', 'room'],
+ name="exclude_overlapping",
+ expressions=[(F("datespan"), RangeOperators.ADJACENT_TO)],
+ include=["cancelled", "room"],
)
self.assertEqual(
repr(constraint),
@@ -345,9 +357,9 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
"include=('cancelled', 'room')>",
)
constraint = ExclusionConstraint(
- name='exclude_overlapping',
+ name="exclude_overlapping",
expressions=[
- (OpClass('datespan', name='range_ops'), RangeOperators.ADJACENT_TO),
+ (OpClass("datespan", name="range_ops"), RangeOperators.ADJACENT_TO),
],
)
self.assertEqual(
@@ -359,75 +371,75 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
def test_eq(self):
constraint_1 = ExclusionConstraint(
- name='exclude_overlapping',
+ name="exclude_overlapping",
expressions=[
- (F('datespan'), RangeOperators.OVERLAPS),
- (F('room'), RangeOperators.EQUAL),
+ (F("datespan"), RangeOperators.OVERLAPS),
+ (F("room"), RangeOperators.EQUAL),
],
condition=Q(cancelled=False),
)
constraint_2 = ExclusionConstraint(
- name='exclude_overlapping',
+ name="exclude_overlapping",
expressions=[
- ('datespan', RangeOperators.OVERLAPS),
- ('room', RangeOperators.EQUAL),
+ ("datespan", RangeOperators.OVERLAPS),
+ ("room", RangeOperators.EQUAL),
],
)
constraint_3 = ExclusionConstraint(
- name='exclude_overlapping',
- expressions=[('datespan', RangeOperators.OVERLAPS)],
+ name="exclude_overlapping",
+ expressions=[("datespan", RangeOperators.OVERLAPS)],
condition=Q(cancelled=False),
)
constraint_4 = ExclusionConstraint(
- name='exclude_overlapping',
+ name="exclude_overlapping",
expressions=[
- ('datespan', RangeOperators.OVERLAPS),
- ('room', RangeOperators.EQUAL),
+ ("datespan", RangeOperators.OVERLAPS),
+ ("room", RangeOperators.EQUAL),
],
deferrable=Deferrable.DEFERRED,
)
constraint_5 = ExclusionConstraint(
- name='exclude_overlapping',
+ name="exclude_overlapping",
expressions=[
- ('datespan', RangeOperators.OVERLAPS),
- ('room', RangeOperators.EQUAL),
+ ("datespan", RangeOperators.OVERLAPS),
+ ("room", RangeOperators.EQUAL),
],
deferrable=Deferrable.IMMEDIATE,
)
constraint_6 = ExclusionConstraint(
- name='exclude_overlapping',
+ name="exclude_overlapping",
expressions=[
- ('datespan', RangeOperators.OVERLAPS),
- ('room', RangeOperators.EQUAL),
+ ("datespan", RangeOperators.OVERLAPS),
+ ("room", RangeOperators.EQUAL),
],
deferrable=Deferrable.IMMEDIATE,
- include=['cancelled'],
+ include=["cancelled"],
)
constraint_7 = ExclusionConstraint(
- name='exclude_overlapping',
+ name="exclude_overlapping",
expressions=[
- ('datespan', RangeOperators.OVERLAPS),
- ('room', RangeOperators.EQUAL),
+ ("datespan", RangeOperators.OVERLAPS),
+ ("room", RangeOperators.EQUAL),
],
- include=['cancelled'],
+ include=["cancelled"],
)
with ignore_warnings(category=RemovedInDjango50Warning):
constraint_8 = ExclusionConstraint(
- name='exclude_overlapping',
+ name="exclude_overlapping",
expressions=[
- ('datespan', RangeOperators.OVERLAPS),
- ('room', RangeOperators.EQUAL),
+ ("datespan", RangeOperators.OVERLAPS),
+ ("room", RangeOperators.EQUAL),
],
- include=['cancelled'],
- opclasses=['range_ops', 'range_ops']
+ include=["cancelled"],
+ opclasses=["range_ops", "range_ops"],
)
constraint_9 = ExclusionConstraint(
- name='exclude_overlapping',
+ name="exclude_overlapping",
expressions=[
- ('datespan', RangeOperators.OVERLAPS),
- ('room', RangeOperators.EQUAL),
+ ("datespan", RangeOperators.OVERLAPS),
+ ("room", RangeOperators.EQUAL),
],
- opclasses=['range_ops', 'range_ops']
+ opclasses=["range_ops", "range_ops"],
)
self.assertNotEqual(constraint_2, constraint_9)
self.assertNotEqual(constraint_7, constraint_8)
@@ -445,99 +457,151 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
def test_deconstruct(self):
constraint = ExclusionConstraint(
- name='exclude_overlapping',
- expressions=[('datespan', RangeOperators.OVERLAPS), ('room', RangeOperators.EQUAL)],
+ name="exclude_overlapping",
+ expressions=[
+ ("datespan", RangeOperators.OVERLAPS),
+ ("room", RangeOperators.EQUAL),
+ ],
)
path, args, kwargs = constraint.deconstruct()
- self.assertEqual(path, 'django.contrib.postgres.constraints.ExclusionConstraint')
+ self.assertEqual(
+ path, "django.contrib.postgres.constraints.ExclusionConstraint"
+ )
self.assertEqual(args, ())
- self.assertEqual(kwargs, {
- 'name': 'exclude_overlapping',
- 'expressions': [('datespan', RangeOperators.OVERLAPS), ('room', RangeOperators.EQUAL)],
- })
+ self.assertEqual(
+ kwargs,
+ {
+ "name": "exclude_overlapping",
+ "expressions": [
+ ("datespan", RangeOperators.OVERLAPS),
+ ("room", RangeOperators.EQUAL),
+ ],
+ },
+ )
def test_deconstruct_index_type(self):
constraint = ExclusionConstraint(
- name='exclude_overlapping',
- index_type='SPGIST',
- expressions=[('datespan', RangeOperators.OVERLAPS), ('room', RangeOperators.EQUAL)],
+ name="exclude_overlapping",
+ index_type="SPGIST",
+ expressions=[
+ ("datespan", RangeOperators.OVERLAPS),
+ ("room", RangeOperators.EQUAL),
+ ],
)
path, args, kwargs = constraint.deconstruct()
- self.assertEqual(path, 'django.contrib.postgres.constraints.ExclusionConstraint')
+ self.assertEqual(
+ path, "django.contrib.postgres.constraints.ExclusionConstraint"
+ )
self.assertEqual(args, ())
- self.assertEqual(kwargs, {
- 'name': 'exclude_overlapping',
- 'index_type': 'SPGIST',
- 'expressions': [('datespan', RangeOperators.OVERLAPS), ('room', RangeOperators.EQUAL)],
- })
+ self.assertEqual(
+ kwargs,
+ {
+ "name": "exclude_overlapping",
+ "index_type": "SPGIST",
+ "expressions": [
+ ("datespan", RangeOperators.OVERLAPS),
+ ("room", RangeOperators.EQUAL),
+ ],
+ },
+ )
def test_deconstruct_condition(self):
constraint = ExclusionConstraint(
- name='exclude_overlapping',
- expressions=[('datespan', RangeOperators.OVERLAPS), ('room', RangeOperators.EQUAL)],
+ name="exclude_overlapping",
+ expressions=[
+ ("datespan", RangeOperators.OVERLAPS),
+ ("room", RangeOperators.EQUAL),
+ ],
condition=Q(cancelled=False),
)
path, args, kwargs = constraint.deconstruct()
- self.assertEqual(path, 'django.contrib.postgres.constraints.ExclusionConstraint')
+ self.assertEqual(
+ path, "django.contrib.postgres.constraints.ExclusionConstraint"
+ )
self.assertEqual(args, ())
- self.assertEqual(kwargs, {
- 'name': 'exclude_overlapping',
- 'expressions': [('datespan', RangeOperators.OVERLAPS), ('room', RangeOperators.EQUAL)],
- 'condition': Q(cancelled=False),
- })
+ self.assertEqual(
+ kwargs,
+ {
+ "name": "exclude_overlapping",
+ "expressions": [
+ ("datespan", RangeOperators.OVERLAPS),
+ ("room", RangeOperators.EQUAL),
+ ],
+ "condition": Q(cancelled=False),
+ },
+ )
def test_deconstruct_deferrable(self):
constraint = ExclusionConstraint(
- name='exclude_overlapping',
- expressions=[('datespan', RangeOperators.OVERLAPS)],
+ name="exclude_overlapping",
+ expressions=[("datespan", RangeOperators.OVERLAPS)],
deferrable=Deferrable.DEFERRED,
)
path, args, kwargs = constraint.deconstruct()
- self.assertEqual(path, 'django.contrib.postgres.constraints.ExclusionConstraint')
+ self.assertEqual(
+ path, "django.contrib.postgres.constraints.ExclusionConstraint"
+ )
self.assertEqual(args, ())
- self.assertEqual(kwargs, {
- 'name': 'exclude_overlapping',
- 'expressions': [('datespan', RangeOperators.OVERLAPS)],
- 'deferrable': Deferrable.DEFERRED,
- })
+ self.assertEqual(
+ kwargs,
+ {
+ "name": "exclude_overlapping",
+ "expressions": [("datespan", RangeOperators.OVERLAPS)],
+ "deferrable": Deferrable.DEFERRED,
+ },
+ )
def test_deconstruct_include(self):
constraint = ExclusionConstraint(
- name='exclude_overlapping',
- expressions=[('datespan', RangeOperators.OVERLAPS)],
- include=['cancelled', 'room'],
+ name="exclude_overlapping",
+ expressions=[("datespan", RangeOperators.OVERLAPS)],
+ include=["cancelled", "room"],
)
path, args, kwargs = constraint.deconstruct()
- self.assertEqual(path, 'django.contrib.postgres.constraints.ExclusionConstraint')
+ self.assertEqual(
+ path, "django.contrib.postgres.constraints.ExclusionConstraint"
+ )
self.assertEqual(args, ())
- self.assertEqual(kwargs, {
- 'name': 'exclude_overlapping',
- 'expressions': [('datespan', RangeOperators.OVERLAPS)],
- 'include': ('cancelled', 'room'),
- })
+ self.assertEqual(
+ kwargs,
+ {
+ "name": "exclude_overlapping",
+ "expressions": [("datespan", RangeOperators.OVERLAPS)],
+ "include": ("cancelled", "room"),
+ },
+ )
@ignore_warnings(category=RemovedInDjango50Warning)
def test_deconstruct_opclasses(self):
constraint = ExclusionConstraint(
- name='exclude_overlapping',
- expressions=[('datespan', RangeOperators.OVERLAPS)],
- opclasses=['range_ops'],
+ name="exclude_overlapping",
+ expressions=[("datespan", RangeOperators.OVERLAPS)],
+ opclasses=["range_ops"],
)
path, args, kwargs = constraint.deconstruct()
- self.assertEqual(path, 'django.contrib.postgres.constraints.ExclusionConstraint')
+ self.assertEqual(
+ path, "django.contrib.postgres.constraints.ExclusionConstraint"
+ )
self.assertEqual(args, ())
- self.assertEqual(kwargs, {
- 'name': 'exclude_overlapping',
- 'expressions': [('datespan', RangeOperators.OVERLAPS)],
- 'opclasses': ['range_ops'],
- })
+ self.assertEqual(
+ kwargs,
+ {
+ "name": "exclude_overlapping",
+ "expressions": [("datespan", RangeOperators.OVERLAPS)],
+ "opclasses": ["range_ops"],
+ },
+ )
def _test_range_overlaps(self, constraint):
# Create exclusion constraint.
- self.assertNotIn(constraint.name, self.get_constraints(HotelReservation._meta.db_table))
+ self.assertNotIn(
+ constraint.name, self.get_constraints(HotelReservation._meta.db_table)
+ )
with connection.schema_editor() as editor:
editor.add_constraint(HotelReservation, constraint)
- self.assertIn(constraint.name, self.get_constraints(HotelReservation._meta.db_table))
+ self.assertIn(
+ constraint.name, self.get_constraints(HotelReservation._meta.db_table)
+ )
# Add initial reservations.
room101 = Room.objects.create(number=101)
room102 = Room.objects.create(number=102)
@@ -570,61 +634,63 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
)
reservation.save()
# Valid range.
- HotelReservation.objects.bulk_create([
- # Other room.
- HotelReservation(
- datespan=(datetimes[1].date(), datetimes[2].date()),
- start=datetimes[1],
- end=datetimes[2],
- room=room101,
- ),
- # Cancelled reservation.
- HotelReservation(
- datespan=(datetimes[1].date(), datetimes[1].date()),
- start=datetimes[1],
- end=datetimes[2],
- room=room102,
- cancelled=True,
- ),
- # Other adjacent dates.
- HotelReservation(
- datespan=(datetimes[3].date(), datetimes[4].date()),
- start=datetimes[3],
- end=datetimes[4],
- room=room102,
- ),
- ])
+ HotelReservation.objects.bulk_create(
+ [
+ # Other room.
+ HotelReservation(
+ datespan=(datetimes[1].date(), datetimes[2].date()),
+ start=datetimes[1],
+ end=datetimes[2],
+ room=room101,
+ ),
+ # Cancelled reservation.
+ HotelReservation(
+ datespan=(datetimes[1].date(), datetimes[1].date()),
+ start=datetimes[1],
+ end=datetimes[2],
+ room=room102,
+ cancelled=True,
+ ),
+ # Other adjacent dates.
+ HotelReservation(
+ datespan=(datetimes[3].date(), datetimes[4].date()),
+ start=datetimes[3],
+ end=datetimes[4],
+ room=room102,
+ ),
+ ]
+ )
@ignore_warnings(category=RemovedInDjango50Warning)
def test_range_overlaps_custom_opclasses(self):
class TsTzRange(Func):
- function = 'TSTZRANGE'
+ function = "TSTZRANGE"
output_field = DateTimeRangeField()
constraint = ExclusionConstraint(
- name='exclude_overlapping_reservations_custom',
+ name="exclude_overlapping_reservations_custom",
expressions=[
- (TsTzRange('start', 'end', RangeBoundary()), RangeOperators.OVERLAPS),
- ('room', RangeOperators.EQUAL)
+ (TsTzRange("start", "end", RangeBoundary()), RangeOperators.OVERLAPS),
+ ("room", RangeOperators.EQUAL),
],
condition=Q(cancelled=False),
- opclasses=['range_ops', 'gist_int4_ops'],
+ opclasses=["range_ops", "gist_int4_ops"],
)
self._test_range_overlaps(constraint)
def test_range_overlaps_custom(self):
class TsTzRange(Func):
- function = 'TSTZRANGE'
+ function = "TSTZRANGE"
output_field = DateTimeRangeField()
constraint = ExclusionConstraint(
- name='exclude_overlapping_reservations_custom_opclass',
+ name="exclude_overlapping_reservations_custom_opclass",
expressions=[
(
- OpClass(TsTzRange('start', 'end', RangeBoundary()), 'range_ops'),
+ OpClass(TsTzRange("start", "end", RangeBoundary()), "range_ops"),
RangeOperators.OVERLAPS,
),
- (OpClass('room', 'gist_int4_ops'), RangeOperators.EQUAL),
+ (OpClass("room", "gist_int4_ops"), RangeOperators.EQUAL),
],
condition=Q(cancelled=False),
)
@@ -632,21 +698,23 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
def test_range_overlaps(self):
constraint = ExclusionConstraint(
- name='exclude_overlapping_reservations',
+ name="exclude_overlapping_reservations",
expressions=[
- (F('datespan'), RangeOperators.OVERLAPS),
- ('room', RangeOperators.EQUAL)
+ (F("datespan"), RangeOperators.OVERLAPS),
+ ("room", RangeOperators.EQUAL),
],
condition=Q(cancelled=False),
)
self._test_range_overlaps(constraint)
def test_range_adjacent(self):
- constraint_name = 'ints_adjacent'
- self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+ constraint_name = "ints_adjacent"
+ self.assertNotIn(
+ constraint_name, self.get_constraints(RangesModel._meta.db_table)
+ )
constraint = ExclusionConstraint(
name=constraint_name,
- expressions=[('ints', RangeOperators.ADJACENT_TO)],
+ expressions=[("ints", RangeOperators.ADJACENT_TO)],
)
with connection.schema_editor() as editor:
editor.add_constraint(RangesModel, constraint)
@@ -659,26 +727,28 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
# Drop the constraint.
with connection.schema_editor() as editor:
editor.remove_constraint(RangesModel, constraint)
- self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+ self.assertNotIn(
+ constraint_name, self.get_constraints(RangesModel._meta.db_table)
+ )
def test_expressions_with_params(self):
- constraint_name = 'scene_left_equal'
+ constraint_name = "scene_left_equal"
self.assertNotIn(constraint_name, self.get_constraints(Scene._meta.db_table))
constraint = ExclusionConstraint(
name=constraint_name,
- expressions=[(Left('scene', 4), RangeOperators.EQUAL)],
+ expressions=[(Left("scene", 4), RangeOperators.EQUAL)],
)
with connection.schema_editor() as editor:
editor.add_constraint(Scene, constraint)
self.assertIn(constraint_name, self.get_constraints(Scene._meta.db_table))
def test_expressions_with_key_transform(self):
- constraint_name = 'exclude_overlapping_reservations_smoking'
+ constraint_name = "exclude_overlapping_reservations_smoking"
constraint = ExclusionConstraint(
name=constraint_name,
expressions=[
- (F('datespan'), RangeOperators.OVERLAPS),
- (KeyTextTransform('smoking', 'requirements'), RangeOperators.EQUAL),
+ (F("datespan"), RangeOperators.OVERLAPS),
+ (KeyTextTransform("smoking", "requirements"), RangeOperators.EQUAL),
],
)
with connection.schema_editor() as editor:
@@ -689,10 +759,10 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
)
def test_index_transform(self):
- constraint_name = 'first_index_equal'
+ constraint_name = "first_index_equal"
constraint = ExclusionConstraint(
name=constraint_name,
- expressions=[('field__0', RangeOperators.EQUAL)],
+ expressions=[("field__0", RangeOperators.EQUAL)],
)
with connection.schema_editor() as editor:
editor.add_constraint(IntegerArrayModel, constraint)
@@ -702,11 +772,13 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
)
def test_range_adjacent_initially_deferred(self):
- constraint_name = 'ints_adjacent_deferred'
- self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+ constraint_name = "ints_adjacent_deferred"
+ self.assertNotIn(
+ constraint_name, self.get_constraints(RangesModel._meta.db_table)
+ )
constraint = ExclusionConstraint(
name=constraint_name,
- expressions=[('ints', RangeOperators.ADJACENT_TO)],
+ expressions=[("ints", RangeOperators.ADJACENT_TO)],
deferrable=Deferrable.DEFERRED,
)
with connection.schema_editor() as editor:
@@ -718,21 +790,23 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
with self.assertRaises(IntegrityError):
with transaction.atomic(), connection.cursor() as cursor:
quoted_name = connection.ops.quote_name(constraint_name)
- cursor.execute('SET CONSTRAINTS %s IMMEDIATE' % quoted_name)
+ cursor.execute("SET CONSTRAINTS %s IMMEDIATE" % quoted_name)
# Remove adjacent range before the end of transaction.
adjacent_range.delete()
RangesModel.objects.create(ints=(10, 19))
RangesModel.objects.create(ints=(51, 60))
- @skipUnlessDBFeature('supports_covering_gist_indexes')
+ @skipUnlessDBFeature("supports_covering_gist_indexes")
def test_range_adjacent_gist_include(self):
- constraint_name = 'ints_adjacent_gist_include'
- self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+ constraint_name = "ints_adjacent_gist_include"
+ self.assertNotIn(
+ constraint_name, self.get_constraints(RangesModel._meta.db_table)
+ )
constraint = ExclusionConstraint(
name=constraint_name,
- expressions=[('ints', RangeOperators.ADJACENT_TO)],
- index_type='gist',
- include=['decimals', 'ints'],
+ expressions=[("ints", RangeOperators.ADJACENT_TO)],
+ index_type="gist",
+ include=["decimals", "ints"],
)
with connection.schema_editor() as editor:
editor.add_constraint(RangesModel, constraint)
@@ -743,15 +817,17 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
RangesModel.objects.create(ints=(10, 19))
RangesModel.objects.create(ints=(51, 60))
- @skipUnlessDBFeature('supports_covering_spgist_indexes')
+ @skipUnlessDBFeature("supports_covering_spgist_indexes")
def test_range_adjacent_spgist_include(self):
- constraint_name = 'ints_adjacent_spgist_include'
- self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+ constraint_name = "ints_adjacent_spgist_include"
+ self.assertNotIn(
+ constraint_name, self.get_constraints(RangesModel._meta.db_table)
+ )
constraint = ExclusionConstraint(
name=constraint_name,
- expressions=[('ints', RangeOperators.ADJACENT_TO)],
- index_type='spgist',
- include=['decimals', 'ints'],
+ expressions=[("ints", RangeOperators.ADJACENT_TO)],
+ index_type="spgist",
+ include=["decimals", "ints"],
)
with connection.schema_editor() as editor:
editor.add_constraint(RangesModel, constraint)
@@ -762,60 +838,68 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
RangesModel.objects.create(ints=(10, 19))
RangesModel.objects.create(ints=(51, 60))
- @skipUnlessDBFeature('supports_covering_gist_indexes')
+ @skipUnlessDBFeature("supports_covering_gist_indexes")
def test_range_adjacent_gist_include_condition(self):
- constraint_name = 'ints_adjacent_gist_include_condition'
- self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+ constraint_name = "ints_adjacent_gist_include_condition"
+ self.assertNotIn(
+ constraint_name, self.get_constraints(RangesModel._meta.db_table)
+ )
constraint = ExclusionConstraint(
name=constraint_name,
- expressions=[('ints', RangeOperators.ADJACENT_TO)],
- index_type='gist',
- include=['decimals'],
+ expressions=[("ints", RangeOperators.ADJACENT_TO)],
+ index_type="gist",
+ include=["decimals"],
condition=Q(id__gte=100),
)
with connection.schema_editor() as editor:
editor.add_constraint(RangesModel, constraint)
self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
- @skipUnlessDBFeature('supports_covering_spgist_indexes')
+ @skipUnlessDBFeature("supports_covering_spgist_indexes")
def test_range_adjacent_spgist_include_condition(self):
- constraint_name = 'ints_adjacent_spgist_include_condition'
- self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+ constraint_name = "ints_adjacent_spgist_include_condition"
+ self.assertNotIn(
+ constraint_name, self.get_constraints(RangesModel._meta.db_table)
+ )
constraint = ExclusionConstraint(
name=constraint_name,
- expressions=[('ints', RangeOperators.ADJACENT_TO)],
- index_type='spgist',
- include=['decimals'],
+ expressions=[("ints", RangeOperators.ADJACENT_TO)],
+ index_type="spgist",
+ include=["decimals"],
condition=Q(id__gte=100),
)
with connection.schema_editor() as editor:
editor.add_constraint(RangesModel, constraint)
self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
- @skipUnlessDBFeature('supports_covering_gist_indexes')
+ @skipUnlessDBFeature("supports_covering_gist_indexes")
def test_range_adjacent_gist_include_deferrable(self):
- constraint_name = 'ints_adjacent_gist_include_deferrable'
- self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+ constraint_name = "ints_adjacent_gist_include_deferrable"
+ self.assertNotIn(
+ constraint_name, self.get_constraints(RangesModel._meta.db_table)
+ )
constraint = ExclusionConstraint(
name=constraint_name,
- expressions=[('ints', RangeOperators.ADJACENT_TO)],
- index_type='gist',
- include=['decimals'],
+ expressions=[("ints", RangeOperators.ADJACENT_TO)],
+ index_type="gist",
+ include=["decimals"],
deferrable=Deferrable.DEFERRED,
)
with connection.schema_editor() as editor:
editor.add_constraint(RangesModel, constraint)
self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
- @skipUnlessDBFeature('supports_covering_spgist_indexes')
+ @skipUnlessDBFeature("supports_covering_spgist_indexes")
def test_range_adjacent_spgist_include_deferrable(self):
- constraint_name = 'ints_adjacent_spgist_include_deferrable'
- self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+ constraint_name = "ints_adjacent_spgist_include_deferrable"
+ self.assertNotIn(
+ constraint_name, self.get_constraints(RangesModel._meta.db_table)
+ )
constraint = ExclusionConstraint(
name=constraint_name,
- expressions=[('ints', RangeOperators.ADJACENT_TO)],
- index_type='spgist',
- include=['decimals'],
+ expressions=[("ints", RangeOperators.ADJACENT_TO)],
+ index_type="spgist",
+ include=["decimals"],
deferrable=Deferrable.DEFERRED,
)
with connection.schema_editor() as editor:
@@ -823,53 +907,55 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
def test_gist_include_not_supported(self):
- constraint_name = 'ints_adjacent_gist_include_not_supported'
+ constraint_name = "ints_adjacent_gist_include_not_supported"
constraint = ExclusionConstraint(
name=constraint_name,
- expressions=[('ints', RangeOperators.ADJACENT_TO)],
- index_type='gist',
- include=['id'],
+ expressions=[("ints", RangeOperators.ADJACENT_TO)],
+ index_type="gist",
+ include=["id"],
)
msg = (
- 'Covering exclusion constraints using a GiST index require '
- 'PostgreSQL 12+.'
+ "Covering exclusion constraints using a GiST index require "
+ "PostgreSQL 12+."
)
with connection.schema_editor() as editor:
with mock.patch(
- 'django.db.backends.postgresql.features.DatabaseFeatures.supports_covering_gist_indexes',
+ "django.db.backends.postgresql.features.DatabaseFeatures.supports_covering_gist_indexes",
False,
):
with self.assertRaisesMessage(NotSupportedError, msg):
editor.add_constraint(RangesModel, constraint)
def test_spgist_include_not_supported(self):
- constraint_name = 'ints_adjacent_spgist_include_not_supported'
+ constraint_name = "ints_adjacent_spgist_include_not_supported"
constraint = ExclusionConstraint(
name=constraint_name,
- expressions=[('ints', RangeOperators.ADJACENT_TO)],
- index_type='spgist',
- include=['id'],
+ expressions=[("ints", RangeOperators.ADJACENT_TO)],
+ index_type="spgist",
+ include=["id"],
)
msg = (
- 'Covering exclusion constraints using an SP-GiST index require '
- 'PostgreSQL 14+.'
+ "Covering exclusion constraints using an SP-GiST index require "
+ "PostgreSQL 14+."
)
with connection.schema_editor() as editor:
with mock.patch(
- 'django.db.backends.postgresql.features.DatabaseFeatures.'
- 'supports_covering_spgist_indexes',
+ "django.db.backends.postgresql.features.DatabaseFeatures."
+ "supports_covering_spgist_indexes",
False,
):
with self.assertRaisesMessage(NotSupportedError, msg):
editor.add_constraint(RangesModel, constraint)
def test_range_adjacent_opclass(self):
- constraint_name = 'ints_adjacent_opclass'
- self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+ constraint_name = "ints_adjacent_opclass"
+ self.assertNotIn(
+ constraint_name, self.get_constraints(RangesModel._meta.db_table)
+ )
constraint = ExclusionConstraint(
name=constraint_name,
expressions=[
- (OpClass('ints', name='range_ops'), RangeOperators.ADJACENT_TO),
+ (OpClass("ints", name="range_ops"), RangeOperators.ADJACENT_TO),
],
)
with connection.schema_editor() as editor:
@@ -880,7 +966,7 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
cursor.execute(SchemaTests.get_opclass_query, [constraint_name])
self.assertEqual(
cursor.fetchall(),
- [('range_ops', constraint_name)],
+ [("range_ops", constraint_name)],
)
RangesModel.objects.create(ints=(20, 50))
with self.assertRaises(IntegrityError), transaction.atomic():
@@ -890,15 +976,19 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
# Drop the constraint.
with connection.schema_editor() as editor:
editor.remove_constraint(RangesModel, constraint)
- self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+ self.assertNotIn(
+ constraint_name, self.get_constraints(RangesModel._meta.db_table)
+ )
def test_range_adjacent_opclass_condition(self):
- constraint_name = 'ints_adjacent_opclass_condition'
- self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+ constraint_name = "ints_adjacent_opclass_condition"
+ self.assertNotIn(
+ constraint_name, self.get_constraints(RangesModel._meta.db_table)
+ )
constraint = ExclusionConstraint(
name=constraint_name,
expressions=[
- (OpClass('ints', name='range_ops'), RangeOperators.ADJACENT_TO),
+ (OpClass("ints", name="range_ops"), RangeOperators.ADJACENT_TO),
],
condition=Q(id__gte=100),
)
@@ -907,12 +997,14 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
def test_range_adjacent_opclass_deferrable(self):
- constraint_name = 'ints_adjacent_opclass_deferrable'
- self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+ constraint_name = "ints_adjacent_opclass_deferrable"
+ self.assertNotIn(
+ constraint_name, self.get_constraints(RangesModel._meta.db_table)
+ )
constraint = ExclusionConstraint(
name=constraint_name,
expressions=[
- (OpClass('ints', name='range_ops'), RangeOperators.ADJACENT_TO),
+ (OpClass("ints", name="range_ops"), RangeOperators.ADJACENT_TO),
],
deferrable=Deferrable.DEFERRED,
)
@@ -920,51 +1012,55 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
editor.add_constraint(RangesModel, constraint)
self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
- @skipUnlessDBFeature('supports_covering_gist_indexes')
+ @skipUnlessDBFeature("supports_covering_gist_indexes")
def test_range_adjacent_gist_opclass_include(self):
- constraint_name = 'ints_adjacent_gist_opclass_include'
- self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+ constraint_name = "ints_adjacent_gist_opclass_include"
+ self.assertNotIn(
+ constraint_name, self.get_constraints(RangesModel._meta.db_table)
+ )
constraint = ExclusionConstraint(
name=constraint_name,
expressions=[
- (OpClass('ints', name='range_ops'), RangeOperators.ADJACENT_TO),
+ (OpClass("ints", name="range_ops"), RangeOperators.ADJACENT_TO),
],
- index_type='gist',
- include=['decimals'],
+ index_type="gist",
+ include=["decimals"],
)
with connection.schema_editor() as editor:
editor.add_constraint(RangesModel, constraint)
self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
- @skipUnlessDBFeature('supports_covering_spgist_indexes')
+ @skipUnlessDBFeature("supports_covering_spgist_indexes")
def test_range_adjacent_spgist_opclass_include(self):
- constraint_name = 'ints_adjacent_spgist_opclass_include'
- self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+ constraint_name = "ints_adjacent_spgist_opclass_include"
+ self.assertNotIn(
+ constraint_name, self.get_constraints(RangesModel._meta.db_table)
+ )
constraint = ExclusionConstraint(
name=constraint_name,
expressions=[
- (OpClass('ints', name='range_ops'), RangeOperators.ADJACENT_TO),
+ (OpClass("ints", name="range_ops"), RangeOperators.ADJACENT_TO),
],
- index_type='spgist',
- include=['decimals'],
+ index_type="spgist",
+ include=["decimals"],
)
with connection.schema_editor() as editor:
editor.add_constraint(RangesModel, constraint)
self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
def test_range_equal_cast(self):
- constraint_name = 'exclusion_equal_room_cast'
+ constraint_name = "exclusion_equal_room_cast"
self.assertNotIn(constraint_name, self.get_constraints(Room._meta.db_table))
constraint = ExclusionConstraint(
name=constraint_name,
- expressions=[(Cast('number', IntegerField()), RangeOperators.EQUAL)],
+ expressions=[(Cast("number", IntegerField()), RangeOperators.EQUAL)],
)
with connection.schema_editor() as editor:
editor.add_constraint(Room, constraint)
self.assertIn(constraint_name, self.get_constraints(Room._meta.db_table))
-@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
+@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
class ExclusionConstraintOpclassesDepracationTests(PostgreSQLTestCase):
def get_constraints(self, table):
"""Get the constraints on the table using a new cursor."""
@@ -973,23 +1069,23 @@ class ExclusionConstraintOpclassesDepracationTests(PostgreSQLTestCase):
def test_warning(self):
msg = (
- 'The opclasses argument is deprecated in favor of using '
- 'django.contrib.postgres.indexes.OpClass in '
- 'ExclusionConstraint.expressions.'
+ "The opclasses argument is deprecated in favor of using "
+ "django.contrib.postgres.indexes.OpClass in "
+ "ExclusionConstraint.expressions."
)
with self.assertWarnsMessage(RemovedInDjango50Warning, msg):
ExclusionConstraint(
- name='exclude_overlapping',
- expressions=[(F('datespan'), RangeOperators.ADJACENT_TO)],
- opclasses=['range_ops'],
+ name="exclude_overlapping",
+ expressions=[(F("datespan"), RangeOperators.ADJACENT_TO)],
+ opclasses=["range_ops"],
)
@ignore_warnings(category=RemovedInDjango50Warning)
def test_repr(self):
constraint = ExclusionConstraint(
- name='exclude_overlapping',
- expressions=[(F('datespan'), RangeOperators.ADJACENT_TO)],
- opclasses=['range_ops'],
+ name="exclude_overlapping",
+ expressions=[(F("datespan"), RangeOperators.ADJACENT_TO)],
+ opclasses=["range_ops"],
)
self.assertEqual(
repr(constraint),
@@ -1000,12 +1096,14 @@ class ExclusionConstraintOpclassesDepracationTests(PostgreSQLTestCase):
@ignore_warnings(category=RemovedInDjango50Warning)
def test_range_adjacent_opclasses(self):
- constraint_name = 'ints_adjacent_opclasses'
- self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+ constraint_name = "ints_adjacent_opclasses"
+ self.assertNotIn(
+ constraint_name, self.get_constraints(RangesModel._meta.db_table)
+ )
constraint = ExclusionConstraint(
name=constraint_name,
- expressions=[('ints', RangeOperators.ADJACENT_TO)],
- opclasses=['range_ops'],
+ expressions=[("ints", RangeOperators.ADJACENT_TO)],
+ opclasses=["range_ops"],
)
with connection.schema_editor() as editor:
editor.add_constraint(RangesModel, constraint)
@@ -1015,7 +1113,7 @@ class ExclusionConstraintOpclassesDepracationTests(PostgreSQLTestCase):
cursor.execute(SchemaTests.get_opclass_query, [constraint.name])
self.assertEqual(
cursor.fetchall(),
- [('range_ops', constraint.name)],
+ [("range_ops", constraint.name)],
)
RangesModel.objects.create(ints=(20, 50))
with self.assertRaises(IntegrityError), transaction.atomic():
@@ -1025,16 +1123,20 @@ class ExclusionConstraintOpclassesDepracationTests(PostgreSQLTestCase):
# Drop the constraint.
with connection.schema_editor() as editor:
editor.remove_constraint(RangesModel, constraint)
- self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+ self.assertNotIn(
+ constraint_name, self.get_constraints(RangesModel._meta.db_table)
+ )
@ignore_warnings(category=RemovedInDjango50Warning)
def test_range_adjacent_opclasses_condition(self):
- constraint_name = 'ints_adjacent_opclasses_condition'
- self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+ constraint_name = "ints_adjacent_opclasses_condition"
+ self.assertNotIn(
+ constraint_name, self.get_constraints(RangesModel._meta.db_table)
+ )
constraint = ExclusionConstraint(
name=constraint_name,
- expressions=[('ints', RangeOperators.ADJACENT_TO)],
- opclasses=['range_ops'],
+ expressions=[("ints", RangeOperators.ADJACENT_TO)],
+ opclasses=["range_ops"],
condition=Q(id__gte=100),
)
with connection.schema_editor() as editor:
@@ -1043,12 +1145,14 @@ class ExclusionConstraintOpclassesDepracationTests(PostgreSQLTestCase):
@ignore_warnings(category=RemovedInDjango50Warning)
def test_range_adjacent_opclasses_deferrable(self):
- constraint_name = 'ints_adjacent_opclasses_deferrable'
- self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+ constraint_name = "ints_adjacent_opclasses_deferrable"
+ self.assertNotIn(
+ constraint_name, self.get_constraints(RangesModel._meta.db_table)
+ )
constraint = ExclusionConstraint(
name=constraint_name,
- expressions=[('ints', RangeOperators.ADJACENT_TO)],
- opclasses=['range_ops'],
+ expressions=[("ints", RangeOperators.ADJACENT_TO)],
+ opclasses=["range_ops"],
deferrable=Deferrable.DEFERRED,
)
with connection.schema_editor() as editor:
@@ -1056,32 +1160,36 @@ class ExclusionConstraintOpclassesDepracationTests(PostgreSQLTestCase):
self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
@ignore_warnings(category=RemovedInDjango50Warning)
- @skipUnlessDBFeature('supports_covering_gist_indexes')
+ @skipUnlessDBFeature("supports_covering_gist_indexes")
def test_range_adjacent_gist_opclasses_include(self):
- constraint_name = 'ints_adjacent_gist_opclasses_include'
- self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+ constraint_name = "ints_adjacent_gist_opclasses_include"
+ self.assertNotIn(
+ constraint_name, self.get_constraints(RangesModel._meta.db_table)
+ )
constraint = ExclusionConstraint(
name=constraint_name,
- expressions=[('ints', RangeOperators.ADJACENT_TO)],
- index_type='gist',
- opclasses=['range_ops'],
- include=['decimals'],
+ expressions=[("ints", RangeOperators.ADJACENT_TO)],
+ index_type="gist",
+ opclasses=["range_ops"],
+ include=["decimals"],
)
with connection.schema_editor() as editor:
editor.add_constraint(RangesModel, constraint)
self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
@ignore_warnings(category=RemovedInDjango50Warning)
- @skipUnlessDBFeature('supports_covering_spgist_indexes')
+ @skipUnlessDBFeature("supports_covering_spgist_indexes")
def test_range_adjacent_spgist_opclasses_include(self):
- constraint_name = 'ints_adjacent_spgist_opclasses_include'
- self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+ constraint_name = "ints_adjacent_spgist_opclasses_include"
+ self.assertNotIn(
+ constraint_name, self.get_constraints(RangesModel._meta.db_table)
+ )
constraint = ExclusionConstraint(
name=constraint_name,
- expressions=[('ints', RangeOperators.ADJACENT_TO)],
- index_type='spgist',
- opclasses=['range_ops'],
- include=['decimals'],
+ expressions=[("ints", RangeOperators.ADJACENT_TO)],
+ index_type="spgist",
+ opclasses=["range_ops"],
+ include=["decimals"],
)
with connection.schema_editor() as editor:
editor.add_constraint(RangesModel, constraint)
diff --git a/tests/postgres_tests/test_functions.py b/tests/postgres_tests/test_functions.py
index 875a4b9520..e500580645 100644
--- a/tests/postgres_tests/test_functions.py
+++ b/tests/postgres_tests/test_functions.py
@@ -9,7 +9,6 @@ from .models import NowTestModel, UUIDTestModel
class TestTransactionNow(PostgreSQLTestCase):
-
def test_transaction_now(self):
"""
The test case puts everything under a transaction, so two models
@@ -30,7 +29,6 @@ class TestTransactionNow(PostgreSQLTestCase):
class TestRandomUUID(PostgreSQLTestCase):
-
def test_random_uuid(self):
m1 = UUIDTestModel.objects.create()
m2 = UUIDTestModel.objects.create()
diff --git a/tests/postgres_tests/test_hstore.py b/tests/postgres_tests/test_hstore.py
index b6fdd2da60..2aaad637c6 100644
--- a/tests/postgres_tests/test_hstore.py
+++ b/tests/postgres_tests/test_hstore.py
@@ -21,7 +21,7 @@ except ImportError:
class SimpleTests(PostgreSQLTestCase):
def test_save_load_success(self):
- value = {'a': 'b'}
+ value = {"a": "b"}
instance = HStoreModel(field=value)
instance.save()
reloaded = HStoreModel.objects.get()
@@ -34,15 +34,15 @@ class SimpleTests(PostgreSQLTestCase):
self.assertIsNone(reloaded.field)
def test_value_null(self):
- value = {'a': None}
+ value = {"a": None}
instance = HStoreModel(field=value)
instance.save()
reloaded = HStoreModel.objects.get()
self.assertEqual(reloaded.field, value)
def test_key_val_cast_to_string(self):
- value = {'a': 1, 'b': 'B', 2: 'c', 'ï': 'ê'}
- expected_value = {'a': '1', 'b': 'B', '2': 'c', 'ï': 'ê'}
+ value = {"a": 1, "b": "B", 2: "c", "ï": "ê"}
+ expected_value = {"a": "1", "b": "B", "2": "c", "ï": "ê"}
instance = HStoreModel.objects.create(field=value)
instance = HStoreModel.objects.get()
@@ -51,17 +51,17 @@ class SimpleTests(PostgreSQLTestCase):
instance = HStoreModel.objects.get(field__a=1)
self.assertEqual(instance.field, expected_value)
- instance = HStoreModel.objects.get(field__has_keys=[2, 'a', 'ï'])
+ instance = HStoreModel.objects.get(field__has_keys=[2, "a", "ï"])
self.assertEqual(instance.field, expected_value)
def test_array_field(self):
value = [
- {'a': 1, 'b': 'B', 2: 'c', 'ï': 'ê'},
- {'a': 1, 'b': 'B', 2: 'c', 'ï': 'ê'},
+ {"a": 1, "b": "B", 2: "c", "ï": "ê"},
+ {"a": 1, "b": "B", 2: "c", "ï": "ê"},
]
expected_value = [
- {'a': '1', 'b': 'B', '2': 'c', 'ï': 'ê'},
- {'a': '1', 'b': 'B', '2': 'c', 'ï': 'ê'},
+ {"a": "1", "b": "B", "2": "c", "ï": "ê"},
+ {"a": "1", "b": "B", "2": "c", "ï": "ê"},
]
instance = HStoreModel.objects.create(array_field=value)
instance.refresh_from_db()
@@ -69,231 +69,225 @@ class SimpleTests(PostgreSQLTestCase):
class TestQuerying(PostgreSQLTestCase):
-
@classmethod
def setUpTestData(cls):
- cls.objs = HStoreModel.objects.bulk_create([
- HStoreModel(field={'a': 'b'}),
- HStoreModel(field={'a': 'b', 'c': 'd'}),
- HStoreModel(field={'c': 'd'}),
- HStoreModel(field={}),
- HStoreModel(field=None),
- HStoreModel(field={'cat': 'TigrOu', 'breed': 'birman'}),
- HStoreModel(field={'cat': 'minou', 'breed': 'ragdoll'}),
- HStoreModel(field={'cat': 'kitty', 'breed': 'Persian'}),
- HStoreModel(field={'cat': 'Kit Kat', 'breed': 'persian'}),
- ])
+ cls.objs = HStoreModel.objects.bulk_create(
+ [
+ HStoreModel(field={"a": "b"}),
+ HStoreModel(field={"a": "b", "c": "d"}),
+ HStoreModel(field={"c": "d"}),
+ HStoreModel(field={}),
+ HStoreModel(field=None),
+ HStoreModel(field={"cat": "TigrOu", "breed": "birman"}),
+ HStoreModel(field={"cat": "minou", "breed": "ragdoll"}),
+ HStoreModel(field={"cat": "kitty", "breed": "Persian"}),
+ HStoreModel(field={"cat": "Kit Kat", "breed": "persian"}),
+ ]
+ )
def test_exact(self):
self.assertSequenceEqual(
- HStoreModel.objects.filter(field__exact={'a': 'b'}),
- self.objs[:1]
+ HStoreModel.objects.filter(field__exact={"a": "b"}), self.objs[:1]
)
def test_contained_by(self):
self.assertSequenceEqual(
- HStoreModel.objects.filter(field__contained_by={'a': 'b', 'c': 'd'}),
- self.objs[:4]
+ HStoreModel.objects.filter(field__contained_by={"a": "b", "c": "d"}),
+ self.objs[:4],
)
def test_contains(self):
self.assertSequenceEqual(
- HStoreModel.objects.filter(field__contains={'a': 'b'}),
- self.objs[:2]
+ HStoreModel.objects.filter(field__contains={"a": "b"}), self.objs[:2]
)
def test_in_generator(self):
def search():
- yield {'a': 'b'}
+ yield {"a": "b"}
+
self.assertSequenceEqual(
- HStoreModel.objects.filter(field__in=search()),
- self.objs[:1]
+ HStoreModel.objects.filter(field__in=search()), self.objs[:1]
)
def test_has_key(self):
self.assertSequenceEqual(
- HStoreModel.objects.filter(field__has_key='c'),
- self.objs[1:3]
+ HStoreModel.objects.filter(field__has_key="c"), self.objs[1:3]
)
def test_has_keys(self):
self.assertSequenceEqual(
- HStoreModel.objects.filter(field__has_keys=['a', 'c']),
- self.objs[1:2]
+ HStoreModel.objects.filter(field__has_keys=["a", "c"]), self.objs[1:2]
)
def test_has_any_keys(self):
self.assertSequenceEqual(
- HStoreModel.objects.filter(field__has_any_keys=['a', 'c']),
- self.objs[:3]
+ HStoreModel.objects.filter(field__has_any_keys=["a", "c"]), self.objs[:3]
)
def test_key_transform(self):
self.assertSequenceEqual(
- HStoreModel.objects.filter(field__a='b'),
- self.objs[:2]
+ HStoreModel.objects.filter(field__a="b"), self.objs[:2]
)
def test_key_transform_raw_expression(self):
- expr = RawSQL('%s::hstore', ['x => b, y => c'])
+ expr = RawSQL("%s::hstore", ["x => b, y => c"])
self.assertSequenceEqual(
- HStoreModel.objects.filter(field__a=KeyTransform('x', expr)),
- self.objs[:2]
+ HStoreModel.objects.filter(field__a=KeyTransform("x", expr)), self.objs[:2]
)
def test_key_transform_annotation(self):
- qs = HStoreModel.objects.annotate(a=F('field__a'))
+ qs = HStoreModel.objects.annotate(a=F("field__a"))
self.assertCountEqual(
- qs.values_list('a', flat=True),
- ['b', 'b', None, None, None, None, None, None, None],
+ qs.values_list("a", flat=True),
+ ["b", "b", None, None, None, None, None, None, None],
)
def test_keys(self):
self.assertSequenceEqual(
- HStoreModel.objects.filter(field__keys=['a']),
- self.objs[:1]
+ HStoreModel.objects.filter(field__keys=["a"]), self.objs[:1]
)
def test_values(self):
self.assertSequenceEqual(
- HStoreModel.objects.filter(field__values=['b']),
- self.objs[:1]
+ HStoreModel.objects.filter(field__values=["b"]), self.objs[:1]
)
def test_field_chaining_contains(self):
self.assertSequenceEqual(
- HStoreModel.objects.filter(field__a__contains='b'),
- self.objs[:2]
+ HStoreModel.objects.filter(field__a__contains="b"), self.objs[:2]
)
def test_field_chaining_icontains(self):
self.assertSequenceEqual(
- HStoreModel.objects.filter(field__cat__icontains='INo'),
+ HStoreModel.objects.filter(field__cat__icontains="INo"),
[self.objs[6]],
)
def test_field_chaining_startswith(self):
self.assertSequenceEqual(
- HStoreModel.objects.filter(field__cat__startswith='kit'),
+ HStoreModel.objects.filter(field__cat__startswith="kit"),
[self.objs[7]],
)
def test_field_chaining_istartswith(self):
self.assertSequenceEqual(
- HStoreModel.objects.filter(field__cat__istartswith='kit'),
+ HStoreModel.objects.filter(field__cat__istartswith="kit"),
self.objs[7:],
)
def test_field_chaining_endswith(self):
self.assertSequenceEqual(
- HStoreModel.objects.filter(field__cat__endswith='ou'),
+ HStoreModel.objects.filter(field__cat__endswith="ou"),
[self.objs[6]],
)
def test_field_chaining_iendswith(self):
self.assertSequenceEqual(
- HStoreModel.objects.filter(field__cat__iendswith='ou'),
+ HStoreModel.objects.filter(field__cat__iendswith="ou"),
self.objs[5:7],
)
def test_field_chaining_iexact(self):
self.assertSequenceEqual(
- HStoreModel.objects.filter(field__breed__iexact='persian'),
+ HStoreModel.objects.filter(field__breed__iexact="persian"),
self.objs[7:],
)
def test_field_chaining_regex(self):
self.assertSequenceEqual(
- HStoreModel.objects.filter(field__cat__regex=r'ou$'),
+ HStoreModel.objects.filter(field__cat__regex=r"ou$"),
[self.objs[6]],
)
def test_field_chaining_iregex(self):
self.assertSequenceEqual(
- HStoreModel.objects.filter(field__cat__iregex=r'oU$'),
+ HStoreModel.objects.filter(field__cat__iregex=r"oU$"),
self.objs[5:7],
)
def test_order_by_field(self):
more_objs = (
- HStoreModel.objects.create(field={'g': '637'}),
- HStoreModel.objects.create(field={'g': '002'}),
- HStoreModel.objects.create(field={'g': '042'}),
- HStoreModel.objects.create(field={'g': '981'}),
+ HStoreModel.objects.create(field={"g": "637"}),
+ HStoreModel.objects.create(field={"g": "002"}),
+ HStoreModel.objects.create(field={"g": "042"}),
+ HStoreModel.objects.create(field={"g": "981"}),
)
self.assertSequenceEqual(
- HStoreModel.objects.filter(field__has_key='g').order_by('field__g'),
- [more_objs[1], more_objs[2], more_objs[0], more_objs[3]]
+ HStoreModel.objects.filter(field__has_key="g").order_by("field__g"),
+ [more_objs[1], more_objs[2], more_objs[0], more_objs[3]],
)
def test_keys_contains(self):
self.assertSequenceEqual(
- HStoreModel.objects.filter(field__keys__contains=['a']),
- self.objs[:2]
+ HStoreModel.objects.filter(field__keys__contains=["a"]), self.objs[:2]
)
def test_values_overlap(self):
self.assertSequenceEqual(
- HStoreModel.objects.filter(field__values__overlap=['b', 'd']),
- self.objs[:3]
+ HStoreModel.objects.filter(field__values__overlap=["b", "d"]), self.objs[:3]
)
def test_key_isnull(self):
- obj = HStoreModel.objects.create(field={'a': None})
+ obj = HStoreModel.objects.create(field={"a": None})
self.assertSequenceEqual(
HStoreModel.objects.filter(field__a__isnull=True),
self.objs[2:9] + [obj],
)
self.assertSequenceEqual(
- HStoreModel.objects.filter(field__a__isnull=False),
- self.objs[:2]
+ HStoreModel.objects.filter(field__a__isnull=False), self.objs[:2]
)
def test_usage_in_subquery(self):
self.assertSequenceEqual(
- HStoreModel.objects.filter(id__in=HStoreModel.objects.filter(field__a='b')),
- self.objs[:2]
+ HStoreModel.objects.filter(id__in=HStoreModel.objects.filter(field__a="b")),
+ self.objs[:2],
)
def test_key_sql_injection(self):
with CaptureQueriesContext(connection) as queries:
self.assertFalse(
- HStoreModel.objects.filter(**{
- "field__test' = 'a') OR 1 = 1 OR ('d": 'x',
- }).exists()
+ HStoreModel.objects.filter(
+ **{
+ "field__test' = 'a') OR 1 = 1 OR ('d": "x",
+ }
+ ).exists()
)
self.assertIn(
"""."field" -> 'test'' = ''a'') OR 1 = 1 OR (''d') = 'x' """,
- queries[0]['sql'],
+ queries[0]["sql"],
)
def test_obj_subquery_lookup(self):
qs = HStoreModel.objects.annotate(
- value=Subquery(HStoreModel.objects.filter(pk=OuterRef('pk')).values('field')),
- ).filter(value__a='b')
+ value=Subquery(
+ HStoreModel.objects.filter(pk=OuterRef("pk")).values("field")
+ ),
+ ).filter(value__a="b")
self.assertSequenceEqual(qs, self.objs[:2])
-@isolate_apps('postgres_tests')
+@isolate_apps("postgres_tests")
class TestChecks(PostgreSQLSimpleTestCase):
-
def test_invalid_default(self):
class MyModel(PostgreSQLModel):
field = HStoreField(default={})
model = MyModel()
- self.assertEqual(model.check(), [
- checks.Warning(
- msg=(
- "HStoreField default should be a callable instead of an "
- "instance so that it's not shared between all field "
- "instances."
- ),
- hint='Use a callable instead, e.g., use `dict` instead of `{}`.',
- obj=MyModel._meta.get_field('field'),
- id='fields.E010',
- )
- ])
+ self.assertEqual(
+ model.check(),
+ [
+ checks.Warning(
+ msg=(
+ "HStoreField default should be a callable instead of an "
+ "instance so that it's not shared between all field "
+ "instances."
+ ),
+ hint="Use a callable instead, e.g., use `dict` instead of `{}`.",
+ obj=MyModel._meta.get_field("field"),
+ id="fields.E010",
+ )
+ ],
+ )
def test_valid_default(self):
class MyModel(PostgreSQLModel):
@@ -303,83 +297,90 @@ class TestChecks(PostgreSQLSimpleTestCase):
class TestSerialization(PostgreSQLSimpleTestCase):
- test_data = json.dumps([{
- 'model': 'postgres_tests.hstoremodel',
- 'pk': None,
- 'fields': {
- 'field': json.dumps({'a': 'b'}),
- 'array_field': json.dumps([
- json.dumps({'a': 'b'}),
- json.dumps({'b': 'a'}),
- ]),
- },
- }])
+ test_data = json.dumps(
+ [
+ {
+ "model": "postgres_tests.hstoremodel",
+ "pk": None,
+ "fields": {
+ "field": json.dumps({"a": "b"}),
+ "array_field": json.dumps(
+ [
+ json.dumps({"a": "b"}),
+ json.dumps({"b": "a"}),
+ ]
+ ),
+ },
+ }
+ ]
+ )
def test_dumping(self):
- instance = HStoreModel(field={'a': 'b'}, array_field=[{'a': 'b'}, {'b': 'a'}])
- data = serializers.serialize('json', [instance])
+ instance = HStoreModel(field={"a": "b"}, array_field=[{"a": "b"}, {"b": "a"}])
+ data = serializers.serialize("json", [instance])
self.assertEqual(json.loads(data), json.loads(self.test_data))
def test_loading(self):
- instance = list(serializers.deserialize('json', self.test_data))[0].object
- self.assertEqual(instance.field, {'a': 'b'})
- self.assertEqual(instance.array_field, [{'a': 'b'}, {'b': 'a'}])
+ instance = list(serializers.deserialize("json", self.test_data))[0].object
+ self.assertEqual(instance.field, {"a": "b"})
+ self.assertEqual(instance.array_field, [{"a": "b"}, {"b": "a"}])
def test_roundtrip_with_null(self):
- instance = HStoreModel(field={'a': 'b', 'c': None})
- data = serializers.serialize('json', [instance])
- new_instance = list(serializers.deserialize('json', data))[0].object
+ instance = HStoreModel(field={"a": "b", "c": None})
+ data = serializers.serialize("json", [instance])
+ new_instance = list(serializers.deserialize("json", data))[0].object
self.assertEqual(instance.field, new_instance.field)
class TestValidation(PostgreSQLSimpleTestCase):
-
def test_not_a_string(self):
field = HStoreField()
with self.assertRaises(exceptions.ValidationError) as cm:
- field.clean({'a': 1}, None)
- self.assertEqual(cm.exception.code, 'not_a_string')
- self.assertEqual(cm.exception.message % cm.exception.params, 'The value of “a” is not a string or null.')
+ field.clean({"a": 1}, None)
+ self.assertEqual(cm.exception.code, "not_a_string")
+ self.assertEqual(
+ cm.exception.message % cm.exception.params,
+ "The value of “a” is not a string or null.",
+ )
def test_none_allowed_as_value(self):
field = HStoreField()
- self.assertEqual(field.clean({'a': None}, None), {'a': None})
+ self.assertEqual(field.clean({"a": None}, None), {"a": None})
class TestFormField(PostgreSQLSimpleTestCase):
-
def test_valid(self):
field = forms.HStoreField()
value = field.clean('{"a": "b"}')
- self.assertEqual(value, {'a': 'b'})
+ self.assertEqual(value, {"a": "b"})
def test_invalid_json(self):
field = forms.HStoreField()
with self.assertRaises(exceptions.ValidationError) as cm:
field.clean('{"a": "b"')
- self.assertEqual(cm.exception.messages[0], 'Could not load JSON data.')
- self.assertEqual(cm.exception.code, 'invalid_json')
+ self.assertEqual(cm.exception.messages[0], "Could not load JSON data.")
+ self.assertEqual(cm.exception.code, "invalid_json")
def test_non_dict_json(self):
field = forms.HStoreField()
- msg = 'Input must be a JSON dictionary.'
+ msg = "Input must be a JSON dictionary."
with self.assertRaisesMessage(exceptions.ValidationError, msg) as cm:
field.clean('["a", "b", 1]')
- self.assertEqual(cm.exception.code, 'invalid_format')
+ self.assertEqual(cm.exception.code, "invalid_format")
def test_not_string_values(self):
field = forms.HStoreField()
value = field.clean('{"a": 1}')
- self.assertEqual(value, {'a': '1'})
+ self.assertEqual(value, {"a": "1"})
def test_none_value(self):
field = forms.HStoreField()
value = field.clean('{"a": null}')
- self.assertEqual(value, {'a': None})
+ self.assertEqual(value, {"a": None})
def test_empty(self):
field = forms.HStoreField(required=False)
- value = field.clean('')
+ value = field.clean("")
self.assertEqual(value, {})
def test_model_field_formfield(self):
@@ -390,69 +391,71 @@ class TestFormField(PostgreSQLSimpleTestCase):
def test_field_has_changed(self):
class HStoreFormTest(Form):
f1 = forms.HStoreField()
+
form_w_hstore = HStoreFormTest()
self.assertFalse(form_w_hstore.has_changed())
- form_w_hstore = HStoreFormTest({'f1': '{"a": 1}'})
+ form_w_hstore = HStoreFormTest({"f1": '{"a": 1}'})
self.assertTrue(form_w_hstore.has_changed())
- form_w_hstore = HStoreFormTest({'f1': '{"a": 1}'}, initial={'f1': '{"a": 1}'})
+ form_w_hstore = HStoreFormTest({"f1": '{"a": 1}'}, initial={"f1": '{"a": 1}'})
self.assertFalse(form_w_hstore.has_changed())
- form_w_hstore = HStoreFormTest({'f1': '{"a": 2}'}, initial={'f1': '{"a": 1}'})
+ form_w_hstore = HStoreFormTest({"f1": '{"a": 2}'}, initial={"f1": '{"a": 1}'})
self.assertTrue(form_w_hstore.has_changed())
- form_w_hstore = HStoreFormTest({'f1': '{"a": 1}'}, initial={'f1': {"a": 1}})
+ form_w_hstore = HStoreFormTest({"f1": '{"a": 1}'}, initial={"f1": {"a": 1}})
self.assertFalse(form_w_hstore.has_changed())
- form_w_hstore = HStoreFormTest({'f1': '{"a": 2}'}, initial={'f1': {"a": 1}})
+ form_w_hstore = HStoreFormTest({"f1": '{"a": 2}'}, initial={"f1": {"a": 1}})
self.assertTrue(form_w_hstore.has_changed())
class TestValidator(PostgreSQLSimpleTestCase):
-
def test_simple_valid(self):
- validator = KeysValidator(keys=['a', 'b'])
- validator({'a': 'foo', 'b': 'bar', 'c': 'baz'})
+ validator = KeysValidator(keys=["a", "b"])
+ validator({"a": "foo", "b": "bar", "c": "baz"})
def test_missing_keys(self):
- validator = KeysValidator(keys=['a', 'b'])
+ validator = KeysValidator(keys=["a", "b"])
with self.assertRaises(exceptions.ValidationError) as cm:
- validator({'a': 'foo', 'c': 'baz'})
- self.assertEqual(cm.exception.messages[0], 'Some keys were missing: b')
- self.assertEqual(cm.exception.code, 'missing_keys')
+ validator({"a": "foo", "c": "baz"})
+ self.assertEqual(cm.exception.messages[0], "Some keys were missing: b")
+ self.assertEqual(cm.exception.code, "missing_keys")
def test_strict_valid(self):
- validator = KeysValidator(keys=['a', 'b'], strict=True)
- validator({'a': 'foo', 'b': 'bar'})
+ validator = KeysValidator(keys=["a", "b"], strict=True)
+ validator({"a": "foo", "b": "bar"})
def test_extra_keys(self):
- validator = KeysValidator(keys=['a', 'b'], strict=True)
+ validator = KeysValidator(keys=["a", "b"], strict=True)
with self.assertRaises(exceptions.ValidationError) as cm:
- validator({'a': 'foo', 'b': 'bar', 'c': 'baz'})
- self.assertEqual(cm.exception.messages[0], 'Some unknown keys were provided: c')
- self.assertEqual(cm.exception.code, 'extra_keys')
+ validator({"a": "foo", "b": "bar", "c": "baz"})
+ self.assertEqual(cm.exception.messages[0], "Some unknown keys were provided: c")
+ self.assertEqual(cm.exception.code, "extra_keys")
def test_custom_messages(self):
messages = {
- 'missing_keys': 'Foobar',
+ "missing_keys": "Foobar",
}
- validator = KeysValidator(keys=['a', 'b'], strict=True, messages=messages)
+ validator = KeysValidator(keys=["a", "b"], strict=True, messages=messages)
with self.assertRaises(exceptions.ValidationError) as cm:
- validator({'a': 'foo', 'c': 'baz'})
- self.assertEqual(cm.exception.messages[0], 'Foobar')
- self.assertEqual(cm.exception.code, 'missing_keys')
+ validator({"a": "foo", "c": "baz"})
+ self.assertEqual(cm.exception.messages[0], "Foobar")
+ self.assertEqual(cm.exception.code, "missing_keys")
with self.assertRaises(exceptions.ValidationError) as cm:
- validator({'a': 'foo', 'b': 'bar', 'c': 'baz'})
- self.assertEqual(cm.exception.messages[0], 'Some unknown keys were provided: c')
- self.assertEqual(cm.exception.code, 'extra_keys')
+ validator({"a": "foo", "b": "bar", "c": "baz"})
+ self.assertEqual(cm.exception.messages[0], "Some unknown keys were provided: c")
+ self.assertEqual(cm.exception.code, "extra_keys")
def test_deconstruct(self):
messages = {
- 'missing_keys': 'Foobar',
+ "missing_keys": "Foobar",
}
- validator = KeysValidator(keys=['a', 'b'], strict=True, messages=messages)
+ validator = KeysValidator(keys=["a", "b"], strict=True, messages=messages)
path, args, kwargs = validator.deconstruct()
- self.assertEqual(path, 'django.contrib.postgres.validators.KeysValidator')
+ self.assertEqual(path, "django.contrib.postgres.validators.KeysValidator")
self.assertEqual(args, ())
- self.assertEqual(kwargs, {'keys': ['a', 'b'], 'strict': True, 'messages': messages})
+ self.assertEqual(
+ kwargs, {"keys": ["a", "b"], "strict": True, "messages": messages}
+ )
diff --git a/tests/postgres_tests/test_indexes.py b/tests/postgres_tests/test_indexes.py
index 75d0640a08..f57e7a7c80 100644
--- a/tests/postgres_tests/test_indexes.py
+++ b/tests/postgres_tests/test_indexes.py
@@ -1,7 +1,13 @@
from unittest import mock
from django.contrib.postgres.indexes import (
- BloomIndex, BrinIndex, BTreeIndex, GinIndex, GistIndex, HashIndex, OpClass,
+ BloomIndex,
+ BrinIndex,
+ BTreeIndex,
+ GinIndex,
+ GistIndex,
+ HashIndex,
+ OpClass,
SpGistIndex,
)
from django.db import NotSupportedError, connection
@@ -16,191 +22,226 @@ from .models import CharFieldModel, IntegerArrayModel, Scene, TextFieldModel
class IndexTestMixin:
-
def test_name_auto_generation(self):
- index = self.index_class(fields=['field'])
+ index = self.index_class(fields=["field"])
index.set_name_with_model(CharFieldModel)
- self.assertRegex(index.name, r'postgres_te_field_[0-9a-f]{6}_%s' % self.index_class.suffix)
+ self.assertRegex(
+ index.name, r"postgres_te_field_[0-9a-f]{6}_%s" % self.index_class.suffix
+ )
def test_deconstruction_no_customization(self):
- index = self.index_class(fields=['title'], name='test_title_%s' % self.index_class.suffix)
+ index = self.index_class(
+ fields=["title"], name="test_title_%s" % self.index_class.suffix
+ )
path, args, kwargs = index.deconstruct()
- self.assertEqual(path, 'django.contrib.postgres.indexes.%s' % self.index_class.__name__)
+ self.assertEqual(
+ path, "django.contrib.postgres.indexes.%s" % self.index_class.__name__
+ )
self.assertEqual(args, ())
- self.assertEqual(kwargs, {'fields': ['title'], 'name': 'test_title_%s' % self.index_class.suffix})
+ self.assertEqual(
+ kwargs,
+ {"fields": ["title"], "name": "test_title_%s" % self.index_class.suffix},
+ )
def test_deconstruction_with_expressions_no_customization(self):
- name = f'test_title_{self.index_class.suffix}'
- index = self.index_class(Lower('title'), name=name)
+ name = f"test_title_{self.index_class.suffix}"
+ index = self.index_class(Lower("title"), name=name)
path, args, kwargs = index.deconstruct()
self.assertEqual(
path,
- f'django.contrib.postgres.indexes.{self.index_class.__name__}',
+ f"django.contrib.postgres.indexes.{self.index_class.__name__}",
)
- self.assertEqual(args, (Lower('title'),))
- self.assertEqual(kwargs, {'name': name})
+ self.assertEqual(args, (Lower("title"),))
+ self.assertEqual(kwargs, {"name": name})
class BloomIndexTests(IndexTestMixin, PostgreSQLSimpleTestCase):
index_class = BloomIndex
def test_suffix(self):
- self.assertEqual(BloomIndex.suffix, 'bloom')
+ self.assertEqual(BloomIndex.suffix, "bloom")
def test_deconstruction(self):
- index = BloomIndex(fields=['title'], name='test_bloom', length=80, columns=[4])
+ index = BloomIndex(fields=["title"], name="test_bloom", length=80, columns=[4])
path, args, kwargs = index.deconstruct()
- self.assertEqual(path, 'django.contrib.postgres.indexes.BloomIndex')
+ self.assertEqual(path, "django.contrib.postgres.indexes.BloomIndex")
self.assertEqual(args, ())
- self.assertEqual(kwargs, {
- 'fields': ['title'],
- 'name': 'test_bloom',
- 'length': 80,
- 'columns': [4],
- })
+ self.assertEqual(
+ kwargs,
+ {
+ "fields": ["title"],
+ "name": "test_bloom",
+ "length": 80,
+ "columns": [4],
+ },
+ )
def test_invalid_fields(self):
- msg = 'Bloom indexes support a maximum of 32 fields.'
+ msg = "Bloom indexes support a maximum of 32 fields."
with self.assertRaisesMessage(ValueError, msg):
- BloomIndex(fields=['title'] * 33, name='test_bloom')
+ BloomIndex(fields=["title"] * 33, name="test_bloom")
def test_invalid_columns(self):
- msg = 'BloomIndex.columns must be a list or tuple.'
+ msg = "BloomIndex.columns must be a list or tuple."
with self.assertRaisesMessage(ValueError, msg):
- BloomIndex(fields=['title'], name='test_bloom', columns='x')
- msg = 'BloomIndex.columns cannot have more values than fields.'
+ BloomIndex(fields=["title"], name="test_bloom", columns="x")
+ msg = "BloomIndex.columns cannot have more values than fields."
with self.assertRaisesMessage(ValueError, msg):
- BloomIndex(fields=['title'], name='test_bloom', columns=[4, 3])
+ BloomIndex(fields=["title"], name="test_bloom", columns=[4, 3])
def test_invalid_columns_value(self):
- msg = 'BloomIndex.columns must contain integers from 1 to 4095.'
+ msg = "BloomIndex.columns must contain integers from 1 to 4095."
for length in (0, 4096):
with self.subTest(length), self.assertRaisesMessage(ValueError, msg):
- BloomIndex(fields=['title'], name='test_bloom', columns=[length])
+ BloomIndex(fields=["title"], name="test_bloom", columns=[length])
def test_invalid_length(self):
- msg = 'BloomIndex.length must be None or an integer from 1 to 4096.'
+ msg = "BloomIndex.length must be None or an integer from 1 to 4096."
for length in (0, 4097):
with self.subTest(length), self.assertRaisesMessage(ValueError, msg):
- BloomIndex(fields=['title'], name='test_bloom', length=length)
+ BloomIndex(fields=["title"], name="test_bloom", length=length)
class BrinIndexTests(IndexTestMixin, PostgreSQLSimpleTestCase):
index_class = BrinIndex
def test_suffix(self):
- self.assertEqual(BrinIndex.suffix, 'brin')
+ self.assertEqual(BrinIndex.suffix, "brin")
def test_deconstruction(self):
- index = BrinIndex(fields=['title'], name='test_title_brin', autosummarize=True, pages_per_range=16)
+ index = BrinIndex(
+ fields=["title"],
+ name="test_title_brin",
+ autosummarize=True,
+ pages_per_range=16,
+ )
path, args, kwargs = index.deconstruct()
- self.assertEqual(path, 'django.contrib.postgres.indexes.BrinIndex')
+ self.assertEqual(path, "django.contrib.postgres.indexes.BrinIndex")
self.assertEqual(args, ())
- self.assertEqual(kwargs, {
- 'fields': ['title'],
- 'name': 'test_title_brin',
- 'autosummarize': True,
- 'pages_per_range': 16,
- })
+ self.assertEqual(
+ kwargs,
+ {
+ "fields": ["title"],
+ "name": "test_title_brin",
+ "autosummarize": True,
+ "pages_per_range": 16,
+ },
+ )
def test_invalid_pages_per_range(self):
- with self.assertRaisesMessage(ValueError, 'pages_per_range must be None or a positive integer'):
- BrinIndex(fields=['title'], name='test_title_brin', pages_per_range=0)
+ with self.assertRaisesMessage(
+ ValueError, "pages_per_range must be None or a positive integer"
+ ):
+ BrinIndex(fields=["title"], name="test_title_brin", pages_per_range=0)
class BTreeIndexTests(IndexTestMixin, PostgreSQLSimpleTestCase):
index_class = BTreeIndex
def test_suffix(self):
- self.assertEqual(BTreeIndex.suffix, 'btree')
+ self.assertEqual(BTreeIndex.suffix, "btree")
def test_deconstruction(self):
- index = BTreeIndex(fields=['title'], name='test_title_btree', fillfactor=80)
+ index = BTreeIndex(fields=["title"], name="test_title_btree", fillfactor=80)
path, args, kwargs = index.deconstruct()
- self.assertEqual(path, 'django.contrib.postgres.indexes.BTreeIndex')
+ self.assertEqual(path, "django.contrib.postgres.indexes.BTreeIndex")
self.assertEqual(args, ())
- self.assertEqual(kwargs, {'fields': ['title'], 'name': 'test_title_btree', 'fillfactor': 80})
+ self.assertEqual(
+ kwargs, {"fields": ["title"], "name": "test_title_btree", "fillfactor": 80}
+ )
class GinIndexTests(IndexTestMixin, PostgreSQLSimpleTestCase):
index_class = GinIndex
def test_suffix(self):
- self.assertEqual(GinIndex.suffix, 'gin')
+ self.assertEqual(GinIndex.suffix, "gin")
def test_deconstruction(self):
index = GinIndex(
- fields=['title'],
- name='test_title_gin',
+ fields=["title"],
+ name="test_title_gin",
fastupdate=True,
gin_pending_list_limit=128,
)
path, args, kwargs = index.deconstruct()
- self.assertEqual(path, 'django.contrib.postgres.indexes.GinIndex')
+ self.assertEqual(path, "django.contrib.postgres.indexes.GinIndex")
self.assertEqual(args, ())
- self.assertEqual(kwargs, {
- 'fields': ['title'],
- 'name': 'test_title_gin',
- 'fastupdate': True,
- 'gin_pending_list_limit': 128,
- })
+ self.assertEqual(
+ kwargs,
+ {
+ "fields": ["title"],
+ "name": "test_title_gin",
+ "fastupdate": True,
+ "gin_pending_list_limit": 128,
+ },
+ )
class GistIndexTests(IndexTestMixin, PostgreSQLSimpleTestCase):
index_class = GistIndex
def test_suffix(self):
- self.assertEqual(GistIndex.suffix, 'gist')
+ self.assertEqual(GistIndex.suffix, "gist")
def test_deconstruction(self):
- index = GistIndex(fields=['title'], name='test_title_gist', buffering=False, fillfactor=80)
+ index = GistIndex(
+ fields=["title"], name="test_title_gist", buffering=False, fillfactor=80
+ )
path, args, kwargs = index.deconstruct()
- self.assertEqual(path, 'django.contrib.postgres.indexes.GistIndex')
+ self.assertEqual(path, "django.contrib.postgres.indexes.GistIndex")
self.assertEqual(args, ())
- self.assertEqual(kwargs, {
- 'fields': ['title'],
- 'name': 'test_title_gist',
- 'buffering': False,
- 'fillfactor': 80,
- })
+ self.assertEqual(
+ kwargs,
+ {
+ "fields": ["title"],
+ "name": "test_title_gist",
+ "buffering": False,
+ "fillfactor": 80,
+ },
+ )
class HashIndexTests(IndexTestMixin, PostgreSQLSimpleTestCase):
index_class = HashIndex
def test_suffix(self):
- self.assertEqual(HashIndex.suffix, 'hash')
+ self.assertEqual(HashIndex.suffix, "hash")
def test_deconstruction(self):
- index = HashIndex(fields=['title'], name='test_title_hash', fillfactor=80)
+ index = HashIndex(fields=["title"], name="test_title_hash", fillfactor=80)
path, args, kwargs = index.deconstruct()
- self.assertEqual(path, 'django.contrib.postgres.indexes.HashIndex')
+ self.assertEqual(path, "django.contrib.postgres.indexes.HashIndex")
self.assertEqual(args, ())
- self.assertEqual(kwargs, {'fields': ['title'], 'name': 'test_title_hash', 'fillfactor': 80})
+ self.assertEqual(
+ kwargs, {"fields": ["title"], "name": "test_title_hash", "fillfactor": 80}
+ )
class SpGistIndexTests(IndexTestMixin, PostgreSQLSimpleTestCase):
index_class = SpGistIndex
def test_suffix(self):
- self.assertEqual(SpGistIndex.suffix, 'spgist')
+ self.assertEqual(SpGistIndex.suffix, "spgist")
def test_deconstruction(self):
- index = SpGistIndex(fields=['title'], name='test_title_spgist', fillfactor=80)
+ index = SpGistIndex(fields=["title"], name="test_title_spgist", fillfactor=80)
path, args, kwargs = index.deconstruct()
- self.assertEqual(path, 'django.contrib.postgres.indexes.SpGistIndex')
+ self.assertEqual(path, "django.contrib.postgres.indexes.SpGistIndex")
self.assertEqual(args, ())
- self.assertEqual(kwargs, {'fields': ['title'], 'name': 'test_title_spgist', 'fillfactor': 80})
+ self.assertEqual(
+ kwargs, {"fields": ["title"], "name": "test_title_spgist", "fillfactor": 80}
+ )
-@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
+@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
class SchemaTests(PostgreSQLTestCase):
- get_opclass_query = '''
+ get_opclass_query = """
SELECT opcname, c.relname FROM pg_opclass AS oc
JOIN pg_index as i on oc.oid = ANY(i.indclass)
JOIN pg_class as c on c.oid = i.indexrelid
WHERE c.relname = %s
- '''
+ """
def get_constraints(self, table):
"""
@@ -211,229 +252,274 @@ class SchemaTests(PostgreSQLTestCase):
def test_gin_index(self):
# Ensure the table is there and doesn't have an index.
- self.assertNotIn('field', self.get_constraints(IntegerArrayModel._meta.db_table))
+ self.assertNotIn(
+ "field", self.get_constraints(IntegerArrayModel._meta.db_table)
+ )
# Add the index
- index_name = 'integer_array_model_field_gin'
- index = GinIndex(fields=['field'], name=index_name)
+ index_name = "integer_array_model_field_gin"
+ index = GinIndex(fields=["field"], name=index_name)
with connection.schema_editor() as editor:
editor.add_index(IntegerArrayModel, index)
constraints = self.get_constraints(IntegerArrayModel._meta.db_table)
# Check gin index was added
- self.assertEqual(constraints[index_name]['type'], GinIndex.suffix)
+ self.assertEqual(constraints[index_name]["type"], GinIndex.suffix)
# Drop the index
with connection.schema_editor() as editor:
editor.remove_index(IntegerArrayModel, index)
- self.assertNotIn(index_name, self.get_constraints(IntegerArrayModel._meta.db_table))
+ self.assertNotIn(
+ index_name, self.get_constraints(IntegerArrayModel._meta.db_table)
+ )
def test_gin_fastupdate(self):
- index_name = 'integer_array_gin_fastupdate'
- index = GinIndex(fields=['field'], name=index_name, fastupdate=False)
+ index_name = "integer_array_gin_fastupdate"
+ index = GinIndex(fields=["field"], name=index_name, fastupdate=False)
with connection.schema_editor() as editor:
editor.add_index(IntegerArrayModel, index)
constraints = self.get_constraints(IntegerArrayModel._meta.db_table)
- self.assertEqual(constraints[index_name]['type'], 'gin')
- self.assertEqual(constraints[index_name]['options'], ['fastupdate=off'])
+ self.assertEqual(constraints[index_name]["type"], "gin")
+ self.assertEqual(constraints[index_name]["options"], ["fastupdate=off"])
with connection.schema_editor() as editor:
editor.remove_index(IntegerArrayModel, index)
- self.assertNotIn(index_name, self.get_constraints(IntegerArrayModel._meta.db_table))
+ self.assertNotIn(
+ index_name, self.get_constraints(IntegerArrayModel._meta.db_table)
+ )
def test_partial_gin_index(self):
with register_lookup(CharField, Length):
- index_name = 'char_field_gin_partial_idx'
- index = GinIndex(fields=['field'], name=index_name, condition=Q(field__length=40))
+ index_name = "char_field_gin_partial_idx"
+ index = GinIndex(
+ fields=["field"], name=index_name, condition=Q(field__length=40)
+ )
with connection.schema_editor() as editor:
editor.add_index(CharFieldModel, index)
constraints = self.get_constraints(CharFieldModel._meta.db_table)
- self.assertEqual(constraints[index_name]['type'], 'gin')
+ self.assertEqual(constraints[index_name]["type"], "gin")
with connection.schema_editor() as editor:
editor.remove_index(CharFieldModel, index)
- self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table))
+ self.assertNotIn(
+ index_name, self.get_constraints(CharFieldModel._meta.db_table)
+ )
def test_partial_gin_index_with_tablespace(self):
with register_lookup(CharField, Length):
- index_name = 'char_field_gin_partial_idx'
+ index_name = "char_field_gin_partial_idx"
index = GinIndex(
- fields=['field'],
+ fields=["field"],
name=index_name,
condition=Q(field__length=40),
- db_tablespace='pg_default',
+ db_tablespace="pg_default",
)
with connection.schema_editor() as editor:
editor.add_index(CharFieldModel, index)
- self.assertIn('TABLESPACE "pg_default" ', str(index.create_sql(CharFieldModel, editor)))
+ self.assertIn(
+ 'TABLESPACE "pg_default" ',
+ str(index.create_sql(CharFieldModel, editor)),
+ )
constraints = self.get_constraints(CharFieldModel._meta.db_table)
- self.assertEqual(constraints[index_name]['type'], 'gin')
+ self.assertEqual(constraints[index_name]["type"], "gin")
with connection.schema_editor() as editor:
editor.remove_index(CharFieldModel, index)
- self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table))
+ self.assertNotIn(
+ index_name, self.get_constraints(CharFieldModel._meta.db_table)
+ )
def test_gin_parameters(self):
- index_name = 'integer_array_gin_params'
- index = GinIndex(fields=['field'], name=index_name, fastupdate=True, gin_pending_list_limit=64)
+ index_name = "integer_array_gin_params"
+ index = GinIndex(
+ fields=["field"],
+ name=index_name,
+ fastupdate=True,
+ gin_pending_list_limit=64,
+ )
with connection.schema_editor() as editor:
editor.add_index(IntegerArrayModel, index)
constraints = self.get_constraints(IntegerArrayModel._meta.db_table)
- self.assertEqual(constraints[index_name]['type'], 'gin')
- self.assertEqual(constraints[index_name]['options'], ['gin_pending_list_limit=64', 'fastupdate=on'])
+ self.assertEqual(constraints[index_name]["type"], "gin")
+ self.assertEqual(
+ constraints[index_name]["options"],
+ ["gin_pending_list_limit=64", "fastupdate=on"],
+ )
with connection.schema_editor() as editor:
editor.remove_index(IntegerArrayModel, index)
- self.assertNotIn(index_name, self.get_constraints(IntegerArrayModel._meta.db_table))
+ self.assertNotIn(
+ index_name, self.get_constraints(IntegerArrayModel._meta.db_table)
+ )
def test_trigram_op_class_gin_index(self):
- index_name = 'trigram_op_class_gin'
- index = GinIndex(OpClass(F('scene'), name='gin_trgm_ops'), name=index_name)
+ index_name = "trigram_op_class_gin"
+ index = GinIndex(OpClass(F("scene"), name="gin_trgm_ops"), name=index_name)
with connection.schema_editor() as editor:
editor.add_index(Scene, index)
with editor.connection.cursor() as cursor:
cursor.execute(self.get_opclass_query, [index_name])
- self.assertCountEqual(cursor.fetchall(), [('gin_trgm_ops', index_name)])
+ self.assertCountEqual(cursor.fetchall(), [("gin_trgm_ops", index_name)])
constraints = self.get_constraints(Scene._meta.db_table)
self.assertIn(index_name, constraints)
- self.assertIn(constraints[index_name]['type'], GinIndex.suffix)
+ self.assertIn(constraints[index_name]["type"], GinIndex.suffix)
with connection.schema_editor() as editor:
editor.remove_index(Scene, index)
self.assertNotIn(index_name, self.get_constraints(Scene._meta.db_table))
def test_cast_search_vector_gin_index(self):
- index_name = 'cast_search_vector_gin'
- index = GinIndex(Cast('field', SearchVectorField()), name=index_name)
+ index_name = "cast_search_vector_gin"
+ index = GinIndex(Cast("field", SearchVectorField()), name=index_name)
with connection.schema_editor() as editor:
editor.add_index(TextFieldModel, index)
sql = index.create_sql(TextFieldModel, editor)
table = TextFieldModel._meta.db_table
constraints = self.get_constraints(table)
self.assertIn(index_name, constraints)
- self.assertIn(constraints[index_name]['type'], GinIndex.suffix)
- self.assertIs(sql.references_column(table, 'field'), True)
- self.assertIn('::tsvector', str(sql))
+ self.assertIn(constraints[index_name]["type"], GinIndex.suffix)
+ self.assertIs(sql.references_column(table, "field"), True)
+ self.assertIn("::tsvector", str(sql))
with connection.schema_editor() as editor:
editor.remove_index(TextFieldModel, index)
self.assertNotIn(index_name, self.get_constraints(table))
def test_bloom_index(self):
- index_name = 'char_field_model_field_bloom'
- index = BloomIndex(fields=['field'], name=index_name)
+ index_name = "char_field_model_field_bloom"
+ index = BloomIndex(fields=["field"], name=index_name)
with connection.schema_editor() as editor:
editor.add_index(CharFieldModel, index)
constraints = self.get_constraints(CharFieldModel._meta.db_table)
- self.assertEqual(constraints[index_name]['type'], BloomIndex.suffix)
+ self.assertEqual(constraints[index_name]["type"], BloomIndex.suffix)
with connection.schema_editor() as editor:
editor.remove_index(CharFieldModel, index)
- self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table))
+ self.assertNotIn(
+ index_name, self.get_constraints(CharFieldModel._meta.db_table)
+ )
def test_bloom_parameters(self):
- index_name = 'char_field_model_field_bloom_params'
- index = BloomIndex(fields=['field'], name=index_name, length=512, columns=[3])
+ index_name = "char_field_model_field_bloom_params"
+ index = BloomIndex(fields=["field"], name=index_name, length=512, columns=[3])
with connection.schema_editor() as editor:
editor.add_index(CharFieldModel, index)
constraints = self.get_constraints(CharFieldModel._meta.db_table)
- self.assertEqual(constraints[index_name]['type'], BloomIndex.suffix)
- self.assertEqual(constraints[index_name]['options'], ['length=512', 'col1=3'])
+ self.assertEqual(constraints[index_name]["type"], BloomIndex.suffix)
+ self.assertEqual(constraints[index_name]["options"], ["length=512", "col1=3"])
with connection.schema_editor() as editor:
editor.remove_index(CharFieldModel, index)
- self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table))
+ self.assertNotIn(
+ index_name, self.get_constraints(CharFieldModel._meta.db_table)
+ )
def test_brin_index(self):
- index_name = 'char_field_model_field_brin'
- index = BrinIndex(fields=['field'], name=index_name, pages_per_range=4)
+ index_name = "char_field_model_field_brin"
+ index = BrinIndex(fields=["field"], name=index_name, pages_per_range=4)
with connection.schema_editor() as editor:
editor.add_index(CharFieldModel, index)
constraints = self.get_constraints(CharFieldModel._meta.db_table)
- self.assertEqual(constraints[index_name]['type'], BrinIndex.suffix)
- self.assertEqual(constraints[index_name]['options'], ['pages_per_range=4'])
+ self.assertEqual(constraints[index_name]["type"], BrinIndex.suffix)
+ self.assertEqual(constraints[index_name]["options"], ["pages_per_range=4"])
with connection.schema_editor() as editor:
editor.remove_index(CharFieldModel, index)
- self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table))
+ self.assertNotIn(
+ index_name, self.get_constraints(CharFieldModel._meta.db_table)
+ )
def test_brin_parameters(self):
- index_name = 'char_field_brin_params'
- index = BrinIndex(fields=['field'], name=index_name, autosummarize=True)
+ index_name = "char_field_brin_params"
+ index = BrinIndex(fields=["field"], name=index_name, autosummarize=True)
with connection.schema_editor() as editor:
editor.add_index(CharFieldModel, index)
constraints = self.get_constraints(CharFieldModel._meta.db_table)
- self.assertEqual(constraints[index_name]['type'], BrinIndex.suffix)
- self.assertEqual(constraints[index_name]['options'], ['autosummarize=on'])
+ self.assertEqual(constraints[index_name]["type"], BrinIndex.suffix)
+ self.assertEqual(constraints[index_name]["options"], ["autosummarize=on"])
with connection.schema_editor() as editor:
editor.remove_index(CharFieldModel, index)
- self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table))
+ self.assertNotIn(
+ index_name, self.get_constraints(CharFieldModel._meta.db_table)
+ )
def test_btree_index(self):
# Ensure the table is there and doesn't have an index.
- self.assertNotIn('field', self.get_constraints(CharFieldModel._meta.db_table))
+ self.assertNotIn("field", self.get_constraints(CharFieldModel._meta.db_table))
# Add the index.
- index_name = 'char_field_model_field_btree'
- index = BTreeIndex(fields=['field'], name=index_name)
+ index_name = "char_field_model_field_btree"
+ index = BTreeIndex(fields=["field"], name=index_name)
with connection.schema_editor() as editor:
editor.add_index(CharFieldModel, index)
constraints = self.get_constraints(CharFieldModel._meta.db_table)
# The index was added.
- self.assertEqual(constraints[index_name]['type'], BTreeIndex.suffix)
+ self.assertEqual(constraints[index_name]["type"], BTreeIndex.suffix)
# Drop the index.
with connection.schema_editor() as editor:
editor.remove_index(CharFieldModel, index)
- self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table))
+ self.assertNotIn(
+ index_name, self.get_constraints(CharFieldModel._meta.db_table)
+ )
def test_btree_parameters(self):
- index_name = 'integer_array_btree_fillfactor'
- index = BTreeIndex(fields=['field'], name=index_name, fillfactor=80)
+ index_name = "integer_array_btree_fillfactor"
+ index = BTreeIndex(fields=["field"], name=index_name, fillfactor=80)
with connection.schema_editor() as editor:
editor.add_index(CharFieldModel, index)
constraints = self.get_constraints(CharFieldModel._meta.db_table)
- self.assertEqual(constraints[index_name]['type'], BTreeIndex.suffix)
- self.assertEqual(constraints[index_name]['options'], ['fillfactor=80'])
+ self.assertEqual(constraints[index_name]["type"], BTreeIndex.suffix)
+ self.assertEqual(constraints[index_name]["options"], ["fillfactor=80"])
with connection.schema_editor() as editor:
editor.remove_index(CharFieldModel, index)
- self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table))
+ self.assertNotIn(
+ index_name, self.get_constraints(CharFieldModel._meta.db_table)
+ )
def test_gist_index(self):
# Ensure the table is there and doesn't have an index.
- self.assertNotIn('field', self.get_constraints(CharFieldModel._meta.db_table))
+ self.assertNotIn("field", self.get_constraints(CharFieldModel._meta.db_table))
# Add the index.
- index_name = 'char_field_model_field_gist'
- index = GistIndex(fields=['field'], name=index_name)
+ index_name = "char_field_model_field_gist"
+ index = GistIndex(fields=["field"], name=index_name)
with connection.schema_editor() as editor:
editor.add_index(CharFieldModel, index)
constraints = self.get_constraints(CharFieldModel._meta.db_table)
# The index was added.
- self.assertEqual(constraints[index_name]['type'], GistIndex.suffix)
+ self.assertEqual(constraints[index_name]["type"], GistIndex.suffix)
# Drop the index.
with connection.schema_editor() as editor:
editor.remove_index(CharFieldModel, index)
- self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table))
+ self.assertNotIn(
+ index_name, self.get_constraints(CharFieldModel._meta.db_table)
+ )
def test_gist_parameters(self):
- index_name = 'integer_array_gist_buffering'
- index = GistIndex(fields=['field'], name=index_name, buffering=True, fillfactor=80)
+ index_name = "integer_array_gist_buffering"
+ index = GistIndex(
+ fields=["field"], name=index_name, buffering=True, fillfactor=80
+ )
with connection.schema_editor() as editor:
editor.add_index(CharFieldModel, index)
constraints = self.get_constraints(CharFieldModel._meta.db_table)
- self.assertEqual(constraints[index_name]['type'], GistIndex.suffix)
- self.assertEqual(constraints[index_name]['options'], ['buffering=on', 'fillfactor=80'])
+ self.assertEqual(constraints[index_name]["type"], GistIndex.suffix)
+ self.assertEqual(
+ constraints[index_name]["options"], ["buffering=on", "fillfactor=80"]
+ )
with connection.schema_editor() as editor:
editor.remove_index(CharFieldModel, index)
- self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table))
+ self.assertNotIn(
+ index_name, self.get_constraints(CharFieldModel._meta.db_table)
+ )
- @skipUnlessDBFeature('supports_covering_gist_indexes')
+ @skipUnlessDBFeature("supports_covering_gist_indexes")
def test_gist_include(self):
- index_name = 'scene_gist_include_setting'
- index = GistIndex(name=index_name, fields=['scene'], include=['setting'])
+ index_name = "scene_gist_include_setting"
+ index = GistIndex(name=index_name, fields=["scene"], include=["setting"])
with connection.schema_editor() as editor:
editor.add_index(Scene, index)
constraints = self.get_constraints(Scene._meta.db_table)
self.assertIn(index_name, constraints)
- self.assertEqual(constraints[index_name]['type'], GistIndex.suffix)
- self.assertEqual(constraints[index_name]['columns'], ['scene', 'setting'])
+ self.assertEqual(constraints[index_name]["type"], GistIndex.suffix)
+ self.assertEqual(constraints[index_name]["columns"], ["scene", "setting"])
with connection.schema_editor() as editor:
editor.remove_index(Scene, index)
self.assertNotIn(index_name, self.get_constraints(Scene._meta.db_table))
def test_gist_include_not_supported(self):
- index_name = 'gist_include_exception'
- index = GistIndex(fields=['scene'], name=index_name, include=['setting'])
- msg = 'Covering GiST indexes require PostgreSQL 12+.'
+ index_name = "gist_include_exception"
+ index = GistIndex(fields=["scene"], name=index_name, include=["setting"])
+ msg = "Covering GiST indexes require PostgreSQL 12+."
with self.assertRaisesMessage(NotSupportedError, msg):
with mock.patch(
- 'django.db.backends.postgresql.features.DatabaseFeatures.supports_covering_gist_indexes',
+ "django.db.backends.postgresql.features.DatabaseFeatures.supports_covering_gist_indexes",
False,
):
with connection.schema_editor() as editor:
@@ -441,11 +527,11 @@ class SchemaTests(PostgreSQLTestCase):
self.assertNotIn(index_name, self.get_constraints(Scene._meta.db_table))
def test_tsvector_op_class_gist_index(self):
- index_name = 'tsvector_op_class_gist'
+ index_name = "tsvector_op_class_gist"
index = GistIndex(
OpClass(
- SearchVector('scene', 'setting', config='english'),
- name='tsvector_ops',
+ SearchVector("scene", "setting", config="english"),
+ name="tsvector_ops",
),
name=index_name,
)
@@ -455,90 +541,98 @@ class SchemaTests(PostgreSQLTestCase):
table = Scene._meta.db_table
constraints = self.get_constraints(table)
self.assertIn(index_name, constraints)
- self.assertIn(constraints[index_name]['type'], GistIndex.suffix)
- self.assertIs(sql.references_column(table, 'scene'), True)
- self.assertIs(sql.references_column(table, 'setting'), True)
+ self.assertIn(constraints[index_name]["type"], GistIndex.suffix)
+ self.assertIs(sql.references_column(table, "scene"), True)
+ self.assertIs(sql.references_column(table, "setting"), True)
with connection.schema_editor() as editor:
editor.remove_index(Scene, index)
self.assertNotIn(index_name, self.get_constraints(table))
def test_hash_index(self):
# Ensure the table is there and doesn't have an index.
- self.assertNotIn('field', self.get_constraints(CharFieldModel._meta.db_table))
+ self.assertNotIn("field", self.get_constraints(CharFieldModel._meta.db_table))
# Add the index.
- index_name = 'char_field_model_field_hash'
- index = HashIndex(fields=['field'], name=index_name)
+ index_name = "char_field_model_field_hash"
+ index = HashIndex(fields=["field"], name=index_name)
with connection.schema_editor() as editor:
editor.add_index(CharFieldModel, index)
constraints = self.get_constraints(CharFieldModel._meta.db_table)
# The index was added.
- self.assertEqual(constraints[index_name]['type'], HashIndex.suffix)
+ self.assertEqual(constraints[index_name]["type"], HashIndex.suffix)
# Drop the index.
with connection.schema_editor() as editor:
editor.remove_index(CharFieldModel, index)
- self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table))
+ self.assertNotIn(
+ index_name, self.get_constraints(CharFieldModel._meta.db_table)
+ )
def test_hash_parameters(self):
- index_name = 'integer_array_hash_fillfactor'
- index = HashIndex(fields=['field'], name=index_name, fillfactor=80)
+ index_name = "integer_array_hash_fillfactor"
+ index = HashIndex(fields=["field"], name=index_name, fillfactor=80)
with connection.schema_editor() as editor:
editor.add_index(CharFieldModel, index)
constraints = self.get_constraints(CharFieldModel._meta.db_table)
- self.assertEqual(constraints[index_name]['type'], HashIndex.suffix)
- self.assertEqual(constraints[index_name]['options'], ['fillfactor=80'])
+ self.assertEqual(constraints[index_name]["type"], HashIndex.suffix)
+ self.assertEqual(constraints[index_name]["options"], ["fillfactor=80"])
with connection.schema_editor() as editor:
editor.remove_index(CharFieldModel, index)
- self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table))
+ self.assertNotIn(
+ index_name, self.get_constraints(CharFieldModel._meta.db_table)
+ )
def test_spgist_index(self):
# Ensure the table is there and doesn't have an index.
- self.assertNotIn('field', self.get_constraints(TextFieldModel._meta.db_table))
+ self.assertNotIn("field", self.get_constraints(TextFieldModel._meta.db_table))
# Add the index.
- index_name = 'text_field_model_field_spgist'
- index = SpGistIndex(fields=['field'], name=index_name)
+ index_name = "text_field_model_field_spgist"
+ index = SpGistIndex(fields=["field"], name=index_name)
with connection.schema_editor() as editor:
editor.add_index(TextFieldModel, index)
constraints = self.get_constraints(TextFieldModel._meta.db_table)
# The index was added.
- self.assertEqual(constraints[index_name]['type'], SpGistIndex.suffix)
+ self.assertEqual(constraints[index_name]["type"], SpGistIndex.suffix)
# Drop the index.
with connection.schema_editor() as editor:
editor.remove_index(TextFieldModel, index)
- self.assertNotIn(index_name, self.get_constraints(TextFieldModel._meta.db_table))
+ self.assertNotIn(
+ index_name, self.get_constraints(TextFieldModel._meta.db_table)
+ )
def test_spgist_parameters(self):
- index_name = 'text_field_model_spgist_fillfactor'
- index = SpGistIndex(fields=['field'], name=index_name, fillfactor=80)
+ index_name = "text_field_model_spgist_fillfactor"
+ index = SpGistIndex(fields=["field"], name=index_name, fillfactor=80)
with connection.schema_editor() as editor:
editor.add_index(TextFieldModel, index)
constraints = self.get_constraints(TextFieldModel._meta.db_table)
- self.assertEqual(constraints[index_name]['type'], SpGistIndex.suffix)
- self.assertEqual(constraints[index_name]['options'], ['fillfactor=80'])
+ self.assertEqual(constraints[index_name]["type"], SpGistIndex.suffix)
+ self.assertEqual(constraints[index_name]["options"], ["fillfactor=80"])
with connection.schema_editor() as editor:
editor.remove_index(TextFieldModel, index)
- self.assertNotIn(index_name, self.get_constraints(TextFieldModel._meta.db_table))
+ self.assertNotIn(
+ index_name, self.get_constraints(TextFieldModel._meta.db_table)
+ )
- @skipUnlessDBFeature('supports_covering_spgist_indexes')
+ @skipUnlessDBFeature("supports_covering_spgist_indexes")
def test_spgist_include(self):
- index_name = 'scene_spgist_include_setting'
- index = SpGistIndex(name=index_name, fields=['scene'], include=['setting'])
+ index_name = "scene_spgist_include_setting"
+ index = SpGistIndex(name=index_name, fields=["scene"], include=["setting"])
with connection.schema_editor() as editor:
editor.add_index(Scene, index)
constraints = self.get_constraints(Scene._meta.db_table)
self.assertIn(index_name, constraints)
- self.assertEqual(constraints[index_name]['type'], SpGistIndex.suffix)
- self.assertEqual(constraints[index_name]['columns'], ['scene', 'setting'])
+ self.assertEqual(constraints[index_name]["type"], SpGistIndex.suffix)
+ self.assertEqual(constraints[index_name]["columns"], ["scene", "setting"])
with connection.schema_editor() as editor:
editor.remove_index(Scene, index)
self.assertNotIn(index_name, self.get_constraints(Scene._meta.db_table))
def test_spgist_include_not_supported(self):
- index_name = 'spgist_include_exception'
- index = SpGistIndex(fields=['scene'], name=index_name, include=['setting'])
- msg = 'Covering SP-GiST indexes require PostgreSQL 14+.'
+ index_name = "spgist_include_exception"
+ index = SpGistIndex(fields=["scene"], name=index_name, include=["setting"])
+ msg = "Covering SP-GiST indexes require PostgreSQL 14+."
with self.assertRaisesMessage(NotSupportedError, msg):
with mock.patch(
- 'django.db.backends.postgresql.features.DatabaseFeatures.supports_covering_spgist_indexes',
+ "django.db.backends.postgresql.features.DatabaseFeatures.supports_covering_spgist_indexes",
False,
):
with connection.schema_editor() as editor:
@@ -546,27 +640,25 @@ class SchemaTests(PostgreSQLTestCase):
self.assertNotIn(index_name, self.get_constraints(Scene._meta.db_table))
def test_op_class(self):
- index_name = 'test_op_class'
+ index_name = "test_op_class"
index = Index(
- OpClass(Lower('field'), name='text_pattern_ops'),
+ OpClass(Lower("field"), name="text_pattern_ops"),
name=index_name,
)
with connection.schema_editor() as editor:
editor.add_index(TextFieldModel, index)
with editor.connection.cursor() as cursor:
cursor.execute(self.get_opclass_query, [index_name])
- self.assertCountEqual(cursor.fetchall(), [('text_pattern_ops', index_name)])
+ self.assertCountEqual(cursor.fetchall(), [("text_pattern_ops", index_name)])
def test_op_class_descending_collation(self):
- collation = connection.features.test_collations.get('non_default')
+ collation = connection.features.test_collations.get("non_default")
if not collation:
- self.skipTest(
- 'This backend does not support case-insensitive collations.'
- )
- index_name = 'test_op_class_descending_collation'
+ self.skipTest("This backend does not support case-insensitive collations.")
+ index_name = "test_op_class_descending_collation"
index = Index(
Collate(
- OpClass(Lower('field'), name='text_pattern_ops').desc(nulls_last=True),
+ OpClass(Lower("field"), name="text_pattern_ops").desc(nulls_last=True),
collation=collation,
),
name=index_name,
@@ -574,53 +666,53 @@ class SchemaTests(PostgreSQLTestCase):
with connection.schema_editor() as editor:
editor.add_index(TextFieldModel, index)
self.assertIn(
- 'COLLATE %s' % editor.quote_name(collation),
+ "COLLATE %s" % editor.quote_name(collation),
str(index.create_sql(TextFieldModel, editor)),
)
with editor.connection.cursor() as cursor:
cursor.execute(self.get_opclass_query, [index_name])
- self.assertCountEqual(cursor.fetchall(), [('text_pattern_ops', index_name)])
+ self.assertCountEqual(cursor.fetchall(), [("text_pattern_ops", index_name)])
table = TextFieldModel._meta.db_table
constraints = self.get_constraints(table)
self.assertIn(index_name, constraints)
- self.assertEqual(constraints[index_name]['orders'], ['DESC'])
+ self.assertEqual(constraints[index_name]["orders"], ["DESC"])
with connection.schema_editor() as editor:
editor.remove_index(TextFieldModel, index)
self.assertNotIn(index_name, self.get_constraints(table))
def test_op_class_descending_partial(self):
- index_name = 'test_op_class_descending_partial'
+ index_name = "test_op_class_descending_partial"
index = Index(
- OpClass(Lower('field'), name='text_pattern_ops').desc(),
+ OpClass(Lower("field"), name="text_pattern_ops").desc(),
name=index_name,
- condition=Q(field__contains='China'),
+ condition=Q(field__contains="China"),
)
with connection.schema_editor() as editor:
editor.add_index(TextFieldModel, index)
with editor.connection.cursor() as cursor:
cursor.execute(self.get_opclass_query, [index_name])
- self.assertCountEqual(cursor.fetchall(), [('text_pattern_ops', index_name)])
+ self.assertCountEqual(cursor.fetchall(), [("text_pattern_ops", index_name)])
constraints = self.get_constraints(TextFieldModel._meta.db_table)
self.assertIn(index_name, constraints)
- self.assertEqual(constraints[index_name]['orders'], ['DESC'])
+ self.assertEqual(constraints[index_name]["orders"], ["DESC"])
def test_op_class_descending_partial_tablespace(self):
- index_name = 'test_op_class_descending_partial_tablespace'
+ index_name = "test_op_class_descending_partial_tablespace"
index = Index(
- OpClass(Lower('field').desc(), name='text_pattern_ops'),
+ OpClass(Lower("field").desc(), name="text_pattern_ops"),
name=index_name,
- condition=Q(field__contains='China'),
- db_tablespace='pg_default',
+ condition=Q(field__contains="China"),
+ db_tablespace="pg_default",
)
with connection.schema_editor() as editor:
editor.add_index(TextFieldModel, index)
self.assertIn(
'TABLESPACE "pg_default" ',
- str(index.create_sql(TextFieldModel, editor))
+ str(index.create_sql(TextFieldModel, editor)),
)
with editor.connection.cursor() as cursor:
cursor.execute(self.get_opclass_query, [index_name])
- self.assertCountEqual(cursor.fetchall(), [('text_pattern_ops', index_name)])
+ self.assertCountEqual(cursor.fetchall(), [("text_pattern_ops", index_name)])
constraints = self.get_constraints(TextFieldModel._meta.db_table)
self.assertIn(index_name, constraints)
- self.assertEqual(constraints[index_name]['orders'], ['DESC'])
+ self.assertEqual(constraints[index_name]["orders"], ["DESC"])
diff --git a/tests/postgres_tests/test_integration.py b/tests/postgres_tests/test_integration.py
index db082a6495..e3c62a93e9 100644
--- a/tests/postgres_tests/test_integration.py
+++ b/tests/postgres_tests/test_integration.py
@@ -8,15 +8,22 @@ from . import PostgreSQLSimpleTestCase
class PostgresIntegrationTests(PostgreSQLSimpleTestCase):
def test_check(self):
test_environ = os.environ.copy()
- if 'DJANGO_SETTINGS_MODULE' in test_environ:
- del test_environ['DJANGO_SETTINGS_MODULE']
- test_environ['PYTHONPATH'] = os.path.join(os.path.dirname(__file__), '../../')
+ if "DJANGO_SETTINGS_MODULE" in test_environ:
+ del test_environ["DJANGO_SETTINGS_MODULE"]
+ test_environ["PYTHONPATH"] = os.path.join(os.path.dirname(__file__), "../../")
result = subprocess.run(
- [sys.executable, '-m', 'django', 'check', '--settings', 'integration_settings'],
+ [
+ sys.executable,
+ "-m",
+ "django",
+ "check",
+ "--settings",
+ "integration_settings",
+ ],
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE,
cwd=os.path.dirname(__file__),
env=test_environ,
- encoding='utf-8',
+ encoding="utf-8",
)
self.assertEqual(result.returncode, 0, msg=result.stderr)
diff --git a/tests/postgres_tests/test_introspection.py b/tests/postgres_tests/test_introspection.py
index 50cb9b2828..670be46536 100644
--- a/tests/postgres_tests/test_introspection.py
+++ b/tests/postgres_tests/test_introspection.py
@@ -6,12 +6,12 @@ from django.test.utils import modify_settings
from . import PostgreSQLTestCase
-@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
+@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
class InspectDBTests(PostgreSQLTestCase):
def assertFieldsInModel(self, model, field_outputs):
out = StringIO()
call_command(
- 'inspectdb',
+ "inspectdb",
table_name_filter=lambda tn: tn.startswith(model),
stdout=out,
)
@@ -21,12 +21,12 @@ class InspectDBTests(PostgreSQLTestCase):
def test_range_fields(self):
self.assertFieldsInModel(
- 'postgres_tests_rangesmodel',
+ "postgres_tests_rangesmodel",
[
- 'ints = django.contrib.postgres.fields.IntegerRangeField(blank=True, null=True)',
- 'bigints = django.contrib.postgres.fields.BigIntegerRangeField(blank=True, null=True)',
- 'decimals = django.contrib.postgres.fields.DecimalRangeField(blank=True, null=True)',
- 'timestamps = django.contrib.postgres.fields.DateTimeRangeField(blank=True, null=True)',
- 'dates = django.contrib.postgres.fields.DateRangeField(blank=True, null=True)',
+ "ints = django.contrib.postgres.fields.IntegerRangeField(blank=True, null=True)",
+ "bigints = django.contrib.postgres.fields.BigIntegerRangeField(blank=True, null=True)",
+ "decimals = django.contrib.postgres.fields.DecimalRangeField(blank=True, null=True)",
+ "timestamps = django.contrib.postgres.fields.DateTimeRangeField(blank=True, null=True)",
+ "dates = django.contrib.postgres.fields.DateRangeField(blank=True, null=True)",
],
)
diff --git a/tests/postgres_tests/test_operations.py b/tests/postgres_tests/test_operations.py
index 790fef8332..a54f8811ba 100644
--- a/tests/postgres_tests/test_operations.py
+++ b/tests/postgres_tests/test_operations.py
@@ -3,9 +3,7 @@ from unittest import mock
from migrations.test_base import OperationTestBase
-from django.db import (
- IntegrityError, NotSupportedError, connection, transaction,
-)
+from django.db import IntegrityError, NotSupportedError, connection, transaction
from django.db.migrations.state import ProjectState
from django.db.models import CheckConstraint, Index, Q, UniqueConstraint
from django.db.utils import ProgrammingError
@@ -17,262 +15,315 @@ from . import PostgreSQLTestCase
try:
from django.contrib.postgres.indexes import BrinIndex, BTreeIndex
from django.contrib.postgres.operations import (
- AddConstraintNotValid, AddIndexConcurrently, BloomExtension,
- CreateCollation, CreateExtension, RemoveCollation,
- RemoveIndexConcurrently, ValidateConstraint,
+ AddConstraintNotValid,
+ AddIndexConcurrently,
+ BloomExtension,
+ CreateCollation,
+ CreateExtension,
+ RemoveCollation,
+ RemoveIndexConcurrently,
+ ValidateConstraint,
)
except ImportError:
pass
-@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.')
-@modify_settings(INSTALLED_APPS={'append': 'migrations'})
+@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests.")
+@modify_settings(INSTALLED_APPS={"append": "migrations"})
class AddIndexConcurrentlyTests(OperationTestBase):
- app_label = 'test_add_concurrently'
+ app_label = "test_add_concurrently"
def test_requires_atomic_false(self):
project_state = self.set_up_test_model(self.app_label)
new_state = project_state.clone()
operation = AddIndexConcurrently(
- 'Pony',
- Index(fields=['pink'], name='pony_pink_idx'),
+ "Pony",
+ Index(fields=["pink"], name="pony_pink_idx"),
)
msg = (
- 'The AddIndexConcurrently operation cannot be executed inside '
- 'a transaction (set atomic = False on the migration).'
+ "The AddIndexConcurrently operation cannot be executed inside "
+ "a transaction (set atomic = False on the migration)."
)
with self.assertRaisesMessage(NotSupportedError, msg):
with connection.schema_editor(atomic=True) as editor:
- operation.database_forwards(self.app_label, editor, project_state, new_state)
+ operation.database_forwards(
+ self.app_label, editor, project_state, new_state
+ )
def test_add(self):
project_state = self.set_up_test_model(self.app_label, index=False)
- table_name = '%s_pony' % self.app_label
- index = Index(fields=['pink'], name='pony_pink_idx')
+ table_name = "%s_pony" % self.app_label
+ index = Index(fields=["pink"], name="pony_pink_idx")
new_state = project_state.clone()
- operation = AddIndexConcurrently('Pony', index)
+ operation = AddIndexConcurrently("Pony", index)
self.assertEqual(
operation.describe(),
- 'Concurrently create index pony_pink_idx on field(s) pink of model Pony',
+ "Concurrently create index pony_pink_idx on field(s) pink of model Pony",
)
operation.state_forwards(self.app_label, new_state)
- self.assertEqual(len(new_state.models[self.app_label, 'pony'].options['indexes']), 1)
- self.assertIndexNotExists(table_name, ['pink'])
+ self.assertEqual(
+ len(new_state.models[self.app_label, "pony"].options["indexes"]), 1
+ )
+ self.assertIndexNotExists(table_name, ["pink"])
# Add index.
with connection.schema_editor(atomic=False) as editor:
- operation.database_forwards(self.app_label, editor, project_state, new_state)
- self.assertIndexExists(table_name, ['pink'])
+ operation.database_forwards(
+ self.app_label, editor, project_state, new_state
+ )
+ self.assertIndexExists(table_name, ["pink"])
# Reversal.
with connection.schema_editor(atomic=False) as editor:
- operation.database_backwards(self.app_label, editor, new_state, project_state)
- self.assertIndexNotExists(table_name, ['pink'])
+ operation.database_backwards(
+ self.app_label, editor, new_state, project_state
+ )
+ self.assertIndexNotExists(table_name, ["pink"])
# Deconstruction.
name, args, kwargs = operation.deconstruct()
- self.assertEqual(name, 'AddIndexConcurrently')
+ self.assertEqual(name, "AddIndexConcurrently")
self.assertEqual(args, [])
- self.assertEqual(kwargs, {'model_name': 'Pony', 'index': index})
+ self.assertEqual(kwargs, {"model_name": "Pony", "index": index})
def test_add_other_index_type(self):
project_state = self.set_up_test_model(self.app_label, index=False)
- table_name = '%s_pony' % self.app_label
+ table_name = "%s_pony" % self.app_label
new_state = project_state.clone()
operation = AddIndexConcurrently(
- 'Pony',
- BrinIndex(fields=['pink'], name='pony_pink_brin_idx'),
+ "Pony",
+ BrinIndex(fields=["pink"], name="pony_pink_brin_idx"),
)
- self.assertIndexNotExists(table_name, ['pink'])
+ self.assertIndexNotExists(table_name, ["pink"])
# Add index.
with connection.schema_editor(atomic=False) as editor:
- operation.database_forwards(self.app_label, editor, project_state, new_state)
- self.assertIndexExists(table_name, ['pink'], index_type='brin')
+ operation.database_forwards(
+ self.app_label, editor, project_state, new_state
+ )
+ self.assertIndexExists(table_name, ["pink"], index_type="brin")
# Reversal.
with connection.schema_editor(atomic=False) as editor:
- operation.database_backwards(self.app_label, editor, new_state, project_state)
- self.assertIndexNotExists(table_name, ['pink'])
+ operation.database_backwards(
+ self.app_label, editor, new_state, project_state
+ )
+ self.assertIndexNotExists(table_name, ["pink"])
def test_add_with_options(self):
project_state = self.set_up_test_model(self.app_label, index=False)
- table_name = '%s_pony' % self.app_label
+ table_name = "%s_pony" % self.app_label
new_state = project_state.clone()
- index = BTreeIndex(fields=['pink'], name='pony_pink_btree_idx', fillfactor=70)
- operation = AddIndexConcurrently('Pony', index)
- self.assertIndexNotExists(table_name, ['pink'])
+ index = BTreeIndex(fields=["pink"], name="pony_pink_btree_idx", fillfactor=70)
+ operation = AddIndexConcurrently("Pony", index)
+ self.assertIndexNotExists(table_name, ["pink"])
# Add index.
with connection.schema_editor(atomic=False) as editor:
- operation.database_forwards(self.app_label, editor, project_state, new_state)
- self.assertIndexExists(table_name, ['pink'], index_type='btree')
+ operation.database_forwards(
+ self.app_label, editor, project_state, new_state
+ )
+ self.assertIndexExists(table_name, ["pink"], index_type="btree")
# Reversal.
with connection.schema_editor(atomic=False) as editor:
- operation.database_backwards(self.app_label, editor, new_state, project_state)
- self.assertIndexNotExists(table_name, ['pink'])
+ operation.database_backwards(
+ self.app_label, editor, new_state, project_state
+ )
+ self.assertIndexNotExists(table_name, ["pink"])
-@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.')
-@modify_settings(INSTALLED_APPS={'append': 'migrations'})
+@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests.")
+@modify_settings(INSTALLED_APPS={"append": "migrations"})
class RemoveIndexConcurrentlyTests(OperationTestBase):
- app_label = 'test_rm_concurrently'
+ app_label = "test_rm_concurrently"
def test_requires_atomic_false(self):
project_state = self.set_up_test_model(self.app_label, index=True)
new_state = project_state.clone()
- operation = RemoveIndexConcurrently('Pony', 'pony_pink_idx')
+ operation = RemoveIndexConcurrently("Pony", "pony_pink_idx")
msg = (
- 'The RemoveIndexConcurrently operation cannot be executed inside '
- 'a transaction (set atomic = False on the migration).'
+ "The RemoveIndexConcurrently operation cannot be executed inside "
+ "a transaction (set atomic = False on the migration)."
)
with self.assertRaisesMessage(NotSupportedError, msg):
with connection.schema_editor(atomic=True) as editor:
- operation.database_forwards(self.app_label, editor, project_state, new_state)
+ operation.database_forwards(
+ self.app_label, editor, project_state, new_state
+ )
def test_remove(self):
project_state = self.set_up_test_model(self.app_label, index=True)
- table_name = '%s_pony' % self.app_label
+ table_name = "%s_pony" % self.app_label
self.assertTableExists(table_name)
new_state = project_state.clone()
- operation = RemoveIndexConcurrently('Pony', 'pony_pink_idx')
+ operation = RemoveIndexConcurrently("Pony", "pony_pink_idx")
self.assertEqual(
operation.describe(),
- 'Concurrently remove index pony_pink_idx from Pony',
+ "Concurrently remove index pony_pink_idx from Pony",
)
operation.state_forwards(self.app_label, new_state)
- self.assertEqual(len(new_state.models[self.app_label, 'pony'].options['indexes']), 0)
- self.assertIndexExists(table_name, ['pink'])
+ self.assertEqual(
+ len(new_state.models[self.app_label, "pony"].options["indexes"]), 0
+ )
+ self.assertIndexExists(table_name, ["pink"])
# Remove index.
with connection.schema_editor(atomic=False) as editor:
- operation.database_forwards(self.app_label, editor, project_state, new_state)
- self.assertIndexNotExists(table_name, ['pink'])
+ operation.database_forwards(
+ self.app_label, editor, project_state, new_state
+ )
+ self.assertIndexNotExists(table_name, ["pink"])
# Reversal.
with connection.schema_editor(atomic=False) as editor:
- operation.database_backwards(self.app_label, editor, new_state, project_state)
- self.assertIndexExists(table_name, ['pink'])
+ operation.database_backwards(
+ self.app_label, editor, new_state, project_state
+ )
+ self.assertIndexExists(table_name, ["pink"])
# Deconstruction.
name, args, kwargs = operation.deconstruct()
- self.assertEqual(name, 'RemoveIndexConcurrently')
+ self.assertEqual(name, "RemoveIndexConcurrently")
self.assertEqual(args, [])
- self.assertEqual(kwargs, {'model_name': 'Pony', 'name': 'pony_pink_idx'})
+ self.assertEqual(kwargs, {"model_name": "Pony", "name": "pony_pink_idx"})
-class NoMigrationRouter():
+class NoMigrationRouter:
def allow_migrate(self, db, app_label, **hints):
return False
-@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.')
+@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests.")
class CreateExtensionTests(PostgreSQLTestCase):
- app_label = 'test_allow_create_extention'
+ app_label = "test_allow_create_extention"
@override_settings(DATABASE_ROUTERS=[NoMigrationRouter()])
def test_no_allow_migrate(self):
- operation = CreateExtension('tablefunc')
+ operation = CreateExtension("tablefunc")
project_state = ProjectState()
new_state = project_state.clone()
# Don't create an extension.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
- operation.database_forwards(self.app_label, editor, project_state, new_state)
+ operation.database_forwards(
+ self.app_label, editor, project_state, new_state
+ )
self.assertEqual(len(captured_queries), 0)
# Reversal.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
- operation.database_backwards(self.app_label, editor, new_state, project_state)
+ operation.database_backwards(
+ self.app_label, editor, new_state, project_state
+ )
self.assertEqual(len(captured_queries), 0)
def test_allow_migrate(self):
- operation = CreateExtension('tablefunc')
- self.assertEqual(operation.migration_name_fragment, 'create_extension_tablefunc')
+ operation = CreateExtension("tablefunc")
+ self.assertEqual(
+ operation.migration_name_fragment, "create_extension_tablefunc"
+ )
project_state = ProjectState()
new_state = project_state.clone()
# Create an extension.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
- operation.database_forwards(self.app_label, editor, project_state, new_state)
+ operation.database_forwards(
+ self.app_label, editor, project_state, new_state
+ )
self.assertEqual(len(captured_queries), 4)
- self.assertIn('CREATE EXTENSION IF NOT EXISTS', captured_queries[1]['sql'])
+ self.assertIn("CREATE EXTENSION IF NOT EXISTS", captured_queries[1]["sql"])
# Reversal.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
- operation.database_backwards(self.app_label, editor, new_state, project_state)
+ operation.database_backwards(
+ self.app_label, editor, new_state, project_state
+ )
self.assertEqual(len(captured_queries), 2)
- self.assertIn('DROP EXTENSION IF EXISTS', captured_queries[1]['sql'])
+ self.assertIn("DROP EXTENSION IF EXISTS", captured_queries[1]["sql"])
def test_create_existing_extension(self):
operation = BloomExtension()
- self.assertEqual(operation.migration_name_fragment, 'create_extension_bloom')
+ self.assertEqual(operation.migration_name_fragment, "create_extension_bloom")
project_state = ProjectState()
new_state = project_state.clone()
# Don't create an existing extension.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
- operation.database_forwards(self.app_label, editor, project_state, new_state)
+ operation.database_forwards(
+ self.app_label, editor, project_state, new_state
+ )
self.assertEqual(len(captured_queries), 3)
- self.assertIn('SELECT', captured_queries[0]['sql'])
+ self.assertIn("SELECT", captured_queries[0]["sql"])
def test_drop_nonexistent_extension(self):
- operation = CreateExtension('tablefunc')
+ operation = CreateExtension("tablefunc")
project_state = ProjectState()
new_state = project_state.clone()
# Don't drop a nonexistent extension.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
- operation.database_backwards(self.app_label, editor, project_state, new_state)
+ operation.database_backwards(
+ self.app_label, editor, project_state, new_state
+ )
self.assertEqual(len(captured_queries), 1)
- self.assertIn('SELECT', captured_queries[0]['sql'])
+ self.assertIn("SELECT", captured_queries[0]["sql"])
-@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.')
+@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests.")
class CreateCollationTests(PostgreSQLTestCase):
- app_label = 'test_allow_create_collation'
+ app_label = "test_allow_create_collation"
@override_settings(DATABASE_ROUTERS=[NoMigrationRouter()])
def test_no_allow_migrate(self):
- operation = CreateCollation('C_test', locale='C')
+ operation = CreateCollation("C_test", locale="C")
project_state = ProjectState()
new_state = project_state.clone()
# Don't create a collation.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
- operation.database_forwards(self.app_label, editor, project_state, new_state)
+ operation.database_forwards(
+ self.app_label, editor, project_state, new_state
+ )
self.assertEqual(len(captured_queries), 0)
# Reversal.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
- operation.database_backwards(self.app_label, editor, new_state, project_state)
+ operation.database_backwards(
+ self.app_label, editor, new_state, project_state
+ )
self.assertEqual(len(captured_queries), 0)
def test_create(self):
- operation = CreateCollation('C_test', locale='C')
- self.assertEqual(operation.migration_name_fragment, 'create_collation_c_test')
- self.assertEqual(operation.describe(), 'Create collation C_test')
+ operation = CreateCollation("C_test", locale="C")
+ self.assertEqual(operation.migration_name_fragment, "create_collation_c_test")
+ self.assertEqual(operation.describe(), "Create collation C_test")
project_state = ProjectState()
new_state = project_state.clone()
# Create a collation.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
- operation.database_forwards(self.app_label, editor, project_state, new_state)
+ operation.database_forwards(
+ self.app_label, editor, project_state, new_state
+ )
self.assertEqual(len(captured_queries), 1)
- self.assertIn('CREATE COLLATION', captured_queries[0]['sql'])
+ self.assertIn("CREATE COLLATION", captured_queries[0]["sql"])
# Creating the same collation raises an exception.
- with self.assertRaisesMessage(ProgrammingError, 'already exists'):
+ with self.assertRaisesMessage(ProgrammingError, "already exists"):
with connection.schema_editor(atomic=True) as editor:
- operation.database_forwards(self.app_label, editor, project_state, new_state)
+ operation.database_forwards(
+ self.app_label, editor, project_state, new_state
+ )
# Reversal.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
- operation.database_backwards(self.app_label, editor, new_state, project_state)
+ operation.database_backwards(
+ self.app_label, editor, new_state, project_state
+ )
self.assertEqual(len(captured_queries), 1)
- self.assertIn('DROP COLLATION', captured_queries[0]['sql'])
+ self.assertIn("DROP COLLATION", captured_queries[0]["sql"])
# Deconstruction.
name, args, kwargs = operation.deconstruct()
- self.assertEqual(name, 'CreateCollation')
+ self.assertEqual(name, "CreateCollation")
self.assertEqual(args, [])
- self.assertEqual(kwargs, {'name': 'C_test', 'locale': 'C'})
+ self.assertEqual(kwargs, {"name": "C_test", "locale": "C"})
- @skipUnlessDBFeature('supports_non_deterministic_collations')
+ @skipUnlessDBFeature("supports_non_deterministic_collations")
def test_create_non_deterministic_collation(self):
operation = CreateCollation(
- 'case_insensitive_test',
- 'und-u-ks-level2',
- provider='icu',
+ "case_insensitive_test",
+ "und-u-ks-level2",
+ provider="icu",
deterministic=False,
)
project_state = ProjectState()
@@ -280,216 +331,253 @@ class CreateCollationTests(PostgreSQLTestCase):
# Create a collation.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
- operation.database_forwards(self.app_label, editor, project_state, new_state)
+ operation.database_forwards(
+ self.app_label, editor, project_state, new_state
+ )
self.assertEqual(len(captured_queries), 1)
- self.assertIn('CREATE COLLATION', captured_queries[0]['sql'])
+ self.assertIn("CREATE COLLATION", captured_queries[0]["sql"])
# Reversal.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
- operation.database_backwards(self.app_label, editor, new_state, project_state)
+ operation.database_backwards(
+ self.app_label, editor, new_state, project_state
+ )
self.assertEqual(len(captured_queries), 1)
- self.assertIn('DROP COLLATION', captured_queries[0]['sql'])
+ self.assertIn("DROP COLLATION", captured_queries[0]["sql"])
# Deconstruction.
name, args, kwargs = operation.deconstruct()
- self.assertEqual(name, 'CreateCollation')
+ self.assertEqual(name, "CreateCollation")
self.assertEqual(args, [])
- self.assertEqual(kwargs, {
- 'name': 'case_insensitive_test',
- 'locale': 'und-u-ks-level2',
- 'provider': 'icu',
- 'deterministic': False,
- })
+ self.assertEqual(
+ kwargs,
+ {
+ "name": "case_insensitive_test",
+ "locale": "und-u-ks-level2",
+ "provider": "icu",
+ "deterministic": False,
+ },
+ )
def test_create_collation_alternate_provider(self):
operation = CreateCollation(
- 'german_phonebook_test',
- provider='icu',
- locale='de-u-co-phonebk',
+ "german_phonebook_test",
+ provider="icu",
+ locale="de-u-co-phonebk",
)
project_state = ProjectState()
new_state = project_state.clone()
# Create an collation.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
- operation.database_forwards(self.app_label, editor, project_state, new_state)
+ operation.database_forwards(
+ self.app_label, editor, project_state, new_state
+ )
self.assertEqual(len(captured_queries), 1)
- self.assertIn('CREATE COLLATION', captured_queries[0]['sql'])
+ self.assertIn("CREATE COLLATION", captured_queries[0]["sql"])
# Reversal.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
- operation.database_backwards(self.app_label, editor, new_state, project_state)
+ operation.database_backwards(
+ self.app_label, editor, new_state, project_state
+ )
self.assertEqual(len(captured_queries), 1)
- self.assertIn('DROP COLLATION', captured_queries[0]['sql'])
+ self.assertIn("DROP COLLATION", captured_queries[0]["sql"])
def test_nondeterministic_collation_not_supported(self):
operation = CreateCollation(
- 'case_insensitive_test',
- provider='icu',
- locale='und-u-ks-level2',
+ "case_insensitive_test",
+ provider="icu",
+ locale="und-u-ks-level2",
deterministic=False,
)
project_state = ProjectState()
new_state = project_state.clone()
- msg = 'Non-deterministic collations require PostgreSQL 12+.'
+ msg = "Non-deterministic collations require PostgreSQL 12+."
with connection.schema_editor(atomic=False) as editor:
with mock.patch(
- 'django.db.backends.postgresql.features.DatabaseFeatures.'
- 'supports_non_deterministic_collations',
+ "django.db.backends.postgresql.features.DatabaseFeatures."
+ "supports_non_deterministic_collations",
False,
):
with self.assertRaisesMessage(NotSupportedError, msg):
- operation.database_forwards(self.app_label, editor, project_state, new_state)
+ operation.database_forwards(
+ self.app_label, editor, project_state, new_state
+ )
-@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.')
+@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests.")
class RemoveCollationTests(PostgreSQLTestCase):
- app_label = 'test_allow_remove_collation'
+ app_label = "test_allow_remove_collation"
@override_settings(DATABASE_ROUTERS=[NoMigrationRouter()])
def test_no_allow_migrate(self):
- operation = RemoveCollation('C_test', locale='C')
+ operation = RemoveCollation("C_test", locale="C")
project_state = ProjectState()
new_state = project_state.clone()
# Don't create a collation.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
- operation.database_forwards(self.app_label, editor, project_state, new_state)
+ operation.database_forwards(
+ self.app_label, editor, project_state, new_state
+ )
self.assertEqual(len(captured_queries), 0)
# Reversal.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
- operation.database_backwards(self.app_label, editor, new_state, project_state)
+ operation.database_backwards(
+ self.app_label, editor, new_state, project_state
+ )
self.assertEqual(len(captured_queries), 0)
def test_remove(self):
- operation = CreateCollation('C_test', locale='C')
+ operation = CreateCollation("C_test", locale="C")
project_state = ProjectState()
new_state = project_state.clone()
with connection.schema_editor(atomic=False) as editor:
- operation.database_forwards(self.app_label, editor, project_state, new_state)
+ operation.database_forwards(
+ self.app_label, editor, project_state, new_state
+ )
- operation = RemoveCollation('C_test', locale='C')
- self.assertEqual(operation.migration_name_fragment, 'remove_collation_c_test')
- self.assertEqual(operation.describe(), 'Remove collation C_test')
+ operation = RemoveCollation("C_test", locale="C")
+ self.assertEqual(operation.migration_name_fragment, "remove_collation_c_test")
+ self.assertEqual(operation.describe(), "Remove collation C_test")
project_state = ProjectState()
new_state = project_state.clone()
# Remove a collation.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
- operation.database_forwards(self.app_label, editor, project_state, new_state)
+ operation.database_forwards(
+ self.app_label, editor, project_state, new_state
+ )
self.assertEqual(len(captured_queries), 1)
- self.assertIn('DROP COLLATION', captured_queries[0]['sql'])
+ self.assertIn("DROP COLLATION", captured_queries[0]["sql"])
# Removing a nonexistent collation raises an exception.
- with self.assertRaisesMessage(ProgrammingError, 'does not exist'):
+ with self.assertRaisesMessage(ProgrammingError, "does not exist"):
with connection.schema_editor(atomic=True) as editor:
- operation.database_forwards(self.app_label, editor, project_state, new_state)
+ operation.database_forwards(
+ self.app_label, editor, project_state, new_state
+ )
# Reversal.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
- operation.database_backwards(self.app_label, editor, new_state, project_state)
+ operation.database_backwards(
+ self.app_label, editor, new_state, project_state
+ )
self.assertEqual(len(captured_queries), 1)
- self.assertIn('CREATE COLLATION', captured_queries[0]['sql'])
+ self.assertIn("CREATE COLLATION", captured_queries[0]["sql"])
# Deconstruction.
name, args, kwargs = operation.deconstruct()
- self.assertEqual(name, 'RemoveCollation')
+ self.assertEqual(name, "RemoveCollation")
self.assertEqual(args, [])
- self.assertEqual(kwargs, {'name': 'C_test', 'locale': 'C'})
+ self.assertEqual(kwargs, {"name": "C_test", "locale": "C"})
-@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.')
-@modify_settings(INSTALLED_APPS={'append': 'migrations'})
+@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests.")
+@modify_settings(INSTALLED_APPS={"append": "migrations"})
class AddConstraintNotValidTests(OperationTestBase):
- app_label = 'test_add_constraint_not_valid'
+ app_label = "test_add_constraint_not_valid"
def test_non_check_constraint_not_supported(self):
- constraint = UniqueConstraint(fields=['pink'], name='pony_pink_uniq')
- msg = 'AddConstraintNotValid.constraint must be a check constraint.'
+ constraint = UniqueConstraint(fields=["pink"], name="pony_pink_uniq")
+ msg = "AddConstraintNotValid.constraint must be a check constraint."
with self.assertRaisesMessage(TypeError, msg):
- AddConstraintNotValid(model_name='pony', constraint=constraint)
+ AddConstraintNotValid(model_name="pony", constraint=constraint)
def test_add(self):
- table_name = f'{self.app_label}_pony'
- constraint_name = 'pony_pink_gte_check'
+ table_name = f"{self.app_label}_pony"
+ constraint_name = "pony_pink_gte_check"
constraint = CheckConstraint(check=Q(pink__gte=4), name=constraint_name)
- operation = AddConstraintNotValid('Pony', constraint=constraint)
+ operation = AddConstraintNotValid("Pony", constraint=constraint)
project_state, new_state = self.make_test_state(self.app_label, operation)
self.assertEqual(
operation.describe(),
- f'Create not valid constraint {constraint_name} on model Pony',
+ f"Create not valid constraint {constraint_name} on model Pony",
)
self.assertEqual(
operation.migration_name_fragment,
- f'pony_{constraint_name}_not_valid',
+ f"pony_{constraint_name}_not_valid",
)
self.assertEqual(
- len(new_state.models[self.app_label, 'pony'].options['constraints']),
+ len(new_state.models[self.app_label, "pony"].options["constraints"]),
1,
)
self.assertConstraintNotExists(table_name, constraint_name)
- Pony = new_state.apps.get_model(self.app_label, 'Pony')
+ Pony = new_state.apps.get_model(self.app_label, "Pony")
self.assertEqual(len(Pony._meta.constraints), 1)
Pony.objects.create(pink=2, weight=1.0)
# Add constraint.
with connection.schema_editor(atomic=True) as editor:
- operation.database_forwards(self.app_label, editor, project_state, new_state)
+ operation.database_forwards(
+ self.app_label, editor, project_state, new_state
+ )
msg = f'check constraint "{constraint_name}"'
with self.assertRaisesMessage(IntegrityError, msg), transaction.atomic():
Pony.objects.create(pink=3, weight=1.0)
self.assertConstraintExists(table_name, constraint_name)
# Reversal.
with connection.schema_editor(atomic=True) as editor:
- operation.database_backwards(self.app_label, editor, project_state, new_state)
+ operation.database_backwards(
+ self.app_label, editor, project_state, new_state
+ )
self.assertConstraintNotExists(table_name, constraint_name)
Pony.objects.create(pink=3, weight=1.0)
# Deconstruction.
name, args, kwargs = operation.deconstruct()
- self.assertEqual(name, 'AddConstraintNotValid')
+ self.assertEqual(name, "AddConstraintNotValid")
self.assertEqual(args, [])
- self.assertEqual(kwargs, {'model_name': 'Pony', 'constraint': constraint})
+ self.assertEqual(kwargs, {"model_name": "Pony", "constraint": constraint})
-@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.')
-@modify_settings(INSTALLED_APPS={'append': 'migrations'})
+@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests.")
+@modify_settings(INSTALLED_APPS={"append": "migrations"})
class ValidateConstraintTests(OperationTestBase):
- app_label = 'test_validate_constraint'
+ app_label = "test_validate_constraint"
def test_validate(self):
- constraint_name = 'pony_pink_gte_check'
+ constraint_name = "pony_pink_gte_check"
constraint = CheckConstraint(check=Q(pink__gte=4), name=constraint_name)
- operation = AddConstraintNotValid('Pony', constraint=constraint)
+ operation = AddConstraintNotValid("Pony", constraint=constraint)
project_state, new_state = self.make_test_state(self.app_label, operation)
- Pony = new_state.apps.get_model(self.app_label, 'Pony')
+ Pony = new_state.apps.get_model(self.app_label, "Pony")
obj = Pony.objects.create(pink=2, weight=1.0)
# Add constraint.
with connection.schema_editor(atomic=True) as editor:
- operation.database_forwards(self.app_label, editor, project_state, new_state)
+ operation.database_forwards(
+ self.app_label, editor, project_state, new_state
+ )
project_state = new_state
new_state = new_state.clone()
- operation = ValidateConstraint('Pony', name=constraint_name)
+ operation = ValidateConstraint("Pony", name=constraint_name)
operation.state_forwards(self.app_label, new_state)
self.assertEqual(
operation.describe(),
- f'Validate constraint {constraint_name} on model Pony',
+ f"Validate constraint {constraint_name} on model Pony",
)
self.assertEqual(
operation.migration_name_fragment,
- f'pony_validate_{constraint_name}',
+ f"pony_validate_{constraint_name}",
)
# Validate constraint.
with connection.schema_editor(atomic=True) as editor:
msg = f'check constraint "{constraint_name}"'
with self.assertRaisesMessage(IntegrityError, msg):
- operation.database_forwards(self.app_label, editor, project_state, new_state)
+ operation.database_forwards(
+ self.app_label, editor, project_state, new_state
+ )
obj.pink = 5
obj.save()
with connection.schema_editor(atomic=True) as editor:
- operation.database_forwards(self.app_label, editor, project_state, new_state)
+ operation.database_forwards(
+ self.app_label, editor, project_state, new_state
+ )
# Reversal is a noop.
with connection.schema_editor() as editor:
with self.assertNumQueries(0):
- operation.database_backwards(self.app_label, editor, new_state, project_state)
+ operation.database_backwards(
+ self.app_label, editor, new_state, project_state
+ )
# Deconstruction.
name, args, kwargs = operation.deconstruct()
- self.assertEqual(name, 'ValidateConstraint')
+ self.assertEqual(name, "ValidateConstraint")
self.assertEqual(args, [])
- self.assertEqual(kwargs, {'model_name': 'Pony', 'name': constraint_name})
+ self.assertEqual(kwargs, {"model_name": "Pony", "name": constraint_name})
diff --git a/tests/postgres_tests/test_ranges.py b/tests/postgres_tests/test_ranges.py
index c2f9d443dd..7563b4ffff 100644
--- a/tests/postgres_tests/test_ranges.py
+++ b/tests/postgres_tests/test_ranges.py
@@ -12,38 +12,43 @@ from django.utils import timezone
from . import PostgreSQLSimpleTestCase, PostgreSQLTestCase
from .models import (
- BigAutoFieldModel, PostgreSQLModel, RangeLookupsModel, RangesModel,
+ BigAutoFieldModel,
+ PostgreSQLModel,
+ RangeLookupsModel,
+ RangesModel,
SmallAutoFieldModel,
)
try:
from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange
- from django.contrib.postgres import fields as pg_fields, forms as pg_forms
+ from django.contrib.postgres import fields as pg_fields
+ from django.contrib.postgres import forms as pg_forms
from django.contrib.postgres.validators import (
- RangeMaxValueValidator, RangeMinValueValidator,
+ RangeMaxValueValidator,
+ RangeMinValueValidator,
)
except ImportError:
pass
-@isolate_apps('postgres_tests')
+@isolate_apps("postgres_tests")
class BasicTests(PostgreSQLSimpleTestCase):
def test_get_field_display(self):
class Model(PostgreSQLModel):
field = pg_fields.IntegerRangeField(
choices=[
- ['1-50', [((1, 25), '1-25'), ([26, 50], '26-50')]],
- ((51, 100), '51-100'),
+ ["1-50", [((1, 25), "1-25"), ([26, 50], "26-50")]],
+ ((51, 100), "51-100"),
],
)
tests = (
- ((1, 25), '1-25'),
- ([26, 50], '26-50'),
- ((51, 100), '51-100'),
- ((1, 2), '(1, 2)'),
- ([1, 2], '[1, 2]'),
+ ((1, 25), "1-25"),
+ ([26, 50], "26-50"),
+ ((51, 100), "51-100"),
+ ((1, 2), "(1, 2)"),
+ ([1, 2], "[1, 2]"),
)
for value, display in tests:
with self.subTest(value=value, display=display):
@@ -59,7 +64,7 @@ class BasicTests(PostgreSQLSimpleTestCase):
for field_type in discrete_range_types:
msg = f"Cannot use 'default_bounds' with {field_type.__name__}."
with self.assertRaisesMessage(TypeError, msg):
- field_type(choices=[((51, 100), '51-100')], default_bounds='[]')
+ field_type(choices=[((51, 100), "51-100")], default_bounds="[]")
def test_continuous_range_fields_default_bounds(self):
continuous_range_types = [
@@ -67,11 +72,11 @@ class BasicTests(PostgreSQLSimpleTestCase):
pg_fields.DateTimeRangeField,
]
for field_type in continuous_range_types:
- field = field_type(choices=[((51, 100), '51-100')], default_bounds='[]')
- self.assertEqual(field.default_bounds, '[]')
+ field = field_type(choices=[((51, 100), "51-100")], default_bounds="[]")
+ self.assertEqual(field.default_bounds, "[]")
def test_invalid_default_bounds(self):
- tests = [')]', ')[', '](', '])', '([', '[(', 'x', '', None]
+ tests = [")]", ")[", "](", "])", "([", "[(", "x", "", None]
msg = "default_bounds must be one of '[)', '(]', '()', or '[]'."
for invalid_bounds in tests:
with self.assertRaisesMessage(ValueError, msg):
@@ -81,13 +86,12 @@ class BasicTests(PostgreSQLSimpleTestCase):
field = pg_fields.DecimalRangeField()
*_, kwargs = field.deconstruct()
self.assertEqual(kwargs, {})
- field = pg_fields.DecimalRangeField(default_bounds='[]')
+ field = pg_fields.DecimalRangeField(default_bounds="[]")
*_, kwargs = field.deconstruct()
- self.assertEqual(kwargs, {'default_bounds': '[]'})
+ self.assertEqual(kwargs, {"default_bounds": "[]"})
class TestSaveLoad(PostgreSQLTestCase):
-
def test_all_fields(self):
now = timezone.now()
instance = RangesModel(
@@ -124,15 +128,15 @@ class TestSaveLoad(PostgreSQLTestCase):
loaded = RangesModel.objects.get()
self.assertEqual(
loaded.timestamps_closed_bounds,
- DateTimeTZRange(range_[0], range_[1], '[]'),
+ DateTimeTZRange(range_[0], range_[1], "[]"),
)
self.assertEqual(
loaded.timestamps,
- DateTimeTZRange(range_[0], range_[1], '[)'),
+ DateTimeTZRange(range_[0], range_[1], "[)"),
)
def test_range_object_boundaries(self):
- r = NumericRange(0, 10, '[]')
+ r = NumericRange(0, 10, "[]")
instance = RangesModel(decimals=r)
instance.save()
loaded = RangesModel.objects.get()
@@ -143,14 +147,14 @@ class TestSaveLoad(PostgreSQLTestCase):
range_ = DateTimeTZRange(
timezone.now(),
timezone.now() + datetime.timedelta(hours=1),
- bounds='()',
+ bounds="()",
)
RangesModel.objects.create(timestamps_closed_bounds=range_)
loaded = RangesModel.objects.get()
self.assertEqual(loaded.timestamps_closed_bounds, range_)
def test_unbounded(self):
- r = NumericRange(None, None, '()')
+ r = NumericRange(None, None, "()")
instance = RangesModel(decimals=r)
instance.save()
loaded = RangesModel.objects.get()
@@ -171,13 +175,12 @@ class TestSaveLoad(PostgreSQLTestCase):
def test_model_set_on_base_field(self):
instance = RangesModel()
- field = instance._meta.get_field('ints')
+ field = instance._meta.get_field("ints")
self.assertEqual(field.model, RangesModel)
self.assertEqual(field.base_field.model, RangesModel)
class TestRangeContainsLookup(PostgreSQLTestCase):
-
@classmethod
def setUpTestData(cls):
cls.timestamps = [
@@ -189,8 +192,7 @@ class TestRangeContainsLookup(PostgreSQLTestCase):
datetime.datetime(year=2016, month=2, day=2),
]
cls.aware_timestamps = [
- timezone.make_aware(timestamp)
- for timestamp in cls.timestamps
+ timezone.make_aware(timestamp) for timestamp in cls.timestamps
]
cls.dates = [
datetime.date(year=2016, month=1, day=1),
@@ -230,13 +232,13 @@ class TestRangeContainsLookup(PostgreSQLTestCase):
(self.timestamps[1], self.timestamps[2]),
(self.aware_timestamps[1], self.aware_timestamps[2]),
Value(self.dates[0]),
- Func(F('dates'), function='lower', output_field=DateTimeField()),
- F('timestamps_inner'),
+ Func(F("dates"), function="lower", output_field=DateTimeField()),
+ F("timestamps_inner"),
)
for filter_arg in filter_args:
with self.subTest(filter_arg=filter_arg):
self.assertCountEqual(
- RangesModel.objects.filter(**{'timestamps__contains': filter_arg}),
+ RangesModel.objects.filter(**{"timestamps__contains": filter_arg}),
[self.obj, self.aware_obj],
)
@@ -245,28 +247,29 @@ class TestRangeContainsLookup(PostgreSQLTestCase):
self.timestamps[1],
(self.dates[1], self.dates[2]),
Value(self.dates[0], output_field=DateField()),
- Func(F('timestamps'), function='lower', output_field=DateField()),
- F('dates_inner'),
+ Func(F("timestamps"), function="lower", output_field=DateField()),
+ F("dates_inner"),
)
for filter_arg in filter_args:
with self.subTest(filter_arg=filter_arg):
self.assertCountEqual(
- RangesModel.objects.filter(**{'dates__contains': filter_arg}),
+ RangesModel.objects.filter(**{"dates__contains": filter_arg}),
[self.obj, self.aware_obj],
)
class TestQuerying(PostgreSQLTestCase):
-
@classmethod
def setUpTestData(cls):
- cls.objs = RangesModel.objects.bulk_create([
- RangesModel(ints=NumericRange(0, 10)),
- RangesModel(ints=NumericRange(5, 15)),
- RangesModel(ints=NumericRange(None, 0)),
- RangesModel(ints=NumericRange(empty=True)),
- RangesModel(ints=None),
- ])
+ cls.objs = RangesModel.objects.bulk_create(
+ [
+ RangesModel(ints=NumericRange(0, 10)),
+ RangesModel(ints=NumericRange(5, 15)),
+ RangesModel(ints=NumericRange(None, 0)),
+ RangesModel(ints=NumericRange(empty=True)),
+ RangesModel(ints=None),
+ ]
+ )
def test_exact(self):
self.assertSequenceEqual(
@@ -359,26 +362,28 @@ class TestQuerying(PostgreSQLTestCase):
)
def test_bound_type(self):
- decimals = RangesModel.objects.bulk_create([
- RangesModel(decimals=NumericRange(None, 10)),
- RangesModel(decimals=NumericRange(10, None)),
- RangesModel(decimals=NumericRange(5, 15)),
- RangesModel(decimals=NumericRange(5, 15, '(]')),
- ])
+ decimals = RangesModel.objects.bulk_create(
+ [
+ RangesModel(decimals=NumericRange(None, 10)),
+ RangesModel(decimals=NumericRange(10, None)),
+ RangesModel(decimals=NumericRange(5, 15)),
+ RangesModel(decimals=NumericRange(5, 15, "(]")),
+ ]
+ )
tests = [
- ('lower_inc', True, [decimals[1], decimals[2]]),
- ('lower_inc', False, [decimals[0], decimals[3]]),
- ('lower_inf', True, [decimals[0]]),
- ('lower_inf', False, [decimals[1], decimals[2], decimals[3]]),
- ('upper_inc', True, [decimals[3]]),
- ('upper_inc', False, [decimals[0], decimals[1], decimals[2]]),
- ('upper_inf', True, [decimals[1]]),
- ('upper_inf', False, [decimals[0], decimals[2], decimals[3]]),
+ ("lower_inc", True, [decimals[1], decimals[2]]),
+ ("lower_inc", False, [decimals[0], decimals[3]]),
+ ("lower_inf", True, [decimals[0]]),
+ ("lower_inf", False, [decimals[1], decimals[2], decimals[3]]),
+ ("upper_inc", True, [decimals[3]]),
+ ("upper_inc", False, [decimals[0], decimals[1], decimals[2]]),
+ ("upper_inf", True, [decimals[1]]),
+ ("upper_inf", False, [decimals[0], decimals[2], decimals[3]]),
]
for lookup, filter_arg, excepted_result in tests:
with self.subTest(lookup=lookup, filter_arg=filter_arg):
self.assertSequenceEqual(
- RangesModel.objects.filter(**{'decimals__%s' % lookup: filter_arg}),
+ RangesModel.objects.filter(**{"decimals__%s" % lookup: filter_arg}),
excepted_result,
)
@@ -386,32 +391,38 @@ class TestQuerying(PostgreSQLTestCase):
class TestQueryingWithRanges(PostgreSQLTestCase):
def test_date_range(self):
objs = [
- RangeLookupsModel.objects.create(date='2015-01-01'),
- RangeLookupsModel.objects.create(date='2015-05-05'),
+ RangeLookupsModel.objects.create(date="2015-01-01"),
+ RangeLookupsModel.objects.create(date="2015-05-05"),
]
self.assertSequenceEqual(
- RangeLookupsModel.objects.filter(date__contained_by=DateRange('2015-01-01', '2015-05-04')),
+ RangeLookupsModel.objects.filter(
+ date__contained_by=DateRange("2015-01-01", "2015-05-04")
+ ),
[objs[0]],
)
def test_date_range_datetime_field(self):
objs = [
- RangeLookupsModel.objects.create(timestamp='2015-01-01'),
- RangeLookupsModel.objects.create(timestamp='2015-05-05'),
+ RangeLookupsModel.objects.create(timestamp="2015-01-01"),
+ RangeLookupsModel.objects.create(timestamp="2015-05-05"),
]
self.assertSequenceEqual(
- RangeLookupsModel.objects.filter(timestamp__date__contained_by=DateRange('2015-01-01', '2015-05-04')),
+ RangeLookupsModel.objects.filter(
+ timestamp__date__contained_by=DateRange("2015-01-01", "2015-05-04")
+ ),
[objs[0]],
)
def test_datetime_range(self):
objs = [
- RangeLookupsModel.objects.create(timestamp='2015-01-01T09:00:00'),
- RangeLookupsModel.objects.create(timestamp='2015-05-05T17:00:00'),
+ RangeLookupsModel.objects.create(timestamp="2015-01-01T09:00:00"),
+ RangeLookupsModel.objects.create(timestamp="2015-05-05T17:00:00"),
]
self.assertSequenceEqual(
RangeLookupsModel.objects.filter(
- timestamp__contained_by=DateTimeTZRange('2015-01-01T09:00', '2015-05-04T23:55')
+ timestamp__contained_by=DateTimeTZRange(
+ "2015-01-01T09:00", "2015-05-04T23:55"
+ )
),
[objs[0]],
)
@@ -423,7 +434,9 @@ class TestQueryingWithRanges(PostgreSQLTestCase):
RangeLookupsModel.objects.create(small_integer=-1),
]
self.assertSequenceEqual(
- RangeLookupsModel.objects.filter(small_integer__contained_by=NumericRange(4, 6)),
+ RangeLookupsModel.objects.filter(
+ small_integer__contained_by=NumericRange(4, 6)
+ ),
[objs[1]],
)
@@ -435,7 +448,7 @@ class TestQueryingWithRanges(PostgreSQLTestCase):
]
self.assertSequenceEqual(
RangeLookupsModel.objects.filter(integer__contained_by=NumericRange(1, 98)),
- [objs[0]]
+ [objs[0]],
)
def test_biginteger_range(self):
@@ -445,19 +458,23 @@ class TestQueryingWithRanges(PostgreSQLTestCase):
RangeLookupsModel.objects.create(big_integer=-1),
]
self.assertSequenceEqual(
- RangeLookupsModel.objects.filter(big_integer__contained_by=NumericRange(1, 98)),
- [objs[0]]
+ RangeLookupsModel.objects.filter(
+ big_integer__contained_by=NumericRange(1, 98)
+ ),
+ [objs[0]],
)
def test_decimal_field_contained_by(self):
objs = [
- RangeLookupsModel.objects.create(decimal_field=Decimal('1.33')),
- RangeLookupsModel.objects.create(decimal_field=Decimal('2.88')),
- RangeLookupsModel.objects.create(decimal_field=Decimal('99.17')),
+ RangeLookupsModel.objects.create(decimal_field=Decimal("1.33")),
+ RangeLookupsModel.objects.create(decimal_field=Decimal("2.88")),
+ RangeLookupsModel.objects.create(decimal_field=Decimal("99.17")),
]
self.assertSequenceEqual(
RangeLookupsModel.objects.filter(
- decimal_field__contained_by=NumericRange(Decimal('1.89'), Decimal('7.91')),
+ decimal_field__contained_by=NumericRange(
+ Decimal("1.89"), Decimal("7.91")
+ ),
),
[objs[1]],
)
@@ -470,13 +487,13 @@ class TestQueryingWithRanges(PostgreSQLTestCase):
]
self.assertSequenceEqual(
RangeLookupsModel.objects.filter(float__contained_by=NumericRange(1, 98)),
- [objs[0]]
+ [objs[0]],
)
def test_small_auto_field_contained_by(self):
- objs = SmallAutoFieldModel.objects.bulk_create([
- SmallAutoFieldModel() for i in range(1, 5)
- ])
+ objs = SmallAutoFieldModel.objects.bulk_create(
+ [SmallAutoFieldModel() for i in range(1, 5)]
+ )
self.assertSequenceEqual(
SmallAutoFieldModel.objects.filter(
id__contained_by=NumericRange(objs[1].pk, objs[3].pk),
@@ -485,9 +502,9 @@ class TestQueryingWithRanges(PostgreSQLTestCase):
)
def test_auto_field_contained_by(self):
- objs = RangeLookupsModel.objects.bulk_create([
- RangeLookupsModel() for i in range(1, 5)
- ])
+ objs = RangeLookupsModel.objects.bulk_create(
+ [RangeLookupsModel() for i in range(1, 5)]
+ )
self.assertSequenceEqual(
RangeLookupsModel.objects.filter(
id__contained_by=NumericRange(objs[1].pk, objs[3].pk),
@@ -496,9 +513,9 @@ class TestQueryingWithRanges(PostgreSQLTestCase):
)
def test_big_auto_field_contained_by(self):
- objs = BigAutoFieldModel.objects.bulk_create([
- BigAutoFieldModel() for i in range(1, 5)
- ])
+ objs = BigAutoFieldModel.objects.bulk_create(
+ [BigAutoFieldModel() for i in range(1, 5)]
+ )
self.assertSequenceEqual(
BigAutoFieldModel.objects.filter(
id__contained_by=NumericRange(objs[1].pk, objs[3].pk),
@@ -513,8 +530,8 @@ class TestQueryingWithRanges(PostgreSQLTestCase):
RangeLookupsModel.objects.create(float=99, parent=parent),
]
self.assertSequenceEqual(
- RangeLookupsModel.objects.filter(float__contained_by=F('parent__decimals')),
- [objs[0]]
+ RangeLookupsModel.objects.filter(float__contained_by=F("parent__decimals")),
+ [objs[0]],
)
def test_exclude(self):
@@ -525,7 +542,7 @@ class TestQueryingWithRanges(PostgreSQLTestCase):
]
self.assertSequenceEqual(
RangeLookupsModel.objects.exclude(float__contained_by=NumericRange(0, 100)),
- [objs[2]]
+ [objs[2]],
)
@@ -550,44 +567,49 @@ class TestSerialization(PostgreSQLSimpleTestCase):
def test_dumping(self):
instance = RangesModel(
- ints=NumericRange(0, 10), decimals=NumericRange(empty=True),
+ ints=NumericRange(0, 10),
+ decimals=NumericRange(empty=True),
timestamps=DateTimeTZRange(self.lower_dt, self.upper_dt),
timestamps_closed_bounds=DateTimeTZRange(
- self.lower_dt, self.upper_dt, bounds='()',
+ self.lower_dt,
+ self.upper_dt,
+ bounds="()",
),
dates=DateRange(self.lower_date, self.upper_date),
)
- data = serializers.serialize('json', [instance])
+ data = serializers.serialize("json", [instance])
dumped = json.loads(data)
- for field in ('ints', 'dates', 'timestamps', 'timestamps_closed_bounds'):
- dumped[0]['fields'][field] = json.loads(dumped[0]['fields'][field])
+ for field in ("ints", "dates", "timestamps", "timestamps_closed_bounds"):
+ dumped[0]["fields"][field] = json.loads(dumped[0]["fields"][field])
check = json.loads(self.test_data)
- for field in ('ints', 'dates', 'timestamps', 'timestamps_closed_bounds'):
- check[0]['fields'][field] = json.loads(check[0]['fields'][field])
+ for field in ("ints", "dates", "timestamps", "timestamps_closed_bounds"):
+ check[0]["fields"][field] = json.loads(check[0]["fields"][field])
self.assertEqual(dumped, check)
def test_loading(self):
- instance = list(serializers.deserialize('json', self.test_data))[0].object
+ instance = list(serializers.deserialize("json", self.test_data))[0].object
self.assertEqual(instance.ints, NumericRange(0, 10))
self.assertEqual(instance.decimals, NumericRange(empty=True))
self.assertIsNone(instance.bigints)
self.assertEqual(instance.dates, DateRange(self.lower_date, self.upper_date))
- self.assertEqual(instance.timestamps, DateTimeTZRange(self.lower_dt, self.upper_dt))
+ self.assertEqual(
+ instance.timestamps, DateTimeTZRange(self.lower_dt, self.upper_dt)
+ )
self.assertEqual(
instance.timestamps_closed_bounds,
- DateTimeTZRange(self.lower_dt, self.upper_dt, bounds='()'),
+ DateTimeTZRange(self.lower_dt, self.upper_dt, bounds="()"),
)
def test_serialize_range_with_null(self):
instance = RangesModel(ints=NumericRange(None, 10))
- data = serializers.serialize('json', [instance])
- new_instance = list(serializers.deserialize('json', data))[0].object
+ data = serializers.serialize("json", [instance])
+ new_instance = list(serializers.deserialize("json", data))[0].object
self.assertEqual(new_instance.ints, NumericRange(None, 10))
instance = RangesModel(ints=NumericRange(10, None))
- data = serializers.serialize('json', [instance])
- new_instance = list(serializers.deserialize('json', data))[0].object
+ data = serializers.serialize("json", [instance])
+ new_instance = list(serializers.deserialize("json", data))[0].object
self.assertEqual(new_instance.ints, NumericRange(10, None))
@@ -596,60 +618,59 @@ class TestChecks(PostgreSQLSimpleTestCase):
class Model(PostgreSQLModel):
field = pg_fields.IntegerRangeField(
choices=[
- ['1-50', [((1, 25), '1-25'), ([26, 50], '26-50')]],
- ((51, 100), '51-100'),
+ ["1-50", [((1, 25), "1-25"), ([26, 50], "26-50")]],
+ ((51, 100), "51-100"),
],
)
- self.assertEqual(Model._meta.get_field('field').check(), [])
+ self.assertEqual(Model._meta.get_field("field").check(), [])
-class TestValidators(PostgreSQLSimpleTestCase):
+class TestValidators(PostgreSQLSimpleTestCase):
def test_max(self):
validator = RangeMaxValueValidator(5)
validator(NumericRange(0, 5))
- msg = 'Ensure that this range is completely less than or equal to 5.'
+ msg = "Ensure that this range is completely less than or equal to 5."
with self.assertRaises(exceptions.ValidationError) as cm:
validator(NumericRange(0, 10))
self.assertEqual(cm.exception.messages[0], msg)
- self.assertEqual(cm.exception.code, 'max_value')
+ self.assertEqual(cm.exception.code, "max_value")
with self.assertRaisesMessage(exceptions.ValidationError, msg):
validator(NumericRange(0, None)) # an unbound range
def test_min(self):
validator = RangeMinValueValidator(5)
validator(NumericRange(10, 15))
- msg = 'Ensure that this range is completely greater than or equal to 5.'
+ msg = "Ensure that this range is completely greater than or equal to 5."
with self.assertRaises(exceptions.ValidationError) as cm:
validator(NumericRange(0, 10))
self.assertEqual(cm.exception.messages[0], msg)
- self.assertEqual(cm.exception.code, 'min_value')
+ self.assertEqual(cm.exception.code, "min_value")
with self.assertRaisesMessage(exceptions.ValidationError, msg):
validator(NumericRange(None, 10)) # an unbound range
class TestFormField(PostgreSQLSimpleTestCase):
-
def test_valid_integer(self):
field = pg_forms.IntegerRangeField()
- value = field.clean(['1', '2'])
+ value = field.clean(["1", "2"])
self.assertEqual(value, NumericRange(1, 2))
def test_valid_decimal(self):
field = pg_forms.DecimalRangeField()
- value = field.clean(['1.12345', '2.001'])
- self.assertEqual(value, NumericRange(Decimal('1.12345'), Decimal('2.001')))
+ value = field.clean(["1.12345", "2.001"])
+ self.assertEqual(value, NumericRange(Decimal("1.12345"), Decimal("2.001")))
def test_valid_timestamps(self):
field = pg_forms.DateTimeRangeField()
- value = field.clean(['01/01/2014 00:00:00', '02/02/2014 12:12:12'])
+ value = field.clean(["01/01/2014 00:00:00", "02/02/2014 12:12:12"])
lower = datetime.datetime(2014, 1, 1, 0, 0, 0)
upper = datetime.datetime(2014, 2, 2, 12, 12, 12)
self.assertEqual(value, DateTimeTZRange(lower, upper))
def test_valid_dates(self):
field = pg_forms.DateRangeField()
- value = field.clean(['01/01/2014', '02/02/2014'])
+ value = field.clean(["01/01/2014", "02/02/2014"])
lower = datetime.date(2014, 1, 1)
upper = datetime.date(2014, 2, 2)
self.assertEqual(value, DateRange(lower, upper))
@@ -662,7 +683,9 @@ class TestFormField(PostgreSQLSimpleTestCase):
field = SplitDateTimeRangeField()
form = SplitForm()
- self.assertHTMLEqual(str(form), '''
+ self.assertHTMLEqual(
+ str(form),
+ """
<tr>
<th>
<label>Field:</label>
@@ -674,21 +697,24 @@ class TestFormField(PostgreSQLSimpleTestCase):
<input id="id_field_1_1" name="field_1_1" type="text">
</td>
</tr>
- ''')
- form = SplitForm({
- 'field_0_0': '01/01/2014',
- 'field_0_1': '00:00:00',
- 'field_1_0': '02/02/2014',
- 'field_1_1': '12:12:12',
- })
+ """,
+ )
+ form = SplitForm(
+ {
+ "field_0_0": "01/01/2014",
+ "field_0_1": "00:00:00",
+ "field_1_0": "02/02/2014",
+ "field_1_1": "12:12:12",
+ }
+ )
self.assertTrue(form.is_valid())
lower = datetime.datetime(2014, 1, 1, 0, 0, 0)
upper = datetime.datetime(2014, 2, 2, 12, 12, 12)
- self.assertEqual(form.cleaned_data['field'], DateTimeTZRange(lower, upper))
+ self.assertEqual(form.cleaned_data["field"], DateTimeTZRange(lower, upper))
def test_none(self):
field = pg_forms.IntegerRangeField(required=False)
- value = field.clean(['', ''])
+ value = field.clean(["", ""])
self.assertIsNone(value)
def test_datetime_form_as_table(self):
@@ -707,12 +733,14 @@ class TestFormField(PostgreSQLSimpleTestCase):
<input type="hidden" name="initial-datetime_field_0" id="initial-id_datetime_field_0">
<input type="hidden" name="initial-datetime_field_1" id="initial-id_datetime_field_1">
</td></tr>
- """
+ """,
+ )
+ form = DateTimeRangeForm(
+ {
+ "datetime_field_0": "2010-01-01 11:13:00",
+ "datetime_field_1": "2020-12-12 16:59:00",
+ }
)
- form = DateTimeRangeForm({
- 'datetime_field_0': '2010-01-01 11:13:00',
- 'datetime_field_1': '2020-12-12 16:59:00',
- })
self.assertHTMLEqual(
form.as_table(),
"""
@@ -727,7 +755,7 @@ class TestFormField(PostgreSQLSimpleTestCase):
id="initial-id_datetime_field_0">
<input type="hidden" name="initial-datetime_field_1" value="2020-12-12 16:59:00"
id="initial-id_datetime_field_1"></td></tr>
- """
+ """,
)
def test_datetime_form_initial_data(self):
@@ -735,16 +763,18 @@ class TestFormField(PostgreSQLSimpleTestCase):
datetime_field = pg_forms.DateTimeRangeField(show_hidden_initial=True)
data = QueryDict(mutable=True)
- data.update({
- 'datetime_field_0': '2010-01-01 11:13:00',
- 'datetime_field_1': '',
- 'initial-datetime_field_0': '2010-01-01 10:12:00',
- 'initial-datetime_field_1': '',
- })
+ data.update(
+ {
+ "datetime_field_0": "2010-01-01 11:13:00",
+ "datetime_field_1": "",
+ "initial-datetime_field_0": "2010-01-01 10:12:00",
+ "initial-datetime_field_1": "",
+ }
+ )
form = DateTimeRangeForm(data=data)
self.assertTrue(form.has_changed())
- data['initial-datetime_field_0'] = '2010-01-01 11:13:00'
+ data["initial-datetime_field_0"] = "2010-01-01 11:13:00"
form = DateTimeRangeForm(data=data)
self.assertFalse(form.has_changed())
@@ -752,7 +782,9 @@ class TestFormField(PostgreSQLSimpleTestCase):
class RangeForm(forms.Form):
ints = pg_forms.IntegerRangeField()
- self.assertHTMLEqual(str(RangeForm()), '''
+ self.assertHTMLEqual(
+ str(RangeForm()),
+ """
<tr>
<th><label>Ints:</label></th>
<td>
@@ -760,195 +792,222 @@ class TestFormField(PostgreSQLSimpleTestCase):
<input id="id_ints_1" name="ints_1" type="number">
</td>
</tr>
- ''')
+ """,
+ )
def test_integer_lower_bound_higher(self):
field = pg_forms.IntegerRangeField()
with self.assertRaises(exceptions.ValidationError) as cm:
- field.clean(['10', '2'])
- self.assertEqual(cm.exception.messages[0], 'The start of the range must not exceed the end of the range.')
- self.assertEqual(cm.exception.code, 'bound_ordering')
+ field.clean(["10", "2"])
+ self.assertEqual(
+ cm.exception.messages[0],
+ "The start of the range must not exceed the end of the range.",
+ )
+ self.assertEqual(cm.exception.code, "bound_ordering")
def test_integer_open(self):
field = pg_forms.IntegerRangeField()
- value = field.clean(['', '0'])
+ value = field.clean(["", "0"])
self.assertEqual(value, NumericRange(None, 0))
def test_integer_incorrect_data_type(self):
field = pg_forms.IntegerRangeField()
with self.assertRaises(exceptions.ValidationError) as cm:
- field.clean('1')
- self.assertEqual(cm.exception.messages[0], 'Enter two whole numbers.')
- self.assertEqual(cm.exception.code, 'invalid')
+ field.clean("1")
+ self.assertEqual(cm.exception.messages[0], "Enter two whole numbers.")
+ self.assertEqual(cm.exception.code, "invalid")
def test_integer_invalid_lower(self):
field = pg_forms.IntegerRangeField()
with self.assertRaises(exceptions.ValidationError) as cm:
- field.clean(['a', '2'])
- self.assertEqual(cm.exception.messages[0], 'Enter a whole number.')
+ field.clean(["a", "2"])
+ self.assertEqual(cm.exception.messages[0], "Enter a whole number.")
def test_integer_invalid_upper(self):
field = pg_forms.IntegerRangeField()
with self.assertRaises(exceptions.ValidationError) as cm:
- field.clean(['1', 'b'])
- self.assertEqual(cm.exception.messages[0], 'Enter a whole number.')
+ field.clean(["1", "b"])
+ self.assertEqual(cm.exception.messages[0], "Enter a whole number.")
def test_integer_required(self):
field = pg_forms.IntegerRangeField(required=True)
with self.assertRaises(exceptions.ValidationError) as cm:
- field.clean(['', ''])
- self.assertEqual(cm.exception.messages[0], 'This field is required.')
- value = field.clean([1, ''])
+ field.clean(["", ""])
+ self.assertEqual(cm.exception.messages[0], "This field is required.")
+ value = field.clean([1, ""])
self.assertEqual(value, NumericRange(1, None))
def test_decimal_lower_bound_higher(self):
field = pg_forms.DecimalRangeField()
with self.assertRaises(exceptions.ValidationError) as cm:
- field.clean(['1.8', '1.6'])
- self.assertEqual(cm.exception.messages[0], 'The start of the range must not exceed the end of the range.')
- self.assertEqual(cm.exception.code, 'bound_ordering')
+ field.clean(["1.8", "1.6"])
+ self.assertEqual(
+ cm.exception.messages[0],
+ "The start of the range must not exceed the end of the range.",
+ )
+ self.assertEqual(cm.exception.code, "bound_ordering")
def test_decimal_open(self):
field = pg_forms.DecimalRangeField()
- value = field.clean(['', '3.1415926'])
- self.assertEqual(value, NumericRange(None, Decimal('3.1415926')))
+ value = field.clean(["", "3.1415926"])
+ self.assertEqual(value, NumericRange(None, Decimal("3.1415926")))
def test_decimal_incorrect_data_type(self):
field = pg_forms.DecimalRangeField()
with self.assertRaises(exceptions.ValidationError) as cm:
- field.clean('1.6')
- self.assertEqual(cm.exception.messages[0], 'Enter two numbers.')
- self.assertEqual(cm.exception.code, 'invalid')
+ field.clean("1.6")
+ self.assertEqual(cm.exception.messages[0], "Enter two numbers.")
+ self.assertEqual(cm.exception.code, "invalid")
def test_decimal_invalid_lower(self):
field = pg_forms.DecimalRangeField()
with self.assertRaises(exceptions.ValidationError) as cm:
- field.clean(['a', '3.1415926'])
- self.assertEqual(cm.exception.messages[0], 'Enter a number.')
+ field.clean(["a", "3.1415926"])
+ self.assertEqual(cm.exception.messages[0], "Enter a number.")
def test_decimal_invalid_upper(self):
field = pg_forms.DecimalRangeField()
with self.assertRaises(exceptions.ValidationError) as cm:
- field.clean(['1.61803399', 'b'])
- self.assertEqual(cm.exception.messages[0], 'Enter a number.')
+ field.clean(["1.61803399", "b"])
+ self.assertEqual(cm.exception.messages[0], "Enter a number.")
def test_decimal_required(self):
field = pg_forms.DecimalRangeField(required=True)
with self.assertRaises(exceptions.ValidationError) as cm:
- field.clean(['', ''])
- self.assertEqual(cm.exception.messages[0], 'This field is required.')
- value = field.clean(['1.61803399', ''])
- self.assertEqual(value, NumericRange(Decimal('1.61803399'), None))
+ field.clean(["", ""])
+ self.assertEqual(cm.exception.messages[0], "This field is required.")
+ value = field.clean(["1.61803399", ""])
+ self.assertEqual(value, NumericRange(Decimal("1.61803399"), None))
def test_date_lower_bound_higher(self):
field = pg_forms.DateRangeField()
with self.assertRaises(exceptions.ValidationError) as cm:
- field.clean(['2013-04-09', '1976-04-16'])
- self.assertEqual(cm.exception.messages[0], 'The start of the range must not exceed the end of the range.')
- self.assertEqual(cm.exception.code, 'bound_ordering')
+ field.clean(["2013-04-09", "1976-04-16"])
+ self.assertEqual(
+ cm.exception.messages[0],
+ "The start of the range must not exceed the end of the range.",
+ )
+ self.assertEqual(cm.exception.code, "bound_ordering")
def test_date_open(self):
field = pg_forms.DateRangeField()
- value = field.clean(['', '2013-04-09'])
+ value = field.clean(["", "2013-04-09"])
self.assertEqual(value, DateRange(None, datetime.date(2013, 4, 9)))
def test_date_incorrect_data_type(self):
field = pg_forms.DateRangeField()
with self.assertRaises(exceptions.ValidationError) as cm:
- field.clean('1')
- self.assertEqual(cm.exception.messages[0], 'Enter two valid dates.')
- self.assertEqual(cm.exception.code, 'invalid')
+ field.clean("1")
+ self.assertEqual(cm.exception.messages[0], "Enter two valid dates.")
+ self.assertEqual(cm.exception.code, "invalid")
def test_date_invalid_lower(self):
field = pg_forms.DateRangeField()
with self.assertRaises(exceptions.ValidationError) as cm:
- field.clean(['a', '2013-04-09'])
- self.assertEqual(cm.exception.messages[0], 'Enter a valid date.')
+ field.clean(["a", "2013-04-09"])
+ self.assertEqual(cm.exception.messages[0], "Enter a valid date.")
def test_date_invalid_upper(self):
field = pg_forms.DateRangeField()
with self.assertRaises(exceptions.ValidationError) as cm:
- field.clean(['2013-04-09', 'b'])
- self.assertEqual(cm.exception.messages[0], 'Enter a valid date.')
+ field.clean(["2013-04-09", "b"])
+ self.assertEqual(cm.exception.messages[0], "Enter a valid date.")
def test_date_required(self):
field = pg_forms.DateRangeField(required=True)
with self.assertRaises(exceptions.ValidationError) as cm:
- field.clean(['', ''])
- self.assertEqual(cm.exception.messages[0], 'This field is required.')
- value = field.clean(['1976-04-16', ''])
+ field.clean(["", ""])
+ self.assertEqual(cm.exception.messages[0], "This field is required.")
+ value = field.clean(["1976-04-16", ""])
self.assertEqual(value, DateRange(datetime.date(1976, 4, 16), None))
def test_date_has_changed_first(self):
- self.assertTrue(pg_forms.DateRangeField().has_changed(
- ['2010-01-01', '2020-12-12'],
- ['2010-01-31', '2020-12-12'],
- ))
+ self.assertTrue(
+ pg_forms.DateRangeField().has_changed(
+ ["2010-01-01", "2020-12-12"],
+ ["2010-01-31", "2020-12-12"],
+ )
+ )
def test_date_has_changed_last(self):
- self.assertTrue(pg_forms.DateRangeField().has_changed(
- ['2010-01-01', '2020-12-12'],
- ['2010-01-01', '2020-12-31'],
- ))
+ self.assertTrue(
+ pg_forms.DateRangeField().has_changed(
+ ["2010-01-01", "2020-12-12"],
+ ["2010-01-01", "2020-12-31"],
+ )
+ )
def test_datetime_lower_bound_higher(self):
field = pg_forms.DateTimeRangeField()
with self.assertRaises(exceptions.ValidationError) as cm:
- field.clean(['2006-10-25 14:59', '2006-10-25 14:58'])
- self.assertEqual(cm.exception.messages[0], 'The start of the range must not exceed the end of the range.')
- self.assertEqual(cm.exception.code, 'bound_ordering')
+ field.clean(["2006-10-25 14:59", "2006-10-25 14:58"])
+ self.assertEqual(
+ cm.exception.messages[0],
+ "The start of the range must not exceed the end of the range.",
+ )
+ self.assertEqual(cm.exception.code, "bound_ordering")
def test_datetime_open(self):
field = pg_forms.DateTimeRangeField()
- value = field.clean(['', '2013-04-09 11:45'])
- self.assertEqual(value, DateTimeTZRange(None, datetime.datetime(2013, 4, 9, 11, 45)))
+ value = field.clean(["", "2013-04-09 11:45"])
+ self.assertEqual(
+ value, DateTimeTZRange(None, datetime.datetime(2013, 4, 9, 11, 45))
+ )
def test_datetime_incorrect_data_type(self):
field = pg_forms.DateTimeRangeField()
with self.assertRaises(exceptions.ValidationError) as cm:
- field.clean('2013-04-09 11:45')
- self.assertEqual(cm.exception.messages[0], 'Enter two valid date/times.')
- self.assertEqual(cm.exception.code, 'invalid')
+ field.clean("2013-04-09 11:45")
+ self.assertEqual(cm.exception.messages[0], "Enter two valid date/times.")
+ self.assertEqual(cm.exception.code, "invalid")
def test_datetime_invalid_lower(self):
field = pg_forms.DateTimeRangeField()
with self.assertRaises(exceptions.ValidationError) as cm:
- field.clean(['45', '2013-04-09 11:45'])
- self.assertEqual(cm.exception.messages[0], 'Enter a valid date/time.')
+ field.clean(["45", "2013-04-09 11:45"])
+ self.assertEqual(cm.exception.messages[0], "Enter a valid date/time.")
def test_datetime_invalid_upper(self):
field = pg_forms.DateTimeRangeField()
with self.assertRaises(exceptions.ValidationError) as cm:
- field.clean(['2013-04-09 11:45', 'sweet pickles'])
- self.assertEqual(cm.exception.messages[0], 'Enter a valid date/time.')
+ field.clean(["2013-04-09 11:45", "sweet pickles"])
+ self.assertEqual(cm.exception.messages[0], "Enter a valid date/time.")
def test_datetime_required(self):
field = pg_forms.DateTimeRangeField(required=True)
with self.assertRaises(exceptions.ValidationError) as cm:
- field.clean(['', ''])
- self.assertEqual(cm.exception.messages[0], 'This field is required.')
- value = field.clean(['2013-04-09 11:45', ''])
- self.assertEqual(value, DateTimeTZRange(datetime.datetime(2013, 4, 9, 11, 45), None))
+ field.clean(["", ""])
+ self.assertEqual(cm.exception.messages[0], "This field is required.")
+ value = field.clean(["2013-04-09 11:45", ""])
+ self.assertEqual(
+ value, DateTimeTZRange(datetime.datetime(2013, 4, 9, 11, 45), None)
+ )
- @override_settings(USE_TZ=True, TIME_ZONE='Africa/Johannesburg')
+ @override_settings(USE_TZ=True, TIME_ZONE="Africa/Johannesburg")
def test_datetime_prepare_value(self):
field = pg_forms.DateTimeRangeField()
value = field.prepare_value(
- DateTimeTZRange(datetime.datetime(2015, 5, 22, 16, 6, 33, tzinfo=timezone.utc), None)
+ DateTimeTZRange(
+ datetime.datetime(2015, 5, 22, 16, 6, 33, tzinfo=timezone.utc), None
+ )
)
self.assertEqual(value, [datetime.datetime(2015, 5, 22, 18, 6, 33), None])
def test_datetime_has_changed_first(self):
- self.assertTrue(pg_forms.DateTimeRangeField().has_changed(
- ['2010-01-01 00:00', '2020-12-12 00:00'],
- ['2010-01-31 23:00', '2020-12-12 00:00'],
- ))
+ self.assertTrue(
+ pg_forms.DateTimeRangeField().has_changed(
+ ["2010-01-01 00:00", "2020-12-12 00:00"],
+ ["2010-01-31 23:00", "2020-12-12 00:00"],
+ )
+ )
def test_datetime_has_changed_last(self):
- self.assertTrue(pg_forms.DateTimeRangeField().has_changed(
- ['2010-01-01 00:00', '2020-12-12 00:00'],
- ['2010-01-01 00:00', '2020-12-31 23:00'],
- ))
+ self.assertTrue(
+ pg_forms.DateTimeRangeField().has_changed(
+ ["2010-01-01 00:00", "2020-12-12 00:00"],
+ ["2010-01-01 00:00", "2020-12-31 23:00"],
+ )
+ )
def test_model_field_formfield_integer(self):
model_field = pg_fields.IntegerRangeField()
@@ -963,10 +1022,10 @@ class TestFormField(PostgreSQLSimpleTestCase):
self.assertEqual(form_field.range_kwargs, {})
def test_model_field_formfield_float(self):
- model_field = pg_fields.DecimalRangeField(default_bounds='()')
+ model_field = pg_fields.DecimalRangeField(default_bounds="()")
form_field = model_field.formfield()
self.assertIsInstance(form_field, pg_forms.DecimalRangeField)
- self.assertEqual(form_field.range_kwargs, {'bounds': '()'})
+ self.assertEqual(form_field.range_kwargs, {"bounds": "()"})
def test_model_field_formfield_date(self):
model_field = pg_fields.DateRangeField()
@@ -980,33 +1039,33 @@ class TestFormField(PostgreSQLSimpleTestCase):
self.assertIsInstance(form_field, pg_forms.DateTimeRangeField)
self.assertEqual(
form_field.range_kwargs,
- {'bounds': pg_fields.ranges.CANONICAL_RANGE_BOUNDS},
+ {"bounds": pg_fields.ranges.CANONICAL_RANGE_BOUNDS},
)
def test_model_field_formfield_datetime_default_bounds(self):
- model_field = pg_fields.DateTimeRangeField(default_bounds='[]')
+ model_field = pg_fields.DateTimeRangeField(default_bounds="[]")
form_field = model_field.formfield()
self.assertIsInstance(form_field, pg_forms.DateTimeRangeField)
- self.assertEqual(form_field.range_kwargs, {'bounds': '[]'})
+ self.assertEqual(form_field.range_kwargs, {"bounds": "[]"})
def test_model_field_with_default_bounds(self):
- field = pg_forms.DateTimeRangeField(default_bounds='[]')
- value = field.clean(['2014-01-01 00:00:00', '2014-02-03 12:13:14'])
+ field = pg_forms.DateTimeRangeField(default_bounds="[]")
+ value = field.clean(["2014-01-01 00:00:00", "2014-02-03 12:13:14"])
lower = datetime.datetime(2014, 1, 1, 0, 0, 0)
upper = datetime.datetime(2014, 2, 3, 12, 13, 14)
- self.assertEqual(value, DateTimeTZRange(lower, upper, '[]'))
+ self.assertEqual(value, DateTimeTZRange(lower, upper, "[]"))
def test_has_changed(self):
for field, value in (
- (pg_forms.DateRangeField(), ['2010-01-01', '2020-12-12']),
- (pg_forms.DateTimeRangeField(), ['2010-01-01 11:13', '2020-12-12 14:52']),
+ (pg_forms.DateRangeField(), ["2010-01-01", "2020-12-12"]),
+ (pg_forms.DateTimeRangeField(), ["2010-01-01 11:13", "2020-12-12 14:52"]),
(pg_forms.IntegerRangeField(), [1, 2]),
- (pg_forms.DecimalRangeField(), ['1.12345', '2.001']),
+ (pg_forms.DecimalRangeField(), ["1.12345", "2.001"]),
):
with self.subTest(field=field.__class__.__name__):
self.assertTrue(field.has_changed(None, value))
- self.assertTrue(field.has_changed([value[0], ''], value))
- self.assertTrue(field.has_changed(['', value[1]], value))
+ self.assertTrue(field.has_changed([value[0], ""], value))
+ self.assertTrue(field.has_changed(["", value[1]], value))
self.assertFalse(field.has_changed(value, value))
@@ -1014,19 +1073,18 @@ class TestWidget(PostgreSQLSimpleTestCase):
def test_range_widget(self):
f = pg_forms.ranges.DateTimeRangeField()
self.assertHTMLEqual(
- f.widget.render('datetimerange', ''),
- '<input type="text" name="datetimerange_0"><input type="text" name="datetimerange_1">'
+ f.widget.render("datetimerange", ""),
+ '<input type="text" name="datetimerange_0"><input type="text" name="datetimerange_1">',
)
self.assertHTMLEqual(
- f.widget.render('datetimerange', None),
- '<input type="text" name="datetimerange_0"><input type="text" name="datetimerange_1">'
+ f.widget.render("datetimerange", None),
+ '<input type="text" name="datetimerange_0"><input type="text" name="datetimerange_1">',
)
dt_range = DateTimeTZRange(
- datetime.datetime(2006, 1, 10, 7, 30),
- datetime.datetime(2006, 2, 12, 9, 50)
+ datetime.datetime(2006, 1, 10, 7, 30), datetime.datetime(2006, 2, 12, 9, 50)
)
self.assertHTMLEqual(
- f.widget.render('datetimerange', dt_range),
+ f.widget.render("datetimerange", dt_range),
'<input type="text" name="datetimerange_0" value="2006-01-10 07:30:00">'
- '<input type="text" name="datetimerange_1" value="2006-02-12 09:50:00">'
+ '<input type="text" name="datetimerange_1" value="2006-02-12 09:50:00">',
)
diff --git a/tests/postgres_tests/test_search.py b/tests/postgres_tests/test_search.py
index bf6f53ddb2..670b0aaaa9 100644
--- a/tests/postgres_tests/test_search.py
+++ b/tests/postgres_tests/test_search.py
@@ -14,108 +14,125 @@ from .models import Character, Line, LineSavedSearch, Scene
try:
from django.contrib.postgres.search import (
- SearchConfig, SearchHeadline, SearchQuery, SearchRank, SearchVector,
+ SearchConfig,
+ SearchHeadline,
+ SearchQuery,
+ SearchRank,
+ SearchVector,
)
except ImportError:
pass
class GrailTestData:
-
@classmethod
def setUpTestData(cls):
- cls.robin = Scene.objects.create(scene='Scene 10', setting='The dark forest of Ewing')
- cls.minstrel = Character.objects.create(name='Minstrel')
+ cls.robin = Scene.objects.create(
+ scene="Scene 10", setting="The dark forest of Ewing"
+ )
+ cls.minstrel = Character.objects.create(name="Minstrel")
verses = [
(
- 'Bravely bold Sir Robin, rode forth from Camelot. '
- 'He was not afraid to die, o Brave Sir Robin. '
- 'He was not at all afraid to be killed in nasty ways. '
- 'Brave, brave, brave, brave Sir Robin'
+ "Bravely bold Sir Robin, rode forth from Camelot. "
+ "He was not afraid to die, o Brave Sir Robin. "
+ "He was not at all afraid to be killed in nasty ways. "
+ "Brave, brave, brave, brave Sir Robin"
),
(
- 'He was not in the least bit scared to be mashed into a pulp, '
- 'Or to have his eyes gouged out, and his elbows broken. '
- 'To have his kneecaps split, and his body burned away, '
- 'And his limbs all hacked and mangled, brave Sir Robin!'
+ "He was not in the least bit scared to be mashed into a pulp, "
+ "Or to have his eyes gouged out, and his elbows broken. "
+ "To have his kneecaps split, and his body burned away, "
+ "And his limbs all hacked and mangled, brave Sir Robin!"
),
(
- 'His head smashed in and his heart cut out, '
- 'And his liver removed and his bowels unplugged, '
- 'And his nostrils ripped and his bottom burned off,'
- 'And his --'
+ "His head smashed in and his heart cut out, "
+ "And his liver removed and his bowels unplugged, "
+ "And his nostrils ripped and his bottom burned off,"
+ "And his --"
),
]
- cls.verses = [Line.objects.create(
- scene=cls.robin,
- character=cls.minstrel,
- dialogue=verse,
- ) for verse in verses]
+ cls.verses = [
+ Line.objects.create(
+ scene=cls.robin,
+ character=cls.minstrel,
+ dialogue=verse,
+ )
+ for verse in verses
+ ]
cls.verse0, cls.verse1, cls.verse2 = cls.verses
- cls.witch_scene = Scene.objects.create(scene='Scene 5', setting="Sir Bedemir's Castle")
- bedemir = Character.objects.create(name='Bedemir')
- crowd = Character.objects.create(name='Crowd')
- witch = Character.objects.create(name='Witch')
- duck = Character.objects.create(name='Duck')
+ cls.witch_scene = Scene.objects.create(
+ scene="Scene 5", setting="Sir Bedemir's Castle"
+ )
+ bedemir = Character.objects.create(name="Bedemir")
+ crowd = Character.objects.create(name="Crowd")
+ witch = Character.objects.create(name="Witch")
+ duck = Character.objects.create(name="Duck")
cls.bedemir0 = Line.objects.create(
scene=cls.witch_scene,
character=bedemir,
- dialogue='We shall use my larger scales!',
- dialogue_config='english',
+ dialogue="We shall use my larger scales!",
+ dialogue_config="english",
)
cls.bedemir1 = Line.objects.create(
scene=cls.witch_scene,
character=bedemir,
- dialogue='Right, remove the supports!',
- dialogue_config='english',
+ dialogue="Right, remove the supports!",
+ dialogue_config="english",
+ )
+ cls.duck = Line.objects.create(
+ scene=cls.witch_scene, character=duck, dialogue=None
+ )
+ cls.crowd = Line.objects.create(
+ scene=cls.witch_scene, character=crowd, dialogue="A witch! A witch!"
+ )
+ cls.witch = Line.objects.create(
+ scene=cls.witch_scene, character=witch, dialogue="It's a fair cop."
)
- cls.duck = Line.objects.create(scene=cls.witch_scene, character=duck, dialogue=None)
- cls.crowd = Line.objects.create(scene=cls.witch_scene, character=crowd, dialogue='A witch! A witch!')
- cls.witch = Line.objects.create(scene=cls.witch_scene, character=witch, dialogue="It's a fair cop.")
- trojan_rabbit = Scene.objects.create(scene='Scene 8', setting="The castle of Our Master Ruiz' de lu la Ramper")
- guards = Character.objects.create(name='French Guards')
+ trojan_rabbit = Scene.objects.create(
+ scene="Scene 8", setting="The castle of Our Master Ruiz' de lu la Ramper"
+ )
+ guards = Character.objects.create(name="French Guards")
cls.french = Line.objects.create(
scene=trojan_rabbit,
character=guards,
- dialogue='Oh. Un beau cadeau. Oui oui.',
- dialogue_config='french',
+ dialogue="Oh. Un beau cadeau. Oui oui.",
+ dialogue_config="french",
)
-@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
+@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
class SimpleSearchTest(GrailTestData, PostgreSQLTestCase):
-
def test_simple(self):
- searched = Line.objects.filter(dialogue__search='elbows')
+ searched = Line.objects.filter(dialogue__search="elbows")
self.assertSequenceEqual(searched, [self.verse1])
def test_non_exact_match(self):
- searched = Line.objects.filter(dialogue__search='hearts')
+ searched = Line.objects.filter(dialogue__search="hearts")
self.assertSequenceEqual(searched, [self.verse2])
def test_search_two_terms(self):
- searched = Line.objects.filter(dialogue__search='heart bowel')
+ searched = Line.objects.filter(dialogue__search="heart bowel")
self.assertSequenceEqual(searched, [self.verse2])
def test_search_two_terms_with_partial_match(self):
- searched = Line.objects.filter(dialogue__search='Robin killed')
+ searched = Line.objects.filter(dialogue__search="Robin killed")
self.assertSequenceEqual(searched, [self.verse0])
def test_search_query_config(self):
searched = Line.objects.filter(
- dialogue__search=SearchQuery('nostrils', config='simple'),
+ dialogue__search=SearchQuery("nostrils", config="simple"),
)
self.assertSequenceEqual(searched, [self.verse2])
def test_search_with_F_expression(self):
# Non-matching query.
- LineSavedSearch.objects.create(line=self.verse1, query='hearts')
+ LineSavedSearch.objects.create(line=self.verse1, query="hearts")
# Matching query.
- match = LineSavedSearch.objects.create(line=self.verse1, query='elbows')
- for query_expression in [F('query'), SearchQuery(F('query'))]:
+ match = LineSavedSearch.objects.create(line=self.verse1, query="elbows")
+ for query_expression in [F("query"), SearchQuery(F("query"))]:
with self.subTest(query_expression):
searched = LineSavedSearch.objects.filter(
line__dialogue__search=query_expression,
@@ -123,254 +140,296 @@ class SimpleSearchTest(GrailTestData, PostgreSQLTestCase):
self.assertSequenceEqual(searched, [match])
-@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
+@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
class SearchVectorFieldTest(GrailTestData, PostgreSQLTestCase):
def test_existing_vector(self):
- Line.objects.update(dialogue_search_vector=SearchVector('dialogue'))
- searched = Line.objects.filter(dialogue_search_vector=SearchQuery('Robin killed'))
+ Line.objects.update(dialogue_search_vector=SearchVector("dialogue"))
+ searched = Line.objects.filter(
+ dialogue_search_vector=SearchQuery("Robin killed")
+ )
self.assertSequenceEqual(searched, [self.verse0])
def test_existing_vector_config_explicit(self):
- Line.objects.update(dialogue_search_vector=SearchVector('dialogue'))
- searched = Line.objects.filter(dialogue_search_vector=SearchQuery('cadeaux', config='french'))
+ Line.objects.update(dialogue_search_vector=SearchVector("dialogue"))
+ searched = Line.objects.filter(
+ dialogue_search_vector=SearchQuery("cadeaux", config="french")
+ )
self.assertSequenceEqual(searched, [self.french])
def test_single_coalesce_expression(self):
- searched = Line.objects.annotate(search=SearchVector('dialogue')).filter(search='cadeaux')
- self.assertNotIn('COALESCE(COALESCE', str(searched.query))
+ searched = Line.objects.annotate(search=SearchVector("dialogue")).filter(
+ search="cadeaux"
+ )
+ self.assertNotIn("COALESCE(COALESCE", str(searched.query))
class SearchConfigTests(PostgreSQLSimpleTestCase):
def test_from_parameter(self):
self.assertIsNone(SearchConfig.from_parameter(None))
- self.assertEqual(SearchConfig.from_parameter('foo'), SearchConfig('foo'))
- self.assertEqual(SearchConfig.from_parameter(SearchConfig('bar')), SearchConfig('bar'))
+ self.assertEqual(SearchConfig.from_parameter("foo"), SearchConfig("foo"))
+ self.assertEqual(
+ SearchConfig.from_parameter(SearchConfig("bar")), SearchConfig("bar")
+ )
class MultipleFieldsTest(GrailTestData, PostgreSQLTestCase):
-
def test_simple_on_dialogue(self):
searched = Line.objects.annotate(
- search=SearchVector('scene__setting', 'dialogue'),
- ).filter(search='elbows')
+ search=SearchVector("scene__setting", "dialogue"),
+ ).filter(search="elbows")
self.assertSequenceEqual(searched, [self.verse1])
def test_simple_on_scene(self):
searched = Line.objects.annotate(
- search=SearchVector('scene__setting', 'dialogue'),
- ).filter(search='Forest')
+ search=SearchVector("scene__setting", "dialogue"),
+ ).filter(search="Forest")
self.assertCountEqual(searched, self.verses)
def test_non_exact_match(self):
searched = Line.objects.annotate(
- search=SearchVector('scene__setting', 'dialogue'),
- ).filter(search='heart')
+ search=SearchVector("scene__setting", "dialogue"),
+ ).filter(search="heart")
self.assertSequenceEqual(searched, [self.verse2])
def test_search_two_terms(self):
searched = Line.objects.annotate(
- search=SearchVector('scene__setting', 'dialogue'),
- ).filter(search='heart forest')
+ search=SearchVector("scene__setting", "dialogue"),
+ ).filter(search="heart forest")
self.assertSequenceEqual(searched, [self.verse2])
def test_terms_adjacent(self):
searched = Line.objects.annotate(
- search=SearchVector('character__name', 'dialogue'),
- ).filter(search='minstrel')
+ search=SearchVector("character__name", "dialogue"),
+ ).filter(search="minstrel")
self.assertCountEqual(searched, self.verses)
searched = Line.objects.annotate(
- search=SearchVector('scene__setting', 'dialogue'),
- ).filter(search='minstrelbravely')
+ search=SearchVector("scene__setting", "dialogue"),
+ ).filter(search="minstrelbravely")
self.assertSequenceEqual(searched, [])
def test_search_with_null(self):
searched = Line.objects.annotate(
- search=SearchVector('scene__setting', 'dialogue'),
- ).filter(search='bedemir')
- self.assertCountEqual(searched, [self.bedemir0, self.bedemir1, self.crowd, self.witch, self.duck])
+ search=SearchVector("scene__setting", "dialogue"),
+ ).filter(search="bedemir")
+ self.assertCountEqual(
+ searched, [self.bedemir0, self.bedemir1, self.crowd, self.witch, self.duck]
+ )
def test_search_with_non_text(self):
searched = Line.objects.annotate(
- search=SearchVector('id'),
+ search=SearchVector("id"),
).filter(search=str(self.crowd.id))
self.assertSequenceEqual(searched, [self.crowd])
def test_phrase_search(self):
- line_qs = Line.objects.annotate(search=SearchVector('dialogue'))
- searched = line_qs.filter(search=SearchQuery('burned body his away', search_type='phrase'))
+ line_qs = Line.objects.annotate(search=SearchVector("dialogue"))
+ searched = line_qs.filter(
+ search=SearchQuery("burned body his away", search_type="phrase")
+ )
self.assertSequenceEqual(searched, [])
- searched = line_qs.filter(search=SearchQuery('his body burned away', search_type='phrase'))
+ searched = line_qs.filter(
+ search=SearchQuery("his body burned away", search_type="phrase")
+ )
self.assertSequenceEqual(searched, [self.verse1])
def test_phrase_search_with_config(self):
line_qs = Line.objects.annotate(
- search=SearchVector('scene__setting', 'dialogue', config='french'),
+ search=SearchVector("scene__setting", "dialogue", config="french"),
)
searched = line_qs.filter(
- search=SearchQuery('cadeau beau un', search_type='phrase', config='french'),
+ search=SearchQuery("cadeau beau un", search_type="phrase", config="french"),
)
self.assertSequenceEqual(searched, [])
searched = line_qs.filter(
- search=SearchQuery('un beau cadeau', search_type='phrase', config='french'),
+ search=SearchQuery("un beau cadeau", search_type="phrase", config="french"),
)
self.assertSequenceEqual(searched, [self.french])
def test_raw_search(self):
- line_qs = Line.objects.annotate(search=SearchVector('dialogue'))
- searched = line_qs.filter(search=SearchQuery('Robin', search_type='raw'))
+ line_qs = Line.objects.annotate(search=SearchVector("dialogue"))
+ searched = line_qs.filter(search=SearchQuery("Robin", search_type="raw"))
self.assertCountEqual(searched, [self.verse0, self.verse1])
- searched = line_qs.filter(search=SearchQuery("Robin & !'Camelot'", search_type='raw'))
+ searched = line_qs.filter(
+ search=SearchQuery("Robin & !'Camelot'", search_type="raw")
+ )
self.assertSequenceEqual(searched, [self.verse1])
def test_raw_search_with_config(self):
- line_qs = Line.objects.annotate(search=SearchVector('dialogue', config='french'))
+ line_qs = Line.objects.annotate(
+ search=SearchVector("dialogue", config="french")
+ )
searched = line_qs.filter(
- search=SearchQuery("'cadeaux' & 'beaux'", search_type='raw', config='french'),
+ search=SearchQuery(
+ "'cadeaux' & 'beaux'", search_type="raw", config="french"
+ ),
)
self.assertSequenceEqual(searched, [self.french])
- @skipUnlessDBFeature('has_websearch_to_tsquery')
+ @skipUnlessDBFeature("has_websearch_to_tsquery")
def test_web_search(self):
- line_qs = Line.objects.annotate(search=SearchVector('dialogue'))
+ line_qs = Line.objects.annotate(search=SearchVector("dialogue"))
searched = line_qs.filter(
search=SearchQuery(
'"burned body" "split kneecaps"',
- search_type='websearch',
+ search_type="websearch",
),
)
self.assertSequenceEqual(searched, [])
searched = line_qs.filter(
search=SearchQuery(
'"body burned" "kneecaps split" -"nostrils"',
- search_type='websearch',
+ search_type="websearch",
),
)
self.assertSequenceEqual(searched, [self.verse1])
searched = line_qs.filter(
search=SearchQuery(
'"Sir Robin" ("kneecaps" OR "Camelot")',
- search_type='websearch',
+ search_type="websearch",
),
)
self.assertSequenceEqual(searched, [self.verse0, self.verse1])
- @skipUnlessDBFeature('has_websearch_to_tsquery')
+ @skipUnlessDBFeature("has_websearch_to_tsquery")
def test_web_search_with_config(self):
line_qs = Line.objects.annotate(
- search=SearchVector('scene__setting', 'dialogue', config='french'),
+ search=SearchVector("scene__setting", "dialogue", config="french"),
)
searched = line_qs.filter(
- search=SearchQuery('cadeau -beau', search_type='websearch', config='french'),
+ search=SearchQuery(
+ "cadeau -beau", search_type="websearch", config="french"
+ ),
)
self.assertSequenceEqual(searched, [])
searched = line_qs.filter(
- search=SearchQuery('beau cadeau', search_type='websearch', config='french'),
+ search=SearchQuery("beau cadeau", search_type="websearch", config="french"),
)
self.assertSequenceEqual(searched, [self.french])
def test_bad_search_type(self):
- with self.assertRaisesMessage(ValueError, "Unknown search_type argument 'foo'."):
- SearchQuery('kneecaps', search_type='foo')
+ with self.assertRaisesMessage(
+ ValueError, "Unknown search_type argument 'foo'."
+ ):
+ SearchQuery("kneecaps", search_type="foo")
def test_config_query_explicit(self):
searched = Line.objects.annotate(
- search=SearchVector('scene__setting', 'dialogue', config='french'),
- ).filter(search=SearchQuery('cadeaux', config='french'))
+ search=SearchVector("scene__setting", "dialogue", config="french"),
+ ).filter(search=SearchQuery("cadeaux", config="french"))
self.assertSequenceEqual(searched, [self.french])
def test_config_query_implicit(self):
searched = Line.objects.annotate(
- search=SearchVector('scene__setting', 'dialogue', config='french'),
- ).filter(search='cadeaux')
+ search=SearchVector("scene__setting", "dialogue", config="french"),
+ ).filter(search="cadeaux")
self.assertSequenceEqual(searched, [self.french])
def test_config_from_field_explicit(self):
searched = Line.objects.annotate(
- search=SearchVector('scene__setting', 'dialogue', config=F('dialogue_config')),
- ).filter(search=SearchQuery('cadeaux', config=F('dialogue_config')))
+ search=SearchVector(
+ "scene__setting", "dialogue", config=F("dialogue_config")
+ ),
+ ).filter(search=SearchQuery("cadeaux", config=F("dialogue_config")))
self.assertSequenceEqual(searched, [self.french])
def test_config_from_field_implicit(self):
searched = Line.objects.annotate(
- search=SearchVector('scene__setting', 'dialogue', config=F('dialogue_config')),
- ).filter(search='cadeaux')
+ search=SearchVector(
+ "scene__setting", "dialogue", config=F("dialogue_config")
+ ),
+ ).filter(search="cadeaux")
self.assertSequenceEqual(searched, [self.french])
-@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
+@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
class TestCombinations(GrailTestData, PostgreSQLTestCase):
-
def test_vector_add(self):
searched = Line.objects.annotate(
- search=SearchVector('scene__setting') + SearchVector('character__name'),
- ).filter(search='bedemir')
- self.assertCountEqual(searched, [self.bedemir0, self.bedemir1, self.crowd, self.witch, self.duck])
+ search=SearchVector("scene__setting") + SearchVector("character__name"),
+ ).filter(search="bedemir")
+ self.assertCountEqual(
+ searched, [self.bedemir0, self.bedemir1, self.crowd, self.witch, self.duck]
+ )
def test_vector_add_multi(self):
searched = Line.objects.annotate(
search=(
- SearchVector('scene__setting') +
- SearchVector('character__name') +
- SearchVector('dialogue')
+ SearchVector("scene__setting")
+ + SearchVector("character__name")
+ + SearchVector("dialogue")
),
- ).filter(search='bedemir')
- self.assertCountEqual(searched, [self.bedemir0, self.bedemir1, self.crowd, self.witch, self.duck])
+ ).filter(search="bedemir")
+ self.assertCountEqual(
+ searched, [self.bedemir0, self.bedemir1, self.crowd, self.witch, self.duck]
+ )
def test_vector_combined_mismatch(self):
msg = (
- 'SearchVector can only be combined with other SearchVector '
- 'instances, got NoneType.'
+ "SearchVector can only be combined with other SearchVector "
+ "instances, got NoneType."
)
with self.assertRaisesMessage(TypeError, msg):
- Line.objects.filter(dialogue__search=None + SearchVector('character__name'))
+ Line.objects.filter(dialogue__search=None + SearchVector("character__name"))
def test_combine_different_vector_configs(self):
searched = Line.objects.annotate(
search=(
- SearchVector('dialogue', config='english') +
- SearchVector('dialogue', config='french')
+ SearchVector("dialogue", config="english")
+ + SearchVector("dialogue", config="french")
),
).filter(
- search=SearchQuery('cadeaux', config='french') | SearchQuery('nostrils')
+ search=SearchQuery("cadeaux", config="french") | SearchQuery("nostrils")
)
self.assertCountEqual(searched, [self.french, self.verse2])
def test_query_and(self):
searched = Line.objects.annotate(
- search=SearchVector('scene__setting', 'dialogue'),
- ).filter(search=SearchQuery('bedemir') & SearchQuery('scales'))
+ search=SearchVector("scene__setting", "dialogue"),
+ ).filter(search=SearchQuery("bedemir") & SearchQuery("scales"))
self.assertSequenceEqual(searched, [self.bedemir0])
def test_query_multiple_and(self):
searched = Line.objects.annotate(
- search=SearchVector('scene__setting', 'dialogue'),
- ).filter(search=SearchQuery('bedemir') & SearchQuery('scales') & SearchQuery('nostrils'))
+ search=SearchVector("scene__setting", "dialogue"),
+ ).filter(
+ search=SearchQuery("bedemir")
+ & SearchQuery("scales")
+ & SearchQuery("nostrils")
+ )
self.assertSequenceEqual(searched, [])
searched = Line.objects.annotate(
- search=SearchVector('scene__setting', 'dialogue'),
- ).filter(search=SearchQuery('shall') & SearchQuery('use') & SearchQuery('larger'))
+ search=SearchVector("scene__setting", "dialogue"),
+ ).filter(
+ search=SearchQuery("shall") & SearchQuery("use") & SearchQuery("larger")
+ )
self.assertSequenceEqual(searched, [self.bedemir0])
def test_query_or(self):
- searched = Line.objects.filter(dialogue__search=SearchQuery('kneecaps') | SearchQuery('nostrils'))
+ searched = Line.objects.filter(
+ dialogue__search=SearchQuery("kneecaps") | SearchQuery("nostrils")
+ )
self.assertCountEqual(searched, [self.verse1, self.verse2])
def test_query_multiple_or(self):
searched = Line.objects.filter(
- dialogue__search=SearchQuery('kneecaps') | SearchQuery('nostrils') | SearchQuery('Sir Robin')
+ dialogue__search=SearchQuery("kneecaps")
+ | SearchQuery("nostrils")
+ | SearchQuery("Sir Robin")
)
self.assertCountEqual(searched, [self.verse1, self.verse2, self.verse0])
def test_query_invert(self):
- searched = Line.objects.filter(character=self.minstrel, dialogue__search=~SearchQuery('kneecaps'))
+ searched = Line.objects.filter(
+ character=self.minstrel, dialogue__search=~SearchQuery("kneecaps")
+ )
self.assertCountEqual(searched, [self.verse0, self.verse2])
def test_combine_different_configs(self):
searched = Line.objects.filter(
dialogue__search=(
- SearchQuery('cadeau', config='french') |
- SearchQuery('nostrils', config='english')
+ SearchQuery("cadeau", config="french")
+ | SearchQuery("nostrils", config="english")
)
)
self.assertCountEqual(searched, [self.french, self.verse2])
@@ -378,8 +437,8 @@ class TestCombinations(GrailTestData, PostgreSQLTestCase):
def test_combined_configs(self):
searched = Line.objects.filter(
dialogue__search=(
- SearchQuery('nostrils', config='simple') &
- SearchQuery('bowels', config='simple')
+ SearchQuery("nostrils", config="simple")
+ & SearchQuery("bowels", config="simple")
),
)
self.assertSequenceEqual(searched, [self.verse2])
@@ -387,63 +446,96 @@ class TestCombinations(GrailTestData, PostgreSQLTestCase):
def test_combine_raw_phrase(self):
searched = Line.objects.filter(
dialogue__search=(
- SearchQuery('burn:*', search_type='raw', config='simple') |
- SearchQuery('rode forth from Camelot', search_type='phrase')
+ SearchQuery("burn:*", search_type="raw", config="simple")
+ | SearchQuery("rode forth from Camelot", search_type="phrase")
)
)
self.assertCountEqual(searched, [self.verse0, self.verse1, self.verse2])
def test_query_combined_mismatch(self):
msg = (
- 'SearchQuery can only be combined with other SearchQuery '
- 'instances, got NoneType.'
+ "SearchQuery can only be combined with other SearchQuery "
+ "instances, got NoneType."
)
with self.assertRaisesMessage(TypeError, msg):
- Line.objects.filter(dialogue__search=None | SearchQuery('kneecaps'))
+ Line.objects.filter(dialogue__search=None | SearchQuery("kneecaps"))
with self.assertRaisesMessage(TypeError, msg):
- Line.objects.filter(dialogue__search=None & SearchQuery('kneecaps'))
+ Line.objects.filter(dialogue__search=None & SearchQuery("kneecaps"))
-@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
+@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
class TestRankingAndWeights(GrailTestData, PostgreSQLTestCase):
-
def test_ranking(self):
- searched = Line.objects.filter(character=self.minstrel).annotate(
- rank=SearchRank(SearchVector('dialogue'), SearchQuery('brave sir robin')),
- ).order_by('rank')
+ searched = (
+ Line.objects.filter(character=self.minstrel)
+ .annotate(
+ rank=SearchRank(
+ SearchVector("dialogue"), SearchQuery("brave sir robin")
+ ),
+ )
+ .order_by("rank")
+ )
self.assertSequenceEqual(searched, [self.verse2, self.verse1, self.verse0])
def test_rank_passing_untyped_args(self):
- searched = Line.objects.filter(character=self.minstrel).annotate(
- rank=SearchRank('dialogue', 'brave sir robin'),
- ).order_by('rank')
+ searched = (
+ Line.objects.filter(character=self.minstrel)
+ .annotate(
+ rank=SearchRank("dialogue", "brave sir robin"),
+ )
+ .order_by("rank")
+ )
self.assertSequenceEqual(searched, [self.verse2, self.verse1, self.verse0])
def test_weights_in_vector(self):
- vector = SearchVector('dialogue', weight='A') + SearchVector('character__name', weight='D')
- searched = Line.objects.filter(scene=self.witch_scene).annotate(
- rank=SearchRank(vector, SearchQuery('witch')),
- ).order_by('-rank')[:2]
+ vector = SearchVector("dialogue", weight="A") + SearchVector(
+ "character__name", weight="D"
+ )
+ searched = (
+ Line.objects.filter(scene=self.witch_scene)
+ .annotate(
+ rank=SearchRank(vector, SearchQuery("witch")),
+ )
+ .order_by("-rank")[:2]
+ )
self.assertSequenceEqual(searched, [self.crowd, self.witch])
- vector = SearchVector('dialogue', weight='D') + SearchVector('character__name', weight='A')
- searched = Line.objects.filter(scene=self.witch_scene).annotate(
- rank=SearchRank(vector, SearchQuery('witch')),
- ).order_by('-rank')[:2]
+ vector = SearchVector("dialogue", weight="D") + SearchVector(
+ "character__name", weight="A"
+ )
+ searched = (
+ Line.objects.filter(scene=self.witch_scene)
+ .annotate(
+ rank=SearchRank(vector, SearchQuery("witch")),
+ )
+ .order_by("-rank")[:2]
+ )
self.assertSequenceEqual(searched, [self.witch, self.crowd])
def test_ranked_custom_weights(self):
- vector = SearchVector('dialogue', weight='D') + SearchVector('character__name', weight='A')
- searched = Line.objects.filter(scene=self.witch_scene).annotate(
- rank=SearchRank(vector, SearchQuery('witch'), weights=[1, 0, 0, 0.5]),
- ).order_by('-rank')[:2]
+ vector = SearchVector("dialogue", weight="D") + SearchVector(
+ "character__name", weight="A"
+ )
+ searched = (
+ Line.objects.filter(scene=self.witch_scene)
+ .annotate(
+ rank=SearchRank(vector, SearchQuery("witch"), weights=[1, 0, 0, 0.5]),
+ )
+ .order_by("-rank")[:2]
+ )
self.assertSequenceEqual(searched, [self.crowd, self.witch])
def test_ranking_chaining(self):
- searched = Line.objects.filter(character=self.minstrel).annotate(
- rank=SearchRank(SearchVector('dialogue'), SearchQuery('brave sir robin')),
- ).filter(rank__gt=0.3)
+ searched = (
+ Line.objects.filter(character=self.minstrel)
+ .annotate(
+ rank=SearchRank(
+ SearchVector("dialogue"), SearchQuery("brave sir robin")
+ ),
+ )
+ .filter(rank__gt=0.3)
+ )
self.assertSequenceEqual(searched, [self.verse0])
def test_cover_density_ranking(self):
@@ -451,17 +543,21 @@ class TestRankingAndWeights(GrailTestData, PostgreSQLTestCase):
scene=self.robin,
character=self.minstrel,
dialogue=(
- 'Bravely taking to his feet, he beat a very brave retreat. '
- 'A brave retreat brave Sir Robin.'
+ "Bravely taking to his feet, he beat a very brave retreat. "
+ "A brave retreat brave Sir Robin."
+ ),
+ )
+ searched = (
+ Line.objects.filter(character=self.minstrel)
+ .annotate(
+ rank=SearchRank(
+ SearchVector("dialogue"),
+ SearchQuery("brave robin"),
+ cover_density=True,
+ ),
)
+ .order_by("rank", "-pk")
)
- searched = Line.objects.filter(character=self.minstrel).annotate(
- rank=SearchRank(
- SearchVector('dialogue'),
- SearchQuery('brave robin'),
- cover_density=True,
- ),
- ).order_by('rank', '-pk')
self.assertSequenceEqual(
searched,
[self.verse2, not_dense_verse, self.verse1, self.verse0],
@@ -471,16 +567,20 @@ class TestRankingAndWeights(GrailTestData, PostgreSQLTestCase):
short_verse = Line.objects.create(
scene=self.robin,
character=self.minstrel,
- dialogue='A brave retreat brave Sir Robin.',
+ dialogue="A brave retreat brave Sir Robin.",
+ )
+ searched = (
+ Line.objects.filter(character=self.minstrel)
+ .annotate(
+ rank=SearchRank(
+ SearchVector("dialogue"),
+ SearchQuery("brave sir robin"),
+ # Divide the rank by the document length.
+ normalization=2,
+ ),
+ )
+ .order_by("rank")
)
- searched = Line.objects.filter(character=self.minstrel).annotate(
- rank=SearchRank(
- SearchVector('dialogue'),
- SearchQuery('brave sir robin'),
- # Divide the rank by the document length.
- normalization=2,
- ),
- ).order_by('rank')
self.assertSequenceEqual(
searched,
[self.verse2, self.verse1, self.verse0, short_verse],
@@ -490,17 +590,21 @@ class TestRankingAndWeights(GrailTestData, PostgreSQLTestCase):
short_verse = Line.objects.create(
scene=self.robin,
character=self.minstrel,
- dialogue='A brave retreat brave Sir Robin.',
+ dialogue="A brave retreat brave Sir Robin.",
+ )
+ searched = (
+ Line.objects.filter(character=self.minstrel)
+ .annotate(
+ rank=SearchRank(
+ SearchVector("dialogue"),
+ SearchQuery("brave sir robin"),
+ # Divide the rank by the document length and by the number of
+ # unique words in document.
+ normalization=Value(2).bitor(Value(8)),
+ ),
+ )
+ .order_by("rank")
)
- searched = Line.objects.filter(character=self.minstrel).annotate(
- rank=SearchRank(
- SearchVector('dialogue'),
- SearchQuery('brave sir robin'),
- # Divide the rank by the document length and by the number of
- # unique words in document.
- normalization=Value(2).bitor(Value(8)),
- ),
- ).order_by('rank')
self.assertSequenceEqual(
searched,
[self.verse2, self.verse1, self.verse0, short_verse],
@@ -513,13 +617,16 @@ class SearchVectorIndexTests(PostgreSQLTestCase):
# This test should be moved to test_indexes and use a functional
# index instead once support lands (see #26167).
query = Line.objects.all().query
- resolved = SearchVector('id', 'dialogue', config='english').resolve_expression(query)
+ resolved = SearchVector("id", "dialogue", config="english").resolve_expression(
+ query
+ )
compiler = query.get_compiler(connection.alias)
sql, params = resolved.as_sql(compiler, connection)
# Indexed function must be IMMUTABLE.
with connection.cursor() as cursor:
cursor.execute(
- 'CREATE INDEX search_vector_index ON %s USING GIN (%s)' % (Line._meta.db_table, sql),
+ "CREATE INDEX search_vector_index ON %s USING GIN (%s)"
+ % (Line._meta.db_table, sql),
params,
)
@@ -527,24 +634,26 @@ class SearchVectorIndexTests(PostgreSQLTestCase):
class SearchQueryTests(PostgreSQLSimpleTestCase):
def test_str(self):
tests = (
- (~SearchQuery('a'), "~SearchQuery(Value('a'))"),
+ (~SearchQuery("a"), "~SearchQuery(Value('a'))"),
(
- (SearchQuery('a') | SearchQuery('b')) & (SearchQuery('c') | SearchQuery('d')),
+ (SearchQuery("a") | SearchQuery("b"))
+ & (SearchQuery("c") | SearchQuery("d")),
"((SearchQuery(Value('a')) || SearchQuery(Value('b'))) && "
"(SearchQuery(Value('c')) || SearchQuery(Value('d'))))",
),
(
- SearchQuery('a') & (SearchQuery('b') | SearchQuery('c')),
+ SearchQuery("a") & (SearchQuery("b") | SearchQuery("c")),
"(SearchQuery(Value('a')) && (SearchQuery(Value('b')) || "
"SearchQuery(Value('c'))))",
),
(
- (SearchQuery('a') | SearchQuery('b')) & SearchQuery('c'),
+ (SearchQuery("a") | SearchQuery("b")) & SearchQuery("c"),
"((SearchQuery(Value('a')) || SearchQuery(Value('b'))) && "
- "SearchQuery(Value('c')))"
+ "SearchQuery(Value('c')))",
),
(
- SearchQuery('a') & (SearchQuery('b') & (SearchQuery('c') | SearchQuery('d'))),
+ SearchQuery("a")
+ & (SearchQuery("b") & (SearchQuery("c") | SearchQuery("d"))),
"(SearchQuery(Value('a')) && (SearchQuery(Value('b')) && "
"(SearchQuery(Value('c')) || SearchQuery(Value('d')))))",
),
@@ -554,109 +663,112 @@ class SearchQueryTests(PostgreSQLSimpleTestCase):
self.assertEqual(str(query), expected_str)
-@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
+@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
class SearchHeadlineTests(GrailTestData, PostgreSQLTestCase):
def test_headline(self):
searched = Line.objects.annotate(
headline=SearchHeadline(
- F('dialogue'),
- SearchQuery('brave sir robin'),
- config=SearchConfig('english'),
+ F("dialogue"),
+ SearchQuery("brave sir robin"),
+ config=SearchConfig("english"),
),
).get(pk=self.verse0.pk)
self.assertEqual(
searched.headline,
- '<b>Robin</b>. He was not at all afraid to be killed in nasty '
- 'ways. <b>Brave</b>, <b>brave</b>, <b>brave</b>, <b>brave</b> '
- '<b>Sir</b> <b>Robin</b>',
+ "<b>Robin</b>. He was not at all afraid to be killed in nasty "
+ "ways. <b>Brave</b>, <b>brave</b>, <b>brave</b>, <b>brave</b> "
+ "<b>Sir</b> <b>Robin</b>",
)
def test_headline_untyped_args(self):
searched = Line.objects.annotate(
- headline=SearchHeadline('dialogue', 'killed', config='english'),
+ headline=SearchHeadline("dialogue", "killed", config="english"),
).get(pk=self.verse0.pk)
self.assertEqual(
searched.headline,
- 'Robin. He was not at all afraid to be <b>killed</b> in nasty '
- 'ways. Brave, brave, brave, brave Sir Robin',
+ "Robin. He was not at all afraid to be <b>killed</b> in nasty "
+ "ways. Brave, brave, brave, brave Sir Robin",
)
def test_headline_with_config(self):
searched = Line.objects.annotate(
headline=SearchHeadline(
- 'dialogue',
- SearchQuery('cadeaux', config='french'),
- config='french',
+ "dialogue",
+ SearchQuery("cadeaux", config="french"),
+ config="french",
),
).get(pk=self.french.pk)
self.assertEqual(
searched.headline,
- 'Oh. Un beau <b>cadeau</b>. Oui oui.',
+ "Oh. Un beau <b>cadeau</b>. Oui oui.",
)
def test_headline_with_config_from_field(self):
searched = Line.objects.annotate(
headline=SearchHeadline(
- 'dialogue',
- SearchQuery('cadeaux', config=F('dialogue_config')),
- config=F('dialogue_config'),
+ "dialogue",
+ SearchQuery("cadeaux", config=F("dialogue_config")),
+ config=F("dialogue_config"),
),
).get(pk=self.french.pk)
self.assertEqual(
searched.headline,
- 'Oh. Un beau <b>cadeau</b>. Oui oui.',
+ "Oh. Un beau <b>cadeau</b>. Oui oui.",
)
def test_headline_separator_options(self):
searched = Line.objects.annotate(
headline=SearchHeadline(
- 'dialogue',
- 'brave sir robin',
- start_sel='<span>',
- stop_sel='</span>',
+ "dialogue",
+ "brave sir robin",
+ start_sel="<span>",
+ stop_sel="</span>",
),
).get(pk=self.verse0.pk)
self.assertEqual(
searched.headline,
- '<span>Robin</span>. He was not at all afraid to be killed in '
- 'nasty ways. <span>Brave</span>, <span>brave</span>, <span>brave'
- '</span>, <span>brave</span> <span>Sir</span> <span>Robin</span>',
+ "<span>Robin</span>. He was not at all afraid to be killed in "
+ "nasty ways. <span>Brave</span>, <span>brave</span>, <span>brave"
+ "</span>, <span>brave</span> <span>Sir</span> <span>Robin</span>",
)
def test_headline_highlight_all_option(self):
searched = Line.objects.annotate(
headline=SearchHeadline(
- 'dialogue',
- SearchQuery('brave sir robin', config='english'),
+ "dialogue",
+ SearchQuery("brave sir robin", config="english"),
highlight_all=True,
),
).get(pk=self.verse0.pk)
self.assertIn(
- '<b>Bravely</b> bold <b>Sir</b> <b>Robin</b>, rode forth from '
- 'Camelot. He was not afraid to die, o ',
+ "<b>Bravely</b> bold <b>Sir</b> <b>Robin</b>, rode forth from "
+ "Camelot. He was not afraid to die, o ",
searched.headline,
)
def test_headline_short_word_option(self):
searched = Line.objects.annotate(
headline=SearchHeadline(
- 'dialogue',
- SearchQuery('Camelot', config='english'),
+ "dialogue",
+ SearchQuery("Camelot", config="english"),
short_word=5,
min_words=8,
),
).get(pk=self.verse0.pk)
- self.assertEqual(searched.headline, (
- '<b>Camelot</b>. He was not afraid to die, o Brave Sir Robin. He '
- 'was not at all afraid'
- ))
+ self.assertEqual(
+ searched.headline,
+ (
+ "<b>Camelot</b>. He was not afraid to die, o Brave Sir Robin. He "
+ "was not at all afraid"
+ ),
+ )
def test_headline_fragments_words_options(self):
searched = Line.objects.annotate(
headline=SearchHeadline(
- 'dialogue',
- SearchQuery('brave sir robin', config='english'),
- fragment_delimiter='...<br>',
+ "dialogue",
+ SearchQuery("brave sir robin", config="english"),
+ fragment_delimiter="...<br>",
max_fragments=4,
max_words=3,
min_words=1,
@@ -664,8 +776,8 @@ class SearchHeadlineTests(GrailTestData, PostgreSQLTestCase):
).get(pk=self.verse0.pk)
self.assertEqual(
searched.headline,
- '<b>Sir</b> <b>Robin</b>, rode...<br>'
- '<b>Brave</b> <b>Sir</b> <b>Robin</b>...<br>'
- '<b>Brave</b>, <b>brave</b>, <b>brave</b>...<br>'
- '<b>brave</b> <b>Sir</b> <b>Robin</b>',
+ "<b>Sir</b> <b>Robin</b>, rode...<br>"
+ "<b>Brave</b> <b>Sir</b> <b>Robin</b>...<br>"
+ "<b>Brave</b>, <b>brave</b>, <b>brave</b>...<br>"
+ "<b>brave</b> <b>Sir</b> <b>Robin</b>",
)
diff --git a/tests/postgres_tests/test_signals.py b/tests/postgres_tests/test_signals.py
index f1569c361c..764524d8e6 100644
--- a/tests/postgres_tests/test_signals.py
+++ b/tests/postgres_tests/test_signals.py
@@ -4,14 +4,15 @@ from . import PostgreSQLTestCase
try:
from django.contrib.postgres.signals import (
- get_citext_oids, get_hstore_oids, register_type_handlers,
+ get_citext_oids,
+ get_hstore_oids,
+ register_type_handlers,
)
except ImportError:
pass # pyscogp2 isn't installed.
class OIDTests(PostgreSQLTestCase):
-
def assertOIDs(self, oids):
self.assertIsInstance(oids, tuple)
self.assertGreater(len(oids), 0)
diff --git a/tests/postgres_tests/test_trigram.py b/tests/postgres_tests/test_trigram.py
index 079a32a19b..a0502e0b9b 100644
--- a/tests/postgres_tests/test_trigram.py
+++ b/tests/postgres_tests/test_trigram.py
@@ -5,65 +5,74 @@ from .models import CharFieldModel, TextFieldModel
try:
from django.contrib.postgres.search import (
- TrigramDistance, TrigramSimilarity, TrigramWordDistance,
+ TrigramDistance,
+ TrigramSimilarity,
+ TrigramWordDistance,
TrigramWordSimilarity,
)
except ImportError:
pass
-@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
+@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
class TrigramTest(PostgreSQLTestCase):
Model = CharFieldModel
@classmethod
def setUpTestData(cls):
- cls.Model.objects.bulk_create([
- cls.Model(field='Matthew'),
- cls.Model(field='Cat sat on mat.'),
- cls.Model(field='Dog sat on rug.'),
- ])
+ cls.Model.objects.bulk_create(
+ [
+ cls.Model(field="Matthew"),
+ cls.Model(field="Cat sat on mat."),
+ cls.Model(field="Dog sat on rug."),
+ ]
+ )
def test_trigram_search(self):
self.assertQuerysetEqual(
- self.Model.objects.filter(field__trigram_similar='Mathew'),
- ['Matthew'],
+ self.Model.objects.filter(field__trigram_similar="Mathew"),
+ ["Matthew"],
transform=lambda instance: instance.field,
)
def test_trigram_word_search(self):
obj = self.Model.objects.create(
- field='Gumby rides on the path of Middlesbrough',
+ field="Gumby rides on the path of Middlesbrough",
)
self.assertSequenceEqual(
- self.Model.objects.filter(field__trigram_word_similar='Middlesborough'),
+ self.Model.objects.filter(field__trigram_word_similar="Middlesborough"),
[obj],
)
def test_trigram_similarity(self):
- search = 'Bat sat on cat.'
+ search = "Bat sat on cat."
# Round result of similarity because PostgreSQL 12+ uses greater
# precision.
self.assertQuerysetEqual(
self.Model.objects.filter(
field__trigram_similar=search,
- ).annotate(similarity=TrigramSimilarity('field', search)).order_by('-similarity'),
- [('Cat sat on mat.', 0.625), ('Dog sat on rug.', 0.333333)],
+ )
+ .annotate(similarity=TrigramSimilarity("field", search))
+ .order_by("-similarity"),
+ [("Cat sat on mat.", 0.625), ("Dog sat on rug.", 0.333333)],
transform=lambda instance: (instance.field, round(instance.similarity, 6)),
ordered=True,
)
def test_trigram_word_similarity(self):
- search = 'mat'
+ search = "mat"
self.assertSequenceEqual(
self.Model.objects.filter(
field__trigram_word_similar=search,
- ).annotate(
- word_similarity=TrigramWordSimilarity(search, 'field'),
- ).values('field', 'word_similarity').order_by('-word_similarity'),
+ )
+ .annotate(
+ word_similarity=TrigramWordSimilarity(search, "field"),
+ )
+ .values("field", "word_similarity")
+ .order_by("-word_similarity"),
[
- {'field': 'Cat sat on mat.', 'word_similarity': 1.0},
- {'field': 'Matthew', 'word_similarity': 0.75},
+ {"field": "Cat sat on mat.", "word_similarity": 1.0},
+ {"field": "Matthew", "word_similarity": 0.75},
],
)
@@ -72,9 +81,11 @@ class TrigramTest(PostgreSQLTestCase):
# precision.
self.assertQuerysetEqual(
self.Model.objects.annotate(
- distance=TrigramDistance('field', 'Bat sat on cat.'),
- ).filter(distance__lte=0.7).order_by('distance'),
- [('Cat sat on mat.', 0.375), ('Dog sat on rug.', 0.666667)],
+ distance=TrigramDistance("field", "Bat sat on cat."),
+ )
+ .filter(distance__lte=0.7)
+ .order_by("distance"),
+ [("Cat sat on mat.", 0.375), ("Dog sat on rug.", 0.666667)],
transform=lambda instance: (instance.field, round(instance.distance, 6)),
ordered=True,
)
@@ -82,13 +93,16 @@ class TrigramTest(PostgreSQLTestCase):
def test_trigram_word_similarity_alternate(self):
self.assertSequenceEqual(
self.Model.objects.annotate(
- word_distance=TrigramWordDistance('mat', 'field'),
- ).filter(
+ word_distance=TrigramWordDistance("mat", "field"),
+ )
+ .filter(
word_distance__lte=0.7,
- ).values('field', 'word_distance').order_by('word_distance'),
+ )
+ .values("field", "word_distance")
+ .order_by("word_distance"),
[
- {'field': 'Cat sat on mat.', 'word_distance': 0},
- {'field': 'Matthew', 'word_distance': 0.25},
+ {"field": "Cat sat on mat.", "word_distance": 0},
+ {"field": "Matthew", "word_distance": 0.25},
],
)
@@ -97,4 +111,5 @@ class TrigramTextFieldTest(TrigramTest):
"""
TextField has the same behavior as CharField regarding trigram lookups.
"""
+
Model = TextFieldModel
diff --git a/tests/postgres_tests/test_unaccent.py b/tests/postgres_tests/test_unaccent.py
index 6d52f1d7dd..4188d90794 100644
--- a/tests/postgres_tests/test_unaccent.py
+++ b/tests/postgres_tests/test_unaccent.py
@@ -5,25 +5,27 @@ from . import PostgreSQLTestCase
from .models import CharFieldModel, TextFieldModel
-@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
+@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
class UnaccentTest(PostgreSQLTestCase):
Model = CharFieldModel
@classmethod
def setUpTestData(cls):
- cls.Model.objects.bulk_create([
- cls.Model(field="àéÖ"),
- cls.Model(field="aeO"),
- cls.Model(field="aeo"),
- ])
+ cls.Model.objects.bulk_create(
+ [
+ cls.Model(field="àéÖ"),
+ cls.Model(field="aeO"),
+ cls.Model(field="aeo"),
+ ]
+ )
def test_unaccent(self):
self.assertQuerysetEqual(
self.Model.objects.filter(field__unaccent="aeO"),
["àéÖ", "aeO"],
transform=lambda instance: instance.field,
- ordered=False
+ ordered=False,
)
def test_unaccent_chained(self):
@@ -35,39 +37,39 @@ class UnaccentTest(PostgreSQLTestCase):
self.Model.objects.filter(field__unaccent__iexact="aeO"),
["àéÖ", "aeO", "aeo"],
transform=lambda instance: instance.field,
- ordered=False
+ ordered=False,
)
self.assertQuerysetEqual(
self.Model.objects.filter(field__unaccent__endswith="éÖ"),
["àéÖ", "aeO"],
transform=lambda instance: instance.field,
- ordered=False
+ ordered=False,
)
def test_unaccent_with_conforming_strings_off(self):
"""SQL is valid when standard_conforming_strings is off."""
with connection.cursor() as cursor:
- cursor.execute('SHOW standard_conforming_strings')
- disable_conforming_strings = cursor.fetchall()[0][0] == 'on'
+ cursor.execute("SHOW standard_conforming_strings")
+ disable_conforming_strings = cursor.fetchall()[0][0] == "on"
if disable_conforming_strings:
- cursor.execute('SET standard_conforming_strings TO off')
+ cursor.execute("SET standard_conforming_strings TO off")
try:
self.assertQuerysetEqual(
- self.Model.objects.filter(field__unaccent__endswith='éÖ'),
- ['àéÖ', 'aeO'],
+ self.Model.objects.filter(field__unaccent__endswith="éÖ"),
+ ["àéÖ", "aeO"],
transform=lambda instance: instance.field,
ordered=False,
)
finally:
if disable_conforming_strings:
- cursor.execute('SET standard_conforming_strings TO on')
+ cursor.execute("SET standard_conforming_strings TO on")
def test_unaccent_accentuated_needle(self):
self.assertQuerysetEqual(
self.Model.objects.filter(field__unaccent="aéÖ"),
["àéÖ", "aeO"],
transform=lambda instance: instance.field,
- ordered=False
+ ordered=False,
)
@@ -76,4 +78,5 @@ class UnaccentTextFieldTest(UnaccentTest):
TextField should have the exact same behavior as CharField
regarding unaccent lookups.
"""
+
Model = TextFieldModel