summaryrefslogtreecommitdiff
path: root/django/db/backends/postgresql
diff options
context:
space:
mode:
authordjango-bot <ops@djangoproject.com>2022-02-03 20:24:19 +0100
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2022-02-07 20:37:05 +0100
commit9c19aff7c7561e3a82978a272ecdaad40dda5c00 (patch)
treef0506b668a013d0063e5fba3dbf4863b466713ba /django/db/backends/postgresql
parentf68fa8b45dfac545cfc4111d4e52804c86db68d3 (diff)
Refs #33476 -- Reformatted code with Black.
Diffstat (limited to 'django/db/backends/postgresql')
-rw-r--r--django/db/backends/postgresql/base.py211
-rw-r--r--django/db/backends/postgresql/client.py48
-rw-r--r--django/db/backends/postgresql/creation.py47
-rw-r--r--django/db/backends/postgresql/features.py30
-rw-r--r--django/db/backends/postgresql/introspection.py130
-rw-r--r--django/db/backends/postgresql/operations.py155
-rw-r--r--django/db/backends/postgresql/schema.py204
7 files changed, 498 insertions, 327 deletions
diff --git a/django/db/backends/postgresql/base.py b/django/db/backends/postgresql/base.py
index e49d453f94..92f393227e 100644
--- a/django/db/backends/postgresql/base.py
+++ b/django/db/backends/postgresql/base.py
@@ -11,11 +11,10 @@ from contextlib import contextmanager
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
-from django.db import DatabaseError as WrappedDatabaseError, connections
+from django.db import DatabaseError as WrappedDatabaseError
+from django.db import connections
from django.db.backends.base.base import BaseDatabaseWrapper
-from django.db.backends.utils import (
- CursorDebugWrapper as BaseCursorDebugWrapper,
-)
+from django.db.backends.utils import CursorDebugWrapper as BaseCursorDebugWrapper
from django.utils.asyncio import async_unsafe
from django.utils.functional import cached_property
from django.utils.safestring import SafeString
@@ -30,14 +29,17 @@ except ImportError as e:
def psycopg2_version():
- version = psycopg2.__version__.split(' ', 1)[0]
+ version = psycopg2.__version__.split(" ", 1)[0]
return get_version_tuple(version)
PSYCOPG2_VERSION = psycopg2_version()
if PSYCOPG2_VERSION < (2, 8, 4):
- raise ImproperlyConfigured("psycopg2 version 2.8.4 or newer is required; you have %s" % psycopg2.__version__)
+ raise ImproperlyConfigured(
+ "psycopg2 version 2.8.4 or newer is required; you have %s"
+ % psycopg2.__version__
+ )
# Some of these import psycopg2, so import them after checking if it's installed.
@@ -56,68 +58,68 @@ psycopg2.extras.register_uuid()
INETARRAY_OID = 1041
INETARRAY = psycopg2.extensions.new_array_type(
(INETARRAY_OID,),
- 'INETARRAY',
+ "INETARRAY",
psycopg2.extensions.UNICODE,
)
psycopg2.extensions.register_type(INETARRAY)
class DatabaseWrapper(BaseDatabaseWrapper):
- vendor = 'postgresql'
- display_name = 'PostgreSQL'
+ vendor = "postgresql"
+ display_name = "PostgreSQL"
# This dictionary maps Field objects to their associated PostgreSQL column
# types, as strings. Column-type strings can contain format strings; they'll
# be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output.
data_types = {
- 'AutoField': 'serial',
- 'BigAutoField': 'bigserial',
- 'BinaryField': 'bytea',
- 'BooleanField': 'boolean',
- 'CharField': 'varchar(%(max_length)s)',
- 'DateField': 'date',
- 'DateTimeField': 'timestamp with time zone',
- 'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)',
- 'DurationField': 'interval',
- 'FileField': 'varchar(%(max_length)s)',
- 'FilePathField': 'varchar(%(max_length)s)',
- 'FloatField': 'double precision',
- 'IntegerField': 'integer',
- 'BigIntegerField': 'bigint',
- 'IPAddressField': 'inet',
- 'GenericIPAddressField': 'inet',
- 'JSONField': 'jsonb',
- 'OneToOneField': 'integer',
- 'PositiveBigIntegerField': 'bigint',
- 'PositiveIntegerField': 'integer',
- 'PositiveSmallIntegerField': 'smallint',
- 'SlugField': 'varchar(%(max_length)s)',
- 'SmallAutoField': 'smallserial',
- 'SmallIntegerField': 'smallint',
- 'TextField': 'text',
- 'TimeField': 'time',
- 'UUIDField': 'uuid',
+ "AutoField": "serial",
+ "BigAutoField": "bigserial",
+ "BinaryField": "bytea",
+ "BooleanField": "boolean",
+ "CharField": "varchar(%(max_length)s)",
+ "DateField": "date",
+ "DateTimeField": "timestamp with time zone",
+ "DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)",
+ "DurationField": "interval",
+ "FileField": "varchar(%(max_length)s)",
+ "FilePathField": "varchar(%(max_length)s)",
+ "FloatField": "double precision",
+ "IntegerField": "integer",
+ "BigIntegerField": "bigint",
+ "IPAddressField": "inet",
+ "GenericIPAddressField": "inet",
+ "JSONField": "jsonb",
+ "OneToOneField": "integer",
+ "PositiveBigIntegerField": "bigint",
+ "PositiveIntegerField": "integer",
+ "PositiveSmallIntegerField": "smallint",
+ "SlugField": "varchar(%(max_length)s)",
+ "SmallAutoField": "smallserial",
+ "SmallIntegerField": "smallint",
+ "TextField": "text",
+ "TimeField": "time",
+ "UUIDField": "uuid",
}
data_type_check_constraints = {
- 'PositiveBigIntegerField': '"%(column)s" >= 0',
- 'PositiveIntegerField': '"%(column)s" >= 0',
- 'PositiveSmallIntegerField': '"%(column)s" >= 0',
+ "PositiveBigIntegerField": '"%(column)s" >= 0',
+ "PositiveIntegerField": '"%(column)s" >= 0',
+ "PositiveSmallIntegerField": '"%(column)s" >= 0',
}
operators = {
- 'exact': '= %s',
- 'iexact': '= UPPER(%s)',
- 'contains': 'LIKE %s',
- 'icontains': 'LIKE UPPER(%s)',
- 'regex': '~ %s',
- 'iregex': '~* %s',
- 'gt': '> %s',
- 'gte': '>= %s',
- 'lt': '< %s',
- 'lte': '<= %s',
- 'startswith': 'LIKE %s',
- 'endswith': 'LIKE %s',
- 'istartswith': 'LIKE UPPER(%s)',
- 'iendswith': 'LIKE UPPER(%s)',
+ "exact": "= %s",
+ "iexact": "= UPPER(%s)",
+ "contains": "LIKE %s",
+ "icontains": "LIKE UPPER(%s)",
+ "regex": "~ %s",
+ "iregex": "~* %s",
+ "gt": "> %s",
+ "gte": ">= %s",
+ "lt": "< %s",
+ "lte": "<= %s",
+ "startswith": "LIKE %s",
+ "endswith": "LIKE %s",
+ "istartswith": "LIKE UPPER(%s)",
+ "iendswith": "LIKE UPPER(%s)",
}
# The patterns below are used to generate SQL pattern lookup clauses when
@@ -128,14 +130,16 @@ class DatabaseWrapper(BaseDatabaseWrapper):
#
# Note: we use str.format() here for readability as '%' is used as a wildcard for
# the LIKE operator.
- pattern_esc = r"REPLACE(REPLACE(REPLACE({}, E'\\', E'\\\\'), E'%%', E'\\%%'), E'_', E'\\_')"
+ pattern_esc = (
+ r"REPLACE(REPLACE(REPLACE({}, E'\\', E'\\\\'), E'%%', E'\\%%'), E'_', E'\\_')"
+ )
pattern_ops = {
- 'contains': "LIKE '%%' || {} || '%%'",
- 'icontains': "LIKE '%%' || UPPER({}) || '%%'",
- 'startswith': "LIKE {} || '%%'",
- 'istartswith': "LIKE UPPER({}) || '%%'",
- 'endswith': "LIKE '%%' || {}",
- 'iendswith': "LIKE '%%' || UPPER({})",
+ "contains": "LIKE '%%' || {} || '%%'",
+ "icontains": "LIKE '%%' || UPPER({}) || '%%'",
+ "startswith": "LIKE {} || '%%'",
+ "istartswith": "LIKE UPPER({}) || '%%'",
+ "endswith": "LIKE '%%' || {}",
+ "iendswith": "LIKE '%%' || UPPER({})",
}
Database = Database
@@ -152,46 +156,46 @@ class DatabaseWrapper(BaseDatabaseWrapper):
def get_connection_params(self):
settings_dict = self.settings_dict
# None may be used to connect to the default 'postgres' db
- if (
- settings_dict['NAME'] == '' and
- not settings_dict.get('OPTIONS', {}).get('service')
+ if settings_dict["NAME"] == "" and not settings_dict.get("OPTIONS", {}).get(
+ "service"
):
raise ImproperlyConfigured(
"settings.DATABASES is improperly configured. "
"Please supply the NAME or OPTIONS['service'] value."
)
- if len(settings_dict['NAME'] or '') > self.ops.max_name_length():
+ if len(settings_dict["NAME"] or "") > self.ops.max_name_length():
raise ImproperlyConfigured(
"The database name '%s' (%d characters) is longer than "
"PostgreSQL's limit of %d characters. Supply a shorter NAME "
- "in settings.DATABASES." % (
- settings_dict['NAME'],
- len(settings_dict['NAME']),
+ "in settings.DATABASES."
+ % (
+ settings_dict["NAME"],
+ len(settings_dict["NAME"]),
self.ops.max_name_length(),
)
)
conn_params = {}
- if settings_dict['NAME']:
+ if settings_dict["NAME"]:
conn_params = {
- 'database': settings_dict['NAME'],
- **settings_dict['OPTIONS'],
+ "database": settings_dict["NAME"],
+ **settings_dict["OPTIONS"],
}
- elif settings_dict['NAME'] is None:
+ elif settings_dict["NAME"] is None:
# Connect to the default 'postgres' db.
- settings_dict.get('OPTIONS', {}).pop('service', None)
- conn_params = {'database': 'postgres', **settings_dict['OPTIONS']}
+ settings_dict.get("OPTIONS", {}).pop("service", None)
+ conn_params = {"database": "postgres", **settings_dict["OPTIONS"]}
else:
- conn_params = {**settings_dict['OPTIONS']}
+ conn_params = {**settings_dict["OPTIONS"]}
- conn_params.pop('isolation_level', None)
- if settings_dict['USER']:
- conn_params['user'] = settings_dict['USER']
- if settings_dict['PASSWORD']:
- conn_params['password'] = settings_dict['PASSWORD']
- if settings_dict['HOST']:
- conn_params['host'] = settings_dict['HOST']
- if settings_dict['PORT']:
- conn_params['port'] = settings_dict['PORT']
+ conn_params.pop("isolation_level", None)
+ if settings_dict["USER"]:
+ conn_params["user"] = settings_dict["USER"]
+ if settings_dict["PASSWORD"]:
+ conn_params["password"] = settings_dict["PASSWORD"]
+ if settings_dict["HOST"]:
+ conn_params["host"] = settings_dict["HOST"]
+ if settings_dict["PORT"]:
+ conn_params["port"] = settings_dict["PORT"]
return conn_params
@async_unsafe
@@ -203,9 +207,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# default when no value is explicitly specified in options.
# - before calling _set_autocommit() because if autocommit is on, that
# will set connection.isolation_level to ISOLATION_LEVEL_AUTOCOMMIT.
- options = self.settings_dict['OPTIONS']
+ options = self.settings_dict["OPTIONS"]
try:
- self.isolation_level = options['isolation_level']
+ self.isolation_level = options["isolation_level"]
except KeyError:
self.isolation_level = connection.isolation_level
else:
@@ -215,13 +219,15 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# Register dummy loads() to avoid a round trip from psycopg2's decode
# to json.dumps() to json.loads(), when using a custom decoder in
# JSONField.
- psycopg2.extras.register_default_jsonb(conn_or_curs=connection, loads=lambda x: x)
+ psycopg2.extras.register_default_jsonb(
+ conn_or_curs=connection, loads=lambda x: x
+ )
return connection
def ensure_timezone(self):
if self.connection is None:
return False
- conn_timezone_name = self.connection.get_parameter_status('TimeZone')
+ conn_timezone_name = self.connection.get_parameter_status("TimeZone")
timezone_name = self.timezone_name
if timezone_name and conn_timezone_name != timezone_name:
with self.connection.cursor() as cursor:
@@ -230,7 +236,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return False
def init_connection_state(self):
- self.connection.set_client_encoding('UTF8')
+ self.connection.set_client_encoding("UTF8")
timezone_changed = self.ensure_timezone()
if timezone_changed:
@@ -243,7 +249,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
if name:
# In autocommit mode, the cursor will be used outside of a
# transaction, hence use a holdable cursor.
- cursor = self.connection.cursor(name, scrollable=False, withhold=self.connection.autocommit)
+ cursor = self.connection.cursor(
+ name, scrollable=False, withhold=self.connection.autocommit
+ )
else:
cursor = self.connection.cursor()
cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None
@@ -268,10 +276,11 @@ class DatabaseWrapper(BaseDatabaseWrapper):
if current_task:
task_ident = str(id(current_task))
else:
- task_ident = 'sync'
+ task_ident = "sync"
# Use that and the thread ident to get a unique name
return self._cursor(
- name='_django_curs_%d_%s_%d' % (
+ name="_django_curs_%d_%s_%d"
+ % (
# Avoid reusing name in other threads / tasks
threading.current_thread().ident,
task_ident,
@@ -289,14 +298,14 @@ class DatabaseWrapper(BaseDatabaseWrapper):
afterward.
"""
with self.cursor() as cursor:
- cursor.execute('SET CONSTRAINTS ALL IMMEDIATE')
- cursor.execute('SET CONSTRAINTS ALL DEFERRED')
+ cursor.execute("SET CONSTRAINTS ALL IMMEDIATE")
+ cursor.execute("SET CONSTRAINTS ALL DEFERRED")
def is_usable(self):
try:
# Use a psycopg cursor directly, bypassing Django's utilities.
with self.connection.cursor() as cursor:
- cursor.execute('SELECT 1')
+ cursor.execute("SELECT 1")
except Database.Error:
return False
else:
@@ -317,12 +326,18 @@ class DatabaseWrapper(BaseDatabaseWrapper):
"database when it's not needed (for example, when running tests). "
"Django was unable to create a connection to the 'postgres' database "
"and will use the first PostgreSQL database instead.",
- RuntimeWarning
+ RuntimeWarning,
)
for connection in connections.all():
- if connection.vendor == 'postgresql' and connection.settings_dict['NAME'] != 'postgres':
+ if (
+ connection.vendor == "postgresql"
+ and connection.settings_dict["NAME"] != "postgres"
+ ):
conn = self.__class__(
- {**self.settings_dict, 'NAME': connection.settings_dict['NAME']},
+ {
+ **self.settings_dict,
+ "NAME": connection.settings_dict["NAME"],
+ },
alias=self.alias,
)
try:
@@ -349,5 +364,5 @@ class CursorDebugWrapper(BaseCursorDebugWrapper):
return self.cursor.copy_expert(sql, file, *args)
def copy_to(self, file, table, *args, **kwargs):
- with self.debug_sql(sql='COPY %s TO STDOUT' % table):
+ with self.debug_sql(sql="COPY %s TO STDOUT" % table):
return self.cursor.copy_to(file, table, *args, **kwargs)
diff --git a/django/db/backends/postgresql/client.py b/django/db/backends/postgresql/client.py
index 0effcc44e6..4c9bd63546 100644
--- a/django/db/backends/postgresql/client.py
+++ b/django/db/backends/postgresql/client.py
@@ -4,53 +4,53 @@ from django.db.backends.base.client import BaseDatabaseClient
class DatabaseClient(BaseDatabaseClient):
- executable_name = 'psql'
+ executable_name = "psql"
@classmethod
def settings_to_cmd_args_env(cls, settings_dict, parameters):
args = [cls.executable_name]
- options = settings_dict.get('OPTIONS', {})
+ options = settings_dict.get("OPTIONS", {})
- host = settings_dict.get('HOST')
- port = settings_dict.get('PORT')
- dbname = settings_dict.get('NAME')
- user = settings_dict.get('USER')
- passwd = settings_dict.get('PASSWORD')
- passfile = options.get('passfile')
- service = options.get('service')
- sslmode = options.get('sslmode')
- sslrootcert = options.get('sslrootcert')
- sslcert = options.get('sslcert')
- sslkey = options.get('sslkey')
+ host = settings_dict.get("HOST")
+ port = settings_dict.get("PORT")
+ dbname = settings_dict.get("NAME")
+ user = settings_dict.get("USER")
+ passwd = settings_dict.get("PASSWORD")
+ passfile = options.get("passfile")
+ service = options.get("service")
+ sslmode = options.get("sslmode")
+ sslrootcert = options.get("sslrootcert")
+ sslcert = options.get("sslcert")
+ sslkey = options.get("sslkey")
if not dbname and not service:
# Connect to the default 'postgres' db.
- dbname = 'postgres'
+ dbname = "postgres"
if user:
- args += ['-U', user]
+ args += ["-U", user]
if host:
- args += ['-h', host]
+ args += ["-h", host]
if port:
- args += ['-p', str(port)]
+ args += ["-p", str(port)]
if dbname:
args += [dbname]
args.extend(parameters)
env = {}
if passwd:
- env['PGPASSWORD'] = str(passwd)
+ env["PGPASSWORD"] = str(passwd)
if service:
- env['PGSERVICE'] = str(service)
+ env["PGSERVICE"] = str(service)
if sslmode:
- env['PGSSLMODE'] = str(sslmode)
+ env["PGSSLMODE"] = str(sslmode)
if sslrootcert:
- env['PGSSLROOTCERT'] = str(sslrootcert)
+ env["PGSSLROOTCERT"] = str(sslrootcert)
if sslcert:
- env['PGSSLCERT'] = str(sslcert)
+ env["PGSSLCERT"] = str(sslcert)
if sslkey:
- env['PGSSLKEY'] = str(sslkey)
+ env["PGSSLKEY"] = str(sslkey)
if passfile:
- env['PGPASSFILE'] = str(passfile)
+ env["PGPASSFILE"] = str(passfile)
return args, (env or None)
def runshell(self, parameters):
diff --git a/django/db/backends/postgresql/creation.py b/django/db/backends/postgresql/creation.py
index eb8ac3bcf5..70c3eda566 100644
--- a/django/db/backends/postgresql/creation.py
+++ b/django/db/backends/postgresql/creation.py
@@ -8,7 +8,6 @@ from django.db.backends.utils import strip_quotes
class DatabaseCreation(BaseDatabaseCreation):
-
def _quote_name(self, name):
return self.connection.ops.quote_name(name)
@@ -21,32 +20,35 @@ class DatabaseCreation(BaseDatabaseCreation):
return suffix and "WITH" + suffix
def sql_table_creation_suffix(self):
- test_settings = self.connection.settings_dict['TEST']
- if test_settings.get('COLLATION') is not None:
+ test_settings = self.connection.settings_dict["TEST"]
+ if test_settings.get("COLLATION") is not None:
raise ImproperlyConfigured(
- 'PostgreSQL does not support collation setting at database '
- 'creation time.'
+ "PostgreSQL does not support collation setting at database "
+ "creation time."
)
return self._get_database_create_suffix(
- encoding=test_settings['CHARSET'],
- template=test_settings.get('TEMPLATE'),
+ encoding=test_settings["CHARSET"],
+ template=test_settings.get("TEMPLATE"),
)
def _database_exists(self, cursor, database_name):
- cursor.execute('SELECT 1 FROM pg_catalog.pg_database WHERE datname = %s', [strip_quotes(database_name)])
+ cursor.execute(
+ "SELECT 1 FROM pg_catalog.pg_database WHERE datname = %s",
+ [strip_quotes(database_name)],
+ )
return cursor.fetchone() is not None
def _execute_create_test_db(self, cursor, parameters, keepdb=False):
try:
- if keepdb and self._database_exists(cursor, parameters['dbname']):
+ if keepdb and self._database_exists(cursor, parameters["dbname"]):
# If the database should be kept and it already exists, don't
# try to create a new one.
return
super()._execute_create_test_db(cursor, parameters, keepdb)
except Exception as e:
- if getattr(e.__cause__, 'pgcode', '') != errorcodes.DUPLICATE_DATABASE:
+ if getattr(e.__cause__, "pgcode", "") != errorcodes.DUPLICATE_DATABASE:
# All errors except "database already exists" cancel tests.
- self.log('Got an error creating the test database: %s' % e)
+ self.log("Got an error creating the test database: %s" % e)
sys.exit(2)
elif not keepdb:
# If the database should be kept, ignore "database already
@@ -58,11 +60,11 @@ class DatabaseCreation(BaseDatabaseCreation):
# to the template database.
self.connection.close()
- source_database_name = self.connection.settings_dict['NAME']
- target_database_name = self.get_test_db_clone_settings(suffix)['NAME']
+ source_database_name = self.connection.settings_dict["NAME"]
+ target_database_name = self.get_test_db_clone_settings(suffix)["NAME"]
test_db_params = {
- 'dbname': self._quote_name(target_database_name),
- 'suffix': self._get_database_create_suffix(template=source_database_name),
+ "dbname": self._quote_name(target_database_name),
+ "suffix": self._get_database_create_suffix(template=source_database_name),
}
with self._nodb_cursor() as cursor:
try:
@@ -70,11 +72,16 @@ class DatabaseCreation(BaseDatabaseCreation):
except Exception:
try:
if verbosity >= 1:
- self.log('Destroying old test database for alias %s...' % (
- self._get_database_display_str(verbosity, target_database_name),
- ))
- cursor.execute('DROP DATABASE %(dbname)s' % test_db_params)
+ self.log(
+ "Destroying old test database for alias %s..."
+ % (
+ self._get_database_display_str(
+ verbosity, target_database_name
+ ),
+ )
+ )
+ cursor.execute("DROP DATABASE %(dbname)s" % test_db_params)
self._execute_create_test_db(cursor, test_db_params, keepdb)
except Exception as e:
- self.log('Got an error cloning the test database: %s' % e)
+ self.log("Got an error cloning the test database: %s" % e)
sys.exit(2)
diff --git a/django/db/backends/postgresql/features.py b/django/db/backends/postgresql/features.py
index 1ce73fb0a8..182c230c75 100644
--- a/django/db/backends/postgresql/features.py
+++ b/django/db/backends/postgresql/features.py
@@ -52,7 +52,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_over_clause = True
only_supports_unbounded_with_preceding_and_following = True
supports_aggregate_filter_clause = True
- supported_explain_formats = {'JSON', 'TEXT', 'XML', 'YAML'}
+ supported_explain_formats = {"JSON", "TEXT", "XML", "YAML"}
validates_explain_options = False # A query will error on invalid options.
supports_deferrable_unique_constraints = True
has_json_operators = True
@@ -60,14 +60,14 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_update_conflicts = True
supports_update_conflicts_with_target = True
test_collations = {
- 'non_default': 'sv-x-icu',
- 'swedish_ci': 'sv-x-icu',
+ "non_default": "sv-x-icu",
+ "swedish_ci": "sv-x-icu",
}
test_now_utc_template = "STATEMENT_TIMESTAMP() AT TIME ZONE 'UTC'"
django_test_skips = {
- 'opclasses are PostgreSQL only.': {
- 'indexes.tests.SchemaIndexesNotPostgreSQLTests.test_create_index_ignores_opclasses',
+ "opclasses are PostgreSQL only.": {
+ "indexes.tests.SchemaIndexesNotPostgreSQLTests.test_create_index_ignores_opclasses",
},
}
@@ -75,9 +75,9 @@ class DatabaseFeatures(BaseDatabaseFeatures):
def introspected_field_types(self):
return {
**super().introspected_field_types,
- 'PositiveBigIntegerField': 'BigIntegerField',
- 'PositiveIntegerField': 'IntegerField',
- 'PositiveSmallIntegerField': 'SmallIntegerField',
+ "PositiveBigIntegerField": "BigIntegerField",
+ "PositiveIntegerField": "IntegerField",
+ "PositiveSmallIntegerField": "SmallIntegerField",
}
@cached_property
@@ -96,9 +96,11 @@ class DatabaseFeatures(BaseDatabaseFeatures):
def is_postgresql_14(self):
return self.connection.pg_version >= 140000
- has_bit_xor = property(operator.attrgetter('is_postgresql_14'))
- has_websearch_to_tsquery = property(operator.attrgetter('is_postgresql_11'))
- supports_covering_indexes = property(operator.attrgetter('is_postgresql_11'))
- supports_covering_gist_indexes = property(operator.attrgetter('is_postgresql_12'))
- supports_covering_spgist_indexes = property(operator.attrgetter('is_postgresql_14'))
- supports_non_deterministic_collations = property(operator.attrgetter('is_postgresql_12'))
+ has_bit_xor = property(operator.attrgetter("is_postgresql_14"))
+ has_websearch_to_tsquery = property(operator.attrgetter("is_postgresql_11"))
+ supports_covering_indexes = property(operator.attrgetter("is_postgresql_11"))
+ supports_covering_gist_indexes = property(operator.attrgetter("is_postgresql_12"))
+ supports_covering_spgist_indexes = property(operator.attrgetter("is_postgresql_14"))
+ supports_non_deterministic_collations = property(
+ operator.attrgetter("is_postgresql_12")
+ )
diff --git a/django/db/backends/postgresql/introspection.py b/django/db/backends/postgresql/introspection.py
index f31d906a2f..a7e9a13d61 100644
--- a/django/db/backends/postgresql/introspection.py
+++ b/django/db/backends/postgresql/introspection.py
@@ -1,5 +1,7 @@
from django.db.backends.base.introspection import (
- BaseDatabaseIntrospection, FieldInfo, TableInfo,
+ BaseDatabaseIntrospection,
+ FieldInfo,
+ TableInfo,
)
from django.db.models import Index
@@ -7,46 +9,47 @@ from django.db.models import Index
class DatabaseIntrospection(BaseDatabaseIntrospection):
# Maps type codes to Django Field types.
data_types_reverse = {
- 16: 'BooleanField',
- 17: 'BinaryField',
- 20: 'BigIntegerField',
- 21: 'SmallIntegerField',
- 23: 'IntegerField',
- 25: 'TextField',
- 700: 'FloatField',
- 701: 'FloatField',
- 869: 'GenericIPAddressField',
- 1042: 'CharField', # blank-padded
- 1043: 'CharField',
- 1082: 'DateField',
- 1083: 'TimeField',
- 1114: 'DateTimeField',
- 1184: 'DateTimeField',
- 1186: 'DurationField',
- 1266: 'TimeField',
- 1700: 'DecimalField',
- 2950: 'UUIDField',
- 3802: 'JSONField',
+ 16: "BooleanField",
+ 17: "BinaryField",
+ 20: "BigIntegerField",
+ 21: "SmallIntegerField",
+ 23: "IntegerField",
+ 25: "TextField",
+ 700: "FloatField",
+ 701: "FloatField",
+ 869: "GenericIPAddressField",
+ 1042: "CharField", # blank-padded
+ 1043: "CharField",
+ 1082: "DateField",
+ 1083: "TimeField",
+ 1114: "DateTimeField",
+ 1184: "DateTimeField",
+ 1186: "DurationField",
+ 1266: "TimeField",
+ 1700: "DecimalField",
+ 2950: "UUIDField",
+ 3802: "JSONField",
}
# A hook for subclasses.
- index_default_access_method = 'btree'
+ index_default_access_method = "btree"
ignored_tables = []
def get_field_type(self, data_type, description):
field_type = super().get_field_type(data_type, description)
- if description.default and 'nextval' in description.default:
- if field_type == 'IntegerField':
- return 'AutoField'
- elif field_type == 'BigIntegerField':
- return 'BigAutoField'
- elif field_type == 'SmallIntegerField':
- return 'SmallAutoField'
+ if description.default and "nextval" in description.default:
+ if field_type == "IntegerField":
+ return "AutoField"
+ elif field_type == "BigIntegerField":
+ return "BigAutoField"
+ elif field_type == "SmallIntegerField":
+ return "SmallAutoField"
return field_type
def get_table_list(self, cursor):
"""Return a list of table and view names in the current database."""
- cursor.execute("""
+ cursor.execute(
+ """
SELECT c.relname,
CASE WHEN c.relispartition THEN 'p' WHEN c.relkind IN ('m', 'v') THEN 'v' ELSE 't' END
FROM pg_catalog.pg_class c
@@ -54,8 +57,13 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
AND pg_catalog.pg_table_is_visible(c.oid)
- """)
- return [TableInfo(*row) for row in cursor.fetchall() if row[0] not in self.ignored_tables]
+ """
+ )
+ return [
+ TableInfo(*row)
+ for row in cursor.fetchall()
+ if row[0] not in self.ignored_tables
+ ]
def get_table_description(self, cursor, table_name):
"""
@@ -65,7 +73,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
# Query the pg_catalog tables as cursor.description does not reliably
# return the nullable property and information_schema.columns does not
# contain details of materialized views.
- cursor.execute("""
+ cursor.execute(
+ """
SELECT
a.attname AS column_name,
NOT (a.attnotnull OR (t.typtype = 'd' AND t.typnotnull)) AS is_nullable,
@@ -81,9 +90,13 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
AND c.relname = %s
AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
AND pg_catalog.pg_table_is_visible(c.oid)
- """, [table_name])
+ """,
+ [table_name],
+ )
field_map = {line[0]: line[1:] for line in cursor.fetchall()}
- cursor.execute("SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name))
+ cursor.execute(
+ "SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name)
+ )
return [
FieldInfo(
line.name,
@@ -98,7 +111,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
]
def get_sequences(self, cursor, table_name, table_fields=()):
- cursor.execute("""
+ cursor.execute(
+ """
SELECT s.relname as sequence_name, col.attname
FROM pg_class s
JOIN pg_namespace sn ON sn.oid = s.relnamespace
@@ -110,9 +124,11 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
AND d.deptype in ('a', 'n')
AND pg_catalog.pg_table_is_visible(tbl.oid)
AND tbl.relname = %s
- """, [table_name])
+ """,
+ [table_name],
+ )
return [
- {'name': row[0], 'table': table_name, 'column': row[1]}
+ {"name": row[0], "table": table_name, "column": row[1]}
for row in cursor.fetchall()
]
@@ -121,7 +137,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
Return a dictionary of {field_name: (field_name_other_table, other_table)}
representing all foreign keys in the given table.
"""
- cursor.execute("""
+ cursor.execute(
+ """
SELECT a1.attname, c2.relname, a2.attname
FROM pg_constraint con
LEFT JOIN pg_class c1 ON con.conrelid = c1.oid
@@ -133,7 +150,9 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
con.contype = 'f' AND
c1.relnamespace = c2.relnamespace AND
pg_catalog.pg_table_is_visible(c1.oid)
- """, [table_name])
+ """,
+ [table_name],
+ )
return {row[0]: (row[2], row[1]) for row in cursor.fetchall()}
def get_constraints(self, cursor, table_name):
@@ -146,7 +165,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
# Loop over the key table, collecting things as constraints. The column
# array must return column names in the same order in which they were
# created.
- cursor.execute("""
+ cursor.execute(
+ """
SELECT
c.conname,
array(
@@ -165,7 +185,9 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
FROM pg_constraint AS c
JOIN pg_class AS cl ON c.conrelid = cl.oid
WHERE cl.relname = %s AND pg_catalog.pg_table_is_visible(cl.oid)
- """, [table_name])
+ """,
+ [table_name],
+ )
for constraint, columns, kind, used_cols, options in cursor.fetchall():
constraints[constraint] = {
"columns": columns,
@@ -178,7 +200,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
"options": options,
}
# Now get indexes
- cursor.execute("""
+ cursor.execute(
+ """
SELECT
indexname, array_agg(attname ORDER BY arridx), indisunique, indisprimary,
array_agg(ordering ORDER BY arridx), amname, exprdef, s2.attoptions
@@ -207,14 +230,27 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
WHERE c.relname = %s AND pg_catalog.pg_table_is_visible(c.oid)
) s2
GROUP BY indexname, indisunique, indisprimary, amname, exprdef, attoptions;
- """, [self.index_default_access_method, table_name])
- for index, columns, unique, primary, orders, type_, definition, options in cursor.fetchall():
+ """,
+ [self.index_default_access_method, table_name],
+ )
+ for (
+ index,
+ columns,
+ unique,
+ primary,
+ orders,
+ type_,
+ definition,
+ options,
+ ) in cursor.fetchall():
if index not in constraints:
basic_index = (
- type_ == self.index_default_access_method and
+ type_ == self.index_default_access_method
+ and
# '_btree' references
# django.contrib.postgres.indexes.BTreeIndex.suffix.
- not index.endswith('_btree') and options is None
+ not index.endswith("_btree")
+ and options is None
)
constraints[index] = {
"columns": columns if columns != [None] else [],
diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py
index 762cd8d23e..68448157ec 100644
--- a/django/db/backends/postgresql/operations.py
+++ b/django/db/backends/postgresql/operations.py
@@ -7,17 +7,22 @@ from django.db.models.constants import OnConflict
class DatabaseOperations(BaseDatabaseOperations):
- cast_char_field_without_max_length = 'varchar'
- explain_prefix = 'EXPLAIN'
+ cast_char_field_without_max_length = "varchar"
+ explain_prefix = "EXPLAIN"
cast_data_types = {
- 'AutoField': 'integer',
- 'BigAutoField': 'bigint',
- 'SmallAutoField': 'smallint',
+ "AutoField": "integer",
+ "BigAutoField": "bigint",
+ "SmallAutoField": "smallint",
}
def unification_cast_sql(self, output_field):
internal_type = output_field.get_internal_type()
- if internal_type in ("GenericIPAddressField", "IPAddressField", "TimeField", "UUIDField"):
+ if internal_type in (
+ "GenericIPAddressField",
+ "IPAddressField",
+ "TimeField",
+ "UUIDField",
+ ):
# PostgreSQL will resolve a union as type 'text' if input types are
# 'unknown'.
# https://www.postgresql.org/docs/current/typeconv-union-case.html
@@ -25,17 +30,19 @@ class DatabaseOperations(BaseDatabaseOperations):
# PostgreSQL configuration so we need to explicitly cast them.
# We must also remove components of the type within brackets:
# varchar(255) -> varchar.
- return 'CAST(%%s AS %s)' % output_field.db_type(self.connection).split('(')[0]
- return '%s'
+ return (
+ "CAST(%%s AS %s)" % output_field.db_type(self.connection).split("(")[0]
+ )
+ return "%s"
def date_extract_sql(self, lookup_type, field_name):
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT
- if lookup_type == 'week_day':
+ if lookup_type == "week_day":
# For consistency across backends, we return Sunday=1, Saturday=7.
return "EXTRACT('dow' FROM %s) + 1" % field_name
- elif lookup_type == 'iso_week_day':
+ elif lookup_type == "iso_week_day":
return "EXTRACT('isodow' FROM %s)" % field_name
- elif lookup_type == 'iso_year':
+ elif lookup_type == "iso_year":
return "EXTRACT('isoyear' FROM %s)" % field_name
else:
return "EXTRACT('%s' FROM %s)" % (lookup_type, field_name)
@@ -48,22 +55,25 @@ class DatabaseOperations(BaseDatabaseOperations):
def _prepare_tzname_delta(self, tzname):
tzname, sign, offset = split_tzname_delta(tzname)
if offset:
- sign = '-' if sign == '+' else '+'
- return f'{tzname}{sign}{offset}'
+ sign = "-" if sign == "+" else "+"
+ return f"{tzname}{sign}{offset}"
return tzname
def _convert_field_to_tz(self, field_name, tzname):
if tzname and settings.USE_TZ:
- field_name = "%s AT TIME ZONE '%s'" % (field_name, self._prepare_tzname_delta(tzname))
+ field_name = "%s AT TIME ZONE '%s'" % (
+ field_name,
+ self._prepare_tzname_delta(tzname),
+ )
return field_name
def datetime_cast_date_sql(self, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
- return '(%s)::date' % field_name
+ return "(%s)::date" % field_name
def datetime_cast_time_sql(self, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
- return '(%s)::time' % field_name
+ return "(%s)::time" % field_name
def datetime_extract_sql(self, lookup_type, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
@@ -89,21 +99,30 @@ class DatabaseOperations(BaseDatabaseOperations):
return cursor.fetchall()
def lookup_cast(self, lookup_type, internal_type=None):
- lookup = '%s'
+ lookup = "%s"
# Cast text lookups to text to allow things like filter(x__contains=4)
- if lookup_type in ('iexact', 'contains', 'icontains', 'startswith',
- 'istartswith', 'endswith', 'iendswith', 'regex', 'iregex'):
- if internal_type in ('IPAddressField', 'GenericIPAddressField'):
+ if lookup_type in (
+ "iexact",
+ "contains",
+ "icontains",
+ "startswith",
+ "istartswith",
+ "endswith",
+ "iendswith",
+ "regex",
+ "iregex",
+ ):
+ if internal_type in ("IPAddressField", "GenericIPAddressField"):
lookup = "HOST(%s)"
- elif internal_type in ('CICharField', 'CIEmailField', 'CITextField'):
- lookup = '%s::citext'
+ elif internal_type in ("CICharField", "CIEmailField", "CITextField"):
+ lookup = "%s::citext"
else:
lookup = "%s::text"
# Use UPPER(x) for case-insensitive lookups; it's faster.
- if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'):
- lookup = 'UPPER(%s)' % lookup
+ if lookup_type in ("iexact", "icontains", "istartswith", "iendswith"):
+ lookup = "UPPER(%s)" % lookup
return lookup
@@ -128,29 +147,32 @@ class DatabaseOperations(BaseDatabaseOperations):
# Perform a single SQL 'TRUNCATE x, y, z...;' statement. It allows us
# to truncate tables referenced by a foreign key in any other table.
sql_parts = [
- style.SQL_KEYWORD('TRUNCATE'),
- ', '.join(style.SQL_FIELD(self.quote_name(table)) for table in tables),
+ style.SQL_KEYWORD("TRUNCATE"),
+ ", ".join(style.SQL_FIELD(self.quote_name(table)) for table in tables),
]
if reset_sequences:
- sql_parts.append(style.SQL_KEYWORD('RESTART IDENTITY'))
+ sql_parts.append(style.SQL_KEYWORD("RESTART IDENTITY"))
if allow_cascade:
- sql_parts.append(style.SQL_KEYWORD('CASCADE'))
- return ['%s;' % ' '.join(sql_parts)]
+ sql_parts.append(style.SQL_KEYWORD("CASCADE"))
+ return ["%s;" % " ".join(sql_parts)]
def sequence_reset_by_name_sql(self, style, sequences):
# 'ALTER SEQUENCE sequence_name RESTART WITH 1;'... style SQL statements
# to reset sequence indices
sql = []
for sequence_info in sequences:
- table_name = sequence_info['table']
+ table_name = sequence_info["table"]
# 'id' will be the case if it's an m2m using an autogenerated
# intermediate table (see BaseDatabaseIntrospection.sequence_list).
- column_name = sequence_info['column'] or 'id'
- sql.append("%s setval(pg_get_serial_sequence('%s','%s'), 1, false);" % (
- style.SQL_KEYWORD('SELECT'),
- style.SQL_TABLE(self.quote_name(table_name)),
- style.SQL_FIELD(column_name),
- ))
+ column_name = sequence_info["column"] or "id"
+ sql.append(
+ "%s setval(pg_get_serial_sequence('%s','%s'), 1, false);"
+ % (
+ style.SQL_KEYWORD("SELECT"),
+ style.SQL_TABLE(self.quote_name(table_name)),
+ style.SQL_FIELD(column_name),
+ )
+ )
return sql
def tablespace_sql(self, tablespace, inline=False):
@@ -161,6 +183,7 @@ class DatabaseOperations(BaseDatabaseOperations):
def sequence_reset_sql(self, style, model_list):
from django.db import models
+
output = []
qn = self.quote_name
for model in model_list:
@@ -174,14 +197,15 @@ class DatabaseOperations(BaseDatabaseOperations):
if isinstance(f, models.AutoField):
output.append(
"%s setval(pg_get_serial_sequence('%s','%s'), "
- "coalesce(max(%s), 1), max(%s) %s null) %s %s;" % (
- style.SQL_KEYWORD('SELECT'),
+ "coalesce(max(%s), 1), max(%s) %s null) %s %s;"
+ % (
+ style.SQL_KEYWORD("SELECT"),
style.SQL_TABLE(qn(model._meta.db_table)),
style.SQL_FIELD(f.column),
style.SQL_FIELD(qn(f.column)),
style.SQL_FIELD(qn(f.column)),
- style.SQL_KEYWORD('IS NOT'),
- style.SQL_KEYWORD('FROM'),
+ style.SQL_KEYWORD("IS NOT"),
+ style.SQL_KEYWORD("FROM"),
style.SQL_TABLE(qn(model._meta.db_table)),
)
)
@@ -207,9 +231,9 @@ class DatabaseOperations(BaseDatabaseOperations):
def distinct_sql(self, fields, params):
if fields:
params = [param for param_list in params for param in param_list]
- return (['DISTINCT ON (%s)' % ', '.join(fields)], params)
+ return (["DISTINCT ON (%s)" % ", ".join(fields)], params)
else:
- return ['DISTINCT'], []
+ return ["DISTINCT"], []
def last_executed_query(self, cursor, sql, params):
# https://www.psycopg.org/docs/cursor.html#cursor.query
@@ -220,14 +244,16 @@ class DatabaseOperations(BaseDatabaseOperations):
def return_insert_columns(self, fields):
if not fields:
- return '', ()
+ return "", ()
columns = [
- '%s.%s' % (
+ "%s.%s"
+ % (
self.quote_name(field.model._meta.db_table),
self.quote_name(field.column),
- ) for field in fields
+ )
+ for field in fields
]
- return 'RETURNING %s' % ', '.join(columns), ()
+ return "RETURNING %s" % ", ".join(columns), ()
def bulk_insert_sql(self, fields, placeholder_rows):
placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
@@ -252,7 +278,7 @@ class DatabaseOperations(BaseDatabaseOperations):
return None
def subtract_temporals(self, internal_type, lhs, rhs):
- if internal_type == 'DateField':
+ if internal_type == "DateField":
lhs_sql, lhs_params = lhs
rhs_sql, rhs_params = rhs
params = (*lhs_params, *rhs_params)
@@ -263,27 +289,34 @@ class DatabaseOperations(BaseDatabaseOperations):
prefix = super().explain_query_prefix(format)
extra = {}
if format:
- extra['FORMAT'] = format
+ extra["FORMAT"] = format
if options:
- extra.update({
- name.upper(): 'true' if value else 'false'
- for name, value in options.items()
- })
+ extra.update(
+ {
+ name.upper(): "true" if value else "false"
+ for name, value in options.items()
+ }
+ )
if extra:
- prefix += ' (%s)' % ', '.join('%s %s' % i for i in extra.items())
+ prefix += " (%s)" % ", ".join("%s %s" % i for i in extra.items())
return prefix
def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
if on_conflict == OnConflict.IGNORE:
- return 'ON CONFLICT DO NOTHING'
+ return "ON CONFLICT DO NOTHING"
if on_conflict == OnConflict.UPDATE:
- return 'ON CONFLICT(%s) DO UPDATE SET %s' % (
- ', '.join(map(self.quote_name, unique_fields)),
- ', '.join([
- f'{field} = EXCLUDED.{field}'
- for field in map(self.quote_name, update_fields)
- ]),
+ return "ON CONFLICT(%s) DO UPDATE SET %s" % (
+ ", ".join(map(self.quote_name, unique_fields)),
+ ", ".join(
+ [
+ f"{field} = EXCLUDED.{field}"
+ for field in map(self.quote_name, update_fields)
+ ]
+ ),
)
return super().on_conflict_suffix_sql(
- fields, on_conflict, update_fields, unique_fields,
+ fields,
+ on_conflict,
+ update_fields,
+ unique_fields,
)
diff --git a/django/db/backends/postgresql/schema.py b/django/db/backends/postgresql/schema.py
index f3b5baecbe..47e9a6a8f3 100644
--- a/django/db/backends/postgresql/schema.py
+++ b/django/db/backends/postgresql/schema.py
@@ -9,16 +9,18 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
sql_create_sequence = "CREATE SEQUENCE %(sequence)s"
sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE"
- sql_set_sequence_max = "SELECT setval('%(sequence)s', MAX(%(column)s)) FROM %(table)s"
- sql_set_sequence_owner = 'ALTER SEQUENCE %(sequence)s OWNED BY %(table)s.%(column)s'
+ sql_set_sequence_max = (
+ "SELECT setval('%(sequence)s', MAX(%(column)s)) FROM %(table)s"
+ )
+ sql_set_sequence_owner = "ALTER SEQUENCE %(sequence)s OWNED BY %(table)s.%(column)s"
sql_create_index = (
- 'CREATE INDEX %(name)s ON %(table)s%(using)s '
- '(%(columns)s)%(include)s%(extra)s%(condition)s'
+ "CREATE INDEX %(name)s ON %(table)s%(using)s "
+ "(%(columns)s)%(include)s%(extra)s%(condition)s"
)
sql_create_index_concurrently = (
- 'CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s '
- '(%(columns)s)%(include)s%(extra)s%(condition)s'
+ "CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s "
+ "(%(columns)s)%(include)s%(extra)s%(condition)s"
)
sql_delete_index = "DROP INDEX IF EXISTS %(name)s"
sql_delete_index_concurrently = "DROP INDEX CONCURRENTLY IF EXISTS %(name)s"
@@ -26,21 +28,21 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# Setting the constraint to IMMEDIATE to allow changing data in the same
# transaction.
sql_create_column_inline_fk = (
- 'CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s'
- '; SET CONSTRAINTS %(namespace)s%(name)s IMMEDIATE'
+ "CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s"
+ "; SET CONSTRAINTS %(namespace)s%(name)s IMMEDIATE"
)
# Setting the constraint to IMMEDIATE runs any deferred checks to allow
# dropping it in the same transaction.
sql_delete_fk = "SET CONSTRAINTS %(name)s IMMEDIATE; ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
- sql_delete_procedure = 'DROP FUNCTION %(procedure)s(%(param_types)s)'
+ sql_delete_procedure = "DROP FUNCTION %(procedure)s(%(param_types)s)"
def quote_value(self, value):
if isinstance(value, str):
- value = value.replace('%', '%%')
+ value = value.replace("%", "%%")
adapted = psycopg2.extensions.adapt(value)
- if hasattr(adapted, 'encoding'):
- adapted.encoding = 'utf8'
+ if hasattr(adapted, "encoding"):
+ adapted.encoding = "utf8"
# getquoted() returns a quoted bytestring of the adapted value.
return adapted.getquoted().decode()
@@ -61,7 +63,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
def _field_base_data_types(self, field):
# Yield base data types for array fields.
- if field.base_field.get_internal_type() == 'ArrayField':
+ if field.base_field.get_internal_type() == "ArrayField":
yield from self._field_base_data_types(field.base_field)
else:
yield self._field_data_type(field.base_field)
@@ -80,45 +82,52 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
#
# The same doesn't apply to array fields such as varchar[size]
# and text[size], so skip them.
- if '[' in db_type:
+ if "[" in db_type:
return None
- if db_type.startswith('varchar'):
+ if db_type.startswith("varchar"):
return self._create_index_sql(
model,
fields=[field],
- suffix='_like',
- opclasses=['varchar_pattern_ops'],
+ suffix="_like",
+ opclasses=["varchar_pattern_ops"],
)
- elif db_type.startswith('text'):
+ elif db_type.startswith("text"):
return self._create_index_sql(
model,
fields=[field],
- suffix='_like',
- opclasses=['text_pattern_ops'],
+ suffix="_like",
+ opclasses=["text_pattern_ops"],
)
return None
def _alter_column_type_sql(self, model, old_field, new_field, new_type):
- self.sql_alter_column_type = 'ALTER COLUMN %(column)s TYPE %(type)s'
+ self.sql_alter_column_type = "ALTER COLUMN %(column)s TYPE %(type)s"
# Cast when data type changed.
- using_sql = ' USING %(column)s::%(type)s'
+ using_sql = " USING %(column)s::%(type)s"
new_internal_type = new_field.get_internal_type()
old_internal_type = old_field.get_internal_type()
- if new_internal_type == 'ArrayField' and new_internal_type == old_internal_type:
+ if new_internal_type == "ArrayField" and new_internal_type == old_internal_type:
# Compare base data types for array fields.
- if list(self._field_base_data_types(old_field)) != list(self._field_base_data_types(new_field)):
+ if list(self._field_base_data_types(old_field)) != list(
+ self._field_base_data_types(new_field)
+ ):
self.sql_alter_column_type += using_sql
elif self._field_data_type(old_field) != self._field_data_type(new_field):
self.sql_alter_column_type += using_sql
# Make ALTER TYPE with SERIAL make sense.
table = strip_quotes(model._meta.db_table)
- serial_fields_map = {'bigserial': 'bigint', 'serial': 'integer', 'smallserial': 'smallint'}
+ serial_fields_map = {
+ "bigserial": "bigint",
+ "serial": "integer",
+ "smallserial": "smallint",
+ }
if new_type.lower() in serial_fields_map:
column = strip_quotes(new_field.column)
sequence_name = "%s_%s_seq" % (table, column)
return (
(
- self.sql_alter_column_type % {
+ self.sql_alter_column_type
+ % {
"column": self.quote_name(column),
"type": serial_fields_map[new_type.lower()],
},
@@ -126,29 +135,35 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
),
[
(
- self.sql_delete_sequence % {
+ self.sql_delete_sequence
+ % {
"sequence": self.quote_name(sequence_name),
},
[],
),
(
- self.sql_create_sequence % {
+ self.sql_create_sequence
+ % {
"sequence": self.quote_name(sequence_name),
},
[],
),
(
- self.sql_alter_column % {
+ self.sql_alter_column
+ % {
"table": self.quote_name(table),
- "changes": self.sql_alter_column_default % {
+ "changes": self.sql_alter_column_default
+ % {
"column": self.quote_name(column),
- "default": "nextval('%s')" % self.quote_name(sequence_name),
- }
+ "default": "nextval('%s')"
+ % self.quote_name(sequence_name),
+ },
},
[],
),
(
- self.sql_set_sequence_max % {
+ self.sql_set_sequence_max
+ % {
"table": self.quote_name(table),
"column": self.quote_name(column),
"sequence": self.quote_name(sequence_name),
@@ -156,24 +171,31 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
[],
),
(
- self.sql_set_sequence_owner % {
- 'table': self.quote_name(table),
- 'column': self.quote_name(column),
- 'sequence': self.quote_name(sequence_name),
+ self.sql_set_sequence_owner
+ % {
+ "table": self.quote_name(table),
+ "column": self.quote_name(column),
+ "sequence": self.quote_name(sequence_name),
},
[],
),
],
)
- elif old_field.db_parameters(connection=self.connection)['type'] in serial_fields_map:
+ elif (
+ old_field.db_parameters(connection=self.connection)["type"]
+ in serial_fields_map
+ ):
# Drop the sequence if migrating away from AutoField.
column = strip_quotes(new_field.column)
- sequence_name = '%s_%s_seq' % (table, column)
- fragment, _ = super()._alter_column_type_sql(model, old_field, new_field, new_type)
+ sequence_name = "%s_%s_seq" % (table, column)
+ fragment, _ = super()._alter_column_type_sql(
+ model, old_field, new_field, new_type
+ )
return fragment, [
(
- self.sql_delete_sequence % {
- 'sequence': self.quote_name(sequence_name),
+ self.sql_delete_sequence
+ % {
+ "sequence": self.quote_name(sequence_name),
},
[],
),
@@ -181,58 +203,114 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
else:
return super()._alter_column_type_sql(model, old_field, new_field, new_type)
- def _alter_field(self, model, old_field, new_field, old_type, new_type,
- old_db_params, new_db_params, strict=False):
+ def _alter_field(
+ self,
+ model,
+ old_field,
+ new_field,
+ old_type,
+ new_type,
+ old_db_params,
+ new_db_params,
+ strict=False,
+ ):
# Drop indexes on varchar/text/citext columns that are changing to a
# different type.
if (old_field.db_index or old_field.unique) and (
- (old_type.startswith('varchar') and not new_type.startswith('varchar')) or
- (old_type.startswith('text') and not new_type.startswith('text')) or
- (old_type.startswith('citext') and not new_type.startswith('citext'))
+ (old_type.startswith("varchar") and not new_type.startswith("varchar"))
+ or (old_type.startswith("text") and not new_type.startswith("text"))
+ or (old_type.startswith("citext") and not new_type.startswith("citext"))
):
- index_name = self._create_index_name(model._meta.db_table, [old_field.column], suffix='_like')
+ index_name = self._create_index_name(
+ model._meta.db_table, [old_field.column], suffix="_like"
+ )
self.execute(self._delete_index_sql(model, index_name))
super()._alter_field(
- model, old_field, new_field, old_type, new_type, old_db_params,
- new_db_params, strict,
+ model,
+ old_field,
+ new_field,
+ old_type,
+ new_type,
+ old_db_params,
+ new_db_params,
+ strict,
)
# Added an index? Create any PostgreSQL-specific indexes.
- if ((not (old_field.db_index or old_field.unique) and new_field.db_index) or
- (not old_field.unique and new_field.unique)):
+ if (not (old_field.db_index or old_field.unique) and new_field.db_index) or (
+ not old_field.unique and new_field.unique
+ ):
like_index_statement = self._create_like_index_sql(model, new_field)
if like_index_statement is not None:
self.execute(like_index_statement)
# Removed an index? Drop any PostgreSQL-specific indexes.
if old_field.unique and not (new_field.db_index or new_field.unique):
- index_to_remove = self._create_index_name(model._meta.db_table, [old_field.column], suffix='_like')
+ index_to_remove = self._create_index_name(
+ model._meta.db_table, [old_field.column], suffix="_like"
+ )
self.execute(self._delete_index_sql(model, index_to_remove))
def _index_columns(self, table, columns, col_suffixes, opclasses):
if opclasses:
- return IndexColumns(table, columns, self.quote_name, col_suffixes=col_suffixes, opclasses=opclasses)
+ return IndexColumns(
+ table,
+ columns,
+ self.quote_name,
+ col_suffixes=col_suffixes,
+ opclasses=opclasses,
+ )
return super()._index_columns(table, columns, col_suffixes, opclasses)
def add_index(self, model, index, concurrently=False):
- self.execute(index.create_sql(model, self, concurrently=concurrently), params=None)
+ self.execute(
+ index.create_sql(model, self, concurrently=concurrently), params=None
+ )
def remove_index(self, model, index, concurrently=False):
self.execute(index.remove_sql(model, self, concurrently=concurrently))
def _delete_index_sql(self, model, name, sql=None, concurrently=False):
- sql = self.sql_delete_index_concurrently if concurrently else self.sql_delete_index
+ sql = (
+ self.sql_delete_index_concurrently
+ if concurrently
+ else self.sql_delete_index
+ )
return super()._delete_index_sql(model, name, sql)
def _create_index_sql(
- self, model, *, fields=None, name=None, suffix='', using='',
- db_tablespace=None, col_suffixes=(), sql=None, opclasses=(),
- condition=None, concurrently=False, include=None, expressions=None,
+ self,
+ model,
+ *,
+ fields=None,
+ name=None,
+ suffix="",
+ using="",
+ db_tablespace=None,
+ col_suffixes=(),
+ sql=None,
+ opclasses=(),
+ condition=None,
+ concurrently=False,
+ include=None,
+ expressions=None,
):
- sql = self.sql_create_index if not concurrently else self.sql_create_index_concurrently
+ sql = (
+ self.sql_create_index
+ if not concurrently
+ else self.sql_create_index_concurrently
+ )
return super()._create_index_sql(
- model, fields=fields, name=name, suffix=suffix, using=using,
- db_tablespace=db_tablespace, col_suffixes=col_suffixes, sql=sql,
- opclasses=opclasses, condition=condition, include=include,
+ model,
+ fields=fields,
+ name=name,
+ suffix=suffix,
+ using=using,
+ db_tablespace=db_tablespace,
+ col_suffixes=col_suffixes,
+ sql=sql,
+ opclasses=opclasses,
+ condition=condition,
+ include=include,
expressions=expressions,
)