diff options
| -rw-r--r-- | django/contrib/auth/decorators.py | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/django/contrib/auth/decorators.py b/django/contrib/auth/decorators.py index 3407555852..d2e6845abe 100644 --- a/django/contrib/auth/decorators.py +++ b/django/contrib/auth/decorators.py @@ -36,25 +36,27 @@ def user_passes_test( return redirect_to_login(path, resolved_login_url, redirect_field_name) if iscoroutinefunction(view_func): + if iscoroutinefunction(test_func): + _async_test_func = test_func + else: + _async_test_func = sync_to_async(test_func) async def _view_wrapper(request, *args, **kwargs): auser = await request.auser() - if iscoroutinefunction(test_func): - test_pass = await test_func(auser) - else: - test_pass = await sync_to_async(test_func)(auser) + test_pass = await _async_test_func(auser) if test_pass: return await view_func(request, *args, **kwargs) return _redirect_to_login(request) else: + if iscoroutinefunction(test_func): + _sync_test_func = async_to_sync(test_func) + else: + _sync_test_func = test_func def _view_wrapper(request, *args, **kwargs): - if iscoroutinefunction(test_func): - test_pass = async_to_sync(test_func)(request.user) - else: - test_pass = test_func(request.user) + test_pass = _sync_test_func(request.user) if test_pass: return view_func(request, *args, **kwargs) |
