diff options
| author | Andrew Godwin <andrew@aeracode.org> | 2020-02-12 15:15:00 -0700 |
|---|---|---|
| committer | Mariusz Felisiak <felisiak.mariusz@gmail.com> | 2020-03-18 19:59:12 +0100 |
| commit | fc0fa72ff4cdbf5861a366e31cb8bbacd44da22d (patch) | |
| tree | d419ce531586808b0a111664907b859cb6d22862 /django | |
| parent | 3f7e4b16bf58f99c71570ba75dc97db8265071be (diff) | |
Fixed #31224 -- Added support for asynchronous views and middleware.
This implements support for asynchronous views, asynchronous tests,
asynchronous middleware, and an asynchronous test client.
Diffstat (limited to 'django')
| -rw-r--r-- | django/contrib/sessions/middleware.py | 1 | ||||
| -rw-r--r-- | django/core/handlers/asgi.py | 13 | ||||
| -rw-r--r-- | django/core/handlers/base.py | 236 | ||||
| -rw-r--r-- | django/core/handlers/exception.py | 29 | ||||
| -rw-r--r-- | django/test/__init__.py | 13 | ||||
| -rw-r--r-- | django/test/client.py | 408 | ||||
| -rw-r--r-- | django/test/testcases.py | 4 | ||||
| -rw-r--r-- | django/test/utils.py | 23 | ||||
| -rw-r--r-- | django/utils/decorators.py | 27 | ||||
| -rw-r--r-- | django/utils/deprecation.py | 33 |
10 files changed, 617 insertions, 170 deletions
diff --git a/django/contrib/sessions/middleware.py b/django/contrib/sessions/middleware.py index e76c08ee5d..63013eef7a 100644 --- a/django/contrib/sessions/middleware.py +++ b/django/contrib/sessions/middleware.py @@ -15,6 +15,7 @@ class SessionMiddleware(MiddlewareMixin): def __init__(self, get_response=None): self._get_response_none_deprecation(get_response) self.get_response = get_response + self._async_check() engine = import_module(settings.SESSION_ENGINE) self.SessionStore = engine.SessionStore diff --git a/django/core/handlers/asgi.py b/django/core/handlers/asgi.py index bb782dad9b..82d2e1ab9d 100644 --- a/django/core/handlers/asgi.py +++ b/django/core/handlers/asgi.py @@ -1,4 +1,3 @@ -import asyncio import logging import sys import tempfile @@ -132,7 +131,7 @@ class ASGIHandler(base.BaseHandler): def __init__(self): super().__init__() - self.load_middleware() + self.load_middleware(is_async=True) async def __call__(self, scope, receive, send): """ @@ -158,12 +157,8 @@ class ASGIHandler(base.BaseHandler): if request is None: await self.send_response(error_response, send) return - # Get the response, using a threadpool via sync_to_async, if needed. - if asyncio.iscoroutinefunction(self.get_response): - response = await self.get_response(request) - else: - # If get_response is synchronous, run it non-blocking. - response = await sync_to_async(self.get_response)(request) + # Get the response, using the async mode of BaseHandler. + response = await self.get_response_async(request) response._handler_class = self.__class__ # Increase chunk size on file responses (ASGI servers handles low-level # chunking). @@ -264,7 +259,7 @@ class ASGIHandler(base.BaseHandler): 'body': chunk, 'more_body': not last, }) - response.close() + await sync_to_async(response.close)() @classmethod def chunk_bytes(cls, data): diff --git a/django/core/handlers/base.py b/django/core/handlers/base.py index 418bc7a46b..e7fbaf594e 100644 --- a/django/core/handlers/base.py +++ b/django/core/handlers/base.py @@ -1,6 +1,9 @@ +import asyncio import logging import types +from asgiref.sync import async_to_sync, sync_to_async + from django.conf import settings from django.core.exceptions import ImproperlyConfigured, MiddlewareNotUsed from django.core.signals import request_finished @@ -20,7 +23,7 @@ class BaseHandler: _exception_middleware = None _middleware_chain = None - def load_middleware(self): + def load_middleware(self, is_async=False): """ Populate middleware lists from settings.MIDDLEWARE. @@ -30,10 +33,28 @@ class BaseHandler: self._template_response_middleware = [] self._exception_middleware = [] - handler = convert_exception_to_response(self._get_response) + get_response = self._get_response_async if is_async else self._get_response + handler = convert_exception_to_response(get_response) + handler_is_async = is_async for middleware_path in reversed(settings.MIDDLEWARE): middleware = import_string(middleware_path) + middleware_can_sync = getattr(middleware, 'sync_capable', True) + middleware_can_async = getattr(middleware, 'async_capable', False) + if not middleware_can_sync and not middleware_can_async: + raise RuntimeError( + 'Middleware %s must have at least one of ' + 'sync_capable/async_capable set to True.' % middleware_path + ) + elif not handler_is_async and middleware_can_sync: + middleware_is_async = False + else: + middleware_is_async = middleware_can_async try: + # Adapt handler, if needed. + handler = self.adapt_method_mode( + middleware_is_async, handler, handler_is_async, + debug=settings.DEBUG, name='middleware %s' % middleware_path, + ) mw_instance = middleware(handler) except MiddlewareNotUsed as exc: if settings.DEBUG: @@ -49,24 +70,56 @@ class BaseHandler: ) if hasattr(mw_instance, 'process_view'): - self._view_middleware.insert(0, mw_instance.process_view) + self._view_middleware.insert( + 0, + self.adapt_method_mode(is_async, mw_instance.process_view), + ) if hasattr(mw_instance, 'process_template_response'): - self._template_response_middleware.append(mw_instance.process_template_response) + self._template_response_middleware.append( + self.adapt_method_mode(is_async, mw_instance.process_template_response), + ) if hasattr(mw_instance, 'process_exception'): - self._exception_middleware.append(mw_instance.process_exception) + # The exception-handling stack is still always synchronous for + # now, so adapt that way. + self._exception_middleware.append( + self.adapt_method_mode(False, mw_instance.process_exception), + ) handler = convert_exception_to_response(mw_instance) + handler_is_async = middleware_is_async + # Adapt the top of the stack, if needed. + handler = self.adapt_method_mode(is_async, handler, handler_is_async) # We only assign to this when initialization is complete as it is used # as a flag for initialization being complete. self._middleware_chain = handler - def make_view_atomic(self, view): - non_atomic_requests = getattr(view, '_non_atomic_requests', set()) - for db in connections.all(): - if db.settings_dict['ATOMIC_REQUESTS'] and db.alias not in non_atomic_requests: - view = transaction.atomic(using=db.alias)(view) - return view + def adapt_method_mode( + self, is_async, method, method_is_async=None, debug=False, name=None, + ): + """ + Adapt a method to be in the correct "mode": + - If is_async is False: + - Synchronous methods are left alone + - Asynchronous methods are wrapped with async_to_sync + - If is_async is True: + - Synchronous methods are wrapped with sync_to_async() + - Asynchronous methods are left alone + """ + if method_is_async is None: + method_is_async = asyncio.iscoroutinefunction(method) + if debug and not name: + name = name or 'method %s()' % method.__qualname__ + if is_async: + if not method_is_async: + if debug: + logger.debug('Synchronous %s adapted.', name) + return sync_to_async(method, thread_sensitive=True) + elif method_is_async: + if debug: + logger.debug('Asynchronous %s adapted.' % name) + return async_to_sync(method) + return method def get_response(self, request): """Return an HttpResponse object for the given HttpRequest.""" @@ -82,6 +135,26 @@ class BaseHandler: ) return response + async def get_response_async(self, request): + """ + Asynchronous version of get_response. + + Funneling everything, including WSGI, into a single async + get_response() is too slow. Avoid the context switch by using + a separate async response path. + """ + # Setup default url resolver for this thread. + set_urlconf(settings.ROOT_URLCONF) + response = await self._middleware_chain(request) + response._resource_closers.append(request.close) + if response.status_code >= 400: + await sync_to_async(log_response)( + '%s: %s', response.reason_phrase, request.path, + response=response, + request=request, + ) + return response + def _get_response(self, request): """ Resolve and call the view, then apply view, exception, and @@ -89,17 +162,7 @@ class BaseHandler: inside the request/response middleware. """ response = None - - if hasattr(request, 'urlconf'): - urlconf = request.urlconf - set_urlconf(urlconf) - resolver = get_resolver(urlconf) - else: - resolver = get_resolver() - - resolver_match = resolver.resolve(request.path_info) - callback, callback_args, callback_kwargs = resolver_match - request.resolver_match = resolver_match + callback, callback_args, callback_kwargs = self.resolve_request(request) # Apply view middleware for middleware_method in self._view_middleware: @@ -109,6 +172,9 @@ class BaseHandler: if response is None: wrapped_callback = self.make_view_atomic(callback) + # If it is an asynchronous view, run it in a subthread. + if asyncio.iscoroutinefunction(wrapped_callback): + wrapped_callback = async_to_sync(wrapped_callback) try: response = wrapped_callback(request, *callback_args, **callback_kwargs) except Exception as e: @@ -137,20 +203,89 @@ class BaseHandler: return response - def process_exception_by_middleware(self, exception, request): + async def _get_response_async(self, request): """ - Pass the exception to the exception middleware. If no middleware - return a response for this exception, raise it. + Resolve and call the view, then apply view, exception, and + template_response middleware. This method is everything that happens + inside the request/response middleware. """ - for middleware_method in self._exception_middleware: - response = middleware_method(request, exception) + response = None + callback, callback_args, callback_kwargs = self.resolve_request(request) + + # Apply view middleware. + for middleware_method in self._view_middleware: + response = await middleware_method(request, callback, callback_args, callback_kwargs) if response: - return response - raise + break + + if response is None: + wrapped_callback = self.make_view_atomic(callback) + # If it is a synchronous view, run it in a subthread + if not asyncio.iscoroutinefunction(wrapped_callback): + wrapped_callback = sync_to_async(wrapped_callback, thread_sensitive=True) + try: + response = await wrapped_callback(request, *callback_args, **callback_kwargs) + except Exception as e: + response = await sync_to_async( + self.process_exception_by_middleware, + thread_sensitive=True, + )(e, request) + + # Complain if the view returned None or an uncalled coroutine. + self.check_response(response, callback) + + # If the response supports deferred rendering, apply template + # response middleware and then render the response + if hasattr(response, 'render') and callable(response.render): + for middleware_method in self._template_response_middleware: + response = await middleware_method(request, response) + # Complain if the template response middleware returned None or + # an uncalled coroutine. + self.check_response( + response, + middleware_method, + name='%s.process_template_response' % ( + middleware_method.__self__.__class__.__name__, + ) + ) + try: + if asyncio.iscoroutinefunction(response.render): + response = await response.render() + else: + response = await sync_to_async(response.render, thread_sensitive=True)() + except Exception as e: + response = await sync_to_async( + self.process_exception_by_middleware, + thread_sensitive=True, + )(e, request) + + # Make sure the response is not a coroutine + if asyncio.iscoroutine(response): + raise RuntimeError('Response is still a coroutine.') + return response + + def resolve_request(self, request): + """ + Retrieve/set the urlconf for the request. Return the view resolved, + with its args and kwargs. + """ + # Work out the resolver. + if hasattr(request, 'urlconf'): + urlconf = request.urlconf + set_urlconf(urlconf) + resolver = get_resolver(urlconf) + else: + resolver = get_resolver() + # Resolve the view, and assign the match object back to the request. + resolver_match = resolver.resolve(request.path_info) + request.resolver_match = resolver_match + return resolver_match def check_response(self, response, callback, name=None): - """Raise an error if the view returned None.""" - if response is not None: + """ + Raise an error if the view returned None or an uncalled coroutine. + """ + if not(response is None or asyncio.iscoroutine(response)): return if not name: if isinstance(callback, types.FunctionType): # FBV @@ -160,10 +295,41 @@ class BaseHandler: callback.__module__, callback.__class__.__name__, ) - raise ValueError( - "%s didn't return an HttpResponse object. It returned None " - "instead." % name - ) + if response is None: + raise ValueError( + "%s didn't return an HttpResponse object. It returned None " + "instead." % name + ) + elif asyncio.iscoroutine(response): + raise ValueError( + "%s didn't return an HttpResponse object. It returned an " + "unawaited coroutine instead. You may need to add an 'await' " + "into your view." % name + ) + + # Other utility methods. + + def make_view_atomic(self, view): + non_atomic_requests = getattr(view, '_non_atomic_requests', set()) + for db in connections.all(): + if db.settings_dict['ATOMIC_REQUESTS'] and db.alias not in non_atomic_requests: + if asyncio.iscoroutinefunction(view): + raise RuntimeError( + 'You cannot use ATOMIC_REQUESTS with async views.' + ) + view = transaction.atomic(using=db.alias)(view) + return view + + def process_exception_by_middleware(self, exception, request): + """ + Pass the exception to the exception middleware. If no middleware + return a response for this exception, raise it. + """ + for middleware_method in self._exception_middleware: + response = middleware_method(request, exception) + if response: + return response + raise def reset_urlconf(sender, **kwargs): diff --git a/django/core/handlers/exception.py b/django/core/handlers/exception.py index 66443ce560..50880f2784 100644 --- a/django/core/handlers/exception.py +++ b/django/core/handlers/exception.py @@ -1,7 +1,10 @@ +import asyncio import logging import sys from functools import wraps +from asgiref.sync import sync_to_async + from django.conf import settings from django.core import signals from django.core.exceptions import ( @@ -28,14 +31,24 @@ def convert_exception_to_response(get_response): no middleware leaks an exception and that the next middleware in the stack can rely on getting a response instead of an exception. """ - @wraps(get_response) - def inner(request): - try: - response = get_response(request) - except Exception as exc: - response = response_for_exception(request, exc) - return response - return inner + if asyncio.iscoroutinefunction(get_response): + @wraps(get_response) + async def inner(request): + try: + response = await get_response(request) + except Exception as exc: + response = await sync_to_async(response_for_exception)(request, exc) + return response + return inner + else: + @wraps(get_response) + def inner(request): + try: + response = get_response(request) + except Exception as exc: + response = response_for_exception(request, exc) + return response + return inner def response_for_exception(request, exc): diff --git a/django/test/__init__.py b/django/test/__init__.py index 4782d72184..d1f953a8dd 100644 --- a/django/test/__init__.py +++ b/django/test/__init__.py @@ -1,6 +1,8 @@ """Django Unit Test framework.""" -from django.test.client import Client, RequestFactory +from django.test.client import ( + AsyncClient, AsyncRequestFactory, Client, RequestFactory, +) from django.test.testcases import ( LiveServerTestCase, SimpleTestCase, TestCase, TransactionTestCase, skipIfDBFeature, skipUnlessAnyDBFeature, skipUnlessDBFeature, @@ -11,8 +13,9 @@ from django.test.utils import ( ) __all__ = [ - 'Client', 'RequestFactory', 'TestCase', 'TransactionTestCase', - 'SimpleTestCase', 'LiveServerTestCase', 'skipIfDBFeature', - 'skipUnlessAnyDBFeature', 'skipUnlessDBFeature', 'ignore_warnings', - 'modify_settings', 'override_settings', 'override_system_checks', 'tag', + 'AsyncClient', 'AsyncRequestFactory', 'Client', 'RequestFactory', + 'TestCase', 'TransactionTestCase', 'SimpleTestCase', 'LiveServerTestCase', + 'skipIfDBFeature', 'skipUnlessAnyDBFeature', 'skipUnlessDBFeature', + 'ignore_warnings', 'modify_settings', 'override_settings', + 'override_system_checks', 'tag', ] diff --git a/django/test/client.py b/django/test/client.py index 34fc9f3cf1..c460832461 100644 --- a/django/test/client.py +++ b/django/test/client.py @@ -9,7 +9,10 @@ from importlib import import_module from io import BytesIO from urllib.parse import unquote_to_bytes, urljoin, urlparse, urlsplit +from asgiref.sync import sync_to_async + from django.conf import settings +from django.core.handlers.asgi import ASGIRequest from django.core.handlers.base import BaseHandler from django.core.handlers.wsgi import WSGIRequest from django.core.serializers.json import DjangoJSONEncoder @@ -157,6 +160,52 @@ class ClientHandler(BaseHandler): return response +class AsyncClientHandler(BaseHandler): + """An async version of ClientHandler.""" + def __init__(self, enforce_csrf_checks=True, *args, **kwargs): + self.enforce_csrf_checks = enforce_csrf_checks + super().__init__(*args, **kwargs) + + async def __call__(self, scope): + # Set up middleware if needed. We couldn't do this earlier, because + # settings weren't available. + if self._middleware_chain is None: + self.load_middleware(is_async=True) + # Extract body file from the scope, if provided. + if '_body_file' in scope: + body_file = scope.pop('_body_file') + else: + body_file = FakePayload('') + + request_started.disconnect(close_old_connections) + await sync_to_async(request_started.send)(sender=self.__class__, scope=scope) + request_started.connect(close_old_connections) + request = ASGIRequest(scope, body_file) + # Sneaky little hack so that we can easily get round + # CsrfViewMiddleware. This makes life easier, and is probably required + # for backwards compatibility with external tests against admin views. + request._dont_enforce_csrf_checks = not self.enforce_csrf_checks + # Request goes through middleware. + response = await self.get_response_async(request) + # Simulate behaviors of most Web servers. + conditional_content_removal(request, response) + # Attach the originating ASGI request to the response so that it could + # be later retrieved. + response.asgi_request = request + # Emulate a server by calling the close method on completion. + if response.streaming: + response.streaming_content = await sync_to_async(closing_iterator_wrapper)( + response.streaming_content, + response.close, + ) + else: + request_finished.disconnect(close_old_connections) + # Will fire request_finished. + await sync_to_async(response.close)() + request_finished.connect(close_old_connections) + return response + + def store_rendered_templates(store, signal, sender, template, context, **kwargs): """ Store templates and contexts that are rendered. @@ -421,7 +470,194 @@ class RequestFactory: return self.request(**r) -class Client(RequestFactory): +class AsyncRequestFactory(RequestFactory): + """ + Class that lets you create mock ASGI-like Request objects for use in + testing. Usage: + + rf = AsyncRequestFactory() + get_request = await rf.get('/hello/') + post_request = await rf.post('/submit/', {'foo': 'bar'}) + + Once you have a request object you can pass it to any view function, + including synchronous ones. The reason we have a separate class here is: + a) this makes ASGIRequest subclasses, and + b) AsyncTestClient can subclass it. + """ + def _base_scope(self, **request): + """The base scope for a request.""" + # This is a minimal valid ASGI scope, plus: + # - headers['cookie'] for cookie support, + # - 'client' often useful, see #8551. + scope = { + 'asgi': {'version': '3.0'}, + 'type': 'http', + 'http_version': '1.1', + 'client': ['127.0.0.1', 0], + 'server': ('testserver', '80'), + 'scheme': 'http', + 'method': 'GET', + 'headers': [], + **self.defaults, + **request, + } + scope['headers'].append(( + b'cookie', + b'; '.join(sorted( + ('%s=%s' % (morsel.key, morsel.coded_value)).encode('ascii') + for morsel in self.cookies.values() + )), + )) + return scope + + def request(self, **request): + """Construct a generic request object.""" + # This is synchronous, which means all methods on this class are. + # AsyncClient, however, has an async request function, which makes all + # its methods async. + if '_body_file' in request: + body_file = request.pop('_body_file') + else: + body_file = FakePayload('') + return ASGIRequest(self._base_scope(**request), body_file) + + def generic( + self, method, path, data='', content_type='application/octet-stream', + secure=False, **extra, + ): + """Construct an arbitrary HTTP request.""" + parsed = urlparse(str(path)) # path can be lazy. + data = force_bytes(data, settings.DEFAULT_CHARSET) + s = { + 'method': method, + 'path': self._get_path(parsed), + 'server': ('127.0.0.1', '443' if secure else '80'), + 'scheme': 'https' if secure else 'http', + 'headers': [(b'host', b'testserver')], + } + if data: + s['headers'].extend([ + (b'content-length', bytes(len(data))), + (b'content-type', content_type.encode('ascii')), + ]) + s['_body_file'] = FakePayload(data) + s.update(extra) + # If QUERY_STRING is absent or empty, we want to extract it from the + # URL. + if not s.get('query_string'): + s['query_string'] = parsed[4] + return self.request(**s) + + +class ClientMixin: + """ + Mixin with common methods between Client and AsyncClient. + """ + def store_exc_info(self, **kwargs): + """Store exceptions when they are generated by a view.""" + self.exc_info = sys.exc_info() + + def check_exception(self, response): + """ + Look for a signaled exception, clear the current context exception + data, re-raise the signaled exception, and clear the signaled exception + from the local cache. + """ + response.exc_info = self.exc_info + if self.exc_info: + _, exc_value, _ = self.exc_info + self.exc_info = None + if self.raise_request_exception: + raise exc_value + + @property + def session(self): + """Return the current session variables.""" + engine = import_module(settings.SESSION_ENGINE) + cookie = self.cookies.get(settings.SESSION_COOKIE_NAME) + if cookie: + return engine.SessionStore(cookie.value) + session = engine.SessionStore() + session.save() + self.cookies[settings.SESSION_COOKIE_NAME] = session.session_key + return session + + def login(self, **credentials): + """ + Set the Factory to appear as if it has successfully logged into a site. + + Return True if login is possible or False if the provided credentials + are incorrect. + """ + from django.contrib.auth import authenticate + user = authenticate(**credentials) + if user: + self._login(user) + return True + return False + + def force_login(self, user, backend=None): + def get_backend(): + from django.contrib.auth import load_backend + for backend_path in settings.AUTHENTICATION_BACKENDS: + backend = load_backend(backend_path) + if hasattr(backend, 'get_user'): + return backend_path + + if backend is None: + backend = get_backend() + user.backend = backend + self._login(user, backend) + + def _login(self, user, backend=None): + from django.contrib.auth import login + # Create a fake request to store login details. + request = HttpRequest() + if self.session: + request.session = self.session + else: + engine = import_module(settings.SESSION_ENGINE) + request.session = engine.SessionStore() + login(request, user, backend) + # Save the session values. + request.session.save() + # Set the cookie to represent the session. + session_cookie = settings.SESSION_COOKIE_NAME + self.cookies[session_cookie] = request.session.session_key + cookie_data = { + 'max-age': None, + 'path': '/', + 'domain': settings.SESSION_COOKIE_DOMAIN, + 'secure': settings.SESSION_COOKIE_SECURE or None, + 'expires': None, + } + self.cookies[session_cookie].update(cookie_data) + + def logout(self): + """Log out the user by removing the cookies and session object.""" + from django.contrib.auth import get_user, logout + request = HttpRequest() + if self.session: + request.session = self.session + request.user = get_user(request) + else: + engine = import_module(settings.SESSION_ENGINE) + request.session = engine.SessionStore() + logout(request) + self.cookies = SimpleCookie() + + def _parse_json(self, response, **extra): + if not hasattr(response, '_json'): + if not JSON_CONTENT_TYPE_RE.match(response.get('Content-Type')): + raise ValueError( + 'Content-Type header is "%s", not "application/json"' + % response.get('Content-Type') + ) + response._json = json.loads(response.content.decode(response.charset), **extra) + return response._json + + +class Client(ClientMixin, RequestFactory): """ A class that can act as a client for testing purposes. @@ -446,23 +682,6 @@ class Client(RequestFactory): self.exc_info = None self.extra = None - def store_exc_info(self, **kwargs): - """Store exceptions when they are generated by a view.""" - self.exc_info = sys.exc_info() - - @property - def session(self): - """Return the current session variables.""" - engine = import_module(settings.SESSION_ENGINE) - cookie = self.cookies.get(settings.SESSION_COOKIE_NAME) - if cookie: - return engine.SessionStore(cookie.value) - - session = engine.SessionStore() - session.save() - self.cookies[settings.SESSION_COOKIE_NAME] = session.session_key - return session - def request(self, **request): """ The master request method. Compose the environment dictionary and pass @@ -486,15 +705,8 @@ class Client(RequestFactory): finally: signals.template_rendered.disconnect(dispatch_uid=signal_uid) got_request_exception.disconnect(dispatch_uid=exception_uid) - # Look for a signaled exception, clear the current context exception - # data, then re-raise the signaled exception. Also clear the signaled - # exception from the local cache. - response.exc_info = self.exc_info - if self.exc_info: - _, exc_value, _ = self.exc_info - self.exc_info = None - if self.raise_request_exception: - raise exc_value + # Check for signaled exceptions. + self.check_exception(response) # Save the client and request that stimulated the response. response.client = self response.request = request @@ -583,85 +795,6 @@ class Client(RequestFactory): response = self._handle_redirects(response, data=data, **extra) return response - def login(self, **credentials): - """ - Set the Factory to appear as if it has successfully logged into a site. - - Return True if login is possible; False if the provided credentials - are incorrect. - """ - from django.contrib.auth import authenticate - user = authenticate(**credentials) - if user: - self._login(user) - return True - else: - return False - - def force_login(self, user, backend=None): - def get_backend(): - from django.contrib.auth import load_backend - for backend_path in settings.AUTHENTICATION_BACKENDS: - backend = load_backend(backend_path) - if hasattr(backend, 'get_user'): - return backend_path - if backend is None: - backend = get_backend() - user.backend = backend - self._login(user, backend) - - def _login(self, user, backend=None): - from django.contrib.auth import login - engine = import_module(settings.SESSION_ENGINE) - - # Create a fake request to store login details. - request = HttpRequest() - - if self.session: - request.session = self.session - else: - request.session = engine.SessionStore() - login(request, user, backend) - - # Save the session values. - request.session.save() - - # Set the cookie to represent the session. - session_cookie = settings.SESSION_COOKIE_NAME - self.cookies[session_cookie] = request.session.session_key - cookie_data = { - 'max-age': None, - 'path': '/', - 'domain': settings.SESSION_COOKIE_DOMAIN, - 'secure': settings.SESSION_COOKIE_SECURE or None, - 'expires': None, - } - self.cookies[session_cookie].update(cookie_data) - - def logout(self): - """Log out the user by removing the cookies and session object.""" - from django.contrib.auth import get_user, logout - - request = HttpRequest() - engine = import_module(settings.SESSION_ENGINE) - if self.session: - request.session = self.session - request.user = get_user(request) - else: - request.session = engine.SessionStore() - logout(request) - self.cookies = SimpleCookie() - - def _parse_json(self, response, **extra): - if not hasattr(response, '_json'): - if not JSON_CONTENT_TYPE_RE.match(response.get('Content-Type')): - raise ValueError( - 'Content-Type header is "{}", not "application/json"' - .format(response.get('Content-Type')) - ) - response._json = json.loads(response.content.decode(response.charset), **extra) - return response._json - def _handle_redirects(self, response, data='', content_type='', **extra): """ Follow any redirects by requesting responses from the server using GET. @@ -714,3 +847,66 @@ class Client(RequestFactory): raise RedirectCycleError("Too many redirects.", last_response=response) return response + + +class AsyncClient(ClientMixin, AsyncRequestFactory): + """ + An async version of Client that creates ASGIRequests and calls through an + async request path. + + Does not currently support "follow" on its methods. + """ + def __init__(self, enforce_csrf_checks=False, raise_request_exception=True, **defaults): + super().__init__(**defaults) + self.handler = AsyncClientHandler(enforce_csrf_checks) + self.raise_request_exception = raise_request_exception + self.exc_info = None + self.extra = None + + async def request(self, **request): + """ + The master request method. Compose the scope dictionary and pass to the + handler, return the result of the handler. Assume defaults for the + query environment, which can be overridden using the arguments to the + request. + """ + if 'follow' in request: + raise NotImplementedError( + 'AsyncClient request methods do not accept the follow ' + 'parameter.' + ) + scope = self._base_scope(**request) + # Curry a data dictionary into an instance of the template renderer + # callback function. + data = {} + on_template_render = partial(store_rendered_templates, data) + signal_uid = 'template-render-%s' % id(request) + signals.template_rendered.connect(on_template_render, dispatch_uid=signal_uid) + # Capture exceptions created by the handler. + exception_uid = 'request-exception-%s' % id(request) + got_request_exception.connect(self.store_exc_info, dispatch_uid=exception_uid) + try: + response = await self.handler(scope) + finally: + signals.template_rendered.disconnect(dispatch_uid=signal_uid) + got_request_exception.disconnect(dispatch_uid=exception_uid) + # Check for signaled exceptions. + self.check_exception(response) + # Save the client and request that stimulated the response. + response.client = self + response.request = request + # Add any rendered template detail to the response. + response.templates = data.get('templates', []) + response.context = data.get('context') + response.json = partial(self._parse_json, response) + # Attach the ResolverMatch instance to the response. + response.resolver_match = SimpleLazyObject(lambda: resolve(request['path'])) + # Flatten a single context. Not really necessary anymore thanks to the + # __getattr__ flattening in ContextList, but has some edge case + # backwards compatibility implications. + if response.context and len(response.context) == 1: + response.context = response.context[0] + # Update persistent cookie data. + if response.cookies: + self.cookies.update(response.cookies) + return response diff --git a/django/test/testcases.py b/django/test/testcases.py index 2c24708a3b..7ebddf80e5 100644 --- a/django/test/testcases.py +++ b/django/test/testcases.py @@ -33,7 +33,7 @@ from django.db import DEFAULT_DB_ALIAS, connection, connections, transaction from django.forms.fields import CharField from django.http import QueryDict from django.http.request import split_domain_port, validate_host -from django.test.client import Client +from django.test.client import AsyncClient, Client from django.test.html import HTMLParseError, parse_html from django.test.signals import setting_changed, template_rendered from django.test.utils import ( @@ -151,6 +151,7 @@ class SimpleTestCase(unittest.TestCase): # The class we'll use for the test client self.client. # Can be overridden in derived classes. client_class = Client + async_client_class = AsyncClient _overridden_settings = None _modified_settings = None @@ -292,6 +293,7 @@ class SimpleTestCase(unittest.TestCase): * Clear the mail test outbox. """ self.client = self.client_class() + self.async_client = self.async_client_class() mail.outbox = [] def _post_teardown(self): diff --git a/django/test/utils.py b/django/test/utils.py index e626667b09..d1f7d19546 100644 --- a/django/test/utils.py +++ b/django/test/utils.py @@ -1,3 +1,4 @@ +import asyncio import logging import re import sys @@ -362,12 +363,22 @@ class TestContextDecorator: raise TypeError('Can only decorate subclasses of unittest.TestCase') def decorate_callable(self, func): - @wraps(func) - def inner(*args, **kwargs): - with self as context: - if self.kwarg_name: - kwargs[self.kwarg_name] = context - return func(*args, **kwargs) + if asyncio.iscoroutinefunction(func): + # If the inner function is an async function, we must execute async + # as well so that the `with` statement executes at the right time. + @wraps(func) + async def inner(*args, **kwargs): + with self as context: + if self.kwarg_name: + kwargs[self.kwarg_name] = context + return await func(*args, **kwargs) + else: + @wraps(func) + def inner(*args, **kwargs): + with self as context: + if self.kwarg_name: + kwargs[self.kwarg_name] = context + return func(*args, **kwargs) return inner def __call__(self, decorated): diff --git a/django/utils/decorators.py b/django/utils/decorators.py index bb2e498e46..5c9a5d01c7 100644 --- a/django/utils/decorators.py +++ b/django/utils/decorators.py @@ -150,3 +150,30 @@ def make_middleware_decorator(middleware_class): return _wrapped_view return _decorator return _make_decorator + + +def sync_and_async_middleware(func): + """ + Mark a middleware factory as returning a hybrid middleware supporting both + types of request. + """ + func.sync_capable = True + func.async_capable = True + return func + + +def sync_only_middleware(func): + """ + Mark a middleware factory as returning a sync middleware. + This is the default. + """ + func.sync_capable = True + func.async_capable = False + return func + + +def async_only_middleware(func): + """Mark a middleware factory as returning an async middleware.""" + func.sync_capable = False + func.async_capable = True + return func diff --git a/django/utils/deprecation.py b/django/utils/deprecation.py index 81e7c3a15b..6336558a81 100644 --- a/django/utils/deprecation.py +++ b/django/utils/deprecation.py @@ -1,6 +1,9 @@ +import asyncio import inspect import warnings +from asgiref.sync import sync_to_async + class RemovedInNextVersionWarning(DeprecationWarning): pass @@ -80,14 +83,31 @@ class DeprecationInstanceCheck(type): class MiddlewareMixin: + sync_capable = True + async_capable = True + # RemovedInDjango40Warning: when the deprecation ends, replace with: # def __init__(self, get_response): def __init__(self, get_response=None): self._get_response_none_deprecation(get_response) self.get_response = get_response + self._async_check() super().__init__() + def _async_check(self): + """ + If get_response is a coroutine function, turns us into async mode so + a thread is not consumed during a whole request. + """ + if asyncio.iscoroutinefunction(self.get_response): + # Mark the class as async-capable, but do the actual switch + # inside __call__ to avoid swapping out dunder methods + self._is_coroutine = asyncio.coroutines._is_coroutine + def __call__(self, request): + # Exit out to async mode, if needed + if asyncio.iscoroutinefunction(self.get_response): + return self.__acall__(request) response = None if hasattr(self, 'process_request'): response = self.process_request(request) @@ -96,6 +116,19 @@ class MiddlewareMixin: response = self.process_response(request, response) return response + async def __acall__(self, request): + """ + Async version of __call__ that is swapped in when an async request + is running. + """ + response = None + if hasattr(self, 'process_request'): + response = await sync_to_async(self.process_request)(request) + response = response or await self.get_response(request) + if hasattr(self, 'process_response'): + response = await sync_to_async(self.process_response)(request, response) + return response + def _get_response_none_deprecation(self, get_response): if get_response is None: warnings.warn( |
