summaryrefslogtreecommitdiff
path: root/tests/auth_tests
diff options
context:
space:
mode:
authorJon Janzen <jon@jonjanzen.com>2024-03-31 12:29:10 -0700
committerSarah Boyce <42296566+sarahboyce@users.noreply.github.com>2024-10-07 14:19:41 +0200
commit50f89ae850f6b4e35819fe725a08c7e579bfd099 (patch)
tree856a0e954e0be928c55f6070f2ac8b766459b3e7 /tests/auth_tests
parent4cad317ff1f9a79d54c1d5b12f1ccbd260ca009f (diff)
Fixed #35303 -- Implemented async auth backends and utils.
Diffstat (limited to 'tests/auth_tests')
-rw-r--r--tests/auth_tests/models/custom_user.py13
-rw-r--r--tests/auth_tests/test_auth_backends.py324
-rw-r--r--tests/auth_tests/test_basic.py39
-rw-r--r--tests/auth_tests/test_decorators.py6
-rw-r--r--tests/auth_tests/test_models.py41
-rw-r--r--tests/auth_tests/test_remote_user.py180
6 files changed, 585 insertions, 18 deletions
diff --git a/tests/auth_tests/models/custom_user.py b/tests/auth_tests/models/custom_user.py
index b9938681ca..4586e452cd 100644
--- a/tests/auth_tests/models/custom_user.py
+++ b/tests/auth_tests/models/custom_user.py
@@ -29,6 +29,19 @@ class CustomUserManager(BaseUserManager):
user.save(using=self._db)
return user
+ async def acreate_user(self, email, date_of_birth, password=None, **fields):
+ """See create_user()"""
+ if not email:
+ raise ValueError("Users must have an email address")
+
+ user = self.model(
+ email=self.normalize_email(email), date_of_birth=date_of_birth, **fields
+ )
+
+ user.set_password(password)
+ await user.asave(using=self._db)
+ return user
+
def create_superuser(self, email, password, date_of_birth, **fields):
u = self.create_user(
email, password=password, date_of_birth=date_of_birth, **fields
diff --git a/tests/auth_tests/test_auth_backends.py b/tests/auth_tests/test_auth_backends.py
index 3b4f40e6e0..b612d27ab0 100644
--- a/tests/auth_tests/test_auth_backends.py
+++ b/tests/auth_tests/test_auth_backends.py
@@ -2,6 +2,8 @@ import sys
from datetime import date
from unittest import mock
+from asgiref.sync import sync_to_async
+
from django.contrib.auth import (
BACKEND_SESSION_KEY,
SESSION_KEY,
@@ -55,17 +57,33 @@ class BaseBackendTest(TestCase):
def test_get_user_permissions(self):
self.assertEqual(self.user.get_user_permissions(), {"user_perm"})
+ async def test_aget_user_permissions(self):
+ self.assertEqual(await self.user.aget_user_permissions(), {"user_perm"})
+
def test_get_group_permissions(self):
self.assertEqual(self.user.get_group_permissions(), {"group_perm"})
+ async def test_aget_group_permissions(self):
+ self.assertEqual(await self.user.aget_group_permissions(), {"group_perm"})
+
def test_get_all_permissions(self):
self.assertEqual(self.user.get_all_permissions(), {"user_perm", "group_perm"})
+ async def test_aget_all_permissions(self):
+ self.assertEqual(
+ await self.user.aget_all_permissions(), {"user_perm", "group_perm"}
+ )
+
def test_has_perm(self):
self.assertIs(self.user.has_perm("user_perm"), True)
self.assertIs(self.user.has_perm("group_perm"), True)
self.assertIs(self.user.has_perm("other_perm", TestObj()), False)
+ async def test_ahas_perm(self):
+ self.assertIs(await self.user.ahas_perm("user_perm"), True)
+ self.assertIs(await self.user.ahas_perm("group_perm"), True)
+ self.assertIs(await self.user.ahas_perm("other_perm", TestObj()), False)
+
def test_has_perms_perm_list_invalid(self):
msg = "perm_list must be an iterable of permissions."
with self.assertRaisesMessage(ValueError, msg):
@@ -73,6 +91,13 @@ class BaseBackendTest(TestCase):
with self.assertRaisesMessage(ValueError, msg):
self.user.has_perms(object())
+ async def test_ahas_perms_perm_list_invalid(self):
+ msg = "perm_list must be an iterable of permissions."
+ with self.assertRaisesMessage(ValueError, msg):
+ await self.user.ahas_perms("user_perm")
+ with self.assertRaisesMessage(ValueError, msg):
+ await self.user.ahas_perms(object())
+
class CountingMD5PasswordHasher(MD5PasswordHasher):
"""Hasher that counts how many times it computes a hash."""
@@ -125,6 +150,25 @@ class BaseModelBackendTest:
user.save()
self.assertIs(user.has_perm("auth.test"), False)
+ async def test_ahas_perm(self):
+ user = await self.UserModel._default_manager.aget(pk=self.user.pk)
+ self.assertIs(await user.ahas_perm("auth.test"), False)
+
+ user.is_staff = True
+ await user.asave()
+ self.assertIs(await user.ahas_perm("auth.test"), False)
+
+ user.is_superuser = True
+ await user.asave()
+ self.assertIs(await user.ahas_perm("auth.test"), True)
+ self.assertIs(await user.ahas_module_perms("auth"), True)
+
+ user.is_staff = True
+ user.is_superuser = True
+ user.is_active = False
+ await user.asave()
+ self.assertIs(await user.ahas_perm("auth.test"), False)
+
def test_custom_perms(self):
user = self.UserModel._default_manager.get(pk=self.user.pk)
content_type = ContentType.objects.get_for_model(Group)
@@ -174,6 +218,55 @@ class BaseModelBackendTest:
self.assertIs(user.has_perm("test"), False)
self.assertIs(user.has_perms(["auth.test2", "auth.test3"]), False)
+ async def test_acustom_perms(self):
+ user = await self.UserModel._default_manager.aget(pk=self.user.pk)
+ content_type = await sync_to_async(ContentType.objects.get_for_model)(Group)
+ perm = await Permission.objects.acreate(
+ name="test", content_type=content_type, codename="test"
+ )
+ await user.user_permissions.aadd(perm)
+
+ # Reloading user to purge the _perm_cache.
+ user = await self.UserModel._default_manager.aget(pk=self.user.pk)
+ self.assertEqual(await user.aget_all_permissions(), {"auth.test"})
+ self.assertEqual(await user.aget_user_permissions(), {"auth.test"})
+ self.assertEqual(await user.aget_group_permissions(), set())
+ self.assertIs(await user.ahas_module_perms("Group"), False)
+ self.assertIs(await user.ahas_module_perms("auth"), True)
+
+ perm = await Permission.objects.acreate(
+ name="test2", content_type=content_type, codename="test2"
+ )
+ await user.user_permissions.aadd(perm)
+ perm = await Permission.objects.acreate(
+ name="test3", content_type=content_type, codename="test3"
+ )
+ await user.user_permissions.aadd(perm)
+ user = await self.UserModel._default_manager.aget(pk=self.user.pk)
+ expected_user_perms = {"auth.test2", "auth.test", "auth.test3"}
+ self.assertEqual(await user.aget_all_permissions(), expected_user_perms)
+ self.assertIs(await user.ahas_perm("test"), False)
+ self.assertIs(await user.ahas_perm("auth.test"), True)
+ self.assertIs(await user.ahas_perms(["auth.test2", "auth.test3"]), True)
+
+ perm = await Permission.objects.acreate(
+ name="test_group", content_type=content_type, codename="test_group"
+ )
+ group = await Group.objects.acreate(name="test_group")
+ await group.permissions.aadd(perm)
+ await user.groups.aadd(group)
+ user = await self.UserModel._default_manager.aget(pk=self.user.pk)
+ self.assertEqual(
+ await user.aget_all_permissions(), {*expected_user_perms, "auth.test_group"}
+ )
+ self.assertEqual(await user.aget_user_permissions(), expected_user_perms)
+ self.assertEqual(await user.aget_group_permissions(), {"auth.test_group"})
+ self.assertIs(await user.ahas_perms(["auth.test3", "auth.test_group"]), True)
+
+ user = AnonymousUser()
+ self.assertIs(await user.ahas_perm("test"), False)
+ self.assertIs(await user.ahas_perms(["auth.test2", "auth.test3"]), False)
+
def test_has_no_object_perm(self):
"""Regressiontest for #12462"""
user = self.UserModel._default_manager.get(pk=self.user.pk)
@@ -188,6 +281,20 @@ class BaseModelBackendTest:
self.assertIs(user.has_perm("auth.test"), True)
self.assertEqual(user.get_all_permissions(), {"auth.test"})
+ async def test_ahas_no_object_perm(self):
+ """See test_has_no_object_perm()"""
+ user = await self.UserModel._default_manager.aget(pk=self.user.pk)
+ content_type = await sync_to_async(ContentType.objects.get_for_model)(Group)
+ perm = await Permission.objects.acreate(
+ name="test", content_type=content_type, codename="test"
+ )
+ await user.user_permissions.aadd(perm)
+
+ self.assertIs(await user.ahas_perm("auth.test", "object"), False)
+ self.assertEqual(await user.aget_all_permissions("object"), set())
+ self.assertIs(await user.ahas_perm("auth.test"), True)
+ self.assertEqual(await user.aget_all_permissions(), {"auth.test"})
+
def test_anonymous_has_no_permissions(self):
"""
#17903 -- Anonymous users shouldn't have permissions in
@@ -220,6 +327,38 @@ class BaseModelBackendTest:
self.assertEqual(backend.get_user_permissions(user), set())
self.assertEqual(backend.get_group_permissions(user), set())
+ async def test_aanonymous_has_no_permissions(self):
+ """See test_anonymous_has_no_permissions()"""
+ backend = ModelBackend()
+
+ user = await self.UserModel._default_manager.aget(pk=self.user.pk)
+ content_type = await sync_to_async(ContentType.objects.get_for_model)(Group)
+ user_perm = await Permission.objects.acreate(
+ name="test", content_type=content_type, codename="test_user"
+ )
+ group_perm = await Permission.objects.acreate(
+ name="test2", content_type=content_type, codename="test_group"
+ )
+ await user.user_permissions.aadd(user_perm)
+
+ group = await Group.objects.acreate(name="test_group")
+ await user.groups.aadd(group)
+ await group.permissions.aadd(group_perm)
+
+ self.assertEqual(
+ await backend.aget_all_permissions(user),
+ {"auth.test_user", "auth.test_group"},
+ )
+ self.assertEqual(await backend.aget_user_permissions(user), {"auth.test_user"})
+ self.assertEqual(
+ await backend.aget_group_permissions(user), {"auth.test_group"}
+ )
+
+ with mock.patch.object(self.UserModel, "is_anonymous", True):
+ self.assertEqual(await backend.aget_all_permissions(user), set())
+ self.assertEqual(await backend.aget_user_permissions(user), set())
+ self.assertEqual(await backend.aget_group_permissions(user), set())
+
def test_inactive_has_no_permissions(self):
"""
#17903 -- Inactive users shouldn't have permissions in
@@ -254,11 +393,52 @@ class BaseModelBackendTest:
self.assertEqual(backend.get_user_permissions(user), set())
self.assertEqual(backend.get_group_permissions(user), set())
+ async def test_ainactive_has_no_permissions(self):
+ """See test_inactive_has_no_permissions()"""
+ backend = ModelBackend()
+
+ user = await self.UserModel._default_manager.aget(pk=self.user.pk)
+ content_type = await sync_to_async(ContentType.objects.get_for_model)(Group)
+ user_perm = await Permission.objects.acreate(
+ name="test", content_type=content_type, codename="test_user"
+ )
+ group_perm = await Permission.objects.acreate(
+ name="test2", content_type=content_type, codename="test_group"
+ )
+ await user.user_permissions.aadd(user_perm)
+
+ group = await Group.objects.acreate(name="test_group")
+ await user.groups.aadd(group)
+ await group.permissions.aadd(group_perm)
+
+ self.assertEqual(
+ await backend.aget_all_permissions(user),
+ {"auth.test_user", "auth.test_group"},
+ )
+ self.assertEqual(await backend.aget_user_permissions(user), {"auth.test_user"})
+ self.assertEqual(
+ await backend.aget_group_permissions(user), {"auth.test_group"}
+ )
+
+ user.is_active = False
+ await user.asave()
+
+ self.assertEqual(await backend.aget_all_permissions(user), set())
+ self.assertEqual(await backend.aget_user_permissions(user), set())
+ self.assertEqual(await backend.aget_group_permissions(user), set())
+
def test_get_all_superuser_permissions(self):
"""A superuser has all permissions. Refs #14795."""
user = self.UserModel._default_manager.get(pk=self.superuser.pk)
self.assertEqual(len(user.get_all_permissions()), len(Permission.objects.all()))
+ async def test_aget_all_superuser_permissions(self):
+ """See test_get_all_superuser_permissions()"""
+ user = await self.UserModel._default_manager.aget(pk=self.superuser.pk)
+ self.assertEqual(
+ len(await user.aget_all_permissions()), await Permission.objects.acount()
+ )
+
@override_settings(
PASSWORD_HASHERS=["auth_tests.test_auth_backends.CountingMD5PasswordHasher"]
)
@@ -280,6 +460,24 @@ class BaseModelBackendTest:
@override_settings(
PASSWORD_HASHERS=["auth_tests.test_auth_backends.CountingMD5PasswordHasher"]
)
+ async def test_aauthentication_timing(self):
+ """See test_authentication_timing()"""
+ # Re-set the password, because this tests overrides PASSWORD_HASHERS.
+ self.user.set_password("test")
+ await self.user.asave()
+
+ CountingMD5PasswordHasher.calls = 0
+ username = getattr(self.user, self.UserModel.USERNAME_FIELD)
+ await aauthenticate(username=username, password="test")
+ self.assertEqual(CountingMD5PasswordHasher.calls, 1)
+
+ CountingMD5PasswordHasher.calls = 0
+ await aauthenticate(username="no_such_user", password="test")
+ self.assertEqual(CountingMD5PasswordHasher.calls, 1)
+
+ @override_settings(
+ PASSWORD_HASHERS=["auth_tests.test_auth_backends.CountingMD5PasswordHasher"]
+ )
def test_authentication_without_credentials(self):
CountingMD5PasswordHasher.calls = 0
for credentials in (
@@ -320,6 +518,15 @@ class ModelBackendTest(BaseModelBackendTest, TestCase):
self.user.save()
self.assertIsNone(authenticate(**self.user_credentials))
+ async def test_aauthenticate_inactive(self):
+ """
+ An inactive user can't authenticate.
+ """
+ self.assertEqual(await aauthenticate(**self.user_credentials), self.user)
+ self.user.is_active = False
+ await self.user.asave()
+ self.assertIsNone(await aauthenticate(**self.user_credentials))
+
@override_settings(AUTH_USER_MODEL="auth_tests.CustomUserWithoutIsActiveField")
def test_authenticate_user_without_is_active_field(self):
"""
@@ -332,6 +539,18 @@ class ModelBackendTest(BaseModelBackendTest, TestCase):
)
self.assertEqual(authenticate(username="test", password="test"), user)
+ @override_settings(AUTH_USER_MODEL="auth_tests.CustomUserWithoutIsActiveField")
+ async def test_aauthenticate_user_without_is_active_field(self):
+ """
+ A custom user without an `is_active` field is allowed to authenticate.
+ """
+ user = await CustomUserWithoutIsActiveField.objects._acreate_user(
+ username="test",
+ email="test@example.com",
+ password="test",
+ )
+ self.assertEqual(await aauthenticate(username="test", password="test"), user)
+
@override_settings(AUTH_USER_MODEL="auth_tests.ExtensionUser")
class ExtensionUserModelBackendTest(BaseModelBackendTest, TestCase):
@@ -403,6 +622,15 @@ class CustomUserModelBackendAuthenticateTest(TestCase):
authenticated_user = authenticate(email="test@example.com", password="test")
self.assertEqual(test_user, authenticated_user)
+ async def test_aauthenticate(self):
+ test_user = await CustomUser._default_manager.acreate_user(
+ email="test@example.com", password="test", date_of_birth=date(2006, 4, 25)
+ )
+ authenticated_user = await aauthenticate(
+ email="test@example.com", password="test"
+ )
+ self.assertEqual(test_user, authenticated_user)
+
@override_settings(AUTH_USER_MODEL="auth_tests.UUIDUser")
class UUIDUserTests(TestCase):
@@ -416,6 +644,13 @@ class UUIDUserTests(TestCase):
UUIDUser.objects.get(pk=self.client.session[SESSION_KEY]), user
)
+ async def test_alogin(self):
+ """See test_login()"""
+ user = await UUIDUser.objects.acreate_user(username="uuid", password="test")
+ self.assertTrue(await self.client.alogin(username="uuid", password="test"))
+ session_key = await self.client.session.aget(SESSION_KEY)
+ self.assertEqual(await UUIDUser.objects.aget(pk=session_key), user)
+
class TestObj:
pass
@@ -435,9 +670,15 @@ class SimpleRowlevelBackend:
return True
return False
+ async def ahas_perm(self, user, perm, obj=None):
+ return self.has_perm(user, perm, obj)
+
def has_module_perms(self, user, app_label):
return (user.is_anonymous or user.is_active) and app_label == "app1"
+ async def ahas_module_perms(self, user, app_label):
+ return self.has_module_perms(user, app_label)
+
def get_all_permissions(self, user, obj=None):
if not obj:
return [] # We only support row level perms
@@ -452,6 +693,9 @@ class SimpleRowlevelBackend:
else:
return ["simple"]
+ async def aget_all_permissions(self, user, obj=None):
+ return self.get_all_permissions(user, obj)
+
def get_group_permissions(self, user, obj=None):
if not obj:
return # We only support row level perms
@@ -524,10 +768,18 @@ class AnonymousUserBackendTest(SimpleTestCase):
self.assertIs(self.user1.has_perm("perm", TestObj()), False)
self.assertIs(self.user1.has_perm("anon", TestObj()), True)
+ async def test_ahas_perm(self):
+ self.assertIs(await self.user1.ahas_perm("perm", TestObj()), False)
+ self.assertIs(await self.user1.ahas_perm("anon", TestObj()), True)
+
def test_has_perms(self):
self.assertIs(self.user1.has_perms(["anon"], TestObj()), True)
self.assertIs(self.user1.has_perms(["anon", "perm"], TestObj()), False)
+ async def test_ahas_perms(self):
+ self.assertIs(await self.user1.ahas_perms(["anon"], TestObj()), True)
+ self.assertIs(await self.user1.ahas_perms(["anon", "perm"], TestObj()), False)
+
def test_has_perms_perm_list_invalid(self):
msg = "perm_list must be an iterable of permissions."
with self.assertRaisesMessage(ValueError, msg):
@@ -535,13 +787,27 @@ class AnonymousUserBackendTest(SimpleTestCase):
with self.assertRaisesMessage(ValueError, msg):
self.user1.has_perms(object())
+ async def test_ahas_perms_perm_list_invalid(self):
+ msg = "perm_list must be an iterable of permissions."
+ with self.assertRaisesMessage(ValueError, msg):
+ await self.user1.ahas_perms("perm")
+ with self.assertRaisesMessage(ValueError, msg):
+ await self.user1.ahas_perms(object())
+
def test_has_module_perms(self):
self.assertIs(self.user1.has_module_perms("app1"), True)
self.assertIs(self.user1.has_module_perms("app2"), False)
+ async def test_ahas_module_perms(self):
+ self.assertIs(await self.user1.ahas_module_perms("app1"), True)
+ self.assertIs(await self.user1.ahas_module_perms("app2"), False)
+
def test_get_all_permissions(self):
self.assertEqual(self.user1.get_all_permissions(TestObj()), {"anon"})
+ async def test_aget_all_permissions(self):
+ self.assertEqual(await self.user1.aget_all_permissions(TestObj()), {"anon"})
+
@override_settings(AUTHENTICATION_BACKENDS=[])
class NoBackendsTest(TestCase):
@@ -561,6 +827,14 @@ class NoBackendsTest(TestCase):
with self.assertRaisesMessage(ImproperlyConfigured, msg):
self.user.has_perm(("perm", TestObj()))
+ async def test_araises_exception(self):
+ msg = (
+ "No authentication backends have been defined. "
+ "Does AUTHENTICATION_BACKENDS contain anything?"
+ )
+ with self.assertRaisesMessage(ImproperlyConfigured, msg):
+ await self.user.ahas_perm(("perm", TestObj()))
+
@override_settings(
AUTHENTICATION_BACKENDS=["auth_tests.test_auth_backends.SimpleRowlevelBackend"]
@@ -593,12 +867,21 @@ class PermissionDeniedBackend:
def authenticate(self, request, username=None, password=None):
raise PermissionDenied
+ async def aauthenticate(self, request, username=None, password=None):
+ raise PermissionDenied
+
def has_perm(self, user_obj, perm, obj=None):
raise PermissionDenied
+ async def ahas_perm(self, user_obj, perm, obj=None):
+ raise PermissionDenied
+
def has_module_perms(self, user_obj, app_label):
raise PermissionDenied
+ async def ahas_module_perms(self, user_obj, app_label):
+ raise PermissionDenied
+
class PermissionDeniedBackendTest(TestCase):
"""
@@ -631,10 +914,25 @@ class PermissionDeniedBackendTest(TestCase):
[{"password": "********************", "username": "test"}],
)
+ @modify_settings(AUTHENTICATION_BACKENDS={"prepend": backend})
+ async def test_aauthenticate_permission_denied(self):
+ self.assertIsNone(await aauthenticate(username="test", password="test"))
+ # user_login_failed signal is sent.
+ self.assertEqual(
+ self.user_login_failed,
+ [{"password": "********************", "username": "test"}],
+ )
+
@modify_settings(AUTHENTICATION_BACKENDS={"append": backend})
def test_authenticates(self):
self.assertEqual(authenticate(username="test", password="test"), self.user1)
+ @modify_settings(AUTHENTICATION_BACKENDS={"append": backend})
+ async def test_aauthenticate(self):
+ self.assertEqual(
+ await aauthenticate(username="test", password="test"), self.user1
+ )
+
@modify_settings(AUTHENTICATION_BACKENDS={"prepend": backend})
def test_has_perm_denied(self):
content_type = ContentType.objects.get_for_model(Group)
@@ -646,6 +944,17 @@ class PermissionDeniedBackendTest(TestCase):
self.assertIs(self.user1.has_perm("auth.test"), False)
self.assertIs(self.user1.has_module_perms("auth"), False)
+ @modify_settings(AUTHENTICATION_BACKENDS={"prepend": backend})
+ async def test_ahas_perm_denied(self):
+ content_type = await sync_to_async(ContentType.objects.get_for_model)(Group)
+ perm = await Permission.objects.acreate(
+ name="test", content_type=content_type, codename="test"
+ )
+ await self.user1.user_permissions.aadd(perm)
+
+ self.assertIs(await self.user1.ahas_perm("auth.test"), False)
+ self.assertIs(await self.user1.ahas_module_perms("auth"), False)
+
@modify_settings(AUTHENTICATION_BACKENDS={"append": backend})
def test_has_perm(self):
content_type = ContentType.objects.get_for_model(Group)
@@ -657,6 +966,17 @@ class PermissionDeniedBackendTest(TestCase):
self.assertIs(self.user1.has_perm("auth.test"), True)
self.assertIs(self.user1.has_module_perms("auth"), True)
+ @modify_settings(AUTHENTICATION_BACKENDS={"append": backend})
+ async def test_ahas_perm(self):
+ content_type = await sync_to_async(ContentType.objects.get_for_model)(Group)
+ perm = await Permission.objects.acreate(
+ name="test", content_type=content_type, codename="test"
+ )
+ await self.user1.user_permissions.aadd(perm)
+
+ self.assertIs(await self.user1.ahas_perm("auth.test"), True)
+ self.assertIs(await self.user1.ahas_module_perms("auth"), True)
+
class NewModelBackend(ModelBackend):
pass
@@ -715,6 +1035,10 @@ class TypeErrorBackend:
def authenticate(self, request, username=None, password=None):
raise TypeError
+ @sensitive_variables("password")
+ async def aauthenticate(self, request, username=None, password=None):
+ raise TypeError
+
class SkippedBackend:
def authenticate(self):
diff --git a/tests/auth_tests/test_basic.py b/tests/auth_tests/test_basic.py
index d7a7750b54..8d54e187fc 100644
--- a/tests/auth_tests/test_basic.py
+++ b/tests/auth_tests/test_basic.py
@@ -1,5 +1,3 @@
-from asgiref.sync import sync_to_async
-
from django.conf import settings
from django.contrib.auth import aget_user, get_user, get_user_model
from django.contrib.auth.models import AnonymousUser, User
@@ -44,6 +42,12 @@ class BasicTestCase(TestCase):
u2 = User.objects.create_user("testuser2", "test2@example.com")
self.assertFalse(u2.has_usable_password())
+ async def test_acreate(self):
+ u = await User.objects.acreate_user("testuser", "test@example.com", "testpw")
+ self.assertTrue(u.has_usable_password())
+ self.assertFalse(await u.acheck_password("bad"))
+ self.assertTrue(await u.acheck_password("testpw"))
+
def test_unicode_username(self):
User.objects.create_user("jörg")
User.objects.create_user("Григорий")
@@ -73,6 +77,15 @@ class BasicTestCase(TestCase):
self.assertTrue(super.is_active)
self.assertTrue(super.is_staff)
+ async def test_asuperuser(self):
+ "Check the creation and properties of a superuser"
+ super = await User.objects.acreate_superuser(
+ "super", "super@example.com", "super"
+ )
+ self.assertTrue(super.is_superuser)
+ self.assertTrue(super.is_active)
+ self.assertTrue(super.is_staff)
+
def test_superuser_no_email_or_password(self):
cases = [
{},
@@ -171,13 +184,25 @@ class TestGetUser(TestCase):
self.assertIsInstance(user, User)
self.assertEqual(user.username, created_user.username)
- async def test_aget_user(self):
- created_user = await sync_to_async(User.objects.create_user)(
+ async def test_aget_user_fallback_secret(self):
+ created_user = await User.objects.acreate_user(
"testuser", "test@example.com", "testpw"
)
await self.client.alogin(username="testuser", password="testpw")
request = HttpRequest()
request.session = await self.client.asession()
- user = await aget_user(request)
- self.assertIsInstance(user, User)
- self.assertEqual(user.username, created_user.username)
+ prev_session_key = request.session.session_key
+ with override_settings(
+ SECRET_KEY="newsecret",
+ SECRET_KEY_FALLBACKS=[settings.SECRET_KEY],
+ ):
+ user = await aget_user(request)
+ self.assertIsInstance(user, User)
+ self.assertEqual(user.username, created_user.username)
+ self.assertNotEqual(request.session.session_key, prev_session_key)
+ # Remove the fallback secret.
+ # The session hash should be updated using the current secret.
+ with override_settings(SECRET_KEY="newsecret"):
+ user = await aget_user(request)
+ self.assertIsInstance(user, User)
+ self.assertEqual(user.username, created_user.username)
diff --git a/tests/auth_tests/test_decorators.py b/tests/auth_tests/test_decorators.py
index e585b28bd5..fa2672beb4 100644
--- a/tests/auth_tests/test_decorators.py
+++ b/tests/auth_tests/test_decorators.py
@@ -1,7 +1,5 @@
from asyncio import iscoroutinefunction
-from asgiref.sync import sync_to_async
-
from django.conf import settings
from django.contrib.auth import models
from django.contrib.auth.decorators import (
@@ -374,7 +372,7 @@ class UserPassesTestDecoratorTest(TestCase):
def test_decorator_async_test_func(self):
async def async_test_func(user):
- return await sync_to_async(user.has_perms)(["auth_tests.add_customuser"])
+ return await user.ahas_perms(["auth_tests.add_customuser"])
@user_passes_test(async_test_func)
def sync_view(request):
@@ -410,7 +408,7 @@ class UserPassesTestDecoratorTest(TestCase):
async def test_decorator_async_view_async_test_func(self):
async def async_test_func(user):
- return await sync_to_async(user.has_perms)(["auth_tests.add_customuser"])
+ return await user.ahas_perms(["auth_tests.add_customuser"])
@user_passes_test(async_test_func)
async def async_view(request):
diff --git a/tests/auth_tests/test_models.py b/tests/auth_tests/test_models.py
index 983424843c..a3e7a3205b 100644
--- a/tests/auth_tests/test_models.py
+++ b/tests/auth_tests/test_models.py
@@ -1,7 +1,5 @@
from unittest import mock
-from asgiref.sync import sync_to_async
-
from django.conf.global_settings import PASSWORD_HASHERS
from django.contrib.auth import get_user_model
from django.contrib.auth.backends import ModelBackend
@@ -30,10 +28,19 @@ class NaturalKeysTestCase(TestCase):
self.assertEqual(User.objects.get_by_natural_key("staff"), staff_user)
self.assertEqual(staff_user.natural_key(), ("staff",))
+ async def test_auser_natural_key(self):
+ staff_user = await User.objects.acreate_user(username="staff")
+ self.assertEqual(await User.objects.aget_by_natural_key("staff"), staff_user)
+ self.assertEqual(staff_user.natural_key(), ("staff",))
+
def test_group_natural_key(self):
users_group = Group.objects.create(name="users")
self.assertEqual(Group.objects.get_by_natural_key("users"), users_group)
+ async def test_agroup_natural_key(self):
+ users_group = await Group.objects.acreate(name="users")
+ self.assertEqual(await Group.objects.aget_by_natural_key("users"), users_group)
+
class LoadDataWithoutNaturalKeysTestCase(TestCase):
fixtures = ["regular.json"]
@@ -157,6 +164,17 @@ class UserManagerTestCase(TransactionTestCase):
is_superuser=False,
)
+ async def test_acreate_super_user_raises_error_on_false_is_superuser(self):
+ with self.assertRaisesMessage(
+ ValueError, "Superuser must have is_superuser=True."
+ ):
+ await User.objects.acreate_superuser(
+ username="test",
+ email="test@test.com",
+ password="test",
+ is_superuser=False,
+ )
+
def test_create_superuser_raises_error_on_false_is_staff(self):
with self.assertRaisesMessage(ValueError, "Superuser must have is_staff=True."):
User.objects.create_superuser(
@@ -166,6 +184,15 @@ class UserManagerTestCase(TransactionTestCase):
is_staff=False,
)
+ async def test_acreate_superuser_raises_error_on_false_is_staff(self):
+ with self.assertRaisesMessage(ValueError, "Superuser must have is_staff=True."):
+ await User.objects.acreate_superuser(
+ username="test",
+ email="test@test.com",
+ password="test",
+ is_staff=False,
+ )
+
def test_runpython_manager_methods(self):
def forwards(apps, schema_editor):
UserModel = apps.get_model("auth", "User")
@@ -301,9 +328,7 @@ class AbstractUserTestCase(TestCase):
@override_settings(PASSWORD_HASHERS=PASSWORD_HASHERS)
async def test_acheck_password_upgrade(self):
- user = await sync_to_async(User.objects.create_user)(
- username="user", password="foo"
- )
+ user = await User.objects.acreate_user(username="user", password="foo")
initial_password = user.password
self.assertIs(await user.acheck_password("foo"), True)
hasher = get_hasher("default")
@@ -557,6 +582,12 @@ class AnonymousUserTests(SimpleTestCase):
self.assertEqual(self.user.get_user_permissions(), set())
self.assertEqual(self.user.get_group_permissions(), set())
+ async def test_properties_async_versions(self):
+ self.assertEqual(await self.user.groups.acount(), 0)
+ self.assertEqual(await self.user.user_permissions.acount(), 0)
+ self.assertEqual(await self.user.aget_user_permissions(), set())
+ self.assertEqual(await self.user.aget_group_permissions(), set())
+
def test_str(self):
self.assertEqual(str(self.user), "AnonymousUser")
diff --git a/tests/auth_tests/test_remote_user.py b/tests/auth_tests/test_remote_user.py
index d3cf4b9da5..85de931c1a 100644
--- a/tests/auth_tests/test_remote_user.py
+++ b/tests/auth_tests/test_remote_user.py
@@ -1,12 +1,18 @@
from datetime import datetime, timezone
from django.conf import settings
-from django.contrib.auth import authenticate
+from django.contrib.auth import aauthenticate, authenticate
from django.contrib.auth.backends import RemoteUserBackend
from django.contrib.auth.middleware import RemoteUserMiddleware
from django.contrib.auth.models import User
from django.middleware.csrf import _get_new_csrf_string, _mask_cipher_secret
-from django.test import Client, TestCase, modify_settings, override_settings
+from django.test import (
+ AsyncClient,
+ Client,
+ TestCase,
+ modify_settings,
+ override_settings,
+)
@override_settings(ROOT_URLCONF="auth_tests.urls")
@@ -30,6 +36,11 @@ class RemoteUserTest(TestCase):
)
super().setUpClass()
+ def test_passing_explicit_none(self):
+ msg = "get_response must be provided."
+ with self.assertRaisesMessage(ValueError, msg):
+ RemoteUserMiddleware(None)
+
def test_no_remote_user(self):
"""Users are not created when remote user is not specified."""
num_users = User.objects.count()
@@ -46,6 +57,18 @@ class RemoteUserTest(TestCase):
self.assertTrue(response.context["user"].is_anonymous)
self.assertEqual(User.objects.count(), num_users)
+ async def test_no_remote_user_async(self):
+ """See test_no_remote_user."""
+ num_users = await User.objects.acount()
+
+ response = await self.async_client.get("/remote_user/")
+ self.assertTrue(response.context["user"].is_anonymous)
+ self.assertEqual(await User.objects.acount(), num_users)
+
+ response = await self.async_client.get("/remote_user/", **{self.header: ""})
+ self.assertTrue(response.context["user"].is_anonymous)
+ self.assertEqual(await User.objects.acount(), num_users)
+
def test_csrf_validation_passes_after_process_request_login(self):
"""
CSRF check must access the CSRF token from the session or cookie,
@@ -75,6 +98,31 @@ class RemoteUserTest(TestCase):
response = csrf_client.post("/remote_user/", data, **headers)
self.assertEqual(response.status_code, 200)
+ async def test_csrf_validation_passes_after_process_request_login_async(self):
+ """See test_csrf_validation_passes_after_process_request_login."""
+ csrf_client = AsyncClient(enforce_csrf_checks=True)
+ csrf_secret = _get_new_csrf_string()
+ csrf_token = _mask_cipher_secret(csrf_secret)
+ csrf_token_form = _mask_cipher_secret(csrf_secret)
+ headers = {self.header: "fakeuser"}
+ data = {"csrfmiddlewaretoken": csrf_token_form}
+
+ # Verify that CSRF is configured for the view
+ csrf_client.cookies.load({settings.CSRF_COOKIE_NAME: csrf_token})
+ response = await csrf_client.post("/remote_user/", **headers)
+ self.assertEqual(response.status_code, 403)
+ self.assertIn(b"CSRF verification failed.", response.content)
+
+ # This request will call django.contrib.auth.alogin() which will call
+ # django.middleware.csrf.rotate_token() thus changing the value of
+ # request.META['CSRF_COOKIE'] from the user submitted value set by
+ # CsrfViewMiddleware.process_request() to the new csrftoken value set
+ # by rotate_token(). Csrf validation should still pass when the view is
+ # later processed by CsrfViewMiddleware.process_view()
+ csrf_client.cookies.load({settings.CSRF_COOKIE_NAME: csrf_token})
+ response = await csrf_client.post("/remote_user/", data, **headers)
+ self.assertEqual(response.status_code, 200)
+
def test_unknown_user(self):
"""
Tests the case where the username passed in the header does not exist
@@ -90,6 +138,22 @@ class RemoteUserTest(TestCase):
response = self.client.get("/remote_user/", **{self.header: "newuser"})
self.assertEqual(User.objects.count(), num_users + 1)
+ async def test_unknown_user_async(self):
+ """See test_unknown_user."""
+ num_users = await User.objects.acount()
+ response = await self.async_client.get(
+ "/remote_user/", **{self.header: "newuser"}
+ )
+ self.assertEqual(response.context["user"].username, "newuser")
+ self.assertEqual(await User.objects.acount(), num_users + 1)
+ await User.objects.aget(username="newuser")
+
+ # Another request with same user should not create any new users.
+ response = await self.async_client.get(
+ "/remote_user/", **{self.header: "newuser"}
+ )
+ self.assertEqual(await User.objects.acount(), num_users + 1)
+
def test_known_user(self):
"""
Tests the case where the username passed in the header is a valid User.
@@ -106,6 +170,24 @@ class RemoteUserTest(TestCase):
self.assertEqual(response.context["user"].username, "knownuser2")
self.assertEqual(User.objects.count(), num_users)
+ async def test_known_user_async(self):
+ """See test_known_user."""
+ await User.objects.acreate(username="knownuser")
+ await User.objects.acreate(username="knownuser2")
+ num_users = await User.objects.acount()
+ response = await self.async_client.get(
+ "/remote_user/", **{self.header: self.known_user}
+ )
+ self.assertEqual(response.context["user"].username, "knownuser")
+ self.assertEqual(await User.objects.acount(), num_users)
+ # A different user passed in the headers causes the new user
+ # to be logged in.
+ response = await self.async_client.get(
+ "/remote_user/", **{self.header: self.known_user2}
+ )
+ self.assertEqual(response.context["user"].username, "knownuser2")
+ self.assertEqual(await User.objects.acount(), num_users)
+
def test_last_login(self):
"""
A user's last_login is set the first time they make a
@@ -128,6 +210,29 @@ class RemoteUserTest(TestCase):
response = self.client.get("/remote_user/", **{self.header: self.known_user})
self.assertEqual(default_login, response.context["user"].last_login)
+ async def test_last_login_async(self):
+ """See test_last_login."""
+ user = await User.objects.acreate(username="knownuser")
+ # Set last_login to something so we can determine if it changes.
+ default_login = datetime(2000, 1, 1)
+ if settings.USE_TZ:
+ default_login = default_login.replace(tzinfo=timezone.utc)
+ user.last_login = default_login
+ await user.asave()
+
+ response = await self.async_client.get(
+ "/remote_user/", **{self.header: self.known_user}
+ )
+ self.assertNotEqual(default_login, response.context["user"].last_login)
+
+ user = await User.objects.aget(username="knownuser")
+ user.last_login = default_login
+ await user.asave()
+ response = await self.async_client.get(
+ "/remote_user/", **{self.header: self.known_user}
+ )
+ self.assertEqual(default_login, response.context["user"].last_login)
+
def test_header_disappears(self):
"""
A logged in user is logged out automatically when
@@ -148,6 +253,25 @@ class RemoteUserTest(TestCase):
response = self.client.get("/remote_user/")
self.assertEqual(response.context["user"].username, "modeluser")
+ async def test_header_disappears_async(self):
+ """See test_header_disappears."""
+ await User.objects.acreate(username="knownuser")
+ # Known user authenticates
+ response = await self.async_client.get(
+ "/remote_user/", **{self.header: self.known_user}
+ )
+ self.assertEqual(response.context["user"].username, "knownuser")
+ # During the session, the REMOTE_USER header disappears. Should trigger logout.
+ response = await self.async_client.get("/remote_user/")
+ self.assertTrue(response.context["user"].is_anonymous)
+ # verify the remoteuser middleware will not remove a user
+ # authenticated via another backend
+ await User.objects.acreate_user(username="modeluser", password="foo")
+ await self.async_client.alogin(username="modeluser", password="foo")
+ await aauthenticate(username="modeluser", password="foo")
+ response = await self.async_client.get("/remote_user/")
+ self.assertEqual(response.context["user"].username, "modeluser")
+
def test_user_switch_forces_new_login(self):
"""
If the username in the header changes between requests
@@ -164,11 +288,35 @@ class RemoteUserTest(TestCase):
# In backends that do not create new users, it is '' (anonymous user)
self.assertNotEqual(response.context["user"].username, "knownuser")
+ async def test_user_switch_forces_new_login_async(self):
+ """See test_user_switch_forces_new_login."""
+ await User.objects.acreate(username="knownuser")
+ # Known user authenticates
+ response = await self.async_client.get(
+ "/remote_user/", **{self.header: self.known_user}
+ )
+ self.assertEqual(response.context["user"].username, "knownuser")
+ # During the session, the REMOTE_USER changes to a different user.
+ response = await self.async_client.get(
+ "/remote_user/", **{self.header: "newnewuser"}
+ )
+ # The current user is not the prior remote_user.
+ # In backends that create a new user, username is "newnewuser"
+ # In backends that do not create new users, it is '' (anonymous user)
+ self.assertNotEqual(response.context["user"].username, "knownuser")
+
def test_inactive_user(self):
User.objects.create(username="knownuser", is_active=False)
response = self.client.get("/remote_user/", **{self.header: "knownuser"})
self.assertTrue(response.context["user"].is_anonymous)
+ async def test_inactive_user_async(self):
+ await User.objects.acreate(username="knownuser", is_active=False)
+ response = await self.async_client.get(
+ "/remote_user/", **{self.header: "knownuser"}
+ )
+ self.assertTrue(response.context["user"].is_anonymous)
+
class RemoteUserNoCreateBackend(RemoteUserBackend):
"""Backend that doesn't create unknown users."""
@@ -190,6 +338,14 @@ class RemoteUserNoCreateTest(RemoteUserTest):
self.assertTrue(response.context["user"].is_anonymous)
self.assertEqual(User.objects.count(), num_users)
+ async def test_unknown_user_async(self):
+ num_users = await User.objects.acount()
+ response = await self.async_client.get(
+ "/remote_user/", **{self.header: "newuser"}
+ )
+ self.assertTrue(response.context["user"].is_anonymous)
+ self.assertEqual(await User.objects.acount(), num_users)
+
class AllowAllUsersRemoteUserBackendTest(RemoteUserTest):
"""Backend that allows inactive users."""
@@ -201,6 +357,13 @@ class AllowAllUsersRemoteUserBackendTest(RemoteUserTest):
response = self.client.get("/remote_user/", **{self.header: self.known_user})
self.assertEqual(response.context["user"].username, user.username)
+ async def test_inactive_user_async(self):
+ user = await User.objects.acreate(username="knownuser", is_active=False)
+ response = await self.async_client.get(
+ "/remote_user/", **{self.header: self.known_user}
+ )
+ self.assertEqual(response.context["user"].username, user.username)
+
class CustomRemoteUserBackend(RemoteUserBackend):
"""
@@ -311,3 +474,16 @@ class PersistentRemoteUserTest(RemoteUserTest):
response = self.client.get("/remote_user/")
self.assertFalse(response.context["user"].is_anonymous)
self.assertEqual(response.context["user"].username, "knownuser")
+
+ async def test_header_disappears_async(self):
+ """See test_header_disappears."""
+ await User.objects.acreate(username="knownuser")
+ # Known user authenticates
+ response = await self.async_client.get(
+ "/remote_user/", **{self.header: self.known_user}
+ )
+ self.assertEqual(response.context["user"].username, "knownuser")
+ # Should stay logged in if the REMOTE_USER header disappears.
+ response = await self.async_client.get("/remote_user/")
+ self.assertFalse(response.context["user"].is_anonymous)
+ self.assertEqual(response.context["user"].username, "knownuser")