import copy
import os
import sys
import threading
import traceback
import unittest
import warnings
from contextlib import contextmanager
from functools import partial
from io import StringIO
from unittest import mock
from django.conf import STATICFILES_STORAGE_ALIAS, settings
from django.contrib.staticfiles.finders import get_finder, get_finders
from django.contrib.staticfiles.storage import staticfiles_storage
from django.core.exceptions import ImproperlyConfigured
from django.core.files.storage import default_storage
from django.db import (
IntegrityError,
connection,
connections,
models,
router,
transaction,
)
from django.forms import (
CharField,
EmailField,
Form,
IntegerField,
ValidationError,
formset_factory,
)
from django.http import HttpResponse, StreamingHttpResponse
from django.template import Context, Template
from django.template.loader import render_to_string
from django.test import (
SimpleTestCase,
TestCase,
TransactionTestCase,
ignore_warnings,
skipIfDBFeature,
skipUnlessDBFeature,
)
from django.test.html import HTMLParseError, parse_html
from django.test.testcases import DatabaseOperationForbidden
from django.test.utils import (
CaptureQueriesContext,
TestContextDecorator,
isolate_apps,
override_settings,
setup_test_environment,
teardown_test_environment,
)
from django.urls import NoReverseMatch, path, reverse, reverse_lazy
from django.utils.deprecation import RemovedInDjango70Warning
from django.utils.html import VOID_ELEMENTS
from .models import Car, Person, PossessedCar
from .views import empty_response
class SkippingTestCase(SimpleTestCase):
def _assert_skipping(self, func, expected_exc, msg=None):
try:
if msg is not None:
with self.assertRaisesMessage(expected_exc, msg):
func()
else:
with self.assertRaises(expected_exc):
func()
except unittest.SkipTest:
self.fail("%s should not result in a skipped test." % func.__name__)
def test_skip_unless_db_feature(self):
"""
Testing the django.test.skipUnlessDBFeature decorator.
"""
# Total hack, but it works, just want an attribute that's always true.
@skipUnlessDBFeature("__class__")
def test_func():
raise ValueError
@skipUnlessDBFeature("notprovided")
def test_func2():
raise ValueError
@skipUnlessDBFeature("__class__", "__class__")
def test_func3():
raise ValueError
@skipUnlessDBFeature("__class__", "notprovided")
def test_func4():
raise ValueError
self._assert_skipping(test_func, ValueError)
self._assert_skipping(test_func2, AttributeError)
self._assert_skipping(test_func3, ValueError)
self._assert_skipping(test_func4, AttributeError)
class SkipTestCase(SimpleTestCase):
@skipUnlessDBFeature("missing")
def test_foo(self):
pass
self._assert_skipping(
SkipTestCase("test_foo").test_foo,
ValueError,
"skipUnlessDBFeature cannot be used on test_foo (test_utils.tests."
"SkippingTestCase.test_skip_unless_db_feature..SkipTestCase."
"test_foo) as SkippingTestCase.test_skip_unless_db_feature.."
"SkipTestCase doesn't allow queries against the 'default' database.",
)
def test_skip_if_db_feature(self):
"""
Testing the django.test.skipIfDBFeature decorator.
"""
@skipIfDBFeature("__class__")
def test_func():
raise ValueError
@skipIfDBFeature("notprovided")
def test_func2():
raise ValueError
@skipIfDBFeature("__class__", "__class__")
def test_func3():
raise ValueError
@skipIfDBFeature("__class__", "notprovided")
def test_func4():
raise ValueError
@skipIfDBFeature("notprovided", "notprovided")
def test_func5():
raise ValueError
self._assert_skipping(test_func, unittest.SkipTest)
self._assert_skipping(test_func2, AttributeError)
self._assert_skipping(test_func3, unittest.SkipTest)
self._assert_skipping(test_func4, unittest.SkipTest)
self._assert_skipping(test_func5, AttributeError)
class SkipTestCase(SimpleTestCase):
@skipIfDBFeature("missing")
def test_foo(self):
pass
self._assert_skipping(
SkipTestCase("test_foo").test_foo,
ValueError,
"skipIfDBFeature cannot be used on test_foo (test_utils.tests."
"SkippingTestCase.test_skip_if_db_feature..SkipTestCase.test_foo) "
"as SkippingTestCase.test_skip_if_db_feature..SkipTestCase "
"doesn't allow queries against the 'default' database.",
)
class SkippingClassTestCase(TransactionTestCase):
available_apps = []
def test_skip_class_unless_db_feature(self):
@skipUnlessDBFeature("__class__")
class NotSkippedTests(TestCase):
def test_dummy(self):
return
@skipUnlessDBFeature("missing")
@skipIfDBFeature("__class__")
class SkippedTests(TestCase):
def test_will_be_skipped(self):
self.fail("We should never arrive here.")
@skipIfDBFeature("__dict__")
class SkippedTestsSubclass(SkippedTests):
pass
test_suite = unittest.TestSuite()
test_suite.addTest(NotSkippedTests("test_dummy"))
try:
test_suite.addTest(SkippedTests("test_will_be_skipped"))
test_suite.addTest(SkippedTestsSubclass("test_will_be_skipped"))
except unittest.SkipTest:
self.fail("SkipTest should not be raised here.")
result = unittest.TextTestRunner(stream=StringIO()).run(test_suite)
# PY312: Python 3.12.1 does not include skipped tests in the number of
# running tests.
self.assertEqual(
result.testsRun, 1 if sys.version_info[:3] == (3, 12, 1) else 3
)
self.assertEqual(len(result.skipped), 2)
self.assertEqual(result.skipped[0][1], "Database has feature(s) __class__")
self.assertEqual(result.skipped[1][1], "Database has feature(s) __class__")
def test_missing_default_databases(self):
@skipIfDBFeature("missing")
class MissingDatabases(SimpleTestCase):
def test_assertion_error(self):
pass
suite = unittest.TestSuite()
try:
suite.addTest(MissingDatabases("test_assertion_error"))
except unittest.SkipTest:
self.fail("SkipTest should not be raised at this stage")
runner = unittest.TextTestRunner(stream=StringIO())
msg = (
"skipIfDBFeature cannot be used on ."
"MissingDatabases'> as it doesn't allow queries against the "
"'default' database."
)
with self.assertRaisesMessage(ValueError, msg):
runner.run(suite)
@override_settings(ROOT_URLCONF="test_utils.urls")
class AssertNumQueriesTests(TestCase):
def test_assert_num_queries(self):
def test_func():
raise ValueError
with self.assertRaises(ValueError):
self.assertNumQueries(2, test_func)
def test_assert_num_queries_with_client(self):
person = Person.objects.create(name="test")
self.assertNumQueries(
1, self.client.get, "/test_utils/get_person/%s/" % person.pk
)
self.assertNumQueries(
1, self.client.get, "/test_utils/get_person/%s/" % person.pk
)
def test_func():
self.client.get("/test_utils/get_person/%s/" % person.pk)
self.client.get("/test_utils/get_person/%s/" % person.pk)
self.assertNumQueries(2, test_func)
class AssertNumQueriesUponConnectionTests(TransactionTestCase):
available_apps = []
def test_ignores_connection_configuration_queries(self):
real_ensure_connection = connection.ensure_connection
connection.close()
def make_configuration_query():
is_opening_connection = connection.connection is None
real_ensure_connection()
if is_opening_connection:
# Avoid infinite recursion. Creating a cursor calls
# ensure_connection() which is currently mocked by this method.
with connection.cursor() as cursor:
cursor.execute("SELECT 1" + connection.features.bare_select_suffix)
ensure_connection = (
"django.db.backends.base.base.BaseDatabaseWrapper.ensure_connection"
)
with mock.patch(ensure_connection, side_effect=make_configuration_query):
with self.assertNumQueries(1):
list(Car.objects.all())
class AssertQuerySetEqualTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.p1 = Person.objects.create(name="p1")
cls.p2 = Person.objects.create(name="p2")
def test_empty(self):
self.assertQuerySetEqual(Person.objects.filter(name="p3"), [])
def test_ordered(self):
self.assertQuerySetEqual(
Person.objects.order_by("name"),
[self.p1, self.p2],
)
def test_unordered(self):
self.assertQuerySetEqual(
Person.objects.order_by("name"), [self.p2, self.p1], ordered=False
)
def test_queryset(self):
self.assertQuerySetEqual(
Person.objects.order_by("name"),
Person.objects.order_by("name"),
)
def test_flat_values_list(self):
self.assertQuerySetEqual(
Person.objects.order_by("name").values_list("name", flat=True),
["p1", "p2"],
)
def test_transform(self):
self.assertQuerySetEqual(
Person.objects.order_by("name"),
[self.p1.pk, self.p2.pk],
transform=lambda x: x.pk,
)
def test_repr_transform(self):
self.assertQuerySetEqual(
Person.objects.order_by("name"),
[repr(self.p1), repr(self.p2)],
transform=repr,
)
def test_undefined_order(self):
# Using an unordered queryset with more than one ordered value
# is an error.
msg = (
"Trying to compare non-ordered queryset against more than one "
"ordered value."
)
with self.assertRaisesMessage(ValueError, msg):
self.assertQuerySetEqual(
Person.objects.all(),
[self.p1, self.p2],
)
# No error for one value.
self.assertQuerySetEqual(Person.objects.filter(name="p1"), [self.p1])
def test_repeated_values(self):
"""
assertQuerySetEqual checks the number of appearance of each item
when used with option ordered=False.
"""
batmobile = Car.objects.create(name="Batmobile")
k2000 = Car.objects.create(name="K 2000")
PossessedCar.objects.bulk_create(
[
PossessedCar(car=batmobile, belongs_to=self.p1),
PossessedCar(car=batmobile, belongs_to=self.p1),
PossessedCar(car=k2000, belongs_to=self.p1),
PossessedCar(car=k2000, belongs_to=self.p1),
PossessedCar(car=k2000, belongs_to=self.p1),
PossessedCar(car=k2000, belongs_to=self.p1),
]
)
with self.assertRaises(AssertionError):
self.assertQuerySetEqual(
self.p1.cars.all(), [batmobile, k2000], ordered=False
)
self.assertQuerySetEqual(
self.p1.cars.all(), [batmobile] * 2 + [k2000] * 4, ordered=False
)
def test_maxdiff(self):
names = ["Joe Smith %s" % i for i in range(20)]
Person.objects.bulk_create([Person(name=name) for name in names])
names.append("Extra Person")
with self.assertRaises(AssertionError) as ctx:
self.assertQuerySetEqual(
Person.objects.filter(name__startswith="Joe"),
names,
ordered=False,
transform=lambda p: p.name,
)
self.assertIn("Set self.maxDiff to None to see it.", str(ctx.exception))
original = self.maxDiff
self.maxDiff = None
try:
with self.assertRaises(AssertionError) as ctx:
self.assertQuerySetEqual(
Person.objects.filter(name__startswith="Joe"),
names,
ordered=False,
transform=lambda p: p.name,
)
finally:
self.maxDiff = original
exception_msg = str(ctx.exception)
self.assertNotIn("Set self.maxDiff to None to see it.", exception_msg)
for name in names:
self.assertIn(name, exception_msg)
@override_settings(ROOT_URLCONF="test_utils.urls")
class CaptureQueriesContextManagerTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.person_pk = str(Person.objects.create(name="test").pk)
cls.url = f"/test_utils/get_person/{cls.person_pk}/"
def test_simple(self):
with CaptureQueriesContext(connection) as captured_queries:
Person.objects.get(pk=self.person_pk)
self.assertEqual(len(captured_queries), 1)
self.assertIn(self.person_pk, captured_queries[0]["sql"])
with CaptureQueriesContext(connection) as captured_queries:
pass
self.assertEqual(0, len(captured_queries))
def test_within(self):
with CaptureQueriesContext(connection) as captured_queries:
Person.objects.get(pk=self.person_pk)
self.assertEqual(len(captured_queries), 1)
self.assertIn(self.person_pk, captured_queries[0]["sql"])
def test_nested(self):
with CaptureQueriesContext(connection) as captured_queries:
Person.objects.count()
with CaptureQueriesContext(connection) as nested_captured_queries:
Person.objects.count()
self.assertEqual(1, len(nested_captured_queries))
self.assertEqual(2, len(captured_queries))
def test_failure(self):
with self.assertRaises(TypeError):
with CaptureQueriesContext(connection):
raise TypeError
def test_with_client(self):
with CaptureQueriesContext(connection) as captured_queries:
self.client.get(self.url)
self.assertEqual(len(captured_queries), 1)
self.assertIn(self.person_pk, captured_queries[0]["sql"])
with CaptureQueriesContext(connection) as captured_queries:
self.client.get(self.url)
self.assertEqual(len(captured_queries), 1)
self.assertIn(self.person_pk, captured_queries[0]["sql"])
with CaptureQueriesContext(connection) as captured_queries:
self.client.get(self.url)
self.client.get(self.url)
self.assertEqual(len(captured_queries), 2)
self.assertIn(self.person_pk, captured_queries[0]["sql"])
self.assertIn(self.person_pk, captured_queries[1]["sql"])
def test_with_client_nested(self):
with CaptureQueriesContext(connection) as captured_queries:
Person.objects.count()
with CaptureQueriesContext(connection):
pass
self.client.get(self.url)
self.assertEqual(2, len(captured_queries))
@override_settings(ROOT_URLCONF="test_utils.urls")
class AssertNumQueriesContextManagerTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.person_pk = str(Person.objects.create(name="test").pk)
cls.url = f"/test_utils/get_person/{cls.person_pk}/"
def test_simple(self):
with self.assertNumQueries(0):
pass
with self.assertNumQueries(1):
Person.objects.count()
with self.assertNumQueries(2):
Person.objects.count()
Person.objects.count()
def test_failure(self):
msg = "1 != 2 : 1 queries executed, 2 expected\nCaptured queries were:\n1."
with self.assertRaisesMessage(AssertionError, msg):
with self.assertNumQueries(2):
Person.objects.count()
with self.assertRaises(TypeError):
with self.assertNumQueries(4000):
raise TypeError
def test_with_client(self):
with self.assertNumQueries(1):
self.client.get(self.url)
with self.assertNumQueries(1):
self.client.get(self.url)
with self.assertNumQueries(2):
self.client.get(self.url)
self.client.get(self.url)
def test_with_client_nested(self):
with self.assertNumQueries(2):
Person.objects.count()
with self.assertNumQueries(0):
pass
self.client.get(self.url)
@override_settings(ROOT_URLCONF="test_utils.urls")
class AssertTemplateUsedContextManagerTests(SimpleTestCase):
def test_usage(self):
with self.assertTemplateUsed("template_used/base.html"):
render_to_string("template_used/base.html")
with self.assertTemplateUsed(template_name="template_used/base.html"):
render_to_string("template_used/base.html")
with self.assertTemplateUsed("template_used/base.html"):
render_to_string("template_used/include.html")
with self.assertTemplateUsed("template_used/base.html"):
render_to_string("template_used/extends.html")
with self.assertTemplateUsed("template_used/base.html"):
render_to_string("template_used/base.html")
render_to_string("template_used/base.html")
def test_nested_usage(self):
with self.assertTemplateUsed("template_used/base.html"):
with self.assertTemplateUsed("template_used/include.html"):
render_to_string("template_used/include.html")
with self.assertTemplateUsed("template_used/extends.html"):
with self.assertTemplateUsed("template_used/base.html"):
render_to_string("template_used/extends.html")
with self.assertTemplateUsed("template_used/base.html"):
with self.assertTemplateUsed("template_used/alternative.html"):
render_to_string("template_used/alternative.html")
render_to_string("template_used/base.html")
with self.assertTemplateUsed("template_used/base.html"):
render_to_string("template_used/extends.html")
with self.assertTemplateNotUsed("template_used/base.html"):
render_to_string("template_used/alternative.html")
render_to_string("template_used/base.html")
def test_not_used(self):
with self.assertTemplateNotUsed("template_used/base.html"):
pass
with self.assertTemplateNotUsed("template_used/alternative.html"):
pass
def test_error_message_no_template_used(self):
msg = "No templates used to render the response"
with self.assertRaisesMessage(AssertionError, msg):
with self.assertTemplateUsed("template_used/base.html"):
pass
with self.assertRaisesMessage(AssertionError, msg):
with self.assertTemplateUsed(template_name="template_used/base.html"):
pass
with self.assertRaisesMessage(AssertionError, msg):
response = self.client.get("/test_utils/no_template_used/")
self.assertTemplateUsed(response, "template_used/base.html")
with self.assertRaisesMessage(AssertionError, msg):
with self.assertTemplateUsed("template_used/base.html"):
self.client.get("/test_utils/no_template_used/")
with self.assertRaisesMessage(AssertionError, msg):
with self.assertTemplateUsed("template_used/base.html"):
template = Template("template_used/alternative.html", name=None)
template.render(Context())
def test_error_message_unexpected_template_used(self):
msg = (
"Template 'template_used/base.html' was not a template used to render "
"the response. Actual template(s) used: template_used/alternative.html"
)
with self.assertRaisesMessage(AssertionError, msg):
with self.assertTemplateUsed("template_used/base.html"):
render_to_string("template_used/alternative.html")
def test_msg_prefix(self):
msg_prefix = "Prefix"
msg = f"{msg_prefix}: No templates used to render the response"
with self.assertRaisesMessage(AssertionError, msg):
with self.assertTemplateUsed(
"template_used/base.html", msg_prefix=msg_prefix
):
pass
with self.assertRaisesMessage(AssertionError, msg):
with self.assertTemplateUsed(
template_name="template_used/base.html",
msg_prefix=msg_prefix,
):
pass
msg = (
f"{msg_prefix}: Template 'template_used/base.html' was not a "
f"template used to render the response. Actual template(s) used: "
f"template_used/alternative.html"
)
with self.assertRaisesMessage(AssertionError, msg):
with self.assertTemplateUsed(
"template_used/base.html", msg_prefix=msg_prefix
):
render_to_string("template_used/alternative.html")
def test_count(self):
with self.assertTemplateUsed("template_used/base.html", count=2):
render_to_string("template_used/base.html")
render_to_string("template_used/base.html")
msg = (
"Template 'template_used/base.html' was expected to be rendered "
"3 time(s) but was actually rendered 2 time(s)."
)
with self.assertRaisesMessage(AssertionError, msg):
with self.assertTemplateUsed("template_used/base.html", count=3):
render_to_string("template_used/base.html")
render_to_string("template_used/base.html")
def test_failure(self):
msg = "response and/or template_name argument must be provided"
with self.assertRaisesMessage(TypeError, msg):
with self.assertTemplateUsed():
pass
msg = "No templates used to render the response"
with self.assertRaisesMessage(AssertionError, msg):
with self.assertTemplateUsed(""):
pass
with self.assertRaisesMessage(AssertionError, msg):
with self.assertTemplateUsed(""):
render_to_string("template_used/base.html")
with self.assertRaisesMessage(AssertionError, msg):
with self.assertTemplateUsed(template_name=""):
pass
msg = (
"Template 'template_used/base.html' was not a template used to "
"render the response. Actual template(s) used: "
"template_used/alternative.html"
)
with self.assertRaisesMessage(AssertionError, msg):
with self.assertTemplateUsed("template_used/base.html"):
render_to_string("template_used/alternative.html")
def test_assert_used_on_http_response(self):
response = HttpResponse()
msg = "%s() is only usable on responses fetched using the Django test Client."
with self.assertRaisesMessage(ValueError, msg % "assertTemplateUsed"):
self.assertTemplateUsed(response, "template.html")
with self.assertRaisesMessage(ValueError, msg % "assertTemplateNotUsed"):
self.assertTemplateNotUsed(response, "template.html")
@override_settings(ROOT_URLCONF="test_utils.urls")
class AssertTemplateUsedPartialTests(SimpleTestCase):
def test_template_used_pass(self):
with self.assertTemplateUsed("template_used/partials.html#hello"):
render_to_string("template_used/partials.html#hello")
def test_template_not_used_pass(self):
with self.assertTemplateNotUsed("hello"):
render_to_string("template_used/partials.html#hello")
def test_template_used_fail(self):
msg = "Template 'hello' was not a template used to render the response."
with (
self.assertRaisesMessage(AssertionError, msg),
self.assertTemplateUsed("hello"),
):
render_to_string("template_used/base.html")
def test_template_not_used_fail(self):
msg = (
"Template 'template_used/partials.html#hello' was used "
"unexpectedly in rendering the response"
)
with (
self.assertRaisesMessage(AssertionError, msg),
self.assertTemplateNotUsed("template_used/partials.html#hello"),
):
render_to_string("template_used/partials.html#hello")
def test_template_not_used_pass_non_partial(self):
with self.assertTemplateNotUsed(
"template_used/base.html#template_used/base.html"
):
render_to_string("template_used/base.html")
def test_template_used_fail_non_partial(self):
msg = (
"Template 'template_used/base.html#template_used/base.html' was not a "
"template used to render the response."
)
with (
self.assertRaisesMessage(AssertionError, msg),
self.assertTemplateUsed("template_used/base.html#template_used/base.html"),
):
render_to_string("template_used/base.html")
class HTMLEqualTests(SimpleTestCase):
def test_html_parser(self):
element = parse_html("
")
self.assertEqual(dom2.count(dom1), 0)
# HTML with a root element contains the same HTML with no root element.
dom1 = parse_html("
foo
bar
")
dom2 = parse_html("
foo
bar
")
self.assertEqual(dom2.count(dom1), 1)
# Target of search is a sequence of child elements and appears more
# than once.
dom2 = parse_html("
foo
bar
foo
bar
")
self.assertEqual(dom2.count(dom1), 2)
# Searched HTML has additional children.
dom1 = parse_html("")
dom2 = parse_html("")
self.assertEqual(dom2.count(dom1), 1)
# No match found in children.
dom1 = parse_html("")
self.assertEqual(dom2.count(dom1), 0)
# Target of search found among children and grandchildren.
dom1 = parse_html("")
dom2 = parse_html("")
self.assertEqual(dom2.count(dom1), 2)
def test_root_element_escaped_html(self):
html = "<br>"
parsed = parse_html(html)
self.assertEqual(str(parsed), html)
def test_parsing_errors(self):
with self.assertRaises(AssertionError):
self.assertHTMLEqual("
", "")
with self.assertRaises(AssertionError):
self.assertHTMLEqual("", "
")
error_msg = (
"First argument is not valid HTML:\n"
"('Unexpected end tag `div` (Line 1, Column 6)', (1, 6))"
)
with self.assertRaisesMessage(AssertionError, error_msg):
self.assertHTMLEqual("< div>", "
")
with self.assertRaises(HTMLParseError):
parse_html("")
def test_escaped_html_errors(self):
msg = "
\n\n
!=
\n<foo>\n
\n"
with self.assertRaisesMessage(AssertionError, msg):
self.assertHTMLEqual("
", "
<foo>
")
with self.assertRaisesMessage(AssertionError, msg):
self.assertHTMLEqual("
", "
<foo>
")
def test_contains_html(self):
response = HttpResponse("""
This is a form: """)
self.assertNotContains(response, "")
self.assertContains(response, '