diff options
| author | Jon Janzen <jon@jonjanzen.com> | 2020-11-07 13:19:20 +0300 |
|---|---|---|
| committer | Mariusz Felisiak <felisiak.mariusz@gmail.com> | 2023-03-07 08:39:25 +0100 |
| commit | e83a88566a71a2353cebc35992c110be0f8628af (patch) | |
| tree | 466c863fc3bfe6fc9946b5a3f7163c62e58ecbb9 /django/dispatch | |
| parent | 9a07999aef7958c9b5441e368cd90646d0edc5c9 (diff) | |
Fixed #32172 -- Adapted signals to allow async handlers.
co-authored-by: kozzztik <kozzztik@mail.ru>
co-authored-by: Carlton Gibson <carlton.gibson@noumenal.es>
Diffstat (limited to 'django/dispatch')
| -rw-r--r-- | django/dispatch/dispatcher.py | 233 |
1 files changed, 209 insertions, 24 deletions
diff --git a/django/dispatch/dispatcher.py b/django/dispatch/dispatcher.py index 86eb1c3b20..26ef09ce49 100644 --- a/django/dispatch/dispatcher.py +++ b/django/dispatch/dispatcher.py @@ -1,7 +1,10 @@ +import asyncio import logging import threading import weakref +from asgiref.sync import async_to_sync, iscoroutinefunction, sync_to_async + from django.utils.inspect import func_accepts_kwargs logger = logging.getLogger("django.dispatch") @@ -52,7 +55,8 @@ class Signal: receiver A function or an instance method which is to receive signals. - Receivers must be hashable objects. + Receivers must be hashable objects. Receivers can be + asynchronous. If weak is True, then receiver must be weak referenceable. @@ -94,6 +98,8 @@ class Signal: else: lookup_key = (_make_id(receiver), _make_id(sender)) + is_async = iscoroutinefunction(receiver) + if weak: ref = weakref.ref receiver_object = receiver @@ -106,8 +112,8 @@ class Signal: with self.lock: self._clear_dead_receivers() - if not any(r_key == lookup_key for r_key, _ in self.receivers): - self.receivers.append((lookup_key, receiver)) + if not any(r_key == lookup_key for r_key, _, _ in self.receivers): + self.receivers.append((lookup_key, receiver, is_async)) self.sender_receivers_cache.clear() def disconnect(self, receiver=None, sender=None, dispatch_uid=None): @@ -138,7 +144,7 @@ class Signal: with self.lock: self._clear_dead_receivers() for index in range(len(self.receivers)): - (r_key, _) = self.receivers[index] + r_key, *_ = self.receivers[index] if r_key == lookup_key: disconnected = True del self.receivers[index] @@ -147,7 +153,8 @@ class Signal: return disconnected def has_listeners(self, sender=None): - return bool(self._live_receivers(sender)) + sync_receivers, async_receivers = self._live_receivers(sender) + return bool(sync_receivers) or bool(async_receivers) def send(self, sender, **named): """ @@ -157,6 +164,10 @@ class Signal: terminating the dispatch loop. So it's possible that all receivers won't be called if an error is raised. + If any receivers are asynchronous, they are called after all the + synchronous receivers via a single call to async_to_sync(). They are + also executed concurrently with asyncio.gather(). + Arguments: sender @@ -172,16 +183,97 @@ class Signal: or self.sender_receivers_cache.get(sender) is NO_RECEIVERS ): return [] + responses = [] + sync_receivers, async_receivers = self._live_receivers(sender) + for receiver in sync_receivers: + response = receiver(signal=self, sender=sender, **named) + responses.append((receiver, response)) + if async_receivers: + + async def asend(): + async_responses = await asyncio.gather( + *( + receiver(signal=self, sender=sender, **named) + for receiver in async_receivers + ) + ) + return zip(async_receivers, async_responses) + + responses.extend(async_to_sync(asend)()) + return responses + + async def asend(self, sender, **named): + """ + Send signal from sender to all connected receivers in async mode. + + All sync receivers will be wrapped by sync_to_async() + If any receiver raises an error, the error propagates back through + send, terminating the dispatch loop. So it's possible that all + receivers won't be called if an error is raised. + + If any receivers are synchronous, they are grouped and called behind a + sync_to_async() adaption before executing any asynchronous receivers. + + If any receivers are asynchronous, they are grouped and executed + concurrently with asyncio.gather(). + + Arguments: + + sender + The sender of the signal. Either a specific object or None. + + named + Named arguments which will be passed to receivers. + + Return a list of tuple pairs [(receiver, response), ...]. + """ + if ( + not self.receivers + or self.sender_receivers_cache.get(sender) is NO_RECEIVERS + ): + return [] + sync_receivers, async_receivers = self._live_receivers(sender) + if sync_receivers: + + @sync_to_async + def sync_send(): + responses = [] + for receiver in sync_receivers: + response = receiver(signal=self, sender=sender, **named) + responses.append((receiver, response)) + return responses + + else: + sync_send = list - return [ - (receiver, receiver(signal=self, sender=sender, **named)) - for receiver in self._live_receivers(sender) - ] + responses, async_responses = await asyncio.gather( + sync_send(), + asyncio.gather( + *( + receiver(signal=self, sender=sender, **named) + for receiver in async_receivers + ) + ), + ) + responses.extend(zip(async_receivers, async_responses)) + return responses + + def _log_robust_failure(self, receiver, err): + logger.error( + "Error calling %s in Signal.send_robust() (%s)", + receiver.__qualname__, + err, + exc_info=err, + ) def send_robust(self, sender, **named): """ Send signal from sender to all connected receivers catching errors. + If any receivers are asynchronous, they are called after all the + synchronous receivers via a single call to async_to_sync(). They are + also executed concurrently with asyncio.gather(). + Arguments: sender @@ -206,19 +298,105 @@ class Signal: # Call each receiver with whatever arguments it can accept. # Return a list of tuple pairs [(receiver, response), ... ]. responses = [] - for receiver in self._live_receivers(sender): + sync_receivers, async_receivers = self._live_receivers(sender) + for receiver in sync_receivers: try: response = receiver(signal=self, sender=sender, **named) except Exception as err: - logger.error( - "Error calling %s in Signal.send_robust() (%s)", - receiver.__qualname__, - err, - exc_info=err, - ) + self._log_robust_failure(receiver, err) responses.append((receiver, err)) else: responses.append((receiver, response)) + if async_receivers: + + async def asend_and_wrap_exception(receiver): + try: + response = await receiver(signal=self, sender=sender, **named) + except Exception as err: + self._log_robust_failure(receiver, err) + return err + return response + + async def asend(): + async_responses = await asyncio.gather( + *( + asend_and_wrap_exception(receiver) + for receiver in async_receivers + ) + ) + return zip(async_receivers, async_responses) + + responses.extend(async_to_sync(asend)()) + return responses + + async def asend_robust(self, sender, **named): + """ + Send signal from sender to all connected receivers catching errors. + + If any receivers are synchronous, they are grouped and called behind a + sync_to_async() adaption before executing any asynchronous receivers. + + If any receivers are asynchronous, they are grouped and executed + concurrently with asyncio.gather. + + Arguments: + + sender + The sender of the signal. Can be any Python object (normally one + registered with a connect if you actually want something to + occur). + + named + Named arguments which will be passed to receivers. + + Return a list of tuple pairs [(receiver, response), ... ]. + + If any receiver raises an error (specifically any subclass of + Exception), return the error instance as the result for that receiver. + """ + if ( + not self.receivers + or self.sender_receivers_cache.get(sender) is NO_RECEIVERS + ): + return [] + + # Call each receiver with whatever arguments it can accept. + # Return a list of tuple pairs [(receiver, response), ... ]. + sync_receivers, async_receivers = self._live_receivers(sender) + + if sync_receivers: + + @sync_to_async + def sync_send(): + responses = [] + for receiver in sync_receivers: + try: + response = receiver(signal=self, sender=sender, **named) + except Exception as err: + self._log_robust_failure(receiver, err) + responses.append((receiver, err)) + else: + responses.append((receiver, response)) + return responses + + else: + sync_send = list + + async def asend_and_wrap_exception(receiver): + try: + response = await receiver(signal=self, sender=sender, **named) + except Exception as err: + self._log_robust_failure(receiver, err) + return err + return response + + responses, async_responses = await asyncio.gather( + sync_send(), + asyncio.gather( + *(asend_and_wrap_exception(receiver) for receiver in async_receivers), + ), + ) + responses.extend(zip(async_receivers, async_responses)) return responses def _clear_dead_receivers(self): @@ -244,31 +422,38 @@ class Signal: # We could end up here with NO_RECEIVERS even if we do check this case in # .send() prior to calling _live_receivers() due to concurrent .send() call. if receivers is NO_RECEIVERS: - return [] + return [], [] if receivers is None: with self.lock: self._clear_dead_receivers() senderkey = _make_id(sender) receivers = [] - for (receiverkey, r_senderkey), receiver in self.receivers: + for (_receiverkey, r_senderkey), receiver, is_async in self.receivers: if r_senderkey == NONE_ID or r_senderkey == senderkey: - receivers.append(receiver) + receivers.append((receiver, is_async)) if self.use_caching: if not receivers: self.sender_receivers_cache[sender] = NO_RECEIVERS else: # Note, we must cache the weakref versions. self.sender_receivers_cache[sender] = receivers - non_weak_receivers = [] - for receiver in receivers: + non_weak_sync_receivers = [] + non_weak_async_receivers = [] + for receiver, is_async in receivers: if isinstance(receiver, weakref.ReferenceType): # Dereference the weak reference. receiver = receiver() if receiver is not None: - non_weak_receivers.append(receiver) + if is_async: + non_weak_async_receivers.append(receiver) + else: + non_weak_sync_receivers.append(receiver) else: - non_weak_receivers.append(receiver) - return non_weak_receivers + if is_async: + non_weak_async_receivers.append(receiver) + else: + non_weak_sync_receivers.append(receiver) + return non_weak_sync_receivers, non_weak_async_receivers def _remove_receiver(self, receiver=None): # Mark that the self.receivers list has dead weakrefs. If so, we will |
