summaryrefslogtreecommitdiff
path: root/tests/asgi/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/asgi/tests.py')
-rw-r--r--tests/asgi/tests.py134
1 files changed, 133 insertions, 1 deletions
diff --git a/tests/asgi/tests.py b/tests/asgi/tests.py
index 0fbb586f85..963f45f798 100644
--- a/tests/asgi/tests.py
+++ b/tests/asgi/tests.py
@@ -1,12 +1,15 @@
import asyncio
import sys
import threading
+import time
from pathlib import Path
+from asgiref.sync import sync_to_async
from asgiref.testing import ApplicationCommunicator
from django.contrib.staticfiles.handlers import ASGIStaticFilesHandler
from django.core.asgi import get_asgi_application
+from django.core.exceptions import RequestDataTooBig
from django.core.handlers.asgi import ASGIHandler, ASGIRequest
from django.core.signals import request_finished, request_started
from django.db import close_old_connections
@@ -20,6 +23,7 @@ from django.test import (
)
from django.urls import path
from django.utils.http import http_date
+from django.views.decorators.csrf import csrf_exempt
from .urls import sync_waiter, test_filename
@@ -205,6 +209,96 @@ class ASGITest(SimpleTestCase):
self.assertEqual(response_body["type"], "http.response.body")
self.assertEqual(response_body["body"], b"Echo!")
+ async def test_create_request_error(self):
+ # Track request_finished signal.
+ signal_handler = SignalHandler()
+ request_finished.connect(signal_handler)
+ self.addCleanup(request_finished.disconnect, signal_handler)
+
+ # Request class that always fails creation with RequestDataTooBig.
+ class TestASGIRequest(ASGIRequest):
+
+ def __init__(self, scope, body_file):
+ super().__init__(scope, body_file)
+ raise RequestDataTooBig()
+
+ # Handler to use the custom request class.
+ class TestASGIHandler(ASGIHandler):
+ request_class = TestASGIRequest
+
+ application = TestASGIHandler()
+ scope = self.async_request_factory._base_scope(path="/not-important/")
+ communicator = ApplicationCommunicator(application, scope)
+
+ # Initiate request.
+ await communicator.send_input({"type": "http.request"})
+ # Give response.close() time to finish.
+ await communicator.wait()
+
+ self.assertEqual(len(signal_handler.calls), 1)
+ self.assertNotEqual(
+ signal_handler.calls[0]["thread"], threading.current_thread()
+ )
+
+ async def test_cancel_post_request_with_sync_processing(self):
+ """
+ The request.body object should be available and readable in view
+ code, even if the ASGIHandler cancels processing part way through.
+ """
+ loop = asyncio.get_event_loop()
+ # Events to monitor the view processing from the parent test code.
+ view_started_event = asyncio.Event()
+ view_finished_event = asyncio.Event()
+ # Record received request body or exceptions raised in the test view
+ outcome = []
+
+ # This view will run in a new thread because it is wrapped in
+ # sync_to_async. The view consumes the POST body data after a short
+ # delay. The test will cancel the request using http.disconnect during
+ # the delay, but because this is a sync view the code runs to
+ # completion. There should be no exceptions raised inside the view
+ # code.
+ @csrf_exempt
+ @sync_to_async
+ def post_view(request):
+ try:
+ loop.call_soon_threadsafe(view_started_event.set)
+ time.sleep(0.1)
+ # Do something to read request.body after pause
+ outcome.append({"request_body": request.body})
+ return HttpResponse("ok")
+ except Exception as e:
+ outcome.append({"exception": e})
+ finally:
+ loop.call_soon_threadsafe(view_finished_event.set)
+
+ # Request class to use the view.
+ class TestASGIRequest(ASGIRequest):
+ urlconf = (path("post/", post_view),)
+
+ # Handler to use request class.
+ class TestASGIHandler(ASGIHandler):
+ request_class = TestASGIRequest
+
+ application = TestASGIHandler()
+ scope = self.async_request_factory._base_scope(
+ method="POST",
+ path="/post/",
+ )
+ communicator = ApplicationCommunicator(application, scope)
+
+ await communicator.send_input({"type": "http.request", "body": b"Body data!"})
+
+ # Wait until the view code has started, then send http.disconnect.
+ await view_started_event.wait()
+ await communicator.send_input({"type": "http.disconnect"})
+ # Wait until view code has finished.
+ await view_finished_event.wait()
+ with self.assertRaises(asyncio.TimeoutError):
+ await communicator.receive_output()
+
+ self.assertEqual(outcome, [{"request_body": b"Body data!"}])
+
async def test_untouched_request_body_gets_closed(self):
application = get_asgi_application()
scope = self.async_request_factory._base_scope(method="POST", path="/post/")
@@ -345,7 +439,9 @@ class ASGITest(SimpleTestCase):
# AsyncToSync should have executed the signals in the same thread.
self.assertEqual(len(signal_handler.calls), 2)
request_started_call, request_finished_call = signal_handler.calls
- self.assertEqual(request_started_call["thread"], request_finished_call["thread"])
+ self.assertEqual(
+ request_started_call["thread"], request_finished_call["thread"]
+ )
async def test_concurrent_async_uses_multiple_thread_pools(self):
sync_waiter.active_threads.clear()
@@ -381,6 +477,10 @@ class ASGITest(SimpleTestCase):
async def test_asyncio_cancel_error(self):
# Flag to check if the view was cancelled.
view_did_cancel = False
+ # Track request_finished signal.
+ signal_handler = SignalHandler()
+ request_finished.connect(signal_handler)
+ self.addCleanup(request_finished.disconnect, signal_handler)
# A view that will listen for the cancelled error.
async def view(request):
@@ -415,6 +515,13 @@ class ASGITest(SimpleTestCase):
# Give response.close() time to finish.
await communicator.wait()
self.assertIs(view_did_cancel, False)
+ # Exactly one call to request_finished handler.
+ self.assertEqual(len(signal_handler.calls), 1)
+ handler_call = signal_handler.calls.pop()
+ # It was NOT on the async thread.
+ self.assertNotEqual(handler_call["thread"], threading.current_thread())
+ # The signal sender is the handler class.
+ self.assertEqual(handler_call["kwargs"], {"sender": TestASGIHandler})
# Request cycle with a disconnect before the view can respond.
application = TestASGIHandler()
@@ -430,11 +537,22 @@ class ASGITest(SimpleTestCase):
await communicator.receive_output()
await communicator.wait()
self.assertIs(view_did_cancel, True)
+ # Exactly one call to request_finished handler.
+ self.assertEqual(len(signal_handler.calls), 1)
+ handler_call = signal_handler.calls.pop()
+ # It was NOT on the async thread.
+ self.assertNotEqual(handler_call["thread"], threading.current_thread())
+ # The signal sender is the handler class.
+ self.assertEqual(handler_call["kwargs"], {"sender": TestASGIHandler})
async def test_asyncio_streaming_cancel_error(self):
# Similar to test_asyncio_cancel_error(), but during a streaming
# response.
view_did_cancel = False
+ # Track request_finished signals.
+ signal_handler = SignalHandler()
+ request_finished.connect(signal_handler)
+ self.addCleanup(request_finished.disconnect, signal_handler)
async def streaming_response():
nonlocal view_did_cancel
@@ -469,6 +587,13 @@ class ASGITest(SimpleTestCase):
self.assertEqual(response_body["body"], b"Hello World!")
await communicator.wait()
self.assertIs(view_did_cancel, False)
+ # Exactly one call to request_finished handler.
+ self.assertEqual(len(signal_handler.calls), 1)
+ handler_call = signal_handler.calls.pop()
+ # It was NOT on the async thread.
+ self.assertNotEqual(handler_call["thread"], threading.current_thread())
+ # The signal sender is the handler class.
+ self.assertEqual(handler_call["kwargs"], {"sender": TestASGIHandler})
# Request cycle with a disconnect.
application = TestASGIHandler()
@@ -487,6 +612,13 @@ class ASGITest(SimpleTestCase):
await communicator.receive_output()
await communicator.wait()
self.assertIs(view_did_cancel, True)
+ # Exactly one call to request_finished handler.
+ self.assertEqual(len(signal_handler.calls), 1)
+ handler_call = signal_handler.calls.pop()
+ # It was NOT on the async thread.
+ self.assertNotEqual(handler_call["thread"], threading.current_thread())
+ # The signal sender is the handler class.
+ self.assertEqual(handler_call["kwargs"], {"sender": TestASGIHandler})
async def test_streaming(self):
scope = self.async_request_factory._base_scope(