summaryrefslogtreecommitdiff
path: root/django/test/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'django/test/utils.py')
-rw-r--r--django/test/utils.py221
1 files changed, 133 insertions, 88 deletions
diff --git a/django/test/utils.py b/django/test/utils.py
index 6c2f566909..ac0fc34b08 100644
--- a/django/test/utils.py
+++ b/django/test/utils.py
@@ -35,15 +35,24 @@ except ImportError:
__all__ = (
- 'Approximate', 'ContextList', 'isolate_lru_cache', 'get_runner',
- 'CaptureQueriesContext',
- 'ignore_warnings', 'isolate_apps', 'modify_settings', 'override_settings',
- 'override_system_checks', 'tag',
- 'requires_tz_support',
- 'setup_databases', 'setup_test_environment', 'teardown_test_environment',
+ "Approximate",
+ "ContextList",
+ "isolate_lru_cache",
+ "get_runner",
+ "CaptureQueriesContext",
+ "ignore_warnings",
+ "isolate_apps",
+ "modify_settings",
+ "override_settings",
+ "override_system_checks",
+ "tag",
+ "requires_tz_support",
+ "setup_databases",
+ "setup_test_environment",
+ "teardown_test_environment",
)
-TZ_SUPPORT = hasattr(time, 'tzset')
+TZ_SUPPORT = hasattr(time, "tzset")
class Approximate:
@@ -63,6 +72,7 @@ class ContextList(list):
A wrapper that provides direct key access to context items contained
in a list of context objects.
"""
+
def __getitem__(self, key):
if isinstance(key, str):
for subcontext in self:
@@ -110,7 +120,7 @@ def setup_test_environment(debug=None):
Perform global pre-test setup, such as installing the instrumented template
renderer and setting the email backend to the locmem email backend.
"""
- if hasattr(_TestState, 'saved_data'):
+ if hasattr(_TestState, "saved_data"):
# Executing this function twice would overwrite the saved values.
raise RuntimeError(
"setup_test_environment() was already called and can't be called "
@@ -125,13 +135,13 @@ def setup_test_environment(debug=None):
saved_data.allowed_hosts = settings.ALLOWED_HOSTS
# Add the default host of the test client.
- settings.ALLOWED_HOSTS = [*settings.ALLOWED_HOSTS, 'testserver']
+ settings.ALLOWED_HOSTS = [*settings.ALLOWED_HOSTS, "testserver"]
saved_data.debug = settings.DEBUG
settings.DEBUG = debug
saved_data.email_backend = settings.EMAIL_BACKEND
- settings.EMAIL_BACKEND = 'django.core.mail.backends.locmem.EmailBackend'
+ settings.EMAIL_BACKEND = "django.core.mail.backends.locmem.EmailBackend"
saved_data.template_render = Template._render
Template._render = instrumented_test_render
@@ -191,18 +201,17 @@ def setup_databases(
# replace with:
# serialize_alias = serialized_aliases is None or alias in serialized_aliases
try:
- serialize_alias = connection.settings_dict['TEST']['SERIALIZE']
+ serialize_alias = connection.settings_dict["TEST"]["SERIALIZE"]
except KeyError:
serialize_alias = (
- serialized_aliases is None or
- alias in serialized_aliases
+ serialized_aliases is None or alias in serialized_aliases
)
else:
warnings.warn(
- 'The SERIALIZE test database setting is '
- 'deprecated as it can be inferred from the '
- 'TestCase/TransactionTestCase.databases that '
- 'enable the serialized_rollback feature.',
+ "The SERIALIZE test database setting is "
+ "deprecated as it can be inferred from the "
+ "TestCase/TransactionTestCase.databases that "
+ "enable the serialized_rollback feature.",
category=RemovedInDjango50Warning,
)
connection.creation.create_test_db(
@@ -221,12 +230,15 @@ def setup_databases(
)
# Configure all other connections as mirrors of the first one
else:
- connections[alias].creation.set_as_test_mirror(connections[first_alias].settings_dict)
+ connections[alias].creation.set_as_test_mirror(
+ connections[first_alias].settings_dict
+ )
# Configure the test mirrors.
for alias, mirror_alias in mirrored_aliases.items():
connections[alias].creation.set_as_test_mirror(
- connections[mirror_alias].settings_dict)
+ connections[mirror_alias].settings_dict
+ )
if debug_sql:
for alias in connections:
@@ -246,8 +258,8 @@ def iter_test_cases(tests):
# Prevent an unfriendly RecursionError that can happen with
# strings.
raise TypeError(
- f'Test {test!r} must be a test case or test suite not string '
- f'(was found in {tests!r}).'
+ f"Test {test!r} must be a test case or test suite not string "
+ f"(was found in {tests!r})."
)
if isinstance(test, TestCase):
yield test
@@ -319,18 +331,18 @@ def get_unique_databases_and_mirrors(aliases=None):
for alias in connections:
connection = connections[alias]
- test_settings = connection.settings_dict['TEST']
+ test_settings = connection.settings_dict["TEST"]
- if test_settings['MIRROR']:
+ if test_settings["MIRROR"]:
# If the database is marked as a test mirror, save the alias.
- mirrored_aliases[alias] = test_settings['MIRROR']
+ mirrored_aliases[alias] = test_settings["MIRROR"]
elif alias in aliases:
# Store a tuple with DB parameters that uniquely identify it.
# If we have two aliases with the same values for that tuple,
# we only need to create the test database once.
item = test_databases.setdefault(
connection.creation.test_db_signature(),
- (connection.settings_dict['NAME'], []),
+ (connection.settings_dict["NAME"], []),
)
# The default database must be the first because data migrations
# use the default alias by default.
@@ -339,11 +351,16 @@ def get_unique_databases_and_mirrors(aliases=None):
else:
item[1].append(alias)
- if 'DEPENDENCIES' in test_settings:
- dependencies[alias] = test_settings['DEPENDENCIES']
+ if "DEPENDENCIES" in test_settings:
+ dependencies[alias] = test_settings["DEPENDENCIES"]
else:
- if alias != DEFAULT_DB_ALIAS and connection.creation.test_db_signature() != default_sig:
- dependencies[alias] = test_settings.get('DEPENDENCIES', [DEFAULT_DB_ALIAS])
+ if (
+ alias != DEFAULT_DB_ALIAS
+ and connection.creation.test_db_signature() != default_sig
+ ):
+ dependencies[alias] = test_settings.get(
+ "DEPENDENCIES", [DEFAULT_DB_ALIAS]
+ )
test_databases = dict(dependency_ordered(test_databases.items(), dependencies))
return test_databases, mirrored_aliases
@@ -365,12 +382,12 @@ def teardown_databases(old_config, verbosity, parallel=0, keepdb=False):
def get_runner(settings, test_runner_class=None):
test_runner_class = test_runner_class or settings.TEST_RUNNER
- test_path = test_runner_class.split('.')
+ test_path = test_runner_class.split(".")
# Allow for relative paths
if len(test_path) > 1:
- test_module_name = '.'.join(test_path[:-1])
+ test_module_name = ".".join(test_path[:-1])
else:
- test_module_name = '.'
+ test_module_name = "."
test_module = __import__(test_module_name, {}, {}, test_path[-1])
return getattr(test_module, test_path[-1])
@@ -387,6 +404,7 @@ class TestContextDecorator:
`kwarg_name`: keyword argument passing the return value of enable() if
used as a function decorator.
"""
+
def __init__(self, attr_name=None, kwarg_name=None):
self.attr_name = attr_name
self.kwarg_name = kwarg_name
@@ -416,7 +434,7 @@ class TestContextDecorator:
cls.setUp = setUp
return cls
- raise TypeError('Can only decorate subclasses of unittest.TestCase')
+ raise TypeError("Can only decorate subclasses of unittest.TestCase")
def decorate_callable(self, func):
if asyncio.iscoroutinefunction(func):
@@ -428,13 +446,16 @@ class TestContextDecorator:
if self.kwarg_name:
kwargs[self.kwarg_name] = context
return await func(*args, **kwargs)
+
else:
+
@wraps(func)
def inner(*args, **kwargs):
with self as context:
if self.kwarg_name:
kwargs[self.kwarg_name] = context
return func(*args, **kwargs)
+
return inner
def __call__(self, decorated):
@@ -442,7 +463,7 @@ class TestContextDecorator:
return self.decorate_class(decorated)
elif callable(decorated):
return self.decorate_callable(decorated)
- raise TypeError('Cannot decorate object of type %s' % type(decorated))
+ raise TypeError("Cannot decorate object of type %s" % type(decorated))
class override_settings(TestContextDecorator):
@@ -452,6 +473,7 @@ class override_settings(TestContextDecorator):
with the ``with`` statement. In either event, entering/exiting are called
before and after, respectively, the function/block is executed.
"""
+
enable_exception = None
def __init__(self, **kwargs):
@@ -461,9 +483,9 @@ class override_settings(TestContextDecorator):
def enable(self):
# Keep this code at the beginning to leave the settings unchanged
# in case it raises an exception because INSTALLED_APPS is invalid.
- if 'INSTALLED_APPS' in self.options:
+ if "INSTALLED_APPS" in self.options:
try:
- apps.set_installed_apps(self.options['INSTALLED_APPS'])
+ apps.set_installed_apps(self.options["INSTALLED_APPS"])
except Exception:
apps.unset_installed_apps()
raise
@@ -476,14 +498,16 @@ class override_settings(TestContextDecorator):
try:
setting_changed.send(
sender=settings._wrapped.__class__,
- setting=key, value=new_value, enter=True,
+ setting=key,
+ value=new_value,
+ enter=True,
)
except Exception as exc:
self.enable_exception = exc
self.disable()
def disable(self):
- if 'INSTALLED_APPS' in self.options:
+ if "INSTALLED_APPS" in self.options:
apps.unset_installed_apps()
settings._wrapped = self.wrapped
del self.wrapped
@@ -492,7 +516,9 @@ class override_settings(TestContextDecorator):
new_value = getattr(settings, key, None)
responses_for_setting = setting_changed.send_robust(
sender=settings._wrapped.__class__,
- setting=key, value=new_value, enter=False,
+ setting=key,
+ value=new_value,
+ enter=False,
)
responses.extend(responses_for_setting)
if self.enable_exception is not None:
@@ -515,10 +541,12 @@ class override_settings(TestContextDecorator):
def decorate_class(self, cls):
from django.test import SimpleTestCase
+
if not issubclass(cls, SimpleTestCase):
raise ValueError(
"Only subclasses of Django SimpleTestCase can be decorated "
- "with override_settings")
+ "with override_settings"
+ )
self.save_options(cls)
return cls
@@ -528,6 +556,7 @@ class modify_settings(override_settings):
Like override_settings, but makes it possible to append, prepend, or remove
items instead of redefining the entire list.
"""
+
def __init__(self, *args, **kwargs):
if args:
# Hack used when instantiating from SimpleTestCase.setUpClass.
@@ -543,8 +572,9 @@ class modify_settings(override_settings):
test_func._modified_settings = self.operations
else:
# Duplicate list to prevent subclasses from altering their parent.
- test_func._modified_settings = list(
- test_func._modified_settings) + self.operations
+ test_func._modified_settings = (
+ list(test_func._modified_settings) + self.operations
+ )
def enable(self):
self.options = {}
@@ -559,11 +589,11 @@ class modify_settings(override_settings):
# items my be a single value or an iterable.
if isinstance(items, str):
items = [items]
- if action == 'append':
+ if action == "append":
value = value + [item for item in items if item not in value]
- elif action == 'prepend':
+ elif action == "prepend":
value = [item for item in items if item not in value] + value
- elif action == 'remove':
+ elif action == "remove":
value = [item for item in value if item not in items]
else:
raise ValueError("Unsupported action: %s" % action)
@@ -577,8 +607,10 @@ class override_system_checks(TestContextDecorator):
Useful when you override `INSTALLED_APPS`, e.g. if you exclude `auth` app,
you also need to exclude its system checks.
"""
+
def __init__(self, new_checks, deployment_checks=None):
from django.core.checks.registry import registry
+
self.registry = registry
self.new_checks = new_checks
self.deployment_checks = deployment_checks
@@ -588,12 +620,12 @@ class override_system_checks(TestContextDecorator):
self.old_checks = self.registry.registered_checks
self.registry.registered_checks = set()
for check in self.new_checks:
- self.registry.register(check, *getattr(check, 'tags', ()))
+ self.registry.register(check, *getattr(check, "tags", ()))
self.old_deployment_checks = self.registry.deployment_checks
if self.deployment_checks is not None:
self.registry.deployment_checks = set()
for check in self.deployment_checks:
- self.registry.register(check, *getattr(check, 'tags', ()), deploy=True)
+ self.registry.register(check, *getattr(check, "tags", ()), deploy=True)
def disable(self):
self.registry.registered_checks = self.old_checks
@@ -609,18 +641,18 @@ def compare_xml(want, got):
Based on https://github.com/lxml/lxml/blob/master/src/lxml/doctestcompare.py
"""
- _norm_whitespace_re = re.compile(r'[ \t\n][ \t\n]+')
+ _norm_whitespace_re = re.compile(r"[ \t\n][ \t\n]+")
def norm_whitespace(v):
- return _norm_whitespace_re.sub(' ', v)
+ return _norm_whitespace_re.sub(" ", v)
def child_text(element):
- return ''.join(c.data for c in element.childNodes
- if c.nodeType == Node.TEXT_NODE)
+ return "".join(
+ c.data for c in element.childNodes if c.nodeType == Node.TEXT_NODE
+ )
def children(element):
- return [c for c in element.childNodes
- if c.nodeType == Node.ELEMENT_NODE]
+ return [c for c in element.childNodes if c.nodeType == Node.ELEMENT_NODE]
def norm_child_text(element):
return norm_whitespace(child_text(element))
@@ -639,7 +671,9 @@ def compare_xml(want, got):
got_children = children(got_element)
if len(want_children) != len(got_children):
return False
- return all(check_element(want, got) for want, got in zip(want_children, got_children))
+ return all(
+ check_element(want, got) for want, got in zip(want_children, got_children)
+ )
def first_node(document):
for node in document.childNodes:
@@ -650,13 +684,13 @@ def compare_xml(want, got):
):
return node
- want = want.strip().replace('\\n', '\n')
- got = got.strip().replace('\\n', '\n')
+ want = want.strip().replace("\\n", "\n")
+ got = got.strip().replace("\\n", "\n")
# If the string is not a complete xml document, we may need to add a
# root element. This allow us to compare fragments, like "<foo/><bar/>"
- if not want.startswith('<?xml'):
- wrapper = '<root>%s</root>'
+ if not want.startswith("<?xml"):
+ wrapper = "<root>%s</root>"
want = wrapper % want
got = wrapper % got
@@ -671,6 +705,7 @@ class CaptureQueriesContext:
"""
Context manager that captures queries executed by the specified connection.
"""
+
def __init__(self, connection):
self.connection = connection
@@ -685,7 +720,7 @@ class CaptureQueriesContext:
@property
def captured_queries(self):
- return self.connection.queries[self.initial_queries:self.final_queries]
+ return self.connection.queries[self.initial_queries : self.final_queries]
def __enter__(self):
self.force_debug_cursor = self.connection.force_debug_cursor
@@ -709,7 +744,7 @@ class CaptureQueriesContext:
class ignore_warnings(TestContextDecorator):
def __init__(self, **kwargs):
self.ignore_kwargs = kwargs
- if 'message' in self.ignore_kwargs or 'module' in self.ignore_kwargs:
+ if "message" in self.ignore_kwargs or "module" in self.ignore_kwargs:
self.filter_func = warnings.filterwarnings
else:
self.filter_func = warnings.simplefilter
@@ -718,7 +753,7 @@ class ignore_warnings(TestContextDecorator):
def enable(self):
self.catch_warnings = warnings.catch_warnings()
self.catch_warnings.__enter__()
- self.filter_func('ignore', **self.ignore_kwargs)
+ self.filter_func("ignore", **self.ignore_kwargs)
def disable(self):
self.catch_warnings.__exit__(*sys.exc_info())
@@ -732,7 +767,7 @@ class ignore_warnings(TestContextDecorator):
requires_tz_support = skipUnless(
TZ_SUPPORT,
"This test relies on the ability to run a program in an arbitrary "
- "time zone, but your operating system isn't able to do that."
+ "time zone, but your operating system isn't able to do that.",
)
@@ -775,9 +810,9 @@ def captured_output(stream_name):
def captured_stdout():
"""Capture the output of sys.stdout:
- with captured_stdout() as stdout:
- print("hello")
- self.assertEqual(stdout.getvalue(), "hello\n")
+ with captured_stdout() as stdout:
+ print("hello")
+ self.assertEqual(stdout.getvalue(), "hello\n")
"""
return captured_output("stdout")
@@ -785,9 +820,9 @@ def captured_stdout():
def captured_stderr():
"""Capture the output of sys.stderr:
- with captured_stderr() as stderr:
- print("hello", file=sys.stderr)
- self.assertEqual(stderr.getvalue(), "hello\n")
+ with captured_stderr() as stderr:
+ print("hello", file=sys.stderr)
+ self.assertEqual(stderr.getvalue(), "hello\n")
"""
return captured_output("stderr")
@@ -795,12 +830,12 @@ def captured_stderr():
def captured_stdin():
"""Capture the input to sys.stdin:
- with captured_stdin() as stdin:
- stdin.write('hello\n')
- stdin.seek(0)
- # call test code that consumes from sys.stdin
- captured = input()
- self.assertEqual(captured, "hello")
+ with captured_stdin() as stdin:
+ stdin.write('hello\n')
+ stdin.seek(0)
+ # call test code that consumes from sys.stdin
+ captured = input()
+ self.assertEqual(captured, "hello")
"""
return captured_output("stdin")
@@ -828,18 +863,24 @@ def require_jinja2(test_func):
Django template engine for a test or skip it if Jinja2 isn't available.
"""
test_func = skipIf(jinja2 is None, "this test requires jinja2")(test_func)
- return override_settings(TEMPLATES=[{
- 'BACKEND': 'django.template.backends.django.DjangoTemplates',
- 'APP_DIRS': True,
- }, {
- 'BACKEND': 'django.template.backends.jinja2.Jinja2',
- 'APP_DIRS': True,
- 'OPTIONS': {'keep_trailing_newline': True},
- }])(test_func)
+ return override_settings(
+ TEMPLATES=[
+ {
+ "BACKEND": "django.template.backends.django.DjangoTemplates",
+ "APP_DIRS": True,
+ },
+ {
+ "BACKEND": "django.template.backends.jinja2.Jinja2",
+ "APP_DIRS": True,
+ "OPTIONS": {"keep_trailing_newline": True},
+ },
+ ]
+ )(test_func)
class override_script_prefix(TestContextDecorator):
"""Decorator or context manager to temporary override the script prefix."""
+
def __init__(self, prefix):
self.prefix = prefix
super().__init__()
@@ -857,8 +898,9 @@ class LoggingCaptureMixin:
Capture the output from the 'django' logger and store it on the class's
logger_output attribute.
"""
+
def setUp(self):
- self.logger = logging.getLogger('django')
+ self.logger = logging.getLogger("django")
self.old_stream = self.logger.handlers[0].stream
self.logger_output = StringIO()
self.logger.handlers[0].stream = self.logger_output
@@ -883,6 +925,7 @@ class isolate_apps(TestContextDecorator):
`kwarg_name`: keyword argument passing the isolated registry if used as a
function decorator.
"""
+
def __init__(self, *installed_apps, **kwargs):
self.installed_apps = installed_apps
super().__init__(**kwargs)
@@ -890,11 +933,11 @@ class isolate_apps(TestContextDecorator):
def enable(self):
self.old_apps = Options.default_apps
apps = Apps(self.installed_apps)
- setattr(Options, 'default_apps', apps)
+ setattr(Options, "default_apps", apps)
return apps
def disable(self):
- setattr(Options, 'default_apps', self.old_apps)
+ setattr(Options, "default_apps", self.old_apps)
class TimeKeeper:
@@ -914,7 +957,7 @@ class TimeKeeper:
def print_results(self):
for name, end_times in self.records.items():
for record_time in end_times:
- record = '%s took %.3fs' % (name, record_time)
+ record = "%s took %.3fs" % (name, record_time)
sys.stderr.write(record + os.linesep)
@@ -929,12 +972,14 @@ class NullTimeKeeper:
def tag(*tags):
"""Decorator to add tags to a test class or method."""
+
def decorator(obj):
- if hasattr(obj, 'tags'):
+ if hasattr(obj, "tags"):
obj.tags = obj.tags.union(tags)
else:
- setattr(obj, 'tags', set(tags))
+ setattr(obj, "tags", set(tags))
return obj
+
return decorator