diff options
Diffstat (limited to 'django/test/testcases.py')
| -rw-r--r-- | django/test/testcases.py | 180 |
1 files changed, 136 insertions, 44 deletions
diff --git a/django/test/testcases.py b/django/test/testcases.py index 36986185ce..f820684c87 100644 --- a/django/test/testcases.py +++ b/django/test/testcases.py @@ -4,9 +4,11 @@ import posixpath import sys import threading import unittest +import warnings from collections import Counter from contextlib import contextmanager from copy import copy +from difflib import get_close_matches from functools import wraps from unittest.util import safe_repr from urllib.parse import ( @@ -17,7 +19,7 @@ from urllib.request import url2pathname from django.apps import apps from django.conf import settings from django.core import mail -from django.core.exceptions import ValidationError +from django.core.exceptions import ImproperlyConfigured, ValidationError from django.core.files import locks from django.core.handlers.wsgi import WSGIHandler, get_path_info from django.core.management import call_command @@ -36,6 +38,7 @@ from django.test.utils import ( override_settings, ) from django.utils.decorators import classproperty +from django.utils.deprecation import RemovedInDjango31Warning from django.views.static import serve __all__ = ('TestCase', 'TransactionTestCase', @@ -133,16 +136,31 @@ class _AssertTemplateNotUsedContext(_AssertTemplateUsedContext): class _CursorFailure: - def __init__(self, cls_name, wrapped): - self.cls_name = cls_name + def __init__(self, wrapped, message): self.wrapped = wrapped + self.message = message def __call__(self): - raise AssertionError( - "Database queries aren't allowed in SimpleTestCase. " - "Either use TestCase or TransactionTestCase to ensure proper test isolation or " - "set %s.allow_database_queries to True to silence this failure." % self.cls_name - ) + raise AssertionError(self.message) + + +class _SimpleTestCaseDatabasesDescriptor: + """Descriptor for SimpleTestCase.allow_database_queries deprecation.""" + def __get__(self, instance, cls=None): + try: + allow_database_queries = cls.allow_database_queries + except AttributeError: + pass + else: + msg = ( + '`SimpleTestCase.allow_database_queries` is deprecated. ' + 'Restrict the databases available during the execution of ' + '%s.%s with the `databases` attribute instead.' + ) % (cls.__module__, cls.__qualname__) + warnings.warn(msg, RemovedInDjango31Warning) + if allow_database_queries: + return {DEFAULT_DB_ALIAS} + return set() class SimpleTestCase(unittest.TestCase): @@ -153,9 +171,13 @@ class SimpleTestCase(unittest.TestCase): _overridden_settings = None _modified_settings = None - # Tests shouldn't be allowed to query the database since - # this base class doesn't enforce any isolation. - allow_database_queries = False + databases = _SimpleTestCaseDatabasesDescriptor() + _disallowed_database_msg = ( + 'Database queries are not allowed in SimpleTestCase subclasses. ' + 'Either subclass TestCase or TransactionTestCase to ensure proper ' + 'test isolation or add %(alias)r to %(test)s.databases to silence ' + 'this failure.' + ) @classmethod def setUpClass(cls): @@ -166,19 +188,51 @@ class SimpleTestCase(unittest.TestCase): if cls._modified_settings: cls._cls_modified_context = modify_settings(cls._modified_settings) cls._cls_modified_context.enable() - if not cls.allow_database_queries: - for alias in connections: - connection = connections[alias] - connection.cursor = _CursorFailure(cls.__name__, connection.cursor) - connection.chunked_cursor = _CursorFailure(cls.__name__, connection.chunked_cursor) + cls._add_cursor_failures() + + @classmethod + def _validate_databases(cls): + if cls.databases == '__all__': + return frozenset(connections) + for alias in cls.databases: + if alias not in connections: + message = '%s.%s.databases refers to %r which is not defined in settings.DATABASES.' % ( + cls.__module__, + cls.__qualname__, + alias, + ) + close_matches = get_close_matches(alias, list(connections)) + if close_matches: + message += ' Did you mean %r?' % close_matches[0] + raise ImproperlyConfigured(message) + return frozenset(cls.databases) + + @classmethod + def _add_cursor_failures(cls): + cls.databases = cls._validate_databases() + for alias in connections: + if alias in cls.databases: + continue + connection = connections[alias] + message = cls._disallowed_database_msg % { + 'test': '%s.%s' % (cls.__module__, cls.__qualname__), + 'alias': alias, + } + connection.cursor = _CursorFailure(connection.cursor, message) + connection.chunked_cursor = _CursorFailure(connection.chunked_cursor, message) + + @classmethod + def _remove_cursor_failures(cls): + for alias in connections: + if alias in cls.databases: + continue + connection = connections[alias] + connection.cursor = connection.cursor.wrapped + connection.chunked_cursor = connection.chunked_cursor.wrapped @classmethod def tearDownClass(cls): - if not cls.allow_database_queries: - for alias in connections: - connection = connections[alias] - connection.cursor = connection.cursor.wrapped - connection.chunked_cursor = connection.chunked_cursor.wrapped + cls._remove_cursor_failures() if hasattr(cls, '_cls_modified_context'): cls._cls_modified_context.disable() delattr(cls, '_cls_modified_context') @@ -806,6 +860,26 @@ class SimpleTestCase(unittest.TestCase): self.fail(self._formatMessage(msg, standardMsg)) +class _TransactionTestCaseDatabasesDescriptor: + """Descriptor for TransactionTestCase.multi_db deprecation.""" + msg = ( + '`TransactionTestCase.multi_db` is deprecated. Databases available ' + 'during this test can be defined using %s.%s.databases.' + ) + + def __get__(self, instance, cls=None): + try: + multi_db = cls.multi_db + except AttributeError: + pass + else: + msg = self.msg % (cls.__module__, cls.__qualname__) + warnings.warn(msg, RemovedInDjango31Warning) + if multi_db: + return set(connections) + return {DEFAULT_DB_ALIAS} + + class TransactionTestCase(SimpleTestCase): # Subclasses can ask for resetting of auto increment sequence before each @@ -818,8 +892,12 @@ class TransactionTestCase(SimpleTestCase): # Subclasses can define fixtures which will be automatically installed. fixtures = None - # Do the tests in this class query non-default databases? - multi_db = False + databases = _TransactionTestCaseDatabasesDescriptor() + _disallowed_database_msg = ( + 'Database queries to %(alias)r are not allowed in this test. Add ' + '%(alias)r to %(test)s.databases to ensure proper test isolation ' + 'and silence this failure.' + ) # If transactions aren't available, Django will serialize the database # contents into a fixture during setup and flush and reload them @@ -827,10 +905,6 @@ class TransactionTestCase(SimpleTestCase): # This can be slow; this flag allows enabling on a per-case basis. serialized_rollback = False - # Since tests will be wrapped in a transaction, or serialized if they - # are not available, we allow queries to be run. - allow_database_queries = True - def _pre_setup(self): """ Perform pre-test setup: @@ -870,15 +944,13 @@ class TransactionTestCase(SimpleTestCase): @classmethod def _databases_names(cls, include_mirrors=True): - # If the test case has a multi_db=True flag, act on all databases, - # including mirrors or not. Otherwise, just on the default DB. - if cls.multi_db: - return [ - alias for alias in connections - if include_mirrors or not connections[alias].settings_dict['TEST']['MIRROR'] - ] - else: - return [DEFAULT_DB_ALIAS] + # Only consider allowed database aliases, including mirrors or not. + return [ + alias for alias in connections + if alias in cls.databases and ( + include_mirrors or not connections[alias].settings_dict['TEST']['MIRROR'] + ) + ] def _reset_sequences(self, db_name): conn = connections[db_name] @@ -984,9 +1056,21 @@ class TransactionTestCase(SimpleTestCase): func(*args, **kwargs) -def connections_support_transactions(): - """Return True if all connections support transactions.""" - return all(conn.features.supports_transactions for conn in connections.all()) +def connections_support_transactions(aliases=None): + """ + Return whether or not all (or specified) connections support + transactions. + """ + conns = connections.all() if aliases is None else (connections[alias] for alias in aliases) + return all(conn.features.supports_transactions for conn in conns) + + +class _TestCaseDatabasesDescriptor(_TransactionTestCaseDatabasesDescriptor): + """Descriptor for TestCase.multi_db deprecation.""" + msg = ( + '`TestCase.multi_db` is deprecated. Databases available during this ' + 'test can be defined using %s.%s.databases.' + ) class TestCase(TransactionTestCase): @@ -1002,6 +1086,8 @@ class TestCase(TransactionTestCase): On database backends with no transaction support, TestCase behaves as TransactionTestCase. """ + databases = _TestCaseDatabasesDescriptor() + @classmethod def _enter_atomics(cls): """Open atomic blocks for multiple databases.""" @@ -1019,9 +1105,13 @@ class TestCase(TransactionTestCase): atomics[db_name].__exit__(None, None, None) @classmethod + def _databases_support_transactions(cls): + return connections_support_transactions(cls.databases) + + @classmethod def setUpClass(cls): super().setUpClass() - if not connections_support_transactions(): + if not cls._databases_support_transactions(): return cls.cls_atomics = cls._enter_atomics() @@ -1031,16 +1121,18 @@ class TestCase(TransactionTestCase): call_command('loaddata', *cls.fixtures, **{'verbosity': 0, 'database': db_name}) except Exception: cls._rollback_atomics(cls.cls_atomics) + cls._remove_cursor_failures() raise try: cls.setUpTestData() except Exception: cls._rollback_atomics(cls.cls_atomics) + cls._remove_cursor_failures() raise @classmethod def tearDownClass(cls): - if connections_support_transactions(): + if cls._databases_support_transactions(): cls._rollback_atomics(cls.cls_atomics) for conn in connections.all(): conn.close() @@ -1052,12 +1144,12 @@ class TestCase(TransactionTestCase): pass def _should_reload_connections(self): - if connections_support_transactions(): + if self._databases_support_transactions(): return False return super()._should_reload_connections() def _fixture_setup(self): - if not connections_support_transactions(): + if not self._databases_support_transactions(): # If the backend does not support transactions, we should reload # class data before each test self.setUpTestData() @@ -1067,7 +1159,7 @@ class TestCase(TransactionTestCase): self.atomics = self._enter_atomics() def _fixture_teardown(self): - if not connections_support_transactions(): + if not self._databases_support_transactions(): return super()._fixture_teardown() try: for db_name in reversed(self._databases_names()): |
