diff options
Diffstat (limited to 'django/http/request.py')
| -rw-r--r-- | django/http/request.py | 246 |
1 files changed, 146 insertions, 100 deletions
diff --git a/django/http/request.py b/django/http/request.py index 5971203261..d975aadf25 100644 --- a/django/http/request.py +++ b/django/http/request.py @@ -8,12 +8,17 @@ from urllib.parse import parse_qsl, quote, urlencode, urljoin, urlsplit from django.conf import settings from django.core import signing from django.core.exceptions import ( - DisallowedHost, ImproperlyConfigured, RequestDataTooBig, TooManyFieldsSent, + DisallowedHost, + ImproperlyConfigured, + RequestDataTooBig, + TooManyFieldsSent, ) from django.core.files import uploadhandler from django.http.multipartparser import MultiPartParser, MultiPartParserError from django.utils.datastructures import ( - CaseInsensitiveMapping, ImmutableList, MultiValueDict, + CaseInsensitiveMapping, + ImmutableList, + MultiValueDict, ) from django.utils.encoding import escape_uri_path, iri_to_uri from django.utils.functional import cached_property @@ -23,7 +28,9 @@ from django.utils.regex_helper import _lazy_re_compile from .multipartparser import parse_header RAISE_ERROR = object() -host_validation_re = _lazy_re_compile(r"^([a-z0-9.-]+|\[[a-f0-9]*:[a-f0-9\.:]+\])(:[0-9]+)?$") +host_validation_re = _lazy_re_compile( + r"^([a-z0-9.-]+|\[[a-f0-9]*:[a-f0-9\.:]+\])(:[0-9]+)?$" +) class UnreadablePostError(OSError): @@ -36,6 +43,7 @@ class RawPostDataException(Exception): multipart/* POST data if it has been accessed via POST, FILES, etc.. """ + pass @@ -57,8 +65,8 @@ class HttpRequest: self.META = {} self.FILES = MultiValueDict() - self.path = '' - self.path_info = '' + self.path = "" + self.path_info = "" self.method = None self.resolver_match = None self.content_type = None @@ -66,8 +74,12 @@ class HttpRequest: def __repr__(self): if self.method is None or not self.get_full_path(): - return '<%s>' % self.__class__.__name__ - return '<%s: %s %r>' % (self.__class__.__name__, self.method, self.get_full_path()) + return "<%s>" % self.__class__.__name__ + return "<%s: %s %r>" % ( + self.__class__.__name__, + self.method, + self.get_full_path(), + ) @cached_property def headers(self): @@ -76,24 +88,25 @@ class HttpRequest: @cached_property def accepted_types(self): """Return a list of MediaType instances.""" - return parse_accept_header(self.headers.get('Accept', '*/*')) + return parse_accept_header(self.headers.get("Accept", "*/*")) def accepts(self, media_type): return any( - accepted_type.match(media_type) - for accepted_type in self.accepted_types + accepted_type.match(media_type) for accepted_type in self.accepted_types ) def _set_content_type_params(self, meta): """Set content_type, content_params, and encoding.""" - self.content_type, self.content_params = cgi.parse_header(meta.get('CONTENT_TYPE', '')) - if 'charset' in self.content_params: + self.content_type, self.content_params = cgi.parse_header( + meta.get("CONTENT_TYPE", "") + ) + if "charset" in self.content_params: try: - codecs.lookup(self.content_params['charset']) + codecs.lookup(self.content_params["charset"]) except LookupError: pass else: - self.encoding = self.content_params['charset'] + self.encoding = self.content_params["charset"] def _get_raw_host(self): """ @@ -101,17 +114,16 @@ class HttpRequest: allowed hosts protection, so may return an insecure host. """ # We try three options, in order of decreasing preference. - if settings.USE_X_FORWARDED_HOST and ( - 'HTTP_X_FORWARDED_HOST' in self.META): - host = self.META['HTTP_X_FORWARDED_HOST'] - elif 'HTTP_HOST' in self.META: - host = self.META['HTTP_HOST'] + if settings.USE_X_FORWARDED_HOST and ("HTTP_X_FORWARDED_HOST" in self.META): + host = self.META["HTTP_X_FORWARDED_HOST"] + elif "HTTP_HOST" in self.META: + host = self.META["HTTP_HOST"] else: # Reconstruct the host using the algorithm from PEP 333. - host = self.META['SERVER_NAME'] + host = self.META["SERVER_NAME"] server_port = self.get_port() - if server_port != ('443' if self.is_secure() else '80'): - host = '%s:%s' % (host, server_port) + if server_port != ("443" if self.is_secure() else "80"): + host = "%s:%s" % (host, server_port) return host def get_host(self): @@ -121,7 +133,7 @@ class HttpRequest: # Allow variants of localhost if ALLOWED_HOSTS is empty and DEBUG=True. allowed_hosts = settings.ALLOWED_HOSTS if settings.DEBUG and not allowed_hosts: - allowed_hosts = ['.localhost', '127.0.0.1', '[::1]'] + allowed_hosts = [".localhost", "127.0.0.1", "[::1]"] domain, port = split_domain_port(host) if domain and validate_host(domain, allowed_hosts): @@ -131,15 +143,17 @@ class HttpRequest: if domain: msg += " You may need to add %r to ALLOWED_HOSTS." % domain else: - msg += " The domain name provided is not valid according to RFC 1034/1035." + msg += ( + " The domain name provided is not valid according to RFC 1034/1035." + ) raise DisallowedHost(msg) def get_port(self): """Return the port number for the request as a string.""" - if settings.USE_X_FORWARDED_PORT and 'HTTP_X_FORWARDED_PORT' in self.META: - port = self.META['HTTP_X_FORWARDED_PORT'] + if settings.USE_X_FORWARDED_PORT and "HTTP_X_FORWARDED_PORT" in self.META: + port = self.META["HTTP_X_FORWARDED_PORT"] else: - port = self.META['SERVER_PORT'] + port = self.META["SERVER_PORT"] return str(port) def get_full_path(self, force_append_slash=False): @@ -151,13 +165,15 @@ class HttpRequest: def _get_full_path(self, path, force_append_slash): # RFC 3986 requires query string arguments to be in the ASCII range. # Rather than crash if this doesn't happen, we encode defensively. - return '%s%s%s' % ( + return "%s%s%s" % ( escape_uri_path(path), - '/' if force_append_slash and not path.endswith('/') else '', - ('?' + iri_to_uri(self.META.get('QUERY_STRING', ''))) if self.META.get('QUERY_STRING', '') else '' + "/" if force_append_slash and not path.endswith("/") else "", + ("?" + iri_to_uri(self.META.get("QUERY_STRING", ""))) + if self.META.get("QUERY_STRING", "") + else "", ) - def get_signed_cookie(self, key, default=RAISE_ERROR, salt='', max_age=None): + def get_signed_cookie(self, key, default=RAISE_ERROR, salt="", max_age=None): """ Attempt to return a signed cookie. If the signature fails or the cookie has expired, raise an exception, unless the `default` argument @@ -172,7 +188,8 @@ class HttpRequest: raise try: value = signing.get_cookie_signer(salt=key + salt).unsign( - cookie_value, max_age=max_age) + cookie_value, max_age=max_age + ) except signing.BadSignature: if default is not RAISE_ERROR: return default @@ -192,7 +209,7 @@ class HttpRequest: if location is None: # Make it an absolute url (but schemeless and domainless) for the # edge case that the path starts with '//'. - location = '//%s' % self.get_full_path() + location = "//%s" % self.get_full_path() else: # Coerce lazy locations. location = str(location) @@ -201,12 +218,17 @@ class HttpRequest: # Handle the simple, most common case. If the location is absolute # and a scheme or host (netloc) isn't provided, skip an expensive # urljoin() as long as no path segments are '.' or '..'. - if (bits.path.startswith('/') and not bits.scheme and not bits.netloc and - '/./' not in bits.path and '/../' not in bits.path): + if ( + bits.path.startswith("/") + and not bits.scheme + and not bits.netloc + and "/./" not in bits.path + and "/../" not in bits.path + ): # If location starts with '//' but has no netloc, reuse the # schema and netloc from the current request. Strip the double # slashes and continue as if it wasn't specified. - if location.startswith('//'): + if location.startswith("//"): location = location[2:] location = self._current_scheme_host + location else: @@ -218,14 +240,14 @@ class HttpRequest: @cached_property def _current_scheme_host(self): - return '{}://{}'.format(self.scheme, self.get_host()) + return "{}://{}".format(self.scheme, self.get_host()) def _get_scheme(self): """ Hook for subclasses like WSGIRequest to implement. Return 'http' by default. """ - return 'http' + return "http" @property def scheme(self): @@ -234,15 +256,15 @@ class HttpRequest: header, secure_value = settings.SECURE_PROXY_SSL_HEADER except ValueError: raise ImproperlyConfigured( - 'The SECURE_PROXY_SSL_HEADER setting must be a tuple containing two values.' + "The SECURE_PROXY_SSL_HEADER setting must be a tuple containing two values." ) header_value = self.META.get(header) if header_value is not None: - return 'https' if header_value == secure_value else 'http' + return "https" if header_value == secure_value else "http" return self._get_scheme() def is_secure(self): - return self.scheme == 'https' + return self.scheme == "https" @property def encoding(self): @@ -256,14 +278,16 @@ class HttpRequest: next access (so that it is decoded correctly). """ self._encoding = val - if hasattr(self, 'GET'): + if hasattr(self, "GET"): del self.GET - if hasattr(self, '_post'): + if hasattr(self, "_post"): del self._post def _initialize_handlers(self): - self._upload_handlers = [uploadhandler.load_handler(handler, self) - for handler in settings.FILE_UPLOAD_HANDLERS] + self._upload_handlers = [ + uploadhandler.load_handler(handler, self) + for handler in settings.FILE_UPLOAD_HANDLERS + ] @property def upload_handlers(self): @@ -274,29 +298,38 @@ class HttpRequest: @upload_handlers.setter def upload_handlers(self, upload_handlers): - if hasattr(self, '_files'): - raise AttributeError("You cannot set the upload handlers after the upload has been processed.") + if hasattr(self, "_files"): + raise AttributeError( + "You cannot set the upload handlers after the upload has been processed." + ) self._upload_handlers = upload_handlers def parse_file_upload(self, META, post_data): """Return a tuple of (POST QueryDict, FILES MultiValueDict).""" self.upload_handlers = ImmutableList( self.upload_handlers, - warning="You cannot alter upload handlers after the upload has been processed." + warning="You cannot alter upload handlers after the upload has been processed.", ) parser = MultiPartParser(META, post_data, self.upload_handlers, self.encoding) return parser.parse() @property def body(self): - if not hasattr(self, '_body'): + if not hasattr(self, "_body"): if self._read_started: - raise RawPostDataException("You cannot access body after reading from request's data stream") + raise RawPostDataException( + "You cannot access body after reading from request's data stream" + ) # Limit the maximum request data size that will be handled in-memory. - if (settings.DATA_UPLOAD_MAX_MEMORY_SIZE is not None and - int(self.META.get('CONTENT_LENGTH') or 0) > settings.DATA_UPLOAD_MAX_MEMORY_SIZE): - raise RequestDataTooBig('Request body exceeded settings.DATA_UPLOAD_MAX_MEMORY_SIZE.') + if ( + settings.DATA_UPLOAD_MAX_MEMORY_SIZE is not None + and int(self.META.get("CONTENT_LENGTH") or 0) + > settings.DATA_UPLOAD_MAX_MEMORY_SIZE + ): + raise RequestDataTooBig( + "Request body exceeded settings.DATA_UPLOAD_MAX_MEMORY_SIZE." + ) try: self._body = self.read() @@ -311,15 +344,18 @@ class HttpRequest: def _load_post_and_files(self): """Populate self._post and self._files if the content-type is a form type""" - if self.method != 'POST': - self._post, self._files = QueryDict(encoding=self._encoding), MultiValueDict() + if self.method != "POST": + self._post, self._files = ( + QueryDict(encoding=self._encoding), + MultiValueDict(), + ) return - if self._read_started and not hasattr(self, '_body'): + if self._read_started and not hasattr(self, "_body"): self._mark_post_parse_error() return - if self.content_type == 'multipart/form-data': - if hasattr(self, '_body'): + if self.content_type == "multipart/form-data": + if hasattr(self, "_body"): # Use already read data data = BytesIO(self._body) else: @@ -333,13 +369,19 @@ class HttpRequest: # attempts to parse POST data again. self._mark_post_parse_error() raise - elif self.content_type == 'application/x-www-form-urlencoded': - self._post, self._files = QueryDict(self.body, encoding=self._encoding), MultiValueDict() + elif self.content_type == "application/x-www-form-urlencoded": + self._post, self._files = ( + QueryDict(self.body, encoding=self._encoding), + MultiValueDict(), + ) else: - self._post, self._files = QueryDict(encoding=self._encoding), MultiValueDict() + self._post, self._files = ( + QueryDict(encoding=self._encoding), + MultiValueDict(), + ) def close(self): - if hasattr(self, '_files'): + if hasattr(self, "_files"): for f in chain.from_iterable(list_[1] for list_ in self._files.lists()): f.close() @@ -366,16 +408,16 @@ class HttpRequest: raise UnreadablePostError(*e.args) from e def __iter__(self): - return iter(self.readline, b'') + return iter(self.readline, b"") def readlines(self): return list(self) class HttpHeaders(CaseInsensitiveMapping): - HTTP_PREFIX = 'HTTP_' + HTTP_PREFIX = "HTTP_" # PEP 333 gives two headers which aren't prepended with HTTP_. - UNPREFIXED_HEADERS = {'CONTENT_TYPE', 'CONTENT_LENGTH'} + UNPREFIXED_HEADERS = {"CONTENT_TYPE", "CONTENT_LENGTH"} def __init__(self, environ): headers = {} @@ -387,15 +429,15 @@ class HttpHeaders(CaseInsensitiveMapping): def __getitem__(self, key): """Allow header lookup using underscores in place of hyphens.""" - return super().__getitem__(key.replace('_', '-')) + return super().__getitem__(key.replace("_", "-")) @classmethod def parse_header_name(cls, header): if header.startswith(cls.HTTP_PREFIX): - header = header[len(cls.HTTP_PREFIX):] + header = header[len(cls.HTTP_PREFIX) :] elif header not in cls.UNPREFIXED_HEADERS: return None - return header.replace('_', '-').title() + return header.replace("_", "-").title() class QueryDict(MultiValueDict): @@ -421,11 +463,11 @@ class QueryDict(MultiValueDict): def __init__(self, query_string=None, mutable=False, encoding=None): super().__init__() self.encoding = encoding or settings.DEFAULT_CHARSET - query_string = query_string or '' + query_string = query_string or "" parse_qsl_kwargs = { - 'keep_blank_values': True, - 'encoding': self.encoding, - 'max_num_fields': settings.DATA_UPLOAD_MAX_NUMBER_FIELDS, + "keep_blank_values": True, + "encoding": self.encoding, + "max_num_fields": settings.DATA_UPLOAD_MAX_NUMBER_FIELDS, } if isinstance(query_string, bytes): # query_string normally contains URL-encoded data, a subset of ASCII. @@ -433,7 +475,7 @@ class QueryDict(MultiValueDict): query_string = query_string.decode(self.encoding) except UnicodeDecodeError: # ... but some user agents are misbehaving :-( - query_string = query_string.decode('iso-8859-1') + query_string = query_string.decode("iso-8859-1") try: for key, value in parse_qsl(query_string, **parse_qsl_kwargs): self.appendlist(key, value) @@ -443,18 +485,18 @@ class QueryDict(MultiValueDict): # the exception was raised by exceeding the value of max_num_fields # instead of fragile checks of exception message strings. raise TooManyFieldsSent( - 'The number of GET/POST parameters exceeded ' - 'settings.DATA_UPLOAD_MAX_NUMBER_FIELDS.' + "The number of GET/POST parameters exceeded " + "settings.DATA_UPLOAD_MAX_NUMBER_FIELDS." ) from e self._mutable = mutable @classmethod - def fromkeys(cls, iterable, value='', mutable=False, encoding=None): + def fromkeys(cls, iterable, value="", mutable=False, encoding=None): """ Return a new QueryDict with keys (may be repeated) from an iterable and values from value. """ - q = cls('', mutable=True, encoding=encoding) + q = cls("", mutable=True, encoding=encoding) for key in iterable: q.appendlist(key, value) if not mutable: @@ -486,13 +528,13 @@ class QueryDict(MultiValueDict): super().__delitem__(key) def __copy__(self): - result = self.__class__('', mutable=True, encoding=self.encoding) + result = self.__class__("", mutable=True, encoding=self.encoding) for key, value in self.lists(): result.setlist(key, value) return result def __deepcopy__(self, memo): - result = self.__class__('', mutable=True, encoding=self.encoding) + result = self.__class__("", mutable=True, encoding=self.encoding) memo[id(self)] = result for key, value in self.lists(): result.setlist(copy.deepcopy(key, memo), copy.deepcopy(value, memo)) @@ -554,48 +596,50 @@ class QueryDict(MultiValueDict): safe = safe.encode(self.encoding) def encode(k, v): - return '%s=%s' % ((quote(k, safe), quote(v, safe))) + return "%s=%s" % ((quote(k, safe), quote(v, safe))) + else: + def encode(k, v): return urlencode({k: v}) + for k, list_ in self.lists(): output.extend( encode(k.encode(self.encoding), str(v).encode(self.encoding)) for v in list_ ) - return '&'.join(output) + return "&".join(output) class MediaType: def __init__(self, media_type_raw_line): full_type, self.params = parse_header( - media_type_raw_line.encode('ascii') if media_type_raw_line else b'' + media_type_raw_line.encode("ascii") if media_type_raw_line else b"" ) - self.main_type, _, self.sub_type = full_type.partition('/') + self.main_type, _, self.sub_type = full_type.partition("/") def __str__(self): - params_str = ''.join( - '; %s=%s' % (k, v.decode('ascii')) - for k, v in self.params.items() + params_str = "".join( + "; %s=%s" % (k, v.decode("ascii")) for k, v in self.params.items() ) - return '%s%s%s' % ( + return "%s%s%s" % ( self.main_type, - ('/%s' % self.sub_type) if self.sub_type else '', + ("/%s" % self.sub_type) if self.sub_type else "", params_str, ) def __repr__(self): - return '<%s: %s>' % (self.__class__.__qualname__, self) + return "<%s: %s>" % (self.__class__.__qualname__, self) @property def is_all_types(self): - return self.main_type == '*' and self.sub_type == '*' + return self.main_type == "*" and self.sub_type == "*" def match(self, other): if self.is_all_types: return True other = MediaType(other) - if self.main_type == other.main_type and self.sub_type in {'*', other.sub_type}: + if self.main_type == other.main_type and self.sub_type in {"*", other.sub_type}: return True return False @@ -612,7 +656,7 @@ def bytes_to_text(s, encoding): Return any non-bytes objects without change. """ if isinstance(s, bytes): - return str(s, encoding, 'replace') + return str(s, encoding, "replace") else: return s @@ -627,15 +671,15 @@ def split_domain_port(host): host = host.lower() if not host_validation_re.match(host): - return '', '' + return "", "" - if host[-1] == ']': + if host[-1] == "]": # It's an IPv6 address without a port. - return host, '' - bits = host.rsplit(':', 1) - domain, port = bits if len(bits) == 2 else (bits[0], '') + return host, "" + bits = host.rsplit(":", 1) + domain, port = bits if len(bits) == 2 else (bits[0], "") # Remove a trailing dot (if present) from the domain. - domain = domain[:-1] if domain.endswith('.') else domain + domain = domain[:-1] if domain.endswith(".") else domain return domain, port @@ -654,8 +698,10 @@ def validate_host(host, allowed_hosts): Return ``True`` for a valid host, ``False`` otherwise. """ - return any(pattern == '*' or is_same_domain(host, pattern) for pattern in allowed_hosts) + return any( + pattern == "*" or is_same_domain(host, pattern) for pattern in allowed_hosts + ) def parse_accept_header(header): - return [MediaType(token) for token in header.split(',') if token.strip()] + return [MediaType(token) for token in header.split(",") if token.strip()] |
