summaryrefslogtreecommitdiff
path: root/django
diff options
context:
space:
mode:
authorAndrew Godwin <andrew@aeracode.org>2020-02-12 15:15:00 -0700
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2020-03-18 19:59:12 +0100
commitfc0fa72ff4cdbf5861a366e31cb8bbacd44da22d (patch)
treed419ce531586808b0a111664907b859cb6d22862 /django
parent3f7e4b16bf58f99c71570ba75dc97db8265071be (diff)
Fixed #31224 -- Added support for asynchronous views and middleware.
This implements support for asynchronous views, asynchronous tests, asynchronous middleware, and an asynchronous test client.
Diffstat (limited to 'django')
-rw-r--r--django/contrib/sessions/middleware.py1
-rw-r--r--django/core/handlers/asgi.py13
-rw-r--r--django/core/handlers/base.py236
-rw-r--r--django/core/handlers/exception.py29
-rw-r--r--django/test/__init__.py13
-rw-r--r--django/test/client.py408
-rw-r--r--django/test/testcases.py4
-rw-r--r--django/test/utils.py23
-rw-r--r--django/utils/decorators.py27
-rw-r--r--django/utils/deprecation.py33
10 files changed, 617 insertions, 170 deletions
diff --git a/django/contrib/sessions/middleware.py b/django/contrib/sessions/middleware.py
index e76c08ee5d..63013eef7a 100644
--- a/django/contrib/sessions/middleware.py
+++ b/django/contrib/sessions/middleware.py
@@ -15,6 +15,7 @@ class SessionMiddleware(MiddlewareMixin):
def __init__(self, get_response=None):
self._get_response_none_deprecation(get_response)
self.get_response = get_response
+ self._async_check()
engine = import_module(settings.SESSION_ENGINE)
self.SessionStore = engine.SessionStore
diff --git a/django/core/handlers/asgi.py b/django/core/handlers/asgi.py
index bb782dad9b..82d2e1ab9d 100644
--- a/django/core/handlers/asgi.py
+++ b/django/core/handlers/asgi.py
@@ -1,4 +1,3 @@
-import asyncio
import logging
import sys
import tempfile
@@ -132,7 +131,7 @@ class ASGIHandler(base.BaseHandler):
def __init__(self):
super().__init__()
- self.load_middleware()
+ self.load_middleware(is_async=True)
async def __call__(self, scope, receive, send):
"""
@@ -158,12 +157,8 @@ class ASGIHandler(base.BaseHandler):
if request is None:
await self.send_response(error_response, send)
return
- # Get the response, using a threadpool via sync_to_async, if needed.
- if asyncio.iscoroutinefunction(self.get_response):
- response = await self.get_response(request)
- else:
- # If get_response is synchronous, run it non-blocking.
- response = await sync_to_async(self.get_response)(request)
+ # Get the response, using the async mode of BaseHandler.
+ response = await self.get_response_async(request)
response._handler_class = self.__class__
# Increase chunk size on file responses (ASGI servers handles low-level
# chunking).
@@ -264,7 +259,7 @@ class ASGIHandler(base.BaseHandler):
'body': chunk,
'more_body': not last,
})
- response.close()
+ await sync_to_async(response.close)()
@classmethod
def chunk_bytes(cls, data):
diff --git a/django/core/handlers/base.py b/django/core/handlers/base.py
index 418bc7a46b..e7fbaf594e 100644
--- a/django/core/handlers/base.py
+++ b/django/core/handlers/base.py
@@ -1,6 +1,9 @@
+import asyncio
import logging
import types
+from asgiref.sync import async_to_sync, sync_to_async
+
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured, MiddlewareNotUsed
from django.core.signals import request_finished
@@ -20,7 +23,7 @@ class BaseHandler:
_exception_middleware = None
_middleware_chain = None
- def load_middleware(self):
+ def load_middleware(self, is_async=False):
"""
Populate middleware lists from settings.MIDDLEWARE.
@@ -30,10 +33,28 @@ class BaseHandler:
self._template_response_middleware = []
self._exception_middleware = []
- handler = convert_exception_to_response(self._get_response)
+ get_response = self._get_response_async if is_async else self._get_response
+ handler = convert_exception_to_response(get_response)
+ handler_is_async = is_async
for middleware_path in reversed(settings.MIDDLEWARE):
middleware = import_string(middleware_path)
+ middleware_can_sync = getattr(middleware, 'sync_capable', True)
+ middleware_can_async = getattr(middleware, 'async_capable', False)
+ if not middleware_can_sync and not middleware_can_async:
+ raise RuntimeError(
+ 'Middleware %s must have at least one of '
+ 'sync_capable/async_capable set to True.' % middleware_path
+ )
+ elif not handler_is_async and middleware_can_sync:
+ middleware_is_async = False
+ else:
+ middleware_is_async = middleware_can_async
try:
+ # Adapt handler, if needed.
+ handler = self.adapt_method_mode(
+ middleware_is_async, handler, handler_is_async,
+ debug=settings.DEBUG, name='middleware %s' % middleware_path,
+ )
mw_instance = middleware(handler)
except MiddlewareNotUsed as exc:
if settings.DEBUG:
@@ -49,24 +70,56 @@ class BaseHandler:
)
if hasattr(mw_instance, 'process_view'):
- self._view_middleware.insert(0, mw_instance.process_view)
+ self._view_middleware.insert(
+ 0,
+ self.adapt_method_mode(is_async, mw_instance.process_view),
+ )
if hasattr(mw_instance, 'process_template_response'):
- self._template_response_middleware.append(mw_instance.process_template_response)
+ self._template_response_middleware.append(
+ self.adapt_method_mode(is_async, mw_instance.process_template_response),
+ )
if hasattr(mw_instance, 'process_exception'):
- self._exception_middleware.append(mw_instance.process_exception)
+ # The exception-handling stack is still always synchronous for
+ # now, so adapt that way.
+ self._exception_middleware.append(
+ self.adapt_method_mode(False, mw_instance.process_exception),
+ )
handler = convert_exception_to_response(mw_instance)
+ handler_is_async = middleware_is_async
+ # Adapt the top of the stack, if needed.
+ handler = self.adapt_method_mode(is_async, handler, handler_is_async)
# We only assign to this when initialization is complete as it is used
# as a flag for initialization being complete.
self._middleware_chain = handler
- def make_view_atomic(self, view):
- non_atomic_requests = getattr(view, '_non_atomic_requests', set())
- for db in connections.all():
- if db.settings_dict['ATOMIC_REQUESTS'] and db.alias not in non_atomic_requests:
- view = transaction.atomic(using=db.alias)(view)
- return view
+ def adapt_method_mode(
+ self, is_async, method, method_is_async=None, debug=False, name=None,
+ ):
+ """
+ Adapt a method to be in the correct "mode":
+ - If is_async is False:
+ - Synchronous methods are left alone
+ - Asynchronous methods are wrapped with async_to_sync
+ - If is_async is True:
+ - Synchronous methods are wrapped with sync_to_async()
+ - Asynchronous methods are left alone
+ """
+ if method_is_async is None:
+ method_is_async = asyncio.iscoroutinefunction(method)
+ if debug and not name:
+ name = name or 'method %s()' % method.__qualname__
+ if is_async:
+ if not method_is_async:
+ if debug:
+ logger.debug('Synchronous %s adapted.', name)
+ return sync_to_async(method, thread_sensitive=True)
+ elif method_is_async:
+ if debug:
+ logger.debug('Asynchronous %s adapted.' % name)
+ return async_to_sync(method)
+ return method
def get_response(self, request):
"""Return an HttpResponse object for the given HttpRequest."""
@@ -82,6 +135,26 @@ class BaseHandler:
)
return response
+ async def get_response_async(self, request):
+ """
+ Asynchronous version of get_response.
+
+ Funneling everything, including WSGI, into a single async
+ get_response() is too slow. Avoid the context switch by using
+ a separate async response path.
+ """
+ # Setup default url resolver for this thread.
+ set_urlconf(settings.ROOT_URLCONF)
+ response = await self._middleware_chain(request)
+ response._resource_closers.append(request.close)
+ if response.status_code >= 400:
+ await sync_to_async(log_response)(
+ '%s: %s', response.reason_phrase, request.path,
+ response=response,
+ request=request,
+ )
+ return response
+
def _get_response(self, request):
"""
Resolve and call the view, then apply view, exception, and
@@ -89,17 +162,7 @@ class BaseHandler:
inside the request/response middleware.
"""
response = None
-
- if hasattr(request, 'urlconf'):
- urlconf = request.urlconf
- set_urlconf(urlconf)
- resolver = get_resolver(urlconf)
- else:
- resolver = get_resolver()
-
- resolver_match = resolver.resolve(request.path_info)
- callback, callback_args, callback_kwargs = resolver_match
- request.resolver_match = resolver_match
+ callback, callback_args, callback_kwargs = self.resolve_request(request)
# Apply view middleware
for middleware_method in self._view_middleware:
@@ -109,6 +172,9 @@ class BaseHandler:
if response is None:
wrapped_callback = self.make_view_atomic(callback)
+ # If it is an asynchronous view, run it in a subthread.
+ if asyncio.iscoroutinefunction(wrapped_callback):
+ wrapped_callback = async_to_sync(wrapped_callback)
try:
response = wrapped_callback(request, *callback_args, **callback_kwargs)
except Exception as e:
@@ -137,20 +203,89 @@ class BaseHandler:
return response
- def process_exception_by_middleware(self, exception, request):
+ async def _get_response_async(self, request):
"""
- Pass the exception to the exception middleware. If no middleware
- return a response for this exception, raise it.
+ Resolve and call the view, then apply view, exception, and
+ template_response middleware. This method is everything that happens
+ inside the request/response middleware.
"""
- for middleware_method in self._exception_middleware:
- response = middleware_method(request, exception)
+ response = None
+ callback, callback_args, callback_kwargs = self.resolve_request(request)
+
+ # Apply view middleware.
+ for middleware_method in self._view_middleware:
+ response = await middleware_method(request, callback, callback_args, callback_kwargs)
if response:
- return response
- raise
+ break
+
+ if response is None:
+ wrapped_callback = self.make_view_atomic(callback)
+ # If it is a synchronous view, run it in a subthread
+ if not asyncio.iscoroutinefunction(wrapped_callback):
+ wrapped_callback = sync_to_async(wrapped_callback, thread_sensitive=True)
+ try:
+ response = await wrapped_callback(request, *callback_args, **callback_kwargs)
+ except Exception as e:
+ response = await sync_to_async(
+ self.process_exception_by_middleware,
+ thread_sensitive=True,
+ )(e, request)
+
+ # Complain if the view returned None or an uncalled coroutine.
+ self.check_response(response, callback)
+
+ # If the response supports deferred rendering, apply template
+ # response middleware and then render the response
+ if hasattr(response, 'render') and callable(response.render):
+ for middleware_method in self._template_response_middleware:
+ response = await middleware_method(request, response)
+ # Complain if the template response middleware returned None or
+ # an uncalled coroutine.
+ self.check_response(
+ response,
+ middleware_method,
+ name='%s.process_template_response' % (
+ middleware_method.__self__.__class__.__name__,
+ )
+ )
+ try:
+ if asyncio.iscoroutinefunction(response.render):
+ response = await response.render()
+ else:
+ response = await sync_to_async(response.render, thread_sensitive=True)()
+ except Exception as e:
+ response = await sync_to_async(
+ self.process_exception_by_middleware,
+ thread_sensitive=True,
+ )(e, request)
+
+ # Make sure the response is not a coroutine
+ if asyncio.iscoroutine(response):
+ raise RuntimeError('Response is still a coroutine.')
+ return response
+
+ def resolve_request(self, request):
+ """
+ Retrieve/set the urlconf for the request. Return the view resolved,
+ with its args and kwargs.
+ """
+ # Work out the resolver.
+ if hasattr(request, 'urlconf'):
+ urlconf = request.urlconf
+ set_urlconf(urlconf)
+ resolver = get_resolver(urlconf)
+ else:
+ resolver = get_resolver()
+ # Resolve the view, and assign the match object back to the request.
+ resolver_match = resolver.resolve(request.path_info)
+ request.resolver_match = resolver_match
+ return resolver_match
def check_response(self, response, callback, name=None):
- """Raise an error if the view returned None."""
- if response is not None:
+ """
+ Raise an error if the view returned None or an uncalled coroutine.
+ """
+ if not(response is None or asyncio.iscoroutine(response)):
return
if not name:
if isinstance(callback, types.FunctionType): # FBV
@@ -160,10 +295,41 @@ class BaseHandler:
callback.__module__,
callback.__class__.__name__,
)
- raise ValueError(
- "%s didn't return an HttpResponse object. It returned None "
- "instead." % name
- )
+ if response is None:
+ raise ValueError(
+ "%s didn't return an HttpResponse object. It returned None "
+ "instead." % name
+ )
+ elif asyncio.iscoroutine(response):
+ raise ValueError(
+ "%s didn't return an HttpResponse object. It returned an "
+ "unawaited coroutine instead. You may need to add an 'await' "
+ "into your view." % name
+ )
+
+ # Other utility methods.
+
+ def make_view_atomic(self, view):
+ non_atomic_requests = getattr(view, '_non_atomic_requests', set())
+ for db in connections.all():
+ if db.settings_dict['ATOMIC_REQUESTS'] and db.alias not in non_atomic_requests:
+ if asyncio.iscoroutinefunction(view):
+ raise RuntimeError(
+ 'You cannot use ATOMIC_REQUESTS with async views.'
+ )
+ view = transaction.atomic(using=db.alias)(view)
+ return view
+
+ def process_exception_by_middleware(self, exception, request):
+ """
+ Pass the exception to the exception middleware. If no middleware
+ return a response for this exception, raise it.
+ """
+ for middleware_method in self._exception_middleware:
+ response = middleware_method(request, exception)
+ if response:
+ return response
+ raise
def reset_urlconf(sender, **kwargs):
diff --git a/django/core/handlers/exception.py b/django/core/handlers/exception.py
index 66443ce560..50880f2784 100644
--- a/django/core/handlers/exception.py
+++ b/django/core/handlers/exception.py
@@ -1,7 +1,10 @@
+import asyncio
import logging
import sys
from functools import wraps
+from asgiref.sync import sync_to_async
+
from django.conf import settings
from django.core import signals
from django.core.exceptions import (
@@ -28,14 +31,24 @@ def convert_exception_to_response(get_response):
no middleware leaks an exception and that the next middleware in the stack
can rely on getting a response instead of an exception.
"""
- @wraps(get_response)
- def inner(request):
- try:
- response = get_response(request)
- except Exception as exc:
- response = response_for_exception(request, exc)
- return response
- return inner
+ if asyncio.iscoroutinefunction(get_response):
+ @wraps(get_response)
+ async def inner(request):
+ try:
+ response = await get_response(request)
+ except Exception as exc:
+ response = await sync_to_async(response_for_exception)(request, exc)
+ return response
+ return inner
+ else:
+ @wraps(get_response)
+ def inner(request):
+ try:
+ response = get_response(request)
+ except Exception as exc:
+ response = response_for_exception(request, exc)
+ return response
+ return inner
def response_for_exception(request, exc):
diff --git a/django/test/__init__.py b/django/test/__init__.py
index 4782d72184..d1f953a8dd 100644
--- a/django/test/__init__.py
+++ b/django/test/__init__.py
@@ -1,6 +1,8 @@
"""Django Unit Test framework."""
-from django.test.client import Client, RequestFactory
+from django.test.client import (
+ AsyncClient, AsyncRequestFactory, Client, RequestFactory,
+)
from django.test.testcases import (
LiveServerTestCase, SimpleTestCase, TestCase, TransactionTestCase,
skipIfDBFeature, skipUnlessAnyDBFeature, skipUnlessDBFeature,
@@ -11,8 +13,9 @@ from django.test.utils import (
)
__all__ = [
- 'Client', 'RequestFactory', 'TestCase', 'TransactionTestCase',
- 'SimpleTestCase', 'LiveServerTestCase', 'skipIfDBFeature',
- 'skipUnlessAnyDBFeature', 'skipUnlessDBFeature', 'ignore_warnings',
- 'modify_settings', 'override_settings', 'override_system_checks', 'tag',
+ 'AsyncClient', 'AsyncRequestFactory', 'Client', 'RequestFactory',
+ 'TestCase', 'TransactionTestCase', 'SimpleTestCase', 'LiveServerTestCase',
+ 'skipIfDBFeature', 'skipUnlessAnyDBFeature', 'skipUnlessDBFeature',
+ 'ignore_warnings', 'modify_settings', 'override_settings',
+ 'override_system_checks', 'tag',
]
diff --git a/django/test/client.py b/django/test/client.py
index 34fc9f3cf1..c460832461 100644
--- a/django/test/client.py
+++ b/django/test/client.py
@@ -9,7 +9,10 @@ from importlib import import_module
from io import BytesIO
from urllib.parse import unquote_to_bytes, urljoin, urlparse, urlsplit
+from asgiref.sync import sync_to_async
+
from django.conf import settings
+from django.core.handlers.asgi import ASGIRequest
from django.core.handlers.base import BaseHandler
from django.core.handlers.wsgi import WSGIRequest
from django.core.serializers.json import DjangoJSONEncoder
@@ -157,6 +160,52 @@ class ClientHandler(BaseHandler):
return response
+class AsyncClientHandler(BaseHandler):
+ """An async version of ClientHandler."""
+ def __init__(self, enforce_csrf_checks=True, *args, **kwargs):
+ self.enforce_csrf_checks = enforce_csrf_checks
+ super().__init__(*args, **kwargs)
+
+ async def __call__(self, scope):
+ # Set up middleware if needed. We couldn't do this earlier, because
+ # settings weren't available.
+ if self._middleware_chain is None:
+ self.load_middleware(is_async=True)
+ # Extract body file from the scope, if provided.
+ if '_body_file' in scope:
+ body_file = scope.pop('_body_file')
+ else:
+ body_file = FakePayload('')
+
+ request_started.disconnect(close_old_connections)
+ await sync_to_async(request_started.send)(sender=self.__class__, scope=scope)
+ request_started.connect(close_old_connections)
+ request = ASGIRequest(scope, body_file)
+ # Sneaky little hack so that we can easily get round
+ # CsrfViewMiddleware. This makes life easier, and is probably required
+ # for backwards compatibility with external tests against admin views.
+ request._dont_enforce_csrf_checks = not self.enforce_csrf_checks
+ # Request goes through middleware.
+ response = await self.get_response_async(request)
+ # Simulate behaviors of most Web servers.
+ conditional_content_removal(request, response)
+ # Attach the originating ASGI request to the response so that it could
+ # be later retrieved.
+ response.asgi_request = request
+ # Emulate a server by calling the close method on completion.
+ if response.streaming:
+ response.streaming_content = await sync_to_async(closing_iterator_wrapper)(
+ response.streaming_content,
+ response.close,
+ )
+ else:
+ request_finished.disconnect(close_old_connections)
+ # Will fire request_finished.
+ await sync_to_async(response.close)()
+ request_finished.connect(close_old_connections)
+ return response
+
+
def store_rendered_templates(store, signal, sender, template, context, **kwargs):
"""
Store templates and contexts that are rendered.
@@ -421,7 +470,194 @@ class RequestFactory:
return self.request(**r)
-class Client(RequestFactory):
+class AsyncRequestFactory(RequestFactory):
+ """
+ Class that lets you create mock ASGI-like Request objects for use in
+ testing. Usage:
+
+ rf = AsyncRequestFactory()
+ get_request = await rf.get('/hello/')
+ post_request = await rf.post('/submit/', {'foo': 'bar'})
+
+ Once you have a request object you can pass it to any view function,
+ including synchronous ones. The reason we have a separate class here is:
+ a) this makes ASGIRequest subclasses, and
+ b) AsyncTestClient can subclass it.
+ """
+ def _base_scope(self, **request):
+ """The base scope for a request."""
+ # This is a minimal valid ASGI scope, plus:
+ # - headers['cookie'] for cookie support,
+ # - 'client' often useful, see #8551.
+ scope = {
+ 'asgi': {'version': '3.0'},
+ 'type': 'http',
+ 'http_version': '1.1',
+ 'client': ['127.0.0.1', 0],
+ 'server': ('testserver', '80'),
+ 'scheme': 'http',
+ 'method': 'GET',
+ 'headers': [],
+ **self.defaults,
+ **request,
+ }
+ scope['headers'].append((
+ b'cookie',
+ b'; '.join(sorted(
+ ('%s=%s' % (morsel.key, morsel.coded_value)).encode('ascii')
+ for morsel in self.cookies.values()
+ )),
+ ))
+ return scope
+
+ def request(self, **request):
+ """Construct a generic request object."""
+ # This is synchronous, which means all methods on this class are.
+ # AsyncClient, however, has an async request function, which makes all
+ # its methods async.
+ if '_body_file' in request:
+ body_file = request.pop('_body_file')
+ else:
+ body_file = FakePayload('')
+ return ASGIRequest(self._base_scope(**request), body_file)
+
+ def generic(
+ self, method, path, data='', content_type='application/octet-stream',
+ secure=False, **extra,
+ ):
+ """Construct an arbitrary HTTP request."""
+ parsed = urlparse(str(path)) # path can be lazy.
+ data = force_bytes(data, settings.DEFAULT_CHARSET)
+ s = {
+ 'method': method,
+ 'path': self._get_path(parsed),
+ 'server': ('127.0.0.1', '443' if secure else '80'),
+ 'scheme': 'https' if secure else 'http',
+ 'headers': [(b'host', b'testserver')],
+ }
+ if data:
+ s['headers'].extend([
+ (b'content-length', bytes(len(data))),
+ (b'content-type', content_type.encode('ascii')),
+ ])
+ s['_body_file'] = FakePayload(data)
+ s.update(extra)
+ # If QUERY_STRING is absent or empty, we want to extract it from the
+ # URL.
+ if not s.get('query_string'):
+ s['query_string'] = parsed[4]
+ return self.request(**s)
+
+
+class ClientMixin:
+ """
+ Mixin with common methods between Client and AsyncClient.
+ """
+ def store_exc_info(self, **kwargs):
+ """Store exceptions when they are generated by a view."""
+ self.exc_info = sys.exc_info()
+
+ def check_exception(self, response):
+ """
+ Look for a signaled exception, clear the current context exception
+ data, re-raise the signaled exception, and clear the signaled exception
+ from the local cache.
+ """
+ response.exc_info = self.exc_info
+ if self.exc_info:
+ _, exc_value, _ = self.exc_info
+ self.exc_info = None
+ if self.raise_request_exception:
+ raise exc_value
+
+ @property
+ def session(self):
+ """Return the current session variables."""
+ engine = import_module(settings.SESSION_ENGINE)
+ cookie = self.cookies.get(settings.SESSION_COOKIE_NAME)
+ if cookie:
+ return engine.SessionStore(cookie.value)
+ session = engine.SessionStore()
+ session.save()
+ self.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
+ return session
+
+ def login(self, **credentials):
+ """
+ Set the Factory to appear as if it has successfully logged into a site.
+
+ Return True if login is possible or False if the provided credentials
+ are incorrect.
+ """
+ from django.contrib.auth import authenticate
+ user = authenticate(**credentials)
+ if user:
+ self._login(user)
+ return True
+ return False
+
+ def force_login(self, user, backend=None):
+ def get_backend():
+ from django.contrib.auth import load_backend
+ for backend_path in settings.AUTHENTICATION_BACKENDS:
+ backend = load_backend(backend_path)
+ if hasattr(backend, 'get_user'):
+ return backend_path
+
+ if backend is None:
+ backend = get_backend()
+ user.backend = backend
+ self._login(user, backend)
+
+ def _login(self, user, backend=None):
+ from django.contrib.auth import login
+ # Create a fake request to store login details.
+ request = HttpRequest()
+ if self.session:
+ request.session = self.session
+ else:
+ engine = import_module(settings.SESSION_ENGINE)
+ request.session = engine.SessionStore()
+ login(request, user, backend)
+ # Save the session values.
+ request.session.save()
+ # Set the cookie to represent the session.
+ session_cookie = settings.SESSION_COOKIE_NAME
+ self.cookies[session_cookie] = request.session.session_key
+ cookie_data = {
+ 'max-age': None,
+ 'path': '/',
+ 'domain': settings.SESSION_COOKIE_DOMAIN,
+ 'secure': settings.SESSION_COOKIE_SECURE or None,
+ 'expires': None,
+ }
+ self.cookies[session_cookie].update(cookie_data)
+
+ def logout(self):
+ """Log out the user by removing the cookies and session object."""
+ from django.contrib.auth import get_user, logout
+ request = HttpRequest()
+ if self.session:
+ request.session = self.session
+ request.user = get_user(request)
+ else:
+ engine = import_module(settings.SESSION_ENGINE)
+ request.session = engine.SessionStore()
+ logout(request)
+ self.cookies = SimpleCookie()
+
+ def _parse_json(self, response, **extra):
+ if not hasattr(response, '_json'):
+ if not JSON_CONTENT_TYPE_RE.match(response.get('Content-Type')):
+ raise ValueError(
+ 'Content-Type header is "%s", not "application/json"'
+ % response.get('Content-Type')
+ )
+ response._json = json.loads(response.content.decode(response.charset), **extra)
+ return response._json
+
+
+class Client(ClientMixin, RequestFactory):
"""
A class that can act as a client for testing purposes.
@@ -446,23 +682,6 @@ class Client(RequestFactory):
self.exc_info = None
self.extra = None
- def store_exc_info(self, **kwargs):
- """Store exceptions when they are generated by a view."""
- self.exc_info = sys.exc_info()
-
- @property
- def session(self):
- """Return the current session variables."""
- engine = import_module(settings.SESSION_ENGINE)
- cookie = self.cookies.get(settings.SESSION_COOKIE_NAME)
- if cookie:
- return engine.SessionStore(cookie.value)
-
- session = engine.SessionStore()
- session.save()
- self.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
- return session
-
def request(self, **request):
"""
The master request method. Compose the environment dictionary and pass
@@ -486,15 +705,8 @@ class Client(RequestFactory):
finally:
signals.template_rendered.disconnect(dispatch_uid=signal_uid)
got_request_exception.disconnect(dispatch_uid=exception_uid)
- # Look for a signaled exception, clear the current context exception
- # data, then re-raise the signaled exception. Also clear the signaled
- # exception from the local cache.
- response.exc_info = self.exc_info
- if self.exc_info:
- _, exc_value, _ = self.exc_info
- self.exc_info = None
- if self.raise_request_exception:
- raise exc_value
+ # Check for signaled exceptions.
+ self.check_exception(response)
# Save the client and request that stimulated the response.
response.client = self
response.request = request
@@ -583,85 +795,6 @@ class Client(RequestFactory):
response = self._handle_redirects(response, data=data, **extra)
return response
- def login(self, **credentials):
- """
- Set the Factory to appear as if it has successfully logged into a site.
-
- Return True if login is possible; False if the provided credentials
- are incorrect.
- """
- from django.contrib.auth import authenticate
- user = authenticate(**credentials)
- if user:
- self._login(user)
- return True
- else:
- return False
-
- def force_login(self, user, backend=None):
- def get_backend():
- from django.contrib.auth import load_backend
- for backend_path in settings.AUTHENTICATION_BACKENDS:
- backend = load_backend(backend_path)
- if hasattr(backend, 'get_user'):
- return backend_path
- if backend is None:
- backend = get_backend()
- user.backend = backend
- self._login(user, backend)
-
- def _login(self, user, backend=None):
- from django.contrib.auth import login
- engine = import_module(settings.SESSION_ENGINE)
-
- # Create a fake request to store login details.
- request = HttpRequest()
-
- if self.session:
- request.session = self.session
- else:
- request.session = engine.SessionStore()
- login(request, user, backend)
-
- # Save the session values.
- request.session.save()
-
- # Set the cookie to represent the session.
- session_cookie = settings.SESSION_COOKIE_NAME
- self.cookies[session_cookie] = request.session.session_key
- cookie_data = {
- 'max-age': None,
- 'path': '/',
- 'domain': settings.SESSION_COOKIE_DOMAIN,
- 'secure': settings.SESSION_COOKIE_SECURE or None,
- 'expires': None,
- }
- self.cookies[session_cookie].update(cookie_data)
-
- def logout(self):
- """Log out the user by removing the cookies and session object."""
- from django.contrib.auth import get_user, logout
-
- request = HttpRequest()
- engine = import_module(settings.SESSION_ENGINE)
- if self.session:
- request.session = self.session
- request.user = get_user(request)
- else:
- request.session = engine.SessionStore()
- logout(request)
- self.cookies = SimpleCookie()
-
- def _parse_json(self, response, **extra):
- if not hasattr(response, '_json'):
- if not JSON_CONTENT_TYPE_RE.match(response.get('Content-Type')):
- raise ValueError(
- 'Content-Type header is "{}", not "application/json"'
- .format(response.get('Content-Type'))
- )
- response._json = json.loads(response.content.decode(response.charset), **extra)
- return response._json
-
def _handle_redirects(self, response, data='', content_type='', **extra):
"""
Follow any redirects by requesting responses from the server using GET.
@@ -714,3 +847,66 @@ class Client(RequestFactory):
raise RedirectCycleError("Too many redirects.", last_response=response)
return response
+
+
+class AsyncClient(ClientMixin, AsyncRequestFactory):
+ """
+ An async version of Client that creates ASGIRequests and calls through an
+ async request path.
+
+ Does not currently support "follow" on its methods.
+ """
+ def __init__(self, enforce_csrf_checks=False, raise_request_exception=True, **defaults):
+ super().__init__(**defaults)
+ self.handler = AsyncClientHandler(enforce_csrf_checks)
+ self.raise_request_exception = raise_request_exception
+ self.exc_info = None
+ self.extra = None
+
+ async def request(self, **request):
+ """
+ The master request method. Compose the scope dictionary and pass to the
+ handler, return the result of the handler. Assume defaults for the
+ query environment, which can be overridden using the arguments to the
+ request.
+ """
+ if 'follow' in request:
+ raise NotImplementedError(
+ 'AsyncClient request methods do not accept the follow '
+ 'parameter.'
+ )
+ scope = self._base_scope(**request)
+ # Curry a data dictionary into an instance of the template renderer
+ # callback function.
+ data = {}
+ on_template_render = partial(store_rendered_templates, data)
+ signal_uid = 'template-render-%s' % id(request)
+ signals.template_rendered.connect(on_template_render, dispatch_uid=signal_uid)
+ # Capture exceptions created by the handler.
+ exception_uid = 'request-exception-%s' % id(request)
+ got_request_exception.connect(self.store_exc_info, dispatch_uid=exception_uid)
+ try:
+ response = await self.handler(scope)
+ finally:
+ signals.template_rendered.disconnect(dispatch_uid=signal_uid)
+ got_request_exception.disconnect(dispatch_uid=exception_uid)
+ # Check for signaled exceptions.
+ self.check_exception(response)
+ # Save the client and request that stimulated the response.
+ response.client = self
+ response.request = request
+ # Add any rendered template detail to the response.
+ response.templates = data.get('templates', [])
+ response.context = data.get('context')
+ response.json = partial(self._parse_json, response)
+ # Attach the ResolverMatch instance to the response.
+ response.resolver_match = SimpleLazyObject(lambda: resolve(request['path']))
+ # Flatten a single context. Not really necessary anymore thanks to the
+ # __getattr__ flattening in ContextList, but has some edge case
+ # backwards compatibility implications.
+ if response.context and len(response.context) == 1:
+ response.context = response.context[0]
+ # Update persistent cookie data.
+ if response.cookies:
+ self.cookies.update(response.cookies)
+ return response
diff --git a/django/test/testcases.py b/django/test/testcases.py
index 2c24708a3b..7ebddf80e5 100644
--- a/django/test/testcases.py
+++ b/django/test/testcases.py
@@ -33,7 +33,7 @@ from django.db import DEFAULT_DB_ALIAS, connection, connections, transaction
from django.forms.fields import CharField
from django.http import QueryDict
from django.http.request import split_domain_port, validate_host
-from django.test.client import Client
+from django.test.client import AsyncClient, Client
from django.test.html import HTMLParseError, parse_html
from django.test.signals import setting_changed, template_rendered
from django.test.utils import (
@@ -151,6 +151,7 @@ class SimpleTestCase(unittest.TestCase):
# The class we'll use for the test client self.client.
# Can be overridden in derived classes.
client_class = Client
+ async_client_class = AsyncClient
_overridden_settings = None
_modified_settings = None
@@ -292,6 +293,7 @@ class SimpleTestCase(unittest.TestCase):
* Clear the mail test outbox.
"""
self.client = self.client_class()
+ self.async_client = self.async_client_class()
mail.outbox = []
def _post_teardown(self):
diff --git a/django/test/utils.py b/django/test/utils.py
index e626667b09..d1f7d19546 100644
--- a/django/test/utils.py
+++ b/django/test/utils.py
@@ -1,3 +1,4 @@
+import asyncio
import logging
import re
import sys
@@ -362,12 +363,22 @@ class TestContextDecorator:
raise TypeError('Can only decorate subclasses of unittest.TestCase')
def decorate_callable(self, func):
- @wraps(func)
- def inner(*args, **kwargs):
- with self as context:
- if self.kwarg_name:
- kwargs[self.kwarg_name] = context
- return func(*args, **kwargs)
+ if asyncio.iscoroutinefunction(func):
+ # If the inner function is an async function, we must execute async
+ # as well so that the `with` statement executes at the right time.
+ @wraps(func)
+ async def inner(*args, **kwargs):
+ with self as context:
+ if self.kwarg_name:
+ kwargs[self.kwarg_name] = context
+ return await func(*args, **kwargs)
+ else:
+ @wraps(func)
+ def inner(*args, **kwargs):
+ with self as context:
+ if self.kwarg_name:
+ kwargs[self.kwarg_name] = context
+ return func(*args, **kwargs)
return inner
def __call__(self, decorated):
diff --git a/django/utils/decorators.py b/django/utils/decorators.py
index bb2e498e46..5c9a5d01c7 100644
--- a/django/utils/decorators.py
+++ b/django/utils/decorators.py
@@ -150,3 +150,30 @@ def make_middleware_decorator(middleware_class):
return _wrapped_view
return _decorator
return _make_decorator
+
+
+def sync_and_async_middleware(func):
+ """
+ Mark a middleware factory as returning a hybrid middleware supporting both
+ types of request.
+ """
+ func.sync_capable = True
+ func.async_capable = True
+ return func
+
+
+def sync_only_middleware(func):
+ """
+ Mark a middleware factory as returning a sync middleware.
+ This is the default.
+ """
+ func.sync_capable = True
+ func.async_capable = False
+ return func
+
+
+def async_only_middleware(func):
+ """Mark a middleware factory as returning an async middleware."""
+ func.sync_capable = False
+ func.async_capable = True
+ return func
diff --git a/django/utils/deprecation.py b/django/utils/deprecation.py
index 81e7c3a15b..6336558a81 100644
--- a/django/utils/deprecation.py
+++ b/django/utils/deprecation.py
@@ -1,6 +1,9 @@
+import asyncio
import inspect
import warnings
+from asgiref.sync import sync_to_async
+
class RemovedInNextVersionWarning(DeprecationWarning):
pass
@@ -80,14 +83,31 @@ class DeprecationInstanceCheck(type):
class MiddlewareMixin:
+ sync_capable = True
+ async_capable = True
+
# RemovedInDjango40Warning: when the deprecation ends, replace with:
# def __init__(self, get_response):
def __init__(self, get_response=None):
self._get_response_none_deprecation(get_response)
self.get_response = get_response
+ self._async_check()
super().__init__()
+ def _async_check(self):
+ """
+ If get_response is a coroutine function, turns us into async mode so
+ a thread is not consumed during a whole request.
+ """
+ if asyncio.iscoroutinefunction(self.get_response):
+ # Mark the class as async-capable, but do the actual switch
+ # inside __call__ to avoid swapping out dunder methods
+ self._is_coroutine = asyncio.coroutines._is_coroutine
+
def __call__(self, request):
+ # Exit out to async mode, if needed
+ if asyncio.iscoroutinefunction(self.get_response):
+ return self.__acall__(request)
response = None
if hasattr(self, 'process_request'):
response = self.process_request(request)
@@ -96,6 +116,19 @@ class MiddlewareMixin:
response = self.process_response(request, response)
return response
+ async def __acall__(self, request):
+ """
+ Async version of __call__ that is swapped in when an async request
+ is running.
+ """
+ response = None
+ if hasattr(self, 'process_request'):
+ response = await sync_to_async(self.process_request)(request)
+ response = response or await self.get_response(request)
+ if hasattr(self, 'process_response'):
+ response = await sync_to_async(self.process_response)(request, response)
+ return response
+
def _get_response_none_deprecation(self, get_response):
if get_response is None:
warnings.warn(