summaryrefslogtreecommitdiff
path: root/django/test/testcases.py
diff options
context:
space:
mode:
Diffstat (limited to 'django/test/testcases.py')
-rw-r--r--django/test/testcases.py180
1 files changed, 136 insertions, 44 deletions
diff --git a/django/test/testcases.py b/django/test/testcases.py
index 36986185ce..f820684c87 100644
--- a/django/test/testcases.py
+++ b/django/test/testcases.py
@@ -4,9 +4,11 @@ import posixpath
import sys
import threading
import unittest
+import warnings
from collections import Counter
from contextlib import contextmanager
from copy import copy
+from difflib import get_close_matches
from functools import wraps
from unittest.util import safe_repr
from urllib.parse import (
@@ -17,7 +19,7 @@ from urllib.request import url2pathname
from django.apps import apps
from django.conf import settings
from django.core import mail
-from django.core.exceptions import ValidationError
+from django.core.exceptions import ImproperlyConfigured, ValidationError
from django.core.files import locks
from django.core.handlers.wsgi import WSGIHandler, get_path_info
from django.core.management import call_command
@@ -36,6 +38,7 @@ from django.test.utils import (
override_settings,
)
from django.utils.decorators import classproperty
+from django.utils.deprecation import RemovedInDjango31Warning
from django.views.static import serve
__all__ = ('TestCase', 'TransactionTestCase',
@@ -133,16 +136,31 @@ class _AssertTemplateNotUsedContext(_AssertTemplateUsedContext):
class _CursorFailure:
- def __init__(self, cls_name, wrapped):
- self.cls_name = cls_name
+ def __init__(self, wrapped, message):
self.wrapped = wrapped
+ self.message = message
def __call__(self):
- raise AssertionError(
- "Database queries aren't allowed in SimpleTestCase. "
- "Either use TestCase or TransactionTestCase to ensure proper test isolation or "
- "set %s.allow_database_queries to True to silence this failure." % self.cls_name
- )
+ raise AssertionError(self.message)
+
+
+class _SimpleTestCaseDatabasesDescriptor:
+ """Descriptor for SimpleTestCase.allow_database_queries deprecation."""
+ def __get__(self, instance, cls=None):
+ try:
+ allow_database_queries = cls.allow_database_queries
+ except AttributeError:
+ pass
+ else:
+ msg = (
+ '`SimpleTestCase.allow_database_queries` is deprecated. '
+ 'Restrict the databases available during the execution of '
+ '%s.%s with the `databases` attribute instead.'
+ ) % (cls.__module__, cls.__qualname__)
+ warnings.warn(msg, RemovedInDjango31Warning)
+ if allow_database_queries:
+ return {DEFAULT_DB_ALIAS}
+ return set()
class SimpleTestCase(unittest.TestCase):
@@ -153,9 +171,13 @@ class SimpleTestCase(unittest.TestCase):
_overridden_settings = None
_modified_settings = None
- # Tests shouldn't be allowed to query the database since
- # this base class doesn't enforce any isolation.
- allow_database_queries = False
+ databases = _SimpleTestCaseDatabasesDescriptor()
+ _disallowed_database_msg = (
+ 'Database queries are not allowed in SimpleTestCase subclasses. '
+ 'Either subclass TestCase or TransactionTestCase to ensure proper '
+ 'test isolation or add %(alias)r to %(test)s.databases to silence '
+ 'this failure.'
+ )
@classmethod
def setUpClass(cls):
@@ -166,19 +188,51 @@ class SimpleTestCase(unittest.TestCase):
if cls._modified_settings:
cls._cls_modified_context = modify_settings(cls._modified_settings)
cls._cls_modified_context.enable()
- if not cls.allow_database_queries:
- for alias in connections:
- connection = connections[alias]
- connection.cursor = _CursorFailure(cls.__name__, connection.cursor)
- connection.chunked_cursor = _CursorFailure(cls.__name__, connection.chunked_cursor)
+ cls._add_cursor_failures()
+
+ @classmethod
+ def _validate_databases(cls):
+ if cls.databases == '__all__':
+ return frozenset(connections)
+ for alias in cls.databases:
+ if alias not in connections:
+ message = '%s.%s.databases refers to %r which is not defined in settings.DATABASES.' % (
+ cls.__module__,
+ cls.__qualname__,
+ alias,
+ )
+ close_matches = get_close_matches(alias, list(connections))
+ if close_matches:
+ message += ' Did you mean %r?' % close_matches[0]
+ raise ImproperlyConfigured(message)
+ return frozenset(cls.databases)
+
+ @classmethod
+ def _add_cursor_failures(cls):
+ cls.databases = cls._validate_databases()
+ for alias in connections:
+ if alias in cls.databases:
+ continue
+ connection = connections[alias]
+ message = cls._disallowed_database_msg % {
+ 'test': '%s.%s' % (cls.__module__, cls.__qualname__),
+ 'alias': alias,
+ }
+ connection.cursor = _CursorFailure(connection.cursor, message)
+ connection.chunked_cursor = _CursorFailure(connection.chunked_cursor, message)
+
+ @classmethod
+ def _remove_cursor_failures(cls):
+ for alias in connections:
+ if alias in cls.databases:
+ continue
+ connection = connections[alias]
+ connection.cursor = connection.cursor.wrapped
+ connection.chunked_cursor = connection.chunked_cursor.wrapped
@classmethod
def tearDownClass(cls):
- if not cls.allow_database_queries:
- for alias in connections:
- connection = connections[alias]
- connection.cursor = connection.cursor.wrapped
- connection.chunked_cursor = connection.chunked_cursor.wrapped
+ cls._remove_cursor_failures()
if hasattr(cls, '_cls_modified_context'):
cls._cls_modified_context.disable()
delattr(cls, '_cls_modified_context')
@@ -806,6 +860,26 @@ class SimpleTestCase(unittest.TestCase):
self.fail(self._formatMessage(msg, standardMsg))
+class _TransactionTestCaseDatabasesDescriptor:
+ """Descriptor for TransactionTestCase.multi_db deprecation."""
+ msg = (
+ '`TransactionTestCase.multi_db` is deprecated. Databases available '
+ 'during this test can be defined using %s.%s.databases.'
+ )
+
+ def __get__(self, instance, cls=None):
+ try:
+ multi_db = cls.multi_db
+ except AttributeError:
+ pass
+ else:
+ msg = self.msg % (cls.__module__, cls.__qualname__)
+ warnings.warn(msg, RemovedInDjango31Warning)
+ if multi_db:
+ return set(connections)
+ return {DEFAULT_DB_ALIAS}
+
+
class TransactionTestCase(SimpleTestCase):
# Subclasses can ask for resetting of auto increment sequence before each
@@ -818,8 +892,12 @@ class TransactionTestCase(SimpleTestCase):
# Subclasses can define fixtures which will be automatically installed.
fixtures = None
- # Do the tests in this class query non-default databases?
- multi_db = False
+ databases = _TransactionTestCaseDatabasesDescriptor()
+ _disallowed_database_msg = (
+ 'Database queries to %(alias)r are not allowed in this test. Add '
+ '%(alias)r to %(test)s.databases to ensure proper test isolation '
+ 'and silence this failure.'
+ )
# If transactions aren't available, Django will serialize the database
# contents into a fixture during setup and flush and reload them
@@ -827,10 +905,6 @@ class TransactionTestCase(SimpleTestCase):
# This can be slow; this flag allows enabling on a per-case basis.
serialized_rollback = False
- # Since tests will be wrapped in a transaction, or serialized if they
- # are not available, we allow queries to be run.
- allow_database_queries = True
-
def _pre_setup(self):
"""
Perform pre-test setup:
@@ -870,15 +944,13 @@ class TransactionTestCase(SimpleTestCase):
@classmethod
def _databases_names(cls, include_mirrors=True):
- # If the test case has a multi_db=True flag, act on all databases,
- # including mirrors or not. Otherwise, just on the default DB.
- if cls.multi_db:
- return [
- alias for alias in connections
- if include_mirrors or not connections[alias].settings_dict['TEST']['MIRROR']
- ]
- else:
- return [DEFAULT_DB_ALIAS]
+ # Only consider allowed database aliases, including mirrors or not.
+ return [
+ alias for alias in connections
+ if alias in cls.databases and (
+ include_mirrors or not connections[alias].settings_dict['TEST']['MIRROR']
+ )
+ ]
def _reset_sequences(self, db_name):
conn = connections[db_name]
@@ -984,9 +1056,21 @@ class TransactionTestCase(SimpleTestCase):
func(*args, **kwargs)
-def connections_support_transactions():
- """Return True if all connections support transactions."""
- return all(conn.features.supports_transactions for conn in connections.all())
+def connections_support_transactions(aliases=None):
+ """
+ Return whether or not all (or specified) connections support
+ transactions.
+ """
+ conns = connections.all() if aliases is None else (connections[alias] for alias in aliases)
+ return all(conn.features.supports_transactions for conn in conns)
+
+
+class _TestCaseDatabasesDescriptor(_TransactionTestCaseDatabasesDescriptor):
+ """Descriptor for TestCase.multi_db deprecation."""
+ msg = (
+ '`TestCase.multi_db` is deprecated. Databases available during this '
+ 'test can be defined using %s.%s.databases.'
+ )
class TestCase(TransactionTestCase):
@@ -1002,6 +1086,8 @@ class TestCase(TransactionTestCase):
On database backends with no transaction support, TestCase behaves as
TransactionTestCase.
"""
+ databases = _TestCaseDatabasesDescriptor()
+
@classmethod
def _enter_atomics(cls):
"""Open atomic blocks for multiple databases."""
@@ -1019,9 +1105,13 @@ class TestCase(TransactionTestCase):
atomics[db_name].__exit__(None, None, None)
@classmethod
+ def _databases_support_transactions(cls):
+ return connections_support_transactions(cls.databases)
+
+ @classmethod
def setUpClass(cls):
super().setUpClass()
- if not connections_support_transactions():
+ if not cls._databases_support_transactions():
return
cls.cls_atomics = cls._enter_atomics()
@@ -1031,16 +1121,18 @@ class TestCase(TransactionTestCase):
call_command('loaddata', *cls.fixtures, **{'verbosity': 0, 'database': db_name})
except Exception:
cls._rollback_atomics(cls.cls_atomics)
+ cls._remove_cursor_failures()
raise
try:
cls.setUpTestData()
except Exception:
cls._rollback_atomics(cls.cls_atomics)
+ cls._remove_cursor_failures()
raise
@classmethod
def tearDownClass(cls):
- if connections_support_transactions():
+ if cls._databases_support_transactions():
cls._rollback_atomics(cls.cls_atomics)
for conn in connections.all():
conn.close()
@@ -1052,12 +1144,12 @@ class TestCase(TransactionTestCase):
pass
def _should_reload_connections(self):
- if connections_support_transactions():
+ if self._databases_support_transactions():
return False
return super()._should_reload_connections()
def _fixture_setup(self):
- if not connections_support_transactions():
+ if not self._databases_support_transactions():
# If the backend does not support transactions, we should reload
# class data before each test
self.setUpTestData()
@@ -1067,7 +1159,7 @@ class TestCase(TransactionTestCase):
self.atomics = self._enter_atomics()
def _fixture_teardown(self):
- if not connections_support_transactions():
+ if not self._databases_support_transactions():
return super()._fixture_teardown()
try:
for db_name in reversed(self._databases_names()):