summaryrefslogtreecommitdiff
path: root/django/contrib/postgres
diff options
context:
space:
mode:
Diffstat (limited to 'django/contrib/postgres')
-rw-r--r--django/contrib/postgres/aggregates/general.py48
-rw-r--r--django/contrib/postgres/aggregates/mixins.py1
-rw-r--r--django/contrib/postgres/aggregates/statistics.py42
-rw-r--r--django/contrib/postgres/apps.py32
-rw-r--r--django/contrib/postgres/constraints.py132
-rw-r--r--django/contrib/postgres/expressions.py2
-rw-r--r--django/contrib/postgres/fields/array.py128
-rw-r--r--django/contrib/postgres/fields/citext.py7
-rw-r--r--django/contrib/postgres/fields/hstore.py37
-rw-r--r--django/contrib/postgres/fields/jsonb.py12
-rw-r--r--django/contrib/postgres/fields/ranges.py152
-rw-r--r--django/contrib/postgres/forms/array.py99
-rw-r--r--django/contrib/postgres/forms/hstore.py15
-rw-r--r--django/contrib/postgres/forms/ranges.py51
-rw-r--r--django/contrib/postgres/functions.py4
-rw-r--r--django/contrib/postgres/indexes.py122
-rw-r--r--django/contrib/postgres/lookups.py40
-rw-r--r--django/contrib/postgres/operations.py149
-rw-r--r--django/contrib/postgres/search.py211
-rw-r--r--django/contrib/postgres/serializers.py4
-rw-r--r--django/contrib/postgres/signals.py10
-rw-r--r--django/contrib/postgres/utils.py8
-rw-r--r--django/contrib/postgres/validators.py55
23 files changed, 757 insertions, 604 deletions
diff --git a/django/contrib/postgres/aggregates/general.py b/django/contrib/postgres/aggregates/general.py
index d90ca50e2b..f8b40fb709 100644
--- a/django/contrib/postgres/aggregates/general.py
+++ b/django/contrib/postgres/aggregates/general.py
@@ -1,16 +1,20 @@
import warnings
from django.contrib.postgres.fields import ArrayField
-from django.db.models import (
- Aggregate, BooleanField, JSONField, TextField, Value,
-)
+from django.db.models import Aggregate, BooleanField, JSONField, TextField, Value
from django.utils.deprecation import RemovedInDjango50Warning
from .mixins import OrderableAggMixin
__all__ = [
- 'ArrayAgg', 'BitAnd', 'BitOr', 'BitXor', 'BoolAnd', 'BoolOr', 'JSONBAgg',
- 'StringAgg',
+ "ArrayAgg",
+ "BitAnd",
+ "BitOr",
+ "BitXor",
+ "BoolAnd",
+ "BoolOr",
+ "JSONBAgg",
+ "StringAgg",
]
@@ -35,17 +39,17 @@ class DeprecatedConvertValueMixin:
class ArrayAgg(DeprecatedConvertValueMixin, OrderableAggMixin, Aggregate):
- function = 'ARRAY_AGG'
- template = '%(function)s(%(distinct)s%(expressions)s %(ordering)s)'
+ function = "ARRAY_AGG"
+ template = "%(function)s(%(distinct)s%(expressions)s %(ordering)s)"
allow_distinct = True
# RemovedInDjango50Warning
deprecation_value = property(lambda self: [])
deprecation_msg = (
- 'In Django 5.0, ArrayAgg() will return None instead of an empty list '
- 'if there are no rows. Pass default=None to opt into the new behavior '
- 'and silence this warning or default=Value([]) to keep the previous '
- 'behavior.'
+ "In Django 5.0, ArrayAgg() will return None instead of an empty list "
+ "if there are no rows. Pass default=None to opt into the new behavior "
+ "and silence this warning or default=Value([]) to keep the previous "
+ "behavior."
)
@property
@@ -54,35 +58,35 @@ class ArrayAgg(DeprecatedConvertValueMixin, OrderableAggMixin, Aggregate):
class BitAnd(Aggregate):
- function = 'BIT_AND'
+ function = "BIT_AND"
class BitOr(Aggregate):
- function = 'BIT_OR'
+ function = "BIT_OR"
class BitXor(Aggregate):
- function = 'BIT_XOR'
+ function = "BIT_XOR"
class BoolAnd(Aggregate):
- function = 'BOOL_AND'
+ function = "BOOL_AND"
output_field = BooleanField()
class BoolOr(Aggregate):
- function = 'BOOL_OR'
+ function = "BOOL_OR"
output_field = BooleanField()
class JSONBAgg(DeprecatedConvertValueMixin, OrderableAggMixin, Aggregate):
- function = 'JSONB_AGG'
- template = '%(function)s(%(distinct)s%(expressions)s %(ordering)s)'
+ function = "JSONB_AGG"
+ template = "%(function)s(%(distinct)s%(expressions)s %(ordering)s)"
allow_distinct = True
output_field = JSONField()
# RemovedInDjango50Warning
- deprecation_value = '[]'
+ deprecation_value = "[]"
deprecation_msg = (
"In Django 5.0, JSONBAgg() will return None instead of an empty list "
"if there are no rows. Pass default=None to opt into the new behavior "
@@ -92,13 +96,13 @@ class JSONBAgg(DeprecatedConvertValueMixin, OrderableAggMixin, Aggregate):
class StringAgg(DeprecatedConvertValueMixin, OrderableAggMixin, Aggregate):
- function = 'STRING_AGG'
- template = '%(function)s(%(distinct)s%(expressions)s %(ordering)s)'
+ function = "STRING_AGG"
+ template = "%(function)s(%(distinct)s%(expressions)s %(ordering)s)"
allow_distinct = True
output_field = TextField()
# RemovedInDjango50Warning
- deprecation_value = ''
+ deprecation_value = ""
deprecation_msg = (
"In Django 5.0, StringAgg() will return None instead of an empty "
"string if there are no rows. Pass default=None to opt into the new "
diff --git a/django/contrib/postgres/aggregates/mixins.py b/django/contrib/postgres/aggregates/mixins.py
index 4fedb9bd98..b2f4097b8f 100644
--- a/django/contrib/postgres/aggregates/mixins.py
+++ b/django/contrib/postgres/aggregates/mixins.py
@@ -2,7 +2,6 @@ from django.db.models.expressions import OrderByList
class OrderableAggMixin:
-
def __init__(self, *expressions, ordering=(), **extra):
if isinstance(ordering, (list, tuple)):
self.order_by = OrderByList(*ordering)
diff --git a/django/contrib/postgres/aggregates/statistics.py b/django/contrib/postgres/aggregates/statistics.py
index 2c83b78c0e..3dc442b290 100644
--- a/django/contrib/postgres/aggregates/statistics.py
+++ b/django/contrib/postgres/aggregates/statistics.py
@@ -1,8 +1,18 @@
from django.db.models import Aggregate, FloatField, IntegerField
__all__ = [
- 'CovarPop', 'Corr', 'RegrAvgX', 'RegrAvgY', 'RegrCount', 'RegrIntercept',
- 'RegrR2', 'RegrSlope', 'RegrSXX', 'RegrSXY', 'RegrSYY', 'StatAggregate',
+ "CovarPop",
+ "Corr",
+ "RegrAvgX",
+ "RegrAvgY",
+ "RegrCount",
+ "RegrIntercept",
+ "RegrR2",
+ "RegrSlope",
+ "RegrSXX",
+ "RegrSXY",
+ "RegrSYY",
+ "StatAggregate",
]
@@ -11,53 +21,55 @@ class StatAggregate(Aggregate):
def __init__(self, y, x, output_field=None, filter=None, default=None):
if not x or not y:
- raise ValueError('Both y and x must be provided.')
- super().__init__(y, x, output_field=output_field, filter=filter, default=default)
+ raise ValueError("Both y and x must be provided.")
+ super().__init__(
+ y, x, output_field=output_field, filter=filter, default=default
+ )
class Corr(StatAggregate):
- function = 'CORR'
+ function = "CORR"
class CovarPop(StatAggregate):
def __init__(self, y, x, sample=False, filter=None, default=None):
- self.function = 'COVAR_SAMP' if sample else 'COVAR_POP'
+ self.function = "COVAR_SAMP" if sample else "COVAR_POP"
super().__init__(y, x, filter=filter, default=default)
class RegrAvgX(StatAggregate):
- function = 'REGR_AVGX'
+ function = "REGR_AVGX"
class RegrAvgY(StatAggregate):
- function = 'REGR_AVGY'
+ function = "REGR_AVGY"
class RegrCount(StatAggregate):
- function = 'REGR_COUNT'
+ function = "REGR_COUNT"
output_field = IntegerField()
empty_result_set_value = 0
class RegrIntercept(StatAggregate):
- function = 'REGR_INTERCEPT'
+ function = "REGR_INTERCEPT"
class RegrR2(StatAggregate):
- function = 'REGR_R2'
+ function = "REGR_R2"
class RegrSlope(StatAggregate):
- function = 'REGR_SLOPE'
+ function = "REGR_SLOPE"
class RegrSXX(StatAggregate):
- function = 'REGR_SXX'
+ function = "REGR_SXX"
class RegrSXY(StatAggregate):
- function = 'REGR_SXY'
+ function = "REGR_SXY"
class RegrSYY(StatAggregate):
- function = 'REGR_SYY'
+ function = "REGR_SYY"
diff --git a/django/contrib/postgres/apps.py b/django/contrib/postgres/apps.py
index b8ec85b7a4..d917201f05 100644
--- a/django/contrib/postgres/apps.py
+++ b/django/contrib/postgres/apps.py
@@ -1,6 +1,4 @@
-from psycopg2.extras import (
- DateRange, DateTimeRange, DateTimeTZRange, NumericRange,
-)
+from psycopg2.extras import DateRange, DateTimeRange, DateTimeTZRange, NumericRange
from django.apps import AppConfig
from django.core.signals import setting_changed
@@ -25,7 +23,11 @@ def uninstall_if_needed(setting, value, enter, **kwargs):
Undo the effects of PostgresConfig.ready() when django.contrib.postgres
is "uninstalled" by override_settings().
"""
- if not enter and setting == 'INSTALLED_APPS' and 'django.contrib.postgres' not in set(value):
+ if (
+ not enter
+ and setting == "INSTALLED_APPS"
+ and "django.contrib.postgres" not in set(value)
+ ):
connection_created.disconnect(register_type_handlers)
CharField._unregister_lookup(Unaccent)
TextField._unregister_lookup(Unaccent)
@@ -43,21 +45,23 @@ def uninstall_if_needed(setting, value, enter, **kwargs):
class PostgresConfig(AppConfig):
- name = 'django.contrib.postgres'
- verbose_name = _('PostgreSQL extensions')
+ name = "django.contrib.postgres"
+ verbose_name = _("PostgreSQL extensions")
def ready(self):
setting_changed.connect(uninstall_if_needed)
# Connections may already exist before we are called.
for conn in connections.all():
- if conn.vendor == 'postgresql':
- conn.introspection.data_types_reverse.update({
- 3904: 'django.contrib.postgres.fields.IntegerRangeField',
- 3906: 'django.contrib.postgres.fields.DecimalRangeField',
- 3910: 'django.contrib.postgres.fields.DateTimeRangeField',
- 3912: 'django.contrib.postgres.fields.DateRangeField',
- 3926: 'django.contrib.postgres.fields.BigIntegerRangeField',
- })
+ if conn.vendor == "postgresql":
+ conn.introspection.data_types_reverse.update(
+ {
+ 3904: "django.contrib.postgres.fields.IntegerRangeField",
+ 3906: "django.contrib.postgres.fields.DecimalRangeField",
+ 3910: "django.contrib.postgres.fields.DateTimeRangeField",
+ 3912: "django.contrib.postgres.fields.DateRangeField",
+ 3926: "django.contrib.postgres.fields.BigIntegerRangeField",
+ }
+ )
if conn.connection is not None:
register_type_handlers(conn)
connection_created.connect(register_type_handlers)
diff --git a/django/contrib/postgres/constraints.py b/django/contrib/postgres/constraints.py
index 06ccd3616e..c0bc0c444f 100644
--- a/django/contrib/postgres/constraints.py
+++ b/django/contrib/postgres/constraints.py
@@ -10,71 +10,69 @@ from django.db.models.indexes import IndexExpression
from django.db.models.sql import Query
from django.utils.deprecation import RemovedInDjango50Warning
-__all__ = ['ExclusionConstraint']
+__all__ = ["ExclusionConstraint"]
class ExclusionConstraintExpression(IndexExpression):
- template = '%(expressions)s WITH %(operator)s'
+ template = "%(expressions)s WITH %(operator)s"
class ExclusionConstraint(BaseConstraint):
- template = 'CONSTRAINT %(name)s EXCLUDE USING %(index_type)s (%(expressions)s)%(include)s%(where)s%(deferrable)s'
+ template = "CONSTRAINT %(name)s EXCLUDE USING %(index_type)s (%(expressions)s)%(include)s%(where)s%(deferrable)s"
def __init__(
- self, *, name, expressions, index_type=None, condition=None,
- deferrable=None, include=None, opclasses=(),
+ self,
+ *,
+ name,
+ expressions,
+ index_type=None,
+ condition=None,
+ deferrable=None,
+ include=None,
+ opclasses=(),
):
- if index_type and index_type.lower() not in {'gist', 'spgist'}:
+ if index_type and index_type.lower() not in {"gist", "spgist"}:
raise ValueError(
- 'Exclusion constraints only support GiST or SP-GiST indexes.'
+ "Exclusion constraints only support GiST or SP-GiST indexes."
)
if not expressions:
raise ValueError(
- 'At least one expression is required to define an exclusion '
- 'constraint.'
+ "At least one expression is required to define an exclusion "
+ "constraint."
)
if not all(
- isinstance(expr, (list, tuple)) and len(expr) == 2
- for expr in expressions
+ isinstance(expr, (list, tuple)) and len(expr) == 2 for expr in expressions
):
- raise ValueError('The expressions must be a list of 2-tuples.')
+ raise ValueError("The expressions must be a list of 2-tuples.")
if not isinstance(condition, (type(None), Q)):
- raise ValueError(
- 'ExclusionConstraint.condition must be a Q instance.'
- )
+ raise ValueError("ExclusionConstraint.condition must be a Q instance.")
if condition and deferrable:
- raise ValueError(
- 'ExclusionConstraint with conditions cannot be deferred.'
- )
+ raise ValueError("ExclusionConstraint with conditions cannot be deferred.")
if not isinstance(deferrable, (type(None), Deferrable)):
raise ValueError(
- 'ExclusionConstraint.deferrable must be a Deferrable instance.'
+ "ExclusionConstraint.deferrable must be a Deferrable instance."
)
if not isinstance(include, (type(None), list, tuple)):
- raise ValueError(
- 'ExclusionConstraint.include must be a list or tuple.'
- )
+ raise ValueError("ExclusionConstraint.include must be a list or tuple.")
if not isinstance(opclasses, (list, tuple)):
- raise ValueError(
- 'ExclusionConstraint.opclasses must be a list or tuple.'
- )
+ raise ValueError("ExclusionConstraint.opclasses must be a list or tuple.")
if opclasses and len(expressions) != len(opclasses):
raise ValueError(
- 'ExclusionConstraint.expressions and '
- 'ExclusionConstraint.opclasses must have the same number of '
- 'elements.'
+ "ExclusionConstraint.expressions and "
+ "ExclusionConstraint.opclasses must have the same number of "
+ "elements."
)
self.expressions = expressions
- self.index_type = index_type or 'GIST'
+ self.index_type = index_type or "GIST"
self.condition = condition
self.deferrable = deferrable
self.include = tuple(include) if include else ()
self.opclasses = opclasses
if self.opclasses:
warnings.warn(
- 'The opclasses argument is deprecated in favor of using '
- 'django.contrib.postgres.indexes.OpClass in '
- 'ExclusionConstraint.expressions.',
+ "The opclasses argument is deprecated in favor of using "
+ "django.contrib.postgres.indexes.OpClass in "
+ "ExclusionConstraint.expressions.",
category=RemovedInDjango50Warning,
stacklevel=2,
)
@@ -107,14 +105,18 @@ class ExclusionConstraint(BaseConstraint):
expressions = self._get_expressions(schema_editor, query)
table = model._meta.db_table
condition = self._get_condition_sql(compiler, schema_editor, query)
- include = [model._meta.get_field(field_name).column for field_name in self.include]
+ include = [
+ model._meta.get_field(field_name).column for field_name in self.include
+ ]
return Statement(
self.template,
table=Table(table, schema_editor.quote_name),
name=schema_editor.quote_name(self.name),
index_type=self.index_type,
- expressions=Expressions(table, expressions, compiler, schema_editor.quote_value),
- where=' WHERE (%s)' % condition if condition else '',
+ expressions=Expressions(
+ table, expressions, compiler, schema_editor.quote_value
+ ),
+ where=" WHERE (%s)" % condition if condition else "",
include=schema_editor._index_include_sql(model, include),
deferrable=schema_editor._deferrable_constraint_sql(self.deferrable),
)
@@ -122,7 +124,7 @@ class ExclusionConstraint(BaseConstraint):
def create_sql(self, model, schema_editor):
self.check_supported(schema_editor)
return Statement(
- 'ALTER TABLE %(table)s ADD %(constraint)s',
+ "ALTER TABLE %(table)s ADD %(constraint)s",
table=Table(model._meta.db_table, schema_editor.quote_name),
constraint=self.constraint_sql(model, schema_editor),
)
@@ -136,60 +138,60 @@ class ExclusionConstraint(BaseConstraint):
def check_supported(self, schema_editor):
if (
- self.include and
- self.index_type.lower() == 'gist' and
- not schema_editor.connection.features.supports_covering_gist_indexes
+ self.include
+ and self.index_type.lower() == "gist"
+ and not schema_editor.connection.features.supports_covering_gist_indexes
):
raise NotSupportedError(
- 'Covering exclusion constraints using a GiST index require '
- 'PostgreSQL 12+.'
+ "Covering exclusion constraints using a GiST index require "
+ "PostgreSQL 12+."
)
if (
- self.include and
- self.index_type.lower() == 'spgist' and
- not schema_editor.connection.features.supports_covering_spgist_indexes
+ self.include
+ and self.index_type.lower() == "spgist"
+ and not schema_editor.connection.features.supports_covering_spgist_indexes
):
raise NotSupportedError(
- 'Covering exclusion constraints using an SP-GiST index '
- 'require PostgreSQL 14+.'
+ "Covering exclusion constraints using an SP-GiST index "
+ "require PostgreSQL 14+."
)
def deconstruct(self):
path, args, kwargs = super().deconstruct()
- kwargs['expressions'] = self.expressions
+ kwargs["expressions"] = self.expressions
if self.condition is not None:
- kwargs['condition'] = self.condition
- if self.index_type.lower() != 'gist':
- kwargs['index_type'] = self.index_type
+ kwargs["condition"] = self.condition
+ if self.index_type.lower() != "gist":
+ kwargs["index_type"] = self.index_type
if self.deferrable:
- kwargs['deferrable'] = self.deferrable
+ kwargs["deferrable"] = self.deferrable
if self.include:
- kwargs['include'] = self.include
+ kwargs["include"] = self.include
if self.opclasses:
- kwargs['opclasses'] = self.opclasses
+ kwargs["opclasses"] = self.opclasses
return path, args, kwargs
def __eq__(self, other):
if isinstance(other, self.__class__):
return (
- self.name == other.name and
- self.index_type == other.index_type and
- self.expressions == other.expressions and
- self.condition == other.condition and
- self.deferrable == other.deferrable and
- self.include == other.include and
- self.opclasses == other.opclasses
+ self.name == other.name
+ and self.index_type == other.index_type
+ and self.expressions == other.expressions
+ and self.condition == other.condition
+ and self.deferrable == other.deferrable
+ and self.include == other.include
+ and self.opclasses == other.opclasses
)
return super().__eq__(other)
def __repr__(self):
- return '<%s: index_type=%s expressions=%s name=%s%s%s%s%s>' % (
+ return "<%s: index_type=%s expressions=%s name=%s%s%s%s%s>" % (
self.__class__.__qualname__,
repr(self.index_type),
repr(self.expressions),
repr(self.name),
- '' if self.condition is None else ' condition=%s' % self.condition,
- '' if self.deferrable is None else ' deferrable=%r' % self.deferrable,
- '' if not self.include else ' include=%s' % repr(self.include),
- '' if not self.opclasses else ' opclasses=%s' % repr(self.opclasses),
+ "" if self.condition is None else " condition=%s" % self.condition,
+ "" if self.deferrable is None else " deferrable=%r" % self.deferrable,
+ "" if not self.include else " include=%s" % repr(self.include),
+ "" if not self.opclasses else " opclasses=%s" % repr(self.opclasses),
)
diff --git a/django/contrib/postgres/expressions.py b/django/contrib/postgres/expressions.py
index ea7cbe038d..469f4e9fb6 100644
--- a/django/contrib/postgres/expressions.py
+++ b/django/contrib/postgres/expressions.py
@@ -4,7 +4,7 @@ from django.utils.functional import cached_property
class ArraySubquery(Subquery):
- template = 'ARRAY(%(subquery)s)'
+ template = "ARRAY(%(subquery)s)"
def __init__(self, queryset, **kwargs):
super().__init__(queryset, **kwargs)
diff --git a/django/contrib/postgres/fields/array.py b/django/contrib/postgres/fields/array.py
index 9c1bb96b61..7269198674 100644
--- a/django/contrib/postgres/fields/array.py
+++ b/django/contrib/postgres/fields/array.py
@@ -12,38 +12,43 @@ from django.utils.translation import gettext_lazy as _
from ..utils import prefix_validation_error
from .utils import AttributeSetter
-__all__ = ['ArrayField']
+__all__ = ["ArrayField"]
class ArrayField(CheckFieldDefaultMixin, Field):
empty_strings_allowed = False
default_error_messages = {
- 'item_invalid': _('Item %(nth)s in the array did not validate:'),
- 'nested_array_mismatch': _('Nested arrays must have the same length.'),
+ "item_invalid": _("Item %(nth)s in the array did not validate:"),
+ "nested_array_mismatch": _("Nested arrays must have the same length."),
}
- _default_hint = ('list', '[]')
+ _default_hint = ("list", "[]")
def __init__(self, base_field, size=None, **kwargs):
self.base_field = base_field
self.size = size
if self.size:
- self.default_validators = [*self.default_validators, ArrayMaxLengthValidator(self.size)]
+ self.default_validators = [
+ *self.default_validators,
+ ArrayMaxLengthValidator(self.size),
+ ]
# For performance, only add a from_db_value() method if the base field
# implements it.
- if hasattr(self.base_field, 'from_db_value'):
+ if hasattr(self.base_field, "from_db_value"):
self.from_db_value = self._from_db_value
super().__init__(**kwargs)
@property
def model(self):
try:
- return self.__dict__['model']
+ return self.__dict__["model"]
except KeyError:
- raise AttributeError("'%s' object has no attribute 'model'" % self.__class__.__name__)
+ raise AttributeError(
+ "'%s' object has no attribute 'model'" % self.__class__.__name__
+ )
@model.setter
def model(self, model):
- self.__dict__['model'] = model
+ self.__dict__["model"] = model
self.base_field.model = model
@classmethod
@@ -55,21 +60,23 @@ class ArrayField(CheckFieldDefaultMixin, Field):
if self.base_field.remote_field:
errors.append(
checks.Error(
- 'Base field for array cannot be a related field.',
+ "Base field for array cannot be a related field.",
obj=self,
- id='postgres.E002'
+ id="postgres.E002",
)
)
else:
# Remove the field name checks as they are not needed here.
base_errors = self.base_field.check()
if base_errors:
- messages = '\n '.join('%s (%s)' % (error.msg, error.id) for error in base_errors)
+ messages = "\n ".join(
+ "%s (%s)" % (error.msg, error.id) for error in base_errors
+ )
errors.append(
checks.Error(
- 'Base field for array has errors:\n %s' % messages,
+ "Base field for array has errors:\n %s" % messages,
obj=self,
- id='postgres.E001'
+ id="postgres.E001",
)
)
return errors
@@ -80,32 +87,37 @@ class ArrayField(CheckFieldDefaultMixin, Field):
@property
def description(self):
- return 'Array of %s' % self.base_field.description
+ return "Array of %s" % self.base_field.description
def db_type(self, connection):
- size = self.size or ''
- return '%s[%s]' % (self.base_field.db_type(connection), size)
+ size = self.size or ""
+ return "%s[%s]" % (self.base_field.db_type(connection), size)
def cast_db_type(self, connection):
- size = self.size or ''
- return '%s[%s]' % (self.base_field.cast_db_type(connection), size)
+ size = self.size or ""
+ return "%s[%s]" % (self.base_field.cast_db_type(connection), size)
def get_placeholder(self, value, compiler, connection):
- return '%s::{}'.format(self.db_type(connection))
+ return "%s::{}".format(self.db_type(connection))
def get_db_prep_value(self, value, connection, prepared=False):
if isinstance(value, (list, tuple)):
- return [self.base_field.get_db_prep_value(i, connection, prepared=False) for i in value]
+ return [
+ self.base_field.get_db_prep_value(i, connection, prepared=False)
+ for i in value
+ ]
return value
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
- if path == 'django.contrib.postgres.fields.array.ArrayField':
- path = 'django.contrib.postgres.fields.ArrayField'
- kwargs.update({
- 'base_field': self.base_field.clone(),
- 'size': self.size,
- })
+ if path == "django.contrib.postgres.fields.array.ArrayField":
+ path = "django.contrib.postgres.fields.ArrayField"
+ kwargs.update(
+ {
+ "base_field": self.base_field.clone(),
+ "size": self.size,
+ }
+ )
return name, path, args, kwargs
def to_python(self, value):
@@ -140,7 +152,7 @@ class ArrayField(CheckFieldDefaultMixin, Field):
transform = super().get_transform(name)
if transform:
return transform
- if '_' not in name:
+ if "_" not in name:
try:
index = int(name)
except ValueError:
@@ -149,7 +161,7 @@ class ArrayField(CheckFieldDefaultMixin, Field):
index += 1 # postgres uses 1-indexing
return IndexTransformFactory(index, self.base_field)
try:
- start, end = name.split('_')
+ start, end = name.split("_")
start = int(start) + 1
end = int(end) # don't add one here because postgres slices are weird
except ValueError:
@@ -165,15 +177,15 @@ class ArrayField(CheckFieldDefaultMixin, Field):
except exceptions.ValidationError as error:
raise prefix_validation_error(
error,
- prefix=self.error_messages['item_invalid'],
- code='item_invalid',
- params={'nth': index + 1},
+ prefix=self.error_messages["item_invalid"],
+ code="item_invalid",
+ params={"nth": index + 1},
)
if isinstance(self.base_field, ArrayField):
if len({len(i) for i in value}) > 1:
raise exceptions.ValidationError(
- self.error_messages['nested_array_mismatch'],
- code='nested_array_mismatch',
+ self.error_messages["nested_array_mismatch"],
+ code="nested_array_mismatch",
)
def run_validators(self, value):
@@ -184,18 +196,20 @@ class ArrayField(CheckFieldDefaultMixin, Field):
except exceptions.ValidationError as error:
raise prefix_validation_error(
error,
- prefix=self.error_messages['item_invalid'],
- code='item_invalid',
- params={'nth': index + 1},
+ prefix=self.error_messages["item_invalid"],
+ code="item_invalid",
+ params={"nth": index + 1},
)
def formfield(self, **kwargs):
- return super().formfield(**{
- 'form_class': SimpleArrayField,
- 'base_field': self.base_field.formfield(),
- 'max_length': self.size,
- **kwargs,
- })
+ return super().formfield(
+ **{
+ "form_class": SimpleArrayField,
+ "base_field": self.base_field.formfield(),
+ "max_length": self.size,
+ **kwargs,
+ }
+ )
class ArrayRHSMixin:
@@ -203,21 +217,21 @@ class ArrayRHSMixin:
if isinstance(rhs, (tuple, list)):
expressions = []
for value in rhs:
- if not hasattr(value, 'resolve_expression'):
+ if not hasattr(value, "resolve_expression"):
field = lhs.output_field
value = Value(field.base_field.get_prep_value(value))
expressions.append(value)
rhs = Func(
*expressions,
- function='ARRAY',
- template='%(function)s[%(expressions)s]',
+ function="ARRAY",
+ template="%(function)s[%(expressions)s]",
)
super().__init__(lhs, rhs)
def process_rhs(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
cast_type = self.lhs.output_field.cast_db_type(connection)
- return '%s::%s' % (rhs, cast_type), rhs_params
+ return "%s::%s" % (rhs, cast_type), rhs_params
@ArrayField.register_lookup
@@ -242,29 +256,29 @@ class ArrayOverlap(ArrayRHSMixin, lookups.Overlap):
@ArrayField.register_lookup
class ArrayLenTransform(Transform):
- lookup_name = 'len'
+ lookup_name = "len"
output_field = IntegerField()
def as_sql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs)
# Distinguish NULL and empty arrays
return (
- 'CASE WHEN %(lhs)s IS NULL THEN NULL ELSE '
- 'coalesce(array_length(%(lhs)s, 1), 0) END'
- ) % {'lhs': lhs}, params
+ "CASE WHEN %(lhs)s IS NULL THEN NULL ELSE "
+ "coalesce(array_length(%(lhs)s, 1), 0) END"
+ ) % {"lhs": lhs}, params
@ArrayField.register_lookup
class ArrayInLookup(In):
def get_prep_lookup(self):
values = super().get_prep_lookup()
- if hasattr(values, 'resolve_expression'):
+ if hasattr(values, "resolve_expression"):
return values
# In.process_rhs() expects values to be hashable, so convert lists
# to tuples.
prepared_values = []
for value in values:
- if hasattr(value, 'resolve_expression'):
+ if hasattr(value, "resolve_expression"):
prepared_values.append(value)
else:
prepared_values.append(tuple(value))
@@ -272,7 +286,6 @@ class ArrayInLookup(In):
class IndexTransform(Transform):
-
def __init__(self, index, base_field, *args, **kwargs):
super().__init__(*args, **kwargs)
self.index = index
@@ -280,7 +293,7 @@ class IndexTransform(Transform):
def as_sql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs)
- return '%s[%%s]' % lhs, params + [self.index]
+ return "%s[%%s]" % lhs, params + [self.index]
@property
def output_field(self):
@@ -288,7 +301,6 @@ class IndexTransform(Transform):
class IndexTransformFactory:
-
def __init__(self, index, base_field):
self.index = index
self.base_field = base_field
@@ -298,7 +310,6 @@ class IndexTransformFactory:
class SliceTransform(Transform):
-
def __init__(self, start, end, *args, **kwargs):
super().__init__(*args, **kwargs)
self.start = start
@@ -306,11 +317,10 @@ class SliceTransform(Transform):
def as_sql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs)
- return '%s[%%s:%%s]' % lhs, params + [self.start, self.end]
+ return "%s[%%s:%%s]" % lhs, params + [self.start, self.end]
class SliceTransformFactory:
-
def __init__(self, start, end):
self.start = start
self.end = end
diff --git a/django/contrib/postgres/fields/citext.py b/django/contrib/postgres/fields/citext.py
index 46f6d3d1c2..2b943614d2 100644
--- a/django/contrib/postgres/fields/citext.py
+++ b/django/contrib/postgres/fields/citext.py
@@ -1,15 +1,14 @@
from django.db.models import CharField, EmailField, TextField
-__all__ = ['CICharField', 'CIEmailField', 'CIText', 'CITextField']
+__all__ = ["CICharField", "CIEmailField", "CIText", "CITextField"]
class CIText:
-
def get_internal_type(self):
- return 'CI' + super().get_internal_type()
+ return "CI" + super().get_internal_type()
def db_type(self, connection):
- return 'citext'
+ return "citext"
class CICharField(CIText, CharField):
diff --git a/django/contrib/postgres/fields/hstore.py b/django/contrib/postgres/fields/hstore.py
index 2ec5766041..cfc156ab59 100644
--- a/django/contrib/postgres/fields/hstore.py
+++ b/django/contrib/postgres/fields/hstore.py
@@ -7,19 +7,19 @@ from django.db.models import Field, TextField, Transform
from django.db.models.fields.mixins import CheckFieldDefaultMixin
from django.utils.translation import gettext_lazy as _
-__all__ = ['HStoreField']
+__all__ = ["HStoreField"]
class HStoreField(CheckFieldDefaultMixin, Field):
empty_strings_allowed = False
- description = _('Map of strings to strings/nulls')
+ description = _("Map of strings to strings/nulls")
default_error_messages = {
- 'not_a_string': _('The value of “%(key)s” is not a string or null.'),
+ "not_a_string": _("The value of “%(key)s” is not a string or null."),
}
- _default_hint = ('dict', '{}')
+ _default_hint = ("dict", "{}")
def db_type(self, connection):
- return 'hstore'
+ return "hstore"
def get_transform(self, name):
transform = super().get_transform(name)
@@ -32,9 +32,9 @@ class HStoreField(CheckFieldDefaultMixin, Field):
for key, val in value.items():
if not isinstance(val, str) and val is not None:
raise exceptions.ValidationError(
- self.error_messages['not_a_string'],
- code='not_a_string',
- params={'key': key},
+ self.error_messages["not_a_string"],
+ code="not_a_string",
+ params={"key": key},
)
def to_python(self, value):
@@ -46,10 +46,12 @@ class HStoreField(CheckFieldDefaultMixin, Field):
return json.dumps(self.value_from_object(obj))
def formfield(self, **kwargs):
- return super().formfield(**{
- 'form_class': forms.HStoreField,
- **kwargs,
- })
+ return super().formfield(
+ **{
+ "form_class": forms.HStoreField,
+ **kwargs,
+ }
+ )
def get_prep_value(self, value):
value = super().get_prep_value(value)
@@ -85,11 +87,10 @@ class KeyTransform(Transform):
def as_sql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs)
- return '(%s -> %%s)' % lhs, tuple(params) + (self.key_name,)
+ return "(%s -> %%s)" % lhs, tuple(params) + (self.key_name,)
class KeyTransformFactory:
-
def __init__(self, key_name):
self.key_name = key_name
@@ -99,13 +100,13 @@ class KeyTransformFactory:
@HStoreField.register_lookup
class KeysTransform(Transform):
- lookup_name = 'keys'
- function = 'akeys'
+ lookup_name = "keys"
+ function = "akeys"
output_field = ArrayField(TextField())
@HStoreField.register_lookup
class ValuesTransform(Transform):
- lookup_name = 'values'
- function = 'avals'
+ lookup_name = "values"
+ function = "avals"
output_field = ArrayField(TextField())
diff --git a/django/contrib/postgres/fields/jsonb.py b/django/contrib/postgres/fields/jsonb.py
index 29e8480665..760b5d8398 100644
--- a/django/contrib/postgres/fields/jsonb.py
+++ b/django/contrib/postgres/fields/jsonb.py
@@ -1,14 +1,14 @@
from django.db.models import JSONField as BuiltinJSONField
-__all__ = ['JSONField']
+__all__ = ["JSONField"]
class JSONField(BuiltinJSONField):
system_check_removed_details = {
- 'msg': (
- 'django.contrib.postgres.fields.JSONField is removed except for '
- 'support in historical migrations.'
+ "msg": (
+ "django.contrib.postgres.fields.JSONField is removed except for "
+ "support in historical migrations."
),
- 'hint': 'Use django.db.models.JSONField instead.',
- 'id': 'fields.E904',
+ "hint": "Use django.db.models.JSONField instead.",
+ "id": "fields.E904",
}
diff --git a/django/contrib/postgres/fields/ranges.py b/django/contrib/postgres/fields/ranges.py
index b395e213a1..58ffacbbe5 100644
--- a/django/contrib/postgres/fields/ranges.py
+++ b/django/contrib/postgres/fields/ranges.py
@@ -10,17 +10,23 @@ from django.db.models.lookups import PostgresOperatorLookup
from .utils import AttributeSetter
__all__ = [
- 'RangeField', 'IntegerRangeField', 'BigIntegerRangeField',
- 'DecimalRangeField', 'DateTimeRangeField', 'DateRangeField',
- 'RangeBoundary', 'RangeOperators',
+ "RangeField",
+ "IntegerRangeField",
+ "BigIntegerRangeField",
+ "DecimalRangeField",
+ "DateTimeRangeField",
+ "DateRangeField",
+ "RangeBoundary",
+ "RangeOperators",
]
class RangeBoundary(models.Expression):
"""A class that represents range boundaries."""
+
def __init__(self, inclusive_lower=True, inclusive_upper=False):
- self.lower = '[' if inclusive_lower else '('
- self.upper = ']' if inclusive_upper else ')'
+ self.lower = "[" if inclusive_lower else "("
+ self.upper = "]" if inclusive_upper else ")"
def as_sql(self, compiler, connection):
return "'%s%s'" % (self.lower, self.upper), []
@@ -28,41 +34,43 @@ class RangeBoundary(models.Expression):
class RangeOperators:
# https://www.postgresql.org/docs/current/functions-range.html#RANGE-OPERATORS-TABLE
- EQUAL = '='
- NOT_EQUAL = '<>'
- CONTAINS = '@>'
- CONTAINED_BY = '<@'
- OVERLAPS = '&&'
- FULLY_LT = '<<'
- FULLY_GT = '>>'
- NOT_LT = '&>'
- NOT_GT = '&<'
- ADJACENT_TO = '-|-'
+ EQUAL = "="
+ NOT_EQUAL = "<>"
+ CONTAINS = "@>"
+ CONTAINED_BY = "<@"
+ OVERLAPS = "&&"
+ FULLY_LT = "<<"
+ FULLY_GT = ">>"
+ NOT_LT = "&>"
+ NOT_GT = "&<"
+ ADJACENT_TO = "-|-"
class RangeField(models.Field):
empty_strings_allowed = False
def __init__(self, *args, **kwargs):
- if 'default_bounds' in kwargs:
+ if "default_bounds" in kwargs:
raise TypeError(
f"Cannot use 'default_bounds' with {self.__class__.__name__}."
)
# Initializing base_field here ensures that its model matches the model for self.
- if hasattr(self, 'base_field'):
+ if hasattr(self, "base_field"):
self.base_field = self.base_field()
super().__init__(*args, **kwargs)
@property
def model(self):
try:
- return self.__dict__['model']
+ return self.__dict__["model"]
except KeyError:
- raise AttributeError("'%s' object has no attribute 'model'" % self.__class__.__name__)
+ raise AttributeError(
+ "'%s' object has no attribute 'model'" % self.__class__.__name__
+ )
@model.setter
def model(self, model):
- self.__dict__['model'] = model
+ self.__dict__["model"] = model
self.base_field.model = model
@classmethod
@@ -82,7 +90,7 @@ class RangeField(models.Field):
if isinstance(value, str):
# Assume we're deserializing
vals = json.loads(value)
- for end in ('lower', 'upper'):
+ for end in ("lower", "upper"):
if end in vals:
vals[end] = self.base_field.to_python(vals[end])
value = self.range_type(**vals)
@@ -102,7 +110,7 @@ class RangeField(models.Field):
return json.dumps({"empty": True})
base_field = self.base_field
result = {"bounds": value._bounds}
- for end in ('lower', 'upper'):
+ for end in ("lower", "upper"):
val = getattr(value, end)
if val is None:
result[end] = None
@@ -112,11 +120,11 @@ class RangeField(models.Field):
return json.dumps(result)
def formfield(self, **kwargs):
- kwargs.setdefault('form_class', self.form_field)
+ kwargs.setdefault("form_class", self.form_field)
return super().formfield(**kwargs)
-CANONICAL_RANGE_BOUNDS = '[)'
+CANONICAL_RANGE_BOUNDS = "[)"
class ContinuousRangeField(RangeField):
@@ -126,7 +134,7 @@ class ContinuousRangeField(RangeField):
"""
def __init__(self, *args, default_bounds=CANONICAL_RANGE_BOUNDS, **kwargs):
- if default_bounds not in ('[)', '(]', '()', '[]'):
+ if default_bounds not in ("[)", "(]", "()", "[]"):
raise ValueError("default_bounds must be one of '[)', '(]', '()', or '[]'.")
self.default_bounds = default_bounds
super().__init__(*args, **kwargs)
@@ -137,13 +145,13 @@ class ContinuousRangeField(RangeField):
return super().get_prep_value(value)
def formfield(self, **kwargs):
- kwargs.setdefault('default_bounds', self.default_bounds)
+ kwargs.setdefault("default_bounds", self.default_bounds)
return super().formfield(**kwargs)
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if self.default_bounds and self.default_bounds != CANONICAL_RANGE_BOUNDS:
- kwargs['default_bounds'] = self.default_bounds
+ kwargs["default_bounds"] = self.default_bounds
return name, path, args, kwargs
@@ -153,7 +161,7 @@ class IntegerRangeField(RangeField):
form_field = forms.IntegerRangeField
def db_type(self, connection):
- return 'int4range'
+ return "int4range"
class BigIntegerRangeField(RangeField):
@@ -162,7 +170,7 @@ class BigIntegerRangeField(RangeField):
form_field = forms.IntegerRangeField
def db_type(self, connection):
- return 'int8range'
+ return "int8range"
class DecimalRangeField(ContinuousRangeField):
@@ -171,7 +179,7 @@ class DecimalRangeField(ContinuousRangeField):
form_field = forms.DecimalRangeField
def db_type(self, connection):
- return 'numrange'
+ return "numrange"
class DateTimeRangeField(ContinuousRangeField):
@@ -180,7 +188,7 @@ class DateTimeRangeField(ContinuousRangeField):
form_field = forms.DateTimeRangeField
def db_type(self, connection):
- return 'tstzrange'
+ return "tstzrange"
class DateRangeField(RangeField):
@@ -189,7 +197,7 @@ class DateRangeField(RangeField):
form_field = forms.DateRangeField
def db_type(self, connection):
- return 'daterange'
+ return "daterange"
RangeField.register_lookup(lookups.DataContains)
@@ -202,7 +210,8 @@ class DateTimeRangeContains(PostgresOperatorLookup):
Lookup for Date/DateTimeRange containment to cast the rhs to the correct
type.
"""
- lookup_name = 'contains'
+
+ lookup_name = "contains"
postgres_operator = RangeOperators.CONTAINS
def process_rhs(self, compiler, connection):
@@ -215,16 +224,19 @@ class DateTimeRangeContains(PostgresOperatorLookup):
def as_postgresql(self, compiler, connection):
sql, params = super().as_postgresql(compiler, connection)
# Cast the rhs if needed.
- cast_sql = ''
+ cast_sql = ""
if (
- isinstance(self.rhs, models.Expression) and
- self.rhs._output_field_or_none and
+ isinstance(self.rhs, models.Expression)
+ and self.rhs._output_field_or_none
+ and
# Skip cast if rhs has a matching range type.
- not isinstance(self.rhs._output_field_or_none, self.lhs.output_field.__class__)
+ not isinstance(
+ self.rhs._output_field_or_none, self.lhs.output_field.__class__
+ )
):
cast_internal_type = self.lhs.output_field.base_field.get_internal_type()
- cast_sql = '::{}'.format(connection.data_types.get(cast_internal_type))
- return '%s%s' % (sql, cast_sql), params
+ cast_sql = "::{}".format(connection.data_types.get(cast_internal_type))
+ return "%s%s" % (sql, cast_sql), params
DateRangeField.register_lookup(DateTimeRangeContains)
@@ -232,31 +244,31 @@ DateTimeRangeField.register_lookup(DateTimeRangeContains)
class RangeContainedBy(PostgresOperatorLookup):
- lookup_name = 'contained_by'
+ lookup_name = "contained_by"
type_mapping = {
- 'smallint': 'int4range',
- 'integer': 'int4range',
- 'bigint': 'int8range',
- 'double precision': 'numrange',
- 'numeric': 'numrange',
- 'date': 'daterange',
- 'timestamp with time zone': 'tstzrange',
+ "smallint": "int4range",
+ "integer": "int4range",
+ "bigint": "int8range",
+ "double precision": "numrange",
+ "numeric": "numrange",
+ "date": "daterange",
+ "timestamp with time zone": "tstzrange",
}
postgres_operator = RangeOperators.CONTAINED_BY
def process_rhs(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
# Ignore precision for DecimalFields.
- db_type = self.lhs.output_field.cast_db_type(connection).split('(')[0]
+ db_type = self.lhs.output_field.cast_db_type(connection).split("(")[0]
cast_type = self.type_mapping[db_type]
- return '%s::%s' % (rhs, cast_type), rhs_params
+ return "%s::%s" % (rhs, cast_type), rhs_params
def process_lhs(self, compiler, connection):
lhs, lhs_params = super().process_lhs(compiler, connection)
if isinstance(self.lhs.output_field, models.FloatField):
- lhs = '%s::numeric' % lhs
+ lhs = "%s::numeric" % lhs
elif isinstance(self.lhs.output_field, models.SmallIntegerField):
- lhs = '%s::integer' % lhs
+ lhs = "%s::integer" % lhs
return lhs, lhs_params
def get_prep_lookup(self):
@@ -272,38 +284,38 @@ models.DecimalField.register_lookup(RangeContainedBy)
@RangeField.register_lookup
class FullyLessThan(PostgresOperatorLookup):
- lookup_name = 'fully_lt'
+ lookup_name = "fully_lt"
postgres_operator = RangeOperators.FULLY_LT
@RangeField.register_lookup
class FullGreaterThan(PostgresOperatorLookup):
- lookup_name = 'fully_gt'
+ lookup_name = "fully_gt"
postgres_operator = RangeOperators.FULLY_GT
@RangeField.register_lookup
class NotLessThan(PostgresOperatorLookup):
- lookup_name = 'not_lt'
+ lookup_name = "not_lt"
postgres_operator = RangeOperators.NOT_LT
@RangeField.register_lookup
class NotGreaterThan(PostgresOperatorLookup):
- lookup_name = 'not_gt'
+ lookup_name = "not_gt"
postgres_operator = RangeOperators.NOT_GT
@RangeField.register_lookup
class AdjacentToLookup(PostgresOperatorLookup):
- lookup_name = 'adjacent_to'
+ lookup_name = "adjacent_to"
postgres_operator = RangeOperators.ADJACENT_TO
@RangeField.register_lookup
class RangeStartsWith(models.Transform):
- lookup_name = 'startswith'
- function = 'lower'
+ lookup_name = "startswith"
+ function = "lower"
@property
def output_field(self):
@@ -312,8 +324,8 @@ class RangeStartsWith(models.Transform):
@RangeField.register_lookup
class RangeEndsWith(models.Transform):
- lookup_name = 'endswith'
- function = 'upper'
+ lookup_name = "endswith"
+ function = "upper"
@property
def output_field(self):
@@ -322,34 +334,34 @@ class RangeEndsWith(models.Transform):
@RangeField.register_lookup
class IsEmpty(models.Transform):
- lookup_name = 'isempty'
- function = 'isempty'
+ lookup_name = "isempty"
+ function = "isempty"
output_field = models.BooleanField()
@RangeField.register_lookup
class LowerInclusive(models.Transform):
- lookup_name = 'lower_inc'
- function = 'LOWER_INC'
+ lookup_name = "lower_inc"
+ function = "LOWER_INC"
output_field = models.BooleanField()
@RangeField.register_lookup
class LowerInfinite(models.Transform):
- lookup_name = 'lower_inf'
- function = 'LOWER_INF'
+ lookup_name = "lower_inf"
+ function = "LOWER_INF"
output_field = models.BooleanField()
@RangeField.register_lookup
class UpperInclusive(models.Transform):
- lookup_name = 'upper_inc'
- function = 'UPPER_INC'
+ lookup_name = "upper_inc"
+ function = "UPPER_INC"
output_field = models.BooleanField()
@RangeField.register_lookup
class UpperInfinite(models.Transform):
- lookup_name = 'upper_inf'
- function = 'UPPER_INF'
+ lookup_name = "upper_inf"
+ function = "UPPER_INF"
output_field = models.BooleanField()
diff --git a/django/contrib/postgres/forms/array.py b/django/contrib/postgres/forms/array.py
index 2e19cd574a..ddb022afc3 100644
--- a/django/contrib/postgres/forms/array.py
+++ b/django/contrib/postgres/forms/array.py
@@ -3,7 +3,8 @@ from itertools import chain
from django import forms
from django.contrib.postgres.validators import (
- ArrayMaxLengthValidator, ArrayMinLengthValidator,
+ ArrayMaxLengthValidator,
+ ArrayMinLengthValidator,
)
from django.core.exceptions import ValidationError
from django.utils.translation import gettext_lazy as _
@@ -13,10 +14,12 @@ from ..utils import prefix_validation_error
class SimpleArrayField(forms.CharField):
default_error_messages = {
- 'item_invalid': _('Item %(nth)s in the array did not validate:'),
+ "item_invalid": _("Item %(nth)s in the array did not validate:"),
}
- def __init__(self, base_field, *, delimiter=',', max_length=None, min_length=None, **kwargs):
+ def __init__(
+ self, base_field, *, delimiter=",", max_length=None, min_length=None, **kwargs
+ ):
self.base_field = base_field
self.delimiter = delimiter
super().__init__(**kwargs)
@@ -33,7 +36,9 @@ class SimpleArrayField(forms.CharField):
def prepare_value(self, value):
if isinstance(value, list):
- return self.delimiter.join(str(self.base_field.prepare_value(v)) for v in value)
+ return self.delimiter.join(
+ str(self.base_field.prepare_value(v)) for v in value
+ )
return value
def to_python(self, value):
@@ -49,12 +54,14 @@ class SimpleArrayField(forms.CharField):
try:
values.append(self.base_field.to_python(item))
except ValidationError as error:
- errors.append(prefix_validation_error(
- error,
- prefix=self.error_messages['item_invalid'],
- code='item_invalid',
- params={'nth': index + 1},
- ))
+ errors.append(
+ prefix_validation_error(
+ error,
+ prefix=self.error_messages["item_invalid"],
+ code="item_invalid",
+ params={"nth": index + 1},
+ )
+ )
if errors:
raise ValidationError(errors)
return values
@@ -66,12 +73,14 @@ class SimpleArrayField(forms.CharField):
try:
self.base_field.validate(item)
except ValidationError as error:
- errors.append(prefix_validation_error(
- error,
- prefix=self.error_messages['item_invalid'],
- code='item_invalid',
- params={'nth': index + 1},
- ))
+ errors.append(
+ prefix_validation_error(
+ error,
+ prefix=self.error_messages["item_invalid"],
+ code="item_invalid",
+ params={"nth": index + 1},
+ )
+ )
if errors:
raise ValidationError(errors)
@@ -82,12 +91,14 @@ class SimpleArrayField(forms.CharField):
try:
self.base_field.run_validators(item)
except ValidationError as error:
- errors.append(prefix_validation_error(
- error,
- prefix=self.error_messages['item_invalid'],
- code='item_invalid',
- params={'nth': index + 1},
- ))
+ errors.append(
+ prefix_validation_error(
+ error,
+ prefix=self.error_messages["item_invalid"],
+ code="item_invalid",
+ params={"nth": index + 1},
+ )
+ )
if errors:
raise ValidationError(errors)
@@ -103,7 +114,7 @@ class SimpleArrayField(forms.CharField):
class SplitArrayWidget(forms.Widget):
- template_name = 'postgres/widgets/split_array.html'
+ template_name = "postgres/widgets/split_array.html"
def __init__(self, widget, size, **kwargs):
self.widget = widget() if isinstance(widget, type) else widget
@@ -115,19 +126,21 @@ class SplitArrayWidget(forms.Widget):
return self.widget.is_hidden
def value_from_datadict(self, data, files, name):
- return [self.widget.value_from_datadict(data, files, '%s_%s' % (name, index))
- for index in range(self.size)]
+ return [
+ self.widget.value_from_datadict(data, files, "%s_%s" % (name, index))
+ for index in range(self.size)
+ ]
def value_omitted_from_data(self, data, files, name):
return all(
- self.widget.value_omitted_from_data(data, files, '%s_%s' % (name, index))
+ self.widget.value_omitted_from_data(data, files, "%s_%s" % (name, index))
for index in range(self.size)
)
def id_for_label(self, id_):
# See the comment for RadioSelect.id_for_label()
if id_:
- id_ += '_0'
+ id_ += "_0"
return id_
def get_context(self, name, value, attrs=None):
@@ -136,18 +149,20 @@ class SplitArrayWidget(forms.Widget):
if self.is_localized:
self.widget.is_localized = self.is_localized
value = value or []
- context['widget']['subwidgets'] = []
+ context["widget"]["subwidgets"] = []
final_attrs = self.build_attrs(attrs)
- id_ = final_attrs.get('id')
+ id_ = final_attrs.get("id")
for i in range(max(len(value), self.size)):
try:
widget_value = value[i]
except IndexError:
widget_value = None
if id_:
- final_attrs = {**final_attrs, 'id': '%s_%s' % (id_, i)}
- context['widget']['subwidgets'].append(
- self.widget.get_context(name + '_%s' % i, widget_value, final_attrs)['widget']
+ final_attrs = {**final_attrs, "id": "%s_%s" % (id_, i)}
+ context["widget"]["subwidgets"].append(
+ self.widget.get_context(name + "_%s" % i, widget_value, final_attrs)[
+ "widget"
+ ]
)
return context
@@ -167,7 +182,7 @@ class SplitArrayWidget(forms.Widget):
class SplitArrayField(forms.Field):
default_error_messages = {
- 'item_invalid': _('Item %(nth)s in the array did not validate:'),
+ "item_invalid": _("Item %(nth)s in the array did not validate:"),
}
def __init__(self, base_field, size, *, remove_trailing_nulls=False, **kwargs):
@@ -175,7 +190,7 @@ class SplitArrayField(forms.Field):
self.size = size
self.remove_trailing_nulls = remove_trailing_nulls
widget = SplitArrayWidget(widget=base_field.widget, size=size)
- kwargs.setdefault('widget', widget)
+ kwargs.setdefault("widget", widget)
super().__init__(**kwargs)
def _remove_trailing_nulls(self, values):
@@ -198,19 +213,21 @@ class SplitArrayField(forms.Field):
cleaned_data = []
errors = []
if not any(value) and self.required:
- raise ValidationError(self.error_messages['required'])
+ raise ValidationError(self.error_messages["required"])
max_size = max(self.size, len(value))
for index in range(max_size):
item = value[index]
try:
cleaned_data.append(self.base_field.clean(item))
except ValidationError as error:
- errors.append(prefix_validation_error(
- error,
- self.error_messages['item_invalid'],
- code='item_invalid',
- params={'nth': index + 1},
- ))
+ errors.append(
+ prefix_validation_error(
+ error,
+ self.error_messages["item_invalid"],
+ code="item_invalid",
+ params={"nth": index + 1},
+ )
+ )
cleaned_data.append(None)
else:
errors.append(None)
diff --git a/django/contrib/postgres/forms/hstore.py b/django/contrib/postgres/forms/hstore.py
index f5af8f10e3..6a20f7b729 100644
--- a/django/contrib/postgres/forms/hstore.py
+++ b/django/contrib/postgres/forms/hstore.py
@@ -4,17 +4,18 @@ from django import forms
from django.core.exceptions import ValidationError
from django.utils.translation import gettext_lazy as _
-__all__ = ['HStoreField']
+__all__ = ["HStoreField"]
class HStoreField(forms.CharField):
"""
A field for HStore data which accepts dictionary JSON input.
"""
+
widget = forms.Textarea
default_error_messages = {
- 'invalid_json': _('Could not load JSON data.'),
- 'invalid_format': _('Input must be a JSON dictionary.'),
+ "invalid_json": _("Could not load JSON data."),
+ "invalid_format": _("Input must be a JSON dictionary."),
}
def prepare_value(self, value):
@@ -30,14 +31,14 @@ class HStoreField(forms.CharField):
value = json.loads(value)
except json.JSONDecodeError:
raise ValidationError(
- self.error_messages['invalid_json'],
- code='invalid_json',
+ self.error_messages["invalid_json"],
+ code="invalid_json",
)
if not isinstance(value, dict):
raise ValidationError(
- self.error_messages['invalid_format'],
- code='invalid_format',
+ self.error_messages["invalid_format"],
+ code="invalid_format",
)
# Cast everything to strings for ease.
diff --git a/django/contrib/postgres/forms/ranges.py b/django/contrib/postgres/forms/ranges.py
index 9c673ab40c..444991970d 100644
--- a/django/contrib/postgres/forms/ranges.py
+++ b/django/contrib/postgres/forms/ranges.py
@@ -6,8 +6,13 @@ from django.forms.widgets import HiddenInput, MultiWidget
from django.utils.translation import gettext_lazy as _
__all__ = [
- 'BaseRangeField', 'IntegerRangeField', 'DecimalRangeField',
- 'DateTimeRangeField', 'DateRangeField', 'HiddenRangeWidget', 'RangeWidget',
+ "BaseRangeField",
+ "IntegerRangeField",
+ "DecimalRangeField",
+ "DateTimeRangeField",
+ "DateRangeField",
+ "HiddenRangeWidget",
+ "RangeWidget",
]
@@ -24,27 +29,33 @@ class RangeWidget(MultiWidget):
class HiddenRangeWidget(RangeWidget):
"""A widget that splits input into two <input type="hidden"> inputs."""
+
def __init__(self, attrs=None):
super().__init__(HiddenInput, attrs)
class BaseRangeField(forms.MultiValueField):
default_error_messages = {
- 'invalid': _('Enter two valid values.'),
- 'bound_ordering': _('The start of the range must not exceed the end of the range.'),
+ "invalid": _("Enter two valid values."),
+ "bound_ordering": _(
+ "The start of the range must not exceed the end of the range."
+ ),
}
hidden_widget = HiddenRangeWidget
def __init__(self, **kwargs):
- if 'widget' not in kwargs:
- kwargs['widget'] = RangeWidget(self.base_field.widget)
- if 'fields' not in kwargs:
- kwargs['fields'] = [self.base_field(required=False), self.base_field(required=False)]
- kwargs.setdefault('required', False)
- kwargs.setdefault('require_all_fields', False)
+ if "widget" not in kwargs:
+ kwargs["widget"] = RangeWidget(self.base_field.widget)
+ if "fields" not in kwargs:
+ kwargs["fields"] = [
+ self.base_field(required=False),
+ self.base_field(required=False),
+ ]
+ kwargs.setdefault("required", False)
+ kwargs.setdefault("require_all_fields", False)
self.range_kwargs = {}
- if default_bounds := kwargs.pop('default_bounds', None):
- self.range_kwargs = {'bounds': default_bounds}
+ if default_bounds := kwargs.pop("default_bounds", None):
+ self.range_kwargs = {"bounds": default_bounds}
super().__init__(**kwargs)
def prepare_value(self, value):
@@ -67,39 +78,39 @@ class BaseRangeField(forms.MultiValueField):
lower, upper = values
if lower is not None and upper is not None and lower > upper:
raise exceptions.ValidationError(
- self.error_messages['bound_ordering'],
- code='bound_ordering',
+ self.error_messages["bound_ordering"],
+ code="bound_ordering",
)
try:
range_value = self.range_type(lower, upper, **self.range_kwargs)
except TypeError:
raise exceptions.ValidationError(
- self.error_messages['invalid'],
- code='invalid',
+ self.error_messages["invalid"],
+ code="invalid",
)
else:
return range_value
class IntegerRangeField(BaseRangeField):
- default_error_messages = {'invalid': _('Enter two whole numbers.')}
+ default_error_messages = {"invalid": _("Enter two whole numbers.")}
base_field = forms.IntegerField
range_type = NumericRange
class DecimalRangeField(BaseRangeField):
- default_error_messages = {'invalid': _('Enter two numbers.')}
+ default_error_messages = {"invalid": _("Enter two numbers.")}
base_field = forms.DecimalField
range_type = NumericRange
class DateTimeRangeField(BaseRangeField):
- default_error_messages = {'invalid': _('Enter two valid date/times.')}
+ default_error_messages = {"invalid": _("Enter two valid date/times.")}
base_field = forms.DateTimeField
range_type = DateTimeTZRange
class DateRangeField(BaseRangeField):
- default_error_messages = {'invalid': _('Enter two valid dates.')}
+ default_error_messages = {"invalid": _("Enter two valid dates.")}
base_field = forms.DateField
range_type = DateRange
diff --git a/django/contrib/postgres/functions.py b/django/contrib/postgres/functions.py
index 819ce058e5..f001a04fdc 100644
--- a/django/contrib/postgres/functions.py
+++ b/django/contrib/postgres/functions.py
@@ -2,10 +2,10 @@ from django.db.models import DateTimeField, Func, UUIDField
class RandomUUID(Func):
- template = 'GEN_RANDOM_UUID()'
+ template = "GEN_RANDOM_UUID()"
output_field = UUIDField()
class TransactionNow(Func):
- template = 'CURRENT_TIMESTAMP'
+ template = "CURRENT_TIMESTAMP"
output_field = DateTimeField()
diff --git a/django/contrib/postgres/indexes.py b/django/contrib/postgres/indexes.py
index 2e3de1d275..409f514147 100644
--- a/django/contrib/postgres/indexes.py
+++ b/django/contrib/postgres/indexes.py
@@ -3,13 +3,17 @@ from django.db.models import Func, Index
from django.utils.functional import cached_property
__all__ = [
- 'BloomIndex', 'BrinIndex', 'BTreeIndex', 'GinIndex', 'GistIndex',
- 'HashIndex', 'SpGistIndex',
+ "BloomIndex",
+ "BrinIndex",
+ "BTreeIndex",
+ "GinIndex",
+ "GistIndex",
+ "HashIndex",
+ "SpGistIndex",
]
class PostgresIndex(Index):
-
@cached_property
def max_name_length(self):
# Allow an index name longer than 30 characters when the suffix is
@@ -18,14 +22,16 @@ class PostgresIndex(Index):
# indexes.
return Index.max_name_length - len(Index.suffix) + len(self.suffix)
- def create_sql(self, model, schema_editor, using='', **kwargs):
+ def create_sql(self, model, schema_editor, using="", **kwargs):
self.check_supported(schema_editor)
- statement = super().create_sql(model, schema_editor, using=' USING %s' % self.suffix, **kwargs)
+ statement = super().create_sql(
+ model, schema_editor, using=" USING %s" % self.suffix, **kwargs
+ )
with_params = self.get_with_params()
if with_params:
- statement.parts['extra'] = 'WITH (%s) %s' % (
- ', '.join(with_params),
- statement.parts['extra'],
+ statement.parts["extra"] = "WITH (%s) %s" % (
+ ", ".join(with_params),
+ statement.parts["extra"],
)
return statement
@@ -37,25 +43,23 @@ class PostgresIndex(Index):
class BloomIndex(PostgresIndex):
- suffix = 'bloom'
+ suffix = "bloom"
def __init__(self, *expressions, length=None, columns=(), **kwargs):
super().__init__(*expressions, **kwargs)
if len(self.fields) > 32:
- raise ValueError('Bloom indexes support a maximum of 32 fields.')
+ raise ValueError("Bloom indexes support a maximum of 32 fields.")
if not isinstance(columns, (list, tuple)):
- raise ValueError('BloomIndex.columns must be a list or tuple.')
+ raise ValueError("BloomIndex.columns must be a list or tuple.")
if len(columns) > len(self.fields):
- raise ValueError(
- 'BloomIndex.columns cannot have more values than fields.'
- )
+ raise ValueError("BloomIndex.columns cannot have more values than fields.")
if not all(0 < col <= 4095 for col in columns):
raise ValueError(
- 'BloomIndex.columns must contain integers from 1 to 4095.',
+ "BloomIndex.columns must contain integers from 1 to 4095.",
)
if length is not None and not 0 < length <= 4096:
raise ValueError(
- 'BloomIndex.length must be None or an integer from 1 to 4096.',
+ "BloomIndex.length must be None or an integer from 1 to 4096.",
)
self.length = length
self.columns = columns
@@ -63,29 +67,30 @@ class BloomIndex(PostgresIndex):
def deconstruct(self):
path, args, kwargs = super().deconstruct()
if self.length is not None:
- kwargs['length'] = self.length
+ kwargs["length"] = self.length
if self.columns:
- kwargs['columns'] = self.columns
+ kwargs["columns"] = self.columns
return path, args, kwargs
def get_with_params(self):
with_params = []
if self.length is not None:
- with_params.append('length = %d' % self.length)
+ with_params.append("length = %d" % self.length)
if self.columns:
with_params.extend(
- 'col%d = %d' % (i, v)
- for i, v in enumerate(self.columns, start=1)
+ "col%d = %d" % (i, v) for i, v in enumerate(self.columns, start=1)
)
return with_params
class BrinIndex(PostgresIndex):
- suffix = 'brin'
+ suffix = "brin"
- def __init__(self, *expressions, autosummarize=None, pages_per_range=None, **kwargs):
+ def __init__(
+ self, *expressions, autosummarize=None, pages_per_range=None, **kwargs
+ ):
if pages_per_range is not None and pages_per_range <= 0:
- raise ValueError('pages_per_range must be None or a positive integer')
+ raise ValueError("pages_per_range must be None or a positive integer")
self.autosummarize = autosummarize
self.pages_per_range = pages_per_range
super().__init__(*expressions, **kwargs)
@@ -93,22 +98,24 @@ class BrinIndex(PostgresIndex):
def deconstruct(self):
path, args, kwargs = super().deconstruct()
if self.autosummarize is not None:
- kwargs['autosummarize'] = self.autosummarize
+ kwargs["autosummarize"] = self.autosummarize
if self.pages_per_range is not None:
- kwargs['pages_per_range'] = self.pages_per_range
+ kwargs["pages_per_range"] = self.pages_per_range
return path, args, kwargs
def get_with_params(self):
with_params = []
if self.autosummarize is not None:
- with_params.append('autosummarize = %s' % ('on' if self.autosummarize else 'off'))
+ with_params.append(
+ "autosummarize = %s" % ("on" if self.autosummarize else "off")
+ )
if self.pages_per_range is not None:
- with_params.append('pages_per_range = %d' % self.pages_per_range)
+ with_params.append("pages_per_range = %d" % self.pages_per_range)
return with_params
class BTreeIndex(PostgresIndex):
- suffix = 'btree'
+ suffix = "btree"
def __init__(self, *expressions, fillfactor=None, **kwargs):
self.fillfactor = fillfactor
@@ -117,20 +124,22 @@ class BTreeIndex(PostgresIndex):
def deconstruct(self):
path, args, kwargs = super().deconstruct()
if self.fillfactor is not None:
- kwargs['fillfactor'] = self.fillfactor
+ kwargs["fillfactor"] = self.fillfactor
return path, args, kwargs
def get_with_params(self):
with_params = []
if self.fillfactor is not None:
- with_params.append('fillfactor = %d' % self.fillfactor)
+ with_params.append("fillfactor = %d" % self.fillfactor)
return with_params
class GinIndex(PostgresIndex):
- suffix = 'gin'
+ suffix = "gin"
- def __init__(self, *expressions, fastupdate=None, gin_pending_list_limit=None, **kwargs):
+ def __init__(
+ self, *expressions, fastupdate=None, gin_pending_list_limit=None, **kwargs
+ ):
self.fastupdate = fastupdate
self.gin_pending_list_limit = gin_pending_list_limit
super().__init__(*expressions, **kwargs)
@@ -138,22 +147,24 @@ class GinIndex(PostgresIndex):
def deconstruct(self):
path, args, kwargs = super().deconstruct()
if self.fastupdate is not None:
- kwargs['fastupdate'] = self.fastupdate
+ kwargs["fastupdate"] = self.fastupdate
if self.gin_pending_list_limit is not None:
- kwargs['gin_pending_list_limit'] = self.gin_pending_list_limit
+ kwargs["gin_pending_list_limit"] = self.gin_pending_list_limit
return path, args, kwargs
def get_with_params(self):
with_params = []
if self.gin_pending_list_limit is not None:
- with_params.append('gin_pending_list_limit = %d' % self.gin_pending_list_limit)
+ with_params.append(
+ "gin_pending_list_limit = %d" % self.gin_pending_list_limit
+ )
if self.fastupdate is not None:
- with_params.append('fastupdate = %s' % ('on' if self.fastupdate else 'off'))
+ with_params.append("fastupdate = %s" % ("on" if self.fastupdate else "off"))
return with_params
class GistIndex(PostgresIndex):
- suffix = 'gist'
+ suffix = "gist"
def __init__(self, *expressions, buffering=None, fillfactor=None, **kwargs):
self.buffering = buffering
@@ -163,26 +174,29 @@ class GistIndex(PostgresIndex):
def deconstruct(self):
path, args, kwargs = super().deconstruct()
if self.buffering is not None:
- kwargs['buffering'] = self.buffering
+ kwargs["buffering"] = self.buffering
if self.fillfactor is not None:
- kwargs['fillfactor'] = self.fillfactor
+ kwargs["fillfactor"] = self.fillfactor
return path, args, kwargs
def get_with_params(self):
with_params = []
if self.buffering is not None:
- with_params.append('buffering = %s' % ('on' if self.buffering else 'off'))
+ with_params.append("buffering = %s" % ("on" if self.buffering else "off"))
if self.fillfactor is not None:
- with_params.append('fillfactor = %d' % self.fillfactor)
+ with_params.append("fillfactor = %d" % self.fillfactor)
return with_params
def check_supported(self, schema_editor):
- if self.include and not schema_editor.connection.features.supports_covering_gist_indexes:
- raise NotSupportedError('Covering GiST indexes require PostgreSQL 12+.')
+ if (
+ self.include
+ and not schema_editor.connection.features.supports_covering_gist_indexes
+ ):
+ raise NotSupportedError("Covering GiST indexes require PostgreSQL 12+.")
class HashIndex(PostgresIndex):
- suffix = 'hash'
+ suffix = "hash"
def __init__(self, *expressions, fillfactor=None, **kwargs):
self.fillfactor = fillfactor
@@ -191,18 +205,18 @@ class HashIndex(PostgresIndex):
def deconstruct(self):
path, args, kwargs = super().deconstruct()
if self.fillfactor is not None:
- kwargs['fillfactor'] = self.fillfactor
+ kwargs["fillfactor"] = self.fillfactor
return path, args, kwargs
def get_with_params(self):
with_params = []
if self.fillfactor is not None:
- with_params.append('fillfactor = %d' % self.fillfactor)
+ with_params.append("fillfactor = %d" % self.fillfactor)
return with_params
class SpGistIndex(PostgresIndex):
- suffix = 'spgist'
+ suffix = "spgist"
def __init__(self, *expressions, fillfactor=None, **kwargs):
self.fillfactor = fillfactor
@@ -211,25 +225,25 @@ class SpGistIndex(PostgresIndex):
def deconstruct(self):
path, args, kwargs = super().deconstruct()
if self.fillfactor is not None:
- kwargs['fillfactor'] = self.fillfactor
+ kwargs["fillfactor"] = self.fillfactor
return path, args, kwargs
def get_with_params(self):
with_params = []
if self.fillfactor is not None:
- with_params.append('fillfactor = %d' % self.fillfactor)
+ with_params.append("fillfactor = %d" % self.fillfactor)
return with_params
def check_supported(self, schema_editor):
if (
- self.include and
- not schema_editor.connection.features.supports_covering_spgist_indexes
+ self.include
+ and not schema_editor.connection.features.supports_covering_spgist_indexes
):
- raise NotSupportedError('Covering SP-GiST indexes require PostgreSQL 14+.')
+ raise NotSupportedError("Covering SP-GiST indexes require PostgreSQL 14+.")
class OpClass(Func):
- template = '%(expressions)s %(name)s'
+ template = "%(expressions)s %(name)s"
def __init__(self, expression, name):
super().__init__(expression, name=name)
diff --git a/django/contrib/postgres/lookups.py b/django/contrib/postgres/lookups.py
index f7c6fc4b0c..9fed0eea30 100644
--- a/django/contrib/postgres/lookups.py
+++ b/django/contrib/postgres/lookups.py
@@ -5,61 +5,61 @@ from .search import SearchVector, SearchVectorExact, SearchVectorField
class DataContains(PostgresOperatorLookup):
- lookup_name = 'contains'
- postgres_operator = '@>'
+ lookup_name = "contains"
+ postgres_operator = "@>"
class ContainedBy(PostgresOperatorLookup):
- lookup_name = 'contained_by'
- postgres_operator = '<@'
+ lookup_name = "contained_by"
+ postgres_operator = "<@"
class Overlap(PostgresOperatorLookup):
- lookup_name = 'overlap'
- postgres_operator = '&&'
+ lookup_name = "overlap"
+ postgres_operator = "&&"
class HasKey(PostgresOperatorLookup):
- lookup_name = 'has_key'
- postgres_operator = '?'
+ lookup_name = "has_key"
+ postgres_operator = "?"
prepare_rhs = False
class HasKeys(PostgresOperatorLookup):
- lookup_name = 'has_keys'
- postgres_operator = '?&'
+ lookup_name = "has_keys"
+ postgres_operator = "?&"
def get_prep_lookup(self):
return [str(item) for item in self.rhs]
class HasAnyKeys(HasKeys):
- lookup_name = 'has_any_keys'
- postgres_operator = '?|'
+ lookup_name = "has_any_keys"
+ postgres_operator = "?|"
class Unaccent(Transform):
bilateral = True
- lookup_name = 'unaccent'
- function = 'UNACCENT'
+ lookup_name = "unaccent"
+ function = "UNACCENT"
class SearchLookup(SearchVectorExact):
- lookup_name = 'search'
+ lookup_name = "search"
def process_lhs(self, qn, connection):
if not isinstance(self.lhs.output_field, SearchVectorField):
- config = getattr(self.rhs, 'config', None)
+ config = getattr(self.rhs, "config", None)
self.lhs = SearchVector(self.lhs, config=config)
lhs, lhs_params = super().process_lhs(qn, connection)
return lhs, lhs_params
class TrigramSimilar(PostgresOperatorLookup):
- lookup_name = 'trigram_similar'
- postgres_operator = '%%'
+ lookup_name = "trigram_similar"
+ postgres_operator = "%%"
class TrigramWordSimilar(PostgresOperatorLookup):
- lookup_name = 'trigram_word_similar'
- postgres_operator = '%%>'
+ lookup_name = "trigram_word_similar"
+ postgres_operator = "%%>"
diff --git a/django/contrib/postgres/operations.py b/django/contrib/postgres/operations.py
index 037bb4ec22..374f5ee1ec 100644
--- a/django/contrib/postgres/operations.py
+++ b/django/contrib/postgres/operations.py
@@ -1,5 +1,7 @@
from django.contrib.postgres.signals import (
- get_citext_oids, get_hstore_oids, register_type_handlers,
+ get_citext_oids,
+ get_hstore_oids,
+ register_type_handlers,
)
from django.db import NotSupportedError, router
from django.db.migrations import AddConstraint, AddIndex, RemoveIndex
@@ -17,14 +19,14 @@ class CreateExtension(Operation):
pass
def database_forwards(self, app_label, schema_editor, from_state, to_state):
- if (
- schema_editor.connection.vendor != 'postgresql' or
- not router.allow_migrate(schema_editor.connection.alias, app_label)
+ if schema_editor.connection.vendor != "postgresql" or not router.allow_migrate(
+ schema_editor.connection.alias, app_label
):
return
if not self.extension_exists(schema_editor, self.name):
schema_editor.execute(
- 'CREATE EXTENSION IF NOT EXISTS %s' % schema_editor.quote_name(self.name)
+ "CREATE EXTENSION IF NOT EXISTS %s"
+ % schema_editor.quote_name(self.name)
)
# Clear cached, stale oids.
get_hstore_oids.cache_clear()
@@ -39,7 +41,7 @@ class CreateExtension(Operation):
return
if self.extension_exists(schema_editor, self.name):
schema_editor.execute(
- 'DROP EXTENSION IF EXISTS %s' % schema_editor.quote_name(self.name)
+ "DROP EXTENSION IF EXISTS %s" % schema_editor.quote_name(self.name)
)
# Clear cached, stale oids.
get_hstore_oids.cache_clear()
@@ -48,7 +50,7 @@ class CreateExtension(Operation):
def extension_exists(self, schema_editor, extension):
with schema_editor.connection.cursor() as cursor:
cursor.execute(
- 'SELECT 1 FROM pg_extension WHERE extname = %s',
+ "SELECT 1 FROM pg_extension WHERE extname = %s",
[extension],
)
return bool(cursor.fetchone())
@@ -58,75 +60,67 @@ class CreateExtension(Operation):
@property
def migration_name_fragment(self):
- return 'create_extension_%s' % self.name
+ return "create_extension_%s" % self.name
class BloomExtension(CreateExtension):
-
def __init__(self):
- self.name = 'bloom'
+ self.name = "bloom"
class BtreeGinExtension(CreateExtension):
-
def __init__(self):
- self.name = 'btree_gin'
+ self.name = "btree_gin"
class BtreeGistExtension(CreateExtension):
-
def __init__(self):
- self.name = 'btree_gist'
+ self.name = "btree_gist"
class CITextExtension(CreateExtension):
-
def __init__(self):
- self.name = 'citext'
+ self.name = "citext"
class CryptoExtension(CreateExtension):
-
def __init__(self):
- self.name = 'pgcrypto'
+ self.name = "pgcrypto"
class HStoreExtension(CreateExtension):
-
def __init__(self):
- self.name = 'hstore'
+ self.name = "hstore"
class TrigramExtension(CreateExtension):
-
def __init__(self):
- self.name = 'pg_trgm'
+ self.name = "pg_trgm"
class UnaccentExtension(CreateExtension):
-
def __init__(self):
- self.name = 'unaccent'
+ self.name = "unaccent"
class NotInTransactionMixin:
def _ensure_not_in_transaction(self, schema_editor):
if schema_editor.connection.in_atomic_block:
raise NotSupportedError(
- 'The %s operation cannot be executed inside a transaction '
- '(set atomic = False on the migration).'
- % self.__class__.__name__
+ "The %s operation cannot be executed inside a transaction "
+ "(set atomic = False on the migration)." % self.__class__.__name__
)
class AddIndexConcurrently(NotInTransactionMixin, AddIndex):
"""Create an index using PostgreSQL's CREATE INDEX CONCURRENTLY syntax."""
+
atomic = False
def describe(self):
- return 'Concurrently create index %s on field(s) %s of model %s' % (
+ return "Concurrently create index %s on field(s) %s of model %s" % (
self.index.name,
- ', '.join(self.index.fields),
+ ", ".join(self.index.fields),
self.model_name,
)
@@ -145,10 +139,11 @@ class AddIndexConcurrently(NotInTransactionMixin, AddIndex):
class RemoveIndexConcurrently(NotInTransactionMixin, RemoveIndex):
"""Remove an index using PostgreSQL's DROP INDEX CONCURRENTLY syntax."""
+
atomic = False
def describe(self):
- return 'Concurrently remove index %s from %s' % (self.name, self.model_name)
+ return "Concurrently remove index %s from %s" % (self.name, self.model_name)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
self._ensure_not_in_transaction(schema_editor)
@@ -168,7 +163,7 @@ class RemoveIndexConcurrently(NotInTransactionMixin, RemoveIndex):
class CollationOperation(Operation):
- def __init__(self, name, locale, *, provider='libc', deterministic=True):
+ def __init__(self, name, locale, *, provider="libc", deterministic=True):
self.name = name
self.locale = locale
self.provider = provider
@@ -178,11 +173,11 @@ class CollationOperation(Operation):
pass
def deconstruct(self):
- kwargs = {'name': self.name, 'locale': self.locale}
- if self.provider and self.provider != 'libc':
- kwargs['provider'] = self.provider
+ kwargs = {"name": self.name, "locale": self.locale}
+ if self.provider and self.provider != "libc":
+ kwargs["provider"] = self.provider
if self.deterministic is False:
- kwargs['deterministic'] = self.deterministic
+ kwargs["deterministic"] = self.deterministic
return (
self.__class__.__qualname__,
[],
@@ -191,34 +186,39 @@ class CollationOperation(Operation):
def create_collation(self, schema_editor):
if (
- self.deterministic is False and
- not schema_editor.connection.features.supports_non_deterministic_collations
+ self.deterministic is False
+ and not schema_editor.connection.features.supports_non_deterministic_collations
):
raise NotSupportedError(
- 'Non-deterministic collations require PostgreSQL 12+.'
+ "Non-deterministic collations require PostgreSQL 12+."
)
- args = {'locale': schema_editor.quote_name(self.locale)}
- if self.provider != 'libc':
- args['provider'] = schema_editor.quote_name(self.provider)
+ args = {"locale": schema_editor.quote_name(self.locale)}
+ if self.provider != "libc":
+ args["provider"] = schema_editor.quote_name(self.provider)
if self.deterministic is False:
- args['deterministic'] = 'false'
- schema_editor.execute('CREATE COLLATION %(name)s (%(args)s)' % {
- 'name': schema_editor.quote_name(self.name),
- 'args': ', '.join(f'{option}={value}' for option, value in args.items()),
- })
+ args["deterministic"] = "false"
+ schema_editor.execute(
+ "CREATE COLLATION %(name)s (%(args)s)"
+ % {
+ "name": schema_editor.quote_name(self.name),
+ "args": ", ".join(
+ f"{option}={value}" for option, value in args.items()
+ ),
+ }
+ )
def remove_collation(self, schema_editor):
schema_editor.execute(
- 'DROP COLLATION %s' % schema_editor.quote_name(self.name),
+ "DROP COLLATION %s" % schema_editor.quote_name(self.name),
)
class CreateCollation(CollationOperation):
"""Create a collation."""
+
def database_forwards(self, app_label, schema_editor, from_state, to_state):
- if (
- schema_editor.connection.vendor != 'postgresql' or
- not router.allow_migrate(schema_editor.connection.alias, app_label)
+ if schema_editor.connection.vendor != "postgresql" or not router.allow_migrate(
+ schema_editor.connection.alias, app_label
):
return
self.create_collation(schema_editor)
@@ -229,19 +229,19 @@ class CreateCollation(CollationOperation):
self.remove_collation(schema_editor)
def describe(self):
- return f'Create collation {self.name}'
+ return f"Create collation {self.name}"
@property
def migration_name_fragment(self):
- return 'create_collation_%s' % self.name.lower()
+ return "create_collation_%s" % self.name.lower()
class RemoveCollation(CollationOperation):
"""Remove a collation."""
+
def database_forwards(self, app_label, schema_editor, from_state, to_state):
- if (
- schema_editor.connection.vendor != 'postgresql' or
- not router.allow_migrate(schema_editor.connection.alias, app_label)
+ if schema_editor.connection.vendor != "postgresql" or not router.allow_migrate(
+ schema_editor.connection.alias, app_label
):
return
self.remove_collation(schema_editor)
@@ -252,11 +252,11 @@ class RemoveCollation(CollationOperation):
self.create_collation(schema_editor)
def describe(self):
- return f'Remove collation {self.name}'
+ return f"Remove collation {self.name}"
@property
def migration_name_fragment(self):
- return 'remove_collation_%s' % self.name.lower()
+ return "remove_collation_%s" % self.name.lower()
class AddConstraintNotValid(AddConstraint):
@@ -268,12 +268,12 @@ class AddConstraintNotValid(AddConstraint):
def __init__(self, model_name, constraint):
if not isinstance(constraint, CheckConstraint):
raise TypeError(
- 'AddConstraintNotValid.constraint must be a check constraint.'
+ "AddConstraintNotValid.constraint must be a check constraint."
)
super().__init__(model_name, constraint)
def describe(self):
- return 'Create not valid constraint %s on model %s' % (
+ return "Create not valid constraint %s on model %s" % (
self.constraint.name,
self.model_name,
)
@@ -286,11 +286,11 @@ class AddConstraintNotValid(AddConstraint):
# Constraint.create_sql returns interpolated SQL which makes
# params=None a necessity to avoid escaping attempts on
# execution.
- schema_editor.execute(str(constraint_sql) + ' NOT VALID', params=None)
+ schema_editor.execute(str(constraint_sql) + " NOT VALID", params=None)
@property
def migration_name_fragment(self):
- return super().migration_name_fragment + '_not_valid'
+ return super().migration_name_fragment + "_not_valid"
class ValidateConstraint(Operation):
@@ -301,15 +301,18 @@ class ValidateConstraint(Operation):
self.name = name
def describe(self):
- return 'Validate constraint %s on model %s' % (self.name, self.model_name)
+ return "Validate constraint %s on model %s" % (self.name, self.model_name)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
model = from_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, model):
- schema_editor.execute('ALTER TABLE %s VALIDATE CONSTRAINT %s' % (
- schema_editor.quote_name(model._meta.db_table),
- schema_editor.quote_name(self.name),
- ))
+ schema_editor.execute(
+ "ALTER TABLE %s VALIDATE CONSTRAINT %s"
+ % (
+ schema_editor.quote_name(model._meta.db_table),
+ schema_editor.quote_name(self.name),
+ )
+ )
def database_backwards(self, app_label, schema_editor, from_state, to_state):
# PostgreSQL does not provide a way to make a constraint invalid.
@@ -320,10 +323,14 @@ class ValidateConstraint(Operation):
@property
def migration_name_fragment(self):
- return '%s_validate_%s' % (self.model_name.lower(), self.name.lower())
+ return "%s_validate_%s" % (self.model_name.lower(), self.name.lower())
def deconstruct(self):
- return self.__class__.__name__, [], {
- 'model_name': self.model_name,
- 'name': self.name,
- }
+ return (
+ self.__class__.__name__,
+ [],
+ {
+ "model_name": self.model_name,
+ "name": self.name,
+ },
+ )
diff --git a/django/contrib/postgres/search.py b/django/contrib/postgres/search.py
index 164d359b91..f652c1d346 100644
--- a/django/contrib/postgres/search.py
+++ b/django/contrib/postgres/search.py
@@ -1,18 +1,25 @@
import psycopg2
from django.db.models import (
- CharField, Expression, Field, FloatField, Func, Lookup, TextField, Value,
+ CharField,
+ Expression,
+ Field,
+ FloatField,
+ Func,
+ Lookup,
+ TextField,
+ Value,
)
from django.db.models.expressions import CombinedExpression
from django.db.models.functions import Cast, Coalesce
class SearchVectorExact(Lookup):
- lookup_name = 'exact'
+ lookup_name = "exact"
def process_rhs(self, qn, connection):
if not isinstance(self.rhs, (SearchQuery, CombinedSearchQuery)):
- config = getattr(self.lhs, 'config', None)
+ config = getattr(self.lhs, "config", None)
self.rhs = SearchQuery(self.rhs, config=config)
rhs, rhs_params = super().process_rhs(qn, connection)
return rhs, rhs_params
@@ -21,25 +28,23 @@ class SearchVectorExact(Lookup):
lhs, lhs_params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection)
params = lhs_params + rhs_params
- return '%s @@ %s' % (lhs, rhs), params
+ return "%s @@ %s" % (lhs, rhs), params
class SearchVectorField(Field):
-
def db_type(self, connection):
- return 'tsvector'
+ return "tsvector"
class SearchQueryField(Field):
-
def db_type(self, connection):
- return 'tsquery'
+ return "tsquery"
class SearchConfig(Expression):
def __init__(self, config):
super().__init__()
- if not hasattr(config, 'resolve_expression'):
+ if not hasattr(config, "resolve_expression"):
config = Value(config)
self.config = config
@@ -53,21 +58,21 @@ class SearchConfig(Expression):
return [self.config]
def set_source_expressions(self, exprs):
- self.config, = exprs
+ (self.config,) = exprs
def as_sql(self, compiler, connection):
sql, params = compiler.compile(self.config)
- return '%s::regconfig' % sql, params
+ return "%s::regconfig" % sql, params
class SearchVectorCombinable:
- ADD = '||'
+ ADD = "||"
def _combine(self, other, connector, reversed):
if not isinstance(other, SearchVectorCombinable):
raise TypeError(
- 'SearchVector can only be combined with other SearchVector '
- 'instances, got %s.' % type(other).__name__
+ "SearchVector can only be combined with other SearchVector "
+ "instances, got %s." % type(other).__name__
)
if reversed:
return CombinedSearchVector(other, connector, self, self.config)
@@ -75,49 +80,61 @@ class SearchVectorCombinable:
class SearchVector(SearchVectorCombinable, Func):
- function = 'to_tsvector'
+ function = "to_tsvector"
arg_joiner = " || ' ' || "
output_field = SearchVectorField()
def __init__(self, *expressions, config=None, weight=None):
super().__init__(*expressions)
self.config = SearchConfig.from_parameter(config)
- if weight is not None and not hasattr(weight, 'resolve_expression'):
+ if weight is not None and not hasattr(weight, "resolve_expression"):
weight = Value(weight)
self.weight = weight
- def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
- resolved = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
+ def resolve_expression(
+ self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
+ ):
+ resolved = super().resolve_expression(
+ query, allow_joins, reuse, summarize, for_save
+ )
if self.config:
- resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save)
+ resolved.config = self.config.resolve_expression(
+ query, allow_joins, reuse, summarize, for_save
+ )
return resolved
def as_sql(self, compiler, connection, function=None, template=None):
clone = self.copy()
- clone.set_source_expressions([
- Coalesce(
- expression
- if isinstance(expression.output_field, (CharField, TextField))
- else Cast(expression, TextField()),
- Value('')
- ) for expression in clone.get_source_expressions()
- ])
+ clone.set_source_expressions(
+ [
+ Coalesce(
+ expression
+ if isinstance(expression.output_field, (CharField, TextField))
+ else Cast(expression, TextField()),
+ Value(""),
+ )
+ for expression in clone.get_source_expressions()
+ ]
+ )
config_sql = None
config_params = []
if template is None:
if clone.config:
config_sql, config_params = compiler.compile(clone.config)
- template = '%(function)s(%(config)s, %(expressions)s)'
+ template = "%(function)s(%(config)s, %(expressions)s)"
else:
template = clone.template
sql, params = super(SearchVector, clone).as_sql(
- compiler, connection, function=function, template=template,
+ compiler,
+ connection,
+ function=function,
+ template=template,
config=config_sql,
)
extra_params = []
if clone.weight:
weight_sql, extra_params = compiler.compile(clone.weight)
- sql = 'setweight({}, {})'.format(sql, weight_sql)
+ sql = "setweight({}, {})".format(sql, weight_sql)
return sql, config_params + params + extra_params
@@ -128,14 +145,14 @@ class CombinedSearchVector(SearchVectorCombinable, CombinedExpression):
class SearchQueryCombinable:
- BITAND = '&&'
- BITOR = '||'
+ BITAND = "&&"
+ BITOR = "||"
def _combine(self, other, connector, reversed):
if not isinstance(other, SearchQueryCombinable):
raise TypeError(
- 'SearchQuery can only be combined with other SearchQuery '
- 'instances, got %s.' % type(other).__name__
+ "SearchQuery can only be combined with other SearchQuery "
+ "instances, got %s." % type(other).__name__
)
if reversed:
return CombinedSearchQuery(other, connector, self, self.config)
@@ -160,17 +177,25 @@ class SearchQueryCombinable:
class SearchQuery(SearchQueryCombinable, Func):
output_field = SearchQueryField()
SEARCH_TYPES = {
- 'plain': 'plainto_tsquery',
- 'phrase': 'phraseto_tsquery',
- 'raw': 'to_tsquery',
- 'websearch': 'websearch_to_tsquery',
+ "plain": "plainto_tsquery",
+ "phrase": "phraseto_tsquery",
+ "raw": "to_tsquery",
+ "websearch": "websearch_to_tsquery",
}
- def __init__(self, value, output_field=None, *, config=None, invert=False, search_type='plain'):
+ def __init__(
+ self,
+ value,
+ output_field=None,
+ *,
+ config=None,
+ invert=False,
+ search_type="plain",
+ ):
self.function = self.SEARCH_TYPES.get(search_type)
if self.function is None:
raise ValueError("Unknown search_type argument '%s'." % search_type)
- if not hasattr(value, 'resolve_expression'):
+ if not hasattr(value, "resolve_expression"):
value = Value(value)
expressions = (value,)
self.config = SearchConfig.from_parameter(config)
@@ -182,7 +207,7 @@ class SearchQuery(SearchQueryCombinable, Func):
def as_sql(self, compiler, connection, function=None, template=None):
sql, params = super().as_sql(compiler, connection, function, template)
if self.invert:
- sql = '!!(%s)' % sql
+ sql = "!!(%s)" % sql
return sql, params
def __invert__(self):
@@ -192,7 +217,7 @@ class SearchQuery(SearchQueryCombinable, Func):
def __str__(self):
result = super().__str__()
- return ('~%s' % result) if self.invert else result
+ return ("~%s" % result) if self.invert else result
class CombinedSearchQuery(SearchQueryCombinable, CombinedExpression):
@@ -201,60 +226,73 @@ class CombinedSearchQuery(SearchQueryCombinable, CombinedExpression):
super().__init__(lhs, connector, rhs, output_field)
def __str__(self):
- return '(%s)' % super().__str__()
+ return "(%s)" % super().__str__()
class SearchRank(Func):
- function = 'ts_rank'
+ function = "ts_rank"
output_field = FloatField()
def __init__(
- self, vector, query, weights=None, normalization=None,
+ self,
+ vector,
+ query,
+ weights=None,
+ normalization=None,
cover_density=False,
):
- if not hasattr(vector, 'resolve_expression'):
+ if not hasattr(vector, "resolve_expression"):
vector = SearchVector(vector)
- if not hasattr(query, 'resolve_expression'):
+ if not hasattr(query, "resolve_expression"):
query = SearchQuery(query)
expressions = (vector, query)
if weights is not None:
- if not hasattr(weights, 'resolve_expression'):
+ if not hasattr(weights, "resolve_expression"):
weights = Value(weights)
expressions = (weights,) + expressions
if normalization is not None:
- if not hasattr(normalization, 'resolve_expression'):
+ if not hasattr(normalization, "resolve_expression"):
normalization = Value(normalization)
expressions += (normalization,)
if cover_density:
- self.function = 'ts_rank_cd'
+ self.function = "ts_rank_cd"
super().__init__(*expressions)
class SearchHeadline(Func):
- function = 'ts_headline'
- template = '%(function)s(%(expressions)s%(options)s)'
+ function = "ts_headline"
+ template = "%(function)s(%(expressions)s%(options)s)"
output_field = TextField()
def __init__(
- self, expression, query, *, config=None, start_sel=None, stop_sel=None,
- max_words=None, min_words=None, short_word=None, highlight_all=None,
- max_fragments=None, fragment_delimiter=None,
+ self,
+ expression,
+ query,
+ *,
+ config=None,
+ start_sel=None,
+ stop_sel=None,
+ max_words=None,
+ min_words=None,
+ short_word=None,
+ highlight_all=None,
+ max_fragments=None,
+ fragment_delimiter=None,
):
- if not hasattr(query, 'resolve_expression'):
+ if not hasattr(query, "resolve_expression"):
query = SearchQuery(query)
options = {
- 'StartSel': start_sel,
- 'StopSel': stop_sel,
- 'MaxWords': max_words,
- 'MinWords': min_words,
- 'ShortWord': short_word,
- 'HighlightAll': highlight_all,
- 'MaxFragments': max_fragments,
- 'FragmentDelimiter': fragment_delimiter,
+ "StartSel": start_sel,
+ "StopSel": stop_sel,
+ "MaxWords": max_words,
+ "MinWords": min_words,
+ "ShortWord": short_word,
+ "HighlightAll": highlight_all,
+ "MaxFragments": max_fragments,
+ "FragmentDelimiter": fragment_delimiter,
}
self.options = {
- option: value
- for option, value in options.items() if value is not None
+ option: value for option, value in options.items() if value is not None
}
expressions = (expression, query)
if config is not None:
@@ -263,19 +301,26 @@ class SearchHeadline(Func):
super().__init__(*expressions)
def as_sql(self, compiler, connection, function=None, template=None):
- options_sql = ''
+ options_sql = ""
options_params = []
if self.options:
# getquoted() returns a quoted bytestring of the adapted value.
- options_params.append(', '.join(
- '%s=%s' % (
- option,
- psycopg2.extensions.adapt(value).getquoted().decode(),
- ) for option, value in self.options.items()
- ))
- options_sql = ', %s'
+ options_params.append(
+ ", ".join(
+ "%s=%s"
+ % (
+ option,
+ psycopg2.extensions.adapt(value).getquoted().decode(),
+ )
+ for option, value in self.options.items()
+ )
+ )
+ options_sql = ", %s"
sql, params = super().as_sql(
- compiler, connection, function=function, template=template,
+ compiler,
+ connection,
+ function=function,
+ template=template,
options=options_sql,
)
return sql, params + options_params
@@ -288,7 +333,7 @@ class TrigramBase(Func):
output_field = FloatField()
def __init__(self, expression, string, **extra):
- if not hasattr(string, 'resolve_expression'):
+ if not hasattr(string, "resolve_expression"):
string = Value(string)
super().__init__(expression, string, **extra)
@@ -297,24 +342,24 @@ class TrigramWordBase(Func):
output_field = FloatField()
def __init__(self, string, expression, **extra):
- if not hasattr(string, 'resolve_expression'):
+ if not hasattr(string, "resolve_expression"):
string = Value(string)
super().__init__(string, expression, **extra)
class TrigramSimilarity(TrigramBase):
- function = 'SIMILARITY'
+ function = "SIMILARITY"
class TrigramDistance(TrigramBase):
- function = ''
- arg_joiner = ' <-> '
+ function = ""
+ arg_joiner = " <-> "
class TrigramWordDistance(TrigramWordBase):
- function = ''
- arg_joiner = ' <<-> '
+ function = ""
+ arg_joiner = " <<-> "
class TrigramWordSimilarity(TrigramWordBase):
- function = 'WORD_SIMILARITY'
+ function = "WORD_SIMILARITY"
diff --git a/django/contrib/postgres/serializers.py b/django/contrib/postgres/serializers.py
index 1b1c2f1112..d04bfdbc69 100644
--- a/django/contrib/postgres/serializers.py
+++ b/django/contrib/postgres/serializers.py
@@ -6,5 +6,5 @@ class RangeSerializer(BaseSerializer):
module = self.value.__class__.__module__
# Ranges are implemented in psycopg2._range but the public import
# location is psycopg2.extras.
- module = 'psycopg2.extras' if module == 'psycopg2._range' else module
- return '%s.%r' % (module, self.value), {'import %s' % module}
+ module = "psycopg2.extras" if module == "psycopg2._range" else module
+ return "%s.%r" % (module, self.value), {"import %s" % module}
diff --git a/django/contrib/postgres/signals.py b/django/contrib/postgres/signals.py
index 420b5f6033..b61673fe1f 100644
--- a/django/contrib/postgres/signals.py
+++ b/django/contrib/postgres/signals.py
@@ -35,12 +35,14 @@ def get_citext_oids(connection_alias):
def register_type_handlers(connection, **kwargs):
- if connection.vendor != 'postgresql' or connection.alias == NO_DB_ALIAS:
+ if connection.vendor != "postgresql" or connection.alias == NO_DB_ALIAS:
return
try:
oids, array_oids = get_hstore_oids(connection.alias)
- register_hstore(connection.connection, globally=True, oid=oids, array_oid=array_oids)
+ register_hstore(
+ connection.connection, globally=True, oid=oids, array_oid=array_oids
+ )
except ProgrammingError:
# Hstore is not available on the database.
#
@@ -54,7 +56,9 @@ def register_type_handlers(connection, **kwargs):
try:
citext_oids = get_citext_oids(connection.alias)
- array_type = psycopg2.extensions.new_array_type(citext_oids, 'citext[]', psycopg2.STRING)
+ array_type = psycopg2.extensions.new_array_type(
+ citext_oids, "citext[]", psycopg2.STRING
+ )
psycopg2.extensions.register_type(array_type, None)
except ProgrammingError:
# citext is not available on the database.
diff --git a/django/contrib/postgres/utils.py b/django/contrib/postgres/utils.py
index f3c022f474..e4f4d81514 100644
--- a/django/contrib/postgres/utils.py
+++ b/django/contrib/postgres/utils.py
@@ -17,13 +17,13 @@ def prefix_validation_error(error, prefix, code, params):
# ngettext calls require a count parameter and are converted
# to an empty string if they are missing it.
message=format_lazy(
- '{} {}',
+ "{} {}",
SimpleLazyObject(lambda: prefix % params),
SimpleLazyObject(lambda: error.message % error_params),
),
code=code,
params={**error_params, **params},
)
- return ValidationError([
- prefix_validation_error(e, prefix, code, params) for e in error.error_list
- ])
+ return ValidationError(
+ [prefix_validation_error(e, prefix, code, params) for e in error.error_list]
+ )
diff --git a/django/contrib/postgres/validators.py b/django/contrib/postgres/validators.py
index db6205f356..df2bd88eb9 100644
--- a/django/contrib/postgres/validators.py
+++ b/django/contrib/postgres/validators.py
@@ -1,24 +1,29 @@
from django.core.exceptions import ValidationError
from django.core.validators import (
- MaxLengthValidator, MaxValueValidator, MinLengthValidator,
+ MaxLengthValidator,
+ MaxValueValidator,
+ MinLengthValidator,
MinValueValidator,
)
from django.utils.deconstruct import deconstructible
-from django.utils.translation import gettext_lazy as _, ngettext_lazy
+from django.utils.translation import gettext_lazy as _
+from django.utils.translation import ngettext_lazy
class ArrayMaxLengthValidator(MaxLengthValidator):
message = ngettext_lazy(
- 'List contains %(show_value)d item, it should contain no more than %(limit_value)d.',
- 'List contains %(show_value)d items, it should contain no more than %(limit_value)d.',
- 'limit_value')
+ "List contains %(show_value)d item, it should contain no more than %(limit_value)d.",
+ "List contains %(show_value)d items, it should contain no more than %(limit_value)d.",
+ "limit_value",
+ )
class ArrayMinLengthValidator(MinLengthValidator):
message = ngettext_lazy(
- 'List contains %(show_value)d item, it should contain no fewer than %(limit_value)d.',
- 'List contains %(show_value)d items, it should contain no fewer than %(limit_value)d.',
- 'limit_value')
+ "List contains %(show_value)d item, it should contain no fewer than %(limit_value)d.",
+ "List contains %(show_value)d items, it should contain no fewer than %(limit_value)d.",
+ "limit_value",
+ )
@deconstructible
@@ -26,8 +31,8 @@ class KeysValidator:
"""A validator designed for HStore to require/restrict keys."""
messages = {
- 'missing_keys': _('Some keys were missing: %(keys)s'),
- 'extra_keys': _('Some unknown keys were provided: %(keys)s'),
+ "missing_keys": _("Some keys were missing: %(keys)s"),
+ "extra_keys": _("Some unknown keys were provided: %(keys)s"),
}
strict = False
@@ -42,35 +47,41 @@ class KeysValidator:
missing_keys = self.keys - keys
if missing_keys:
raise ValidationError(
- self.messages['missing_keys'],
- code='missing_keys',
- params={'keys': ', '.join(missing_keys)},
+ self.messages["missing_keys"],
+ code="missing_keys",
+ params={"keys": ", ".join(missing_keys)},
)
if self.strict:
extra_keys = keys - self.keys
if extra_keys:
raise ValidationError(
- self.messages['extra_keys'],
- code='extra_keys',
- params={'keys': ', '.join(extra_keys)},
+ self.messages["extra_keys"],
+ code="extra_keys",
+ params={"keys": ", ".join(extra_keys)},
)
def __eq__(self, other):
return (
- isinstance(other, self.__class__) and
- self.keys == other.keys and
- self.messages == other.messages and
- self.strict == other.strict
+ isinstance(other, self.__class__)
+ and self.keys == other.keys
+ and self.messages == other.messages
+ and self.strict == other.strict
)
class RangeMaxValueValidator(MaxValueValidator):
def compare(self, a, b):
return a.upper is None or a.upper > b
- message = _('Ensure that this range is completely less than or equal to %(limit_value)s.')
+
+ message = _(
+ "Ensure that this range is completely less than or equal to %(limit_value)s."
+ )
class RangeMinValueValidator(MinValueValidator):
def compare(self, a, b):
return a.lower is None or a.lower < b
- message = _('Ensure that this range is completely greater than or equal to %(limit_value)s.')
+
+ message = _(
+ "Ensure that this range is completely greater than or equal to %(limit_value)s."
+ )