diff options
Diffstat (limited to 'django/test/utils.py')
| -rw-r--r-- | django/test/utils.py | 221 |
1 files changed, 133 insertions, 88 deletions
diff --git a/django/test/utils.py b/django/test/utils.py index 6c2f566909..ac0fc34b08 100644 --- a/django/test/utils.py +++ b/django/test/utils.py @@ -35,15 +35,24 @@ except ImportError: __all__ = ( - 'Approximate', 'ContextList', 'isolate_lru_cache', 'get_runner', - 'CaptureQueriesContext', - 'ignore_warnings', 'isolate_apps', 'modify_settings', 'override_settings', - 'override_system_checks', 'tag', - 'requires_tz_support', - 'setup_databases', 'setup_test_environment', 'teardown_test_environment', + "Approximate", + "ContextList", + "isolate_lru_cache", + "get_runner", + "CaptureQueriesContext", + "ignore_warnings", + "isolate_apps", + "modify_settings", + "override_settings", + "override_system_checks", + "tag", + "requires_tz_support", + "setup_databases", + "setup_test_environment", + "teardown_test_environment", ) -TZ_SUPPORT = hasattr(time, 'tzset') +TZ_SUPPORT = hasattr(time, "tzset") class Approximate: @@ -63,6 +72,7 @@ class ContextList(list): A wrapper that provides direct key access to context items contained in a list of context objects. """ + def __getitem__(self, key): if isinstance(key, str): for subcontext in self: @@ -110,7 +120,7 @@ def setup_test_environment(debug=None): Perform global pre-test setup, such as installing the instrumented template renderer and setting the email backend to the locmem email backend. """ - if hasattr(_TestState, 'saved_data'): + if hasattr(_TestState, "saved_data"): # Executing this function twice would overwrite the saved values. raise RuntimeError( "setup_test_environment() was already called and can't be called " @@ -125,13 +135,13 @@ def setup_test_environment(debug=None): saved_data.allowed_hosts = settings.ALLOWED_HOSTS # Add the default host of the test client. - settings.ALLOWED_HOSTS = [*settings.ALLOWED_HOSTS, 'testserver'] + settings.ALLOWED_HOSTS = [*settings.ALLOWED_HOSTS, "testserver"] saved_data.debug = settings.DEBUG settings.DEBUG = debug saved_data.email_backend = settings.EMAIL_BACKEND - settings.EMAIL_BACKEND = 'django.core.mail.backends.locmem.EmailBackend' + settings.EMAIL_BACKEND = "django.core.mail.backends.locmem.EmailBackend" saved_data.template_render = Template._render Template._render = instrumented_test_render @@ -191,18 +201,17 @@ def setup_databases( # replace with: # serialize_alias = serialized_aliases is None or alias in serialized_aliases try: - serialize_alias = connection.settings_dict['TEST']['SERIALIZE'] + serialize_alias = connection.settings_dict["TEST"]["SERIALIZE"] except KeyError: serialize_alias = ( - serialized_aliases is None or - alias in serialized_aliases + serialized_aliases is None or alias in serialized_aliases ) else: warnings.warn( - 'The SERIALIZE test database setting is ' - 'deprecated as it can be inferred from the ' - 'TestCase/TransactionTestCase.databases that ' - 'enable the serialized_rollback feature.', + "The SERIALIZE test database setting is " + "deprecated as it can be inferred from the " + "TestCase/TransactionTestCase.databases that " + "enable the serialized_rollback feature.", category=RemovedInDjango50Warning, ) connection.creation.create_test_db( @@ -221,12 +230,15 @@ def setup_databases( ) # Configure all other connections as mirrors of the first one else: - connections[alias].creation.set_as_test_mirror(connections[first_alias].settings_dict) + connections[alias].creation.set_as_test_mirror( + connections[first_alias].settings_dict + ) # Configure the test mirrors. for alias, mirror_alias in mirrored_aliases.items(): connections[alias].creation.set_as_test_mirror( - connections[mirror_alias].settings_dict) + connections[mirror_alias].settings_dict + ) if debug_sql: for alias in connections: @@ -246,8 +258,8 @@ def iter_test_cases(tests): # Prevent an unfriendly RecursionError that can happen with # strings. raise TypeError( - f'Test {test!r} must be a test case or test suite not string ' - f'(was found in {tests!r}).' + f"Test {test!r} must be a test case or test suite not string " + f"(was found in {tests!r})." ) if isinstance(test, TestCase): yield test @@ -319,18 +331,18 @@ def get_unique_databases_and_mirrors(aliases=None): for alias in connections: connection = connections[alias] - test_settings = connection.settings_dict['TEST'] + test_settings = connection.settings_dict["TEST"] - if test_settings['MIRROR']: + if test_settings["MIRROR"]: # If the database is marked as a test mirror, save the alias. - mirrored_aliases[alias] = test_settings['MIRROR'] + mirrored_aliases[alias] = test_settings["MIRROR"] elif alias in aliases: # Store a tuple with DB parameters that uniquely identify it. # If we have two aliases with the same values for that tuple, # we only need to create the test database once. item = test_databases.setdefault( connection.creation.test_db_signature(), - (connection.settings_dict['NAME'], []), + (connection.settings_dict["NAME"], []), ) # The default database must be the first because data migrations # use the default alias by default. @@ -339,11 +351,16 @@ def get_unique_databases_and_mirrors(aliases=None): else: item[1].append(alias) - if 'DEPENDENCIES' in test_settings: - dependencies[alias] = test_settings['DEPENDENCIES'] + if "DEPENDENCIES" in test_settings: + dependencies[alias] = test_settings["DEPENDENCIES"] else: - if alias != DEFAULT_DB_ALIAS and connection.creation.test_db_signature() != default_sig: - dependencies[alias] = test_settings.get('DEPENDENCIES', [DEFAULT_DB_ALIAS]) + if ( + alias != DEFAULT_DB_ALIAS + and connection.creation.test_db_signature() != default_sig + ): + dependencies[alias] = test_settings.get( + "DEPENDENCIES", [DEFAULT_DB_ALIAS] + ) test_databases = dict(dependency_ordered(test_databases.items(), dependencies)) return test_databases, mirrored_aliases @@ -365,12 +382,12 @@ def teardown_databases(old_config, verbosity, parallel=0, keepdb=False): def get_runner(settings, test_runner_class=None): test_runner_class = test_runner_class or settings.TEST_RUNNER - test_path = test_runner_class.split('.') + test_path = test_runner_class.split(".") # Allow for relative paths if len(test_path) > 1: - test_module_name = '.'.join(test_path[:-1]) + test_module_name = ".".join(test_path[:-1]) else: - test_module_name = '.' + test_module_name = "." test_module = __import__(test_module_name, {}, {}, test_path[-1]) return getattr(test_module, test_path[-1]) @@ -387,6 +404,7 @@ class TestContextDecorator: `kwarg_name`: keyword argument passing the return value of enable() if used as a function decorator. """ + def __init__(self, attr_name=None, kwarg_name=None): self.attr_name = attr_name self.kwarg_name = kwarg_name @@ -416,7 +434,7 @@ class TestContextDecorator: cls.setUp = setUp return cls - raise TypeError('Can only decorate subclasses of unittest.TestCase') + raise TypeError("Can only decorate subclasses of unittest.TestCase") def decorate_callable(self, func): if asyncio.iscoroutinefunction(func): @@ -428,13 +446,16 @@ class TestContextDecorator: if self.kwarg_name: kwargs[self.kwarg_name] = context return await func(*args, **kwargs) + else: + @wraps(func) def inner(*args, **kwargs): with self as context: if self.kwarg_name: kwargs[self.kwarg_name] = context return func(*args, **kwargs) + return inner def __call__(self, decorated): @@ -442,7 +463,7 @@ class TestContextDecorator: return self.decorate_class(decorated) elif callable(decorated): return self.decorate_callable(decorated) - raise TypeError('Cannot decorate object of type %s' % type(decorated)) + raise TypeError("Cannot decorate object of type %s" % type(decorated)) class override_settings(TestContextDecorator): @@ -452,6 +473,7 @@ class override_settings(TestContextDecorator): with the ``with`` statement. In either event, entering/exiting are called before and after, respectively, the function/block is executed. """ + enable_exception = None def __init__(self, **kwargs): @@ -461,9 +483,9 @@ class override_settings(TestContextDecorator): def enable(self): # Keep this code at the beginning to leave the settings unchanged # in case it raises an exception because INSTALLED_APPS is invalid. - if 'INSTALLED_APPS' in self.options: + if "INSTALLED_APPS" in self.options: try: - apps.set_installed_apps(self.options['INSTALLED_APPS']) + apps.set_installed_apps(self.options["INSTALLED_APPS"]) except Exception: apps.unset_installed_apps() raise @@ -476,14 +498,16 @@ class override_settings(TestContextDecorator): try: setting_changed.send( sender=settings._wrapped.__class__, - setting=key, value=new_value, enter=True, + setting=key, + value=new_value, + enter=True, ) except Exception as exc: self.enable_exception = exc self.disable() def disable(self): - if 'INSTALLED_APPS' in self.options: + if "INSTALLED_APPS" in self.options: apps.unset_installed_apps() settings._wrapped = self.wrapped del self.wrapped @@ -492,7 +516,9 @@ class override_settings(TestContextDecorator): new_value = getattr(settings, key, None) responses_for_setting = setting_changed.send_robust( sender=settings._wrapped.__class__, - setting=key, value=new_value, enter=False, + setting=key, + value=new_value, + enter=False, ) responses.extend(responses_for_setting) if self.enable_exception is not None: @@ -515,10 +541,12 @@ class override_settings(TestContextDecorator): def decorate_class(self, cls): from django.test import SimpleTestCase + if not issubclass(cls, SimpleTestCase): raise ValueError( "Only subclasses of Django SimpleTestCase can be decorated " - "with override_settings") + "with override_settings" + ) self.save_options(cls) return cls @@ -528,6 +556,7 @@ class modify_settings(override_settings): Like override_settings, but makes it possible to append, prepend, or remove items instead of redefining the entire list. """ + def __init__(self, *args, **kwargs): if args: # Hack used when instantiating from SimpleTestCase.setUpClass. @@ -543,8 +572,9 @@ class modify_settings(override_settings): test_func._modified_settings = self.operations else: # Duplicate list to prevent subclasses from altering their parent. - test_func._modified_settings = list( - test_func._modified_settings) + self.operations + test_func._modified_settings = ( + list(test_func._modified_settings) + self.operations + ) def enable(self): self.options = {} @@ -559,11 +589,11 @@ class modify_settings(override_settings): # items my be a single value or an iterable. if isinstance(items, str): items = [items] - if action == 'append': + if action == "append": value = value + [item for item in items if item not in value] - elif action == 'prepend': + elif action == "prepend": value = [item for item in items if item not in value] + value - elif action == 'remove': + elif action == "remove": value = [item for item in value if item not in items] else: raise ValueError("Unsupported action: %s" % action) @@ -577,8 +607,10 @@ class override_system_checks(TestContextDecorator): Useful when you override `INSTALLED_APPS`, e.g. if you exclude `auth` app, you also need to exclude its system checks. """ + def __init__(self, new_checks, deployment_checks=None): from django.core.checks.registry import registry + self.registry = registry self.new_checks = new_checks self.deployment_checks = deployment_checks @@ -588,12 +620,12 @@ class override_system_checks(TestContextDecorator): self.old_checks = self.registry.registered_checks self.registry.registered_checks = set() for check in self.new_checks: - self.registry.register(check, *getattr(check, 'tags', ())) + self.registry.register(check, *getattr(check, "tags", ())) self.old_deployment_checks = self.registry.deployment_checks if self.deployment_checks is not None: self.registry.deployment_checks = set() for check in self.deployment_checks: - self.registry.register(check, *getattr(check, 'tags', ()), deploy=True) + self.registry.register(check, *getattr(check, "tags", ()), deploy=True) def disable(self): self.registry.registered_checks = self.old_checks @@ -609,18 +641,18 @@ def compare_xml(want, got): Based on https://github.com/lxml/lxml/blob/master/src/lxml/doctestcompare.py """ - _norm_whitespace_re = re.compile(r'[ \t\n][ \t\n]+') + _norm_whitespace_re = re.compile(r"[ \t\n][ \t\n]+") def norm_whitespace(v): - return _norm_whitespace_re.sub(' ', v) + return _norm_whitespace_re.sub(" ", v) def child_text(element): - return ''.join(c.data for c in element.childNodes - if c.nodeType == Node.TEXT_NODE) + return "".join( + c.data for c in element.childNodes if c.nodeType == Node.TEXT_NODE + ) def children(element): - return [c for c in element.childNodes - if c.nodeType == Node.ELEMENT_NODE] + return [c for c in element.childNodes if c.nodeType == Node.ELEMENT_NODE] def norm_child_text(element): return norm_whitespace(child_text(element)) @@ -639,7 +671,9 @@ def compare_xml(want, got): got_children = children(got_element) if len(want_children) != len(got_children): return False - return all(check_element(want, got) for want, got in zip(want_children, got_children)) + return all( + check_element(want, got) for want, got in zip(want_children, got_children) + ) def first_node(document): for node in document.childNodes: @@ -650,13 +684,13 @@ def compare_xml(want, got): ): return node - want = want.strip().replace('\\n', '\n') - got = got.strip().replace('\\n', '\n') + want = want.strip().replace("\\n", "\n") + got = got.strip().replace("\\n", "\n") # If the string is not a complete xml document, we may need to add a # root element. This allow us to compare fragments, like "<foo/><bar/>" - if not want.startswith('<?xml'): - wrapper = '<root>%s</root>' + if not want.startswith("<?xml"): + wrapper = "<root>%s</root>" want = wrapper % want got = wrapper % got @@ -671,6 +705,7 @@ class CaptureQueriesContext: """ Context manager that captures queries executed by the specified connection. """ + def __init__(self, connection): self.connection = connection @@ -685,7 +720,7 @@ class CaptureQueriesContext: @property def captured_queries(self): - return self.connection.queries[self.initial_queries:self.final_queries] + return self.connection.queries[self.initial_queries : self.final_queries] def __enter__(self): self.force_debug_cursor = self.connection.force_debug_cursor @@ -709,7 +744,7 @@ class CaptureQueriesContext: class ignore_warnings(TestContextDecorator): def __init__(self, **kwargs): self.ignore_kwargs = kwargs - if 'message' in self.ignore_kwargs or 'module' in self.ignore_kwargs: + if "message" in self.ignore_kwargs or "module" in self.ignore_kwargs: self.filter_func = warnings.filterwarnings else: self.filter_func = warnings.simplefilter @@ -718,7 +753,7 @@ class ignore_warnings(TestContextDecorator): def enable(self): self.catch_warnings = warnings.catch_warnings() self.catch_warnings.__enter__() - self.filter_func('ignore', **self.ignore_kwargs) + self.filter_func("ignore", **self.ignore_kwargs) def disable(self): self.catch_warnings.__exit__(*sys.exc_info()) @@ -732,7 +767,7 @@ class ignore_warnings(TestContextDecorator): requires_tz_support = skipUnless( TZ_SUPPORT, "This test relies on the ability to run a program in an arbitrary " - "time zone, but your operating system isn't able to do that." + "time zone, but your operating system isn't able to do that.", ) @@ -775,9 +810,9 @@ def captured_output(stream_name): def captured_stdout(): """Capture the output of sys.stdout: - with captured_stdout() as stdout: - print("hello") - self.assertEqual(stdout.getvalue(), "hello\n") + with captured_stdout() as stdout: + print("hello") + self.assertEqual(stdout.getvalue(), "hello\n") """ return captured_output("stdout") @@ -785,9 +820,9 @@ def captured_stdout(): def captured_stderr(): """Capture the output of sys.stderr: - with captured_stderr() as stderr: - print("hello", file=sys.stderr) - self.assertEqual(stderr.getvalue(), "hello\n") + with captured_stderr() as stderr: + print("hello", file=sys.stderr) + self.assertEqual(stderr.getvalue(), "hello\n") """ return captured_output("stderr") @@ -795,12 +830,12 @@ def captured_stderr(): def captured_stdin(): """Capture the input to sys.stdin: - with captured_stdin() as stdin: - stdin.write('hello\n') - stdin.seek(0) - # call test code that consumes from sys.stdin - captured = input() - self.assertEqual(captured, "hello") + with captured_stdin() as stdin: + stdin.write('hello\n') + stdin.seek(0) + # call test code that consumes from sys.stdin + captured = input() + self.assertEqual(captured, "hello") """ return captured_output("stdin") @@ -828,18 +863,24 @@ def require_jinja2(test_func): Django template engine for a test or skip it if Jinja2 isn't available. """ test_func = skipIf(jinja2 is None, "this test requires jinja2")(test_func) - return override_settings(TEMPLATES=[{ - 'BACKEND': 'django.template.backends.django.DjangoTemplates', - 'APP_DIRS': True, - }, { - 'BACKEND': 'django.template.backends.jinja2.Jinja2', - 'APP_DIRS': True, - 'OPTIONS': {'keep_trailing_newline': True}, - }])(test_func) + return override_settings( + TEMPLATES=[ + { + "BACKEND": "django.template.backends.django.DjangoTemplates", + "APP_DIRS": True, + }, + { + "BACKEND": "django.template.backends.jinja2.Jinja2", + "APP_DIRS": True, + "OPTIONS": {"keep_trailing_newline": True}, + }, + ] + )(test_func) class override_script_prefix(TestContextDecorator): """Decorator or context manager to temporary override the script prefix.""" + def __init__(self, prefix): self.prefix = prefix super().__init__() @@ -857,8 +898,9 @@ class LoggingCaptureMixin: Capture the output from the 'django' logger and store it on the class's logger_output attribute. """ + def setUp(self): - self.logger = logging.getLogger('django') + self.logger = logging.getLogger("django") self.old_stream = self.logger.handlers[0].stream self.logger_output = StringIO() self.logger.handlers[0].stream = self.logger_output @@ -883,6 +925,7 @@ class isolate_apps(TestContextDecorator): `kwarg_name`: keyword argument passing the isolated registry if used as a function decorator. """ + def __init__(self, *installed_apps, **kwargs): self.installed_apps = installed_apps super().__init__(**kwargs) @@ -890,11 +933,11 @@ class isolate_apps(TestContextDecorator): def enable(self): self.old_apps = Options.default_apps apps = Apps(self.installed_apps) - setattr(Options, 'default_apps', apps) + setattr(Options, "default_apps", apps) return apps def disable(self): - setattr(Options, 'default_apps', self.old_apps) + setattr(Options, "default_apps", self.old_apps) class TimeKeeper: @@ -914,7 +957,7 @@ class TimeKeeper: def print_results(self): for name, end_times in self.records.items(): for record_time in end_times: - record = '%s took %.3fs' % (name, record_time) + record = "%s took %.3fs" % (name, record_time) sys.stderr.write(record + os.linesep) @@ -929,12 +972,14 @@ class NullTimeKeeper: def tag(*tags): """Decorator to add tags to a test class or method.""" + def decorator(obj): - if hasattr(obj, 'tags'): + if hasattr(obj, "tags"): obj.tags = obj.tags.union(tags) else: - setattr(obj, 'tags', set(tags)) + setattr(obj, "tags", set(tags)) return obj + return decorator |
