summaryrefslogtreecommitdiff
path: root/tests/postgres_tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests/postgres_tests')
-rw-r--r--tests/postgres_tests/test_operations.py49
1 files changed, 47 insertions, 2 deletions
diff --git a/tests/postgres_tests/test_operations.py b/tests/postgres_tests/test_operations.py
index 95c88d5fe0..7bcf6b2300 100644
--- a/tests/postgres_tests/test_operations.py
+++ b/tests/postgres_tests/test_operations.py
@@ -3,12 +3,16 @@ import unittest
from migrations.test_base import OperationTestBase
from django.db import NotSupportedError, connection
+from django.db.migrations.state import ProjectState
from django.db.models import Index
-from django.test import modify_settings
+from django.test import modify_settings, override_settings
+from django.test.utils import CaptureQueriesContext
+
+from . import PostgreSQLTestCase
try:
from django.contrib.postgres.operations import (
- AddIndexConcurrently, RemoveIndexConcurrently,
+ AddIndexConcurrently, CreateExtension, RemoveIndexConcurrently,
)
from django.contrib.postgres.indexes import BrinIndex, BTreeIndex
except ImportError:
@@ -141,3 +145,44 @@ class RemoveIndexConcurrentlyTests(OperationTestBase):
self.assertEqual(name, 'RemoveIndexConcurrently')
self.assertEqual(args, [])
self.assertEqual(kwargs, {'model_name': 'Pony', 'name': 'pony_pink_idx'})
+
+
+class NoExtensionRouter():
+ def allow_migrate(self, db, app_label, **hints):
+ return False
+
+
+@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.')
+class CreateExtensionTests(PostgreSQLTestCase):
+ app_label = 'test_allow_create_extention'
+
+ @override_settings(DATABASE_ROUTERS=[NoExtensionRouter()])
+ def test_no_allow_migrate(self):
+ operation = CreateExtension('uuid-ossp')
+ project_state = ProjectState()
+ new_state = project_state.clone()
+ # Don't create an extension.
+ with CaptureQueriesContext(connection) as captured_queries:
+ with connection.schema_editor(atomic=False) as editor:
+ operation.database_forwards(self.app_label, editor, project_state, new_state)
+ self.assertEqual(len(captured_queries), 0)
+ # Reversal.
+ with CaptureQueriesContext(connection) as captured_queries:
+ with connection.schema_editor(atomic=False) as editor:
+ operation.database_backwards(self.app_label, editor, new_state, project_state)
+ self.assertEqual(len(captured_queries), 0)
+
+ def test_allow_migrate(self):
+ operation = CreateExtension('uuid-ossp')
+ project_state = ProjectState()
+ new_state = project_state.clone()
+ # Create an extension.
+ with CaptureQueriesContext(connection) as captured_queries:
+ with connection.schema_editor(atomic=False) as editor:
+ operation.database_forwards(self.app_label, editor, project_state, new_state)
+ self.assertIn('CREATE EXTENSION', captured_queries[0]['sql'])
+ # Reversal.
+ with CaptureQueriesContext(connection) as captured_queries:
+ with connection.schema_editor(atomic=False) as editor:
+ operation.database_backwards(self.app_label, editor, new_state, project_state)
+ self.assertIn('DROP EXTENSION', captured_queries[0]['sql'])