diff options
| author | django-bot <ops@djangoproject.com> | 2022-02-03 20:24:19 +0100 |
|---|---|---|
| committer | Mariusz Felisiak <felisiak.mariusz@gmail.com> | 2022-02-07 20:37:05 +0100 |
| commit | 9c19aff7c7561e3a82978a272ecdaad40dda5c00 (patch) | |
| tree | f0506b668a013d0063e5fba3dbf4863b466713ba /django/db/backends/postgresql | |
| parent | f68fa8b45dfac545cfc4111d4e52804c86db68d3 (diff) | |
Refs #33476 -- Reformatted code with Black.
Diffstat (limited to 'django/db/backends/postgresql')
| -rw-r--r-- | django/db/backends/postgresql/base.py | 211 | ||||
| -rw-r--r-- | django/db/backends/postgresql/client.py | 48 | ||||
| -rw-r--r-- | django/db/backends/postgresql/creation.py | 47 | ||||
| -rw-r--r-- | django/db/backends/postgresql/features.py | 30 | ||||
| -rw-r--r-- | django/db/backends/postgresql/introspection.py | 130 | ||||
| -rw-r--r-- | django/db/backends/postgresql/operations.py | 155 | ||||
| -rw-r--r-- | django/db/backends/postgresql/schema.py | 204 |
7 files changed, 498 insertions, 327 deletions
diff --git a/django/db/backends/postgresql/base.py b/django/db/backends/postgresql/base.py index e49d453f94..92f393227e 100644 --- a/django/db/backends/postgresql/base.py +++ b/django/db/backends/postgresql/base.py @@ -11,11 +11,10 @@ from contextlib import contextmanager from django.conf import settings from django.core.exceptions import ImproperlyConfigured -from django.db import DatabaseError as WrappedDatabaseError, connections +from django.db import DatabaseError as WrappedDatabaseError +from django.db import connections from django.db.backends.base.base import BaseDatabaseWrapper -from django.db.backends.utils import ( - CursorDebugWrapper as BaseCursorDebugWrapper, -) +from django.db.backends.utils import CursorDebugWrapper as BaseCursorDebugWrapper from django.utils.asyncio import async_unsafe from django.utils.functional import cached_property from django.utils.safestring import SafeString @@ -30,14 +29,17 @@ except ImportError as e: def psycopg2_version(): - version = psycopg2.__version__.split(' ', 1)[0] + version = psycopg2.__version__.split(" ", 1)[0] return get_version_tuple(version) PSYCOPG2_VERSION = psycopg2_version() if PSYCOPG2_VERSION < (2, 8, 4): - raise ImproperlyConfigured("psycopg2 version 2.8.4 or newer is required; you have %s" % psycopg2.__version__) + raise ImproperlyConfigured( + "psycopg2 version 2.8.4 or newer is required; you have %s" + % psycopg2.__version__ + ) # Some of these import psycopg2, so import them after checking if it's installed. @@ -56,68 +58,68 @@ psycopg2.extras.register_uuid() INETARRAY_OID = 1041 INETARRAY = psycopg2.extensions.new_array_type( (INETARRAY_OID,), - 'INETARRAY', + "INETARRAY", psycopg2.extensions.UNICODE, ) psycopg2.extensions.register_type(INETARRAY) class DatabaseWrapper(BaseDatabaseWrapper): - vendor = 'postgresql' - display_name = 'PostgreSQL' + vendor = "postgresql" + display_name = "PostgreSQL" # This dictionary maps Field objects to their associated PostgreSQL column # types, as strings. Column-type strings can contain format strings; they'll # be interpolated against the values of Field.__dict__ before being output. # If a column type is set to None, it won't be included in the output. data_types = { - 'AutoField': 'serial', - 'BigAutoField': 'bigserial', - 'BinaryField': 'bytea', - 'BooleanField': 'boolean', - 'CharField': 'varchar(%(max_length)s)', - 'DateField': 'date', - 'DateTimeField': 'timestamp with time zone', - 'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)', - 'DurationField': 'interval', - 'FileField': 'varchar(%(max_length)s)', - 'FilePathField': 'varchar(%(max_length)s)', - 'FloatField': 'double precision', - 'IntegerField': 'integer', - 'BigIntegerField': 'bigint', - 'IPAddressField': 'inet', - 'GenericIPAddressField': 'inet', - 'JSONField': 'jsonb', - 'OneToOneField': 'integer', - 'PositiveBigIntegerField': 'bigint', - 'PositiveIntegerField': 'integer', - 'PositiveSmallIntegerField': 'smallint', - 'SlugField': 'varchar(%(max_length)s)', - 'SmallAutoField': 'smallserial', - 'SmallIntegerField': 'smallint', - 'TextField': 'text', - 'TimeField': 'time', - 'UUIDField': 'uuid', + "AutoField": "serial", + "BigAutoField": "bigserial", + "BinaryField": "bytea", + "BooleanField": "boolean", + "CharField": "varchar(%(max_length)s)", + "DateField": "date", + "DateTimeField": "timestamp with time zone", + "DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)", + "DurationField": "interval", + "FileField": "varchar(%(max_length)s)", + "FilePathField": "varchar(%(max_length)s)", + "FloatField": "double precision", + "IntegerField": "integer", + "BigIntegerField": "bigint", + "IPAddressField": "inet", + "GenericIPAddressField": "inet", + "JSONField": "jsonb", + "OneToOneField": "integer", + "PositiveBigIntegerField": "bigint", + "PositiveIntegerField": "integer", + "PositiveSmallIntegerField": "smallint", + "SlugField": "varchar(%(max_length)s)", + "SmallAutoField": "smallserial", + "SmallIntegerField": "smallint", + "TextField": "text", + "TimeField": "time", + "UUIDField": "uuid", } data_type_check_constraints = { - 'PositiveBigIntegerField': '"%(column)s" >= 0', - 'PositiveIntegerField': '"%(column)s" >= 0', - 'PositiveSmallIntegerField': '"%(column)s" >= 0', + "PositiveBigIntegerField": '"%(column)s" >= 0', + "PositiveIntegerField": '"%(column)s" >= 0', + "PositiveSmallIntegerField": '"%(column)s" >= 0', } operators = { - 'exact': '= %s', - 'iexact': '= UPPER(%s)', - 'contains': 'LIKE %s', - 'icontains': 'LIKE UPPER(%s)', - 'regex': '~ %s', - 'iregex': '~* %s', - 'gt': '> %s', - 'gte': '>= %s', - 'lt': '< %s', - 'lte': '<= %s', - 'startswith': 'LIKE %s', - 'endswith': 'LIKE %s', - 'istartswith': 'LIKE UPPER(%s)', - 'iendswith': 'LIKE UPPER(%s)', + "exact": "= %s", + "iexact": "= UPPER(%s)", + "contains": "LIKE %s", + "icontains": "LIKE UPPER(%s)", + "regex": "~ %s", + "iregex": "~* %s", + "gt": "> %s", + "gte": ">= %s", + "lt": "< %s", + "lte": "<= %s", + "startswith": "LIKE %s", + "endswith": "LIKE %s", + "istartswith": "LIKE UPPER(%s)", + "iendswith": "LIKE UPPER(%s)", } # The patterns below are used to generate SQL pattern lookup clauses when @@ -128,14 +130,16 @@ class DatabaseWrapper(BaseDatabaseWrapper): # # Note: we use str.format() here for readability as '%' is used as a wildcard for # the LIKE operator. - pattern_esc = r"REPLACE(REPLACE(REPLACE({}, E'\\', E'\\\\'), E'%%', E'\\%%'), E'_', E'\\_')" + pattern_esc = ( + r"REPLACE(REPLACE(REPLACE({}, E'\\', E'\\\\'), E'%%', E'\\%%'), E'_', E'\\_')" + ) pattern_ops = { - 'contains': "LIKE '%%' || {} || '%%'", - 'icontains': "LIKE '%%' || UPPER({}) || '%%'", - 'startswith': "LIKE {} || '%%'", - 'istartswith': "LIKE UPPER({}) || '%%'", - 'endswith': "LIKE '%%' || {}", - 'iendswith': "LIKE '%%' || UPPER({})", + "contains": "LIKE '%%' || {} || '%%'", + "icontains": "LIKE '%%' || UPPER({}) || '%%'", + "startswith": "LIKE {} || '%%'", + "istartswith": "LIKE UPPER({}) || '%%'", + "endswith": "LIKE '%%' || {}", + "iendswith": "LIKE '%%' || UPPER({})", } Database = Database @@ -152,46 +156,46 @@ class DatabaseWrapper(BaseDatabaseWrapper): def get_connection_params(self): settings_dict = self.settings_dict # None may be used to connect to the default 'postgres' db - if ( - settings_dict['NAME'] == '' and - not settings_dict.get('OPTIONS', {}).get('service') + if settings_dict["NAME"] == "" and not settings_dict.get("OPTIONS", {}).get( + "service" ): raise ImproperlyConfigured( "settings.DATABASES is improperly configured. " "Please supply the NAME or OPTIONS['service'] value." ) - if len(settings_dict['NAME'] or '') > self.ops.max_name_length(): + if len(settings_dict["NAME"] or "") > self.ops.max_name_length(): raise ImproperlyConfigured( "The database name '%s' (%d characters) is longer than " "PostgreSQL's limit of %d characters. Supply a shorter NAME " - "in settings.DATABASES." % ( - settings_dict['NAME'], - len(settings_dict['NAME']), + "in settings.DATABASES." + % ( + settings_dict["NAME"], + len(settings_dict["NAME"]), self.ops.max_name_length(), ) ) conn_params = {} - if settings_dict['NAME']: + if settings_dict["NAME"]: conn_params = { - 'database': settings_dict['NAME'], - **settings_dict['OPTIONS'], + "database": settings_dict["NAME"], + **settings_dict["OPTIONS"], } - elif settings_dict['NAME'] is None: + elif settings_dict["NAME"] is None: # Connect to the default 'postgres' db. - settings_dict.get('OPTIONS', {}).pop('service', None) - conn_params = {'database': 'postgres', **settings_dict['OPTIONS']} + settings_dict.get("OPTIONS", {}).pop("service", None) + conn_params = {"database": "postgres", **settings_dict["OPTIONS"]} else: - conn_params = {**settings_dict['OPTIONS']} + conn_params = {**settings_dict["OPTIONS"]} - conn_params.pop('isolation_level', None) - if settings_dict['USER']: - conn_params['user'] = settings_dict['USER'] - if settings_dict['PASSWORD']: - conn_params['password'] = settings_dict['PASSWORD'] - if settings_dict['HOST']: - conn_params['host'] = settings_dict['HOST'] - if settings_dict['PORT']: - conn_params['port'] = settings_dict['PORT'] + conn_params.pop("isolation_level", None) + if settings_dict["USER"]: + conn_params["user"] = settings_dict["USER"] + if settings_dict["PASSWORD"]: + conn_params["password"] = settings_dict["PASSWORD"] + if settings_dict["HOST"]: + conn_params["host"] = settings_dict["HOST"] + if settings_dict["PORT"]: + conn_params["port"] = settings_dict["PORT"] return conn_params @async_unsafe @@ -203,9 +207,9 @@ class DatabaseWrapper(BaseDatabaseWrapper): # default when no value is explicitly specified in options. # - before calling _set_autocommit() because if autocommit is on, that # will set connection.isolation_level to ISOLATION_LEVEL_AUTOCOMMIT. - options = self.settings_dict['OPTIONS'] + options = self.settings_dict["OPTIONS"] try: - self.isolation_level = options['isolation_level'] + self.isolation_level = options["isolation_level"] except KeyError: self.isolation_level = connection.isolation_level else: @@ -215,13 +219,15 @@ class DatabaseWrapper(BaseDatabaseWrapper): # Register dummy loads() to avoid a round trip from psycopg2's decode # to json.dumps() to json.loads(), when using a custom decoder in # JSONField. - psycopg2.extras.register_default_jsonb(conn_or_curs=connection, loads=lambda x: x) + psycopg2.extras.register_default_jsonb( + conn_or_curs=connection, loads=lambda x: x + ) return connection def ensure_timezone(self): if self.connection is None: return False - conn_timezone_name = self.connection.get_parameter_status('TimeZone') + conn_timezone_name = self.connection.get_parameter_status("TimeZone") timezone_name = self.timezone_name if timezone_name and conn_timezone_name != timezone_name: with self.connection.cursor() as cursor: @@ -230,7 +236,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): return False def init_connection_state(self): - self.connection.set_client_encoding('UTF8') + self.connection.set_client_encoding("UTF8") timezone_changed = self.ensure_timezone() if timezone_changed: @@ -243,7 +249,9 @@ class DatabaseWrapper(BaseDatabaseWrapper): if name: # In autocommit mode, the cursor will be used outside of a # transaction, hence use a holdable cursor. - cursor = self.connection.cursor(name, scrollable=False, withhold=self.connection.autocommit) + cursor = self.connection.cursor( + name, scrollable=False, withhold=self.connection.autocommit + ) else: cursor = self.connection.cursor() cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None @@ -268,10 +276,11 @@ class DatabaseWrapper(BaseDatabaseWrapper): if current_task: task_ident = str(id(current_task)) else: - task_ident = 'sync' + task_ident = "sync" # Use that and the thread ident to get a unique name return self._cursor( - name='_django_curs_%d_%s_%d' % ( + name="_django_curs_%d_%s_%d" + % ( # Avoid reusing name in other threads / tasks threading.current_thread().ident, task_ident, @@ -289,14 +298,14 @@ class DatabaseWrapper(BaseDatabaseWrapper): afterward. """ with self.cursor() as cursor: - cursor.execute('SET CONSTRAINTS ALL IMMEDIATE') - cursor.execute('SET CONSTRAINTS ALL DEFERRED') + cursor.execute("SET CONSTRAINTS ALL IMMEDIATE") + cursor.execute("SET CONSTRAINTS ALL DEFERRED") def is_usable(self): try: # Use a psycopg cursor directly, bypassing Django's utilities. with self.connection.cursor() as cursor: - cursor.execute('SELECT 1') + cursor.execute("SELECT 1") except Database.Error: return False else: @@ -317,12 +326,18 @@ class DatabaseWrapper(BaseDatabaseWrapper): "database when it's not needed (for example, when running tests). " "Django was unable to create a connection to the 'postgres' database " "and will use the first PostgreSQL database instead.", - RuntimeWarning + RuntimeWarning, ) for connection in connections.all(): - if connection.vendor == 'postgresql' and connection.settings_dict['NAME'] != 'postgres': + if ( + connection.vendor == "postgresql" + and connection.settings_dict["NAME"] != "postgres" + ): conn = self.__class__( - {**self.settings_dict, 'NAME': connection.settings_dict['NAME']}, + { + **self.settings_dict, + "NAME": connection.settings_dict["NAME"], + }, alias=self.alias, ) try: @@ -349,5 +364,5 @@ class CursorDebugWrapper(BaseCursorDebugWrapper): return self.cursor.copy_expert(sql, file, *args) def copy_to(self, file, table, *args, **kwargs): - with self.debug_sql(sql='COPY %s TO STDOUT' % table): + with self.debug_sql(sql="COPY %s TO STDOUT" % table): return self.cursor.copy_to(file, table, *args, **kwargs) diff --git a/django/db/backends/postgresql/client.py b/django/db/backends/postgresql/client.py index 0effcc44e6..4c9bd63546 100644 --- a/django/db/backends/postgresql/client.py +++ b/django/db/backends/postgresql/client.py @@ -4,53 +4,53 @@ from django.db.backends.base.client import BaseDatabaseClient class DatabaseClient(BaseDatabaseClient): - executable_name = 'psql' + executable_name = "psql" @classmethod def settings_to_cmd_args_env(cls, settings_dict, parameters): args = [cls.executable_name] - options = settings_dict.get('OPTIONS', {}) + options = settings_dict.get("OPTIONS", {}) - host = settings_dict.get('HOST') - port = settings_dict.get('PORT') - dbname = settings_dict.get('NAME') - user = settings_dict.get('USER') - passwd = settings_dict.get('PASSWORD') - passfile = options.get('passfile') - service = options.get('service') - sslmode = options.get('sslmode') - sslrootcert = options.get('sslrootcert') - sslcert = options.get('sslcert') - sslkey = options.get('sslkey') + host = settings_dict.get("HOST") + port = settings_dict.get("PORT") + dbname = settings_dict.get("NAME") + user = settings_dict.get("USER") + passwd = settings_dict.get("PASSWORD") + passfile = options.get("passfile") + service = options.get("service") + sslmode = options.get("sslmode") + sslrootcert = options.get("sslrootcert") + sslcert = options.get("sslcert") + sslkey = options.get("sslkey") if not dbname and not service: # Connect to the default 'postgres' db. - dbname = 'postgres' + dbname = "postgres" if user: - args += ['-U', user] + args += ["-U", user] if host: - args += ['-h', host] + args += ["-h", host] if port: - args += ['-p', str(port)] + args += ["-p", str(port)] if dbname: args += [dbname] args.extend(parameters) env = {} if passwd: - env['PGPASSWORD'] = str(passwd) + env["PGPASSWORD"] = str(passwd) if service: - env['PGSERVICE'] = str(service) + env["PGSERVICE"] = str(service) if sslmode: - env['PGSSLMODE'] = str(sslmode) + env["PGSSLMODE"] = str(sslmode) if sslrootcert: - env['PGSSLROOTCERT'] = str(sslrootcert) + env["PGSSLROOTCERT"] = str(sslrootcert) if sslcert: - env['PGSSLCERT'] = str(sslcert) + env["PGSSLCERT"] = str(sslcert) if sslkey: - env['PGSSLKEY'] = str(sslkey) + env["PGSSLKEY"] = str(sslkey) if passfile: - env['PGPASSFILE'] = str(passfile) + env["PGPASSFILE"] = str(passfile) return args, (env or None) def runshell(self, parameters): diff --git a/django/db/backends/postgresql/creation.py b/django/db/backends/postgresql/creation.py index eb8ac3bcf5..70c3eda566 100644 --- a/django/db/backends/postgresql/creation.py +++ b/django/db/backends/postgresql/creation.py @@ -8,7 +8,6 @@ from django.db.backends.utils import strip_quotes class DatabaseCreation(BaseDatabaseCreation): - def _quote_name(self, name): return self.connection.ops.quote_name(name) @@ -21,32 +20,35 @@ class DatabaseCreation(BaseDatabaseCreation): return suffix and "WITH" + suffix def sql_table_creation_suffix(self): - test_settings = self.connection.settings_dict['TEST'] - if test_settings.get('COLLATION') is not None: + test_settings = self.connection.settings_dict["TEST"] + if test_settings.get("COLLATION") is not None: raise ImproperlyConfigured( - 'PostgreSQL does not support collation setting at database ' - 'creation time.' + "PostgreSQL does not support collation setting at database " + "creation time." ) return self._get_database_create_suffix( - encoding=test_settings['CHARSET'], - template=test_settings.get('TEMPLATE'), + encoding=test_settings["CHARSET"], + template=test_settings.get("TEMPLATE"), ) def _database_exists(self, cursor, database_name): - cursor.execute('SELECT 1 FROM pg_catalog.pg_database WHERE datname = %s', [strip_quotes(database_name)]) + cursor.execute( + "SELECT 1 FROM pg_catalog.pg_database WHERE datname = %s", + [strip_quotes(database_name)], + ) return cursor.fetchone() is not None def _execute_create_test_db(self, cursor, parameters, keepdb=False): try: - if keepdb and self._database_exists(cursor, parameters['dbname']): + if keepdb and self._database_exists(cursor, parameters["dbname"]): # If the database should be kept and it already exists, don't # try to create a new one. return super()._execute_create_test_db(cursor, parameters, keepdb) except Exception as e: - if getattr(e.__cause__, 'pgcode', '') != errorcodes.DUPLICATE_DATABASE: + if getattr(e.__cause__, "pgcode", "") != errorcodes.DUPLICATE_DATABASE: # All errors except "database already exists" cancel tests. - self.log('Got an error creating the test database: %s' % e) + self.log("Got an error creating the test database: %s" % e) sys.exit(2) elif not keepdb: # If the database should be kept, ignore "database already @@ -58,11 +60,11 @@ class DatabaseCreation(BaseDatabaseCreation): # to the template database. self.connection.close() - source_database_name = self.connection.settings_dict['NAME'] - target_database_name = self.get_test_db_clone_settings(suffix)['NAME'] + source_database_name = self.connection.settings_dict["NAME"] + target_database_name = self.get_test_db_clone_settings(suffix)["NAME"] test_db_params = { - 'dbname': self._quote_name(target_database_name), - 'suffix': self._get_database_create_suffix(template=source_database_name), + "dbname": self._quote_name(target_database_name), + "suffix": self._get_database_create_suffix(template=source_database_name), } with self._nodb_cursor() as cursor: try: @@ -70,11 +72,16 @@ class DatabaseCreation(BaseDatabaseCreation): except Exception: try: if verbosity >= 1: - self.log('Destroying old test database for alias %s...' % ( - self._get_database_display_str(verbosity, target_database_name), - )) - cursor.execute('DROP DATABASE %(dbname)s' % test_db_params) + self.log( + "Destroying old test database for alias %s..." + % ( + self._get_database_display_str( + verbosity, target_database_name + ), + ) + ) + cursor.execute("DROP DATABASE %(dbname)s" % test_db_params) self._execute_create_test_db(cursor, test_db_params, keepdb) except Exception as e: - self.log('Got an error cloning the test database: %s' % e) + self.log("Got an error cloning the test database: %s" % e) sys.exit(2) diff --git a/django/db/backends/postgresql/features.py b/django/db/backends/postgresql/features.py index 1ce73fb0a8..182c230c75 100644 --- a/django/db/backends/postgresql/features.py +++ b/django/db/backends/postgresql/features.py @@ -52,7 +52,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_over_clause = True only_supports_unbounded_with_preceding_and_following = True supports_aggregate_filter_clause = True - supported_explain_formats = {'JSON', 'TEXT', 'XML', 'YAML'} + supported_explain_formats = {"JSON", "TEXT", "XML", "YAML"} validates_explain_options = False # A query will error on invalid options. supports_deferrable_unique_constraints = True has_json_operators = True @@ -60,14 +60,14 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_update_conflicts = True supports_update_conflicts_with_target = True test_collations = { - 'non_default': 'sv-x-icu', - 'swedish_ci': 'sv-x-icu', + "non_default": "sv-x-icu", + "swedish_ci": "sv-x-icu", } test_now_utc_template = "STATEMENT_TIMESTAMP() AT TIME ZONE 'UTC'" django_test_skips = { - 'opclasses are PostgreSQL only.': { - 'indexes.tests.SchemaIndexesNotPostgreSQLTests.test_create_index_ignores_opclasses', + "opclasses are PostgreSQL only.": { + "indexes.tests.SchemaIndexesNotPostgreSQLTests.test_create_index_ignores_opclasses", }, } @@ -75,9 +75,9 @@ class DatabaseFeatures(BaseDatabaseFeatures): def introspected_field_types(self): return { **super().introspected_field_types, - 'PositiveBigIntegerField': 'BigIntegerField', - 'PositiveIntegerField': 'IntegerField', - 'PositiveSmallIntegerField': 'SmallIntegerField', + "PositiveBigIntegerField": "BigIntegerField", + "PositiveIntegerField": "IntegerField", + "PositiveSmallIntegerField": "SmallIntegerField", } @cached_property @@ -96,9 +96,11 @@ class DatabaseFeatures(BaseDatabaseFeatures): def is_postgresql_14(self): return self.connection.pg_version >= 140000 - has_bit_xor = property(operator.attrgetter('is_postgresql_14')) - has_websearch_to_tsquery = property(operator.attrgetter('is_postgresql_11')) - supports_covering_indexes = property(operator.attrgetter('is_postgresql_11')) - supports_covering_gist_indexes = property(operator.attrgetter('is_postgresql_12')) - supports_covering_spgist_indexes = property(operator.attrgetter('is_postgresql_14')) - supports_non_deterministic_collations = property(operator.attrgetter('is_postgresql_12')) + has_bit_xor = property(operator.attrgetter("is_postgresql_14")) + has_websearch_to_tsquery = property(operator.attrgetter("is_postgresql_11")) + supports_covering_indexes = property(operator.attrgetter("is_postgresql_11")) + supports_covering_gist_indexes = property(operator.attrgetter("is_postgresql_12")) + supports_covering_spgist_indexes = property(operator.attrgetter("is_postgresql_14")) + supports_non_deterministic_collations = property( + operator.attrgetter("is_postgresql_12") + ) diff --git a/django/db/backends/postgresql/introspection.py b/django/db/backends/postgresql/introspection.py index f31d906a2f..a7e9a13d61 100644 --- a/django/db/backends/postgresql/introspection.py +++ b/django/db/backends/postgresql/introspection.py @@ -1,5 +1,7 @@ from django.db.backends.base.introspection import ( - BaseDatabaseIntrospection, FieldInfo, TableInfo, + BaseDatabaseIntrospection, + FieldInfo, + TableInfo, ) from django.db.models import Index @@ -7,46 +9,47 @@ from django.db.models import Index class DatabaseIntrospection(BaseDatabaseIntrospection): # Maps type codes to Django Field types. data_types_reverse = { - 16: 'BooleanField', - 17: 'BinaryField', - 20: 'BigIntegerField', - 21: 'SmallIntegerField', - 23: 'IntegerField', - 25: 'TextField', - 700: 'FloatField', - 701: 'FloatField', - 869: 'GenericIPAddressField', - 1042: 'CharField', # blank-padded - 1043: 'CharField', - 1082: 'DateField', - 1083: 'TimeField', - 1114: 'DateTimeField', - 1184: 'DateTimeField', - 1186: 'DurationField', - 1266: 'TimeField', - 1700: 'DecimalField', - 2950: 'UUIDField', - 3802: 'JSONField', + 16: "BooleanField", + 17: "BinaryField", + 20: "BigIntegerField", + 21: "SmallIntegerField", + 23: "IntegerField", + 25: "TextField", + 700: "FloatField", + 701: "FloatField", + 869: "GenericIPAddressField", + 1042: "CharField", # blank-padded + 1043: "CharField", + 1082: "DateField", + 1083: "TimeField", + 1114: "DateTimeField", + 1184: "DateTimeField", + 1186: "DurationField", + 1266: "TimeField", + 1700: "DecimalField", + 2950: "UUIDField", + 3802: "JSONField", } # A hook for subclasses. - index_default_access_method = 'btree' + index_default_access_method = "btree" ignored_tables = [] def get_field_type(self, data_type, description): field_type = super().get_field_type(data_type, description) - if description.default and 'nextval' in description.default: - if field_type == 'IntegerField': - return 'AutoField' - elif field_type == 'BigIntegerField': - return 'BigAutoField' - elif field_type == 'SmallIntegerField': - return 'SmallAutoField' + if description.default and "nextval" in description.default: + if field_type == "IntegerField": + return "AutoField" + elif field_type == "BigIntegerField": + return "BigAutoField" + elif field_type == "SmallIntegerField": + return "SmallAutoField" return field_type def get_table_list(self, cursor): """Return a list of table and view names in the current database.""" - cursor.execute(""" + cursor.execute( + """ SELECT c.relname, CASE WHEN c.relispartition THEN 'p' WHEN c.relkind IN ('m', 'v') THEN 'v' ELSE 't' END FROM pg_catalog.pg_class c @@ -54,8 +57,13 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v') AND n.nspname NOT IN ('pg_catalog', 'pg_toast') AND pg_catalog.pg_table_is_visible(c.oid) - """) - return [TableInfo(*row) for row in cursor.fetchall() if row[0] not in self.ignored_tables] + """ + ) + return [ + TableInfo(*row) + for row in cursor.fetchall() + if row[0] not in self.ignored_tables + ] def get_table_description(self, cursor, table_name): """ @@ -65,7 +73,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): # Query the pg_catalog tables as cursor.description does not reliably # return the nullable property and information_schema.columns does not # contain details of materialized views. - cursor.execute(""" + cursor.execute( + """ SELECT a.attname AS column_name, NOT (a.attnotnull OR (t.typtype = 'd' AND t.typnotnull)) AS is_nullable, @@ -81,9 +90,13 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): AND c.relname = %s AND n.nspname NOT IN ('pg_catalog', 'pg_toast') AND pg_catalog.pg_table_is_visible(c.oid) - """, [table_name]) + """, + [table_name], + ) field_map = {line[0]: line[1:] for line in cursor.fetchall()} - cursor.execute("SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name)) + cursor.execute( + "SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name) + ) return [ FieldInfo( line.name, @@ -98,7 +111,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): ] def get_sequences(self, cursor, table_name, table_fields=()): - cursor.execute(""" + cursor.execute( + """ SELECT s.relname as sequence_name, col.attname FROM pg_class s JOIN pg_namespace sn ON sn.oid = s.relnamespace @@ -110,9 +124,11 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): AND d.deptype in ('a', 'n') AND pg_catalog.pg_table_is_visible(tbl.oid) AND tbl.relname = %s - """, [table_name]) + """, + [table_name], + ) return [ - {'name': row[0], 'table': table_name, 'column': row[1]} + {"name": row[0], "table": table_name, "column": row[1]} for row in cursor.fetchall() ] @@ -121,7 +137,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): Return a dictionary of {field_name: (field_name_other_table, other_table)} representing all foreign keys in the given table. """ - cursor.execute(""" + cursor.execute( + """ SELECT a1.attname, c2.relname, a2.attname FROM pg_constraint con LEFT JOIN pg_class c1 ON con.conrelid = c1.oid @@ -133,7 +150,9 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): con.contype = 'f' AND c1.relnamespace = c2.relnamespace AND pg_catalog.pg_table_is_visible(c1.oid) - """, [table_name]) + """, + [table_name], + ) return {row[0]: (row[2], row[1]) for row in cursor.fetchall()} def get_constraints(self, cursor, table_name): @@ -146,7 +165,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): # Loop over the key table, collecting things as constraints. The column # array must return column names in the same order in which they were # created. - cursor.execute(""" + cursor.execute( + """ SELECT c.conname, array( @@ -165,7 +185,9 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): FROM pg_constraint AS c JOIN pg_class AS cl ON c.conrelid = cl.oid WHERE cl.relname = %s AND pg_catalog.pg_table_is_visible(cl.oid) - """, [table_name]) + """, + [table_name], + ) for constraint, columns, kind, used_cols, options in cursor.fetchall(): constraints[constraint] = { "columns": columns, @@ -178,7 +200,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): "options": options, } # Now get indexes - cursor.execute(""" + cursor.execute( + """ SELECT indexname, array_agg(attname ORDER BY arridx), indisunique, indisprimary, array_agg(ordering ORDER BY arridx), amname, exprdef, s2.attoptions @@ -207,14 +230,27 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): WHERE c.relname = %s AND pg_catalog.pg_table_is_visible(c.oid) ) s2 GROUP BY indexname, indisunique, indisprimary, amname, exprdef, attoptions; - """, [self.index_default_access_method, table_name]) - for index, columns, unique, primary, orders, type_, definition, options in cursor.fetchall(): + """, + [self.index_default_access_method, table_name], + ) + for ( + index, + columns, + unique, + primary, + orders, + type_, + definition, + options, + ) in cursor.fetchall(): if index not in constraints: basic_index = ( - type_ == self.index_default_access_method and + type_ == self.index_default_access_method + and # '_btree' references # django.contrib.postgres.indexes.BTreeIndex.suffix. - not index.endswith('_btree') and options is None + not index.endswith("_btree") + and options is None ) constraints[index] = { "columns": columns if columns != [None] else [], diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py index 762cd8d23e..68448157ec 100644 --- a/django/db/backends/postgresql/operations.py +++ b/django/db/backends/postgresql/operations.py @@ -7,17 +7,22 @@ from django.db.models.constants import OnConflict class DatabaseOperations(BaseDatabaseOperations): - cast_char_field_without_max_length = 'varchar' - explain_prefix = 'EXPLAIN' + cast_char_field_without_max_length = "varchar" + explain_prefix = "EXPLAIN" cast_data_types = { - 'AutoField': 'integer', - 'BigAutoField': 'bigint', - 'SmallAutoField': 'smallint', + "AutoField": "integer", + "BigAutoField": "bigint", + "SmallAutoField": "smallint", } def unification_cast_sql(self, output_field): internal_type = output_field.get_internal_type() - if internal_type in ("GenericIPAddressField", "IPAddressField", "TimeField", "UUIDField"): + if internal_type in ( + "GenericIPAddressField", + "IPAddressField", + "TimeField", + "UUIDField", + ): # PostgreSQL will resolve a union as type 'text' if input types are # 'unknown'. # https://www.postgresql.org/docs/current/typeconv-union-case.html @@ -25,17 +30,19 @@ class DatabaseOperations(BaseDatabaseOperations): # PostgreSQL configuration so we need to explicitly cast them. # We must also remove components of the type within brackets: # varchar(255) -> varchar. - return 'CAST(%%s AS %s)' % output_field.db_type(self.connection).split('(')[0] - return '%s' + return ( + "CAST(%%s AS %s)" % output_field.db_type(self.connection).split("(")[0] + ) + return "%s" def date_extract_sql(self, lookup_type, field_name): # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT - if lookup_type == 'week_day': + if lookup_type == "week_day": # For consistency across backends, we return Sunday=1, Saturday=7. return "EXTRACT('dow' FROM %s) + 1" % field_name - elif lookup_type == 'iso_week_day': + elif lookup_type == "iso_week_day": return "EXTRACT('isodow' FROM %s)" % field_name - elif lookup_type == 'iso_year': + elif lookup_type == "iso_year": return "EXTRACT('isoyear' FROM %s)" % field_name else: return "EXTRACT('%s' FROM %s)" % (lookup_type, field_name) @@ -48,22 +55,25 @@ class DatabaseOperations(BaseDatabaseOperations): def _prepare_tzname_delta(self, tzname): tzname, sign, offset = split_tzname_delta(tzname) if offset: - sign = '-' if sign == '+' else '+' - return f'{tzname}{sign}{offset}' + sign = "-" if sign == "+" else "+" + return f"{tzname}{sign}{offset}" return tzname def _convert_field_to_tz(self, field_name, tzname): if tzname and settings.USE_TZ: - field_name = "%s AT TIME ZONE '%s'" % (field_name, self._prepare_tzname_delta(tzname)) + field_name = "%s AT TIME ZONE '%s'" % ( + field_name, + self._prepare_tzname_delta(tzname), + ) return field_name def datetime_cast_date_sql(self, field_name, tzname): field_name = self._convert_field_to_tz(field_name, tzname) - return '(%s)::date' % field_name + return "(%s)::date" % field_name def datetime_cast_time_sql(self, field_name, tzname): field_name = self._convert_field_to_tz(field_name, tzname) - return '(%s)::time' % field_name + return "(%s)::time" % field_name def datetime_extract_sql(self, lookup_type, field_name, tzname): field_name = self._convert_field_to_tz(field_name, tzname) @@ -89,21 +99,30 @@ class DatabaseOperations(BaseDatabaseOperations): return cursor.fetchall() def lookup_cast(self, lookup_type, internal_type=None): - lookup = '%s' + lookup = "%s" # Cast text lookups to text to allow things like filter(x__contains=4) - if lookup_type in ('iexact', 'contains', 'icontains', 'startswith', - 'istartswith', 'endswith', 'iendswith', 'regex', 'iregex'): - if internal_type in ('IPAddressField', 'GenericIPAddressField'): + if lookup_type in ( + "iexact", + "contains", + "icontains", + "startswith", + "istartswith", + "endswith", + "iendswith", + "regex", + "iregex", + ): + if internal_type in ("IPAddressField", "GenericIPAddressField"): lookup = "HOST(%s)" - elif internal_type in ('CICharField', 'CIEmailField', 'CITextField'): - lookup = '%s::citext' + elif internal_type in ("CICharField", "CIEmailField", "CITextField"): + lookup = "%s::citext" else: lookup = "%s::text" # Use UPPER(x) for case-insensitive lookups; it's faster. - if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'): - lookup = 'UPPER(%s)' % lookup + if lookup_type in ("iexact", "icontains", "istartswith", "iendswith"): + lookup = "UPPER(%s)" % lookup return lookup @@ -128,29 +147,32 @@ class DatabaseOperations(BaseDatabaseOperations): # Perform a single SQL 'TRUNCATE x, y, z...;' statement. It allows us # to truncate tables referenced by a foreign key in any other table. sql_parts = [ - style.SQL_KEYWORD('TRUNCATE'), - ', '.join(style.SQL_FIELD(self.quote_name(table)) for table in tables), + style.SQL_KEYWORD("TRUNCATE"), + ", ".join(style.SQL_FIELD(self.quote_name(table)) for table in tables), ] if reset_sequences: - sql_parts.append(style.SQL_KEYWORD('RESTART IDENTITY')) + sql_parts.append(style.SQL_KEYWORD("RESTART IDENTITY")) if allow_cascade: - sql_parts.append(style.SQL_KEYWORD('CASCADE')) - return ['%s;' % ' '.join(sql_parts)] + sql_parts.append(style.SQL_KEYWORD("CASCADE")) + return ["%s;" % " ".join(sql_parts)] def sequence_reset_by_name_sql(self, style, sequences): # 'ALTER SEQUENCE sequence_name RESTART WITH 1;'... style SQL statements # to reset sequence indices sql = [] for sequence_info in sequences: - table_name = sequence_info['table'] + table_name = sequence_info["table"] # 'id' will be the case if it's an m2m using an autogenerated # intermediate table (see BaseDatabaseIntrospection.sequence_list). - column_name = sequence_info['column'] or 'id' - sql.append("%s setval(pg_get_serial_sequence('%s','%s'), 1, false);" % ( - style.SQL_KEYWORD('SELECT'), - style.SQL_TABLE(self.quote_name(table_name)), - style.SQL_FIELD(column_name), - )) + column_name = sequence_info["column"] or "id" + sql.append( + "%s setval(pg_get_serial_sequence('%s','%s'), 1, false);" + % ( + style.SQL_KEYWORD("SELECT"), + style.SQL_TABLE(self.quote_name(table_name)), + style.SQL_FIELD(column_name), + ) + ) return sql def tablespace_sql(self, tablespace, inline=False): @@ -161,6 +183,7 @@ class DatabaseOperations(BaseDatabaseOperations): def sequence_reset_sql(self, style, model_list): from django.db import models + output = [] qn = self.quote_name for model in model_list: @@ -174,14 +197,15 @@ class DatabaseOperations(BaseDatabaseOperations): if isinstance(f, models.AutoField): output.append( "%s setval(pg_get_serial_sequence('%s','%s'), " - "coalesce(max(%s), 1), max(%s) %s null) %s %s;" % ( - style.SQL_KEYWORD('SELECT'), + "coalesce(max(%s), 1), max(%s) %s null) %s %s;" + % ( + style.SQL_KEYWORD("SELECT"), style.SQL_TABLE(qn(model._meta.db_table)), style.SQL_FIELD(f.column), style.SQL_FIELD(qn(f.column)), style.SQL_FIELD(qn(f.column)), - style.SQL_KEYWORD('IS NOT'), - style.SQL_KEYWORD('FROM'), + style.SQL_KEYWORD("IS NOT"), + style.SQL_KEYWORD("FROM"), style.SQL_TABLE(qn(model._meta.db_table)), ) ) @@ -207,9 +231,9 @@ class DatabaseOperations(BaseDatabaseOperations): def distinct_sql(self, fields, params): if fields: params = [param for param_list in params for param in param_list] - return (['DISTINCT ON (%s)' % ', '.join(fields)], params) + return (["DISTINCT ON (%s)" % ", ".join(fields)], params) else: - return ['DISTINCT'], [] + return ["DISTINCT"], [] def last_executed_query(self, cursor, sql, params): # https://www.psycopg.org/docs/cursor.html#cursor.query @@ -220,14 +244,16 @@ class DatabaseOperations(BaseDatabaseOperations): def return_insert_columns(self, fields): if not fields: - return '', () + return "", () columns = [ - '%s.%s' % ( + "%s.%s" + % ( self.quote_name(field.model._meta.db_table), self.quote_name(field.column), - ) for field in fields + ) + for field in fields ] - return 'RETURNING %s' % ', '.join(columns), () + return "RETURNING %s" % ", ".join(columns), () def bulk_insert_sql(self, fields, placeholder_rows): placeholder_rows_sql = (", ".join(row) for row in placeholder_rows) @@ -252,7 +278,7 @@ class DatabaseOperations(BaseDatabaseOperations): return None def subtract_temporals(self, internal_type, lhs, rhs): - if internal_type == 'DateField': + if internal_type == "DateField": lhs_sql, lhs_params = lhs rhs_sql, rhs_params = rhs params = (*lhs_params, *rhs_params) @@ -263,27 +289,34 @@ class DatabaseOperations(BaseDatabaseOperations): prefix = super().explain_query_prefix(format) extra = {} if format: - extra['FORMAT'] = format + extra["FORMAT"] = format if options: - extra.update({ - name.upper(): 'true' if value else 'false' - for name, value in options.items() - }) + extra.update( + { + name.upper(): "true" if value else "false" + for name, value in options.items() + } + ) if extra: - prefix += ' (%s)' % ', '.join('%s %s' % i for i in extra.items()) + prefix += " (%s)" % ", ".join("%s %s" % i for i in extra.items()) return prefix def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields): if on_conflict == OnConflict.IGNORE: - return 'ON CONFLICT DO NOTHING' + return "ON CONFLICT DO NOTHING" if on_conflict == OnConflict.UPDATE: - return 'ON CONFLICT(%s) DO UPDATE SET %s' % ( - ', '.join(map(self.quote_name, unique_fields)), - ', '.join([ - f'{field} = EXCLUDED.{field}' - for field in map(self.quote_name, update_fields) - ]), + return "ON CONFLICT(%s) DO UPDATE SET %s" % ( + ", ".join(map(self.quote_name, unique_fields)), + ", ".join( + [ + f"{field} = EXCLUDED.{field}" + for field in map(self.quote_name, update_fields) + ] + ), ) return super().on_conflict_suffix_sql( - fields, on_conflict, update_fields, unique_fields, + fields, + on_conflict, + update_fields, + unique_fields, ) diff --git a/django/db/backends/postgresql/schema.py b/django/db/backends/postgresql/schema.py index f3b5baecbe..47e9a6a8f3 100644 --- a/django/db/backends/postgresql/schema.py +++ b/django/db/backends/postgresql/schema.py @@ -9,16 +9,18 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): sql_create_sequence = "CREATE SEQUENCE %(sequence)s" sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE" - sql_set_sequence_max = "SELECT setval('%(sequence)s', MAX(%(column)s)) FROM %(table)s" - sql_set_sequence_owner = 'ALTER SEQUENCE %(sequence)s OWNED BY %(table)s.%(column)s' + sql_set_sequence_max = ( + "SELECT setval('%(sequence)s', MAX(%(column)s)) FROM %(table)s" + ) + sql_set_sequence_owner = "ALTER SEQUENCE %(sequence)s OWNED BY %(table)s.%(column)s" sql_create_index = ( - 'CREATE INDEX %(name)s ON %(table)s%(using)s ' - '(%(columns)s)%(include)s%(extra)s%(condition)s' + "CREATE INDEX %(name)s ON %(table)s%(using)s " + "(%(columns)s)%(include)s%(extra)s%(condition)s" ) sql_create_index_concurrently = ( - 'CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s ' - '(%(columns)s)%(include)s%(extra)s%(condition)s' + "CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s " + "(%(columns)s)%(include)s%(extra)s%(condition)s" ) sql_delete_index = "DROP INDEX IF EXISTS %(name)s" sql_delete_index_concurrently = "DROP INDEX CONCURRENTLY IF EXISTS %(name)s" @@ -26,21 +28,21 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): # Setting the constraint to IMMEDIATE to allow changing data in the same # transaction. sql_create_column_inline_fk = ( - 'CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s' - '; SET CONSTRAINTS %(namespace)s%(name)s IMMEDIATE' + "CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s" + "; SET CONSTRAINTS %(namespace)s%(name)s IMMEDIATE" ) # Setting the constraint to IMMEDIATE runs any deferred checks to allow # dropping it in the same transaction. sql_delete_fk = "SET CONSTRAINTS %(name)s IMMEDIATE; ALTER TABLE %(table)s DROP CONSTRAINT %(name)s" - sql_delete_procedure = 'DROP FUNCTION %(procedure)s(%(param_types)s)' + sql_delete_procedure = "DROP FUNCTION %(procedure)s(%(param_types)s)" def quote_value(self, value): if isinstance(value, str): - value = value.replace('%', '%%') + value = value.replace("%", "%%") adapted = psycopg2.extensions.adapt(value) - if hasattr(adapted, 'encoding'): - adapted.encoding = 'utf8' + if hasattr(adapted, "encoding"): + adapted.encoding = "utf8" # getquoted() returns a quoted bytestring of the adapted value. return adapted.getquoted().decode() @@ -61,7 +63,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): def _field_base_data_types(self, field): # Yield base data types for array fields. - if field.base_field.get_internal_type() == 'ArrayField': + if field.base_field.get_internal_type() == "ArrayField": yield from self._field_base_data_types(field.base_field) else: yield self._field_data_type(field.base_field) @@ -80,45 +82,52 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): # # The same doesn't apply to array fields such as varchar[size] # and text[size], so skip them. - if '[' in db_type: + if "[" in db_type: return None - if db_type.startswith('varchar'): + if db_type.startswith("varchar"): return self._create_index_sql( model, fields=[field], - suffix='_like', - opclasses=['varchar_pattern_ops'], + suffix="_like", + opclasses=["varchar_pattern_ops"], ) - elif db_type.startswith('text'): + elif db_type.startswith("text"): return self._create_index_sql( model, fields=[field], - suffix='_like', - opclasses=['text_pattern_ops'], + suffix="_like", + opclasses=["text_pattern_ops"], ) return None def _alter_column_type_sql(self, model, old_field, new_field, new_type): - self.sql_alter_column_type = 'ALTER COLUMN %(column)s TYPE %(type)s' + self.sql_alter_column_type = "ALTER COLUMN %(column)s TYPE %(type)s" # Cast when data type changed. - using_sql = ' USING %(column)s::%(type)s' + using_sql = " USING %(column)s::%(type)s" new_internal_type = new_field.get_internal_type() old_internal_type = old_field.get_internal_type() - if new_internal_type == 'ArrayField' and new_internal_type == old_internal_type: + if new_internal_type == "ArrayField" and new_internal_type == old_internal_type: # Compare base data types for array fields. - if list(self._field_base_data_types(old_field)) != list(self._field_base_data_types(new_field)): + if list(self._field_base_data_types(old_field)) != list( + self._field_base_data_types(new_field) + ): self.sql_alter_column_type += using_sql elif self._field_data_type(old_field) != self._field_data_type(new_field): self.sql_alter_column_type += using_sql # Make ALTER TYPE with SERIAL make sense. table = strip_quotes(model._meta.db_table) - serial_fields_map = {'bigserial': 'bigint', 'serial': 'integer', 'smallserial': 'smallint'} + serial_fields_map = { + "bigserial": "bigint", + "serial": "integer", + "smallserial": "smallint", + } if new_type.lower() in serial_fields_map: column = strip_quotes(new_field.column) sequence_name = "%s_%s_seq" % (table, column) return ( ( - self.sql_alter_column_type % { + self.sql_alter_column_type + % { "column": self.quote_name(column), "type": serial_fields_map[new_type.lower()], }, @@ -126,29 +135,35 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): ), [ ( - self.sql_delete_sequence % { + self.sql_delete_sequence + % { "sequence": self.quote_name(sequence_name), }, [], ), ( - self.sql_create_sequence % { + self.sql_create_sequence + % { "sequence": self.quote_name(sequence_name), }, [], ), ( - self.sql_alter_column % { + self.sql_alter_column + % { "table": self.quote_name(table), - "changes": self.sql_alter_column_default % { + "changes": self.sql_alter_column_default + % { "column": self.quote_name(column), - "default": "nextval('%s')" % self.quote_name(sequence_name), - } + "default": "nextval('%s')" + % self.quote_name(sequence_name), + }, }, [], ), ( - self.sql_set_sequence_max % { + self.sql_set_sequence_max + % { "table": self.quote_name(table), "column": self.quote_name(column), "sequence": self.quote_name(sequence_name), @@ -156,24 +171,31 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): [], ), ( - self.sql_set_sequence_owner % { - 'table': self.quote_name(table), - 'column': self.quote_name(column), - 'sequence': self.quote_name(sequence_name), + self.sql_set_sequence_owner + % { + "table": self.quote_name(table), + "column": self.quote_name(column), + "sequence": self.quote_name(sequence_name), }, [], ), ], ) - elif old_field.db_parameters(connection=self.connection)['type'] in serial_fields_map: + elif ( + old_field.db_parameters(connection=self.connection)["type"] + in serial_fields_map + ): # Drop the sequence if migrating away from AutoField. column = strip_quotes(new_field.column) - sequence_name = '%s_%s_seq' % (table, column) - fragment, _ = super()._alter_column_type_sql(model, old_field, new_field, new_type) + sequence_name = "%s_%s_seq" % (table, column) + fragment, _ = super()._alter_column_type_sql( + model, old_field, new_field, new_type + ) return fragment, [ ( - self.sql_delete_sequence % { - 'sequence': self.quote_name(sequence_name), + self.sql_delete_sequence + % { + "sequence": self.quote_name(sequence_name), }, [], ), @@ -181,58 +203,114 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): else: return super()._alter_column_type_sql(model, old_field, new_field, new_type) - def _alter_field(self, model, old_field, new_field, old_type, new_type, - old_db_params, new_db_params, strict=False): + def _alter_field( + self, + model, + old_field, + new_field, + old_type, + new_type, + old_db_params, + new_db_params, + strict=False, + ): # Drop indexes on varchar/text/citext columns that are changing to a # different type. if (old_field.db_index or old_field.unique) and ( - (old_type.startswith('varchar') and not new_type.startswith('varchar')) or - (old_type.startswith('text') and not new_type.startswith('text')) or - (old_type.startswith('citext') and not new_type.startswith('citext')) + (old_type.startswith("varchar") and not new_type.startswith("varchar")) + or (old_type.startswith("text") and not new_type.startswith("text")) + or (old_type.startswith("citext") and not new_type.startswith("citext")) ): - index_name = self._create_index_name(model._meta.db_table, [old_field.column], suffix='_like') + index_name = self._create_index_name( + model._meta.db_table, [old_field.column], suffix="_like" + ) self.execute(self._delete_index_sql(model, index_name)) super()._alter_field( - model, old_field, new_field, old_type, new_type, old_db_params, - new_db_params, strict, + model, + old_field, + new_field, + old_type, + new_type, + old_db_params, + new_db_params, + strict, ) # Added an index? Create any PostgreSQL-specific indexes. - if ((not (old_field.db_index or old_field.unique) and new_field.db_index) or - (not old_field.unique and new_field.unique)): + if (not (old_field.db_index or old_field.unique) and new_field.db_index) or ( + not old_field.unique and new_field.unique + ): like_index_statement = self._create_like_index_sql(model, new_field) if like_index_statement is not None: self.execute(like_index_statement) # Removed an index? Drop any PostgreSQL-specific indexes. if old_field.unique and not (new_field.db_index or new_field.unique): - index_to_remove = self._create_index_name(model._meta.db_table, [old_field.column], suffix='_like') + index_to_remove = self._create_index_name( + model._meta.db_table, [old_field.column], suffix="_like" + ) self.execute(self._delete_index_sql(model, index_to_remove)) def _index_columns(self, table, columns, col_suffixes, opclasses): if opclasses: - return IndexColumns(table, columns, self.quote_name, col_suffixes=col_suffixes, opclasses=opclasses) + return IndexColumns( + table, + columns, + self.quote_name, + col_suffixes=col_suffixes, + opclasses=opclasses, + ) return super()._index_columns(table, columns, col_suffixes, opclasses) def add_index(self, model, index, concurrently=False): - self.execute(index.create_sql(model, self, concurrently=concurrently), params=None) + self.execute( + index.create_sql(model, self, concurrently=concurrently), params=None + ) def remove_index(self, model, index, concurrently=False): self.execute(index.remove_sql(model, self, concurrently=concurrently)) def _delete_index_sql(self, model, name, sql=None, concurrently=False): - sql = self.sql_delete_index_concurrently if concurrently else self.sql_delete_index + sql = ( + self.sql_delete_index_concurrently + if concurrently + else self.sql_delete_index + ) return super()._delete_index_sql(model, name, sql) def _create_index_sql( - self, model, *, fields=None, name=None, suffix='', using='', - db_tablespace=None, col_suffixes=(), sql=None, opclasses=(), - condition=None, concurrently=False, include=None, expressions=None, + self, + model, + *, + fields=None, + name=None, + suffix="", + using="", + db_tablespace=None, + col_suffixes=(), + sql=None, + opclasses=(), + condition=None, + concurrently=False, + include=None, + expressions=None, ): - sql = self.sql_create_index if not concurrently else self.sql_create_index_concurrently + sql = ( + self.sql_create_index + if not concurrently + else self.sql_create_index_concurrently + ) return super()._create_index_sql( - model, fields=fields, name=name, suffix=suffix, using=using, - db_tablespace=db_tablespace, col_suffixes=col_suffixes, sql=sql, - opclasses=opclasses, condition=condition, include=include, + model, + fields=fields, + name=name, + suffix=suffix, + using=using, + db_tablespace=db_tablespace, + col_suffixes=col_suffixes, + sql=sql, + opclasses=opclasses, + condition=condition, + include=include, expressions=expressions, ) |
