diff options
Diffstat (limited to 'tests/postgres_tests/test_operations.py')
| -rw-r--r-- | tests/postgres_tests/test_operations.py | 110 |
1 files changed, 106 insertions, 4 deletions
diff --git a/tests/postgres_tests/test_operations.py b/tests/postgres_tests/test_operations.py index 9faf938c55..1464f3177e 100644 --- a/tests/postgres_tests/test_operations.py +++ b/tests/postgres_tests/test_operations.py @@ -3,9 +3,11 @@ from unittest import mock from migrations.test_base import OperationTestBase -from django.db import NotSupportedError, connection +from django.db import ( + IntegrityError, NotSupportedError, connection, transaction, +) from django.db.migrations.state import ProjectState -from django.db.models import Index +from django.db.models import CheckConstraint, Index, Q, UniqueConstraint from django.db.utils import ProgrammingError from django.test import modify_settings, override_settings, skipUnlessDBFeature from django.test.utils import CaptureQueriesContext @@ -15,8 +17,9 @@ from . import PostgreSQLTestCase try: from django.contrib.postgres.indexes import BrinIndex, BTreeIndex from django.contrib.postgres.operations import ( - AddIndexConcurrently, BloomExtension, CreateCollation, CreateExtension, - RemoveCollation, RemoveIndexConcurrently, + AddConstraintNotValid, AddIndexConcurrently, BloomExtension, + CreateCollation, CreateExtension, RemoveCollation, + RemoveIndexConcurrently, ValidateConstraint, ) except ImportError: pass @@ -392,3 +395,102 @@ class RemoveCollationTests(PostgreSQLTestCase): self.assertEqual(name, 'RemoveCollation') self.assertEqual(args, []) self.assertEqual(kwargs, {'name': 'C_test', 'locale': 'C'}) + + +@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.') +@modify_settings(INSTALLED_APPS={'append': 'migrations'}) +class AddConstraintNotValidTests(OperationTestBase): + app_label = 'test_add_constraint_not_valid' + + def test_non_check_constraint_not_supported(self): + constraint = UniqueConstraint(fields=['pink'], name='pony_pink_uniq') + msg = 'AddConstraintNotValid.constraint must be a check constraint.' + with self.assertRaisesMessage(TypeError, msg): + AddConstraintNotValid(model_name='pony', constraint=constraint) + + def test_add(self): + table_name = f'{self.app_label}_pony' + constraint_name = 'pony_pink_gte_check' + constraint = CheckConstraint(check=Q(pink__gte=4), name=constraint_name) + operation = AddConstraintNotValid('Pony', constraint=constraint) + project_state, new_state = self.make_test_state(self.app_label, operation) + self.assertEqual( + operation.describe(), + f'Create not valid constraint {constraint_name} on model Pony', + ) + self.assertEqual( + operation.migration_name_fragment, + f'pony_{constraint_name}_not_valid', + ) + self.assertEqual( + len(new_state.models[self.app_label, 'pony'].options['constraints']), + 1, + ) + self.assertConstraintNotExists(table_name, constraint_name) + Pony = new_state.apps.get_model(self.app_label, 'Pony') + self.assertEqual(len(Pony._meta.constraints), 1) + Pony.objects.create(pink=2, weight=1.0) + # Add constraint. + with connection.schema_editor(atomic=True) as editor: + operation.database_forwards(self.app_label, editor, project_state, new_state) + msg = f'check constraint "{constraint_name}"' + with self.assertRaisesMessage(IntegrityError, msg), transaction.atomic(): + Pony.objects.create(pink=3, weight=1.0) + self.assertConstraintExists(table_name, constraint_name) + # Reversal. + with connection.schema_editor(atomic=True) as editor: + operation.database_backwards(self.app_label, editor, project_state, new_state) + self.assertConstraintNotExists(table_name, constraint_name) + Pony.objects.create(pink=3, weight=1.0) + # Deconstruction. + name, args, kwargs = operation.deconstruct() + self.assertEqual(name, 'AddConstraintNotValid') + self.assertEqual(args, []) + self.assertEqual(kwargs, {'model_name': 'Pony', 'constraint': constraint}) + + +@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.') +@modify_settings(INSTALLED_APPS={'append': 'migrations'}) +class ValidateConstraintTests(OperationTestBase): + app_label = 'test_validate_constraint' + + def test_validate(self): + constraint_name = 'pony_pink_gte_check' + constraint = CheckConstraint(check=Q(pink__gte=4), name=constraint_name) + operation = AddConstraintNotValid('Pony', constraint=constraint) + project_state, new_state = self.make_test_state(self.app_label, operation) + Pony = new_state.apps.get_model(self.app_label, 'Pony') + obj = Pony.objects.create(pink=2, weight=1.0) + # Add constraint. + with connection.schema_editor(atomic=True) as editor: + operation.database_forwards(self.app_label, editor, project_state, new_state) + project_state = new_state + new_state = new_state.clone() + operation = ValidateConstraint('Pony', name=constraint_name) + operation.state_forwards(self.app_label, new_state) + self.assertEqual( + operation.describe(), + f'Validate constraint {constraint_name} on model Pony', + ) + self.assertEqual( + operation.migration_name_fragment, + f'pony_validate_{constraint_name}', + ) + # Validate constraint. + with connection.schema_editor(atomic=True) as editor: + msg = f'check constraint "{constraint_name}"' + with self.assertRaisesMessage(IntegrityError, msg): + operation.database_forwards(self.app_label, editor, project_state, new_state) + obj.pink = 5 + obj.save() + with connection.schema_editor(atomic=True) as editor: + operation.database_forwards(self.app_label, editor, project_state, new_state) + # Reversal is a noop. + with connection.schema_editor() as editor: + with self.assertNumQueries(0): + operation.database_backwards(self.app_label, editor, new_state, project_state) + # Deconstruction. + name, args, kwargs = operation.deconstruct() + self.assertEqual(name, 'ValidateConstraint') + self.assertEqual(args, []) + self.assertEqual(kwargs, {'model_name': 'Pony', 'name': constraint_name}) |
