diff options
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/asgi/tests.py | 46 | ||||
| -rw-r--r-- | tests/handlers/tests.py | 51 | ||||
| -rw-r--r-- | tests/handlers/urls.py | 2 | ||||
| -rw-r--r-- | tests/handlers/views.py | 10 | ||||
| -rw-r--r-- | tests/middleware_exceptions/middleware.py | 64 | ||||
| -rw-r--r-- | tests/middleware_exceptions/tests.py | 159 | ||||
| -rw-r--r-- | tests/middleware_exceptions/urls.py | 5 | ||||
| -rw-r--r-- | tests/middleware_exceptions/views.py | 8 | ||||
| -rw-r--r-- | tests/test_client/tests.py | 59 | ||||
| -rw-r--r-- | tests/test_client/urls.py | 2 | ||||
| -rw-r--r-- | tests/test_client/views.py | 4 |
11 files changed, 381 insertions, 29 deletions
diff --git a/tests/asgi/tests.py b/tests/asgi/tests.py index fada34d0d8..c123f027fb 100644 --- a/tests/asgi/tests.py +++ b/tests/asgi/tests.py @@ -7,7 +7,7 @@ from asgiref.testing import ApplicationCommunicator from django.core.asgi import get_asgi_application from django.core.signals import request_started from django.db import close_old_connections -from django.test import SimpleTestCase, override_settings +from django.test import AsyncRequestFactory, SimpleTestCase, override_settings from .urls import test_filename @@ -15,21 +15,11 @@ from .urls import test_filename @skipIf(sys.platform == 'win32' and (3, 8, 0) < sys.version_info < (3, 8, 1), 'https://bugs.python.org/issue38563') @override_settings(ROOT_URLCONF='asgi.urls') class ASGITest(SimpleTestCase): + async_request_factory = AsyncRequestFactory() def setUp(self): request_started.disconnect(close_old_connections) - def _get_scope(self, **kwargs): - return { - 'type': 'http', - 'asgi': {'version': '3.0', 'spec_version': '2.1'}, - 'http_version': '1.1', - 'method': 'GET', - 'query_string': b'', - 'server': ('testserver', 80), - **kwargs, - } - def tearDown(self): request_started.connect(close_old_connections) @@ -39,7 +29,8 @@ class ASGITest(SimpleTestCase): """ application = get_asgi_application() # Construct HTTP request. - communicator = ApplicationCommunicator(application, self._get_scope(path='/')) + scope = self.async_request_factory._base_scope(path='/') + communicator = ApplicationCommunicator(application, scope) await communicator.send_input({'type': 'http.request'}) # Read the response. response_start = await communicator.receive_output() @@ -62,7 +53,8 @@ class ASGITest(SimpleTestCase): """ application = get_asgi_application() # Construct HTTP request. - communicator = ApplicationCommunicator(application, self._get_scope(path='/file/')) + scope = self.async_request_factory._base_scope(path='/file/') + communicator = ApplicationCommunicator(application, scope) await communicator.send_input({'type': 'http.request'}) # Get the file content. with open(test_filename, 'rb') as test_file: @@ -82,12 +74,14 @@ class ASGITest(SimpleTestCase): response_body = await communicator.receive_output() self.assertEqual(response_body['type'], 'http.response.body') self.assertEqual(response_body['body'], test_file_contents) + # Allow response.close() to finish. + await communicator.wait() async def test_headers(self): application = get_asgi_application() communicator = ApplicationCommunicator( application, - self._get_scope( + self.async_request_factory._base_scope( path='/meta/', headers=[ [b'content-type', b'text/plain; charset=utf-8'], @@ -116,10 +110,11 @@ class ASGITest(SimpleTestCase): application = get_asgi_application() for query_string in (b'name=Andrew', 'name=Andrew'): with self.subTest(query_string=query_string): - communicator = ApplicationCommunicator( - application, - self._get_scope(path='/', query_string=query_string), + scope = self.async_request_factory._base_scope( + path='/', + query_string=query_string, ) + communicator = ApplicationCommunicator(application, scope) await communicator.send_input({'type': 'http.request'}) response_start = await communicator.receive_output() self.assertEqual(response_start['type'], 'http.response.start') @@ -130,17 +125,16 @@ class ASGITest(SimpleTestCase): async def test_disconnect(self): application = get_asgi_application() - communicator = ApplicationCommunicator(application, self._get_scope(path='/')) + scope = self.async_request_factory._base_scope(path='/') + communicator = ApplicationCommunicator(application, scope) await communicator.send_input({'type': 'http.disconnect'}) with self.assertRaises(asyncio.TimeoutError): await communicator.receive_output() async def test_wrong_connection_type(self): application = get_asgi_application() - communicator = ApplicationCommunicator( - application, - self._get_scope(path='/', type='other'), - ) + scope = self.async_request_factory._base_scope(path='/', type='other') + communicator = ApplicationCommunicator(application, scope) await communicator.send_input({'type': 'http.request'}) msg = 'Django can only handle ASGI/HTTP connections, not other.' with self.assertRaisesMessage(ValueError, msg): @@ -148,10 +142,8 @@ class ASGITest(SimpleTestCase): async def test_non_unicode_query_string(self): application = get_asgi_application() - communicator = ApplicationCommunicator( - application, - self._get_scope(path='/', query_string=b'\xff'), - ) + scope = self.async_request_factory._base_scope(path='/', query_string=b'\xff') + communicator = ApplicationCommunicator(application, scope) await communicator.send_input({'type': 'http.request'}) response_start = await communicator.receive_output() self.assertEqual(response_start['type'], 'http.response.start') diff --git a/tests/handlers/tests.py b/tests/handlers/tests.py index fc7074833b..dac2d967f3 100644 --- a/tests/handlers/tests.py +++ b/tests/handlers/tests.py @@ -106,6 +106,16 @@ class TransactionsPerRequestTests(TransactionTestCase): connection.settings_dict['ATOMIC_REQUESTS'] = old_atomic_requests self.assertContains(response, 'True') + async def test_auto_transaction_async_view(self): + old_atomic_requests = connection.settings_dict['ATOMIC_REQUESTS'] + try: + connection.settings_dict['ATOMIC_REQUESTS'] = True + msg = 'You cannot use ATOMIC_REQUESTS with async views.' + with self.assertRaisesMessage(RuntimeError, msg): + await self.async_client.get('/async_regular/') + finally: + connection.settings_dict['ATOMIC_REQUESTS'] = old_atomic_requests + def test_no_auto_transaction(self): old_atomic_requests = connection.settings_dict['ATOMIC_REQUESTS'] try: @@ -157,6 +167,11 @@ def empty_middleware(get_response): class HandlerRequestTests(SimpleTestCase): request_factory = RequestFactory() + def test_async_view(self): + """Calling an async view down the normal synchronous path.""" + response = self.client.get('/async_regular/') + self.assertEqual(response.status_code, 200) + def test_suspiciousop_in_view_returns_400(self): response = self.client.get('/suspicious/') self.assertEqual(response.status_code, 400) @@ -224,3 +239,39 @@ class ScriptNameTests(SimpleTestCase): 'PATH_INFO': '/milestones/accounts/login/help', }) self.assertEqual(script_name, '/mst') + + +@override_settings(ROOT_URLCONF='handlers.urls') +class AsyncHandlerRequestTests(SimpleTestCase): + """Async variants of the normal handler request tests.""" + + async def test_sync_view(self): + """Calling a sync view down the asynchronous path.""" + response = await self.async_client.get('/regular/') + self.assertEqual(response.status_code, 200) + + async def test_async_view(self): + """Calling an async view down the asynchronous path.""" + response = await self.async_client.get('/async_regular/') + self.assertEqual(response.status_code, 200) + + async def test_suspiciousop_in_view_returns_400(self): + response = await self.async_client.get('/suspicious/') + self.assertEqual(response.status_code, 400) + + async def test_no_response(self): + msg = ( + "The view handlers.views.no_response didn't return an " + "HttpResponse object. It returned None instead." + ) + with self.assertRaisesMessage(ValueError, msg): + await self.async_client.get('/no_response_fbv/') + + async def test_unawaited_response(self): + msg = ( + "The view handlers.views.async_unawaited didn't return an " + "HttpResponse object. It returned an unawaited coroutine instead. " + "You may need to add an 'await' into your view." + ) + with self.assertRaisesMessage(ValueError, msg): + await self.async_client.get('/unawaited/') diff --git a/tests/handlers/urls.py b/tests/handlers/urls.py index b008395267..a438da55b4 100644 --- a/tests/handlers/urls.py +++ b/tests/handlers/urls.py @@ -4,6 +4,7 @@ from . import views urlpatterns = [ path('regular/', views.regular), + path('async_regular/', views.async_regular), path('no_response_fbv/', views.no_response), path('no_response_cbv/', views.NoResponse()), path('streaming/', views.streaming), @@ -12,4 +13,5 @@ urlpatterns = [ path('suspicious/', views.suspicious), path('malformed_post/', views.malformed_post), path('httpstatus_enum/', views.httpstatus_enum), + path('unawaited/', views.async_unawaited), ] diff --git a/tests/handlers/views.py b/tests/handlers/views.py index 872fd52676..9180c5e5a4 100644 --- a/tests/handlers/views.py +++ b/tests/handlers/views.py @@ -1,3 +1,4 @@ +import asyncio from http import HTTPStatus from django.core.exceptions import SuspiciousOperation @@ -44,3 +45,12 @@ def malformed_post(request): def httpstatus_enum(request): return HttpResponse(status=HTTPStatus.OK) + + +async def async_regular(request): + return HttpResponse(b'regular content') + + +async def async_unawaited(request): + """Return an unawaited coroutine (common error for async views).""" + return asyncio.sleep(0) diff --git a/tests/middleware_exceptions/middleware.py b/tests/middleware_exceptions/middleware.py index 63502c6902..69c6db57e7 100644 --- a/tests/middleware_exceptions/middleware.py +++ b/tests/middleware_exceptions/middleware.py @@ -1,6 +1,9 @@ from django.http import Http404, HttpResponse from django.template import engines from django.template.response import TemplateResponse +from django.utils.decorators import ( + async_only_middleware, sync_and_async_middleware, sync_only_middleware, +) log = [] @@ -18,6 +21,12 @@ class ProcessExceptionMiddleware(BaseMiddleware): return HttpResponse('Exception caught') +@async_only_middleware +class AsyncProcessExceptionMiddleware(BaseMiddleware): + async def process_exception(self, request, exception): + return HttpResponse('Exception caught') + + class ProcessExceptionLogMiddleware(BaseMiddleware): def process_exception(self, request, exception): log.append('process-exception') @@ -33,6 +42,12 @@ class ProcessViewMiddleware(BaseMiddleware): return HttpResponse('Processed view %s' % view_func.__name__) +@async_only_middleware +class AsyncProcessViewMiddleware(BaseMiddleware): + async def process_view(self, request, view_func, view_args, view_kwargs): + return HttpResponse('Processed view %s' % view_func.__name__) + + class ProcessViewNoneMiddleware(BaseMiddleware): def process_view(self, request, view_func, view_args, view_kwargs): log.append('processed view %s' % view_func.__name__) @@ -51,6 +66,13 @@ class TemplateResponseMiddleware(BaseMiddleware): return response +@async_only_middleware +class AsyncTemplateResponseMiddleware(BaseMiddleware): + async def process_template_response(self, request, response): + response.context_data['mw'].append(self.__class__.__name__) + return response + + class LogMiddleware(BaseMiddleware): def __call__(self, request): response = self.get_response(request) @@ -63,6 +85,48 @@ class NoTemplateResponseMiddleware(BaseMiddleware): return None +@async_only_middleware +class AsyncNoTemplateResponseMiddleware(BaseMiddleware): + async def process_template_response(self, request, response): + return None + + class NotFoundMiddleware(BaseMiddleware): def __call__(self, request): raise Http404('not found') + + +class TeapotMiddleware(BaseMiddleware): + def __call__(self, request): + response = self.get_response(request) + response.status_code = 418 + return response + + +@async_only_middleware +def async_teapot_middleware(get_response): + async def middleware(request): + response = await get_response(request) + response.status_code = 418 + return response + + return middleware + + +@sync_and_async_middleware +class SyncAndAsyncMiddleware(BaseMiddleware): + pass + + +@sync_only_middleware +class DecoratedTeapotMiddleware(TeapotMiddleware): + pass + + +class NotSyncOrAsyncMiddleware(BaseMiddleware): + """Middleware that is deliberately neither sync or async.""" + sync_capable = False + async_capable = False + + def __call__(self, request): + return self.get_response(request) diff --git a/tests/middleware_exceptions/tests.py b/tests/middleware_exceptions/tests.py index 3e614ae0de..697841a35d 100644 --- a/tests/middleware_exceptions/tests.py +++ b/tests/middleware_exceptions/tests.py @@ -180,3 +180,162 @@ class MiddlewareNotUsedTests(SimpleTestCase): with self.assertRaisesMessage(AssertionError, 'no logs'): with self.assertLogs('django.request', 'DEBUG'): self.client.get('/middleware_exceptions/view/') + + +@override_settings( + DEBUG=True, + ROOT_URLCONF='middleware_exceptions.urls', +) +class MiddlewareSyncAsyncTests(SimpleTestCase): + @override_settings(MIDDLEWARE=[ + 'middleware_exceptions.middleware.TeapotMiddleware', + ]) + def test_sync_teapot_middleware(self): + response = self.client.get('/middleware_exceptions/view/') + self.assertEqual(response.status_code, 418) + + @override_settings(MIDDLEWARE=[ + 'middleware_exceptions.middleware.DecoratedTeapotMiddleware', + ]) + def test_sync_decorated_teapot_middleware(self): + response = self.client.get('/middleware_exceptions/view/') + self.assertEqual(response.status_code, 418) + + @override_settings(MIDDLEWARE=[ + 'middleware_exceptions.middleware.async_teapot_middleware', + ]) + def test_async_teapot_middleware(self): + with self.assertLogs('django.request', 'DEBUG') as cm: + response = self.client.get('/middleware_exceptions/view/') + self.assertEqual(response.status_code, 418) + self.assertEqual( + cm.records[0].getMessage(), + "Synchronous middleware " + "middleware_exceptions.middleware.async_teapot_middleware " + "adapted.", + ) + + @override_settings(MIDDLEWARE=[ + 'middleware_exceptions.middleware.NotSyncOrAsyncMiddleware', + ]) + def test_not_sync_or_async_middleware(self): + msg = ( + 'Middleware ' + 'middleware_exceptions.middleware.NotSyncOrAsyncMiddleware must ' + 'have at least one of sync_capable/async_capable set to True.' + ) + with self.assertRaisesMessage(RuntimeError, msg): + self.client.get('/middleware_exceptions/view/') + + @override_settings(MIDDLEWARE=[ + 'middleware_exceptions.middleware.TeapotMiddleware', + ]) + async def test_sync_teapot_middleware_async(self): + with self.assertLogs('django.request', 'DEBUG') as cm: + response = await self.async_client.get('/middleware_exceptions/view/') + self.assertEqual(response.status_code, 418) + self.assertEqual( + cm.records[0].getMessage(), + "Asynchronous middleware " + "middleware_exceptions.middleware.TeapotMiddleware adapted.", + ) + + @override_settings(MIDDLEWARE=[ + 'middleware_exceptions.middleware.async_teapot_middleware', + ]) + async def test_async_teapot_middleware_async(self): + with self.assertLogs('django.request', 'WARNING') as cm: + response = await self.async_client.get('/middleware_exceptions/view/') + self.assertEqual(response.status_code, 418) + self.assertEqual( + cm.records[0].getMessage(), + 'Unknown Status Code: /middleware_exceptions/view/', + ) + + @override_settings( + DEBUG=False, + MIDDLEWARE=[ + 'middleware_exceptions.middleware.AsyncNoTemplateResponseMiddleware', + ], + ) + def test_async_process_template_response_returns_none_with_sync_client(self): + msg = ( + "AsyncNoTemplateResponseMiddleware.process_template_response " + "didn't return an HttpResponse object." + ) + with self.assertRaisesMessage(ValueError, msg): + self.client.get('/middleware_exceptions/template_response/') + + @override_settings(MIDDLEWARE=[ + 'middleware_exceptions.middleware.SyncAndAsyncMiddleware', + ]) + async def test_async_and_sync_middleware_async_call(self): + response = await self.async_client.get('/middleware_exceptions/view/') + self.assertEqual(response.content, b'OK') + self.assertEqual(response.status_code, 200) + + @override_settings(MIDDLEWARE=[ + 'middleware_exceptions.middleware.SyncAndAsyncMiddleware', + ]) + def test_async_and_sync_middleware_sync_call(self): + response = self.client.get('/middleware_exceptions/view/') + self.assertEqual(response.content, b'OK') + self.assertEqual(response.status_code, 200) + + +@override_settings(ROOT_URLCONF='middleware_exceptions.urls') +class AsyncMiddlewareTests(SimpleTestCase): + @override_settings(MIDDLEWARE=[ + 'middleware_exceptions.middleware.AsyncTemplateResponseMiddleware', + ]) + async def test_process_template_response(self): + response = await self.async_client.get( + '/middleware_exceptions/template_response/' + ) + self.assertEqual( + response.content, + b'template_response OK\nAsyncTemplateResponseMiddleware', + ) + + @override_settings(MIDDLEWARE=[ + 'middleware_exceptions.middleware.AsyncNoTemplateResponseMiddleware', + ]) + async def test_process_template_response_returns_none(self): + msg = ( + "AsyncNoTemplateResponseMiddleware.process_template_response " + "didn't return an HttpResponse object. It returned None instead." + ) + with self.assertRaisesMessage(ValueError, msg): + await self.async_client.get('/middleware_exceptions/template_response/') + + @override_settings(MIDDLEWARE=[ + 'middleware_exceptions.middleware.AsyncProcessExceptionMiddleware', + ]) + async def test_exception_in_render_passed_to_process_exception(self): + response = await self.async_client.get( + '/middleware_exceptions/exception_in_render/' + ) + self.assertEqual(response.content, b'Exception caught') + + @override_settings(MIDDLEWARE=[ + 'middleware_exceptions.middleware.AsyncProcessExceptionMiddleware', + ]) + async def test_exception_in_async_render_passed_to_process_exception(self): + response = await self.async_client.get( + '/middleware_exceptions/async_exception_in_render/' + ) + self.assertEqual(response.content, b'Exception caught') + + @override_settings(MIDDLEWARE=[ + 'middleware_exceptions.middleware.AsyncProcessExceptionMiddleware', + ]) + async def test_view_exception_handled_by_process_exception(self): + response = await self.async_client.get('/middleware_exceptions/error/') + self.assertEqual(response.content, b'Exception caught') + + @override_settings(MIDDLEWARE=[ + 'middleware_exceptions.middleware.AsyncProcessViewMiddleware', + ]) + async def test_process_view_return_response(self): + response = await self.async_client.get('/middleware_exceptions/view/') + self.assertEqual(response.content, b'Processed view normal_view') diff --git a/tests/middleware_exceptions/urls.py b/tests/middleware_exceptions/urls.py index 46332916b6..d676ef470c 100644 --- a/tests/middleware_exceptions/urls.py +++ b/tests/middleware_exceptions/urls.py @@ -8,4 +8,9 @@ urlpatterns = [ path('middleware_exceptions/permission_denied/', views.permission_denied), path('middleware_exceptions/exception_in_render/', views.exception_in_render), path('middleware_exceptions/template_response/', views.template_response), + # Async views. + path( + 'middleware_exceptions/async_exception_in_render/', + views.async_exception_in_render, + ), ] diff --git a/tests/middleware_exceptions/views.py b/tests/middleware_exceptions/views.py index 3ae54081ab..7a1d244863 100644 --- a/tests/middleware_exceptions/views.py +++ b/tests/middleware_exceptions/views.py @@ -27,3 +27,11 @@ def exception_in_render(request): raise Exception('Exception in HttpResponse.render()') return CustomHttpResponse('Error') + + +async def async_exception_in_render(request): + class CustomHttpResponse(HttpResponse): + async def render(self): + raise Exception('Exception in HttpResponse.render()') + + return CustomHttpResponse('Error') diff --git a/tests/test_client/tests.py b/tests/test_client/tests.py index ce9ce40de5..93ac3e37a7 100644 --- a/tests/test_client/tests.py +++ b/tests/test_client/tests.py @@ -25,9 +25,10 @@ from unittest import mock from django.contrib.auth.models import User from django.core import mail -from django.http import HttpResponse +from django.http import HttpResponse, HttpResponseNotAllowed from django.test import ( - Client, RequestFactory, SimpleTestCase, TestCase, override_settings, + AsyncRequestFactory, Client, RequestFactory, SimpleTestCase, TestCase, + override_settings, ) from django.urls import reverse_lazy @@ -918,3 +919,57 @@ class RequestFactoryTest(SimpleTestCase): protocol = request.META["SERVER_PROTOCOL"] echoed_request_line = "TRACE {} {}".format(url_path, protocol) self.assertContains(response, echoed_request_line) + + +@override_settings(ROOT_URLCONF='test_client.urls') +class AsyncClientTest(TestCase): + async def test_response_resolver_match(self): + response = await self.async_client.get('/async_get_view/') + self.assertTrue(hasattr(response, 'resolver_match')) + self.assertEqual(response.resolver_match.url_name, 'async_get_view') + + async def test_follow_parameter_not_implemented(self): + msg = 'AsyncClient request methods do not accept the follow parameter.' + tests = ( + 'get', + 'post', + 'put', + 'patch', + 'delete', + 'head', + 'options', + 'trace', + ) + for method_name in tests: + with self.subTest(method=method_name): + method = getattr(self.async_client, method_name) + with self.assertRaisesMessage(NotImplementedError, msg): + await method('/redirect_view/', follow=True) + + +@override_settings(ROOT_URLCONF='test_client.urls') +class AsyncRequestFactoryTest(SimpleTestCase): + request_factory = AsyncRequestFactory() + + async def test_request_factory(self): + tests = ( + 'get', + 'post', + 'put', + 'patch', + 'delete', + 'head', + 'options', + 'trace', + ) + for method_name in tests: + with self.subTest(method=method_name): + async def async_generic_view(request): + if request.method.lower() != method_name: + return HttpResponseNotAllowed(method_name) + return HttpResponse(status=200) + + method = getattr(self.request_factory, method_name) + request = method('/somewhere/') + response = await async_generic_view(request) + self.assertEqual(response.status_code, 200) diff --git a/tests/test_client/urls.py b/tests/test_client/urls.py index 61cbe00547..16cca52c38 100644 --- a/tests/test_client/urls.py +++ b/tests/test_client/urls.py @@ -44,4 +44,6 @@ urlpatterns = [ path('accounts/no_trailing_slash', RedirectView.as_view(url='login/')), path('accounts/login/', auth_views.LoginView.as_view(template_name='login.html')), path('accounts/logout/', auth_views.LogoutView.as_view()), + # Async views. + path('async_get_view/', views.async_get_view, name='async_get_view'), ] diff --git a/tests/test_client/views.py b/tests/test_client/views.py index 2d076fafaf..c2aef76508 100644 --- a/tests/test_client/views.py +++ b/tests/test_client/views.py @@ -25,6 +25,10 @@ def get_view(request): return HttpResponse(t.render(c)) +async def async_get_view(request): + return HttpResponse(b'GET content.') + + def trace_view(request): """ A simple view that expects a TRACE request and echoes its status line. |
