summaryrefslogtreecommitdiff
path: root/tests/postgres_tests/test_operations.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/postgres_tests/test_operations.py')
-rw-r--r--tests/postgres_tests/test_operations.py110
1 files changed, 106 insertions, 4 deletions
diff --git a/tests/postgres_tests/test_operations.py b/tests/postgres_tests/test_operations.py
index 9faf938c55..1464f3177e 100644
--- a/tests/postgres_tests/test_operations.py
+++ b/tests/postgres_tests/test_operations.py
@@ -3,9 +3,11 @@ from unittest import mock
from migrations.test_base import OperationTestBase
-from django.db import NotSupportedError, connection
+from django.db import (
+ IntegrityError, NotSupportedError, connection, transaction,
+)
from django.db.migrations.state import ProjectState
-from django.db.models import Index
+from django.db.models import CheckConstraint, Index, Q, UniqueConstraint
from django.db.utils import ProgrammingError
from django.test import modify_settings, override_settings, skipUnlessDBFeature
from django.test.utils import CaptureQueriesContext
@@ -15,8 +17,9 @@ from . import PostgreSQLTestCase
try:
from django.contrib.postgres.indexes import BrinIndex, BTreeIndex
from django.contrib.postgres.operations import (
- AddIndexConcurrently, BloomExtension, CreateCollation, CreateExtension,
- RemoveCollation, RemoveIndexConcurrently,
+ AddConstraintNotValid, AddIndexConcurrently, BloomExtension,
+ CreateCollation, CreateExtension, RemoveCollation,
+ RemoveIndexConcurrently, ValidateConstraint,
)
except ImportError:
pass
@@ -392,3 +395,102 @@ class RemoveCollationTests(PostgreSQLTestCase):
self.assertEqual(name, 'RemoveCollation')
self.assertEqual(args, [])
self.assertEqual(kwargs, {'name': 'C_test', 'locale': 'C'})
+
+
+@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.')
+@modify_settings(INSTALLED_APPS={'append': 'migrations'})
+class AddConstraintNotValidTests(OperationTestBase):
+ app_label = 'test_add_constraint_not_valid'
+
+ def test_non_check_constraint_not_supported(self):
+ constraint = UniqueConstraint(fields=['pink'], name='pony_pink_uniq')
+ msg = 'AddConstraintNotValid.constraint must be a check constraint.'
+ with self.assertRaisesMessage(TypeError, msg):
+ AddConstraintNotValid(model_name='pony', constraint=constraint)
+
+ def test_add(self):
+ table_name = f'{self.app_label}_pony'
+ constraint_name = 'pony_pink_gte_check'
+ constraint = CheckConstraint(check=Q(pink__gte=4), name=constraint_name)
+ operation = AddConstraintNotValid('Pony', constraint=constraint)
+ project_state, new_state = self.make_test_state(self.app_label, operation)
+ self.assertEqual(
+ operation.describe(),
+ f'Create not valid constraint {constraint_name} on model Pony',
+ )
+ self.assertEqual(
+ operation.migration_name_fragment,
+ f'pony_{constraint_name}_not_valid',
+ )
+ self.assertEqual(
+ len(new_state.models[self.app_label, 'pony'].options['constraints']),
+ 1,
+ )
+ self.assertConstraintNotExists(table_name, constraint_name)
+ Pony = new_state.apps.get_model(self.app_label, 'Pony')
+ self.assertEqual(len(Pony._meta.constraints), 1)
+ Pony.objects.create(pink=2, weight=1.0)
+ # Add constraint.
+ with connection.schema_editor(atomic=True) as editor:
+ operation.database_forwards(self.app_label, editor, project_state, new_state)
+ msg = f'check constraint "{constraint_name}"'
+ with self.assertRaisesMessage(IntegrityError, msg), transaction.atomic():
+ Pony.objects.create(pink=3, weight=1.0)
+ self.assertConstraintExists(table_name, constraint_name)
+ # Reversal.
+ with connection.schema_editor(atomic=True) as editor:
+ operation.database_backwards(self.app_label, editor, project_state, new_state)
+ self.assertConstraintNotExists(table_name, constraint_name)
+ Pony.objects.create(pink=3, weight=1.0)
+ # Deconstruction.
+ name, args, kwargs = operation.deconstruct()
+ self.assertEqual(name, 'AddConstraintNotValid')
+ self.assertEqual(args, [])
+ self.assertEqual(kwargs, {'model_name': 'Pony', 'constraint': constraint})
+
+
+@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.')
+@modify_settings(INSTALLED_APPS={'append': 'migrations'})
+class ValidateConstraintTests(OperationTestBase):
+ app_label = 'test_validate_constraint'
+
+ def test_validate(self):
+ constraint_name = 'pony_pink_gte_check'
+ constraint = CheckConstraint(check=Q(pink__gte=4), name=constraint_name)
+ operation = AddConstraintNotValid('Pony', constraint=constraint)
+ project_state, new_state = self.make_test_state(self.app_label, operation)
+ Pony = new_state.apps.get_model(self.app_label, 'Pony')
+ obj = Pony.objects.create(pink=2, weight=1.0)
+ # Add constraint.
+ with connection.schema_editor(atomic=True) as editor:
+ operation.database_forwards(self.app_label, editor, project_state, new_state)
+ project_state = new_state
+ new_state = new_state.clone()
+ operation = ValidateConstraint('Pony', name=constraint_name)
+ operation.state_forwards(self.app_label, new_state)
+ self.assertEqual(
+ operation.describe(),
+ f'Validate constraint {constraint_name} on model Pony',
+ )
+ self.assertEqual(
+ operation.migration_name_fragment,
+ f'pony_validate_{constraint_name}',
+ )
+ # Validate constraint.
+ with connection.schema_editor(atomic=True) as editor:
+ msg = f'check constraint "{constraint_name}"'
+ with self.assertRaisesMessage(IntegrityError, msg):
+ operation.database_forwards(self.app_label, editor, project_state, new_state)
+ obj.pink = 5
+ obj.save()
+ with connection.schema_editor(atomic=True) as editor:
+ operation.database_forwards(self.app_label, editor, project_state, new_state)
+ # Reversal is a noop.
+ with connection.schema_editor() as editor:
+ with self.assertNumQueries(0):
+ operation.database_backwards(self.app_label, editor, new_state, project_state)
+ # Deconstruction.
+ name, args, kwargs = operation.deconstruct()
+ self.assertEqual(name, 'ValidateConstraint')
+ self.assertEqual(args, [])
+ self.assertEqual(kwargs, {'model_name': 'Pony', 'name': constraint_name})