diff options
| author | Patrick Jenkins <me@patjenk.com> | 2017-08-17 17:10:10 -0700 |
|---|---|---|
| committer | Mariusz Felisiak <felisiak.mariusz@gmail.com> | 2019-10-18 08:42:51 +0200 |
| commit | 46e74a525671f2eef09021c06933c45b49f9d421 (patch) | |
| tree | c7f82abf885b0c6f2a410070a93bd574e951f8a2 /django/test | |
| parent | 3ca9df51c779510fbbfe296ca95a127d1dec2f87 (diff) | |
Fixed #28337 -- Preserved extra headers of requests made with django.test.Client in assertRedirects().
Co-Authored-By: Hasan Ramezani <hasan.r67@gmail.com>
Diffstat (limited to 'django/test')
| -rw-r--r-- | django/test/client.py | 9 | ||||
| -rw-r--r-- | django/test/testcases.py | 9 |
2 files changed, 16 insertions, 2 deletions
diff --git a/django/test/client.py b/django/test/client.py index 9b3fdf5936..98ede36499 100644 --- a/django/test/client.py +++ b/django/test/client.py @@ -444,6 +444,7 @@ class Client(RequestFactory): self.handler = ClientHandler(enforce_csrf_checks) self.raise_request_exception = raise_request_exception self.exc_info = None + self.extra = None def store_exc_info(self, **kwargs): """Store exceptions when they are generated by a view.""" @@ -515,6 +516,7 @@ class Client(RequestFactory): def get(self, path, data=None, follow=False, secure=False, **extra): """Request a response from the server using GET.""" + self.extra = extra response = super().get(path, data=data, secure=secure, **extra) if follow: response = self._handle_redirects(response, data=data, **extra) @@ -523,6 +525,7 @@ class Client(RequestFactory): def post(self, path, data=None, content_type=MULTIPART_CONTENT, follow=False, secure=False, **extra): """Request a response from the server using POST.""" + self.extra = extra response = super().post(path, data=data, content_type=content_type, secure=secure, **extra) if follow: response = self._handle_redirects(response, data=data, content_type=content_type, **extra) @@ -530,6 +533,7 @@ class Client(RequestFactory): def head(self, path, data=None, follow=False, secure=False, **extra): """Request a response from the server using HEAD.""" + self.extra = extra response = super().head(path, data=data, secure=secure, **extra) if follow: response = self._handle_redirects(response, data=data, **extra) @@ -538,6 +542,7 @@ class Client(RequestFactory): def options(self, path, data='', content_type='application/octet-stream', follow=False, secure=False, **extra): """Request a response from the server using OPTIONS.""" + self.extra = extra response = super().options(path, data=data, content_type=content_type, secure=secure, **extra) if follow: response = self._handle_redirects(response, data=data, content_type=content_type, **extra) @@ -546,6 +551,7 @@ class Client(RequestFactory): def put(self, path, data='', content_type='application/octet-stream', follow=False, secure=False, **extra): """Send a resource to the server using PUT.""" + self.extra = extra response = super().put(path, data=data, content_type=content_type, secure=secure, **extra) if follow: response = self._handle_redirects(response, data=data, content_type=content_type, **extra) @@ -554,6 +560,7 @@ class Client(RequestFactory): def patch(self, path, data='', content_type='application/octet-stream', follow=False, secure=False, **extra): """Send a resource to the server using PATCH.""" + self.extra = extra response = super().patch(path, data=data, content_type=content_type, secure=secure, **extra) if follow: response = self._handle_redirects(response, data=data, content_type=content_type, **extra) @@ -562,6 +569,7 @@ class Client(RequestFactory): def delete(self, path, data='', content_type='application/octet-stream', follow=False, secure=False, **extra): """Send a DELETE request to the server.""" + self.extra = extra response = super().delete(path, data=data, content_type=content_type, secure=secure, **extra) if follow: response = self._handle_redirects(response, data=data, content_type=content_type, **extra) @@ -569,6 +577,7 @@ class Client(RequestFactory): def trace(self, path, data='', follow=False, secure=False, **extra): """Send a TRACE request to the server.""" + self.extra = extra response = super().trace(path, data=data, secure=secure, **extra) if follow: response = self._handle_redirects(response, data=data, **extra) diff --git a/django/test/testcases.py b/django/test/testcases.py index 1a7414b91d..468c0c4fbc 100644 --- a/django/test/testcases.py +++ b/django/test/testcases.py @@ -347,10 +347,15 @@ class SimpleTestCase(unittest.TestCase): "Otherwise, use assertRedirects(..., fetch_redirect_response=False)." % (url, domain) ) - redirect_response = response.client.get(path, QueryDict(query), secure=(scheme == 'https')) - # Get the redirection page, using the same client that was used # to obtain the original response. + extra = response.client.extra or {} + redirect_response = response.client.get( + path, + QueryDict(query), + secure=(scheme == 'https'), + **extra, + ) self.assertEqual( redirect_response.status_code, target_status_code, msg_prefix + "Couldn't retrieve redirection page '%s': response code was %d (expected %d)" |
