diff options
| author | django-bot <ops@djangoproject.com> | 2022-02-03 20:24:19 +0100 |
|---|---|---|
| committer | Mariusz Felisiak <felisiak.mariusz@gmail.com> | 2022-02-07 20:37:05 +0100 |
| commit | 9c19aff7c7561e3a82978a272ecdaad40dda5c00 (patch) | |
| tree | f0506b668a013d0063e5fba3dbf4863b466713ba /django/test | |
| parent | f68fa8b45dfac545cfc4111d4e52804c86db68d3 (diff) | |
Refs #33476 -- Reformatted code with Black.
Diffstat (limited to 'django/test')
| -rw-r--r-- | django/test/__init__.py | 41 | ||||
| -rw-r--r-- | django/test/client.py | 578 | ||||
| -rw-r--r-- | django/test/html.py | 106 | ||||
| -rw-r--r-- | django/test/runner.py | 343 | ||||
| -rw-r--r-- | django/test/selenium.py | 32 | ||||
| -rw-r--r-- | django/test/signals.py | 76 | ||||
| -rw-r--r-- | django/test/testcases.py | 632 | ||||
| -rw-r--r-- | django/test/utils.py | 221 |
8 files changed, 1286 insertions, 743 deletions
diff --git a/django/test/__init__.py b/django/test/__init__.py index d1f953a8dd..485298e8e7 100644 --- a/django/test/__init__.py +++ b/django/test/__init__.py @@ -1,21 +1,38 @@ """Django Unit Test framework.""" -from django.test.client import ( - AsyncClient, AsyncRequestFactory, Client, RequestFactory, -) +from django.test.client import AsyncClient, AsyncRequestFactory, Client, RequestFactory from django.test.testcases import ( - LiveServerTestCase, SimpleTestCase, TestCase, TransactionTestCase, - skipIfDBFeature, skipUnlessAnyDBFeature, skipUnlessDBFeature, + LiveServerTestCase, + SimpleTestCase, + TestCase, + TransactionTestCase, + skipIfDBFeature, + skipUnlessAnyDBFeature, + skipUnlessDBFeature, ) from django.test.utils import ( - ignore_warnings, modify_settings, override_settings, - override_system_checks, tag, + ignore_warnings, + modify_settings, + override_settings, + override_system_checks, + tag, ) __all__ = [ - 'AsyncClient', 'AsyncRequestFactory', 'Client', 'RequestFactory', - 'TestCase', 'TransactionTestCase', 'SimpleTestCase', 'LiveServerTestCase', - 'skipIfDBFeature', 'skipUnlessAnyDBFeature', 'skipUnlessDBFeature', - 'ignore_warnings', 'modify_settings', 'override_settings', - 'override_system_checks', 'tag', + "AsyncClient", + "AsyncRequestFactory", + "Client", + "RequestFactory", + "TestCase", + "TransactionTestCase", + "SimpleTestCase", + "LiveServerTestCase", + "skipIfDBFeature", + "skipUnlessAnyDBFeature", + "skipUnlessDBFeature", + "ignore_warnings", + "modify_settings", + "override_settings", + "override_system_checks", + "tag", ] diff --git a/django/test/client.py b/django/test/client.py index af1090a740..a38e7dae13 100644 --- a/django/test/client.py +++ b/django/test/client.py @@ -16,9 +16,7 @@ from django.core.handlers.asgi import ASGIRequest from django.core.handlers.base import BaseHandler from django.core.handlers.wsgi import WSGIRequest from django.core.serializers.json import DjangoJSONEncoder -from django.core.signals import ( - got_request_exception, request_finished, request_started, -) +from django.core.signals import got_request_exception, request_finished, request_started from django.db import close_old_connections from django.http import HttpRequest, QueryDict, SimpleCookie from django.test import signals @@ -31,20 +29,26 @@ from django.utils.itercompat import is_iterable from django.utils.regex_helper import _lazy_re_compile __all__ = ( - 'AsyncClient', 'AsyncRequestFactory', 'Client', 'RedirectCycleError', - 'RequestFactory', 'encode_file', 'encode_multipart', + "AsyncClient", + "AsyncRequestFactory", + "Client", + "RedirectCycleError", + "RequestFactory", + "encode_file", + "encode_multipart", ) -BOUNDARY = 'BoUnDaRyStRiNg' -MULTIPART_CONTENT = 'multipart/form-data; boundary=%s' % BOUNDARY -CONTENT_TYPE_RE = _lazy_re_compile(r'.*; charset=([\w-]+);?') +BOUNDARY = "BoUnDaRyStRiNg" +MULTIPART_CONTENT = "multipart/form-data; boundary=%s" % BOUNDARY +CONTENT_TYPE_RE = _lazy_re_compile(r".*; charset=([\w-]+);?") # Structured suffix spec: https://tools.ietf.org/html/rfc6838#section-4.2.8 -JSON_CONTENT_TYPE_RE = _lazy_re_compile(r'^application\/(.+\+)?json') +JSON_CONTENT_TYPE_RE = _lazy_re_compile(r"^application\/(.+\+)?json") class RedirectCycleError(Exception): """The test client has been asked to follow a redirect loop.""" + def __init__(self, message, last_response): super().__init__(message) self.last_response = last_response @@ -58,6 +62,7 @@ class FakePayload: length. This makes sure that views can't do anything under the test client that wouldn't work in real life. """ + def __init__(self, content=None): self.__content = BytesIO() self.__len = 0 @@ -74,7 +79,9 @@ class FakePayload: self.read_started = True if num_bytes is None: num_bytes = self.__len or 0 - assert self.__len >= num_bytes, "Cannot read more than the available bytes from the HTTP incoming data." + assert ( + self.__len >= num_bytes + ), "Cannot read more than the available bytes from the HTTP incoming data." content = self.__content.read(num_bytes) self.__len -= num_bytes return content @@ -92,7 +99,7 @@ def closing_iterator_wrapper(iterable, close): yield from iterable finally: request_finished.disconnect(close_old_connections) - close() # will fire request_finished + close() # will fire request_finished request_finished.connect(close_old_connections) @@ -106,12 +113,12 @@ def conditional_content_removal(request, response): if response.streaming: response.streaming_content = [] else: - response.content = b'' - if request.method == 'HEAD': + response.content = b"" + if request.method == "HEAD": if response.streaming: response.streaming_content = [] else: - response.content = b'' + response.content = b"" return response @@ -121,6 +128,7 @@ class ClientHandler(BaseHandler): interface to compose requests, but return the raw HttpResponse object with the originating WSGIRequest attached to its ``wsgi_request`` attribute. """ + def __init__(self, enforce_csrf_checks=True, *args, **kwargs): self.enforce_csrf_checks = enforce_csrf_checks super().__init__(*args, **kwargs) @@ -154,10 +162,11 @@ class ClientHandler(BaseHandler): # Emulate a WSGI server by calling the close method on completion. if response.streaming: response.streaming_content = closing_iterator_wrapper( - response.streaming_content, response.close) + response.streaming_content, response.close + ) else: request_finished.disconnect(close_old_connections) - response.close() # will fire request_finished + response.close() # will fire request_finished request_finished.connect(close_old_connections) return response @@ -165,6 +174,7 @@ class ClientHandler(BaseHandler): class AsyncClientHandler(BaseHandler): """An async version of ClientHandler.""" + def __init__(self, enforce_csrf_checks=True, *args, **kwargs): self.enforce_csrf_checks = enforce_csrf_checks super().__init__(*args, **kwargs) @@ -175,13 +185,15 @@ class AsyncClientHandler(BaseHandler): if self._middleware_chain is None: self.load_middleware(is_async=True) # Extract body file from the scope, if provided. - if '_body_file' in scope: - body_file = scope.pop('_body_file') + if "_body_file" in scope: + body_file = scope.pop("_body_file") else: - body_file = FakePayload('') + body_file = FakePayload("") request_started.disconnect(close_old_connections) - await sync_to_async(request_started.send, thread_sensitive=False)(sender=self.__class__, scope=scope) + await sync_to_async(request_started.send, thread_sensitive=False)( + sender=self.__class__, scope=scope + ) request_started.connect(close_old_connections) request = ASGIRequest(scope, body_file) # Sneaky little hack so that we can easily get round @@ -197,7 +209,9 @@ class AsyncClientHandler(BaseHandler): response.asgi_request = request # Emulate a server by calling the close method on completion. if response.streaming: - response.streaming_content = await sync_to_async(closing_iterator_wrapper, thread_sensitive=False)( + response.streaming_content = await sync_to_async( + closing_iterator_wrapper, thread_sensitive=False + )( response.streaming_content, response.close, ) @@ -216,10 +230,10 @@ def store_rendered_templates(store, signal, sender, template, context, **kwargs) The context is copied so that it is an accurate representation at the time of rendering. """ - store.setdefault('templates', []).append(template) - if 'context' not in store: - store['context'] = ContextList() - store['context'].append(copy(context)) + store.setdefault("templates", []).append(template) + if "context" not in store: + store["context"] = ContextList() + store["context"].append(copy(context)) def encode_multipart(boundary, data): @@ -255,25 +269,33 @@ def encode_multipart(boundary, data): if is_file(item): lines.extend(encode_file(boundary, key, item)) else: - lines.extend(to_bytes(val) for val in [ - '--%s' % boundary, - 'Content-Disposition: form-data; name="%s"' % key, - '', - item - ]) + lines.extend( + to_bytes(val) + for val in [ + "--%s" % boundary, + 'Content-Disposition: form-data; name="%s"' % key, + "", + item, + ] + ) else: - lines.extend(to_bytes(val) for val in [ - '--%s' % boundary, - 'Content-Disposition: form-data; name="%s"' % key, - '', - value - ]) + lines.extend( + to_bytes(val) + for val in [ + "--%s" % boundary, + 'Content-Disposition: form-data; name="%s"' % key, + "", + value, + ] + ) - lines.extend([ - to_bytes('--%s--' % boundary), - b'', - ]) - return b'\r\n'.join(lines) + lines.extend( + [ + to_bytes("--%s--" % boundary), + b"", + ] + ) + return b"\r\n".join(lines) def encode_file(boundary, key, file): @@ -282,10 +304,10 @@ def encode_file(boundary, key, file): # file.name might not be a string. For example, it's an int for # tempfile.TemporaryFile(). - file_has_string_name = hasattr(file, 'name') and isinstance(file.name, str) - filename = os.path.basename(file.name) if file_has_string_name else '' + file_has_string_name = hasattr(file, "name") and isinstance(file.name, str) + filename = os.path.basename(file.name) if file_has_string_name else "" - if hasattr(file, 'content_type'): + if hasattr(file, "content_type"): content_type = file.content_type elif filename: content_type = mimetypes.guess_type(filename)[0] @@ -293,15 +315,16 @@ def encode_file(boundary, key, file): content_type = None if content_type is None: - content_type = 'application/octet-stream' + content_type = "application/octet-stream" filename = filename or key return [ - to_bytes('--%s' % boundary), - to_bytes('Content-Disposition: form-data; name="%s"; filename="%s"' - % (key, filename)), - to_bytes('Content-Type: %s' % content_type), - b'', - to_bytes(file.read()) + to_bytes("--%s" % boundary), + to_bytes( + 'Content-Disposition: form-data; name="%s"; filename="%s"' % (key, filename) + ), + to_bytes("Content-Type: %s" % content_type), + b"", + to_bytes(file.read()), ] @@ -318,6 +341,7 @@ class RequestFactory: Once you have a request object you can pass it to any view function, just as if that view had been hooked up using a URLconf. """ + def __init__(self, *, json_encoder=DjangoJSONEncoder, **defaults): self.json_encoder = json_encoder self.defaults = defaults @@ -333,24 +357,26 @@ class RequestFactory: # - REMOTE_ADDR: often useful, see #8551. # See https://www.python.org/dev/peps/pep-3333/#environ-variables return { - 'HTTP_COOKIE': '; '.join(sorted( - '%s=%s' % (morsel.key, morsel.coded_value) - for morsel in self.cookies.values() - )), - 'PATH_INFO': '/', - 'REMOTE_ADDR': '127.0.0.1', - 'REQUEST_METHOD': 'GET', - 'SCRIPT_NAME': '', - 'SERVER_NAME': 'testserver', - 'SERVER_PORT': '80', - 'SERVER_PROTOCOL': 'HTTP/1.1', - 'wsgi.version': (1, 0), - 'wsgi.url_scheme': 'http', - 'wsgi.input': FakePayload(b''), - 'wsgi.errors': self.errors, - 'wsgi.multiprocess': True, - 'wsgi.multithread': False, - 'wsgi.run_once': False, + "HTTP_COOKIE": "; ".join( + sorted( + "%s=%s" % (morsel.key, morsel.coded_value) + for morsel in self.cookies.values() + ) + ), + "PATH_INFO": "/", + "REMOTE_ADDR": "127.0.0.1", + "REQUEST_METHOD": "GET", + "SCRIPT_NAME": "", + "SERVER_NAME": "testserver", + "SERVER_PORT": "80", + "SERVER_PROTOCOL": "HTTP/1.1", + "wsgi.version": (1, 0), + "wsgi.url_scheme": "http", + "wsgi.input": FakePayload(b""), + "wsgi.errors": self.errors, + "wsgi.multiprocess": True, + "wsgi.multithread": False, + "wsgi.run_once": False, **self.defaults, **request, } @@ -376,7 +402,9 @@ class RequestFactory: Return encoded JSON if data is a dict, list, or tuple and content_type is application/json. """ - should_encode = JSON_CONTENT_TYPE_RE.match(content_type) and isinstance(data, (dict, list, tuple)) + should_encode = JSON_CONTENT_TYPE_RE.match(content_type) and isinstance( + data, (dict, list, tuple) + ) return json.dumps(data, cls=self.json_encoder) if should_encode else data def _get_path(self, parsed): @@ -388,88 +416,128 @@ class RequestFactory: # Replace the behavior where non-ASCII values in the WSGI environ are # arbitrarily decoded with ISO-8859-1. # Refs comment in `get_bytes_from_wsgi()`. - return path.decode('iso-8859-1') + return path.decode("iso-8859-1") def get(self, path, data=None, secure=False, **extra): """Construct a GET request.""" data = {} if data is None else data - return self.generic('GET', path, secure=secure, **{ - 'QUERY_STRING': urlencode(data, doseq=True), - **extra, - }) + return self.generic( + "GET", + path, + secure=secure, + **{ + "QUERY_STRING": urlencode(data, doseq=True), + **extra, + }, + ) - def post(self, path, data=None, content_type=MULTIPART_CONTENT, - secure=False, **extra): + def post( + self, path, data=None, content_type=MULTIPART_CONTENT, secure=False, **extra + ): """Construct a POST request.""" data = self._encode_json({} if data is None else data, content_type) post_data = self._encode_data(data, content_type) - return self.generic('POST', path, post_data, content_type, - secure=secure, **extra) + return self.generic( + "POST", path, post_data, content_type, secure=secure, **extra + ) def head(self, path, data=None, secure=False, **extra): """Construct a HEAD request.""" data = {} if data is None else data - return self.generic('HEAD', path, secure=secure, **{ - 'QUERY_STRING': urlencode(data, doseq=True), - **extra, - }) + return self.generic( + "HEAD", + path, + secure=secure, + **{ + "QUERY_STRING": urlencode(data, doseq=True), + **extra, + }, + ) def trace(self, path, secure=False, **extra): """Construct a TRACE request.""" - return self.generic('TRACE', path, secure=secure, **extra) + return self.generic("TRACE", path, secure=secure, **extra) - def options(self, path, data='', content_type='application/octet-stream', - secure=False, **extra): + def options( + self, + path, + data="", + content_type="application/octet-stream", + secure=False, + **extra, + ): "Construct an OPTIONS request." - return self.generic('OPTIONS', path, data, content_type, - secure=secure, **extra) + return self.generic("OPTIONS", path, data, content_type, secure=secure, **extra) - def put(self, path, data='', content_type='application/octet-stream', - secure=False, **extra): + def put( + self, + path, + data="", + content_type="application/octet-stream", + secure=False, + **extra, + ): """Construct a PUT request.""" data = self._encode_json(data, content_type) - return self.generic('PUT', path, data, content_type, - secure=secure, **extra) + return self.generic("PUT", path, data, content_type, secure=secure, **extra) - def patch(self, path, data='', content_type='application/octet-stream', - secure=False, **extra): + def patch( + self, + path, + data="", + content_type="application/octet-stream", + secure=False, + **extra, + ): """Construct a PATCH request.""" data = self._encode_json(data, content_type) - return self.generic('PATCH', path, data, content_type, - secure=secure, **extra) + return self.generic("PATCH", path, data, content_type, secure=secure, **extra) - def delete(self, path, data='', content_type='application/octet-stream', - secure=False, **extra): + def delete( + self, + path, + data="", + content_type="application/octet-stream", + secure=False, + **extra, + ): """Construct a DELETE request.""" data = self._encode_json(data, content_type) - return self.generic('DELETE', path, data, content_type, - secure=secure, **extra) + return self.generic("DELETE", path, data, content_type, secure=secure, **extra) - def generic(self, method, path, data='', - content_type='application/octet-stream', secure=False, - **extra): + def generic( + self, + method, + path, + data="", + content_type="application/octet-stream", + secure=False, + **extra, + ): """Construct an arbitrary HTTP request.""" parsed = urlparse(str(path)) # path can be lazy data = force_bytes(data, settings.DEFAULT_CHARSET) r = { - 'PATH_INFO': self._get_path(parsed), - 'REQUEST_METHOD': method, - 'SERVER_PORT': '443' if secure else '80', - 'wsgi.url_scheme': 'https' if secure else 'http', + "PATH_INFO": self._get_path(parsed), + "REQUEST_METHOD": method, + "SERVER_PORT": "443" if secure else "80", + "wsgi.url_scheme": "https" if secure else "http", } if data: - r.update({ - 'CONTENT_LENGTH': str(len(data)), - 'CONTENT_TYPE': content_type, - 'wsgi.input': FakePayload(data), - }) + r.update( + { + "CONTENT_LENGTH": str(len(data)), + "CONTENT_TYPE": content_type, + "wsgi.input": FakePayload(data), + } + ) r.update(extra) # If QUERY_STRING is absent or empty, we want to extract it from the URL. - if not r.get('QUERY_STRING'): + if not r.get("QUERY_STRING"): # WSGI requires latin-1 encoded strings. See get_path_info(). - query_string = parsed[4].encode().decode('iso-8859-1') - r['QUERY_STRING'] = query_string + query_string = parsed[4].encode().decode("iso-8859-1") + r["QUERY_STRING"] = query_string return self.request(**r) @@ -487,30 +555,35 @@ class AsyncRequestFactory(RequestFactory): a) this makes ASGIRequest subclasses, and b) AsyncTestClient can subclass it. """ + def _base_scope(self, **request): """The base scope for a request.""" # This is a minimal valid ASGI scope, plus: # - headers['cookie'] for cookie support, # - 'client' often useful, see #8551. scope = { - 'asgi': {'version': '3.0'}, - 'type': 'http', - 'http_version': '1.1', - 'client': ['127.0.0.1', 0], - 'server': ('testserver', '80'), - 'scheme': 'http', - 'method': 'GET', - 'headers': [], + "asgi": {"version": "3.0"}, + "type": "http", + "http_version": "1.1", + "client": ["127.0.0.1", 0], + "server": ("testserver", "80"), + "scheme": "http", + "method": "GET", + "headers": [], **self.defaults, **request, } - scope['headers'].append(( - b'cookie', - b'; '.join(sorted( - ('%s=%s' % (morsel.key, morsel.coded_value)).encode('ascii') - for morsel in self.cookies.values() - )), - )) + scope["headers"].append( + ( + b"cookie", + b"; ".join( + sorted( + ("%s=%s" % (morsel.key, morsel.coded_value)).encode("ascii") + for morsel in self.cookies.values() + ) + ), + ) + ) return scope def request(self, **request): @@ -518,45 +591,52 @@ class AsyncRequestFactory(RequestFactory): # This is synchronous, which means all methods on this class are. # AsyncClient, however, has an async request function, which makes all # its methods async. - if '_body_file' in request: - body_file = request.pop('_body_file') + if "_body_file" in request: + body_file = request.pop("_body_file") else: - body_file = FakePayload('') + body_file = FakePayload("") return ASGIRequest(self._base_scope(**request), body_file) def generic( - self, method, path, data='', content_type='application/octet-stream', - secure=False, **extra, + self, + method, + path, + data="", + content_type="application/octet-stream", + secure=False, + **extra, ): """Construct an arbitrary HTTP request.""" parsed = urlparse(str(path)) # path can be lazy. data = force_bytes(data, settings.DEFAULT_CHARSET) s = { - 'method': method, - 'path': self._get_path(parsed), - 'server': ('127.0.0.1', '443' if secure else '80'), - 'scheme': 'https' if secure else 'http', - 'headers': [(b'host', b'testserver')], + "method": method, + "path": self._get_path(parsed), + "server": ("127.0.0.1", "443" if secure else "80"), + "scheme": "https" if secure else "http", + "headers": [(b"host", b"testserver")], } if data: - s['headers'].extend([ - (b'content-length', str(len(data)).encode('ascii')), - (b'content-type', content_type.encode('ascii')), - ]) - s['_body_file'] = FakePayload(data) - follow = extra.pop('follow', None) + s["headers"].extend( + [ + (b"content-length", str(len(data)).encode("ascii")), + (b"content-type", content_type.encode("ascii")), + ] + ) + s["_body_file"] = FakePayload(data) + follow = extra.pop("follow", None) if follow is not None: - s['follow'] = follow - if query_string := extra.pop('QUERY_STRING', None): - s['query_string'] = query_string - s['headers'] += [ - (key.lower().encode('ascii'), value.encode('latin1')) + s["follow"] = follow + if query_string := extra.pop("QUERY_STRING", None): + s["query_string"] = query_string + s["headers"] += [ + (key.lower().encode("ascii"), value.encode("latin1")) for key, value in extra.items() ] # If QUERY_STRING is absent or empty, we want to extract it from the # URL. - if not s.get('query_string'): - s['query_string'] = parsed[4] + if not s.get("query_string"): + s["query_string"] = parsed[4] return self.request(**s) @@ -564,6 +644,7 @@ class ClientMixin: """ Mixin with common methods between Client and AsyncClient. """ + def store_exc_info(self, **kwargs): """Store exceptions when they are generated by a view.""" self.exc_info = sys.exc_info() @@ -601,6 +682,7 @@ class ClientMixin: are incorrect. """ from django.contrib.auth import authenticate + user = authenticate(**credentials) if user: self._login(user) @@ -610,9 +692,10 @@ class ClientMixin: def force_login(self, user, backend=None): def get_backend(): from django.contrib.auth import load_backend + for backend_path in settings.AUTHENTICATION_BACKENDS: backend = load_backend(backend_path) - if hasattr(backend, 'get_user'): + if hasattr(backend, "get_user"): return backend_path if backend is None: @@ -637,17 +720,18 @@ class ClientMixin: session_cookie = settings.SESSION_COOKIE_NAME self.cookies[session_cookie] = request.session.session_key cookie_data = { - 'max-age': None, - 'path': '/', - 'domain': settings.SESSION_COOKIE_DOMAIN, - 'secure': settings.SESSION_COOKIE_SECURE or None, - 'expires': None, + "max-age": None, + "path": "/", + "domain": settings.SESSION_COOKIE_DOMAIN, + "secure": settings.SESSION_COOKIE_SECURE or None, + "expires": None, } self.cookies[session_cookie].update(cookie_data) def logout(self): """Log out the user by removing the cookies and session object.""" from django.contrib.auth import get_user, logout + request = HttpRequest() if self.session: request.session = self.session @@ -659,13 +743,15 @@ class ClientMixin: self.cookies = SimpleCookie() def _parse_json(self, response, **extra): - if not hasattr(response, '_json'): - if not JSON_CONTENT_TYPE_RE.match(response.get('Content-Type')): + if not hasattr(response, "_json"): + if not JSON_CONTENT_TYPE_RE.match(response.get("Content-Type")): raise ValueError( 'Content-Type header is "%s", not "application/json"' - % response.get('Content-Type') + % response.get("Content-Type") ) - response._json = json.loads(response.content.decode(response.charset), **extra) + response._json = json.loads( + response.content.decode(response.charset), **extra + ) return response._json @@ -687,7 +773,10 @@ class Client(ClientMixin, RequestFactory): contexts and templates produced by a view, rather than the HTML rendered to the end-user. """ - def __init__(self, enforce_csrf_checks=False, raise_request_exception=True, **defaults): + + def __init__( + self, enforce_csrf_checks=False, raise_request_exception=True, **defaults + ): super().__init__(**defaults) self.handler = ClientHandler(enforce_csrf_checks) self.raise_request_exception = raise_request_exception @@ -723,13 +812,13 @@ class Client(ClientMixin, RequestFactory): response.client = self response.request = request # Add any rendered template detail to the response. - response.templates = data.get('templates', []) - response.context = data.get('context') + response.templates = data.get("templates", []) + response.context = data.get("context") response.json = partial(self._parse_json, response) # Attach the ResolverMatch instance to the response. - urlconf = getattr(response.wsgi_request, 'urlconf', None) + urlconf = getattr(response.wsgi_request, "urlconf", None) response.resolver_match = SimpleLazyObject( - lambda: resolve(request['PATH_INFO'], urlconf=urlconf), + lambda: resolve(request["PATH_INFO"], urlconf=urlconf), ) # Flatten a single context. Not really necessary anymore thanks to the # __getattr__ flattening in ContextList, but has some edge case @@ -749,13 +838,24 @@ class Client(ClientMixin, RequestFactory): response = self._handle_redirects(response, data=data, **extra) return response - def post(self, path, data=None, content_type=MULTIPART_CONTENT, - follow=False, secure=False, **extra): + def post( + self, + path, + data=None, + content_type=MULTIPART_CONTENT, + follow=False, + secure=False, + **extra, + ): """Request a response from the server using POST.""" self.extra = extra - response = super().post(path, data=data, content_type=content_type, secure=secure, **extra) + response = super().post( + path, data=data, content_type=content_type, secure=secure, **extra + ) if follow: - response = self._handle_redirects(response, data=data, content_type=content_type, **extra) + response = self._handle_redirects( + response, data=data, content_type=content_type, **extra + ) return response def head(self, path, data=None, follow=False, secure=False, **extra): @@ -766,43 +866,87 @@ class Client(ClientMixin, RequestFactory): response = self._handle_redirects(response, data=data, **extra) return response - def options(self, path, data='', content_type='application/octet-stream', - follow=False, secure=False, **extra): + def options( + self, + path, + data="", + content_type="application/octet-stream", + follow=False, + secure=False, + **extra, + ): """Request a response from the server using OPTIONS.""" self.extra = extra - response = super().options(path, data=data, content_type=content_type, secure=secure, **extra) + response = super().options( + path, data=data, content_type=content_type, secure=secure, **extra + ) if follow: - response = self._handle_redirects(response, data=data, content_type=content_type, **extra) + response = self._handle_redirects( + response, data=data, content_type=content_type, **extra + ) return response - def put(self, path, data='', content_type='application/octet-stream', - follow=False, secure=False, **extra): + def put( + self, + path, + data="", + content_type="application/octet-stream", + follow=False, + secure=False, + **extra, + ): """Send a resource to the server using PUT.""" self.extra = extra - response = super().put(path, data=data, content_type=content_type, secure=secure, **extra) + response = super().put( + path, data=data, content_type=content_type, secure=secure, **extra + ) if follow: - response = self._handle_redirects(response, data=data, content_type=content_type, **extra) + response = self._handle_redirects( + response, data=data, content_type=content_type, **extra + ) return response - def patch(self, path, data='', content_type='application/octet-stream', - follow=False, secure=False, **extra): + def patch( + self, + path, + data="", + content_type="application/octet-stream", + follow=False, + secure=False, + **extra, + ): """Send a resource to the server using PATCH.""" self.extra = extra - response = super().patch(path, data=data, content_type=content_type, secure=secure, **extra) + response = super().patch( + path, data=data, content_type=content_type, secure=secure, **extra + ) if follow: - response = self._handle_redirects(response, data=data, content_type=content_type, **extra) + response = self._handle_redirects( + response, data=data, content_type=content_type, **extra + ) return response - def delete(self, path, data='', content_type='application/octet-stream', - follow=False, secure=False, **extra): + def delete( + self, + path, + data="", + content_type="application/octet-stream", + follow=False, + secure=False, + **extra, + ): """Send a DELETE request to the server.""" self.extra = extra - response = super().delete(path, data=data, content_type=content_type, secure=secure, **extra) + response = super().delete( + path, data=data, content_type=content_type, secure=secure, **extra + ) if follow: - response = self._handle_redirects(response, data=data, content_type=content_type, **extra) + response = self._handle_redirects( + response, data=data, content_type=content_type, **extra + ) return response - def trace(self, path, data='', follow=False, secure=False, **extra): + def trace(self, path, data="", follow=False, secure=False, **extra): """Send a TRACE request to the server.""" self.extra = extra response = super().trace(path, data=data, secure=secure, **extra) @@ -810,7 +954,7 @@ class Client(ClientMixin, RequestFactory): response = self._handle_redirects(response, data=data, **extra) return response - def _handle_redirects(self, response, data='', content_type='', **extra): + def _handle_redirects(self, response, data="", content_type="", **extra): """ Follow any redirects by requesting responses from the server using GET. """ @@ -829,39 +973,46 @@ class Client(ClientMixin, RequestFactory): url = urlsplit(response_url) if url.scheme: - extra['wsgi.url_scheme'] = url.scheme + extra["wsgi.url_scheme"] = url.scheme if url.hostname: - extra['SERVER_NAME'] = url.hostname + extra["SERVER_NAME"] = url.hostname if url.port: - extra['SERVER_PORT'] = str(url.port) + extra["SERVER_PORT"] = str(url.port) path = url.path # RFC 2616: bare domains without path are treated as the root. if not path and url.netloc: - path = '/' + path = "/" # Prepend the request path to handle relative path redirects - if not path.startswith('/'): - path = urljoin(response.request['PATH_INFO'], path) + if not path.startswith("/"): + path = urljoin(response.request["PATH_INFO"], path) - if response.status_code in (HTTPStatus.TEMPORARY_REDIRECT, HTTPStatus.PERMANENT_REDIRECT): + if response.status_code in ( + HTTPStatus.TEMPORARY_REDIRECT, + HTTPStatus.PERMANENT_REDIRECT, + ): # Preserve request method and query string (if needed) # post-redirect for 307/308 responses. - request_method = response.request['REQUEST_METHOD'].lower() - if request_method not in ('get', 'head'): - extra['QUERY_STRING'] = url.query + request_method = response.request["REQUEST_METHOD"].lower() + if request_method not in ("get", "head"): + extra["QUERY_STRING"] = url.query request_method = getattr(self, request_method) else: request_method = self.get data = QueryDict(url.query) content_type = None - response = request_method(path, data=data, content_type=content_type, follow=False, **extra) + response = request_method( + path, data=data, content_type=content_type, follow=False, **extra + ) response.redirect_chain = redirect_chain if redirect_chain[-1] in redirect_chain[:-1]: # Check that we're not redirecting to somewhere we've already # been to, to prevent loops. - raise RedirectCycleError("Redirect loop detected.", last_response=response) + raise RedirectCycleError( + "Redirect loop detected.", last_response=response + ) if len(redirect_chain) > 20: # Such a lengthy chain likely also means a loop, but one with # a growing path, changing view, or changing query argument; @@ -878,7 +1029,10 @@ class AsyncClient(ClientMixin, AsyncRequestFactory): Does not currently support "follow" on its methods. """ - def __init__(self, enforce_csrf_checks=False, raise_request_exception=True, **defaults): + + def __init__( + self, enforce_csrf_checks=False, raise_request_exception=True, **defaults + ): super().__init__(**defaults) self.handler = AsyncClientHandler(enforce_csrf_checks) self.raise_request_exception = raise_request_exception @@ -892,19 +1046,19 @@ class AsyncClient(ClientMixin, AsyncRequestFactory): query environment, which can be overridden using the arguments to the request. """ - if 'follow' in request: + if "follow" in request: raise NotImplementedError( - 'AsyncClient request methods do not accept the follow parameter.' + "AsyncClient request methods do not accept the follow parameter." ) scope = self._base_scope(**request) # Curry a data dictionary into an instance of the template renderer # callback function. data = {} on_template_render = partial(store_rendered_templates, data) - signal_uid = 'template-render-%s' % id(request) + signal_uid = "template-render-%s" % id(request) signals.template_rendered.connect(on_template_render, dispatch_uid=signal_uid) # Capture exceptions created by the handler. - exception_uid = 'request-exception-%s' % id(request) + exception_uid = "request-exception-%s" % id(request) got_request_exception.connect(self.store_exc_info, dispatch_uid=exception_uid) try: response = await self.handler(scope) @@ -917,13 +1071,13 @@ class AsyncClient(ClientMixin, AsyncRequestFactory): response.client = self response.request = request # Add any rendered template detail to the response. - response.templates = data.get('templates', []) - response.context = data.get('context') + response.templates = data.get("templates", []) + response.context = data.get("context") response.json = partial(self._parse_json, response) # Attach the ResolverMatch instance to the response. - urlconf = getattr(response.asgi_request, 'urlconf', None) + urlconf = getattr(response.asgi_request, "urlconf", None) response.resolver_match = SimpleLazyObject( - lambda: resolve(request['path'], urlconf=urlconf), + lambda: resolve(request["path"], urlconf=urlconf), ) # Flatten a single context. Not really necessary anymore thanks to the # __getattr__ flattening in ContextList, but has some edge case diff --git a/django/test/html.py b/django/test/html.py index 07e986439b..87e213d651 100644 --- a/django/test/html.py +++ b/django/test/html.py @@ -7,32 +7,52 @@ from django.utils.regex_helper import _lazy_re_compile # ASCII whitespace is U+0009 TAB, U+000A LF, U+000C FF, U+000D CR, or U+0020 # SPACE. # https://infra.spec.whatwg.org/#ascii-whitespace -ASCII_WHITESPACE = _lazy_re_compile(r'[\t\n\f\r ]+') +ASCII_WHITESPACE = _lazy_re_compile(r"[\t\n\f\r ]+") # https://html.spec.whatwg.org/#attributes-3 BOOLEAN_ATTRIBUTES = { - 'allowfullscreen', 'async', 'autofocus', 'autoplay', 'checked', 'controls', - 'default', 'defer ', 'disabled', 'formnovalidate', 'hidden', 'ismap', - 'itemscope', 'loop', 'multiple', 'muted', 'nomodule', 'novalidate', 'open', - 'playsinline', 'readonly', 'required', 'reversed', 'selected', + "allowfullscreen", + "async", + "autofocus", + "autoplay", + "checked", + "controls", + "default", + "defer ", + "disabled", + "formnovalidate", + "hidden", + "ismap", + "itemscope", + "loop", + "multiple", + "muted", + "nomodule", + "novalidate", + "open", + "playsinline", + "readonly", + "required", + "reversed", + "selected", # Attributes for deprecated tags. - 'truespeed', + "truespeed", } def normalize_whitespace(string): - return ASCII_WHITESPACE.sub(' ', string) + return ASCII_WHITESPACE.sub(" ", string) def normalize_attributes(attributes): normalized = [] for name, value in attributes: - if name == 'class' and value: + if name == "class" and value: # Special case handling of 'class' attribute, so that comparisons # of DOM instances are not sensitive to ordering of classes. - value = ' '.join(sorted( - value for value in ASCII_WHITESPACE.split(value) if value - )) + value = " ".join( + sorted(value for value in ASCII_WHITESPACE.split(value) if value) + ) # Boolean attributes without a value is same as attribute with value # that equals the attributes name. For example: # <input checked> == <input checked="checked"> @@ -40,7 +60,7 @@ def normalize_attributes(attributes): if not value or value == name: value = None elif value is None: - value = '' + value = "" normalized.append((name, value)) return normalized @@ -80,11 +100,11 @@ class Element: for i, child in enumerate(self.children): if isinstance(child, str): self.children[i] = child.strip() - elif hasattr(child, 'finalize'): + elif hasattr(child, "finalize"): child.finalize() def __eq__(self, element): - if not hasattr(element, 'name') or self.name != element.name: + if not hasattr(element, "name") or self.name != element.name: return False if self.attributes != element.attributes: return False @@ -142,21 +162,23 @@ class Element: return self.children[key] def __str__(self): - output = '<%s' % self.name + output = "<%s" % self.name for key, value in self.attributes: if value is not None: output += ' %s="%s"' % (key, value) else: - output += ' %s' % key + output += " %s" % key if self.children: - output += '>\n' - output += ''.join([ - html.escape(c) if isinstance(c, str) else str(c) - for c in self.children - ]) - output += '\n</%s>' % self.name + output += ">\n" + output += "".join( + [ + html.escape(c) if isinstance(c, str) else str(c) + for c in self.children + ] + ) + output += "\n</%s>" % self.name else: - output += '>' + output += ">" return output def __repr__(self): @@ -168,10 +190,9 @@ class RootElement(Element): super().__init__(None, ()) def __str__(self): - return ''.join([ - html.escape(c) if isinstance(c, str) else str(c) - for c in self.children - ]) + return "".join( + [html.escape(c) if isinstance(c, str) else str(c) for c in self.children] + ) class HTMLParseError(Exception): @@ -181,10 +202,23 @@ class HTMLParseError(Exception): class Parser(HTMLParser): # https://html.spec.whatwg.org/#void-elements SELF_CLOSING_TAGS = { - 'area', 'base', 'br', 'col', 'embed', 'hr', 'img', 'input', 'link', 'meta', - 'param', 'source', 'track', 'wbr', + "area", + "base", + "br", + "col", + "embed", + "hr", + "img", + "input", + "link", + "meta", + "param", + "source", + "track", + "wbr", # Deprecated tags - 'frame', 'spacer', + "frame", + "spacer", } def __init__(self): @@ -201,9 +235,9 @@ class Parser(HTMLParser): position = self.element_positions[element] if position is None: position = self.getpos() - if hasattr(position, 'lineno'): + if hasattr(position, "lineno"): position = position.lineno, position.offset - return 'Line %d, Column %d' % position + return "Line %d, Column %d" % position @property def current(self): @@ -227,13 +261,13 @@ class Parser(HTMLParser): def handle_endtag(self, tag): if not self.open_tags: - self.error("Unexpected end tag `%s` (%s)" % ( - tag, self.format_position())) + self.error("Unexpected end tag `%s` (%s)" % (tag, self.format_position())) element = self.open_tags.pop() while element.name != tag: if not self.open_tags: - self.error("Unexpected end tag `%s` (%s)" % ( - tag, self.format_position())) + self.error( + "Unexpected end tag `%s` (%s)" % (tag, self.format_position()) + ) element = self.open_tags.pop() def handle_data(self, data): diff --git a/django/test/runner.py b/django/test/runner.py index 2e36514922..113d5216a6 100644 --- a/django/test/runner.py +++ b/django/test/runner.py @@ -20,11 +20,11 @@ from io import StringIO from django.core.management import call_command from django.db import connections from django.test import SimpleTestCase, TestCase -from django.test.utils import ( - NullTimeKeeper, TimeKeeper, iter_test_cases, - setup_databases as _setup_databases, setup_test_environment, - teardown_databases as _teardown_databases, teardown_test_environment, -) +from django.test.utils import NullTimeKeeper, TimeKeeper, iter_test_cases +from django.test.utils import setup_databases as _setup_databases +from django.test.utils import setup_test_environment +from django.test.utils import teardown_databases as _teardown_databases +from django.test.utils import teardown_test_environment from django.utils.crypto import new_hash from django.utils.datastructures import OrderedSet from django.utils.deprecation import RemovedInDjango50Warning @@ -42,7 +42,7 @@ except ImportError: class DebugSQLTextTestResult(unittest.TextTestResult): def __init__(self, stream, descriptions, verbosity): - self.logger = logging.getLogger('django.db.backends') + self.logger = logging.getLogger("django.db.backends") self.logger.setLevel(logging.DEBUG) self.debug_sql_stream = None super().__init__(stream, descriptions, verbosity) @@ -65,7 +65,7 @@ class DebugSQLTextTestResult(unittest.TextTestResult): super().addError(test, err) if self.debug_sql_stream is None: # Error before tests e.g. in setUpTestData(). - sql = '' + sql = "" else: self.debug_sql_stream.seek(0) sql = self.debug_sql_stream.read() @@ -80,7 +80,11 @@ class DebugSQLTextTestResult(unittest.TextTestResult): super().addSubTest(test, subtest, err) if err is not None: self.debug_sql_stream.seek(0) - errors = self.failures if issubclass(err[0], test.failureException) else self.errors + errors = ( + self.failures + if issubclass(err[0], test.failureException) + else self.errors + ) errors[-1] = errors[-1] + (self.debug_sql_stream.read(),) def printErrorList(self, flavour, errors): @@ -124,6 +128,7 @@ class DummyList: """ Dummy list class for faking storage of results in unittest.TestResult. """ + __slots__ = () def append(self, item): @@ -157,10 +162,10 @@ class RemoteTestResult(unittest.TestResult): # attributes. This is possible since they aren't used after unpickling # after being sent to ParallelTestSuite. state = self.__dict__.copy() - state.pop('_stdout_buffer', None) - state.pop('_stderr_buffer', None) - state.pop('_original_stdout', None) - state.pop('_original_stderr', None) + state.pop("_stdout_buffer", None) + state.pop("_stderr_buffer", None) + state.pop("_original_stdout", None) + state.pop("_original_stderr", None) return state @property @@ -176,7 +181,8 @@ class RemoteTestResult(unittest.TestResult): pickle.loads(pickle.dumps(obj)) def _print_unpicklable_subtest(self, test, subtest, pickle_exc): - print(""" + print( + """ Subtest failed: test: {} @@ -189,7 +195,10 @@ test runner cannot handle it cleanly. Here is the pickling error: You should re-run this test with --parallel=1 to reproduce the failure with a cleaner failure message. -""".format(test, subtest, pickle_exc)) +""".format( + test, subtest, pickle_exc + ) + ) def check_picklable(self, test, err): # Ensure that sys.exc_info() tuples are picklable. This displays a @@ -202,11 +211,16 @@ with a cleaner failure message. self._confirm_picklable(err) except Exception as exc: original_exc_txt = repr(err[1]) - original_exc_txt = textwrap.fill(original_exc_txt, 75, initial_indent=' ', subsequent_indent=' ') + original_exc_txt = textwrap.fill( + original_exc_txt, 75, initial_indent=" ", subsequent_indent=" " + ) pickle_exc_txt = repr(exc) - pickle_exc_txt = textwrap.fill(pickle_exc_txt, 75, initial_indent=' ', subsequent_indent=' ') + pickle_exc_txt = textwrap.fill( + pickle_exc_txt, 75, initial_indent=" ", subsequent_indent=" " + ) if tblib is None: - print(""" + print( + """ {} failed: @@ -218,9 +232,13 @@ parallel test runner to handle this exception cleanly. In order to see the traceback, you should install tblib: python -m pip install tblib -""".format(test, original_exc_txt)) +""".format( + test, original_exc_txt + ) + ) else: - print(""" + print( + """ {} failed: @@ -235,7 +253,10 @@ Here's the error encountered while trying to pickle the exception: You should re-run this test with the --parallel=1 option to reproduce the failure and get a correct traceback. -""".format(test, original_exc_txt, pickle_exc_txt)) +""".format( + test, original_exc_txt, pickle_exc_txt + ) + ) raise def check_subtest_picklable(self, test, subtest): @@ -247,28 +268,28 @@ failure and get a correct traceback. def startTestRun(self): super().startTestRun() - self.events.append(('startTestRun',)) + self.events.append(("startTestRun",)) def stopTestRun(self): super().stopTestRun() - self.events.append(('stopTestRun',)) + self.events.append(("stopTestRun",)) def startTest(self, test): super().startTest(test) - self.events.append(('startTest', self.test_index)) + self.events.append(("startTest", self.test_index)) def stopTest(self, test): super().stopTest(test) - self.events.append(('stopTest', self.test_index)) + self.events.append(("stopTest", self.test_index)) def addError(self, test, err): self.check_picklable(test, err) - self.events.append(('addError', self.test_index, err)) + self.events.append(("addError", self.test_index, err)) super().addError(test, err) def addFailure(self, test, err): self.check_picklable(test, err) - self.events.append(('addFailure', self.test_index, err)) + self.events.append(("addFailure", self.test_index, err)) super().addFailure(test, err) def addSubTest(self, test, subtest, err): @@ -279,15 +300,15 @@ failure and get a correct traceback. # check_picklable() performs the tblib check. self.check_picklable(test, err) self.check_subtest_picklable(test, subtest) - self.events.append(('addSubTest', self.test_index, subtest, err)) + self.events.append(("addSubTest", self.test_index, subtest, err)) super().addSubTest(test, subtest, err) def addSuccess(self, test): - self.events.append(('addSuccess', self.test_index)) + self.events.append(("addSuccess", self.test_index)) super().addSuccess(test) def addSkip(self, test, reason): - self.events.append(('addSkip', self.test_index, reason)) + self.events.append(("addSkip", self.test_index, reason)) super().addSkip(test, reason) def addExpectedFailure(self, test, err): @@ -298,23 +319,23 @@ failure and get a correct traceback. if tblib is None: err = err[0], err[1], None self.check_picklable(test, err) - self.events.append(('addExpectedFailure', self.test_index, err)) + self.events.append(("addExpectedFailure", self.test_index, err)) super().addExpectedFailure(test, err) def addUnexpectedSuccess(self, test): - self.events.append(('addUnexpectedSuccess', self.test_index)) + self.events.append(("addUnexpectedSuccess", self.test_index)) super().addUnexpectedSuccess(test) def wasSuccessful(self): """Tells whether or not this result was a success.""" - failure_types = {'addError', 'addFailure', 'addSubTest', 'addUnexpectedSuccess'} + failure_types = {"addError", "addFailure", "addSubTest", "addUnexpectedSuccess"} return all(e[0] not in failure_types for e in self.events) def _exc_info_to_string(self, err, test): # Make this method no-op. It only powers the default unittest behavior # for recording errors, but this class pickles errors into 'events' # instead. - return '' + return "" class RemoteTestRunner: @@ -347,17 +368,17 @@ def get_max_test_processes(): """ # The current implementation of the parallel test runner requires # multiprocessing to start subprocesses with fork(). - if multiprocessing.get_start_method() != 'fork': + if multiprocessing.get_start_method() != "fork": return 1 try: - return int(os.environ['DJANGO_TEST_PROCESSES']) + return int(os.environ["DJANGO_TEST_PROCESSES"]) except KeyError: return multiprocessing.cpu_count() def parallel_type(value): """Parse value passed to the --parallel option.""" - if value == 'auto': + if value == "auto": return value try: return int(value) @@ -505,30 +526,30 @@ class Shuffler: """ # This doesn't need to be cryptographically strong, so use what's fastest. - hash_algorithm = 'md5' + hash_algorithm = "md5" @classmethod def _hash_text(cls, text): h = new_hash(cls.hash_algorithm, usedforsecurity=False) - h.update(text.encode('utf-8')) + h.update(text.encode("utf-8")) return h.hexdigest() def __init__(self, seed=None): if seed is None: # Limit seeds to 10 digits for simpler output. seed = random.randint(0, 10**10 - 1) - seed_source = 'generated' + seed_source = "generated" else: - seed_source = 'given' + seed_source = "given" self.seed = seed self.seed_source = seed_source @property def seed_display(self): - return f'{self.seed!r} ({self.seed_source})' + return f"{self.seed!r} ({self.seed_source})" def _hash_item(self, item, key): - text = '{}{}'.format(self.seed, key(item)) + text = "{}{}".format(self.seed, key(item)) return self._hash_text(text) def shuffle(self, items, key): @@ -544,8 +565,10 @@ class Shuffler: for item in items: hashed = self._hash_item(item, key) if hashed in hashes: - msg = 'item {!r} has same hash {!r} as item {!r}'.format( - item, hashed, hashes[hashed], + msg = "item {!r} has same hash {!r} as item {!r}".format( + item, + hashed, + hashes[hashed], ) raise RuntimeError(msg) hashes[hashed] = item @@ -561,12 +584,29 @@ class DiscoverRunner: test_loader = unittest.defaultTestLoader reorder_by = (TestCase, SimpleTestCase) - def __init__(self, pattern=None, top_level=None, verbosity=1, - interactive=True, failfast=False, keepdb=False, - reverse=False, debug_mode=False, debug_sql=False, parallel=0, - tags=None, exclude_tags=None, test_name_patterns=None, - pdb=False, buffer=False, enable_faulthandler=True, - timing=False, shuffle=False, logger=None, **kwargs): + def __init__( + self, + pattern=None, + top_level=None, + verbosity=1, + interactive=True, + failfast=False, + keepdb=False, + reverse=False, + debug_mode=False, + debug_sql=False, + parallel=0, + tags=None, + exclude_tags=None, + test_name_patterns=None, + pdb=False, + buffer=False, + enable_faulthandler=True, + timing=False, + shuffle=False, + logger=None, + **kwargs, + ): self.pattern = pattern self.top_level = top_level @@ -587,7 +627,9 @@ class DiscoverRunner: faulthandler.enable(file=sys.__stderr__.fileno()) self.pdb = pdb if self.pdb and self.parallel > 1: - raise ValueError('You cannot use --pdb with parallel tests; pass --parallel=1 to use it.') + raise ValueError( + "You cannot use --pdb with parallel tests; pass --parallel=1 to use it." + ) self.buffer = buffer self.test_name_patterns = None self.time_keeper = TimeKeeper() if timing else NullTimeKeeper() @@ -595,7 +637,7 @@ class DiscoverRunner: # unittest does not export the _convert_select_pattern function # that converts command-line arguments to patterns. self.test_name_patterns = { - pattern if '*' in pattern else '*%s*' % pattern + pattern if "*" in pattern else "*%s*" % pattern for pattern in test_name_patterns } self.shuffle = shuffle @@ -605,73 +647,99 @@ class DiscoverRunner: @classmethod def add_arguments(cls, parser): parser.add_argument( - '-t', '--top-level-directory', dest='top_level', - help='Top level of project for unittest discovery.', + "-t", + "--top-level-directory", + dest="top_level", + help="Top level of project for unittest discovery.", ) parser.add_argument( - '-p', '--pattern', default="test*.py", - help='The test matching pattern. Defaults to test*.py.', + "-p", + "--pattern", + default="test*.py", + help="The test matching pattern. Defaults to test*.py.", ) parser.add_argument( - '--keepdb', action='store_true', - help='Preserves the test DB between runs.' + "--keepdb", action="store_true", help="Preserves the test DB between runs." ) parser.add_argument( - '--shuffle', nargs='?', default=False, type=int, metavar='SEED', - help='Shuffles test case order.', + "--shuffle", + nargs="?", + default=False, + type=int, + metavar="SEED", + help="Shuffles test case order.", ) parser.add_argument( - '-r', '--reverse', action='store_true', - help='Reverses test case order.', + "-r", + "--reverse", + action="store_true", + help="Reverses test case order.", ) parser.add_argument( - '--debug-mode', action='store_true', - help='Sets settings.DEBUG to True.', + "--debug-mode", + action="store_true", + help="Sets settings.DEBUG to True.", ) parser.add_argument( - '-d', '--debug-sql', action='store_true', - help='Prints logged SQL queries on failure.', + "-d", + "--debug-sql", + action="store_true", + help="Prints logged SQL queries on failure.", ) parser.add_argument( - '--parallel', nargs='?', const='auto', default=0, - type=parallel_type, metavar='N', + "--parallel", + nargs="?", + const="auto", + default=0, + type=parallel_type, + metavar="N", help=( - 'Run tests using up to N parallel processes. Use the value ' + "Run tests using up to N parallel processes. Use the value " '"auto" to run one test process for each processor core.' ), ) parser.add_argument( - '--tag', action='append', dest='tags', - help='Run only tests with the specified tag. Can be used multiple times.', + "--tag", + action="append", + dest="tags", + help="Run only tests with the specified tag. Can be used multiple times.", ) parser.add_argument( - '--exclude-tag', action='append', dest='exclude_tags', - help='Do not run tests with the specified tag. Can be used multiple times.', + "--exclude-tag", + action="append", + dest="exclude_tags", + help="Do not run tests with the specified tag. Can be used multiple times.", ) parser.add_argument( - '--pdb', action='store_true', - help='Runs a debugger (pdb, or ipdb if installed) on error or failure.' + "--pdb", + action="store_true", + help="Runs a debugger (pdb, or ipdb if installed) on error or failure.", ) parser.add_argument( - '-b', '--buffer', action='store_true', - help='Discard output from passing tests.', + "-b", + "--buffer", + action="store_true", + help="Discard output from passing tests.", ) parser.add_argument( - '--no-faulthandler', action='store_false', dest='enable_faulthandler', - help='Disables the Python faulthandler module during tests.', + "--no-faulthandler", + action="store_false", + dest="enable_faulthandler", + help="Disables the Python faulthandler module during tests.", ) parser.add_argument( - '--timing', action='store_true', - help=( - 'Output timings, including database set up and total run time.' - ), + "--timing", + action="store_true", + help=("Output timings, including database set up and total run time."), ) parser.add_argument( - '-k', action='append', dest='test_name_patterns', + "-k", + action="append", + dest="test_name_patterns", help=( - 'Only run test methods and classes that match the pattern ' - 'or substring. Can be used multiple times. Same as ' - 'unittest -k option.' + "Only run test methods and classes that match the pattern " + "or substring. Can be used multiple times. Same as " + "unittest -k option." ), ) @@ -693,9 +761,7 @@ class DiscoverRunner: if level is None: level = logging.INFO if self.logger is None: - if self.verbosity <= 0 or ( - self.verbosity == 1 and level < logging.INFO - ): + if self.verbosity <= 0 or (self.verbosity == 1 and level < logging.INFO): return print(msg) else: @@ -709,7 +775,7 @@ class DiscoverRunner: if self.shuffle is False: return shuffler = Shuffler(seed=self.shuffle) - self.log(f'Using shuffle seed: {shuffler.seed_display}') + self.log(f"Using shuffle seed: {shuffler.seed_display}") self._shuffler = shuffler @contextmanager @@ -741,15 +807,15 @@ class DiscoverRunner: if os.path.exists(label_as_path): assert tests is None raise RuntimeError( - f'One of the test labels is a path to a file: {label!r}, ' - f'which is not supported. Use a dotted module name or ' - f'path to a directory instead.' + f"One of the test labels is a path to a file: {label!r}, " + f"which is not supported. Use a dotted module name or " + f"path to a directory instead." ) return tests kwargs = discover_kwargs.copy() if os.path.isdir(label_as_path) and not self.top_level: - kwargs['top_level_dir'] = find_top_level(label_as_path) + kwargs["top_level_dir"] = find_top_level(label_as_path) with self.load_with_patterns(): tests = self.test_loader.discover(start_dir=label, **kwargs) @@ -762,18 +828,18 @@ class DiscoverRunner: def build_suite(self, test_labels=None, extra_tests=None, **kwargs): if extra_tests is not None: warnings.warn( - 'The extra_tests argument is deprecated.', + "The extra_tests argument is deprecated.", RemovedInDjango50Warning, stacklevel=2, ) - test_labels = test_labels or ['.'] + test_labels = test_labels or ["."] extra_tests = extra_tests or [] discover_kwargs = {} if self.pattern is not None: - discover_kwargs['pattern'] = self.pattern + discover_kwargs["pattern"] = self.pattern if self.top_level is not None: - discover_kwargs['top_level_dir'] = self.top_level + discover_kwargs["top_level_dir"] = self.top_level self.setup_shuffler() all_tests = [] @@ -786,12 +852,12 @@ class DiscoverRunner: if self.tags or self.exclude_tags: if self.tags: self.log( - 'Including test tag(s): %s.' % ', '.join(sorted(self.tags)), + "Including test tag(s): %s." % ", ".join(sorted(self.tags)), level=logging.DEBUG, ) if self.exclude_tags: self.log( - 'Excluding test tag(s): %s.' % ', '.join(sorted(self.exclude_tags)), + "Excluding test tag(s): %s." % ", ".join(sorted(self.exclude_tags)), level=logging.DEBUG, ) all_tests = filter_tests_by_tags(all_tests, self.tags, self.exclude_tags) @@ -800,13 +866,15 @@ class DiscoverRunner: # _FailedTest objects include things like test modules that couldn't be # found or that couldn't be loaded due to syntax errors. test_types = (unittest.loader._FailedTest, *self.reorder_by) - all_tests = list(reorder_tests( - all_tests, - test_types, - shuffler=self._shuffler, - reverse=self.reverse, - )) - self.log('Found %d test(s).' % len(all_tests)) + all_tests = list( + reorder_tests( + all_tests, + test_types, + shuffler=self._shuffler, + reverse=self.reverse, + ) + ) + self.log("Found %d test(s)." % len(all_tests)) suite = self.test_suite(all_tests) if self.parallel > 1: @@ -828,8 +896,13 @@ class DiscoverRunner: def setup_databases(self, **kwargs): return _setup_databases( - self.verbosity, self.interactive, time_keeper=self.time_keeper, keepdb=self.keepdb, - debug_sql=self.debug_sql, parallel=self.parallel, **kwargs + self.verbosity, + self.interactive, + time_keeper=self.time_keeper, + keepdb=self.keepdb, + debug_sql=self.debug_sql, + parallel=self.parallel, + **kwargs, ) def get_resultclass(self): @@ -840,16 +913,16 @@ class DiscoverRunner: def get_test_runner_kwargs(self): return { - 'failfast': self.failfast, - 'resultclass': self.get_resultclass(), - 'verbosity': self.verbosity, - 'buffer': self.buffer, + "failfast": self.failfast, + "resultclass": self.get_resultclass(), + "verbosity": self.verbosity, + "buffer": self.buffer, } def run_checks(self, databases): # Checks are run after database creation since some checks require # database access. - call_command('check', verbosity=self.verbosity, databases=databases) + call_command("check", verbosity=self.verbosity, databases=databases) def run_suite(self, suite, **kwargs): kwargs = self.get_test_runner_kwargs() @@ -859,7 +932,7 @@ class DiscoverRunner: finally: if self._shuffler is not None: seed_display = self._shuffler.seed_display - self.log(f'Used shuffle seed: {seed_display}') + self.log(f"Used shuffle seed: {seed_display}") def teardown_databases(self, old_config, **kwargs): """Destroy all the non-mirror databases.""" @@ -875,16 +948,18 @@ class DiscoverRunner: teardown_test_environment() def suite_result(self, suite, result, **kwargs): - return len(result.failures) + len(result.errors) + len(result.unexpectedSuccesses) + return ( + len(result.failures) + len(result.errors) + len(result.unexpectedSuccesses) + ) def _get_databases(self, suite): databases = {} for test in iter_test_cases(suite): - test_databases = getattr(test, 'databases', None) - if test_databases == '__all__': + test_databases = getattr(test, "databases", None) + if test_databases == "__all__": test_databases = connections if test_databases: - serialized_rollback = getattr(test, 'serialized_rollback', False) + serialized_rollback = getattr(test, "serialized_rollback", False) databases.update( (alias, serialized_rollback or databases.get(alias, False)) for alias in test_databases @@ -896,7 +971,8 @@ class DiscoverRunner: unused_databases = [alias for alias in connections if alias not in databases] if unused_databases: self.log( - 'Skipping setup of unused database(s): %s.' % ', '.join(sorted(unused_databases)), + "Skipping setup of unused database(s): %s." + % ", ".join(sorted(unused_databases)), level=logging.DEBUG, ) return databases @@ -912,7 +988,7 @@ class DiscoverRunner: """ if extra_tests is not None: warnings.warn( - 'The extra_tests argument is deprecated.', + "The extra_tests argument is deprecated.", RemovedInDjango50Warning, stacklevel=2, ) @@ -920,10 +996,9 @@ class DiscoverRunner: suite = self.build_suite(test_labels, extra_tests) databases = self.get_databases(suite) serialized_aliases = set( - alias - for alias, serialize in databases.items() if serialize + alias for alias, serialize in databases.items() if serialize ) - with self.time_keeper.timed('Total database setup'): + with self.time_keeper.timed("Total database setup"): old_config = self.setup_databases( aliases=databases, serialized_aliases=serialized_aliases, @@ -937,7 +1012,7 @@ class DiscoverRunner: raise finally: try: - with self.time_keeper.timed('Total database teardown'): + with self.time_keeper.timed("Total database teardown"): self.teardown_databases(old_config) self.teardown_test_environment() except Exception: @@ -960,7 +1035,7 @@ def try_importing(label): except (ImportError, TypeError): return (False, False) - return (True, hasattr(mod, '__path__')) + return (True, hasattr(mod, "__path__")) def find_top_level(top_level): @@ -976,7 +1051,7 @@ def find_top_level(top_level): # top-level module or as a directory path, unittest unfortunately prefers # the latter. while True: - init_py = os.path.join(top_level, '__init__.py') + init_py = os.path.join(top_level, "__init__.py") if not os.path.exists(init_py): break try_next = os.path.dirname(top_level) @@ -988,7 +1063,7 @@ def find_top_level(top_level): def _class_shuffle_key(cls): - return f'{cls.__module__}.{cls.__qualname__}' + return f"{cls.__module__}.{cls.__qualname__}" def shuffle_tests(tests, shuffler): @@ -1073,9 +1148,7 @@ def partition_suite_by_case(suite): """Partition a test suite by test case, preserving the order of tests.""" suite_class = type(suite) all_tests = iter_test_cases(suite) - return [ - suite_class(tests) for _, tests in itertools.groupby(all_tests, type) - ] + return [suite_class(tests) for _, tests in itertools.groupby(all_tests, type)] def test_match_tags(test, tags, exclude_tags): @@ -1083,11 +1156,11 @@ def test_match_tags(test, tags, exclude_tags): # Tests that couldn't load always match to prevent tests from falsely # passing due e.g. to syntax errors. return True - test_tags = set(getattr(test, 'tags', [])) - test_fn_name = getattr(test, '_testMethodName', str(test)) + test_tags = set(getattr(test, "tags", [])) + test_fn_name = getattr(test, "_testMethodName", str(test)) if hasattr(test, test_fn_name): test_fn = getattr(test, test_fn_name) - test_fn_tags = list(getattr(test_fn, 'tags', [])) + test_fn_tags = list(getattr(test_fn, "tags", [])) test_tags = test_tags.union(test_fn_tags) if tags and test_tags.isdisjoint(tags): return False diff --git a/django/test/selenium.py b/django/test/selenium.py index 97a7840fea..aa714ad365 100644 --- a/django/test/selenium.py +++ b/django/test/selenium.py @@ -27,7 +27,9 @@ class SeleniumTestCaseBase(type(LiveServerTestCase)): """ test_class = super().__new__(cls, name, bases, attrs) # If the test class is either browser-specific or a test base, return it. - if test_class.browser or not any(name.startswith('test') and callable(value) for name, value in attrs.items()): + if test_class.browser or not any( + name.startswith("test") and callable(value) for name, value in attrs.items() + ): return test_class elif test_class.browsers: # Reuse the created test class to make it browser-specific. @@ -37,7 +39,7 @@ class SeleniumTestCaseBase(type(LiveServerTestCase)): first_browser = test_class.browsers[0] test_class.browser = first_browser # Listen on an external interface if using a selenium hub. - host = test_class.host if not test_class.selenium_hub else '0.0.0.0' + host = test_class.host if not test_class.selenium_hub else "0.0.0.0" test_class.host = host test_class.external_host = cls.external_host # Create subclasses for each of the remaining browsers and expose @@ -49,16 +51,16 @@ class SeleniumTestCaseBase(type(LiveServerTestCase)): "%s%s" % (capfirst(browser), name), (test_class,), { - 'browser': browser, - 'host': host, - 'external_host': cls.external_host, - '__module__': test_class.__module__, - } + "browser": browser, + "host": host, + "external_host": cls.external_host, + "__module__": test_class.__module__, + }, ) setattr(module, browser_test_class.__name__, browser_test_class) return test_class # If no browsers were specified, skip this class (it'll still be discovered). - return unittest.skip('No browsers specified.')(test_class) + return unittest.skip("No browsers specified.")(test_class) @classmethod def import_webdriver(cls, browser): @@ -66,13 +68,12 @@ class SeleniumTestCaseBase(type(LiveServerTestCase)): @classmethod def import_options(cls, browser): - return import_string('selenium.webdriver.%s.options.Options' % browser) + return import_string("selenium.webdriver.%s.options.Options" % browser) @classmethod def get_capability(cls, browser): - from selenium.webdriver.common.desired_capabilities import ( - DesiredCapabilities, - ) + from selenium.webdriver.common.desired_capabilities import DesiredCapabilities + return getattr(DesiredCapabilities, browser.upper()) def create_options(self): @@ -87,6 +88,7 @@ class SeleniumTestCaseBase(type(LiveServerTestCase)): def create_webdriver(self): if self.selenium_hub: from selenium import webdriver + return webdriver.Remote( command_executor=self.selenium_hub, desired_capabilities=self.get_capability(self.browser), @@ -94,14 +96,14 @@ class SeleniumTestCaseBase(type(LiveServerTestCase)): return self.import_webdriver(self.browser)(options=self.create_options()) -@tag('selenium') +@tag("selenium") class SeleniumTestCase(LiveServerTestCase, metaclass=SeleniumTestCaseBase): implicit_wait = 10 external_host = None @classproperty def live_server_url(cls): - return 'http://%s:%s' % (cls.external_host or cls.host, cls.server_thread.port) + return "http://%s:%s" % (cls.external_host or cls.host, cls.server_thread.port) @classproperty def allowed_host(cls): @@ -118,7 +120,7 @@ class SeleniumTestCase(LiveServerTestCase, metaclass=SeleniumTestCaseBase): # quit() the WebDriver before attempting to terminate and join the # single-threaded LiveServerThread to avoid a dead lock if the browser # kept a connection alive. - if hasattr(cls, 'selenium'): + if hasattr(cls, "selenium"): cls.selenium.quit() super()._tearDownClassInternal() diff --git a/django/test/signals.py b/django/test/signals.py index c82b95013d..c874f220df 100644 --- a/django/test/signals.py +++ b/django/test/signals.py @@ -20,13 +20,14 @@ template_rendered = Signal() # except for cases where the receiver is related to a contrib app. # Settings that may not work well when using 'override_settings' (#19031) -COMPLEX_OVERRIDE_SETTINGS = {'DATABASES'} +COMPLEX_OVERRIDE_SETTINGS = {"DATABASES"} @receiver(setting_changed) def clear_cache_handlers(*, setting, **kwargs): - if setting == 'CACHES': + if setting == "CACHES": from django.core.cache import caches, close_caches + close_caches() caches._settings = caches.settings = caches.configure_settings(None) caches._connections = Local() @@ -34,37 +35,41 @@ def clear_cache_handlers(*, setting, **kwargs): @receiver(setting_changed) def update_installed_apps(*, setting, **kwargs): - if setting == 'INSTALLED_APPS': + if setting == "INSTALLED_APPS": # Rebuild any AppDirectoriesFinder instance. from django.contrib.staticfiles.finders import get_finder + get_finder.cache_clear() # Rebuild management commands cache from django.core.management import get_commands + get_commands.cache_clear() # Rebuild get_app_template_dirs cache. from django.template.utils import get_app_template_dirs + get_app_template_dirs.cache_clear() # Rebuild translations cache. from django.utils.translation import trans_real + trans_real._translations = {} @receiver(setting_changed) def update_connections_time_zone(*, setting, **kwargs): - if setting == 'TIME_ZONE': + if setting == "TIME_ZONE": # Reset process time zone - if hasattr(time, 'tzset'): - if kwargs['value']: - os.environ['TZ'] = kwargs['value'] + if hasattr(time, "tzset"): + if kwargs["value"]: + os.environ["TZ"] = kwargs["value"] else: - os.environ.pop('TZ', None) + os.environ.pop("TZ", None) time.tzset() # Reset local time zone cache timezone.get_default_timezone.cache_clear() # Reset the database connections' time zone - if setting in {'TIME_ZONE', 'USE_TZ'}: + if setting in {"TIME_ZONE", "USE_TZ"}: for conn in connections.all(): try: del conn.timezone @@ -79,18 +84,19 @@ def update_connections_time_zone(*, setting, **kwargs): @receiver(setting_changed) def clear_routers_cache(*, setting, **kwargs): - if setting == 'DATABASE_ROUTERS': + if setting == "DATABASE_ROUTERS": router.routers = ConnectionRouter().routers @receiver(setting_changed) def reset_template_engines(*, setting, **kwargs): if setting in { - 'TEMPLATES', - 'DEBUG', - 'INSTALLED_APPS', + "TEMPLATES", + "DEBUG", + "INSTALLED_APPS", }: from django.template import engines + try: del engines.templates except AttributeError: @@ -98,40 +104,46 @@ def reset_template_engines(*, setting, **kwargs): engines._templates = None engines._engines = {} from django.template.engine import Engine + Engine.get_default.cache_clear() from django.forms.renderers import get_default_renderer + get_default_renderer.cache_clear() @receiver(setting_changed) def clear_serializers_cache(*, setting, **kwargs): - if setting == 'SERIALIZATION_MODULES': + if setting == "SERIALIZATION_MODULES": from django.core import serializers + serializers._serializers = {} @receiver(setting_changed) def language_changed(*, setting, **kwargs): - if setting in {'LANGUAGES', 'LANGUAGE_CODE', 'LOCALE_PATHS'}: + if setting in {"LANGUAGES", "LANGUAGE_CODE", "LOCALE_PATHS"}: from django.utils.translation import trans_real + trans_real._default = None trans_real._active = Local() - if setting in {'LANGUAGES', 'LOCALE_PATHS'}: + if setting in {"LANGUAGES", "LOCALE_PATHS"}: from django.utils.translation import trans_real + trans_real._translations = {} trans_real.check_for_language.cache_clear() @receiver(setting_changed) def localize_settings_changed(*, setting, **kwargs): - if setting in FORMAT_SETTINGS or setting == 'USE_THOUSAND_SEPARATOR': + if setting in FORMAT_SETTINGS or setting == "USE_THOUSAND_SEPARATOR": reset_format_cache() @receiver(setting_changed) def file_storage_changed(*, setting, **kwargs): - if setting == 'DEFAULT_FILE_STORAGE': + if setting == "DEFAULT_FILE_STORAGE": from django.core.files.storage import default_storage + default_storage._wrapped = empty @@ -141,15 +153,16 @@ def complex_setting_changed(*, enter, setting, **kwargs): # Considering the current implementation of the signals framework, # this stacklevel shows the line containing the override_settings call. warnings.warn( - f'Overriding setting {setting} can lead to unexpected behavior.', + f"Overriding setting {setting} can lead to unexpected behavior.", stacklevel=6, ) @receiver(setting_changed) def root_urlconf_changed(*, setting, **kwargs): - if setting == 'ROOT_URLCONF': + if setting == "ROOT_URLCONF": from django.urls import clear_url_caches, set_urlconf + clear_url_caches() set_urlconf(None) @@ -157,55 +170,64 @@ def root_urlconf_changed(*, setting, **kwargs): @receiver(setting_changed) def static_storage_changed(*, setting, **kwargs): if setting in { - 'STATICFILES_STORAGE', - 'STATIC_ROOT', - 'STATIC_URL', + "STATICFILES_STORAGE", + "STATIC_ROOT", + "STATIC_URL", }: from django.contrib.staticfiles.storage import staticfiles_storage + staticfiles_storage._wrapped = empty @receiver(setting_changed) def static_finders_changed(*, setting, **kwargs): if setting in { - 'STATICFILES_DIRS', - 'STATIC_ROOT', + "STATICFILES_DIRS", + "STATIC_ROOT", }: from django.contrib.staticfiles.finders import get_finder + get_finder.cache_clear() @receiver(setting_changed) def auth_password_validators_changed(*, setting, **kwargs): - if setting == 'AUTH_PASSWORD_VALIDATORS': + if setting == "AUTH_PASSWORD_VALIDATORS": from django.contrib.auth.password_validation import ( get_default_password_validators, ) + get_default_password_validators.cache_clear() @receiver(setting_changed) def user_model_swapped(*, setting, **kwargs): - if setting == 'AUTH_USER_MODEL': + if setting == "AUTH_USER_MODEL": apps.clear_cache() try: from django.contrib.auth import get_user_model + UserModel = get_user_model() except ImproperlyConfigured: # Some tests set an invalid AUTH_USER_MODEL. pass else: from django.contrib.auth import backends + backends.UserModel = UserModel from django.contrib.auth import forms + forms.UserModel = UserModel from django.contrib.auth.handlers import modwsgi + modwsgi.UserModel = UserModel from django.contrib.auth.management.commands import changepassword + changepassword.UserModel = UserModel from django.contrib.auth import views + views.UserModel = UserModel diff --git a/django/test/testcases.py b/django/test/testcases.py index d24a065790..d514d06c7a 100644 --- a/django/test/testcases.py +++ b/django/test/testcases.py @@ -15,7 +15,13 @@ from functools import wraps from unittest.suite import _DebugResult from unittest.util import safe_repr from urllib.parse import ( - parse_qsl, unquote, urlencode, urljoin, urlparse, urlsplit, urlunparse, + parse_qsl, + unquote, + urlencode, + urljoin, + urlparse, + urlsplit, + urlunparse, ) from urllib.request import url2pathname @@ -40,7 +46,10 @@ from django.test.client import AsyncClient, Client from django.test.html import HTMLParseError, parse_html from django.test.signals import template_rendered from django.test.utils import ( - CaptureQueriesContext, ContextList, compare_xml, modify_settings, + CaptureQueriesContext, + ContextList, + compare_xml, + modify_settings, override_settings, ) from django.utils.deprecation import RemovedInDjango50Warning @@ -48,8 +57,13 @@ from django.utils.functional import classproperty from django.utils.version import PY310 from django.views.static import serve -__all__ = ('TestCase', 'TransactionTestCase', - 'SimpleTestCase', 'skipIfDBFeature', 'skipUnlessDBFeature') +__all__ = ( + "TestCase", + "TransactionTestCase", + "SimpleTestCase", + "skipIfDBFeature", + "skipUnlessDBFeature", +) def to_list(value): @@ -63,7 +77,7 @@ def assert_and_parse_html(self, html, user_msg, msg): try: dom = parse_html(html) except HTMLParseError as e: - standardMsg = '%s\n%s' % (msg, e) + standardMsg = "%s\n%s" % (msg, e) self.fail(self._formatMessage(user_msg, standardMsg)) return dom @@ -80,18 +94,22 @@ class _AssertNumQueriesContext(CaptureQueriesContext): return executed = len(self) self.test_case.assertEqual( - executed, self.num, - "%d queries executed, %d expected\nCaptured queries were:\n%s" % ( - executed, self.num, - '\n'.join( - '%d. %s' % (i, query['sql']) for i, query in enumerate(self.captured_queries, start=1) - ) - ) + executed, + self.num, + "%d queries executed, %d expected\nCaptured queries were:\n%s" + % ( + executed, + self.num, + "\n".join( + "%d. %s" % (i, query["sql"]) + for i, query in enumerate(self.captured_queries, start=1) + ), + ), ) class _AssertTemplateUsedContext: - def __init__(self, test_case, template_name, msg_prefix='', count=None): + def __init__(self, test_case, template_name, msg_prefix="", count=None): self.test_case = test_case self.template_name = template_name self.msg_prefix = msg_prefix @@ -108,7 +126,9 @@ class _AssertTemplateUsedContext: def test(self): self.test_case._assert_template_used( - self.template_name, self.rendered_template_names, self.msg_prefix, + self.template_name, + self.rendered_template_names, + self.msg_prefix, self.count, ) @@ -128,7 +148,7 @@ class _AssertTemplateNotUsedContext(_AssertTemplateUsedContext): self.test_case.assertFalse( self.template_name in self.rendered_template_names, f"{self.msg_prefix}Template '{self.template_name}' was used " - f"unexpectedly in rendering the response" + f"unexpectedly in rendering the response", ) @@ -156,16 +176,16 @@ class SimpleTestCase(unittest.TestCase): databases = set() _disallowed_database_msg = ( - 'Database %(operation)s to %(alias)r are not allowed in SimpleTestCase ' - 'subclasses. Either subclass TestCase or TransactionTestCase to ensure ' - 'proper test isolation or add %(alias)r to %(test)s.databases to silence ' - 'this failure.' + "Database %(operation)s to %(alias)r are not allowed in SimpleTestCase " + "subclasses. Either subclass TestCase or TransactionTestCase to ensure " + "proper test isolation or add %(alias)r to %(test)s.databases to silence " + "this failure." ) _disallowed_connection_methods = [ - ('connect', 'connections'), - ('temporary_connection', 'connections'), - ('cursor', 'queries'), - ('chunked_cursor', 'queries'), + ("connect", "connections"), + ("temporary_connection", "connections"), + ("cursor", "queries"), + ("chunked_cursor", "queries"), ] @classmethod @@ -184,18 +204,21 @@ class SimpleTestCase(unittest.TestCase): @classmethod def _validate_databases(cls): - if cls.databases == '__all__': + if cls.databases == "__all__": return frozenset(connections) for alias in cls.databases: if alias not in connections: - message = '%s.%s.databases refers to %r which is not defined in settings.DATABASES.' % ( - cls.__module__, - cls.__qualname__, - alias, + message = ( + "%s.%s.databases refers to %r which is not defined in settings.DATABASES." + % ( + cls.__module__, + cls.__qualname__, + alias, + ) ) close_matches = get_close_matches(alias, list(connections)) if close_matches: - message += ' Did you mean %r?' % close_matches[0] + message += " Did you mean %r?" % close_matches[0] raise ImproperlyConfigured(message) return frozenset(cls.databases) @@ -208,9 +231,9 @@ class SimpleTestCase(unittest.TestCase): connection = connections[alias] for name, operation in cls._disallowed_connection_methods: message = cls._disallowed_database_msg % { - 'test': '%s.%s' % (cls.__module__, cls.__qualname__), - 'alias': alias, - 'operation': operation, + "test": "%s.%s" % (cls.__module__, cls.__qualname__), + "alias": alias, + "operation": operation, } method = getattr(connection, name) setattr(connection, name, _DatabaseFailure(method, message)) @@ -247,9 +270,8 @@ class SimpleTestCase(unittest.TestCase): instead of __call__() to run the test. """ testMethod = getattr(self, self._testMethodName) - skipped = ( - getattr(self.__class__, "__unittest_skip__", False) or - getattr(testMethod, "__unittest_skip__", False) + skipped = getattr(self.__class__, "__unittest_skip__", False) or getattr( + testMethod, "__unittest_skip__", False ) # Convert async test methods. @@ -305,9 +327,15 @@ class SimpleTestCase(unittest.TestCase): """ return modify_settings(**kwargs) - def assertRedirects(self, response, expected_url, status_code=302, - target_status_code=200, msg_prefix='', - fetch_redirect_response=True): + def assertRedirects( + self, + response, + expected_url, + status_code=302, + target_status_code=200, + msg_prefix="", + fetch_redirect_response=True, + ): """ Assert that a response redirected to a specific URL and that the redirect URL can be loaded. @@ -319,43 +347,50 @@ class SimpleTestCase(unittest.TestCase): if msg_prefix: msg_prefix += ": " - if hasattr(response, 'redirect_chain'): + if hasattr(response, "redirect_chain"): # The request was a followed redirect self.assertTrue( response.redirect_chain, - msg_prefix + "Response didn't redirect as expected: Response code was %d (expected %d)" - % (response.status_code, status_code) + msg_prefix + + "Response didn't redirect as expected: Response code was %d (expected %d)" + % (response.status_code, status_code), ) self.assertEqual( - response.redirect_chain[0][1], status_code, - msg_prefix + "Initial response didn't redirect as expected: Response code was %d (expected %d)" - % (response.redirect_chain[0][1], status_code) + response.redirect_chain[0][1], + status_code, + msg_prefix + + "Initial response didn't redirect as expected: Response code was %d (expected %d)" + % (response.redirect_chain[0][1], status_code), ) url, status_code = response.redirect_chain[-1] self.assertEqual( - response.status_code, target_status_code, - msg_prefix + "Response didn't redirect as expected: Final Response code was %d (expected %d)" - % (response.status_code, target_status_code) + response.status_code, + target_status_code, + msg_prefix + + "Response didn't redirect as expected: Final Response code was %d (expected %d)" + % (response.status_code, target_status_code), ) else: # Not a followed redirect self.assertEqual( - response.status_code, status_code, - msg_prefix + "Response didn't redirect as expected: Response code was %d (expected %d)" - % (response.status_code, status_code) + response.status_code, + status_code, + msg_prefix + + "Response didn't redirect as expected: Response code was %d (expected %d)" + % (response.status_code, status_code), ) url = response.url scheme, netloc, path, query, fragment = urlsplit(url) # Prepend the request path to handle relative path redirects. - if not path.startswith('/'): - url = urljoin(response.request['PATH_INFO'], url) - path = urljoin(response.request['PATH_INFO'], path) + if not path.startswith("/"): + url = urljoin(response.request["PATH_INFO"], url) + path = urljoin(response.request["PATH_INFO"], path) if fetch_redirect_response: # netloc might be empty, or in cases where Django tests the @@ -375,21 +410,25 @@ class SimpleTestCase(unittest.TestCase): redirect_response = response.client.get( path, QueryDict(query), - secure=(scheme == 'https'), + secure=(scheme == "https"), **extra, ) self.assertEqual( - redirect_response.status_code, target_status_code, - msg_prefix + "Couldn't retrieve redirection page '%s': response code was %d (expected %d)" - % (path, redirect_response.status_code, target_status_code) + redirect_response.status_code, + target_status_code, + msg_prefix + + "Couldn't retrieve redirection page '%s': response code was %d (expected %d)" + % (path, redirect_response.status_code, target_status_code), ) self.assertURLEqual( - url, expected_url, - msg_prefix + "Response redirected to '%s', expected '%s'" % (url, expected_url) + url, + expected_url, + msg_prefix + + "Response redirected to '%s', expected '%s'" % (url, expected_url), ) - def assertURLEqual(self, url1, url2, msg_prefix=''): + def assertURLEqual(self, url1, url2, msg_prefix=""): """ Assert that two URLs are the same, ignoring the order of query string parameters except for parameters with the same name. @@ -397,35 +436,44 @@ class SimpleTestCase(unittest.TestCase): For example, /path/?x=1&y=2 is equal to /path/?y=2&x=1, but /path/?a=1&a=2 isn't equal to /path/?a=2&a=1. """ + def normalize(url): """Sort the URL's query string parameters.""" url = str(url) # Coerce reverse_lazy() URLs. scheme, netloc, path, params, query, fragment = urlparse(url) query_parts = sorted(parse_qsl(query)) - return urlunparse((scheme, netloc, path, params, urlencode(query_parts), fragment)) + return urlunparse( + (scheme, netloc, path, params, urlencode(query_parts), fragment) + ) self.assertEqual( - normalize(url1), normalize(url2), - msg_prefix + "Expected '%s' to equal '%s'." % (url1, url2) + normalize(url1), + normalize(url2), + msg_prefix + "Expected '%s' to equal '%s'." % (url1, url2), ) def _assert_contains(self, response, text, status_code, msg_prefix, html): # If the response supports deferred rendering and hasn't been rendered # yet, then ensure that it does get rendered before proceeding further. - if hasattr(response, 'render') and callable(response.render) and not response.is_rendered: + if ( + hasattr(response, "render") + and callable(response.render) + and not response.is_rendered + ): response.render() if msg_prefix: msg_prefix += ": " self.assertEqual( - response.status_code, status_code, + response.status_code, + status_code, msg_prefix + "Couldn't retrieve content: Response code was %d" - " (expected %d)" % (response.status_code, status_code) + " (expected %d)" % (response.status_code, status_code), ) if response.streaming: - content = b''.join(response.streaming_content) + content = b"".join(response.streaming_content) else: content = response.content if not isinstance(text, bytes) or html: @@ -435,12 +483,18 @@ class SimpleTestCase(unittest.TestCase): else: text_repr = repr(text) if html: - content = assert_and_parse_html(self, content, None, "Response's content is not valid HTML:") - text = assert_and_parse_html(self, text, None, "Second argument is not valid HTML:") + content = assert_and_parse_html( + self, content, None, "Response's content is not valid HTML:" + ) + text = assert_and_parse_html( + self, text, None, "Second argument is not valid HTML:" + ) real_count = content.count(text) return (text_repr, real_count, msg_prefix) - def assertContains(self, response, text, count=None, status_code=200, msg_prefix='', html=False): + def assertContains( + self, response, text, count=None, status_code=200, msg_prefix="", html=False + ): """ Assert that a response indicates that some content was retrieved successfully, (i.e., the HTTP status code was as expected) and that @@ -449,26 +503,37 @@ class SimpleTestCase(unittest.TestCase): if the text occurs at least once in the response. """ text_repr, real_count, msg_prefix = self._assert_contains( - response, text, status_code, msg_prefix, html) + response, text, status_code, msg_prefix, html + ) if count is not None: self.assertEqual( - real_count, count, - msg_prefix + "Found %d instances of %s in response (expected %d)" % (real_count, text_repr, count) + real_count, + count, + msg_prefix + + "Found %d instances of %s in response (expected %d)" + % (real_count, text_repr, count), ) else: - self.assertTrue(real_count != 0, msg_prefix + "Couldn't find %s in response" % text_repr) + self.assertTrue( + real_count != 0, msg_prefix + "Couldn't find %s in response" % text_repr + ) - def assertNotContains(self, response, text, status_code=200, msg_prefix='', html=False): + def assertNotContains( + self, response, text, status_code=200, msg_prefix="", html=False + ): """ Assert that a response indicates that some content was retrieved successfully, (i.e., the HTTP status code was as expected) and that ``text`` doesn't occur in the content of the response. """ text_repr, real_count, msg_prefix = self._assert_contains( - response, text, status_code, msg_prefix, html) + response, text, status_code, msg_prefix, html + ) - self.assertEqual(real_count, 0, msg_prefix + "Response should not contain %s" % text_repr) + self.assertEqual( + real_count, 0, msg_prefix + "Response should not contain %s" % text_repr + ) def _check_test_client_response(self, response, attribute, method_name): """ @@ -481,24 +546,26 @@ class SimpleTestCase(unittest.TestCase): "the Django test Client." ) - def assertFormError(self, response, form, field, errors, msg_prefix=''): + def assertFormError(self, response, form, field, errors, msg_prefix=""): """ Assert that a form used to render the response has a specific field error. """ - self._check_test_client_response(response, 'context', 'assertFormError') + self._check_test_client_response(response, "context", "assertFormError") if msg_prefix: msg_prefix += ": " # Put context(s) into a list to simplify processing. contexts = [] if response.context is None else to_list(response.context) if not contexts: - self.fail(msg_prefix + "Response did not use any contexts to render the response") + self.fail( + msg_prefix + "Response did not use any contexts to render the response" + ) if errors is None: warnings.warn( - 'Passing errors=None to assertFormError() is deprecated, use ' - 'errors=[] instead.', + "Passing errors=None to assertFormError() is deprecated, use " + "errors=[] instead.", RemovedInDjango50Warning, stacklevel=2, ) @@ -520,18 +587,20 @@ class SimpleTestCase(unittest.TestCase): err in field_errors, msg_prefix + "The field '%s' on form '%s' in" " context %d does not contain the error '%s'" - " (actual errors: %s)" % - (field, form, i, err, repr(field_errors)) + " (actual errors: %s)" + % (field, form, i, err, repr(field_errors)), ) elif field in context[form].fields: self.fail( - msg_prefix + "The field '%s' on form '%s' in context %d contains no errors" % - (field, form, i) + msg_prefix + + "The field '%s' on form '%s' in context %d contains no errors" + % (field, form, i) ) else: self.fail( - msg_prefix + "The form '%s' in context %d does not contain the field '%s'" % - (form, i, field) + msg_prefix + + "The form '%s' in context %d does not contain the field '%s'" + % (form, i, field) ) else: non_field_errors = context[form].non_field_errors() @@ -539,14 +608,17 @@ class SimpleTestCase(unittest.TestCase): err in non_field_errors, msg_prefix + "The form '%s' in context %d does not" " contain the non-field error '%s'" - " (actual errors: %s)" % - (form, i, err, non_field_errors or 'none') + " (actual errors: %s)" + % (form, i, err, non_field_errors or "none"), ) if not found_form: - self.fail(msg_prefix + "The form '%s' was not used to render the response" % form) + self.fail( + msg_prefix + "The form '%s' was not used to render the response" % form + ) - def assertFormsetError(self, response, formset, form_index, field, errors, - msg_prefix=''): + def assertFormsetError( + self, response, formset, form_index, field, errors, msg_prefix="" + ): """ Assert that a formset used to render the response has a specific error. @@ -556,7 +628,7 @@ class SimpleTestCase(unittest.TestCase): For non-form errors, specify ``form_index`` as None and the ``field`` as None. """ - self._check_test_client_response(response, 'context', 'assertFormsetError') + self._check_test_client_response(response, "context", "assertFormsetError") # Add punctuation to msg_prefix if msg_prefix: msg_prefix += ": " @@ -564,13 +636,15 @@ class SimpleTestCase(unittest.TestCase): # Put context(s) into a list to simplify processing. contexts = [] if response.context is None else to_list(response.context) if not contexts: - self.fail(msg_prefix + 'Response did not use any contexts to ' - 'render the response') + self.fail( + msg_prefix + "Response did not use any contexts to " + "render the response" + ) if errors is None: warnings.warn( - 'Passing errors=None to assertFormsetError() is deprecated, ' - 'use errors=[] instead.', + "Passing errors=None to assertFormsetError() is deprecated, " + "use errors=[] instead.", RemovedInDjango50Warning, stacklevel=2, ) @@ -581,7 +655,7 @@ class SimpleTestCase(unittest.TestCase): # Search all contexts for the error. found_formset = False for i, context in enumerate(contexts): - if formset not in context or not hasattr(context[formset], 'forms'): + if formset not in context or not hasattr(context[formset], "forms"): continue found_formset = True for err in errors: @@ -592,60 +666,68 @@ class SimpleTestCase(unittest.TestCase): err in field_errors, msg_prefix + "The field '%s' on formset '%s', " "form %d in context %d does not contain the " - "error '%s' (actual errors: %s)" % - (field, formset, form_index, i, err, repr(field_errors)) + "error '%s' (actual errors: %s)" + % (field, formset, form_index, i, err, repr(field_errors)), ) elif field in context[formset].forms[form_index].fields: self.fail( - msg_prefix + "The field '%s' on formset '%s', form %d in context %d contains no errors" + msg_prefix + + "The field '%s' on formset '%s', form %d in context %d contains no errors" % (field, formset, form_index, i) ) else: self.fail( - msg_prefix + "The formset '%s', form %d in context %d does not contain the field '%s'" + msg_prefix + + "The formset '%s', form %d in context %d does not contain the field '%s'" % (formset, form_index, i, field) ) elif form_index is not None: - non_field_errors = context[formset].forms[form_index].non_field_errors() + non_field_errors = ( + context[formset].forms[form_index].non_field_errors() + ) self.assertFalse( not non_field_errors, msg_prefix + "The formset '%s', form %d in context %d " - "does not contain any non-field errors." % (formset, form_index, i) + "does not contain any non-field errors." + % (formset, form_index, i), ) self.assertTrue( err in non_field_errors, msg_prefix + "The formset '%s', form %d in context %d " "does not contain the non-field error '%s' (actual errors: %s)" - % (formset, form_index, i, err, repr(non_field_errors)) + % (formset, form_index, i, err, repr(non_field_errors)), ) else: non_form_errors = context[formset].non_form_errors() self.assertFalse( not non_form_errors, msg_prefix + "The formset '%s' in context %d does not " - "contain any non-form errors." % (formset, i) + "contain any non-form errors." % (formset, i), ) self.assertTrue( err in non_form_errors, msg_prefix + "The formset '%s' in context %d does not " "contain the non-form error '%s' (actual errors: %s)" - % (formset, i, err, repr(non_form_errors)) + % (formset, i, err, repr(non_form_errors)), ) if not found_formset: - self.fail(msg_prefix + "The formset '%s' was not used to render the response" % formset) + self.fail( + msg_prefix + + "The formset '%s' was not used to render the response" % formset + ) def _get_template_used(self, response, template_name, msg_prefix, method_name): if response is None and template_name is None: - raise TypeError('response and/or template_name argument must be provided') + raise TypeError("response and/or template_name argument must be provided") if msg_prefix: msg_prefix += ": " if template_name is not None and response is not None: - self._check_test_client_response(response, 'templates', method_name) + self._check_test_client_response(response, "templates", method_name) - if not hasattr(response, 'templates') or (response is None and template_name): + if not hasattr(response, "templates") or (response is None and template_name): if response: template_name = response response = None @@ -662,38 +744,49 @@ class SimpleTestCase(unittest.TestCase): template_name in template_names, msg_prefix + "Template '%s' was not a template used to render" " the response. Actual template(s) used: %s" - % (template_name, ', '.join(template_names)) + % (template_name, ", ".join(template_names)), ) if count is not None: self.assertEqual( - template_names.count(template_name), count, + template_names.count(template_name), + count, msg_prefix + "Template '%s' was expected to be rendered %d " "time(s) but was actually rendered %d time(s)." - % (template_name, count, template_names.count(template_name)) + % (template_name, count, template_names.count(template_name)), ) - def assertTemplateUsed(self, response=None, template_name=None, msg_prefix='', count=None): + def assertTemplateUsed( + self, response=None, template_name=None, msg_prefix="", count=None + ): """ Assert that the template with the provided name was used in rendering the response. Also usable as context manager. """ context_mgr_template, template_names, msg_prefix = self._get_template_used( - response, template_name, msg_prefix, 'assertTemplateUsed', + response, + template_name, + msg_prefix, + "assertTemplateUsed", ) if context_mgr_template: # Use assertTemplateUsed as context manager. - return _AssertTemplateUsedContext(self, context_mgr_template, msg_prefix, count) + return _AssertTemplateUsedContext( + self, context_mgr_template, msg_prefix, count + ) self._assert_template_used(template_name, template_names, msg_prefix, count) - def assertTemplateNotUsed(self, response=None, template_name=None, msg_prefix=''): + def assertTemplateNotUsed(self, response=None, template_name=None, msg_prefix=""): """ Assert that the template with the provided name was NOT used in rendering the response. Also usable as context manager. """ context_mgr_template, template_names, msg_prefix = self._get_template_used( - response, template_name, msg_prefix, 'assertTemplateNotUsed', + response, + template_name, + msg_prefix, + "assertTemplateNotUsed", ) if context_mgr_template: # Use assertTemplateNotUsed as context manager. @@ -701,20 +794,28 @@ class SimpleTestCase(unittest.TestCase): self.assertFalse( template_name in template_names, - msg_prefix + "Template '%s' was used unexpectedly in rendering the response" % template_name + msg_prefix + + "Template '%s' was used unexpectedly in rendering the response" + % template_name, ) @contextmanager - def _assert_raises_or_warns_cm(self, func, cm_attr, expected_exception, expected_message): + def _assert_raises_or_warns_cm( + self, func, cm_attr, expected_exception, expected_message + ): with func(expected_exception) as cm: yield cm self.assertIn(expected_message, str(getattr(cm, cm_attr))) - def _assertFooMessage(self, func, cm_attr, expected_exception, expected_message, *args, **kwargs): + def _assertFooMessage( + self, func, cm_attr, expected_exception, expected_message, *args, **kwargs + ): callable_obj = None if args: callable_obj, *args = args - cm = self._assert_raises_or_warns_cm(func, cm_attr, expected_exception, expected_message) + cm = self._assert_raises_or_warns_cm( + func, cm_attr, expected_exception, expected_message + ) # Assertion used in context manager fashion. if callable_obj is None: return cm @@ -722,7 +823,9 @@ class SimpleTestCase(unittest.TestCase): with cm: callable_obj(*args, **kwargs) - def assertRaisesMessage(self, expected_exception, expected_message, *args, **kwargs): + def assertRaisesMessage( + self, expected_exception, expected_message, *args, **kwargs + ): """ Assert that expected_message is found in the message of a raised exception. @@ -734,8 +837,12 @@ class SimpleTestCase(unittest.TestCase): kwargs: Extra kwargs. """ return self._assertFooMessage( - self.assertRaises, 'exception', expected_exception, expected_message, - *args, **kwargs + self.assertRaises, + "exception", + expected_exception, + expected_message, + *args, + **kwargs, ) def assertWarnsMessage(self, expected_warning, expected_message, *args, **kwargs): @@ -744,12 +851,17 @@ class SimpleTestCase(unittest.TestCase): assertRaises(). """ return self._assertFooMessage( - self.assertWarns, 'warning', expected_warning, expected_message, - *args, **kwargs + self.assertWarns, + "warning", + expected_warning, + expected_message, + *args, + **kwargs, ) # A similar method is available in Python 3.10+. if not PY310: + @contextmanager def assertNoLogs(self, logger, level=None): """ @@ -759,20 +871,29 @@ class SimpleTestCase(unittest.TestCase): if isinstance(level, int): level = logging.getLevelName(level) elif level is None: - level = 'INFO' + level = "INFO" try: with self.assertLogs(logger, level) as cm: yield except AssertionError as e: msg = e.args[0] - expected_msg = f'no logs of level {level} or higher triggered on {logger}' + expected_msg = ( + f"no logs of level {level} or higher triggered on {logger}" + ) if msg != expected_msg: raise e else: - self.fail(f'Unexpected logs found: {cm.output!r}') + self.fail(f"Unexpected logs found: {cm.output!r}") - def assertFieldOutput(self, fieldclass, valid, invalid, field_args=None, - field_kwargs=None, empty_value=''): + def assertFieldOutput( + self, + fieldclass, + valid, + invalid, + field_args=None, + field_kwargs=None, + empty_value="", + ): """ Assert that a form field behaves correctly with various inputs. @@ -791,7 +912,7 @@ class SimpleTestCase(unittest.TestCase): if field_kwargs is None: field_kwargs = {} required = fieldclass(*field_args, **field_kwargs) - optional = fieldclass(*field_args, **{**field_kwargs, 'required': False}) + optional = fieldclass(*field_args, **{**field_kwargs, "required": False}) # test valid inputs for input, output in valid.items(): self.assertEqual(required.clean(input), output) @@ -806,7 +927,7 @@ class SimpleTestCase(unittest.TestCase): optional.clean(input) self.assertEqual(context_manager.exception.messages, errors) # test required inputs - error_required = [required.error_messages['required']] + error_required = [required.error_messages["required"]] for e in required.empty_values: with self.assertRaises(ValidationError) as context_manager: required.clean(e) @@ -814,7 +935,7 @@ class SimpleTestCase(unittest.TestCase): self.assertEqual(optional.clean(e), empty_value) # test that max_length and min_length are always accepted if issubclass(fieldclass, CharField): - field_kwargs.update({'min_length': 2, 'max_length': 20}) + field_kwargs.update({"min_length": 2, "max_length": 20}) self.assertIsInstance(fieldclass(*field_args, **field_kwargs), fieldclass) def assertHTMLEqual(self, html1, html2, msg=None): @@ -823,39 +944,57 @@ class SimpleTestCase(unittest.TestCase): Whitespace in most cases is ignored, and attribute ordering is not significant. The arguments must be valid HTML. """ - dom1 = assert_and_parse_html(self, html1, msg, 'First argument is not valid HTML:') - dom2 = assert_and_parse_html(self, html2, msg, 'Second argument is not valid HTML:') + dom1 = assert_and_parse_html( + self, html1, msg, "First argument is not valid HTML:" + ) + dom2 = assert_and_parse_html( + self, html2, msg, "Second argument is not valid HTML:" + ) if dom1 != dom2: - standardMsg = '%s != %s' % ( - safe_repr(dom1, True), safe_repr(dom2, True)) - diff = ('\n' + '\n'.join(difflib.ndiff( - str(dom1).splitlines(), str(dom2).splitlines(), - ))) + standardMsg = "%s != %s" % (safe_repr(dom1, True), safe_repr(dom2, True)) + diff = "\n" + "\n".join( + difflib.ndiff( + str(dom1).splitlines(), + str(dom2).splitlines(), + ) + ) standardMsg = self._truncateMessage(standardMsg, diff) self.fail(self._formatMessage(msg, standardMsg)) def assertHTMLNotEqual(self, html1, html2, msg=None): """Assert that two HTML snippets are not semantically equivalent.""" - dom1 = assert_and_parse_html(self, html1, msg, 'First argument is not valid HTML:') - dom2 = assert_and_parse_html(self, html2, msg, 'Second argument is not valid HTML:') + dom1 = assert_and_parse_html( + self, html1, msg, "First argument is not valid HTML:" + ) + dom2 = assert_and_parse_html( + self, html2, msg, "Second argument is not valid HTML:" + ) if dom1 == dom2: - standardMsg = '%s == %s' % ( - safe_repr(dom1, True), safe_repr(dom2, True)) + standardMsg = "%s == %s" % (safe_repr(dom1, True), safe_repr(dom2, True)) self.fail(self._formatMessage(msg, standardMsg)) - def assertInHTML(self, needle, haystack, count=None, msg_prefix=''): - needle = assert_and_parse_html(self, needle, None, 'First argument is not valid HTML:') - haystack = assert_and_parse_html(self, haystack, None, 'Second argument is not valid HTML:') + def assertInHTML(self, needle, haystack, count=None, msg_prefix=""): + needle = assert_and_parse_html( + self, needle, None, "First argument is not valid HTML:" + ) + haystack = assert_and_parse_html( + self, haystack, None, "Second argument is not valid HTML:" + ) real_count = haystack.count(needle) if count is not None: self.assertEqual( - real_count, count, - msg_prefix + "Found %d instances of '%s' in response (expected %d)" % (real_count, needle, count) + real_count, + count, + msg_prefix + + "Found %d instances of '%s' in response (expected %d)" + % (real_count, needle, count), ) else: - self.assertTrue(real_count != 0, msg_prefix + "Couldn't find '%s' in response" % needle) + self.assertTrue( + real_count != 0, msg_prefix + "Couldn't find '%s' in response" % needle + ) def assertJSONEqual(self, raw, expected_data, msg=None): """ @@ -900,14 +1039,17 @@ class SimpleTestCase(unittest.TestCase): try: result = compare_xml(xml1, xml2) except Exception as e: - standardMsg = 'First or second argument is not valid XML\n%s' % e + standardMsg = "First or second argument is not valid XML\n%s" % e self.fail(self._formatMessage(msg, standardMsg)) else: if not result: - standardMsg = '%s != %s' % (safe_repr(xml1, True), safe_repr(xml2, True)) - diff = ('\n' + '\n'.join( + standardMsg = "%s != %s" % ( + safe_repr(xml1, True), + safe_repr(xml2, True), + ) + diff = "\n" + "\n".join( difflib.ndiff(xml1.splitlines(), xml2.splitlines()) - )) + ) standardMsg = self._truncateMessage(standardMsg, diff) self.fail(self._formatMessage(msg, standardMsg)) @@ -920,11 +1062,14 @@ class SimpleTestCase(unittest.TestCase): try: result = compare_xml(xml1, xml2) except Exception as e: - standardMsg = 'First or second argument is not valid XML\n%s' % e + standardMsg = "First or second argument is not valid XML\n%s" % e self.fail(self._formatMessage(msg, standardMsg)) else: if result: - standardMsg = '%s == %s' % (safe_repr(xml1, True), safe_repr(xml2, True)) + standardMsg = "%s == %s" % ( + safe_repr(xml1, True), + safe_repr(xml2, True), + ) self.fail(self._formatMessage(msg, standardMsg)) @@ -942,9 +1087,9 @@ class TransactionTestCase(SimpleTestCase): databases = {DEFAULT_DB_ALIAS} _disallowed_database_msg = ( - 'Database %(operation)s to %(alias)r are not allowed in this test. ' - 'Add %(alias)r to %(test)s.databases to ensure proper test isolation ' - 'and silence this failure.' + "Database %(operation)s to %(alias)r are not allowed in this test. " + "Add %(alias)r to %(test)s.databases to ensure proper test isolation " + "and silence this failure." ) # If transactions aren't available, Django will serialize the database @@ -966,7 +1111,7 @@ class TransactionTestCase(SimpleTestCase): apps.set_available_apps(self.available_apps) setting_changed.send( sender=settings._wrapped.__class__, - setting='INSTALLED_APPS', + setting="INSTALLED_APPS", value=self.available_apps, enter=True, ) @@ -979,7 +1124,7 @@ class TransactionTestCase(SimpleTestCase): apps.unset_available_apps() setting_changed.send( sender=settings._wrapped.__class__, - setting='INSTALLED_APPS', + setting="INSTALLED_APPS", value=settings.INSTALLED_APPS, enter=False, ) @@ -994,9 +1139,12 @@ class TransactionTestCase(SimpleTestCase): def _databases_names(cls, include_mirrors=True): # Only consider allowed database aliases, including mirrors or not. return [ - alias for alias in connections - if alias in cls.databases and ( - include_mirrors or not connections[alias].settings_dict['TEST']['MIRROR'] + alias + for alias in connections + if alias in cls.databases + and ( + include_mirrors + or not connections[alias].settings_dict["TEST"]["MIRROR"] ) ] @@ -1004,7 +1152,8 @@ class TransactionTestCase(SimpleTestCase): conn = connections[db_name] if conn.features.supports_sequence_reset: sql_list = conn.ops.sequence_reset_by_name_sql( - no_style(), conn.introspection.sequence_list()) + no_style(), conn.introspection.sequence_list() + ) if sql_list: with transaction.atomic(using=db_name): with conn.cursor() as cursor: @@ -1018,7 +1167,9 @@ class TransactionTestCase(SimpleTestCase): self._reset_sequences(db_name) # Provide replica initial data from migrated apps, if needed. - if self.serialized_rollback and hasattr(connections[db_name], "_test_serialized_contents"): + if self.serialized_rollback and hasattr( + connections[db_name], "_test_serialized_contents" + ): if self.available_apps is not None: apps.unset_available_apps() connections[db_name].creation.deserialize_db_from_string( @@ -1030,8 +1181,9 @@ class TransactionTestCase(SimpleTestCase): if self.fixtures: # We have to use this slightly awkward syntax due to the fact # that we're using *args and **kwargs together. - call_command('loaddata', *self.fixtures, - **{'verbosity': 0, 'database': db_name}) + call_command( + "loaddata", *self.fixtures, **{"verbosity": 0, "database": db_name} + ) def _should_reload_connections(self): return True @@ -1058,10 +1210,12 @@ class TransactionTestCase(SimpleTestCase): finally: if self.available_apps is not None: apps.unset_available_apps() - setting_changed.send(sender=settings._wrapped.__class__, - setting='INSTALLED_APPS', - value=settings.INSTALLED_APPS, - enter=False) + setting_changed.send( + sender=settings._wrapped.__class__, + setting="INSTALLED_APPS", + value=settings.INSTALLED_APPS, + enter=False, + ) def _fixture_teardown(self): # Allow TRUNCATE ... CASCADE and don't emit the post_migrate signal @@ -1069,17 +1223,22 @@ class TransactionTestCase(SimpleTestCase): for db_name in self._databases_names(include_mirrors=False): # Flush the database inhibit_post_migrate = ( - self.available_apps is not None or - ( # Inhibit the post_migrate signal when using serialized + self.available_apps is not None + or ( # Inhibit the post_migrate signal when using serialized # rollback to avoid trying to recreate the serialized data. - self.serialized_rollback and - hasattr(connections[db_name], '_test_serialized_contents') + self.serialized_rollback + and hasattr(connections[db_name], "_test_serialized_contents") ) ) - call_command('flush', verbosity=0, interactive=False, - database=db_name, reset_sequences=False, - allow_cascade=self.available_apps is not None, - inhibit_post_migrate=inhibit_post_migrate) + call_command( + "flush", + verbosity=0, + interactive=False, + database=db_name, + reset_sequences=False, + allow_cascade=self.available_apps is not None, + inhibit_post_migrate=inhibit_post_migrate, + ) def assertQuerysetEqual(self, qs, values, transform=None, ordered=True, msg=None): values = list(values) @@ -1090,10 +1249,10 @@ class TransactionTestCase(SimpleTestCase): return self.assertDictEqual(Counter(items), Counter(values), msg=msg) # For example qs.iterator() could be passed as qs, but it does not # have 'ordered' attribute. - if len(values) > 1 and hasattr(qs, 'ordered') and not qs.ordered: + if len(values) > 1 and hasattr(qs, "ordered") and not qs.ordered: raise ValueError( - 'Trying to compare non-ordered queryset against more than one ' - 'ordered value.' + "Trying to compare non-ordered queryset against more than one " + "ordered value." ) return self.assertEqual(list(items), values, msg=msg) @@ -1113,7 +1272,11 @@ def connections_support_transactions(aliases=None): Return whether or not all (or specified) connections support transactions. """ - conns = connections.all() if aliases is None else (connections[alias] for alias in aliases) + conns = ( + connections.all() + if aliases is None + else (connections[alias] for alias in aliases) + ) return all(conn.features.supports_transactions for conn in conns) @@ -1128,7 +1291,8 @@ class TestData: Objects are deep copied using a memo kept on the test case instance in order to maintain their original relationships. """ - memo_attr = '_testdata_memo' + + memo_attr = "_testdata_memo" def __init__(self, name, data): self.name = name @@ -1151,7 +1315,7 @@ class TestData: return data def __repr__(self): - return '<TestData: name=%r, data=%r>' % (self.name, self.data) + return "<TestData: name=%r, data=%r>" % (self.name, self.data) class TestCase(TransactionTestCase): @@ -1167,6 +1331,7 @@ class TestCase(TransactionTestCase): On database backends with no transaction support, TestCase behaves as TransactionTestCase. """ + @classmethod def _enter_atomics(cls): """Open atomic blocks for multiple databases.""" @@ -1199,7 +1364,11 @@ class TestCase(TransactionTestCase): if cls.fixtures: for db_name in cls._databases_names(include_mirrors=False): try: - call_command('loaddata', *cls.fixtures, **{'verbosity': 0, 'database': db_name}) + call_command( + "loaddata", + *cls.fixtures, + **{"verbosity": 0, "database": db_name}, + ) except Exception: cls._rollback_atomics(cls.cls_atomics) raise @@ -1239,7 +1408,7 @@ class TestCase(TransactionTestCase): return super()._fixture_setup() if self.reset_sequences: - raise TypeError('reset_sequences cannot be used on TestCase instances') + raise TypeError("reset_sequences cannot be used on TestCase instances") self.atomics = self._enter_atomics() def _fixture_teardown(self): @@ -1254,8 +1423,9 @@ class TestCase(TransactionTestCase): def _should_check_constraints(self, connection): return ( - connection.features.can_defer_constraint_checks and - not connection.needs_rollback and connection.is_usable() + connection.features.can_defer_constraint_checks + and not connection.needs_rollback + and connection.is_usable() ) @classmethod @@ -1281,6 +1451,7 @@ class TestCase(TransactionTestCase): class CheckCondition: """Descriptor class for deferred condition checking.""" + def __init__(self, *conditions): self.conditions = conditions @@ -1289,7 +1460,7 @@ class CheckCondition: def __get__(self, instance, cls=None): # Trigger access for all bases. - if any(getattr(base, '__unittest_skip__', False) for base in cls.__bases__): + if any(getattr(base, "__unittest_skip__", False) for base in cls.__bases__): return True for condition, reason in self.conditions: if condition(): @@ -1303,15 +1474,21 @@ class CheckCondition: def _deferredSkip(condition, reason, name): def decorator(test_func): nonlocal condition - if not (isinstance(test_func, type) and - issubclass(test_func, unittest.TestCase)): + if not ( + isinstance(test_func, type) and issubclass(test_func, unittest.TestCase) + ): + @wraps(test_func) def skip_wrapper(*args, **kwargs): - if (args and isinstance(args[0], unittest.TestCase) and - connection.alias not in getattr(args[0], 'databases', {})): + if ( + args + and isinstance(args[0], unittest.TestCase) + and connection.alias not in getattr(args[0], "databases", {}) + ): raise ValueError( "%s cannot be used on %s as %s doesn't allow queries " - "against the %r database." % ( + "against the %r database." + % ( name, args[0], args[0].__class__.__qualname__, @@ -1321,55 +1498,67 @@ def _deferredSkip(condition, reason, name): if condition(): raise unittest.SkipTest(reason) return test_func(*args, **kwargs) + test_item = skip_wrapper else: # Assume a class is decorated test_item = test_func - databases = getattr(test_item, 'databases', None) + databases = getattr(test_item, "databases", None) if not databases or connection.alias not in databases: # Defer raising to allow importing test class's module. def condition(): raise ValueError( "%s cannot be used on %s as it doesn't allow queries " - "against the '%s' database." % ( - name, test_item, connection.alias, + "against the '%s' database." + % ( + name, + test_item, + connection.alias, ) ) + # Retrieve the possibly existing value from the class's dict to # avoid triggering the descriptor. - skip = test_func.__dict__.get('__unittest_skip__') + skip = test_func.__dict__.get("__unittest_skip__") if isinstance(skip, CheckCondition): test_item.__unittest_skip__ = skip.add_condition(condition, reason) elif skip is not True: test_item.__unittest_skip__ = CheckCondition((condition, reason)) return test_item + return decorator def skipIfDBFeature(*features): """Skip a test if a database has at least one of the named features.""" return _deferredSkip( - lambda: any(getattr(connection.features, feature, False) for feature in features), + lambda: any( + getattr(connection.features, feature, False) for feature in features + ), "Database has feature(s) %s" % ", ".join(features), - 'skipIfDBFeature', + "skipIfDBFeature", ) def skipUnlessDBFeature(*features): """Skip a test unless a database has all the named features.""" return _deferredSkip( - lambda: not all(getattr(connection.features, feature, False) for feature in features), + lambda: not all( + getattr(connection.features, feature, False) for feature in features + ), "Database doesn't support feature(s): %s" % ", ".join(features), - 'skipUnlessDBFeature', + "skipUnlessDBFeature", ) def skipUnlessAnyDBFeature(*features): """Skip a test unless a database has any of the named features.""" return _deferredSkip( - lambda: not any(getattr(connection.features, feature, False) for feature in features), + lambda: not any( + getattr(connection.features, feature, False) for feature in features + ), "Database doesn't support any of the feature(s): %s" % ", ".join(features), - 'skipUnlessAnyDBFeature', + "skipUnlessAnyDBFeature", ) @@ -1378,6 +1567,7 @@ class QuietWSGIRequestHandler(WSGIRequestHandler): A WSGIRequestHandler that doesn't log to standard output any of the requests received, so as to not clutter the test result output. """ + def log_message(*args): pass @@ -1387,6 +1577,7 @@ class FSFilesHandler(WSGIHandler): WSGI middleware that intercepts calls to a directory, as defined by one of the *_ROOT settings, and serves those files, publishing them under *_URL. """ + def __init__(self, application): self.application = application self.base_url = urlparse(self.get_base_url()) @@ -1402,7 +1593,7 @@ class FSFilesHandler(WSGIHandler): def file_path(self, url): """Return the relative path to the file on disk for the given URL.""" - relative_url = url[len(self.base_url[2]):] + relative_url = url[len(self.base_url[2]) :] return url2pathname(relative_url) def get_response(self, request): @@ -1421,7 +1612,7 @@ class FSFilesHandler(WSGIHandler): # Emulate behavior of django.contrib.staticfiles.views.serve() when it # invokes staticfiles' finders functionality. # TODO: Modify if/when that internal API is refactored - final_rel_path = os_rel_path.replace('\\', '/').lstrip('/') + final_rel_path = os_rel_path.replace("\\", "/").lstrip("/") return serve(request, final_rel_path, document_root=self.get_base_dir()) def __call__(self, environ, start_response): @@ -1435,6 +1626,7 @@ class _StaticFilesHandler(FSFilesHandler): Handler for serving static files. A private class that is meant to be used solely as a convenience by LiveServerThread. """ + def get_base_dir(self): return settings.STATIC_ROOT @@ -1447,6 +1639,7 @@ class _MediaFilesHandler(FSFilesHandler): Handler for serving the media files. A private class that is meant to be used solely as a convenience by LiveServerThread. """ + def get_base_dir(self): return settings.MEDIA_ROOT @@ -1503,7 +1696,7 @@ class LiveServerThread(threading.Thread): ) def terminate(self): - if hasattr(self, 'httpd'): + if hasattr(self, "httpd"): # Stop the WSGI server self.httpd.shutdown() self.httpd.server_close() @@ -1521,14 +1714,15 @@ class LiveServerTestCase(TransactionTestCase): and each thread needs to commit all their transactions so that the other thread can see the changes. """ - host = 'localhost' + + host = "localhost" port = 0 server_thread_class = LiveServerThread static_handler = _StaticFilesHandler @classproperty def live_server_url(cls): - return 'http://%s:%s' % (cls.host, cls.server_thread.port) + return "http://%s:%s" % (cls.host, cls.server_thread.port) @classproperty def allowed_host(cls): @@ -1540,7 +1734,7 @@ class LiveServerTestCase(TransactionTestCase): for conn in connections.all(): # If using in-memory sqlite databases, pass the connections to # the server thread. - if conn.vendor == 'sqlite' and conn.is_in_memory_db(): + if conn.vendor == "sqlite" and conn.is_in_memory_db(): connections_override[conn.alias] = conn return connections_override @@ -1548,7 +1742,7 @@ class LiveServerTestCase(TransactionTestCase): def setUpClass(cls): super().setUpClass() cls._live_server_modified_settings = modify_settings( - ALLOWED_HOSTS={'append': cls.allowed_host}, + ALLOWED_HOSTS={"append": cls.allowed_host}, ) cls._live_server_modified_settings.enable() cls.addClassCleanup(cls._live_server_modified_settings.disable) @@ -1598,6 +1792,7 @@ class SerializeMixin: Place it early in the MRO in order to isolate setUpClass()/tearDownClass(). """ + lockfile = None def __init_subclass__(cls, /, **kwargs): @@ -1605,7 +1800,8 @@ class SerializeMixin: if cls.lockfile is None: raise ValueError( "{}.lockfile isn't set. Set it to a unique value " - "in the base class.".format(cls.__name__)) + "in the base class.".format(cls.__name__) + ) @classmethod def setUpClass(cls): diff --git a/django/test/utils.py b/django/test/utils.py index 6c2f566909..ac0fc34b08 100644 --- a/django/test/utils.py +++ b/django/test/utils.py @@ -35,15 +35,24 @@ except ImportError: __all__ = ( - 'Approximate', 'ContextList', 'isolate_lru_cache', 'get_runner', - 'CaptureQueriesContext', - 'ignore_warnings', 'isolate_apps', 'modify_settings', 'override_settings', - 'override_system_checks', 'tag', - 'requires_tz_support', - 'setup_databases', 'setup_test_environment', 'teardown_test_environment', + "Approximate", + "ContextList", + "isolate_lru_cache", + "get_runner", + "CaptureQueriesContext", + "ignore_warnings", + "isolate_apps", + "modify_settings", + "override_settings", + "override_system_checks", + "tag", + "requires_tz_support", + "setup_databases", + "setup_test_environment", + "teardown_test_environment", ) -TZ_SUPPORT = hasattr(time, 'tzset') +TZ_SUPPORT = hasattr(time, "tzset") class Approximate: @@ -63,6 +72,7 @@ class ContextList(list): A wrapper that provides direct key access to context items contained in a list of context objects. """ + def __getitem__(self, key): if isinstance(key, str): for subcontext in self: @@ -110,7 +120,7 @@ def setup_test_environment(debug=None): Perform global pre-test setup, such as installing the instrumented template renderer and setting the email backend to the locmem email backend. """ - if hasattr(_TestState, 'saved_data'): + if hasattr(_TestState, "saved_data"): # Executing this function twice would overwrite the saved values. raise RuntimeError( "setup_test_environment() was already called and can't be called " @@ -125,13 +135,13 @@ def setup_test_environment(debug=None): saved_data.allowed_hosts = settings.ALLOWED_HOSTS # Add the default host of the test client. - settings.ALLOWED_HOSTS = [*settings.ALLOWED_HOSTS, 'testserver'] + settings.ALLOWED_HOSTS = [*settings.ALLOWED_HOSTS, "testserver"] saved_data.debug = settings.DEBUG settings.DEBUG = debug saved_data.email_backend = settings.EMAIL_BACKEND - settings.EMAIL_BACKEND = 'django.core.mail.backends.locmem.EmailBackend' + settings.EMAIL_BACKEND = "django.core.mail.backends.locmem.EmailBackend" saved_data.template_render = Template._render Template._render = instrumented_test_render @@ -191,18 +201,17 @@ def setup_databases( # replace with: # serialize_alias = serialized_aliases is None or alias in serialized_aliases try: - serialize_alias = connection.settings_dict['TEST']['SERIALIZE'] + serialize_alias = connection.settings_dict["TEST"]["SERIALIZE"] except KeyError: serialize_alias = ( - serialized_aliases is None or - alias in serialized_aliases + serialized_aliases is None or alias in serialized_aliases ) else: warnings.warn( - 'The SERIALIZE test database setting is ' - 'deprecated as it can be inferred from the ' - 'TestCase/TransactionTestCase.databases that ' - 'enable the serialized_rollback feature.', + "The SERIALIZE test database setting is " + "deprecated as it can be inferred from the " + "TestCase/TransactionTestCase.databases that " + "enable the serialized_rollback feature.", category=RemovedInDjango50Warning, ) connection.creation.create_test_db( @@ -221,12 +230,15 @@ def setup_databases( ) # Configure all other connections as mirrors of the first one else: - connections[alias].creation.set_as_test_mirror(connections[first_alias].settings_dict) + connections[alias].creation.set_as_test_mirror( + connections[first_alias].settings_dict + ) # Configure the test mirrors. for alias, mirror_alias in mirrored_aliases.items(): connections[alias].creation.set_as_test_mirror( - connections[mirror_alias].settings_dict) + connections[mirror_alias].settings_dict + ) if debug_sql: for alias in connections: @@ -246,8 +258,8 @@ def iter_test_cases(tests): # Prevent an unfriendly RecursionError that can happen with # strings. raise TypeError( - f'Test {test!r} must be a test case or test suite not string ' - f'(was found in {tests!r}).' + f"Test {test!r} must be a test case or test suite not string " + f"(was found in {tests!r})." ) if isinstance(test, TestCase): yield test @@ -319,18 +331,18 @@ def get_unique_databases_and_mirrors(aliases=None): for alias in connections: connection = connections[alias] - test_settings = connection.settings_dict['TEST'] + test_settings = connection.settings_dict["TEST"] - if test_settings['MIRROR']: + if test_settings["MIRROR"]: # If the database is marked as a test mirror, save the alias. - mirrored_aliases[alias] = test_settings['MIRROR'] + mirrored_aliases[alias] = test_settings["MIRROR"] elif alias in aliases: # Store a tuple with DB parameters that uniquely identify it. # If we have two aliases with the same values for that tuple, # we only need to create the test database once. item = test_databases.setdefault( connection.creation.test_db_signature(), - (connection.settings_dict['NAME'], []), + (connection.settings_dict["NAME"], []), ) # The default database must be the first because data migrations # use the default alias by default. @@ -339,11 +351,16 @@ def get_unique_databases_and_mirrors(aliases=None): else: item[1].append(alias) - if 'DEPENDENCIES' in test_settings: - dependencies[alias] = test_settings['DEPENDENCIES'] + if "DEPENDENCIES" in test_settings: + dependencies[alias] = test_settings["DEPENDENCIES"] else: - if alias != DEFAULT_DB_ALIAS and connection.creation.test_db_signature() != default_sig: - dependencies[alias] = test_settings.get('DEPENDENCIES', [DEFAULT_DB_ALIAS]) + if ( + alias != DEFAULT_DB_ALIAS + and connection.creation.test_db_signature() != default_sig + ): + dependencies[alias] = test_settings.get( + "DEPENDENCIES", [DEFAULT_DB_ALIAS] + ) test_databases = dict(dependency_ordered(test_databases.items(), dependencies)) return test_databases, mirrored_aliases @@ -365,12 +382,12 @@ def teardown_databases(old_config, verbosity, parallel=0, keepdb=False): def get_runner(settings, test_runner_class=None): test_runner_class = test_runner_class or settings.TEST_RUNNER - test_path = test_runner_class.split('.') + test_path = test_runner_class.split(".") # Allow for relative paths if len(test_path) > 1: - test_module_name = '.'.join(test_path[:-1]) + test_module_name = ".".join(test_path[:-1]) else: - test_module_name = '.' + test_module_name = "." test_module = __import__(test_module_name, {}, {}, test_path[-1]) return getattr(test_module, test_path[-1]) @@ -387,6 +404,7 @@ class TestContextDecorator: `kwarg_name`: keyword argument passing the return value of enable() if used as a function decorator. """ + def __init__(self, attr_name=None, kwarg_name=None): self.attr_name = attr_name self.kwarg_name = kwarg_name @@ -416,7 +434,7 @@ class TestContextDecorator: cls.setUp = setUp return cls - raise TypeError('Can only decorate subclasses of unittest.TestCase') + raise TypeError("Can only decorate subclasses of unittest.TestCase") def decorate_callable(self, func): if asyncio.iscoroutinefunction(func): @@ -428,13 +446,16 @@ class TestContextDecorator: if self.kwarg_name: kwargs[self.kwarg_name] = context return await func(*args, **kwargs) + else: + @wraps(func) def inner(*args, **kwargs): with self as context: if self.kwarg_name: kwargs[self.kwarg_name] = context return func(*args, **kwargs) + return inner def __call__(self, decorated): @@ -442,7 +463,7 @@ class TestContextDecorator: return self.decorate_class(decorated) elif callable(decorated): return self.decorate_callable(decorated) - raise TypeError('Cannot decorate object of type %s' % type(decorated)) + raise TypeError("Cannot decorate object of type %s" % type(decorated)) class override_settings(TestContextDecorator): @@ -452,6 +473,7 @@ class override_settings(TestContextDecorator): with the ``with`` statement. In either event, entering/exiting are called before and after, respectively, the function/block is executed. """ + enable_exception = None def __init__(self, **kwargs): @@ -461,9 +483,9 @@ class override_settings(TestContextDecorator): def enable(self): # Keep this code at the beginning to leave the settings unchanged # in case it raises an exception because INSTALLED_APPS is invalid. - if 'INSTALLED_APPS' in self.options: + if "INSTALLED_APPS" in self.options: try: - apps.set_installed_apps(self.options['INSTALLED_APPS']) + apps.set_installed_apps(self.options["INSTALLED_APPS"]) except Exception: apps.unset_installed_apps() raise @@ -476,14 +498,16 @@ class override_settings(TestContextDecorator): try: setting_changed.send( sender=settings._wrapped.__class__, - setting=key, value=new_value, enter=True, + setting=key, + value=new_value, + enter=True, ) except Exception as exc: self.enable_exception = exc self.disable() def disable(self): - if 'INSTALLED_APPS' in self.options: + if "INSTALLED_APPS" in self.options: apps.unset_installed_apps() settings._wrapped = self.wrapped del self.wrapped @@ -492,7 +516,9 @@ class override_settings(TestContextDecorator): new_value = getattr(settings, key, None) responses_for_setting = setting_changed.send_robust( sender=settings._wrapped.__class__, - setting=key, value=new_value, enter=False, + setting=key, + value=new_value, + enter=False, ) responses.extend(responses_for_setting) if self.enable_exception is not None: @@ -515,10 +541,12 @@ class override_settings(TestContextDecorator): def decorate_class(self, cls): from django.test import SimpleTestCase + if not issubclass(cls, SimpleTestCase): raise ValueError( "Only subclasses of Django SimpleTestCase can be decorated " - "with override_settings") + "with override_settings" + ) self.save_options(cls) return cls @@ -528,6 +556,7 @@ class modify_settings(override_settings): Like override_settings, but makes it possible to append, prepend, or remove items instead of redefining the entire list. """ + def __init__(self, *args, **kwargs): if args: # Hack used when instantiating from SimpleTestCase.setUpClass. @@ -543,8 +572,9 @@ class modify_settings(override_settings): test_func._modified_settings = self.operations else: # Duplicate list to prevent subclasses from altering their parent. - test_func._modified_settings = list( - test_func._modified_settings) + self.operations + test_func._modified_settings = ( + list(test_func._modified_settings) + self.operations + ) def enable(self): self.options = {} @@ -559,11 +589,11 @@ class modify_settings(override_settings): # items my be a single value or an iterable. if isinstance(items, str): items = [items] - if action == 'append': + if action == "append": value = value + [item for item in items if item not in value] - elif action == 'prepend': + elif action == "prepend": value = [item for item in items if item not in value] + value - elif action == 'remove': + elif action == "remove": value = [item for item in value if item not in items] else: raise ValueError("Unsupported action: %s" % action) @@ -577,8 +607,10 @@ class override_system_checks(TestContextDecorator): Useful when you override `INSTALLED_APPS`, e.g. if you exclude `auth` app, you also need to exclude its system checks. """ + def __init__(self, new_checks, deployment_checks=None): from django.core.checks.registry import registry + self.registry = registry self.new_checks = new_checks self.deployment_checks = deployment_checks @@ -588,12 +620,12 @@ class override_system_checks(TestContextDecorator): self.old_checks = self.registry.registered_checks self.registry.registered_checks = set() for check in self.new_checks: - self.registry.register(check, *getattr(check, 'tags', ())) + self.registry.register(check, *getattr(check, "tags", ())) self.old_deployment_checks = self.registry.deployment_checks if self.deployment_checks is not None: self.registry.deployment_checks = set() for check in self.deployment_checks: - self.registry.register(check, *getattr(check, 'tags', ()), deploy=True) + self.registry.register(check, *getattr(check, "tags", ()), deploy=True) def disable(self): self.registry.registered_checks = self.old_checks @@ -609,18 +641,18 @@ def compare_xml(want, got): Based on https://github.com/lxml/lxml/blob/master/src/lxml/doctestcompare.py """ - _norm_whitespace_re = re.compile(r'[ \t\n][ \t\n]+') + _norm_whitespace_re = re.compile(r"[ \t\n][ \t\n]+") def norm_whitespace(v): - return _norm_whitespace_re.sub(' ', v) + return _norm_whitespace_re.sub(" ", v) def child_text(element): - return ''.join(c.data for c in element.childNodes - if c.nodeType == Node.TEXT_NODE) + return "".join( + c.data for c in element.childNodes if c.nodeType == Node.TEXT_NODE + ) def children(element): - return [c for c in element.childNodes - if c.nodeType == Node.ELEMENT_NODE] + return [c for c in element.childNodes if c.nodeType == Node.ELEMENT_NODE] def norm_child_text(element): return norm_whitespace(child_text(element)) @@ -639,7 +671,9 @@ def compare_xml(want, got): got_children = children(got_element) if len(want_children) != len(got_children): return False - return all(check_element(want, got) for want, got in zip(want_children, got_children)) + return all( + check_element(want, got) for want, got in zip(want_children, got_children) + ) def first_node(document): for node in document.childNodes: @@ -650,13 +684,13 @@ def compare_xml(want, got): ): return node - want = want.strip().replace('\\n', '\n') - got = got.strip().replace('\\n', '\n') + want = want.strip().replace("\\n", "\n") + got = got.strip().replace("\\n", "\n") # If the string is not a complete xml document, we may need to add a # root element. This allow us to compare fragments, like "<foo/><bar/>" - if not want.startswith('<?xml'): - wrapper = '<root>%s</root>' + if not want.startswith("<?xml"): + wrapper = "<root>%s</root>" want = wrapper % want got = wrapper % got @@ -671,6 +705,7 @@ class CaptureQueriesContext: """ Context manager that captures queries executed by the specified connection. """ + def __init__(self, connection): self.connection = connection @@ -685,7 +720,7 @@ class CaptureQueriesContext: @property def captured_queries(self): - return self.connection.queries[self.initial_queries:self.final_queries] + return self.connection.queries[self.initial_queries : self.final_queries] def __enter__(self): self.force_debug_cursor = self.connection.force_debug_cursor @@ -709,7 +744,7 @@ class CaptureQueriesContext: class ignore_warnings(TestContextDecorator): def __init__(self, **kwargs): self.ignore_kwargs = kwargs - if 'message' in self.ignore_kwargs or 'module' in self.ignore_kwargs: + if "message" in self.ignore_kwargs or "module" in self.ignore_kwargs: self.filter_func = warnings.filterwarnings else: self.filter_func = warnings.simplefilter @@ -718,7 +753,7 @@ class ignore_warnings(TestContextDecorator): def enable(self): self.catch_warnings = warnings.catch_warnings() self.catch_warnings.__enter__() - self.filter_func('ignore', **self.ignore_kwargs) + self.filter_func("ignore", **self.ignore_kwargs) def disable(self): self.catch_warnings.__exit__(*sys.exc_info()) @@ -732,7 +767,7 @@ class ignore_warnings(TestContextDecorator): requires_tz_support = skipUnless( TZ_SUPPORT, "This test relies on the ability to run a program in an arbitrary " - "time zone, but your operating system isn't able to do that." + "time zone, but your operating system isn't able to do that.", ) @@ -775,9 +810,9 @@ def captured_output(stream_name): def captured_stdout(): """Capture the output of sys.stdout: - with captured_stdout() as stdout: - print("hello") - self.assertEqual(stdout.getvalue(), "hello\n") + with captured_stdout() as stdout: + print("hello") + self.assertEqual(stdout.getvalue(), "hello\n") """ return captured_output("stdout") @@ -785,9 +820,9 @@ def captured_stdout(): def captured_stderr(): """Capture the output of sys.stderr: - with captured_stderr() as stderr: - print("hello", file=sys.stderr) - self.assertEqual(stderr.getvalue(), "hello\n") + with captured_stderr() as stderr: + print("hello", file=sys.stderr) + self.assertEqual(stderr.getvalue(), "hello\n") """ return captured_output("stderr") @@ -795,12 +830,12 @@ def captured_stderr(): def captured_stdin(): """Capture the input to sys.stdin: - with captured_stdin() as stdin: - stdin.write('hello\n') - stdin.seek(0) - # call test code that consumes from sys.stdin - captured = input() - self.assertEqual(captured, "hello") + with captured_stdin() as stdin: + stdin.write('hello\n') + stdin.seek(0) + # call test code that consumes from sys.stdin + captured = input() + self.assertEqual(captured, "hello") """ return captured_output("stdin") @@ -828,18 +863,24 @@ def require_jinja2(test_func): Django template engine for a test or skip it if Jinja2 isn't available. """ test_func = skipIf(jinja2 is None, "this test requires jinja2")(test_func) - return override_settings(TEMPLATES=[{ - 'BACKEND': 'django.template.backends.django.DjangoTemplates', - 'APP_DIRS': True, - }, { - 'BACKEND': 'django.template.backends.jinja2.Jinja2', - 'APP_DIRS': True, - 'OPTIONS': {'keep_trailing_newline': True}, - }])(test_func) + return override_settings( + TEMPLATES=[ + { + "BACKEND": "django.template.backends.django.DjangoTemplates", + "APP_DIRS": True, + }, + { + "BACKEND": "django.template.backends.jinja2.Jinja2", + "APP_DIRS": True, + "OPTIONS": {"keep_trailing_newline": True}, + }, + ] + )(test_func) class override_script_prefix(TestContextDecorator): """Decorator or context manager to temporary override the script prefix.""" + def __init__(self, prefix): self.prefix = prefix super().__init__() @@ -857,8 +898,9 @@ class LoggingCaptureMixin: Capture the output from the 'django' logger and store it on the class's logger_output attribute. """ + def setUp(self): - self.logger = logging.getLogger('django') + self.logger = logging.getLogger("django") self.old_stream = self.logger.handlers[0].stream self.logger_output = StringIO() self.logger.handlers[0].stream = self.logger_output @@ -883,6 +925,7 @@ class isolate_apps(TestContextDecorator): `kwarg_name`: keyword argument passing the isolated registry if used as a function decorator. """ + def __init__(self, *installed_apps, **kwargs): self.installed_apps = installed_apps super().__init__(**kwargs) @@ -890,11 +933,11 @@ class isolate_apps(TestContextDecorator): def enable(self): self.old_apps = Options.default_apps apps = Apps(self.installed_apps) - setattr(Options, 'default_apps', apps) + setattr(Options, "default_apps", apps) return apps def disable(self): - setattr(Options, 'default_apps', self.old_apps) + setattr(Options, "default_apps", self.old_apps) class TimeKeeper: @@ -914,7 +957,7 @@ class TimeKeeper: def print_results(self): for name, end_times in self.records.items(): for record_time in end_times: - record = '%s took %.3fs' % (name, record_time) + record = "%s took %.3fs" % (name, record_time) sys.stderr.write(record + os.linesep) @@ -929,12 +972,14 @@ class NullTimeKeeper: def tag(*tags): """Decorator to add tags to a test class or method.""" + def decorator(obj): - if hasattr(obj, 'tags'): + if hasattr(obj, "tags"): obj.tags = obj.tags.union(tags) else: - setattr(obj, 'tags', set(tags)) + setattr(obj, "tags", set(tags)) return obj + return decorator |
