summaryrefslogtreecommitdiff
path: root/django/dispatch
diff options
context:
space:
mode:
authorArfey <Arfey17.mg@gmail.com>2025-11-10 01:10:32 +0200
committerJacob Walls <jacobtylerwalls@gmail.com>2025-12-29 09:48:11 -0500
commitcc0f6c4f74cc278fdab79b269401127f2d869334 (patch)
tree55c4ebf38ffbc8247d6bf1688bf9bc2e9a747852 /django/dispatch
parent1c34b8716afde049f95ad1c72c2f8e148f826662 (diff)
Fixed #36714 -- Fixed context sharing among async signal handlers.
Diffstat (limited to 'django/dispatch')
-rw-r--r--django/dispatch/dispatcher.py57
1 files changed, 37 insertions, 20 deletions
diff --git a/django/dispatch/dispatcher.py b/django/dispatch/dispatcher.py
index 63fb75285e..21d77bd884 100644
--- a/django/dispatch/dispatcher.py
+++ b/django/dispatch/dispatcher.py
@@ -1,4 +1,5 @@
import asyncio
+import contextvars
import logging
import threading
import weakref
@@ -22,13 +23,30 @@ NONE_ID = _make_id(None)
NO_RECEIVERS = object()
-async def _gather(*coros):
+def _restore_context(context):
+ """
+ Check for changes in contextvars, and set them to the current
+ context for downstream consumers.
+ """
+ for cvar in context:
+ cvalue = context.get(cvar)
+ try:
+ if cvar.get() != cvalue:
+ cvar.set(cvalue)
+ except LookupError:
+ cvar.set(cvalue)
+
+
+async def _run_parallel(*coros):
+ """
+ Execute multiple asynchronous coroutines in parallel,
+ sharing the current context between them.
+ """
+ context = contextvars.copy_context()
+
if len(coros) == 0:
return []
- if len(coros) == 1:
- return [await coros[0]]
-
async def run(i, coro):
results[i] = await coro
@@ -36,12 +54,14 @@ async def _gather(*coros):
async with asyncio.TaskGroup() as tg:
results = [None] * len(coros)
for i, coro in enumerate(coros):
- tg.create_task(run(i, coro))
+ tg.create_task(run(i, coro), context=context)
return results
except BaseExceptionGroup as exception_group:
if len(exception_group.exceptions) == 1:
raise exception_group.exceptions[0]
raise
+ finally:
+ _restore_context(context=context)
class Signal:
@@ -233,7 +253,7 @@ class Signal:
if async_receivers:
async def asend():
- async_responses = await _gather(
+ async_responses = await _run_parallel(
*(
receiver(signal=self, sender=sender, **named)
for receiver in async_receivers
@@ -275,6 +295,7 @@ class Signal:
):
return []
sync_receivers, async_receivers = self._live_receivers(sender)
+
if sync_receivers:
@sync_to_async
@@ -290,14 +311,12 @@ class Signal:
async def sync_send():
return []
- responses, async_responses = await _gather(
- sync_send(),
- _gather(
- *(
- receiver(signal=self, sender=sender, **named)
- for receiver in async_receivers
- )
- ),
+ responses = await sync_send()
+ async_responses = await _run_parallel(
+ *(
+ receiver(signal=self, sender=sender, **named)
+ for receiver in async_receivers
+ )
)
responses.extend(zip(async_receivers, async_responses))
return responses
@@ -362,7 +381,7 @@ class Signal:
return response
async def asend():
- async_responses = await _gather(
+ async_responses = await _run_parallel(
*(
asend_and_wrap_exception(receiver)
for receiver in async_receivers
@@ -436,11 +455,9 @@ class Signal:
return err
return response
- responses, async_responses = await _gather(
- sync_send(),
- _gather(
- *(asend_and_wrap_exception(receiver) for receiver in async_receivers),
- ),
+ responses = await sync_send()
+ async_responses = await _run_parallel(
+ *(asend_and_wrap_exception(receiver) for receiver in async_receivers),
)
responses.extend(zip(async_receivers, async_responses))
return responses