diff options
Diffstat (limited to 'django/test/client.py')
| -rw-r--r-- | django/test/client.py | 408 |
1 files changed, 302 insertions, 106 deletions
diff --git a/django/test/client.py b/django/test/client.py index 34fc9f3cf1..c460832461 100644 --- a/django/test/client.py +++ b/django/test/client.py @@ -9,7 +9,10 @@ from importlib import import_module from io import BytesIO from urllib.parse import unquote_to_bytes, urljoin, urlparse, urlsplit +from asgiref.sync import sync_to_async + from django.conf import settings +from django.core.handlers.asgi import ASGIRequest from django.core.handlers.base import BaseHandler from django.core.handlers.wsgi import WSGIRequest from django.core.serializers.json import DjangoJSONEncoder @@ -157,6 +160,52 @@ class ClientHandler(BaseHandler): return response +class AsyncClientHandler(BaseHandler): + """An async version of ClientHandler.""" + def __init__(self, enforce_csrf_checks=True, *args, **kwargs): + self.enforce_csrf_checks = enforce_csrf_checks + super().__init__(*args, **kwargs) + + async def __call__(self, scope): + # Set up middleware if needed. We couldn't do this earlier, because + # settings weren't available. + if self._middleware_chain is None: + self.load_middleware(is_async=True) + # Extract body file from the scope, if provided. + if '_body_file' in scope: + body_file = scope.pop('_body_file') + else: + body_file = FakePayload('') + + request_started.disconnect(close_old_connections) + await sync_to_async(request_started.send)(sender=self.__class__, scope=scope) + request_started.connect(close_old_connections) + request = ASGIRequest(scope, body_file) + # Sneaky little hack so that we can easily get round + # CsrfViewMiddleware. This makes life easier, and is probably required + # for backwards compatibility with external tests against admin views. + request._dont_enforce_csrf_checks = not self.enforce_csrf_checks + # Request goes through middleware. + response = await self.get_response_async(request) + # Simulate behaviors of most Web servers. + conditional_content_removal(request, response) + # Attach the originating ASGI request to the response so that it could + # be later retrieved. + response.asgi_request = request + # Emulate a server by calling the close method on completion. + if response.streaming: + response.streaming_content = await sync_to_async(closing_iterator_wrapper)( + response.streaming_content, + response.close, + ) + else: + request_finished.disconnect(close_old_connections) + # Will fire request_finished. + await sync_to_async(response.close)() + request_finished.connect(close_old_connections) + return response + + def store_rendered_templates(store, signal, sender, template, context, **kwargs): """ Store templates and contexts that are rendered. @@ -421,7 +470,194 @@ class RequestFactory: return self.request(**r) -class Client(RequestFactory): +class AsyncRequestFactory(RequestFactory): + """ + Class that lets you create mock ASGI-like Request objects for use in + testing. Usage: + + rf = AsyncRequestFactory() + get_request = await rf.get('/hello/') + post_request = await rf.post('/submit/', {'foo': 'bar'}) + + Once you have a request object you can pass it to any view function, + including synchronous ones. The reason we have a separate class here is: + a) this makes ASGIRequest subclasses, and + b) AsyncTestClient can subclass it. + """ + def _base_scope(self, **request): + """The base scope for a request.""" + # This is a minimal valid ASGI scope, plus: + # - headers['cookie'] for cookie support, + # - 'client' often useful, see #8551. + scope = { + 'asgi': {'version': '3.0'}, + 'type': 'http', + 'http_version': '1.1', + 'client': ['127.0.0.1', 0], + 'server': ('testserver', '80'), + 'scheme': 'http', + 'method': 'GET', + 'headers': [], + **self.defaults, + **request, + } + scope['headers'].append(( + b'cookie', + b'; '.join(sorted( + ('%s=%s' % (morsel.key, morsel.coded_value)).encode('ascii') + for morsel in self.cookies.values() + )), + )) + return scope + + def request(self, **request): + """Construct a generic request object.""" + # This is synchronous, which means all methods on this class are. + # AsyncClient, however, has an async request function, which makes all + # its methods async. + if '_body_file' in request: + body_file = request.pop('_body_file') + else: + body_file = FakePayload('') + return ASGIRequest(self._base_scope(**request), body_file) + + def generic( + self, method, path, data='', content_type='application/octet-stream', + secure=False, **extra, + ): + """Construct an arbitrary HTTP request.""" + parsed = urlparse(str(path)) # path can be lazy. + data = force_bytes(data, settings.DEFAULT_CHARSET) + s = { + 'method': method, + 'path': self._get_path(parsed), + 'server': ('127.0.0.1', '443' if secure else '80'), + 'scheme': 'https' if secure else 'http', + 'headers': [(b'host', b'testserver')], + } + if data: + s['headers'].extend([ + (b'content-length', bytes(len(data))), + (b'content-type', content_type.encode('ascii')), + ]) + s['_body_file'] = FakePayload(data) + s.update(extra) + # If QUERY_STRING is absent or empty, we want to extract it from the + # URL. + if not s.get('query_string'): + s['query_string'] = parsed[4] + return self.request(**s) + + +class ClientMixin: + """ + Mixin with common methods between Client and AsyncClient. + """ + def store_exc_info(self, **kwargs): + """Store exceptions when they are generated by a view.""" + self.exc_info = sys.exc_info() + + def check_exception(self, response): + """ + Look for a signaled exception, clear the current context exception + data, re-raise the signaled exception, and clear the signaled exception + from the local cache. + """ + response.exc_info = self.exc_info + if self.exc_info: + _, exc_value, _ = self.exc_info + self.exc_info = None + if self.raise_request_exception: + raise exc_value + + @property + def session(self): + """Return the current session variables.""" + engine = import_module(settings.SESSION_ENGINE) + cookie = self.cookies.get(settings.SESSION_COOKIE_NAME) + if cookie: + return engine.SessionStore(cookie.value) + session = engine.SessionStore() + session.save() + self.cookies[settings.SESSION_COOKIE_NAME] = session.session_key + return session + + def login(self, **credentials): + """ + Set the Factory to appear as if it has successfully logged into a site. + + Return True if login is possible or False if the provided credentials + are incorrect. + """ + from django.contrib.auth import authenticate + user = authenticate(**credentials) + if user: + self._login(user) + return True + return False + + def force_login(self, user, backend=None): + def get_backend(): + from django.contrib.auth import load_backend + for backend_path in settings.AUTHENTICATION_BACKENDS: + backend = load_backend(backend_path) + if hasattr(backend, 'get_user'): + return backend_path + + if backend is None: + backend = get_backend() + user.backend = backend + self._login(user, backend) + + def _login(self, user, backend=None): + from django.contrib.auth import login + # Create a fake request to store login details. + request = HttpRequest() + if self.session: + request.session = self.session + else: + engine = import_module(settings.SESSION_ENGINE) + request.session = engine.SessionStore() + login(request, user, backend) + # Save the session values. + request.session.save() + # Set the cookie to represent the session. + session_cookie = settings.SESSION_COOKIE_NAME + self.cookies[session_cookie] = request.session.session_key + cookie_data = { + 'max-age': None, + 'path': '/', + 'domain': settings.SESSION_COOKIE_DOMAIN, + 'secure': settings.SESSION_COOKIE_SECURE or None, + 'expires': None, + } + self.cookies[session_cookie].update(cookie_data) + + def logout(self): + """Log out the user by removing the cookies and session object.""" + from django.contrib.auth import get_user, logout + request = HttpRequest() + if self.session: + request.session = self.session + request.user = get_user(request) + else: + engine = import_module(settings.SESSION_ENGINE) + request.session = engine.SessionStore() + logout(request) + self.cookies = SimpleCookie() + + def _parse_json(self, response, **extra): + if not hasattr(response, '_json'): + if not JSON_CONTENT_TYPE_RE.match(response.get('Content-Type')): + raise ValueError( + 'Content-Type header is "%s", not "application/json"' + % response.get('Content-Type') + ) + response._json = json.loads(response.content.decode(response.charset), **extra) + return response._json + + +class Client(ClientMixin, RequestFactory): """ A class that can act as a client for testing purposes. @@ -446,23 +682,6 @@ class Client(RequestFactory): self.exc_info = None self.extra = None - def store_exc_info(self, **kwargs): - """Store exceptions when they are generated by a view.""" - self.exc_info = sys.exc_info() - - @property - def session(self): - """Return the current session variables.""" - engine = import_module(settings.SESSION_ENGINE) - cookie = self.cookies.get(settings.SESSION_COOKIE_NAME) - if cookie: - return engine.SessionStore(cookie.value) - - session = engine.SessionStore() - session.save() - self.cookies[settings.SESSION_COOKIE_NAME] = session.session_key - return session - def request(self, **request): """ The master request method. Compose the environment dictionary and pass @@ -486,15 +705,8 @@ class Client(RequestFactory): finally: signals.template_rendered.disconnect(dispatch_uid=signal_uid) got_request_exception.disconnect(dispatch_uid=exception_uid) - # Look for a signaled exception, clear the current context exception - # data, then re-raise the signaled exception. Also clear the signaled - # exception from the local cache. - response.exc_info = self.exc_info - if self.exc_info: - _, exc_value, _ = self.exc_info - self.exc_info = None - if self.raise_request_exception: - raise exc_value + # Check for signaled exceptions. + self.check_exception(response) # Save the client and request that stimulated the response. response.client = self response.request = request @@ -583,85 +795,6 @@ class Client(RequestFactory): response = self._handle_redirects(response, data=data, **extra) return response - def login(self, **credentials): - """ - Set the Factory to appear as if it has successfully logged into a site. - - Return True if login is possible; False if the provided credentials - are incorrect. - """ - from django.contrib.auth import authenticate - user = authenticate(**credentials) - if user: - self._login(user) - return True - else: - return False - - def force_login(self, user, backend=None): - def get_backend(): - from django.contrib.auth import load_backend - for backend_path in settings.AUTHENTICATION_BACKENDS: - backend = load_backend(backend_path) - if hasattr(backend, 'get_user'): - return backend_path - if backend is None: - backend = get_backend() - user.backend = backend - self._login(user, backend) - - def _login(self, user, backend=None): - from django.contrib.auth import login - engine = import_module(settings.SESSION_ENGINE) - - # Create a fake request to store login details. - request = HttpRequest() - - if self.session: - request.session = self.session - else: - request.session = engine.SessionStore() - login(request, user, backend) - - # Save the session values. - request.session.save() - - # Set the cookie to represent the session. - session_cookie = settings.SESSION_COOKIE_NAME - self.cookies[session_cookie] = request.session.session_key - cookie_data = { - 'max-age': None, - 'path': '/', - 'domain': settings.SESSION_COOKIE_DOMAIN, - 'secure': settings.SESSION_COOKIE_SECURE or None, - 'expires': None, - } - self.cookies[session_cookie].update(cookie_data) - - def logout(self): - """Log out the user by removing the cookies and session object.""" - from django.contrib.auth import get_user, logout - - request = HttpRequest() - engine = import_module(settings.SESSION_ENGINE) - if self.session: - request.session = self.session - request.user = get_user(request) - else: - request.session = engine.SessionStore() - logout(request) - self.cookies = SimpleCookie() - - def _parse_json(self, response, **extra): - if not hasattr(response, '_json'): - if not JSON_CONTENT_TYPE_RE.match(response.get('Content-Type')): - raise ValueError( - 'Content-Type header is "{}", not "application/json"' - .format(response.get('Content-Type')) - ) - response._json = json.loads(response.content.decode(response.charset), **extra) - return response._json - def _handle_redirects(self, response, data='', content_type='', **extra): """ Follow any redirects by requesting responses from the server using GET. @@ -714,3 +847,66 @@ class Client(RequestFactory): raise RedirectCycleError("Too many redirects.", last_response=response) return response + + +class AsyncClient(ClientMixin, AsyncRequestFactory): + """ + An async version of Client that creates ASGIRequests and calls through an + async request path. + + Does not currently support "follow" on its methods. + """ + def __init__(self, enforce_csrf_checks=False, raise_request_exception=True, **defaults): + super().__init__(**defaults) + self.handler = AsyncClientHandler(enforce_csrf_checks) + self.raise_request_exception = raise_request_exception + self.exc_info = None + self.extra = None + + async def request(self, **request): + """ + The master request method. Compose the scope dictionary and pass to the + handler, return the result of the handler. Assume defaults for the + query environment, which can be overridden using the arguments to the + request. + """ + if 'follow' in request: + raise NotImplementedError( + 'AsyncClient request methods do not accept the follow ' + 'parameter.' + ) + scope = self._base_scope(**request) + # Curry a data dictionary into an instance of the template renderer + # callback function. + data = {} + on_template_render = partial(store_rendered_templates, data) + signal_uid = 'template-render-%s' % id(request) + signals.template_rendered.connect(on_template_render, dispatch_uid=signal_uid) + # Capture exceptions created by the handler. + exception_uid = 'request-exception-%s' % id(request) + got_request_exception.connect(self.store_exc_info, dispatch_uid=exception_uid) + try: + response = await self.handler(scope) + finally: + signals.template_rendered.disconnect(dispatch_uid=signal_uid) + got_request_exception.disconnect(dispatch_uid=exception_uid) + # Check for signaled exceptions. + self.check_exception(response) + # Save the client and request that stimulated the response. + response.client = self + response.request = request + # Add any rendered template detail to the response. + response.templates = data.get('templates', []) + response.context = data.get('context') + response.json = partial(self._parse_json, response) + # Attach the ResolverMatch instance to the response. + response.resolver_match = SimpleLazyObject(lambda: resolve(request['path'])) + # Flatten a single context. Not really necessary anymore thanks to the + # __getattr__ flattening in ContextList, but has some edge case + # backwards compatibility implications. + if response.context and len(response.context) == 1: + response.context = response.context[0] + # Update persistent cookie data. + if response.cookies: + self.cookies.update(response.cookies) + return response |
