diff options
Diffstat (limited to 'django/contrib/postgres')
23 files changed, 757 insertions, 604 deletions
diff --git a/django/contrib/postgres/aggregates/general.py b/django/contrib/postgres/aggregates/general.py index d90ca50e2b..f8b40fb709 100644 --- a/django/contrib/postgres/aggregates/general.py +++ b/django/contrib/postgres/aggregates/general.py @@ -1,16 +1,20 @@ import warnings from django.contrib.postgres.fields import ArrayField -from django.db.models import ( - Aggregate, BooleanField, JSONField, TextField, Value, -) +from django.db.models import Aggregate, BooleanField, JSONField, TextField, Value from django.utils.deprecation import RemovedInDjango50Warning from .mixins import OrderableAggMixin __all__ = [ - 'ArrayAgg', 'BitAnd', 'BitOr', 'BitXor', 'BoolAnd', 'BoolOr', 'JSONBAgg', - 'StringAgg', + "ArrayAgg", + "BitAnd", + "BitOr", + "BitXor", + "BoolAnd", + "BoolOr", + "JSONBAgg", + "StringAgg", ] @@ -35,17 +39,17 @@ class DeprecatedConvertValueMixin: class ArrayAgg(DeprecatedConvertValueMixin, OrderableAggMixin, Aggregate): - function = 'ARRAY_AGG' - template = '%(function)s(%(distinct)s%(expressions)s %(ordering)s)' + function = "ARRAY_AGG" + template = "%(function)s(%(distinct)s%(expressions)s %(ordering)s)" allow_distinct = True # RemovedInDjango50Warning deprecation_value = property(lambda self: []) deprecation_msg = ( - 'In Django 5.0, ArrayAgg() will return None instead of an empty list ' - 'if there are no rows. Pass default=None to opt into the new behavior ' - 'and silence this warning or default=Value([]) to keep the previous ' - 'behavior.' + "In Django 5.0, ArrayAgg() will return None instead of an empty list " + "if there are no rows. Pass default=None to opt into the new behavior " + "and silence this warning or default=Value([]) to keep the previous " + "behavior." ) @property @@ -54,35 +58,35 @@ class ArrayAgg(DeprecatedConvertValueMixin, OrderableAggMixin, Aggregate): class BitAnd(Aggregate): - function = 'BIT_AND' + function = "BIT_AND" class BitOr(Aggregate): - function = 'BIT_OR' + function = "BIT_OR" class BitXor(Aggregate): - function = 'BIT_XOR' + function = "BIT_XOR" class BoolAnd(Aggregate): - function = 'BOOL_AND' + function = "BOOL_AND" output_field = BooleanField() class BoolOr(Aggregate): - function = 'BOOL_OR' + function = "BOOL_OR" output_field = BooleanField() class JSONBAgg(DeprecatedConvertValueMixin, OrderableAggMixin, Aggregate): - function = 'JSONB_AGG' - template = '%(function)s(%(distinct)s%(expressions)s %(ordering)s)' + function = "JSONB_AGG" + template = "%(function)s(%(distinct)s%(expressions)s %(ordering)s)" allow_distinct = True output_field = JSONField() # RemovedInDjango50Warning - deprecation_value = '[]' + deprecation_value = "[]" deprecation_msg = ( "In Django 5.0, JSONBAgg() will return None instead of an empty list " "if there are no rows. Pass default=None to opt into the new behavior " @@ -92,13 +96,13 @@ class JSONBAgg(DeprecatedConvertValueMixin, OrderableAggMixin, Aggregate): class StringAgg(DeprecatedConvertValueMixin, OrderableAggMixin, Aggregate): - function = 'STRING_AGG' - template = '%(function)s(%(distinct)s%(expressions)s %(ordering)s)' + function = "STRING_AGG" + template = "%(function)s(%(distinct)s%(expressions)s %(ordering)s)" allow_distinct = True output_field = TextField() # RemovedInDjango50Warning - deprecation_value = '' + deprecation_value = "" deprecation_msg = ( "In Django 5.0, StringAgg() will return None instead of an empty " "string if there are no rows. Pass default=None to opt into the new " diff --git a/django/contrib/postgres/aggregates/mixins.py b/django/contrib/postgres/aggregates/mixins.py index 4fedb9bd98..b2f4097b8f 100644 --- a/django/contrib/postgres/aggregates/mixins.py +++ b/django/contrib/postgres/aggregates/mixins.py @@ -2,7 +2,6 @@ from django.db.models.expressions import OrderByList class OrderableAggMixin: - def __init__(self, *expressions, ordering=(), **extra): if isinstance(ordering, (list, tuple)): self.order_by = OrderByList(*ordering) diff --git a/django/contrib/postgres/aggregates/statistics.py b/django/contrib/postgres/aggregates/statistics.py index 2c83b78c0e..3dc442b290 100644 --- a/django/contrib/postgres/aggregates/statistics.py +++ b/django/contrib/postgres/aggregates/statistics.py @@ -1,8 +1,18 @@ from django.db.models import Aggregate, FloatField, IntegerField __all__ = [ - 'CovarPop', 'Corr', 'RegrAvgX', 'RegrAvgY', 'RegrCount', 'RegrIntercept', - 'RegrR2', 'RegrSlope', 'RegrSXX', 'RegrSXY', 'RegrSYY', 'StatAggregate', + "CovarPop", + "Corr", + "RegrAvgX", + "RegrAvgY", + "RegrCount", + "RegrIntercept", + "RegrR2", + "RegrSlope", + "RegrSXX", + "RegrSXY", + "RegrSYY", + "StatAggregate", ] @@ -11,53 +21,55 @@ class StatAggregate(Aggregate): def __init__(self, y, x, output_field=None, filter=None, default=None): if not x or not y: - raise ValueError('Both y and x must be provided.') - super().__init__(y, x, output_field=output_field, filter=filter, default=default) + raise ValueError("Both y and x must be provided.") + super().__init__( + y, x, output_field=output_field, filter=filter, default=default + ) class Corr(StatAggregate): - function = 'CORR' + function = "CORR" class CovarPop(StatAggregate): def __init__(self, y, x, sample=False, filter=None, default=None): - self.function = 'COVAR_SAMP' if sample else 'COVAR_POP' + self.function = "COVAR_SAMP" if sample else "COVAR_POP" super().__init__(y, x, filter=filter, default=default) class RegrAvgX(StatAggregate): - function = 'REGR_AVGX' + function = "REGR_AVGX" class RegrAvgY(StatAggregate): - function = 'REGR_AVGY' + function = "REGR_AVGY" class RegrCount(StatAggregate): - function = 'REGR_COUNT' + function = "REGR_COUNT" output_field = IntegerField() empty_result_set_value = 0 class RegrIntercept(StatAggregate): - function = 'REGR_INTERCEPT' + function = "REGR_INTERCEPT" class RegrR2(StatAggregate): - function = 'REGR_R2' + function = "REGR_R2" class RegrSlope(StatAggregate): - function = 'REGR_SLOPE' + function = "REGR_SLOPE" class RegrSXX(StatAggregate): - function = 'REGR_SXX' + function = "REGR_SXX" class RegrSXY(StatAggregate): - function = 'REGR_SXY' + function = "REGR_SXY" class RegrSYY(StatAggregate): - function = 'REGR_SYY' + function = "REGR_SYY" diff --git a/django/contrib/postgres/apps.py b/django/contrib/postgres/apps.py index b8ec85b7a4..d917201f05 100644 --- a/django/contrib/postgres/apps.py +++ b/django/contrib/postgres/apps.py @@ -1,6 +1,4 @@ -from psycopg2.extras import ( - DateRange, DateTimeRange, DateTimeTZRange, NumericRange, -) +from psycopg2.extras import DateRange, DateTimeRange, DateTimeTZRange, NumericRange from django.apps import AppConfig from django.core.signals import setting_changed @@ -25,7 +23,11 @@ def uninstall_if_needed(setting, value, enter, **kwargs): Undo the effects of PostgresConfig.ready() when django.contrib.postgres is "uninstalled" by override_settings(). """ - if not enter and setting == 'INSTALLED_APPS' and 'django.contrib.postgres' not in set(value): + if ( + not enter + and setting == "INSTALLED_APPS" + and "django.contrib.postgres" not in set(value) + ): connection_created.disconnect(register_type_handlers) CharField._unregister_lookup(Unaccent) TextField._unregister_lookup(Unaccent) @@ -43,21 +45,23 @@ def uninstall_if_needed(setting, value, enter, **kwargs): class PostgresConfig(AppConfig): - name = 'django.contrib.postgres' - verbose_name = _('PostgreSQL extensions') + name = "django.contrib.postgres" + verbose_name = _("PostgreSQL extensions") def ready(self): setting_changed.connect(uninstall_if_needed) # Connections may already exist before we are called. for conn in connections.all(): - if conn.vendor == 'postgresql': - conn.introspection.data_types_reverse.update({ - 3904: 'django.contrib.postgres.fields.IntegerRangeField', - 3906: 'django.contrib.postgres.fields.DecimalRangeField', - 3910: 'django.contrib.postgres.fields.DateTimeRangeField', - 3912: 'django.contrib.postgres.fields.DateRangeField', - 3926: 'django.contrib.postgres.fields.BigIntegerRangeField', - }) + if conn.vendor == "postgresql": + conn.introspection.data_types_reverse.update( + { + 3904: "django.contrib.postgres.fields.IntegerRangeField", + 3906: "django.contrib.postgres.fields.DecimalRangeField", + 3910: "django.contrib.postgres.fields.DateTimeRangeField", + 3912: "django.contrib.postgres.fields.DateRangeField", + 3926: "django.contrib.postgres.fields.BigIntegerRangeField", + } + ) if conn.connection is not None: register_type_handlers(conn) connection_created.connect(register_type_handlers) diff --git a/django/contrib/postgres/constraints.py b/django/contrib/postgres/constraints.py index 06ccd3616e..c0bc0c444f 100644 --- a/django/contrib/postgres/constraints.py +++ b/django/contrib/postgres/constraints.py @@ -10,71 +10,69 @@ from django.db.models.indexes import IndexExpression from django.db.models.sql import Query from django.utils.deprecation import RemovedInDjango50Warning -__all__ = ['ExclusionConstraint'] +__all__ = ["ExclusionConstraint"] class ExclusionConstraintExpression(IndexExpression): - template = '%(expressions)s WITH %(operator)s' + template = "%(expressions)s WITH %(operator)s" class ExclusionConstraint(BaseConstraint): - template = 'CONSTRAINT %(name)s EXCLUDE USING %(index_type)s (%(expressions)s)%(include)s%(where)s%(deferrable)s' + template = "CONSTRAINT %(name)s EXCLUDE USING %(index_type)s (%(expressions)s)%(include)s%(where)s%(deferrable)s" def __init__( - self, *, name, expressions, index_type=None, condition=None, - deferrable=None, include=None, opclasses=(), + self, + *, + name, + expressions, + index_type=None, + condition=None, + deferrable=None, + include=None, + opclasses=(), ): - if index_type and index_type.lower() not in {'gist', 'spgist'}: + if index_type and index_type.lower() not in {"gist", "spgist"}: raise ValueError( - 'Exclusion constraints only support GiST or SP-GiST indexes.' + "Exclusion constraints only support GiST or SP-GiST indexes." ) if not expressions: raise ValueError( - 'At least one expression is required to define an exclusion ' - 'constraint.' + "At least one expression is required to define an exclusion " + "constraint." ) if not all( - isinstance(expr, (list, tuple)) and len(expr) == 2 - for expr in expressions + isinstance(expr, (list, tuple)) and len(expr) == 2 for expr in expressions ): - raise ValueError('The expressions must be a list of 2-tuples.') + raise ValueError("The expressions must be a list of 2-tuples.") if not isinstance(condition, (type(None), Q)): - raise ValueError( - 'ExclusionConstraint.condition must be a Q instance.' - ) + raise ValueError("ExclusionConstraint.condition must be a Q instance.") if condition and deferrable: - raise ValueError( - 'ExclusionConstraint with conditions cannot be deferred.' - ) + raise ValueError("ExclusionConstraint with conditions cannot be deferred.") if not isinstance(deferrable, (type(None), Deferrable)): raise ValueError( - 'ExclusionConstraint.deferrable must be a Deferrable instance.' + "ExclusionConstraint.deferrable must be a Deferrable instance." ) if not isinstance(include, (type(None), list, tuple)): - raise ValueError( - 'ExclusionConstraint.include must be a list or tuple.' - ) + raise ValueError("ExclusionConstraint.include must be a list or tuple.") if not isinstance(opclasses, (list, tuple)): - raise ValueError( - 'ExclusionConstraint.opclasses must be a list or tuple.' - ) + raise ValueError("ExclusionConstraint.opclasses must be a list or tuple.") if opclasses and len(expressions) != len(opclasses): raise ValueError( - 'ExclusionConstraint.expressions and ' - 'ExclusionConstraint.opclasses must have the same number of ' - 'elements.' + "ExclusionConstraint.expressions and " + "ExclusionConstraint.opclasses must have the same number of " + "elements." ) self.expressions = expressions - self.index_type = index_type or 'GIST' + self.index_type = index_type or "GIST" self.condition = condition self.deferrable = deferrable self.include = tuple(include) if include else () self.opclasses = opclasses if self.opclasses: warnings.warn( - 'The opclasses argument is deprecated in favor of using ' - 'django.contrib.postgres.indexes.OpClass in ' - 'ExclusionConstraint.expressions.', + "The opclasses argument is deprecated in favor of using " + "django.contrib.postgres.indexes.OpClass in " + "ExclusionConstraint.expressions.", category=RemovedInDjango50Warning, stacklevel=2, ) @@ -107,14 +105,18 @@ class ExclusionConstraint(BaseConstraint): expressions = self._get_expressions(schema_editor, query) table = model._meta.db_table condition = self._get_condition_sql(compiler, schema_editor, query) - include = [model._meta.get_field(field_name).column for field_name in self.include] + include = [ + model._meta.get_field(field_name).column for field_name in self.include + ] return Statement( self.template, table=Table(table, schema_editor.quote_name), name=schema_editor.quote_name(self.name), index_type=self.index_type, - expressions=Expressions(table, expressions, compiler, schema_editor.quote_value), - where=' WHERE (%s)' % condition if condition else '', + expressions=Expressions( + table, expressions, compiler, schema_editor.quote_value + ), + where=" WHERE (%s)" % condition if condition else "", include=schema_editor._index_include_sql(model, include), deferrable=schema_editor._deferrable_constraint_sql(self.deferrable), ) @@ -122,7 +124,7 @@ class ExclusionConstraint(BaseConstraint): def create_sql(self, model, schema_editor): self.check_supported(schema_editor) return Statement( - 'ALTER TABLE %(table)s ADD %(constraint)s', + "ALTER TABLE %(table)s ADD %(constraint)s", table=Table(model._meta.db_table, schema_editor.quote_name), constraint=self.constraint_sql(model, schema_editor), ) @@ -136,60 +138,60 @@ class ExclusionConstraint(BaseConstraint): def check_supported(self, schema_editor): if ( - self.include and - self.index_type.lower() == 'gist' and - not schema_editor.connection.features.supports_covering_gist_indexes + self.include + and self.index_type.lower() == "gist" + and not schema_editor.connection.features.supports_covering_gist_indexes ): raise NotSupportedError( - 'Covering exclusion constraints using a GiST index require ' - 'PostgreSQL 12+.' + "Covering exclusion constraints using a GiST index require " + "PostgreSQL 12+." ) if ( - self.include and - self.index_type.lower() == 'spgist' and - not schema_editor.connection.features.supports_covering_spgist_indexes + self.include + and self.index_type.lower() == "spgist" + and not schema_editor.connection.features.supports_covering_spgist_indexes ): raise NotSupportedError( - 'Covering exclusion constraints using an SP-GiST index ' - 'require PostgreSQL 14+.' + "Covering exclusion constraints using an SP-GiST index " + "require PostgreSQL 14+." ) def deconstruct(self): path, args, kwargs = super().deconstruct() - kwargs['expressions'] = self.expressions + kwargs["expressions"] = self.expressions if self.condition is not None: - kwargs['condition'] = self.condition - if self.index_type.lower() != 'gist': - kwargs['index_type'] = self.index_type + kwargs["condition"] = self.condition + if self.index_type.lower() != "gist": + kwargs["index_type"] = self.index_type if self.deferrable: - kwargs['deferrable'] = self.deferrable + kwargs["deferrable"] = self.deferrable if self.include: - kwargs['include'] = self.include + kwargs["include"] = self.include if self.opclasses: - kwargs['opclasses'] = self.opclasses + kwargs["opclasses"] = self.opclasses return path, args, kwargs def __eq__(self, other): if isinstance(other, self.__class__): return ( - self.name == other.name and - self.index_type == other.index_type and - self.expressions == other.expressions and - self.condition == other.condition and - self.deferrable == other.deferrable and - self.include == other.include and - self.opclasses == other.opclasses + self.name == other.name + and self.index_type == other.index_type + and self.expressions == other.expressions + and self.condition == other.condition + and self.deferrable == other.deferrable + and self.include == other.include + and self.opclasses == other.opclasses ) return super().__eq__(other) def __repr__(self): - return '<%s: index_type=%s expressions=%s name=%s%s%s%s%s>' % ( + return "<%s: index_type=%s expressions=%s name=%s%s%s%s%s>" % ( self.__class__.__qualname__, repr(self.index_type), repr(self.expressions), repr(self.name), - '' if self.condition is None else ' condition=%s' % self.condition, - '' if self.deferrable is None else ' deferrable=%r' % self.deferrable, - '' if not self.include else ' include=%s' % repr(self.include), - '' if not self.opclasses else ' opclasses=%s' % repr(self.opclasses), + "" if self.condition is None else " condition=%s" % self.condition, + "" if self.deferrable is None else " deferrable=%r" % self.deferrable, + "" if not self.include else " include=%s" % repr(self.include), + "" if not self.opclasses else " opclasses=%s" % repr(self.opclasses), ) diff --git a/django/contrib/postgres/expressions.py b/django/contrib/postgres/expressions.py index ea7cbe038d..469f4e9fb6 100644 --- a/django/contrib/postgres/expressions.py +++ b/django/contrib/postgres/expressions.py @@ -4,7 +4,7 @@ from django.utils.functional import cached_property class ArraySubquery(Subquery): - template = 'ARRAY(%(subquery)s)' + template = "ARRAY(%(subquery)s)" def __init__(self, queryset, **kwargs): super().__init__(queryset, **kwargs) diff --git a/django/contrib/postgres/fields/array.py b/django/contrib/postgres/fields/array.py index 9c1bb96b61..7269198674 100644 --- a/django/contrib/postgres/fields/array.py +++ b/django/contrib/postgres/fields/array.py @@ -12,38 +12,43 @@ from django.utils.translation import gettext_lazy as _ from ..utils import prefix_validation_error from .utils import AttributeSetter -__all__ = ['ArrayField'] +__all__ = ["ArrayField"] class ArrayField(CheckFieldDefaultMixin, Field): empty_strings_allowed = False default_error_messages = { - 'item_invalid': _('Item %(nth)s in the array did not validate:'), - 'nested_array_mismatch': _('Nested arrays must have the same length.'), + "item_invalid": _("Item %(nth)s in the array did not validate:"), + "nested_array_mismatch": _("Nested arrays must have the same length."), } - _default_hint = ('list', '[]') + _default_hint = ("list", "[]") def __init__(self, base_field, size=None, **kwargs): self.base_field = base_field self.size = size if self.size: - self.default_validators = [*self.default_validators, ArrayMaxLengthValidator(self.size)] + self.default_validators = [ + *self.default_validators, + ArrayMaxLengthValidator(self.size), + ] # For performance, only add a from_db_value() method if the base field # implements it. - if hasattr(self.base_field, 'from_db_value'): + if hasattr(self.base_field, "from_db_value"): self.from_db_value = self._from_db_value super().__init__(**kwargs) @property def model(self): try: - return self.__dict__['model'] + return self.__dict__["model"] except KeyError: - raise AttributeError("'%s' object has no attribute 'model'" % self.__class__.__name__) + raise AttributeError( + "'%s' object has no attribute 'model'" % self.__class__.__name__ + ) @model.setter def model(self, model): - self.__dict__['model'] = model + self.__dict__["model"] = model self.base_field.model = model @classmethod @@ -55,21 +60,23 @@ class ArrayField(CheckFieldDefaultMixin, Field): if self.base_field.remote_field: errors.append( checks.Error( - 'Base field for array cannot be a related field.', + "Base field for array cannot be a related field.", obj=self, - id='postgres.E002' + id="postgres.E002", ) ) else: # Remove the field name checks as they are not needed here. base_errors = self.base_field.check() if base_errors: - messages = '\n '.join('%s (%s)' % (error.msg, error.id) for error in base_errors) + messages = "\n ".join( + "%s (%s)" % (error.msg, error.id) for error in base_errors + ) errors.append( checks.Error( - 'Base field for array has errors:\n %s' % messages, + "Base field for array has errors:\n %s" % messages, obj=self, - id='postgres.E001' + id="postgres.E001", ) ) return errors @@ -80,32 +87,37 @@ class ArrayField(CheckFieldDefaultMixin, Field): @property def description(self): - return 'Array of %s' % self.base_field.description + return "Array of %s" % self.base_field.description def db_type(self, connection): - size = self.size or '' - return '%s[%s]' % (self.base_field.db_type(connection), size) + size = self.size or "" + return "%s[%s]" % (self.base_field.db_type(connection), size) def cast_db_type(self, connection): - size = self.size or '' - return '%s[%s]' % (self.base_field.cast_db_type(connection), size) + size = self.size or "" + return "%s[%s]" % (self.base_field.cast_db_type(connection), size) def get_placeholder(self, value, compiler, connection): - return '%s::{}'.format(self.db_type(connection)) + return "%s::{}".format(self.db_type(connection)) def get_db_prep_value(self, value, connection, prepared=False): if isinstance(value, (list, tuple)): - return [self.base_field.get_db_prep_value(i, connection, prepared=False) for i in value] + return [ + self.base_field.get_db_prep_value(i, connection, prepared=False) + for i in value + ] return value def deconstruct(self): name, path, args, kwargs = super().deconstruct() - if path == 'django.contrib.postgres.fields.array.ArrayField': - path = 'django.contrib.postgres.fields.ArrayField' - kwargs.update({ - 'base_field': self.base_field.clone(), - 'size': self.size, - }) + if path == "django.contrib.postgres.fields.array.ArrayField": + path = "django.contrib.postgres.fields.ArrayField" + kwargs.update( + { + "base_field": self.base_field.clone(), + "size": self.size, + } + ) return name, path, args, kwargs def to_python(self, value): @@ -140,7 +152,7 @@ class ArrayField(CheckFieldDefaultMixin, Field): transform = super().get_transform(name) if transform: return transform - if '_' not in name: + if "_" not in name: try: index = int(name) except ValueError: @@ -149,7 +161,7 @@ class ArrayField(CheckFieldDefaultMixin, Field): index += 1 # postgres uses 1-indexing return IndexTransformFactory(index, self.base_field) try: - start, end = name.split('_') + start, end = name.split("_") start = int(start) + 1 end = int(end) # don't add one here because postgres slices are weird except ValueError: @@ -165,15 +177,15 @@ class ArrayField(CheckFieldDefaultMixin, Field): except exceptions.ValidationError as error: raise prefix_validation_error( error, - prefix=self.error_messages['item_invalid'], - code='item_invalid', - params={'nth': index + 1}, + prefix=self.error_messages["item_invalid"], + code="item_invalid", + params={"nth": index + 1}, ) if isinstance(self.base_field, ArrayField): if len({len(i) for i in value}) > 1: raise exceptions.ValidationError( - self.error_messages['nested_array_mismatch'], - code='nested_array_mismatch', + self.error_messages["nested_array_mismatch"], + code="nested_array_mismatch", ) def run_validators(self, value): @@ -184,18 +196,20 @@ class ArrayField(CheckFieldDefaultMixin, Field): except exceptions.ValidationError as error: raise prefix_validation_error( error, - prefix=self.error_messages['item_invalid'], - code='item_invalid', - params={'nth': index + 1}, + prefix=self.error_messages["item_invalid"], + code="item_invalid", + params={"nth": index + 1}, ) def formfield(self, **kwargs): - return super().formfield(**{ - 'form_class': SimpleArrayField, - 'base_field': self.base_field.formfield(), - 'max_length': self.size, - **kwargs, - }) + return super().formfield( + **{ + "form_class": SimpleArrayField, + "base_field": self.base_field.formfield(), + "max_length": self.size, + **kwargs, + } + ) class ArrayRHSMixin: @@ -203,21 +217,21 @@ class ArrayRHSMixin: if isinstance(rhs, (tuple, list)): expressions = [] for value in rhs: - if not hasattr(value, 'resolve_expression'): + if not hasattr(value, "resolve_expression"): field = lhs.output_field value = Value(field.base_field.get_prep_value(value)) expressions.append(value) rhs = Func( *expressions, - function='ARRAY', - template='%(function)s[%(expressions)s]', + function="ARRAY", + template="%(function)s[%(expressions)s]", ) super().__init__(lhs, rhs) def process_rhs(self, compiler, connection): rhs, rhs_params = super().process_rhs(compiler, connection) cast_type = self.lhs.output_field.cast_db_type(connection) - return '%s::%s' % (rhs, cast_type), rhs_params + return "%s::%s" % (rhs, cast_type), rhs_params @ArrayField.register_lookup @@ -242,29 +256,29 @@ class ArrayOverlap(ArrayRHSMixin, lookups.Overlap): @ArrayField.register_lookup class ArrayLenTransform(Transform): - lookup_name = 'len' + lookup_name = "len" output_field = IntegerField() def as_sql(self, compiler, connection): lhs, params = compiler.compile(self.lhs) # Distinguish NULL and empty arrays return ( - 'CASE WHEN %(lhs)s IS NULL THEN NULL ELSE ' - 'coalesce(array_length(%(lhs)s, 1), 0) END' - ) % {'lhs': lhs}, params + "CASE WHEN %(lhs)s IS NULL THEN NULL ELSE " + "coalesce(array_length(%(lhs)s, 1), 0) END" + ) % {"lhs": lhs}, params @ArrayField.register_lookup class ArrayInLookup(In): def get_prep_lookup(self): values = super().get_prep_lookup() - if hasattr(values, 'resolve_expression'): + if hasattr(values, "resolve_expression"): return values # In.process_rhs() expects values to be hashable, so convert lists # to tuples. prepared_values = [] for value in values: - if hasattr(value, 'resolve_expression'): + if hasattr(value, "resolve_expression"): prepared_values.append(value) else: prepared_values.append(tuple(value)) @@ -272,7 +286,6 @@ class ArrayInLookup(In): class IndexTransform(Transform): - def __init__(self, index, base_field, *args, **kwargs): super().__init__(*args, **kwargs) self.index = index @@ -280,7 +293,7 @@ class IndexTransform(Transform): def as_sql(self, compiler, connection): lhs, params = compiler.compile(self.lhs) - return '%s[%%s]' % lhs, params + [self.index] + return "%s[%%s]" % lhs, params + [self.index] @property def output_field(self): @@ -288,7 +301,6 @@ class IndexTransform(Transform): class IndexTransformFactory: - def __init__(self, index, base_field): self.index = index self.base_field = base_field @@ -298,7 +310,6 @@ class IndexTransformFactory: class SliceTransform(Transform): - def __init__(self, start, end, *args, **kwargs): super().__init__(*args, **kwargs) self.start = start @@ -306,11 +317,10 @@ class SliceTransform(Transform): def as_sql(self, compiler, connection): lhs, params = compiler.compile(self.lhs) - return '%s[%%s:%%s]' % lhs, params + [self.start, self.end] + return "%s[%%s:%%s]" % lhs, params + [self.start, self.end] class SliceTransformFactory: - def __init__(self, start, end): self.start = start self.end = end diff --git a/django/contrib/postgres/fields/citext.py b/django/contrib/postgres/fields/citext.py index 46f6d3d1c2..2b943614d2 100644 --- a/django/contrib/postgres/fields/citext.py +++ b/django/contrib/postgres/fields/citext.py @@ -1,15 +1,14 @@ from django.db.models import CharField, EmailField, TextField -__all__ = ['CICharField', 'CIEmailField', 'CIText', 'CITextField'] +__all__ = ["CICharField", "CIEmailField", "CIText", "CITextField"] class CIText: - def get_internal_type(self): - return 'CI' + super().get_internal_type() + return "CI" + super().get_internal_type() def db_type(self, connection): - return 'citext' + return "citext" class CICharField(CIText, CharField): diff --git a/django/contrib/postgres/fields/hstore.py b/django/contrib/postgres/fields/hstore.py index 2ec5766041..cfc156ab59 100644 --- a/django/contrib/postgres/fields/hstore.py +++ b/django/contrib/postgres/fields/hstore.py @@ -7,19 +7,19 @@ from django.db.models import Field, TextField, Transform from django.db.models.fields.mixins import CheckFieldDefaultMixin from django.utils.translation import gettext_lazy as _ -__all__ = ['HStoreField'] +__all__ = ["HStoreField"] class HStoreField(CheckFieldDefaultMixin, Field): empty_strings_allowed = False - description = _('Map of strings to strings/nulls') + description = _("Map of strings to strings/nulls") default_error_messages = { - 'not_a_string': _('The value of “%(key)s” is not a string or null.'), + "not_a_string": _("The value of “%(key)s” is not a string or null."), } - _default_hint = ('dict', '{}') + _default_hint = ("dict", "{}") def db_type(self, connection): - return 'hstore' + return "hstore" def get_transform(self, name): transform = super().get_transform(name) @@ -32,9 +32,9 @@ class HStoreField(CheckFieldDefaultMixin, Field): for key, val in value.items(): if not isinstance(val, str) and val is not None: raise exceptions.ValidationError( - self.error_messages['not_a_string'], - code='not_a_string', - params={'key': key}, + self.error_messages["not_a_string"], + code="not_a_string", + params={"key": key}, ) def to_python(self, value): @@ -46,10 +46,12 @@ class HStoreField(CheckFieldDefaultMixin, Field): return json.dumps(self.value_from_object(obj)) def formfield(self, **kwargs): - return super().formfield(**{ - 'form_class': forms.HStoreField, - **kwargs, - }) + return super().formfield( + **{ + "form_class": forms.HStoreField, + **kwargs, + } + ) def get_prep_value(self, value): value = super().get_prep_value(value) @@ -85,11 +87,10 @@ class KeyTransform(Transform): def as_sql(self, compiler, connection): lhs, params = compiler.compile(self.lhs) - return '(%s -> %%s)' % lhs, tuple(params) + (self.key_name,) + return "(%s -> %%s)" % lhs, tuple(params) + (self.key_name,) class KeyTransformFactory: - def __init__(self, key_name): self.key_name = key_name @@ -99,13 +100,13 @@ class KeyTransformFactory: @HStoreField.register_lookup class KeysTransform(Transform): - lookup_name = 'keys' - function = 'akeys' + lookup_name = "keys" + function = "akeys" output_field = ArrayField(TextField()) @HStoreField.register_lookup class ValuesTransform(Transform): - lookup_name = 'values' - function = 'avals' + lookup_name = "values" + function = "avals" output_field = ArrayField(TextField()) diff --git a/django/contrib/postgres/fields/jsonb.py b/django/contrib/postgres/fields/jsonb.py index 29e8480665..760b5d8398 100644 --- a/django/contrib/postgres/fields/jsonb.py +++ b/django/contrib/postgres/fields/jsonb.py @@ -1,14 +1,14 @@ from django.db.models import JSONField as BuiltinJSONField -__all__ = ['JSONField'] +__all__ = ["JSONField"] class JSONField(BuiltinJSONField): system_check_removed_details = { - 'msg': ( - 'django.contrib.postgres.fields.JSONField is removed except for ' - 'support in historical migrations.' + "msg": ( + "django.contrib.postgres.fields.JSONField is removed except for " + "support in historical migrations." ), - 'hint': 'Use django.db.models.JSONField instead.', - 'id': 'fields.E904', + "hint": "Use django.db.models.JSONField instead.", + "id": "fields.E904", } diff --git a/django/contrib/postgres/fields/ranges.py b/django/contrib/postgres/fields/ranges.py index b395e213a1..58ffacbbe5 100644 --- a/django/contrib/postgres/fields/ranges.py +++ b/django/contrib/postgres/fields/ranges.py @@ -10,17 +10,23 @@ from django.db.models.lookups import PostgresOperatorLookup from .utils import AttributeSetter __all__ = [ - 'RangeField', 'IntegerRangeField', 'BigIntegerRangeField', - 'DecimalRangeField', 'DateTimeRangeField', 'DateRangeField', - 'RangeBoundary', 'RangeOperators', + "RangeField", + "IntegerRangeField", + "BigIntegerRangeField", + "DecimalRangeField", + "DateTimeRangeField", + "DateRangeField", + "RangeBoundary", + "RangeOperators", ] class RangeBoundary(models.Expression): """A class that represents range boundaries.""" + def __init__(self, inclusive_lower=True, inclusive_upper=False): - self.lower = '[' if inclusive_lower else '(' - self.upper = ']' if inclusive_upper else ')' + self.lower = "[" if inclusive_lower else "(" + self.upper = "]" if inclusive_upper else ")" def as_sql(self, compiler, connection): return "'%s%s'" % (self.lower, self.upper), [] @@ -28,41 +34,43 @@ class RangeBoundary(models.Expression): class RangeOperators: # https://www.postgresql.org/docs/current/functions-range.html#RANGE-OPERATORS-TABLE - EQUAL = '=' - NOT_EQUAL = '<>' - CONTAINS = '@>' - CONTAINED_BY = '<@' - OVERLAPS = '&&' - FULLY_LT = '<<' - FULLY_GT = '>>' - NOT_LT = '&>' - NOT_GT = '&<' - ADJACENT_TO = '-|-' + EQUAL = "=" + NOT_EQUAL = "<>" + CONTAINS = "@>" + CONTAINED_BY = "<@" + OVERLAPS = "&&" + FULLY_LT = "<<" + FULLY_GT = ">>" + NOT_LT = "&>" + NOT_GT = "&<" + ADJACENT_TO = "-|-" class RangeField(models.Field): empty_strings_allowed = False def __init__(self, *args, **kwargs): - if 'default_bounds' in kwargs: + if "default_bounds" in kwargs: raise TypeError( f"Cannot use 'default_bounds' with {self.__class__.__name__}." ) # Initializing base_field here ensures that its model matches the model for self. - if hasattr(self, 'base_field'): + if hasattr(self, "base_field"): self.base_field = self.base_field() super().__init__(*args, **kwargs) @property def model(self): try: - return self.__dict__['model'] + return self.__dict__["model"] except KeyError: - raise AttributeError("'%s' object has no attribute 'model'" % self.__class__.__name__) + raise AttributeError( + "'%s' object has no attribute 'model'" % self.__class__.__name__ + ) @model.setter def model(self, model): - self.__dict__['model'] = model + self.__dict__["model"] = model self.base_field.model = model @classmethod @@ -82,7 +90,7 @@ class RangeField(models.Field): if isinstance(value, str): # Assume we're deserializing vals = json.loads(value) - for end in ('lower', 'upper'): + for end in ("lower", "upper"): if end in vals: vals[end] = self.base_field.to_python(vals[end]) value = self.range_type(**vals) @@ -102,7 +110,7 @@ class RangeField(models.Field): return json.dumps({"empty": True}) base_field = self.base_field result = {"bounds": value._bounds} - for end in ('lower', 'upper'): + for end in ("lower", "upper"): val = getattr(value, end) if val is None: result[end] = None @@ -112,11 +120,11 @@ class RangeField(models.Field): return json.dumps(result) def formfield(self, **kwargs): - kwargs.setdefault('form_class', self.form_field) + kwargs.setdefault("form_class", self.form_field) return super().formfield(**kwargs) -CANONICAL_RANGE_BOUNDS = '[)' +CANONICAL_RANGE_BOUNDS = "[)" class ContinuousRangeField(RangeField): @@ -126,7 +134,7 @@ class ContinuousRangeField(RangeField): """ def __init__(self, *args, default_bounds=CANONICAL_RANGE_BOUNDS, **kwargs): - if default_bounds not in ('[)', '(]', '()', '[]'): + if default_bounds not in ("[)", "(]", "()", "[]"): raise ValueError("default_bounds must be one of '[)', '(]', '()', or '[]'.") self.default_bounds = default_bounds super().__init__(*args, **kwargs) @@ -137,13 +145,13 @@ class ContinuousRangeField(RangeField): return super().get_prep_value(value) def formfield(self, **kwargs): - kwargs.setdefault('default_bounds', self.default_bounds) + kwargs.setdefault("default_bounds", self.default_bounds) return super().formfield(**kwargs) def deconstruct(self): name, path, args, kwargs = super().deconstruct() if self.default_bounds and self.default_bounds != CANONICAL_RANGE_BOUNDS: - kwargs['default_bounds'] = self.default_bounds + kwargs["default_bounds"] = self.default_bounds return name, path, args, kwargs @@ -153,7 +161,7 @@ class IntegerRangeField(RangeField): form_field = forms.IntegerRangeField def db_type(self, connection): - return 'int4range' + return "int4range" class BigIntegerRangeField(RangeField): @@ -162,7 +170,7 @@ class BigIntegerRangeField(RangeField): form_field = forms.IntegerRangeField def db_type(self, connection): - return 'int8range' + return "int8range" class DecimalRangeField(ContinuousRangeField): @@ -171,7 +179,7 @@ class DecimalRangeField(ContinuousRangeField): form_field = forms.DecimalRangeField def db_type(self, connection): - return 'numrange' + return "numrange" class DateTimeRangeField(ContinuousRangeField): @@ -180,7 +188,7 @@ class DateTimeRangeField(ContinuousRangeField): form_field = forms.DateTimeRangeField def db_type(self, connection): - return 'tstzrange' + return "tstzrange" class DateRangeField(RangeField): @@ -189,7 +197,7 @@ class DateRangeField(RangeField): form_field = forms.DateRangeField def db_type(self, connection): - return 'daterange' + return "daterange" RangeField.register_lookup(lookups.DataContains) @@ -202,7 +210,8 @@ class DateTimeRangeContains(PostgresOperatorLookup): Lookup for Date/DateTimeRange containment to cast the rhs to the correct type. """ - lookup_name = 'contains' + + lookup_name = "contains" postgres_operator = RangeOperators.CONTAINS def process_rhs(self, compiler, connection): @@ -215,16 +224,19 @@ class DateTimeRangeContains(PostgresOperatorLookup): def as_postgresql(self, compiler, connection): sql, params = super().as_postgresql(compiler, connection) # Cast the rhs if needed. - cast_sql = '' + cast_sql = "" if ( - isinstance(self.rhs, models.Expression) and - self.rhs._output_field_or_none and + isinstance(self.rhs, models.Expression) + and self.rhs._output_field_or_none + and # Skip cast if rhs has a matching range type. - not isinstance(self.rhs._output_field_or_none, self.lhs.output_field.__class__) + not isinstance( + self.rhs._output_field_or_none, self.lhs.output_field.__class__ + ) ): cast_internal_type = self.lhs.output_field.base_field.get_internal_type() - cast_sql = '::{}'.format(connection.data_types.get(cast_internal_type)) - return '%s%s' % (sql, cast_sql), params + cast_sql = "::{}".format(connection.data_types.get(cast_internal_type)) + return "%s%s" % (sql, cast_sql), params DateRangeField.register_lookup(DateTimeRangeContains) @@ -232,31 +244,31 @@ DateTimeRangeField.register_lookup(DateTimeRangeContains) class RangeContainedBy(PostgresOperatorLookup): - lookup_name = 'contained_by' + lookup_name = "contained_by" type_mapping = { - 'smallint': 'int4range', - 'integer': 'int4range', - 'bigint': 'int8range', - 'double precision': 'numrange', - 'numeric': 'numrange', - 'date': 'daterange', - 'timestamp with time zone': 'tstzrange', + "smallint": "int4range", + "integer": "int4range", + "bigint": "int8range", + "double precision": "numrange", + "numeric": "numrange", + "date": "daterange", + "timestamp with time zone": "tstzrange", } postgres_operator = RangeOperators.CONTAINED_BY def process_rhs(self, compiler, connection): rhs, rhs_params = super().process_rhs(compiler, connection) # Ignore precision for DecimalFields. - db_type = self.lhs.output_field.cast_db_type(connection).split('(')[0] + db_type = self.lhs.output_field.cast_db_type(connection).split("(")[0] cast_type = self.type_mapping[db_type] - return '%s::%s' % (rhs, cast_type), rhs_params + return "%s::%s" % (rhs, cast_type), rhs_params def process_lhs(self, compiler, connection): lhs, lhs_params = super().process_lhs(compiler, connection) if isinstance(self.lhs.output_field, models.FloatField): - lhs = '%s::numeric' % lhs + lhs = "%s::numeric" % lhs elif isinstance(self.lhs.output_field, models.SmallIntegerField): - lhs = '%s::integer' % lhs + lhs = "%s::integer" % lhs return lhs, lhs_params def get_prep_lookup(self): @@ -272,38 +284,38 @@ models.DecimalField.register_lookup(RangeContainedBy) @RangeField.register_lookup class FullyLessThan(PostgresOperatorLookup): - lookup_name = 'fully_lt' + lookup_name = "fully_lt" postgres_operator = RangeOperators.FULLY_LT @RangeField.register_lookup class FullGreaterThan(PostgresOperatorLookup): - lookup_name = 'fully_gt' + lookup_name = "fully_gt" postgres_operator = RangeOperators.FULLY_GT @RangeField.register_lookup class NotLessThan(PostgresOperatorLookup): - lookup_name = 'not_lt' + lookup_name = "not_lt" postgres_operator = RangeOperators.NOT_LT @RangeField.register_lookup class NotGreaterThan(PostgresOperatorLookup): - lookup_name = 'not_gt' + lookup_name = "not_gt" postgres_operator = RangeOperators.NOT_GT @RangeField.register_lookup class AdjacentToLookup(PostgresOperatorLookup): - lookup_name = 'adjacent_to' + lookup_name = "adjacent_to" postgres_operator = RangeOperators.ADJACENT_TO @RangeField.register_lookup class RangeStartsWith(models.Transform): - lookup_name = 'startswith' - function = 'lower' + lookup_name = "startswith" + function = "lower" @property def output_field(self): @@ -312,8 +324,8 @@ class RangeStartsWith(models.Transform): @RangeField.register_lookup class RangeEndsWith(models.Transform): - lookup_name = 'endswith' - function = 'upper' + lookup_name = "endswith" + function = "upper" @property def output_field(self): @@ -322,34 +334,34 @@ class RangeEndsWith(models.Transform): @RangeField.register_lookup class IsEmpty(models.Transform): - lookup_name = 'isempty' - function = 'isempty' + lookup_name = "isempty" + function = "isempty" output_field = models.BooleanField() @RangeField.register_lookup class LowerInclusive(models.Transform): - lookup_name = 'lower_inc' - function = 'LOWER_INC' + lookup_name = "lower_inc" + function = "LOWER_INC" output_field = models.BooleanField() @RangeField.register_lookup class LowerInfinite(models.Transform): - lookup_name = 'lower_inf' - function = 'LOWER_INF' + lookup_name = "lower_inf" + function = "LOWER_INF" output_field = models.BooleanField() @RangeField.register_lookup class UpperInclusive(models.Transform): - lookup_name = 'upper_inc' - function = 'UPPER_INC' + lookup_name = "upper_inc" + function = "UPPER_INC" output_field = models.BooleanField() @RangeField.register_lookup class UpperInfinite(models.Transform): - lookup_name = 'upper_inf' - function = 'UPPER_INF' + lookup_name = "upper_inf" + function = "UPPER_INF" output_field = models.BooleanField() diff --git a/django/contrib/postgres/forms/array.py b/django/contrib/postgres/forms/array.py index 2e19cd574a..ddb022afc3 100644 --- a/django/contrib/postgres/forms/array.py +++ b/django/contrib/postgres/forms/array.py @@ -3,7 +3,8 @@ from itertools import chain from django import forms from django.contrib.postgres.validators import ( - ArrayMaxLengthValidator, ArrayMinLengthValidator, + ArrayMaxLengthValidator, + ArrayMinLengthValidator, ) from django.core.exceptions import ValidationError from django.utils.translation import gettext_lazy as _ @@ -13,10 +14,12 @@ from ..utils import prefix_validation_error class SimpleArrayField(forms.CharField): default_error_messages = { - 'item_invalid': _('Item %(nth)s in the array did not validate:'), + "item_invalid": _("Item %(nth)s in the array did not validate:"), } - def __init__(self, base_field, *, delimiter=',', max_length=None, min_length=None, **kwargs): + def __init__( + self, base_field, *, delimiter=",", max_length=None, min_length=None, **kwargs + ): self.base_field = base_field self.delimiter = delimiter super().__init__(**kwargs) @@ -33,7 +36,9 @@ class SimpleArrayField(forms.CharField): def prepare_value(self, value): if isinstance(value, list): - return self.delimiter.join(str(self.base_field.prepare_value(v)) for v in value) + return self.delimiter.join( + str(self.base_field.prepare_value(v)) for v in value + ) return value def to_python(self, value): @@ -49,12 +54,14 @@ class SimpleArrayField(forms.CharField): try: values.append(self.base_field.to_python(item)) except ValidationError as error: - errors.append(prefix_validation_error( - error, - prefix=self.error_messages['item_invalid'], - code='item_invalid', - params={'nth': index + 1}, - )) + errors.append( + prefix_validation_error( + error, + prefix=self.error_messages["item_invalid"], + code="item_invalid", + params={"nth": index + 1}, + ) + ) if errors: raise ValidationError(errors) return values @@ -66,12 +73,14 @@ class SimpleArrayField(forms.CharField): try: self.base_field.validate(item) except ValidationError as error: - errors.append(prefix_validation_error( - error, - prefix=self.error_messages['item_invalid'], - code='item_invalid', - params={'nth': index + 1}, - )) + errors.append( + prefix_validation_error( + error, + prefix=self.error_messages["item_invalid"], + code="item_invalid", + params={"nth": index + 1}, + ) + ) if errors: raise ValidationError(errors) @@ -82,12 +91,14 @@ class SimpleArrayField(forms.CharField): try: self.base_field.run_validators(item) except ValidationError as error: - errors.append(prefix_validation_error( - error, - prefix=self.error_messages['item_invalid'], - code='item_invalid', - params={'nth': index + 1}, - )) + errors.append( + prefix_validation_error( + error, + prefix=self.error_messages["item_invalid"], + code="item_invalid", + params={"nth": index + 1}, + ) + ) if errors: raise ValidationError(errors) @@ -103,7 +114,7 @@ class SimpleArrayField(forms.CharField): class SplitArrayWidget(forms.Widget): - template_name = 'postgres/widgets/split_array.html' + template_name = "postgres/widgets/split_array.html" def __init__(self, widget, size, **kwargs): self.widget = widget() if isinstance(widget, type) else widget @@ -115,19 +126,21 @@ class SplitArrayWidget(forms.Widget): return self.widget.is_hidden def value_from_datadict(self, data, files, name): - return [self.widget.value_from_datadict(data, files, '%s_%s' % (name, index)) - for index in range(self.size)] + return [ + self.widget.value_from_datadict(data, files, "%s_%s" % (name, index)) + for index in range(self.size) + ] def value_omitted_from_data(self, data, files, name): return all( - self.widget.value_omitted_from_data(data, files, '%s_%s' % (name, index)) + self.widget.value_omitted_from_data(data, files, "%s_%s" % (name, index)) for index in range(self.size) ) def id_for_label(self, id_): # See the comment for RadioSelect.id_for_label() if id_: - id_ += '_0' + id_ += "_0" return id_ def get_context(self, name, value, attrs=None): @@ -136,18 +149,20 @@ class SplitArrayWidget(forms.Widget): if self.is_localized: self.widget.is_localized = self.is_localized value = value or [] - context['widget']['subwidgets'] = [] + context["widget"]["subwidgets"] = [] final_attrs = self.build_attrs(attrs) - id_ = final_attrs.get('id') + id_ = final_attrs.get("id") for i in range(max(len(value), self.size)): try: widget_value = value[i] except IndexError: widget_value = None if id_: - final_attrs = {**final_attrs, 'id': '%s_%s' % (id_, i)} - context['widget']['subwidgets'].append( - self.widget.get_context(name + '_%s' % i, widget_value, final_attrs)['widget'] + final_attrs = {**final_attrs, "id": "%s_%s" % (id_, i)} + context["widget"]["subwidgets"].append( + self.widget.get_context(name + "_%s" % i, widget_value, final_attrs)[ + "widget" + ] ) return context @@ -167,7 +182,7 @@ class SplitArrayWidget(forms.Widget): class SplitArrayField(forms.Field): default_error_messages = { - 'item_invalid': _('Item %(nth)s in the array did not validate:'), + "item_invalid": _("Item %(nth)s in the array did not validate:"), } def __init__(self, base_field, size, *, remove_trailing_nulls=False, **kwargs): @@ -175,7 +190,7 @@ class SplitArrayField(forms.Field): self.size = size self.remove_trailing_nulls = remove_trailing_nulls widget = SplitArrayWidget(widget=base_field.widget, size=size) - kwargs.setdefault('widget', widget) + kwargs.setdefault("widget", widget) super().__init__(**kwargs) def _remove_trailing_nulls(self, values): @@ -198,19 +213,21 @@ class SplitArrayField(forms.Field): cleaned_data = [] errors = [] if not any(value) and self.required: - raise ValidationError(self.error_messages['required']) + raise ValidationError(self.error_messages["required"]) max_size = max(self.size, len(value)) for index in range(max_size): item = value[index] try: cleaned_data.append(self.base_field.clean(item)) except ValidationError as error: - errors.append(prefix_validation_error( - error, - self.error_messages['item_invalid'], - code='item_invalid', - params={'nth': index + 1}, - )) + errors.append( + prefix_validation_error( + error, + self.error_messages["item_invalid"], + code="item_invalid", + params={"nth": index + 1}, + ) + ) cleaned_data.append(None) else: errors.append(None) diff --git a/django/contrib/postgres/forms/hstore.py b/django/contrib/postgres/forms/hstore.py index f5af8f10e3..6a20f7b729 100644 --- a/django/contrib/postgres/forms/hstore.py +++ b/django/contrib/postgres/forms/hstore.py @@ -4,17 +4,18 @@ from django import forms from django.core.exceptions import ValidationError from django.utils.translation import gettext_lazy as _ -__all__ = ['HStoreField'] +__all__ = ["HStoreField"] class HStoreField(forms.CharField): """ A field for HStore data which accepts dictionary JSON input. """ + widget = forms.Textarea default_error_messages = { - 'invalid_json': _('Could not load JSON data.'), - 'invalid_format': _('Input must be a JSON dictionary.'), + "invalid_json": _("Could not load JSON data."), + "invalid_format": _("Input must be a JSON dictionary."), } def prepare_value(self, value): @@ -30,14 +31,14 @@ class HStoreField(forms.CharField): value = json.loads(value) except json.JSONDecodeError: raise ValidationError( - self.error_messages['invalid_json'], - code='invalid_json', + self.error_messages["invalid_json"], + code="invalid_json", ) if not isinstance(value, dict): raise ValidationError( - self.error_messages['invalid_format'], - code='invalid_format', + self.error_messages["invalid_format"], + code="invalid_format", ) # Cast everything to strings for ease. diff --git a/django/contrib/postgres/forms/ranges.py b/django/contrib/postgres/forms/ranges.py index 9c673ab40c..444991970d 100644 --- a/django/contrib/postgres/forms/ranges.py +++ b/django/contrib/postgres/forms/ranges.py @@ -6,8 +6,13 @@ from django.forms.widgets import HiddenInput, MultiWidget from django.utils.translation import gettext_lazy as _ __all__ = [ - 'BaseRangeField', 'IntegerRangeField', 'DecimalRangeField', - 'DateTimeRangeField', 'DateRangeField', 'HiddenRangeWidget', 'RangeWidget', + "BaseRangeField", + "IntegerRangeField", + "DecimalRangeField", + "DateTimeRangeField", + "DateRangeField", + "HiddenRangeWidget", + "RangeWidget", ] @@ -24,27 +29,33 @@ class RangeWidget(MultiWidget): class HiddenRangeWidget(RangeWidget): """A widget that splits input into two <input type="hidden"> inputs.""" + def __init__(self, attrs=None): super().__init__(HiddenInput, attrs) class BaseRangeField(forms.MultiValueField): default_error_messages = { - 'invalid': _('Enter two valid values.'), - 'bound_ordering': _('The start of the range must not exceed the end of the range.'), + "invalid": _("Enter two valid values."), + "bound_ordering": _( + "The start of the range must not exceed the end of the range." + ), } hidden_widget = HiddenRangeWidget def __init__(self, **kwargs): - if 'widget' not in kwargs: - kwargs['widget'] = RangeWidget(self.base_field.widget) - if 'fields' not in kwargs: - kwargs['fields'] = [self.base_field(required=False), self.base_field(required=False)] - kwargs.setdefault('required', False) - kwargs.setdefault('require_all_fields', False) + if "widget" not in kwargs: + kwargs["widget"] = RangeWidget(self.base_field.widget) + if "fields" not in kwargs: + kwargs["fields"] = [ + self.base_field(required=False), + self.base_field(required=False), + ] + kwargs.setdefault("required", False) + kwargs.setdefault("require_all_fields", False) self.range_kwargs = {} - if default_bounds := kwargs.pop('default_bounds', None): - self.range_kwargs = {'bounds': default_bounds} + if default_bounds := kwargs.pop("default_bounds", None): + self.range_kwargs = {"bounds": default_bounds} super().__init__(**kwargs) def prepare_value(self, value): @@ -67,39 +78,39 @@ class BaseRangeField(forms.MultiValueField): lower, upper = values if lower is not None and upper is not None and lower > upper: raise exceptions.ValidationError( - self.error_messages['bound_ordering'], - code='bound_ordering', + self.error_messages["bound_ordering"], + code="bound_ordering", ) try: range_value = self.range_type(lower, upper, **self.range_kwargs) except TypeError: raise exceptions.ValidationError( - self.error_messages['invalid'], - code='invalid', + self.error_messages["invalid"], + code="invalid", ) else: return range_value class IntegerRangeField(BaseRangeField): - default_error_messages = {'invalid': _('Enter two whole numbers.')} + default_error_messages = {"invalid": _("Enter two whole numbers.")} base_field = forms.IntegerField range_type = NumericRange class DecimalRangeField(BaseRangeField): - default_error_messages = {'invalid': _('Enter two numbers.')} + default_error_messages = {"invalid": _("Enter two numbers.")} base_field = forms.DecimalField range_type = NumericRange class DateTimeRangeField(BaseRangeField): - default_error_messages = {'invalid': _('Enter two valid date/times.')} + default_error_messages = {"invalid": _("Enter two valid date/times.")} base_field = forms.DateTimeField range_type = DateTimeTZRange class DateRangeField(BaseRangeField): - default_error_messages = {'invalid': _('Enter two valid dates.')} + default_error_messages = {"invalid": _("Enter two valid dates.")} base_field = forms.DateField range_type = DateRange diff --git a/django/contrib/postgres/functions.py b/django/contrib/postgres/functions.py index 819ce058e5..f001a04fdc 100644 --- a/django/contrib/postgres/functions.py +++ b/django/contrib/postgres/functions.py @@ -2,10 +2,10 @@ from django.db.models import DateTimeField, Func, UUIDField class RandomUUID(Func): - template = 'GEN_RANDOM_UUID()' + template = "GEN_RANDOM_UUID()" output_field = UUIDField() class TransactionNow(Func): - template = 'CURRENT_TIMESTAMP' + template = "CURRENT_TIMESTAMP" output_field = DateTimeField() diff --git a/django/contrib/postgres/indexes.py b/django/contrib/postgres/indexes.py index 2e3de1d275..409f514147 100644 --- a/django/contrib/postgres/indexes.py +++ b/django/contrib/postgres/indexes.py @@ -3,13 +3,17 @@ from django.db.models import Func, Index from django.utils.functional import cached_property __all__ = [ - 'BloomIndex', 'BrinIndex', 'BTreeIndex', 'GinIndex', 'GistIndex', - 'HashIndex', 'SpGistIndex', + "BloomIndex", + "BrinIndex", + "BTreeIndex", + "GinIndex", + "GistIndex", + "HashIndex", + "SpGistIndex", ] class PostgresIndex(Index): - @cached_property def max_name_length(self): # Allow an index name longer than 30 characters when the suffix is @@ -18,14 +22,16 @@ class PostgresIndex(Index): # indexes. return Index.max_name_length - len(Index.suffix) + len(self.suffix) - def create_sql(self, model, schema_editor, using='', **kwargs): + def create_sql(self, model, schema_editor, using="", **kwargs): self.check_supported(schema_editor) - statement = super().create_sql(model, schema_editor, using=' USING %s' % self.suffix, **kwargs) + statement = super().create_sql( + model, schema_editor, using=" USING %s" % self.suffix, **kwargs + ) with_params = self.get_with_params() if with_params: - statement.parts['extra'] = 'WITH (%s) %s' % ( - ', '.join(with_params), - statement.parts['extra'], + statement.parts["extra"] = "WITH (%s) %s" % ( + ", ".join(with_params), + statement.parts["extra"], ) return statement @@ -37,25 +43,23 @@ class PostgresIndex(Index): class BloomIndex(PostgresIndex): - suffix = 'bloom' + suffix = "bloom" def __init__(self, *expressions, length=None, columns=(), **kwargs): super().__init__(*expressions, **kwargs) if len(self.fields) > 32: - raise ValueError('Bloom indexes support a maximum of 32 fields.') + raise ValueError("Bloom indexes support a maximum of 32 fields.") if not isinstance(columns, (list, tuple)): - raise ValueError('BloomIndex.columns must be a list or tuple.') + raise ValueError("BloomIndex.columns must be a list or tuple.") if len(columns) > len(self.fields): - raise ValueError( - 'BloomIndex.columns cannot have more values than fields.' - ) + raise ValueError("BloomIndex.columns cannot have more values than fields.") if not all(0 < col <= 4095 for col in columns): raise ValueError( - 'BloomIndex.columns must contain integers from 1 to 4095.', + "BloomIndex.columns must contain integers from 1 to 4095.", ) if length is not None and not 0 < length <= 4096: raise ValueError( - 'BloomIndex.length must be None or an integer from 1 to 4096.', + "BloomIndex.length must be None or an integer from 1 to 4096.", ) self.length = length self.columns = columns @@ -63,29 +67,30 @@ class BloomIndex(PostgresIndex): def deconstruct(self): path, args, kwargs = super().deconstruct() if self.length is not None: - kwargs['length'] = self.length + kwargs["length"] = self.length if self.columns: - kwargs['columns'] = self.columns + kwargs["columns"] = self.columns return path, args, kwargs def get_with_params(self): with_params = [] if self.length is not None: - with_params.append('length = %d' % self.length) + with_params.append("length = %d" % self.length) if self.columns: with_params.extend( - 'col%d = %d' % (i, v) - for i, v in enumerate(self.columns, start=1) + "col%d = %d" % (i, v) for i, v in enumerate(self.columns, start=1) ) return with_params class BrinIndex(PostgresIndex): - suffix = 'brin' + suffix = "brin" - def __init__(self, *expressions, autosummarize=None, pages_per_range=None, **kwargs): + def __init__( + self, *expressions, autosummarize=None, pages_per_range=None, **kwargs + ): if pages_per_range is not None and pages_per_range <= 0: - raise ValueError('pages_per_range must be None or a positive integer') + raise ValueError("pages_per_range must be None or a positive integer") self.autosummarize = autosummarize self.pages_per_range = pages_per_range super().__init__(*expressions, **kwargs) @@ -93,22 +98,24 @@ class BrinIndex(PostgresIndex): def deconstruct(self): path, args, kwargs = super().deconstruct() if self.autosummarize is not None: - kwargs['autosummarize'] = self.autosummarize + kwargs["autosummarize"] = self.autosummarize if self.pages_per_range is not None: - kwargs['pages_per_range'] = self.pages_per_range + kwargs["pages_per_range"] = self.pages_per_range return path, args, kwargs def get_with_params(self): with_params = [] if self.autosummarize is not None: - with_params.append('autosummarize = %s' % ('on' if self.autosummarize else 'off')) + with_params.append( + "autosummarize = %s" % ("on" if self.autosummarize else "off") + ) if self.pages_per_range is not None: - with_params.append('pages_per_range = %d' % self.pages_per_range) + with_params.append("pages_per_range = %d" % self.pages_per_range) return with_params class BTreeIndex(PostgresIndex): - suffix = 'btree' + suffix = "btree" def __init__(self, *expressions, fillfactor=None, **kwargs): self.fillfactor = fillfactor @@ -117,20 +124,22 @@ class BTreeIndex(PostgresIndex): def deconstruct(self): path, args, kwargs = super().deconstruct() if self.fillfactor is not None: - kwargs['fillfactor'] = self.fillfactor + kwargs["fillfactor"] = self.fillfactor return path, args, kwargs def get_with_params(self): with_params = [] if self.fillfactor is not None: - with_params.append('fillfactor = %d' % self.fillfactor) + with_params.append("fillfactor = %d" % self.fillfactor) return with_params class GinIndex(PostgresIndex): - suffix = 'gin' + suffix = "gin" - def __init__(self, *expressions, fastupdate=None, gin_pending_list_limit=None, **kwargs): + def __init__( + self, *expressions, fastupdate=None, gin_pending_list_limit=None, **kwargs + ): self.fastupdate = fastupdate self.gin_pending_list_limit = gin_pending_list_limit super().__init__(*expressions, **kwargs) @@ -138,22 +147,24 @@ class GinIndex(PostgresIndex): def deconstruct(self): path, args, kwargs = super().deconstruct() if self.fastupdate is not None: - kwargs['fastupdate'] = self.fastupdate + kwargs["fastupdate"] = self.fastupdate if self.gin_pending_list_limit is not None: - kwargs['gin_pending_list_limit'] = self.gin_pending_list_limit + kwargs["gin_pending_list_limit"] = self.gin_pending_list_limit return path, args, kwargs def get_with_params(self): with_params = [] if self.gin_pending_list_limit is not None: - with_params.append('gin_pending_list_limit = %d' % self.gin_pending_list_limit) + with_params.append( + "gin_pending_list_limit = %d" % self.gin_pending_list_limit + ) if self.fastupdate is not None: - with_params.append('fastupdate = %s' % ('on' if self.fastupdate else 'off')) + with_params.append("fastupdate = %s" % ("on" if self.fastupdate else "off")) return with_params class GistIndex(PostgresIndex): - suffix = 'gist' + suffix = "gist" def __init__(self, *expressions, buffering=None, fillfactor=None, **kwargs): self.buffering = buffering @@ -163,26 +174,29 @@ class GistIndex(PostgresIndex): def deconstruct(self): path, args, kwargs = super().deconstruct() if self.buffering is not None: - kwargs['buffering'] = self.buffering + kwargs["buffering"] = self.buffering if self.fillfactor is not None: - kwargs['fillfactor'] = self.fillfactor + kwargs["fillfactor"] = self.fillfactor return path, args, kwargs def get_with_params(self): with_params = [] if self.buffering is not None: - with_params.append('buffering = %s' % ('on' if self.buffering else 'off')) + with_params.append("buffering = %s" % ("on" if self.buffering else "off")) if self.fillfactor is not None: - with_params.append('fillfactor = %d' % self.fillfactor) + with_params.append("fillfactor = %d" % self.fillfactor) return with_params def check_supported(self, schema_editor): - if self.include and not schema_editor.connection.features.supports_covering_gist_indexes: - raise NotSupportedError('Covering GiST indexes require PostgreSQL 12+.') + if ( + self.include + and not schema_editor.connection.features.supports_covering_gist_indexes + ): + raise NotSupportedError("Covering GiST indexes require PostgreSQL 12+.") class HashIndex(PostgresIndex): - suffix = 'hash' + suffix = "hash" def __init__(self, *expressions, fillfactor=None, **kwargs): self.fillfactor = fillfactor @@ -191,18 +205,18 @@ class HashIndex(PostgresIndex): def deconstruct(self): path, args, kwargs = super().deconstruct() if self.fillfactor is not None: - kwargs['fillfactor'] = self.fillfactor + kwargs["fillfactor"] = self.fillfactor return path, args, kwargs def get_with_params(self): with_params = [] if self.fillfactor is not None: - with_params.append('fillfactor = %d' % self.fillfactor) + with_params.append("fillfactor = %d" % self.fillfactor) return with_params class SpGistIndex(PostgresIndex): - suffix = 'spgist' + suffix = "spgist" def __init__(self, *expressions, fillfactor=None, **kwargs): self.fillfactor = fillfactor @@ -211,25 +225,25 @@ class SpGistIndex(PostgresIndex): def deconstruct(self): path, args, kwargs = super().deconstruct() if self.fillfactor is not None: - kwargs['fillfactor'] = self.fillfactor + kwargs["fillfactor"] = self.fillfactor return path, args, kwargs def get_with_params(self): with_params = [] if self.fillfactor is not None: - with_params.append('fillfactor = %d' % self.fillfactor) + with_params.append("fillfactor = %d" % self.fillfactor) return with_params def check_supported(self, schema_editor): if ( - self.include and - not schema_editor.connection.features.supports_covering_spgist_indexes + self.include + and not schema_editor.connection.features.supports_covering_spgist_indexes ): - raise NotSupportedError('Covering SP-GiST indexes require PostgreSQL 14+.') + raise NotSupportedError("Covering SP-GiST indexes require PostgreSQL 14+.") class OpClass(Func): - template = '%(expressions)s %(name)s' + template = "%(expressions)s %(name)s" def __init__(self, expression, name): super().__init__(expression, name=name) diff --git a/django/contrib/postgres/lookups.py b/django/contrib/postgres/lookups.py index f7c6fc4b0c..9fed0eea30 100644 --- a/django/contrib/postgres/lookups.py +++ b/django/contrib/postgres/lookups.py @@ -5,61 +5,61 @@ from .search import SearchVector, SearchVectorExact, SearchVectorField class DataContains(PostgresOperatorLookup): - lookup_name = 'contains' - postgres_operator = '@>' + lookup_name = "contains" + postgres_operator = "@>" class ContainedBy(PostgresOperatorLookup): - lookup_name = 'contained_by' - postgres_operator = '<@' + lookup_name = "contained_by" + postgres_operator = "<@" class Overlap(PostgresOperatorLookup): - lookup_name = 'overlap' - postgres_operator = '&&' + lookup_name = "overlap" + postgres_operator = "&&" class HasKey(PostgresOperatorLookup): - lookup_name = 'has_key' - postgres_operator = '?' + lookup_name = "has_key" + postgres_operator = "?" prepare_rhs = False class HasKeys(PostgresOperatorLookup): - lookup_name = 'has_keys' - postgres_operator = '?&' + lookup_name = "has_keys" + postgres_operator = "?&" def get_prep_lookup(self): return [str(item) for item in self.rhs] class HasAnyKeys(HasKeys): - lookup_name = 'has_any_keys' - postgres_operator = '?|' + lookup_name = "has_any_keys" + postgres_operator = "?|" class Unaccent(Transform): bilateral = True - lookup_name = 'unaccent' - function = 'UNACCENT' + lookup_name = "unaccent" + function = "UNACCENT" class SearchLookup(SearchVectorExact): - lookup_name = 'search' + lookup_name = "search" def process_lhs(self, qn, connection): if not isinstance(self.lhs.output_field, SearchVectorField): - config = getattr(self.rhs, 'config', None) + config = getattr(self.rhs, "config", None) self.lhs = SearchVector(self.lhs, config=config) lhs, lhs_params = super().process_lhs(qn, connection) return lhs, lhs_params class TrigramSimilar(PostgresOperatorLookup): - lookup_name = 'trigram_similar' - postgres_operator = '%%' + lookup_name = "trigram_similar" + postgres_operator = "%%" class TrigramWordSimilar(PostgresOperatorLookup): - lookup_name = 'trigram_word_similar' - postgres_operator = '%%>' + lookup_name = "trigram_word_similar" + postgres_operator = "%%>" diff --git a/django/contrib/postgres/operations.py b/django/contrib/postgres/operations.py index 037bb4ec22..374f5ee1ec 100644 --- a/django/contrib/postgres/operations.py +++ b/django/contrib/postgres/operations.py @@ -1,5 +1,7 @@ from django.contrib.postgres.signals import ( - get_citext_oids, get_hstore_oids, register_type_handlers, + get_citext_oids, + get_hstore_oids, + register_type_handlers, ) from django.db import NotSupportedError, router from django.db.migrations import AddConstraint, AddIndex, RemoveIndex @@ -17,14 +19,14 @@ class CreateExtension(Operation): pass def database_forwards(self, app_label, schema_editor, from_state, to_state): - if ( - schema_editor.connection.vendor != 'postgresql' or - not router.allow_migrate(schema_editor.connection.alias, app_label) + if schema_editor.connection.vendor != "postgresql" or not router.allow_migrate( + schema_editor.connection.alias, app_label ): return if not self.extension_exists(schema_editor, self.name): schema_editor.execute( - 'CREATE EXTENSION IF NOT EXISTS %s' % schema_editor.quote_name(self.name) + "CREATE EXTENSION IF NOT EXISTS %s" + % schema_editor.quote_name(self.name) ) # Clear cached, stale oids. get_hstore_oids.cache_clear() @@ -39,7 +41,7 @@ class CreateExtension(Operation): return if self.extension_exists(schema_editor, self.name): schema_editor.execute( - 'DROP EXTENSION IF EXISTS %s' % schema_editor.quote_name(self.name) + "DROP EXTENSION IF EXISTS %s" % schema_editor.quote_name(self.name) ) # Clear cached, stale oids. get_hstore_oids.cache_clear() @@ -48,7 +50,7 @@ class CreateExtension(Operation): def extension_exists(self, schema_editor, extension): with schema_editor.connection.cursor() as cursor: cursor.execute( - 'SELECT 1 FROM pg_extension WHERE extname = %s', + "SELECT 1 FROM pg_extension WHERE extname = %s", [extension], ) return bool(cursor.fetchone()) @@ -58,75 +60,67 @@ class CreateExtension(Operation): @property def migration_name_fragment(self): - return 'create_extension_%s' % self.name + return "create_extension_%s" % self.name class BloomExtension(CreateExtension): - def __init__(self): - self.name = 'bloom' + self.name = "bloom" class BtreeGinExtension(CreateExtension): - def __init__(self): - self.name = 'btree_gin' + self.name = "btree_gin" class BtreeGistExtension(CreateExtension): - def __init__(self): - self.name = 'btree_gist' + self.name = "btree_gist" class CITextExtension(CreateExtension): - def __init__(self): - self.name = 'citext' + self.name = "citext" class CryptoExtension(CreateExtension): - def __init__(self): - self.name = 'pgcrypto' + self.name = "pgcrypto" class HStoreExtension(CreateExtension): - def __init__(self): - self.name = 'hstore' + self.name = "hstore" class TrigramExtension(CreateExtension): - def __init__(self): - self.name = 'pg_trgm' + self.name = "pg_trgm" class UnaccentExtension(CreateExtension): - def __init__(self): - self.name = 'unaccent' + self.name = "unaccent" class NotInTransactionMixin: def _ensure_not_in_transaction(self, schema_editor): if schema_editor.connection.in_atomic_block: raise NotSupportedError( - 'The %s operation cannot be executed inside a transaction ' - '(set atomic = False on the migration).' - % self.__class__.__name__ + "The %s operation cannot be executed inside a transaction " + "(set atomic = False on the migration)." % self.__class__.__name__ ) class AddIndexConcurrently(NotInTransactionMixin, AddIndex): """Create an index using PostgreSQL's CREATE INDEX CONCURRENTLY syntax.""" + atomic = False def describe(self): - return 'Concurrently create index %s on field(s) %s of model %s' % ( + return "Concurrently create index %s on field(s) %s of model %s" % ( self.index.name, - ', '.join(self.index.fields), + ", ".join(self.index.fields), self.model_name, ) @@ -145,10 +139,11 @@ class AddIndexConcurrently(NotInTransactionMixin, AddIndex): class RemoveIndexConcurrently(NotInTransactionMixin, RemoveIndex): """Remove an index using PostgreSQL's DROP INDEX CONCURRENTLY syntax.""" + atomic = False def describe(self): - return 'Concurrently remove index %s from %s' % (self.name, self.model_name) + return "Concurrently remove index %s from %s" % (self.name, self.model_name) def database_forwards(self, app_label, schema_editor, from_state, to_state): self._ensure_not_in_transaction(schema_editor) @@ -168,7 +163,7 @@ class RemoveIndexConcurrently(NotInTransactionMixin, RemoveIndex): class CollationOperation(Operation): - def __init__(self, name, locale, *, provider='libc', deterministic=True): + def __init__(self, name, locale, *, provider="libc", deterministic=True): self.name = name self.locale = locale self.provider = provider @@ -178,11 +173,11 @@ class CollationOperation(Operation): pass def deconstruct(self): - kwargs = {'name': self.name, 'locale': self.locale} - if self.provider and self.provider != 'libc': - kwargs['provider'] = self.provider + kwargs = {"name": self.name, "locale": self.locale} + if self.provider and self.provider != "libc": + kwargs["provider"] = self.provider if self.deterministic is False: - kwargs['deterministic'] = self.deterministic + kwargs["deterministic"] = self.deterministic return ( self.__class__.__qualname__, [], @@ -191,34 +186,39 @@ class CollationOperation(Operation): def create_collation(self, schema_editor): if ( - self.deterministic is False and - not schema_editor.connection.features.supports_non_deterministic_collations + self.deterministic is False + and not schema_editor.connection.features.supports_non_deterministic_collations ): raise NotSupportedError( - 'Non-deterministic collations require PostgreSQL 12+.' + "Non-deterministic collations require PostgreSQL 12+." ) - args = {'locale': schema_editor.quote_name(self.locale)} - if self.provider != 'libc': - args['provider'] = schema_editor.quote_name(self.provider) + args = {"locale": schema_editor.quote_name(self.locale)} + if self.provider != "libc": + args["provider"] = schema_editor.quote_name(self.provider) if self.deterministic is False: - args['deterministic'] = 'false' - schema_editor.execute('CREATE COLLATION %(name)s (%(args)s)' % { - 'name': schema_editor.quote_name(self.name), - 'args': ', '.join(f'{option}={value}' for option, value in args.items()), - }) + args["deterministic"] = "false" + schema_editor.execute( + "CREATE COLLATION %(name)s (%(args)s)" + % { + "name": schema_editor.quote_name(self.name), + "args": ", ".join( + f"{option}={value}" for option, value in args.items() + ), + } + ) def remove_collation(self, schema_editor): schema_editor.execute( - 'DROP COLLATION %s' % schema_editor.quote_name(self.name), + "DROP COLLATION %s" % schema_editor.quote_name(self.name), ) class CreateCollation(CollationOperation): """Create a collation.""" + def database_forwards(self, app_label, schema_editor, from_state, to_state): - if ( - schema_editor.connection.vendor != 'postgresql' or - not router.allow_migrate(schema_editor.connection.alias, app_label) + if schema_editor.connection.vendor != "postgresql" or not router.allow_migrate( + schema_editor.connection.alias, app_label ): return self.create_collation(schema_editor) @@ -229,19 +229,19 @@ class CreateCollation(CollationOperation): self.remove_collation(schema_editor) def describe(self): - return f'Create collation {self.name}' + return f"Create collation {self.name}" @property def migration_name_fragment(self): - return 'create_collation_%s' % self.name.lower() + return "create_collation_%s" % self.name.lower() class RemoveCollation(CollationOperation): """Remove a collation.""" + def database_forwards(self, app_label, schema_editor, from_state, to_state): - if ( - schema_editor.connection.vendor != 'postgresql' or - not router.allow_migrate(schema_editor.connection.alias, app_label) + if schema_editor.connection.vendor != "postgresql" or not router.allow_migrate( + schema_editor.connection.alias, app_label ): return self.remove_collation(schema_editor) @@ -252,11 +252,11 @@ class RemoveCollation(CollationOperation): self.create_collation(schema_editor) def describe(self): - return f'Remove collation {self.name}' + return f"Remove collation {self.name}" @property def migration_name_fragment(self): - return 'remove_collation_%s' % self.name.lower() + return "remove_collation_%s" % self.name.lower() class AddConstraintNotValid(AddConstraint): @@ -268,12 +268,12 @@ class AddConstraintNotValid(AddConstraint): def __init__(self, model_name, constraint): if not isinstance(constraint, CheckConstraint): raise TypeError( - 'AddConstraintNotValid.constraint must be a check constraint.' + "AddConstraintNotValid.constraint must be a check constraint." ) super().__init__(model_name, constraint) def describe(self): - return 'Create not valid constraint %s on model %s' % ( + return "Create not valid constraint %s on model %s" % ( self.constraint.name, self.model_name, ) @@ -286,11 +286,11 @@ class AddConstraintNotValid(AddConstraint): # Constraint.create_sql returns interpolated SQL which makes # params=None a necessity to avoid escaping attempts on # execution. - schema_editor.execute(str(constraint_sql) + ' NOT VALID', params=None) + schema_editor.execute(str(constraint_sql) + " NOT VALID", params=None) @property def migration_name_fragment(self): - return super().migration_name_fragment + '_not_valid' + return super().migration_name_fragment + "_not_valid" class ValidateConstraint(Operation): @@ -301,15 +301,18 @@ class ValidateConstraint(Operation): self.name = name def describe(self): - return 'Validate constraint %s on model %s' % (self.name, self.model_name) + return "Validate constraint %s on model %s" % (self.name, self.model_name) def database_forwards(self, app_label, schema_editor, from_state, to_state): model = from_state.apps.get_model(app_label, self.model_name) if self.allow_migrate_model(schema_editor.connection.alias, model): - schema_editor.execute('ALTER TABLE %s VALIDATE CONSTRAINT %s' % ( - schema_editor.quote_name(model._meta.db_table), - schema_editor.quote_name(self.name), - )) + schema_editor.execute( + "ALTER TABLE %s VALIDATE CONSTRAINT %s" + % ( + schema_editor.quote_name(model._meta.db_table), + schema_editor.quote_name(self.name), + ) + ) def database_backwards(self, app_label, schema_editor, from_state, to_state): # PostgreSQL does not provide a way to make a constraint invalid. @@ -320,10 +323,14 @@ class ValidateConstraint(Operation): @property def migration_name_fragment(self): - return '%s_validate_%s' % (self.model_name.lower(), self.name.lower()) + return "%s_validate_%s" % (self.model_name.lower(), self.name.lower()) def deconstruct(self): - return self.__class__.__name__, [], { - 'model_name': self.model_name, - 'name': self.name, - } + return ( + self.__class__.__name__, + [], + { + "model_name": self.model_name, + "name": self.name, + }, + ) diff --git a/django/contrib/postgres/search.py b/django/contrib/postgres/search.py index 164d359b91..f652c1d346 100644 --- a/django/contrib/postgres/search.py +++ b/django/contrib/postgres/search.py @@ -1,18 +1,25 @@ import psycopg2 from django.db.models import ( - CharField, Expression, Field, FloatField, Func, Lookup, TextField, Value, + CharField, + Expression, + Field, + FloatField, + Func, + Lookup, + TextField, + Value, ) from django.db.models.expressions import CombinedExpression from django.db.models.functions import Cast, Coalesce class SearchVectorExact(Lookup): - lookup_name = 'exact' + lookup_name = "exact" def process_rhs(self, qn, connection): if not isinstance(self.rhs, (SearchQuery, CombinedSearchQuery)): - config = getattr(self.lhs, 'config', None) + config = getattr(self.lhs, "config", None) self.rhs = SearchQuery(self.rhs, config=config) rhs, rhs_params = super().process_rhs(qn, connection) return rhs, rhs_params @@ -21,25 +28,23 @@ class SearchVectorExact(Lookup): lhs, lhs_params = self.process_lhs(qn, connection) rhs, rhs_params = self.process_rhs(qn, connection) params = lhs_params + rhs_params - return '%s @@ %s' % (lhs, rhs), params + return "%s @@ %s" % (lhs, rhs), params class SearchVectorField(Field): - def db_type(self, connection): - return 'tsvector' + return "tsvector" class SearchQueryField(Field): - def db_type(self, connection): - return 'tsquery' + return "tsquery" class SearchConfig(Expression): def __init__(self, config): super().__init__() - if not hasattr(config, 'resolve_expression'): + if not hasattr(config, "resolve_expression"): config = Value(config) self.config = config @@ -53,21 +58,21 @@ class SearchConfig(Expression): return [self.config] def set_source_expressions(self, exprs): - self.config, = exprs + (self.config,) = exprs def as_sql(self, compiler, connection): sql, params = compiler.compile(self.config) - return '%s::regconfig' % sql, params + return "%s::regconfig" % sql, params class SearchVectorCombinable: - ADD = '||' + ADD = "||" def _combine(self, other, connector, reversed): if not isinstance(other, SearchVectorCombinable): raise TypeError( - 'SearchVector can only be combined with other SearchVector ' - 'instances, got %s.' % type(other).__name__ + "SearchVector can only be combined with other SearchVector " + "instances, got %s." % type(other).__name__ ) if reversed: return CombinedSearchVector(other, connector, self, self.config) @@ -75,49 +80,61 @@ class SearchVectorCombinable: class SearchVector(SearchVectorCombinable, Func): - function = 'to_tsvector' + function = "to_tsvector" arg_joiner = " || ' ' || " output_field = SearchVectorField() def __init__(self, *expressions, config=None, weight=None): super().__init__(*expressions) self.config = SearchConfig.from_parameter(config) - if weight is not None and not hasattr(weight, 'resolve_expression'): + if weight is not None and not hasattr(weight, "resolve_expression"): weight = Value(weight) self.weight = weight - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): - resolved = super().resolve_expression(query, allow_joins, reuse, summarize, for_save) + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False + ): + resolved = super().resolve_expression( + query, allow_joins, reuse, summarize, for_save + ) if self.config: - resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save) + resolved.config = self.config.resolve_expression( + query, allow_joins, reuse, summarize, for_save + ) return resolved def as_sql(self, compiler, connection, function=None, template=None): clone = self.copy() - clone.set_source_expressions([ - Coalesce( - expression - if isinstance(expression.output_field, (CharField, TextField)) - else Cast(expression, TextField()), - Value('') - ) for expression in clone.get_source_expressions() - ]) + clone.set_source_expressions( + [ + Coalesce( + expression + if isinstance(expression.output_field, (CharField, TextField)) + else Cast(expression, TextField()), + Value(""), + ) + for expression in clone.get_source_expressions() + ] + ) config_sql = None config_params = [] if template is None: if clone.config: config_sql, config_params = compiler.compile(clone.config) - template = '%(function)s(%(config)s, %(expressions)s)' + template = "%(function)s(%(config)s, %(expressions)s)" else: template = clone.template sql, params = super(SearchVector, clone).as_sql( - compiler, connection, function=function, template=template, + compiler, + connection, + function=function, + template=template, config=config_sql, ) extra_params = [] if clone.weight: weight_sql, extra_params = compiler.compile(clone.weight) - sql = 'setweight({}, {})'.format(sql, weight_sql) + sql = "setweight({}, {})".format(sql, weight_sql) return sql, config_params + params + extra_params @@ -128,14 +145,14 @@ class CombinedSearchVector(SearchVectorCombinable, CombinedExpression): class SearchQueryCombinable: - BITAND = '&&' - BITOR = '||' + BITAND = "&&" + BITOR = "||" def _combine(self, other, connector, reversed): if not isinstance(other, SearchQueryCombinable): raise TypeError( - 'SearchQuery can only be combined with other SearchQuery ' - 'instances, got %s.' % type(other).__name__ + "SearchQuery can only be combined with other SearchQuery " + "instances, got %s." % type(other).__name__ ) if reversed: return CombinedSearchQuery(other, connector, self, self.config) @@ -160,17 +177,25 @@ class SearchQueryCombinable: class SearchQuery(SearchQueryCombinable, Func): output_field = SearchQueryField() SEARCH_TYPES = { - 'plain': 'plainto_tsquery', - 'phrase': 'phraseto_tsquery', - 'raw': 'to_tsquery', - 'websearch': 'websearch_to_tsquery', + "plain": "plainto_tsquery", + "phrase": "phraseto_tsquery", + "raw": "to_tsquery", + "websearch": "websearch_to_tsquery", } - def __init__(self, value, output_field=None, *, config=None, invert=False, search_type='plain'): + def __init__( + self, + value, + output_field=None, + *, + config=None, + invert=False, + search_type="plain", + ): self.function = self.SEARCH_TYPES.get(search_type) if self.function is None: raise ValueError("Unknown search_type argument '%s'." % search_type) - if not hasattr(value, 'resolve_expression'): + if not hasattr(value, "resolve_expression"): value = Value(value) expressions = (value,) self.config = SearchConfig.from_parameter(config) @@ -182,7 +207,7 @@ class SearchQuery(SearchQueryCombinable, Func): def as_sql(self, compiler, connection, function=None, template=None): sql, params = super().as_sql(compiler, connection, function, template) if self.invert: - sql = '!!(%s)' % sql + sql = "!!(%s)" % sql return sql, params def __invert__(self): @@ -192,7 +217,7 @@ class SearchQuery(SearchQueryCombinable, Func): def __str__(self): result = super().__str__() - return ('~%s' % result) if self.invert else result + return ("~%s" % result) if self.invert else result class CombinedSearchQuery(SearchQueryCombinable, CombinedExpression): @@ -201,60 +226,73 @@ class CombinedSearchQuery(SearchQueryCombinable, CombinedExpression): super().__init__(lhs, connector, rhs, output_field) def __str__(self): - return '(%s)' % super().__str__() + return "(%s)" % super().__str__() class SearchRank(Func): - function = 'ts_rank' + function = "ts_rank" output_field = FloatField() def __init__( - self, vector, query, weights=None, normalization=None, + self, + vector, + query, + weights=None, + normalization=None, cover_density=False, ): - if not hasattr(vector, 'resolve_expression'): + if not hasattr(vector, "resolve_expression"): vector = SearchVector(vector) - if not hasattr(query, 'resolve_expression'): + if not hasattr(query, "resolve_expression"): query = SearchQuery(query) expressions = (vector, query) if weights is not None: - if not hasattr(weights, 'resolve_expression'): + if not hasattr(weights, "resolve_expression"): weights = Value(weights) expressions = (weights,) + expressions if normalization is not None: - if not hasattr(normalization, 'resolve_expression'): + if not hasattr(normalization, "resolve_expression"): normalization = Value(normalization) expressions += (normalization,) if cover_density: - self.function = 'ts_rank_cd' + self.function = "ts_rank_cd" super().__init__(*expressions) class SearchHeadline(Func): - function = 'ts_headline' - template = '%(function)s(%(expressions)s%(options)s)' + function = "ts_headline" + template = "%(function)s(%(expressions)s%(options)s)" output_field = TextField() def __init__( - self, expression, query, *, config=None, start_sel=None, stop_sel=None, - max_words=None, min_words=None, short_word=None, highlight_all=None, - max_fragments=None, fragment_delimiter=None, + self, + expression, + query, + *, + config=None, + start_sel=None, + stop_sel=None, + max_words=None, + min_words=None, + short_word=None, + highlight_all=None, + max_fragments=None, + fragment_delimiter=None, ): - if not hasattr(query, 'resolve_expression'): + if not hasattr(query, "resolve_expression"): query = SearchQuery(query) options = { - 'StartSel': start_sel, - 'StopSel': stop_sel, - 'MaxWords': max_words, - 'MinWords': min_words, - 'ShortWord': short_word, - 'HighlightAll': highlight_all, - 'MaxFragments': max_fragments, - 'FragmentDelimiter': fragment_delimiter, + "StartSel": start_sel, + "StopSel": stop_sel, + "MaxWords": max_words, + "MinWords": min_words, + "ShortWord": short_word, + "HighlightAll": highlight_all, + "MaxFragments": max_fragments, + "FragmentDelimiter": fragment_delimiter, } self.options = { - option: value - for option, value in options.items() if value is not None + option: value for option, value in options.items() if value is not None } expressions = (expression, query) if config is not None: @@ -263,19 +301,26 @@ class SearchHeadline(Func): super().__init__(*expressions) def as_sql(self, compiler, connection, function=None, template=None): - options_sql = '' + options_sql = "" options_params = [] if self.options: # getquoted() returns a quoted bytestring of the adapted value. - options_params.append(', '.join( - '%s=%s' % ( - option, - psycopg2.extensions.adapt(value).getquoted().decode(), - ) for option, value in self.options.items() - )) - options_sql = ', %s' + options_params.append( + ", ".join( + "%s=%s" + % ( + option, + psycopg2.extensions.adapt(value).getquoted().decode(), + ) + for option, value in self.options.items() + ) + ) + options_sql = ", %s" sql, params = super().as_sql( - compiler, connection, function=function, template=template, + compiler, + connection, + function=function, + template=template, options=options_sql, ) return sql, params + options_params @@ -288,7 +333,7 @@ class TrigramBase(Func): output_field = FloatField() def __init__(self, expression, string, **extra): - if not hasattr(string, 'resolve_expression'): + if not hasattr(string, "resolve_expression"): string = Value(string) super().__init__(expression, string, **extra) @@ -297,24 +342,24 @@ class TrigramWordBase(Func): output_field = FloatField() def __init__(self, string, expression, **extra): - if not hasattr(string, 'resolve_expression'): + if not hasattr(string, "resolve_expression"): string = Value(string) super().__init__(string, expression, **extra) class TrigramSimilarity(TrigramBase): - function = 'SIMILARITY' + function = "SIMILARITY" class TrigramDistance(TrigramBase): - function = '' - arg_joiner = ' <-> ' + function = "" + arg_joiner = " <-> " class TrigramWordDistance(TrigramWordBase): - function = '' - arg_joiner = ' <<-> ' + function = "" + arg_joiner = " <<-> " class TrigramWordSimilarity(TrigramWordBase): - function = 'WORD_SIMILARITY' + function = "WORD_SIMILARITY" diff --git a/django/contrib/postgres/serializers.py b/django/contrib/postgres/serializers.py index 1b1c2f1112..d04bfdbc69 100644 --- a/django/contrib/postgres/serializers.py +++ b/django/contrib/postgres/serializers.py @@ -6,5 +6,5 @@ class RangeSerializer(BaseSerializer): module = self.value.__class__.__module__ # Ranges are implemented in psycopg2._range but the public import # location is psycopg2.extras. - module = 'psycopg2.extras' if module == 'psycopg2._range' else module - return '%s.%r' % (module, self.value), {'import %s' % module} + module = "psycopg2.extras" if module == "psycopg2._range" else module + return "%s.%r" % (module, self.value), {"import %s" % module} diff --git a/django/contrib/postgres/signals.py b/django/contrib/postgres/signals.py index 420b5f6033..b61673fe1f 100644 --- a/django/contrib/postgres/signals.py +++ b/django/contrib/postgres/signals.py @@ -35,12 +35,14 @@ def get_citext_oids(connection_alias): def register_type_handlers(connection, **kwargs): - if connection.vendor != 'postgresql' or connection.alias == NO_DB_ALIAS: + if connection.vendor != "postgresql" or connection.alias == NO_DB_ALIAS: return try: oids, array_oids = get_hstore_oids(connection.alias) - register_hstore(connection.connection, globally=True, oid=oids, array_oid=array_oids) + register_hstore( + connection.connection, globally=True, oid=oids, array_oid=array_oids + ) except ProgrammingError: # Hstore is not available on the database. # @@ -54,7 +56,9 @@ def register_type_handlers(connection, **kwargs): try: citext_oids = get_citext_oids(connection.alias) - array_type = psycopg2.extensions.new_array_type(citext_oids, 'citext[]', psycopg2.STRING) + array_type = psycopg2.extensions.new_array_type( + citext_oids, "citext[]", psycopg2.STRING + ) psycopg2.extensions.register_type(array_type, None) except ProgrammingError: # citext is not available on the database. diff --git a/django/contrib/postgres/utils.py b/django/contrib/postgres/utils.py index f3c022f474..e4f4d81514 100644 --- a/django/contrib/postgres/utils.py +++ b/django/contrib/postgres/utils.py @@ -17,13 +17,13 @@ def prefix_validation_error(error, prefix, code, params): # ngettext calls require a count parameter and are converted # to an empty string if they are missing it. message=format_lazy( - '{} {}', + "{} {}", SimpleLazyObject(lambda: prefix % params), SimpleLazyObject(lambda: error.message % error_params), ), code=code, params={**error_params, **params}, ) - return ValidationError([ - prefix_validation_error(e, prefix, code, params) for e in error.error_list - ]) + return ValidationError( + [prefix_validation_error(e, prefix, code, params) for e in error.error_list] + ) diff --git a/django/contrib/postgres/validators.py b/django/contrib/postgres/validators.py index db6205f356..df2bd88eb9 100644 --- a/django/contrib/postgres/validators.py +++ b/django/contrib/postgres/validators.py @@ -1,24 +1,29 @@ from django.core.exceptions import ValidationError from django.core.validators import ( - MaxLengthValidator, MaxValueValidator, MinLengthValidator, + MaxLengthValidator, + MaxValueValidator, + MinLengthValidator, MinValueValidator, ) from django.utils.deconstruct import deconstructible -from django.utils.translation import gettext_lazy as _, ngettext_lazy +from django.utils.translation import gettext_lazy as _ +from django.utils.translation import ngettext_lazy class ArrayMaxLengthValidator(MaxLengthValidator): message = ngettext_lazy( - 'List contains %(show_value)d item, it should contain no more than %(limit_value)d.', - 'List contains %(show_value)d items, it should contain no more than %(limit_value)d.', - 'limit_value') + "List contains %(show_value)d item, it should contain no more than %(limit_value)d.", + "List contains %(show_value)d items, it should contain no more than %(limit_value)d.", + "limit_value", + ) class ArrayMinLengthValidator(MinLengthValidator): message = ngettext_lazy( - 'List contains %(show_value)d item, it should contain no fewer than %(limit_value)d.', - 'List contains %(show_value)d items, it should contain no fewer than %(limit_value)d.', - 'limit_value') + "List contains %(show_value)d item, it should contain no fewer than %(limit_value)d.", + "List contains %(show_value)d items, it should contain no fewer than %(limit_value)d.", + "limit_value", + ) @deconstructible @@ -26,8 +31,8 @@ class KeysValidator: """A validator designed for HStore to require/restrict keys.""" messages = { - 'missing_keys': _('Some keys were missing: %(keys)s'), - 'extra_keys': _('Some unknown keys were provided: %(keys)s'), + "missing_keys": _("Some keys were missing: %(keys)s"), + "extra_keys": _("Some unknown keys were provided: %(keys)s"), } strict = False @@ -42,35 +47,41 @@ class KeysValidator: missing_keys = self.keys - keys if missing_keys: raise ValidationError( - self.messages['missing_keys'], - code='missing_keys', - params={'keys': ', '.join(missing_keys)}, + self.messages["missing_keys"], + code="missing_keys", + params={"keys": ", ".join(missing_keys)}, ) if self.strict: extra_keys = keys - self.keys if extra_keys: raise ValidationError( - self.messages['extra_keys'], - code='extra_keys', - params={'keys': ', '.join(extra_keys)}, + self.messages["extra_keys"], + code="extra_keys", + params={"keys": ", ".join(extra_keys)}, ) def __eq__(self, other): return ( - isinstance(other, self.__class__) and - self.keys == other.keys and - self.messages == other.messages and - self.strict == other.strict + isinstance(other, self.__class__) + and self.keys == other.keys + and self.messages == other.messages + and self.strict == other.strict ) class RangeMaxValueValidator(MaxValueValidator): def compare(self, a, b): return a.upper is None or a.upper > b - message = _('Ensure that this range is completely less than or equal to %(limit_value)s.') + + message = _( + "Ensure that this range is completely less than or equal to %(limit_value)s." + ) class RangeMinValueValidator(MinValueValidator): def compare(self, a, b): return a.lower is None or a.lower < b - message = _('Ensure that this range is completely greater than or equal to %(limit_value)s.') + + message = _( + "Ensure that this range is completely greater than or equal to %(limit_value)s." + ) |
