summaryrefslogtreecommitdiff
path: root/django
diff options
context:
space:
mode:
authorTom <tom@tomforb.es>2017-05-22 11:49:39 +0100
committerTim Graham <timograham@gmail.com>2018-02-06 09:03:43 -0500
commit272f685794de0b8dead220ee57b30e65c9aa097c (patch)
treea3b3efebeba4df9eb9a6a8d865b0e0fb407e7c85 /django
parent0f0a07ac278dc2be6da81e519188f77e2a2a00cf (diff)
Fixed #27999 -- Added test client support for HTTP 307 and 308 redirects.
Diffstat (limited to 'django')
-rw-r--r--django/test/client.py38
1 files changed, 27 insertions, 11 deletions
diff --git a/django/test/client.py b/django/test/client.py
index d69c33f1bd..9fce782dd4 100644
--- a/django/test/client.py
+++ b/django/test/client.py
@@ -5,6 +5,7 @@ import re
import sys
from copy import copy
from functools import partial
+from http import HTTPStatus
from importlib import import_module
from io import BytesIO
from urllib.parse import unquote_to_bytes, urljoin, urlparse, urlsplit
@@ -512,7 +513,7 @@ class Client(RequestFactory):
"""Request a response from the server using GET."""
response = super().get(path, data=data, secure=secure, **extra)
if follow:
- response = self._handle_redirects(response, **extra)
+ response = self._handle_redirects(response, data=data, **extra)
return response
def post(self, path, data=None, content_type=MULTIPART_CONTENT,
@@ -520,14 +521,14 @@ class Client(RequestFactory):
"""Request a response from the server using POST."""
response = super().post(path, data=data, content_type=content_type, secure=secure, **extra)
if follow:
- response = self._handle_redirects(response, **extra)
+ response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
return response
def head(self, path, data=None, follow=False, secure=False, **extra):
"""Request a response from the server using HEAD."""
response = super().head(path, data=data, secure=secure, **extra)
if follow:
- response = self._handle_redirects(response, **extra)
+ response = self._handle_redirects(response, data=data, **extra)
return response
def options(self, path, data='', content_type='application/octet-stream',
@@ -535,7 +536,7 @@ class Client(RequestFactory):
"""Request a response from the server using OPTIONS."""
response = super().options(path, data=data, content_type=content_type, secure=secure, **extra)
if follow:
- response = self._handle_redirects(response, **extra)
+ response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
return response
def put(self, path, data='', content_type='application/octet-stream',
@@ -543,7 +544,7 @@ class Client(RequestFactory):
"""Send a resource to the server using PUT."""
response = super().put(path, data=data, content_type=content_type, secure=secure, **extra)
if follow:
- response = self._handle_redirects(response, **extra)
+ response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
return response
def patch(self, path, data='', content_type='application/octet-stream',
@@ -551,7 +552,7 @@ class Client(RequestFactory):
"""Send a resource to the server using PATCH."""
response = super().patch(path, data=data, content_type=content_type, secure=secure, **extra)
if follow:
- response = self._handle_redirects(response, **extra)
+ response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
return response
def delete(self, path, data='', content_type='application/octet-stream',
@@ -559,14 +560,14 @@ class Client(RequestFactory):
"""Send a DELETE request to the server."""
response = super().delete(path, data=data, content_type=content_type, secure=secure, **extra)
if follow:
- response = self._handle_redirects(response, **extra)
+ response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
return response
def trace(self, path, data='', follow=False, secure=False, **extra):
"""Send a TRACE request to the server."""
response = super().trace(path, data=data, secure=secure, **extra)
if follow:
- response = self._handle_redirects(response, **extra)
+ response = self._handle_redirects(response, data=data, **extra)
return response
def login(self, **credentials):
@@ -648,12 +649,19 @@ class Client(RequestFactory):
response._json = json.loads(response.content.decode(), **extra)
return response._json
- def _handle_redirects(self, response, **extra):
+ def _handle_redirects(self, response, data='', content_type='', **extra):
"""
Follow any redirects by requesting responses from the server using GET.
"""
response.redirect_chain = []
- while response.status_code in (301, 302, 303, 307):
+ redirect_status_codes = (
+ HTTPStatus.MOVED_PERMANENTLY,
+ HTTPStatus.FOUND,
+ HTTPStatus.SEE_OTHER,
+ HTTPStatus.TEMPORARY_REDIRECT,
+ HTTPStatus.PERMANENT_REDIRECT,
+ )
+ while response.status_code in redirect_status_codes:
response_url = response.url
redirect_chain = response.redirect_chain
redirect_chain.append((response_url, response.status_code))
@@ -671,7 +679,15 @@ class Client(RequestFactory):
if not path.startswith('/'):
path = urljoin(response.request['PATH_INFO'], path)
- response = self.get(path, QueryDict(url.query), follow=False, **extra)
+ if response.status_code in (HTTPStatus.TEMPORARY_REDIRECT, HTTPStatus.PERMANENT_REDIRECT):
+ # Preserve request method post-redirect for 307/308 responses.
+ request_method = getattr(self, response.request['REQUEST_METHOD'].lower())
+ else:
+ request_method = self.get
+ data = QueryDict(url.query)
+ content_type = None
+
+ response = request_method(path, data=data, content_type=content_type, follow=False, **extra)
response.redirect_chain = redirect_chain
if redirect_chain[-1] in redirect_chain[:-1]: