diff options
| author | Jon Janzen <jon@jonjanzen.com> | 2024-03-31 12:29:10 -0700 |
|---|---|---|
| committer | Sarah Boyce <42296566+sarahboyce@users.noreply.github.com> | 2024-10-07 14:19:41 +0200 |
| commit | 50f89ae850f6b4e35819fe725a08c7e579bfd099 (patch) | |
| tree | 856a0e954e0be928c55f6070f2ac8b766459b3e7 /tests/auth_tests | |
| parent | 4cad317ff1f9a79d54c1d5b12f1ccbd260ca009f (diff) | |
Fixed #35303 -- Implemented async auth backends and utils.
Diffstat (limited to 'tests/auth_tests')
| -rw-r--r-- | tests/auth_tests/models/custom_user.py | 13 | ||||
| -rw-r--r-- | tests/auth_tests/test_auth_backends.py | 324 | ||||
| -rw-r--r-- | tests/auth_tests/test_basic.py | 39 | ||||
| -rw-r--r-- | tests/auth_tests/test_decorators.py | 6 | ||||
| -rw-r--r-- | tests/auth_tests/test_models.py | 41 | ||||
| -rw-r--r-- | tests/auth_tests/test_remote_user.py | 180 |
6 files changed, 585 insertions, 18 deletions
diff --git a/tests/auth_tests/models/custom_user.py b/tests/auth_tests/models/custom_user.py index b9938681ca..4586e452cd 100644 --- a/tests/auth_tests/models/custom_user.py +++ b/tests/auth_tests/models/custom_user.py @@ -29,6 +29,19 @@ class CustomUserManager(BaseUserManager): user.save(using=self._db) return user + async def acreate_user(self, email, date_of_birth, password=None, **fields): + """See create_user()""" + if not email: + raise ValueError("Users must have an email address") + + user = self.model( + email=self.normalize_email(email), date_of_birth=date_of_birth, **fields + ) + + user.set_password(password) + await user.asave(using=self._db) + return user + def create_superuser(self, email, password, date_of_birth, **fields): u = self.create_user( email, password=password, date_of_birth=date_of_birth, **fields diff --git a/tests/auth_tests/test_auth_backends.py b/tests/auth_tests/test_auth_backends.py index 3b4f40e6e0..b612d27ab0 100644 --- a/tests/auth_tests/test_auth_backends.py +++ b/tests/auth_tests/test_auth_backends.py @@ -2,6 +2,8 @@ import sys from datetime import date from unittest import mock +from asgiref.sync import sync_to_async + from django.contrib.auth import ( BACKEND_SESSION_KEY, SESSION_KEY, @@ -55,17 +57,33 @@ class BaseBackendTest(TestCase): def test_get_user_permissions(self): self.assertEqual(self.user.get_user_permissions(), {"user_perm"}) + async def test_aget_user_permissions(self): + self.assertEqual(await self.user.aget_user_permissions(), {"user_perm"}) + def test_get_group_permissions(self): self.assertEqual(self.user.get_group_permissions(), {"group_perm"}) + async def test_aget_group_permissions(self): + self.assertEqual(await self.user.aget_group_permissions(), {"group_perm"}) + def test_get_all_permissions(self): self.assertEqual(self.user.get_all_permissions(), {"user_perm", "group_perm"}) + async def test_aget_all_permissions(self): + self.assertEqual( + await self.user.aget_all_permissions(), {"user_perm", "group_perm"} + ) + def test_has_perm(self): self.assertIs(self.user.has_perm("user_perm"), True) self.assertIs(self.user.has_perm("group_perm"), True) self.assertIs(self.user.has_perm("other_perm", TestObj()), False) + async def test_ahas_perm(self): + self.assertIs(await self.user.ahas_perm("user_perm"), True) + self.assertIs(await self.user.ahas_perm("group_perm"), True) + self.assertIs(await self.user.ahas_perm("other_perm", TestObj()), False) + def test_has_perms_perm_list_invalid(self): msg = "perm_list must be an iterable of permissions." with self.assertRaisesMessage(ValueError, msg): @@ -73,6 +91,13 @@ class BaseBackendTest(TestCase): with self.assertRaisesMessage(ValueError, msg): self.user.has_perms(object()) + async def test_ahas_perms_perm_list_invalid(self): + msg = "perm_list must be an iterable of permissions." + with self.assertRaisesMessage(ValueError, msg): + await self.user.ahas_perms("user_perm") + with self.assertRaisesMessage(ValueError, msg): + await self.user.ahas_perms(object()) + class CountingMD5PasswordHasher(MD5PasswordHasher): """Hasher that counts how many times it computes a hash.""" @@ -125,6 +150,25 @@ class BaseModelBackendTest: user.save() self.assertIs(user.has_perm("auth.test"), False) + async def test_ahas_perm(self): + user = await self.UserModel._default_manager.aget(pk=self.user.pk) + self.assertIs(await user.ahas_perm("auth.test"), False) + + user.is_staff = True + await user.asave() + self.assertIs(await user.ahas_perm("auth.test"), False) + + user.is_superuser = True + await user.asave() + self.assertIs(await user.ahas_perm("auth.test"), True) + self.assertIs(await user.ahas_module_perms("auth"), True) + + user.is_staff = True + user.is_superuser = True + user.is_active = False + await user.asave() + self.assertIs(await user.ahas_perm("auth.test"), False) + def test_custom_perms(self): user = self.UserModel._default_manager.get(pk=self.user.pk) content_type = ContentType.objects.get_for_model(Group) @@ -174,6 +218,55 @@ class BaseModelBackendTest: self.assertIs(user.has_perm("test"), False) self.assertIs(user.has_perms(["auth.test2", "auth.test3"]), False) + async def test_acustom_perms(self): + user = await self.UserModel._default_manager.aget(pk=self.user.pk) + content_type = await sync_to_async(ContentType.objects.get_for_model)(Group) + perm = await Permission.objects.acreate( + name="test", content_type=content_type, codename="test" + ) + await user.user_permissions.aadd(perm) + + # Reloading user to purge the _perm_cache. + user = await self.UserModel._default_manager.aget(pk=self.user.pk) + self.assertEqual(await user.aget_all_permissions(), {"auth.test"}) + self.assertEqual(await user.aget_user_permissions(), {"auth.test"}) + self.assertEqual(await user.aget_group_permissions(), set()) + self.assertIs(await user.ahas_module_perms("Group"), False) + self.assertIs(await user.ahas_module_perms("auth"), True) + + perm = await Permission.objects.acreate( + name="test2", content_type=content_type, codename="test2" + ) + await user.user_permissions.aadd(perm) + perm = await Permission.objects.acreate( + name="test3", content_type=content_type, codename="test3" + ) + await user.user_permissions.aadd(perm) + user = await self.UserModel._default_manager.aget(pk=self.user.pk) + expected_user_perms = {"auth.test2", "auth.test", "auth.test3"} + self.assertEqual(await user.aget_all_permissions(), expected_user_perms) + self.assertIs(await user.ahas_perm("test"), False) + self.assertIs(await user.ahas_perm("auth.test"), True) + self.assertIs(await user.ahas_perms(["auth.test2", "auth.test3"]), True) + + perm = await Permission.objects.acreate( + name="test_group", content_type=content_type, codename="test_group" + ) + group = await Group.objects.acreate(name="test_group") + await group.permissions.aadd(perm) + await user.groups.aadd(group) + user = await self.UserModel._default_manager.aget(pk=self.user.pk) + self.assertEqual( + await user.aget_all_permissions(), {*expected_user_perms, "auth.test_group"} + ) + self.assertEqual(await user.aget_user_permissions(), expected_user_perms) + self.assertEqual(await user.aget_group_permissions(), {"auth.test_group"}) + self.assertIs(await user.ahas_perms(["auth.test3", "auth.test_group"]), True) + + user = AnonymousUser() + self.assertIs(await user.ahas_perm("test"), False) + self.assertIs(await user.ahas_perms(["auth.test2", "auth.test3"]), False) + def test_has_no_object_perm(self): """Regressiontest for #12462""" user = self.UserModel._default_manager.get(pk=self.user.pk) @@ -188,6 +281,20 @@ class BaseModelBackendTest: self.assertIs(user.has_perm("auth.test"), True) self.assertEqual(user.get_all_permissions(), {"auth.test"}) + async def test_ahas_no_object_perm(self): + """See test_has_no_object_perm()""" + user = await self.UserModel._default_manager.aget(pk=self.user.pk) + content_type = await sync_to_async(ContentType.objects.get_for_model)(Group) + perm = await Permission.objects.acreate( + name="test", content_type=content_type, codename="test" + ) + await user.user_permissions.aadd(perm) + + self.assertIs(await user.ahas_perm("auth.test", "object"), False) + self.assertEqual(await user.aget_all_permissions("object"), set()) + self.assertIs(await user.ahas_perm("auth.test"), True) + self.assertEqual(await user.aget_all_permissions(), {"auth.test"}) + def test_anonymous_has_no_permissions(self): """ #17903 -- Anonymous users shouldn't have permissions in @@ -220,6 +327,38 @@ class BaseModelBackendTest: self.assertEqual(backend.get_user_permissions(user), set()) self.assertEqual(backend.get_group_permissions(user), set()) + async def test_aanonymous_has_no_permissions(self): + """See test_anonymous_has_no_permissions()""" + backend = ModelBackend() + + user = await self.UserModel._default_manager.aget(pk=self.user.pk) + content_type = await sync_to_async(ContentType.objects.get_for_model)(Group) + user_perm = await Permission.objects.acreate( + name="test", content_type=content_type, codename="test_user" + ) + group_perm = await Permission.objects.acreate( + name="test2", content_type=content_type, codename="test_group" + ) + await user.user_permissions.aadd(user_perm) + + group = await Group.objects.acreate(name="test_group") + await user.groups.aadd(group) + await group.permissions.aadd(group_perm) + + self.assertEqual( + await backend.aget_all_permissions(user), + {"auth.test_user", "auth.test_group"}, + ) + self.assertEqual(await backend.aget_user_permissions(user), {"auth.test_user"}) + self.assertEqual( + await backend.aget_group_permissions(user), {"auth.test_group"} + ) + + with mock.patch.object(self.UserModel, "is_anonymous", True): + self.assertEqual(await backend.aget_all_permissions(user), set()) + self.assertEqual(await backend.aget_user_permissions(user), set()) + self.assertEqual(await backend.aget_group_permissions(user), set()) + def test_inactive_has_no_permissions(self): """ #17903 -- Inactive users shouldn't have permissions in @@ -254,11 +393,52 @@ class BaseModelBackendTest: self.assertEqual(backend.get_user_permissions(user), set()) self.assertEqual(backend.get_group_permissions(user), set()) + async def test_ainactive_has_no_permissions(self): + """See test_inactive_has_no_permissions()""" + backend = ModelBackend() + + user = await self.UserModel._default_manager.aget(pk=self.user.pk) + content_type = await sync_to_async(ContentType.objects.get_for_model)(Group) + user_perm = await Permission.objects.acreate( + name="test", content_type=content_type, codename="test_user" + ) + group_perm = await Permission.objects.acreate( + name="test2", content_type=content_type, codename="test_group" + ) + await user.user_permissions.aadd(user_perm) + + group = await Group.objects.acreate(name="test_group") + await user.groups.aadd(group) + await group.permissions.aadd(group_perm) + + self.assertEqual( + await backend.aget_all_permissions(user), + {"auth.test_user", "auth.test_group"}, + ) + self.assertEqual(await backend.aget_user_permissions(user), {"auth.test_user"}) + self.assertEqual( + await backend.aget_group_permissions(user), {"auth.test_group"} + ) + + user.is_active = False + await user.asave() + + self.assertEqual(await backend.aget_all_permissions(user), set()) + self.assertEqual(await backend.aget_user_permissions(user), set()) + self.assertEqual(await backend.aget_group_permissions(user), set()) + def test_get_all_superuser_permissions(self): """A superuser has all permissions. Refs #14795.""" user = self.UserModel._default_manager.get(pk=self.superuser.pk) self.assertEqual(len(user.get_all_permissions()), len(Permission.objects.all())) + async def test_aget_all_superuser_permissions(self): + """See test_get_all_superuser_permissions()""" + user = await self.UserModel._default_manager.aget(pk=self.superuser.pk) + self.assertEqual( + len(await user.aget_all_permissions()), await Permission.objects.acount() + ) + @override_settings( PASSWORD_HASHERS=["auth_tests.test_auth_backends.CountingMD5PasswordHasher"] ) @@ -280,6 +460,24 @@ class BaseModelBackendTest: @override_settings( PASSWORD_HASHERS=["auth_tests.test_auth_backends.CountingMD5PasswordHasher"] ) + async def test_aauthentication_timing(self): + """See test_authentication_timing()""" + # Re-set the password, because this tests overrides PASSWORD_HASHERS. + self.user.set_password("test") + await self.user.asave() + + CountingMD5PasswordHasher.calls = 0 + username = getattr(self.user, self.UserModel.USERNAME_FIELD) + await aauthenticate(username=username, password="test") + self.assertEqual(CountingMD5PasswordHasher.calls, 1) + + CountingMD5PasswordHasher.calls = 0 + await aauthenticate(username="no_such_user", password="test") + self.assertEqual(CountingMD5PasswordHasher.calls, 1) + + @override_settings( + PASSWORD_HASHERS=["auth_tests.test_auth_backends.CountingMD5PasswordHasher"] + ) def test_authentication_without_credentials(self): CountingMD5PasswordHasher.calls = 0 for credentials in ( @@ -320,6 +518,15 @@ class ModelBackendTest(BaseModelBackendTest, TestCase): self.user.save() self.assertIsNone(authenticate(**self.user_credentials)) + async def test_aauthenticate_inactive(self): + """ + An inactive user can't authenticate. + """ + self.assertEqual(await aauthenticate(**self.user_credentials), self.user) + self.user.is_active = False + await self.user.asave() + self.assertIsNone(await aauthenticate(**self.user_credentials)) + @override_settings(AUTH_USER_MODEL="auth_tests.CustomUserWithoutIsActiveField") def test_authenticate_user_without_is_active_field(self): """ @@ -332,6 +539,18 @@ class ModelBackendTest(BaseModelBackendTest, TestCase): ) self.assertEqual(authenticate(username="test", password="test"), user) + @override_settings(AUTH_USER_MODEL="auth_tests.CustomUserWithoutIsActiveField") + async def test_aauthenticate_user_without_is_active_field(self): + """ + A custom user without an `is_active` field is allowed to authenticate. + """ + user = await CustomUserWithoutIsActiveField.objects._acreate_user( + username="test", + email="test@example.com", + password="test", + ) + self.assertEqual(await aauthenticate(username="test", password="test"), user) + @override_settings(AUTH_USER_MODEL="auth_tests.ExtensionUser") class ExtensionUserModelBackendTest(BaseModelBackendTest, TestCase): @@ -403,6 +622,15 @@ class CustomUserModelBackendAuthenticateTest(TestCase): authenticated_user = authenticate(email="test@example.com", password="test") self.assertEqual(test_user, authenticated_user) + async def test_aauthenticate(self): + test_user = await CustomUser._default_manager.acreate_user( + email="test@example.com", password="test", date_of_birth=date(2006, 4, 25) + ) + authenticated_user = await aauthenticate( + email="test@example.com", password="test" + ) + self.assertEqual(test_user, authenticated_user) + @override_settings(AUTH_USER_MODEL="auth_tests.UUIDUser") class UUIDUserTests(TestCase): @@ -416,6 +644,13 @@ class UUIDUserTests(TestCase): UUIDUser.objects.get(pk=self.client.session[SESSION_KEY]), user ) + async def test_alogin(self): + """See test_login()""" + user = await UUIDUser.objects.acreate_user(username="uuid", password="test") + self.assertTrue(await self.client.alogin(username="uuid", password="test")) + session_key = await self.client.session.aget(SESSION_KEY) + self.assertEqual(await UUIDUser.objects.aget(pk=session_key), user) + class TestObj: pass @@ -435,9 +670,15 @@ class SimpleRowlevelBackend: return True return False + async def ahas_perm(self, user, perm, obj=None): + return self.has_perm(user, perm, obj) + def has_module_perms(self, user, app_label): return (user.is_anonymous or user.is_active) and app_label == "app1" + async def ahas_module_perms(self, user, app_label): + return self.has_module_perms(user, app_label) + def get_all_permissions(self, user, obj=None): if not obj: return [] # We only support row level perms @@ -452,6 +693,9 @@ class SimpleRowlevelBackend: else: return ["simple"] + async def aget_all_permissions(self, user, obj=None): + return self.get_all_permissions(user, obj) + def get_group_permissions(self, user, obj=None): if not obj: return # We only support row level perms @@ -524,10 +768,18 @@ class AnonymousUserBackendTest(SimpleTestCase): self.assertIs(self.user1.has_perm("perm", TestObj()), False) self.assertIs(self.user1.has_perm("anon", TestObj()), True) + async def test_ahas_perm(self): + self.assertIs(await self.user1.ahas_perm("perm", TestObj()), False) + self.assertIs(await self.user1.ahas_perm("anon", TestObj()), True) + def test_has_perms(self): self.assertIs(self.user1.has_perms(["anon"], TestObj()), True) self.assertIs(self.user1.has_perms(["anon", "perm"], TestObj()), False) + async def test_ahas_perms(self): + self.assertIs(await self.user1.ahas_perms(["anon"], TestObj()), True) + self.assertIs(await self.user1.ahas_perms(["anon", "perm"], TestObj()), False) + def test_has_perms_perm_list_invalid(self): msg = "perm_list must be an iterable of permissions." with self.assertRaisesMessage(ValueError, msg): @@ -535,13 +787,27 @@ class AnonymousUserBackendTest(SimpleTestCase): with self.assertRaisesMessage(ValueError, msg): self.user1.has_perms(object()) + async def test_ahas_perms_perm_list_invalid(self): + msg = "perm_list must be an iterable of permissions." + with self.assertRaisesMessage(ValueError, msg): + await self.user1.ahas_perms("perm") + with self.assertRaisesMessage(ValueError, msg): + await self.user1.ahas_perms(object()) + def test_has_module_perms(self): self.assertIs(self.user1.has_module_perms("app1"), True) self.assertIs(self.user1.has_module_perms("app2"), False) + async def test_ahas_module_perms(self): + self.assertIs(await self.user1.ahas_module_perms("app1"), True) + self.assertIs(await self.user1.ahas_module_perms("app2"), False) + def test_get_all_permissions(self): self.assertEqual(self.user1.get_all_permissions(TestObj()), {"anon"}) + async def test_aget_all_permissions(self): + self.assertEqual(await self.user1.aget_all_permissions(TestObj()), {"anon"}) + @override_settings(AUTHENTICATION_BACKENDS=[]) class NoBackendsTest(TestCase): @@ -561,6 +827,14 @@ class NoBackendsTest(TestCase): with self.assertRaisesMessage(ImproperlyConfigured, msg): self.user.has_perm(("perm", TestObj())) + async def test_araises_exception(self): + msg = ( + "No authentication backends have been defined. " + "Does AUTHENTICATION_BACKENDS contain anything?" + ) + with self.assertRaisesMessage(ImproperlyConfigured, msg): + await self.user.ahas_perm(("perm", TestObj())) + @override_settings( AUTHENTICATION_BACKENDS=["auth_tests.test_auth_backends.SimpleRowlevelBackend"] @@ -593,12 +867,21 @@ class PermissionDeniedBackend: def authenticate(self, request, username=None, password=None): raise PermissionDenied + async def aauthenticate(self, request, username=None, password=None): + raise PermissionDenied + def has_perm(self, user_obj, perm, obj=None): raise PermissionDenied + async def ahas_perm(self, user_obj, perm, obj=None): + raise PermissionDenied + def has_module_perms(self, user_obj, app_label): raise PermissionDenied + async def ahas_module_perms(self, user_obj, app_label): + raise PermissionDenied + class PermissionDeniedBackendTest(TestCase): """ @@ -631,10 +914,25 @@ class PermissionDeniedBackendTest(TestCase): [{"password": "********************", "username": "test"}], ) + @modify_settings(AUTHENTICATION_BACKENDS={"prepend": backend}) + async def test_aauthenticate_permission_denied(self): + self.assertIsNone(await aauthenticate(username="test", password="test")) + # user_login_failed signal is sent. + self.assertEqual( + self.user_login_failed, + [{"password": "********************", "username": "test"}], + ) + @modify_settings(AUTHENTICATION_BACKENDS={"append": backend}) def test_authenticates(self): self.assertEqual(authenticate(username="test", password="test"), self.user1) + @modify_settings(AUTHENTICATION_BACKENDS={"append": backend}) + async def test_aauthenticate(self): + self.assertEqual( + await aauthenticate(username="test", password="test"), self.user1 + ) + @modify_settings(AUTHENTICATION_BACKENDS={"prepend": backend}) def test_has_perm_denied(self): content_type = ContentType.objects.get_for_model(Group) @@ -646,6 +944,17 @@ class PermissionDeniedBackendTest(TestCase): self.assertIs(self.user1.has_perm("auth.test"), False) self.assertIs(self.user1.has_module_perms("auth"), False) + @modify_settings(AUTHENTICATION_BACKENDS={"prepend": backend}) + async def test_ahas_perm_denied(self): + content_type = await sync_to_async(ContentType.objects.get_for_model)(Group) + perm = await Permission.objects.acreate( + name="test", content_type=content_type, codename="test" + ) + await self.user1.user_permissions.aadd(perm) + + self.assertIs(await self.user1.ahas_perm("auth.test"), False) + self.assertIs(await self.user1.ahas_module_perms("auth"), False) + @modify_settings(AUTHENTICATION_BACKENDS={"append": backend}) def test_has_perm(self): content_type = ContentType.objects.get_for_model(Group) @@ -657,6 +966,17 @@ class PermissionDeniedBackendTest(TestCase): self.assertIs(self.user1.has_perm("auth.test"), True) self.assertIs(self.user1.has_module_perms("auth"), True) + @modify_settings(AUTHENTICATION_BACKENDS={"append": backend}) + async def test_ahas_perm(self): + content_type = await sync_to_async(ContentType.objects.get_for_model)(Group) + perm = await Permission.objects.acreate( + name="test", content_type=content_type, codename="test" + ) + await self.user1.user_permissions.aadd(perm) + + self.assertIs(await self.user1.ahas_perm("auth.test"), True) + self.assertIs(await self.user1.ahas_module_perms("auth"), True) + class NewModelBackend(ModelBackend): pass @@ -715,6 +1035,10 @@ class TypeErrorBackend: def authenticate(self, request, username=None, password=None): raise TypeError + @sensitive_variables("password") + async def aauthenticate(self, request, username=None, password=None): + raise TypeError + class SkippedBackend: def authenticate(self): diff --git a/tests/auth_tests/test_basic.py b/tests/auth_tests/test_basic.py index d7a7750b54..8d54e187fc 100644 --- a/tests/auth_tests/test_basic.py +++ b/tests/auth_tests/test_basic.py @@ -1,5 +1,3 @@ -from asgiref.sync import sync_to_async - from django.conf import settings from django.contrib.auth import aget_user, get_user, get_user_model from django.contrib.auth.models import AnonymousUser, User @@ -44,6 +42,12 @@ class BasicTestCase(TestCase): u2 = User.objects.create_user("testuser2", "test2@example.com") self.assertFalse(u2.has_usable_password()) + async def test_acreate(self): + u = await User.objects.acreate_user("testuser", "test@example.com", "testpw") + self.assertTrue(u.has_usable_password()) + self.assertFalse(await u.acheck_password("bad")) + self.assertTrue(await u.acheck_password("testpw")) + def test_unicode_username(self): User.objects.create_user("jörg") User.objects.create_user("Григорий") @@ -73,6 +77,15 @@ class BasicTestCase(TestCase): self.assertTrue(super.is_active) self.assertTrue(super.is_staff) + async def test_asuperuser(self): + "Check the creation and properties of a superuser" + super = await User.objects.acreate_superuser( + "super", "super@example.com", "super" + ) + self.assertTrue(super.is_superuser) + self.assertTrue(super.is_active) + self.assertTrue(super.is_staff) + def test_superuser_no_email_or_password(self): cases = [ {}, @@ -171,13 +184,25 @@ class TestGetUser(TestCase): self.assertIsInstance(user, User) self.assertEqual(user.username, created_user.username) - async def test_aget_user(self): - created_user = await sync_to_async(User.objects.create_user)( + async def test_aget_user_fallback_secret(self): + created_user = await User.objects.acreate_user( "testuser", "test@example.com", "testpw" ) await self.client.alogin(username="testuser", password="testpw") request = HttpRequest() request.session = await self.client.asession() - user = await aget_user(request) - self.assertIsInstance(user, User) - self.assertEqual(user.username, created_user.username) + prev_session_key = request.session.session_key + with override_settings( + SECRET_KEY="newsecret", + SECRET_KEY_FALLBACKS=[settings.SECRET_KEY], + ): + user = await aget_user(request) + self.assertIsInstance(user, User) + self.assertEqual(user.username, created_user.username) + self.assertNotEqual(request.session.session_key, prev_session_key) + # Remove the fallback secret. + # The session hash should be updated using the current secret. + with override_settings(SECRET_KEY="newsecret"): + user = await aget_user(request) + self.assertIsInstance(user, User) + self.assertEqual(user.username, created_user.username) diff --git a/tests/auth_tests/test_decorators.py b/tests/auth_tests/test_decorators.py index e585b28bd5..fa2672beb4 100644 --- a/tests/auth_tests/test_decorators.py +++ b/tests/auth_tests/test_decorators.py @@ -1,7 +1,5 @@ from asyncio import iscoroutinefunction -from asgiref.sync import sync_to_async - from django.conf import settings from django.contrib.auth import models from django.contrib.auth.decorators import ( @@ -374,7 +372,7 @@ class UserPassesTestDecoratorTest(TestCase): def test_decorator_async_test_func(self): async def async_test_func(user): - return await sync_to_async(user.has_perms)(["auth_tests.add_customuser"]) + return await user.ahas_perms(["auth_tests.add_customuser"]) @user_passes_test(async_test_func) def sync_view(request): @@ -410,7 +408,7 @@ class UserPassesTestDecoratorTest(TestCase): async def test_decorator_async_view_async_test_func(self): async def async_test_func(user): - return await sync_to_async(user.has_perms)(["auth_tests.add_customuser"]) + return await user.ahas_perms(["auth_tests.add_customuser"]) @user_passes_test(async_test_func) async def async_view(request): diff --git a/tests/auth_tests/test_models.py b/tests/auth_tests/test_models.py index 983424843c..a3e7a3205b 100644 --- a/tests/auth_tests/test_models.py +++ b/tests/auth_tests/test_models.py @@ -1,7 +1,5 @@ from unittest import mock -from asgiref.sync import sync_to_async - from django.conf.global_settings import PASSWORD_HASHERS from django.contrib.auth import get_user_model from django.contrib.auth.backends import ModelBackend @@ -30,10 +28,19 @@ class NaturalKeysTestCase(TestCase): self.assertEqual(User.objects.get_by_natural_key("staff"), staff_user) self.assertEqual(staff_user.natural_key(), ("staff",)) + async def test_auser_natural_key(self): + staff_user = await User.objects.acreate_user(username="staff") + self.assertEqual(await User.objects.aget_by_natural_key("staff"), staff_user) + self.assertEqual(staff_user.natural_key(), ("staff",)) + def test_group_natural_key(self): users_group = Group.objects.create(name="users") self.assertEqual(Group.objects.get_by_natural_key("users"), users_group) + async def test_agroup_natural_key(self): + users_group = await Group.objects.acreate(name="users") + self.assertEqual(await Group.objects.aget_by_natural_key("users"), users_group) + class LoadDataWithoutNaturalKeysTestCase(TestCase): fixtures = ["regular.json"] @@ -157,6 +164,17 @@ class UserManagerTestCase(TransactionTestCase): is_superuser=False, ) + async def test_acreate_super_user_raises_error_on_false_is_superuser(self): + with self.assertRaisesMessage( + ValueError, "Superuser must have is_superuser=True." + ): + await User.objects.acreate_superuser( + username="test", + email="test@test.com", + password="test", + is_superuser=False, + ) + def test_create_superuser_raises_error_on_false_is_staff(self): with self.assertRaisesMessage(ValueError, "Superuser must have is_staff=True."): User.objects.create_superuser( @@ -166,6 +184,15 @@ class UserManagerTestCase(TransactionTestCase): is_staff=False, ) + async def test_acreate_superuser_raises_error_on_false_is_staff(self): + with self.assertRaisesMessage(ValueError, "Superuser must have is_staff=True."): + await User.objects.acreate_superuser( + username="test", + email="test@test.com", + password="test", + is_staff=False, + ) + def test_runpython_manager_methods(self): def forwards(apps, schema_editor): UserModel = apps.get_model("auth", "User") @@ -301,9 +328,7 @@ class AbstractUserTestCase(TestCase): @override_settings(PASSWORD_HASHERS=PASSWORD_HASHERS) async def test_acheck_password_upgrade(self): - user = await sync_to_async(User.objects.create_user)( - username="user", password="foo" - ) + user = await User.objects.acreate_user(username="user", password="foo") initial_password = user.password self.assertIs(await user.acheck_password("foo"), True) hasher = get_hasher("default") @@ -557,6 +582,12 @@ class AnonymousUserTests(SimpleTestCase): self.assertEqual(self.user.get_user_permissions(), set()) self.assertEqual(self.user.get_group_permissions(), set()) + async def test_properties_async_versions(self): + self.assertEqual(await self.user.groups.acount(), 0) + self.assertEqual(await self.user.user_permissions.acount(), 0) + self.assertEqual(await self.user.aget_user_permissions(), set()) + self.assertEqual(await self.user.aget_group_permissions(), set()) + def test_str(self): self.assertEqual(str(self.user), "AnonymousUser") diff --git a/tests/auth_tests/test_remote_user.py b/tests/auth_tests/test_remote_user.py index d3cf4b9da5..85de931c1a 100644 --- a/tests/auth_tests/test_remote_user.py +++ b/tests/auth_tests/test_remote_user.py @@ -1,12 +1,18 @@ from datetime import datetime, timezone from django.conf import settings -from django.contrib.auth import authenticate +from django.contrib.auth import aauthenticate, authenticate from django.contrib.auth.backends import RemoteUserBackend from django.contrib.auth.middleware import RemoteUserMiddleware from django.contrib.auth.models import User from django.middleware.csrf import _get_new_csrf_string, _mask_cipher_secret -from django.test import Client, TestCase, modify_settings, override_settings +from django.test import ( + AsyncClient, + Client, + TestCase, + modify_settings, + override_settings, +) @override_settings(ROOT_URLCONF="auth_tests.urls") @@ -30,6 +36,11 @@ class RemoteUserTest(TestCase): ) super().setUpClass() + def test_passing_explicit_none(self): + msg = "get_response must be provided." + with self.assertRaisesMessage(ValueError, msg): + RemoteUserMiddleware(None) + def test_no_remote_user(self): """Users are not created when remote user is not specified.""" num_users = User.objects.count() @@ -46,6 +57,18 @@ class RemoteUserTest(TestCase): self.assertTrue(response.context["user"].is_anonymous) self.assertEqual(User.objects.count(), num_users) + async def test_no_remote_user_async(self): + """See test_no_remote_user.""" + num_users = await User.objects.acount() + + response = await self.async_client.get("/remote_user/") + self.assertTrue(response.context["user"].is_anonymous) + self.assertEqual(await User.objects.acount(), num_users) + + response = await self.async_client.get("/remote_user/", **{self.header: ""}) + self.assertTrue(response.context["user"].is_anonymous) + self.assertEqual(await User.objects.acount(), num_users) + def test_csrf_validation_passes_after_process_request_login(self): """ CSRF check must access the CSRF token from the session or cookie, @@ -75,6 +98,31 @@ class RemoteUserTest(TestCase): response = csrf_client.post("/remote_user/", data, **headers) self.assertEqual(response.status_code, 200) + async def test_csrf_validation_passes_after_process_request_login_async(self): + """See test_csrf_validation_passes_after_process_request_login.""" + csrf_client = AsyncClient(enforce_csrf_checks=True) + csrf_secret = _get_new_csrf_string() + csrf_token = _mask_cipher_secret(csrf_secret) + csrf_token_form = _mask_cipher_secret(csrf_secret) + headers = {self.header: "fakeuser"} + data = {"csrfmiddlewaretoken": csrf_token_form} + + # Verify that CSRF is configured for the view + csrf_client.cookies.load({settings.CSRF_COOKIE_NAME: csrf_token}) + response = await csrf_client.post("/remote_user/", **headers) + self.assertEqual(response.status_code, 403) + self.assertIn(b"CSRF verification failed.", response.content) + + # This request will call django.contrib.auth.alogin() which will call + # django.middleware.csrf.rotate_token() thus changing the value of + # request.META['CSRF_COOKIE'] from the user submitted value set by + # CsrfViewMiddleware.process_request() to the new csrftoken value set + # by rotate_token(). Csrf validation should still pass when the view is + # later processed by CsrfViewMiddleware.process_view() + csrf_client.cookies.load({settings.CSRF_COOKIE_NAME: csrf_token}) + response = await csrf_client.post("/remote_user/", data, **headers) + self.assertEqual(response.status_code, 200) + def test_unknown_user(self): """ Tests the case where the username passed in the header does not exist @@ -90,6 +138,22 @@ class RemoteUserTest(TestCase): response = self.client.get("/remote_user/", **{self.header: "newuser"}) self.assertEqual(User.objects.count(), num_users + 1) + async def test_unknown_user_async(self): + """See test_unknown_user.""" + num_users = await User.objects.acount() + response = await self.async_client.get( + "/remote_user/", **{self.header: "newuser"} + ) + self.assertEqual(response.context["user"].username, "newuser") + self.assertEqual(await User.objects.acount(), num_users + 1) + await User.objects.aget(username="newuser") + + # Another request with same user should not create any new users. + response = await self.async_client.get( + "/remote_user/", **{self.header: "newuser"} + ) + self.assertEqual(await User.objects.acount(), num_users + 1) + def test_known_user(self): """ Tests the case where the username passed in the header is a valid User. @@ -106,6 +170,24 @@ class RemoteUserTest(TestCase): self.assertEqual(response.context["user"].username, "knownuser2") self.assertEqual(User.objects.count(), num_users) + async def test_known_user_async(self): + """See test_known_user.""" + await User.objects.acreate(username="knownuser") + await User.objects.acreate(username="knownuser2") + num_users = await User.objects.acount() + response = await self.async_client.get( + "/remote_user/", **{self.header: self.known_user} + ) + self.assertEqual(response.context["user"].username, "knownuser") + self.assertEqual(await User.objects.acount(), num_users) + # A different user passed in the headers causes the new user + # to be logged in. + response = await self.async_client.get( + "/remote_user/", **{self.header: self.known_user2} + ) + self.assertEqual(response.context["user"].username, "knownuser2") + self.assertEqual(await User.objects.acount(), num_users) + def test_last_login(self): """ A user's last_login is set the first time they make a @@ -128,6 +210,29 @@ class RemoteUserTest(TestCase): response = self.client.get("/remote_user/", **{self.header: self.known_user}) self.assertEqual(default_login, response.context["user"].last_login) + async def test_last_login_async(self): + """See test_last_login.""" + user = await User.objects.acreate(username="knownuser") + # Set last_login to something so we can determine if it changes. + default_login = datetime(2000, 1, 1) + if settings.USE_TZ: + default_login = default_login.replace(tzinfo=timezone.utc) + user.last_login = default_login + await user.asave() + + response = await self.async_client.get( + "/remote_user/", **{self.header: self.known_user} + ) + self.assertNotEqual(default_login, response.context["user"].last_login) + + user = await User.objects.aget(username="knownuser") + user.last_login = default_login + await user.asave() + response = await self.async_client.get( + "/remote_user/", **{self.header: self.known_user} + ) + self.assertEqual(default_login, response.context["user"].last_login) + def test_header_disappears(self): """ A logged in user is logged out automatically when @@ -148,6 +253,25 @@ class RemoteUserTest(TestCase): response = self.client.get("/remote_user/") self.assertEqual(response.context["user"].username, "modeluser") + async def test_header_disappears_async(self): + """See test_header_disappears.""" + await User.objects.acreate(username="knownuser") + # Known user authenticates + response = await self.async_client.get( + "/remote_user/", **{self.header: self.known_user} + ) + self.assertEqual(response.context["user"].username, "knownuser") + # During the session, the REMOTE_USER header disappears. Should trigger logout. + response = await self.async_client.get("/remote_user/") + self.assertTrue(response.context["user"].is_anonymous) + # verify the remoteuser middleware will not remove a user + # authenticated via another backend + await User.objects.acreate_user(username="modeluser", password="foo") + await self.async_client.alogin(username="modeluser", password="foo") + await aauthenticate(username="modeluser", password="foo") + response = await self.async_client.get("/remote_user/") + self.assertEqual(response.context["user"].username, "modeluser") + def test_user_switch_forces_new_login(self): """ If the username in the header changes between requests @@ -164,11 +288,35 @@ class RemoteUserTest(TestCase): # In backends that do not create new users, it is '' (anonymous user) self.assertNotEqual(response.context["user"].username, "knownuser") + async def test_user_switch_forces_new_login_async(self): + """See test_user_switch_forces_new_login.""" + await User.objects.acreate(username="knownuser") + # Known user authenticates + response = await self.async_client.get( + "/remote_user/", **{self.header: self.known_user} + ) + self.assertEqual(response.context["user"].username, "knownuser") + # During the session, the REMOTE_USER changes to a different user. + response = await self.async_client.get( + "/remote_user/", **{self.header: "newnewuser"} + ) + # The current user is not the prior remote_user. + # In backends that create a new user, username is "newnewuser" + # In backends that do not create new users, it is '' (anonymous user) + self.assertNotEqual(response.context["user"].username, "knownuser") + def test_inactive_user(self): User.objects.create(username="knownuser", is_active=False) response = self.client.get("/remote_user/", **{self.header: "knownuser"}) self.assertTrue(response.context["user"].is_anonymous) + async def test_inactive_user_async(self): + await User.objects.acreate(username="knownuser", is_active=False) + response = await self.async_client.get( + "/remote_user/", **{self.header: "knownuser"} + ) + self.assertTrue(response.context["user"].is_anonymous) + class RemoteUserNoCreateBackend(RemoteUserBackend): """Backend that doesn't create unknown users.""" @@ -190,6 +338,14 @@ class RemoteUserNoCreateTest(RemoteUserTest): self.assertTrue(response.context["user"].is_anonymous) self.assertEqual(User.objects.count(), num_users) + async def test_unknown_user_async(self): + num_users = await User.objects.acount() + response = await self.async_client.get( + "/remote_user/", **{self.header: "newuser"} + ) + self.assertTrue(response.context["user"].is_anonymous) + self.assertEqual(await User.objects.acount(), num_users) + class AllowAllUsersRemoteUserBackendTest(RemoteUserTest): """Backend that allows inactive users.""" @@ -201,6 +357,13 @@ class AllowAllUsersRemoteUserBackendTest(RemoteUserTest): response = self.client.get("/remote_user/", **{self.header: self.known_user}) self.assertEqual(response.context["user"].username, user.username) + async def test_inactive_user_async(self): + user = await User.objects.acreate(username="knownuser", is_active=False) + response = await self.async_client.get( + "/remote_user/", **{self.header: self.known_user} + ) + self.assertEqual(response.context["user"].username, user.username) + class CustomRemoteUserBackend(RemoteUserBackend): """ @@ -311,3 +474,16 @@ class PersistentRemoteUserTest(RemoteUserTest): response = self.client.get("/remote_user/") self.assertFalse(response.context["user"].is_anonymous) self.assertEqual(response.context["user"].username, "knownuser") + + async def test_header_disappears_async(self): + """See test_header_disappears.""" + await User.objects.acreate(username="knownuser") + # Known user authenticates + response = await self.async_client.get( + "/remote_user/", **{self.header: self.known_user} + ) + self.assertEqual(response.context["user"].username, "knownuser") + # Should stay logged in if the REMOTE_USER header disappears. + response = await self.async_client.get("/remote_user/") + self.assertFalse(response.context["user"].is_anonymous) + self.assertEqual(response.context["user"].username, "knownuser") |
