summaryrefslogtreecommitdiff
path: root/tests/asgi/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/asgi/tests.py')
-rw-r--r--tests/asgi/tests.py29
1 files changed, 17 insertions, 12 deletions
diff --git a/tests/asgi/tests.py b/tests/asgi/tests.py
index 3aeade4c05..0fbb586f85 100644
--- a/tests/asgi/tests.py
+++ b/tests/asgi/tests.py
@@ -26,6 +26,17 @@ from .urls import sync_waiter, test_filename
TEST_STATIC_ROOT = Path(__file__).parent / "project" / "static"
+class SignalHandler:
+ """Helper class to track threads and kwargs when signals are dispatched."""
+
+ def __init__(self):
+ super().__init__()
+ self.calls = []
+
+ def __call__(self, signal, **kwargs):
+ self.calls.append({"thread": threading.current_thread(), "kwargs": kwargs})
+
+
@override_settings(ROOT_URLCONF="asgi.urls")
class ASGITest(SimpleTestCase):
async_request_factory = AsyncRequestFactory()
@@ -310,17 +321,12 @@ class ASGITest(SimpleTestCase):
self.assertEqual(response_body["body"], b"")
async def test_request_lifecycle_signals_dispatched_with_thread_sensitive(self):
- class SignalHandler:
- """Track threads handler is dispatched on."""
-
- threads = []
-
- def __call__(self, **kwargs):
- self.threads.append(threading.current_thread())
-
+ # Track request_started and request_finished signals.
signal_handler = SignalHandler()
request_started.connect(signal_handler)
+ self.addCleanup(request_started.disconnect, signal_handler)
request_finished.connect(signal_handler)
+ self.addCleanup(request_finished.disconnect, signal_handler)
# Perform a basic request.
application = get_asgi_application()
@@ -337,10 +343,9 @@ class ASGITest(SimpleTestCase):
await communicator.wait()
# AsyncToSync should have executed the signals in the same thread.
- request_started_thread, request_finished_thread = signal_handler.threads
- self.assertEqual(request_started_thread, request_finished_thread)
- request_started.disconnect(signal_handler)
- request_finished.disconnect(signal_handler)
+ self.assertEqual(len(signal_handler.calls), 2)
+ request_started_call, request_finished_call = signal_handler.calls
+ self.assertEqual(request_started_call["thread"], request_finished_call["thread"])
async def test_concurrent_async_uses_multiple_thread_pools(self):
sync_waiter.active_threads.clear()