summaryrefslogtreecommitdiff
path: root/django/test
diff options
context:
space:
mode:
authordjango-bot <ops@djangoproject.com>2022-02-03 20:24:19 +0100
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2022-02-07 20:37:05 +0100
commit9c19aff7c7561e3a82978a272ecdaad40dda5c00 (patch)
treef0506b668a013d0063e5fba3dbf4863b466713ba /django/test
parentf68fa8b45dfac545cfc4111d4e52804c86db68d3 (diff)
Refs #33476 -- Reformatted code with Black.
Diffstat (limited to 'django/test')
-rw-r--r--django/test/__init__.py41
-rw-r--r--django/test/client.py578
-rw-r--r--django/test/html.py106
-rw-r--r--django/test/runner.py343
-rw-r--r--django/test/selenium.py32
-rw-r--r--django/test/signals.py76
-rw-r--r--django/test/testcases.py632
-rw-r--r--django/test/utils.py221
8 files changed, 1286 insertions, 743 deletions
diff --git a/django/test/__init__.py b/django/test/__init__.py
index d1f953a8dd..485298e8e7 100644
--- a/django/test/__init__.py
+++ b/django/test/__init__.py
@@ -1,21 +1,38 @@
"""Django Unit Test framework."""
-from django.test.client import (
- AsyncClient, AsyncRequestFactory, Client, RequestFactory,
-)
+from django.test.client import AsyncClient, AsyncRequestFactory, Client, RequestFactory
from django.test.testcases import (
- LiveServerTestCase, SimpleTestCase, TestCase, TransactionTestCase,
- skipIfDBFeature, skipUnlessAnyDBFeature, skipUnlessDBFeature,
+ LiveServerTestCase,
+ SimpleTestCase,
+ TestCase,
+ TransactionTestCase,
+ skipIfDBFeature,
+ skipUnlessAnyDBFeature,
+ skipUnlessDBFeature,
)
from django.test.utils import (
- ignore_warnings, modify_settings, override_settings,
- override_system_checks, tag,
+ ignore_warnings,
+ modify_settings,
+ override_settings,
+ override_system_checks,
+ tag,
)
__all__ = [
- 'AsyncClient', 'AsyncRequestFactory', 'Client', 'RequestFactory',
- 'TestCase', 'TransactionTestCase', 'SimpleTestCase', 'LiveServerTestCase',
- 'skipIfDBFeature', 'skipUnlessAnyDBFeature', 'skipUnlessDBFeature',
- 'ignore_warnings', 'modify_settings', 'override_settings',
- 'override_system_checks', 'tag',
+ "AsyncClient",
+ "AsyncRequestFactory",
+ "Client",
+ "RequestFactory",
+ "TestCase",
+ "TransactionTestCase",
+ "SimpleTestCase",
+ "LiveServerTestCase",
+ "skipIfDBFeature",
+ "skipUnlessAnyDBFeature",
+ "skipUnlessDBFeature",
+ "ignore_warnings",
+ "modify_settings",
+ "override_settings",
+ "override_system_checks",
+ "tag",
]
diff --git a/django/test/client.py b/django/test/client.py
index af1090a740..a38e7dae13 100644
--- a/django/test/client.py
+++ b/django/test/client.py
@@ -16,9 +16,7 @@ from django.core.handlers.asgi import ASGIRequest
from django.core.handlers.base import BaseHandler
from django.core.handlers.wsgi import WSGIRequest
from django.core.serializers.json import DjangoJSONEncoder
-from django.core.signals import (
- got_request_exception, request_finished, request_started,
-)
+from django.core.signals import got_request_exception, request_finished, request_started
from django.db import close_old_connections
from django.http import HttpRequest, QueryDict, SimpleCookie
from django.test import signals
@@ -31,20 +29,26 @@ from django.utils.itercompat import is_iterable
from django.utils.regex_helper import _lazy_re_compile
__all__ = (
- 'AsyncClient', 'AsyncRequestFactory', 'Client', 'RedirectCycleError',
- 'RequestFactory', 'encode_file', 'encode_multipart',
+ "AsyncClient",
+ "AsyncRequestFactory",
+ "Client",
+ "RedirectCycleError",
+ "RequestFactory",
+ "encode_file",
+ "encode_multipart",
)
-BOUNDARY = 'BoUnDaRyStRiNg'
-MULTIPART_CONTENT = 'multipart/form-data; boundary=%s' % BOUNDARY
-CONTENT_TYPE_RE = _lazy_re_compile(r'.*; charset=([\w-]+);?')
+BOUNDARY = "BoUnDaRyStRiNg"
+MULTIPART_CONTENT = "multipart/form-data; boundary=%s" % BOUNDARY
+CONTENT_TYPE_RE = _lazy_re_compile(r".*; charset=([\w-]+);?")
# Structured suffix spec: https://tools.ietf.org/html/rfc6838#section-4.2.8
-JSON_CONTENT_TYPE_RE = _lazy_re_compile(r'^application\/(.+\+)?json')
+JSON_CONTENT_TYPE_RE = _lazy_re_compile(r"^application\/(.+\+)?json")
class RedirectCycleError(Exception):
"""The test client has been asked to follow a redirect loop."""
+
def __init__(self, message, last_response):
super().__init__(message)
self.last_response = last_response
@@ -58,6 +62,7 @@ class FakePayload:
length. This makes sure that views can't do anything under the test client
that wouldn't work in real life.
"""
+
def __init__(self, content=None):
self.__content = BytesIO()
self.__len = 0
@@ -74,7 +79,9 @@ class FakePayload:
self.read_started = True
if num_bytes is None:
num_bytes = self.__len or 0
- assert self.__len >= num_bytes, "Cannot read more than the available bytes from the HTTP incoming data."
+ assert (
+ self.__len >= num_bytes
+ ), "Cannot read more than the available bytes from the HTTP incoming data."
content = self.__content.read(num_bytes)
self.__len -= num_bytes
return content
@@ -92,7 +99,7 @@ def closing_iterator_wrapper(iterable, close):
yield from iterable
finally:
request_finished.disconnect(close_old_connections)
- close() # will fire request_finished
+ close() # will fire request_finished
request_finished.connect(close_old_connections)
@@ -106,12 +113,12 @@ def conditional_content_removal(request, response):
if response.streaming:
response.streaming_content = []
else:
- response.content = b''
- if request.method == 'HEAD':
+ response.content = b""
+ if request.method == "HEAD":
if response.streaming:
response.streaming_content = []
else:
- response.content = b''
+ response.content = b""
return response
@@ -121,6 +128,7 @@ class ClientHandler(BaseHandler):
interface to compose requests, but return the raw HttpResponse object with
the originating WSGIRequest attached to its ``wsgi_request`` attribute.
"""
+
def __init__(self, enforce_csrf_checks=True, *args, **kwargs):
self.enforce_csrf_checks = enforce_csrf_checks
super().__init__(*args, **kwargs)
@@ -154,10 +162,11 @@ class ClientHandler(BaseHandler):
# Emulate a WSGI server by calling the close method on completion.
if response.streaming:
response.streaming_content = closing_iterator_wrapper(
- response.streaming_content, response.close)
+ response.streaming_content, response.close
+ )
else:
request_finished.disconnect(close_old_connections)
- response.close() # will fire request_finished
+ response.close() # will fire request_finished
request_finished.connect(close_old_connections)
return response
@@ -165,6 +174,7 @@ class ClientHandler(BaseHandler):
class AsyncClientHandler(BaseHandler):
"""An async version of ClientHandler."""
+
def __init__(self, enforce_csrf_checks=True, *args, **kwargs):
self.enforce_csrf_checks = enforce_csrf_checks
super().__init__(*args, **kwargs)
@@ -175,13 +185,15 @@ class AsyncClientHandler(BaseHandler):
if self._middleware_chain is None:
self.load_middleware(is_async=True)
# Extract body file from the scope, if provided.
- if '_body_file' in scope:
- body_file = scope.pop('_body_file')
+ if "_body_file" in scope:
+ body_file = scope.pop("_body_file")
else:
- body_file = FakePayload('')
+ body_file = FakePayload("")
request_started.disconnect(close_old_connections)
- await sync_to_async(request_started.send, thread_sensitive=False)(sender=self.__class__, scope=scope)
+ await sync_to_async(request_started.send, thread_sensitive=False)(
+ sender=self.__class__, scope=scope
+ )
request_started.connect(close_old_connections)
request = ASGIRequest(scope, body_file)
# Sneaky little hack so that we can easily get round
@@ -197,7 +209,9 @@ class AsyncClientHandler(BaseHandler):
response.asgi_request = request
# Emulate a server by calling the close method on completion.
if response.streaming:
- response.streaming_content = await sync_to_async(closing_iterator_wrapper, thread_sensitive=False)(
+ response.streaming_content = await sync_to_async(
+ closing_iterator_wrapper, thread_sensitive=False
+ )(
response.streaming_content,
response.close,
)
@@ -216,10 +230,10 @@ def store_rendered_templates(store, signal, sender, template, context, **kwargs)
The context is copied so that it is an accurate representation at the time
of rendering.
"""
- store.setdefault('templates', []).append(template)
- if 'context' not in store:
- store['context'] = ContextList()
- store['context'].append(copy(context))
+ store.setdefault("templates", []).append(template)
+ if "context" not in store:
+ store["context"] = ContextList()
+ store["context"].append(copy(context))
def encode_multipart(boundary, data):
@@ -255,25 +269,33 @@ def encode_multipart(boundary, data):
if is_file(item):
lines.extend(encode_file(boundary, key, item))
else:
- lines.extend(to_bytes(val) for val in [
- '--%s' % boundary,
- 'Content-Disposition: form-data; name="%s"' % key,
- '',
- item
- ])
+ lines.extend(
+ to_bytes(val)
+ for val in [
+ "--%s" % boundary,
+ 'Content-Disposition: form-data; name="%s"' % key,
+ "",
+ item,
+ ]
+ )
else:
- lines.extend(to_bytes(val) for val in [
- '--%s' % boundary,
- 'Content-Disposition: form-data; name="%s"' % key,
- '',
- value
- ])
+ lines.extend(
+ to_bytes(val)
+ for val in [
+ "--%s" % boundary,
+ 'Content-Disposition: form-data; name="%s"' % key,
+ "",
+ value,
+ ]
+ )
- lines.extend([
- to_bytes('--%s--' % boundary),
- b'',
- ])
- return b'\r\n'.join(lines)
+ lines.extend(
+ [
+ to_bytes("--%s--" % boundary),
+ b"",
+ ]
+ )
+ return b"\r\n".join(lines)
def encode_file(boundary, key, file):
@@ -282,10 +304,10 @@ def encode_file(boundary, key, file):
# file.name might not be a string. For example, it's an int for
# tempfile.TemporaryFile().
- file_has_string_name = hasattr(file, 'name') and isinstance(file.name, str)
- filename = os.path.basename(file.name) if file_has_string_name else ''
+ file_has_string_name = hasattr(file, "name") and isinstance(file.name, str)
+ filename = os.path.basename(file.name) if file_has_string_name else ""
- if hasattr(file, 'content_type'):
+ if hasattr(file, "content_type"):
content_type = file.content_type
elif filename:
content_type = mimetypes.guess_type(filename)[0]
@@ -293,15 +315,16 @@ def encode_file(boundary, key, file):
content_type = None
if content_type is None:
- content_type = 'application/octet-stream'
+ content_type = "application/octet-stream"
filename = filename or key
return [
- to_bytes('--%s' % boundary),
- to_bytes('Content-Disposition: form-data; name="%s"; filename="%s"'
- % (key, filename)),
- to_bytes('Content-Type: %s' % content_type),
- b'',
- to_bytes(file.read())
+ to_bytes("--%s" % boundary),
+ to_bytes(
+ 'Content-Disposition: form-data; name="%s"; filename="%s"' % (key, filename)
+ ),
+ to_bytes("Content-Type: %s" % content_type),
+ b"",
+ to_bytes(file.read()),
]
@@ -318,6 +341,7 @@ class RequestFactory:
Once you have a request object you can pass it to any view function,
just as if that view had been hooked up using a URLconf.
"""
+
def __init__(self, *, json_encoder=DjangoJSONEncoder, **defaults):
self.json_encoder = json_encoder
self.defaults = defaults
@@ -333,24 +357,26 @@ class RequestFactory:
# - REMOTE_ADDR: often useful, see #8551.
# See https://www.python.org/dev/peps/pep-3333/#environ-variables
return {
- 'HTTP_COOKIE': '; '.join(sorted(
- '%s=%s' % (morsel.key, morsel.coded_value)
- for morsel in self.cookies.values()
- )),
- 'PATH_INFO': '/',
- 'REMOTE_ADDR': '127.0.0.1',
- 'REQUEST_METHOD': 'GET',
- 'SCRIPT_NAME': '',
- 'SERVER_NAME': 'testserver',
- 'SERVER_PORT': '80',
- 'SERVER_PROTOCOL': 'HTTP/1.1',
- 'wsgi.version': (1, 0),
- 'wsgi.url_scheme': 'http',
- 'wsgi.input': FakePayload(b''),
- 'wsgi.errors': self.errors,
- 'wsgi.multiprocess': True,
- 'wsgi.multithread': False,
- 'wsgi.run_once': False,
+ "HTTP_COOKIE": "; ".join(
+ sorted(
+ "%s=%s" % (morsel.key, morsel.coded_value)
+ for morsel in self.cookies.values()
+ )
+ ),
+ "PATH_INFO": "/",
+ "REMOTE_ADDR": "127.0.0.1",
+ "REQUEST_METHOD": "GET",
+ "SCRIPT_NAME": "",
+ "SERVER_NAME": "testserver",
+ "SERVER_PORT": "80",
+ "SERVER_PROTOCOL": "HTTP/1.1",
+ "wsgi.version": (1, 0),
+ "wsgi.url_scheme": "http",
+ "wsgi.input": FakePayload(b""),
+ "wsgi.errors": self.errors,
+ "wsgi.multiprocess": True,
+ "wsgi.multithread": False,
+ "wsgi.run_once": False,
**self.defaults,
**request,
}
@@ -376,7 +402,9 @@ class RequestFactory:
Return encoded JSON if data is a dict, list, or tuple and content_type
is application/json.
"""
- should_encode = JSON_CONTENT_TYPE_RE.match(content_type) and isinstance(data, (dict, list, tuple))
+ should_encode = JSON_CONTENT_TYPE_RE.match(content_type) and isinstance(
+ data, (dict, list, tuple)
+ )
return json.dumps(data, cls=self.json_encoder) if should_encode else data
def _get_path(self, parsed):
@@ -388,88 +416,128 @@ class RequestFactory:
# Replace the behavior where non-ASCII values in the WSGI environ are
# arbitrarily decoded with ISO-8859-1.
# Refs comment in `get_bytes_from_wsgi()`.
- return path.decode('iso-8859-1')
+ return path.decode("iso-8859-1")
def get(self, path, data=None, secure=False, **extra):
"""Construct a GET request."""
data = {} if data is None else data
- return self.generic('GET', path, secure=secure, **{
- 'QUERY_STRING': urlencode(data, doseq=True),
- **extra,
- })
+ return self.generic(
+ "GET",
+ path,
+ secure=secure,
+ **{
+ "QUERY_STRING": urlencode(data, doseq=True),
+ **extra,
+ },
+ )
- def post(self, path, data=None, content_type=MULTIPART_CONTENT,
- secure=False, **extra):
+ def post(
+ self, path, data=None, content_type=MULTIPART_CONTENT, secure=False, **extra
+ ):
"""Construct a POST request."""
data = self._encode_json({} if data is None else data, content_type)
post_data = self._encode_data(data, content_type)
- return self.generic('POST', path, post_data, content_type,
- secure=secure, **extra)
+ return self.generic(
+ "POST", path, post_data, content_type, secure=secure, **extra
+ )
def head(self, path, data=None, secure=False, **extra):
"""Construct a HEAD request."""
data = {} if data is None else data
- return self.generic('HEAD', path, secure=secure, **{
- 'QUERY_STRING': urlencode(data, doseq=True),
- **extra,
- })
+ return self.generic(
+ "HEAD",
+ path,
+ secure=secure,
+ **{
+ "QUERY_STRING": urlencode(data, doseq=True),
+ **extra,
+ },
+ )
def trace(self, path, secure=False, **extra):
"""Construct a TRACE request."""
- return self.generic('TRACE', path, secure=secure, **extra)
+ return self.generic("TRACE", path, secure=secure, **extra)
- def options(self, path, data='', content_type='application/octet-stream',
- secure=False, **extra):
+ def options(
+ self,
+ path,
+ data="",
+ content_type="application/octet-stream",
+ secure=False,
+ **extra,
+ ):
"Construct an OPTIONS request."
- return self.generic('OPTIONS', path, data, content_type,
- secure=secure, **extra)
+ return self.generic("OPTIONS", path, data, content_type, secure=secure, **extra)
- def put(self, path, data='', content_type='application/octet-stream',
- secure=False, **extra):
+ def put(
+ self,
+ path,
+ data="",
+ content_type="application/octet-stream",
+ secure=False,
+ **extra,
+ ):
"""Construct a PUT request."""
data = self._encode_json(data, content_type)
- return self.generic('PUT', path, data, content_type,
- secure=secure, **extra)
+ return self.generic("PUT", path, data, content_type, secure=secure, **extra)
- def patch(self, path, data='', content_type='application/octet-stream',
- secure=False, **extra):
+ def patch(
+ self,
+ path,
+ data="",
+ content_type="application/octet-stream",
+ secure=False,
+ **extra,
+ ):
"""Construct a PATCH request."""
data = self._encode_json(data, content_type)
- return self.generic('PATCH', path, data, content_type,
- secure=secure, **extra)
+ return self.generic("PATCH", path, data, content_type, secure=secure, **extra)
- def delete(self, path, data='', content_type='application/octet-stream',
- secure=False, **extra):
+ def delete(
+ self,
+ path,
+ data="",
+ content_type="application/octet-stream",
+ secure=False,
+ **extra,
+ ):
"""Construct a DELETE request."""
data = self._encode_json(data, content_type)
- return self.generic('DELETE', path, data, content_type,
- secure=secure, **extra)
+ return self.generic("DELETE", path, data, content_type, secure=secure, **extra)
- def generic(self, method, path, data='',
- content_type='application/octet-stream', secure=False,
- **extra):
+ def generic(
+ self,
+ method,
+ path,
+ data="",
+ content_type="application/octet-stream",
+ secure=False,
+ **extra,
+ ):
"""Construct an arbitrary HTTP request."""
parsed = urlparse(str(path)) # path can be lazy
data = force_bytes(data, settings.DEFAULT_CHARSET)
r = {
- 'PATH_INFO': self._get_path(parsed),
- 'REQUEST_METHOD': method,
- 'SERVER_PORT': '443' if secure else '80',
- 'wsgi.url_scheme': 'https' if secure else 'http',
+ "PATH_INFO": self._get_path(parsed),
+ "REQUEST_METHOD": method,
+ "SERVER_PORT": "443" if secure else "80",
+ "wsgi.url_scheme": "https" if secure else "http",
}
if data:
- r.update({
- 'CONTENT_LENGTH': str(len(data)),
- 'CONTENT_TYPE': content_type,
- 'wsgi.input': FakePayload(data),
- })
+ r.update(
+ {
+ "CONTENT_LENGTH": str(len(data)),
+ "CONTENT_TYPE": content_type,
+ "wsgi.input": FakePayload(data),
+ }
+ )
r.update(extra)
# If QUERY_STRING is absent or empty, we want to extract it from the URL.
- if not r.get('QUERY_STRING'):
+ if not r.get("QUERY_STRING"):
# WSGI requires latin-1 encoded strings. See get_path_info().
- query_string = parsed[4].encode().decode('iso-8859-1')
- r['QUERY_STRING'] = query_string
+ query_string = parsed[4].encode().decode("iso-8859-1")
+ r["QUERY_STRING"] = query_string
return self.request(**r)
@@ -487,30 +555,35 @@ class AsyncRequestFactory(RequestFactory):
a) this makes ASGIRequest subclasses, and
b) AsyncTestClient can subclass it.
"""
+
def _base_scope(self, **request):
"""The base scope for a request."""
# This is a minimal valid ASGI scope, plus:
# - headers['cookie'] for cookie support,
# - 'client' often useful, see #8551.
scope = {
- 'asgi': {'version': '3.0'},
- 'type': 'http',
- 'http_version': '1.1',
- 'client': ['127.0.0.1', 0],
- 'server': ('testserver', '80'),
- 'scheme': 'http',
- 'method': 'GET',
- 'headers': [],
+ "asgi": {"version": "3.0"},
+ "type": "http",
+ "http_version": "1.1",
+ "client": ["127.0.0.1", 0],
+ "server": ("testserver", "80"),
+ "scheme": "http",
+ "method": "GET",
+ "headers": [],
**self.defaults,
**request,
}
- scope['headers'].append((
- b'cookie',
- b'; '.join(sorted(
- ('%s=%s' % (morsel.key, morsel.coded_value)).encode('ascii')
- for morsel in self.cookies.values()
- )),
- ))
+ scope["headers"].append(
+ (
+ b"cookie",
+ b"; ".join(
+ sorted(
+ ("%s=%s" % (morsel.key, morsel.coded_value)).encode("ascii")
+ for morsel in self.cookies.values()
+ )
+ ),
+ )
+ )
return scope
def request(self, **request):
@@ -518,45 +591,52 @@ class AsyncRequestFactory(RequestFactory):
# This is synchronous, which means all methods on this class are.
# AsyncClient, however, has an async request function, which makes all
# its methods async.
- if '_body_file' in request:
- body_file = request.pop('_body_file')
+ if "_body_file" in request:
+ body_file = request.pop("_body_file")
else:
- body_file = FakePayload('')
+ body_file = FakePayload("")
return ASGIRequest(self._base_scope(**request), body_file)
def generic(
- self, method, path, data='', content_type='application/octet-stream',
- secure=False, **extra,
+ self,
+ method,
+ path,
+ data="",
+ content_type="application/octet-stream",
+ secure=False,
+ **extra,
):
"""Construct an arbitrary HTTP request."""
parsed = urlparse(str(path)) # path can be lazy.
data = force_bytes(data, settings.DEFAULT_CHARSET)
s = {
- 'method': method,
- 'path': self._get_path(parsed),
- 'server': ('127.0.0.1', '443' if secure else '80'),
- 'scheme': 'https' if secure else 'http',
- 'headers': [(b'host', b'testserver')],
+ "method": method,
+ "path": self._get_path(parsed),
+ "server": ("127.0.0.1", "443" if secure else "80"),
+ "scheme": "https" if secure else "http",
+ "headers": [(b"host", b"testserver")],
}
if data:
- s['headers'].extend([
- (b'content-length', str(len(data)).encode('ascii')),
- (b'content-type', content_type.encode('ascii')),
- ])
- s['_body_file'] = FakePayload(data)
- follow = extra.pop('follow', None)
+ s["headers"].extend(
+ [
+ (b"content-length", str(len(data)).encode("ascii")),
+ (b"content-type", content_type.encode("ascii")),
+ ]
+ )
+ s["_body_file"] = FakePayload(data)
+ follow = extra.pop("follow", None)
if follow is not None:
- s['follow'] = follow
- if query_string := extra.pop('QUERY_STRING', None):
- s['query_string'] = query_string
- s['headers'] += [
- (key.lower().encode('ascii'), value.encode('latin1'))
+ s["follow"] = follow
+ if query_string := extra.pop("QUERY_STRING", None):
+ s["query_string"] = query_string
+ s["headers"] += [
+ (key.lower().encode("ascii"), value.encode("latin1"))
for key, value in extra.items()
]
# If QUERY_STRING is absent or empty, we want to extract it from the
# URL.
- if not s.get('query_string'):
- s['query_string'] = parsed[4]
+ if not s.get("query_string"):
+ s["query_string"] = parsed[4]
return self.request(**s)
@@ -564,6 +644,7 @@ class ClientMixin:
"""
Mixin with common methods between Client and AsyncClient.
"""
+
def store_exc_info(self, **kwargs):
"""Store exceptions when they are generated by a view."""
self.exc_info = sys.exc_info()
@@ -601,6 +682,7 @@ class ClientMixin:
are incorrect.
"""
from django.contrib.auth import authenticate
+
user = authenticate(**credentials)
if user:
self._login(user)
@@ -610,9 +692,10 @@ class ClientMixin:
def force_login(self, user, backend=None):
def get_backend():
from django.contrib.auth import load_backend
+
for backend_path in settings.AUTHENTICATION_BACKENDS:
backend = load_backend(backend_path)
- if hasattr(backend, 'get_user'):
+ if hasattr(backend, "get_user"):
return backend_path
if backend is None:
@@ -637,17 +720,18 @@ class ClientMixin:
session_cookie = settings.SESSION_COOKIE_NAME
self.cookies[session_cookie] = request.session.session_key
cookie_data = {
- 'max-age': None,
- 'path': '/',
- 'domain': settings.SESSION_COOKIE_DOMAIN,
- 'secure': settings.SESSION_COOKIE_SECURE or None,
- 'expires': None,
+ "max-age": None,
+ "path": "/",
+ "domain": settings.SESSION_COOKIE_DOMAIN,
+ "secure": settings.SESSION_COOKIE_SECURE or None,
+ "expires": None,
}
self.cookies[session_cookie].update(cookie_data)
def logout(self):
"""Log out the user by removing the cookies and session object."""
from django.contrib.auth import get_user, logout
+
request = HttpRequest()
if self.session:
request.session = self.session
@@ -659,13 +743,15 @@ class ClientMixin:
self.cookies = SimpleCookie()
def _parse_json(self, response, **extra):
- if not hasattr(response, '_json'):
- if not JSON_CONTENT_TYPE_RE.match(response.get('Content-Type')):
+ if not hasattr(response, "_json"):
+ if not JSON_CONTENT_TYPE_RE.match(response.get("Content-Type")):
raise ValueError(
'Content-Type header is "%s", not "application/json"'
- % response.get('Content-Type')
+ % response.get("Content-Type")
)
- response._json = json.loads(response.content.decode(response.charset), **extra)
+ response._json = json.loads(
+ response.content.decode(response.charset), **extra
+ )
return response._json
@@ -687,7 +773,10 @@ class Client(ClientMixin, RequestFactory):
contexts and templates produced by a view, rather than the
HTML rendered to the end-user.
"""
- def __init__(self, enforce_csrf_checks=False, raise_request_exception=True, **defaults):
+
+ def __init__(
+ self, enforce_csrf_checks=False, raise_request_exception=True, **defaults
+ ):
super().__init__(**defaults)
self.handler = ClientHandler(enforce_csrf_checks)
self.raise_request_exception = raise_request_exception
@@ -723,13 +812,13 @@ class Client(ClientMixin, RequestFactory):
response.client = self
response.request = request
# Add any rendered template detail to the response.
- response.templates = data.get('templates', [])
- response.context = data.get('context')
+ response.templates = data.get("templates", [])
+ response.context = data.get("context")
response.json = partial(self._parse_json, response)
# Attach the ResolverMatch instance to the response.
- urlconf = getattr(response.wsgi_request, 'urlconf', None)
+ urlconf = getattr(response.wsgi_request, "urlconf", None)
response.resolver_match = SimpleLazyObject(
- lambda: resolve(request['PATH_INFO'], urlconf=urlconf),
+ lambda: resolve(request["PATH_INFO"], urlconf=urlconf),
)
# Flatten a single context. Not really necessary anymore thanks to the
# __getattr__ flattening in ContextList, but has some edge case
@@ -749,13 +838,24 @@ class Client(ClientMixin, RequestFactory):
response = self._handle_redirects(response, data=data, **extra)
return response
- def post(self, path, data=None, content_type=MULTIPART_CONTENT,
- follow=False, secure=False, **extra):
+ def post(
+ self,
+ path,
+ data=None,
+ content_type=MULTIPART_CONTENT,
+ follow=False,
+ secure=False,
+ **extra,
+ ):
"""Request a response from the server using POST."""
self.extra = extra
- response = super().post(path, data=data, content_type=content_type, secure=secure, **extra)
+ response = super().post(
+ path, data=data, content_type=content_type, secure=secure, **extra
+ )
if follow:
- response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
+ response = self._handle_redirects(
+ response, data=data, content_type=content_type, **extra
+ )
return response
def head(self, path, data=None, follow=False, secure=False, **extra):
@@ -766,43 +866,87 @@ class Client(ClientMixin, RequestFactory):
response = self._handle_redirects(response, data=data, **extra)
return response
- def options(self, path, data='', content_type='application/octet-stream',
- follow=False, secure=False, **extra):
+ def options(
+ self,
+ path,
+ data="",
+ content_type="application/octet-stream",
+ follow=False,
+ secure=False,
+ **extra,
+ ):
"""Request a response from the server using OPTIONS."""
self.extra = extra
- response = super().options(path, data=data, content_type=content_type, secure=secure, **extra)
+ response = super().options(
+ path, data=data, content_type=content_type, secure=secure, **extra
+ )
if follow:
- response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
+ response = self._handle_redirects(
+ response, data=data, content_type=content_type, **extra
+ )
return response
- def put(self, path, data='', content_type='application/octet-stream',
- follow=False, secure=False, **extra):
+ def put(
+ self,
+ path,
+ data="",
+ content_type="application/octet-stream",
+ follow=False,
+ secure=False,
+ **extra,
+ ):
"""Send a resource to the server using PUT."""
self.extra = extra
- response = super().put(path, data=data, content_type=content_type, secure=secure, **extra)
+ response = super().put(
+ path, data=data, content_type=content_type, secure=secure, **extra
+ )
if follow:
- response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
+ response = self._handle_redirects(
+ response, data=data, content_type=content_type, **extra
+ )
return response
- def patch(self, path, data='', content_type='application/octet-stream',
- follow=False, secure=False, **extra):
+ def patch(
+ self,
+ path,
+ data="",
+ content_type="application/octet-stream",
+ follow=False,
+ secure=False,
+ **extra,
+ ):
"""Send a resource to the server using PATCH."""
self.extra = extra
- response = super().patch(path, data=data, content_type=content_type, secure=secure, **extra)
+ response = super().patch(
+ path, data=data, content_type=content_type, secure=secure, **extra
+ )
if follow:
- response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
+ response = self._handle_redirects(
+ response, data=data, content_type=content_type, **extra
+ )
return response
- def delete(self, path, data='', content_type='application/octet-stream',
- follow=False, secure=False, **extra):
+ def delete(
+ self,
+ path,
+ data="",
+ content_type="application/octet-stream",
+ follow=False,
+ secure=False,
+ **extra,
+ ):
"""Send a DELETE request to the server."""
self.extra = extra
- response = super().delete(path, data=data, content_type=content_type, secure=secure, **extra)
+ response = super().delete(
+ path, data=data, content_type=content_type, secure=secure, **extra
+ )
if follow:
- response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
+ response = self._handle_redirects(
+ response, data=data, content_type=content_type, **extra
+ )
return response
- def trace(self, path, data='', follow=False, secure=False, **extra):
+ def trace(self, path, data="", follow=False, secure=False, **extra):
"""Send a TRACE request to the server."""
self.extra = extra
response = super().trace(path, data=data, secure=secure, **extra)
@@ -810,7 +954,7 @@ class Client(ClientMixin, RequestFactory):
response = self._handle_redirects(response, data=data, **extra)
return response
- def _handle_redirects(self, response, data='', content_type='', **extra):
+ def _handle_redirects(self, response, data="", content_type="", **extra):
"""
Follow any redirects by requesting responses from the server using GET.
"""
@@ -829,39 +973,46 @@ class Client(ClientMixin, RequestFactory):
url = urlsplit(response_url)
if url.scheme:
- extra['wsgi.url_scheme'] = url.scheme
+ extra["wsgi.url_scheme"] = url.scheme
if url.hostname:
- extra['SERVER_NAME'] = url.hostname
+ extra["SERVER_NAME"] = url.hostname
if url.port:
- extra['SERVER_PORT'] = str(url.port)
+ extra["SERVER_PORT"] = str(url.port)
path = url.path
# RFC 2616: bare domains without path are treated as the root.
if not path and url.netloc:
- path = '/'
+ path = "/"
# Prepend the request path to handle relative path redirects
- if not path.startswith('/'):
- path = urljoin(response.request['PATH_INFO'], path)
+ if not path.startswith("/"):
+ path = urljoin(response.request["PATH_INFO"], path)
- if response.status_code in (HTTPStatus.TEMPORARY_REDIRECT, HTTPStatus.PERMANENT_REDIRECT):
+ if response.status_code in (
+ HTTPStatus.TEMPORARY_REDIRECT,
+ HTTPStatus.PERMANENT_REDIRECT,
+ ):
# Preserve request method and query string (if needed)
# post-redirect for 307/308 responses.
- request_method = response.request['REQUEST_METHOD'].lower()
- if request_method not in ('get', 'head'):
- extra['QUERY_STRING'] = url.query
+ request_method = response.request["REQUEST_METHOD"].lower()
+ if request_method not in ("get", "head"):
+ extra["QUERY_STRING"] = url.query
request_method = getattr(self, request_method)
else:
request_method = self.get
data = QueryDict(url.query)
content_type = None
- response = request_method(path, data=data, content_type=content_type, follow=False, **extra)
+ response = request_method(
+ path, data=data, content_type=content_type, follow=False, **extra
+ )
response.redirect_chain = redirect_chain
if redirect_chain[-1] in redirect_chain[:-1]:
# Check that we're not redirecting to somewhere we've already
# been to, to prevent loops.
- raise RedirectCycleError("Redirect loop detected.", last_response=response)
+ raise RedirectCycleError(
+ "Redirect loop detected.", last_response=response
+ )
if len(redirect_chain) > 20:
# Such a lengthy chain likely also means a loop, but one with
# a growing path, changing view, or changing query argument;
@@ -878,7 +1029,10 @@ class AsyncClient(ClientMixin, AsyncRequestFactory):
Does not currently support "follow" on its methods.
"""
- def __init__(self, enforce_csrf_checks=False, raise_request_exception=True, **defaults):
+
+ def __init__(
+ self, enforce_csrf_checks=False, raise_request_exception=True, **defaults
+ ):
super().__init__(**defaults)
self.handler = AsyncClientHandler(enforce_csrf_checks)
self.raise_request_exception = raise_request_exception
@@ -892,19 +1046,19 @@ class AsyncClient(ClientMixin, AsyncRequestFactory):
query environment, which can be overridden using the arguments to the
request.
"""
- if 'follow' in request:
+ if "follow" in request:
raise NotImplementedError(
- 'AsyncClient request methods do not accept the follow parameter.'
+ "AsyncClient request methods do not accept the follow parameter."
)
scope = self._base_scope(**request)
# Curry a data dictionary into an instance of the template renderer
# callback function.
data = {}
on_template_render = partial(store_rendered_templates, data)
- signal_uid = 'template-render-%s' % id(request)
+ signal_uid = "template-render-%s" % id(request)
signals.template_rendered.connect(on_template_render, dispatch_uid=signal_uid)
# Capture exceptions created by the handler.
- exception_uid = 'request-exception-%s' % id(request)
+ exception_uid = "request-exception-%s" % id(request)
got_request_exception.connect(self.store_exc_info, dispatch_uid=exception_uid)
try:
response = await self.handler(scope)
@@ -917,13 +1071,13 @@ class AsyncClient(ClientMixin, AsyncRequestFactory):
response.client = self
response.request = request
# Add any rendered template detail to the response.
- response.templates = data.get('templates', [])
- response.context = data.get('context')
+ response.templates = data.get("templates", [])
+ response.context = data.get("context")
response.json = partial(self._parse_json, response)
# Attach the ResolverMatch instance to the response.
- urlconf = getattr(response.asgi_request, 'urlconf', None)
+ urlconf = getattr(response.asgi_request, "urlconf", None)
response.resolver_match = SimpleLazyObject(
- lambda: resolve(request['path'], urlconf=urlconf),
+ lambda: resolve(request["path"], urlconf=urlconf),
)
# Flatten a single context. Not really necessary anymore thanks to the
# __getattr__ flattening in ContextList, but has some edge case
diff --git a/django/test/html.py b/django/test/html.py
index 07e986439b..87e213d651 100644
--- a/django/test/html.py
+++ b/django/test/html.py
@@ -7,32 +7,52 @@ from django.utils.regex_helper import _lazy_re_compile
# ASCII whitespace is U+0009 TAB, U+000A LF, U+000C FF, U+000D CR, or U+0020
# SPACE.
# https://infra.spec.whatwg.org/#ascii-whitespace
-ASCII_WHITESPACE = _lazy_re_compile(r'[\t\n\f\r ]+')
+ASCII_WHITESPACE = _lazy_re_compile(r"[\t\n\f\r ]+")
# https://html.spec.whatwg.org/#attributes-3
BOOLEAN_ATTRIBUTES = {
- 'allowfullscreen', 'async', 'autofocus', 'autoplay', 'checked', 'controls',
- 'default', 'defer ', 'disabled', 'formnovalidate', 'hidden', 'ismap',
- 'itemscope', 'loop', 'multiple', 'muted', 'nomodule', 'novalidate', 'open',
- 'playsinline', 'readonly', 'required', 'reversed', 'selected',
+ "allowfullscreen",
+ "async",
+ "autofocus",
+ "autoplay",
+ "checked",
+ "controls",
+ "default",
+ "defer ",
+ "disabled",
+ "formnovalidate",
+ "hidden",
+ "ismap",
+ "itemscope",
+ "loop",
+ "multiple",
+ "muted",
+ "nomodule",
+ "novalidate",
+ "open",
+ "playsinline",
+ "readonly",
+ "required",
+ "reversed",
+ "selected",
# Attributes for deprecated tags.
- 'truespeed',
+ "truespeed",
}
def normalize_whitespace(string):
- return ASCII_WHITESPACE.sub(' ', string)
+ return ASCII_WHITESPACE.sub(" ", string)
def normalize_attributes(attributes):
normalized = []
for name, value in attributes:
- if name == 'class' and value:
+ if name == "class" and value:
# Special case handling of 'class' attribute, so that comparisons
# of DOM instances are not sensitive to ordering of classes.
- value = ' '.join(sorted(
- value for value in ASCII_WHITESPACE.split(value) if value
- ))
+ value = " ".join(
+ sorted(value for value in ASCII_WHITESPACE.split(value) if value)
+ )
# Boolean attributes without a value is same as attribute with value
# that equals the attributes name. For example:
# <input checked> == <input checked="checked">
@@ -40,7 +60,7 @@ def normalize_attributes(attributes):
if not value or value == name:
value = None
elif value is None:
- value = ''
+ value = ""
normalized.append((name, value))
return normalized
@@ -80,11 +100,11 @@ class Element:
for i, child in enumerate(self.children):
if isinstance(child, str):
self.children[i] = child.strip()
- elif hasattr(child, 'finalize'):
+ elif hasattr(child, "finalize"):
child.finalize()
def __eq__(self, element):
- if not hasattr(element, 'name') or self.name != element.name:
+ if not hasattr(element, "name") or self.name != element.name:
return False
if self.attributes != element.attributes:
return False
@@ -142,21 +162,23 @@ class Element:
return self.children[key]
def __str__(self):
- output = '<%s' % self.name
+ output = "<%s" % self.name
for key, value in self.attributes:
if value is not None:
output += ' %s="%s"' % (key, value)
else:
- output += ' %s' % key
+ output += " %s" % key
if self.children:
- output += '>\n'
- output += ''.join([
- html.escape(c) if isinstance(c, str) else str(c)
- for c in self.children
- ])
- output += '\n</%s>' % self.name
+ output += ">\n"
+ output += "".join(
+ [
+ html.escape(c) if isinstance(c, str) else str(c)
+ for c in self.children
+ ]
+ )
+ output += "\n</%s>" % self.name
else:
- output += '>'
+ output += ">"
return output
def __repr__(self):
@@ -168,10 +190,9 @@ class RootElement(Element):
super().__init__(None, ())
def __str__(self):
- return ''.join([
- html.escape(c) if isinstance(c, str) else str(c)
- for c in self.children
- ])
+ return "".join(
+ [html.escape(c) if isinstance(c, str) else str(c) for c in self.children]
+ )
class HTMLParseError(Exception):
@@ -181,10 +202,23 @@ class HTMLParseError(Exception):
class Parser(HTMLParser):
# https://html.spec.whatwg.org/#void-elements
SELF_CLOSING_TAGS = {
- 'area', 'base', 'br', 'col', 'embed', 'hr', 'img', 'input', 'link', 'meta',
- 'param', 'source', 'track', 'wbr',
+ "area",
+ "base",
+ "br",
+ "col",
+ "embed",
+ "hr",
+ "img",
+ "input",
+ "link",
+ "meta",
+ "param",
+ "source",
+ "track",
+ "wbr",
# Deprecated tags
- 'frame', 'spacer',
+ "frame",
+ "spacer",
}
def __init__(self):
@@ -201,9 +235,9 @@ class Parser(HTMLParser):
position = self.element_positions[element]
if position is None:
position = self.getpos()
- if hasattr(position, 'lineno'):
+ if hasattr(position, "lineno"):
position = position.lineno, position.offset
- return 'Line %d, Column %d' % position
+ return "Line %d, Column %d" % position
@property
def current(self):
@@ -227,13 +261,13 @@ class Parser(HTMLParser):
def handle_endtag(self, tag):
if not self.open_tags:
- self.error("Unexpected end tag `%s` (%s)" % (
- tag, self.format_position()))
+ self.error("Unexpected end tag `%s` (%s)" % (tag, self.format_position()))
element = self.open_tags.pop()
while element.name != tag:
if not self.open_tags:
- self.error("Unexpected end tag `%s` (%s)" % (
- tag, self.format_position()))
+ self.error(
+ "Unexpected end tag `%s` (%s)" % (tag, self.format_position())
+ )
element = self.open_tags.pop()
def handle_data(self, data):
diff --git a/django/test/runner.py b/django/test/runner.py
index 2e36514922..113d5216a6 100644
--- a/django/test/runner.py
+++ b/django/test/runner.py
@@ -20,11 +20,11 @@ from io import StringIO
from django.core.management import call_command
from django.db import connections
from django.test import SimpleTestCase, TestCase
-from django.test.utils import (
- NullTimeKeeper, TimeKeeper, iter_test_cases,
- setup_databases as _setup_databases, setup_test_environment,
- teardown_databases as _teardown_databases, teardown_test_environment,
-)
+from django.test.utils import NullTimeKeeper, TimeKeeper, iter_test_cases
+from django.test.utils import setup_databases as _setup_databases
+from django.test.utils import setup_test_environment
+from django.test.utils import teardown_databases as _teardown_databases
+from django.test.utils import teardown_test_environment
from django.utils.crypto import new_hash
from django.utils.datastructures import OrderedSet
from django.utils.deprecation import RemovedInDjango50Warning
@@ -42,7 +42,7 @@ except ImportError:
class DebugSQLTextTestResult(unittest.TextTestResult):
def __init__(self, stream, descriptions, verbosity):
- self.logger = logging.getLogger('django.db.backends')
+ self.logger = logging.getLogger("django.db.backends")
self.logger.setLevel(logging.DEBUG)
self.debug_sql_stream = None
super().__init__(stream, descriptions, verbosity)
@@ -65,7 +65,7 @@ class DebugSQLTextTestResult(unittest.TextTestResult):
super().addError(test, err)
if self.debug_sql_stream is None:
# Error before tests e.g. in setUpTestData().
- sql = ''
+ sql = ""
else:
self.debug_sql_stream.seek(0)
sql = self.debug_sql_stream.read()
@@ -80,7 +80,11 @@ class DebugSQLTextTestResult(unittest.TextTestResult):
super().addSubTest(test, subtest, err)
if err is not None:
self.debug_sql_stream.seek(0)
- errors = self.failures if issubclass(err[0], test.failureException) else self.errors
+ errors = (
+ self.failures
+ if issubclass(err[0], test.failureException)
+ else self.errors
+ )
errors[-1] = errors[-1] + (self.debug_sql_stream.read(),)
def printErrorList(self, flavour, errors):
@@ -124,6 +128,7 @@ class DummyList:
"""
Dummy list class for faking storage of results in unittest.TestResult.
"""
+
__slots__ = ()
def append(self, item):
@@ -157,10 +162,10 @@ class RemoteTestResult(unittest.TestResult):
# attributes. This is possible since they aren't used after unpickling
# after being sent to ParallelTestSuite.
state = self.__dict__.copy()
- state.pop('_stdout_buffer', None)
- state.pop('_stderr_buffer', None)
- state.pop('_original_stdout', None)
- state.pop('_original_stderr', None)
+ state.pop("_stdout_buffer", None)
+ state.pop("_stderr_buffer", None)
+ state.pop("_original_stdout", None)
+ state.pop("_original_stderr", None)
return state
@property
@@ -176,7 +181,8 @@ class RemoteTestResult(unittest.TestResult):
pickle.loads(pickle.dumps(obj))
def _print_unpicklable_subtest(self, test, subtest, pickle_exc):
- print("""
+ print(
+ """
Subtest failed:
test: {}
@@ -189,7 +195,10 @@ test runner cannot handle it cleanly. Here is the pickling error:
You should re-run this test with --parallel=1 to reproduce the failure
with a cleaner failure message.
-""".format(test, subtest, pickle_exc))
+""".format(
+ test, subtest, pickle_exc
+ )
+ )
def check_picklable(self, test, err):
# Ensure that sys.exc_info() tuples are picklable. This displays a
@@ -202,11 +211,16 @@ with a cleaner failure message.
self._confirm_picklable(err)
except Exception as exc:
original_exc_txt = repr(err[1])
- original_exc_txt = textwrap.fill(original_exc_txt, 75, initial_indent=' ', subsequent_indent=' ')
+ original_exc_txt = textwrap.fill(
+ original_exc_txt, 75, initial_indent=" ", subsequent_indent=" "
+ )
pickle_exc_txt = repr(exc)
- pickle_exc_txt = textwrap.fill(pickle_exc_txt, 75, initial_indent=' ', subsequent_indent=' ')
+ pickle_exc_txt = textwrap.fill(
+ pickle_exc_txt, 75, initial_indent=" ", subsequent_indent=" "
+ )
if tblib is None:
- print("""
+ print(
+ """
{} failed:
@@ -218,9 +232,13 @@ parallel test runner to handle this exception cleanly.
In order to see the traceback, you should install tblib:
python -m pip install tblib
-""".format(test, original_exc_txt))
+""".format(
+ test, original_exc_txt
+ )
+ )
else:
- print("""
+ print(
+ """
{} failed:
@@ -235,7 +253,10 @@ Here's the error encountered while trying to pickle the exception:
You should re-run this test with the --parallel=1 option to reproduce the
failure and get a correct traceback.
-""".format(test, original_exc_txt, pickle_exc_txt))
+""".format(
+ test, original_exc_txt, pickle_exc_txt
+ )
+ )
raise
def check_subtest_picklable(self, test, subtest):
@@ -247,28 +268,28 @@ failure and get a correct traceback.
def startTestRun(self):
super().startTestRun()
- self.events.append(('startTestRun',))
+ self.events.append(("startTestRun",))
def stopTestRun(self):
super().stopTestRun()
- self.events.append(('stopTestRun',))
+ self.events.append(("stopTestRun",))
def startTest(self, test):
super().startTest(test)
- self.events.append(('startTest', self.test_index))
+ self.events.append(("startTest", self.test_index))
def stopTest(self, test):
super().stopTest(test)
- self.events.append(('stopTest', self.test_index))
+ self.events.append(("stopTest", self.test_index))
def addError(self, test, err):
self.check_picklable(test, err)
- self.events.append(('addError', self.test_index, err))
+ self.events.append(("addError", self.test_index, err))
super().addError(test, err)
def addFailure(self, test, err):
self.check_picklable(test, err)
- self.events.append(('addFailure', self.test_index, err))
+ self.events.append(("addFailure", self.test_index, err))
super().addFailure(test, err)
def addSubTest(self, test, subtest, err):
@@ -279,15 +300,15 @@ failure and get a correct traceback.
# check_picklable() performs the tblib check.
self.check_picklable(test, err)
self.check_subtest_picklable(test, subtest)
- self.events.append(('addSubTest', self.test_index, subtest, err))
+ self.events.append(("addSubTest", self.test_index, subtest, err))
super().addSubTest(test, subtest, err)
def addSuccess(self, test):
- self.events.append(('addSuccess', self.test_index))
+ self.events.append(("addSuccess", self.test_index))
super().addSuccess(test)
def addSkip(self, test, reason):
- self.events.append(('addSkip', self.test_index, reason))
+ self.events.append(("addSkip", self.test_index, reason))
super().addSkip(test, reason)
def addExpectedFailure(self, test, err):
@@ -298,23 +319,23 @@ failure and get a correct traceback.
if tblib is None:
err = err[0], err[1], None
self.check_picklable(test, err)
- self.events.append(('addExpectedFailure', self.test_index, err))
+ self.events.append(("addExpectedFailure", self.test_index, err))
super().addExpectedFailure(test, err)
def addUnexpectedSuccess(self, test):
- self.events.append(('addUnexpectedSuccess', self.test_index))
+ self.events.append(("addUnexpectedSuccess", self.test_index))
super().addUnexpectedSuccess(test)
def wasSuccessful(self):
"""Tells whether or not this result was a success."""
- failure_types = {'addError', 'addFailure', 'addSubTest', 'addUnexpectedSuccess'}
+ failure_types = {"addError", "addFailure", "addSubTest", "addUnexpectedSuccess"}
return all(e[0] not in failure_types for e in self.events)
def _exc_info_to_string(self, err, test):
# Make this method no-op. It only powers the default unittest behavior
# for recording errors, but this class pickles errors into 'events'
# instead.
- return ''
+ return ""
class RemoteTestRunner:
@@ -347,17 +368,17 @@ def get_max_test_processes():
"""
# The current implementation of the parallel test runner requires
# multiprocessing to start subprocesses with fork().
- if multiprocessing.get_start_method() != 'fork':
+ if multiprocessing.get_start_method() != "fork":
return 1
try:
- return int(os.environ['DJANGO_TEST_PROCESSES'])
+ return int(os.environ["DJANGO_TEST_PROCESSES"])
except KeyError:
return multiprocessing.cpu_count()
def parallel_type(value):
"""Parse value passed to the --parallel option."""
- if value == 'auto':
+ if value == "auto":
return value
try:
return int(value)
@@ -505,30 +526,30 @@ class Shuffler:
"""
# This doesn't need to be cryptographically strong, so use what's fastest.
- hash_algorithm = 'md5'
+ hash_algorithm = "md5"
@classmethod
def _hash_text(cls, text):
h = new_hash(cls.hash_algorithm, usedforsecurity=False)
- h.update(text.encode('utf-8'))
+ h.update(text.encode("utf-8"))
return h.hexdigest()
def __init__(self, seed=None):
if seed is None:
# Limit seeds to 10 digits for simpler output.
seed = random.randint(0, 10**10 - 1)
- seed_source = 'generated'
+ seed_source = "generated"
else:
- seed_source = 'given'
+ seed_source = "given"
self.seed = seed
self.seed_source = seed_source
@property
def seed_display(self):
- return f'{self.seed!r} ({self.seed_source})'
+ return f"{self.seed!r} ({self.seed_source})"
def _hash_item(self, item, key):
- text = '{}{}'.format(self.seed, key(item))
+ text = "{}{}".format(self.seed, key(item))
return self._hash_text(text)
def shuffle(self, items, key):
@@ -544,8 +565,10 @@ class Shuffler:
for item in items:
hashed = self._hash_item(item, key)
if hashed in hashes:
- msg = 'item {!r} has same hash {!r} as item {!r}'.format(
- item, hashed, hashes[hashed],
+ msg = "item {!r} has same hash {!r} as item {!r}".format(
+ item,
+ hashed,
+ hashes[hashed],
)
raise RuntimeError(msg)
hashes[hashed] = item
@@ -561,12 +584,29 @@ class DiscoverRunner:
test_loader = unittest.defaultTestLoader
reorder_by = (TestCase, SimpleTestCase)
- def __init__(self, pattern=None, top_level=None, verbosity=1,
- interactive=True, failfast=False, keepdb=False,
- reverse=False, debug_mode=False, debug_sql=False, parallel=0,
- tags=None, exclude_tags=None, test_name_patterns=None,
- pdb=False, buffer=False, enable_faulthandler=True,
- timing=False, shuffle=False, logger=None, **kwargs):
+ def __init__(
+ self,
+ pattern=None,
+ top_level=None,
+ verbosity=1,
+ interactive=True,
+ failfast=False,
+ keepdb=False,
+ reverse=False,
+ debug_mode=False,
+ debug_sql=False,
+ parallel=0,
+ tags=None,
+ exclude_tags=None,
+ test_name_patterns=None,
+ pdb=False,
+ buffer=False,
+ enable_faulthandler=True,
+ timing=False,
+ shuffle=False,
+ logger=None,
+ **kwargs,
+ ):
self.pattern = pattern
self.top_level = top_level
@@ -587,7 +627,9 @@ class DiscoverRunner:
faulthandler.enable(file=sys.__stderr__.fileno())
self.pdb = pdb
if self.pdb and self.parallel > 1:
- raise ValueError('You cannot use --pdb with parallel tests; pass --parallel=1 to use it.')
+ raise ValueError(
+ "You cannot use --pdb with parallel tests; pass --parallel=1 to use it."
+ )
self.buffer = buffer
self.test_name_patterns = None
self.time_keeper = TimeKeeper() if timing else NullTimeKeeper()
@@ -595,7 +637,7 @@ class DiscoverRunner:
# unittest does not export the _convert_select_pattern function
# that converts command-line arguments to patterns.
self.test_name_patterns = {
- pattern if '*' in pattern else '*%s*' % pattern
+ pattern if "*" in pattern else "*%s*" % pattern
for pattern in test_name_patterns
}
self.shuffle = shuffle
@@ -605,73 +647,99 @@ class DiscoverRunner:
@classmethod
def add_arguments(cls, parser):
parser.add_argument(
- '-t', '--top-level-directory', dest='top_level',
- help='Top level of project for unittest discovery.',
+ "-t",
+ "--top-level-directory",
+ dest="top_level",
+ help="Top level of project for unittest discovery.",
)
parser.add_argument(
- '-p', '--pattern', default="test*.py",
- help='The test matching pattern. Defaults to test*.py.',
+ "-p",
+ "--pattern",
+ default="test*.py",
+ help="The test matching pattern. Defaults to test*.py.",
)
parser.add_argument(
- '--keepdb', action='store_true',
- help='Preserves the test DB between runs.'
+ "--keepdb", action="store_true", help="Preserves the test DB between runs."
)
parser.add_argument(
- '--shuffle', nargs='?', default=False, type=int, metavar='SEED',
- help='Shuffles test case order.',
+ "--shuffle",
+ nargs="?",
+ default=False,
+ type=int,
+ metavar="SEED",
+ help="Shuffles test case order.",
)
parser.add_argument(
- '-r', '--reverse', action='store_true',
- help='Reverses test case order.',
+ "-r",
+ "--reverse",
+ action="store_true",
+ help="Reverses test case order.",
)
parser.add_argument(
- '--debug-mode', action='store_true',
- help='Sets settings.DEBUG to True.',
+ "--debug-mode",
+ action="store_true",
+ help="Sets settings.DEBUG to True.",
)
parser.add_argument(
- '-d', '--debug-sql', action='store_true',
- help='Prints logged SQL queries on failure.',
+ "-d",
+ "--debug-sql",
+ action="store_true",
+ help="Prints logged SQL queries on failure.",
)
parser.add_argument(
- '--parallel', nargs='?', const='auto', default=0,
- type=parallel_type, metavar='N',
+ "--parallel",
+ nargs="?",
+ const="auto",
+ default=0,
+ type=parallel_type,
+ metavar="N",
help=(
- 'Run tests using up to N parallel processes. Use the value '
+ "Run tests using up to N parallel processes. Use the value "
'"auto" to run one test process for each processor core.'
),
)
parser.add_argument(
- '--tag', action='append', dest='tags',
- help='Run only tests with the specified tag. Can be used multiple times.',
+ "--tag",
+ action="append",
+ dest="tags",
+ help="Run only tests with the specified tag. Can be used multiple times.",
)
parser.add_argument(
- '--exclude-tag', action='append', dest='exclude_tags',
- help='Do not run tests with the specified tag. Can be used multiple times.',
+ "--exclude-tag",
+ action="append",
+ dest="exclude_tags",
+ help="Do not run tests with the specified tag. Can be used multiple times.",
)
parser.add_argument(
- '--pdb', action='store_true',
- help='Runs a debugger (pdb, or ipdb if installed) on error or failure.'
+ "--pdb",
+ action="store_true",
+ help="Runs a debugger (pdb, or ipdb if installed) on error or failure.",
)
parser.add_argument(
- '-b', '--buffer', action='store_true',
- help='Discard output from passing tests.',
+ "-b",
+ "--buffer",
+ action="store_true",
+ help="Discard output from passing tests.",
)
parser.add_argument(
- '--no-faulthandler', action='store_false', dest='enable_faulthandler',
- help='Disables the Python faulthandler module during tests.',
+ "--no-faulthandler",
+ action="store_false",
+ dest="enable_faulthandler",
+ help="Disables the Python faulthandler module during tests.",
)
parser.add_argument(
- '--timing', action='store_true',
- help=(
- 'Output timings, including database set up and total run time.'
- ),
+ "--timing",
+ action="store_true",
+ help=("Output timings, including database set up and total run time."),
)
parser.add_argument(
- '-k', action='append', dest='test_name_patterns',
+ "-k",
+ action="append",
+ dest="test_name_patterns",
help=(
- 'Only run test methods and classes that match the pattern '
- 'or substring. Can be used multiple times. Same as '
- 'unittest -k option.'
+ "Only run test methods and classes that match the pattern "
+ "or substring. Can be used multiple times. Same as "
+ "unittest -k option."
),
)
@@ -693,9 +761,7 @@ class DiscoverRunner:
if level is None:
level = logging.INFO
if self.logger is None:
- if self.verbosity <= 0 or (
- self.verbosity == 1 and level < logging.INFO
- ):
+ if self.verbosity <= 0 or (self.verbosity == 1 and level < logging.INFO):
return
print(msg)
else:
@@ -709,7 +775,7 @@ class DiscoverRunner:
if self.shuffle is False:
return
shuffler = Shuffler(seed=self.shuffle)
- self.log(f'Using shuffle seed: {shuffler.seed_display}')
+ self.log(f"Using shuffle seed: {shuffler.seed_display}")
self._shuffler = shuffler
@contextmanager
@@ -741,15 +807,15 @@ class DiscoverRunner:
if os.path.exists(label_as_path):
assert tests is None
raise RuntimeError(
- f'One of the test labels is a path to a file: {label!r}, '
- f'which is not supported. Use a dotted module name or '
- f'path to a directory instead.'
+ f"One of the test labels is a path to a file: {label!r}, "
+ f"which is not supported. Use a dotted module name or "
+ f"path to a directory instead."
)
return tests
kwargs = discover_kwargs.copy()
if os.path.isdir(label_as_path) and not self.top_level:
- kwargs['top_level_dir'] = find_top_level(label_as_path)
+ kwargs["top_level_dir"] = find_top_level(label_as_path)
with self.load_with_patterns():
tests = self.test_loader.discover(start_dir=label, **kwargs)
@@ -762,18 +828,18 @@ class DiscoverRunner:
def build_suite(self, test_labels=None, extra_tests=None, **kwargs):
if extra_tests is not None:
warnings.warn(
- 'The extra_tests argument is deprecated.',
+ "The extra_tests argument is deprecated.",
RemovedInDjango50Warning,
stacklevel=2,
)
- test_labels = test_labels or ['.']
+ test_labels = test_labels or ["."]
extra_tests = extra_tests or []
discover_kwargs = {}
if self.pattern is not None:
- discover_kwargs['pattern'] = self.pattern
+ discover_kwargs["pattern"] = self.pattern
if self.top_level is not None:
- discover_kwargs['top_level_dir'] = self.top_level
+ discover_kwargs["top_level_dir"] = self.top_level
self.setup_shuffler()
all_tests = []
@@ -786,12 +852,12 @@ class DiscoverRunner:
if self.tags or self.exclude_tags:
if self.tags:
self.log(
- 'Including test tag(s): %s.' % ', '.join(sorted(self.tags)),
+ "Including test tag(s): %s." % ", ".join(sorted(self.tags)),
level=logging.DEBUG,
)
if self.exclude_tags:
self.log(
- 'Excluding test tag(s): %s.' % ', '.join(sorted(self.exclude_tags)),
+ "Excluding test tag(s): %s." % ", ".join(sorted(self.exclude_tags)),
level=logging.DEBUG,
)
all_tests = filter_tests_by_tags(all_tests, self.tags, self.exclude_tags)
@@ -800,13 +866,15 @@ class DiscoverRunner:
# _FailedTest objects include things like test modules that couldn't be
# found or that couldn't be loaded due to syntax errors.
test_types = (unittest.loader._FailedTest, *self.reorder_by)
- all_tests = list(reorder_tests(
- all_tests,
- test_types,
- shuffler=self._shuffler,
- reverse=self.reverse,
- ))
- self.log('Found %d test(s).' % len(all_tests))
+ all_tests = list(
+ reorder_tests(
+ all_tests,
+ test_types,
+ shuffler=self._shuffler,
+ reverse=self.reverse,
+ )
+ )
+ self.log("Found %d test(s)." % len(all_tests))
suite = self.test_suite(all_tests)
if self.parallel > 1:
@@ -828,8 +896,13 @@ class DiscoverRunner:
def setup_databases(self, **kwargs):
return _setup_databases(
- self.verbosity, self.interactive, time_keeper=self.time_keeper, keepdb=self.keepdb,
- debug_sql=self.debug_sql, parallel=self.parallel, **kwargs
+ self.verbosity,
+ self.interactive,
+ time_keeper=self.time_keeper,
+ keepdb=self.keepdb,
+ debug_sql=self.debug_sql,
+ parallel=self.parallel,
+ **kwargs,
)
def get_resultclass(self):
@@ -840,16 +913,16 @@ class DiscoverRunner:
def get_test_runner_kwargs(self):
return {
- 'failfast': self.failfast,
- 'resultclass': self.get_resultclass(),
- 'verbosity': self.verbosity,
- 'buffer': self.buffer,
+ "failfast": self.failfast,
+ "resultclass": self.get_resultclass(),
+ "verbosity": self.verbosity,
+ "buffer": self.buffer,
}
def run_checks(self, databases):
# Checks are run after database creation since some checks require
# database access.
- call_command('check', verbosity=self.verbosity, databases=databases)
+ call_command("check", verbosity=self.verbosity, databases=databases)
def run_suite(self, suite, **kwargs):
kwargs = self.get_test_runner_kwargs()
@@ -859,7 +932,7 @@ class DiscoverRunner:
finally:
if self._shuffler is not None:
seed_display = self._shuffler.seed_display
- self.log(f'Used shuffle seed: {seed_display}')
+ self.log(f"Used shuffle seed: {seed_display}")
def teardown_databases(self, old_config, **kwargs):
"""Destroy all the non-mirror databases."""
@@ -875,16 +948,18 @@ class DiscoverRunner:
teardown_test_environment()
def suite_result(self, suite, result, **kwargs):
- return len(result.failures) + len(result.errors) + len(result.unexpectedSuccesses)
+ return (
+ len(result.failures) + len(result.errors) + len(result.unexpectedSuccesses)
+ )
def _get_databases(self, suite):
databases = {}
for test in iter_test_cases(suite):
- test_databases = getattr(test, 'databases', None)
- if test_databases == '__all__':
+ test_databases = getattr(test, "databases", None)
+ if test_databases == "__all__":
test_databases = connections
if test_databases:
- serialized_rollback = getattr(test, 'serialized_rollback', False)
+ serialized_rollback = getattr(test, "serialized_rollback", False)
databases.update(
(alias, serialized_rollback or databases.get(alias, False))
for alias in test_databases
@@ -896,7 +971,8 @@ class DiscoverRunner:
unused_databases = [alias for alias in connections if alias not in databases]
if unused_databases:
self.log(
- 'Skipping setup of unused database(s): %s.' % ', '.join(sorted(unused_databases)),
+ "Skipping setup of unused database(s): %s."
+ % ", ".join(sorted(unused_databases)),
level=logging.DEBUG,
)
return databases
@@ -912,7 +988,7 @@ class DiscoverRunner:
"""
if extra_tests is not None:
warnings.warn(
- 'The extra_tests argument is deprecated.',
+ "The extra_tests argument is deprecated.",
RemovedInDjango50Warning,
stacklevel=2,
)
@@ -920,10 +996,9 @@ class DiscoverRunner:
suite = self.build_suite(test_labels, extra_tests)
databases = self.get_databases(suite)
serialized_aliases = set(
- alias
- for alias, serialize in databases.items() if serialize
+ alias for alias, serialize in databases.items() if serialize
)
- with self.time_keeper.timed('Total database setup'):
+ with self.time_keeper.timed("Total database setup"):
old_config = self.setup_databases(
aliases=databases,
serialized_aliases=serialized_aliases,
@@ -937,7 +1012,7 @@ class DiscoverRunner:
raise
finally:
try:
- with self.time_keeper.timed('Total database teardown'):
+ with self.time_keeper.timed("Total database teardown"):
self.teardown_databases(old_config)
self.teardown_test_environment()
except Exception:
@@ -960,7 +1035,7 @@ def try_importing(label):
except (ImportError, TypeError):
return (False, False)
- return (True, hasattr(mod, '__path__'))
+ return (True, hasattr(mod, "__path__"))
def find_top_level(top_level):
@@ -976,7 +1051,7 @@ def find_top_level(top_level):
# top-level module or as a directory path, unittest unfortunately prefers
# the latter.
while True:
- init_py = os.path.join(top_level, '__init__.py')
+ init_py = os.path.join(top_level, "__init__.py")
if not os.path.exists(init_py):
break
try_next = os.path.dirname(top_level)
@@ -988,7 +1063,7 @@ def find_top_level(top_level):
def _class_shuffle_key(cls):
- return f'{cls.__module__}.{cls.__qualname__}'
+ return f"{cls.__module__}.{cls.__qualname__}"
def shuffle_tests(tests, shuffler):
@@ -1073,9 +1148,7 @@ def partition_suite_by_case(suite):
"""Partition a test suite by test case, preserving the order of tests."""
suite_class = type(suite)
all_tests = iter_test_cases(suite)
- return [
- suite_class(tests) for _, tests in itertools.groupby(all_tests, type)
- ]
+ return [suite_class(tests) for _, tests in itertools.groupby(all_tests, type)]
def test_match_tags(test, tags, exclude_tags):
@@ -1083,11 +1156,11 @@ def test_match_tags(test, tags, exclude_tags):
# Tests that couldn't load always match to prevent tests from falsely
# passing due e.g. to syntax errors.
return True
- test_tags = set(getattr(test, 'tags', []))
- test_fn_name = getattr(test, '_testMethodName', str(test))
+ test_tags = set(getattr(test, "tags", []))
+ test_fn_name = getattr(test, "_testMethodName", str(test))
if hasattr(test, test_fn_name):
test_fn = getattr(test, test_fn_name)
- test_fn_tags = list(getattr(test_fn, 'tags', []))
+ test_fn_tags = list(getattr(test_fn, "tags", []))
test_tags = test_tags.union(test_fn_tags)
if tags and test_tags.isdisjoint(tags):
return False
diff --git a/django/test/selenium.py b/django/test/selenium.py
index 97a7840fea..aa714ad365 100644
--- a/django/test/selenium.py
+++ b/django/test/selenium.py
@@ -27,7 +27,9 @@ class SeleniumTestCaseBase(type(LiveServerTestCase)):
"""
test_class = super().__new__(cls, name, bases, attrs)
# If the test class is either browser-specific or a test base, return it.
- if test_class.browser or not any(name.startswith('test') and callable(value) for name, value in attrs.items()):
+ if test_class.browser or not any(
+ name.startswith("test") and callable(value) for name, value in attrs.items()
+ ):
return test_class
elif test_class.browsers:
# Reuse the created test class to make it browser-specific.
@@ -37,7 +39,7 @@ class SeleniumTestCaseBase(type(LiveServerTestCase)):
first_browser = test_class.browsers[0]
test_class.browser = first_browser
# Listen on an external interface if using a selenium hub.
- host = test_class.host if not test_class.selenium_hub else '0.0.0.0'
+ host = test_class.host if not test_class.selenium_hub else "0.0.0.0"
test_class.host = host
test_class.external_host = cls.external_host
# Create subclasses for each of the remaining browsers and expose
@@ -49,16 +51,16 @@ class SeleniumTestCaseBase(type(LiveServerTestCase)):
"%s%s" % (capfirst(browser), name),
(test_class,),
{
- 'browser': browser,
- 'host': host,
- 'external_host': cls.external_host,
- '__module__': test_class.__module__,
- }
+ "browser": browser,
+ "host": host,
+ "external_host": cls.external_host,
+ "__module__": test_class.__module__,
+ },
)
setattr(module, browser_test_class.__name__, browser_test_class)
return test_class
# If no browsers were specified, skip this class (it'll still be discovered).
- return unittest.skip('No browsers specified.')(test_class)
+ return unittest.skip("No browsers specified.")(test_class)
@classmethod
def import_webdriver(cls, browser):
@@ -66,13 +68,12 @@ class SeleniumTestCaseBase(type(LiveServerTestCase)):
@classmethod
def import_options(cls, browser):
- return import_string('selenium.webdriver.%s.options.Options' % browser)
+ return import_string("selenium.webdriver.%s.options.Options" % browser)
@classmethod
def get_capability(cls, browser):
- from selenium.webdriver.common.desired_capabilities import (
- DesiredCapabilities,
- )
+ from selenium.webdriver.common.desired_capabilities import DesiredCapabilities
+
return getattr(DesiredCapabilities, browser.upper())
def create_options(self):
@@ -87,6 +88,7 @@ class SeleniumTestCaseBase(type(LiveServerTestCase)):
def create_webdriver(self):
if self.selenium_hub:
from selenium import webdriver
+
return webdriver.Remote(
command_executor=self.selenium_hub,
desired_capabilities=self.get_capability(self.browser),
@@ -94,14 +96,14 @@ class SeleniumTestCaseBase(type(LiveServerTestCase)):
return self.import_webdriver(self.browser)(options=self.create_options())
-@tag('selenium')
+@tag("selenium")
class SeleniumTestCase(LiveServerTestCase, metaclass=SeleniumTestCaseBase):
implicit_wait = 10
external_host = None
@classproperty
def live_server_url(cls):
- return 'http://%s:%s' % (cls.external_host or cls.host, cls.server_thread.port)
+ return "http://%s:%s" % (cls.external_host or cls.host, cls.server_thread.port)
@classproperty
def allowed_host(cls):
@@ -118,7 +120,7 @@ class SeleniumTestCase(LiveServerTestCase, metaclass=SeleniumTestCaseBase):
# quit() the WebDriver before attempting to terminate and join the
# single-threaded LiveServerThread to avoid a dead lock if the browser
# kept a connection alive.
- if hasattr(cls, 'selenium'):
+ if hasattr(cls, "selenium"):
cls.selenium.quit()
super()._tearDownClassInternal()
diff --git a/django/test/signals.py b/django/test/signals.py
index c82b95013d..c874f220df 100644
--- a/django/test/signals.py
+++ b/django/test/signals.py
@@ -20,13 +20,14 @@ template_rendered = Signal()
# except for cases where the receiver is related to a contrib app.
# Settings that may not work well when using 'override_settings' (#19031)
-COMPLEX_OVERRIDE_SETTINGS = {'DATABASES'}
+COMPLEX_OVERRIDE_SETTINGS = {"DATABASES"}
@receiver(setting_changed)
def clear_cache_handlers(*, setting, **kwargs):
- if setting == 'CACHES':
+ if setting == "CACHES":
from django.core.cache import caches, close_caches
+
close_caches()
caches._settings = caches.settings = caches.configure_settings(None)
caches._connections = Local()
@@ -34,37 +35,41 @@ def clear_cache_handlers(*, setting, **kwargs):
@receiver(setting_changed)
def update_installed_apps(*, setting, **kwargs):
- if setting == 'INSTALLED_APPS':
+ if setting == "INSTALLED_APPS":
# Rebuild any AppDirectoriesFinder instance.
from django.contrib.staticfiles.finders import get_finder
+
get_finder.cache_clear()
# Rebuild management commands cache
from django.core.management import get_commands
+
get_commands.cache_clear()
# Rebuild get_app_template_dirs cache.
from django.template.utils import get_app_template_dirs
+
get_app_template_dirs.cache_clear()
# Rebuild translations cache.
from django.utils.translation import trans_real
+
trans_real._translations = {}
@receiver(setting_changed)
def update_connections_time_zone(*, setting, **kwargs):
- if setting == 'TIME_ZONE':
+ if setting == "TIME_ZONE":
# Reset process time zone
- if hasattr(time, 'tzset'):
- if kwargs['value']:
- os.environ['TZ'] = kwargs['value']
+ if hasattr(time, "tzset"):
+ if kwargs["value"]:
+ os.environ["TZ"] = kwargs["value"]
else:
- os.environ.pop('TZ', None)
+ os.environ.pop("TZ", None)
time.tzset()
# Reset local time zone cache
timezone.get_default_timezone.cache_clear()
# Reset the database connections' time zone
- if setting in {'TIME_ZONE', 'USE_TZ'}:
+ if setting in {"TIME_ZONE", "USE_TZ"}:
for conn in connections.all():
try:
del conn.timezone
@@ -79,18 +84,19 @@ def update_connections_time_zone(*, setting, **kwargs):
@receiver(setting_changed)
def clear_routers_cache(*, setting, **kwargs):
- if setting == 'DATABASE_ROUTERS':
+ if setting == "DATABASE_ROUTERS":
router.routers = ConnectionRouter().routers
@receiver(setting_changed)
def reset_template_engines(*, setting, **kwargs):
if setting in {
- 'TEMPLATES',
- 'DEBUG',
- 'INSTALLED_APPS',
+ "TEMPLATES",
+ "DEBUG",
+ "INSTALLED_APPS",
}:
from django.template import engines
+
try:
del engines.templates
except AttributeError:
@@ -98,40 +104,46 @@ def reset_template_engines(*, setting, **kwargs):
engines._templates = None
engines._engines = {}
from django.template.engine import Engine
+
Engine.get_default.cache_clear()
from django.forms.renderers import get_default_renderer
+
get_default_renderer.cache_clear()
@receiver(setting_changed)
def clear_serializers_cache(*, setting, **kwargs):
- if setting == 'SERIALIZATION_MODULES':
+ if setting == "SERIALIZATION_MODULES":
from django.core import serializers
+
serializers._serializers = {}
@receiver(setting_changed)
def language_changed(*, setting, **kwargs):
- if setting in {'LANGUAGES', 'LANGUAGE_CODE', 'LOCALE_PATHS'}:
+ if setting in {"LANGUAGES", "LANGUAGE_CODE", "LOCALE_PATHS"}:
from django.utils.translation import trans_real
+
trans_real._default = None
trans_real._active = Local()
- if setting in {'LANGUAGES', 'LOCALE_PATHS'}:
+ if setting in {"LANGUAGES", "LOCALE_PATHS"}:
from django.utils.translation import trans_real
+
trans_real._translations = {}
trans_real.check_for_language.cache_clear()
@receiver(setting_changed)
def localize_settings_changed(*, setting, **kwargs):
- if setting in FORMAT_SETTINGS or setting == 'USE_THOUSAND_SEPARATOR':
+ if setting in FORMAT_SETTINGS or setting == "USE_THOUSAND_SEPARATOR":
reset_format_cache()
@receiver(setting_changed)
def file_storage_changed(*, setting, **kwargs):
- if setting == 'DEFAULT_FILE_STORAGE':
+ if setting == "DEFAULT_FILE_STORAGE":
from django.core.files.storage import default_storage
+
default_storage._wrapped = empty
@@ -141,15 +153,16 @@ def complex_setting_changed(*, enter, setting, **kwargs):
# Considering the current implementation of the signals framework,
# this stacklevel shows the line containing the override_settings call.
warnings.warn(
- f'Overriding setting {setting} can lead to unexpected behavior.',
+ f"Overriding setting {setting} can lead to unexpected behavior.",
stacklevel=6,
)
@receiver(setting_changed)
def root_urlconf_changed(*, setting, **kwargs):
- if setting == 'ROOT_URLCONF':
+ if setting == "ROOT_URLCONF":
from django.urls import clear_url_caches, set_urlconf
+
clear_url_caches()
set_urlconf(None)
@@ -157,55 +170,64 @@ def root_urlconf_changed(*, setting, **kwargs):
@receiver(setting_changed)
def static_storage_changed(*, setting, **kwargs):
if setting in {
- 'STATICFILES_STORAGE',
- 'STATIC_ROOT',
- 'STATIC_URL',
+ "STATICFILES_STORAGE",
+ "STATIC_ROOT",
+ "STATIC_URL",
}:
from django.contrib.staticfiles.storage import staticfiles_storage
+
staticfiles_storage._wrapped = empty
@receiver(setting_changed)
def static_finders_changed(*, setting, **kwargs):
if setting in {
- 'STATICFILES_DIRS',
- 'STATIC_ROOT',
+ "STATICFILES_DIRS",
+ "STATIC_ROOT",
}:
from django.contrib.staticfiles.finders import get_finder
+
get_finder.cache_clear()
@receiver(setting_changed)
def auth_password_validators_changed(*, setting, **kwargs):
- if setting == 'AUTH_PASSWORD_VALIDATORS':
+ if setting == "AUTH_PASSWORD_VALIDATORS":
from django.contrib.auth.password_validation import (
get_default_password_validators,
)
+
get_default_password_validators.cache_clear()
@receiver(setting_changed)
def user_model_swapped(*, setting, **kwargs):
- if setting == 'AUTH_USER_MODEL':
+ if setting == "AUTH_USER_MODEL":
apps.clear_cache()
try:
from django.contrib.auth import get_user_model
+
UserModel = get_user_model()
except ImproperlyConfigured:
# Some tests set an invalid AUTH_USER_MODEL.
pass
else:
from django.contrib.auth import backends
+
backends.UserModel = UserModel
from django.contrib.auth import forms
+
forms.UserModel = UserModel
from django.contrib.auth.handlers import modwsgi
+
modwsgi.UserModel = UserModel
from django.contrib.auth.management.commands import changepassword
+
changepassword.UserModel = UserModel
from django.contrib.auth import views
+
views.UserModel = UserModel
diff --git a/django/test/testcases.py b/django/test/testcases.py
index d24a065790..d514d06c7a 100644
--- a/django/test/testcases.py
+++ b/django/test/testcases.py
@@ -15,7 +15,13 @@ from functools import wraps
from unittest.suite import _DebugResult
from unittest.util import safe_repr
from urllib.parse import (
- parse_qsl, unquote, urlencode, urljoin, urlparse, urlsplit, urlunparse,
+ parse_qsl,
+ unquote,
+ urlencode,
+ urljoin,
+ urlparse,
+ urlsplit,
+ urlunparse,
)
from urllib.request import url2pathname
@@ -40,7 +46,10 @@ from django.test.client import AsyncClient, Client
from django.test.html import HTMLParseError, parse_html
from django.test.signals import template_rendered
from django.test.utils import (
- CaptureQueriesContext, ContextList, compare_xml, modify_settings,
+ CaptureQueriesContext,
+ ContextList,
+ compare_xml,
+ modify_settings,
override_settings,
)
from django.utils.deprecation import RemovedInDjango50Warning
@@ -48,8 +57,13 @@ from django.utils.functional import classproperty
from django.utils.version import PY310
from django.views.static import serve
-__all__ = ('TestCase', 'TransactionTestCase',
- 'SimpleTestCase', 'skipIfDBFeature', 'skipUnlessDBFeature')
+__all__ = (
+ "TestCase",
+ "TransactionTestCase",
+ "SimpleTestCase",
+ "skipIfDBFeature",
+ "skipUnlessDBFeature",
+)
def to_list(value):
@@ -63,7 +77,7 @@ def assert_and_parse_html(self, html, user_msg, msg):
try:
dom = parse_html(html)
except HTMLParseError as e:
- standardMsg = '%s\n%s' % (msg, e)
+ standardMsg = "%s\n%s" % (msg, e)
self.fail(self._formatMessage(user_msg, standardMsg))
return dom
@@ -80,18 +94,22 @@ class _AssertNumQueriesContext(CaptureQueriesContext):
return
executed = len(self)
self.test_case.assertEqual(
- executed, self.num,
- "%d queries executed, %d expected\nCaptured queries were:\n%s" % (
- executed, self.num,
- '\n'.join(
- '%d. %s' % (i, query['sql']) for i, query in enumerate(self.captured_queries, start=1)
- )
- )
+ executed,
+ self.num,
+ "%d queries executed, %d expected\nCaptured queries were:\n%s"
+ % (
+ executed,
+ self.num,
+ "\n".join(
+ "%d. %s" % (i, query["sql"])
+ for i, query in enumerate(self.captured_queries, start=1)
+ ),
+ ),
)
class _AssertTemplateUsedContext:
- def __init__(self, test_case, template_name, msg_prefix='', count=None):
+ def __init__(self, test_case, template_name, msg_prefix="", count=None):
self.test_case = test_case
self.template_name = template_name
self.msg_prefix = msg_prefix
@@ -108,7 +126,9 @@ class _AssertTemplateUsedContext:
def test(self):
self.test_case._assert_template_used(
- self.template_name, self.rendered_template_names, self.msg_prefix,
+ self.template_name,
+ self.rendered_template_names,
+ self.msg_prefix,
self.count,
)
@@ -128,7 +148,7 @@ class _AssertTemplateNotUsedContext(_AssertTemplateUsedContext):
self.test_case.assertFalse(
self.template_name in self.rendered_template_names,
f"{self.msg_prefix}Template '{self.template_name}' was used "
- f"unexpectedly in rendering the response"
+ f"unexpectedly in rendering the response",
)
@@ -156,16 +176,16 @@ class SimpleTestCase(unittest.TestCase):
databases = set()
_disallowed_database_msg = (
- 'Database %(operation)s to %(alias)r are not allowed in SimpleTestCase '
- 'subclasses. Either subclass TestCase or TransactionTestCase to ensure '
- 'proper test isolation or add %(alias)r to %(test)s.databases to silence '
- 'this failure.'
+ "Database %(operation)s to %(alias)r are not allowed in SimpleTestCase "
+ "subclasses. Either subclass TestCase or TransactionTestCase to ensure "
+ "proper test isolation or add %(alias)r to %(test)s.databases to silence "
+ "this failure."
)
_disallowed_connection_methods = [
- ('connect', 'connections'),
- ('temporary_connection', 'connections'),
- ('cursor', 'queries'),
- ('chunked_cursor', 'queries'),
+ ("connect", "connections"),
+ ("temporary_connection", "connections"),
+ ("cursor", "queries"),
+ ("chunked_cursor", "queries"),
]
@classmethod
@@ -184,18 +204,21 @@ class SimpleTestCase(unittest.TestCase):
@classmethod
def _validate_databases(cls):
- if cls.databases == '__all__':
+ if cls.databases == "__all__":
return frozenset(connections)
for alias in cls.databases:
if alias not in connections:
- message = '%s.%s.databases refers to %r which is not defined in settings.DATABASES.' % (
- cls.__module__,
- cls.__qualname__,
- alias,
+ message = (
+ "%s.%s.databases refers to %r which is not defined in settings.DATABASES."
+ % (
+ cls.__module__,
+ cls.__qualname__,
+ alias,
+ )
)
close_matches = get_close_matches(alias, list(connections))
if close_matches:
- message += ' Did you mean %r?' % close_matches[0]
+ message += " Did you mean %r?" % close_matches[0]
raise ImproperlyConfigured(message)
return frozenset(cls.databases)
@@ -208,9 +231,9 @@ class SimpleTestCase(unittest.TestCase):
connection = connections[alias]
for name, operation in cls._disallowed_connection_methods:
message = cls._disallowed_database_msg % {
- 'test': '%s.%s' % (cls.__module__, cls.__qualname__),
- 'alias': alias,
- 'operation': operation,
+ "test": "%s.%s" % (cls.__module__, cls.__qualname__),
+ "alias": alias,
+ "operation": operation,
}
method = getattr(connection, name)
setattr(connection, name, _DatabaseFailure(method, message))
@@ -247,9 +270,8 @@ class SimpleTestCase(unittest.TestCase):
instead of __call__() to run the test.
"""
testMethod = getattr(self, self._testMethodName)
- skipped = (
- getattr(self.__class__, "__unittest_skip__", False) or
- getattr(testMethod, "__unittest_skip__", False)
+ skipped = getattr(self.__class__, "__unittest_skip__", False) or getattr(
+ testMethod, "__unittest_skip__", False
)
# Convert async test methods.
@@ -305,9 +327,15 @@ class SimpleTestCase(unittest.TestCase):
"""
return modify_settings(**kwargs)
- def assertRedirects(self, response, expected_url, status_code=302,
- target_status_code=200, msg_prefix='',
- fetch_redirect_response=True):
+ def assertRedirects(
+ self,
+ response,
+ expected_url,
+ status_code=302,
+ target_status_code=200,
+ msg_prefix="",
+ fetch_redirect_response=True,
+ ):
"""
Assert that a response redirected to a specific URL and that the
redirect URL can be loaded.
@@ -319,43 +347,50 @@ class SimpleTestCase(unittest.TestCase):
if msg_prefix:
msg_prefix += ": "
- if hasattr(response, 'redirect_chain'):
+ if hasattr(response, "redirect_chain"):
# The request was a followed redirect
self.assertTrue(
response.redirect_chain,
- msg_prefix + "Response didn't redirect as expected: Response code was %d (expected %d)"
- % (response.status_code, status_code)
+ msg_prefix
+ + "Response didn't redirect as expected: Response code was %d (expected %d)"
+ % (response.status_code, status_code),
)
self.assertEqual(
- response.redirect_chain[0][1], status_code,
- msg_prefix + "Initial response didn't redirect as expected: Response code was %d (expected %d)"
- % (response.redirect_chain[0][1], status_code)
+ response.redirect_chain[0][1],
+ status_code,
+ msg_prefix
+ + "Initial response didn't redirect as expected: Response code was %d (expected %d)"
+ % (response.redirect_chain[0][1], status_code),
)
url, status_code = response.redirect_chain[-1]
self.assertEqual(
- response.status_code, target_status_code,
- msg_prefix + "Response didn't redirect as expected: Final Response code was %d (expected %d)"
- % (response.status_code, target_status_code)
+ response.status_code,
+ target_status_code,
+ msg_prefix
+ + "Response didn't redirect as expected: Final Response code was %d (expected %d)"
+ % (response.status_code, target_status_code),
)
else:
# Not a followed redirect
self.assertEqual(
- response.status_code, status_code,
- msg_prefix + "Response didn't redirect as expected: Response code was %d (expected %d)"
- % (response.status_code, status_code)
+ response.status_code,
+ status_code,
+ msg_prefix
+ + "Response didn't redirect as expected: Response code was %d (expected %d)"
+ % (response.status_code, status_code),
)
url = response.url
scheme, netloc, path, query, fragment = urlsplit(url)
# Prepend the request path to handle relative path redirects.
- if not path.startswith('/'):
- url = urljoin(response.request['PATH_INFO'], url)
- path = urljoin(response.request['PATH_INFO'], path)
+ if not path.startswith("/"):
+ url = urljoin(response.request["PATH_INFO"], url)
+ path = urljoin(response.request["PATH_INFO"], path)
if fetch_redirect_response:
# netloc might be empty, or in cases where Django tests the
@@ -375,21 +410,25 @@ class SimpleTestCase(unittest.TestCase):
redirect_response = response.client.get(
path,
QueryDict(query),
- secure=(scheme == 'https'),
+ secure=(scheme == "https"),
**extra,
)
self.assertEqual(
- redirect_response.status_code, target_status_code,
- msg_prefix + "Couldn't retrieve redirection page '%s': response code was %d (expected %d)"
- % (path, redirect_response.status_code, target_status_code)
+ redirect_response.status_code,
+ target_status_code,
+ msg_prefix
+ + "Couldn't retrieve redirection page '%s': response code was %d (expected %d)"
+ % (path, redirect_response.status_code, target_status_code),
)
self.assertURLEqual(
- url, expected_url,
- msg_prefix + "Response redirected to '%s', expected '%s'" % (url, expected_url)
+ url,
+ expected_url,
+ msg_prefix
+ + "Response redirected to '%s', expected '%s'" % (url, expected_url),
)
- def assertURLEqual(self, url1, url2, msg_prefix=''):
+ def assertURLEqual(self, url1, url2, msg_prefix=""):
"""
Assert that two URLs are the same, ignoring the order of query string
parameters except for parameters with the same name.
@@ -397,35 +436,44 @@ class SimpleTestCase(unittest.TestCase):
For example, /path/?x=1&y=2 is equal to /path/?y=2&x=1, but
/path/?a=1&a=2 isn't equal to /path/?a=2&a=1.
"""
+
def normalize(url):
"""Sort the URL's query string parameters."""
url = str(url) # Coerce reverse_lazy() URLs.
scheme, netloc, path, params, query, fragment = urlparse(url)
query_parts = sorted(parse_qsl(query))
- return urlunparse((scheme, netloc, path, params, urlencode(query_parts), fragment))
+ return urlunparse(
+ (scheme, netloc, path, params, urlencode(query_parts), fragment)
+ )
self.assertEqual(
- normalize(url1), normalize(url2),
- msg_prefix + "Expected '%s' to equal '%s'." % (url1, url2)
+ normalize(url1),
+ normalize(url2),
+ msg_prefix + "Expected '%s' to equal '%s'." % (url1, url2),
)
def _assert_contains(self, response, text, status_code, msg_prefix, html):
# If the response supports deferred rendering and hasn't been rendered
# yet, then ensure that it does get rendered before proceeding further.
- if hasattr(response, 'render') and callable(response.render) and not response.is_rendered:
+ if (
+ hasattr(response, "render")
+ and callable(response.render)
+ and not response.is_rendered
+ ):
response.render()
if msg_prefix:
msg_prefix += ": "
self.assertEqual(
- response.status_code, status_code,
+ response.status_code,
+ status_code,
msg_prefix + "Couldn't retrieve content: Response code was %d"
- " (expected %d)" % (response.status_code, status_code)
+ " (expected %d)" % (response.status_code, status_code),
)
if response.streaming:
- content = b''.join(response.streaming_content)
+ content = b"".join(response.streaming_content)
else:
content = response.content
if not isinstance(text, bytes) or html:
@@ -435,12 +483,18 @@ class SimpleTestCase(unittest.TestCase):
else:
text_repr = repr(text)
if html:
- content = assert_and_parse_html(self, content, None, "Response's content is not valid HTML:")
- text = assert_and_parse_html(self, text, None, "Second argument is not valid HTML:")
+ content = assert_and_parse_html(
+ self, content, None, "Response's content is not valid HTML:"
+ )
+ text = assert_and_parse_html(
+ self, text, None, "Second argument is not valid HTML:"
+ )
real_count = content.count(text)
return (text_repr, real_count, msg_prefix)
- def assertContains(self, response, text, count=None, status_code=200, msg_prefix='', html=False):
+ def assertContains(
+ self, response, text, count=None, status_code=200, msg_prefix="", html=False
+ ):
"""
Assert that a response indicates that some content was retrieved
successfully, (i.e., the HTTP status code was as expected) and that
@@ -449,26 +503,37 @@ class SimpleTestCase(unittest.TestCase):
if the text occurs at least once in the response.
"""
text_repr, real_count, msg_prefix = self._assert_contains(
- response, text, status_code, msg_prefix, html)
+ response, text, status_code, msg_prefix, html
+ )
if count is not None:
self.assertEqual(
- real_count, count,
- msg_prefix + "Found %d instances of %s in response (expected %d)" % (real_count, text_repr, count)
+ real_count,
+ count,
+ msg_prefix
+ + "Found %d instances of %s in response (expected %d)"
+ % (real_count, text_repr, count),
)
else:
- self.assertTrue(real_count != 0, msg_prefix + "Couldn't find %s in response" % text_repr)
+ self.assertTrue(
+ real_count != 0, msg_prefix + "Couldn't find %s in response" % text_repr
+ )
- def assertNotContains(self, response, text, status_code=200, msg_prefix='', html=False):
+ def assertNotContains(
+ self, response, text, status_code=200, msg_prefix="", html=False
+ ):
"""
Assert that a response indicates that some content was retrieved
successfully, (i.e., the HTTP status code was as expected) and that
``text`` doesn't occur in the content of the response.
"""
text_repr, real_count, msg_prefix = self._assert_contains(
- response, text, status_code, msg_prefix, html)
+ response, text, status_code, msg_prefix, html
+ )
- self.assertEqual(real_count, 0, msg_prefix + "Response should not contain %s" % text_repr)
+ self.assertEqual(
+ real_count, 0, msg_prefix + "Response should not contain %s" % text_repr
+ )
def _check_test_client_response(self, response, attribute, method_name):
"""
@@ -481,24 +546,26 @@ class SimpleTestCase(unittest.TestCase):
"the Django test Client."
)
- def assertFormError(self, response, form, field, errors, msg_prefix=''):
+ def assertFormError(self, response, form, field, errors, msg_prefix=""):
"""
Assert that a form used to render the response has a specific field
error.
"""
- self._check_test_client_response(response, 'context', 'assertFormError')
+ self._check_test_client_response(response, "context", "assertFormError")
if msg_prefix:
msg_prefix += ": "
# Put context(s) into a list to simplify processing.
contexts = [] if response.context is None else to_list(response.context)
if not contexts:
- self.fail(msg_prefix + "Response did not use any contexts to render the response")
+ self.fail(
+ msg_prefix + "Response did not use any contexts to render the response"
+ )
if errors is None:
warnings.warn(
- 'Passing errors=None to assertFormError() is deprecated, use '
- 'errors=[] instead.',
+ "Passing errors=None to assertFormError() is deprecated, use "
+ "errors=[] instead.",
RemovedInDjango50Warning,
stacklevel=2,
)
@@ -520,18 +587,20 @@ class SimpleTestCase(unittest.TestCase):
err in field_errors,
msg_prefix + "The field '%s' on form '%s' in"
" context %d does not contain the error '%s'"
- " (actual errors: %s)" %
- (field, form, i, err, repr(field_errors))
+ " (actual errors: %s)"
+ % (field, form, i, err, repr(field_errors)),
)
elif field in context[form].fields:
self.fail(
- msg_prefix + "The field '%s' on form '%s' in context %d contains no errors" %
- (field, form, i)
+ msg_prefix
+ + "The field '%s' on form '%s' in context %d contains no errors"
+ % (field, form, i)
)
else:
self.fail(
- msg_prefix + "The form '%s' in context %d does not contain the field '%s'" %
- (form, i, field)
+ msg_prefix
+ + "The form '%s' in context %d does not contain the field '%s'"
+ % (form, i, field)
)
else:
non_field_errors = context[form].non_field_errors()
@@ -539,14 +608,17 @@ class SimpleTestCase(unittest.TestCase):
err in non_field_errors,
msg_prefix + "The form '%s' in context %d does not"
" contain the non-field error '%s'"
- " (actual errors: %s)" %
- (form, i, err, non_field_errors or 'none')
+ " (actual errors: %s)"
+ % (form, i, err, non_field_errors or "none"),
)
if not found_form:
- self.fail(msg_prefix + "The form '%s' was not used to render the response" % form)
+ self.fail(
+ msg_prefix + "The form '%s' was not used to render the response" % form
+ )
- def assertFormsetError(self, response, formset, form_index, field, errors,
- msg_prefix=''):
+ def assertFormsetError(
+ self, response, formset, form_index, field, errors, msg_prefix=""
+ ):
"""
Assert that a formset used to render the response has a specific error.
@@ -556,7 +628,7 @@ class SimpleTestCase(unittest.TestCase):
For non-form errors, specify ``form_index`` as None and the ``field``
as None.
"""
- self._check_test_client_response(response, 'context', 'assertFormsetError')
+ self._check_test_client_response(response, "context", "assertFormsetError")
# Add punctuation to msg_prefix
if msg_prefix:
msg_prefix += ": "
@@ -564,13 +636,15 @@ class SimpleTestCase(unittest.TestCase):
# Put context(s) into a list to simplify processing.
contexts = [] if response.context is None else to_list(response.context)
if not contexts:
- self.fail(msg_prefix + 'Response did not use any contexts to '
- 'render the response')
+ self.fail(
+ msg_prefix + "Response did not use any contexts to "
+ "render the response"
+ )
if errors is None:
warnings.warn(
- 'Passing errors=None to assertFormsetError() is deprecated, '
- 'use errors=[] instead.',
+ "Passing errors=None to assertFormsetError() is deprecated, "
+ "use errors=[] instead.",
RemovedInDjango50Warning,
stacklevel=2,
)
@@ -581,7 +655,7 @@ class SimpleTestCase(unittest.TestCase):
# Search all contexts for the error.
found_formset = False
for i, context in enumerate(contexts):
- if formset not in context or not hasattr(context[formset], 'forms'):
+ if formset not in context or not hasattr(context[formset], "forms"):
continue
found_formset = True
for err in errors:
@@ -592,60 +666,68 @@ class SimpleTestCase(unittest.TestCase):
err in field_errors,
msg_prefix + "The field '%s' on formset '%s', "
"form %d in context %d does not contain the "
- "error '%s' (actual errors: %s)" %
- (field, formset, form_index, i, err, repr(field_errors))
+ "error '%s' (actual errors: %s)"
+ % (field, formset, form_index, i, err, repr(field_errors)),
)
elif field in context[formset].forms[form_index].fields:
self.fail(
- msg_prefix + "The field '%s' on formset '%s', form %d in context %d contains no errors"
+ msg_prefix
+ + "The field '%s' on formset '%s', form %d in context %d contains no errors"
% (field, formset, form_index, i)
)
else:
self.fail(
- msg_prefix + "The formset '%s', form %d in context %d does not contain the field '%s'"
+ msg_prefix
+ + "The formset '%s', form %d in context %d does not contain the field '%s'"
% (formset, form_index, i, field)
)
elif form_index is not None:
- non_field_errors = context[formset].forms[form_index].non_field_errors()
+ non_field_errors = (
+ context[formset].forms[form_index].non_field_errors()
+ )
self.assertFalse(
not non_field_errors,
msg_prefix + "The formset '%s', form %d in context %d "
- "does not contain any non-field errors." % (formset, form_index, i)
+ "does not contain any non-field errors."
+ % (formset, form_index, i),
)
self.assertTrue(
err in non_field_errors,
msg_prefix + "The formset '%s', form %d in context %d "
"does not contain the non-field error '%s' (actual errors: %s)"
- % (formset, form_index, i, err, repr(non_field_errors))
+ % (formset, form_index, i, err, repr(non_field_errors)),
)
else:
non_form_errors = context[formset].non_form_errors()
self.assertFalse(
not non_form_errors,
msg_prefix + "The formset '%s' in context %d does not "
- "contain any non-form errors." % (formset, i)
+ "contain any non-form errors." % (formset, i),
)
self.assertTrue(
err in non_form_errors,
msg_prefix + "The formset '%s' in context %d does not "
"contain the non-form error '%s' (actual errors: %s)"
- % (formset, i, err, repr(non_form_errors))
+ % (formset, i, err, repr(non_form_errors)),
)
if not found_formset:
- self.fail(msg_prefix + "The formset '%s' was not used to render the response" % formset)
+ self.fail(
+ msg_prefix
+ + "The formset '%s' was not used to render the response" % formset
+ )
def _get_template_used(self, response, template_name, msg_prefix, method_name):
if response is None and template_name is None:
- raise TypeError('response and/or template_name argument must be provided')
+ raise TypeError("response and/or template_name argument must be provided")
if msg_prefix:
msg_prefix += ": "
if template_name is not None and response is not None:
- self._check_test_client_response(response, 'templates', method_name)
+ self._check_test_client_response(response, "templates", method_name)
- if not hasattr(response, 'templates') or (response is None and template_name):
+ if not hasattr(response, "templates") or (response is None and template_name):
if response:
template_name = response
response = None
@@ -662,38 +744,49 @@ class SimpleTestCase(unittest.TestCase):
template_name in template_names,
msg_prefix + "Template '%s' was not a template used to render"
" the response. Actual template(s) used: %s"
- % (template_name, ', '.join(template_names))
+ % (template_name, ", ".join(template_names)),
)
if count is not None:
self.assertEqual(
- template_names.count(template_name), count,
+ template_names.count(template_name),
+ count,
msg_prefix + "Template '%s' was expected to be rendered %d "
"time(s) but was actually rendered %d time(s)."
- % (template_name, count, template_names.count(template_name))
+ % (template_name, count, template_names.count(template_name)),
)
- def assertTemplateUsed(self, response=None, template_name=None, msg_prefix='', count=None):
+ def assertTemplateUsed(
+ self, response=None, template_name=None, msg_prefix="", count=None
+ ):
"""
Assert that the template with the provided name was used in rendering
the response. Also usable as context manager.
"""
context_mgr_template, template_names, msg_prefix = self._get_template_used(
- response, template_name, msg_prefix, 'assertTemplateUsed',
+ response,
+ template_name,
+ msg_prefix,
+ "assertTemplateUsed",
)
if context_mgr_template:
# Use assertTemplateUsed as context manager.
- return _AssertTemplateUsedContext(self, context_mgr_template, msg_prefix, count)
+ return _AssertTemplateUsedContext(
+ self, context_mgr_template, msg_prefix, count
+ )
self._assert_template_used(template_name, template_names, msg_prefix, count)
- def assertTemplateNotUsed(self, response=None, template_name=None, msg_prefix=''):
+ def assertTemplateNotUsed(self, response=None, template_name=None, msg_prefix=""):
"""
Assert that the template with the provided name was NOT used in
rendering the response. Also usable as context manager.
"""
context_mgr_template, template_names, msg_prefix = self._get_template_used(
- response, template_name, msg_prefix, 'assertTemplateNotUsed',
+ response,
+ template_name,
+ msg_prefix,
+ "assertTemplateNotUsed",
)
if context_mgr_template:
# Use assertTemplateNotUsed as context manager.
@@ -701,20 +794,28 @@ class SimpleTestCase(unittest.TestCase):
self.assertFalse(
template_name in template_names,
- msg_prefix + "Template '%s' was used unexpectedly in rendering the response" % template_name
+ msg_prefix
+ + "Template '%s' was used unexpectedly in rendering the response"
+ % template_name,
)
@contextmanager
- def _assert_raises_or_warns_cm(self, func, cm_attr, expected_exception, expected_message):
+ def _assert_raises_or_warns_cm(
+ self, func, cm_attr, expected_exception, expected_message
+ ):
with func(expected_exception) as cm:
yield cm
self.assertIn(expected_message, str(getattr(cm, cm_attr)))
- def _assertFooMessage(self, func, cm_attr, expected_exception, expected_message, *args, **kwargs):
+ def _assertFooMessage(
+ self, func, cm_attr, expected_exception, expected_message, *args, **kwargs
+ ):
callable_obj = None
if args:
callable_obj, *args = args
- cm = self._assert_raises_or_warns_cm(func, cm_attr, expected_exception, expected_message)
+ cm = self._assert_raises_or_warns_cm(
+ func, cm_attr, expected_exception, expected_message
+ )
# Assertion used in context manager fashion.
if callable_obj is None:
return cm
@@ -722,7 +823,9 @@ class SimpleTestCase(unittest.TestCase):
with cm:
callable_obj(*args, **kwargs)
- def assertRaisesMessage(self, expected_exception, expected_message, *args, **kwargs):
+ def assertRaisesMessage(
+ self, expected_exception, expected_message, *args, **kwargs
+ ):
"""
Assert that expected_message is found in the message of a raised
exception.
@@ -734,8 +837,12 @@ class SimpleTestCase(unittest.TestCase):
kwargs: Extra kwargs.
"""
return self._assertFooMessage(
- self.assertRaises, 'exception', expected_exception, expected_message,
- *args, **kwargs
+ self.assertRaises,
+ "exception",
+ expected_exception,
+ expected_message,
+ *args,
+ **kwargs,
)
def assertWarnsMessage(self, expected_warning, expected_message, *args, **kwargs):
@@ -744,12 +851,17 @@ class SimpleTestCase(unittest.TestCase):
assertRaises().
"""
return self._assertFooMessage(
- self.assertWarns, 'warning', expected_warning, expected_message,
- *args, **kwargs
+ self.assertWarns,
+ "warning",
+ expected_warning,
+ expected_message,
+ *args,
+ **kwargs,
)
# A similar method is available in Python 3.10+.
if not PY310:
+
@contextmanager
def assertNoLogs(self, logger, level=None):
"""
@@ -759,20 +871,29 @@ class SimpleTestCase(unittest.TestCase):
if isinstance(level, int):
level = logging.getLevelName(level)
elif level is None:
- level = 'INFO'
+ level = "INFO"
try:
with self.assertLogs(logger, level) as cm:
yield
except AssertionError as e:
msg = e.args[0]
- expected_msg = f'no logs of level {level} or higher triggered on {logger}'
+ expected_msg = (
+ f"no logs of level {level} or higher triggered on {logger}"
+ )
if msg != expected_msg:
raise e
else:
- self.fail(f'Unexpected logs found: {cm.output!r}')
+ self.fail(f"Unexpected logs found: {cm.output!r}")
- def assertFieldOutput(self, fieldclass, valid, invalid, field_args=None,
- field_kwargs=None, empty_value=''):
+ def assertFieldOutput(
+ self,
+ fieldclass,
+ valid,
+ invalid,
+ field_args=None,
+ field_kwargs=None,
+ empty_value="",
+ ):
"""
Assert that a form field behaves correctly with various inputs.
@@ -791,7 +912,7 @@ class SimpleTestCase(unittest.TestCase):
if field_kwargs is None:
field_kwargs = {}
required = fieldclass(*field_args, **field_kwargs)
- optional = fieldclass(*field_args, **{**field_kwargs, 'required': False})
+ optional = fieldclass(*field_args, **{**field_kwargs, "required": False})
# test valid inputs
for input, output in valid.items():
self.assertEqual(required.clean(input), output)
@@ -806,7 +927,7 @@ class SimpleTestCase(unittest.TestCase):
optional.clean(input)
self.assertEqual(context_manager.exception.messages, errors)
# test required inputs
- error_required = [required.error_messages['required']]
+ error_required = [required.error_messages["required"]]
for e in required.empty_values:
with self.assertRaises(ValidationError) as context_manager:
required.clean(e)
@@ -814,7 +935,7 @@ class SimpleTestCase(unittest.TestCase):
self.assertEqual(optional.clean(e), empty_value)
# test that max_length and min_length are always accepted
if issubclass(fieldclass, CharField):
- field_kwargs.update({'min_length': 2, 'max_length': 20})
+ field_kwargs.update({"min_length": 2, "max_length": 20})
self.assertIsInstance(fieldclass(*field_args, **field_kwargs), fieldclass)
def assertHTMLEqual(self, html1, html2, msg=None):
@@ -823,39 +944,57 @@ class SimpleTestCase(unittest.TestCase):
Whitespace in most cases is ignored, and attribute ordering is not
significant. The arguments must be valid HTML.
"""
- dom1 = assert_and_parse_html(self, html1, msg, 'First argument is not valid HTML:')
- dom2 = assert_and_parse_html(self, html2, msg, 'Second argument is not valid HTML:')
+ dom1 = assert_and_parse_html(
+ self, html1, msg, "First argument is not valid HTML:"
+ )
+ dom2 = assert_and_parse_html(
+ self, html2, msg, "Second argument is not valid HTML:"
+ )
if dom1 != dom2:
- standardMsg = '%s != %s' % (
- safe_repr(dom1, True), safe_repr(dom2, True))
- diff = ('\n' + '\n'.join(difflib.ndiff(
- str(dom1).splitlines(), str(dom2).splitlines(),
- )))
+ standardMsg = "%s != %s" % (safe_repr(dom1, True), safe_repr(dom2, True))
+ diff = "\n" + "\n".join(
+ difflib.ndiff(
+ str(dom1).splitlines(),
+ str(dom2).splitlines(),
+ )
+ )
standardMsg = self._truncateMessage(standardMsg, diff)
self.fail(self._formatMessage(msg, standardMsg))
def assertHTMLNotEqual(self, html1, html2, msg=None):
"""Assert that two HTML snippets are not semantically equivalent."""
- dom1 = assert_and_parse_html(self, html1, msg, 'First argument is not valid HTML:')
- dom2 = assert_and_parse_html(self, html2, msg, 'Second argument is not valid HTML:')
+ dom1 = assert_and_parse_html(
+ self, html1, msg, "First argument is not valid HTML:"
+ )
+ dom2 = assert_and_parse_html(
+ self, html2, msg, "Second argument is not valid HTML:"
+ )
if dom1 == dom2:
- standardMsg = '%s == %s' % (
- safe_repr(dom1, True), safe_repr(dom2, True))
+ standardMsg = "%s == %s" % (safe_repr(dom1, True), safe_repr(dom2, True))
self.fail(self._formatMessage(msg, standardMsg))
- def assertInHTML(self, needle, haystack, count=None, msg_prefix=''):
- needle = assert_and_parse_html(self, needle, None, 'First argument is not valid HTML:')
- haystack = assert_and_parse_html(self, haystack, None, 'Second argument is not valid HTML:')
+ def assertInHTML(self, needle, haystack, count=None, msg_prefix=""):
+ needle = assert_and_parse_html(
+ self, needle, None, "First argument is not valid HTML:"
+ )
+ haystack = assert_and_parse_html(
+ self, haystack, None, "Second argument is not valid HTML:"
+ )
real_count = haystack.count(needle)
if count is not None:
self.assertEqual(
- real_count, count,
- msg_prefix + "Found %d instances of '%s' in response (expected %d)" % (real_count, needle, count)
+ real_count,
+ count,
+ msg_prefix
+ + "Found %d instances of '%s' in response (expected %d)"
+ % (real_count, needle, count),
)
else:
- self.assertTrue(real_count != 0, msg_prefix + "Couldn't find '%s' in response" % needle)
+ self.assertTrue(
+ real_count != 0, msg_prefix + "Couldn't find '%s' in response" % needle
+ )
def assertJSONEqual(self, raw, expected_data, msg=None):
"""
@@ -900,14 +1039,17 @@ class SimpleTestCase(unittest.TestCase):
try:
result = compare_xml(xml1, xml2)
except Exception as e:
- standardMsg = 'First or second argument is not valid XML\n%s' % e
+ standardMsg = "First or second argument is not valid XML\n%s" % e
self.fail(self._formatMessage(msg, standardMsg))
else:
if not result:
- standardMsg = '%s != %s' % (safe_repr(xml1, True), safe_repr(xml2, True))
- diff = ('\n' + '\n'.join(
+ standardMsg = "%s != %s" % (
+ safe_repr(xml1, True),
+ safe_repr(xml2, True),
+ )
+ diff = "\n" + "\n".join(
difflib.ndiff(xml1.splitlines(), xml2.splitlines())
- ))
+ )
standardMsg = self._truncateMessage(standardMsg, diff)
self.fail(self._formatMessage(msg, standardMsg))
@@ -920,11 +1062,14 @@ class SimpleTestCase(unittest.TestCase):
try:
result = compare_xml(xml1, xml2)
except Exception as e:
- standardMsg = 'First or second argument is not valid XML\n%s' % e
+ standardMsg = "First or second argument is not valid XML\n%s" % e
self.fail(self._formatMessage(msg, standardMsg))
else:
if result:
- standardMsg = '%s == %s' % (safe_repr(xml1, True), safe_repr(xml2, True))
+ standardMsg = "%s == %s" % (
+ safe_repr(xml1, True),
+ safe_repr(xml2, True),
+ )
self.fail(self._formatMessage(msg, standardMsg))
@@ -942,9 +1087,9 @@ class TransactionTestCase(SimpleTestCase):
databases = {DEFAULT_DB_ALIAS}
_disallowed_database_msg = (
- 'Database %(operation)s to %(alias)r are not allowed in this test. '
- 'Add %(alias)r to %(test)s.databases to ensure proper test isolation '
- 'and silence this failure.'
+ "Database %(operation)s to %(alias)r are not allowed in this test. "
+ "Add %(alias)r to %(test)s.databases to ensure proper test isolation "
+ "and silence this failure."
)
# If transactions aren't available, Django will serialize the database
@@ -966,7 +1111,7 @@ class TransactionTestCase(SimpleTestCase):
apps.set_available_apps(self.available_apps)
setting_changed.send(
sender=settings._wrapped.__class__,
- setting='INSTALLED_APPS',
+ setting="INSTALLED_APPS",
value=self.available_apps,
enter=True,
)
@@ -979,7 +1124,7 @@ class TransactionTestCase(SimpleTestCase):
apps.unset_available_apps()
setting_changed.send(
sender=settings._wrapped.__class__,
- setting='INSTALLED_APPS',
+ setting="INSTALLED_APPS",
value=settings.INSTALLED_APPS,
enter=False,
)
@@ -994,9 +1139,12 @@ class TransactionTestCase(SimpleTestCase):
def _databases_names(cls, include_mirrors=True):
# Only consider allowed database aliases, including mirrors or not.
return [
- alias for alias in connections
- if alias in cls.databases and (
- include_mirrors or not connections[alias].settings_dict['TEST']['MIRROR']
+ alias
+ for alias in connections
+ if alias in cls.databases
+ and (
+ include_mirrors
+ or not connections[alias].settings_dict["TEST"]["MIRROR"]
)
]
@@ -1004,7 +1152,8 @@ class TransactionTestCase(SimpleTestCase):
conn = connections[db_name]
if conn.features.supports_sequence_reset:
sql_list = conn.ops.sequence_reset_by_name_sql(
- no_style(), conn.introspection.sequence_list())
+ no_style(), conn.introspection.sequence_list()
+ )
if sql_list:
with transaction.atomic(using=db_name):
with conn.cursor() as cursor:
@@ -1018,7 +1167,9 @@ class TransactionTestCase(SimpleTestCase):
self._reset_sequences(db_name)
# Provide replica initial data from migrated apps, if needed.
- if self.serialized_rollback and hasattr(connections[db_name], "_test_serialized_contents"):
+ if self.serialized_rollback and hasattr(
+ connections[db_name], "_test_serialized_contents"
+ ):
if self.available_apps is not None:
apps.unset_available_apps()
connections[db_name].creation.deserialize_db_from_string(
@@ -1030,8 +1181,9 @@ class TransactionTestCase(SimpleTestCase):
if self.fixtures:
# We have to use this slightly awkward syntax due to the fact
# that we're using *args and **kwargs together.
- call_command('loaddata', *self.fixtures,
- **{'verbosity': 0, 'database': db_name})
+ call_command(
+ "loaddata", *self.fixtures, **{"verbosity": 0, "database": db_name}
+ )
def _should_reload_connections(self):
return True
@@ -1058,10 +1210,12 @@ class TransactionTestCase(SimpleTestCase):
finally:
if self.available_apps is not None:
apps.unset_available_apps()
- setting_changed.send(sender=settings._wrapped.__class__,
- setting='INSTALLED_APPS',
- value=settings.INSTALLED_APPS,
- enter=False)
+ setting_changed.send(
+ sender=settings._wrapped.__class__,
+ setting="INSTALLED_APPS",
+ value=settings.INSTALLED_APPS,
+ enter=False,
+ )
def _fixture_teardown(self):
# Allow TRUNCATE ... CASCADE and don't emit the post_migrate signal
@@ -1069,17 +1223,22 @@ class TransactionTestCase(SimpleTestCase):
for db_name in self._databases_names(include_mirrors=False):
# Flush the database
inhibit_post_migrate = (
- self.available_apps is not None or
- ( # Inhibit the post_migrate signal when using serialized
+ self.available_apps is not None
+ or ( # Inhibit the post_migrate signal when using serialized
# rollback to avoid trying to recreate the serialized data.
- self.serialized_rollback and
- hasattr(connections[db_name], '_test_serialized_contents')
+ self.serialized_rollback
+ and hasattr(connections[db_name], "_test_serialized_contents")
)
)
- call_command('flush', verbosity=0, interactive=False,
- database=db_name, reset_sequences=False,
- allow_cascade=self.available_apps is not None,
- inhibit_post_migrate=inhibit_post_migrate)
+ call_command(
+ "flush",
+ verbosity=0,
+ interactive=False,
+ database=db_name,
+ reset_sequences=False,
+ allow_cascade=self.available_apps is not None,
+ inhibit_post_migrate=inhibit_post_migrate,
+ )
def assertQuerysetEqual(self, qs, values, transform=None, ordered=True, msg=None):
values = list(values)
@@ -1090,10 +1249,10 @@ class TransactionTestCase(SimpleTestCase):
return self.assertDictEqual(Counter(items), Counter(values), msg=msg)
# For example qs.iterator() could be passed as qs, but it does not
# have 'ordered' attribute.
- if len(values) > 1 and hasattr(qs, 'ordered') and not qs.ordered:
+ if len(values) > 1 and hasattr(qs, "ordered") and not qs.ordered:
raise ValueError(
- 'Trying to compare non-ordered queryset against more than one '
- 'ordered value.'
+ "Trying to compare non-ordered queryset against more than one "
+ "ordered value."
)
return self.assertEqual(list(items), values, msg=msg)
@@ -1113,7 +1272,11 @@ def connections_support_transactions(aliases=None):
Return whether or not all (or specified) connections support
transactions.
"""
- conns = connections.all() if aliases is None else (connections[alias] for alias in aliases)
+ conns = (
+ connections.all()
+ if aliases is None
+ else (connections[alias] for alias in aliases)
+ )
return all(conn.features.supports_transactions for conn in conns)
@@ -1128,7 +1291,8 @@ class TestData:
Objects are deep copied using a memo kept on the test case instance in
order to maintain their original relationships.
"""
- memo_attr = '_testdata_memo'
+
+ memo_attr = "_testdata_memo"
def __init__(self, name, data):
self.name = name
@@ -1151,7 +1315,7 @@ class TestData:
return data
def __repr__(self):
- return '<TestData: name=%r, data=%r>' % (self.name, self.data)
+ return "<TestData: name=%r, data=%r>" % (self.name, self.data)
class TestCase(TransactionTestCase):
@@ -1167,6 +1331,7 @@ class TestCase(TransactionTestCase):
On database backends with no transaction support, TestCase behaves as
TransactionTestCase.
"""
+
@classmethod
def _enter_atomics(cls):
"""Open atomic blocks for multiple databases."""
@@ -1199,7 +1364,11 @@ class TestCase(TransactionTestCase):
if cls.fixtures:
for db_name in cls._databases_names(include_mirrors=False):
try:
- call_command('loaddata', *cls.fixtures, **{'verbosity': 0, 'database': db_name})
+ call_command(
+ "loaddata",
+ *cls.fixtures,
+ **{"verbosity": 0, "database": db_name},
+ )
except Exception:
cls._rollback_atomics(cls.cls_atomics)
raise
@@ -1239,7 +1408,7 @@ class TestCase(TransactionTestCase):
return super()._fixture_setup()
if self.reset_sequences:
- raise TypeError('reset_sequences cannot be used on TestCase instances')
+ raise TypeError("reset_sequences cannot be used on TestCase instances")
self.atomics = self._enter_atomics()
def _fixture_teardown(self):
@@ -1254,8 +1423,9 @@ class TestCase(TransactionTestCase):
def _should_check_constraints(self, connection):
return (
- connection.features.can_defer_constraint_checks and
- not connection.needs_rollback and connection.is_usable()
+ connection.features.can_defer_constraint_checks
+ and not connection.needs_rollback
+ and connection.is_usable()
)
@classmethod
@@ -1281,6 +1451,7 @@ class TestCase(TransactionTestCase):
class CheckCondition:
"""Descriptor class for deferred condition checking."""
+
def __init__(self, *conditions):
self.conditions = conditions
@@ -1289,7 +1460,7 @@ class CheckCondition:
def __get__(self, instance, cls=None):
# Trigger access for all bases.
- if any(getattr(base, '__unittest_skip__', False) for base in cls.__bases__):
+ if any(getattr(base, "__unittest_skip__", False) for base in cls.__bases__):
return True
for condition, reason in self.conditions:
if condition():
@@ -1303,15 +1474,21 @@ class CheckCondition:
def _deferredSkip(condition, reason, name):
def decorator(test_func):
nonlocal condition
- if not (isinstance(test_func, type) and
- issubclass(test_func, unittest.TestCase)):
+ if not (
+ isinstance(test_func, type) and issubclass(test_func, unittest.TestCase)
+ ):
+
@wraps(test_func)
def skip_wrapper(*args, **kwargs):
- if (args and isinstance(args[0], unittest.TestCase) and
- connection.alias not in getattr(args[0], 'databases', {})):
+ if (
+ args
+ and isinstance(args[0], unittest.TestCase)
+ and connection.alias not in getattr(args[0], "databases", {})
+ ):
raise ValueError(
"%s cannot be used on %s as %s doesn't allow queries "
- "against the %r database." % (
+ "against the %r database."
+ % (
name,
args[0],
args[0].__class__.__qualname__,
@@ -1321,55 +1498,67 @@ def _deferredSkip(condition, reason, name):
if condition():
raise unittest.SkipTest(reason)
return test_func(*args, **kwargs)
+
test_item = skip_wrapper
else:
# Assume a class is decorated
test_item = test_func
- databases = getattr(test_item, 'databases', None)
+ databases = getattr(test_item, "databases", None)
if not databases or connection.alias not in databases:
# Defer raising to allow importing test class's module.
def condition():
raise ValueError(
"%s cannot be used on %s as it doesn't allow queries "
- "against the '%s' database." % (
- name, test_item, connection.alias,
+ "against the '%s' database."
+ % (
+ name,
+ test_item,
+ connection.alias,
)
)
+
# Retrieve the possibly existing value from the class's dict to
# avoid triggering the descriptor.
- skip = test_func.__dict__.get('__unittest_skip__')
+ skip = test_func.__dict__.get("__unittest_skip__")
if isinstance(skip, CheckCondition):
test_item.__unittest_skip__ = skip.add_condition(condition, reason)
elif skip is not True:
test_item.__unittest_skip__ = CheckCondition((condition, reason))
return test_item
+
return decorator
def skipIfDBFeature(*features):
"""Skip a test if a database has at least one of the named features."""
return _deferredSkip(
- lambda: any(getattr(connection.features, feature, False) for feature in features),
+ lambda: any(
+ getattr(connection.features, feature, False) for feature in features
+ ),
"Database has feature(s) %s" % ", ".join(features),
- 'skipIfDBFeature',
+ "skipIfDBFeature",
)
def skipUnlessDBFeature(*features):
"""Skip a test unless a database has all the named features."""
return _deferredSkip(
- lambda: not all(getattr(connection.features, feature, False) for feature in features),
+ lambda: not all(
+ getattr(connection.features, feature, False) for feature in features
+ ),
"Database doesn't support feature(s): %s" % ", ".join(features),
- 'skipUnlessDBFeature',
+ "skipUnlessDBFeature",
)
def skipUnlessAnyDBFeature(*features):
"""Skip a test unless a database has any of the named features."""
return _deferredSkip(
- lambda: not any(getattr(connection.features, feature, False) for feature in features),
+ lambda: not any(
+ getattr(connection.features, feature, False) for feature in features
+ ),
"Database doesn't support any of the feature(s): %s" % ", ".join(features),
- 'skipUnlessAnyDBFeature',
+ "skipUnlessAnyDBFeature",
)
@@ -1378,6 +1567,7 @@ class QuietWSGIRequestHandler(WSGIRequestHandler):
A WSGIRequestHandler that doesn't log to standard output any of the
requests received, so as to not clutter the test result output.
"""
+
def log_message(*args):
pass
@@ -1387,6 +1577,7 @@ class FSFilesHandler(WSGIHandler):
WSGI middleware that intercepts calls to a directory, as defined by one of
the *_ROOT settings, and serves those files, publishing them under *_URL.
"""
+
def __init__(self, application):
self.application = application
self.base_url = urlparse(self.get_base_url())
@@ -1402,7 +1593,7 @@ class FSFilesHandler(WSGIHandler):
def file_path(self, url):
"""Return the relative path to the file on disk for the given URL."""
- relative_url = url[len(self.base_url[2]):]
+ relative_url = url[len(self.base_url[2]) :]
return url2pathname(relative_url)
def get_response(self, request):
@@ -1421,7 +1612,7 @@ class FSFilesHandler(WSGIHandler):
# Emulate behavior of django.contrib.staticfiles.views.serve() when it
# invokes staticfiles' finders functionality.
# TODO: Modify if/when that internal API is refactored
- final_rel_path = os_rel_path.replace('\\', '/').lstrip('/')
+ final_rel_path = os_rel_path.replace("\\", "/").lstrip("/")
return serve(request, final_rel_path, document_root=self.get_base_dir())
def __call__(self, environ, start_response):
@@ -1435,6 +1626,7 @@ class _StaticFilesHandler(FSFilesHandler):
Handler for serving static files. A private class that is meant to be used
solely as a convenience by LiveServerThread.
"""
+
def get_base_dir(self):
return settings.STATIC_ROOT
@@ -1447,6 +1639,7 @@ class _MediaFilesHandler(FSFilesHandler):
Handler for serving the media files. A private class that is meant to be
used solely as a convenience by LiveServerThread.
"""
+
def get_base_dir(self):
return settings.MEDIA_ROOT
@@ -1503,7 +1696,7 @@ class LiveServerThread(threading.Thread):
)
def terminate(self):
- if hasattr(self, 'httpd'):
+ if hasattr(self, "httpd"):
# Stop the WSGI server
self.httpd.shutdown()
self.httpd.server_close()
@@ -1521,14 +1714,15 @@ class LiveServerTestCase(TransactionTestCase):
and each thread needs to commit all their transactions so that the other
thread can see the changes.
"""
- host = 'localhost'
+
+ host = "localhost"
port = 0
server_thread_class = LiveServerThread
static_handler = _StaticFilesHandler
@classproperty
def live_server_url(cls):
- return 'http://%s:%s' % (cls.host, cls.server_thread.port)
+ return "http://%s:%s" % (cls.host, cls.server_thread.port)
@classproperty
def allowed_host(cls):
@@ -1540,7 +1734,7 @@ class LiveServerTestCase(TransactionTestCase):
for conn in connections.all():
# If using in-memory sqlite databases, pass the connections to
# the server thread.
- if conn.vendor == 'sqlite' and conn.is_in_memory_db():
+ if conn.vendor == "sqlite" and conn.is_in_memory_db():
connections_override[conn.alias] = conn
return connections_override
@@ -1548,7 +1742,7 @@ class LiveServerTestCase(TransactionTestCase):
def setUpClass(cls):
super().setUpClass()
cls._live_server_modified_settings = modify_settings(
- ALLOWED_HOSTS={'append': cls.allowed_host},
+ ALLOWED_HOSTS={"append": cls.allowed_host},
)
cls._live_server_modified_settings.enable()
cls.addClassCleanup(cls._live_server_modified_settings.disable)
@@ -1598,6 +1792,7 @@ class SerializeMixin:
Place it early in the MRO in order to isolate setUpClass()/tearDownClass().
"""
+
lockfile = None
def __init_subclass__(cls, /, **kwargs):
@@ -1605,7 +1800,8 @@ class SerializeMixin:
if cls.lockfile is None:
raise ValueError(
"{}.lockfile isn't set. Set it to a unique value "
- "in the base class.".format(cls.__name__))
+ "in the base class.".format(cls.__name__)
+ )
@classmethod
def setUpClass(cls):
diff --git a/django/test/utils.py b/django/test/utils.py
index 6c2f566909..ac0fc34b08 100644
--- a/django/test/utils.py
+++ b/django/test/utils.py
@@ -35,15 +35,24 @@ except ImportError:
__all__ = (
- 'Approximate', 'ContextList', 'isolate_lru_cache', 'get_runner',
- 'CaptureQueriesContext',
- 'ignore_warnings', 'isolate_apps', 'modify_settings', 'override_settings',
- 'override_system_checks', 'tag',
- 'requires_tz_support',
- 'setup_databases', 'setup_test_environment', 'teardown_test_environment',
+ "Approximate",
+ "ContextList",
+ "isolate_lru_cache",
+ "get_runner",
+ "CaptureQueriesContext",
+ "ignore_warnings",
+ "isolate_apps",
+ "modify_settings",
+ "override_settings",
+ "override_system_checks",
+ "tag",
+ "requires_tz_support",
+ "setup_databases",
+ "setup_test_environment",
+ "teardown_test_environment",
)
-TZ_SUPPORT = hasattr(time, 'tzset')
+TZ_SUPPORT = hasattr(time, "tzset")
class Approximate:
@@ -63,6 +72,7 @@ class ContextList(list):
A wrapper that provides direct key access to context items contained
in a list of context objects.
"""
+
def __getitem__(self, key):
if isinstance(key, str):
for subcontext in self:
@@ -110,7 +120,7 @@ def setup_test_environment(debug=None):
Perform global pre-test setup, such as installing the instrumented template
renderer and setting the email backend to the locmem email backend.
"""
- if hasattr(_TestState, 'saved_data'):
+ if hasattr(_TestState, "saved_data"):
# Executing this function twice would overwrite the saved values.
raise RuntimeError(
"setup_test_environment() was already called and can't be called "
@@ -125,13 +135,13 @@ def setup_test_environment(debug=None):
saved_data.allowed_hosts = settings.ALLOWED_HOSTS
# Add the default host of the test client.
- settings.ALLOWED_HOSTS = [*settings.ALLOWED_HOSTS, 'testserver']
+ settings.ALLOWED_HOSTS = [*settings.ALLOWED_HOSTS, "testserver"]
saved_data.debug = settings.DEBUG
settings.DEBUG = debug
saved_data.email_backend = settings.EMAIL_BACKEND
- settings.EMAIL_BACKEND = 'django.core.mail.backends.locmem.EmailBackend'
+ settings.EMAIL_BACKEND = "django.core.mail.backends.locmem.EmailBackend"
saved_data.template_render = Template._render
Template._render = instrumented_test_render
@@ -191,18 +201,17 @@ def setup_databases(
# replace with:
# serialize_alias = serialized_aliases is None or alias in serialized_aliases
try:
- serialize_alias = connection.settings_dict['TEST']['SERIALIZE']
+ serialize_alias = connection.settings_dict["TEST"]["SERIALIZE"]
except KeyError:
serialize_alias = (
- serialized_aliases is None or
- alias in serialized_aliases
+ serialized_aliases is None or alias in serialized_aliases
)
else:
warnings.warn(
- 'The SERIALIZE test database setting is '
- 'deprecated as it can be inferred from the '
- 'TestCase/TransactionTestCase.databases that '
- 'enable the serialized_rollback feature.',
+ "The SERIALIZE test database setting is "
+ "deprecated as it can be inferred from the "
+ "TestCase/TransactionTestCase.databases that "
+ "enable the serialized_rollback feature.",
category=RemovedInDjango50Warning,
)
connection.creation.create_test_db(
@@ -221,12 +230,15 @@ def setup_databases(
)
# Configure all other connections as mirrors of the first one
else:
- connections[alias].creation.set_as_test_mirror(connections[first_alias].settings_dict)
+ connections[alias].creation.set_as_test_mirror(
+ connections[first_alias].settings_dict
+ )
# Configure the test mirrors.
for alias, mirror_alias in mirrored_aliases.items():
connections[alias].creation.set_as_test_mirror(
- connections[mirror_alias].settings_dict)
+ connections[mirror_alias].settings_dict
+ )
if debug_sql:
for alias in connections:
@@ -246,8 +258,8 @@ def iter_test_cases(tests):
# Prevent an unfriendly RecursionError that can happen with
# strings.
raise TypeError(
- f'Test {test!r} must be a test case or test suite not string '
- f'(was found in {tests!r}).'
+ f"Test {test!r} must be a test case or test suite not string "
+ f"(was found in {tests!r})."
)
if isinstance(test, TestCase):
yield test
@@ -319,18 +331,18 @@ def get_unique_databases_and_mirrors(aliases=None):
for alias in connections:
connection = connections[alias]
- test_settings = connection.settings_dict['TEST']
+ test_settings = connection.settings_dict["TEST"]
- if test_settings['MIRROR']:
+ if test_settings["MIRROR"]:
# If the database is marked as a test mirror, save the alias.
- mirrored_aliases[alias] = test_settings['MIRROR']
+ mirrored_aliases[alias] = test_settings["MIRROR"]
elif alias in aliases:
# Store a tuple with DB parameters that uniquely identify it.
# If we have two aliases with the same values for that tuple,
# we only need to create the test database once.
item = test_databases.setdefault(
connection.creation.test_db_signature(),
- (connection.settings_dict['NAME'], []),
+ (connection.settings_dict["NAME"], []),
)
# The default database must be the first because data migrations
# use the default alias by default.
@@ -339,11 +351,16 @@ def get_unique_databases_and_mirrors(aliases=None):
else:
item[1].append(alias)
- if 'DEPENDENCIES' in test_settings:
- dependencies[alias] = test_settings['DEPENDENCIES']
+ if "DEPENDENCIES" in test_settings:
+ dependencies[alias] = test_settings["DEPENDENCIES"]
else:
- if alias != DEFAULT_DB_ALIAS and connection.creation.test_db_signature() != default_sig:
- dependencies[alias] = test_settings.get('DEPENDENCIES', [DEFAULT_DB_ALIAS])
+ if (
+ alias != DEFAULT_DB_ALIAS
+ and connection.creation.test_db_signature() != default_sig
+ ):
+ dependencies[alias] = test_settings.get(
+ "DEPENDENCIES", [DEFAULT_DB_ALIAS]
+ )
test_databases = dict(dependency_ordered(test_databases.items(), dependencies))
return test_databases, mirrored_aliases
@@ -365,12 +382,12 @@ def teardown_databases(old_config, verbosity, parallel=0, keepdb=False):
def get_runner(settings, test_runner_class=None):
test_runner_class = test_runner_class or settings.TEST_RUNNER
- test_path = test_runner_class.split('.')
+ test_path = test_runner_class.split(".")
# Allow for relative paths
if len(test_path) > 1:
- test_module_name = '.'.join(test_path[:-1])
+ test_module_name = ".".join(test_path[:-1])
else:
- test_module_name = '.'
+ test_module_name = "."
test_module = __import__(test_module_name, {}, {}, test_path[-1])
return getattr(test_module, test_path[-1])
@@ -387,6 +404,7 @@ class TestContextDecorator:
`kwarg_name`: keyword argument passing the return value of enable() if
used as a function decorator.
"""
+
def __init__(self, attr_name=None, kwarg_name=None):
self.attr_name = attr_name
self.kwarg_name = kwarg_name
@@ -416,7 +434,7 @@ class TestContextDecorator:
cls.setUp = setUp
return cls
- raise TypeError('Can only decorate subclasses of unittest.TestCase')
+ raise TypeError("Can only decorate subclasses of unittest.TestCase")
def decorate_callable(self, func):
if asyncio.iscoroutinefunction(func):
@@ -428,13 +446,16 @@ class TestContextDecorator:
if self.kwarg_name:
kwargs[self.kwarg_name] = context
return await func(*args, **kwargs)
+
else:
+
@wraps(func)
def inner(*args, **kwargs):
with self as context:
if self.kwarg_name:
kwargs[self.kwarg_name] = context
return func(*args, **kwargs)
+
return inner
def __call__(self, decorated):
@@ -442,7 +463,7 @@ class TestContextDecorator:
return self.decorate_class(decorated)
elif callable(decorated):
return self.decorate_callable(decorated)
- raise TypeError('Cannot decorate object of type %s' % type(decorated))
+ raise TypeError("Cannot decorate object of type %s" % type(decorated))
class override_settings(TestContextDecorator):
@@ -452,6 +473,7 @@ class override_settings(TestContextDecorator):
with the ``with`` statement. In either event, entering/exiting are called
before and after, respectively, the function/block is executed.
"""
+
enable_exception = None
def __init__(self, **kwargs):
@@ -461,9 +483,9 @@ class override_settings(TestContextDecorator):
def enable(self):
# Keep this code at the beginning to leave the settings unchanged
# in case it raises an exception because INSTALLED_APPS is invalid.
- if 'INSTALLED_APPS' in self.options:
+ if "INSTALLED_APPS" in self.options:
try:
- apps.set_installed_apps(self.options['INSTALLED_APPS'])
+ apps.set_installed_apps(self.options["INSTALLED_APPS"])
except Exception:
apps.unset_installed_apps()
raise
@@ -476,14 +498,16 @@ class override_settings(TestContextDecorator):
try:
setting_changed.send(
sender=settings._wrapped.__class__,
- setting=key, value=new_value, enter=True,
+ setting=key,
+ value=new_value,
+ enter=True,
)
except Exception as exc:
self.enable_exception = exc
self.disable()
def disable(self):
- if 'INSTALLED_APPS' in self.options:
+ if "INSTALLED_APPS" in self.options:
apps.unset_installed_apps()
settings._wrapped = self.wrapped
del self.wrapped
@@ -492,7 +516,9 @@ class override_settings(TestContextDecorator):
new_value = getattr(settings, key, None)
responses_for_setting = setting_changed.send_robust(
sender=settings._wrapped.__class__,
- setting=key, value=new_value, enter=False,
+ setting=key,
+ value=new_value,
+ enter=False,
)
responses.extend(responses_for_setting)
if self.enable_exception is not None:
@@ -515,10 +541,12 @@ class override_settings(TestContextDecorator):
def decorate_class(self, cls):
from django.test import SimpleTestCase
+
if not issubclass(cls, SimpleTestCase):
raise ValueError(
"Only subclasses of Django SimpleTestCase can be decorated "
- "with override_settings")
+ "with override_settings"
+ )
self.save_options(cls)
return cls
@@ -528,6 +556,7 @@ class modify_settings(override_settings):
Like override_settings, but makes it possible to append, prepend, or remove
items instead of redefining the entire list.
"""
+
def __init__(self, *args, **kwargs):
if args:
# Hack used when instantiating from SimpleTestCase.setUpClass.
@@ -543,8 +572,9 @@ class modify_settings(override_settings):
test_func._modified_settings = self.operations
else:
# Duplicate list to prevent subclasses from altering their parent.
- test_func._modified_settings = list(
- test_func._modified_settings) + self.operations
+ test_func._modified_settings = (
+ list(test_func._modified_settings) + self.operations
+ )
def enable(self):
self.options = {}
@@ -559,11 +589,11 @@ class modify_settings(override_settings):
# items my be a single value or an iterable.
if isinstance(items, str):
items = [items]
- if action == 'append':
+ if action == "append":
value = value + [item for item in items if item not in value]
- elif action == 'prepend':
+ elif action == "prepend":
value = [item for item in items if item not in value] + value
- elif action == 'remove':
+ elif action == "remove":
value = [item for item in value if item not in items]
else:
raise ValueError("Unsupported action: %s" % action)
@@ -577,8 +607,10 @@ class override_system_checks(TestContextDecorator):
Useful when you override `INSTALLED_APPS`, e.g. if you exclude `auth` app,
you also need to exclude its system checks.
"""
+
def __init__(self, new_checks, deployment_checks=None):
from django.core.checks.registry import registry
+
self.registry = registry
self.new_checks = new_checks
self.deployment_checks = deployment_checks
@@ -588,12 +620,12 @@ class override_system_checks(TestContextDecorator):
self.old_checks = self.registry.registered_checks
self.registry.registered_checks = set()
for check in self.new_checks:
- self.registry.register(check, *getattr(check, 'tags', ()))
+ self.registry.register(check, *getattr(check, "tags", ()))
self.old_deployment_checks = self.registry.deployment_checks
if self.deployment_checks is not None:
self.registry.deployment_checks = set()
for check in self.deployment_checks:
- self.registry.register(check, *getattr(check, 'tags', ()), deploy=True)
+ self.registry.register(check, *getattr(check, "tags", ()), deploy=True)
def disable(self):
self.registry.registered_checks = self.old_checks
@@ -609,18 +641,18 @@ def compare_xml(want, got):
Based on https://github.com/lxml/lxml/blob/master/src/lxml/doctestcompare.py
"""
- _norm_whitespace_re = re.compile(r'[ \t\n][ \t\n]+')
+ _norm_whitespace_re = re.compile(r"[ \t\n][ \t\n]+")
def norm_whitespace(v):
- return _norm_whitespace_re.sub(' ', v)
+ return _norm_whitespace_re.sub(" ", v)
def child_text(element):
- return ''.join(c.data for c in element.childNodes
- if c.nodeType == Node.TEXT_NODE)
+ return "".join(
+ c.data for c in element.childNodes if c.nodeType == Node.TEXT_NODE
+ )
def children(element):
- return [c for c in element.childNodes
- if c.nodeType == Node.ELEMENT_NODE]
+ return [c for c in element.childNodes if c.nodeType == Node.ELEMENT_NODE]
def norm_child_text(element):
return norm_whitespace(child_text(element))
@@ -639,7 +671,9 @@ def compare_xml(want, got):
got_children = children(got_element)
if len(want_children) != len(got_children):
return False
- return all(check_element(want, got) for want, got in zip(want_children, got_children))
+ return all(
+ check_element(want, got) for want, got in zip(want_children, got_children)
+ )
def first_node(document):
for node in document.childNodes:
@@ -650,13 +684,13 @@ def compare_xml(want, got):
):
return node
- want = want.strip().replace('\\n', '\n')
- got = got.strip().replace('\\n', '\n')
+ want = want.strip().replace("\\n", "\n")
+ got = got.strip().replace("\\n", "\n")
# If the string is not a complete xml document, we may need to add a
# root element. This allow us to compare fragments, like "<foo/><bar/>"
- if not want.startswith('<?xml'):
- wrapper = '<root>%s</root>'
+ if not want.startswith("<?xml"):
+ wrapper = "<root>%s</root>"
want = wrapper % want
got = wrapper % got
@@ -671,6 +705,7 @@ class CaptureQueriesContext:
"""
Context manager that captures queries executed by the specified connection.
"""
+
def __init__(self, connection):
self.connection = connection
@@ -685,7 +720,7 @@ class CaptureQueriesContext:
@property
def captured_queries(self):
- return self.connection.queries[self.initial_queries:self.final_queries]
+ return self.connection.queries[self.initial_queries : self.final_queries]
def __enter__(self):
self.force_debug_cursor = self.connection.force_debug_cursor
@@ -709,7 +744,7 @@ class CaptureQueriesContext:
class ignore_warnings(TestContextDecorator):
def __init__(self, **kwargs):
self.ignore_kwargs = kwargs
- if 'message' in self.ignore_kwargs or 'module' in self.ignore_kwargs:
+ if "message" in self.ignore_kwargs or "module" in self.ignore_kwargs:
self.filter_func = warnings.filterwarnings
else:
self.filter_func = warnings.simplefilter
@@ -718,7 +753,7 @@ class ignore_warnings(TestContextDecorator):
def enable(self):
self.catch_warnings = warnings.catch_warnings()
self.catch_warnings.__enter__()
- self.filter_func('ignore', **self.ignore_kwargs)
+ self.filter_func("ignore", **self.ignore_kwargs)
def disable(self):
self.catch_warnings.__exit__(*sys.exc_info())
@@ -732,7 +767,7 @@ class ignore_warnings(TestContextDecorator):
requires_tz_support = skipUnless(
TZ_SUPPORT,
"This test relies on the ability to run a program in an arbitrary "
- "time zone, but your operating system isn't able to do that."
+ "time zone, but your operating system isn't able to do that.",
)
@@ -775,9 +810,9 @@ def captured_output(stream_name):
def captured_stdout():
"""Capture the output of sys.stdout:
- with captured_stdout() as stdout:
- print("hello")
- self.assertEqual(stdout.getvalue(), "hello\n")
+ with captured_stdout() as stdout:
+ print("hello")
+ self.assertEqual(stdout.getvalue(), "hello\n")
"""
return captured_output("stdout")
@@ -785,9 +820,9 @@ def captured_stdout():
def captured_stderr():
"""Capture the output of sys.stderr:
- with captured_stderr() as stderr:
- print("hello", file=sys.stderr)
- self.assertEqual(stderr.getvalue(), "hello\n")
+ with captured_stderr() as stderr:
+ print("hello", file=sys.stderr)
+ self.assertEqual(stderr.getvalue(), "hello\n")
"""
return captured_output("stderr")
@@ -795,12 +830,12 @@ def captured_stderr():
def captured_stdin():
"""Capture the input to sys.stdin:
- with captured_stdin() as stdin:
- stdin.write('hello\n')
- stdin.seek(0)
- # call test code that consumes from sys.stdin
- captured = input()
- self.assertEqual(captured, "hello")
+ with captured_stdin() as stdin:
+ stdin.write('hello\n')
+ stdin.seek(0)
+ # call test code that consumes from sys.stdin
+ captured = input()
+ self.assertEqual(captured, "hello")
"""
return captured_output("stdin")
@@ -828,18 +863,24 @@ def require_jinja2(test_func):
Django template engine for a test or skip it if Jinja2 isn't available.
"""
test_func = skipIf(jinja2 is None, "this test requires jinja2")(test_func)
- return override_settings(TEMPLATES=[{
- 'BACKEND': 'django.template.backends.django.DjangoTemplates',
- 'APP_DIRS': True,
- }, {
- 'BACKEND': 'django.template.backends.jinja2.Jinja2',
- 'APP_DIRS': True,
- 'OPTIONS': {'keep_trailing_newline': True},
- }])(test_func)
+ return override_settings(
+ TEMPLATES=[
+ {
+ "BACKEND": "django.template.backends.django.DjangoTemplates",
+ "APP_DIRS": True,
+ },
+ {
+ "BACKEND": "django.template.backends.jinja2.Jinja2",
+ "APP_DIRS": True,
+ "OPTIONS": {"keep_trailing_newline": True},
+ },
+ ]
+ )(test_func)
class override_script_prefix(TestContextDecorator):
"""Decorator or context manager to temporary override the script prefix."""
+
def __init__(self, prefix):
self.prefix = prefix
super().__init__()
@@ -857,8 +898,9 @@ class LoggingCaptureMixin:
Capture the output from the 'django' logger and store it on the class's
logger_output attribute.
"""
+
def setUp(self):
- self.logger = logging.getLogger('django')
+ self.logger = logging.getLogger("django")
self.old_stream = self.logger.handlers[0].stream
self.logger_output = StringIO()
self.logger.handlers[0].stream = self.logger_output
@@ -883,6 +925,7 @@ class isolate_apps(TestContextDecorator):
`kwarg_name`: keyword argument passing the isolated registry if used as a
function decorator.
"""
+
def __init__(self, *installed_apps, **kwargs):
self.installed_apps = installed_apps
super().__init__(**kwargs)
@@ -890,11 +933,11 @@ class isolate_apps(TestContextDecorator):
def enable(self):
self.old_apps = Options.default_apps
apps = Apps(self.installed_apps)
- setattr(Options, 'default_apps', apps)
+ setattr(Options, "default_apps", apps)
return apps
def disable(self):
- setattr(Options, 'default_apps', self.old_apps)
+ setattr(Options, "default_apps", self.old_apps)
class TimeKeeper:
@@ -914,7 +957,7 @@ class TimeKeeper:
def print_results(self):
for name, end_times in self.records.items():
for record_time in end_times:
- record = '%s took %.3fs' % (name, record_time)
+ record = "%s took %.3fs" % (name, record_time)
sys.stderr.write(record + os.linesep)
@@ -929,12 +972,14 @@ class NullTimeKeeper:
def tag(*tags):
"""Decorator to add tags to a test class or method."""
+
def decorator(obj):
- if hasattr(obj, 'tags'):
+ if hasattr(obj, "tags"):
obj.tags = obj.tags.union(tags)
else:
- setattr(obj, 'tags', set(tags))
+ setattr(obj, "tags", set(tags))
return obj
+
return decorator