summaryrefslogtreecommitdiff
path: root/django/dispatch
diff options
context:
space:
mode:
authorJon Janzen <jon@jonjanzen.com>2020-11-07 13:19:20 +0300
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2023-03-07 08:39:25 +0100
commite83a88566a71a2353cebc35992c110be0f8628af (patch)
tree466c863fc3bfe6fc9946b5a3f7163c62e58ecbb9 /django/dispatch
parent9a07999aef7958c9b5441e368cd90646d0edc5c9 (diff)
Fixed #32172 -- Adapted signals to allow async handlers.
co-authored-by: kozzztik <kozzztik@mail.ru> co-authored-by: Carlton Gibson <carlton.gibson@noumenal.es>
Diffstat (limited to 'django/dispatch')
-rw-r--r--django/dispatch/dispatcher.py233
1 files changed, 209 insertions, 24 deletions
diff --git a/django/dispatch/dispatcher.py b/django/dispatch/dispatcher.py
index 86eb1c3b20..26ef09ce49 100644
--- a/django/dispatch/dispatcher.py
+++ b/django/dispatch/dispatcher.py
@@ -1,7 +1,10 @@
+import asyncio
import logging
import threading
import weakref
+from asgiref.sync import async_to_sync, iscoroutinefunction, sync_to_async
+
from django.utils.inspect import func_accepts_kwargs
logger = logging.getLogger("django.dispatch")
@@ -52,7 +55,8 @@ class Signal:
receiver
A function or an instance method which is to receive signals.
- Receivers must be hashable objects.
+ Receivers must be hashable objects. Receivers can be
+ asynchronous.
If weak is True, then receiver must be weak referenceable.
@@ -94,6 +98,8 @@ class Signal:
else:
lookup_key = (_make_id(receiver), _make_id(sender))
+ is_async = iscoroutinefunction(receiver)
+
if weak:
ref = weakref.ref
receiver_object = receiver
@@ -106,8 +112,8 @@ class Signal:
with self.lock:
self._clear_dead_receivers()
- if not any(r_key == lookup_key for r_key, _ in self.receivers):
- self.receivers.append((lookup_key, receiver))
+ if not any(r_key == lookup_key for r_key, _, _ in self.receivers):
+ self.receivers.append((lookup_key, receiver, is_async))
self.sender_receivers_cache.clear()
def disconnect(self, receiver=None, sender=None, dispatch_uid=None):
@@ -138,7 +144,7 @@ class Signal:
with self.lock:
self._clear_dead_receivers()
for index in range(len(self.receivers)):
- (r_key, _) = self.receivers[index]
+ r_key, *_ = self.receivers[index]
if r_key == lookup_key:
disconnected = True
del self.receivers[index]
@@ -147,7 +153,8 @@ class Signal:
return disconnected
def has_listeners(self, sender=None):
- return bool(self._live_receivers(sender))
+ sync_receivers, async_receivers = self._live_receivers(sender)
+ return bool(sync_receivers) or bool(async_receivers)
def send(self, sender, **named):
"""
@@ -157,6 +164,10 @@ class Signal:
terminating the dispatch loop. So it's possible that all receivers
won't be called if an error is raised.
+ If any receivers are asynchronous, they are called after all the
+ synchronous receivers via a single call to async_to_sync(). They are
+ also executed concurrently with asyncio.gather().
+
Arguments:
sender
@@ -172,16 +183,97 @@ class Signal:
or self.sender_receivers_cache.get(sender) is NO_RECEIVERS
):
return []
+ responses = []
+ sync_receivers, async_receivers = self._live_receivers(sender)
+ for receiver in sync_receivers:
+ response = receiver(signal=self, sender=sender, **named)
+ responses.append((receiver, response))
+ if async_receivers:
+
+ async def asend():
+ async_responses = await asyncio.gather(
+ *(
+ receiver(signal=self, sender=sender, **named)
+ for receiver in async_receivers
+ )
+ )
+ return zip(async_receivers, async_responses)
+
+ responses.extend(async_to_sync(asend)())
+ return responses
+
+ async def asend(self, sender, **named):
+ """
+ Send signal from sender to all connected receivers in async mode.
+
+ All sync receivers will be wrapped by sync_to_async()
+ If any receiver raises an error, the error propagates back through
+ send, terminating the dispatch loop. So it's possible that all
+ receivers won't be called if an error is raised.
+
+ If any receivers are synchronous, they are grouped and called behind a
+ sync_to_async() adaption before executing any asynchronous receivers.
+
+ If any receivers are asynchronous, they are grouped and executed
+ concurrently with asyncio.gather().
+
+ Arguments:
+
+ sender
+ The sender of the signal. Either a specific object or None.
+
+ named
+ Named arguments which will be passed to receivers.
+
+ Return a list of tuple pairs [(receiver, response), ...].
+ """
+ if (
+ not self.receivers
+ or self.sender_receivers_cache.get(sender) is NO_RECEIVERS
+ ):
+ return []
+ sync_receivers, async_receivers = self._live_receivers(sender)
+ if sync_receivers:
+
+ @sync_to_async
+ def sync_send():
+ responses = []
+ for receiver in sync_receivers:
+ response = receiver(signal=self, sender=sender, **named)
+ responses.append((receiver, response))
+ return responses
+
+ else:
+ sync_send = list
- return [
- (receiver, receiver(signal=self, sender=sender, **named))
- for receiver in self._live_receivers(sender)
- ]
+ responses, async_responses = await asyncio.gather(
+ sync_send(),
+ asyncio.gather(
+ *(
+ receiver(signal=self, sender=sender, **named)
+ for receiver in async_receivers
+ )
+ ),
+ )
+ responses.extend(zip(async_receivers, async_responses))
+ return responses
+
+ def _log_robust_failure(self, receiver, err):
+ logger.error(
+ "Error calling %s in Signal.send_robust() (%s)",
+ receiver.__qualname__,
+ err,
+ exc_info=err,
+ )
def send_robust(self, sender, **named):
"""
Send signal from sender to all connected receivers catching errors.
+ If any receivers are asynchronous, they are called after all the
+ synchronous receivers via a single call to async_to_sync(). They are
+ also executed concurrently with asyncio.gather().
+
Arguments:
sender
@@ -206,19 +298,105 @@ class Signal:
# Call each receiver with whatever arguments it can accept.
# Return a list of tuple pairs [(receiver, response), ... ].
responses = []
- for receiver in self._live_receivers(sender):
+ sync_receivers, async_receivers = self._live_receivers(sender)
+ for receiver in sync_receivers:
try:
response = receiver(signal=self, sender=sender, **named)
except Exception as err:
- logger.error(
- "Error calling %s in Signal.send_robust() (%s)",
- receiver.__qualname__,
- err,
- exc_info=err,
- )
+ self._log_robust_failure(receiver, err)
responses.append((receiver, err))
else:
responses.append((receiver, response))
+ if async_receivers:
+
+ async def asend_and_wrap_exception(receiver):
+ try:
+ response = await receiver(signal=self, sender=sender, **named)
+ except Exception as err:
+ self._log_robust_failure(receiver, err)
+ return err
+ return response
+
+ async def asend():
+ async_responses = await asyncio.gather(
+ *(
+ asend_and_wrap_exception(receiver)
+ for receiver in async_receivers
+ )
+ )
+ return zip(async_receivers, async_responses)
+
+ responses.extend(async_to_sync(asend)())
+ return responses
+
+ async def asend_robust(self, sender, **named):
+ """
+ Send signal from sender to all connected receivers catching errors.
+
+ If any receivers are synchronous, they are grouped and called behind a
+ sync_to_async() adaption before executing any asynchronous receivers.
+
+ If any receivers are asynchronous, they are grouped and executed
+ concurrently with asyncio.gather.
+
+ Arguments:
+
+ sender
+ The sender of the signal. Can be any Python object (normally one
+ registered with a connect if you actually want something to
+ occur).
+
+ named
+ Named arguments which will be passed to receivers.
+
+ Return a list of tuple pairs [(receiver, response), ... ].
+
+ If any receiver raises an error (specifically any subclass of
+ Exception), return the error instance as the result for that receiver.
+ """
+ if (
+ not self.receivers
+ or self.sender_receivers_cache.get(sender) is NO_RECEIVERS
+ ):
+ return []
+
+ # Call each receiver with whatever arguments it can accept.
+ # Return a list of tuple pairs [(receiver, response), ... ].
+ sync_receivers, async_receivers = self._live_receivers(sender)
+
+ if sync_receivers:
+
+ @sync_to_async
+ def sync_send():
+ responses = []
+ for receiver in sync_receivers:
+ try:
+ response = receiver(signal=self, sender=sender, **named)
+ except Exception as err:
+ self._log_robust_failure(receiver, err)
+ responses.append((receiver, err))
+ else:
+ responses.append((receiver, response))
+ return responses
+
+ else:
+ sync_send = list
+
+ async def asend_and_wrap_exception(receiver):
+ try:
+ response = await receiver(signal=self, sender=sender, **named)
+ except Exception as err:
+ self._log_robust_failure(receiver, err)
+ return err
+ return response
+
+ responses, async_responses = await asyncio.gather(
+ sync_send(),
+ asyncio.gather(
+ *(asend_and_wrap_exception(receiver) for receiver in async_receivers),
+ ),
+ )
+ responses.extend(zip(async_receivers, async_responses))
return responses
def _clear_dead_receivers(self):
@@ -244,31 +422,38 @@ class Signal:
# We could end up here with NO_RECEIVERS even if we do check this case in
# .send() prior to calling _live_receivers() due to concurrent .send() call.
if receivers is NO_RECEIVERS:
- return []
+ return [], []
if receivers is None:
with self.lock:
self._clear_dead_receivers()
senderkey = _make_id(sender)
receivers = []
- for (receiverkey, r_senderkey), receiver in self.receivers:
+ for (_receiverkey, r_senderkey), receiver, is_async in self.receivers:
if r_senderkey == NONE_ID or r_senderkey == senderkey:
- receivers.append(receiver)
+ receivers.append((receiver, is_async))
if self.use_caching:
if not receivers:
self.sender_receivers_cache[sender] = NO_RECEIVERS
else:
# Note, we must cache the weakref versions.
self.sender_receivers_cache[sender] = receivers
- non_weak_receivers = []
- for receiver in receivers:
+ non_weak_sync_receivers = []
+ non_weak_async_receivers = []
+ for receiver, is_async in receivers:
if isinstance(receiver, weakref.ReferenceType):
# Dereference the weak reference.
receiver = receiver()
if receiver is not None:
- non_weak_receivers.append(receiver)
+ if is_async:
+ non_weak_async_receivers.append(receiver)
+ else:
+ non_weak_sync_receivers.append(receiver)
else:
- non_weak_receivers.append(receiver)
- return non_weak_receivers
+ if is_async:
+ non_weak_async_receivers.append(receiver)
+ else:
+ non_weak_sync_receivers.append(receiver)
+ return non_weak_sync_receivers, non_weak_async_receivers
def _remove_receiver(self, receiver=None):
# Mark that the self.receivers list has dead weakrefs. If so, we will