summaryrefslogtreecommitdiff
path: root/django/test
diff options
context:
space:
mode:
authorPatrick Jenkins <me@patjenk.com>2017-08-17 17:10:10 -0700
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2019-10-18 08:42:51 +0200
commit46e74a525671f2eef09021c06933c45b49f9d421 (patch)
treec7f82abf885b0c6f2a410070a93bd574e951f8a2 /django/test
parent3ca9df51c779510fbbfe296ca95a127d1dec2f87 (diff)
Fixed #28337 -- Preserved extra headers of requests made with django.test.Client in assertRedirects().
Co-Authored-By: Hasan Ramezani <hasan.r67@gmail.com>
Diffstat (limited to 'django/test')
-rw-r--r--django/test/client.py9
-rw-r--r--django/test/testcases.py9
2 files changed, 16 insertions, 2 deletions
diff --git a/django/test/client.py b/django/test/client.py
index 9b3fdf5936..98ede36499 100644
--- a/django/test/client.py
+++ b/django/test/client.py
@@ -444,6 +444,7 @@ class Client(RequestFactory):
self.handler = ClientHandler(enforce_csrf_checks)
self.raise_request_exception = raise_request_exception
self.exc_info = None
+ self.extra = None
def store_exc_info(self, **kwargs):
"""Store exceptions when they are generated by a view."""
@@ -515,6 +516,7 @@ class Client(RequestFactory):
def get(self, path, data=None, follow=False, secure=False, **extra):
"""Request a response from the server using GET."""
+ self.extra = extra
response = super().get(path, data=data, secure=secure, **extra)
if follow:
response = self._handle_redirects(response, data=data, **extra)
@@ -523,6 +525,7 @@ class Client(RequestFactory):
def post(self, path, data=None, content_type=MULTIPART_CONTENT,
follow=False, secure=False, **extra):
"""Request a response from the server using POST."""
+ self.extra = extra
response = super().post(path, data=data, content_type=content_type, secure=secure, **extra)
if follow:
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
@@ -530,6 +533,7 @@ class Client(RequestFactory):
def head(self, path, data=None, follow=False, secure=False, **extra):
"""Request a response from the server using HEAD."""
+ self.extra = extra
response = super().head(path, data=data, secure=secure, **extra)
if follow:
response = self._handle_redirects(response, data=data, **extra)
@@ -538,6 +542,7 @@ class Client(RequestFactory):
def options(self, path, data='', content_type='application/octet-stream',
follow=False, secure=False, **extra):
"""Request a response from the server using OPTIONS."""
+ self.extra = extra
response = super().options(path, data=data, content_type=content_type, secure=secure, **extra)
if follow:
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
@@ -546,6 +551,7 @@ class Client(RequestFactory):
def put(self, path, data='', content_type='application/octet-stream',
follow=False, secure=False, **extra):
"""Send a resource to the server using PUT."""
+ self.extra = extra
response = super().put(path, data=data, content_type=content_type, secure=secure, **extra)
if follow:
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
@@ -554,6 +560,7 @@ class Client(RequestFactory):
def patch(self, path, data='', content_type='application/octet-stream',
follow=False, secure=False, **extra):
"""Send a resource to the server using PATCH."""
+ self.extra = extra
response = super().patch(path, data=data, content_type=content_type, secure=secure, **extra)
if follow:
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
@@ -562,6 +569,7 @@ class Client(RequestFactory):
def delete(self, path, data='', content_type='application/octet-stream',
follow=False, secure=False, **extra):
"""Send a DELETE request to the server."""
+ self.extra = extra
response = super().delete(path, data=data, content_type=content_type, secure=secure, **extra)
if follow:
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
@@ -569,6 +577,7 @@ class Client(RequestFactory):
def trace(self, path, data='', follow=False, secure=False, **extra):
"""Send a TRACE request to the server."""
+ self.extra = extra
response = super().trace(path, data=data, secure=secure, **extra)
if follow:
response = self._handle_redirects(response, data=data, **extra)
diff --git a/django/test/testcases.py b/django/test/testcases.py
index 1a7414b91d..468c0c4fbc 100644
--- a/django/test/testcases.py
+++ b/django/test/testcases.py
@@ -347,10 +347,15 @@ class SimpleTestCase(unittest.TestCase):
"Otherwise, use assertRedirects(..., fetch_redirect_response=False)."
% (url, domain)
)
- redirect_response = response.client.get(path, QueryDict(query), secure=(scheme == 'https'))
-
# Get the redirection page, using the same client that was used
# to obtain the original response.
+ extra = response.client.extra or {}
+ redirect_response = response.client.get(
+ path,
+ QueryDict(query),
+ secure=(scheme == 'https'),
+ **extra,
+ )
self.assertEqual(
redirect_response.status_code, target_status_code,
msg_prefix + "Couldn't retrieve redirection page '%s': response code was %d (expected %d)"