diff options
| author | Arfey <Arfey17.mg@gmail.com> | 2025-11-10 01:10:32 +0200 |
|---|---|---|
| committer | Jacob Walls <jacobtylerwalls@gmail.com> | 2025-12-29 09:48:11 -0500 |
| commit | cc0f6c4f74cc278fdab79b269401127f2d869334 (patch) | |
| tree | 55c4ebf38ffbc8247d6bf1688bf9bc2e9a747852 /django | |
| parent | 1c34b8716afde049f95ad1c72c2f8e148f826662 (diff) | |
Fixed #36714 -- Fixed context sharing among async signal handlers.
Diffstat (limited to 'django')
| -rw-r--r-- | django/dispatch/dispatcher.py | 57 |
1 files changed, 37 insertions, 20 deletions
diff --git a/django/dispatch/dispatcher.py b/django/dispatch/dispatcher.py index 63fb75285e..21d77bd884 100644 --- a/django/dispatch/dispatcher.py +++ b/django/dispatch/dispatcher.py @@ -1,4 +1,5 @@ import asyncio +import contextvars import logging import threading import weakref @@ -22,13 +23,30 @@ NONE_ID = _make_id(None) NO_RECEIVERS = object() -async def _gather(*coros): +def _restore_context(context): + """ + Check for changes in contextvars, and set them to the current + context for downstream consumers. + """ + for cvar in context: + cvalue = context.get(cvar) + try: + if cvar.get() != cvalue: + cvar.set(cvalue) + except LookupError: + cvar.set(cvalue) + + +async def _run_parallel(*coros): + """ + Execute multiple asynchronous coroutines in parallel, + sharing the current context between them. + """ + context = contextvars.copy_context() + if len(coros) == 0: return [] - if len(coros) == 1: - return [await coros[0]] - async def run(i, coro): results[i] = await coro @@ -36,12 +54,14 @@ async def _gather(*coros): async with asyncio.TaskGroup() as tg: results = [None] * len(coros) for i, coro in enumerate(coros): - tg.create_task(run(i, coro)) + tg.create_task(run(i, coro), context=context) return results except BaseExceptionGroup as exception_group: if len(exception_group.exceptions) == 1: raise exception_group.exceptions[0] raise + finally: + _restore_context(context=context) class Signal: @@ -233,7 +253,7 @@ class Signal: if async_receivers: async def asend(): - async_responses = await _gather( + async_responses = await _run_parallel( *( receiver(signal=self, sender=sender, **named) for receiver in async_receivers @@ -275,6 +295,7 @@ class Signal: ): return [] sync_receivers, async_receivers = self._live_receivers(sender) + if sync_receivers: @sync_to_async @@ -290,14 +311,12 @@ class Signal: async def sync_send(): return [] - responses, async_responses = await _gather( - sync_send(), - _gather( - *( - receiver(signal=self, sender=sender, **named) - for receiver in async_receivers - ) - ), + responses = await sync_send() + async_responses = await _run_parallel( + *( + receiver(signal=self, sender=sender, **named) + for receiver in async_receivers + ) ) responses.extend(zip(async_receivers, async_responses)) return responses @@ -362,7 +381,7 @@ class Signal: return response async def asend(): - async_responses = await _gather( + async_responses = await _run_parallel( *( asend_and_wrap_exception(receiver) for receiver in async_receivers @@ -436,11 +455,9 @@ class Signal: return err return response - responses, async_responses = await _gather( - sync_send(), - _gather( - *(asend_and_wrap_exception(receiver) for receiver in async_receivers), - ), + responses = await sync_send() + async_responses = await _run_parallel( + *(asend_and_wrap_exception(receiver) for receiver in async_receivers), ) responses.extend(zip(async_receivers, async_responses)) return responses |
