diff options
Diffstat (limited to 'tests/check_framework/test_database.py')
| -rw-r--r-- | tests/check_framework/test_database.py | 37 |
1 files changed, 9 insertions, 28 deletions
diff --git a/tests/check_framework/test_database.py b/tests/check_framework/test_database.py index 06baf0e38d..bf291b24a1 100644 --- a/tests/check_framework/test_database.py +++ b/tests/check_framework/test_database.py @@ -1,8 +1,7 @@ import unittest from unittest import mock -from django.core.checks import Tags, run_checks -from django.core.checks.registry import CheckRegistry +from django.core.checks.database import check_database_backends from django.db import connection from django.test import TestCase @@ -10,30 +9,12 @@ from django.test import TestCase class DatabaseCheckTests(TestCase): databases = {'default', 'other'} - @property - def func(self): - from django.core.checks.database import check_database_backends - return check_database_backends - - def test_database_checks_not_run_by_default(self): - """ - `database` checks are only run when their tag is specified. - """ - def f1(**kwargs): - return [5] - - registry = CheckRegistry() - registry.register(Tags.database)(f1) - errors = registry.run_checks() - self.assertEqual(errors, []) - - errors2 = registry.run_checks(tags=[Tags.database]) - self.assertEqual(errors2, [5]) - - def test_database_checks_called(self): - with mock.patch('django.db.backends.base.validation.BaseDatabaseValidation.check') as mocked_check: - run_checks(tags=[Tags.database]) - self.assertTrue(mocked_check.called) + @mock.patch('django.db.backends.base.validation.BaseDatabaseValidation.check') + def test_database_checks_called(self, mocked_check): + check_database_backends() + self.assertFalse(mocked_check.called) + check_database_backends(databases=self.databases) + self.assertTrue(mocked_check.called) @unittest.skipUnless(connection.vendor == 'mysql', 'Test only for MySQL') def test_mysql_strict_mode(self): @@ -47,7 +28,7 @@ class DatabaseCheckTests(TestCase): 'django.db.backends.utils.CursorWrapper.fetchone', create=True, return_value=(response,) ): - self.assertEqual(self.func(None), []) + self.assertEqual(check_database_backends(databases=self.databases), []) bad_sql_modes = ['', 'WHATEVER'] for response in bad_sql_modes: @@ -56,6 +37,6 @@ class DatabaseCheckTests(TestCase): return_value=(response,) ): # One warning for each database alias - result = self.func(None) + result = check_database_backends(databases=self.databases) self.assertEqual(len(result), 2) self.assertEqual([r.id for r in result], ['mysql.W002', 'mysql.W002']) |
