summaryrefslogtreecommitdiff
path: root/django/test/client.py
diff options
context:
space:
mode:
authorAndrew Godwin <andrew@aeracode.org>2020-02-12 15:15:00 -0700
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2020-03-18 19:59:12 +0100
commitfc0fa72ff4cdbf5861a366e31cb8bbacd44da22d (patch)
treed419ce531586808b0a111664907b859cb6d22862 /django/test/client.py
parent3f7e4b16bf58f99c71570ba75dc97db8265071be (diff)
Fixed #31224 -- Added support for asynchronous views and middleware.
This implements support for asynchronous views, asynchronous tests, asynchronous middleware, and an asynchronous test client.
Diffstat (limited to 'django/test/client.py')
-rw-r--r--django/test/client.py408
1 files changed, 302 insertions, 106 deletions
diff --git a/django/test/client.py b/django/test/client.py
index 34fc9f3cf1..c460832461 100644
--- a/django/test/client.py
+++ b/django/test/client.py
@@ -9,7 +9,10 @@ from importlib import import_module
from io import BytesIO
from urllib.parse import unquote_to_bytes, urljoin, urlparse, urlsplit
+from asgiref.sync import sync_to_async
+
from django.conf import settings
+from django.core.handlers.asgi import ASGIRequest
from django.core.handlers.base import BaseHandler
from django.core.handlers.wsgi import WSGIRequest
from django.core.serializers.json import DjangoJSONEncoder
@@ -157,6 +160,52 @@ class ClientHandler(BaseHandler):
return response
+class AsyncClientHandler(BaseHandler):
+ """An async version of ClientHandler."""
+ def __init__(self, enforce_csrf_checks=True, *args, **kwargs):
+ self.enforce_csrf_checks = enforce_csrf_checks
+ super().__init__(*args, **kwargs)
+
+ async def __call__(self, scope):
+ # Set up middleware if needed. We couldn't do this earlier, because
+ # settings weren't available.
+ if self._middleware_chain is None:
+ self.load_middleware(is_async=True)
+ # Extract body file from the scope, if provided.
+ if '_body_file' in scope:
+ body_file = scope.pop('_body_file')
+ else:
+ body_file = FakePayload('')
+
+ request_started.disconnect(close_old_connections)
+ await sync_to_async(request_started.send)(sender=self.__class__, scope=scope)
+ request_started.connect(close_old_connections)
+ request = ASGIRequest(scope, body_file)
+ # Sneaky little hack so that we can easily get round
+ # CsrfViewMiddleware. This makes life easier, and is probably required
+ # for backwards compatibility with external tests against admin views.
+ request._dont_enforce_csrf_checks = not self.enforce_csrf_checks
+ # Request goes through middleware.
+ response = await self.get_response_async(request)
+ # Simulate behaviors of most Web servers.
+ conditional_content_removal(request, response)
+ # Attach the originating ASGI request to the response so that it could
+ # be later retrieved.
+ response.asgi_request = request
+ # Emulate a server by calling the close method on completion.
+ if response.streaming:
+ response.streaming_content = await sync_to_async(closing_iterator_wrapper)(
+ response.streaming_content,
+ response.close,
+ )
+ else:
+ request_finished.disconnect(close_old_connections)
+ # Will fire request_finished.
+ await sync_to_async(response.close)()
+ request_finished.connect(close_old_connections)
+ return response
+
+
def store_rendered_templates(store, signal, sender, template, context, **kwargs):
"""
Store templates and contexts that are rendered.
@@ -421,7 +470,194 @@ class RequestFactory:
return self.request(**r)
-class Client(RequestFactory):
+class AsyncRequestFactory(RequestFactory):
+ """
+ Class that lets you create mock ASGI-like Request objects for use in
+ testing. Usage:
+
+ rf = AsyncRequestFactory()
+ get_request = await rf.get('/hello/')
+ post_request = await rf.post('/submit/', {'foo': 'bar'})
+
+ Once you have a request object you can pass it to any view function,
+ including synchronous ones. The reason we have a separate class here is:
+ a) this makes ASGIRequest subclasses, and
+ b) AsyncTestClient can subclass it.
+ """
+ def _base_scope(self, **request):
+ """The base scope for a request."""
+ # This is a minimal valid ASGI scope, plus:
+ # - headers['cookie'] for cookie support,
+ # - 'client' often useful, see #8551.
+ scope = {
+ 'asgi': {'version': '3.0'},
+ 'type': 'http',
+ 'http_version': '1.1',
+ 'client': ['127.0.0.1', 0],
+ 'server': ('testserver', '80'),
+ 'scheme': 'http',
+ 'method': 'GET',
+ 'headers': [],
+ **self.defaults,
+ **request,
+ }
+ scope['headers'].append((
+ b'cookie',
+ b'; '.join(sorted(
+ ('%s=%s' % (morsel.key, morsel.coded_value)).encode('ascii')
+ for morsel in self.cookies.values()
+ )),
+ ))
+ return scope
+
+ def request(self, **request):
+ """Construct a generic request object."""
+ # This is synchronous, which means all methods on this class are.
+ # AsyncClient, however, has an async request function, which makes all
+ # its methods async.
+ if '_body_file' in request:
+ body_file = request.pop('_body_file')
+ else:
+ body_file = FakePayload('')
+ return ASGIRequest(self._base_scope(**request), body_file)
+
+ def generic(
+ self, method, path, data='', content_type='application/octet-stream',
+ secure=False, **extra,
+ ):
+ """Construct an arbitrary HTTP request."""
+ parsed = urlparse(str(path)) # path can be lazy.
+ data = force_bytes(data, settings.DEFAULT_CHARSET)
+ s = {
+ 'method': method,
+ 'path': self._get_path(parsed),
+ 'server': ('127.0.0.1', '443' if secure else '80'),
+ 'scheme': 'https' if secure else 'http',
+ 'headers': [(b'host', b'testserver')],
+ }
+ if data:
+ s['headers'].extend([
+ (b'content-length', bytes(len(data))),
+ (b'content-type', content_type.encode('ascii')),
+ ])
+ s['_body_file'] = FakePayload(data)
+ s.update(extra)
+ # If QUERY_STRING is absent or empty, we want to extract it from the
+ # URL.
+ if not s.get('query_string'):
+ s['query_string'] = parsed[4]
+ return self.request(**s)
+
+
+class ClientMixin:
+ """
+ Mixin with common methods between Client and AsyncClient.
+ """
+ def store_exc_info(self, **kwargs):
+ """Store exceptions when they are generated by a view."""
+ self.exc_info = sys.exc_info()
+
+ def check_exception(self, response):
+ """
+ Look for a signaled exception, clear the current context exception
+ data, re-raise the signaled exception, and clear the signaled exception
+ from the local cache.
+ """
+ response.exc_info = self.exc_info
+ if self.exc_info:
+ _, exc_value, _ = self.exc_info
+ self.exc_info = None
+ if self.raise_request_exception:
+ raise exc_value
+
+ @property
+ def session(self):
+ """Return the current session variables."""
+ engine = import_module(settings.SESSION_ENGINE)
+ cookie = self.cookies.get(settings.SESSION_COOKIE_NAME)
+ if cookie:
+ return engine.SessionStore(cookie.value)
+ session = engine.SessionStore()
+ session.save()
+ self.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
+ return session
+
+ def login(self, **credentials):
+ """
+ Set the Factory to appear as if it has successfully logged into a site.
+
+ Return True if login is possible or False if the provided credentials
+ are incorrect.
+ """
+ from django.contrib.auth import authenticate
+ user = authenticate(**credentials)
+ if user:
+ self._login(user)
+ return True
+ return False
+
+ def force_login(self, user, backend=None):
+ def get_backend():
+ from django.contrib.auth import load_backend
+ for backend_path in settings.AUTHENTICATION_BACKENDS:
+ backend = load_backend(backend_path)
+ if hasattr(backend, 'get_user'):
+ return backend_path
+
+ if backend is None:
+ backend = get_backend()
+ user.backend = backend
+ self._login(user, backend)
+
+ def _login(self, user, backend=None):
+ from django.contrib.auth import login
+ # Create a fake request to store login details.
+ request = HttpRequest()
+ if self.session:
+ request.session = self.session
+ else:
+ engine = import_module(settings.SESSION_ENGINE)
+ request.session = engine.SessionStore()
+ login(request, user, backend)
+ # Save the session values.
+ request.session.save()
+ # Set the cookie to represent the session.
+ session_cookie = settings.SESSION_COOKIE_NAME
+ self.cookies[session_cookie] = request.session.session_key
+ cookie_data = {
+ 'max-age': None,
+ 'path': '/',
+ 'domain': settings.SESSION_COOKIE_DOMAIN,
+ 'secure': settings.SESSION_COOKIE_SECURE or None,
+ 'expires': None,
+ }
+ self.cookies[session_cookie].update(cookie_data)
+
+ def logout(self):
+ """Log out the user by removing the cookies and session object."""
+ from django.contrib.auth import get_user, logout
+ request = HttpRequest()
+ if self.session:
+ request.session = self.session
+ request.user = get_user(request)
+ else:
+ engine = import_module(settings.SESSION_ENGINE)
+ request.session = engine.SessionStore()
+ logout(request)
+ self.cookies = SimpleCookie()
+
+ def _parse_json(self, response, **extra):
+ if not hasattr(response, '_json'):
+ if not JSON_CONTENT_TYPE_RE.match(response.get('Content-Type')):
+ raise ValueError(
+ 'Content-Type header is "%s", not "application/json"'
+ % response.get('Content-Type')
+ )
+ response._json = json.loads(response.content.decode(response.charset), **extra)
+ return response._json
+
+
+class Client(ClientMixin, RequestFactory):
"""
A class that can act as a client for testing purposes.
@@ -446,23 +682,6 @@ class Client(RequestFactory):
self.exc_info = None
self.extra = None
- def store_exc_info(self, **kwargs):
- """Store exceptions when they are generated by a view."""
- self.exc_info = sys.exc_info()
-
- @property
- def session(self):
- """Return the current session variables."""
- engine = import_module(settings.SESSION_ENGINE)
- cookie = self.cookies.get(settings.SESSION_COOKIE_NAME)
- if cookie:
- return engine.SessionStore(cookie.value)
-
- session = engine.SessionStore()
- session.save()
- self.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
- return session
-
def request(self, **request):
"""
The master request method. Compose the environment dictionary and pass
@@ -486,15 +705,8 @@ class Client(RequestFactory):
finally:
signals.template_rendered.disconnect(dispatch_uid=signal_uid)
got_request_exception.disconnect(dispatch_uid=exception_uid)
- # Look for a signaled exception, clear the current context exception
- # data, then re-raise the signaled exception. Also clear the signaled
- # exception from the local cache.
- response.exc_info = self.exc_info
- if self.exc_info:
- _, exc_value, _ = self.exc_info
- self.exc_info = None
- if self.raise_request_exception:
- raise exc_value
+ # Check for signaled exceptions.
+ self.check_exception(response)
# Save the client and request that stimulated the response.
response.client = self
response.request = request
@@ -583,85 +795,6 @@ class Client(RequestFactory):
response = self._handle_redirects(response, data=data, **extra)
return response
- def login(self, **credentials):
- """
- Set the Factory to appear as if it has successfully logged into a site.
-
- Return True if login is possible; False if the provided credentials
- are incorrect.
- """
- from django.contrib.auth import authenticate
- user = authenticate(**credentials)
- if user:
- self._login(user)
- return True
- else:
- return False
-
- def force_login(self, user, backend=None):
- def get_backend():
- from django.contrib.auth import load_backend
- for backend_path in settings.AUTHENTICATION_BACKENDS:
- backend = load_backend(backend_path)
- if hasattr(backend, 'get_user'):
- return backend_path
- if backend is None:
- backend = get_backend()
- user.backend = backend
- self._login(user, backend)
-
- def _login(self, user, backend=None):
- from django.contrib.auth import login
- engine = import_module(settings.SESSION_ENGINE)
-
- # Create a fake request to store login details.
- request = HttpRequest()
-
- if self.session:
- request.session = self.session
- else:
- request.session = engine.SessionStore()
- login(request, user, backend)
-
- # Save the session values.
- request.session.save()
-
- # Set the cookie to represent the session.
- session_cookie = settings.SESSION_COOKIE_NAME
- self.cookies[session_cookie] = request.session.session_key
- cookie_data = {
- 'max-age': None,
- 'path': '/',
- 'domain': settings.SESSION_COOKIE_DOMAIN,
- 'secure': settings.SESSION_COOKIE_SECURE or None,
- 'expires': None,
- }
- self.cookies[session_cookie].update(cookie_data)
-
- def logout(self):
- """Log out the user by removing the cookies and session object."""
- from django.contrib.auth import get_user, logout
-
- request = HttpRequest()
- engine = import_module(settings.SESSION_ENGINE)
- if self.session:
- request.session = self.session
- request.user = get_user(request)
- else:
- request.session = engine.SessionStore()
- logout(request)
- self.cookies = SimpleCookie()
-
- def _parse_json(self, response, **extra):
- if not hasattr(response, '_json'):
- if not JSON_CONTENT_TYPE_RE.match(response.get('Content-Type')):
- raise ValueError(
- 'Content-Type header is "{}", not "application/json"'
- .format(response.get('Content-Type'))
- )
- response._json = json.loads(response.content.decode(response.charset), **extra)
- return response._json
-
def _handle_redirects(self, response, data='', content_type='', **extra):
"""
Follow any redirects by requesting responses from the server using GET.
@@ -714,3 +847,66 @@ class Client(RequestFactory):
raise RedirectCycleError("Too many redirects.", last_response=response)
return response
+
+
+class AsyncClient(ClientMixin, AsyncRequestFactory):
+ """
+ An async version of Client that creates ASGIRequests and calls through an
+ async request path.
+
+ Does not currently support "follow" on its methods.
+ """
+ def __init__(self, enforce_csrf_checks=False, raise_request_exception=True, **defaults):
+ super().__init__(**defaults)
+ self.handler = AsyncClientHandler(enforce_csrf_checks)
+ self.raise_request_exception = raise_request_exception
+ self.exc_info = None
+ self.extra = None
+
+ async def request(self, **request):
+ """
+ The master request method. Compose the scope dictionary and pass to the
+ handler, return the result of the handler. Assume defaults for the
+ query environment, which can be overridden using the arguments to the
+ request.
+ """
+ if 'follow' in request:
+ raise NotImplementedError(
+ 'AsyncClient request methods do not accept the follow '
+ 'parameter.'
+ )
+ scope = self._base_scope(**request)
+ # Curry a data dictionary into an instance of the template renderer
+ # callback function.
+ data = {}
+ on_template_render = partial(store_rendered_templates, data)
+ signal_uid = 'template-render-%s' % id(request)
+ signals.template_rendered.connect(on_template_render, dispatch_uid=signal_uid)
+ # Capture exceptions created by the handler.
+ exception_uid = 'request-exception-%s' % id(request)
+ got_request_exception.connect(self.store_exc_info, dispatch_uid=exception_uid)
+ try:
+ response = await self.handler(scope)
+ finally:
+ signals.template_rendered.disconnect(dispatch_uid=signal_uid)
+ got_request_exception.disconnect(dispatch_uid=exception_uid)
+ # Check for signaled exceptions.
+ self.check_exception(response)
+ # Save the client and request that stimulated the response.
+ response.client = self
+ response.request = request
+ # Add any rendered template detail to the response.
+ response.templates = data.get('templates', [])
+ response.context = data.get('context')
+ response.json = partial(self._parse_json, response)
+ # Attach the ResolverMatch instance to the response.
+ response.resolver_match = SimpleLazyObject(lambda: resolve(request['path']))
+ # Flatten a single context. Not really necessary anymore thanks to the
+ # __getattr__ flattening in ContextList, but has some edge case
+ # backwards compatibility implications.
+ if response.context and len(response.context) == 1:
+ response.context = response.context[0]
+ # Update persistent cookie data.
+ if response.cookies:
+ self.cookies.update(response.cookies)
+ return response