summaryrefslogtreecommitdiff
path: root/django/db
diff options
context:
space:
mode:
authordjango-bot <ops@djangoproject.com>2022-02-03 20:24:19 +0100
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2022-02-07 20:37:05 +0100
commit9c19aff7c7561e3a82978a272ecdaad40dda5c00 (patch)
treef0506b668a013d0063e5fba3dbf4863b466713ba /django/db
parentf68fa8b45dfac545cfc4111d4e52804c86db68d3 (diff)
Refs #33476 -- Reformatted code with Black.
Diffstat (limited to 'django/db')
-rw-r--r--django/db/__init__.py33
-rw-r--r--django/db/backends/base/base.py103
-rw-r--r--django/db/backends/base/client.py9
-rw-r--r--django/db/backends/base/creation.py162
-rw-r--r--django/db/backends/base/features.py50
-rw-r--r--django/db/backends/base/introspection.py63
-rw-r--r--django/db/backends/base/operations.py201
-rw-r--r--django/db/backends/base/schema.py858
-rw-r--r--django/db/backends/base/validation.py10
-rw-r--r--django/db/backends/ddl_references.py64
-rw-r--r--django/db/backends/dummy/base.py8
-rw-r--r--django/db/backends/mysql/base.py248
-rw-r--r--django/db/backends/mysql/client.py36
-rw-r--r--django/db/backends/mysql/compiler.py24
-rw-r--r--django/db/backends/mysql/creation.py61
-rw-r--r--django/db/backends/mysql/features.py202
-rw-r--r--django/db/backends/mysql/introspection.py205
-rw-r--r--django/db/backends/mysql/operations.py279
-rw-r--r--django/db/backends/mysql/schema.py65
-rw-r--r--django/db/backends/mysql/validation.py59
-rw-r--r--django/db/backends/oracle/base.py256
-rw-r--r--django/db/backends/oracle/client.py10
-rw-r--r--django/db/backends/oracle/creation.py306
-rw-r--r--django/db/backends/oracle/features.py44
-rw-r--r--django/db/backends/oracle/functions.py12
-rw-r--r--django/db/backends/oracle/introspection.py183
-rw-r--r--django/db/backends/oracle/operations.py364
-rw-r--r--django/db/backends/oracle/schema.py98
-rw-r--r--django/db/backends/oracle/utils.py95
-rw-r--r--django/db/backends/oracle/validation.py4
-rw-r--r--django/db/backends/postgresql/base.py211
-rw-r--r--django/db/backends/postgresql/client.py48
-rw-r--r--django/db/backends/postgresql/creation.py47
-rw-r--r--django/db/backends/postgresql/features.py30
-rw-r--r--django/db/backends/postgresql/introspection.py130
-rw-r--r--django/db/backends/postgresql/operations.py155
-rw-r--r--django/db/backends/postgresql/schema.py204
-rw-r--r--django/db/backends/sqlite3/_functions.py234
-rw-r--r--django/db/backends/sqlite3/base.py224
-rw-r--r--django/db/backends/sqlite3/client.py4
-rw-r--r--django/db/backends/sqlite3/creation.py49
-rw-r--r--django/db/backends/sqlite3/features.py76
-rw-r--r--django/db/backends/sqlite3/introspection.py241
-rw-r--r--django/db/backends/sqlite3/operations.py183
-rw-r--r--django/db/backends/sqlite3/schema.py307
-rw-r--r--django/db/backends/utils.py107
-rw-r--r--django/db/migrations/autodetector.py669
-rw-r--r--django/db/migrations/exceptions.py6
-rw-r--r--django/db/migrations/executor.py69
-rw-r--r--django/db/migrations/graph.py38
-rw-r--r--django/db/migrations/loader.py91
-rw-r--r--django/db/migrations/migration.py42
-rw-r--r--django/db/migrations/operations/__init__.py43
-rw-r--r--django/db/migrations/operations/base.py12
-rw-r--r--django/db/migrations/operations/fields.py156
-rw-r--r--django/db/migrations/operations/models.py410
-rw-r--r--django/db/migrations/operations/special.py75
-rw-r--r--django/db/migrations/optimizer.py8
-rw-r--r--django/db/migrations/questioner.py138
-rw-r--r--django/db/migrations/recorder.py17
-rw-r--r--django/db/migrations/serializer.py88
-rw-r--r--django/db/migrations/state.py250
-rw-r--r--django/db/migrations/utils.py40
-rw-r--r--django/db/migrations/writer.py77
-rw-r--r--django/db/models/__init__.py99
-rw-r--r--django/db/models/aggregates.py104
-rw-r--r--django/db/models/base.py841
-rw-r--r--django/db/models/constants.py6
-rw-r--r--django/db/models/constraints.py151
-rw-r--r--django/db/models/deletion.py169
-rw-r--r--django/db/models/enums.py17
-rw-r--r--django/db/models/expressions.py504
-rw-r--r--django/db/models/fields/__init__.py1094
-rw-r--r--django/db/models/fields/files.py123
-rw-r--r--django/db/models/fields/json.py249
-rw-r--r--django/db/models/fields/mixins.py19
-rw-r--r--django/db/models/fields/proxy.py4
-rw-r--r--django/db/models/fields/related.py883
-rw-r--r--django/db/models/fields/related_descriptors.py425
-rw-r--r--django/db/models/fields/related_lookups.py81
-rw-r--r--django/db/models/fields/reverse_related.py97
-rw-r--r--django/db/models/functions/__init__.py202
-rw-r--r--django/db/models/functions/comparison.py100
-rw-r--r--django/db/models/functions/datetime.py200
-rw-r--r--django/db/models/functions/math.py135
-rw-r--r--django/db/models/functions/mixins.py29
-rw-r--r--django/db/models/functions/text.py178
-rw-r--r--django/db/models/functions/window.py52
-rw-r--r--django/db/models/indexes.py154
-rw-r--r--django/db/models/lookups.py230
-rw-r--r--django/db/models/manager.py49
-rw-r--r--django/db/models/options.py261
-rw-r--r--django/db/models/query.py729
-rw-r--r--django/db/models/query_utils.py82
-rw-r--r--django/db/models/signals.py9
-rw-r--r--django/db/models/sql/__init__.py2
-rw-r--r--django/db/models/sql/compiler.py770
-rw-r--r--django/db/models/sql/constants.py16
-rw-r--r--django/db/models/sql/datastructures.py74
-rw-r--r--django/db/models/sql/query.py616
-rw-r--r--django/db/models/sql/subqueries.py42
-rw-r--r--django/db/models/sql/where.py57
-rw-r--r--django/db/models/utils.py6
-rw-r--r--django/db/transaction.py22
-rw-r--r--django/db/utils.py87
105 files changed, 10846 insertions, 6713 deletions
diff --git a/django/db/__init__.py b/django/db/__init__.py
index 26127860ed..b0cae97e01 100644
--- a/django/db/__init__.py
+++ b/django/db/__init__.py
@@ -1,17 +1,36 @@
from django.core import signals
from django.db.utils import (
- DEFAULT_DB_ALIAS, DJANGO_VERSION_PICKLE_KEY, ConnectionHandler,
- ConnectionRouter, DatabaseError, DataError, Error, IntegrityError,
- InterfaceError, InternalError, NotSupportedError, OperationalError,
+ DEFAULT_DB_ALIAS,
+ DJANGO_VERSION_PICKLE_KEY,
+ ConnectionHandler,
+ ConnectionRouter,
+ DatabaseError,
+ DataError,
+ Error,
+ IntegrityError,
+ InterfaceError,
+ InternalError,
+ NotSupportedError,
+ OperationalError,
ProgrammingError,
)
from django.utils.connection import ConnectionProxy
__all__ = [
- 'connection', 'connections', 'router', 'DatabaseError', 'IntegrityError',
- 'InternalError', 'ProgrammingError', 'DataError', 'NotSupportedError',
- 'Error', 'InterfaceError', 'OperationalError', 'DEFAULT_DB_ALIAS',
- 'DJANGO_VERSION_PICKLE_KEY',
+ "connection",
+ "connections",
+ "router",
+ "DatabaseError",
+ "IntegrityError",
+ "InternalError",
+ "ProgrammingError",
+ "DataError",
+ "NotSupportedError",
+ "Error",
+ "InterfaceError",
+ "OperationalError",
+ "DEFAULT_DB_ALIAS",
+ "DJANGO_VERSION_PICKLE_KEY",
]
connections = ConnectionHandler()
diff --git a/django/db/backends/base/base.py b/django/db/backends/base/base.py
index 58dd6d43bd..1aee03848b 100644
--- a/django/db/backends/base/base.py
+++ b/django/db/backends/base/base.py
@@ -23,19 +23,21 @@ from django.utils import timezone
from django.utils.asyncio import async_unsafe
from django.utils.functional import cached_property
-NO_DB_ALIAS = '__no_db__'
+NO_DB_ALIAS = "__no_db__"
# RemovedInDjango50Warning
def timezone_constructor(tzname):
if settings.USE_DEPRECATED_PYTZ:
import pytz
+
return pytz.timezone(tzname)
return zoneinfo.ZoneInfo(tzname)
class BaseDatabaseWrapper:
"""Represent a database connection."""
+
# Mapping of Field objects to their column types.
data_types = {}
# Mapping of Field objects to their SQL suffix such as AUTOINCREMENT.
@@ -43,8 +45,8 @@ class BaseDatabaseWrapper:
# Mapping of Field objects to their SQL for CHECK constraints.
data_type_check_constraints = {}
ops = None
- vendor = 'unknown'
- display_name = 'unknown'
+ vendor = "unknown"
+ display_name = "unknown"
SchemaEditorClass = None
# Classes instantiated in __init__().
client_class = None
@@ -124,8 +126,8 @@ class BaseDatabaseWrapper:
def __repr__(self):
return (
- f'<{self.__class__.__qualname__} '
- f'vendor={self.vendor!r} alias={self.alias!r}>'
+ f"<{self.__class__.__qualname__} "
+ f"vendor={self.vendor!r} alias={self.alias!r}>"
)
def ensure_timezone(self):
@@ -153,10 +155,10 @@ class BaseDatabaseWrapper:
"""
if not settings.USE_TZ:
return None
- elif self.settings_dict['TIME_ZONE'] is None:
+ elif self.settings_dict["TIME_ZONE"] is None:
return timezone.utc
else:
- return timezone_constructor(self.settings_dict['TIME_ZONE'])
+ return timezone_constructor(self.settings_dict["TIME_ZONE"])
@cached_property
def timezone_name(self):
@@ -165,10 +167,10 @@ class BaseDatabaseWrapper:
"""
if not settings.USE_TZ:
return settings.TIME_ZONE
- elif self.settings_dict['TIME_ZONE'] is None:
- return 'UTC'
+ elif self.settings_dict["TIME_ZONE"] is None:
+ return "UTC"
else:
- return self.settings_dict['TIME_ZONE']
+ return self.settings_dict["TIME_ZONE"]
@property
def queries_logged(self):
@@ -179,26 +181,35 @@ class BaseDatabaseWrapper:
if len(self.queries_log) == self.queries_log.maxlen:
warnings.warn(
"Limit for query logging exceeded, only the last {} queries "
- "will be returned.".format(self.queries_log.maxlen))
+ "will be returned.".format(self.queries_log.maxlen)
+ )
return list(self.queries_log)
# ##### Backend-specific methods for creating connections and cursors #####
def get_connection_params(self):
"""Return a dict of parameters suitable for get_new_connection."""
- raise NotImplementedError('subclasses of BaseDatabaseWrapper may require a get_connection_params() method')
+ raise NotImplementedError(
+ "subclasses of BaseDatabaseWrapper may require a get_connection_params() method"
+ )
def get_new_connection(self, conn_params):
"""Open a connection to the database."""
- raise NotImplementedError('subclasses of BaseDatabaseWrapper may require a get_new_connection() method')
+ raise NotImplementedError(
+ "subclasses of BaseDatabaseWrapper may require a get_new_connection() method"
+ )
def init_connection_state(self):
"""Initialize the database connection settings."""
- raise NotImplementedError('subclasses of BaseDatabaseWrapper may require an init_connection_state() method')
+ raise NotImplementedError(
+ "subclasses of BaseDatabaseWrapper may require an init_connection_state() method"
+ )
def create_cursor(self, name=None):
"""Create a cursor. Assume that a connection is established."""
- raise NotImplementedError('subclasses of BaseDatabaseWrapper may require a create_cursor() method')
+ raise NotImplementedError(
+ "subclasses of BaseDatabaseWrapper may require a create_cursor() method"
+ )
# ##### Backend-specific methods for creating connections #####
@@ -213,8 +224,8 @@ class BaseDatabaseWrapper:
self.atomic_blocks = []
self.needs_rollback = False
# Reset parameters defining when to close/health-check the connection.
- self.health_check_enabled = self.settings_dict['CONN_HEALTH_CHECKS']
- max_age = self.settings_dict['CONN_MAX_AGE']
+ self.health_check_enabled = self.settings_dict["CONN_HEALTH_CHECKS"]
+ max_age = self.settings_dict["CONN_MAX_AGE"]
self.close_at = None if max_age is None else time.monotonic() + max_age
self.closed_in_transaction = False
self.errors_occurred = False
@@ -223,14 +234,14 @@ class BaseDatabaseWrapper:
# Establish the connection
conn_params = self.get_connection_params()
self.connection = self.get_new_connection(conn_params)
- self.set_autocommit(self.settings_dict['AUTOCOMMIT'])
+ self.set_autocommit(self.settings_dict["AUTOCOMMIT"])
self.init_connection_state()
connection_created.send(sender=self.__class__, connection=self)
self.run_on_commit = []
def check_settings(self):
- if self.settings_dict['TIME_ZONE'] is not None and not settings.USE_TZ:
+ if self.settings_dict["TIME_ZONE"] is not None and not settings.USE_TZ:
raise ImproperlyConfigured(
"Connection '%s' cannot set TIME_ZONE because USE_TZ is False."
% self.alias
@@ -356,7 +367,7 @@ class BaseDatabaseWrapper:
return
thread_ident = _thread.get_ident()
- tid = str(thread_ident).replace('-', '')
+ tid = str(thread_ident).replace("-", "")
self.savepoint_state += 1
sid = "s%s_x%d" % (tid, self.savepoint_state)
@@ -406,7 +417,9 @@ class BaseDatabaseWrapper:
"""
Backend-specific implementation to enable or disable autocommit.
"""
- raise NotImplementedError('subclasses of BaseDatabaseWrapper may require a _set_autocommit() method')
+ raise NotImplementedError(
+ "subclasses of BaseDatabaseWrapper may require a _set_autocommit() method"
+ )
# ##### Generic transaction management methods #####
@@ -415,7 +428,9 @@ class BaseDatabaseWrapper:
self.ensure_connection()
return self.autocommit
- def set_autocommit(self, autocommit, force_begin_transaction_with_broken_autocommit=False):
+ def set_autocommit(
+ self, autocommit, force_begin_transaction_with_broken_autocommit=False
+ ):
"""
Enable or disable autocommit.
@@ -432,8 +447,9 @@ class BaseDatabaseWrapper:
self.ensure_connection()
start_transaction_under_autocommit = (
- force_begin_transaction_with_broken_autocommit and not autocommit and
- hasattr(self, '_start_transaction_under_autocommit')
+ force_begin_transaction_with_broken_autocommit
+ and not autocommit
+ and hasattr(self, "_start_transaction_under_autocommit")
)
if start_transaction_under_autocommit:
@@ -451,7 +467,8 @@ class BaseDatabaseWrapper:
"""Get the "needs rollback" flag -- for *advanced use* only."""
if not self.in_atomic_block:
raise TransactionManagementError(
- "The rollback flag doesn't work outside of an 'atomic' block.")
+ "The rollback flag doesn't work outside of an 'atomic' block."
+ )
return self.needs_rollback
def set_rollback(self, rollback):
@@ -460,20 +477,23 @@ class BaseDatabaseWrapper:
"""
if not self.in_atomic_block:
raise TransactionManagementError(
- "The rollback flag doesn't work outside of an 'atomic' block.")
+ "The rollback flag doesn't work outside of an 'atomic' block."
+ )
self.needs_rollback = rollback
def validate_no_atomic_block(self):
"""Raise an error if an atomic block is active."""
if self.in_atomic_block:
raise TransactionManagementError(
- "This is forbidden when an 'atomic' block is active.")
+ "This is forbidden when an 'atomic' block is active."
+ )
def validate_no_broken_transaction(self):
if self.needs_rollback:
raise TransactionManagementError(
"An error occurred in the current transaction. You can't "
- "execute queries until the end of the 'atomic' block.")
+ "execute queries until the end of the 'atomic' block."
+ )
# ##### Foreign key constraints checks handling #####
@@ -524,14 +544,15 @@ class BaseDatabaseWrapper:
as that may prevent Django from recycling unusable connections.
"""
raise NotImplementedError(
- "subclasses of BaseDatabaseWrapper may require an is_usable() method")
+ "subclasses of BaseDatabaseWrapper may require an is_usable() method"
+ )
def close_if_health_check_failed(self):
"""Close existing connection if it fails a health check."""
if (
- self.connection is None or
- not self.health_check_enabled or
- self.health_check_done
+ self.connection is None
+ or not self.health_check_enabled
+ or self.health_check_done
):
return
@@ -548,7 +569,7 @@ class BaseDatabaseWrapper:
self.health_check_done = False
# If the application didn't restore the original autocommit setting,
# don't take chances, drop the connection.
- if self.get_autocommit() != self.settings_dict['AUTOCOMMIT']:
+ if self.get_autocommit() != self.settings_dict["AUTOCOMMIT"]:
self.close()
return
@@ -580,7 +601,9 @@ class BaseDatabaseWrapper:
def dec_thread_sharing(self):
with self._thread_sharing_lock:
if self._thread_sharing_count <= 0:
- raise RuntimeError('Cannot decrement the thread sharing count below zero.')
+ raise RuntimeError(
+ "Cannot decrement the thread sharing count below zero."
+ )
self._thread_sharing_count -= 1
def validate_thread_sharing(self):
@@ -595,8 +618,7 @@ class BaseDatabaseWrapper:
"DatabaseWrapper objects created in a "
"thread can only be used in that same thread. The object "
"with alias '%s' was created in thread id %s and this is "
- "thread id %s."
- % (self.alias, self._thread_ident, _thread.get_ident())
+ "thread id %s." % (self.alias, self._thread_ident, _thread.get_ident())
)
# ##### Miscellaneous #####
@@ -657,7 +679,7 @@ class BaseDatabaseWrapper:
being exposed to potential child threads while (or after) the test
database is destroyed. Refs #10868, #17786, #16969.
"""
- conn = self.__class__({**self.settings_dict, 'NAME': None}, alias=NO_DB_ALIAS)
+ conn = self.__class__({**self.settings_dict, "NAME": None}, alias=NO_DB_ALIAS)
try:
with conn.cursor() as cursor:
yield cursor
@@ -670,7 +692,8 @@ class BaseDatabaseWrapper:
"""
if self.SchemaEditorClass is None:
raise NotImplementedError(
- 'The SchemaEditorClass attribute of this database wrapper is still None')
+ "The SchemaEditorClass attribute of this database wrapper is still None"
+ )
return self.SchemaEditorClass(self, *args, **kwargs)
def on_commit(self, func):
@@ -680,7 +703,9 @@ class BaseDatabaseWrapper:
# Transaction in progress; save for execution on commit.
self.run_on_commit.append((set(self.savepoint_ids), func))
elif not self.get_autocommit():
- raise TransactionManagementError('on_commit() cannot be used in manual transaction management')
+ raise TransactionManagementError(
+ "on_commit() cannot be used in manual transaction management"
+ )
else:
# No transaction in progress and in autocommit mode; execute
# immediately.
diff --git a/django/db/backends/base/client.py b/django/db/backends/base/client.py
index 8aca821fd2..031056372d 100644
--- a/django/db/backends/base/client.py
+++ b/django/db/backends/base/client.py
@@ -4,6 +4,7 @@ import subprocess
class BaseDatabaseClient:
"""Encapsulate backend-specific methods for opening a client shell."""
+
# This should be a string representing the name of the executable
# (e.g., "psql"). Subclasses must override this.
executable_name = None
@@ -15,11 +16,13 @@ class BaseDatabaseClient:
@classmethod
def settings_to_cmd_args_env(cls, settings_dict, parameters):
raise NotImplementedError(
- 'subclasses of BaseDatabaseClient must provide a '
- 'settings_to_cmd_args_env() method or override a runshell().'
+ "subclasses of BaseDatabaseClient must provide a "
+ "settings_to_cmd_args_env() method or override a runshell()."
)
def runshell(self, parameters):
- args, env = self.settings_to_cmd_args_env(self.connection.settings_dict, parameters)
+ args, env = self.settings_to_cmd_args_env(
+ self.connection.settings_dict, parameters
+ )
env = {**os.environ, **env} if env else None
subprocess.run(args, env=env, check=True)
diff --git a/django/db/backends/base/creation.py b/django/db/backends/base/creation.py
index d1c0e1ac96..78480fc0f8 100644
--- a/django/db/backends/base/creation.py
+++ b/django/db/backends/base/creation.py
@@ -11,7 +11,7 @@ from django.utils.module_loading import import_string
# The prefix to put on the default database name when creating
# the test database.
-TEST_DATABASE_PREFIX = 'test_'
+TEST_DATABASE_PREFIX = "test_"
class BaseDatabaseCreation:
@@ -19,6 +19,7 @@ class BaseDatabaseCreation:
Encapsulate backend-specific differences pertaining to creation and
destruction of the test database.
"""
+
def __init__(self, connection):
self.connection = connection
@@ -28,7 +29,9 @@ class BaseDatabaseCreation:
def log(self, msg):
sys.stderr.write(msg + os.linesep)
- def create_test_db(self, verbosity=1, autoclobber=False, serialize=True, keepdb=False):
+ def create_test_db(
+ self, verbosity=1, autoclobber=False, serialize=True, keepdb=False
+ ):
"""
Create a test database, prompting the user for confirmation if the
database already exists. Return the name of the test database created.
@@ -39,14 +42,17 @@ class BaseDatabaseCreation:
test_database_name = self._get_test_db_name()
if verbosity >= 1:
- action = 'Creating'
+ action = "Creating"
if keepdb:
action = "Using existing"
- self.log('%s test database for alias %s...' % (
- action,
- self._get_database_display_str(verbosity, test_database_name),
- ))
+ self.log(
+ "%s test database for alias %s..."
+ % (
+ action,
+ self._get_database_display_str(verbosity, test_database_name),
+ )
+ )
# We could skip this call if keepdb is True, but we instead
# give it the keepdb param. This is to handle the case
@@ -60,25 +66,24 @@ class BaseDatabaseCreation:
self.connection.settings_dict["NAME"] = test_database_name
try:
- if self.connection.settings_dict['TEST']['MIGRATE'] is False:
+ if self.connection.settings_dict["TEST"]["MIGRATE"] is False:
# Disable migrations for all apps.
old_migration_modules = settings.MIGRATION_MODULES
settings.MIGRATION_MODULES = {
- app.label: None
- for app in apps.get_app_configs()
+ app.label: None for app in apps.get_app_configs()
}
# We report migrate messages at one level lower than that
# requested. This ensures we don't get flooded with messages during
# testing (unless you really ask to be flooded).
call_command(
- 'migrate',
+ "migrate",
verbosity=max(verbosity - 1, 0),
interactive=False,
database=self.connection.alias,
run_syncdb=True,
)
finally:
- if self.connection.settings_dict['TEST']['MIGRATE'] is False:
+ if self.connection.settings_dict["TEST"]["MIGRATE"] is False:
settings.MIGRATION_MODULES = old_migration_modules
# We then serialize the current state of the database into a string
@@ -88,12 +93,12 @@ class BaseDatabaseCreation:
if serialize:
self.connection._test_serialized_contents = self.serialize_db_to_string()
- call_command('createcachetable', database=self.connection.alias)
+ call_command("createcachetable", database=self.connection.alias)
# Ensure a connection for the side effect of initializing the test database.
self.connection.ensure_connection()
- if os.environ.get('RUNNING_DJANGOS_TEST_SUITE') == 'true':
+ if os.environ.get("RUNNING_DJANGOS_TEST_SUITE") == "true":
self.mark_expected_failures_and_skips()
return test_database_name
@@ -103,7 +108,7 @@ class BaseDatabaseCreation:
Set this database up to be used in testing as a mirror of a primary
database whose settings are given.
"""
- self.connection.settings_dict['NAME'] = primary_settings_dict['NAME']
+ self.connection.settings_dict["NAME"] = primary_settings_dict["NAME"]
def serialize_db_to_string(self):
"""
@@ -114,22 +119,23 @@ class BaseDatabaseCreation:
# Iteratively return every object for all models to serialize.
def get_objects():
from django.db.migrations.loader import MigrationLoader
+
loader = MigrationLoader(self.connection)
for app_config in apps.get_app_configs():
if (
- app_config.models_module is not None and
- app_config.label in loader.migrated_apps and
- app_config.name not in settings.TEST_NON_SERIALIZED_APPS
+ app_config.models_module is not None
+ and app_config.label in loader.migrated_apps
+ and app_config.name not in settings.TEST_NON_SERIALIZED_APPS
):
for model in app_config.get_models():
- if (
- model._meta.can_migrate(self.connection) and
- router.allow_migrate_model(self.connection.alias, model)
- ):
+ if model._meta.can_migrate(
+ self.connection
+ ) and router.allow_migrate_model(self.connection.alias, model):
queryset = model._base_manager.using(
self.connection.alias,
).order_by(model._meta.pk.name)
yield from queryset.iterator()
+
# Serialize to a string
out = StringIO()
serializers.serialize("json", get_objects(), indent=None, stream=out)
@@ -147,7 +153,9 @@ class BaseDatabaseCreation:
# Disable constraint checks, because some databases (MySQL) doesn't
# support deferred checks.
with self.connection.constraint_checks_disabled():
- for obj in serializers.deserialize('json', data, using=self.connection.alias):
+ for obj in serializers.deserialize(
+ "json", data, using=self.connection.alias
+ ):
obj.save()
table_names.add(obj.object.__class__._meta.db_table)
# Manually check for any invalid keys that might have been added,
@@ -160,7 +168,7 @@ class BaseDatabaseCreation:
"""
return "'%s'%s" % (
self.connection.alias,
- (" ('%s')" % database_name) if verbosity >= 2 else '',
+ (" ('%s')" % database_name) if verbosity >= 2 else "",
)
def _get_test_db_name(self):
@@ -170,12 +178,12 @@ class BaseDatabaseCreation:
_create_test_db() and when no external munging is done with the 'NAME'
settings.
"""
- if self.connection.settings_dict['TEST']['NAME']:
- return self.connection.settings_dict['TEST']['NAME']
- return TEST_DATABASE_PREFIX + self.connection.settings_dict['NAME']
+ if self.connection.settings_dict["TEST"]["NAME"]:
+ return self.connection.settings_dict["TEST"]["NAME"]
+ return TEST_DATABASE_PREFIX + self.connection.settings_dict["NAME"]
def _execute_create_test_db(self, cursor, parameters, keepdb=False):
- cursor.execute('CREATE DATABASE %(dbname)s %(suffix)s' % parameters)
+ cursor.execute("CREATE DATABASE %(dbname)s %(suffix)s" % parameters)
def _create_test_db(self, verbosity, autoclobber, keepdb=False):
"""
@@ -183,8 +191,8 @@ class BaseDatabaseCreation:
"""
test_database_name = self._get_test_db_name()
test_db_params = {
- 'dbname': self.connection.ops.quote_name(test_database_name),
- 'suffix': self.sql_table_creation_suffix(),
+ "dbname": self.connection.ops.quote_name(test_database_name),
+ "suffix": self.sql_table_creation_suffix(),
}
# Create the test database and connect to it.
with self._nodb_cursor() as cursor:
@@ -196,24 +204,30 @@ class BaseDatabaseCreation:
if keepdb:
return test_database_name
- self.log('Got an error creating the test database: %s' % e)
+ self.log("Got an error creating the test database: %s" % e)
if not autoclobber:
confirm = input(
"Type 'yes' if you would like to try deleting the test "
- "database '%s', or 'no' to cancel: " % test_database_name)
- if autoclobber or confirm == 'yes':
+ "database '%s', or 'no' to cancel: " % test_database_name
+ )
+ if autoclobber or confirm == "yes":
try:
if verbosity >= 1:
- self.log('Destroying old test database for alias %s...' % (
- self._get_database_display_str(verbosity, test_database_name),
- ))
- cursor.execute('DROP DATABASE %(dbname)s' % test_db_params)
+ self.log(
+ "Destroying old test database for alias %s..."
+ % (
+ self._get_database_display_str(
+ verbosity, test_database_name
+ ),
+ )
+ )
+ cursor.execute("DROP DATABASE %(dbname)s" % test_db_params)
self._execute_create_test_db(cursor, test_db_params, keepdb)
except Exception as e:
- self.log('Got an error recreating the test database: %s' % e)
+ self.log("Got an error recreating the test database: %s" % e)
sys.exit(2)
else:
- self.log('Tests cancelled.')
+ self.log("Tests cancelled.")
sys.exit(1)
return test_database_name
@@ -222,16 +236,19 @@ class BaseDatabaseCreation:
"""
Clone a test database.
"""
- source_database_name = self.connection.settings_dict['NAME']
+ source_database_name = self.connection.settings_dict["NAME"]
if verbosity >= 1:
- action = 'Cloning test database'
+ action = "Cloning test database"
if keepdb:
- action = 'Using existing clone'
- self.log('%s for alias %s...' % (
- action,
- self._get_database_display_str(verbosity, source_database_name),
- ))
+ action = "Using existing clone"
+ self.log(
+ "%s for alias %s..."
+ % (
+ action,
+ self._get_database_display_str(verbosity, source_database_name),
+ )
+ )
# We could skip this call if keepdb is True, but we instead
# give it the keepdb param. See create_test_db for details.
@@ -245,7 +262,10 @@ class BaseDatabaseCreation:
# already and its name has been copied to settings_dict['NAME'] so
# we don't need to call _get_test_db_name.
orig_settings_dict = self.connection.settings_dict
- return {**orig_settings_dict, 'NAME': '{}_{}'.format(orig_settings_dict['NAME'], suffix)}
+ return {
+ **orig_settings_dict,
+ "NAME": "{}_{}".format(orig_settings_dict["NAME"], suffix),
+ }
def _clone_test_db(self, suffix, verbosity, keepdb=False):
"""
@@ -253,27 +273,33 @@ class BaseDatabaseCreation:
"""
raise NotImplementedError(
"The database backend doesn't support cloning databases. "
- "Disable the option to run tests in parallel processes.")
+ "Disable the option to run tests in parallel processes."
+ )
- def destroy_test_db(self, old_database_name=None, verbosity=1, keepdb=False, suffix=None):
+ def destroy_test_db(
+ self, old_database_name=None, verbosity=1, keepdb=False, suffix=None
+ ):
"""
Destroy a test database, prompting the user for confirmation if the
database already exists.
"""
self.connection.close()
if suffix is None:
- test_database_name = self.connection.settings_dict['NAME']
+ test_database_name = self.connection.settings_dict["NAME"]
else:
- test_database_name = self.get_test_db_clone_settings(suffix)['NAME']
+ test_database_name = self.get_test_db_clone_settings(suffix)["NAME"]
if verbosity >= 1:
- action = 'Destroying'
+ action = "Destroying"
if keepdb:
- action = 'Preserving'
- self.log('%s test database for alias %s...' % (
- action,
- self._get_database_display_str(verbosity, test_database_name),
- ))
+ action = "Preserving"
+ self.log(
+ "%s test database for alias %s..."
+ % (
+ action,
+ self._get_database_display_str(verbosity, test_database_name),
+ )
+ )
# if we want to preserve the database
# skip the actual destroying piece.
@@ -294,8 +320,9 @@ class BaseDatabaseCreation:
# to do so, because it's not allowed to delete a database while being
# connected to it.
with self._nodb_cursor() as cursor:
- cursor.execute("DROP DATABASE %s"
- % self.connection.ops.quote_name(test_database_name))
+ cursor.execute(
+ "DROP DATABASE %s" % self.connection.ops.quote_name(test_database_name)
+ )
def mark_expected_failures_and_skips(self):
"""
@@ -304,9 +331,10 @@ class BaseDatabaseCreation:
"""
# Only load unittest if we're actually testing.
from unittest import expectedFailure, skip
+
for test_name in self.connection.features.django_test_expected_failures:
- test_case_name, _, test_method_name = test_name.rpartition('.')
- test_app = test_name.split('.')[0]
+ test_case_name, _, test_method_name = test_name.rpartition(".")
+ test_app = test_name.split(".")[0]
# Importing a test app that isn't installed raises RuntimeError.
if test_app in settings.INSTALLED_APPS:
test_case = import_string(test_case_name)
@@ -314,8 +342,8 @@ class BaseDatabaseCreation:
setattr(test_case, test_method_name, expectedFailure(test_method))
for reason, tests in self.connection.features.django_test_skips.items():
for test_name in tests:
- test_case_name, _, test_method_name = test_name.rpartition('.')
- test_app = test_name.split('.')[0]
+ test_case_name, _, test_method_name = test_name.rpartition(".")
+ test_app = test_name.split(".")[0]
# Importing a test app that isn't installed raises RuntimeError.
if test_app in settings.INSTALLED_APPS:
test_case = import_string(test_case_name)
@@ -326,7 +354,7 @@ class BaseDatabaseCreation:
"""
SQL to append to the end of the test table creation statements.
"""
- return ''
+ return ""
def test_db_signature(self):
"""
@@ -336,8 +364,8 @@ class BaseDatabaseCreation:
"""
settings_dict = self.connection.settings_dict
return (
- settings_dict['HOST'],
- settings_dict['PORT'],
- settings_dict['ENGINE'],
+ settings_dict["HOST"],
+ settings_dict["PORT"],
+ settings_dict["ENGINE"],
self._get_test_db_name(),
)
diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py
index 20d5c8f772..42399b769a 100644
--- a/django/db/backends/base/features.py
+++ b/django/db/backends/base/features.py
@@ -130,21 +130,21 @@ class BaseDatabaseFeatures:
# Map fields which some backends may not be able to differentiate to the
# field it's introspected as.
introspected_field_types = {
- 'AutoField': 'AutoField',
- 'BigAutoField': 'BigAutoField',
- 'BigIntegerField': 'BigIntegerField',
- 'BinaryField': 'BinaryField',
- 'BooleanField': 'BooleanField',
- 'CharField': 'CharField',
- 'DurationField': 'DurationField',
- 'GenericIPAddressField': 'GenericIPAddressField',
- 'IntegerField': 'IntegerField',
- 'PositiveBigIntegerField': 'PositiveBigIntegerField',
- 'PositiveIntegerField': 'PositiveIntegerField',
- 'PositiveSmallIntegerField': 'PositiveSmallIntegerField',
- 'SmallAutoField': 'SmallAutoField',
- 'SmallIntegerField': 'SmallIntegerField',
- 'TimeField': 'TimeField',
+ "AutoField": "AutoField",
+ "BigAutoField": "BigAutoField",
+ "BigIntegerField": "BigIntegerField",
+ "BinaryField": "BinaryField",
+ "BooleanField": "BooleanField",
+ "CharField": "CharField",
+ "DurationField": "DurationField",
+ "GenericIPAddressField": "GenericIPAddressField",
+ "IntegerField": "IntegerField",
+ "PositiveBigIntegerField": "PositiveBigIntegerField",
+ "PositiveIntegerField": "PositiveIntegerField",
+ "PositiveSmallIntegerField": "PositiveSmallIntegerField",
+ "SmallAutoField": "SmallAutoField",
+ "SmallIntegerField": "SmallIntegerField",
+ "TimeField": "TimeField",
}
# Can the backend introspect the column order (ASC/DESC) for indexes?
@@ -201,7 +201,7 @@ class BaseDatabaseFeatures:
has_case_insensitive_like = False
# Suffix for backends that don't support "SELECT xxx;" queries.
- bare_select_suffix = ''
+ bare_select_suffix = ""
# If NULL is implied on columns without needing to be explicitly specified
implied_column_null = False
@@ -325,10 +325,10 @@ class BaseDatabaseFeatures:
# Collation names for use by the Django test suite.
test_collations = {
- 'ci': None, # Case-insensitive.
- 'cs': None, # Case-sensitive.
- 'non_default': None, # Non-default.
- 'swedish_ci': None # Swedish case-insensitive.
+ "ci": None, # Case-insensitive.
+ "cs": None, # Case-sensitive.
+ "non_default": None, # Non-default.
+ "swedish_ci": None, # Swedish case-insensitive.
}
# SQL template override for tests.aggregation.tests.NowUTC
test_now_utc_template = None
@@ -352,14 +352,14 @@ class BaseDatabaseFeatures:
def supports_transactions(self):
"""Confirm support for transactions."""
with self.connection.cursor() as cursor:
- cursor.execute('CREATE TABLE ROLLBACK_TEST (X INT)')
+ cursor.execute("CREATE TABLE ROLLBACK_TEST (X INT)")
self.connection.set_autocommit(False)
- cursor.execute('INSERT INTO ROLLBACK_TEST (X) VALUES (8)')
+ cursor.execute("INSERT INTO ROLLBACK_TEST (X) VALUES (8)")
self.connection.rollback()
self.connection.set_autocommit(True)
- cursor.execute('SELECT COUNT(X) FROM ROLLBACK_TEST')
- count, = cursor.fetchone()
- cursor.execute('DROP TABLE ROLLBACK_TEST')
+ cursor.execute("SELECT COUNT(X) FROM ROLLBACK_TEST")
+ (count,) = cursor.fetchone()
+ cursor.execute("DROP TABLE ROLLBACK_TEST")
return count == 0
def allows_group_by_selected_pks_on_model(self, model):
diff --git a/django/db/backends/base/introspection.py b/django/db/backends/base/introspection.py
index 079c1835b0..c8036ef1e9 100644
--- a/django/db/backends/base/introspection.py
+++ b/django/db/backends/base/introspection.py
@@ -1,18 +1,19 @@
from collections import namedtuple
# Structure returned by DatabaseIntrospection.get_table_list()
-TableInfo = namedtuple('TableInfo', ['name', 'type'])
+TableInfo = namedtuple("TableInfo", ["name", "type"])
# Structure returned by the DB-API cursor.description interface (PEP 249)
FieldInfo = namedtuple(
- 'FieldInfo',
- 'name type_code display_size internal_size precision scale null_ok '
- 'default collation'
+ "FieldInfo",
+ "name type_code display_size internal_size precision scale null_ok "
+ "default collation",
)
class BaseDatabaseIntrospection:
"""Encapsulate backend-specific introspection utilities."""
+
data_types_reverse = {}
def __init__(self, connection):
@@ -43,9 +44,14 @@ class BaseDatabaseIntrospection:
the database's ORDER BY here to avoid subtle differences in sorting
order between databases.
"""
+
def get_names(cursor):
- return sorted(ti.name for ti in self.get_table_list(cursor)
- if include_views or ti.type == 't')
+ return sorted(
+ ti.name
+ for ti in self.get_table_list(cursor)
+ if include_views or ti.type == "t"
+ )
+
if cursor is None:
with self.connection.cursor() as cursor:
return get_names(cursor)
@@ -56,7 +62,9 @@ class BaseDatabaseIntrospection:
Return an unsorted list of TableInfo named tuples of all tables and
views that exist in the database.
"""
- raise NotImplementedError('subclasses of BaseDatabaseIntrospection may require a get_table_list() method')
+ raise NotImplementedError(
+ "subclasses of BaseDatabaseIntrospection may require a get_table_list() method"
+ )
def get_table_description(self, cursor, table_name):
"""
@@ -64,13 +72,14 @@ class BaseDatabaseIntrospection:
interface.
"""
raise NotImplementedError(
- 'subclasses of BaseDatabaseIntrospection may require a '
- 'get_table_description() method.'
+ "subclasses of BaseDatabaseIntrospection may require a "
+ "get_table_description() method."
)
def get_migratable_models(self):
from django.apps import apps
from django.db import router
+
return (
model
for app_config in apps.get_app_configs()
@@ -91,16 +100,15 @@ class BaseDatabaseIntrospection:
continue
tables.add(model._meta.db_table)
tables.update(
- f.m2m_db_table() for f in model._meta.local_many_to_many
+ f.m2m_db_table()
+ for f in model._meta.local_many_to_many
if f.remote_field.through._meta.managed
)
tables = list(tables)
if only_existing:
existing_tables = set(self.table_names(include_views=include_views))
tables = [
- t
- for t in tables
- if self.identifier_converter(t) in existing_tables
+ t for t in tables if self.identifier_converter(t) in existing_tables
]
return tables
@@ -111,7 +119,8 @@ class BaseDatabaseIntrospection:
"""
tables = set(map(self.identifier_converter, tables))
return {
- m for m in self.get_migratable_models()
+ m
+ for m in self.get_migratable_models()
if self.identifier_converter(m._meta.db_table) in tables
}
@@ -127,13 +136,19 @@ class BaseDatabaseIntrospection:
continue
if model._meta.swapped:
continue
- sequence_list.extend(self.get_sequences(cursor, model._meta.db_table, model._meta.local_fields))
+ sequence_list.extend(
+ self.get_sequences(
+ cursor, model._meta.db_table, model._meta.local_fields
+ )
+ )
for f in model._meta.local_many_to_many:
# If this is an m2m using an intermediate table,
# we don't need to reset the sequence.
if f.remote_field.through._meta.auto_created:
sequence = self.get_sequences(cursor, f.m2m_db_table())
- sequence_list.extend(sequence or [{'table': f.m2m_db_table(), 'column': None}])
+ sequence_list.extend(
+ sequence or [{"table": f.m2m_db_table(), "column": None}]
+ )
return sequence_list
def get_sequences(self, cursor, table_name, table_fields=()):
@@ -142,7 +157,9 @@ class BaseDatabaseIntrospection:
is a dict: {'table': <table_name>, 'column': <column_name>}. An optional
'name' key can be added if the backend supports named sequences.
"""
- raise NotImplementedError('subclasses of BaseDatabaseIntrospection may require a get_sequences() method')
+ raise NotImplementedError(
+ "subclasses of BaseDatabaseIntrospection may require a get_sequences() method"
+ )
def get_relations(self, cursor, table_name):
"""
@@ -150,8 +167,8 @@ class BaseDatabaseIntrospection:
representing all foreign keys in the given table.
"""
raise NotImplementedError(
- 'subclasses of BaseDatabaseIntrospection may require a '
- 'get_relations() method.'
+ "subclasses of BaseDatabaseIntrospection may require a "
+ "get_relations() method."
)
def get_primary_key_column(self, cursor, table_name):
@@ -159,8 +176,8 @@ class BaseDatabaseIntrospection:
Return the name of the primary key column for the given table.
"""
for constraint in self.get_constraints(cursor, table_name).values():
- if constraint['primary_key']:
- return constraint['columns'][0]
+ if constraint["primary_key"]:
+ return constraint["columns"][0]
return None
def get_constraints(self, cursor, table_name):
@@ -182,4 +199,6 @@ class BaseDatabaseIntrospection:
Some backends may return special constraint names that don't exist
if they don't name constraints of a certain type (e.g. SQLite)
"""
- raise NotImplementedError('subclasses of BaseDatabaseIntrospection may require a get_constraints() method')
+ raise NotImplementedError(
+ "subclasses of BaseDatabaseIntrospection may require a get_constraints() method"
+ )
diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py
index 7422137304..5201e53af6 100644
--- a/django/db/backends/base/operations.py
+++ b/django/db/backends/base/operations.py
@@ -16,25 +16,26 @@ class BaseDatabaseOperations:
Encapsulate backend-specific differences, such as the way a backend
performs ordering or calculates the ID of a recently-inserted row.
"""
+
compiler_module = "django.db.models.sql.compiler"
# Integer field safe ranges by `internal_type` as documented
# in docs/ref/models/fields.txt.
integer_field_ranges = {
- 'SmallIntegerField': (-32768, 32767),
- 'IntegerField': (-2147483648, 2147483647),
- 'BigIntegerField': (-9223372036854775808, 9223372036854775807),
- 'PositiveBigIntegerField': (0, 9223372036854775807),
- 'PositiveSmallIntegerField': (0, 32767),
- 'PositiveIntegerField': (0, 2147483647),
- 'SmallAutoField': (-32768, 32767),
- 'AutoField': (-2147483648, 2147483647),
- 'BigAutoField': (-9223372036854775808, 9223372036854775807),
+ "SmallIntegerField": (-32768, 32767),
+ "IntegerField": (-2147483648, 2147483647),
+ "BigIntegerField": (-9223372036854775808, 9223372036854775807),
+ "PositiveBigIntegerField": (0, 9223372036854775807),
+ "PositiveSmallIntegerField": (0, 32767),
+ "PositiveIntegerField": (0, 2147483647),
+ "SmallAutoField": (-32768, 32767),
+ "AutoField": (-2147483648, 2147483647),
+ "BigAutoField": (-9223372036854775808, 9223372036854775807),
}
set_operators = {
- 'union': 'UNION',
- 'intersection': 'INTERSECT',
- 'difference': 'EXCEPT',
+ "union": "UNION",
+ "intersection": "INTERSECT",
+ "difference": "EXCEPT",
}
# Mapping of Field.get_internal_type() (typically the model field's class
# name) to the data type to use for the Cast() function, if different from
@@ -44,11 +45,11 @@ class BaseDatabaseOperations:
cast_char_field_without_max_length = None
# Start and end points for window expressions.
- PRECEDING = 'PRECEDING'
- FOLLOWING = 'FOLLOWING'
- UNBOUNDED_PRECEDING = 'UNBOUNDED ' + PRECEDING
- UNBOUNDED_FOLLOWING = 'UNBOUNDED ' + FOLLOWING
- CURRENT_ROW = 'CURRENT ROW'
+ PRECEDING = "PRECEDING"
+ FOLLOWING = "FOLLOWING"
+ UNBOUNDED_PRECEDING = "UNBOUNDED " + PRECEDING
+ UNBOUNDED_FOLLOWING = "UNBOUNDED " + FOLLOWING
+ CURRENT_ROW = "CURRENT ROW"
# Prefix for EXPLAIN queries, or None EXPLAIN isn't supported.
explain_prefix = None
@@ -76,8 +77,8 @@ class BaseDatabaseOperations:
def format_for_duration_arithmetic(self, sql):
raise NotImplementedError(
- 'subclasses of BaseDatabaseOperations may require a '
- 'format_for_duration_arithmetic() method.'
+ "subclasses of BaseDatabaseOperations may require a "
+ "format_for_duration_arithmetic() method."
)
def cache_key_culling_sql(self):
@@ -88,8 +89,8 @@ class BaseDatabaseOperations:
This is used by the 'db' cache backend to determine where to start
culling.
"""
- cache_key = self.quote_name('cache_key')
- return f'SELECT {cache_key} FROM %s ORDER BY {cache_key} LIMIT 1 OFFSET %%s'
+ cache_key = self.quote_name("cache_key")
+ return f"SELECT {cache_key} FROM %s ORDER BY {cache_key} LIMIT 1 OFFSET %%s"
def unification_cast_sql(self, output_field):
"""
@@ -97,14 +98,16 @@ class BaseDatabaseOperations:
to that type. The resulting string should contain a '%s' placeholder
for the expression being cast.
"""
- return '%s'
+ return "%s"
def date_extract_sql(self, lookup_type, field_name):
"""
Given a lookup_type of 'year', 'month', or 'day', return the SQL that
extracts a value from the given date field field_name.
"""
- raise NotImplementedError('subclasses of BaseDatabaseOperations may require a date_extract_sql() method')
+ raise NotImplementedError(
+ "subclasses of BaseDatabaseOperations may require a date_extract_sql() method"
+ )
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
"""
@@ -115,22 +118,26 @@ class BaseDatabaseOperations:
If `tzname` is provided, the given value is truncated in a specific
timezone.
"""
- raise NotImplementedError('subclasses of BaseDatabaseOperations may require a date_trunc_sql() method.')
+ raise NotImplementedError(
+ "subclasses of BaseDatabaseOperations may require a date_trunc_sql() method."
+ )
def datetime_cast_date_sql(self, field_name, tzname):
"""
Return the SQL to cast a datetime value to date value.
"""
raise NotImplementedError(
- 'subclasses of BaseDatabaseOperations may require a '
- 'datetime_cast_date_sql() method.'
+ "subclasses of BaseDatabaseOperations may require a "
+ "datetime_cast_date_sql() method."
)
def datetime_cast_time_sql(self, field_name, tzname):
"""
Return the SQL to cast a datetime value to time value.
"""
- raise NotImplementedError('subclasses of BaseDatabaseOperations may require a datetime_cast_time_sql() method')
+ raise NotImplementedError(
+ "subclasses of BaseDatabaseOperations may require a datetime_cast_time_sql() method"
+ )
def datetime_extract_sql(self, lookup_type, field_name, tzname):
"""
@@ -138,7 +145,9 @@ class BaseDatabaseOperations:
'second', return the SQL that extracts a value from the given
datetime field field_name.
"""
- raise NotImplementedError('subclasses of BaseDatabaseOperations may require a datetime_extract_sql() method')
+ raise NotImplementedError(
+ "subclasses of BaseDatabaseOperations may require a datetime_extract_sql() method"
+ )
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
"""
@@ -146,7 +155,9 @@ class BaseDatabaseOperations:
'second', return the SQL that truncates the given datetime field
field_name to a datetime object with only the given specificity.
"""
- raise NotImplementedError('subclasses of BaseDatabaseOperations may require a datetime_trunc_sql() method')
+ raise NotImplementedError(
+ "subclasses of BaseDatabaseOperations may require a datetime_trunc_sql() method"
+ )
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
"""
@@ -157,7 +168,9 @@ class BaseDatabaseOperations:
If `tzname` is provided, the given value is truncated in a specific
timezone.
"""
- raise NotImplementedError('subclasses of BaseDatabaseOperations may require a time_trunc_sql() method')
+ raise NotImplementedError(
+ "subclasses of BaseDatabaseOperations may require a time_trunc_sql() method"
+ )
def time_extract_sql(self, lookup_type, field_name):
"""
@@ -171,7 +184,7 @@ class BaseDatabaseOperations:
Return the SQL to make a constraint "initially deferred" during a
CREATE TABLE statement.
"""
- return ''
+ return ""
def distinct_sql(self, fields, params):
"""
@@ -180,9 +193,11 @@ class BaseDatabaseOperations:
duplicates.
"""
if fields:
- raise NotSupportedError('DISTINCT ON fields is not supported by this database backend')
+ raise NotSupportedError(
+ "DISTINCT ON fields is not supported by this database backend"
+ )
else:
- return ['DISTINCT'], []
+ return ["DISTINCT"], []
def fetch_returned_insert_columns(self, cursor, returning_params):
"""
@@ -198,7 +213,7 @@ class BaseDatabaseOperations:
it in a WHERE statement. The resulting string should contain a '%s'
placeholder for the column being searched against.
"""
- return '%s'
+ return "%s"
def force_no_ordering(self):
"""
@@ -211,11 +226,11 @@ class BaseDatabaseOperations:
"""
Return the FOR UPDATE SQL clause to lock rows for an update operation.
"""
- return 'FOR%s UPDATE%s%s%s' % (
- ' NO KEY' if no_key else '',
- ' OF %s' % ', '.join(of) if of else '',
- ' NOWAIT' if nowait else '',
- ' SKIP LOCKED' if skip_locked else '',
+ return "FOR%s UPDATE%s%s%s" % (
+ " NO KEY" if no_key else "",
+ " OF %s" % ", ".join(of) if of else "",
+ " NOWAIT" if nowait else "",
+ " SKIP LOCKED" if skip_locked else "",
)
def _get_limit_offset_params(self, low_mark, high_mark):
@@ -229,10 +244,14 @@ class BaseDatabaseOperations:
def limit_offset_sql(self, low_mark, high_mark):
"""Return LIMIT/OFFSET SQL clause."""
limit, offset = self._get_limit_offset_params(low_mark, high_mark)
- return ' '.join(sql for sql in (
- ('LIMIT %d' % limit) if limit else None,
- ('OFFSET %d' % offset) if offset else None,
- ) if sql)
+ return " ".join(
+ sql
+ for sql in (
+ ("LIMIT %d" % limit) if limit else None,
+ ("OFFSET %d" % offset) if offset else None,
+ )
+ if sql
+ )
def last_executed_query(self, cursor, sql, params):
"""
@@ -246,7 +265,8 @@ class BaseDatabaseOperations:
"""
# Convert params to contain string values.
def to_string(s):
- return force_str(s, strings_only=True, errors='replace')
+ return force_str(s, strings_only=True, errors="replace")
+
if isinstance(params, (list, tuple)):
u_params = tuple(to_string(val) for val in params)
elif params is None:
@@ -292,14 +312,16 @@ class BaseDatabaseOperations:
Return the value to use for the LIMIT when we are wanting "LIMIT
infinity". Return None if the limit clause can be omitted in this case.
"""
- raise NotImplementedError('subclasses of BaseDatabaseOperations may require a no_limit_value() method')
+ raise NotImplementedError(
+ "subclasses of BaseDatabaseOperations may require a no_limit_value() method"
+ )
def pk_default_value(self):
"""
Return the value to use during an INSERT statement to specify that
the field should use its default value.
"""
- return 'DEFAULT'
+ return "DEFAULT"
def prepare_sql_script(self, sql):
"""
@@ -312,7 +334,8 @@ class BaseDatabaseOperations:
"""
return [
sqlparse.format(statement, strip_comments=True)
- for statement in sqlparse.split(sql) if statement
+ for statement in sqlparse.split(sql)
+ if statement
]
def process_clob(self, value):
@@ -345,7 +368,9 @@ class BaseDatabaseOperations:
Return a quoted version of the given table, index, or column name. Do
not quote the given name if it's already been quoted.
"""
- raise NotImplementedError('subclasses of BaseDatabaseOperations may require a quote_name() method')
+ raise NotImplementedError(
+ "subclasses of BaseDatabaseOperations may require a quote_name() method"
+ )
def regex_lookup(self, lookup_type):
"""
@@ -356,7 +381,9 @@ class BaseDatabaseOperations:
If the feature is not supported (or part of it is not supported), raise
NotImplementedError.
"""
- raise NotImplementedError('subclasses of BaseDatabaseOperations may require a regex_lookup() method')
+ raise NotImplementedError(
+ "subclasses of BaseDatabaseOperations may require a regex_lookup() method"
+ )
def savepoint_create_sql(self, sid):
"""
@@ -384,7 +411,7 @@ class BaseDatabaseOperations:
Return '' if the backend doesn't support time zones.
"""
- return ''
+ return ""
def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
"""
@@ -402,7 +429,9 @@ class BaseDatabaseOperations:
to tables with foreign keys pointing the tables being truncated.
PostgreSQL requires a cascade even if these tables are empty.
"""
- raise NotImplementedError('subclasses of BaseDatabaseOperations must provide an sql_flush() method')
+ raise NotImplementedError(
+ "subclasses of BaseDatabaseOperations must provide an sql_flush() method"
+ )
def execute_sql_flush(self, sql_list):
"""Execute a list of SQL statements to flush the database."""
@@ -453,7 +482,7 @@ class BaseDatabaseOperations:
If `inline` is True, append the SQL to a row; otherwise append it to
the entire CREATE TABLE or CREATE INDEX statement.
"""
- return ''
+ return ""
def prep_for_like_query(self, x):
"""Prepare a value for use in a LIKE query."""
@@ -479,7 +508,7 @@ class BaseDatabaseOperations:
cases where the target type isn't known, such as .raw() SQL queries.
As a consequence it may not work perfectly in all circumstances.
"""
- if isinstance(value, datetime.datetime): # must be before date
+ if isinstance(value, datetime.datetime): # must be before date
return self.adapt_datetimefield_value(value)
elif isinstance(value, datetime.date):
return self.adapt_datefield_value(value)
@@ -507,7 +536,7 @@ class BaseDatabaseOperations:
if value is None:
return None
# Expression values are adapted by the database.
- if hasattr(value, 'resolve_expression'):
+ if hasattr(value, "resolve_expression"):
return value
return str(value)
@@ -520,7 +549,7 @@ class BaseDatabaseOperations:
if value is None:
return None
# Expression values are adapted by the database.
- if hasattr(value, 'resolve_expression'):
+ if hasattr(value, "resolve_expression"):
return value
if timezone.is_aware(value):
@@ -552,10 +581,9 @@ class BaseDatabaseOperations:
"""
if iso_year:
first = datetime.date.fromisocalendar(value, 1, 1)
- second = (
- datetime.date.fromisocalendar(value + 1, 1, 1) -
- datetime.timedelta(days=1)
- )
+ second = datetime.date.fromisocalendar(
+ value + 1, 1, 1
+ ) - datetime.timedelta(days=1)
else:
first = datetime.date(value, 1, 1)
second = datetime.date(value, 12, 31)
@@ -574,10 +602,9 @@ class BaseDatabaseOperations:
"""
if iso_year:
first = datetime.datetime.fromisocalendar(value, 1, 1)
- second = (
- datetime.datetime.fromisocalendar(value + 1, 1, 1) -
- datetime.timedelta(microseconds=1)
- )
+ second = datetime.datetime.fromisocalendar(
+ value + 1, 1, 1
+ ) - datetime.timedelta(microseconds=1)
else:
first = datetime.datetime(value, 1, 1)
second = datetime.datetime(value, 12, 31, 23, 59, 59, 999999)
@@ -627,7 +654,7 @@ class BaseDatabaseOperations:
can vary between backends (e.g., Oracle with %% and &) and between
subexpression types (e.g., date expressions).
"""
- conn = ' %s ' % connector
+ conn = " %s " % connector
return conn.join(sub_expressions)
def combine_duration_expression(self, connector, sub_expressions):
@@ -638,7 +665,7 @@ class BaseDatabaseOperations:
Some backends require special syntax to insert binary content (MySQL
for example uses '_binary %s').
"""
- return '%s'
+ return "%s"
def modify_insert_params(self, placeholder, params):
"""
@@ -659,66 +686,76 @@ class BaseDatabaseOperations:
if self.connection.features.supports_temporal_subtraction:
lhs_sql, lhs_params = lhs
rhs_sql, rhs_params = rhs
- return '(%s - %s)' % (lhs_sql, rhs_sql), (*lhs_params, *rhs_params)
- raise NotSupportedError("This backend does not support %s subtraction." % internal_type)
+ return "(%s - %s)" % (lhs_sql, rhs_sql), (*lhs_params, *rhs_params)
+ raise NotSupportedError(
+ "This backend does not support %s subtraction." % internal_type
+ )
def window_frame_start(self, start):
if isinstance(start, int):
if start < 0:
- return '%d %s' % (abs(start), self.PRECEDING)
+ return "%d %s" % (abs(start), self.PRECEDING)
elif start == 0:
return self.CURRENT_ROW
elif start is None:
return self.UNBOUNDED_PRECEDING
- raise ValueError("start argument must be a negative integer, zero, or None, but got '%s'." % start)
+ raise ValueError(
+ "start argument must be a negative integer, zero, or None, but got '%s'."
+ % start
+ )
def window_frame_end(self, end):
if isinstance(end, int):
if end == 0:
return self.CURRENT_ROW
elif end > 0:
- return '%d %s' % (end, self.FOLLOWING)
+ return "%d %s" % (end, self.FOLLOWING)
elif end is None:
return self.UNBOUNDED_FOLLOWING
- raise ValueError("end argument must be a positive integer, zero, or None, but got '%s'." % end)
+ raise ValueError(
+ "end argument must be a positive integer, zero, or None, but got '%s'."
+ % end
+ )
def window_frame_rows_start_end(self, start=None, end=None):
"""
Return SQL for start and end points in an OVER clause window frame.
"""
if not self.connection.features.supports_over_clause:
- raise NotSupportedError('This backend does not support window expressions.')
+ raise NotSupportedError("This backend does not support window expressions.")
return self.window_frame_start(start), self.window_frame_end(end)
def window_frame_range_start_end(self, start=None, end=None):
start_, end_ = self.window_frame_rows_start_end(start, end)
if (
- self.connection.features.only_supports_unbounded_with_preceding_and_following and
- ((start and start < 0) or (end and end > 0))
+ self.connection.features.only_supports_unbounded_with_preceding_and_following
+ and ((start and start < 0) or (end and end > 0))
):
raise NotSupportedError(
- '%s only supports UNBOUNDED together with PRECEDING and '
- 'FOLLOWING.' % self.connection.display_name
+ "%s only supports UNBOUNDED together with PRECEDING and "
+ "FOLLOWING." % self.connection.display_name
)
return start_, end_
def explain_query_prefix(self, format=None, **options):
if not self.connection.features.supports_explaining_query_execution:
- raise NotSupportedError('This backend does not support explaining query execution.')
+ raise NotSupportedError(
+ "This backend does not support explaining query execution."
+ )
if format:
supported_formats = self.connection.features.supported_explain_formats
normalized_format = format.upper()
if normalized_format not in supported_formats:
- msg = '%s is not a recognized format.' % normalized_format
+ msg = "%s is not a recognized format." % normalized_format
if supported_formats:
- msg += ' Allowed formats: %s' % ', '.join(sorted(supported_formats))
+ msg += " Allowed formats: %s" % ", ".join(sorted(supported_formats))
raise ValueError(msg)
if options:
- raise ValueError('Unknown options: %s' % ', '.join(sorted(options.keys())))
+ raise ValueError("Unknown options: %s" % ", ".join(sorted(options.keys())))
return self.explain_prefix
def insert_statement(self, on_conflict=None):
- return 'INSERT INTO'
+ return "INSERT INTO"
def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
- return ''
+ return ""
diff --git a/django/db/backends/base/schema.py b/django/db/backends/base/schema.py
index 4cd4567cbc..ea98e86b77 100644
--- a/django/db/backends/base/schema.py
+++ b/django/db/backends/base/schema.py
@@ -2,7 +2,12 @@ import logging
from datetime import datetime
from django.db.backends.ddl_references import (
- Columns, Expressions, ForeignKeyName, IndexName, Statement, Table,
+ Columns,
+ Expressions,
+ ForeignKeyName,
+ IndexName,
+ Statement,
+ Table,
)
from django.db.backends.utils import names_digest, split_identifier
from django.db.models import Deferrable, Index
@@ -10,7 +15,7 @@ from django.db.models.sql import Query
from django.db.transaction import TransactionManagementError, atomic
from django.utils import timezone
-logger = logging.getLogger('django.db.backends.schema')
+logger = logging.getLogger("django.db.backends.schema")
def _is_relevant_relation(relation, altered_field):
@@ -31,7 +36,10 @@ def _is_relevant_relation(relation, altered_field):
def _all_related_fields(model):
return model._meta._get_fields(
- forward=False, reverse=True, include_hidden=True, include_parents=False,
+ forward=False,
+ reverse=True,
+ include_hidden=True,
+ include_parents=False,
)
@@ -39,8 +47,16 @@ def _related_non_m2m_objects(old_field, new_field):
# Filter out m2m objects from reverse relations.
# Return (old_relation, new_relation) tuples.
related_fields = zip(
- (obj for obj in _all_related_fields(old_field.model) if _is_relevant_relation(obj, old_field)),
- (obj for obj in _all_related_fields(new_field.model) if _is_relevant_relation(obj, new_field)),
+ (
+ obj
+ for obj in _all_related_fields(old_field.model)
+ if _is_relevant_relation(obj, old_field)
+ ),
+ (
+ obj
+ for obj in _all_related_fields(new_field.model)
+ if _is_relevant_relation(obj, new_field)
+ ),
)
for old_rel, new_rel in related_fields:
yield old_rel, new_rel
@@ -73,8 +89,12 @@ class BaseDatabaseSchemaEditor:
sql_alter_column_no_default_null = sql_alter_column_no_default
sql_alter_column_collate = "ALTER COLUMN %(column)s TYPE %(type)s%(collation)s"
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s CASCADE"
- sql_rename_column = "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s"
- sql_update_with_default = "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL"
+ sql_rename_column = (
+ "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s"
+ )
+ sql_update_with_default = (
+ "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL"
+ )
sql_unique_constraint = "UNIQUE (%(columns)s)%(deferrable)s"
sql_check_constraint = "CHECK (%(check)s)"
@@ -99,10 +119,12 @@ class BaseDatabaseSchemaEditor:
sql_create_unique_index = "CREATE UNIQUE INDEX %(name)s ON %(table)s (%(columns)s)%(include)s%(condition)s"
sql_delete_index = "DROP INDEX %(name)s"
- sql_create_pk = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s PRIMARY KEY (%(columns)s)"
+ sql_create_pk = (
+ "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s PRIMARY KEY (%(columns)s)"
+ )
sql_delete_pk = sql_delete_constraint
- sql_delete_procedure = 'DROP PROCEDURE %(procedure)s'
+ sql_delete_procedure = "DROP PROCEDURE %(procedure)s"
def __init__(self, connection, collect_sql=False, atomic=True):
self.connection = connection
@@ -133,7 +155,11 @@ class BaseDatabaseSchemaEditor:
"""Execute the given SQL statement, with optional parameters."""
# Don't perform the transactional DDL check if SQL is being collected
# as it's not going to be executed anyway.
- if not self.collect_sql and self.connection.in_atomic_block and not self.connection.features.can_rollback_ddl:
+ if (
+ not self.collect_sql
+ and self.connection.in_atomic_block
+ and not self.connection.features.can_rollback_ddl
+ ):
raise TransactionManagementError(
"Executing DDL statements while in a transaction on databases "
"that can't perform a rollback is prohibited."
@@ -141,11 +167,15 @@ class BaseDatabaseSchemaEditor:
# Account for non-string statement objects.
sql = str(sql)
# Log the command we're running, then run it
- logger.debug("%s; (params %r)", sql, params, extra={'params': params, 'sql': sql})
+ logger.debug(
+ "%s; (params %r)", sql, params, extra={"params": params, "sql": sql}
+ )
if self.collect_sql:
ending = "" if sql.rstrip().endswith(";") else ";"
if params is not None:
- self.collected_sql.append((sql % tuple(map(self.quote_value, params))) + ending)
+ self.collected_sql.append(
+ (sql % tuple(map(self.quote_value, params))) + ending
+ )
else:
self.collected_sql.append(sql + ending)
else:
@@ -172,59 +202,82 @@ class BaseDatabaseSchemaEditor:
continue
# Check constraints can go on the column SQL here.
db_params = field.db_parameters(connection=self.connection)
- if db_params['check']:
- definition += ' ' + self.sql_check_constraint % db_params
+ if db_params["check"]:
+ definition += " " + self.sql_check_constraint % db_params
# Autoincrement SQL (for backends with inline variant).
col_type_suffix = field.db_type_suffix(connection=self.connection)
if col_type_suffix:
- definition += ' %s' % col_type_suffix
+ definition += " %s" % col_type_suffix
params.extend(extra_params)
# FK.
if field.remote_field and field.db_constraint:
to_table = field.remote_field.model._meta.db_table
- to_column = field.remote_field.model._meta.get_field(field.remote_field.field_name).column
+ to_column = field.remote_field.model._meta.get_field(
+ field.remote_field.field_name
+ ).column
if self.sql_create_inline_fk:
- definition += ' ' + self.sql_create_inline_fk % {
- 'to_table': self.quote_name(to_table),
- 'to_column': self.quote_name(to_column),
+ definition += " " + self.sql_create_inline_fk % {
+ "to_table": self.quote_name(to_table),
+ "to_column": self.quote_name(to_column),
}
elif self.connection.features.supports_foreign_keys:
- self.deferred_sql.append(self._create_fk_sql(model, field, '_fk_%(to_table)s_%(to_column)s'))
+ self.deferred_sql.append(
+ self._create_fk_sql(
+ model, field, "_fk_%(to_table)s_%(to_column)s"
+ )
+ )
# Add the SQL to our big list.
- column_sqls.append('%s %s' % (
- self.quote_name(field.column),
- definition,
- ))
+ column_sqls.append(
+ "%s %s"
+ % (
+ self.quote_name(field.column),
+ definition,
+ )
+ )
# Autoincrement SQL (for backends with post table definition
# variant).
- if field.get_internal_type() in ('AutoField', 'BigAutoField', 'SmallAutoField'):
- autoinc_sql = self.connection.ops.autoinc_sql(model._meta.db_table, field.column)
+ if field.get_internal_type() in (
+ "AutoField",
+ "BigAutoField",
+ "SmallAutoField",
+ ):
+ autoinc_sql = self.connection.ops.autoinc_sql(
+ model._meta.db_table, field.column
+ )
if autoinc_sql:
self.deferred_sql.extend(autoinc_sql)
- constraints = [constraint.constraint_sql(model, self) for constraint in model._meta.constraints]
+ constraints = [
+ constraint.constraint_sql(model, self)
+ for constraint in model._meta.constraints
+ ]
sql = self.sql_create_table % {
- 'table': self.quote_name(model._meta.db_table),
- 'definition': ', '.join(constraint for constraint in (*column_sqls, *constraints) if constraint),
+ "table": self.quote_name(model._meta.db_table),
+ "definition": ", ".join(
+ constraint for constraint in (*column_sqls, *constraints) if constraint
+ ),
}
if model._meta.db_tablespace:
- tablespace_sql = self.connection.ops.tablespace_sql(model._meta.db_tablespace)
+ tablespace_sql = self.connection.ops.tablespace_sql(
+ model._meta.db_tablespace
+ )
if tablespace_sql:
- sql += ' ' + tablespace_sql
+ sql += " " + tablespace_sql
return sql, params
# Field <-> database mapping functions
def _iter_column_sql(self, column_db_type, params, model, field, include_default):
yield column_db_type
- collation = getattr(field, 'db_collation', None)
+ collation = getattr(field, "db_collation", None)
if collation:
yield self._collate_sql(collation)
# Work out nullability.
null = field.null
# Include a default value, if requested.
include_default = (
- include_default and
- not self.skip_default(field) and
+ include_default
+ and not self.skip_default(field)
+ and
# Don't include a default value if it's a nullable field and the
# default cannot be dropped in the ALTER COLUMN statement (e.g.
# MySQL longtext and longblob).
@@ -233,7 +286,7 @@ class BaseDatabaseSchemaEditor:
if include_default:
default_value = self.effective_default(field)
if default_value is not None:
- column_default = 'DEFAULT ' + self._column_default_sql(field)
+ column_default = "DEFAULT " + self._column_default_sql(field)
if self.connection.features.requires_literal_defaults:
# Some databases can't take defaults as a parameter (Oracle).
# If this is the case, the individual schema backend should
@@ -244,20 +297,27 @@ class BaseDatabaseSchemaEditor:
params.append(default_value)
# Oracle treats the empty string ('') as null, so coerce the null
# option whenever '' is a possible value.
- if (field.empty_strings_allowed and not field.primary_key and
- self.connection.features.interprets_empty_strings_as_nulls):
+ if (
+ field.empty_strings_allowed
+ and not field.primary_key
+ and self.connection.features.interprets_empty_strings_as_nulls
+ ):
null = True
if not null:
- yield 'NOT NULL'
+ yield "NOT NULL"
elif not self.connection.features.implied_column_null:
- yield 'NULL'
+ yield "NULL"
if field.primary_key:
- yield 'PRIMARY KEY'
+ yield "PRIMARY KEY"
elif field.unique:
- yield 'UNIQUE'
+ yield "UNIQUE"
# Optionally add the tablespace if it's an implicitly indexed column.
tablespace = field.db_tablespace or model._meta.db_tablespace
- if tablespace and self.connection.features.supports_tablespaces and field.unique:
+ if (
+ tablespace
+ and self.connection.features.supports_tablespaces
+ and field.unique
+ ):
yield self.connection.ops.tablespace_sql(tablespace, inline=True)
def column_sql(self, model, field, include_default=False):
@@ -267,15 +327,20 @@ class BaseDatabaseSchemaEditor:
"""
# Get the column's type and use that as the basis of the SQL.
db_params = field.db_parameters(connection=self.connection)
- column_db_type = db_params['type']
+ column_db_type = db_params["type"]
# Check for fields that aren't actually columns (e.g. M2M).
if column_db_type is None:
return None, None
params = []
- return ' '.join(
- # This appends to the params being returned.
- self._iter_column_sql(column_db_type, params, model, field, include_default)
- ), params
+ return (
+ " ".join(
+ # This appends to the params being returned.
+ self._iter_column_sql(
+ column_db_type, params, model, field, include_default
+ )
+ ),
+ params,
+ )
def skip_default(self, field):
"""
@@ -296,8 +361,8 @@ class BaseDatabaseSchemaEditor:
Only used for backends which have requires_literal_defaults feature
"""
raise NotImplementedError(
- 'subclasses of BaseDatabaseSchemaEditor for backends which have '
- 'requires_literal_defaults must provide a prepare_default() method'
+ "subclasses of BaseDatabaseSchemaEditor for backends which have "
+ "requires_literal_defaults must provide a prepare_default() method"
)
def _column_default_sql(self, field):
@@ -305,7 +370,7 @@ class BaseDatabaseSchemaEditor:
Return the SQL to use in a DEFAULT clause. The resulting string should
contain a '%s' placeholder for a default value.
"""
- return '%s'
+ return "%s"
@staticmethod
def _effective_default(field):
@@ -314,18 +379,18 @@ class BaseDatabaseSchemaEditor:
default = field.get_default()
elif not field.null and field.blank and field.empty_strings_allowed:
if field.get_internal_type() == "BinaryField":
- default = b''
+ default = b""
else:
- default = ''
- elif getattr(field, 'auto_now', False) or getattr(field, 'auto_now_add', False):
+ default = ""
+ elif getattr(field, "auto_now", False) or getattr(field, "auto_now_add", False):
internal_type = field.get_internal_type()
- if internal_type == 'DateTimeField':
+ if internal_type == "DateTimeField":
default = timezone.now()
else:
default = datetime.now()
- if internal_type == 'DateField':
+ if internal_type == "DateField":
default = default.date()
- elif internal_type == 'TimeField':
+ elif internal_type == "TimeField":
default = default.time()
else:
default = None
@@ -372,19 +437,24 @@ class BaseDatabaseSchemaEditor:
self.delete_model(field.remote_field.through)
# Delete the table
- self.execute(self.sql_delete_table % {
- "table": self.quote_name(model._meta.db_table),
- })
+ self.execute(
+ self.sql_delete_table
+ % {
+ "table": self.quote_name(model._meta.db_table),
+ }
+ )
# Remove all deferred statements referencing the deleted table.
for sql in list(self.deferred_sql):
- if isinstance(sql, Statement) and sql.references_table(model._meta.db_table):
+ if isinstance(sql, Statement) and sql.references_table(
+ model._meta.db_table
+ ):
self.deferred_sql.remove(sql)
def add_index(self, model, index):
"""Add an index on a model."""
if (
- index.contains_expressions and
- not self.connection.features.supports_expression_indexes
+ index.contains_expressions
+ and not self.connection.features.supports_expression_indexes
):
return None
# Index.create_sql returns interpolated SQL which makes params=None a
@@ -394,8 +464,8 @@ class BaseDatabaseSchemaEditor:
def remove_index(self, model, index):
"""Remove an index from a model."""
if (
- index.contains_expressions and
- not self.connection.features.supports_expression_indexes
+ index.contains_expressions
+ and not self.connection.features.supports_expression_indexes
):
return None
self.execute(index.remove_sql(model, self))
@@ -424,7 +494,9 @@ class BaseDatabaseSchemaEditor:
news = {tuple(fields) for fields in new_unique_together}
# Deleted uniques
for fields in olds.difference(news):
- self._delete_composed_index(model, fields, {'unique': True}, self.sql_delete_unique)
+ self._delete_composed_index(
+ model, fields, {"unique": True}, self.sql_delete_unique
+ )
# Created uniques
for field_names in news.difference(olds):
fields = [model._meta.get_field(field) for field in field_names]
@@ -443,40 +515,51 @@ class BaseDatabaseSchemaEditor:
self._delete_composed_index(
model,
fields,
- {'index': True, 'unique': False},
+ {"index": True, "unique": False},
self.sql_delete_index,
)
# Created indexes
for field_names in news.difference(olds):
fields = [model._meta.get_field(field) for field in field_names]
- self.execute(self._create_index_sql(model, fields=fields, suffix='_idx'))
+ self.execute(self._create_index_sql(model, fields=fields, suffix="_idx"))
def _delete_composed_index(self, model, fields, constraint_kwargs, sql):
- meta_constraint_names = {constraint.name for constraint in model._meta.constraints}
+ meta_constraint_names = {
+ constraint.name for constraint in model._meta.constraints
+ }
meta_index_names = {constraint.name for constraint in model._meta.indexes}
columns = [model._meta.get_field(field).column for field in fields]
constraint_names = self._constraint_names(
- model, columns, exclude=meta_constraint_names | meta_index_names,
- **constraint_kwargs
+ model,
+ columns,
+ exclude=meta_constraint_names | meta_index_names,
+ **constraint_kwargs,
)
if len(constraint_names) != 1:
- raise ValueError("Found wrong number (%s) of constraints for %s(%s)" % (
- len(constraint_names),
- model._meta.db_table,
- ", ".join(columns),
- ))
+ raise ValueError(
+ "Found wrong number (%s) of constraints for %s(%s)"
+ % (
+ len(constraint_names),
+ model._meta.db_table,
+ ", ".join(columns),
+ )
+ )
self.execute(self._delete_constraint_sql(sql, model, constraint_names[0]))
def alter_db_table(self, model, old_db_table, new_db_table):
"""Rename the table a model points to."""
- if (old_db_table == new_db_table or
- (self.connection.features.ignores_table_name_case and
- old_db_table.lower() == new_db_table.lower())):
+ if old_db_table == new_db_table or (
+ self.connection.features.ignores_table_name_case
+ and old_db_table.lower() == new_db_table.lower()
+ ):
return
- self.execute(self.sql_rename_table % {
- "old_table": self.quote_name(old_db_table),
- "new_table": self.quote_name(new_db_table),
- })
+ self.execute(
+ self.sql_rename_table
+ % {
+ "old_table": self.quote_name(old_db_table),
+ "new_table": self.quote_name(new_db_table),
+ }
+ )
# Rename all references to the old table name.
for sql in self.deferred_sql:
if isinstance(sql, Statement):
@@ -484,11 +567,14 @@ class BaseDatabaseSchemaEditor:
def alter_db_tablespace(self, model, old_db_tablespace, new_db_tablespace):
"""Move a model's table between tablespaces."""
- self.execute(self.sql_retablespace_table % {
- "table": self.quote_name(model._meta.db_table),
- "old_tablespace": self.quote_name(old_db_tablespace),
- "new_tablespace": self.quote_name(new_db_tablespace),
- })
+ self.execute(
+ self.sql_retablespace_table
+ % {
+ "table": self.quote_name(model._meta.db_table),
+ "old_tablespace": self.quote_name(old_db_tablespace),
+ "new_tablespace": self.quote_name(new_db_tablespace),
+ }
+ )
def add_field(self, model, field):
"""
@@ -505,26 +591,36 @@ class BaseDatabaseSchemaEditor:
return
# Check constraints can go on the column SQL here
db_params = field.db_parameters(connection=self.connection)
- if db_params['check']:
+ if db_params["check"]:
definition += " " + self.sql_check_constraint % db_params
- if field.remote_field and self.connection.features.supports_foreign_keys and field.db_constraint:
- constraint_suffix = '_fk_%(to_table)s_%(to_column)s'
+ if (
+ field.remote_field
+ and self.connection.features.supports_foreign_keys
+ and field.db_constraint
+ ):
+ constraint_suffix = "_fk_%(to_table)s_%(to_column)s"
# Add FK constraint inline, if supported.
if self.sql_create_column_inline_fk:
to_table = field.remote_field.model._meta.db_table
- to_column = field.remote_field.model._meta.get_field(field.remote_field.field_name).column
+ to_column = field.remote_field.model._meta.get_field(
+ field.remote_field.field_name
+ ).column
namespace, _ = split_identifier(model._meta.db_table)
definition += " " + self.sql_create_column_inline_fk % {
- 'name': self._fk_constraint_name(model, field, constraint_suffix),
- 'namespace': '%s.' % self.quote_name(namespace) if namespace else '',
- 'column': self.quote_name(field.column),
- 'to_table': self.quote_name(to_table),
- 'to_column': self.quote_name(to_column),
- 'deferrable': self.connection.ops.deferrable_sql()
+ "name": self._fk_constraint_name(model, field, constraint_suffix),
+ "namespace": "%s." % self.quote_name(namespace)
+ if namespace
+ else "",
+ "column": self.quote_name(field.column),
+ "to_table": self.quote_name(to_table),
+ "to_column": self.quote_name(to_column),
+ "deferrable": self.connection.ops.deferrable_sql(),
}
# Otherwise, add FK constraints later.
else:
- self.deferred_sql.append(self._create_fk_sql(model, field, constraint_suffix))
+ self.deferred_sql.append(
+ self._create_fk_sql(model, field, constraint_suffix)
+ )
# Build the SQL and run it
sql = self.sql_create_column % {
"table": self.quote_name(model._meta.db_table),
@@ -534,8 +630,13 @@ class BaseDatabaseSchemaEditor:
self.execute(sql, params)
# Drop the default if we need to
# (Django usually does not use in-database defaults)
- if not self.skip_default_on_alter(field) and self.effective_default(field) is not None:
- changes_sql, params = self._alter_column_default_sql(model, None, field, drop=True)
+ if (
+ not self.skip_default_on_alter(field)
+ and self.effective_default(field) is not None
+ ):
+ changes_sql, params = self._alter_column_default_sql(
+ model, None, field, drop=True
+ )
sql = self.sql_alter_column % {
"table": self.quote_name(model._meta.db_table),
"changes": changes_sql,
@@ -556,7 +657,7 @@ class BaseDatabaseSchemaEditor:
if field.many_to_many and field.remote_field.through._meta.auto_created:
return self.delete_model(field.remote_field.through)
# It might not actually have a column behind it
- if field.db_parameters(connection=self.connection)['type'] is None:
+ if field.db_parameters(connection=self.connection)["type"] is None:
return
# Drop any FK constraints, MySQL requires explicit deletion
if field.remote_field:
@@ -574,7 +675,9 @@ class BaseDatabaseSchemaEditor:
self.connection.close()
# Remove all deferred statements referencing the deleted column.
for sql in list(self.deferred_sql):
- if isinstance(sql, Statement) and sql.references_column(model._meta.db_table, field.column):
+ if isinstance(sql, Statement) and sql.references_column(
+ model._meta.db_table, field.column
+ ):
self.deferred_sql.remove(sql)
def alter_field(self, model, old_field, new_field, strict=False):
@@ -589,25 +692,38 @@ class BaseDatabaseSchemaEditor:
return
# Ensure this field is even column-based
old_db_params = old_field.db_parameters(connection=self.connection)
- old_type = old_db_params['type']
+ old_type = old_db_params["type"]
new_db_params = new_field.db_parameters(connection=self.connection)
- new_type = new_db_params['type']
- if ((old_type is None and old_field.remote_field is None) or
- (new_type is None and new_field.remote_field is None)):
+ new_type = new_db_params["type"]
+ if (old_type is None and old_field.remote_field is None) or (
+ new_type is None and new_field.remote_field is None
+ ):
raise ValueError(
"Cannot alter field %s into %s - they do not properly define "
- "db_type (are you using a badly-written custom field?)" %
- (old_field, new_field),
+ "db_type (are you using a badly-written custom field?)"
+ % (old_field, new_field),
+ )
+ elif (
+ old_type is None
+ and new_type is None
+ and (
+ old_field.remote_field.through
+ and new_field.remote_field.through
+ and old_field.remote_field.through._meta.auto_created
+ and new_field.remote_field.through._meta.auto_created
)
- elif old_type is None and new_type is None and (
- old_field.remote_field.through and new_field.remote_field.through and
- old_field.remote_field.through._meta.auto_created and
- new_field.remote_field.through._meta.auto_created):
+ ):
return self._alter_many_to_many(model, old_field, new_field, strict)
- elif old_type is None and new_type is None and (
- old_field.remote_field.through and new_field.remote_field.through and
- not old_field.remote_field.through._meta.auto_created and
- not new_field.remote_field.through._meta.auto_created):
+ elif (
+ old_type is None
+ and new_type is None
+ and (
+ old_field.remote_field.through
+ and new_field.remote_field.through
+ and not old_field.remote_field.through._meta.auto_created
+ and not new_field.remote_field.through._meta.auto_created
+ )
+ ):
# Both sides have through models; this is a no-op.
return
elif old_type is None or new_type is None:
@@ -617,52 +733,86 @@ class BaseDatabaseSchemaEditor:
"through= on M2M fields)" % (old_field, new_field)
)
- self._alter_field(model, old_field, new_field, old_type, new_type,
- old_db_params, new_db_params, strict)
+ self._alter_field(
+ model,
+ old_field,
+ new_field,
+ old_type,
+ new_type,
+ old_db_params,
+ new_db_params,
+ strict,
+ )
- def _alter_field(self, model, old_field, new_field, old_type, new_type,
- old_db_params, new_db_params, strict=False):
+ def _alter_field(
+ self,
+ model,
+ old_field,
+ new_field,
+ old_type,
+ new_type,
+ old_db_params,
+ new_db_params,
+ strict=False,
+ ):
"""Perform a "physical" (non-ManyToMany) field update."""
# Drop any FK constraints, we'll remake them later
fks_dropped = set()
if (
- self.connection.features.supports_foreign_keys and
- old_field.remote_field and
- old_field.db_constraint
+ self.connection.features.supports_foreign_keys
+ and old_field.remote_field
+ and old_field.db_constraint
):
- fk_names = self._constraint_names(model, [old_field.column], foreign_key=True)
+ fk_names = self._constraint_names(
+ model, [old_field.column], foreign_key=True
+ )
if strict and len(fk_names) != 1:
- raise ValueError("Found wrong number (%s) of foreign key constraints for %s.%s" % (
- len(fk_names),
- model._meta.db_table,
- old_field.column,
- ))
+ raise ValueError(
+ "Found wrong number (%s) of foreign key constraints for %s.%s"
+ % (
+ len(fk_names),
+ model._meta.db_table,
+ old_field.column,
+ )
+ )
for fk_name in fk_names:
fks_dropped.add((old_field.column,))
self.execute(self._delete_fk_sql(model, fk_name))
# Has unique been removed?
- if old_field.unique and (not new_field.unique or self._field_became_primary_key(old_field, new_field)):
+ if old_field.unique and (
+ not new_field.unique or self._field_became_primary_key(old_field, new_field)
+ ):
# Find the unique constraint for this field
- meta_constraint_names = {constraint.name for constraint in model._meta.constraints}
+ meta_constraint_names = {
+ constraint.name for constraint in model._meta.constraints
+ }
constraint_names = self._constraint_names(
- model, [old_field.column], unique=True, primary_key=False,
+ model,
+ [old_field.column],
+ unique=True,
+ primary_key=False,
exclude=meta_constraint_names,
)
if strict and len(constraint_names) != 1:
- raise ValueError("Found wrong number (%s) of unique constraints for %s.%s" % (
- len(constraint_names),
- model._meta.db_table,
- old_field.column,
- ))
+ raise ValueError(
+ "Found wrong number (%s) of unique constraints for %s.%s"
+ % (
+ len(constraint_names),
+ model._meta.db_table,
+ old_field.column,
+ )
+ )
for constraint_name in constraint_names:
self.execute(self._delete_unique_sql(model, constraint_name))
# Drop incoming FK constraints if the field is a primary key or unique,
# which might be a to_field target, and things are going to change.
drop_foreign_keys = (
- self.connection.features.supports_foreign_keys and (
- (old_field.primary_key and new_field.primary_key) or
- (old_field.unique and new_field.unique)
- ) and old_type != new_type
+ self.connection.features.supports_foreign_keys
+ and (
+ (old_field.primary_key and new_field.primary_key)
+ or (old_field.unique and new_field.unique)
+ )
+ and old_type != new_type
)
if drop_foreign_keys:
# '_meta.related_field' also contains M2M reverse fields, these
@@ -683,13 +833,20 @@ class BaseDatabaseSchemaEditor:
# True | False | False | False
# True | False | False | True
# True | False | True | True
- if old_field.db_index and not old_field.unique and (not new_field.db_index or new_field.unique):
+ if (
+ old_field.db_index
+ and not old_field.unique
+ and (not new_field.db_index or new_field.unique)
+ ):
# Find the index for this field
meta_index_names = {index.name for index in model._meta.indexes}
# Retrieve only BTREE indexes since this is what's created with
# db_index=True.
index_names = self._constraint_names(
- model, [old_field.column], index=True, type_=Index.suffix,
+ model,
+ [old_field.column],
+ index=True,
+ type_=Index.suffix,
exclude=meta_index_names,
)
for index_name in index_names:
@@ -698,41 +855,58 @@ class BaseDatabaseSchemaEditor:
# is to look at its name (refs #28053).
self.execute(self._delete_index_sql(model, index_name))
# Change check constraints?
- if old_db_params['check'] != new_db_params['check'] and old_db_params['check']:
- meta_constraint_names = {constraint.name for constraint in model._meta.constraints}
+ if old_db_params["check"] != new_db_params["check"] and old_db_params["check"]:
+ meta_constraint_names = {
+ constraint.name for constraint in model._meta.constraints
+ }
constraint_names = self._constraint_names(
- model, [old_field.column], check=True,
+ model,
+ [old_field.column],
+ check=True,
exclude=meta_constraint_names,
)
if strict and len(constraint_names) != 1:
- raise ValueError("Found wrong number (%s) of check constraints for %s.%s" % (
- len(constraint_names),
- model._meta.db_table,
- old_field.column,
- ))
+ raise ValueError(
+ "Found wrong number (%s) of check constraints for %s.%s"
+ % (
+ len(constraint_names),
+ model._meta.db_table,
+ old_field.column,
+ )
+ )
for constraint_name in constraint_names:
self.execute(self._delete_check_sql(model, constraint_name))
# Have they renamed the column?
if old_field.column != new_field.column:
- self.execute(self._rename_field_sql(model._meta.db_table, old_field, new_field, new_type))
+ self.execute(
+ self._rename_field_sql(
+ model._meta.db_table, old_field, new_field, new_type
+ )
+ )
# Rename all references to the renamed column.
for sql in self.deferred_sql:
if isinstance(sql, Statement):
- sql.rename_column_references(model._meta.db_table, old_field.column, new_field.column)
+ sql.rename_column_references(
+ model._meta.db_table, old_field.column, new_field.column
+ )
# Next, start accumulating actions to do
actions = []
null_actions = []
post_actions = []
# Collation change?
- old_collation = getattr(old_field, 'db_collation', None)
- new_collation = getattr(new_field, 'db_collation', None)
+ old_collation = getattr(old_field, "db_collation", None)
+ new_collation = getattr(new_field, "db_collation", None)
if old_collation != new_collation:
# Collation change handles also a type change.
- fragment = self._alter_column_collation_sql(model, new_field, new_type, new_collation)
+ fragment = self._alter_column_collation_sql(
+ model, new_field, new_type, new_collation
+ )
actions.append(fragment)
# Type change?
elif old_type != new_type:
- fragment, other_actions = self._alter_column_type_sql(model, old_field, new_field, new_type)
+ fragment, other_actions = self._alter_column_type_sql(
+ model, old_field, new_field, new_type
+ )
actions.append(fragment)
post_actions.extend(other_actions)
# When changing a column NULL constraint to NOT NULL with a given
@@ -747,21 +921,22 @@ class BaseDatabaseSchemaEditor:
old_default = self.effective_default(old_field)
new_default = self.effective_default(new_field)
if (
- not self.skip_default_on_alter(new_field) and
- old_default != new_default and
- new_default is not None
+ not self.skip_default_on_alter(new_field)
+ and old_default != new_default
+ and new_default is not None
):
needs_database_default = True
- actions.append(self._alter_column_default_sql(model, old_field, new_field))
+ actions.append(
+ self._alter_column_default_sql(model, old_field, new_field)
+ )
# Nullability change?
if old_field.null != new_field.null:
fragment = self._alter_column_null_sql(model, old_field, new_field)
if fragment:
null_actions.append(fragment)
# Only if we have a default and there is a change from NULL to NOT NULL
- four_way_default_alteration = (
- new_field.has_default() and
- (old_field.null and not new_field.null)
+ four_way_default_alteration = new_field.has_default() and (
+ old_field.null and not new_field.null
)
if actions or null_actions:
if not four_way_default_alteration:
@@ -775,7 +950,8 @@ class BaseDatabaseSchemaEditor:
# Apply those actions
for sql, params in actions:
self.execute(
- self.sql_alter_column % {
+ self.sql_alter_column
+ % {
"table": self.quote_name(model._meta.db_table),
"changes": sql,
},
@@ -784,7 +960,8 @@ class BaseDatabaseSchemaEditor:
if four_way_default_alteration:
# Update existing rows with default value
self.execute(
- self.sql_update_with_default % {
+ self.sql_update_with_default
+ % {
"table": self.quote_name(model._meta.db_table),
"column": self.quote_name(new_field.column),
"default": "%s",
@@ -795,7 +972,8 @@ class BaseDatabaseSchemaEditor:
# now
for sql, params in null_actions:
self.execute(
- self.sql_alter_column % {
+ self.sql_alter_column
+ % {
"table": self.quote_name(model._meta.db_table),
"changes": sql,
},
@@ -819,7 +997,11 @@ class BaseDatabaseSchemaEditor:
# False | False | True | False
# False | True | True | False
# True | True | True | False
- if (not old_field.db_index or old_field.unique) and new_field.db_index and not new_field.unique:
+ if (
+ (not old_field.db_index or old_field.unique)
+ and new_field.db_index
+ and not new_field.unique
+ ):
self.execute(self._create_index_sql(model, fields=[new_field]))
# Type alteration on primary key? Then we need to alter the column
# referring to us.
@@ -835,12 +1017,13 @@ class BaseDatabaseSchemaEditor:
# Handle our type alters on the other end of rels from the PK stuff above
for old_rel, new_rel in rels_to_update:
rel_db_params = new_rel.field.db_parameters(connection=self.connection)
- rel_type = rel_db_params['type']
+ rel_type = rel_db_params["type"]
fragment, other_actions = self._alter_column_type_sql(
new_rel.related_model, old_rel.field, new_rel.field, rel_type
)
self.execute(
- self.sql_alter_column % {
+ self.sql_alter_column
+ % {
"table": self.quote_name(new_rel.related_model._meta.db_table),
"changes": fragment[0],
},
@@ -849,23 +1032,38 @@ class BaseDatabaseSchemaEditor:
for sql, params in other_actions:
self.execute(sql, params)
# Does it have a foreign key?
- if (self.connection.features.supports_foreign_keys and new_field.remote_field and
- (fks_dropped or not old_field.remote_field or not old_field.db_constraint) and
- new_field.db_constraint):
- self.execute(self._create_fk_sql(model, new_field, "_fk_%(to_table)s_%(to_column)s"))
+ if (
+ self.connection.features.supports_foreign_keys
+ and new_field.remote_field
+ and (
+ fks_dropped or not old_field.remote_field or not old_field.db_constraint
+ )
+ and new_field.db_constraint
+ ):
+ self.execute(
+ self._create_fk_sql(model, new_field, "_fk_%(to_table)s_%(to_column)s")
+ )
# Rebuild FKs that pointed to us if we previously had to drop them
if drop_foreign_keys:
for _, rel in rels_to_update:
if rel.field.db_constraint:
- self.execute(self._create_fk_sql(rel.related_model, rel.field, "_fk"))
+ self.execute(
+ self._create_fk_sql(rel.related_model, rel.field, "_fk")
+ )
# Does it have check constraints we need to add?
- if old_db_params['check'] != new_db_params['check'] and new_db_params['check']:
- constraint_name = self._create_index_name(model._meta.db_table, [new_field.column], suffix='_check')
- self.execute(self._create_check_sql(model, constraint_name, new_db_params['check']))
+ if old_db_params["check"] != new_db_params["check"] and new_db_params["check"]:
+ constraint_name = self._create_index_name(
+ model._meta.db_table, [new_field.column], suffix="_check"
+ )
+ self.execute(
+ self._create_check_sql(model, constraint_name, new_db_params["check"])
+ )
# Drop the default if we need to
# (Django usually does not use in-database defaults)
if needs_database_default:
- changes_sql, params = self._alter_column_default_sql(model, old_field, new_field, drop=True)
+ changes_sql, params = self._alter_column_default_sql(
+ model, old_field, new_field, drop=True
+ )
sql = self.sql_alter_column % {
"table": self.quote_name(model._meta.db_table),
"changes": changes_sql,
@@ -883,18 +1081,23 @@ class BaseDatabaseSchemaEditor:
as required by new_field, or None if no changes are required.
"""
if (
- self.connection.features.interprets_empty_strings_as_nulls and
- new_field.empty_strings_allowed
+ self.connection.features.interprets_empty_strings_as_nulls
+ and new_field.empty_strings_allowed
):
# The field is nullable in the database anyway, leave it alone.
return
else:
new_db_params = new_field.db_parameters(connection=self.connection)
- sql = self.sql_alter_column_null if new_field.null else self.sql_alter_column_not_null
+ sql = (
+ self.sql_alter_column_null
+ if new_field.null
+ else self.sql_alter_column_not_null
+ )
return (
- sql % {
- 'column': self.quote_name(new_field.column),
- 'type': new_db_params['type'],
+ sql
+ % {
+ "column": self.quote_name(new_field.column),
+ "type": new_db_params["type"],
},
[],
)
@@ -928,10 +1131,11 @@ class BaseDatabaseSchemaEditor:
else:
sql = self.sql_alter_column_default
return (
- sql % {
- 'column': self.quote_name(new_field.column),
- 'type': new_db_params['type'],
- 'default': default,
+ sql
+ % {
+ "column": self.quote_name(new_field.column),
+ "type": new_db_params["type"],
+ "default": default,
},
params,
)
@@ -948,7 +1152,8 @@ class BaseDatabaseSchemaEditor:
"""
return (
(
- self.sql_alter_column_type % {
+ self.sql_alter_column_type
+ % {
"column": self.quote_name(new_field.column),
"type": new_type,
},
@@ -959,10 +1164,13 @@ class BaseDatabaseSchemaEditor:
def _alter_column_collation_sql(self, model, new_field, new_type, new_collation):
return (
- self.sql_alter_column_collate % {
- 'column': self.quote_name(new_field.column),
- 'type': new_type,
- 'collation': ' ' + self._collate_sql(new_collation) if new_collation else '',
+ self.sql_alter_column_collate
+ % {
+ "column": self.quote_name(new_field.column),
+ "type": new_type,
+ "collation": " " + self._collate_sql(new_collation)
+ if new_collation
+ else "",
},
[],
)
@@ -970,16 +1178,26 @@ class BaseDatabaseSchemaEditor:
def _alter_many_to_many(self, model, old_field, new_field, strict):
"""Alter M2Ms to repoint their to= endpoints."""
# Rename the through table
- if old_field.remote_field.through._meta.db_table != new_field.remote_field.through._meta.db_table:
- self.alter_db_table(old_field.remote_field.through, old_field.remote_field.through._meta.db_table,
- new_field.remote_field.through._meta.db_table)
+ if (
+ old_field.remote_field.through._meta.db_table
+ != new_field.remote_field.through._meta.db_table
+ ):
+ self.alter_db_table(
+ old_field.remote_field.through,
+ old_field.remote_field.through._meta.db_table,
+ new_field.remote_field.through._meta.db_table,
+ )
# Repoint the FK to the other side
self.alter_field(
new_field.remote_field.through,
# We need the field that points to the target model, so we can tell alter_field to change it -
# this is m2m_reverse_field_name() (as opposed to m2m_field_name, which points to our model)
- old_field.remote_field.through._meta.get_field(old_field.m2m_reverse_field_name()),
- new_field.remote_field.through._meta.get_field(new_field.m2m_reverse_field_name()),
+ old_field.remote_field.through._meta.get_field(
+ old_field.m2m_reverse_field_name()
+ ),
+ new_field.remote_field.through._meta.get_field(
+ new_field.m2m_reverse_field_name()
+ ),
)
self.alter_field(
new_field.remote_field.through,
@@ -996,19 +1214,22 @@ class BaseDatabaseSchemaEditor:
and a unique digest and suffix.
"""
_, table_name = split_identifier(table_name)
- hash_suffix_part = '%s%s' % (names_digest(table_name, *column_names, length=8), suffix)
+ hash_suffix_part = "%s%s" % (
+ names_digest(table_name, *column_names, length=8),
+ suffix,
+ )
max_length = self.connection.ops.max_name_length() or 200
# If everything fits into max_length, use that name.
- index_name = '%s_%s_%s' % (table_name, '_'.join(column_names), hash_suffix_part)
+ index_name = "%s_%s_%s" % (table_name, "_".join(column_names), hash_suffix_part)
if len(index_name) <= max_length:
return index_name
# Shorten a long suffix.
if len(hash_suffix_part) > max_length / 3:
- hash_suffix_part = hash_suffix_part[:max_length // 3]
+ hash_suffix_part = hash_suffix_part[: max_length // 3]
other_length = (max_length - len(hash_suffix_part)) // 2 - 1
- index_name = '%s_%s_%s' % (
+ index_name = "%s_%s_%s" % (
table_name[:other_length],
- '_'.join(column_names)[:other_length],
+ "_".join(column_names)[:other_length],
hash_suffix_part,
)
# Prepend D if needed to prevent the name from starting with an
@@ -1024,25 +1245,38 @@ class BaseDatabaseSchemaEditor:
elif model._meta.db_tablespace:
db_tablespace = model._meta.db_tablespace
if db_tablespace is not None:
- return ' ' + self.connection.ops.tablespace_sql(db_tablespace)
- return ''
+ return " " + self.connection.ops.tablespace_sql(db_tablespace)
+ return ""
def _index_condition_sql(self, condition):
if condition:
- return ' WHERE ' + condition
- return ''
+ return " WHERE " + condition
+ return ""
def _index_include_sql(self, model, columns):
if not columns or not self.connection.features.supports_covering_indexes:
- return ''
+ return ""
return Statement(
- ' INCLUDE (%(columns)s)',
+ " INCLUDE (%(columns)s)",
columns=Columns(model._meta.db_table, columns, self.quote_name),
)
- def _create_index_sql(self, model, *, fields=None, name=None, suffix='', using='',
- db_tablespace=None, col_suffixes=(), sql=None, opclasses=(),
- condition=None, include=None, expressions=None):
+ def _create_index_sql(
+ self,
+ model,
+ *,
+ fields=None,
+ name=None,
+ suffix="",
+ using="",
+ db_tablespace=None,
+ col_suffixes=(),
+ sql=None,
+ opclasses=(),
+ condition=None,
+ include=None,
+ expressions=None,
+ ):
"""
Return the SQL statement to create the index for one or several fields
or expressions. `sql` can be specified if the syntax differs from the
@@ -1053,7 +1287,9 @@ class BaseDatabaseSchemaEditor:
compiler = Query(model, alias_cols=False).get_compiler(
connection=self.connection,
)
- tablespace_sql = self._get_index_tablespace_sql(model, fields, db_tablespace=db_tablespace)
+ tablespace_sql = self._get_index_tablespace_sql(
+ model, fields, db_tablespace=db_tablespace
+ )
columns = [field.column for field in fields]
sql_create_index = sql or self.sql_create_index
table = model._meta.db_table
@@ -1102,12 +1338,12 @@ class BaseDatabaseSchemaEditor:
for field_names in model._meta.index_together:
fields = [model._meta.get_field(field) for field in field_names]
- output.append(self._create_index_sql(model, fields=fields, suffix='_idx'))
+ output.append(self._create_index_sql(model, fields=fields, suffix="_idx"))
for index in model._meta.indexes:
if (
- not index.contains_expressions or
- self.connection.features.supports_expression_indexes
+ not index.contains_expressions
+ or self.connection.features.supports_expression_indexes
):
output.append(index.create_sql(model, self))
return output
@@ -1129,26 +1365,25 @@ class BaseDatabaseSchemaEditor:
# - changing an attribute that doesn't affect the schema
# - adding only a db_column and the column name is not changed
non_database_attrs = [
- 'blank',
- 'db_column',
- 'editable',
- 'error_messages',
- 'help_text',
- 'limit_choices_to',
+ "blank",
+ "db_column",
+ "editable",
+ "error_messages",
+ "help_text",
+ "limit_choices_to",
# Database-level options are not supported, see #21961.
- 'on_delete',
- 'related_name',
- 'related_query_name',
- 'validators',
- 'verbose_name',
+ "on_delete",
+ "related_name",
+ "related_query_name",
+ "validators",
+ "verbose_name",
]
for attr in non_database_attrs:
old_kwargs.pop(attr, None)
new_kwargs.pop(attr, None)
- return (
- self.quote_name(old_field.column) != self.quote_name(new_field.column) or
- (old_path, old_args, old_kwargs) != (new_path, new_args, new_kwargs)
- )
+ return self.quote_name(old_field.column) != self.quote_name(
+ new_field.column
+ ) or (old_path, old_args, old_kwargs) != (new_path, new_args, new_kwargs)
def _field_should_be_indexed(self, model, field):
return field.db_index and not field.unique
@@ -1158,9 +1393,9 @@ class BaseDatabaseSchemaEditor:
def _unique_should_be_added(self, old_field, new_field):
return (
- not new_field.primary_key and
- new_field.unique and
- (not old_field.unique or old_field.primary_key)
+ not new_field.primary_key
+ and new_field.unique
+ and (not old_field.unique or old_field.primary_key)
)
def _rename_field_sql(self, table, old_field, new_field, new_type):
@@ -1176,7 +1411,11 @@ class BaseDatabaseSchemaEditor:
name = self._fk_constraint_name(model, field, suffix)
column = Columns(model._meta.db_table, [field.column], self.quote_name)
to_table = Table(field.target_field.model._meta.db_table, self.quote_name)
- to_column = Columns(field.target_field.model._meta.db_table, [field.target_field.column], self.quote_name)
+ to_column = Columns(
+ field.target_field.model._meta.db_table,
+ [field.target_field.column],
+ self.quote_name,
+ )
deferrable = self.connection.ops.deferrable_sql()
return Statement(
self.sql_create_fk,
@@ -1206,19 +1445,26 @@ class BaseDatabaseSchemaEditor:
def _deferrable_constraint_sql(self, deferrable):
if deferrable is None:
- return ''
+ return ""
if deferrable == Deferrable.DEFERRED:
- return ' DEFERRABLE INITIALLY DEFERRED'
+ return " DEFERRABLE INITIALLY DEFERRED"
if deferrable == Deferrable.IMMEDIATE:
- return ' DEFERRABLE INITIALLY IMMEDIATE'
+ return " DEFERRABLE INITIALLY IMMEDIATE"
def _unique_sql(
- self, model, fields, name, condition=None, deferrable=None,
- include=None, opclasses=None, expressions=None,
+ self,
+ model,
+ fields,
+ name,
+ condition=None,
+ deferrable=None,
+ include=None,
+ opclasses=None,
+ expressions=None,
):
if (
- deferrable and
- not self.connection.features.supports_deferrable_unique_constraints
+ deferrable
+ and not self.connection.features.supports_deferrable_unique_constraints
):
return None
if condition or include or opclasses or expressions:
@@ -1237,37 +1483,48 @@ class BaseDatabaseSchemaEditor:
self.deferred_sql.append(sql)
return None
constraint = self.sql_unique_constraint % {
- 'columns': ', '.join([self.quote_name(field.column) for field in fields]),
- 'deferrable': self._deferrable_constraint_sql(deferrable),
+ "columns": ", ".join([self.quote_name(field.column) for field in fields]),
+ "deferrable": self._deferrable_constraint_sql(deferrable),
}
return self.sql_constraint % {
- 'name': self.quote_name(name),
- 'constraint': constraint,
+ "name": self.quote_name(name),
+ "constraint": constraint,
}
def _create_unique_sql(
- self, model, fields, name=None, condition=None, deferrable=None,
- include=None, opclasses=None, expressions=None,
+ self,
+ model,
+ fields,
+ name=None,
+ condition=None,
+ deferrable=None,
+ include=None,
+ opclasses=None,
+ expressions=None,
):
if (
(
- deferrable and
- not self.connection.features.supports_deferrable_unique_constraints
- ) or
- (condition and not self.connection.features.supports_partial_indexes) or
- (include and not self.connection.features.supports_covering_indexes) or
- (expressions and not self.connection.features.supports_expression_indexes)
+ deferrable
+ and not self.connection.features.supports_deferrable_unique_constraints
+ )
+ or (condition and not self.connection.features.supports_partial_indexes)
+ or (include and not self.connection.features.supports_covering_indexes)
+ or (
+ expressions and not self.connection.features.supports_expression_indexes
+ )
):
return None
def create_unique_name(*args, **kwargs):
return self.quote_name(self._create_index_name(*args, **kwargs))
- compiler = Query(model, alias_cols=False).get_compiler(connection=self.connection)
+ compiler = Query(model, alias_cols=False).get_compiler(
+ connection=self.connection
+ )
table = model._meta.db_table
columns = [field.column for field in fields]
if name is None:
- name = IndexName(table, columns, '_uniq', create_unique_name)
+ name = IndexName(table, columns, "_uniq", create_unique_name)
else:
name = self.quote_name(name)
if condition or include or opclasses or expressions:
@@ -1275,7 +1532,9 @@ class BaseDatabaseSchemaEditor:
else:
sql = self.sql_create_unique
if columns:
- columns = self._index_columns(table, columns, col_suffixes=(), opclasses=opclasses)
+ columns = self._index_columns(
+ table, columns, col_suffixes=(), opclasses=opclasses
+ )
else:
columns = Expressions(table, expressions, compiler, self.quote_value)
return Statement(
@@ -1289,18 +1548,25 @@ class BaseDatabaseSchemaEditor:
)
def _delete_unique_sql(
- self, model, name, condition=None, deferrable=None, include=None,
- opclasses=None, expressions=None,
+ self,
+ model,
+ name,
+ condition=None,
+ deferrable=None,
+ include=None,
+ opclasses=None,
+ expressions=None,
):
if (
(
- deferrable and
- not self.connection.features.supports_deferrable_unique_constraints
- ) or
- (condition and not self.connection.features.supports_partial_indexes) or
- (include and not self.connection.features.supports_covering_indexes) or
- (expressions and not self.connection.features.supports_expression_indexes)
-
+ deferrable
+ and not self.connection.features.supports_deferrable_unique_constraints
+ )
+ or (condition and not self.connection.features.supports_partial_indexes)
+ or (include and not self.connection.features.supports_covering_indexes)
+ or (
+ expressions and not self.connection.features.supports_expression_indexes
+ )
):
return None
if condition or include or opclasses or expressions:
@@ -1311,8 +1577,8 @@ class BaseDatabaseSchemaEditor:
def _check_sql(self, name, check):
return self.sql_constraint % {
- 'name': self.quote_name(name),
- 'constraint': self.sql_check_constraint % {'check': check},
+ "name": self.quote_name(name),
+ "constraint": self.sql_check_constraint % {"check": check},
}
def _create_check_sql(self, model, name, check):
@@ -1333,9 +1599,18 @@ class BaseDatabaseSchemaEditor:
name=self.quote_name(name),
)
- def _constraint_names(self, model, column_names=None, unique=None,
- primary_key=None, index=None, foreign_key=None,
- check=None, type_=None, exclude=None):
+ def _constraint_names(
+ self,
+ model,
+ column_names=None,
+ unique=None,
+ primary_key=None,
+ index=None,
+ foreign_key=None,
+ check=None,
+ type_=None,
+ exclude=None,
+ ):
"""Return all constraint names matching the columns and conditions."""
if column_names is not None:
column_names = [
@@ -1343,21 +1618,23 @@ class BaseDatabaseSchemaEditor:
for name in column_names
]
with self.connection.cursor() as cursor:
- constraints = self.connection.introspection.get_constraints(cursor, model._meta.db_table)
+ constraints = self.connection.introspection.get_constraints(
+ cursor, model._meta.db_table
+ )
result = []
for name, infodict in constraints.items():
- if column_names is None or column_names == infodict['columns']:
- if unique is not None and infodict['unique'] != unique:
+ if column_names is None or column_names == infodict["columns"]:
+ if unique is not None and infodict["unique"] != unique:
continue
- if primary_key is not None and infodict['primary_key'] != primary_key:
+ if primary_key is not None and infodict["primary_key"] != primary_key:
continue
- if index is not None and infodict['index'] != index:
+ if index is not None and infodict["index"] != index:
continue
- if check is not None and infodict['check'] != check:
+ if check is not None and infodict["check"] != check:
continue
- if foreign_key is not None and not infodict['foreign_key']:
+ if foreign_key is not None and not infodict["foreign_key"]:
continue
- if type_ is not None and infodict['type'] != type_:
+ if type_ is not None and infodict["type"] != type_:
continue
if not exclude or name not in exclude:
result.append(name)
@@ -1366,10 +1643,13 @@ class BaseDatabaseSchemaEditor:
def _delete_primary_key(self, model, strict=False):
constraint_names = self._constraint_names(model, primary_key=True)
if strict and len(constraint_names) != 1:
- raise ValueError('Found wrong number (%s) of PK constraints for %s' % (
- len(constraint_names),
- model._meta.db_table,
- ))
+ raise ValueError(
+ "Found wrong number (%s) of PK constraints for %s"
+ % (
+ len(constraint_names),
+ model._meta.db_table,
+ )
+ )
for constraint_name in constraint_names:
self.execute(self._delete_primary_key_sql(model, constraint_name))
@@ -1378,7 +1658,9 @@ class BaseDatabaseSchemaEditor:
self.sql_create_pk,
table=Table(model._meta.db_table, self.quote_name),
name=self.quote_name(
- self._create_index_name(model._meta.db_table, [field.column], suffix="_pk")
+ self._create_index_name(
+ model._meta.db_table, [field.column], suffix="_pk"
+ )
),
columns=Columns(model._meta.db_table, [field.column], self.quote_name),
)
@@ -1387,11 +1669,11 @@ class BaseDatabaseSchemaEditor:
return self._delete_constraint_sql(self.sql_delete_pk, model, name)
def _collate_sql(self, collation):
- return 'COLLATE ' + self.quote_name(collation)
+ return "COLLATE " + self.quote_name(collation)
def remove_procedure(self, procedure_name, param_types=()):
sql = self.sql_delete_procedure % {
- 'procedure': self.quote_name(procedure_name),
- 'param_types': ','.join(param_types),
+ "procedure": self.quote_name(procedure_name),
+ "param_types": ",".join(param_types),
}
self.execute(sql)
diff --git a/django/db/backends/base/validation.py b/django/db/backends/base/validation.py
index a02780a694..d0e3e2157d 100644
--- a/django/db/backends/base/validation.py
+++ b/django/db/backends/base/validation.py
@@ -1,5 +1,6 @@
class BaseDatabaseValidation:
"""Encapsulate backend-specific validation."""
+
def __init__(self, connection):
self.connection = connection
@@ -9,9 +10,12 @@ class BaseDatabaseValidation:
def check_field(self, field, **kwargs):
errors = []
# Backends may implement a check_field_type() method.
- if (hasattr(self, 'check_field_type') and
- # Ignore any related fields.
- not getattr(field, 'remote_field', None)):
+ if (
+ hasattr(self, "check_field_type")
+ and
+ # Ignore any related fields.
+ not getattr(field, "remote_field", None)
+ ):
# Ignore fields with unsupported features.
db_supports_all_required_features = all(
getattr(self.connection.features, feature, False)
diff --git a/django/db/backends/ddl_references.py b/django/db/backends/ddl_references.py
index f798fd648b..412d07a993 100644
--- a/django/db/backends/ddl_references.py
+++ b/django/db/backends/ddl_references.py
@@ -33,10 +33,12 @@ class Reference:
pass
def __repr__(self):
- return '<%s %r>' % (self.__class__.__name__, str(self))
+ return "<%s %r>" % (self.__class__.__name__, str(self))
def __str__(self):
- raise NotImplementedError('Subclasses must define how they should be converted to string.')
+ raise NotImplementedError(
+ "Subclasses must define how they should be converted to string."
+ )
class Table(Reference):
@@ -88,12 +90,14 @@ class Columns(TableColumns):
try:
suffix = self.col_suffixes[idx]
if suffix:
- col = '{} {}'.format(col, suffix)
+ col = "{} {}".format(col, suffix)
except IndexError:
pass
return col
- return ', '.join(col_str(column, idx) for idx, column in enumerate(self.columns))
+ return ", ".join(
+ col_str(column, idx) for idx, column in enumerate(self.columns)
+ )
class IndexName(TableColumns):
@@ -117,35 +121,49 @@ class IndexColumns(Columns):
def col_str(column, idx):
# Index.__init__() guarantees that self.opclasses is the same
# length as self.columns.
- col = '{} {}'.format(self.quote_name(column), self.opclasses[idx])
+ col = "{} {}".format(self.quote_name(column), self.opclasses[idx])
try:
suffix = self.col_suffixes[idx]
if suffix:
- col = '{} {}'.format(col, suffix)
+ col = "{} {}".format(col, suffix)
except IndexError:
pass
return col
- return ', '.join(col_str(column, idx) for idx, column in enumerate(self.columns))
+ return ", ".join(
+ col_str(column, idx) for idx, column in enumerate(self.columns)
+ )
class ForeignKeyName(TableColumns):
"""Hold a reference to a foreign key name."""
- def __init__(self, from_table, from_columns, to_table, to_columns, suffix_template, create_fk_name):
+ def __init__(
+ self,
+ from_table,
+ from_columns,
+ to_table,
+ to_columns,
+ suffix_template,
+ create_fk_name,
+ ):
self.to_reference = TableColumns(to_table, to_columns)
self.suffix_template = suffix_template
self.create_fk_name = create_fk_name
- super().__init__(from_table, from_columns,)
+ super().__init__(
+ from_table,
+ from_columns,
+ )
def references_table(self, table):
- return super().references_table(table) or self.to_reference.references_table(table)
+ return super().references_table(table) or self.to_reference.references_table(
+ table
+ )
def references_column(self, table, column):
- return (
- super().references_column(table, column) or
- self.to_reference.references_column(table, column)
- )
+ return super().references_column(
+ table, column
+ ) or self.to_reference.references_column(table, column)
def rename_table_references(self, old_table, new_table):
super().rename_table_references(old_table, new_table)
@@ -157,8 +175,8 @@ class ForeignKeyName(TableColumns):
def __str__(self):
suffix = self.suffix_template % {
- 'to_table': self.to_reference.table,
- 'to_column': self.to_reference.columns[0],
+ "to_table": self.to_reference.table,
+ "to_column": self.to_reference.columns[0],
}
return self.create_fk_name(self.table, self.columns, suffix)
@@ -171,30 +189,31 @@ class Statement(Reference):
that might have to be adjusted if they're referencing a table or column
that is removed
"""
+
def __init__(self, template, **parts):
self.template = template
self.parts = parts
def references_table(self, table):
return any(
- hasattr(part, 'references_table') and part.references_table(table)
+ hasattr(part, "references_table") and part.references_table(table)
for part in self.parts.values()
)
def references_column(self, table, column):
return any(
- hasattr(part, 'references_column') and part.references_column(table, column)
+ hasattr(part, "references_column") and part.references_column(table, column)
for part in self.parts.values()
)
def rename_table_references(self, old_table, new_table):
for part in self.parts.values():
- if hasattr(part, 'rename_table_references'):
+ if hasattr(part, "rename_table_references"):
part.rename_table_references(old_table, new_table)
def rename_column_references(self, table, old_column, new_column):
for part in self.parts.values():
- if hasattr(part, 'rename_column_references'):
+ if hasattr(part, "rename_column_references"):
part.rename_column_references(table, old_column, new_column)
def __str__(self):
@@ -206,7 +225,10 @@ class Expressions(TableColumns):
self.compiler = compiler
self.expressions = expressions
self.quote_value = quote_value
- columns = [col.target.column for col in self.compiler.query._gen_cols([self.expressions])]
+ columns = [
+ col.target.column
+ for col in self.compiler.query._gen_cols([self.expressions])
+ ]
super().__init__(table, columns)
def rename_table_references(self, old_table, new_table):
diff --git a/django/db/backends/dummy/base.py b/django/db/backends/dummy/base.py
index 06a25f2276..36c6480a78 100644
--- a/django/db/backends/dummy/base.py
+++ b/django/db/backends/dummy/base.py
@@ -17,9 +17,11 @@ from django.db.backends.dummy.features import DummyDatabaseFeatures
def complain(*args, **kwargs):
- raise ImproperlyConfigured("settings.DATABASES is improperly configured. "
- "Please supply the ENGINE value. Check "
- "settings documentation for more details.")
+ raise ImproperlyConfigured(
+ "settings.DATABASES is improperly configured. "
+ "Please supply the ENGINE value. Check "
+ "settings documentation for more details."
+ )
def ignore(*args, **kwargs):
diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py
index ef2cd3a5ae..b689040f7f 100644
--- a/django/db/backends/mysql/base.py
+++ b/django/db/backends/mysql/base.py
@@ -15,7 +15,7 @@ try:
import MySQLdb as Database
except ImportError as err:
raise ImproperlyConfigured(
- 'Error loading MySQLdb module.\nDid you install mysqlclient?'
+ "Error loading MySQLdb module.\nDid you install mysqlclient?"
) from err
from MySQLdb.constants import CLIENT, FIELD_TYPE
@@ -32,7 +32,9 @@ from .validation import DatabaseValidation
version = Database.version_info
if version < (1, 4, 0):
- raise ImproperlyConfigured('mysqlclient 1.4.0 or newer is required; you have %s.' % Database.__version__)
+ raise ImproperlyConfigured(
+ "mysqlclient 1.4.0 or newer is required; you have %s." % Database.__version__
+ )
# MySQLdb returns TIME columns as timedelta -- they are more like timedelta in
@@ -45,7 +47,7 @@ django_conversions = {
# This should match the numerical portion of the version numbers (we can treat
# versions like 5.0.24 and 5.0.24a as the same).
-server_version_re = _lazy_re_compile(r'(\d{1,2})\.(\d{1,2})\.(\d{1,2})')
+server_version_re = _lazy_re_compile(r"(\d{1,2})\.(\d{1,2})\.(\d{1,2})")
class CursorWrapper:
@@ -56,6 +58,7 @@ class CursorWrapper:
Implemented as a wrapper, rather than a subclass, so that it isn't stuck
to the particular underlying representation returned by Connection.cursor().
"""
+
codes_for_integrityerror = (
1048, # Column cannot be null
1690, # BIGINT UNSIGNED value is out of range
@@ -95,39 +98,39 @@ class CursorWrapper:
class DatabaseWrapper(BaseDatabaseWrapper):
- vendor = 'mysql'
+ vendor = "mysql"
# This dictionary maps Field objects to their associated MySQL column
# types, as strings. Column-type strings can contain format strings; they'll
# be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output.
data_types = {
- 'AutoField': 'integer AUTO_INCREMENT',
- 'BigAutoField': 'bigint AUTO_INCREMENT',
- 'BinaryField': 'longblob',
- 'BooleanField': 'bool',
- 'CharField': 'varchar(%(max_length)s)',
- 'DateField': 'date',
- 'DateTimeField': 'datetime(6)',
- 'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)',
- 'DurationField': 'bigint',
- 'FileField': 'varchar(%(max_length)s)',
- 'FilePathField': 'varchar(%(max_length)s)',
- 'FloatField': 'double precision',
- 'IntegerField': 'integer',
- 'BigIntegerField': 'bigint',
- 'IPAddressField': 'char(15)',
- 'GenericIPAddressField': 'char(39)',
- 'JSONField': 'json',
- 'OneToOneField': 'integer',
- 'PositiveBigIntegerField': 'bigint UNSIGNED',
- 'PositiveIntegerField': 'integer UNSIGNED',
- 'PositiveSmallIntegerField': 'smallint UNSIGNED',
- 'SlugField': 'varchar(%(max_length)s)',
- 'SmallAutoField': 'smallint AUTO_INCREMENT',
- 'SmallIntegerField': 'smallint',
- 'TextField': 'longtext',
- 'TimeField': 'time(6)',
- 'UUIDField': 'char(32)',
+ "AutoField": "integer AUTO_INCREMENT",
+ "BigAutoField": "bigint AUTO_INCREMENT",
+ "BinaryField": "longblob",
+ "BooleanField": "bool",
+ "CharField": "varchar(%(max_length)s)",
+ "DateField": "date",
+ "DateTimeField": "datetime(6)",
+ "DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)",
+ "DurationField": "bigint",
+ "FileField": "varchar(%(max_length)s)",
+ "FilePathField": "varchar(%(max_length)s)",
+ "FloatField": "double precision",
+ "IntegerField": "integer",
+ "BigIntegerField": "bigint",
+ "IPAddressField": "char(15)",
+ "GenericIPAddressField": "char(39)",
+ "JSONField": "json",
+ "OneToOneField": "integer",
+ "PositiveBigIntegerField": "bigint UNSIGNED",
+ "PositiveIntegerField": "integer UNSIGNED",
+ "PositiveSmallIntegerField": "smallint UNSIGNED",
+ "SlugField": "varchar(%(max_length)s)",
+ "SmallAutoField": "smallint AUTO_INCREMENT",
+ "SmallIntegerField": "smallint",
+ "TextField": "longtext",
+ "TimeField": "time(6)",
+ "UUIDField": "char(32)",
}
# For these data types:
@@ -136,23 +139,30 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# - all versions of MySQL and MariaDB don't support full width database
# indexes
_limited_data_types = (
- 'tinyblob', 'blob', 'mediumblob', 'longblob', 'tinytext', 'text',
- 'mediumtext', 'longtext', 'json',
+ "tinyblob",
+ "blob",
+ "mediumblob",
+ "longblob",
+ "tinytext",
+ "text",
+ "mediumtext",
+ "longtext",
+ "json",
)
operators = {
- 'exact': '= %s',
- 'iexact': 'LIKE %s',
- 'contains': 'LIKE BINARY %s',
- 'icontains': 'LIKE %s',
- 'gt': '> %s',
- 'gte': '>= %s',
- 'lt': '< %s',
- 'lte': '<= %s',
- 'startswith': 'LIKE BINARY %s',
- 'endswith': 'LIKE BINARY %s',
- 'istartswith': 'LIKE %s',
- 'iendswith': 'LIKE %s',
+ "exact": "= %s",
+ "iexact": "LIKE %s",
+ "contains": "LIKE BINARY %s",
+ "icontains": "LIKE %s",
+ "gt": "> %s",
+ "gte": ">= %s",
+ "lt": "< %s",
+ "lte": "<= %s",
+ "startswith": "LIKE BINARY %s",
+ "endswith": "LIKE BINARY %s",
+ "istartswith": "LIKE %s",
+ "iendswith": "LIKE %s",
}
# The patterns below are used to generate SQL pattern lookup clauses when
@@ -165,19 +175,19 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# the LIKE operator.
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\\', '\\\\'), '%%', '\%%'), '_', '\_')"
pattern_ops = {
- 'contains': "LIKE BINARY CONCAT('%%', {}, '%%')",
- 'icontains': "LIKE CONCAT('%%', {}, '%%')",
- 'startswith': "LIKE BINARY CONCAT({}, '%%')",
- 'istartswith': "LIKE CONCAT({}, '%%')",
- 'endswith': "LIKE BINARY CONCAT('%%', {})",
- 'iendswith': "LIKE CONCAT('%%', {})",
+ "contains": "LIKE BINARY CONCAT('%%', {}, '%%')",
+ "icontains": "LIKE CONCAT('%%', {}, '%%')",
+ "startswith": "LIKE BINARY CONCAT({}, '%%')",
+ "istartswith": "LIKE CONCAT({}, '%%')",
+ "endswith": "LIKE BINARY CONCAT('%%', {})",
+ "iendswith": "LIKE CONCAT('%%', {})",
}
isolation_levels = {
- 'read uncommitted',
- 'read committed',
- 'repeatable read',
- 'serializable',
+ "read uncommitted",
+ "read committed",
+ "repeatable read",
+ "serializable",
}
Database = Database
@@ -192,37 +202,39 @@ class DatabaseWrapper(BaseDatabaseWrapper):
def get_connection_params(self):
kwargs = {
- 'conv': django_conversions,
- 'charset': 'utf8',
+ "conv": django_conversions,
+ "charset": "utf8",
}
settings_dict = self.settings_dict
- if settings_dict['USER']:
- kwargs['user'] = settings_dict['USER']
- if settings_dict['NAME']:
- kwargs['database'] = settings_dict['NAME']
- if settings_dict['PASSWORD']:
- kwargs['password'] = settings_dict['PASSWORD']
- if settings_dict['HOST'].startswith('/'):
- kwargs['unix_socket'] = settings_dict['HOST']
- elif settings_dict['HOST']:
- kwargs['host'] = settings_dict['HOST']
- if settings_dict['PORT']:
- kwargs['port'] = int(settings_dict['PORT'])
+ if settings_dict["USER"]:
+ kwargs["user"] = settings_dict["USER"]
+ if settings_dict["NAME"]:
+ kwargs["database"] = settings_dict["NAME"]
+ if settings_dict["PASSWORD"]:
+ kwargs["password"] = settings_dict["PASSWORD"]
+ if settings_dict["HOST"].startswith("/"):
+ kwargs["unix_socket"] = settings_dict["HOST"]
+ elif settings_dict["HOST"]:
+ kwargs["host"] = settings_dict["HOST"]
+ if settings_dict["PORT"]:
+ kwargs["port"] = int(settings_dict["PORT"])
# We need the number of potentially affected rows after an
# "UPDATE", not the number of changed rows.
- kwargs['client_flag'] = CLIENT.FOUND_ROWS
+ kwargs["client_flag"] = CLIENT.FOUND_ROWS
# Validate the transaction isolation level, if specified.
- options = settings_dict['OPTIONS'].copy()
- isolation_level = options.pop('isolation_level', 'read committed')
+ options = settings_dict["OPTIONS"].copy()
+ isolation_level = options.pop("isolation_level", "read committed")
if isolation_level:
isolation_level = isolation_level.lower()
if isolation_level not in self.isolation_levels:
raise ImproperlyConfigured(
"Invalid transaction isolation level '%s' specified.\n"
- "Use one of %s, or None." % (
+ "Use one of %s, or None."
+ % (
isolation_level,
- ', '.join("'%s'" % s for s in sorted(self.isolation_levels))
- ))
+ ", ".join("'%s'" % s for s in sorted(self.isolation_levels)),
+ )
+ )
self.isolation_level = isolation_level
kwargs.update(options)
return kwargs
@@ -245,14 +257,17 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# a recently inserted row will return when the field is tested
# for NULL. Disabling this brings this aspect of MySQL in line
# with SQL standards.
- assignments.append('SET SQL_AUTO_IS_NULL = 0')
+ assignments.append("SET SQL_AUTO_IS_NULL = 0")
if self.isolation_level:
- assignments.append('SET SESSION TRANSACTION ISOLATION LEVEL %s' % self.isolation_level.upper())
+ assignments.append(
+ "SET SESSION TRANSACTION ISOLATION LEVEL %s"
+ % self.isolation_level.upper()
+ )
if assignments:
with self.cursor() as cursor:
- cursor.execute('; '.join(assignments))
+ cursor.execute("; ".join(assignments))
@async_unsafe
def create_cursor(self, name=None):
@@ -276,7 +291,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
need to be re-enabled.
"""
with self.cursor() as cursor:
- cursor.execute('SET foreign_key_checks=0')
+ cursor.execute("SET foreign_key_checks=0")
return True
def enable_constraint_checking(self):
@@ -288,7 +303,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
self.needs_rollback, needs_rollback = False, self.needs_rollback
try:
with self.cursor() as cursor:
- cursor.execute('SET foreign_key_checks=1')
+ cursor.execute("SET foreign_key_checks=1")
finally:
self.needs_rollback = needs_rollback
@@ -304,21 +319,32 @@ class DatabaseWrapper(BaseDatabaseWrapper):
if table_names is None:
table_names = self.introspection.table_names(cursor)
for table_name in table_names:
- primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
+ primary_key_column_name = self.introspection.get_primary_key_column(
+ cursor, table_name
+ )
if not primary_key_column_name:
continue
relations = self.introspection.get_relations(cursor, table_name)
- for column_name, (referenced_column_name, referenced_table_name) in relations.items():
+ for column_name, (
+ referenced_column_name,
+ referenced_table_name,
+ ) in relations.items():
cursor.execute(
"""
SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
LEFT JOIN `%s` as REFERRED
ON (REFERRING.`%s` = REFERRED.`%s`)
WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL
- """ % (
- primary_key_column_name, column_name, table_name,
- referenced_table_name, column_name, referenced_column_name,
- column_name, referenced_column_name,
+ """
+ % (
+ primary_key_column_name,
+ column_name,
+ table_name,
+ referenced_table_name,
+ column_name,
+ referenced_column_name,
+ column_name,
+ referenced_column_name,
)
)
for bad_row in cursor.fetchall():
@@ -327,8 +353,13 @@ class DatabaseWrapper(BaseDatabaseWrapper):
"foreign key: %s.%s contains a value '%s' that does not "
"have a corresponding value in %s.%s."
% (
- table_name, bad_row[0], table_name, column_name,
- bad_row[1], referenced_table_name, referenced_column_name,
+ table_name,
+ bad_row[0],
+ table_name,
+ column_name,
+ bad_row[1],
+ referenced_table_name,
+ referenced_column_name,
)
)
@@ -342,20 +373,20 @@ class DatabaseWrapper(BaseDatabaseWrapper):
@cached_property
def display_name(self):
- return 'MariaDB' if self.mysql_is_mariadb else 'MySQL'
+ return "MariaDB" if self.mysql_is_mariadb else "MySQL"
@cached_property
def data_type_check_constraints(self):
if self.features.supports_column_check_constraints:
check_constraints = {
- 'PositiveBigIntegerField': '`%(column)s` >= 0',
- 'PositiveIntegerField': '`%(column)s` >= 0',
- 'PositiveSmallIntegerField': '`%(column)s` >= 0',
+ "PositiveBigIntegerField": "`%(column)s` >= 0",
+ "PositiveIntegerField": "`%(column)s` >= 0",
+ "PositiveSmallIntegerField": "`%(column)s` >= 0",
}
if self.mysql_is_mariadb and self.mysql_version < (10, 4, 3):
# MariaDB < 10.4.3 doesn't automatically use the JSON_VALID as
# a check constraint.
- check_constraints['JSONField'] = 'JSON_VALID(`%(column)s`)'
+ check_constraints["JSONField"] = "JSON_VALID(`%(column)s`)"
return check_constraints
return {}
@@ -365,40 +396,45 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# Select some server variables and test if the time zone
# definitions are installed. CONVERT_TZ returns NULL if 'UTC'
# timezone isn't loaded into the mysql.time_zone table.
- cursor.execute("""
+ cursor.execute(
+ """
SELECT VERSION(),
@@sql_mode,
@@default_storage_engine,
@@sql_auto_is_null,
@@lower_case_table_names,
CONVERT_TZ('2001-01-01 01:00:00', 'UTC', 'UTC') IS NOT NULL
- """)
+ """
+ )
row = cursor.fetchone()
return {
- 'version': row[0],
- 'sql_mode': row[1],
- 'default_storage_engine': row[2],
- 'sql_auto_is_null': bool(row[3]),
- 'lower_case_table_names': bool(row[4]),
- 'has_zoneinfo_database': bool(row[5]),
+ "version": row[0],
+ "sql_mode": row[1],
+ "default_storage_engine": row[2],
+ "sql_auto_is_null": bool(row[3]),
+ "lower_case_table_names": bool(row[4]),
+ "has_zoneinfo_database": bool(row[5]),
}
@cached_property
def mysql_server_info(self):
- return self.mysql_server_data['version']
+ return self.mysql_server_data["version"]
@cached_property
def mysql_version(self):
match = server_version_re.match(self.mysql_server_info)
if not match:
- raise Exception('Unable to determine MySQL version from version string %r' % self.mysql_server_info)
+ raise Exception(
+ "Unable to determine MySQL version from version string %r"
+ % self.mysql_server_info
+ )
return tuple(int(x) for x in match.groups())
@cached_property
def mysql_is_mariadb(self):
- return 'mariadb' in self.mysql_server_info.lower()
+ return "mariadb" in self.mysql_server_info.lower()
@cached_property
def sql_mode(self):
- sql_mode = self.mysql_server_data['sql_mode']
- return set(sql_mode.split(',') if sql_mode else ())
+ sql_mode = self.mysql_server_data["sql_mode"]
+ return set(sql_mode.split(",") if sql_mode else ())
diff --git a/django/db/backends/mysql/client.py b/django/db/backends/mysql/client.py
index 7cbe314afe..0c09a2ca1e 100644
--- a/django/db/backends/mysql/client.py
+++ b/django/db/backends/mysql/client.py
@@ -2,28 +2,28 @@ from django.db.backends.base.client import BaseDatabaseClient
class DatabaseClient(BaseDatabaseClient):
- executable_name = 'mysql'
+ executable_name = "mysql"
@classmethod
def settings_to_cmd_args_env(cls, settings_dict, parameters):
args = [cls.executable_name]
env = None
- database = settings_dict['OPTIONS'].get(
- 'database',
- settings_dict['OPTIONS'].get('db', settings_dict['NAME']),
+ database = settings_dict["OPTIONS"].get(
+ "database",
+ settings_dict["OPTIONS"].get("db", settings_dict["NAME"]),
)
- user = settings_dict['OPTIONS'].get('user', settings_dict['USER'])
- password = settings_dict['OPTIONS'].get(
- 'password',
- settings_dict['OPTIONS'].get('passwd', settings_dict['PASSWORD'])
+ user = settings_dict["OPTIONS"].get("user", settings_dict["USER"])
+ password = settings_dict["OPTIONS"].get(
+ "password",
+ settings_dict["OPTIONS"].get("passwd", settings_dict["PASSWORD"]),
)
- host = settings_dict['OPTIONS'].get('host', settings_dict['HOST'])
- port = settings_dict['OPTIONS'].get('port', settings_dict['PORT'])
- server_ca = settings_dict['OPTIONS'].get('ssl', {}).get('ca')
- client_cert = settings_dict['OPTIONS'].get('ssl', {}).get('cert')
- client_key = settings_dict['OPTIONS'].get('ssl', {}).get('key')
- defaults_file = settings_dict['OPTIONS'].get('read_default_file')
- charset = settings_dict['OPTIONS'].get('charset')
+ host = settings_dict["OPTIONS"].get("host", settings_dict["HOST"])
+ port = settings_dict["OPTIONS"].get("port", settings_dict["PORT"])
+ server_ca = settings_dict["OPTIONS"].get("ssl", {}).get("ca")
+ client_cert = settings_dict["OPTIONS"].get("ssl", {}).get("cert")
+ client_key = settings_dict["OPTIONS"].get("ssl", {}).get("key")
+ defaults_file = settings_dict["OPTIONS"].get("read_default_file")
+ charset = settings_dict["OPTIONS"].get("charset")
# Seems to be no good way to set sql_mode with CLI.
if defaults_file:
@@ -38,9 +38,9 @@ class DatabaseClient(BaseDatabaseClient):
# prevents password exposure if the subprocess.run(check=True) call
# raises a CalledProcessError since the string representation of
# the latter includes all of the provided `args`.
- env = {'MYSQL_PWD': password}
+ env = {"MYSQL_PWD": password}
if host:
- if '/' in host:
+ if "/" in host:
args += ["--socket=%s" % host]
else:
args += ["--host=%s" % host]
@@ -53,7 +53,7 @@ class DatabaseClient(BaseDatabaseClient):
if client_key:
args += ["--ssl-key=%s" % client_key]
if charset:
- args += ['--default-character-set=%s' % charset]
+ args += ["--default-character-set=%s" % charset]
if database:
args += [database]
args.extend(parameters)
diff --git a/django/db/backends/mysql/compiler.py b/django/db/backends/mysql/compiler.py
index 49b47961a1..a8ab03a55e 100644
--- a/django/db/backends/mysql/compiler.py
+++ b/django/db/backends/mysql/compiler.py
@@ -8,7 +8,14 @@ class SQLCompiler(compiler.SQLCompiler):
qn = compiler.quote_name_unless_alias
qn2 = self.connection.ops.quote_name
sql, params = self.as_sql()
- return '(%s) IN (%s)' % (', '.join('%s.%s' % (qn(alias), qn2(column)) for column in columns), sql), params
+ return (
+ "(%s) IN (%s)"
+ % (
+ ", ".join("%s.%s" % (qn(alias), qn2(column)) for column in columns),
+ sql,
+ ),
+ params,
+ )
class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler):
@@ -27,16 +34,15 @@ class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
# since it doesn't allow for GROUP BY and HAVING clauses.
return super().as_sql()
result = [
- 'DELETE %s FROM' % self.quote_name_unless_alias(
- self.query.get_initial_alias()
- )
+ "DELETE %s FROM"
+ % self.quote_name_unless_alias(self.query.get_initial_alias())
]
from_sql, from_params = self.get_from_clause()
result.extend(from_sql)
where_sql, where_params = self.compile(where)
if where_sql:
- result.append('WHERE %s' % where_sql)
- return ' '.join(result), tuple(from_params) + tuple(where_params)
+ result.append("WHERE %s" % where_sql)
+ return " ".join(result), tuple(from_params) + tuple(where_params)
class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
@@ -50,15 +56,15 @@ class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
try:
for resolved, (sql, params, _) in self.get_order_by():
if (
- isinstance(resolved.expression, Col) and
- resolved.expression.alias != db_table
+ isinstance(resolved.expression, Col)
+ and resolved.expression.alias != db_table
):
# Ignore ordering if it contains joined fields, because
# they cannot be used in the ORDER BY clause.
raise FieldError
order_by_sql.append(sql)
order_by_params.extend(params)
- update_query += ' ORDER BY ' + ', '.join(order_by_sql)
+ update_query += " ORDER BY " + ", ".join(order_by_sql)
update_params += tuple(order_by_params)
except FieldError:
# Ignore ordering if it contains annotations, because they're
diff --git a/django/db/backends/mysql/creation.py b/django/db/backends/mysql/creation.py
index 1f0261b667..a060f41d18 100644
--- a/django/db/backends/mysql/creation.py
+++ b/django/db/backends/mysql/creation.py
@@ -8,15 +8,14 @@ from .client import DatabaseClient
class DatabaseCreation(BaseDatabaseCreation):
-
def sql_table_creation_suffix(self):
suffix = []
- test_settings = self.connection.settings_dict['TEST']
- if test_settings['CHARSET']:
- suffix.append('CHARACTER SET %s' % test_settings['CHARSET'])
- if test_settings['COLLATION']:
- suffix.append('COLLATE %s' % test_settings['COLLATION'])
- return ' '.join(suffix)
+ test_settings = self.connection.settings_dict["TEST"]
+ if test_settings["CHARSET"]:
+ suffix.append("CHARACTER SET %s" % test_settings["CHARSET"])
+ if test_settings["COLLATION"]:
+ suffix.append("COLLATE %s" % test_settings["COLLATION"])
+ return " ".join(suffix)
def _execute_create_test_db(self, cursor, parameters, keepdb=False):
try:
@@ -24,17 +23,17 @@ class DatabaseCreation(BaseDatabaseCreation):
except Exception as e:
if len(e.args) < 1 or e.args[0] != 1007:
# All errors except "database exists" (1007) cancel tests.
- self.log('Got an error creating the test database: %s' % e)
+ self.log("Got an error creating the test database: %s" % e)
sys.exit(2)
else:
raise
def _clone_test_db(self, suffix, verbosity, keepdb=False):
- source_database_name = self.connection.settings_dict['NAME']
- target_database_name = self.get_test_db_clone_settings(suffix)['NAME']
+ source_database_name = self.connection.settings_dict["NAME"]
+ target_database_name = self.get_test_db_clone_settings(suffix)["NAME"]
test_db_params = {
- 'dbname': self.connection.ops.quote_name(target_database_name),
- 'suffix': self.sql_table_creation_suffix(),
+ "dbname": self.connection.ops.quote_name(target_database_name),
+ "suffix": self.sql_table_creation_suffix(),
}
with self._nodb_cursor() as cursor:
try:
@@ -45,24 +44,44 @@ class DatabaseCreation(BaseDatabaseCreation):
return
try:
if verbosity >= 1:
- self.log('Destroying old test database for alias %s...' % (
- self._get_database_display_str(verbosity, target_database_name),
- ))
- cursor.execute('DROP DATABASE %(dbname)s' % test_db_params)
+ self.log(
+ "Destroying old test database for alias %s..."
+ % (
+ self._get_database_display_str(
+ verbosity, target_database_name
+ ),
+ )
+ )
+ cursor.execute("DROP DATABASE %(dbname)s" % test_db_params)
self._execute_create_test_db(cursor, test_db_params, keepdb)
except Exception as e:
- self.log('Got an error recreating the test database: %s' % e)
+ self.log("Got an error recreating the test database: %s" % e)
sys.exit(2)
self._clone_db(source_database_name, target_database_name)
def _clone_db(self, source_database_name, target_database_name):
- cmd_args, cmd_env = DatabaseClient.settings_to_cmd_args_env(self.connection.settings_dict, [])
- dump_cmd = ['mysqldump', *cmd_args[1:-1], '--routines', '--events', source_database_name]
+ cmd_args, cmd_env = DatabaseClient.settings_to_cmd_args_env(
+ self.connection.settings_dict, []
+ )
+ dump_cmd = [
+ "mysqldump",
+ *cmd_args[1:-1],
+ "--routines",
+ "--events",
+ source_database_name,
+ ]
dump_env = load_env = {**os.environ, **cmd_env} if cmd_env else None
load_cmd = cmd_args
load_cmd[-1] = target_database_name
- with subprocess.Popen(dump_cmd, stdout=subprocess.PIPE, env=dump_env) as dump_proc:
- with subprocess.Popen(load_cmd, stdin=dump_proc.stdout, stdout=subprocess.DEVNULL, env=load_env):
+ with subprocess.Popen(
+ dump_cmd, stdout=subprocess.PIPE, env=dump_env
+ ) as dump_proc:
+ with subprocess.Popen(
+ load_cmd,
+ stdin=dump_proc.stdout,
+ stdout=subprocess.DEVNULL,
+ env=load_env,
+ ):
# Allow dump_proc to receive a SIGPIPE if the load process exits.
dump_proc.stdout.close()
diff --git a/django/db/backends/mysql/features.py b/django/db/backends/mysql/features.py
index 5d6c4afde0..d485d40d60 100644
--- a/django/db/backends/mysql/features.py
+++ b/django/db/backends/mysql/features.py
@@ -50,87 +50,104 @@ class DatabaseFeatures(BaseDatabaseFeatures):
@cached_property
def test_collations(self):
- charset = 'utf8'
- if self.connection.mysql_is_mariadb and self.connection.mysql_version >= (10, 6):
+ charset = "utf8"
+ if self.connection.mysql_is_mariadb and self.connection.mysql_version >= (
+ 10,
+ 6,
+ ):
# utf8 is an alias for utf8mb3 in MariaDB 10.6+.
- charset = 'utf8mb3'
+ charset = "utf8mb3"
return {
- 'ci': f'{charset}_general_ci',
- 'non_default': f'{charset}_esperanto_ci',
- 'swedish_ci': f'{charset}_swedish_ci',
+ "ci": f"{charset}_general_ci",
+ "non_default": f"{charset}_esperanto_ci",
+ "swedish_ci": f"{charset}_swedish_ci",
}
- test_now_utc_template = 'UTC_TIMESTAMP'
+ test_now_utc_template = "UTC_TIMESTAMP"
@cached_property
def django_test_skips(self):
skips = {
"This doesn't work on MySQL.": {
- 'db_functions.comparison.test_greatest.GreatestTests.test_coalesce_workaround',
- 'db_functions.comparison.test_least.LeastTests.test_coalesce_workaround',
+ "db_functions.comparison.test_greatest.GreatestTests.test_coalesce_workaround",
+ "db_functions.comparison.test_least.LeastTests.test_coalesce_workaround",
},
- 'Running on MySQL requires utf8mb4 encoding (#18392).': {
- 'model_fields.test_textfield.TextFieldTests.test_emoji',
- 'model_fields.test_charfield.TestCharField.test_emoji',
+ "Running on MySQL requires utf8mb4 encoding (#18392).": {
+ "model_fields.test_textfield.TextFieldTests.test_emoji",
+ "model_fields.test_charfield.TestCharField.test_emoji",
},
"MySQL doesn't support functional indexes on a function that "
"returns JSON": {
- 'schema.tests.SchemaTests.test_func_index_json_key_transform',
+ "schema.tests.SchemaTests.test_func_index_json_key_transform",
},
"MySQL supports multiplying and dividing DurationFields by a "
"scalar value but it's not implemented (#25287).": {
- 'expressions.tests.FTimeDeltaTests.test_durationfield_multiply_divide',
+ "expressions.tests.FTimeDeltaTests.test_durationfield_multiply_divide",
},
}
- if 'ONLY_FULL_GROUP_BY' in self.connection.sql_mode:
- skips.update({
- 'GROUP BY optimization does not work properly when '
- 'ONLY_FULL_GROUP_BY mode is enabled on MySQL, see #31331.': {
- 'aggregation.tests.AggregateTestCase.test_aggregation_subquery_annotation_multivalued',
- 'annotations.tests.NonAggregateAnnotationTestCase.test_annotation_aggregate_with_m2o',
- },
- })
- if not self.connection.mysql_is_mariadb and self.connection.mysql_version < (8,):
- skips.update({
- 'Casting to datetime/time is not supported by MySQL < 8.0. (#30224)': {
- 'aggregation.tests.AggregateTestCase.test_aggregation_default_using_time_from_python',
- 'aggregation.tests.AggregateTestCase.test_aggregation_default_using_datetime_from_python',
- },
- 'MySQL < 8.0 returns string type instead of datetime/time. (#30224)': {
- 'aggregation.tests.AggregateTestCase.test_aggregation_default_using_time_from_database',
- 'aggregation.tests.AggregateTestCase.test_aggregation_default_using_datetime_from_database',
- },
- })
- if (
- self.connection.mysql_is_mariadb and
- (10, 4, 3) < self.connection.mysql_version < (10, 5, 2)
- ):
- skips.update({
- 'https://jira.mariadb.org/browse/MDEV-19598': {
- 'schema.tests.SchemaTests.test_alter_not_unique_field_to_primary_key',
- },
- })
- if (
- self.connection.mysql_is_mariadb and
- (10, 4, 12) < self.connection.mysql_version < (10, 5)
+ if "ONLY_FULL_GROUP_BY" in self.connection.sql_mode:
+ skips.update(
+ {
+ "GROUP BY optimization does not work properly when "
+ "ONLY_FULL_GROUP_BY mode is enabled on MySQL, see #31331.": {
+ "aggregation.tests.AggregateTestCase.test_aggregation_subquery_annotation_multivalued",
+ "annotations.tests.NonAggregateAnnotationTestCase.test_annotation_aggregate_with_m2o",
+ },
+ }
+ )
+ if not self.connection.mysql_is_mariadb and self.connection.mysql_version < (
+ 8,
):
- skips.update({
- 'https://jira.mariadb.org/browse/MDEV-22775': {
- 'schema.tests.SchemaTests.test_alter_pk_with_self_referential_field',
- },
- })
+ skips.update(
+ {
+ "Casting to datetime/time is not supported by MySQL < 8.0. (#30224)": {
+ "aggregation.tests.AggregateTestCase.test_aggregation_default_using_time_from_python",
+ "aggregation.tests.AggregateTestCase.test_aggregation_default_using_datetime_from_python",
+ },
+ "MySQL < 8.0 returns string type instead of datetime/time. (#30224)": {
+ "aggregation.tests.AggregateTestCase.test_aggregation_default_using_time_from_database",
+ "aggregation.tests.AggregateTestCase.test_aggregation_default_using_datetime_from_database",
+ },
+ }
+ )
+ if self.connection.mysql_is_mariadb and (
+ 10,
+ 4,
+ 3,
+ ) < self.connection.mysql_version < (10, 5, 2):
+ skips.update(
+ {
+ "https://jira.mariadb.org/browse/MDEV-19598": {
+ "schema.tests.SchemaTests.test_alter_not_unique_field_to_primary_key",
+ },
+ }
+ )
+ if self.connection.mysql_is_mariadb and (
+ 10,
+ 4,
+ 12,
+ ) < self.connection.mysql_version < (10, 5):
+ skips.update(
+ {
+ "https://jira.mariadb.org/browse/MDEV-22775": {
+ "schema.tests.SchemaTests.test_alter_pk_with_self_referential_field",
+ },
+ }
+ )
if not self.supports_explain_analyze:
- skips.update({
- 'MariaDB and MySQL >= 8.0.18 specific.': {
- 'queries.test_explain.ExplainTests.test_mysql_analyze',
- },
- })
+ skips.update(
+ {
+ "MariaDB and MySQL >= 8.0.18 specific.": {
+ "queries.test_explain.ExplainTests.test_mysql_analyze",
+ },
+ }
+ )
return skips
@cached_property
def _mysql_storage_engine(self):
"Internal method used in Django tests. Don't rely on this from your code"
- return self.connection.mysql_server_data['default_storage_engine']
+ return self.connection.mysql_server_data["default_storage_engine"]
@cached_property
def allows_auto_pk_0(self):
@@ -138,40 +155,50 @@ class DatabaseFeatures(BaseDatabaseFeatures):
Autoincrement primary key can be set to 0 if it doesn't generate new
autoincrement values.
"""
- return 'NO_AUTO_VALUE_ON_ZERO' in self.connection.sql_mode
+ return "NO_AUTO_VALUE_ON_ZERO" in self.connection.sql_mode
@cached_property
def update_can_self_select(self):
- return self.connection.mysql_is_mariadb and self.connection.mysql_version >= (10, 3, 2)
+ return self.connection.mysql_is_mariadb and self.connection.mysql_version >= (
+ 10,
+ 3,
+ 2,
+ )
@cached_property
def can_introspect_foreign_keys(self):
"Confirm support for introspected foreign keys"
- return self._mysql_storage_engine != 'MyISAM'
+ return self._mysql_storage_engine != "MyISAM"
@cached_property
def introspected_field_types(self):
return {
**super().introspected_field_types,
- 'BinaryField': 'TextField',
- 'BooleanField': 'IntegerField',
- 'DurationField': 'BigIntegerField',
- 'GenericIPAddressField': 'CharField',
+ "BinaryField": "TextField",
+ "BooleanField": "IntegerField",
+ "DurationField": "BigIntegerField",
+ "GenericIPAddressField": "CharField",
}
@cached_property
def can_return_columns_from_insert(self):
- return self.connection.mysql_is_mariadb and self.connection.mysql_version >= (10, 5, 0)
+ return self.connection.mysql_is_mariadb and self.connection.mysql_version >= (
+ 10,
+ 5,
+ 0,
+ )
- can_return_rows_from_bulk_insert = property(operator.attrgetter('can_return_columns_from_insert'))
+ can_return_rows_from_bulk_insert = property(
+ operator.attrgetter("can_return_columns_from_insert")
+ )
@cached_property
def has_zoneinfo_database(self):
- return self.connection.mysql_server_data['has_zoneinfo_database']
+ return self.connection.mysql_server_data["has_zoneinfo_database"]
@cached_property
def is_sql_auto_is_null_enabled(self):
- return self.connection.mysql_server_data['sql_auto_is_null']
+ return self.connection.mysql_server_data["sql_auto_is_null"]
@cached_property
def supports_over_clause(self):
@@ -179,7 +206,9 @@ class DatabaseFeatures(BaseDatabaseFeatures):
return True
return self.connection.mysql_version >= (8, 0, 2)
- supports_frame_range_fixed_distance = property(operator.attrgetter('supports_over_clause'))
+ supports_frame_range_fixed_distance = property(
+ operator.attrgetter("supports_over_clause")
+ )
@cached_property
def supports_column_check_constraints(self):
@@ -187,7 +216,9 @@ class DatabaseFeatures(BaseDatabaseFeatures):
return True
return self.connection.mysql_version >= (8, 0, 16)
- supports_table_check_constraints = property(operator.attrgetter('supports_column_check_constraints'))
+ supports_table_check_constraints = property(
+ operator.attrgetter("supports_column_check_constraints")
+ )
@cached_property
def can_introspect_check_constraints(self):
@@ -210,19 +241,30 @@ class DatabaseFeatures(BaseDatabaseFeatures):
@cached_property
def has_select_for_update_of(self):
- return not self.connection.mysql_is_mariadb and self.connection.mysql_version >= (8, 0, 1)
+ return (
+ not self.connection.mysql_is_mariadb
+ and self.connection.mysql_version >= (8, 0, 1)
+ )
@cached_property
def supports_explain_analyze(self):
- return self.connection.mysql_is_mariadb or self.connection.mysql_version >= (8, 0, 18)
+ return self.connection.mysql_is_mariadb or self.connection.mysql_version >= (
+ 8,
+ 0,
+ 18,
+ )
@cached_property
def supported_explain_formats(self):
# Alias MySQL's TRADITIONAL to TEXT for consistency with other
# backends.
- formats = {'JSON', 'TEXT', 'TRADITIONAL'}
- if not self.connection.mysql_is_mariadb and self.connection.mysql_version >= (8, 0, 16):
- formats.add('TREE')
+ formats = {"JSON", "TEXT", "TRADITIONAL"}
+ if not self.connection.mysql_is_mariadb and self.connection.mysql_version >= (
+ 8,
+ 0,
+ 16,
+ ):
+ formats.add("TREE")
return formats
@cached_property
@@ -230,11 +272,11 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"""
All storage engines except MyISAM support transactions.
"""
- return self._mysql_storage_engine != 'MyISAM'
+ return self._mysql_storage_engine != "MyISAM"
@cached_property
def ignores_table_name_case(self):
- return self.connection.mysql_server_data['lower_case_table_names']
+ return self.connection.mysql_server_data["lower_case_table_names"]
@cached_property
def supports_default_in_lead_lag(self):
@@ -256,13 +298,13 @@ class DatabaseFeatures(BaseDatabaseFeatures):
@cached_property
def supports_index_column_ordering(self):
return (
- not self.connection.mysql_is_mariadb and
- self.connection.mysql_version >= (8, 0, 1)
+ not self.connection.mysql_is_mariadb
+ and self.connection.mysql_version >= (8, 0, 1)
)
@cached_property
def supports_expression_indexes(self):
return (
- not self.connection.mysql_is_mariadb and
- self.connection.mysql_version >= (8, 0, 13)
+ not self.connection.mysql_is_mariadb
+ and self.connection.mysql_version >= (8, 0, 13)
)
diff --git a/django/db/backends/mysql/introspection.py b/django/db/backends/mysql/introspection.py
index 3a76168227..3cf56dffce 100644
--- a/django/db/backends/mysql/introspection.py
+++ b/django/db/backends/mysql/introspection.py
@@ -3,72 +3,76 @@ from collections import namedtuple
import sqlparse
from MySQLdb.constants import FIELD_TYPE
-from django.db.backends.base.introspection import (
- BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo,
-)
+from django.db.backends.base.introspection import BaseDatabaseIntrospection
+from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo
+from django.db.backends.base.introspection import TableInfo
from django.db.models import Index
from django.utils.datastructures import OrderedSet
-FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('extra', 'is_unsigned', 'has_json_constraint'))
+FieldInfo = namedtuple(
+ "FieldInfo", BaseFieldInfo._fields + ("extra", "is_unsigned", "has_json_constraint")
+)
InfoLine = namedtuple(
- 'InfoLine',
- 'col_name data_type max_len num_prec num_scale extra column_default '
- 'collation is_unsigned'
+ "InfoLine",
+ "col_name data_type max_len num_prec num_scale extra column_default "
+ "collation is_unsigned",
)
class DatabaseIntrospection(BaseDatabaseIntrospection):
data_types_reverse = {
- FIELD_TYPE.BLOB: 'TextField',
- FIELD_TYPE.CHAR: 'CharField',
- FIELD_TYPE.DECIMAL: 'DecimalField',
- FIELD_TYPE.NEWDECIMAL: 'DecimalField',
- FIELD_TYPE.DATE: 'DateField',
- FIELD_TYPE.DATETIME: 'DateTimeField',
- FIELD_TYPE.DOUBLE: 'FloatField',
- FIELD_TYPE.FLOAT: 'FloatField',
- FIELD_TYPE.INT24: 'IntegerField',
- FIELD_TYPE.JSON: 'JSONField',
- FIELD_TYPE.LONG: 'IntegerField',
- FIELD_TYPE.LONGLONG: 'BigIntegerField',
- FIELD_TYPE.SHORT: 'SmallIntegerField',
- FIELD_TYPE.STRING: 'CharField',
- FIELD_TYPE.TIME: 'TimeField',
- FIELD_TYPE.TIMESTAMP: 'DateTimeField',
- FIELD_TYPE.TINY: 'IntegerField',
- FIELD_TYPE.TINY_BLOB: 'TextField',
- FIELD_TYPE.MEDIUM_BLOB: 'TextField',
- FIELD_TYPE.LONG_BLOB: 'TextField',
- FIELD_TYPE.VAR_STRING: 'CharField',
+ FIELD_TYPE.BLOB: "TextField",
+ FIELD_TYPE.CHAR: "CharField",
+ FIELD_TYPE.DECIMAL: "DecimalField",
+ FIELD_TYPE.NEWDECIMAL: "DecimalField",
+ FIELD_TYPE.DATE: "DateField",
+ FIELD_TYPE.DATETIME: "DateTimeField",
+ FIELD_TYPE.DOUBLE: "FloatField",
+ FIELD_TYPE.FLOAT: "FloatField",
+ FIELD_TYPE.INT24: "IntegerField",
+ FIELD_TYPE.JSON: "JSONField",
+ FIELD_TYPE.LONG: "IntegerField",
+ FIELD_TYPE.LONGLONG: "BigIntegerField",
+ FIELD_TYPE.SHORT: "SmallIntegerField",
+ FIELD_TYPE.STRING: "CharField",
+ FIELD_TYPE.TIME: "TimeField",
+ FIELD_TYPE.TIMESTAMP: "DateTimeField",
+ FIELD_TYPE.TINY: "IntegerField",
+ FIELD_TYPE.TINY_BLOB: "TextField",
+ FIELD_TYPE.MEDIUM_BLOB: "TextField",
+ FIELD_TYPE.LONG_BLOB: "TextField",
+ FIELD_TYPE.VAR_STRING: "CharField",
}
def get_field_type(self, data_type, description):
field_type = super().get_field_type(data_type, description)
- if 'auto_increment' in description.extra:
- if field_type == 'IntegerField':
- return 'AutoField'
- elif field_type == 'BigIntegerField':
- return 'BigAutoField'
- elif field_type == 'SmallIntegerField':
- return 'SmallAutoField'
+ if "auto_increment" in description.extra:
+ if field_type == "IntegerField":
+ return "AutoField"
+ elif field_type == "BigIntegerField":
+ return "BigAutoField"
+ elif field_type == "SmallIntegerField":
+ return "SmallAutoField"
if description.is_unsigned:
- if field_type == 'BigIntegerField':
- return 'PositiveBigIntegerField'
- elif field_type == 'IntegerField':
- return 'PositiveIntegerField'
- elif field_type == 'SmallIntegerField':
- return 'PositiveSmallIntegerField'
+ if field_type == "BigIntegerField":
+ return "PositiveBigIntegerField"
+ elif field_type == "IntegerField":
+ return "PositiveIntegerField"
+ elif field_type == "SmallIntegerField":
+ return "PositiveSmallIntegerField"
# JSON data type is an alias for LONGTEXT in MariaDB, use check
# constraints clauses to introspect JSONField.
if description.has_json_constraint:
- return 'JSONField'
+ return "JSONField"
return field_type
def get_table_list(self, cursor):
"""Return a list of table and view names in the current database."""
cursor.execute("SHOW FULL TABLES")
- return [TableInfo(row[0], {'BASE TABLE': 't', 'VIEW': 'v'}.get(row[1]))
- for row in cursor.fetchall()]
+ return [
+ TableInfo(row[0], {"BASE TABLE": "t", "VIEW": "v"}.get(row[1]))
+ for row in cursor.fetchall()
+ ]
def get_table_description(self, cursor, table_name):
"""
@@ -76,7 +80,10 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
interface."
"""
json_constraints = {}
- if self.connection.mysql_is_mariadb and self.connection.features.can_introspect_json_field:
+ if (
+ self.connection.mysql_is_mariadb
+ and self.connection.features.can_introspect_json_field
+ ):
# JSON data type is an alias for LONGTEXT in MariaDB, select
# JSON_VALID() constraints to introspect JSONField.
cursor.execute(
@@ -102,7 +109,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
[table_name],
)
row = cursor.fetchone()
- default_column_collation = row[0] if row else ''
+ default_column_collation = row[0] if row else ""
# information_schema database gives more accurate results for some figures:
# - varchar length returned by cursor.description is an internal length,
# not visible length (#5725)
@@ -128,7 +135,9 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
)
field_info = {line[0]: InfoLine(*line) for line in cursor.fetchall()}
- cursor.execute("SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name))
+ cursor.execute(
+ "SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name)
+ )
def to_int(i):
return int(i) if i is not None else i
@@ -136,25 +145,27 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
fields = []
for line in cursor.description:
info = field_info[line[0]]
- fields.append(FieldInfo(
- *line[:3],
- to_int(info.max_len) or line[3],
- to_int(info.num_prec) or line[4],
- to_int(info.num_scale) or line[5],
- line[6],
- info.column_default,
- info.collation,
- info.extra,
- info.is_unsigned,
- line[0] in json_constraints,
- ))
+ fields.append(
+ FieldInfo(
+ *line[:3],
+ to_int(info.max_len) or line[3],
+ to_int(info.num_prec) or line[4],
+ to_int(info.num_scale) or line[5],
+ line[6],
+ info.column_default,
+ info.collation,
+ info.extra,
+ info.is_unsigned,
+ line[0] in json_constraints,
+ )
+ )
return fields
def get_sequences(self, cursor, table_name, table_fields=()):
for field_info in self.get_table_description(cursor, table_name):
- if 'auto_increment' in field_info.extra:
+ if "auto_increment" in field_info.extra:
# MySQL allows only one auto-increment column per table.
- return [{'table': table_name, 'column': field_info.name}]
+ return [{"table": table_name, "column": field_info.name}]
return []
def get_relations(self, cursor, table_name):
@@ -204,9 +215,9 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
tokens = (token for token in statement.flatten() if not token.is_whitespace)
for token in tokens:
if (
- token.ttype == sqlparse.tokens.Name and
- self.connection.ops.quote_name(token.value) == token.value and
- token.value[1:-1] in columns
+ token.ttype == sqlparse.tokens.Name
+ and self.connection.ops.quote_name(token.value) == token.value
+ and token.value[1:-1] in columns
):
check_columns.add(token.value[1:-1])
return check_columns
@@ -237,20 +248,22 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
for constraint, column, ref_table, ref_column, kind in cursor.fetchall():
if constraint not in constraints:
constraints[constraint] = {
- 'columns': OrderedSet(),
- 'primary_key': kind == 'PRIMARY KEY',
- 'unique': kind in {'PRIMARY KEY', 'UNIQUE'},
- 'index': False,
- 'check': False,
- 'foreign_key': (ref_table, ref_column) if ref_column else None,
+ "columns": OrderedSet(),
+ "primary_key": kind == "PRIMARY KEY",
+ "unique": kind in {"PRIMARY KEY", "UNIQUE"},
+ "index": False,
+ "check": False,
+ "foreign_key": (ref_table, ref_column) if ref_column else None,
}
if self.connection.features.supports_index_column_ordering:
- constraints[constraint]['orders'] = []
- constraints[constraint]['columns'].add(column)
+ constraints[constraint]["orders"] = []
+ constraints[constraint]["columns"].add(column)
# Add check constraints.
if self.connection.features.can_introspect_check_constraints:
unnamed_constraints_index = 0
- columns = {info.name for info in self.get_table_description(cursor, table_name)}
+ columns = {
+ info.name for info in self.get_table_description(cursor, table_name)
+ }
if self.connection.mysql_is_mariadb:
type_query = """
SELECT c.constraint_name, c.check_clause
@@ -274,42 +287,48 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
"""
cursor.execute(type_query, [table_name])
for constraint, check_clause in cursor.fetchall():
- constraint_columns = self._parse_constraint_columns(check_clause, columns)
+ constraint_columns = self._parse_constraint_columns(
+ check_clause, columns
+ )
# Ensure uniqueness of unnamed constraints. Unnamed unique
# and check columns constraints have the same name as
# a column.
if set(constraint_columns) == {constraint}:
unnamed_constraints_index += 1
- constraint = '__unnamed_constraint_%s__' % unnamed_constraints_index
+ constraint = "__unnamed_constraint_%s__" % unnamed_constraints_index
constraints[constraint] = {
- 'columns': constraint_columns,
- 'primary_key': False,
- 'unique': False,
- 'index': False,
- 'check': True,
- 'foreign_key': None,
+ "columns": constraint_columns,
+ "primary_key": False,
+ "unique": False,
+ "index": False,
+ "check": True,
+ "foreign_key": None,
}
# Now add in the indexes
- cursor.execute("SHOW INDEX FROM %s" % self.connection.ops.quote_name(table_name))
+ cursor.execute(
+ "SHOW INDEX FROM %s" % self.connection.ops.quote_name(table_name)
+ )
for table, non_unique, index, colseq, column, order, type_ in [
x[:6] + (x[10],) for x in cursor.fetchall()
]:
if index not in constraints:
constraints[index] = {
- 'columns': OrderedSet(),
- 'primary_key': False,
- 'unique': not non_unique,
- 'check': False,
- 'foreign_key': None,
+ "columns": OrderedSet(),
+ "primary_key": False,
+ "unique": not non_unique,
+ "check": False,
+ "foreign_key": None,
}
if self.connection.features.supports_index_column_ordering:
- constraints[index]['orders'] = []
- constraints[index]['index'] = True
- constraints[index]['type'] = Index.suffix if type_ == 'BTREE' else type_.lower()
- constraints[index]['columns'].add(column)
+ constraints[index]["orders"] = []
+ constraints[index]["index"] = True
+ constraints[index]["type"] = (
+ Index.suffix if type_ == "BTREE" else type_.lower()
+ )
+ constraints[index]["columns"].add(column)
if self.connection.features.supports_index_column_ordering:
- constraints[index]['orders'].append('DESC' if order == 'D' else 'ASC')
+ constraints[index]["orders"].append("DESC" if order == "D" else "ASC")
# Convert the sorted sets to lists
for constraint in constraints.values():
- constraint['columns'] = list(constraint['columns'])
+ constraint["columns"] = list(constraint["columns"])
return constraints
diff --git a/django/db/backends/mysql/operations.py b/django/db/backends/mysql/operations.py
index 7f1994e657..5bcc03a67b 100644
--- a/django/db/backends/mysql/operations.py
+++ b/django/db/backends/mysql/operations.py
@@ -15,42 +15,42 @@ class DatabaseOperations(BaseDatabaseOperations):
# MySQL stores positive fields as UNSIGNED ints.
integer_field_ranges = {
**BaseDatabaseOperations.integer_field_ranges,
- 'PositiveSmallIntegerField': (0, 65535),
- 'PositiveIntegerField': (0, 4294967295),
- 'PositiveBigIntegerField': (0, 18446744073709551615),
+ "PositiveSmallIntegerField": (0, 65535),
+ "PositiveIntegerField": (0, 4294967295),
+ "PositiveBigIntegerField": (0, 18446744073709551615),
}
cast_data_types = {
- 'AutoField': 'signed integer',
- 'BigAutoField': 'signed integer',
- 'SmallAutoField': 'signed integer',
- 'CharField': 'char(%(max_length)s)',
- 'DecimalField': 'decimal(%(max_digits)s, %(decimal_places)s)',
- 'TextField': 'char',
- 'IntegerField': 'signed integer',
- 'BigIntegerField': 'signed integer',
- 'SmallIntegerField': 'signed integer',
- 'PositiveBigIntegerField': 'unsigned integer',
- 'PositiveIntegerField': 'unsigned integer',
- 'PositiveSmallIntegerField': 'unsigned integer',
- 'DurationField': 'signed integer',
+ "AutoField": "signed integer",
+ "BigAutoField": "signed integer",
+ "SmallAutoField": "signed integer",
+ "CharField": "char(%(max_length)s)",
+ "DecimalField": "decimal(%(max_digits)s, %(decimal_places)s)",
+ "TextField": "char",
+ "IntegerField": "signed integer",
+ "BigIntegerField": "signed integer",
+ "SmallIntegerField": "signed integer",
+ "PositiveBigIntegerField": "unsigned integer",
+ "PositiveIntegerField": "unsigned integer",
+ "PositiveSmallIntegerField": "unsigned integer",
+ "DurationField": "signed integer",
}
- cast_char_field_without_max_length = 'char'
- explain_prefix = 'EXPLAIN'
+ cast_char_field_without_max_length = "char"
+ explain_prefix = "EXPLAIN"
def date_extract_sql(self, lookup_type, field_name):
# https://dev.mysql.com/doc/mysql/en/date-and-time-functions.html
- if lookup_type == 'week_day':
+ if lookup_type == "week_day":
# DAYOFWEEK() returns an integer, 1-7, Sunday=1.
return "DAYOFWEEK(%s)" % field_name
- elif lookup_type == 'iso_week_day':
+ elif lookup_type == "iso_week_day":
# WEEKDAY() returns an integer, 0-6, Monday=0.
return "WEEKDAY(%s) + 1" % field_name
- elif lookup_type == 'week':
+ elif lookup_type == "week":
# Override the value of default_week_format for consistency with
# other database backends.
# Mode 3: Monday, 1-53, with 4 or more days this year.
return "WEEK(%s, 3)" % field_name
- elif lookup_type == 'iso_year':
+ elif lookup_type == "iso_year":
# Get the year part from the YEARWEEK function, which returns a
# number as year * 100 + week.
return "TRUNCATE(YEARWEEK(%s, 3), -2) / 100" % field_name
@@ -61,26 +61,25 @@ class DatabaseOperations(BaseDatabaseOperations):
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
field_name = self._convert_field_to_tz(field_name, tzname)
fields = {
- 'year': '%%Y-01-01',
- 'month': '%%Y-%%m-01',
+ "year": "%%Y-01-01",
+ "month": "%%Y-%%m-01",
} # Use double percents to escape.
if lookup_type in fields:
format_str = fields[lookup_type]
return "CAST(DATE_FORMAT(%s, '%s') AS DATE)" % (field_name, format_str)
- elif lookup_type == 'quarter':
- return "MAKEDATE(YEAR(%s), 1) + INTERVAL QUARTER(%s) QUARTER - INTERVAL 1 QUARTER" % (
- field_name, field_name
- )
- elif lookup_type == 'week':
- return "DATE_SUB(%s, INTERVAL WEEKDAY(%s) DAY)" % (
- field_name, field_name
+ elif lookup_type == "quarter":
+ return (
+ "MAKEDATE(YEAR(%s), 1) + INTERVAL QUARTER(%s) QUARTER - INTERVAL 1 QUARTER"
+ % (field_name, field_name)
)
+ elif lookup_type == "week":
+ return "DATE_SUB(%s, INTERVAL WEEKDAY(%s) DAY)" % (field_name, field_name)
else:
return "DATE(%s)" % (field_name)
def _prepare_tzname_delta(self, tzname):
tzname, sign, offset = split_tzname_delta(tzname)
- return f'{sign}{offset}' if offset else tzname
+ return f"{sign}{offset}" if offset else tzname
def _convert_field_to_tz(self, field_name, tzname):
if tzname and settings.USE_TZ and self.connection.timezone_name != tzname:
@@ -105,16 +104,23 @@ class DatabaseOperations(BaseDatabaseOperations):
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
- fields = ['year', 'month', 'day', 'hour', 'minute', 'second']
- format = ('%%Y-', '%%m', '-%%d', ' %%H:', '%%i', ':%%s') # Use double percents to escape.
- format_def = ('0000-', '01', '-01', ' 00:', '00', ':00')
- if lookup_type == 'quarter':
+ fields = ["year", "month", "day", "hour", "minute", "second"]
+ format = (
+ "%%Y-",
+ "%%m",
+ "-%%d",
+ " %%H:",
+ "%%i",
+ ":%%s",
+ ) # Use double percents to escape.
+ format_def = ("0000-", "01", "-01", " 00:", "00", ":00")
+ if lookup_type == "quarter":
return (
"CAST(DATE_FORMAT(MAKEDATE(YEAR({field_name}), 1) + "
- "INTERVAL QUARTER({field_name}) QUARTER - " +
- "INTERVAL 1 QUARTER, '%%Y-%%m-01 00:00:00') AS DATETIME)"
+ "INTERVAL QUARTER({field_name}) QUARTER - "
+ + "INTERVAL 1 QUARTER, '%%Y-%%m-01 00:00:00') AS DATETIME)"
).format(field_name=field_name)
- if lookup_type == 'week':
+ if lookup_type == "week":
return (
"CAST(DATE_FORMAT(DATE_SUB({field_name}, "
"INTERVAL WEEKDAY({field_name}) DAY), "
@@ -125,16 +131,16 @@ class DatabaseOperations(BaseDatabaseOperations):
except ValueError:
sql = field_name
else:
- format_str = ''.join(format[:i] + format_def[i:])
+ format_str = "".join(format[:i] + format_def[i:])
sql = "CAST(DATE_FORMAT(%s, '%s') AS DATETIME)" % (field_name, format_str)
return sql
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
field_name = self._convert_field_to_tz(field_name, tzname)
fields = {
- 'hour': '%%H:00:00',
- 'minute': '%%H:%%i:00',
- 'second': '%%H:%%i:%%s',
+ "hour": "%%H:00:00",
+ "minute": "%%H:%%i:00",
+ "second": "%%H:%%i:%%s",
} # Use double percents to escape.
if lookup_type in fields:
format_str = fields[lookup_type]
@@ -150,7 +156,7 @@ class DatabaseOperations(BaseDatabaseOperations):
return cursor.fetchall()
def format_for_duration_arithmetic(self, sql):
- return 'INTERVAL %s MICROSECOND' % sql
+ return "INTERVAL %s MICROSECOND" % sql
def force_no_ordering(self):
"""
@@ -168,7 +174,7 @@ class DatabaseOperations(BaseDatabaseOperations):
# attribute where the exact query sent to the database is saved.
# See MySQLdb/cursors.py in the source distribution.
# MySQLdb returns string, PyMySQL bytes.
- return force_str(getattr(cursor, '_executed', None), errors='replace')
+ return force_str(getattr(cursor, "_executed", None), errors="replace")
def no_limit_value(self):
# 2**64 - 1, as recommended by the MySQL documentation
@@ -183,50 +189,58 @@ class DatabaseOperations(BaseDatabaseOperations):
# MySQL and MariaDB < 10.5.0 don't support an INSERT...RETURNING
# statement.
if not fields:
- return '', ()
+ return "", ()
columns = [
- '%s.%s' % (
+ "%s.%s"
+ % (
self.quote_name(field.model._meta.db_table),
self.quote_name(field.column),
- ) for field in fields
+ )
+ for field in fields
]
- return 'RETURNING %s' % ', '.join(columns), ()
+ return "RETURNING %s" % ", ".join(columns), ()
def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
if not tables:
return []
- sql = ['SET FOREIGN_KEY_CHECKS = 0;']
+ sql = ["SET FOREIGN_KEY_CHECKS = 0;"]
if reset_sequences:
# It's faster to TRUNCATE tables that require a sequence reset
# since ALTER TABLE AUTO_INCREMENT is slower than TRUNCATE.
sql.extend(
- '%s %s;' % (
- style.SQL_KEYWORD('TRUNCATE'),
+ "%s %s;"
+ % (
+ style.SQL_KEYWORD("TRUNCATE"),
style.SQL_FIELD(self.quote_name(table_name)),
- ) for table_name in tables
+ )
+ for table_name in tables
)
else:
# Otherwise issue a simple DELETE since it's faster than TRUNCATE
# and preserves sequences.
sql.extend(
- '%s %s %s;' % (
- style.SQL_KEYWORD('DELETE'),
- style.SQL_KEYWORD('FROM'),
+ "%s %s %s;"
+ % (
+ style.SQL_KEYWORD("DELETE"),
+ style.SQL_KEYWORD("FROM"),
style.SQL_FIELD(self.quote_name(table_name)),
- ) for table_name in tables
+ )
+ for table_name in tables
)
- sql.append('SET FOREIGN_KEY_CHECKS = 1;')
+ sql.append("SET FOREIGN_KEY_CHECKS = 1;")
return sql
def sequence_reset_by_name_sql(self, style, sequences):
return [
- '%s %s %s %s = 1;' % (
- style.SQL_KEYWORD('ALTER'),
- style.SQL_KEYWORD('TABLE'),
- style.SQL_FIELD(self.quote_name(sequence_info['table'])),
- style.SQL_FIELD('AUTO_INCREMENT'),
- ) for sequence_info in sequences
+ "%s %s %s %s = 1;"
+ % (
+ style.SQL_KEYWORD("ALTER"),
+ style.SQL_KEYWORD("TABLE"),
+ style.SQL_FIELD(self.quote_name(sequence_info["table"])),
+ style.SQL_FIELD("AUTO_INCREMENT"),
+ )
+ for sequence_info in sequences
]
def validate_autopk_value(self, value):
@@ -234,7 +248,7 @@ class DatabaseOperations(BaseDatabaseOperations):
# NO_AUTO_VALUE_ON_ZERO SQL mode.
if value == 0 and not self.connection.features.allows_auto_pk_0:
raise ValueError(
- 'The database backend does not accept 0 as a value for AutoField.'
+ "The database backend does not accept 0 as a value for AutoField."
)
return value
@@ -243,7 +257,7 @@ class DatabaseOperations(BaseDatabaseOperations):
return None
# Expression values are adapted by the database.
- if hasattr(value, 'resolve_expression'):
+ if hasattr(value, "resolve_expression"):
return value
# MySQL doesn't support tz-aware datetimes
@@ -251,7 +265,9 @@ class DatabaseOperations(BaseDatabaseOperations):
if settings.USE_TZ:
value = timezone.make_naive(value, self.connection.timezone)
else:
- raise ValueError("MySQL backend does not support timezone-aware datetimes when USE_TZ is False.")
+ raise ValueError(
+ "MySQL backend does not support timezone-aware datetimes when USE_TZ is False."
+ )
return str(value)
def adapt_timefield_value(self, value):
@@ -259,20 +275,20 @@ class DatabaseOperations(BaseDatabaseOperations):
return None
# Expression values are adapted by the database.
- if hasattr(value, 'resolve_expression'):
+ if hasattr(value, "resolve_expression"):
return value
# MySQL doesn't support tz-aware times
if timezone.is_aware(value):
raise ValueError("MySQL backend does not support timezone-aware times.")
- return value.isoformat(timespec='microseconds')
+ return value.isoformat(timespec="microseconds")
def max_name_length(self):
return 64
def pk_default_value(self):
- return 'NULL'
+ return "NULL"
def bulk_insert_sql(self, fields, placeholder_rows):
placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
@@ -280,27 +296,27 @@ class DatabaseOperations(BaseDatabaseOperations):
return "VALUES " + values_sql
def combine_expression(self, connector, sub_expressions):
- if connector == '^':
- return 'POW(%s)' % ','.join(sub_expressions)
+ if connector == "^":
+ return "POW(%s)" % ",".join(sub_expressions)
# Convert the result to a signed integer since MySQL's binary operators
# return an unsigned integer.
- elif connector in ('&', '|', '<<', '#'):
- connector = '^' if connector == '#' else connector
- return 'CONVERT(%s, SIGNED)' % connector.join(sub_expressions)
- elif connector == '>>':
+ elif connector in ("&", "|", "<<", "#"):
+ connector = "^" if connector == "#" else connector
+ return "CONVERT(%s, SIGNED)" % connector.join(sub_expressions)
+ elif connector == ">>":
lhs, rhs = sub_expressions
- return 'FLOOR(%(lhs)s / POW(2, %(rhs)s))' % {'lhs': lhs, 'rhs': rhs}
+ return "FLOOR(%(lhs)s / POW(2, %(rhs)s))" % {"lhs": lhs, "rhs": rhs}
return super().combine_expression(connector, sub_expressions)
def get_db_converters(self, expression):
converters = super().get_db_converters(expression)
internal_type = expression.output_field.get_internal_type()
- if internal_type == 'BooleanField':
+ if internal_type == "BooleanField":
converters.append(self.convert_booleanfield_value)
- elif internal_type == 'DateTimeField':
+ elif internal_type == "DateTimeField":
if settings.USE_TZ:
converters.append(self.convert_datetimefield_value)
- elif internal_type == 'UUIDField':
+ elif internal_type == "UUIDField":
converters.append(self.convert_uuidfield_value)
return converters
@@ -320,66 +336,88 @@ class DatabaseOperations(BaseDatabaseOperations):
return value
def binary_placeholder_sql(self, value):
- return '_binary %s' if value is not None and not hasattr(value, 'as_sql') else '%s'
+ return (
+ "_binary %s" if value is not None and not hasattr(value, "as_sql") else "%s"
+ )
def subtract_temporals(self, internal_type, lhs, rhs):
lhs_sql, lhs_params = lhs
rhs_sql, rhs_params = rhs
- if internal_type == 'TimeField':
+ if internal_type == "TimeField":
if self.connection.mysql_is_mariadb:
# MariaDB includes the microsecond component in TIME_TO_SEC as
# a decimal. MySQL returns an integer without microseconds.
- return 'CAST((TIME_TO_SEC(%(lhs)s) - TIME_TO_SEC(%(rhs)s)) * 1000000 AS SIGNED)' % {
- 'lhs': lhs_sql, 'rhs': rhs_sql
- }, (*lhs_params, *rhs_params)
+ return "CAST((TIME_TO_SEC(%(lhs)s) - TIME_TO_SEC(%(rhs)s)) * 1000000 AS SIGNED)" % {
+ "lhs": lhs_sql,
+ "rhs": rhs_sql,
+ }, (
+ *lhs_params,
+ *rhs_params,
+ )
return (
"((TIME_TO_SEC(%(lhs)s) * 1000000 + MICROSECOND(%(lhs)s)) -"
" (TIME_TO_SEC(%(rhs)s) * 1000000 + MICROSECOND(%(rhs)s)))"
- ) % {'lhs': lhs_sql, 'rhs': rhs_sql}, tuple(lhs_params) * 2 + tuple(rhs_params) * 2
+ ) % {"lhs": lhs_sql, "rhs": rhs_sql}, tuple(lhs_params) * 2 + tuple(
+ rhs_params
+ ) * 2
params = (*rhs_params, *lhs_params)
return "TIMESTAMPDIFF(MICROSECOND, %s, %s)" % (rhs_sql, lhs_sql), params
def explain_query_prefix(self, format=None, **options):
# Alias MySQL's TRADITIONAL to TEXT for consistency with other backends.
- if format and format.upper() == 'TEXT':
- format = 'TRADITIONAL'
- elif not format and 'TREE' in self.connection.features.supported_explain_formats:
+ if format and format.upper() == "TEXT":
+ format = "TRADITIONAL"
+ elif (
+ not format and "TREE" in self.connection.features.supported_explain_formats
+ ):
# Use TREE by default (if supported) as it's more informative.
- format = 'TREE'
- analyze = options.pop('analyze', False)
+ format = "TREE"
+ analyze = options.pop("analyze", False)
prefix = super().explain_query_prefix(format, **options)
if analyze and self.connection.features.supports_explain_analyze:
# MariaDB uses ANALYZE instead of EXPLAIN ANALYZE.
- prefix = 'ANALYZE' if self.connection.mysql_is_mariadb else prefix + ' ANALYZE'
+ prefix = (
+ "ANALYZE" if self.connection.mysql_is_mariadb else prefix + " ANALYZE"
+ )
if format and not (analyze and not self.connection.mysql_is_mariadb):
# Only MariaDB supports the analyze option with formats.
- prefix += ' FORMAT=%s' % format
+ prefix += " FORMAT=%s" % format
return prefix
def regex_lookup(self, lookup_type):
# REGEXP BINARY doesn't work correctly in MySQL 8+ and REGEXP_LIKE
# doesn't exist in MySQL 5.x or in MariaDB.
- if self.connection.mysql_version < (8, 0, 0) or self.connection.mysql_is_mariadb:
- if lookup_type == 'regex':
- return '%s REGEXP BINARY %s'
- return '%s REGEXP %s'
+ if (
+ self.connection.mysql_version < (8, 0, 0)
+ or self.connection.mysql_is_mariadb
+ ):
+ if lookup_type == "regex":
+ return "%s REGEXP BINARY %s"
+ return "%s REGEXP %s"
- match_option = 'c' if lookup_type == 'regex' else 'i'
+ match_option = "c" if lookup_type == "regex" else "i"
return "REGEXP_LIKE(%%s, %%s, '%s')" % match_option
def insert_statement(self, on_conflict=None):
if on_conflict == OnConflict.IGNORE:
- return 'INSERT IGNORE INTO'
+ return "INSERT IGNORE INTO"
return super().insert_statement(on_conflict=on_conflict)
def lookup_cast(self, lookup_type, internal_type=None):
- lookup = '%s'
- if internal_type == 'JSONField':
+ lookup = "%s"
+ if internal_type == "JSONField":
if self.connection.mysql_is_mariadb or lookup_type in (
- 'iexact', 'contains', 'icontains', 'startswith', 'istartswith',
- 'endswith', 'iendswith', 'regex', 'iregex',
+ "iexact",
+ "contains",
+ "icontains",
+ "startswith",
+ "istartswith",
+ "endswith",
+ "iendswith",
+ "regex",
+ "iregex",
):
- lookup = 'JSON_UNQUOTE(%s)'
+ lookup = "JSON_UNQUOTE(%s)"
return lookup
def conditional_expression_supported_in_where_clause(self, expression):
@@ -388,31 +426,38 @@ class DatabaseOperations(BaseDatabaseOperations):
if isinstance(expression, (Exists, Lookup)):
return True
if isinstance(expression, ExpressionWrapper) and expression.conditional:
- return self.conditional_expression_supported_in_where_clause(expression.expression)
- if getattr(expression, 'conditional', False):
+ return self.conditional_expression_supported_in_where_clause(
+ expression.expression
+ )
+ if getattr(expression, "conditional", False):
return False
return super().conditional_expression_supported_in_where_clause(expression)
def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
if on_conflict == OnConflict.UPDATE:
- conflict_suffix_sql = 'ON DUPLICATE KEY UPDATE %(fields)s'
- field_sql = '%(field)s = VALUES(%(field)s)'
+ conflict_suffix_sql = "ON DUPLICATE KEY UPDATE %(fields)s"
+ field_sql = "%(field)s = VALUES(%(field)s)"
# The use of VALUES() is deprecated in MySQL 8.0.20+. Instead, use
# aliases for the new row and its columns available in MySQL
# 8.0.19+.
if not self.connection.mysql_is_mariadb:
if self.connection.mysql_version >= (8, 0, 19):
- conflict_suffix_sql = f'AS new {conflict_suffix_sql}'
- field_sql = '%(field)s = new.%(field)s'
+ conflict_suffix_sql = f"AS new {conflict_suffix_sql}"
+ field_sql = "%(field)s = new.%(field)s"
# VALUES() was renamed to VALUE() in MariaDB 10.3.3+.
elif self.connection.mysql_version >= (10, 3, 3):
- field_sql = '%(field)s = VALUE(%(field)s)'
+ field_sql = "%(field)s = VALUE(%(field)s)"
- fields = ', '.join([
- field_sql % {'field': field}
- for field in map(self.quote_name, update_fields)
- ])
- return conflict_suffix_sql % {'fields': fields}
+ fields = ", ".join(
+ [
+ field_sql % {"field": field}
+ for field in map(self.quote_name, update_fields)
+ ]
+ )
+ return conflict_suffix_sql % {"fields": fields}
return super().on_conflict_suffix_sql(
- fields, on_conflict, update_fields, unique_fields,
+ fields,
+ on_conflict,
+ update_fields,
+ unique_fields,
)
diff --git a/django/db/backends/mysql/schema.py b/django/db/backends/mysql/schema.py
index 17827c2195..562b209eef 100644
--- a/django/db/backends/mysql/schema.py
+++ b/django/db/backends/mysql/schema.py
@@ -10,24 +10,26 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
sql_alter_column_not_null = "MODIFY %(column)s %(type)s NOT NULL"
sql_alter_column_type = "MODIFY %(column)s %(type)s"
sql_alter_column_collate = "MODIFY %(column)s %(type)s%(collation)s"
- sql_alter_column_no_default_null = 'ALTER COLUMN %(column)s SET DEFAULT NULL'
+ sql_alter_column_no_default_null = "ALTER COLUMN %(column)s SET DEFAULT NULL"
# No 'CASCADE' which works as a no-op in MySQL but is undocumented
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s"
sql_delete_unique = "ALTER TABLE %(table)s DROP INDEX %(name)s"
sql_create_column_inline_fk = (
- ', ADD CONSTRAINT %(name)s FOREIGN KEY (%(column)s) '
- 'REFERENCES %(to_table)s(%(to_column)s)'
+ ", ADD CONSTRAINT %(name)s FOREIGN KEY (%(column)s) "
+ "REFERENCES %(to_table)s(%(to_column)s)"
)
sql_delete_fk = "ALTER TABLE %(table)s DROP FOREIGN KEY %(name)s"
sql_delete_index = "DROP INDEX %(name)s ON %(table)s"
- sql_create_pk = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s PRIMARY KEY (%(columns)s)"
+ sql_create_pk = (
+ "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s PRIMARY KEY (%(columns)s)"
+ )
sql_delete_pk = "ALTER TABLE %(table)s DROP PRIMARY KEY"
- sql_create_index = 'CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s'
+ sql_create_index = "CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s"
@property
def sql_delete_check(self):
@@ -35,8 +37,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# The name of the column check constraint is the same as the field
# name on MariaDB. Adding IF EXISTS clause prevents migrations
# crash. Constraint is removed during a "MODIFY" column statement.
- return 'ALTER TABLE %(table)s DROP CONSTRAINT IF EXISTS %(name)s'
- return 'ALTER TABLE %(table)s DROP CHECK %(name)s'
+ return "ALTER TABLE %(table)s DROP CONSTRAINT IF EXISTS %(name)s"
+ return "ALTER TABLE %(table)s DROP CHECK %(name)s"
@property
def sql_rename_column(self):
@@ -47,21 +49,26 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
return super().sql_rename_column
elif self.connection.mysql_version >= (8, 0, 4):
return super().sql_rename_column
- return 'ALTER TABLE %(table)s CHANGE %(old_column)s %(new_column)s %(type)s'
+ return "ALTER TABLE %(table)s CHANGE %(old_column)s %(new_column)s %(type)s"
def quote_value(self, value):
self.connection.ensure_connection()
if isinstance(value, str):
- value = value.replace('%', '%%')
+ value = value.replace("%", "%%")
# MySQLdb escapes to string, PyMySQL to bytes.
- quoted = self.connection.connection.escape(value, self.connection.connection.encoders)
+ quoted = self.connection.connection.escape(
+ value, self.connection.connection.encoders
+ )
if isinstance(value, str) and isinstance(quoted, bytes):
quoted = quoted.decode()
return quoted
def _is_limited_data_type(self, field):
db_type = field.db_type(self.connection)
- return db_type is not None and db_type.lower() in self.connection._limited_data_types
+ return (
+ db_type is not None
+ and db_type.lower() in self.connection._limited_data_types
+ )
def skip_default(self, field):
if not self._supports_limited_data_type_defaults:
@@ -84,13 +91,13 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
def _column_default_sql(self, field):
if (
- not self.connection.mysql_is_mariadb and
- self._supports_limited_data_type_defaults and
- self._is_limited_data_type(field)
+ not self.connection.mysql_is_mariadb
+ and self._supports_limited_data_type_defaults
+ and self._is_limited_data_type(field)
):
# MySQL supports defaults for BLOB and TEXT columns only if the
# default value is written as an expression i.e. in parentheses.
- return '(%s)'
+ return "(%s)"
return super()._column_default_sql(field)
def add_field(self, model, field):
@@ -100,10 +107,14 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# field.default may be unhashable, so a set isn't used for "in" check.
if self.skip_default(field) and field.default not in (None, NOT_PROVIDED):
effective_default = self.effective_default(field)
- self.execute('UPDATE %(table)s SET %(column)s = %%s' % {
- 'table': self.quote_name(model._meta.db_table),
- 'column': self.quote_name(field.column),
- }, [effective_default])
+ self.execute(
+ "UPDATE %(table)s SET %(column)s = %%s"
+ % {
+ "table": self.quote_name(model._meta.db_table),
+ "column": self.quote_name(field.column),
+ },
+ [effective_default],
+ )
def _field_should_be_indexed(self, model, field):
if not super()._field_should_be_indexed(model, field):
@@ -115,9 +126,11 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# No need to create an index for ForeignKey fields except if
# db_constraint=False because the index from that constraint won't be
# created.
- if (storage == "InnoDB" and
- field.get_internal_type() == 'ForeignKey' and
- field.db_constraint):
+ if (
+ storage == "InnoDB"
+ and field.get_internal_type() == "ForeignKey"
+ and field.db_constraint
+ ):
return False
return not self._is_limited_data_type(field)
@@ -131,11 +144,13 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
recreate a FK index.
"""
first_field = model._meta.get_field(fields[0])
- if first_field.get_internal_type() == 'ForeignKey':
- constraint_names = self._constraint_names(model, [first_field.column], index=True)
+ if first_field.get_internal_type() == "ForeignKey":
+ constraint_names = self._constraint_names(
+ model, [first_field.column], index=True
+ )
if not constraint_names:
self.execute(
- self._create_index_sql(model, fields=[first_field], suffix='')
+ self._create_index_sql(model, fields=[first_field], suffix="")
)
return super()._delete_composed_index(model, fields, *args)
diff --git a/django/db/backends/mysql/validation.py b/django/db/backends/mysql/validation.py
index 41e600a856..5d61b4865a 100644
--- a/django/db/backends/mysql/validation.py
+++ b/django/db/backends/mysql/validation.py
@@ -10,24 +10,28 @@ class DatabaseValidation(BaseDatabaseValidation):
return issues
def _check_sql_mode(self, **kwargs):
- if not (self.connection.sql_mode & {'STRICT_TRANS_TABLES', 'STRICT_ALL_TABLES'}):
- return [checks.Warning(
- "%s Strict Mode is not set for database connection '%s'"
- % (self.connection.display_name, self.connection.alias),
- hint=(
- "%s's Strict Mode fixes many data integrity problems in "
- "%s, such as data truncation upon insertion, by "
- "escalating warnings into errors. It is strongly "
- "recommended you activate it. See: "
- "https://docs.djangoproject.com/en/%s/ref/databases/#mysql-sql-mode"
- % (
- self.connection.display_name,
- self.connection.display_name,
- get_docs_version(),
+ if not (
+ self.connection.sql_mode & {"STRICT_TRANS_TABLES", "STRICT_ALL_TABLES"}
+ ):
+ return [
+ checks.Warning(
+ "%s Strict Mode is not set for database connection '%s'"
+ % (self.connection.display_name, self.connection.alias),
+ hint=(
+ "%s's Strict Mode fixes many data integrity problems in "
+ "%s, such as data truncation upon insertion, by "
+ "escalating warnings into errors. It is strongly "
+ "recommended you activate it. See: "
+ "https://docs.djangoproject.com/en/%s/ref/databases/#mysql-sql-mode"
+ % (
+ self.connection.display_name,
+ self.connection.display_name,
+ get_docs_version(),
+ ),
),
- ),
- id='mysql.W002',
- )]
+ id="mysql.W002",
+ )
+ ]
return []
def check_field_type(self, field, field_type):
@@ -38,32 +42,35 @@ class DatabaseValidation(BaseDatabaseValidation):
MySQL doesn't support a database index on some data types.
"""
errors = []
- if (field_type.startswith('varchar') and field.unique and
- (field.max_length is None or int(field.max_length) > 255)):
+ if (
+ field_type.startswith("varchar")
+ and field.unique
+ and (field.max_length is None or int(field.max_length) > 255)
+ ):
errors.append(
checks.Warning(
- '%s may not allow unique CharFields to have a max_length '
- '> 255.' % self.connection.display_name,
+ "%s may not allow unique CharFields to have a max_length "
+ "> 255." % self.connection.display_name,
obj=field,
hint=(
- 'See: https://docs.djangoproject.com/en/%s/ref/'
- 'databases/#mysql-character-fields' % get_docs_version()
+ "See: https://docs.djangoproject.com/en/%s/ref/"
+ "databases/#mysql-character-fields" % get_docs_version()
),
- id='mysql.W003',
+ id="mysql.W003",
)
)
if field.db_index and field_type.lower() in self.connection._limited_data_types:
errors.append(
checks.Warning(
- '%s does not support a database index on %s columns.'
+ "%s does not support a database index on %s columns."
% (self.connection.display_name, field_type),
hint=(
"An index won't be created. Silence this warning if "
"you don't care about it."
),
obj=field,
- id='fields.W162',
+ id="fields.W162",
)
)
return errors
diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py
index 966eb4b6f4..b13c5f8bb2 100644
--- a/django/db/backends/oracle/base.py
+++ b/django/db/backends/oracle/base.py
@@ -21,27 +21,31 @@ from django.utils.functional import cached_property
def _setup_environment(environ):
# Cygwin requires some special voodoo to set the environment variables
# properly so that Oracle will see them.
- if platform.system().upper().startswith('CYGWIN'):
+ if platform.system().upper().startswith("CYGWIN"):
try:
import ctypes
except ImportError as e:
- raise ImproperlyConfigured("Error loading ctypes: %s; "
- "the Oracle backend requires ctypes to "
- "operate correctly under Cygwin." % e)
- kernel32 = ctypes.CDLL('kernel32')
+ raise ImproperlyConfigured(
+ "Error loading ctypes: %s; "
+ "the Oracle backend requires ctypes to "
+ "operate correctly under Cygwin." % e
+ )
+ kernel32 = ctypes.CDLL("kernel32")
for name, value in environ:
kernel32.SetEnvironmentVariableA(name, value)
else:
os.environ.update(environ)
-_setup_environment([
- # Oracle takes client-side character set encoding from the environment.
- ('NLS_LANG', '.AL32UTF8'),
- # This prevents Unicode from getting mangled by getting encoded into the
- # potentially non-Unicode database character set.
- ('ORA_NCHAR_LITERAL_REPLACE', 'TRUE'),
-])
+_setup_environment(
+ [
+ # Oracle takes client-side character set encoding from the environment.
+ ("NLS_LANG", ".AL32UTF8"),
+ # This prevents Unicode from getting mangled by getting encoded into the
+ # potentially non-Unicode database character set.
+ ("ORA_NCHAR_LITERAL_REPLACE", "TRUE"),
+ ]
+)
try:
@@ -77,17 +81,16 @@ def wrap_oracle_errors():
# Convert that case to Django's IntegrityError exception.
x = e.args[0]
if (
- hasattr(x, 'code') and
- hasattr(x, 'message') and
- x.code == 2091 and
- ('ORA-02291' in x.message or 'ORA-00001' in x.message)
+ hasattr(x, "code")
+ and hasattr(x, "message")
+ and x.code == 2091
+ and ("ORA-02291" in x.message or "ORA-00001" in x.message)
):
raise IntegrityError(*tuple(e.args))
raise
class _UninitializedOperatorsDescriptor:
-
def __get__(self, instance, cls=None):
# If connection.operators is looked up before a connection has been
# created, transparently initialize connection.operators to avert an
@@ -96,12 +99,12 @@ class _UninitializedOperatorsDescriptor:
raise AttributeError("operators not available as class attribute")
# Creating a cursor will initialize the operators.
instance.cursor().close()
- return instance.__dict__['operators']
+ return instance.__dict__["operators"]
class DatabaseWrapper(BaseDatabaseWrapper):
- vendor = 'oracle'
- display_name = 'Oracle'
+ vendor = "oracle"
+ display_name = "Oracle"
# This dictionary maps Field objects to their associated Oracle column
# types, as strings. Column-type strings can contain format strings; they'll
# be interpolated against the values of Field.__dict__ before being output.
@@ -110,71 +113,71 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# Any format strings starting with "qn_" are quoted before being used in the
# output (the "qn_" prefix is stripped before the lookup is performed.
data_types = {
- 'AutoField': 'NUMBER(11) GENERATED BY DEFAULT ON NULL AS IDENTITY',
- 'BigAutoField': 'NUMBER(19) GENERATED BY DEFAULT ON NULL AS IDENTITY',
- 'BinaryField': 'BLOB',
- 'BooleanField': 'NUMBER(1)',
- 'CharField': 'NVARCHAR2(%(max_length)s)',
- 'DateField': 'DATE',
- 'DateTimeField': 'TIMESTAMP',
- 'DecimalField': 'NUMBER(%(max_digits)s, %(decimal_places)s)',
- 'DurationField': 'INTERVAL DAY(9) TO SECOND(6)',
- 'FileField': 'NVARCHAR2(%(max_length)s)',
- 'FilePathField': 'NVARCHAR2(%(max_length)s)',
- 'FloatField': 'DOUBLE PRECISION',
- 'IntegerField': 'NUMBER(11)',
- 'JSONField': 'NCLOB',
- 'BigIntegerField': 'NUMBER(19)',
- 'IPAddressField': 'VARCHAR2(15)',
- 'GenericIPAddressField': 'VARCHAR2(39)',
- 'OneToOneField': 'NUMBER(11)',
- 'PositiveBigIntegerField': 'NUMBER(19)',
- 'PositiveIntegerField': 'NUMBER(11)',
- 'PositiveSmallIntegerField': 'NUMBER(11)',
- 'SlugField': 'NVARCHAR2(%(max_length)s)',
- 'SmallAutoField': 'NUMBER(5) GENERATED BY DEFAULT ON NULL AS IDENTITY',
- 'SmallIntegerField': 'NUMBER(11)',
- 'TextField': 'NCLOB',
- 'TimeField': 'TIMESTAMP',
- 'URLField': 'VARCHAR2(%(max_length)s)',
- 'UUIDField': 'VARCHAR2(32)',
+ "AutoField": "NUMBER(11) GENERATED BY DEFAULT ON NULL AS IDENTITY",
+ "BigAutoField": "NUMBER(19) GENERATED BY DEFAULT ON NULL AS IDENTITY",
+ "BinaryField": "BLOB",
+ "BooleanField": "NUMBER(1)",
+ "CharField": "NVARCHAR2(%(max_length)s)",
+ "DateField": "DATE",
+ "DateTimeField": "TIMESTAMP",
+ "DecimalField": "NUMBER(%(max_digits)s, %(decimal_places)s)",
+ "DurationField": "INTERVAL DAY(9) TO SECOND(6)",
+ "FileField": "NVARCHAR2(%(max_length)s)",
+ "FilePathField": "NVARCHAR2(%(max_length)s)",
+ "FloatField": "DOUBLE PRECISION",
+ "IntegerField": "NUMBER(11)",
+ "JSONField": "NCLOB",
+ "BigIntegerField": "NUMBER(19)",
+ "IPAddressField": "VARCHAR2(15)",
+ "GenericIPAddressField": "VARCHAR2(39)",
+ "OneToOneField": "NUMBER(11)",
+ "PositiveBigIntegerField": "NUMBER(19)",
+ "PositiveIntegerField": "NUMBER(11)",
+ "PositiveSmallIntegerField": "NUMBER(11)",
+ "SlugField": "NVARCHAR2(%(max_length)s)",
+ "SmallAutoField": "NUMBER(5) GENERATED BY DEFAULT ON NULL AS IDENTITY",
+ "SmallIntegerField": "NUMBER(11)",
+ "TextField": "NCLOB",
+ "TimeField": "TIMESTAMP",
+ "URLField": "VARCHAR2(%(max_length)s)",
+ "UUIDField": "VARCHAR2(32)",
}
data_type_check_constraints = {
- 'BooleanField': '%(qn_column)s IN (0,1)',
- 'JSONField': '%(qn_column)s IS JSON',
- 'PositiveBigIntegerField': '%(qn_column)s >= 0',
- 'PositiveIntegerField': '%(qn_column)s >= 0',
- 'PositiveSmallIntegerField': '%(qn_column)s >= 0',
+ "BooleanField": "%(qn_column)s IN (0,1)",
+ "JSONField": "%(qn_column)s IS JSON",
+ "PositiveBigIntegerField": "%(qn_column)s >= 0",
+ "PositiveIntegerField": "%(qn_column)s >= 0",
+ "PositiveSmallIntegerField": "%(qn_column)s >= 0",
}
# Oracle doesn't support a database index on these columns.
- _limited_data_types = ('clob', 'nclob', 'blob')
+ _limited_data_types = ("clob", "nclob", "blob")
operators = _UninitializedOperatorsDescriptor()
_standard_operators = {
- 'exact': '= %s',
- 'iexact': '= UPPER(%s)',
- 'contains': "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
- 'icontains': "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
- 'gt': '> %s',
- 'gte': '>= %s',
- 'lt': '< %s',
- 'lte': '<= %s',
- 'startswith': "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
- 'endswith': "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
- 'istartswith': "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
- 'iendswith': "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
+ "exact": "= %s",
+ "iexact": "= UPPER(%s)",
+ "contains": "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
+ "icontains": "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
+ "gt": "> %s",
+ "gte": ">= %s",
+ "lt": "< %s",
+ "lte": "<= %s",
+ "startswith": "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
+ "endswith": "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
+ "istartswith": "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
+ "iendswith": "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
}
_likec_operators = {
**_standard_operators,
- 'contains': "LIKEC %s ESCAPE '\\'",
- 'icontains': "LIKEC UPPER(%s) ESCAPE '\\'",
- 'startswith': "LIKEC %s ESCAPE '\\'",
- 'endswith': "LIKEC %s ESCAPE '\\'",
- 'istartswith': "LIKEC UPPER(%s) ESCAPE '\\'",
- 'iendswith': "LIKEC UPPER(%s) ESCAPE '\\'",
+ "contains": "LIKEC %s ESCAPE '\\'",
+ "icontains": "LIKEC UPPER(%s) ESCAPE '\\'",
+ "startswith": "LIKEC %s ESCAPE '\\'",
+ "endswith": "LIKEC %s ESCAPE '\\'",
+ "istartswith": "LIKEC UPPER(%s) ESCAPE '\\'",
+ "iendswith": "LIKEC UPPER(%s) ESCAPE '\\'",
}
# The patterns below are used to generate SQL pattern lookup clauses when
@@ -187,19 +190,22 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# the LIKE operator.
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\', '\\'), '%%', '\%%'), '_', '\_')"
_pattern_ops = {
- 'contains': "'%%' || {} || '%%'",
- 'icontains': "'%%' || UPPER({}) || '%%'",
- 'startswith': "{} || '%%'",
- 'istartswith': "UPPER({}) || '%%'",
- 'endswith': "'%%' || {}",
- 'iendswith': "'%%' || UPPER({})",
+ "contains": "'%%' || {} || '%%'",
+ "icontains": "'%%' || UPPER({}) || '%%'",
+ "startswith": "{} || '%%'",
+ "istartswith": "UPPER({}) || '%%'",
+ "endswith": "'%%' || {}",
+ "iendswith": "'%%' || UPPER({})",
}
- _standard_pattern_ops = {k: "LIKE TRANSLATE( " + v + " USING NCHAR_CS)"
- " ESCAPE TRANSLATE('\\' USING NCHAR_CS)"
- for k, v in _pattern_ops.items()}
- _likec_pattern_ops = {k: "LIKEC " + v + " ESCAPE '\\'"
- for k, v in _pattern_ops.items()}
+ _standard_pattern_ops = {
+ k: "LIKE TRANSLATE( " + v + " USING NCHAR_CS)"
+ " ESCAPE TRANSLATE('\\' USING NCHAR_CS)"
+ for k, v in _pattern_ops.items()
+ }
+ _likec_pattern_ops = {
+ k: "LIKEC " + v + " ESCAPE '\\'" for k, v in _pattern_ops.items()
+ }
Database = Database
SchemaEditorClass = DatabaseSchemaEditor
@@ -213,20 +219,22 @@ class DatabaseWrapper(BaseDatabaseWrapper):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- use_returning_into = self.settings_dict["OPTIONS"].get('use_returning_into', True)
+ use_returning_into = self.settings_dict["OPTIONS"].get(
+ "use_returning_into", True
+ )
self.features.can_return_columns_from_insert = use_returning_into
def get_connection_params(self):
- conn_params = self.settings_dict['OPTIONS'].copy()
- if 'use_returning_into' in conn_params:
- del conn_params['use_returning_into']
+ conn_params = self.settings_dict["OPTIONS"].copy()
+ if "use_returning_into" in conn_params:
+ del conn_params["use_returning_into"]
return conn_params
@async_unsafe
def get_new_connection(self, conn_params):
return Database.connect(
- user=self.settings_dict['USER'],
- password=self.settings_dict['PASSWORD'],
+ user=self.settings_dict["USER"],
+ password=self.settings_dict["PASSWORD"],
dsn=dsn(self.settings_dict),
**conn_params,
)
@@ -244,11 +252,11 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# TO_CHAR().
cursor.execute(
"ALTER SESSION SET NLS_DATE_FORMAT = 'YYYY-MM-DD HH24:MI:SS'"
- " NLS_TIMESTAMP_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF'" +
- (" TIME_ZONE = 'UTC'" if settings.USE_TZ else '')
+ " NLS_TIMESTAMP_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF'"
+ + (" TIME_ZONE = 'UTC'" if settings.USE_TZ else "")
)
cursor.close()
- if 'operators' not in self.__dict__:
+ if "operators" not in self.__dict__:
# Ticket #14149: Check whether our LIKE implementation will
# work for this connection or we need to fall back on LIKEC.
# This check is performed only once per DatabaseWrapper
@@ -256,9 +264,11 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# the same settings.
cursor = self.create_cursor()
try:
- cursor.execute("SELECT 1 FROM DUAL WHERE DUMMY %s"
- % self._standard_operators['contains'],
- ['X'])
+ cursor.execute(
+ "SELECT 1 FROM DUAL WHERE DUMMY %s"
+ % self._standard_operators["contains"],
+ ["X"],
+ )
except Database.DatabaseError:
self.operators = self._likec_operators
self.pattern_ops = self._likec_pattern_ops
@@ -284,10 +294,12 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# logging is enabled to keep query counts consistent with other backends.
def _savepoint_commit(self, sid):
if self.queries_logged:
- self.queries_log.append({
- 'sql': '-- RELEASE SAVEPOINT %s (faked)' % self.ops.quote_name(sid),
- 'time': '0.000',
- })
+ self.queries_log.append(
+ {
+ "sql": "-- RELEASE SAVEPOINT %s (faked)" % self.ops.quote_name(sid),
+ "time": "0.000",
+ }
+ )
def _set_autocommit(self, autocommit):
with self.wrap_database_errors:
@@ -299,8 +311,8 @@ class DatabaseWrapper(BaseDatabaseWrapper):
afterward.
"""
with self.cursor() as cursor:
- cursor.execute('SET CONSTRAINTS ALL IMMEDIATE')
- cursor.execute('SET CONSTRAINTS ALL DEFERRED')
+ cursor.execute("SET CONSTRAINTS ALL IMMEDIATE")
+ cursor.execute("SET CONSTRAINTS ALL DEFERRED")
def is_usable(self):
try:
@@ -312,12 +324,12 @@ class DatabaseWrapper(BaseDatabaseWrapper):
@cached_property
def cx_oracle_version(self):
- return tuple(int(x) for x in Database.version.split('.'))
+ return tuple(int(x) for x in Database.version.split("."))
@cached_property
def oracle_version(self):
with self.temporary_connection():
- return tuple(int(x) for x in self.connection.version.split('.'))
+ return tuple(int(x) for x in self.connection.version.split("."))
class OracleParam:
@@ -333,8 +345,10 @@ class OracleParam:
def __init__(self, param, cursor, strings_only=False):
# With raw SQL queries, datetimes can reach this function
# without being converted by DateTimeField.get_db_prep_value.
- if settings.USE_TZ and (isinstance(param, datetime.datetime) and
- not isinstance(param, Oracle_datetime)):
+ if settings.USE_TZ and (
+ isinstance(param, datetime.datetime)
+ and not isinstance(param, Oracle_datetime)
+ ):
param = Oracle_datetime.from_datetime(param)
string_size = 0
@@ -343,7 +357,7 @@ class OracleParam:
param = 1
elif param is False:
param = 0
- if hasattr(param, 'bind_parameter'):
+ if hasattr(param, "bind_parameter"):
self.force_bytes = param.bind_parameter(cursor)
elif isinstance(param, (Database.Binary, datetime.timedelta)):
self.force_bytes = param
@@ -354,7 +368,7 @@ class OracleParam:
if isinstance(self.force_bytes, str):
# We could optimize by only converting up to 4000 bytes here
string_size = len(force_bytes(param, cursor.charset, strings_only))
- if hasattr(param, 'input_size'):
+ if hasattr(param, "input_size"):
# If parameter has `input_size` attribute, use that.
self.input_size = param.input_size
elif string_size > 4000:
@@ -384,7 +398,7 @@ class VariableWrapper:
return getattr(self.var, key)
def __setattr__(self, key, value):
- if key == 'var':
+ if key == "var":
self.__dict__[key] = value
else:
setattr(self.var, key, value)
@@ -396,7 +410,8 @@ class FormatStylePlaceholderCursor:
style. This fixes it -- but note that if you want to use a literal "%s" in
a query, you'll need to use "%%s".
"""
- charset = 'utf-8'
+
+ charset = "utf-8"
def __init__(self, connection):
self.cursor = connection.cursor()
@@ -404,7 +419,7 @@ class FormatStylePlaceholderCursor:
@staticmethod
def _output_number_converter(value):
- return decimal.Decimal(value) if '.' in value else int(value)
+ return decimal.Decimal(value) if "." in value else int(value)
@staticmethod
def _get_decimal_converter(precision, scale):
@@ -434,7 +449,9 @@ class FormatStylePlaceholderCursor:
elif precision > 0:
# NUMBER(p,s) column: decimal-precision fixed point.
# This comes from IntegerField and DecimalField columns.
- outconverter = FormatStylePlaceholderCursor._get_decimal_converter(precision, scale)
+ outconverter = FormatStylePlaceholderCursor._get_decimal_converter(
+ precision, scale
+ )
else:
# No type information. This normally comes from a
# mathematical expression in the SELECT list. Guess int
@@ -455,7 +472,7 @@ class FormatStylePlaceholderCursor:
def _guess_input_sizes(self, params_list):
# Try dict handling; if that fails, treat as sequence
- if hasattr(params_list[0], 'keys'):
+ if hasattr(params_list[0], "keys"):
sizes = {}
for params in params_list:
for k, value in params.items():
@@ -475,7 +492,7 @@ class FormatStylePlaceholderCursor:
def _param_generator(self, params):
# Try dict handling; if that fails, treat as sequence
- if hasattr(params, 'items'):
+ if hasattr(params, "items"):
return {k: v.force_bytes for k, v in params.items()}
else:
return [p.force_bytes for p in params]
@@ -485,11 +502,11 @@ class FormatStylePlaceholderCursor:
# it does want a trailing ';' but not a trailing '/'. However, these
# characters must be included in the original query in case the query
# is being passed to SQL*Plus.
- if query.endswith(';') or query.endswith('/'):
+ if query.endswith(";") or query.endswith("/"):
query = query[:-1]
if params is None:
params = []
- elif hasattr(params, 'keys'):
+ elif hasattr(params, "keys"):
# Handle params as dict
args = {k: ":%s" % k for k in params}
query = query % args
@@ -502,15 +519,14 @@ class FormatStylePlaceholderCursor:
# args = [':arg0', ':arg1', ':arg0', ':arg2', ':arg0']
# params = {':arg0': 0.75, ':arg1': 2, ':arg2': 'sth'}
params_dict = {
- param: ':arg%d' % i
- for i, param in enumerate(dict.fromkeys(params))
+ param: ":arg%d" % i for i, param in enumerate(dict.fromkeys(params))
}
args = [params_dict[param] for param in params]
params = {value: key for key, value in params_dict.items()}
query = query % tuple(args)
else:
# Handle params as sequence
- args = [(':arg%d' % i) for i in range(len(params))]
+ args = [(":arg%d" % i) for i in range(len(params))]
query = query % tuple(args)
return query, self._format_params(params)
@@ -532,7 +548,9 @@ class FormatStylePlaceholderCursor:
formatted = [firstparams] + [self._format_params(p) for p in params_iter]
self._guess_input_sizes(formatted)
with wrap_oracle_errors():
- return self.cursor.executemany(query, [self._param_generator(p) for p in formatted])
+ return self.cursor.executemany(
+ query, [self._param_generator(p) for p in formatted]
+ )
def close(self):
try:
diff --git a/django/db/backends/oracle/client.py b/django/db/backends/oracle/client.py
index 9920f4ca67..365b116046 100644
--- a/django/db/backends/oracle/client.py
+++ b/django/db/backends/oracle/client.py
@@ -4,22 +4,22 @@ from django.db.backends.base.client import BaseDatabaseClient
class DatabaseClient(BaseDatabaseClient):
- executable_name = 'sqlplus'
- wrapper_name = 'rlwrap'
+ executable_name = "sqlplus"
+ wrapper_name = "rlwrap"
@staticmethod
def connect_string(settings_dict):
from django.db.backends.oracle.utils import dsn
return '%s/"%s"@%s' % (
- settings_dict['USER'],
- settings_dict['PASSWORD'],
+ settings_dict["USER"],
+ settings_dict["PASSWORD"],
dsn(settings_dict),
)
@classmethod
def settings_to_cmd_args_env(cls, settings_dict, parameters):
- args = [cls.executable_name, '-L', cls.connect_string(settings_dict)]
+ args = [cls.executable_name, "-L", cls.connect_string(settings_dict)]
wrapper_path = shutil.which(cls.wrapper_name)
if wrapper_path:
args = [wrapper_path, *args]
diff --git a/django/db/backends/oracle/creation.py b/django/db/backends/oracle/creation.py
index 3ca3754e15..bdde162aa8 100644
--- a/django/db/backends/oracle/creation.py
+++ b/django/db/backends/oracle/creation.py
@@ -6,11 +6,10 @@ from django.db.backends.base.creation import BaseDatabaseCreation
from django.utils.crypto import get_random_string
from django.utils.functional import cached_property
-TEST_DATABASE_PREFIX = 'test_'
+TEST_DATABASE_PREFIX = "test_"
class DatabaseCreation(BaseDatabaseCreation):
-
@cached_property
def _maindb_connection(self):
"""
@@ -21,9 +20,9 @@ class DatabaseCreation(BaseDatabaseCreation):
is the main (non-test) connection.
"""
settings_dict = settings.DATABASES[self.connection.alias]
- user = settings_dict.get('SAVED_USER') or settings_dict['USER']
- password = settings_dict.get('SAVED_PASSWORD') or settings_dict['PASSWORD']
- settings_dict = {**settings_dict, 'USER': user, 'PASSWORD': password}
+ user = settings_dict.get("SAVED_USER") or settings_dict["USER"]
+ password = settings_dict.get("SAVED_PASSWORD") or settings_dict["PASSWORD"]
+ settings_dict = {**settings_dict, "USER": user, "PASSWORD": password}
DatabaseWrapper = type(self.connection)
return DatabaseWrapper(settings_dict, alias=self.connection.alias)
@@ -32,72 +31,95 @@ class DatabaseCreation(BaseDatabaseCreation):
with self._maindb_connection.cursor() as cursor:
if self._test_database_create():
try:
- self._execute_test_db_creation(cursor, parameters, verbosity, keepdb)
+ self._execute_test_db_creation(
+ cursor, parameters, verbosity, keepdb
+ )
except Exception as e:
- if 'ORA-01543' not in str(e):
+ if "ORA-01543" not in str(e):
# All errors except "tablespace already exists" cancel tests
- self.log('Got an error creating the test database: %s' % e)
+ self.log("Got an error creating the test database: %s" % e)
sys.exit(2)
if not autoclobber:
confirm = input(
"It appears the test database, %s, already exists. "
- "Type 'yes' to delete it, or 'no' to cancel: " % parameters['user'])
- if autoclobber or confirm == 'yes':
+ "Type 'yes' to delete it, or 'no' to cancel: "
+ % parameters["user"]
+ )
+ if autoclobber or confirm == "yes":
if verbosity >= 1:
- self.log("Destroying old test database for alias '%s'..." % self.connection.alias)
+ self.log(
+ "Destroying old test database for alias '%s'..."
+ % self.connection.alias
+ )
try:
- self._execute_test_db_destruction(cursor, parameters, verbosity)
+ self._execute_test_db_destruction(
+ cursor, parameters, verbosity
+ )
except DatabaseError as e:
- if 'ORA-29857' in str(e):
- self._handle_objects_preventing_db_destruction(cursor, parameters,
- verbosity, autoclobber)
+ if "ORA-29857" in str(e):
+ self._handle_objects_preventing_db_destruction(
+ cursor, parameters, verbosity, autoclobber
+ )
else:
# Ran into a database error that isn't about leftover objects in the tablespace
- self.log('Got an error destroying the old test database: %s' % e)
+ self.log(
+ "Got an error destroying the old test database: %s"
+ % e
+ )
sys.exit(2)
except Exception as e:
- self.log('Got an error destroying the old test database: %s' % e)
+ self.log(
+ "Got an error destroying the old test database: %s" % e
+ )
sys.exit(2)
try:
- self._execute_test_db_creation(cursor, parameters, verbosity, keepdb)
+ self._execute_test_db_creation(
+ cursor, parameters, verbosity, keepdb
+ )
except Exception as e:
- self.log('Got an error recreating the test database: %s' % e)
+ self.log(
+ "Got an error recreating the test database: %s" % e
+ )
sys.exit(2)
else:
- self.log('Tests cancelled.')
+ self.log("Tests cancelled.")
sys.exit(1)
if self._test_user_create():
if verbosity >= 1:
- self.log('Creating test user...')
+ self.log("Creating test user...")
try:
self._create_test_user(cursor, parameters, verbosity, keepdb)
except Exception as e:
- if 'ORA-01920' not in str(e):
+ if "ORA-01920" not in str(e):
# All errors except "user already exists" cancel tests
- self.log('Got an error creating the test user: %s' % e)
+ self.log("Got an error creating the test user: %s" % e)
sys.exit(2)
if not autoclobber:
confirm = input(
"It appears the test user, %s, already exists. Type "
- "'yes' to delete it, or 'no' to cancel: " % parameters['user'])
- if autoclobber or confirm == 'yes':
+ "'yes' to delete it, or 'no' to cancel: "
+ % parameters["user"]
+ )
+ if autoclobber or confirm == "yes":
try:
if verbosity >= 1:
- self.log('Destroying old test user...')
+ self.log("Destroying old test user...")
self._destroy_test_user(cursor, parameters, verbosity)
if verbosity >= 1:
- self.log('Creating test user...')
- self._create_test_user(cursor, parameters, verbosity, keepdb)
+ self.log("Creating test user...")
+ self._create_test_user(
+ cursor, parameters, verbosity, keepdb
+ )
except Exception as e:
- self.log('Got an error recreating the test user: %s' % e)
+ self.log("Got an error recreating the test user: %s" % e)
sys.exit(2)
else:
- self.log('Tests cancelled.')
+ self.log("Tests cancelled.")
sys.exit(1)
self._maindb_connection.close() # done with main user -- test user and tablespaces created
self._switch_to_test_user(parameters)
- return self.connection.settings_dict['NAME']
+ return self.connection.settings_dict["NAME"]
def _switch_to_test_user(self, parameters):
"""
@@ -109,59 +131,71 @@ class DatabaseCreation(BaseDatabaseCreation):
credentials in the SAVED_USER/SAVED_PASSWORD key in the settings dict.
"""
real_settings = settings.DATABASES[self.connection.alias]
- real_settings['SAVED_USER'] = self.connection.settings_dict['SAVED_USER'] = \
- self.connection.settings_dict['USER']
- real_settings['SAVED_PASSWORD'] = self.connection.settings_dict['SAVED_PASSWORD'] = \
- self.connection.settings_dict['PASSWORD']
- real_test_settings = real_settings['TEST']
- test_settings = self.connection.settings_dict['TEST']
- real_test_settings['USER'] = real_settings['USER'] = test_settings['USER'] = \
- self.connection.settings_dict['USER'] = parameters['user']
- real_settings['PASSWORD'] = self.connection.settings_dict['PASSWORD'] = parameters['password']
+ real_settings["SAVED_USER"] = self.connection.settings_dict[
+ "SAVED_USER"
+ ] = self.connection.settings_dict["USER"]
+ real_settings["SAVED_PASSWORD"] = self.connection.settings_dict[
+ "SAVED_PASSWORD"
+ ] = self.connection.settings_dict["PASSWORD"]
+ real_test_settings = real_settings["TEST"]
+ test_settings = self.connection.settings_dict["TEST"]
+ real_test_settings["USER"] = real_settings["USER"] = test_settings[
+ "USER"
+ ] = self.connection.settings_dict["USER"] = parameters["user"]
+ real_settings["PASSWORD"] = self.connection.settings_dict[
+ "PASSWORD"
+ ] = parameters["password"]
def set_as_test_mirror(self, primary_settings_dict):
"""
Set this database up to be used in testing as a mirror of a primary
database whose settings are given.
"""
- self.connection.settings_dict['USER'] = primary_settings_dict['USER']
- self.connection.settings_dict['PASSWORD'] = primary_settings_dict['PASSWORD']
+ self.connection.settings_dict["USER"] = primary_settings_dict["USER"]
+ self.connection.settings_dict["PASSWORD"] = primary_settings_dict["PASSWORD"]
- def _handle_objects_preventing_db_destruction(self, cursor, parameters, verbosity, autoclobber):
+ def _handle_objects_preventing_db_destruction(
+ self, cursor, parameters, verbosity, autoclobber
+ ):
# There are objects in the test tablespace which prevent dropping it
# The easy fix is to drop the test user -- but are we allowed to do so?
self.log(
- 'There are objects in the old test database which prevent its destruction.\n'
- 'If they belong to the test user, deleting the user will allow the test '
- 'database to be recreated.\n'
- 'Otherwise, you will need to find and remove each of these objects, '
- 'or use a different tablespace.\n'
+ "There are objects in the old test database which prevent its destruction.\n"
+ "If they belong to the test user, deleting the user will allow the test "
+ "database to be recreated.\n"
+ "Otherwise, you will need to find and remove each of these objects, "
+ "or use a different tablespace.\n"
)
if self._test_user_create():
if not autoclobber:
- confirm = input("Type 'yes' to delete user %s: " % parameters['user'])
- if autoclobber or confirm == 'yes':
+ confirm = input("Type 'yes' to delete user %s: " % parameters["user"])
+ if autoclobber or confirm == "yes":
try:
if verbosity >= 1:
- self.log('Destroying old test user...')
+ self.log("Destroying old test user...")
self._destroy_test_user(cursor, parameters, verbosity)
except Exception as e:
- self.log('Got an error destroying the test user: %s' % e)
+ self.log("Got an error destroying the test user: %s" % e)
sys.exit(2)
try:
if verbosity >= 1:
- self.log("Destroying old test database for alias '%s'..." % self.connection.alias)
+ self.log(
+ "Destroying old test database for alias '%s'..."
+ % self.connection.alias
+ )
self._execute_test_db_destruction(cursor, parameters, verbosity)
except Exception as e:
- self.log('Got an error destroying the test database: %s' % e)
+ self.log("Got an error destroying the test database: %s" % e)
sys.exit(2)
else:
- self.log('Tests cancelled -- test database cannot be recreated.')
+ self.log("Tests cancelled -- test database cannot be recreated.")
sys.exit(1)
else:
- self.log("Django is configured to use pre-existing test user '%s',"
- " and will not attempt to delete it." % parameters['user'])
- self.log('Tests cancelled -- test database cannot be recreated.')
+ self.log(
+ "Django is configured to use pre-existing test user '%s',"
+ " and will not attempt to delete it." % parameters["user"]
+ )
+ self.log("Tests cancelled -- test database cannot be recreated.")
sys.exit(1)
def _destroy_test_db(self, test_database_name, verbosity=1):
@@ -169,24 +203,28 @@ class DatabaseCreation(BaseDatabaseCreation):
Destroy a test database, prompting the user for confirmation if the
database already exists. Return the name of the test database created.
"""
- self.connection.settings_dict['USER'] = self.connection.settings_dict['SAVED_USER']
- self.connection.settings_dict['PASSWORD'] = self.connection.settings_dict['SAVED_PASSWORD']
+ self.connection.settings_dict["USER"] = self.connection.settings_dict[
+ "SAVED_USER"
+ ]
+ self.connection.settings_dict["PASSWORD"] = self.connection.settings_dict[
+ "SAVED_PASSWORD"
+ ]
self.connection.close()
parameters = self._get_test_db_params()
with self._maindb_connection.cursor() as cursor:
if self._test_user_create():
if verbosity >= 1:
- self.log('Destroying test user...')
+ self.log("Destroying test user...")
self._destroy_test_user(cursor, parameters, verbosity)
if self._test_database_create():
if verbosity >= 1:
- self.log('Destroying test database tables...')
+ self.log("Destroying test database tables...")
self._execute_test_db_destruction(cursor, parameters, verbosity)
self._maindb_connection.close()
def _execute_test_db_creation(self, cursor, parameters, verbosity, keepdb=False):
if verbosity >= 2:
- self.log('_create_test_db(): dbname = %s' % parameters['user'])
+ self.log("_create_test_db(): dbname = %s" % parameters["user"])
if self._test_database_oracle_managed_files():
statements = [
"""
@@ -214,12 +252,14 @@ class DatabaseCreation(BaseDatabaseCreation):
""",
]
# Ignore "tablespace already exists" error when keepdb is on.
- acceptable_ora_err = 'ORA-01543' if keepdb else None
- self._execute_allow_fail_statements(cursor, statements, parameters, verbosity, acceptable_ora_err)
+ acceptable_ora_err = "ORA-01543" if keepdb else None
+ self._execute_allow_fail_statements(
+ cursor, statements, parameters, verbosity, acceptable_ora_err
+ )
def _create_test_user(self, cursor, parameters, verbosity, keepdb=False):
if verbosity >= 2:
- self.log('_create_test_user(): username = %s' % parameters['user'])
+ self.log("_create_test_user(): username = %s" % parameters["user"])
statements = [
"""CREATE USER %(user)s
IDENTIFIED BY "%(password)s"
@@ -235,40 +275,49 @@ class DatabaseCreation(BaseDatabaseCreation):
TO %(user)s""",
]
# Ignore "user already exists" error when keepdb is on
- acceptable_ora_err = 'ORA-01920' if keepdb else None
- success = self._execute_allow_fail_statements(cursor, statements, parameters, verbosity, acceptable_ora_err)
+ acceptable_ora_err = "ORA-01920" if keepdb else None
+ success = self._execute_allow_fail_statements(
+ cursor, statements, parameters, verbosity, acceptable_ora_err
+ )
# If the password was randomly generated, change the user accordingly.
- if not success and self._test_settings_get('PASSWORD') is None:
+ if not success and self._test_settings_get("PASSWORD") is None:
set_password = 'ALTER USER %(user)s IDENTIFIED BY "%(password)s"'
self._execute_statements(cursor, [set_password], parameters, verbosity)
# Most test suites can be run without "create view" and
# "create materialized view" privileges. But some need it.
- for object_type in ('VIEW', 'MATERIALIZED VIEW'):
- extra = 'GRANT CREATE %(object_type)s TO %(user)s'
- parameters['object_type'] = object_type
- success = self._execute_allow_fail_statements(cursor, [extra], parameters, verbosity, 'ORA-01031')
+ for object_type in ("VIEW", "MATERIALIZED VIEW"):
+ extra = "GRANT CREATE %(object_type)s TO %(user)s"
+ parameters["object_type"] = object_type
+ success = self._execute_allow_fail_statements(
+ cursor, [extra], parameters, verbosity, "ORA-01031"
+ )
if not success and verbosity >= 2:
- self.log('Failed to grant CREATE %s permission to test user. This may be ok.' % object_type)
+ self.log(
+ "Failed to grant CREATE %s permission to test user. This may be ok."
+ % object_type
+ )
def _execute_test_db_destruction(self, cursor, parameters, verbosity):
if verbosity >= 2:
- self.log('_execute_test_db_destruction(): dbname=%s' % parameters['user'])
+ self.log("_execute_test_db_destruction(): dbname=%s" % parameters["user"])
statements = [
- 'DROP TABLESPACE %(tblspace)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS',
- 'DROP TABLESPACE %(tblspace_temp)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS',
+ "DROP TABLESPACE %(tblspace)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS",
+ "DROP TABLESPACE %(tblspace_temp)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS",
]
self._execute_statements(cursor, statements, parameters, verbosity)
def _destroy_test_user(self, cursor, parameters, verbosity):
if verbosity >= 2:
- self.log('_destroy_test_user(): user=%s' % parameters['user'])
- self.log('Be patient. This can take some time...')
+ self.log("_destroy_test_user(): user=%s" % parameters["user"])
+ self.log("Be patient. This can take some time...")
statements = [
- 'DROP USER %(user)s CASCADE',
+ "DROP USER %(user)s CASCADE",
]
self._execute_statements(cursor, statements, parameters, verbosity)
- def _execute_statements(self, cursor, statements, parameters, verbosity, allow_quiet_fail=False):
+ def _execute_statements(
+ self, cursor, statements, parameters, verbosity, allow_quiet_fail=False
+ ):
for template in statements:
stmt = template % parameters
if verbosity >= 2:
@@ -277,10 +326,12 @@ class DatabaseCreation(BaseDatabaseCreation):
cursor.execute(stmt)
except Exception as err:
if (not allow_quiet_fail) or verbosity >= 2:
- self.log('Failed (%s)' % (err))
+ self.log("Failed (%s)" % (err))
raise
- def _execute_allow_fail_statements(self, cursor, statements, parameters, verbosity, acceptable_ora_err):
+ def _execute_allow_fail_statements(
+ self, cursor, statements, parameters, verbosity, acceptable_ora_err
+ ):
"""
Execute statements which are allowed to fail silently if the Oracle
error code given by `acceptable_ora_err` is raised. Return True if the
@@ -288,8 +339,16 @@ class DatabaseCreation(BaseDatabaseCreation):
"""
try:
# Statement can fail when acceptable_ora_err is not None
- allow_quiet_fail = acceptable_ora_err is not None and len(acceptable_ora_err) > 0
- self._execute_statements(cursor, statements, parameters, verbosity, allow_quiet_fail=allow_quiet_fail)
+ allow_quiet_fail = (
+ acceptable_ora_err is not None and len(acceptable_ora_err) > 0
+ )
+ self._execute_statements(
+ cursor,
+ statements,
+ parameters,
+ verbosity,
+ allow_quiet_fail=allow_quiet_fail,
+ )
return True
except DatabaseError as err:
description = str(err)
@@ -299,19 +358,19 @@ class DatabaseCreation(BaseDatabaseCreation):
def _get_test_db_params(self):
return {
- 'dbname': self._test_database_name(),
- 'user': self._test_database_user(),
- 'password': self._test_database_passwd(),
- 'tblspace': self._test_database_tblspace(),
- 'tblspace_temp': self._test_database_tblspace_tmp(),
- 'datafile': self._test_database_tblspace_datafile(),
- 'datafile_tmp': self._test_database_tblspace_tmp_datafile(),
- 'maxsize': self._test_database_tblspace_maxsize(),
- 'maxsize_tmp': self._test_database_tblspace_tmp_maxsize(),
- 'size': self._test_database_tblspace_size(),
- 'size_tmp': self._test_database_tblspace_tmp_size(),
- 'extsize': self._test_database_tblspace_extsize(),
- 'extsize_tmp': self._test_database_tblspace_tmp_extsize(),
+ "dbname": self._test_database_name(),
+ "user": self._test_database_user(),
+ "password": self._test_database_passwd(),
+ "tblspace": self._test_database_tblspace(),
+ "tblspace_temp": self._test_database_tblspace_tmp(),
+ "datafile": self._test_database_tblspace_datafile(),
+ "datafile_tmp": self._test_database_tblspace_tmp_datafile(),
+ "maxsize": self._test_database_tblspace_maxsize(),
+ "maxsize_tmp": self._test_database_tblspace_tmp_maxsize(),
+ "size": self._test_database_tblspace_size(),
+ "size_tmp": self._test_database_tblspace_tmp_size(),
+ "extsize": self._test_database_tblspace_extsize(),
+ "extsize_tmp": self._test_database_tblspace_tmp_extsize(),
}
def _test_settings_get(self, key, default=None, prefixed=None):
@@ -320,66 +379,67 @@ class DatabaseCreation(BaseDatabaseCreation):
prefixed entry from the main settings dict.
"""
settings_dict = self.connection.settings_dict
- val = settings_dict['TEST'].get(key, default)
+ val = settings_dict["TEST"].get(key, default)
if val is None and prefixed:
val = TEST_DATABASE_PREFIX + settings_dict[prefixed]
return val
def _test_database_name(self):
- return self._test_settings_get('NAME', prefixed='NAME')
+ return self._test_settings_get("NAME", prefixed="NAME")
def _test_database_create(self):
- return self._test_settings_get('CREATE_DB', default=True)
+ return self._test_settings_get("CREATE_DB", default=True)
def _test_user_create(self):
- return self._test_settings_get('CREATE_USER', default=True)
+ return self._test_settings_get("CREATE_USER", default=True)
def _test_database_user(self):
- return self._test_settings_get('USER', prefixed='USER')
+ return self._test_settings_get("USER", prefixed="USER")
def _test_database_passwd(self):
- password = self._test_settings_get('PASSWORD')
+ password = self._test_settings_get("PASSWORD")
if password is None and self._test_user_create():
# Oracle passwords are limited to 30 chars and can't contain symbols.
password = get_random_string(30)
return password
def _test_database_tblspace(self):
- return self._test_settings_get('TBLSPACE', prefixed='USER')
+ return self._test_settings_get("TBLSPACE", prefixed="USER")
def _test_database_tblspace_tmp(self):
settings_dict = self.connection.settings_dict
- return settings_dict['TEST'].get('TBLSPACE_TMP',
- TEST_DATABASE_PREFIX + settings_dict['USER'] + '_temp')
+ return settings_dict["TEST"].get(
+ "TBLSPACE_TMP", TEST_DATABASE_PREFIX + settings_dict["USER"] + "_temp"
+ )
def _test_database_tblspace_datafile(self):
- tblspace = '%s.dbf' % self._test_database_tblspace()
- return self._test_settings_get('DATAFILE', default=tblspace)
+ tblspace = "%s.dbf" % self._test_database_tblspace()
+ return self._test_settings_get("DATAFILE", default=tblspace)
def _test_database_tblspace_tmp_datafile(self):
- tblspace = '%s.dbf' % self._test_database_tblspace_tmp()
- return self._test_settings_get('DATAFILE_TMP', default=tblspace)
+ tblspace = "%s.dbf" % self._test_database_tblspace_tmp()
+ return self._test_settings_get("DATAFILE_TMP", default=tblspace)
def _test_database_tblspace_maxsize(self):
- return self._test_settings_get('DATAFILE_MAXSIZE', default='500M')
+ return self._test_settings_get("DATAFILE_MAXSIZE", default="500M")
def _test_database_tblspace_tmp_maxsize(self):
- return self._test_settings_get('DATAFILE_TMP_MAXSIZE', default='500M')
+ return self._test_settings_get("DATAFILE_TMP_MAXSIZE", default="500M")
def _test_database_tblspace_size(self):
- return self._test_settings_get('DATAFILE_SIZE', default='50M')
+ return self._test_settings_get("DATAFILE_SIZE", default="50M")
def _test_database_tblspace_tmp_size(self):
- return self._test_settings_get('DATAFILE_TMP_SIZE', default='50M')
+ return self._test_settings_get("DATAFILE_TMP_SIZE", default="50M")
def _test_database_tblspace_extsize(self):
- return self._test_settings_get('DATAFILE_EXTSIZE', default='25M')
+ return self._test_settings_get("DATAFILE_EXTSIZE", default="25M")
def _test_database_tblspace_tmp_extsize(self):
- return self._test_settings_get('DATAFILE_TMP_EXTSIZE', default='25M')
+ return self._test_settings_get("DATAFILE_TMP_EXTSIZE", default="25M")
def _test_database_oracle_managed_files(self):
- return self._test_settings_get('ORACLE_MANAGED_FILES', default=False)
+ return self._test_settings_get("ORACLE_MANAGED_FILES", default=False)
def _get_test_db_name(self):
"""
@@ -387,14 +447,14 @@ class DatabaseCreation(BaseDatabaseCreation):
to work. This isn't a great deal in this case because DB names as
handled by Django don't have real counterparts in Oracle.
"""
- return self.connection.settings_dict['NAME']
+ return self.connection.settings_dict["NAME"]
def test_db_signature(self):
settings_dict = self.connection.settings_dict
return (
- settings_dict['HOST'],
- settings_dict['PORT'],
- settings_dict['ENGINE'],
- settings_dict['NAME'],
+ settings_dict["HOST"],
+ settings_dict["PORT"],
+ settings_dict["ENGINE"],
+ settings_dict["NAME"],
self._test_database_user(),
)
diff --git a/django/db/backends/oracle/features.py b/django/db/backends/oracle/features.py
index 898a82e5d5..6a3c9dab79 100644
--- a/django/db/backends/oracle/features.py
+++ b/django/db/backends/oracle/features.py
@@ -65,51 +65,51 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_json_field_contains = False
supports_collation_on_textfield = False
test_collations = {
- 'ci': 'BINARY_CI',
- 'cs': 'BINARY',
- 'non_default': 'SWEDISH_CI',
- 'swedish_ci': 'SWEDISH_CI',
+ "ci": "BINARY_CI",
+ "cs": "BINARY",
+ "non_default": "SWEDISH_CI",
+ "swedish_ci": "SWEDISH_CI",
}
test_now_utc_template = "CURRENT_TIMESTAMP AT TIME ZONE 'UTC'"
django_test_skips = {
"Oracle doesn't support SHA224.": {
- 'db_functions.text.test_sha224.SHA224Tests.test_basic',
- 'db_functions.text.test_sha224.SHA224Tests.test_transform',
+ "db_functions.text.test_sha224.SHA224Tests.test_basic",
+ "db_functions.text.test_sha224.SHA224Tests.test_transform",
},
"Oracle doesn't correctly calculate ISO 8601 week numbering before "
"1583 (the Gregorian calendar was introduced in 1582).": {
- 'db_functions.datetime.test_extract_trunc.DateFunctionTests.test_trunc_week_before_1000',
- 'db_functions.datetime.test_extract_trunc.DateFunctionWithTimeZoneTests.test_trunc_week_before_1000',
+ "db_functions.datetime.test_extract_trunc.DateFunctionTests.test_trunc_week_before_1000",
+ "db_functions.datetime.test_extract_trunc.DateFunctionWithTimeZoneTests.test_trunc_week_before_1000",
},
"Oracle doesn't support bitwise XOR.": {
- 'expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor',
- 'expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor_null',
- 'expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor_right_null',
+ "expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor",
+ "expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor_null",
+ "expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor_right_null",
},
"Oracle requires ORDER BY in row_number, ANSI:SQL doesn't.": {
- 'expressions_window.tests.WindowFunctionTests.test_row_number_no_ordering',
+ "expressions_window.tests.WindowFunctionTests.test_row_number_no_ordering",
},
- 'Raises ORA-00600: internal error code.': {
- 'model_fields.test_jsonfield.TestQuerying.test_usage_in_subquery',
+ "Raises ORA-00600: internal error code.": {
+ "model_fields.test_jsonfield.TestQuerying.test_usage_in_subquery",
},
}
django_test_expected_failures = {
# A bug in Django/cx_Oracle with respect to string handling (#23843).
- 'annotations.tests.NonAggregateAnnotationTestCase.test_custom_functions',
- 'annotations.tests.NonAggregateAnnotationTestCase.test_custom_functions_can_ref_other_functions',
+ "annotations.tests.NonAggregateAnnotationTestCase.test_custom_functions",
+ "annotations.tests.NonAggregateAnnotationTestCase.test_custom_functions_can_ref_other_functions",
}
@cached_property
def introspected_field_types(self):
return {
**super().introspected_field_types,
- 'GenericIPAddressField': 'CharField',
- 'PositiveBigIntegerField': 'BigIntegerField',
- 'PositiveIntegerField': 'IntegerField',
- 'PositiveSmallIntegerField': 'IntegerField',
- 'SmallIntegerField': 'IntegerField',
- 'TimeField': 'DateTimeField',
+ "GenericIPAddressField": "CharField",
+ "PositiveBigIntegerField": "BigIntegerField",
+ "PositiveIntegerField": "IntegerField",
+ "PositiveSmallIntegerField": "IntegerField",
+ "SmallIntegerField": "IntegerField",
+ "TimeField": "DateTimeField",
}
@cached_property
diff --git a/django/db/backends/oracle/functions.py b/django/db/backends/oracle/functions.py
index 1aeb4597e3..936cc9e73f 100644
--- a/django/db/backends/oracle/functions.py
+++ b/django/db/backends/oracle/functions.py
@@ -2,7 +2,7 @@ from django.db.models import DecimalField, DurationField, Func
class IntervalToSeconds(Func):
- function = ''
+ function = ""
template = """
EXTRACT(day from %(expressions)s) * 86400 +
EXTRACT(hour from %(expressions)s) * 3600 +
@@ -11,12 +11,16 @@ class IntervalToSeconds(Func):
"""
def __init__(self, expression, *, output_field=None, **extra):
- super().__init__(expression, output_field=output_field or DecimalField(), **extra)
+ super().__init__(
+ expression, output_field=output_field or DecimalField(), **extra
+ )
class SecondsToInterval(Func):
- function = 'NUMTODSINTERVAL'
+ function = "NUMTODSINTERVAL"
template = "%(function)s(%(expressions)s, 'SECOND')"
def __init__(self, expression, *, output_field=None, **extra):
- super().__init__(expression, output_field=output_field or DurationField(), **extra)
+ super().__init__(
+ expression, output_field=output_field or DurationField(), **extra
+ )
diff --git a/django/db/backends/oracle/introspection.py b/django/db/backends/oracle/introspection.py
index b8882e3cd8..17ffd3a99d 100644
--- a/django/db/backends/oracle/introspection.py
+++ b/django/db/backends/oracle/introspection.py
@@ -3,12 +3,12 @@ from collections import namedtuple
import cx_Oracle
from django.db import models
-from django.db.backends.base.introspection import (
- BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo,
-)
+from django.db.backends.base.introspection import BaseDatabaseIntrospection
+from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo
+from django.db.backends.base.introspection import TableInfo
from django.utils.functional import cached_property
-FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('is_autofield', 'is_json'))
+FieldInfo = namedtuple("FieldInfo", BaseFieldInfo._fields + ("is_autofield", "is_json"))
class DatabaseIntrospection(BaseDatabaseIntrospection):
@@ -19,33 +19,33 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
def data_types_reverse(self):
if self.connection.cx_oracle_version < (8,):
return {
- cx_Oracle.BLOB: 'BinaryField',
- cx_Oracle.CLOB: 'TextField',
- cx_Oracle.DATETIME: 'DateField',
- cx_Oracle.FIXED_CHAR: 'CharField',
- cx_Oracle.FIXED_NCHAR: 'CharField',
- cx_Oracle.INTERVAL: 'DurationField',
- cx_Oracle.NATIVE_FLOAT: 'FloatField',
- cx_Oracle.NCHAR: 'CharField',
- cx_Oracle.NCLOB: 'TextField',
- cx_Oracle.NUMBER: 'DecimalField',
- cx_Oracle.STRING: 'CharField',
- cx_Oracle.TIMESTAMP: 'DateTimeField',
+ cx_Oracle.BLOB: "BinaryField",
+ cx_Oracle.CLOB: "TextField",
+ cx_Oracle.DATETIME: "DateField",
+ cx_Oracle.FIXED_CHAR: "CharField",
+ cx_Oracle.FIXED_NCHAR: "CharField",
+ cx_Oracle.INTERVAL: "DurationField",
+ cx_Oracle.NATIVE_FLOAT: "FloatField",
+ cx_Oracle.NCHAR: "CharField",
+ cx_Oracle.NCLOB: "TextField",
+ cx_Oracle.NUMBER: "DecimalField",
+ cx_Oracle.STRING: "CharField",
+ cx_Oracle.TIMESTAMP: "DateTimeField",
}
else:
return {
- cx_Oracle.DB_TYPE_DATE: 'DateField',
- cx_Oracle.DB_TYPE_BINARY_DOUBLE: 'FloatField',
- cx_Oracle.DB_TYPE_BLOB: 'BinaryField',
- cx_Oracle.DB_TYPE_CHAR: 'CharField',
- cx_Oracle.DB_TYPE_CLOB: 'TextField',
- cx_Oracle.DB_TYPE_INTERVAL_DS: 'DurationField',
- cx_Oracle.DB_TYPE_NCHAR: 'CharField',
- cx_Oracle.DB_TYPE_NCLOB: 'TextField',
- cx_Oracle.DB_TYPE_NVARCHAR: 'CharField',
- cx_Oracle.DB_TYPE_NUMBER: 'DecimalField',
- cx_Oracle.DB_TYPE_TIMESTAMP: 'DateTimeField',
- cx_Oracle.DB_TYPE_VARCHAR: 'CharField',
+ cx_Oracle.DB_TYPE_DATE: "DateField",
+ cx_Oracle.DB_TYPE_BINARY_DOUBLE: "FloatField",
+ cx_Oracle.DB_TYPE_BLOB: "BinaryField",
+ cx_Oracle.DB_TYPE_CHAR: "CharField",
+ cx_Oracle.DB_TYPE_CLOB: "TextField",
+ cx_Oracle.DB_TYPE_INTERVAL_DS: "DurationField",
+ cx_Oracle.DB_TYPE_NCHAR: "CharField",
+ cx_Oracle.DB_TYPE_NCLOB: "TextField",
+ cx_Oracle.DB_TYPE_NVARCHAR: "CharField",
+ cx_Oracle.DB_TYPE_NUMBER: "DecimalField",
+ cx_Oracle.DB_TYPE_TIMESTAMP: "DateTimeField",
+ cx_Oracle.DB_TYPE_VARCHAR: "CharField",
}
def get_field_type(self, data_type, description):
@@ -53,25 +53,30 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
precision, scale = description[4:6]
if scale == 0:
if precision > 11:
- return 'BigAutoField' if description.is_autofield else 'BigIntegerField'
+ return (
+ "BigAutoField"
+ if description.is_autofield
+ else "BigIntegerField"
+ )
elif 1 < precision < 6 and description.is_autofield:
- return 'SmallAutoField'
+ return "SmallAutoField"
elif precision == 1:
- return 'BooleanField'
+ return "BooleanField"
elif description.is_autofield:
- return 'AutoField'
+ return "AutoField"
else:
- return 'IntegerField'
+ return "IntegerField"
elif scale == -127:
- return 'FloatField'
+ return "FloatField"
elif data_type == cx_Oracle.NCLOB and description.is_json:
- return 'JSONField'
+ return "JSONField"
return super().get_field_type(data_type, description)
def get_table_list(self, cursor):
"""Return a list of table and view names in the current database."""
- cursor.execute("""
+ cursor.execute(
+ """
SELECT table_name, 't'
FROM user_tables
WHERE
@@ -84,8 +89,12 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
SELECT view_name, 'v' FROM user_views
UNION ALL
SELECT mview_name, 'v' FROM user_mviews
- """)
- return [TableInfo(self.identifier_converter(row[0]), row[1]) for row in cursor.fetchall()]
+ """
+ )
+ return [
+ TableInfo(self.identifier_converter(row[0]), row[1])
+ for row in cursor.fetchall()
+ ]
def get_table_description(self, cursor, table_name):
"""
@@ -131,22 +140,40 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
[table_name],
)
field_map = {
- column: (internal_size, default if default != 'NULL' else None, collation, is_autofield, is_json)
+ column: (
+ internal_size,
+ default if default != "NULL" else None,
+ collation,
+ is_autofield,
+ is_json,
+ )
for column, default, collation, internal_size, is_autofield, is_json in cursor.fetchall()
}
self.cache_bust_counter += 1
- cursor.execute("SELECT * FROM {} WHERE ROWNUM < 2 AND {} > 0".format(
- self.connection.ops.quote_name(table_name),
- self.cache_bust_counter))
+ cursor.execute(
+ "SELECT * FROM {} WHERE ROWNUM < 2 AND {} > 0".format(
+ self.connection.ops.quote_name(table_name), self.cache_bust_counter
+ )
+ )
description = []
for desc in cursor.description:
name = desc[0]
internal_size, default, collation, is_autofield, is_json = field_map[name]
name = name % {} # cx_Oracle, for some reason, doubles percent signs.
- description.append(FieldInfo(
- self.identifier_converter(name), *desc[1:3], internal_size, desc[4] or 0,
- desc[5] or 0, *desc[6:], default, collation, is_autofield, is_json,
- ))
+ description.append(
+ FieldInfo(
+ self.identifier_converter(name),
+ *desc[1:3],
+ internal_size,
+ desc[4] or 0,
+ desc[5] or 0,
+ *desc[6:],
+ default,
+ collation,
+ is_autofield,
+ is_json,
+ )
+ )
return description
def identifier_converter(self, name):
@@ -175,16 +202,18 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
# Oracle allows only one identity column per table.
row = cursor.fetchone()
if row:
- return [{
- 'name': self.identifier_converter(row[0]),
- 'table': self.identifier_converter(table_name),
- 'column': self.identifier_converter(row[1]),
- }]
+ return [
+ {
+ "name": self.identifier_converter(row[0]),
+ "table": self.identifier_converter(table_name),
+ "column": self.identifier_converter(row[1]),
+ }
+ ]
# To keep backward compatibility for AutoFields that aren't Oracle
# identity columns.
for f in table_fields:
if isinstance(f, models.AutoField):
- return [{'table': table_name, 'column': f.column}]
+ return [{"table": table_name, "column": f.column}]
return []
def get_relations(self, cursor, table_name):
@@ -193,19 +222,23 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
representing all foreign keys in the given table.
"""
table_name = table_name.upper()
- cursor.execute("""
+ cursor.execute(
+ """
SELECT ca.column_name, cb.table_name, cb.column_name
FROM user_constraints, USER_CONS_COLUMNS ca, USER_CONS_COLUMNS cb
WHERE user_constraints.table_name = %s AND
user_constraints.constraint_name = ca.constraint_name AND
user_constraints.r_constraint_name = cb.constraint_name AND
- ca.position = cb.position""", [table_name])
+ ca.position = cb.position""",
+ [table_name],
+ )
return {
self.identifier_converter(field_name): (
self.identifier_converter(rel_field_name),
self.identifier_converter(rel_table_name),
- ) for field_name, rel_table_name, rel_field_name in cursor.fetchall()
+ )
+ for field_name, rel_table_name, rel_field_name in cursor.fetchall()
}
def get_primary_key_column(self, cursor, table_name):
@@ -265,12 +298,12 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
for constraint, columns, pk, unique, check in cursor.fetchall():
constraint = self.identifier_converter(constraint)
constraints[constraint] = {
- 'columns': columns.split(','),
- 'primary_key': pk,
- 'unique': unique,
- 'foreign_key': None,
- 'check': check,
- 'index': unique, # All uniques come with an index
+ "columns": columns.split(","),
+ "primary_key": pk,
+ "unique": unique,
+ "foreign_key": None,
+ "check": check,
+ "index": unique, # All uniques come with an index
}
# Foreign key constraints
cursor.execute(
@@ -296,12 +329,12 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
for constraint, columns, other_table, other_column in cursor.fetchall():
constraint = self.identifier_converter(constraint)
constraints[constraint] = {
- 'primary_key': False,
- 'unique': False,
- 'foreign_key': (other_table, other_column),
- 'check': False,
- 'index': False,
- 'columns': columns.split(','),
+ "primary_key": False,
+ "unique": False,
+ "foreign_key": (other_table, other_column),
+ "check": False,
+ "index": False,
+ "columns": columns.split(","),
}
# Now get indexes
cursor.execute(
@@ -328,13 +361,13 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
for constraint, type_, unique, columns, orders in cursor.fetchall():
constraint = self.identifier_converter(constraint)
constraints[constraint] = {
- 'primary_key': False,
- 'unique': unique == 'unique',
- 'foreign_key': None,
- 'check': False,
- 'index': True,
- 'type': 'idx' if type_ == 'normal' else type_,
- 'columns': columns.split(','),
- 'orders': orders.split(','),
+ "primary_key": False,
+ "unique": unique == "unique",
+ "foreign_key": None,
+ "check": False,
+ "index": True,
+ "type": "idx" if type_ == "normal" else type_,
+ "columns": columns.split(","),
+ "orders": orders.split(","),
}
return constraints
diff --git a/django/db/backends/oracle/operations.py b/django/db/backends/oracle/operations.py
index d53942b919..63f9714333 100644
--- a/django/db/backends/oracle/operations.py
+++ b/django/db/backends/oracle/operations.py
@@ -5,9 +5,7 @@ from functools import lru_cache
from django.conf import settings
from django.db import DatabaseError, NotSupportedError
from django.db.backends.base.operations import BaseDatabaseOperations
-from django.db.backends.utils import (
- split_tzname_delta, strip_quotes, truncate_name,
-)
+from django.db.backends.utils import split_tzname_delta, strip_quotes, truncate_name
from django.db.models import AutoField, Exists, ExpressionWrapper, Lookup
from django.db.models.expressions import RawSQL
from django.db.models.sql.where import WhereNode
@@ -25,17 +23,17 @@ class DatabaseOperations(BaseDatabaseOperations):
# SmallIntegerField uses NUMBER(11) instead of NUMBER(5), which is used by
# SmallAutoField, to preserve backward compatibility.
integer_field_ranges = {
- 'SmallIntegerField': (-99999999999, 99999999999),
- 'IntegerField': (-99999999999, 99999999999),
- 'BigIntegerField': (-9999999999999999999, 9999999999999999999),
- 'PositiveBigIntegerField': (0, 9999999999999999999),
- 'PositiveSmallIntegerField': (0, 99999999999),
- 'PositiveIntegerField': (0, 99999999999),
- 'SmallAutoField': (-99999, 99999),
- 'AutoField': (-99999999999, 99999999999),
- 'BigAutoField': (-9999999999999999999, 9999999999999999999),
+ "SmallIntegerField": (-99999999999, 99999999999),
+ "IntegerField": (-99999999999, 99999999999),
+ "BigIntegerField": (-9999999999999999999, 9999999999999999999),
+ "PositiveBigIntegerField": (0, 9999999999999999999),
+ "PositiveSmallIntegerField": (0, 99999999999),
+ "PositiveIntegerField": (0, 99999999999),
+ "SmallAutoField": (-99999, 99999),
+ "AutoField": (-99999999999, 99999999999),
+ "BigAutoField": (-9999999999999999999, 9999999999999999999),
}
- set_operators = {**BaseDatabaseOperations.set_operators, 'difference': 'MINUS'}
+ set_operators = {**BaseDatabaseOperations.set_operators, "difference": "MINUS"}
# TODO: colorize this SQL code with style.SQL_KEYWORD(), etc.
_sequence_reset_sql = """
@@ -63,34 +61,34 @@ END;
/"""
# Oracle doesn't support string without precision; use the max string size.
- cast_char_field_without_max_length = 'NVARCHAR2(2000)'
+ cast_char_field_without_max_length = "NVARCHAR2(2000)"
cast_data_types = {
- 'AutoField': 'NUMBER(11)',
- 'BigAutoField': 'NUMBER(19)',
- 'SmallAutoField': 'NUMBER(5)',
- 'TextField': cast_char_field_without_max_length,
+ "AutoField": "NUMBER(11)",
+ "BigAutoField": "NUMBER(19)",
+ "SmallAutoField": "NUMBER(5)",
+ "TextField": cast_char_field_without_max_length,
}
def cache_key_culling_sql(self):
- cache_key = self.quote_name('cache_key')
+ cache_key = self.quote_name("cache_key")
return (
- f'SELECT {cache_key} '
- f'FROM %s '
- f'ORDER BY {cache_key} OFFSET %%s ROWS FETCH FIRST 1 ROWS ONLY'
+ f"SELECT {cache_key} "
+ f"FROM %s "
+ f"ORDER BY {cache_key} OFFSET %%s ROWS FETCH FIRST 1 ROWS ONLY"
)
def date_extract_sql(self, lookup_type, field_name):
- if lookup_type == 'week_day':
+ if lookup_type == "week_day":
# TO_CHAR(field, 'D') returns an integer from 1-7, where 1=Sunday.
return "TO_CHAR(%s, 'D')" % field_name
- elif lookup_type == 'iso_week_day':
+ elif lookup_type == "iso_week_day":
return "TO_CHAR(%s - 1, 'D')" % field_name
- elif lookup_type == 'week':
+ elif lookup_type == "week":
# IW = ISO week number
return "TO_CHAR(%s, 'IW')" % field_name
- elif lookup_type == 'quarter':
+ elif lookup_type == "quarter":
return "TO_CHAR(%s, 'Q')" % field_name
- elif lookup_type == 'iso_year':
+ elif lookup_type == "iso_year":
return "TO_CHAR(%s, 'IYYY')" % field_name
else:
# https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf/EXTRACT-datetime.html
@@ -99,11 +97,11 @@ END;
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
field_name = self._convert_field_to_tz(field_name, tzname)
# https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf/ROUND-and-TRUNC-Date-Functions.html
- if lookup_type in ('year', 'month'):
+ if lookup_type in ("year", "month"):
return "TRUNC(%s, '%s')" % (field_name, lookup_type.upper())
- elif lookup_type == 'quarter':
+ elif lookup_type == "quarter":
return "TRUNC(%s, 'Q')" % field_name
- elif lookup_type == 'week':
+ elif lookup_type == "week":
return "TRUNC(%s, 'IW')" % field_name
else:
return "TRUNC(%s)" % field_name
@@ -112,11 +110,11 @@ END;
# if the time zone name is passed in parameter. Use interpolation instead.
# https://groups.google.com/forum/#!msg/django-developers/zwQju7hbG78/9l934yelwfsJ
# This regexp matches all time zone names from the zoneinfo database.
- _tzname_re = _lazy_re_compile(r'^[\w/:+-]+$')
+ _tzname_re = _lazy_re_compile(r"^[\w/:+-]+$")
def _prepare_tzname_delta(self, tzname):
tzname, sign, offset = split_tzname_delta(tzname)
- return f'{sign}{offset}' if offset else tzname
+ return f"{sign}{offset}" if offset else tzname
def _convert_field_to_tz(self, field_name, tzname):
if not (settings.USE_TZ and tzname):
@@ -136,7 +134,7 @@ END;
def datetime_cast_date_sql(self, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
- return 'TRUNC(%s)' % field_name
+ return "TRUNC(%s)" % field_name
def datetime_cast_time_sql(self, field_name, tzname):
# Since `TimeField` values are stored as TIMESTAMP change to the
@@ -146,7 +144,8 @@ END;
"'YYYY-MM-DD HH24:MI:SS.FF')"
) % self._convert_field_to_tz(field_name, tzname)
return "CASE WHEN %s IS NOT NULL THEN %s ELSE NULL END" % (
- field_name, convert_datetime_sql,
+ field_name,
+ convert_datetime_sql,
)
def datetime_extract_sql(self, lookup_type, field_name, tzname):
@@ -156,20 +155,22 @@ END;
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
# https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf/ROUND-and-TRUNC-Date-Functions.html
- if lookup_type in ('year', 'month'):
+ if lookup_type in ("year", "month"):
sql = "TRUNC(%s, '%s')" % (field_name, lookup_type.upper())
- elif lookup_type == 'quarter':
+ elif lookup_type == "quarter":
sql = "TRUNC(%s, 'Q')" % field_name
- elif lookup_type == 'week':
+ elif lookup_type == "week":
sql = "TRUNC(%s, 'IW')" % field_name
- elif lookup_type == 'day':
+ elif lookup_type == "day":
sql = "TRUNC(%s)" % field_name
- elif lookup_type == 'hour':
+ elif lookup_type == "hour":
sql = "TRUNC(%s, 'HH24')" % field_name
- elif lookup_type == 'minute':
+ elif lookup_type == "minute":
sql = "TRUNC(%s, 'MI')" % field_name
else:
- sql = "CAST(%s AS DATE)" % field_name # Cast to DATE removes sub-second precision.
+ sql = (
+ "CAST(%s AS DATE)" % field_name
+ ) # Cast to DATE removes sub-second precision.
return sql
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
@@ -177,31 +178,33 @@ END;
# `DateTimeField` and `TimeField` are stored as TIMESTAMP where
# the date part of the later is ignored.
field_name = self._convert_field_to_tz(field_name, tzname)
- if lookup_type == 'hour':
+ if lookup_type == "hour":
sql = "TRUNC(%s, 'HH24')" % field_name
- elif lookup_type == 'minute':
+ elif lookup_type == "minute":
sql = "TRUNC(%s, 'MI')" % field_name
- elif lookup_type == 'second':
- sql = "CAST(%s AS DATE)" % field_name # Cast to DATE removes sub-second precision.
+ elif lookup_type == "second":
+ sql = (
+ "CAST(%s AS DATE)" % field_name
+ ) # Cast to DATE removes sub-second precision.
return sql
def get_db_converters(self, expression):
converters = super().get_db_converters(expression)
internal_type = expression.output_field.get_internal_type()
- if internal_type in ['JSONField', 'TextField']:
+ if internal_type in ["JSONField", "TextField"]:
converters.append(self.convert_textfield_value)
- elif internal_type == 'BinaryField':
+ elif internal_type == "BinaryField":
converters.append(self.convert_binaryfield_value)
- elif internal_type == 'BooleanField':
+ elif internal_type == "BooleanField":
converters.append(self.convert_booleanfield_value)
- elif internal_type == 'DateTimeField':
+ elif internal_type == "DateTimeField":
if settings.USE_TZ:
converters.append(self.convert_datetimefield_value)
- elif internal_type == 'DateField':
+ elif internal_type == "DateField":
converters.append(self.convert_datefield_value)
- elif internal_type == 'TimeField':
+ elif internal_type == "TimeField":
converters.append(self.convert_timefield_value)
- elif internal_type == 'UUIDField':
+ elif internal_type == "UUIDField":
converters.append(self.convert_uuidfield_value)
# Oracle stores empty strings as null. If the field accepts the empty
# string, undo this to adhere to the Django convention of using
@@ -209,8 +212,8 @@ END;
if expression.output_field.empty_strings_allowed:
converters.append(
self.convert_empty_bytes
- if internal_type == 'BinaryField' else
- self.convert_empty_string
+ if internal_type == "BinaryField"
+ else self.convert_empty_string
)
return converters
@@ -255,11 +258,11 @@ END;
@staticmethod
def convert_empty_string(value, expression, connection):
- return '' if value is None else value
+ return "" if value is None else value
@staticmethod
def convert_empty_bytes(value, expression, connection):
- return b'' if value is None else value
+ return b"" if value is None else value
def deferrable_sql(self):
return " DEFERRABLE INITIALLY DEFERRED"
@@ -270,16 +273,16 @@ END;
value = param.get_value()
if value == []:
raise DatabaseError(
- 'The database did not return a new row id. Probably '
+ "The database did not return a new row id. Probably "
'"ORA-1403: no data found" was raised internally but was '
- 'hidden by the Oracle OCI library (see '
- 'https://code.djangoproject.com/ticket/28859).'
+ "hidden by the Oracle OCI library (see "
+ "https://code.djangoproject.com/ticket/28859)."
)
columns.append(value[0])
return tuple(columns)
def field_cast_sql(self, db_type, internal_type):
- if db_type and db_type.endswith('LOB') and internal_type != 'JSONField':
+ if db_type and db_type.endswith("LOB") and internal_type != "JSONField":
return "DBMS_LOB.SUBSTR(%s)"
else:
return "%s"
@@ -289,10 +292,14 @@ END;
def limit_offset_sql(self, low_mark, high_mark):
fetch, offset = self._get_limit_offset_params(low_mark, high_mark)
- return ' '.join(sql for sql in (
- ('OFFSET %d ROWS' % offset) if offset else None,
- ('FETCH FIRST %d ROWS ONLY' % fetch) if fetch else None,
- ) if sql)
+ return " ".join(
+ sql
+ for sql in (
+ ("OFFSET %d ROWS" % offset) if offset else None,
+ ("FETCH FIRST %d ROWS ONLY" % fetch) if fetch else None,
+ )
+ if sql
+ )
def last_executed_query(self, cursor, sql, params):
# https://cx-oracle.readthedocs.io/en/latest/api_manual/cursor.html#Cursor.statement
@@ -303,10 +310,14 @@ END;
# parameters manually.
if isinstance(params, (tuple, list)):
for i, param in enumerate(params):
- statement = statement.replace(':arg%d' % i, force_str(param, errors='replace'))
+ statement = statement.replace(
+ ":arg%d" % i, force_str(param, errors="replace")
+ )
elif isinstance(params, dict):
for key, param in params.items():
- statement = statement.replace(':%s' % key, force_str(param, errors='replace'))
+ statement = statement.replace(
+ ":%s" % key, force_str(param, errors="replace")
+ )
return statement
def last_insert_id(self, cursor, table_name, pk_name):
@@ -315,10 +326,10 @@ END;
return cursor.fetchone()[0]
def lookup_cast(self, lookup_type, internal_type=None):
- if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'):
+ if lookup_type in ("iexact", "icontains", "istartswith", "iendswith"):
return "UPPER(%s)"
- if internal_type == 'JSONField' and lookup_type == 'exact':
- return 'DBMS_LOB.SUBSTR(%s)'
+ if internal_type == "JSONField" and lookup_type == "exact":
+ return "DBMS_LOB.SUBSTR(%s)"
return "%s"
def max_in_list_size(self):
@@ -335,7 +346,7 @@ END;
def process_clob(self, value):
if value is None:
- return ''
+ return ""
return value.read()
def quote_name(self, name):
@@ -348,30 +359,33 @@ END;
# Oracle puts the query text into a (query % args) construct, so % signs
# in names need to be escaped. The '%%' will be collapsed back to '%' at
# that stage so we aren't really making the name longer here.
- name = name.replace('%', '%%')
+ name = name.replace("%", "%%")
return name.upper()
def regex_lookup(self, lookup_type):
- if lookup_type == 'regex':
+ if lookup_type == "regex":
match_option = "'c'"
else:
match_option = "'i'"
- return 'REGEXP_LIKE(%%s, %%s, %s)' % match_option
+ return "REGEXP_LIKE(%%s, %%s, %s)" % match_option
def return_insert_columns(self, fields):
if not fields:
- return '', ()
+ return "", ()
field_names = []
params = []
for field in fields:
- field_names.append('%s.%s' % (
- self.quote_name(field.model._meta.db_table),
- self.quote_name(field.column),
- ))
+ field_names.append(
+ "%s.%s"
+ % (
+ self.quote_name(field.model._meta.db_table),
+ self.quote_name(field.column),
+ )
+ )
params.append(InsertVar(field))
- return 'RETURNING %s INTO %s' % (
- ', '.join(field_names),
- ', '.join(['%s'] * len(params)),
+ return "RETURNING %s INTO %s" % (
+ ", ".join(field_names),
+ ", ".join(["%s"] * len(params)),
), tuple(params)
def __foreign_key_constraints(self, table_name, recursive):
@@ -430,42 +444,54 @@ END;
# which truncates all dependent tables by manually retrieving all
# foreign key constraints and resolving dependencies.
for table in tables:
- for foreign_table, constraint in self._foreign_key_constraints(table, recursive=allow_cascade):
+ for foreign_table, constraint in self._foreign_key_constraints(
+ table, recursive=allow_cascade
+ ):
if allow_cascade:
truncated_tables.add(foreign_table)
constraints.add((foreign_table, constraint))
- sql = [
- '%s %s %s %s %s %s %s %s;' % (
- style.SQL_KEYWORD('ALTER'),
- style.SQL_KEYWORD('TABLE'),
- style.SQL_FIELD(self.quote_name(table)),
- style.SQL_KEYWORD('DISABLE'),
- style.SQL_KEYWORD('CONSTRAINT'),
- style.SQL_FIELD(self.quote_name(constraint)),
- style.SQL_KEYWORD('KEEP'),
- style.SQL_KEYWORD('INDEX'),
- ) for table, constraint in constraints
- ] + [
- '%s %s %s;' % (
- style.SQL_KEYWORD('TRUNCATE'),
- style.SQL_KEYWORD('TABLE'),
- style.SQL_FIELD(self.quote_name(table)),
- ) for table in truncated_tables
- ] + [
- '%s %s %s %s %s %s;' % (
- style.SQL_KEYWORD('ALTER'),
- style.SQL_KEYWORD('TABLE'),
- style.SQL_FIELD(self.quote_name(table)),
- style.SQL_KEYWORD('ENABLE'),
- style.SQL_KEYWORD('CONSTRAINT'),
- style.SQL_FIELD(self.quote_name(constraint)),
- ) for table, constraint in constraints
- ]
+ sql = (
+ [
+ "%s %s %s %s %s %s %s %s;"
+ % (
+ style.SQL_KEYWORD("ALTER"),
+ style.SQL_KEYWORD("TABLE"),
+ style.SQL_FIELD(self.quote_name(table)),
+ style.SQL_KEYWORD("DISABLE"),
+ style.SQL_KEYWORD("CONSTRAINT"),
+ style.SQL_FIELD(self.quote_name(constraint)),
+ style.SQL_KEYWORD("KEEP"),
+ style.SQL_KEYWORD("INDEX"),
+ )
+ for table, constraint in constraints
+ ]
+ + [
+ "%s %s %s;"
+ % (
+ style.SQL_KEYWORD("TRUNCATE"),
+ style.SQL_KEYWORD("TABLE"),
+ style.SQL_FIELD(self.quote_name(table)),
+ )
+ for table in truncated_tables
+ ]
+ + [
+ "%s %s %s %s %s %s;"
+ % (
+ style.SQL_KEYWORD("ALTER"),
+ style.SQL_KEYWORD("TABLE"),
+ style.SQL_FIELD(self.quote_name(table)),
+ style.SQL_KEYWORD("ENABLE"),
+ style.SQL_KEYWORD("CONSTRAINT"),
+ style.SQL_FIELD(self.quote_name(constraint)),
+ )
+ for table, constraint in constraints
+ ]
+ )
if reset_sequences:
sequences = [
sequence
for sequence in self.connection.introspection.sequence_list()
- if sequence['table'].upper() in truncated_tables
+ if sequence["table"].upper() in truncated_tables
]
# Since we've just deleted all the rows, running our sequence ALTER
# code will reset the sequence to 0.
@@ -475,15 +501,17 @@ END;
def sequence_reset_by_name_sql(self, style, sequences):
sql = []
for sequence_info in sequences:
- no_autofield_sequence_name = self._get_no_autofield_sequence_name(sequence_info['table'])
- table = self.quote_name(sequence_info['table'])
- column = self.quote_name(sequence_info['column'] or 'id')
+ no_autofield_sequence_name = self._get_no_autofield_sequence_name(
+ sequence_info["table"]
+ )
+ table = self.quote_name(sequence_info["table"])
+ column = self.quote_name(sequence_info["column"] or "id")
query = self._sequence_reset_sql % {
- 'no_autofield_sequence_name': no_autofield_sequence_name,
- 'table': table,
- 'column': column,
- 'table_name': strip_quotes(table),
- 'column_name': strip_quotes(column),
+ "no_autofield_sequence_name": no_autofield_sequence_name,
+ "table": table,
+ "column": column,
+ "table_name": strip_quotes(table),
+ "column_name": strip_quotes(column),
}
sql.append(query)
return sql
@@ -494,23 +522,28 @@ END;
for model in model_list:
for f in model._meta.local_fields:
if isinstance(f, AutoField):
- no_autofield_sequence_name = self._get_no_autofield_sequence_name(model._meta.db_table)
+ no_autofield_sequence_name = self._get_no_autofield_sequence_name(
+ model._meta.db_table
+ )
table = self.quote_name(model._meta.db_table)
column = self.quote_name(f.column)
- output.append(query % {
- 'no_autofield_sequence_name': no_autofield_sequence_name,
- 'table': table,
- 'column': column,
- 'table_name': strip_quotes(table),
- 'column_name': strip_quotes(column),
- })
+ output.append(
+ query
+ % {
+ "no_autofield_sequence_name": no_autofield_sequence_name,
+ "table": table,
+ "column": column,
+ "table_name": strip_quotes(table),
+ "column_name": strip_quotes(column),
+ }
+ )
# Only one AutoField is allowed per model, so don't
# continue to loop
break
return output
def start_transaction_sql(self):
- return ''
+ return ""
def tablespace_sql(self, tablespace, inline=False):
if inline:
@@ -541,7 +574,7 @@ END;
return None
# Expression values are adapted by the database.
- if hasattr(value, 'resolve_expression'):
+ if hasattr(value, "resolve_expression"):
return value
# cx_Oracle doesn't support tz-aware datetimes
@@ -549,7 +582,9 @@ END;
if settings.USE_TZ:
value = timezone.make_naive(value, self.connection.timezone)
else:
- raise ValueError("Oracle backend does not support timezone-aware datetimes when USE_TZ is False.")
+ raise ValueError(
+ "Oracle backend does not support timezone-aware datetimes when USE_TZ is False."
+ )
return Oracle_datetime.from_datetime(value)
@@ -558,38 +593,39 @@ END;
return None
# Expression values are adapted by the database.
- if hasattr(value, 'resolve_expression'):
+ if hasattr(value, "resolve_expression"):
return value
if isinstance(value, str):
- return datetime.datetime.strptime(value, '%H:%M:%S')
+ return datetime.datetime.strptime(value, "%H:%M:%S")
# Oracle doesn't support tz-aware times
if timezone.is_aware(value):
raise ValueError("Oracle backend does not support timezone-aware times.")
- return Oracle_datetime(1900, 1, 1, value.hour, value.minute,
- value.second, value.microsecond)
+ return Oracle_datetime(
+ 1900, 1, 1, value.hour, value.minute, value.second, value.microsecond
+ )
def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
return value
def combine_expression(self, connector, sub_expressions):
lhs, rhs = sub_expressions
- if connector == '%%':
- return 'MOD(%s)' % ','.join(sub_expressions)
- elif connector == '&':
- return 'BITAND(%s)' % ','.join(sub_expressions)
- elif connector == '|':
- return 'BITAND(-%(lhs)s-1,%(rhs)s)+%(lhs)s' % {'lhs': lhs, 'rhs': rhs}
- elif connector == '<<':
- return '(%(lhs)s * POWER(2, %(rhs)s))' % {'lhs': lhs, 'rhs': rhs}
- elif connector == '>>':
- return 'FLOOR(%(lhs)s / POWER(2, %(rhs)s))' % {'lhs': lhs, 'rhs': rhs}
- elif connector == '^':
- return 'POWER(%s)' % ','.join(sub_expressions)
- elif connector == '#':
- raise NotSupportedError('Bitwise XOR is not supported in Oracle.')
+ if connector == "%%":
+ return "MOD(%s)" % ",".join(sub_expressions)
+ elif connector == "&":
+ return "BITAND(%s)" % ",".join(sub_expressions)
+ elif connector == "|":
+ return "BITAND(-%(lhs)s-1,%(rhs)s)+%(lhs)s" % {"lhs": lhs, "rhs": rhs}
+ elif connector == "<<":
+ return "(%(lhs)s * POWER(2, %(rhs)s))" % {"lhs": lhs, "rhs": rhs}
+ elif connector == ">>":
+ return "FLOOR(%(lhs)s / POWER(2, %(rhs)s))" % {"lhs": lhs, "rhs": rhs}
+ elif connector == "^":
+ return "POWER(%s)" % ",".join(sub_expressions)
+ elif connector == "#":
+ raise NotSupportedError("Bitwise XOR is not supported in Oracle.")
return super().combine_expression(connector, sub_expressions)
def _get_no_autofield_sequence_name(self, table):
@@ -598,14 +634,17 @@ END;
AutoFields that aren't Oracle identity columns.
"""
name_length = self.max_name_length() - 3
- return '%s_SQ' % truncate_name(strip_quotes(table), name_length).upper()
+ return "%s_SQ" % truncate_name(strip_quotes(table), name_length).upper()
def _get_sequence_name(self, cursor, table, pk_name):
- cursor.execute("""
+ cursor.execute(
+ """
SELECT sequence_name
FROM user_tab_identity_cols
WHERE table_name = UPPER(%s)
- AND column_name = UPPER(%s)""", [table, pk_name])
+ AND column_name = UPPER(%s)""",
+ [table, pk_name],
+ )
row = cursor.fetchone()
return self._get_no_autofield_sequence_name(table) if row is None else row[0]
@@ -616,26 +655,33 @@ END;
for i, placeholder in enumerate(row):
# A model without any fields has fields=[None].
if fields[i]:
- internal_type = getattr(fields[i], 'target_field', fields[i]).get_internal_type()
- placeholder = BulkInsertMapper.types.get(internal_type, '%s') % placeholder
+ internal_type = getattr(
+ fields[i], "target_field", fields[i]
+ ).get_internal_type()
+ placeholder = (
+ BulkInsertMapper.types.get(internal_type, "%s") % placeholder
+ )
# Add columns aliases to the first select to avoid "ORA-00918:
# column ambiguously defined" when two or more columns in the
# first select have the same value.
if not query:
- placeholder = '%s col_%s' % (placeholder, i)
+ placeholder = "%s col_%s" % (placeholder, i)
select.append(placeholder)
- query.append('SELECT %s FROM DUAL' % ', '.join(select))
+ query.append("SELECT %s FROM DUAL" % ", ".join(select))
# Bulk insert to tables with Oracle identity columns causes Oracle to
# add sequence.nextval to it. Sequence.nextval cannot be used with the
# UNION operator. To prevent incorrect SQL, move UNION to a subquery.
- return 'SELECT * FROM (%s)' % ' UNION ALL '.join(query)
+ return "SELECT * FROM (%s)" % " UNION ALL ".join(query)
def subtract_temporals(self, internal_type, lhs, rhs):
- if internal_type == 'DateField':
+ if internal_type == "DateField":
lhs_sql, lhs_params = lhs
rhs_sql, rhs_params = rhs
params = (*lhs_params, *rhs_params)
- return "NUMTODSINTERVAL(TO_NUMBER(%s - %s), 'DAY')" % (lhs_sql, rhs_sql), params
+ return (
+ "NUMTODSINTERVAL(TO_NUMBER(%s - %s), 'DAY')" % (lhs_sql, rhs_sql),
+ params,
+ )
return super().subtract_temporals(internal_type, lhs, rhs)
def bulk_batch_size(self, fields, objs):
@@ -652,7 +698,9 @@ END;
if isinstance(expression, (Exists, Lookup, WhereNode)):
return True
if isinstance(expression, ExpressionWrapper) and expression.conditional:
- return self.conditional_expression_supported_in_where_clause(expression.expression)
+ return self.conditional_expression_supported_in_where_clause(
+ expression.expression
+ )
if isinstance(expression, RawSQL) and expression.conditional:
return True
return False
diff --git a/django/db/backends/oracle/schema.py b/django/db/backends/oracle/schema.py
index 98e49413c9..2b1027d6b5 100644
--- a/django/db/backends/oracle/schema.py
+++ b/django/db/backends/oracle/schema.py
@@ -4,7 +4,8 @@ import re
from django.db import DatabaseError
from django.db.backends.base.schema import (
- BaseDatabaseSchemaEditor, _related_non_m2m_objects,
+ BaseDatabaseSchemaEditor,
+ _related_non_m2m_objects,
)
from django.utils.duration import duration_iso_string
@@ -21,7 +22,9 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
sql_alter_column_collate = "MODIFY %(column)s %(type)s%(collation)s"
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s"
- sql_create_column_inline_fk = 'CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s'
+ sql_create_column_inline_fk = (
+ "CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s"
+ )
sql_delete_table = "DROP TABLE %(table)s CASCADE CONSTRAINTS"
sql_create_index = "CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s"
@@ -31,7 +34,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
elif isinstance(value, datetime.timedelta):
return "'%s'" % duration_iso_string(value)
elif isinstance(value, str):
- return "'%s'" % value.replace("\'", "\'\'").replace('%', '%%')
+ return "'%s'" % value.replace("'", "''").replace("%", "%%")
elif isinstance(value, (bytes, bytearray, memoryview)):
return "'%s'" % value.hex()
elif isinstance(value, bool):
@@ -50,7 +53,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# Run superclass action
super().delete_model(model)
# Clean up manually created sequence.
- self.execute("""
+ self.execute(
+ """
DECLARE
i INTEGER;
BEGIN
@@ -60,7 +64,13 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
EXECUTE IMMEDIATE 'DROP SEQUENCE "%(sq_name)s"';
END IF;
END;
- /""" % {'sq_name': self.connection.ops._get_no_autofield_sequence_name(model._meta.db_table)})
+ /"""
+ % {
+ "sq_name": self.connection.ops._get_no_autofield_sequence_name(
+ model._meta.db_table
+ )
+ }
+ )
def alter_field(self, model, old_field, new_field, strict=False):
try:
@@ -69,16 +79,16 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
description = str(e)
# If we're changing type to an unsupported type we need a
# SQLite-ish workaround
- if 'ORA-22858' in description or 'ORA-22859' in description:
+ if "ORA-22858" in description or "ORA-22859" in description:
self._alter_field_type_workaround(model, old_field, new_field)
# If an identity column is changing to a non-numeric type, drop the
# identity first.
- elif 'ORA-30675' in description:
+ elif "ORA-30675" in description:
self._drop_identity(model._meta.db_table, old_field.column)
self.alter_field(model, old_field, new_field, strict)
# If a primary key column is changing to an identity column, drop
# the primary key first.
- elif 'ORA-30673' in description and old_field.primary_key:
+ elif "ORA-30673" in description and old_field.primary_key:
self._delete_primary_key(model, strict=True)
self._alter_field_type_workaround(model, old_field, new_field)
else:
@@ -98,7 +108,11 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# Make a new field that's like the new one but with a temporary
# column name.
new_temp_field = copy.deepcopy(new_field)
- new_temp_field.null = (new_field.get_internal_type() not in ('AutoField', 'BigAutoField', 'SmallAutoField'))
+ new_temp_field.null = new_field.get_internal_type() not in (
+ "AutoField",
+ "BigAutoField",
+ "SmallAutoField",
+ )
new_temp_field.column = self._generate_temp_name(new_field.column)
# Add it
self.add_field(model, new_temp_field)
@@ -107,24 +121,30 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# /Data-Type-Comparison-Rules.html#GUID-D0C5A47E-6F93-4C2D-9E49-4F2B86B359DD
new_value = self.quote_name(old_field.column)
old_type = old_field.db_type(self.connection)
- if re.match('^N?CLOB', old_type):
+ if re.match("^N?CLOB", old_type):
new_value = "TO_CHAR(%s)" % new_value
- old_type = 'VARCHAR2'
- if re.match('^N?VARCHAR2', old_type):
+ old_type = "VARCHAR2"
+ if re.match("^N?VARCHAR2", old_type):
new_internal_type = new_field.get_internal_type()
- if new_internal_type == 'DateField':
+ if new_internal_type == "DateField":
new_value = "TO_DATE(%s, 'YYYY-MM-DD')" % new_value
- elif new_internal_type == 'DateTimeField':
+ elif new_internal_type == "DateTimeField":
new_value = "TO_TIMESTAMP(%s, 'YYYY-MM-DD HH24:MI:SS.FF')" % new_value
- elif new_internal_type == 'TimeField':
+ elif new_internal_type == "TimeField":
# TimeField are stored as TIMESTAMP with a 1900-01-01 date part.
- new_value = "TO_TIMESTAMP(CONCAT('1900-01-01 ', %s), 'YYYY-MM-DD HH24:MI:SS.FF')" % new_value
+ new_value = (
+ "TO_TIMESTAMP(CONCAT('1900-01-01 ', %s), 'YYYY-MM-DD HH24:MI:SS.FF')"
+ % new_value
+ )
# Transfer values across
- self.execute("UPDATE %s set %s=%s" % (
- self.quote_name(model._meta.db_table),
- self.quote_name(new_temp_field.column),
- new_value,
- ))
+ self.execute(
+ "UPDATE %s set %s=%s"
+ % (
+ self.quote_name(model._meta.db_table),
+ self.quote_name(new_temp_field.column),
+ new_value,
+ )
+ )
# Drop the old field
self.remove_field(model, old_field)
# Rename and possibly make the new field NOT NULL
@@ -134,20 +154,22 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# new_field always match.
new_type = new_field.db_type(self.connection)
if (
- (old_field.primary_key and new_field.primary_key) or
- (old_field.unique and new_field.unique)
+ (old_field.primary_key and new_field.primary_key)
+ or (old_field.unique and new_field.unique)
) and old_type != new_type:
for _, rel in _related_non_m2m_objects(new_temp_field, new_field):
if rel.field.db_constraint:
- self.execute(self._create_fk_sql(rel.related_model, rel.field, '_fk'))
+ self.execute(
+ self._create_fk_sql(rel.related_model, rel.field, "_fk")
+ )
def _alter_column_type_sql(self, model, old_field, new_field, new_type):
- auto_field_types = {'AutoField', 'BigAutoField', 'SmallAutoField'}
+ auto_field_types = {"AutoField", "BigAutoField", "SmallAutoField"}
# Drop the identity if migrating away from AutoField.
if (
- old_field.get_internal_type() in auto_field_types and
- new_field.get_internal_type() not in auto_field_types and
- self._is_identity_column(model._meta.db_table, new_field.column)
+ old_field.get_internal_type() in auto_field_types
+ and new_field.get_internal_type() not in auto_field_types
+ and self._is_identity_column(model._meta.db_table, new_field.column)
):
self._drop_identity(model._meta.db_table, new_field.column)
return super()._alter_column_type_sql(model, old_field, new_field, new_type)
@@ -173,7 +195,10 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
def _field_should_be_indexed(self, model, field):
create_index = super()._field_should_be_indexed(model, field)
db_type = field.db_type(self.connection)
- if db_type is not None and db_type.lower() in self.connection._limited_data_types:
+ if (
+ db_type is not None
+ and db_type.lower() in self.connection._limited_data_types
+ ):
return False
return create_index
@@ -193,10 +218,13 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
return row[0] if row else False
def _drop_identity(self, table_name, column_name):
- self.execute('ALTER TABLE %(table)s MODIFY %(column)s DROP IDENTITY' % {
- 'table': self.quote_name(table_name),
- 'column': self.quote_name(column_name),
- })
+ self.execute(
+ "ALTER TABLE %(table)s MODIFY %(column)s DROP IDENTITY"
+ % {
+ "table": self.quote_name(table_name),
+ "column": self.quote_name(column_name),
+ }
+ )
def _get_default_collation(self, table_name):
with self.connection.cursor() as cursor:
@@ -211,4 +239,6 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
def _alter_column_collation_sql(self, model, new_field, new_type, new_collation):
if new_collation is None:
new_collation = self._get_default_collation(model._meta.db_table)
- return super()._alter_column_collation_sql(model, new_field, new_type, new_collation)
+ return super()._alter_column_collation_sql(
+ model, new_field, new_type, new_collation
+ )
diff --git a/django/db/backends/oracle/utils.py b/django/db/backends/oracle/utils.py
index e3786541af..8941a85967 100644
--- a/django/db/backends/oracle/utils.py
+++ b/django/db/backends/oracle/utils.py
@@ -9,24 +9,25 @@ class InsertVar:
as a parameter, in order to receive the id of the row created by an
insert statement.
"""
+
types = {
- 'AutoField': int,
- 'BigAutoField': int,
- 'SmallAutoField': int,
- 'IntegerField': int,
- 'BigIntegerField': int,
- 'SmallIntegerField': int,
- 'PositiveBigIntegerField': int,
- 'PositiveSmallIntegerField': int,
- 'PositiveIntegerField': int,
- 'FloatField': Database.NATIVE_FLOAT,
- 'DateTimeField': Database.TIMESTAMP,
- 'DateField': Database.Date,
- 'DecimalField': Database.NUMBER,
+ "AutoField": int,
+ "BigAutoField": int,
+ "SmallAutoField": int,
+ "IntegerField": int,
+ "BigIntegerField": int,
+ "SmallIntegerField": int,
+ "PositiveBigIntegerField": int,
+ "PositiveSmallIntegerField": int,
+ "PositiveIntegerField": int,
+ "FloatField": Database.NATIVE_FLOAT,
+ "DateTimeField": Database.TIMESTAMP,
+ "DateField": Database.Date,
+ "DecimalField": Database.NUMBER,
}
def __init__(self, field):
- internal_type = getattr(field, 'target_field', field).get_internal_type()
+ internal_type = getattr(field, "target_field", field).get_internal_type()
self.db_type = self.types.get(internal_type, str)
self.bound_param = None
@@ -43,48 +44,54 @@ class Oracle_datetime(datetime.datetime):
A datetime object, with an additional class attribute
to tell cx_Oracle to save the microseconds too.
"""
+
input_size = Database.TIMESTAMP
@classmethod
def from_datetime(cls, dt):
return Oracle_datetime(
- dt.year, dt.month, dt.day,
- dt.hour, dt.minute, dt.second, dt.microsecond,
+ dt.year,
+ dt.month,
+ dt.day,
+ dt.hour,
+ dt.minute,
+ dt.second,
+ dt.microsecond,
)
class BulkInsertMapper:
- BLOB = 'TO_BLOB(%s)'
- DATE = 'TO_DATE(%s)'
- INTERVAL = 'CAST(%s as INTERVAL DAY(9) TO SECOND(6))'
- NCLOB = 'TO_NCLOB(%s)'
- NUMBER = 'TO_NUMBER(%s)'
- TIMESTAMP = 'TO_TIMESTAMP(%s)'
+ BLOB = "TO_BLOB(%s)"
+ DATE = "TO_DATE(%s)"
+ INTERVAL = "CAST(%s as INTERVAL DAY(9) TO SECOND(6))"
+ NCLOB = "TO_NCLOB(%s)"
+ NUMBER = "TO_NUMBER(%s)"
+ TIMESTAMP = "TO_TIMESTAMP(%s)"
types = {
- 'AutoField': NUMBER,
- 'BigAutoField': NUMBER,
- 'BigIntegerField': NUMBER,
- 'BinaryField': BLOB,
- 'BooleanField': NUMBER,
- 'DateField': DATE,
- 'DateTimeField': TIMESTAMP,
- 'DecimalField': NUMBER,
- 'DurationField': INTERVAL,
- 'FloatField': NUMBER,
- 'IntegerField': NUMBER,
- 'PositiveBigIntegerField': NUMBER,
- 'PositiveIntegerField': NUMBER,
- 'PositiveSmallIntegerField': NUMBER,
- 'SmallAutoField': NUMBER,
- 'SmallIntegerField': NUMBER,
- 'TextField': NCLOB,
- 'TimeField': TIMESTAMP,
+ "AutoField": NUMBER,
+ "BigAutoField": NUMBER,
+ "BigIntegerField": NUMBER,
+ "BinaryField": BLOB,
+ "BooleanField": NUMBER,
+ "DateField": DATE,
+ "DateTimeField": TIMESTAMP,
+ "DecimalField": NUMBER,
+ "DurationField": INTERVAL,
+ "FloatField": NUMBER,
+ "IntegerField": NUMBER,
+ "PositiveBigIntegerField": NUMBER,
+ "PositiveIntegerField": NUMBER,
+ "PositiveSmallIntegerField": NUMBER,
+ "SmallAutoField": NUMBER,
+ "SmallIntegerField": NUMBER,
+ "TextField": NCLOB,
+ "TimeField": TIMESTAMP,
}
def dsn(settings_dict):
- if settings_dict['PORT']:
- host = settings_dict['HOST'].strip() or 'localhost'
- return Database.makedsn(host, int(settings_dict['PORT']), settings_dict['NAME'])
- return settings_dict['NAME']
+ if settings_dict["PORT"]:
+ host = settings_dict["HOST"].strip() or "localhost"
+ return Database.makedsn(host, int(settings_dict["PORT"]), settings_dict["NAME"])
+ return settings_dict["NAME"]
diff --git a/django/db/backends/oracle/validation.py b/django/db/backends/oracle/validation.py
index e5a35fd3ca..4035b12085 100644
--- a/django/db/backends/oracle/validation.py
+++ b/django/db/backends/oracle/validation.py
@@ -9,14 +9,14 @@ class DatabaseValidation(BaseDatabaseValidation):
if field.db_index and field_type.lower() in self.connection._limited_data_types:
errors.append(
checks.Warning(
- 'Oracle does not support a database index on %s columns.'
+ "Oracle does not support a database index on %s columns."
% field_type,
hint=(
"An index won't be created. Silence this warning if "
"you don't care about it."
),
obj=field,
- id='fields.W162',
+ id="fields.W162",
)
)
return errors
diff --git a/django/db/backends/postgresql/base.py b/django/db/backends/postgresql/base.py
index e49d453f94..92f393227e 100644
--- a/django/db/backends/postgresql/base.py
+++ b/django/db/backends/postgresql/base.py
@@ -11,11 +11,10 @@ from contextlib import contextmanager
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
-from django.db import DatabaseError as WrappedDatabaseError, connections
+from django.db import DatabaseError as WrappedDatabaseError
+from django.db import connections
from django.db.backends.base.base import BaseDatabaseWrapper
-from django.db.backends.utils import (
- CursorDebugWrapper as BaseCursorDebugWrapper,
-)
+from django.db.backends.utils import CursorDebugWrapper as BaseCursorDebugWrapper
from django.utils.asyncio import async_unsafe
from django.utils.functional import cached_property
from django.utils.safestring import SafeString
@@ -30,14 +29,17 @@ except ImportError as e:
def psycopg2_version():
- version = psycopg2.__version__.split(' ', 1)[0]
+ version = psycopg2.__version__.split(" ", 1)[0]
return get_version_tuple(version)
PSYCOPG2_VERSION = psycopg2_version()
if PSYCOPG2_VERSION < (2, 8, 4):
- raise ImproperlyConfigured("psycopg2 version 2.8.4 or newer is required; you have %s" % psycopg2.__version__)
+ raise ImproperlyConfigured(
+ "psycopg2 version 2.8.4 or newer is required; you have %s"
+ % psycopg2.__version__
+ )
# Some of these import psycopg2, so import them after checking if it's installed.
@@ -56,68 +58,68 @@ psycopg2.extras.register_uuid()
INETARRAY_OID = 1041
INETARRAY = psycopg2.extensions.new_array_type(
(INETARRAY_OID,),
- 'INETARRAY',
+ "INETARRAY",
psycopg2.extensions.UNICODE,
)
psycopg2.extensions.register_type(INETARRAY)
class DatabaseWrapper(BaseDatabaseWrapper):
- vendor = 'postgresql'
- display_name = 'PostgreSQL'
+ vendor = "postgresql"
+ display_name = "PostgreSQL"
# This dictionary maps Field objects to their associated PostgreSQL column
# types, as strings. Column-type strings can contain format strings; they'll
# be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output.
data_types = {
- 'AutoField': 'serial',
- 'BigAutoField': 'bigserial',
- 'BinaryField': 'bytea',
- 'BooleanField': 'boolean',
- 'CharField': 'varchar(%(max_length)s)',
- 'DateField': 'date',
- 'DateTimeField': 'timestamp with time zone',
- 'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)',
- 'DurationField': 'interval',
- 'FileField': 'varchar(%(max_length)s)',
- 'FilePathField': 'varchar(%(max_length)s)',
- 'FloatField': 'double precision',
- 'IntegerField': 'integer',
- 'BigIntegerField': 'bigint',
- 'IPAddressField': 'inet',
- 'GenericIPAddressField': 'inet',
- 'JSONField': 'jsonb',
- 'OneToOneField': 'integer',
- 'PositiveBigIntegerField': 'bigint',
- 'PositiveIntegerField': 'integer',
- 'PositiveSmallIntegerField': 'smallint',
- 'SlugField': 'varchar(%(max_length)s)',
- 'SmallAutoField': 'smallserial',
- 'SmallIntegerField': 'smallint',
- 'TextField': 'text',
- 'TimeField': 'time',
- 'UUIDField': 'uuid',
+ "AutoField": "serial",
+ "BigAutoField": "bigserial",
+ "BinaryField": "bytea",
+ "BooleanField": "boolean",
+ "CharField": "varchar(%(max_length)s)",
+ "DateField": "date",
+ "DateTimeField": "timestamp with time zone",
+ "DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)",
+ "DurationField": "interval",
+ "FileField": "varchar(%(max_length)s)",
+ "FilePathField": "varchar(%(max_length)s)",
+ "FloatField": "double precision",
+ "IntegerField": "integer",
+ "BigIntegerField": "bigint",
+ "IPAddressField": "inet",
+ "GenericIPAddressField": "inet",
+ "JSONField": "jsonb",
+ "OneToOneField": "integer",
+ "PositiveBigIntegerField": "bigint",
+ "PositiveIntegerField": "integer",
+ "PositiveSmallIntegerField": "smallint",
+ "SlugField": "varchar(%(max_length)s)",
+ "SmallAutoField": "smallserial",
+ "SmallIntegerField": "smallint",
+ "TextField": "text",
+ "TimeField": "time",
+ "UUIDField": "uuid",
}
data_type_check_constraints = {
- 'PositiveBigIntegerField': '"%(column)s" >= 0',
- 'PositiveIntegerField': '"%(column)s" >= 0',
- 'PositiveSmallIntegerField': '"%(column)s" >= 0',
+ "PositiveBigIntegerField": '"%(column)s" >= 0',
+ "PositiveIntegerField": '"%(column)s" >= 0',
+ "PositiveSmallIntegerField": '"%(column)s" >= 0',
}
operators = {
- 'exact': '= %s',
- 'iexact': '= UPPER(%s)',
- 'contains': 'LIKE %s',
- 'icontains': 'LIKE UPPER(%s)',
- 'regex': '~ %s',
- 'iregex': '~* %s',
- 'gt': '> %s',
- 'gte': '>= %s',
- 'lt': '< %s',
- 'lte': '<= %s',
- 'startswith': 'LIKE %s',
- 'endswith': 'LIKE %s',
- 'istartswith': 'LIKE UPPER(%s)',
- 'iendswith': 'LIKE UPPER(%s)',
+ "exact": "= %s",
+ "iexact": "= UPPER(%s)",
+ "contains": "LIKE %s",
+ "icontains": "LIKE UPPER(%s)",
+ "regex": "~ %s",
+ "iregex": "~* %s",
+ "gt": "> %s",
+ "gte": ">= %s",
+ "lt": "< %s",
+ "lte": "<= %s",
+ "startswith": "LIKE %s",
+ "endswith": "LIKE %s",
+ "istartswith": "LIKE UPPER(%s)",
+ "iendswith": "LIKE UPPER(%s)",
}
# The patterns below are used to generate SQL pattern lookup clauses when
@@ -128,14 +130,16 @@ class DatabaseWrapper(BaseDatabaseWrapper):
#
# Note: we use str.format() here for readability as '%' is used as a wildcard for
# the LIKE operator.
- pattern_esc = r"REPLACE(REPLACE(REPLACE({}, E'\\', E'\\\\'), E'%%', E'\\%%'), E'_', E'\\_')"
+ pattern_esc = (
+ r"REPLACE(REPLACE(REPLACE({}, E'\\', E'\\\\'), E'%%', E'\\%%'), E'_', E'\\_')"
+ )
pattern_ops = {
- 'contains': "LIKE '%%' || {} || '%%'",
- 'icontains': "LIKE '%%' || UPPER({}) || '%%'",
- 'startswith': "LIKE {} || '%%'",
- 'istartswith': "LIKE UPPER({}) || '%%'",
- 'endswith': "LIKE '%%' || {}",
- 'iendswith': "LIKE '%%' || UPPER({})",
+ "contains": "LIKE '%%' || {} || '%%'",
+ "icontains": "LIKE '%%' || UPPER({}) || '%%'",
+ "startswith": "LIKE {} || '%%'",
+ "istartswith": "LIKE UPPER({}) || '%%'",
+ "endswith": "LIKE '%%' || {}",
+ "iendswith": "LIKE '%%' || UPPER({})",
}
Database = Database
@@ -152,46 +156,46 @@ class DatabaseWrapper(BaseDatabaseWrapper):
def get_connection_params(self):
settings_dict = self.settings_dict
# None may be used to connect to the default 'postgres' db
- if (
- settings_dict['NAME'] == '' and
- not settings_dict.get('OPTIONS', {}).get('service')
+ if settings_dict["NAME"] == "" and not settings_dict.get("OPTIONS", {}).get(
+ "service"
):
raise ImproperlyConfigured(
"settings.DATABASES is improperly configured. "
"Please supply the NAME or OPTIONS['service'] value."
)
- if len(settings_dict['NAME'] or '') > self.ops.max_name_length():
+ if len(settings_dict["NAME"] or "") > self.ops.max_name_length():
raise ImproperlyConfigured(
"The database name '%s' (%d characters) is longer than "
"PostgreSQL's limit of %d characters. Supply a shorter NAME "
- "in settings.DATABASES." % (
- settings_dict['NAME'],
- len(settings_dict['NAME']),
+ "in settings.DATABASES."
+ % (
+ settings_dict["NAME"],
+ len(settings_dict["NAME"]),
self.ops.max_name_length(),
)
)
conn_params = {}
- if settings_dict['NAME']:
+ if settings_dict["NAME"]:
conn_params = {
- 'database': settings_dict['NAME'],
- **settings_dict['OPTIONS'],
+ "database": settings_dict["NAME"],
+ **settings_dict["OPTIONS"],
}
- elif settings_dict['NAME'] is None:
+ elif settings_dict["NAME"] is None:
# Connect to the default 'postgres' db.
- settings_dict.get('OPTIONS', {}).pop('service', None)
- conn_params = {'database': 'postgres', **settings_dict['OPTIONS']}
+ settings_dict.get("OPTIONS", {}).pop("service", None)
+ conn_params = {"database": "postgres", **settings_dict["OPTIONS"]}
else:
- conn_params = {**settings_dict['OPTIONS']}
+ conn_params = {**settings_dict["OPTIONS"]}
- conn_params.pop('isolation_level', None)
- if settings_dict['USER']:
- conn_params['user'] = settings_dict['USER']
- if settings_dict['PASSWORD']:
- conn_params['password'] = settings_dict['PASSWORD']
- if settings_dict['HOST']:
- conn_params['host'] = settings_dict['HOST']
- if settings_dict['PORT']:
- conn_params['port'] = settings_dict['PORT']
+ conn_params.pop("isolation_level", None)
+ if settings_dict["USER"]:
+ conn_params["user"] = settings_dict["USER"]
+ if settings_dict["PASSWORD"]:
+ conn_params["password"] = settings_dict["PASSWORD"]
+ if settings_dict["HOST"]:
+ conn_params["host"] = settings_dict["HOST"]
+ if settings_dict["PORT"]:
+ conn_params["port"] = settings_dict["PORT"]
return conn_params
@async_unsafe
@@ -203,9 +207,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# default when no value is explicitly specified in options.
# - before calling _set_autocommit() because if autocommit is on, that
# will set connection.isolation_level to ISOLATION_LEVEL_AUTOCOMMIT.
- options = self.settings_dict['OPTIONS']
+ options = self.settings_dict["OPTIONS"]
try:
- self.isolation_level = options['isolation_level']
+ self.isolation_level = options["isolation_level"]
except KeyError:
self.isolation_level = connection.isolation_level
else:
@@ -215,13 +219,15 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# Register dummy loads() to avoid a round trip from psycopg2's decode
# to json.dumps() to json.loads(), when using a custom decoder in
# JSONField.
- psycopg2.extras.register_default_jsonb(conn_or_curs=connection, loads=lambda x: x)
+ psycopg2.extras.register_default_jsonb(
+ conn_or_curs=connection, loads=lambda x: x
+ )
return connection
def ensure_timezone(self):
if self.connection is None:
return False
- conn_timezone_name = self.connection.get_parameter_status('TimeZone')
+ conn_timezone_name = self.connection.get_parameter_status("TimeZone")
timezone_name = self.timezone_name
if timezone_name and conn_timezone_name != timezone_name:
with self.connection.cursor() as cursor:
@@ -230,7 +236,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return False
def init_connection_state(self):
- self.connection.set_client_encoding('UTF8')
+ self.connection.set_client_encoding("UTF8")
timezone_changed = self.ensure_timezone()
if timezone_changed:
@@ -243,7 +249,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
if name:
# In autocommit mode, the cursor will be used outside of a
# transaction, hence use a holdable cursor.
- cursor = self.connection.cursor(name, scrollable=False, withhold=self.connection.autocommit)
+ cursor = self.connection.cursor(
+ name, scrollable=False, withhold=self.connection.autocommit
+ )
else:
cursor = self.connection.cursor()
cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None
@@ -268,10 +276,11 @@ class DatabaseWrapper(BaseDatabaseWrapper):
if current_task:
task_ident = str(id(current_task))
else:
- task_ident = 'sync'
+ task_ident = "sync"
# Use that and the thread ident to get a unique name
return self._cursor(
- name='_django_curs_%d_%s_%d' % (
+ name="_django_curs_%d_%s_%d"
+ % (
# Avoid reusing name in other threads / tasks
threading.current_thread().ident,
task_ident,
@@ -289,14 +298,14 @@ class DatabaseWrapper(BaseDatabaseWrapper):
afterward.
"""
with self.cursor() as cursor:
- cursor.execute('SET CONSTRAINTS ALL IMMEDIATE')
- cursor.execute('SET CONSTRAINTS ALL DEFERRED')
+ cursor.execute("SET CONSTRAINTS ALL IMMEDIATE")
+ cursor.execute("SET CONSTRAINTS ALL DEFERRED")
def is_usable(self):
try:
# Use a psycopg cursor directly, bypassing Django's utilities.
with self.connection.cursor() as cursor:
- cursor.execute('SELECT 1')
+ cursor.execute("SELECT 1")
except Database.Error:
return False
else:
@@ -317,12 +326,18 @@ class DatabaseWrapper(BaseDatabaseWrapper):
"database when it's not needed (for example, when running tests). "
"Django was unable to create a connection to the 'postgres' database "
"and will use the first PostgreSQL database instead.",
- RuntimeWarning
+ RuntimeWarning,
)
for connection in connections.all():
- if connection.vendor == 'postgresql' and connection.settings_dict['NAME'] != 'postgres':
+ if (
+ connection.vendor == "postgresql"
+ and connection.settings_dict["NAME"] != "postgres"
+ ):
conn = self.__class__(
- {**self.settings_dict, 'NAME': connection.settings_dict['NAME']},
+ {
+ **self.settings_dict,
+ "NAME": connection.settings_dict["NAME"],
+ },
alias=self.alias,
)
try:
@@ -349,5 +364,5 @@ class CursorDebugWrapper(BaseCursorDebugWrapper):
return self.cursor.copy_expert(sql, file, *args)
def copy_to(self, file, table, *args, **kwargs):
- with self.debug_sql(sql='COPY %s TO STDOUT' % table):
+ with self.debug_sql(sql="COPY %s TO STDOUT" % table):
return self.cursor.copy_to(file, table, *args, **kwargs)
diff --git a/django/db/backends/postgresql/client.py b/django/db/backends/postgresql/client.py
index 0effcc44e6..4c9bd63546 100644
--- a/django/db/backends/postgresql/client.py
+++ b/django/db/backends/postgresql/client.py
@@ -4,53 +4,53 @@ from django.db.backends.base.client import BaseDatabaseClient
class DatabaseClient(BaseDatabaseClient):
- executable_name = 'psql'
+ executable_name = "psql"
@classmethod
def settings_to_cmd_args_env(cls, settings_dict, parameters):
args = [cls.executable_name]
- options = settings_dict.get('OPTIONS', {})
+ options = settings_dict.get("OPTIONS", {})
- host = settings_dict.get('HOST')
- port = settings_dict.get('PORT')
- dbname = settings_dict.get('NAME')
- user = settings_dict.get('USER')
- passwd = settings_dict.get('PASSWORD')
- passfile = options.get('passfile')
- service = options.get('service')
- sslmode = options.get('sslmode')
- sslrootcert = options.get('sslrootcert')
- sslcert = options.get('sslcert')
- sslkey = options.get('sslkey')
+ host = settings_dict.get("HOST")
+ port = settings_dict.get("PORT")
+ dbname = settings_dict.get("NAME")
+ user = settings_dict.get("USER")
+ passwd = settings_dict.get("PASSWORD")
+ passfile = options.get("passfile")
+ service = options.get("service")
+ sslmode = options.get("sslmode")
+ sslrootcert = options.get("sslrootcert")
+ sslcert = options.get("sslcert")
+ sslkey = options.get("sslkey")
if not dbname and not service:
# Connect to the default 'postgres' db.
- dbname = 'postgres'
+ dbname = "postgres"
if user:
- args += ['-U', user]
+ args += ["-U", user]
if host:
- args += ['-h', host]
+ args += ["-h", host]
if port:
- args += ['-p', str(port)]
+ args += ["-p", str(port)]
if dbname:
args += [dbname]
args.extend(parameters)
env = {}
if passwd:
- env['PGPASSWORD'] = str(passwd)
+ env["PGPASSWORD"] = str(passwd)
if service:
- env['PGSERVICE'] = str(service)
+ env["PGSERVICE"] = str(service)
if sslmode:
- env['PGSSLMODE'] = str(sslmode)
+ env["PGSSLMODE"] = str(sslmode)
if sslrootcert:
- env['PGSSLROOTCERT'] = str(sslrootcert)
+ env["PGSSLROOTCERT"] = str(sslrootcert)
if sslcert:
- env['PGSSLCERT'] = str(sslcert)
+ env["PGSSLCERT"] = str(sslcert)
if sslkey:
- env['PGSSLKEY'] = str(sslkey)
+ env["PGSSLKEY"] = str(sslkey)
if passfile:
- env['PGPASSFILE'] = str(passfile)
+ env["PGPASSFILE"] = str(passfile)
return args, (env or None)
def runshell(self, parameters):
diff --git a/django/db/backends/postgresql/creation.py b/django/db/backends/postgresql/creation.py
index eb8ac3bcf5..70c3eda566 100644
--- a/django/db/backends/postgresql/creation.py
+++ b/django/db/backends/postgresql/creation.py
@@ -8,7 +8,6 @@ from django.db.backends.utils import strip_quotes
class DatabaseCreation(BaseDatabaseCreation):
-
def _quote_name(self, name):
return self.connection.ops.quote_name(name)
@@ -21,32 +20,35 @@ class DatabaseCreation(BaseDatabaseCreation):
return suffix and "WITH" + suffix
def sql_table_creation_suffix(self):
- test_settings = self.connection.settings_dict['TEST']
- if test_settings.get('COLLATION') is not None:
+ test_settings = self.connection.settings_dict["TEST"]
+ if test_settings.get("COLLATION") is not None:
raise ImproperlyConfigured(
- 'PostgreSQL does not support collation setting at database '
- 'creation time.'
+ "PostgreSQL does not support collation setting at database "
+ "creation time."
)
return self._get_database_create_suffix(
- encoding=test_settings['CHARSET'],
- template=test_settings.get('TEMPLATE'),
+ encoding=test_settings["CHARSET"],
+ template=test_settings.get("TEMPLATE"),
)
def _database_exists(self, cursor, database_name):
- cursor.execute('SELECT 1 FROM pg_catalog.pg_database WHERE datname = %s', [strip_quotes(database_name)])
+ cursor.execute(
+ "SELECT 1 FROM pg_catalog.pg_database WHERE datname = %s",
+ [strip_quotes(database_name)],
+ )
return cursor.fetchone() is not None
def _execute_create_test_db(self, cursor, parameters, keepdb=False):
try:
- if keepdb and self._database_exists(cursor, parameters['dbname']):
+ if keepdb and self._database_exists(cursor, parameters["dbname"]):
# If the database should be kept and it already exists, don't
# try to create a new one.
return
super()._execute_create_test_db(cursor, parameters, keepdb)
except Exception as e:
- if getattr(e.__cause__, 'pgcode', '') != errorcodes.DUPLICATE_DATABASE:
+ if getattr(e.__cause__, "pgcode", "") != errorcodes.DUPLICATE_DATABASE:
# All errors except "database already exists" cancel tests.
- self.log('Got an error creating the test database: %s' % e)
+ self.log("Got an error creating the test database: %s" % e)
sys.exit(2)
elif not keepdb:
# If the database should be kept, ignore "database already
@@ -58,11 +60,11 @@ class DatabaseCreation(BaseDatabaseCreation):
# to the template database.
self.connection.close()
- source_database_name = self.connection.settings_dict['NAME']
- target_database_name = self.get_test_db_clone_settings(suffix)['NAME']
+ source_database_name = self.connection.settings_dict["NAME"]
+ target_database_name = self.get_test_db_clone_settings(suffix)["NAME"]
test_db_params = {
- 'dbname': self._quote_name(target_database_name),
- 'suffix': self._get_database_create_suffix(template=source_database_name),
+ "dbname": self._quote_name(target_database_name),
+ "suffix": self._get_database_create_suffix(template=source_database_name),
}
with self._nodb_cursor() as cursor:
try:
@@ -70,11 +72,16 @@ class DatabaseCreation(BaseDatabaseCreation):
except Exception:
try:
if verbosity >= 1:
- self.log('Destroying old test database for alias %s...' % (
- self._get_database_display_str(verbosity, target_database_name),
- ))
- cursor.execute('DROP DATABASE %(dbname)s' % test_db_params)
+ self.log(
+ "Destroying old test database for alias %s..."
+ % (
+ self._get_database_display_str(
+ verbosity, target_database_name
+ ),
+ )
+ )
+ cursor.execute("DROP DATABASE %(dbname)s" % test_db_params)
self._execute_create_test_db(cursor, test_db_params, keepdb)
except Exception as e:
- self.log('Got an error cloning the test database: %s' % e)
+ self.log("Got an error cloning the test database: %s" % e)
sys.exit(2)
diff --git a/django/db/backends/postgresql/features.py b/django/db/backends/postgresql/features.py
index 1ce73fb0a8..182c230c75 100644
--- a/django/db/backends/postgresql/features.py
+++ b/django/db/backends/postgresql/features.py
@@ -52,7 +52,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_over_clause = True
only_supports_unbounded_with_preceding_and_following = True
supports_aggregate_filter_clause = True
- supported_explain_formats = {'JSON', 'TEXT', 'XML', 'YAML'}
+ supported_explain_formats = {"JSON", "TEXT", "XML", "YAML"}
validates_explain_options = False # A query will error on invalid options.
supports_deferrable_unique_constraints = True
has_json_operators = True
@@ -60,14 +60,14 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_update_conflicts = True
supports_update_conflicts_with_target = True
test_collations = {
- 'non_default': 'sv-x-icu',
- 'swedish_ci': 'sv-x-icu',
+ "non_default": "sv-x-icu",
+ "swedish_ci": "sv-x-icu",
}
test_now_utc_template = "STATEMENT_TIMESTAMP() AT TIME ZONE 'UTC'"
django_test_skips = {
- 'opclasses are PostgreSQL only.': {
- 'indexes.tests.SchemaIndexesNotPostgreSQLTests.test_create_index_ignores_opclasses',
+ "opclasses are PostgreSQL only.": {
+ "indexes.tests.SchemaIndexesNotPostgreSQLTests.test_create_index_ignores_opclasses",
},
}
@@ -75,9 +75,9 @@ class DatabaseFeatures(BaseDatabaseFeatures):
def introspected_field_types(self):
return {
**super().introspected_field_types,
- 'PositiveBigIntegerField': 'BigIntegerField',
- 'PositiveIntegerField': 'IntegerField',
- 'PositiveSmallIntegerField': 'SmallIntegerField',
+ "PositiveBigIntegerField": "BigIntegerField",
+ "PositiveIntegerField": "IntegerField",
+ "PositiveSmallIntegerField": "SmallIntegerField",
}
@cached_property
@@ -96,9 +96,11 @@ class DatabaseFeatures(BaseDatabaseFeatures):
def is_postgresql_14(self):
return self.connection.pg_version >= 140000
- has_bit_xor = property(operator.attrgetter('is_postgresql_14'))
- has_websearch_to_tsquery = property(operator.attrgetter('is_postgresql_11'))
- supports_covering_indexes = property(operator.attrgetter('is_postgresql_11'))
- supports_covering_gist_indexes = property(operator.attrgetter('is_postgresql_12'))
- supports_covering_spgist_indexes = property(operator.attrgetter('is_postgresql_14'))
- supports_non_deterministic_collations = property(operator.attrgetter('is_postgresql_12'))
+ has_bit_xor = property(operator.attrgetter("is_postgresql_14"))
+ has_websearch_to_tsquery = property(operator.attrgetter("is_postgresql_11"))
+ supports_covering_indexes = property(operator.attrgetter("is_postgresql_11"))
+ supports_covering_gist_indexes = property(operator.attrgetter("is_postgresql_12"))
+ supports_covering_spgist_indexes = property(operator.attrgetter("is_postgresql_14"))
+ supports_non_deterministic_collations = property(
+ operator.attrgetter("is_postgresql_12")
+ )
diff --git a/django/db/backends/postgresql/introspection.py b/django/db/backends/postgresql/introspection.py
index f31d906a2f..a7e9a13d61 100644
--- a/django/db/backends/postgresql/introspection.py
+++ b/django/db/backends/postgresql/introspection.py
@@ -1,5 +1,7 @@
from django.db.backends.base.introspection import (
- BaseDatabaseIntrospection, FieldInfo, TableInfo,
+ BaseDatabaseIntrospection,
+ FieldInfo,
+ TableInfo,
)
from django.db.models import Index
@@ -7,46 +9,47 @@ from django.db.models import Index
class DatabaseIntrospection(BaseDatabaseIntrospection):
# Maps type codes to Django Field types.
data_types_reverse = {
- 16: 'BooleanField',
- 17: 'BinaryField',
- 20: 'BigIntegerField',
- 21: 'SmallIntegerField',
- 23: 'IntegerField',
- 25: 'TextField',
- 700: 'FloatField',
- 701: 'FloatField',
- 869: 'GenericIPAddressField',
- 1042: 'CharField', # blank-padded
- 1043: 'CharField',
- 1082: 'DateField',
- 1083: 'TimeField',
- 1114: 'DateTimeField',
- 1184: 'DateTimeField',
- 1186: 'DurationField',
- 1266: 'TimeField',
- 1700: 'DecimalField',
- 2950: 'UUIDField',
- 3802: 'JSONField',
+ 16: "BooleanField",
+ 17: "BinaryField",
+ 20: "BigIntegerField",
+ 21: "SmallIntegerField",
+ 23: "IntegerField",
+ 25: "TextField",
+ 700: "FloatField",
+ 701: "FloatField",
+ 869: "GenericIPAddressField",
+ 1042: "CharField", # blank-padded
+ 1043: "CharField",
+ 1082: "DateField",
+ 1083: "TimeField",
+ 1114: "DateTimeField",
+ 1184: "DateTimeField",
+ 1186: "DurationField",
+ 1266: "TimeField",
+ 1700: "DecimalField",
+ 2950: "UUIDField",
+ 3802: "JSONField",
}
# A hook for subclasses.
- index_default_access_method = 'btree'
+ index_default_access_method = "btree"
ignored_tables = []
def get_field_type(self, data_type, description):
field_type = super().get_field_type(data_type, description)
- if description.default and 'nextval' in description.default:
- if field_type == 'IntegerField':
- return 'AutoField'
- elif field_type == 'BigIntegerField':
- return 'BigAutoField'
- elif field_type == 'SmallIntegerField':
- return 'SmallAutoField'
+ if description.default and "nextval" in description.default:
+ if field_type == "IntegerField":
+ return "AutoField"
+ elif field_type == "BigIntegerField":
+ return "BigAutoField"
+ elif field_type == "SmallIntegerField":
+ return "SmallAutoField"
return field_type
def get_table_list(self, cursor):
"""Return a list of table and view names in the current database."""
- cursor.execute("""
+ cursor.execute(
+ """
SELECT c.relname,
CASE WHEN c.relispartition THEN 'p' WHEN c.relkind IN ('m', 'v') THEN 'v' ELSE 't' END
FROM pg_catalog.pg_class c
@@ -54,8 +57,13 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
AND pg_catalog.pg_table_is_visible(c.oid)
- """)
- return [TableInfo(*row) for row in cursor.fetchall() if row[0] not in self.ignored_tables]
+ """
+ )
+ return [
+ TableInfo(*row)
+ for row in cursor.fetchall()
+ if row[0] not in self.ignored_tables
+ ]
def get_table_description(self, cursor, table_name):
"""
@@ -65,7 +73,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
# Query the pg_catalog tables as cursor.description does not reliably
# return the nullable property and information_schema.columns does not
# contain details of materialized views.
- cursor.execute("""
+ cursor.execute(
+ """
SELECT
a.attname AS column_name,
NOT (a.attnotnull OR (t.typtype = 'd' AND t.typnotnull)) AS is_nullable,
@@ -81,9 +90,13 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
AND c.relname = %s
AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
AND pg_catalog.pg_table_is_visible(c.oid)
- """, [table_name])
+ """,
+ [table_name],
+ )
field_map = {line[0]: line[1:] for line in cursor.fetchall()}
- cursor.execute("SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name))
+ cursor.execute(
+ "SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name)
+ )
return [
FieldInfo(
line.name,
@@ -98,7 +111,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
]
def get_sequences(self, cursor, table_name, table_fields=()):
- cursor.execute("""
+ cursor.execute(
+ """
SELECT s.relname as sequence_name, col.attname
FROM pg_class s
JOIN pg_namespace sn ON sn.oid = s.relnamespace
@@ -110,9 +124,11 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
AND d.deptype in ('a', 'n')
AND pg_catalog.pg_table_is_visible(tbl.oid)
AND tbl.relname = %s
- """, [table_name])
+ """,
+ [table_name],
+ )
return [
- {'name': row[0], 'table': table_name, 'column': row[1]}
+ {"name": row[0], "table": table_name, "column": row[1]}
for row in cursor.fetchall()
]
@@ -121,7 +137,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
Return a dictionary of {field_name: (field_name_other_table, other_table)}
representing all foreign keys in the given table.
"""
- cursor.execute("""
+ cursor.execute(
+ """
SELECT a1.attname, c2.relname, a2.attname
FROM pg_constraint con
LEFT JOIN pg_class c1 ON con.conrelid = c1.oid
@@ -133,7 +150,9 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
con.contype = 'f' AND
c1.relnamespace = c2.relnamespace AND
pg_catalog.pg_table_is_visible(c1.oid)
- """, [table_name])
+ """,
+ [table_name],
+ )
return {row[0]: (row[2], row[1]) for row in cursor.fetchall()}
def get_constraints(self, cursor, table_name):
@@ -146,7 +165,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
# Loop over the key table, collecting things as constraints. The column
# array must return column names in the same order in which they were
# created.
- cursor.execute("""
+ cursor.execute(
+ """
SELECT
c.conname,
array(
@@ -165,7 +185,9 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
FROM pg_constraint AS c
JOIN pg_class AS cl ON c.conrelid = cl.oid
WHERE cl.relname = %s AND pg_catalog.pg_table_is_visible(cl.oid)
- """, [table_name])
+ """,
+ [table_name],
+ )
for constraint, columns, kind, used_cols, options in cursor.fetchall():
constraints[constraint] = {
"columns": columns,
@@ -178,7 +200,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
"options": options,
}
# Now get indexes
- cursor.execute("""
+ cursor.execute(
+ """
SELECT
indexname, array_agg(attname ORDER BY arridx), indisunique, indisprimary,
array_agg(ordering ORDER BY arridx), amname, exprdef, s2.attoptions
@@ -207,14 +230,27 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
WHERE c.relname = %s AND pg_catalog.pg_table_is_visible(c.oid)
) s2
GROUP BY indexname, indisunique, indisprimary, amname, exprdef, attoptions;
- """, [self.index_default_access_method, table_name])
- for index, columns, unique, primary, orders, type_, definition, options in cursor.fetchall():
+ """,
+ [self.index_default_access_method, table_name],
+ )
+ for (
+ index,
+ columns,
+ unique,
+ primary,
+ orders,
+ type_,
+ definition,
+ options,
+ ) in cursor.fetchall():
if index not in constraints:
basic_index = (
- type_ == self.index_default_access_method and
+ type_ == self.index_default_access_method
+ and
# '_btree' references
# django.contrib.postgres.indexes.BTreeIndex.suffix.
- not index.endswith('_btree') and options is None
+ not index.endswith("_btree")
+ and options is None
)
constraints[index] = {
"columns": columns if columns != [None] else [],
diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py
index 762cd8d23e..68448157ec 100644
--- a/django/db/backends/postgresql/operations.py
+++ b/django/db/backends/postgresql/operations.py
@@ -7,17 +7,22 @@ from django.db.models.constants import OnConflict
class DatabaseOperations(BaseDatabaseOperations):
- cast_char_field_without_max_length = 'varchar'
- explain_prefix = 'EXPLAIN'
+ cast_char_field_without_max_length = "varchar"
+ explain_prefix = "EXPLAIN"
cast_data_types = {
- 'AutoField': 'integer',
- 'BigAutoField': 'bigint',
- 'SmallAutoField': 'smallint',
+ "AutoField": "integer",
+ "BigAutoField": "bigint",
+ "SmallAutoField": "smallint",
}
def unification_cast_sql(self, output_field):
internal_type = output_field.get_internal_type()
- if internal_type in ("GenericIPAddressField", "IPAddressField", "TimeField", "UUIDField"):
+ if internal_type in (
+ "GenericIPAddressField",
+ "IPAddressField",
+ "TimeField",
+ "UUIDField",
+ ):
# PostgreSQL will resolve a union as type 'text' if input types are
# 'unknown'.
# https://www.postgresql.org/docs/current/typeconv-union-case.html
@@ -25,17 +30,19 @@ class DatabaseOperations(BaseDatabaseOperations):
# PostgreSQL configuration so we need to explicitly cast them.
# We must also remove components of the type within brackets:
# varchar(255) -> varchar.
- return 'CAST(%%s AS %s)' % output_field.db_type(self.connection).split('(')[0]
- return '%s'
+ return (
+ "CAST(%%s AS %s)" % output_field.db_type(self.connection).split("(")[0]
+ )
+ return "%s"
def date_extract_sql(self, lookup_type, field_name):
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT
- if lookup_type == 'week_day':
+ if lookup_type == "week_day":
# For consistency across backends, we return Sunday=1, Saturday=7.
return "EXTRACT('dow' FROM %s) + 1" % field_name
- elif lookup_type == 'iso_week_day':
+ elif lookup_type == "iso_week_day":
return "EXTRACT('isodow' FROM %s)" % field_name
- elif lookup_type == 'iso_year':
+ elif lookup_type == "iso_year":
return "EXTRACT('isoyear' FROM %s)" % field_name
else:
return "EXTRACT('%s' FROM %s)" % (lookup_type, field_name)
@@ -48,22 +55,25 @@ class DatabaseOperations(BaseDatabaseOperations):
def _prepare_tzname_delta(self, tzname):
tzname, sign, offset = split_tzname_delta(tzname)
if offset:
- sign = '-' if sign == '+' else '+'
- return f'{tzname}{sign}{offset}'
+ sign = "-" if sign == "+" else "+"
+ return f"{tzname}{sign}{offset}"
return tzname
def _convert_field_to_tz(self, field_name, tzname):
if tzname and settings.USE_TZ:
- field_name = "%s AT TIME ZONE '%s'" % (field_name, self._prepare_tzname_delta(tzname))
+ field_name = "%s AT TIME ZONE '%s'" % (
+ field_name,
+ self._prepare_tzname_delta(tzname),
+ )
return field_name
def datetime_cast_date_sql(self, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
- return '(%s)::date' % field_name
+ return "(%s)::date" % field_name
def datetime_cast_time_sql(self, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
- return '(%s)::time' % field_name
+ return "(%s)::time" % field_name
def datetime_extract_sql(self, lookup_type, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
@@ -89,21 +99,30 @@ class DatabaseOperations(BaseDatabaseOperations):
return cursor.fetchall()
def lookup_cast(self, lookup_type, internal_type=None):
- lookup = '%s'
+ lookup = "%s"
# Cast text lookups to text to allow things like filter(x__contains=4)
- if lookup_type in ('iexact', 'contains', 'icontains', 'startswith',
- 'istartswith', 'endswith', 'iendswith', 'regex', 'iregex'):
- if internal_type in ('IPAddressField', 'GenericIPAddressField'):
+ if lookup_type in (
+ "iexact",
+ "contains",
+ "icontains",
+ "startswith",
+ "istartswith",
+ "endswith",
+ "iendswith",
+ "regex",
+ "iregex",
+ ):
+ if internal_type in ("IPAddressField", "GenericIPAddressField"):
lookup = "HOST(%s)"
- elif internal_type in ('CICharField', 'CIEmailField', 'CITextField'):
- lookup = '%s::citext'
+ elif internal_type in ("CICharField", "CIEmailField", "CITextField"):
+ lookup = "%s::citext"
else:
lookup = "%s::text"
# Use UPPER(x) for case-insensitive lookups; it's faster.
- if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'):
- lookup = 'UPPER(%s)' % lookup
+ if lookup_type in ("iexact", "icontains", "istartswith", "iendswith"):
+ lookup = "UPPER(%s)" % lookup
return lookup
@@ -128,29 +147,32 @@ class DatabaseOperations(BaseDatabaseOperations):
# Perform a single SQL 'TRUNCATE x, y, z...;' statement. It allows us
# to truncate tables referenced by a foreign key in any other table.
sql_parts = [
- style.SQL_KEYWORD('TRUNCATE'),
- ', '.join(style.SQL_FIELD(self.quote_name(table)) for table in tables),
+ style.SQL_KEYWORD("TRUNCATE"),
+ ", ".join(style.SQL_FIELD(self.quote_name(table)) for table in tables),
]
if reset_sequences:
- sql_parts.append(style.SQL_KEYWORD('RESTART IDENTITY'))
+ sql_parts.append(style.SQL_KEYWORD("RESTART IDENTITY"))
if allow_cascade:
- sql_parts.append(style.SQL_KEYWORD('CASCADE'))
- return ['%s;' % ' '.join(sql_parts)]
+ sql_parts.append(style.SQL_KEYWORD("CASCADE"))
+ return ["%s;" % " ".join(sql_parts)]
def sequence_reset_by_name_sql(self, style, sequences):
# 'ALTER SEQUENCE sequence_name RESTART WITH 1;'... style SQL statements
# to reset sequence indices
sql = []
for sequence_info in sequences:
- table_name = sequence_info['table']
+ table_name = sequence_info["table"]
# 'id' will be the case if it's an m2m using an autogenerated
# intermediate table (see BaseDatabaseIntrospection.sequence_list).
- column_name = sequence_info['column'] or 'id'
- sql.append("%s setval(pg_get_serial_sequence('%s','%s'), 1, false);" % (
- style.SQL_KEYWORD('SELECT'),
- style.SQL_TABLE(self.quote_name(table_name)),
- style.SQL_FIELD(column_name),
- ))
+ column_name = sequence_info["column"] or "id"
+ sql.append(
+ "%s setval(pg_get_serial_sequence('%s','%s'), 1, false);"
+ % (
+ style.SQL_KEYWORD("SELECT"),
+ style.SQL_TABLE(self.quote_name(table_name)),
+ style.SQL_FIELD(column_name),
+ )
+ )
return sql
def tablespace_sql(self, tablespace, inline=False):
@@ -161,6 +183,7 @@ class DatabaseOperations(BaseDatabaseOperations):
def sequence_reset_sql(self, style, model_list):
from django.db import models
+
output = []
qn = self.quote_name
for model in model_list:
@@ -174,14 +197,15 @@ class DatabaseOperations(BaseDatabaseOperations):
if isinstance(f, models.AutoField):
output.append(
"%s setval(pg_get_serial_sequence('%s','%s'), "
- "coalesce(max(%s), 1), max(%s) %s null) %s %s;" % (
- style.SQL_KEYWORD('SELECT'),
+ "coalesce(max(%s), 1), max(%s) %s null) %s %s;"
+ % (
+ style.SQL_KEYWORD("SELECT"),
style.SQL_TABLE(qn(model._meta.db_table)),
style.SQL_FIELD(f.column),
style.SQL_FIELD(qn(f.column)),
style.SQL_FIELD(qn(f.column)),
- style.SQL_KEYWORD('IS NOT'),
- style.SQL_KEYWORD('FROM'),
+ style.SQL_KEYWORD("IS NOT"),
+ style.SQL_KEYWORD("FROM"),
style.SQL_TABLE(qn(model._meta.db_table)),
)
)
@@ -207,9 +231,9 @@ class DatabaseOperations(BaseDatabaseOperations):
def distinct_sql(self, fields, params):
if fields:
params = [param for param_list in params for param in param_list]
- return (['DISTINCT ON (%s)' % ', '.join(fields)], params)
+ return (["DISTINCT ON (%s)" % ", ".join(fields)], params)
else:
- return ['DISTINCT'], []
+ return ["DISTINCT"], []
def last_executed_query(self, cursor, sql, params):
# https://www.psycopg.org/docs/cursor.html#cursor.query
@@ -220,14 +244,16 @@ class DatabaseOperations(BaseDatabaseOperations):
def return_insert_columns(self, fields):
if not fields:
- return '', ()
+ return "", ()
columns = [
- '%s.%s' % (
+ "%s.%s"
+ % (
self.quote_name(field.model._meta.db_table),
self.quote_name(field.column),
- ) for field in fields
+ )
+ for field in fields
]
- return 'RETURNING %s' % ', '.join(columns), ()
+ return "RETURNING %s" % ", ".join(columns), ()
def bulk_insert_sql(self, fields, placeholder_rows):
placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
@@ -252,7 +278,7 @@ class DatabaseOperations(BaseDatabaseOperations):
return None
def subtract_temporals(self, internal_type, lhs, rhs):
- if internal_type == 'DateField':
+ if internal_type == "DateField":
lhs_sql, lhs_params = lhs
rhs_sql, rhs_params = rhs
params = (*lhs_params, *rhs_params)
@@ -263,27 +289,34 @@ class DatabaseOperations(BaseDatabaseOperations):
prefix = super().explain_query_prefix(format)
extra = {}
if format:
- extra['FORMAT'] = format
+ extra["FORMAT"] = format
if options:
- extra.update({
- name.upper(): 'true' if value else 'false'
- for name, value in options.items()
- })
+ extra.update(
+ {
+ name.upper(): "true" if value else "false"
+ for name, value in options.items()
+ }
+ )
if extra:
- prefix += ' (%s)' % ', '.join('%s %s' % i for i in extra.items())
+ prefix += " (%s)" % ", ".join("%s %s" % i for i in extra.items())
return prefix
def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
if on_conflict == OnConflict.IGNORE:
- return 'ON CONFLICT DO NOTHING'
+ return "ON CONFLICT DO NOTHING"
if on_conflict == OnConflict.UPDATE:
- return 'ON CONFLICT(%s) DO UPDATE SET %s' % (
- ', '.join(map(self.quote_name, unique_fields)),
- ', '.join([
- f'{field} = EXCLUDED.{field}'
- for field in map(self.quote_name, update_fields)
- ]),
+ return "ON CONFLICT(%s) DO UPDATE SET %s" % (
+ ", ".join(map(self.quote_name, unique_fields)),
+ ", ".join(
+ [
+ f"{field} = EXCLUDED.{field}"
+ for field in map(self.quote_name, update_fields)
+ ]
+ ),
)
return super().on_conflict_suffix_sql(
- fields, on_conflict, update_fields, unique_fields,
+ fields,
+ on_conflict,
+ update_fields,
+ unique_fields,
)
diff --git a/django/db/backends/postgresql/schema.py b/django/db/backends/postgresql/schema.py
index f3b5baecbe..47e9a6a8f3 100644
--- a/django/db/backends/postgresql/schema.py
+++ b/django/db/backends/postgresql/schema.py
@@ -9,16 +9,18 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
sql_create_sequence = "CREATE SEQUENCE %(sequence)s"
sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE"
- sql_set_sequence_max = "SELECT setval('%(sequence)s', MAX(%(column)s)) FROM %(table)s"
- sql_set_sequence_owner = 'ALTER SEQUENCE %(sequence)s OWNED BY %(table)s.%(column)s'
+ sql_set_sequence_max = (
+ "SELECT setval('%(sequence)s', MAX(%(column)s)) FROM %(table)s"
+ )
+ sql_set_sequence_owner = "ALTER SEQUENCE %(sequence)s OWNED BY %(table)s.%(column)s"
sql_create_index = (
- 'CREATE INDEX %(name)s ON %(table)s%(using)s '
- '(%(columns)s)%(include)s%(extra)s%(condition)s'
+ "CREATE INDEX %(name)s ON %(table)s%(using)s "
+ "(%(columns)s)%(include)s%(extra)s%(condition)s"
)
sql_create_index_concurrently = (
- 'CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s '
- '(%(columns)s)%(include)s%(extra)s%(condition)s'
+ "CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s "
+ "(%(columns)s)%(include)s%(extra)s%(condition)s"
)
sql_delete_index = "DROP INDEX IF EXISTS %(name)s"
sql_delete_index_concurrently = "DROP INDEX CONCURRENTLY IF EXISTS %(name)s"
@@ -26,21 +28,21 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# Setting the constraint to IMMEDIATE to allow changing data in the same
# transaction.
sql_create_column_inline_fk = (
- 'CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s'
- '; SET CONSTRAINTS %(namespace)s%(name)s IMMEDIATE'
+ "CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s"
+ "; SET CONSTRAINTS %(namespace)s%(name)s IMMEDIATE"
)
# Setting the constraint to IMMEDIATE runs any deferred checks to allow
# dropping it in the same transaction.
sql_delete_fk = "SET CONSTRAINTS %(name)s IMMEDIATE; ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
- sql_delete_procedure = 'DROP FUNCTION %(procedure)s(%(param_types)s)'
+ sql_delete_procedure = "DROP FUNCTION %(procedure)s(%(param_types)s)"
def quote_value(self, value):
if isinstance(value, str):
- value = value.replace('%', '%%')
+ value = value.replace("%", "%%")
adapted = psycopg2.extensions.adapt(value)
- if hasattr(adapted, 'encoding'):
- adapted.encoding = 'utf8'
+ if hasattr(adapted, "encoding"):
+ adapted.encoding = "utf8"
# getquoted() returns a quoted bytestring of the adapted value.
return adapted.getquoted().decode()
@@ -61,7 +63,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
def _field_base_data_types(self, field):
# Yield base data types for array fields.
- if field.base_field.get_internal_type() == 'ArrayField':
+ if field.base_field.get_internal_type() == "ArrayField":
yield from self._field_base_data_types(field.base_field)
else:
yield self._field_data_type(field.base_field)
@@ -80,45 +82,52 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
#
# The same doesn't apply to array fields such as varchar[size]
# and text[size], so skip them.
- if '[' in db_type:
+ if "[" in db_type:
return None
- if db_type.startswith('varchar'):
+ if db_type.startswith("varchar"):
return self._create_index_sql(
model,
fields=[field],
- suffix='_like',
- opclasses=['varchar_pattern_ops'],
+ suffix="_like",
+ opclasses=["varchar_pattern_ops"],
)
- elif db_type.startswith('text'):
+ elif db_type.startswith("text"):
return self._create_index_sql(
model,
fields=[field],
- suffix='_like',
- opclasses=['text_pattern_ops'],
+ suffix="_like",
+ opclasses=["text_pattern_ops"],
)
return None
def _alter_column_type_sql(self, model, old_field, new_field, new_type):
- self.sql_alter_column_type = 'ALTER COLUMN %(column)s TYPE %(type)s'
+ self.sql_alter_column_type = "ALTER COLUMN %(column)s TYPE %(type)s"
# Cast when data type changed.
- using_sql = ' USING %(column)s::%(type)s'
+ using_sql = " USING %(column)s::%(type)s"
new_internal_type = new_field.get_internal_type()
old_internal_type = old_field.get_internal_type()
- if new_internal_type == 'ArrayField' and new_internal_type == old_internal_type:
+ if new_internal_type == "ArrayField" and new_internal_type == old_internal_type:
# Compare base data types for array fields.
- if list(self._field_base_data_types(old_field)) != list(self._field_base_data_types(new_field)):
+ if list(self._field_base_data_types(old_field)) != list(
+ self._field_base_data_types(new_field)
+ ):
self.sql_alter_column_type += using_sql
elif self._field_data_type(old_field) != self._field_data_type(new_field):
self.sql_alter_column_type += using_sql
# Make ALTER TYPE with SERIAL make sense.
table = strip_quotes(model._meta.db_table)
- serial_fields_map = {'bigserial': 'bigint', 'serial': 'integer', 'smallserial': 'smallint'}
+ serial_fields_map = {
+ "bigserial": "bigint",
+ "serial": "integer",
+ "smallserial": "smallint",
+ }
if new_type.lower() in serial_fields_map:
column = strip_quotes(new_field.column)
sequence_name = "%s_%s_seq" % (table, column)
return (
(
- self.sql_alter_column_type % {
+ self.sql_alter_column_type
+ % {
"column": self.quote_name(column),
"type": serial_fields_map[new_type.lower()],
},
@@ -126,29 +135,35 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
),
[
(
- self.sql_delete_sequence % {
+ self.sql_delete_sequence
+ % {
"sequence": self.quote_name(sequence_name),
},
[],
),
(
- self.sql_create_sequence % {
+ self.sql_create_sequence
+ % {
"sequence": self.quote_name(sequence_name),
},
[],
),
(
- self.sql_alter_column % {
+ self.sql_alter_column
+ % {
"table": self.quote_name(table),
- "changes": self.sql_alter_column_default % {
+ "changes": self.sql_alter_column_default
+ % {
"column": self.quote_name(column),
- "default": "nextval('%s')" % self.quote_name(sequence_name),
- }
+ "default": "nextval('%s')"
+ % self.quote_name(sequence_name),
+ },
},
[],
),
(
- self.sql_set_sequence_max % {
+ self.sql_set_sequence_max
+ % {
"table": self.quote_name(table),
"column": self.quote_name(column),
"sequence": self.quote_name(sequence_name),
@@ -156,24 +171,31 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
[],
),
(
- self.sql_set_sequence_owner % {
- 'table': self.quote_name(table),
- 'column': self.quote_name(column),
- 'sequence': self.quote_name(sequence_name),
+ self.sql_set_sequence_owner
+ % {
+ "table": self.quote_name(table),
+ "column": self.quote_name(column),
+ "sequence": self.quote_name(sequence_name),
},
[],
),
],
)
- elif old_field.db_parameters(connection=self.connection)['type'] in serial_fields_map:
+ elif (
+ old_field.db_parameters(connection=self.connection)["type"]
+ in serial_fields_map
+ ):
# Drop the sequence if migrating away from AutoField.
column = strip_quotes(new_field.column)
- sequence_name = '%s_%s_seq' % (table, column)
- fragment, _ = super()._alter_column_type_sql(model, old_field, new_field, new_type)
+ sequence_name = "%s_%s_seq" % (table, column)
+ fragment, _ = super()._alter_column_type_sql(
+ model, old_field, new_field, new_type
+ )
return fragment, [
(
- self.sql_delete_sequence % {
- 'sequence': self.quote_name(sequence_name),
+ self.sql_delete_sequence
+ % {
+ "sequence": self.quote_name(sequence_name),
},
[],
),
@@ -181,58 +203,114 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
else:
return super()._alter_column_type_sql(model, old_field, new_field, new_type)
- def _alter_field(self, model, old_field, new_field, old_type, new_type,
- old_db_params, new_db_params, strict=False):
+ def _alter_field(
+ self,
+ model,
+ old_field,
+ new_field,
+ old_type,
+ new_type,
+ old_db_params,
+ new_db_params,
+ strict=False,
+ ):
# Drop indexes on varchar/text/citext columns that are changing to a
# different type.
if (old_field.db_index or old_field.unique) and (
- (old_type.startswith('varchar') and not new_type.startswith('varchar')) or
- (old_type.startswith('text') and not new_type.startswith('text')) or
- (old_type.startswith('citext') and not new_type.startswith('citext'))
+ (old_type.startswith("varchar") and not new_type.startswith("varchar"))
+ or (old_type.startswith("text") and not new_type.startswith("text"))
+ or (old_type.startswith("citext") and not new_type.startswith("citext"))
):
- index_name = self._create_index_name(model._meta.db_table, [old_field.column], suffix='_like')
+ index_name = self._create_index_name(
+ model._meta.db_table, [old_field.column], suffix="_like"
+ )
self.execute(self._delete_index_sql(model, index_name))
super()._alter_field(
- model, old_field, new_field, old_type, new_type, old_db_params,
- new_db_params, strict,
+ model,
+ old_field,
+ new_field,
+ old_type,
+ new_type,
+ old_db_params,
+ new_db_params,
+ strict,
)
# Added an index? Create any PostgreSQL-specific indexes.
- if ((not (old_field.db_index or old_field.unique) and new_field.db_index) or
- (not old_field.unique and new_field.unique)):
+ if (not (old_field.db_index or old_field.unique) and new_field.db_index) or (
+ not old_field.unique and new_field.unique
+ ):
like_index_statement = self._create_like_index_sql(model, new_field)
if like_index_statement is not None:
self.execute(like_index_statement)
# Removed an index? Drop any PostgreSQL-specific indexes.
if old_field.unique and not (new_field.db_index or new_field.unique):
- index_to_remove = self._create_index_name(model._meta.db_table, [old_field.column], suffix='_like')
+ index_to_remove = self._create_index_name(
+ model._meta.db_table, [old_field.column], suffix="_like"
+ )
self.execute(self._delete_index_sql(model, index_to_remove))
def _index_columns(self, table, columns, col_suffixes, opclasses):
if opclasses:
- return IndexColumns(table, columns, self.quote_name, col_suffixes=col_suffixes, opclasses=opclasses)
+ return IndexColumns(
+ table,
+ columns,
+ self.quote_name,
+ col_suffixes=col_suffixes,
+ opclasses=opclasses,
+ )
return super()._index_columns(table, columns, col_suffixes, opclasses)
def add_index(self, model, index, concurrently=False):
- self.execute(index.create_sql(model, self, concurrently=concurrently), params=None)
+ self.execute(
+ index.create_sql(model, self, concurrently=concurrently), params=None
+ )
def remove_index(self, model, index, concurrently=False):
self.execute(index.remove_sql(model, self, concurrently=concurrently))
def _delete_index_sql(self, model, name, sql=None, concurrently=False):
- sql = self.sql_delete_index_concurrently if concurrently else self.sql_delete_index
+ sql = (
+ self.sql_delete_index_concurrently
+ if concurrently
+ else self.sql_delete_index
+ )
return super()._delete_index_sql(model, name, sql)
def _create_index_sql(
- self, model, *, fields=None, name=None, suffix='', using='',
- db_tablespace=None, col_suffixes=(), sql=None, opclasses=(),
- condition=None, concurrently=False, include=None, expressions=None,
+ self,
+ model,
+ *,
+ fields=None,
+ name=None,
+ suffix="",
+ using="",
+ db_tablespace=None,
+ col_suffixes=(),
+ sql=None,
+ opclasses=(),
+ condition=None,
+ concurrently=False,
+ include=None,
+ expressions=None,
):
- sql = self.sql_create_index if not concurrently else self.sql_create_index_concurrently
+ sql = (
+ self.sql_create_index
+ if not concurrently
+ else self.sql_create_index_concurrently
+ )
return super()._create_index_sql(
- model, fields=fields, name=name, suffix=suffix, using=using,
- db_tablespace=db_tablespace, col_suffixes=col_suffixes, sql=sql,
- opclasses=opclasses, condition=condition, include=include,
+ model,
+ fields=fields,
+ name=name,
+ suffix=suffix,
+ using=using,
+ db_tablespace=db_tablespace,
+ col_suffixes=col_suffixes,
+ sql=sql,
+ opclasses=opclasses,
+ condition=condition,
+ include=include,
expressions=expressions,
)
diff --git a/django/db/backends/sqlite3/_functions.py b/django/db/backends/sqlite3/_functions.py
index 3529a99dd6..86684c1907 100644
--- a/django/db/backends/sqlite3/_functions.py
+++ b/django/db/backends/sqlite3/_functions.py
@@ -7,14 +7,30 @@ import statistics
from datetime import timedelta
from hashlib import sha1, sha224, sha256, sha384, sha512
from math import (
- acos, asin, atan, atan2, ceil, cos, degrees, exp, floor, fmod, log, pi,
- radians, sin, sqrt, tan,
+ acos,
+ asin,
+ atan,
+ atan2,
+ ceil,
+ cos,
+ degrees,
+ exp,
+ floor,
+ fmod,
+ log,
+ pi,
+ radians,
+ sin,
+ sqrt,
+ tan,
)
from re import search as re_search
from django.db.backends.base.base import timezone_constructor
from django.db.backends.utils import (
- split_tzname_delta, typecast_time, typecast_timestamp,
+ split_tzname_delta,
+ typecast_time,
+ typecast_timestamp,
)
from django.utils import timezone
from django.utils.crypto import md5
@@ -26,56 +42,62 @@ def register(connection):
connection.create_function,
deterministic=True,
)
- create_deterministic_function('django_date_extract', 2, _sqlite_datetime_extract)
- create_deterministic_function('django_date_trunc', 4, _sqlite_date_trunc)
- create_deterministic_function('django_datetime_cast_date', 3, _sqlite_datetime_cast_date)
- create_deterministic_function('django_datetime_cast_time', 3, _sqlite_datetime_cast_time)
- create_deterministic_function('django_datetime_extract', 4, _sqlite_datetime_extract)
- create_deterministic_function('django_datetime_trunc', 4, _sqlite_datetime_trunc)
- create_deterministic_function('django_time_extract', 2, _sqlite_time_extract)
- create_deterministic_function('django_time_trunc', 4, _sqlite_time_trunc)
- create_deterministic_function('django_time_diff', 2, _sqlite_time_diff)
- create_deterministic_function('django_timestamp_diff', 2, _sqlite_timestamp_diff)
- create_deterministic_function('django_format_dtdelta', 3, _sqlite_format_dtdelta)
- create_deterministic_function('regexp', 2, _sqlite_regexp)
- create_deterministic_function('ACOS', 1, _sqlite_acos)
- create_deterministic_function('ASIN', 1, _sqlite_asin)
- create_deterministic_function('ATAN', 1, _sqlite_atan)
- create_deterministic_function('ATAN2', 2, _sqlite_atan2)
- create_deterministic_function('BITXOR', 2, _sqlite_bitxor)
- create_deterministic_function('CEILING', 1, _sqlite_ceiling)
- create_deterministic_function('COS', 1, _sqlite_cos)
- create_deterministic_function('COT', 1, _sqlite_cot)
- create_deterministic_function('DEGREES', 1, _sqlite_degrees)
- create_deterministic_function('EXP', 1, _sqlite_exp)
- create_deterministic_function('FLOOR', 1, _sqlite_floor)
- create_deterministic_function('LN', 1, _sqlite_ln)
- create_deterministic_function('LOG', 2, _sqlite_log)
- create_deterministic_function('LPAD', 3, _sqlite_lpad)
- create_deterministic_function('MD5', 1, _sqlite_md5)
- create_deterministic_function('MOD', 2, _sqlite_mod)
- create_deterministic_function('PI', 0, _sqlite_pi)
- create_deterministic_function('POWER', 2, _sqlite_power)
- create_deterministic_function('RADIANS', 1, _sqlite_radians)
- create_deterministic_function('REPEAT', 2, _sqlite_repeat)
- create_deterministic_function('REVERSE', 1, _sqlite_reverse)
- create_deterministic_function('RPAD', 3, _sqlite_rpad)
- create_deterministic_function('SHA1', 1, _sqlite_sha1)
- create_deterministic_function('SHA224', 1, _sqlite_sha224)
- create_deterministic_function('SHA256', 1, _sqlite_sha256)
- create_deterministic_function('SHA384', 1, _sqlite_sha384)
- create_deterministic_function('SHA512', 1, _sqlite_sha512)
- create_deterministic_function('SIGN', 1, _sqlite_sign)
- create_deterministic_function('SIN', 1, _sqlite_sin)
- create_deterministic_function('SQRT', 1, _sqlite_sqrt)
- create_deterministic_function('TAN', 1, _sqlite_tan)
+ create_deterministic_function("django_date_extract", 2, _sqlite_datetime_extract)
+ create_deterministic_function("django_date_trunc", 4, _sqlite_date_trunc)
+ create_deterministic_function(
+ "django_datetime_cast_date", 3, _sqlite_datetime_cast_date
+ )
+ create_deterministic_function(
+ "django_datetime_cast_time", 3, _sqlite_datetime_cast_time
+ )
+ create_deterministic_function(
+ "django_datetime_extract", 4, _sqlite_datetime_extract
+ )
+ create_deterministic_function("django_datetime_trunc", 4, _sqlite_datetime_trunc)
+ create_deterministic_function("django_time_extract", 2, _sqlite_time_extract)
+ create_deterministic_function("django_time_trunc", 4, _sqlite_time_trunc)
+ create_deterministic_function("django_time_diff", 2, _sqlite_time_diff)
+ create_deterministic_function("django_timestamp_diff", 2, _sqlite_timestamp_diff)
+ create_deterministic_function("django_format_dtdelta", 3, _sqlite_format_dtdelta)
+ create_deterministic_function("regexp", 2, _sqlite_regexp)
+ create_deterministic_function("ACOS", 1, _sqlite_acos)
+ create_deterministic_function("ASIN", 1, _sqlite_asin)
+ create_deterministic_function("ATAN", 1, _sqlite_atan)
+ create_deterministic_function("ATAN2", 2, _sqlite_atan2)
+ create_deterministic_function("BITXOR", 2, _sqlite_bitxor)
+ create_deterministic_function("CEILING", 1, _sqlite_ceiling)
+ create_deterministic_function("COS", 1, _sqlite_cos)
+ create_deterministic_function("COT", 1, _sqlite_cot)
+ create_deterministic_function("DEGREES", 1, _sqlite_degrees)
+ create_deterministic_function("EXP", 1, _sqlite_exp)
+ create_deterministic_function("FLOOR", 1, _sqlite_floor)
+ create_deterministic_function("LN", 1, _sqlite_ln)
+ create_deterministic_function("LOG", 2, _sqlite_log)
+ create_deterministic_function("LPAD", 3, _sqlite_lpad)
+ create_deterministic_function("MD5", 1, _sqlite_md5)
+ create_deterministic_function("MOD", 2, _sqlite_mod)
+ create_deterministic_function("PI", 0, _sqlite_pi)
+ create_deterministic_function("POWER", 2, _sqlite_power)
+ create_deterministic_function("RADIANS", 1, _sqlite_radians)
+ create_deterministic_function("REPEAT", 2, _sqlite_repeat)
+ create_deterministic_function("REVERSE", 1, _sqlite_reverse)
+ create_deterministic_function("RPAD", 3, _sqlite_rpad)
+ create_deterministic_function("SHA1", 1, _sqlite_sha1)
+ create_deterministic_function("SHA224", 1, _sqlite_sha224)
+ create_deterministic_function("SHA256", 1, _sqlite_sha256)
+ create_deterministic_function("SHA384", 1, _sqlite_sha384)
+ create_deterministic_function("SHA512", 1, _sqlite_sha512)
+ create_deterministic_function("SIGN", 1, _sqlite_sign)
+ create_deterministic_function("SIN", 1, _sqlite_sin)
+ create_deterministic_function("SQRT", 1, _sqlite_sqrt)
+ create_deterministic_function("TAN", 1, _sqlite_tan)
# Don't use the built-in RANDOM() function because it returns a value
# in the range [-1 * 2^63, 2^63 - 1] instead of [0, 1).
- connection.create_function('RAND', 0, random.random)
- connection.create_aggregate('STDDEV_POP', 1, StdDevPop)
- connection.create_aggregate('STDDEV_SAMP', 1, StdDevSamp)
- connection.create_aggregate('VAR_POP', 1, VarPop)
- connection.create_aggregate('VAR_SAMP', 1, VarSamp)
+ connection.create_function("RAND", 0, random.random)
+ connection.create_aggregate("STDDEV_POP", 1, StdDevPop)
+ connection.create_aggregate("STDDEV_SAMP", 1, StdDevSamp)
+ connection.create_aggregate("VAR_POP", 1, VarPop)
+ connection.create_aggregate("VAR_SAMP", 1, VarSamp)
def _sqlite_datetime_parse(dt, tzname=None, conn_tzname=None):
@@ -90,9 +112,9 @@ def _sqlite_datetime_parse(dt, tzname=None, conn_tzname=None):
if tzname is not None and tzname != conn_tzname:
tzname, sign, offset = split_tzname_delta(tzname)
if offset:
- hours, minutes = offset.split(':')
+ hours, minutes = offset.split(":")
offset_delta = timedelta(hours=int(hours), minutes=int(minutes))
- dt += offset_delta if sign == '+' else -offset_delta
+ dt += offset_delta if sign == "+" else -offset_delta
dt = timezone.localtime(dt, timezone_constructor(tzname))
return dt
@@ -101,19 +123,19 @@ def _sqlite_date_trunc(lookup_type, dt, tzname, conn_tzname):
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt is None:
return None
- if lookup_type == 'year':
- return f'{dt.year:04d}-01-01'
- elif lookup_type == 'quarter':
+ if lookup_type == "year":
+ return f"{dt.year:04d}-01-01"
+ elif lookup_type == "quarter":
month_in_quarter = dt.month - (dt.month - 1) % 3
- return f'{dt.year:04d}-{month_in_quarter:02d}-01'
- elif lookup_type == 'month':
- return f'{dt.year:04d}-{dt.month:02d}-01'
- elif lookup_type == 'week':
+ return f"{dt.year:04d}-{month_in_quarter:02d}-01"
+ elif lookup_type == "month":
+ return f"{dt.year:04d}-{dt.month:02d}-01"
+ elif lookup_type == "week":
dt = dt - timedelta(days=dt.weekday())
- return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d}'
- elif lookup_type == 'day':
- return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d}'
- raise ValueError(f'Unsupported lookup type: {lookup_type!r}')
+ return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}"
+ elif lookup_type == "day":
+ return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}"
+ raise ValueError(f"Unsupported lookup type: {lookup_type!r}")
def _sqlite_time_trunc(lookup_type, dt, tzname, conn_tzname):
@@ -127,13 +149,13 @@ def _sqlite_time_trunc(lookup_type, dt, tzname, conn_tzname):
return None
else:
dt = dt_parsed
- if lookup_type == 'hour':
- return f'{dt.hour:02d}:00:00'
- elif lookup_type == 'minute':
- return f'{dt.hour:02d}:{dt.minute:02d}:00'
- elif lookup_type == 'second':
- return f'{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}'
- raise ValueError(f'Unsupported lookup type: {lookup_type!r}')
+ if lookup_type == "hour":
+ return f"{dt.hour:02d}:00:00"
+ elif lookup_type == "minute":
+ return f"{dt.hour:02d}:{dt.minute:02d}:00"
+ elif lookup_type == "second":
+ return f"{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}"
+ raise ValueError(f"Unsupported lookup type: {lookup_type!r}")
def _sqlite_datetime_cast_date(dt, tzname, conn_tzname):
@@ -154,15 +176,15 @@ def _sqlite_datetime_extract(lookup_type, dt, tzname=None, conn_tzname=None):
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt is None:
return None
- if lookup_type == 'week_day':
+ if lookup_type == "week_day":
return (dt.isoweekday() % 7) + 1
- elif lookup_type == 'iso_week_day':
+ elif lookup_type == "iso_week_day":
return dt.isoweekday()
- elif lookup_type == 'week':
+ elif lookup_type == "week":
return dt.isocalendar()[1]
- elif lookup_type == 'quarter':
+ elif lookup_type == "quarter":
return ceil(dt.month / 3)
- elif lookup_type == 'iso_year':
+ elif lookup_type == "iso_year":
return dt.isocalendar()[0]
else:
return getattr(dt, lookup_type)
@@ -172,25 +194,25 @@ def _sqlite_datetime_trunc(lookup_type, dt, tzname, conn_tzname):
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt is None:
return None
- if lookup_type == 'year':
- return f'{dt.year:04d}-01-01 00:00:00'
- elif lookup_type == 'quarter':
+ if lookup_type == "year":
+ return f"{dt.year:04d}-01-01 00:00:00"
+ elif lookup_type == "quarter":
month_in_quarter = dt.month - (dt.month - 1) % 3
- return f'{dt.year:04d}-{month_in_quarter:02d}-01 00:00:00'
- elif lookup_type == 'month':
- return f'{dt.year:04d}-{dt.month:02d}-01 00:00:00'
- elif lookup_type == 'week':
+ return f"{dt.year:04d}-{month_in_quarter:02d}-01 00:00:00"
+ elif lookup_type == "month":
+ return f"{dt.year:04d}-{dt.month:02d}-01 00:00:00"
+ elif lookup_type == "week":
dt = dt - timedelta(days=dt.weekday())
- return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d} 00:00:00'
- elif lookup_type == 'day':
- return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d} 00:00:00'
- elif lookup_type == 'hour':
- return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d} {dt.hour:02d}:00:00'
- elif lookup_type == 'minute':
- return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d} {dt.hour:02d}:{dt.minute:02d}:00'
- elif lookup_type == 'second':
- return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d} {dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}'
- raise ValueError(f'Unsupported lookup type: {lookup_type!r}')
+ return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} 00:00:00"
+ elif lookup_type == "day":
+ return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} 00:00:00"
+ elif lookup_type == "hour":
+ return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} {dt.hour:02d}:00:00"
+ elif lookup_type == "minute":
+ return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} {dt.hour:02d}:{dt.minute:02d}:00"
+ elif lookup_type == "second":
+ return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} {dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}"
+ raise ValueError(f"Unsupported lookup type: {lookup_type!r}")
def _sqlite_time_extract(lookup_type, dt):
@@ -204,7 +226,7 @@ def _sqlite_time_extract(lookup_type, dt):
def _sqlite_prepare_dtdelta_param(conn, param):
- if conn in ['+', '-']:
+ if conn in ["+", "-"]:
if isinstance(param, int):
return timedelta(0, 0, param)
else:
@@ -227,13 +249,13 @@ def _sqlite_format_dtdelta(connector, lhs, rhs):
real_rhs = _sqlite_prepare_dtdelta_param(connector, rhs)
except (ValueError, TypeError):
return None
- if connector == '+':
+ if connector == "+":
# typecast_timestamp() returns a date or a datetime without timezone.
# It will be formatted as "%Y-%m-%d" or "%Y-%m-%d %H:%M:%S[.%f]"
out = str(real_lhs + real_rhs)
- elif connector == '-':
+ elif connector == "-":
out = str(real_lhs - real_rhs)
- elif connector == '*':
+ elif connector == "*":
out = real_lhs * real_rhs
else:
out = real_lhs / real_rhs
@@ -246,14 +268,14 @@ def _sqlite_time_diff(lhs, rhs):
left = typecast_time(lhs)
right = typecast_time(rhs)
return (
- (left.hour * 60 * 60 * 1000000) +
- (left.minute * 60 * 1000000) +
- (left.second * 1000000) +
- (left.microsecond) -
- (right.hour * 60 * 60 * 1000000) -
- (right.minute * 60 * 1000000) -
- (right.second * 1000000) -
- (right.microsecond)
+ (left.hour * 60 * 60 * 1000000)
+ + (left.minute * 60 * 1000000)
+ + (left.second * 1000000)
+ + (left.microsecond)
+ - (right.hour * 60 * 60 * 1000000)
+ - (right.minute * 60 * 1000000)
+ - (right.second * 1000000)
+ - (right.microsecond)
)
@@ -380,7 +402,7 @@ def _sqlite_pi():
def _sqlite_power(x, y):
if x is None or y is None:
return None
- return x ** y
+ return x**y
def _sqlite_radians(x):
diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py
index 4343ea180e..5bcd61eb96 100644
--- a/django/db/backends/sqlite3/base.py
+++ b/django/db/backends/sqlite3/base.py
@@ -32,13 +32,13 @@ def decoder(conv_func):
def check_sqlite_version():
if Database.sqlite_version_info < (3, 9, 0):
raise ImproperlyConfigured(
- 'SQLite 3.9.0 or later is required (found %s).' % Database.sqlite_version
+ "SQLite 3.9.0 or later is required (found %s)." % Database.sqlite_version
)
check_sqlite_version()
-Database.register_converter("bool", b'1'.__eq__)
+Database.register_converter("bool", b"1".__eq__)
Database.register_converter("time", decoder(parse_time))
Database.register_converter("datetime", decoder(parse_datetime))
Database.register_converter("timestamp", decoder(parse_datetime))
@@ -47,69 +47,69 @@ Database.register_adapter(decimal.Decimal, str)
class DatabaseWrapper(BaseDatabaseWrapper):
- vendor = 'sqlite'
- display_name = 'SQLite'
+ vendor = "sqlite"
+ display_name = "SQLite"
# SQLite doesn't actually support most of these types, but it "does the right
# thing" given more verbose field definitions, so leave them as is so that
# schema inspection is more useful.
data_types = {
- 'AutoField': 'integer',
- 'BigAutoField': 'integer',
- 'BinaryField': 'BLOB',
- 'BooleanField': 'bool',
- 'CharField': 'varchar(%(max_length)s)',
- 'DateField': 'date',
- 'DateTimeField': 'datetime',
- 'DecimalField': 'decimal',
- 'DurationField': 'bigint',
- 'FileField': 'varchar(%(max_length)s)',
- 'FilePathField': 'varchar(%(max_length)s)',
- 'FloatField': 'real',
- 'IntegerField': 'integer',
- 'BigIntegerField': 'bigint',
- 'IPAddressField': 'char(15)',
- 'GenericIPAddressField': 'char(39)',
- 'JSONField': 'text',
- 'OneToOneField': 'integer',
- 'PositiveBigIntegerField': 'bigint unsigned',
- 'PositiveIntegerField': 'integer unsigned',
- 'PositiveSmallIntegerField': 'smallint unsigned',
- 'SlugField': 'varchar(%(max_length)s)',
- 'SmallAutoField': 'integer',
- 'SmallIntegerField': 'smallint',
- 'TextField': 'text',
- 'TimeField': 'time',
- 'UUIDField': 'char(32)',
+ "AutoField": "integer",
+ "BigAutoField": "integer",
+ "BinaryField": "BLOB",
+ "BooleanField": "bool",
+ "CharField": "varchar(%(max_length)s)",
+ "DateField": "date",
+ "DateTimeField": "datetime",
+ "DecimalField": "decimal",
+ "DurationField": "bigint",
+ "FileField": "varchar(%(max_length)s)",
+ "FilePathField": "varchar(%(max_length)s)",
+ "FloatField": "real",
+ "IntegerField": "integer",
+ "BigIntegerField": "bigint",
+ "IPAddressField": "char(15)",
+ "GenericIPAddressField": "char(39)",
+ "JSONField": "text",
+ "OneToOneField": "integer",
+ "PositiveBigIntegerField": "bigint unsigned",
+ "PositiveIntegerField": "integer unsigned",
+ "PositiveSmallIntegerField": "smallint unsigned",
+ "SlugField": "varchar(%(max_length)s)",
+ "SmallAutoField": "integer",
+ "SmallIntegerField": "smallint",
+ "TextField": "text",
+ "TimeField": "time",
+ "UUIDField": "char(32)",
}
data_type_check_constraints = {
- 'PositiveBigIntegerField': '"%(column)s" >= 0',
- 'JSONField': '(JSON_VALID("%(column)s") OR "%(column)s" IS NULL)',
- 'PositiveIntegerField': '"%(column)s" >= 0',
- 'PositiveSmallIntegerField': '"%(column)s" >= 0',
+ "PositiveBigIntegerField": '"%(column)s" >= 0',
+ "JSONField": '(JSON_VALID("%(column)s") OR "%(column)s" IS NULL)',
+ "PositiveIntegerField": '"%(column)s" >= 0',
+ "PositiveSmallIntegerField": '"%(column)s" >= 0',
}
data_types_suffix = {
- 'AutoField': 'AUTOINCREMENT',
- 'BigAutoField': 'AUTOINCREMENT',
- 'SmallAutoField': 'AUTOINCREMENT',
+ "AutoField": "AUTOINCREMENT",
+ "BigAutoField": "AUTOINCREMENT",
+ "SmallAutoField": "AUTOINCREMENT",
}
# SQLite requires LIKE statements to include an ESCAPE clause if the value
# being escaped has a percent or underscore in it.
# See https://www.sqlite.org/lang_expr.html for an explanation.
operators = {
- 'exact': '= %s',
- 'iexact': "LIKE %s ESCAPE '\\'",
- 'contains': "LIKE %s ESCAPE '\\'",
- 'icontains': "LIKE %s ESCAPE '\\'",
- 'regex': 'REGEXP %s',
- 'iregex': "REGEXP '(?i)' || %s",
- 'gt': '> %s',
- 'gte': '>= %s',
- 'lt': '< %s',
- 'lte': '<= %s',
- 'startswith': "LIKE %s ESCAPE '\\'",
- 'endswith': "LIKE %s ESCAPE '\\'",
- 'istartswith': "LIKE %s ESCAPE '\\'",
- 'iendswith': "LIKE %s ESCAPE '\\'",
+ "exact": "= %s",
+ "iexact": "LIKE %s ESCAPE '\\'",
+ "contains": "LIKE %s ESCAPE '\\'",
+ "icontains": "LIKE %s ESCAPE '\\'",
+ "regex": "REGEXP %s",
+ "iregex": "REGEXP '(?i)' || %s",
+ "gt": "> %s",
+ "gte": ">= %s",
+ "lt": "< %s",
+ "lte": "<= %s",
+ "startswith": "LIKE %s ESCAPE '\\'",
+ "endswith": "LIKE %s ESCAPE '\\'",
+ "istartswith": "LIKE %s ESCAPE '\\'",
+ "iendswith": "LIKE %s ESCAPE '\\'",
}
# The patterns below are used to generate SQL pattern lookup clauses when
@@ -122,12 +122,12 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# the LIKE operator.
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\', '\\'), '%%', '\%%'), '_', '\_')"
pattern_ops = {
- 'contains': r"LIKE '%%' || {} || '%%' ESCAPE '\'",
- 'icontains': r"LIKE '%%' || UPPER({}) || '%%' ESCAPE '\'",
- 'startswith': r"LIKE {} || '%%' ESCAPE '\'",
- 'istartswith': r"LIKE UPPER({}) || '%%' ESCAPE '\'",
- 'endswith': r"LIKE '%%' || {} ESCAPE '\'",
- 'iendswith': r"LIKE '%%' || UPPER({}) ESCAPE '\'",
+ "contains": r"LIKE '%%' || {} || '%%' ESCAPE '\'",
+ "icontains": r"LIKE '%%' || UPPER({}) || '%%' ESCAPE '\'",
+ "startswith": r"LIKE {} || '%%' ESCAPE '\'",
+ "istartswith": r"LIKE UPPER({}) || '%%' ESCAPE '\'",
+ "endswith": r"LIKE '%%' || {} ESCAPE '\'",
+ "iendswith": r"LIKE '%%' || UPPER({}) ESCAPE '\'",
}
Database = Database
@@ -141,14 +141,15 @@ class DatabaseWrapper(BaseDatabaseWrapper):
def get_connection_params(self):
settings_dict = self.settings_dict
- if not settings_dict['NAME']:
+ if not settings_dict["NAME"]:
raise ImproperlyConfigured(
"settings.DATABASES is improperly configured. "
- "Please supply the NAME value.")
+ "Please supply the NAME value."
+ )
kwargs = {
- 'database': settings_dict['NAME'],
- 'detect_types': Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES,
- **settings_dict['OPTIONS'],
+ "database": settings_dict["NAME"],
+ "detect_types": Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES,
+ **settings_dict["OPTIONS"],
}
# Always allow the underlying SQLite connection to be shareable
# between multiple threads. The safe-guarding will be handled at a
@@ -156,15 +157,15 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# property. This is necessary as the shareability is disabled by
# default in pysqlite and it cannot be changed once a connection is
# opened.
- if 'check_same_thread' in kwargs and kwargs['check_same_thread']:
+ if "check_same_thread" in kwargs and kwargs["check_same_thread"]:
warnings.warn(
- 'The `check_same_thread` option was provided and set to '
- 'True. It will be overridden with False. Use the '
- '`DatabaseWrapper.allow_thread_sharing` property instead '
- 'for controlling thread shareability.',
- RuntimeWarning
+ "The `check_same_thread` option was provided and set to "
+ "True. It will be overridden with False. Use the "
+ "`DatabaseWrapper.allow_thread_sharing` property instead "
+ "for controlling thread shareability.",
+ RuntimeWarning,
)
- kwargs.update({'check_same_thread': False, 'uri': True})
+ kwargs.update({"check_same_thread": False, "uri": True})
return kwargs
@async_unsafe
@@ -172,10 +173,10 @@ class DatabaseWrapper(BaseDatabaseWrapper):
conn = Database.connect(**conn_params)
register_functions(conn)
- conn.execute('PRAGMA foreign_keys = ON')
+ conn.execute("PRAGMA foreign_keys = ON")
# The macOS bundled SQLite defaults legacy_alter_table ON, which
# prevents atomic table renames (feature supports_atomic_references_rename)
- conn.execute('PRAGMA legacy_alter_table = OFF')
+ conn.execute("PRAGMA legacy_alter_table = OFF")
return conn
def init_connection_state(self):
@@ -207,7 +208,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
else:
# sqlite3's internal default is ''. It's different from None.
# See Modules/_sqlite/connection.c.
- level = ''
+ level = ""
# 'isolation_level' is a misleading API.
# SQLite always runs at the SERIALIZABLE isolation level.
with self.wrap_database_errors:
@@ -215,16 +216,16 @@ class DatabaseWrapper(BaseDatabaseWrapper):
def disable_constraint_checking(self):
with self.cursor() as cursor:
- cursor.execute('PRAGMA foreign_keys = OFF')
+ cursor.execute("PRAGMA foreign_keys = OFF")
# Foreign key constraints cannot be turned off while in a multi-
# statement transaction. Fetch the current state of the pragma
# to determine if constraints are effectively disabled.
- enabled = cursor.execute('PRAGMA foreign_keys').fetchone()[0]
+ enabled = cursor.execute("PRAGMA foreign_keys").fetchone()[0]
return not bool(enabled)
def enable_constraint_checking(self):
with self.cursor() as cursor:
- cursor.execute('PRAGMA foreign_keys = ON')
+ cursor.execute("PRAGMA foreign_keys = ON")
def check_constraints(self, table_names=None):
"""
@@ -237,24 +238,32 @@ class DatabaseWrapper(BaseDatabaseWrapper):
if self.features.supports_pragma_foreign_key_check:
with self.cursor() as cursor:
if table_names is None:
- violations = cursor.execute('PRAGMA foreign_key_check').fetchall()
+ violations = cursor.execute("PRAGMA foreign_key_check").fetchall()
else:
violations = chain.from_iterable(
cursor.execute(
- 'PRAGMA foreign_key_check(%s)'
+ "PRAGMA foreign_key_check(%s)"
% self.ops.quote_name(table_name)
).fetchall()
for table_name in table_names
)
# See https://www.sqlite.org/pragma.html#pragma_foreign_key_check
- for table_name, rowid, referenced_table_name, foreign_key_index in violations:
+ for (
+ table_name,
+ rowid,
+ referenced_table_name,
+ foreign_key_index,
+ ) in violations:
foreign_key = cursor.execute(
- 'PRAGMA foreign_key_list(%s)' % self.ops.quote_name(table_name)
+ "PRAGMA foreign_key_list(%s)" % self.ops.quote_name(table_name)
).fetchall()[foreign_key_index]
column_name, referenced_column_name = foreign_key[3:5]
- primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
+ primary_key_column_name = self.introspection.get_primary_key_column(
+ cursor, table_name
+ )
primary_key_value, bad_value = cursor.execute(
- 'SELECT %s, %s FROM %s WHERE rowid = %%s' % (
+ "SELECT %s, %s FROM %s WHERE rowid = %%s"
+ % (
self.ops.quote_name(primary_key_column_name),
self.ops.quote_name(column_name),
self.ops.quote_name(table_name),
@@ -264,9 +273,15 @@ class DatabaseWrapper(BaseDatabaseWrapper):
raise IntegrityError(
"The row in table '%s' with primary key '%s' has an "
"invalid foreign key: %s.%s contains a value '%s' that "
- "does not have a corresponding value in %s.%s." % (
- table_name, primary_key_value, table_name, column_name,
- bad_value, referenced_table_name, referenced_column_name
+ "does not have a corresponding value in %s.%s."
+ % (
+ table_name,
+ primary_key_value,
+ table_name,
+ column_name,
+ bad_value,
+ referenced_table_name,
+ referenced_column_name,
)
)
else:
@@ -274,11 +289,16 @@ class DatabaseWrapper(BaseDatabaseWrapper):
if table_names is None:
table_names = self.introspection.table_names(cursor)
for table_name in table_names:
- primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
+ primary_key_column_name = self.introspection.get_primary_key_column(
+ cursor, table_name
+ )
if not primary_key_column_name:
continue
relations = self.introspection.get_relations(cursor, table_name)
- for column_name, (referenced_column_name, referenced_table_name) in relations:
+ for column_name, (
+ referenced_column_name,
+ referenced_table_name,
+ ) in relations:
cursor.execute(
"""
SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
@@ -287,18 +307,29 @@ class DatabaseWrapper(BaseDatabaseWrapper):
WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL
"""
% (
- primary_key_column_name, column_name, table_name,
- referenced_table_name, column_name, referenced_column_name,
- column_name, referenced_column_name,
+ primary_key_column_name,
+ column_name,
+ table_name,
+ referenced_table_name,
+ column_name,
+ referenced_column_name,
+ column_name,
+ referenced_column_name,
)
)
for bad_row in cursor.fetchall():
raise IntegrityError(
"The row in table '%s' with primary key '%s' has an "
"invalid foreign key: %s.%s contains a value '%s' that "
- "does not have a corresponding value in %s.%s." % (
- table_name, bad_row[0], table_name, column_name,
- bad_row[1], referenced_table_name, referenced_column_name,
+ "does not have a corresponding value in %s.%s."
+ % (
+ table_name,
+ bad_row[0],
+ table_name,
+ column_name,
+ bad_row[1],
+ referenced_table_name,
+ referenced_column_name,
)
)
@@ -315,10 +346,10 @@ class DatabaseWrapper(BaseDatabaseWrapper):
self.cursor().execute("BEGIN")
def is_in_memory_db(self):
- return self.creation.is_in_memory_db(self.settings_dict['NAME'])
+ return self.creation.is_in_memory_db(self.settings_dict["NAME"])
-FORMAT_QMARK_REGEX = _lazy_re_compile(r'(?<!%)%s')
+FORMAT_QMARK_REGEX = _lazy_re_compile(r"(?<!%)%s")
class SQLiteCursorWrapper(Database.Cursor):
@@ -327,6 +358,7 @@ class SQLiteCursorWrapper(Database.Cursor):
This fixes it -- but note that if you want to use a literal "%s" in a query,
you'll need to use "%%s".
"""
+
def execute(self, query, params=None):
if params is None:
return Database.Cursor.execute(self, query)
@@ -338,4 +370,4 @@ class SQLiteCursorWrapper(Database.Cursor):
return Database.Cursor.executemany(self, query, param_list)
def convert_query(self, query):
- return FORMAT_QMARK_REGEX.sub('?', query).replace('%%', '%')
+ return FORMAT_QMARK_REGEX.sub("?", query).replace("%%", "%")
diff --git a/django/db/backends/sqlite3/client.py b/django/db/backends/sqlite3/client.py
index 69b9568db3..7cee35dc81 100644
--- a/django/db/backends/sqlite3/client.py
+++ b/django/db/backends/sqlite3/client.py
@@ -2,9 +2,9 @@ from django.db.backends.base.client import BaseDatabaseClient
class DatabaseClient(BaseDatabaseClient):
- executable_name = 'sqlite3'
+ executable_name = "sqlite3"
@classmethod
def settings_to_cmd_args_env(cls, settings_dict, parameters):
- args = [cls.executable_name, settings_dict['NAME'], *parameters]
+ args = [cls.executable_name, settings_dict["NAME"], *parameters]
return args, None
diff --git a/django/db/backends/sqlite3/creation.py b/django/db/backends/sqlite3/creation.py
index 4a4046c670..9d8d4a63ad 100644
--- a/django/db/backends/sqlite3/creation.py
+++ b/django/db/backends/sqlite3/creation.py
@@ -7,17 +7,16 @@ from django.db.backends.base.creation import BaseDatabaseCreation
class DatabaseCreation(BaseDatabaseCreation):
-
@staticmethod
def is_in_memory_db(database_name):
return not isinstance(database_name, Path) and (
- database_name == ':memory:' or 'mode=memory' in database_name
+ database_name == ":memory:" or "mode=memory" in database_name
)
def _get_test_db_name(self):
- test_database_name = self.connection.settings_dict['TEST']['NAME'] or ':memory:'
- if test_database_name == ':memory:':
- return 'file:memorydb_%s?mode=memory&cache=shared' % self.connection.alias
+ test_database_name = self.connection.settings_dict["TEST"]["NAME"] or ":memory:"
+ if test_database_name == ":memory:":
+ return "file:memorydb_%s?mode=memory&cache=shared" % self.connection.alias
return test_database_name
def _create_test_db(self, verbosity, autoclobber, keepdb=False):
@@ -28,38 +27,39 @@ class DatabaseCreation(BaseDatabaseCreation):
if not self.is_in_memory_db(test_database_name):
# Erase the old test database
if verbosity >= 1:
- self.log('Destroying old test database for alias %s...' % (
- self._get_database_display_str(verbosity, test_database_name),
- ))
+ self.log(
+ "Destroying old test database for alias %s..."
+ % (self._get_database_display_str(verbosity, test_database_name),)
+ )
if os.access(test_database_name, os.F_OK):
if not autoclobber:
confirm = input(
"Type 'yes' if you would like to try deleting the test "
"database '%s', or 'no' to cancel: " % test_database_name
)
- if autoclobber or confirm == 'yes':
+ if autoclobber or confirm == "yes":
try:
os.remove(test_database_name)
except Exception as e:
- self.log('Got an error deleting the old test database: %s' % e)
+ self.log("Got an error deleting the old test database: %s" % e)
sys.exit(2)
else:
- self.log('Tests cancelled.')
+ self.log("Tests cancelled.")
sys.exit(1)
return test_database_name
def get_test_db_clone_settings(self, suffix):
orig_settings_dict = self.connection.settings_dict
- source_database_name = orig_settings_dict['NAME']
+ source_database_name = orig_settings_dict["NAME"]
if self.is_in_memory_db(source_database_name):
return orig_settings_dict
else:
- root, ext = os.path.splitext(orig_settings_dict['NAME'])
- return {**orig_settings_dict, 'NAME': '{}_{}{}'.format(root, suffix, ext)}
+ root, ext = os.path.splitext(orig_settings_dict["NAME"])
+ return {**orig_settings_dict, "NAME": "{}_{}{}".format(root, suffix, ext)}
def _clone_test_db(self, suffix, verbosity, keepdb=False):
- source_database_name = self.connection.settings_dict['NAME']
- target_database_name = self.get_test_db_clone_settings(suffix)['NAME']
+ source_database_name = self.connection.settings_dict["NAME"]
+ target_database_name = self.get_test_db_clone_settings(suffix)["NAME"]
# Forking automatically makes a copy of an in-memory database.
if not self.is_in_memory_db(source_database_name):
# Erase the old test database
@@ -67,18 +67,23 @@ class DatabaseCreation(BaseDatabaseCreation):
if keepdb:
return
if verbosity >= 1:
- self.log('Destroying old test database for alias %s...' % (
- self._get_database_display_str(verbosity, target_database_name),
- ))
+ self.log(
+ "Destroying old test database for alias %s..."
+ % (
+ self._get_database_display_str(
+ verbosity, target_database_name
+ ),
+ )
+ )
try:
os.remove(target_database_name)
except Exception as e:
- self.log('Got an error deleting the old test database: %s' % e)
+ self.log("Got an error deleting the old test database: %s" % e)
sys.exit(2)
try:
shutil.copy(source_database_name, target_database_name)
except Exception as e:
- self.log('Got an error cloning the test database: %s' % e)
+ self.log("Got an error cloning the test database: %s" % e)
sys.exit(2)
def _destroy_test_db(self, test_database_name, verbosity):
@@ -95,7 +100,7 @@ class DatabaseCreation(BaseDatabaseCreation):
TEST NAME. See https://www.sqlite.org/inmemorydb.html
"""
test_database_name = self._get_test_db_name()
- sig = [self.connection.settings_dict['NAME']]
+ sig = [self.connection.settings_dict["NAME"]]
if self.is_in_memory_db(test_database_name):
sig.append(self.connection.alias)
else:
diff --git a/django/db/backends/sqlite3/features.py b/django/db/backends/sqlite3/features.py
index 153ce8d1d1..c076f0121e 100644
--- a/django/db/backends/sqlite3/features.py
+++ b/django/db/backends/sqlite3/features.py
@@ -43,49 +43,53 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_update_conflicts = Database.sqlite_version_info >= (3, 24, 0)
supports_update_conflicts_with_target = supports_update_conflicts
test_collations = {
- 'ci': 'nocase',
- 'cs': 'binary',
- 'non_default': 'nocase',
+ "ci": "nocase",
+ "cs": "binary",
+ "non_default": "nocase",
}
django_test_expected_failures = {
# The django_format_dtdelta() function doesn't properly handle mixed
# Date/DateTime fields and timedeltas.
- 'expressions.tests.FTimeDeltaTests.test_mixed_comparisons1',
+ "expressions.tests.FTimeDeltaTests.test_mixed_comparisons1",
}
@cached_property
def django_test_skips(self):
skips = {
- 'SQLite stores values rounded to 15 significant digits.': {
- 'model_fields.test_decimalfield.DecimalFieldTests.test_fetch_from_db_without_float_rounding',
+ "SQLite stores values rounded to 15 significant digits.": {
+ "model_fields.test_decimalfield.DecimalFieldTests.test_fetch_from_db_without_float_rounding",
},
- 'SQLite naively remakes the table on field alteration.': {
- 'schema.tests.SchemaTests.test_unique_no_unnecessary_fk_drops',
- 'schema.tests.SchemaTests.test_unique_and_reverse_m2m',
- 'schema.tests.SchemaTests.test_alter_field_default_doesnt_perform_queries',
- 'schema.tests.SchemaTests.test_rename_column_renames_deferred_sql_references',
+ "SQLite naively remakes the table on field alteration.": {
+ "schema.tests.SchemaTests.test_unique_no_unnecessary_fk_drops",
+ "schema.tests.SchemaTests.test_unique_and_reverse_m2m",
+ "schema.tests.SchemaTests.test_alter_field_default_doesnt_perform_queries",
+ "schema.tests.SchemaTests.test_rename_column_renames_deferred_sql_references",
},
"SQLite doesn't support negative precision for ROUND().": {
- 'db_functions.math.test_round.RoundTests.test_null_with_negative_precision',
- 'db_functions.math.test_round.RoundTests.test_decimal_with_negative_precision',
- 'db_functions.math.test_round.RoundTests.test_float_with_negative_precision',
- 'db_functions.math.test_round.RoundTests.test_integer_with_negative_precision',
+ "db_functions.math.test_round.RoundTests.test_null_with_negative_precision",
+ "db_functions.math.test_round.RoundTests.test_decimal_with_negative_precision",
+ "db_functions.math.test_round.RoundTests.test_float_with_negative_precision",
+ "db_functions.math.test_round.RoundTests.test_integer_with_negative_precision",
},
}
if Database.sqlite_version_info < (3, 27):
- skips.update({
- 'Nondeterministic failure on SQLite < 3.27.': {
- 'expressions_window.tests.WindowFunctionTests.test_subquery_row_range_rank',
- },
- })
+ skips.update(
+ {
+ "Nondeterministic failure on SQLite < 3.27.": {
+ "expressions_window.tests.WindowFunctionTests.test_subquery_row_range_rank",
+ },
+ }
+ )
if self.connection.is_in_memory_db():
- skips.update({
- "the sqlite backend's close() method is a no-op when using an "
- "in-memory database": {
- 'servers.test_liveserverthread.LiveServerThreadTest.test_closes_connections',
- 'servers.tests.LiveServerTestCloseConnectionTest.test_closes_connections',
- },
- })
+ skips.update(
+ {
+ "the sqlite backend's close() method is a no-op when using an "
+ "in-memory database": {
+ "servers.test_liveserverthread.LiveServerThreadTest.test_closes_connections",
+ "servers.tests.LiveServerTestCloseConnectionTest.test_closes_connections",
+ },
+ }
+ )
return skips
@cached_property
@@ -94,12 +98,12 @@ class DatabaseFeatures(BaseDatabaseFeatures):
@cached_property
def introspected_field_types(self):
- return{
+ return {
**super().introspected_field_types,
- 'BigAutoField': 'AutoField',
- 'DurationField': 'BigIntegerField',
- 'GenericIPAddressField': 'CharField',
- 'SmallAutoField': 'AutoField',
+ "BigAutoField": "AutoField",
+ "DurationField": "BigIntegerField",
+ "GenericIPAddressField": "CharField",
+ "SmallAutoField": "AutoField",
}
@cached_property
@@ -112,11 +116,13 @@ class DatabaseFeatures(BaseDatabaseFeatures):
return False
return True
- can_introspect_json_field = property(operator.attrgetter('supports_json_field'))
- has_json_object_function = property(operator.attrgetter('supports_json_field'))
+ can_introspect_json_field = property(operator.attrgetter("supports_json_field"))
+ has_json_object_function = property(operator.attrgetter("supports_json_field"))
@cached_property
def can_return_columns_from_insert(self):
return Database.sqlite_version_info >= (3, 35)
- can_return_rows_from_bulk_insert = property(operator.attrgetter('can_return_columns_from_insert'))
+ can_return_rows_from_bulk_insert = property(
+ operator.attrgetter("can_return_columns_from_insert")
+ )
diff --git a/django/db/backends/sqlite3/introspection.py b/django/db/backends/sqlite3/introspection.py
index 81884a7951..f5a5e81e9d 100644
--- a/django/db/backends/sqlite3/introspection.py
+++ b/django/db/backends/sqlite3/introspection.py
@@ -3,19 +3,21 @@ from collections import namedtuple
import sqlparse
from django.db import DatabaseError
-from django.db.backends.base.introspection import (
- BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo,
-)
+from django.db.backends.base.introspection import BaseDatabaseIntrospection
+from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo
+from django.db.backends.base.introspection import TableInfo
from django.db.models import Index
from django.utils.regex_helper import _lazy_re_compile
-FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('pk', 'has_json_constraint'))
+FieldInfo = namedtuple(
+ "FieldInfo", BaseFieldInfo._fields + ("pk", "has_json_constraint")
+)
-field_size_re = _lazy_re_compile(r'^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$')
+field_size_re = _lazy_re_compile(r"^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$")
def get_field_size(name):
- """ Extract the size number from a "varchar(11)" type name """
+ """Extract the size number from a "varchar(11)" type name"""
m = field_size_re.search(name)
return int(m[1]) if m else None
@@ -28,29 +30,29 @@ class FlexibleFieldLookupDict:
# entries here because SQLite allows for anything and doesn't normalize the
# field type; it uses whatever was given.
base_data_types_reverse = {
- 'bool': 'BooleanField',
- 'boolean': 'BooleanField',
- 'smallint': 'SmallIntegerField',
- 'smallint unsigned': 'PositiveSmallIntegerField',
- 'smallinteger': 'SmallIntegerField',
- 'int': 'IntegerField',
- 'integer': 'IntegerField',
- 'bigint': 'BigIntegerField',
- 'integer unsigned': 'PositiveIntegerField',
- 'bigint unsigned': 'PositiveBigIntegerField',
- 'decimal': 'DecimalField',
- 'real': 'FloatField',
- 'text': 'TextField',
- 'char': 'CharField',
- 'varchar': 'CharField',
- 'blob': 'BinaryField',
- 'date': 'DateField',
- 'datetime': 'DateTimeField',
- 'time': 'TimeField',
+ "bool": "BooleanField",
+ "boolean": "BooleanField",
+ "smallint": "SmallIntegerField",
+ "smallint unsigned": "PositiveSmallIntegerField",
+ "smallinteger": "SmallIntegerField",
+ "int": "IntegerField",
+ "integer": "IntegerField",
+ "bigint": "BigIntegerField",
+ "integer unsigned": "PositiveIntegerField",
+ "bigint unsigned": "PositiveBigIntegerField",
+ "decimal": "DecimalField",
+ "real": "FloatField",
+ "text": "TextField",
+ "char": "CharField",
+ "varchar": "CharField",
+ "blob": "BinaryField",
+ "date": "DateField",
+ "datetime": "DateTimeField",
+ "time": "TimeField",
}
def __getitem__(self, key):
- key = key.lower().split('(', 1)[0].strip()
+ key = key.lower().split("(", 1)[0].strip()
return self.base_data_types_reverse[key]
@@ -59,22 +61,28 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
def get_field_type(self, data_type, description):
field_type = super().get_field_type(data_type, description)
- if description.pk and field_type in {'BigIntegerField', 'IntegerField', 'SmallIntegerField'}:
+ if description.pk and field_type in {
+ "BigIntegerField",
+ "IntegerField",
+ "SmallIntegerField",
+ }:
# No support for BigAutoField or SmallAutoField as SQLite treats
# all integer primary keys as signed 64-bit integers.
- return 'AutoField'
+ return "AutoField"
if description.has_json_constraint:
- return 'JSONField'
+ return "JSONField"
return field_type
def get_table_list(self, cursor):
"""Return a list of table and view names in the current database."""
# Skip the sqlite_sequence system table used for autoincrement key
# generation.
- cursor.execute("""
+ cursor.execute(
+ """
SELECT name, type FROM sqlite_master
WHERE type in ('table', 'view') AND NOT name='sqlite_sequence'
- ORDER BY name""")
+ ORDER BY name"""
+ )
return [TableInfo(row[0], row[1][0]) for row in cursor.fetchall()]
def get_table_description(self, cursor, table_name):
@@ -82,37 +90,51 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
Return a description of the table with the DB-API cursor.description
interface.
"""
- cursor.execute('PRAGMA table_info(%s)' % self.connection.ops.quote_name(table_name))
+ cursor.execute(
+ "PRAGMA table_info(%s)" % self.connection.ops.quote_name(table_name)
+ )
table_info = cursor.fetchall()
if not table_info:
- raise DatabaseError(f'Table {table_name} does not exist (empty pragma).')
+ raise DatabaseError(f"Table {table_name} does not exist (empty pragma).")
collations = self._get_column_collations(cursor, table_name)
json_columns = set()
if self.connection.features.can_introspect_json_field:
for line in table_info:
column = line[1]
json_constraint_sql = '%%json_valid("%s")%%' % column
- has_json_constraint = cursor.execute("""
+ has_json_constraint = cursor.execute(
+ """
SELECT sql
FROM sqlite_master
WHERE
type = 'table' AND
name = %s AND
sql LIKE %s
- """, [table_name, json_constraint_sql]).fetchone()
+ """,
+ [table_name, json_constraint_sql],
+ ).fetchone()
if has_json_constraint:
json_columns.add(column)
return [
FieldInfo(
- name, data_type, None, get_field_size(data_type), None, None,
- not notnull, default, collations.get(name), pk == 1, name in json_columns
+ name,
+ data_type,
+ None,
+ get_field_size(data_type),
+ None,
+ None,
+ not notnull,
+ default,
+ collations.get(name),
+ pk == 1,
+ name in json_columns,
)
for cid, name, data_type, notnull, default, pk in table_info
]
def get_sequences(self, cursor, table_name, table_fields=()):
pk_col = self.get_primary_key_column(cursor, table_name)
- return [{'table': table_name, 'column': pk_col}]
+ return [{"table": table_name, "column": pk_col}]
def get_relations(self, cursor, table_name):
"""
@@ -120,7 +142,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
representing all foreign keys in the given table.
"""
cursor.execute(
- 'PRAGMA foreign_key_list(%s)' % self.connection.ops.quote_name(table_name)
+ "PRAGMA foreign_key_list(%s)" % self.connection.ops.quote_name(table_name)
)
return {
column_name: (ref_column_name, ref_table_name)
@@ -130,7 +152,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
def get_primary_key_column(self, cursor, table_name):
"""Return the column name of the primary key for the given table."""
cursor.execute(
- 'PRAGMA table_info(%s)' % self.connection.ops.quote_name(table_name)
+ "PRAGMA table_info(%s)" % self.connection.ops.quote_name(table_name)
)
for _, name, *_, pk in cursor.fetchall():
if pk:
@@ -148,19 +170,21 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
check_columns = []
braces_deep = 0
for token in tokens:
- if token.match(sqlparse.tokens.Punctuation, '('):
+ if token.match(sqlparse.tokens.Punctuation, "("):
braces_deep += 1
- elif token.match(sqlparse.tokens.Punctuation, ')'):
+ elif token.match(sqlparse.tokens.Punctuation, ")"):
braces_deep -= 1
if braces_deep < 0:
# End of columns and constraints for table definition.
break
- elif braces_deep == 0 and token.match(sqlparse.tokens.Punctuation, ','):
+ elif braces_deep == 0 and token.match(sqlparse.tokens.Punctuation, ","):
# End of current column or constraint definition.
break
# Detect column or constraint definition by first token.
if is_constraint_definition is None:
- is_constraint_definition = token.match(sqlparse.tokens.Keyword, 'CONSTRAINT')
+ is_constraint_definition = token.match(
+ sqlparse.tokens.Keyword, "CONSTRAINT"
+ )
if is_constraint_definition:
continue
if is_constraint_definition:
@@ -171,7 +195,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
constraint_name = token.value[1:-1]
# Start constraint columns parsing after UNIQUE keyword.
- if token.match(sqlparse.tokens.Keyword, 'UNIQUE'):
+ if token.match(sqlparse.tokens.Keyword, "UNIQUE"):
unique = True
unique_braces_deep = braces_deep
elif unique:
@@ -191,10 +215,10 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
field_name = token.value
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
field_name = token.value[1:-1]
- if token.match(sqlparse.tokens.Keyword, 'UNIQUE'):
+ if token.match(sqlparse.tokens.Keyword, "UNIQUE"):
unique_columns = [field_name]
# Start constraint columns parsing after CHECK keyword.
- if token.match(sqlparse.tokens.Keyword, 'CHECK'):
+ if token.match(sqlparse.tokens.Keyword, "CHECK"):
check = True
check_braces_deep = braces_deep
elif check:
@@ -209,22 +233,30 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
if token.value[1:-1] in columns:
check_columns.append(token.value[1:-1])
- unique_constraint = {
- 'unique': True,
- 'columns': unique_columns,
- 'primary_key': False,
- 'foreign_key': None,
- 'check': False,
- 'index': False,
- } if unique_columns else None
- check_constraint = {
- 'check': True,
- 'columns': check_columns,
- 'primary_key': False,
- 'unique': False,
- 'foreign_key': None,
- 'index': False,
- } if check_columns else None
+ unique_constraint = (
+ {
+ "unique": True,
+ "columns": unique_columns,
+ "primary_key": False,
+ "foreign_key": None,
+ "check": False,
+ "index": False,
+ }
+ if unique_columns
+ else None
+ )
+ check_constraint = (
+ {
+ "check": True,
+ "columns": check_columns,
+ "primary_key": False,
+ "unique": False,
+ "foreign_key": None,
+ "index": False,
+ }
+ if check_columns
+ else None
+ )
return constraint_name, unique_constraint, check_constraint, token
def _parse_table_constraints(self, sql, columns):
@@ -236,24 +268,33 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
tokens = (token for token in statement.flatten() if not token.is_whitespace)
# Go to columns and constraint definition
for token in tokens:
- if token.match(sqlparse.tokens.Punctuation, '('):
+ if token.match(sqlparse.tokens.Punctuation, "("):
break
# Parse columns and constraint definition
while True:
- constraint_name, unique, check, end_token = self._parse_column_or_constraint_definition(tokens, columns)
+ (
+ constraint_name,
+ unique,
+ check,
+ end_token,
+ ) = self._parse_column_or_constraint_definition(tokens, columns)
if unique:
if constraint_name:
constraints[constraint_name] = unique
else:
unnamed_constrains_index += 1
- constraints['__unnamed_constraint_%s__' % unnamed_constrains_index] = unique
+ constraints[
+ "__unnamed_constraint_%s__" % unnamed_constrains_index
+ ] = unique
if check:
if constraint_name:
constraints[constraint_name] = check
else:
unnamed_constrains_index += 1
- constraints['__unnamed_constraint_%s__' % unnamed_constrains_index] = check
- if end_token.match(sqlparse.tokens.Punctuation, ')'):
+ constraints[
+ "__unnamed_constraint_%s__" % unnamed_constrains_index
+ ] = check
+ if end_token.match(sqlparse.tokens.Punctuation, ")"):
break
return constraints
@@ -266,19 +307,22 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
# Find inline check constraints.
try:
table_schema = cursor.execute(
- "SELECT sql FROM sqlite_master WHERE type='table' and name=%s" % (
- self.connection.ops.quote_name(table_name),
- )
+ "SELECT sql FROM sqlite_master WHERE type='table' and name=%s"
+ % (self.connection.ops.quote_name(table_name),)
).fetchone()[0]
except TypeError:
# table_name is a view.
pass
else:
- columns = {info.name for info in self.get_table_description(cursor, table_name)}
+ columns = {
+ info.name for info in self.get_table_description(cursor, table_name)
+ }
constraints.update(self._parse_table_constraints(table_schema, columns))
# Get the index info
- cursor.execute("PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name))
+ cursor.execute(
+ "PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name)
+ )
for row in cursor.fetchall():
# SQLite 3.8.9+ has 5 columns, however older versions only give 3
# columns. Discard last 2 columns if there.
@@ -288,7 +332,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
"WHERE type='index' AND name=%s" % self.connection.ops.quote_name(index)
)
# There's at most one row.
- sql, = cursor.fetchone() or (None,)
+ (sql,) = cursor.fetchone() or (None,)
# Inline constraints are already detected in
# _parse_table_constraints(). The reasons to avoid fetching inline
# constraints from `PRAGMA index_list` are:
@@ -299,7 +343,9 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
# An inline constraint
continue
# Get the index info for that index
- cursor.execute('PRAGMA index_info(%s)' % self.connection.ops.quote_name(index))
+ cursor.execute(
+ "PRAGMA index_info(%s)" % self.connection.ops.quote_name(index)
+ )
for index_rank, column_rank, column in cursor.fetchall():
if index not in constraints:
constraints[index] = {
@@ -310,14 +356,14 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
"check": False,
"index": True,
}
- constraints[index]['columns'].append(column)
+ constraints[index]["columns"].append(column)
# Add type and column orders for indexes
- if constraints[index]['index']:
+ if constraints[index]["index"]:
# SQLite doesn't support any index type other than b-tree
- constraints[index]['type'] = Index.suffix
+ constraints[index]["type"] = Index.suffix
orders = self._get_index_columns_orders(sql)
if orders is not None:
- constraints[index]['orders'] = orders
+ constraints[index]["orders"] = orders
# Get the PK
pk_column = self.get_primary_key_column(cursor, table_name)
if pk_column:
@@ -334,44 +380,49 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
"index": False,
}
relations = enumerate(self.get_relations(cursor, table_name).items())
- constraints.update({
- f'fk_{index}': {
- 'columns': [column_name],
- 'primary_key': False,
- 'unique': False,
- 'foreign_key': (ref_table_name, ref_column_name),
- 'check': False,
- 'index': False,
+ constraints.update(
+ {
+ f"fk_{index}": {
+ "columns": [column_name],
+ "primary_key": False,
+ "unique": False,
+ "foreign_key": (ref_table_name, ref_column_name),
+ "check": False,
+ "index": False,
+ }
+ for index, (column_name, (ref_column_name, ref_table_name)) in relations
}
- for index, (column_name, (ref_column_name, ref_table_name)) in relations
- })
+ )
return constraints
def _get_index_columns_orders(self, sql):
tokens = sqlparse.parse(sql)[0]
for token in tokens:
if isinstance(token, sqlparse.sql.Parenthesis):
- columns = str(token).strip('()').split(', ')
- return ['DESC' if info.endswith('DESC') else 'ASC' for info in columns]
+ columns = str(token).strip("()").split(", ")
+ return ["DESC" if info.endswith("DESC") else "ASC" for info in columns]
return None
def _get_column_collations(self, cursor, table_name):
- row = cursor.execute("""
+ row = cursor.execute(
+ """
SELECT sql
FROM sqlite_master
WHERE type = 'table' AND name = %s
- """, [table_name]).fetchone()
+ """,
+ [table_name],
+ ).fetchone()
if not row:
return {}
sql = row[0]
- columns = str(sqlparse.parse(sql)[0][-1]).strip('()').split(', ')
+ columns = str(sqlparse.parse(sql)[0][-1]).strip("()").split(", ")
collations = {}
for column in columns:
tokens = column[1:].split()
column_name = tokens[0].strip('"')
for index, token in enumerate(tokens):
- if token == 'COLLATE':
+ if token == "COLLATE":
collation = tokens[index + 1]
break
else:
diff --git a/django/db/backends/sqlite3/operations.py b/django/db/backends/sqlite3/operations.py
index c1a6da4e5d..ef8b91c0f0 100644
--- a/django/db/backends/sqlite3/operations.py
+++ b/django/db/backends/sqlite3/operations.py
@@ -16,15 +16,15 @@ from django.utils.functional import cached_property
class DatabaseOperations(BaseDatabaseOperations):
- cast_char_field_without_max_length = 'text'
+ cast_char_field_without_max_length = "text"
cast_data_types = {
- 'DateField': 'TEXT',
- 'DateTimeField': 'TEXT',
+ "DateField": "TEXT",
+ "DateTimeField": "TEXT",
}
- explain_prefix = 'EXPLAIN QUERY PLAN'
+ explain_prefix = "EXPLAIN QUERY PLAN"
# List of datatypes to that cannot be extracted with JSON_EXTRACT() on
# SQLite. Use JSON_TYPE() instead.
- jsonfield_datatype_values = frozenset(['null', 'false', 'true'])
+ jsonfield_datatype_values = frozenset(["null", "false", "true"])
def bulk_batch_size(self, fields, objs):
"""
@@ -55,14 +55,14 @@ class DatabaseOperations(BaseDatabaseOperations):
else:
if isinstance(output_field, bad_fields):
raise NotSupportedError(
- 'You cannot use Sum, Avg, StdDev, and Variance '
- 'aggregations on date/time fields in sqlite3 '
- 'since date/time is saved as text.'
+ "You cannot use Sum, Avg, StdDev, and Variance "
+ "aggregations on date/time fields in sqlite3 "
+ "since date/time is saved as text."
)
if (
- isinstance(expression, models.Aggregate) and
- expression.distinct and
- len(expression.source_expressions) > 1
+ isinstance(expression, models.Aggregate)
+ and expression.distinct
+ and len(expression.source_expressions) > 1
):
raise NotSupportedError(
"SQLite doesn't support DISTINCT on aggregate functions "
@@ -105,26 +105,32 @@ class DatabaseOperations(BaseDatabaseOperations):
def _convert_tznames_to_sql(self, tzname):
if tzname and settings.USE_TZ:
return "'%s'" % tzname, "'%s'" % self.connection.timezone_name
- return 'NULL', 'NULL'
+ return "NULL", "NULL"
def datetime_cast_date_sql(self, field_name, tzname):
- return 'django_datetime_cast_date(%s, %s, %s)' % (
- field_name, *self._convert_tznames_to_sql(tzname),
+ return "django_datetime_cast_date(%s, %s, %s)" % (
+ field_name,
+ *self._convert_tznames_to_sql(tzname),
)
def datetime_cast_time_sql(self, field_name, tzname):
- return 'django_datetime_cast_time(%s, %s, %s)' % (
- field_name, *self._convert_tznames_to_sql(tzname),
+ return "django_datetime_cast_time(%s, %s, %s)" % (
+ field_name,
+ *self._convert_tznames_to_sql(tzname),
)
def datetime_extract_sql(self, lookup_type, field_name, tzname):
return "django_datetime_extract('%s', %s, %s, %s)" % (
- lookup_type.lower(), field_name, *self._convert_tznames_to_sql(tzname),
+ lookup_type.lower(),
+ field_name,
+ *self._convert_tznames_to_sql(tzname),
)
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
return "django_datetime_trunc('%s', %s, %s, %s)" % (
- lookup_type.lower(), field_name, *self._convert_tznames_to_sql(tzname),
+ lookup_type.lower(),
+ field_name,
+ *self._convert_tznames_to_sql(tzname),
)
def time_extract_sql(self, lookup_type, field_name):
@@ -146,11 +152,11 @@ class DatabaseOperations(BaseDatabaseOperations):
if len(params) > BATCH_SIZE:
results = ()
for index in range(0, len(params), BATCH_SIZE):
- chunk = params[index:index + BATCH_SIZE]
+ chunk = params[index : index + BATCH_SIZE]
results += self._quote_params_for_last_executed_query(chunk)
return results
- sql = 'SELECT ' + ', '.join(['QUOTE(?)'] * len(params))
+ sql = "SELECT " + ", ".join(["QUOTE(?)"] * len(params))
# Bypass Django's wrappers and use the underlying sqlite3 connection
# to avoid logging this query - it would trigger infinite recursion.
cursor = self.connection.connection.cursor()
@@ -215,14 +221,20 @@ class DatabaseOperations(BaseDatabaseOperations):
if tables and allow_cascade:
# Simulate TRUNCATE CASCADE by recursively collecting the tables
# referencing the tables to be flushed.
- tables = set(chain.from_iterable(self._references_graph(table) for table in tables))
- sql = ['%s %s %s;' % (
- style.SQL_KEYWORD('DELETE'),
- style.SQL_KEYWORD('FROM'),
- style.SQL_FIELD(self.quote_name(table))
- ) for table in tables]
+ tables = set(
+ chain.from_iterable(self._references_graph(table) for table in tables)
+ )
+ sql = [
+ "%s %s %s;"
+ % (
+ style.SQL_KEYWORD("DELETE"),
+ style.SQL_KEYWORD("FROM"),
+ style.SQL_FIELD(self.quote_name(table)),
+ )
+ for table in tables
+ ]
if reset_sequences:
- sequences = [{'table': table} for table in tables]
+ sequences = [{"table": table} for table in tables]
sql.extend(self.sequence_reset_by_name_sql(style, sequences))
return sql
@@ -230,17 +242,18 @@ class DatabaseOperations(BaseDatabaseOperations):
if not sequences:
return []
return [
- '%s %s %s %s = 0 %s %s %s (%s);' % (
- style.SQL_KEYWORD('UPDATE'),
- style.SQL_TABLE(self.quote_name('sqlite_sequence')),
- style.SQL_KEYWORD('SET'),
- style.SQL_FIELD(self.quote_name('seq')),
- style.SQL_KEYWORD('WHERE'),
- style.SQL_FIELD(self.quote_name('name')),
- style.SQL_KEYWORD('IN'),
- ', '.join([
- "'%s'" % sequence_info['table'] for sequence_info in sequences
- ]),
+ "%s %s %s %s = 0 %s %s %s (%s);"
+ % (
+ style.SQL_KEYWORD("UPDATE"),
+ style.SQL_TABLE(self.quote_name("sqlite_sequence")),
+ style.SQL_KEYWORD("SET"),
+ style.SQL_FIELD(self.quote_name("seq")),
+ style.SQL_KEYWORD("WHERE"),
+ style.SQL_FIELD(self.quote_name("name")),
+ style.SQL_KEYWORD("IN"),
+ ", ".join(
+ ["'%s'" % sequence_info["table"] for sequence_info in sequences]
+ ),
),
]
@@ -249,7 +262,7 @@ class DatabaseOperations(BaseDatabaseOperations):
return None
# Expression values are adapted by the database.
- if hasattr(value, 'resolve_expression'):
+ if hasattr(value, "resolve_expression"):
return value
# SQLite doesn't support tz-aware datetimes
@@ -257,7 +270,9 @@ class DatabaseOperations(BaseDatabaseOperations):
if settings.USE_TZ:
value = timezone.make_naive(value, self.connection.timezone)
else:
- raise ValueError("SQLite backend does not support timezone-aware datetimes when USE_TZ is False.")
+ raise ValueError(
+ "SQLite backend does not support timezone-aware datetimes when USE_TZ is False."
+ )
return str(value)
@@ -266,7 +281,7 @@ class DatabaseOperations(BaseDatabaseOperations):
return None
# Expression values are adapted by the database.
- if hasattr(value, 'resolve_expression'):
+ if hasattr(value, "resolve_expression"):
return value
# SQLite doesn't support tz-aware datetimes
@@ -278,17 +293,17 @@ class DatabaseOperations(BaseDatabaseOperations):
def get_db_converters(self, expression):
converters = super().get_db_converters(expression)
internal_type = expression.output_field.get_internal_type()
- if internal_type == 'DateTimeField':
+ if internal_type == "DateTimeField":
converters.append(self.convert_datetimefield_value)
- elif internal_type == 'DateField':
+ elif internal_type == "DateField":
converters.append(self.convert_datefield_value)
- elif internal_type == 'TimeField':
+ elif internal_type == "TimeField":
converters.append(self.convert_timefield_value)
- elif internal_type == 'DecimalField':
+ elif internal_type == "DecimalField":
converters.append(self.get_decimalfield_converter(expression))
- elif internal_type == 'UUIDField':
+ elif internal_type == "UUIDField":
converters.append(self.convert_uuidfield_value)
- elif internal_type == 'BooleanField':
+ elif internal_type == "BooleanField":
converters.append(self.convert_booleanfield_value)
return converters
@@ -317,15 +332,22 @@ class DatabaseOperations(BaseDatabaseOperations):
# float inaccuracy must be removed.
create_decimal = decimal.Context(prec=15).create_decimal_from_float
if isinstance(expression, Col):
- quantize_value = decimal.Decimal(1).scaleb(-expression.output_field.decimal_places)
+ quantize_value = decimal.Decimal(1).scaleb(
+ -expression.output_field.decimal_places
+ )
def converter(value, expression, connection):
if value is not None:
- return create_decimal(value).quantize(quantize_value, context=expression.output_field.context)
+ return create_decimal(value).quantize(
+ quantize_value, context=expression.output_field.context
+ )
+
else:
+
def converter(value, expression, connection):
if value is not None:
return create_decimal(value)
+
return converter
def convert_uuidfield_value(self, value, expression, connection):
@@ -337,26 +359,26 @@ class DatabaseOperations(BaseDatabaseOperations):
return bool(value) if value in (1, 0) else value
def bulk_insert_sql(self, fields, placeholder_rows):
- placeholder_rows_sql = (', '.join(row) for row in placeholder_rows)
- values_sql = ', '.join(f'({sql})' for sql in placeholder_rows_sql)
- return f'VALUES {values_sql}'
+ placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
+ values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
+ return f"VALUES {values_sql}"
def combine_expression(self, connector, sub_expressions):
# SQLite doesn't have a ^ operator, so use the user-defined POWER
# function that's registered in connect().
- if connector == '^':
- return 'POWER(%s)' % ','.join(sub_expressions)
- elif connector == '#':
- return 'BITXOR(%s)' % ','.join(sub_expressions)
+ if connector == "^":
+ return "POWER(%s)" % ",".join(sub_expressions)
+ elif connector == "#":
+ return "BITXOR(%s)" % ",".join(sub_expressions)
return super().combine_expression(connector, sub_expressions)
def combine_duration_expression(self, connector, sub_expressions):
- if connector not in ['+', '-', '*', '/']:
- raise DatabaseError('Invalid connector for timedelta: %s.' % connector)
+ if connector not in ["+", "-", "*", "/"]:
+ raise DatabaseError("Invalid connector for timedelta: %s." % connector)
fn_params = ["'%s'" % connector] + sub_expressions
if len(fn_params) > 3:
- raise ValueError('Too many params for timedelta operations.')
- return "django_format_dtdelta(%s)" % ', '.join(fn_params)
+ raise ValueError("Too many params for timedelta operations.")
+ return "django_format_dtdelta(%s)" % ", ".join(fn_params)
def integer_field_range(self, internal_type):
# SQLite doesn't enforce any integer constraints
@@ -366,39 +388,46 @@ class DatabaseOperations(BaseDatabaseOperations):
lhs_sql, lhs_params = lhs
rhs_sql, rhs_params = rhs
params = (*lhs_params, *rhs_params)
- if internal_type == 'TimeField':
- return 'django_time_diff(%s, %s)' % (lhs_sql, rhs_sql), params
- return 'django_timestamp_diff(%s, %s)' % (lhs_sql, rhs_sql), params
+ if internal_type == "TimeField":
+ return "django_time_diff(%s, %s)" % (lhs_sql, rhs_sql), params
+ return "django_timestamp_diff(%s, %s)" % (lhs_sql, rhs_sql), params
def insert_statement(self, on_conflict=None):
if on_conflict == OnConflict.IGNORE:
- return 'INSERT OR IGNORE INTO'
+ return "INSERT OR IGNORE INTO"
return super().insert_statement(on_conflict=on_conflict)
def return_insert_columns(self, fields):
# SQLite < 3.35 doesn't support an INSERT...RETURNING statement.
if not fields:
- return '', ()
+ return "", ()
columns = [
- '%s.%s' % (
+ "%s.%s"
+ % (
self.quote_name(field.model._meta.db_table),
self.quote_name(field.column),
- ) for field in fields
+ )
+ for field in fields
]
- return 'RETURNING %s' % ', '.join(columns), ()
+ return "RETURNING %s" % ", ".join(columns), ()
def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
if (
- on_conflict == OnConflict.UPDATE and
- self.connection.features.supports_update_conflicts_with_target
+ on_conflict == OnConflict.UPDATE
+ and self.connection.features.supports_update_conflicts_with_target
):
- return 'ON CONFLICT(%s) DO UPDATE SET %s' % (
- ', '.join(map(self.quote_name, unique_fields)),
- ', '.join([
- f'{field} = EXCLUDED.{field}'
- for field in map(self.quote_name, update_fields)
- ]),
+ return "ON CONFLICT(%s) DO UPDATE SET %s" % (
+ ", ".join(map(self.quote_name, unique_fields)),
+ ", ".join(
+ [
+ f"{field} = EXCLUDED.{field}"
+ for field in map(self.quote_name, update_fields)
+ ]
+ ),
)
return super().on_conflict_suffix_sql(
- fields, on_conflict, update_fields, unique_fields,
+ fields,
+ on_conflict,
+ update_fields,
+ unique_fields,
)
diff --git a/django/db/backends/sqlite3/schema.py b/django/db/backends/sqlite3/schema.py
index 3ff0a3f7db..c9af8088e5 100644
--- a/django/db/backends/sqlite3/schema.py
+++ b/django/db/backends/sqlite3/schema.py
@@ -14,7 +14,9 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
sql_delete_table = "DROP TABLE %(table)s"
sql_create_fk = None
- sql_create_inline_fk = "REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED"
+ sql_create_inline_fk = (
+ "REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED"
+ )
sql_create_column_inline_fk = sql_create_inline_fk
sql_create_unique = "CREATE UNIQUE INDEX %(name)s ON %(table)s (%(columns)s)"
sql_delete_unique = "DROP INDEX %(name)s"
@@ -24,11 +26,11 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# disabled. Enforce it here for the duration of the schema edition.
if not self.connection.disable_constraint_checking():
raise NotSupportedError(
- 'SQLite schema editor cannot be used while foreign key '
- 'constraint checks are enabled. Make sure to disable them '
- 'before entering a transaction.atomic() context because '
- 'SQLite does not support disabling them in the middle of '
- 'a multi-statement transaction.'
+ "SQLite schema editor cannot be used while foreign key "
+ "constraint checks are enabled. Make sure to disable them "
+ "before entering a transaction.atomic() context because "
+ "SQLite does not support disabling them in the middle of "
+ "a multi-statement transaction."
)
return super().__enter__()
@@ -43,6 +45,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# security hardening).
try:
import sqlite3
+
value = sqlite3.adapt(value)
except ImportError:
pass
@@ -54,7 +57,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
elif isinstance(value, (Decimal, float, int)):
return str(value)
elif isinstance(value, str):
- return "'%s'" % value.replace("\'", "\'\'")
+ return "'%s'" % value.replace("'", "''")
elif value is None:
return "NULL"
elif isinstance(value, (bytes, bytearray, memoryview)):
@@ -63,12 +66,16 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# character.
return "X'%s'" % value.hex()
else:
- raise ValueError("Cannot quote parameter value %r of type %s" % (value, type(value)))
+ raise ValueError(
+ "Cannot quote parameter value %r of type %s" % (value, type(value))
+ )
def prepare_default(self, value):
return self.quote_value(value)
- def _is_referenced_by_fk_constraint(self, table_name, column_name=None, ignore_self=False):
+ def _is_referenced_by_fk_constraint(
+ self, table_name, column_name=None, ignore_self=False
+ ):
"""
Return whether or not the provided table name is referenced by another
one. If `column_name` is specified, only references pointing to that
@@ -79,22 +86,33 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
for other_table in self.connection.introspection.get_table_list(cursor):
if ignore_self and other_table.name == table_name:
continue
- relations = self.connection.introspection.get_relations(cursor, other_table.name)
+ relations = self.connection.introspection.get_relations(
+ cursor, other_table.name
+ )
for constraint_column, constraint_table in relations.values():
- if (constraint_table == table_name and
- (column_name is None or constraint_column == column_name)):
+ if constraint_table == table_name and (
+ column_name is None or constraint_column == column_name
+ ):
return True
return False
- def alter_db_table(self, model, old_db_table, new_db_table, disable_constraints=True):
- if (not self.connection.features.supports_atomic_references_rename and
- disable_constraints and self._is_referenced_by_fk_constraint(old_db_table)):
+ def alter_db_table(
+ self, model, old_db_table, new_db_table, disable_constraints=True
+ ):
+ if (
+ not self.connection.features.supports_atomic_references_rename
+ and disable_constraints
+ and self._is_referenced_by_fk_constraint(old_db_table)
+ ):
if self.connection.in_atomic_block:
- raise NotSupportedError((
- 'Renaming the %r table while in a transaction is not '
- 'supported on SQLite < 3.26 because it would break referential '
- 'integrity. Try adding `atomic = False` to the Migration class.'
- ) % old_db_table)
+ raise NotSupportedError(
+ (
+ "Renaming the %r table while in a transaction is not "
+ "supported on SQLite < 3.26 because it would break referential "
+ "integrity. Try adding `atomic = False` to the Migration class."
+ )
+ % old_db_table
+ )
self.connection.enable_constraint_checking()
super().alter_db_table(model, old_db_table, new_db_table)
self.connection.disable_constraint_checking()
@@ -107,42 +125,56 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
old_field_name = old_field.name
table_name = model._meta.db_table
_, old_column_name = old_field.get_attname_column()
- if (new_field.name != old_field_name and
- not self.connection.features.supports_atomic_references_rename and
- self._is_referenced_by_fk_constraint(table_name, old_column_name, ignore_self=True)):
+ if (
+ new_field.name != old_field_name
+ and not self.connection.features.supports_atomic_references_rename
+ and self._is_referenced_by_fk_constraint(
+ table_name, old_column_name, ignore_self=True
+ )
+ ):
if self.connection.in_atomic_block:
- raise NotSupportedError((
- 'Renaming the %r.%r column while in a transaction is not '
- 'supported on SQLite < 3.26 because it would break referential '
- 'integrity. Try adding `atomic = False` to the Migration class.'
- ) % (model._meta.db_table, old_field_name))
+ raise NotSupportedError(
+ (
+ "Renaming the %r.%r column while in a transaction is not "
+ "supported on SQLite < 3.26 because it would break referential "
+ "integrity. Try adding `atomic = False` to the Migration class."
+ )
+ % (model._meta.db_table, old_field_name)
+ )
with atomic(self.connection.alias):
super().alter_field(model, old_field, new_field, strict=strict)
# Follow SQLite's documented procedure for performing changes
# that don't affect the on-disk content.
# https://sqlite.org/lang_altertable.html#otheralter
with self.connection.cursor() as cursor:
- schema_version = cursor.execute('PRAGMA schema_version').fetchone()[0]
- cursor.execute('PRAGMA writable_schema = 1')
+ schema_version = cursor.execute("PRAGMA schema_version").fetchone()[
+ 0
+ ]
+ cursor.execute("PRAGMA writable_schema = 1")
references_template = ' REFERENCES "%s" ("%%s") ' % table_name
new_column_name = new_field.get_attname_column()[1]
search = references_template % old_column_name
replacement = references_template % new_column_name
- cursor.execute('UPDATE sqlite_master SET sql = replace(sql, %s, %s)', (search, replacement))
- cursor.execute('PRAGMA schema_version = %d' % (schema_version + 1))
- cursor.execute('PRAGMA writable_schema = 0')
+ cursor.execute(
+ "UPDATE sqlite_master SET sql = replace(sql, %s, %s)",
+ (search, replacement),
+ )
+ cursor.execute("PRAGMA schema_version = %d" % (schema_version + 1))
+ cursor.execute("PRAGMA writable_schema = 0")
# The integrity check will raise an exception and rollback
# the transaction if the sqlite_master updates corrupt the
# database.
- cursor.execute('PRAGMA integrity_check')
+ cursor.execute("PRAGMA integrity_check")
# Perform a VACUUM to refresh the database representation from
# the sqlite_master table.
with self.connection.cursor() as cursor:
- cursor.execute('VACUUM')
+ cursor.execute("VACUUM")
else:
super().alter_field(model, old_field, new_field, strict=strict)
- def _remake_table(self, model, create_field=None, delete_field=None, alter_field=None):
+ def _remake_table(
+ self, model, create_field=None, delete_field=None, alter_field=None
+ ):
"""
Shortcut to transform a model from old_model into new_model
@@ -163,6 +195,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# to an altered field.
def is_self_referential(f):
return f.is_relation and f.remote_field.model is model
+
# Work out the new fields dict / mapping
body = {
f.name: f.clone() if is_self_referential(f) else f
@@ -170,14 +203,18 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
}
# Since mapping might mix column names and default values,
# its values must be already quoted.
- mapping = {f.column: self.quote_name(f.column) for f in model._meta.local_concrete_fields}
+ mapping = {
+ f.column: self.quote_name(f.column)
+ for f in model._meta.local_concrete_fields
+ }
# This maps field names (not columns) for things like unique_together
rename_mapping = {}
# If any of the new or altered fields is introducing a new PK,
# remove the old one
restore_pk_field = None
- if getattr(create_field, 'primary_key', False) or (
- alter_field and getattr(alter_field[1], 'primary_key', False)):
+ if getattr(create_field, "primary_key", False) or (
+ alter_field and getattr(alter_field[1], "primary_key", False)
+ ):
for name, field in list(body.items()):
if field.primary_key:
field.primary_key = False
@@ -201,8 +238,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
body[new_field.name] = new_field
if old_field.null and not new_field.null:
case_sql = "coalesce(%(col)s, %(default)s)" % {
- 'col': self.quote_name(old_field.column),
- 'default': self.prepare_default(self.effective_default(new_field)),
+ "col": self.quote_name(old_field.column),
+ "default": self.prepare_default(self.effective_default(new_field)),
}
mapping[new_field.column] = case_sql
else:
@@ -213,7 +250,10 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
del body[delete_field.name]
del mapping[delete_field.column]
# Remove any implicit M2M tables
- if delete_field.many_to_many and delete_field.remote_field.through._meta.auto_created:
+ if (
+ delete_field.many_to_many
+ and delete_field.remote_field.through._meta.auto_created
+ ):
return self.delete_model(delete_field.remote_field.through)
# Work inside a new app registry
apps = Apps()
@@ -235,8 +275,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
indexes = model._meta.indexes
if delete_field:
indexes = [
- index for index in indexes
- if delete_field.name not in index.fields
+ index for index in indexes if delete_field.name not in index.fields
]
constraints = list(model._meta.constraints)
@@ -252,52 +291,57 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# This wouldn't be required if the schema editor was operating on model
# states instead of rendered models.
meta_contents = {
- 'app_label': model._meta.app_label,
- 'db_table': model._meta.db_table,
- 'unique_together': unique_together,
- 'index_together': index_together,
- 'indexes': indexes,
- 'constraints': constraints,
- 'apps': apps,
+ "app_label": model._meta.app_label,
+ "db_table": model._meta.db_table,
+ "unique_together": unique_together,
+ "index_together": index_together,
+ "indexes": indexes,
+ "constraints": constraints,
+ "apps": apps,
}
meta = type("Meta", (), meta_contents)
- body_copy['Meta'] = meta
- body_copy['__module__'] = model.__module__
+ body_copy["Meta"] = meta
+ body_copy["__module__"] = model.__module__
type(model._meta.object_name, model.__bases__, body_copy)
# Construct a model with a renamed table name.
body_copy = copy.deepcopy(body)
meta_contents = {
- 'app_label': model._meta.app_label,
- 'db_table': 'new__%s' % strip_quotes(model._meta.db_table),
- 'unique_together': unique_together,
- 'index_together': index_together,
- 'indexes': indexes,
- 'constraints': constraints,
- 'apps': apps,
+ "app_label": model._meta.app_label,
+ "db_table": "new__%s" % strip_quotes(model._meta.db_table),
+ "unique_together": unique_together,
+ "index_together": index_together,
+ "indexes": indexes,
+ "constraints": constraints,
+ "apps": apps,
}
meta = type("Meta", (), meta_contents)
- body_copy['Meta'] = meta
- body_copy['__module__'] = model.__module__
- new_model = type('New%s' % model._meta.object_name, model.__bases__, body_copy)
+ body_copy["Meta"] = meta
+ body_copy["__module__"] = model.__module__
+ new_model = type("New%s" % model._meta.object_name, model.__bases__, body_copy)
# Create a new table with the updated schema.
self.create_model(new_model)
# Copy data from the old table into the new table
- self.execute("INSERT INTO %s (%s) SELECT %s FROM %s" % (
- self.quote_name(new_model._meta.db_table),
- ', '.join(self.quote_name(x) for x in mapping),
- ', '.join(mapping.values()),
- self.quote_name(model._meta.db_table),
- ))
+ self.execute(
+ "INSERT INTO %s (%s) SELECT %s FROM %s"
+ % (
+ self.quote_name(new_model._meta.db_table),
+ ", ".join(self.quote_name(x) for x in mapping),
+ ", ".join(mapping.values()),
+ self.quote_name(model._meta.db_table),
+ )
+ )
# Delete the old table to make way for the new
self.delete_model(model, handle_autom2m=False)
# Rename the new table to take way for the old
self.alter_db_table(
- new_model, new_model._meta.db_table, model._meta.db_table,
+ new_model,
+ new_model._meta.db_table,
+ model._meta.db_table,
disable_constraints=False,
)
@@ -314,12 +358,17 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
super().delete_model(model)
else:
# Delete the table (and only that)
- self.execute(self.sql_delete_table % {
- "table": self.quote_name(model._meta.db_table),
- })
+ self.execute(
+ self.sql_delete_table
+ % {
+ "table": self.quote_name(model._meta.db_table),
+ }
+ )
# Remove all deferred statements referencing the deleted table.
for sql in list(self.deferred_sql):
- if isinstance(sql, Statement) and sql.references_table(model._meta.db_table):
+ if isinstance(sql, Statement) and sql.references_table(
+ model._meta.db_table
+ ):
self.deferred_sql.remove(sql)
def add_field(self, model, field):
@@ -327,11 +376,14 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
if (
# Primary keys and unique fields are not supported in ALTER TABLE
# ADD COLUMN.
- field.primary_key or field.unique or
+ field.primary_key
+ or field.unique
+ or
# Fields with default values cannot by handled by ALTER TABLE ADD
# COLUMN statement because DROP DEFAULT is not supported in
# ALTER TABLE.
- not field.null or self.effective_default(field) is not None
+ not field.null
+ or self.effective_default(field) is not None
):
self._remake_table(model, create_field=field)
else:
@@ -351,21 +403,40 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# For everything else, remake.
else:
# It might not actually have a column behind it
- if field.db_parameters(connection=self.connection)['type'] is None:
+ if field.db_parameters(connection=self.connection)["type"] is None:
return
self._remake_table(model, delete_field=field)
- def _alter_field(self, model, old_field, new_field, old_type, new_type,
- old_db_params, new_db_params, strict=False):
+ def _alter_field(
+ self,
+ model,
+ old_field,
+ new_field,
+ old_type,
+ new_type,
+ old_db_params,
+ new_db_params,
+ strict=False,
+ ):
"""Perform a "physical" (non-ManyToMany) field update."""
# Use "ALTER TABLE ... RENAME COLUMN" if only the column name
# changed and there aren't any constraints.
- if (self.connection.features.can_alter_table_rename_column and
- old_field.column != new_field.column and
- self.column_sql(model, old_field) == self.column_sql(model, new_field) and
- not (old_field.remote_field and old_field.db_constraint or
- new_field.remote_field and new_field.db_constraint)):
- return self.execute(self._rename_field_sql(model._meta.db_table, old_field, new_field, new_type))
+ if (
+ self.connection.features.can_alter_table_rename_column
+ and old_field.column != new_field.column
+ and self.column_sql(model, old_field) == self.column_sql(model, new_field)
+ and not (
+ old_field.remote_field
+ and old_field.db_constraint
+ or new_field.remote_field
+ and new_field.db_constraint
+ )
+ ):
+ return self.execute(
+ self._rename_field_sql(
+ model._meta.db_table, old_field, new_field, new_type
+ )
+ )
# Alter by remaking table
self._remake_table(model, alter_field=(old_field, new_field))
# Rebuild tables with FKs pointing to this field.
@@ -393,15 +464,22 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
def _alter_many_to_many(self, model, old_field, new_field, strict):
"""Alter M2Ms to repoint their to= endpoints."""
- if old_field.remote_field.through._meta.db_table == new_field.remote_field.through._meta.db_table:
+ if (
+ old_field.remote_field.through._meta.db_table
+ == new_field.remote_field.through._meta.db_table
+ ):
# The field name didn't change, but some options did; we have to propagate this altering.
self._remake_table(
old_field.remote_field.through,
alter_field=(
# We need the field that points to the target model, so we can tell alter_field to change it -
# this is m2m_reverse_field_name() (as opposed to m2m_field_name, which points to our model)
- old_field.remote_field.through._meta.get_field(old_field.m2m_reverse_field_name()),
- new_field.remote_field.through._meta.get_field(new_field.m2m_reverse_field_name()),
+ old_field.remote_field.through._meta.get_field(
+ old_field.m2m_reverse_field_name()
+ ),
+ new_field.remote_field.through._meta.get_field(
+ new_field.m2m_reverse_field_name()
+ ),
),
)
return
@@ -409,29 +487,36 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# Make a new through table
self.create_model(new_field.remote_field.through)
# Copy the data across
- self.execute("INSERT INTO %s (%s) SELECT %s FROM %s" % (
- self.quote_name(new_field.remote_field.through._meta.db_table),
- ', '.join([
- "id",
- new_field.m2m_column_name(),
- new_field.m2m_reverse_name(),
- ]),
- ', '.join([
- "id",
- old_field.m2m_column_name(),
- old_field.m2m_reverse_name(),
- ]),
- self.quote_name(old_field.remote_field.through._meta.db_table),
- ))
+ self.execute(
+ "INSERT INTO %s (%s) SELECT %s FROM %s"
+ % (
+ self.quote_name(new_field.remote_field.through._meta.db_table),
+ ", ".join(
+ [
+ "id",
+ new_field.m2m_column_name(),
+ new_field.m2m_reverse_name(),
+ ]
+ ),
+ ", ".join(
+ [
+ "id",
+ old_field.m2m_column_name(),
+ old_field.m2m_reverse_name(),
+ ]
+ ),
+ self.quote_name(old_field.remote_field.through._meta.db_table),
+ )
+ )
# Delete the old through table
self.delete_model(old_field.remote_field.through)
def add_constraint(self, model, constraint):
if isinstance(constraint, UniqueConstraint) and (
- constraint.condition or
- constraint.contains_expressions or
- constraint.include or
- constraint.deferrable
+ constraint.condition
+ or constraint.contains_expressions
+ or constraint.include
+ or constraint.deferrable
):
super().add_constraint(model, constraint)
else:
@@ -439,14 +524,14 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
def remove_constraint(self, model, constraint):
if isinstance(constraint, UniqueConstraint) and (
- constraint.condition or
- constraint.contains_expressions or
- constraint.include or
- constraint.deferrable
+ constraint.condition
+ or constraint.contains_expressions
+ or constraint.include
+ or constraint.deferrable
):
super().remove_constraint(model, constraint)
else:
self._remake_table(model)
def _collate_sql(self, collation):
- return 'COLLATE ' + collation
+ return "COLLATE " + collation
diff --git a/django/db/backends/utils.py b/django/db/backends/utils.py
index b7318bae62..d505cd7904 100644
--- a/django/db/backends/utils.py
+++ b/django/db/backends/utils.py
@@ -9,7 +9,7 @@ from django.db import NotSupportedError
from django.utils.crypto import md5
from django.utils.dateparse import parse_time
-logger = logging.getLogger('django.db.backends')
+logger = logging.getLogger("django.db.backends")
class CursorWrapper:
@@ -17,7 +17,7 @@ class CursorWrapper:
self.cursor = cursor
self.db = db
- WRAP_ERROR_ATTRS = frozenset(['fetchone', 'fetchmany', 'fetchall', 'nextset'])
+ WRAP_ERROR_ATTRS = frozenset(["fetchone", "fetchmany", "fetchall", "nextset"])
def __getattr__(self, attr):
cursor_attr = getattr(self.cursor, attr)
@@ -50,8 +50,8 @@ class CursorWrapper:
# database driver may support them (e.g. cx_Oracle).
if kparams is not None and not self.db.features.supports_callproc_kwargs:
raise NotSupportedError(
- 'Keyword parameters for callproc are not supported on this '
- 'database backend.'
+ "Keyword parameters for callproc are not supported on this "
+ "database backend."
)
self.db.validate_no_broken_transaction()
with self.db.wrap_database_errors:
@@ -64,13 +64,17 @@ class CursorWrapper:
return self.cursor.callproc(procname, params, kparams)
def execute(self, sql, params=None):
- return self._execute_with_wrappers(sql, params, many=False, executor=self._execute)
+ return self._execute_with_wrappers(
+ sql, params, many=False, executor=self._execute
+ )
def executemany(self, sql, param_list):
- return self._execute_with_wrappers(sql, param_list, many=True, executor=self._executemany)
+ return self._execute_with_wrappers(
+ sql, param_list, many=True, executor=self._executemany
+ )
def _execute_with_wrappers(self, sql, params, many, executor):
- context = {'connection': self.db, 'cursor': self}
+ context = {"connection": self.db, "cursor": self}
for wrapper in reversed(self.db.execute_wrappers):
executor = functools.partial(wrapper, executor)
return executor(sql, params, many, context)
@@ -103,7 +107,9 @@ class CursorDebugWrapper(CursorWrapper):
return super().executemany(sql, param_list)
@contextmanager
- def debug_sql(self, sql=None, params=None, use_last_executed_query=False, many=False):
+ def debug_sql(
+ self, sql=None, params=None, use_last_executed_query=False, many=False
+ ):
start = time.monotonic()
try:
yield
@@ -113,21 +119,28 @@ class CursorDebugWrapper(CursorWrapper):
if use_last_executed_query:
sql = self.db.ops.last_executed_query(self.cursor, sql, params)
try:
- times = len(params) if many else ''
+ times = len(params) if many else ""
except TypeError:
# params could be an iterator.
- times = '?'
- self.db.queries_log.append({
- 'sql': '%s times: %s' % (times, sql) if many else sql,
- 'time': '%.3f' % duration,
- })
+ times = "?"
+ self.db.queries_log.append(
+ {
+ "sql": "%s times: %s" % (times, sql) if many else sql,
+ "time": "%.3f" % duration,
+ }
+ )
logger.debug(
- '(%.3f) %s; args=%s; alias=%s',
+ "(%.3f) %s; args=%s; alias=%s",
duration,
sql,
params,
self.db.alias,
- extra={'duration': duration, 'sql': sql, 'params': params, 'alias': self.db.alias},
+ extra={
+ "duration": duration,
+ "sql": sql,
+ "params": params,
+ "alias": self.db.alias,
+ },
)
@@ -135,7 +148,7 @@ def split_tzname_delta(tzname):
"""
Split a time zone name into a 3-tuple of (name, sign, offset).
"""
- for sign in ['+', '-']:
+ for sign in ["+", "-"]:
if sign in tzname:
name, offset = tzname.rsplit(sign, 1)
if offset and parse_time(offset):
@@ -147,19 +160,24 @@ def split_tzname_delta(tzname):
# Converters from database (string) to Python #
###############################################
+
def typecast_date(s):
- return datetime.date(*map(int, s.split('-'))) if s else None # return None if s is null
+ return (
+ datetime.date(*map(int, s.split("-"))) if s else None
+ ) # return None if s is null
def typecast_time(s): # does NOT store time zone information
if not s:
return None
- hour, minutes, seconds = s.split(':')
- if '.' in seconds: # check whether seconds have a fractional part
- seconds, microseconds = seconds.split('.')
+ hour, minutes, seconds = s.split(":")
+ if "." in seconds: # check whether seconds have a fractional part
+ seconds, microseconds = seconds.split(".")
else:
- microseconds = '0'
- return datetime.time(int(hour), int(minutes), int(seconds), int((microseconds + '000000')[:6]))
+ microseconds = "0"
+ return datetime.time(
+ int(hour), int(minutes), int(seconds), int((microseconds + "000000")[:6])
+ )
def typecast_timestamp(s): # does NOT store time zone information
@@ -167,25 +185,29 @@ def typecast_timestamp(s): # does NOT store time zone information
# "2005-07-29 09:56:00-05"
if not s:
return None
- if ' ' not in s:
+ if " " not in s:
return typecast_date(s)
d, t = s.split()
# Remove timezone information.
- if '-' in t:
- t, _ = t.split('-', 1)
- elif '+' in t:
- t, _ = t.split('+', 1)
- dates = d.split('-')
- times = t.split(':')
+ if "-" in t:
+ t, _ = t.split("-", 1)
+ elif "+" in t:
+ t, _ = t.split("+", 1)
+ dates = d.split("-")
+ times = t.split(":")
seconds = times[2]
- if '.' in seconds: # check whether seconds have a fractional part
- seconds, microseconds = seconds.split('.')
+ if "." in seconds: # check whether seconds have a fractional part
+ seconds, microseconds = seconds.split(".")
else:
- microseconds = '0'
+ microseconds = "0"
return datetime.datetime(
- int(dates[0]), int(dates[1]), int(dates[2]),
- int(times[0]), int(times[1]), int(seconds),
- int((microseconds + '000000')[:6])
+ int(dates[0]),
+ int(dates[1]),
+ int(dates[2]),
+ int(times[0]),
+ int(times[1]),
+ int(seconds),
+ int((microseconds + "000000")[:6]),
)
@@ -193,6 +215,7 @@ def typecast_timestamp(s): # does NOT store time zone information
# Converters from Python to database (string) #
###############################################
+
def split_identifier(identifier):
"""
Split an SQL identifier into a two element tuple of (namespace, name).
@@ -203,7 +226,7 @@ def split_identifier(identifier):
try:
namespace, name = identifier.split('"."')
except ValueError:
- namespace, name = '', identifier
+ namespace, name = "", identifier
return namespace.strip('"'), name.strip('"')
@@ -221,7 +244,11 @@ def truncate_name(identifier, length=None, hash_len=4):
return identifier
digest = names_digest(name, length=hash_len)
- return '%s%s%s' % ('%s"."' % namespace if namespace else '', name[:length - hash_len], digest)
+ return "%s%s%s" % (
+ '%s"."' % namespace if namespace else "",
+ name[: length - hash_len],
+ digest,
+ )
def names_digest(*args, length):
@@ -246,7 +273,9 @@ def format_number(value, max_digits, decimal_places):
if max_digits is not None:
context.prec = max_digits
if decimal_places is not None:
- value = value.quantize(decimal.Decimal(1).scaleb(-decimal_places), context=context)
+ value = value.quantize(
+ decimal.Decimal(1).scaleb(-decimal_places), context=context
+ )
else:
context.traps[decimal.Rounded] = 1
value = context.create_decimal(value)
diff --git a/django/db/migrations/autodetector.py b/django/db/migrations/autodetector.py
index f1238a3504..f8140f1845 100644
--- a/django/db/migrations/autodetector.py
+++ b/django/db/migrations/autodetector.py
@@ -10,7 +10,9 @@ from django.db.migrations.operations.models import AlterModelOptions
from django.db.migrations.optimizer import MigrationOptimizer
from django.db.migrations.questioner import MigrationQuestioner
from django.db.migrations.utils import (
- COMPILED_REGEX_TYPE, RegexObject, resolve_relation,
+ COMPILED_REGEX_TYPE,
+ RegexObject,
+ resolve_relation,
)
from django.utils.topological_sort import stable_topological_sort
@@ -57,19 +59,20 @@ class MigrationAutodetector:
elif isinstance(obj, tuple):
return tuple(self.deep_deconstruct(value) for value in obj)
elif isinstance(obj, dict):
- return {
- key: self.deep_deconstruct(value)
- for key, value in obj.items()
- }
+ return {key: self.deep_deconstruct(value) for key, value in obj.items()}
elif isinstance(obj, functools.partial):
- return (obj.func, self.deep_deconstruct(obj.args), self.deep_deconstruct(obj.keywords))
+ return (
+ obj.func,
+ self.deep_deconstruct(obj.args),
+ self.deep_deconstruct(obj.keywords),
+ )
elif isinstance(obj, COMPILED_REGEX_TYPE):
return RegexObject(obj)
elif isinstance(obj, type):
# If this is a type that implements 'deconstruct' as an instance method,
# avoid treating this as being deconstructible itself - see #22951
return obj
- elif hasattr(obj, 'deconstruct'):
+ elif hasattr(obj, "deconstruct"):
deconstructed = obj.deconstruct()
if isinstance(obj, models.Field):
# we have a field which also returns a name
@@ -78,10 +81,7 @@ class MigrationAutodetector:
return (
path,
[self.deep_deconstruct(value) for value in args],
- {
- key: self.deep_deconstruct(value)
- for key, value in kwargs.items()
- },
+ {key: self.deep_deconstruct(value) for key, value in kwargs.items()},
)
else:
return obj
@@ -96,7 +96,7 @@ class MigrationAutodetector:
for name, field in sorted(fields.items()):
deconstruction = self.deep_deconstruct(field)
if field.remote_field and field.remote_field.model:
- deconstruction[2].pop('to', None)
+ deconstruction[2].pop("to", None)
fields_def.append(deconstruction)
return fields_def
@@ -132,22 +132,21 @@ class MigrationAutodetector:
self.new_proxy_keys = set()
self.new_unmanaged_keys = set()
for (app_label, model_name), model_state in self.from_state.models.items():
- if not model_state.options.get('managed', True):
+ if not model_state.options.get("managed", True):
self.old_unmanaged_keys.add((app_label, model_name))
elif app_label not in self.from_state.real_apps:
- if model_state.options.get('proxy'):
+ if model_state.options.get("proxy"):
self.old_proxy_keys.add((app_label, model_name))
else:
self.old_model_keys.add((app_label, model_name))
for (app_label, model_name), model_state in self.to_state.models.items():
- if not model_state.options.get('managed', True):
+ if not model_state.options.get("managed", True):
self.new_unmanaged_keys.add((app_label, model_name))
- elif (
- app_label not in self.from_state.real_apps or
- (convert_apps and app_label in convert_apps)
+ elif app_label not in self.from_state.real_apps or (
+ convert_apps and app_label in convert_apps
):
- if model_state.options.get('proxy'):
+ if model_state.options.get("proxy"):
self.new_proxy_keys.add((app_label, model_name))
else:
self.new_model_keys.add((app_label, model_name))
@@ -214,8 +213,7 @@ class MigrationAutodetector:
(app_label, model_name, field_name)
for app_label, model_name in self.kept_model_keys
for field_name in self.from_state.models[
- app_label,
- self.renamed_models.get((app_label, model_name), model_name)
+ app_label, self.renamed_models.get((app_label, model_name), model_name)
].fields
}
self.new_field_keys = {
@@ -227,12 +225,22 @@ class MigrationAutodetector:
def _generate_through_model_map(self):
"""Through model map generation."""
for app_label, model_name in sorted(self.old_model_keys):
- old_model_name = self.renamed_models.get((app_label, model_name), model_name)
+ old_model_name = self.renamed_models.get(
+ (app_label, model_name), model_name
+ )
old_model_state = self.from_state.models[app_label, old_model_name]
for field_name, field in old_model_state.fields.items():
- if hasattr(field, 'remote_field') and getattr(field.remote_field, 'through', None):
- through_key = resolve_relation(field.remote_field.through, app_label, model_name)
- self.through_users[through_key] = (app_label, old_model_name, field_name)
+ if hasattr(field, "remote_field") and getattr(
+ field.remote_field, "through", None
+ ):
+ through_key = resolve_relation(
+ field.remote_field.through, app_label, model_name
+ )
+ self.through_users[through_key] = (
+ app_label,
+ old_model_name,
+ field_name,
+ )
@staticmethod
def _resolve_dependency(dependency):
@@ -240,9 +248,11 @@ class MigrationAutodetector:
Return the resolved dependency and a boolean denoting whether or not
it was swappable.
"""
- if dependency[0] != '__setting__':
+ if dependency[0] != "__setting__":
return dependency, False
- resolved_app_label, resolved_object_name = getattr(settings, dependency[1]).split('.')
+ resolved_app_label, resolved_object_name = getattr(
+ settings, dependency[1]
+ ).split(".")
return (resolved_app_label, resolved_object_name.lower()) + dependency[2:], True
def _build_migration_list(self, graph=None):
@@ -282,7 +292,9 @@ class MigrationAutodetector:
if dep[0] != app_label:
# External app dependency. See if it's not yet
# satisfied.
- for other_operation in self.generated_operations.get(dep[0], []):
+ for other_operation in self.generated_operations.get(
+ dep[0], []
+ ):
if self.check_dependency(other_operation, dep):
deps_satisfied = False
break
@@ -290,9 +302,13 @@ class MigrationAutodetector:
break
else:
if is_swappable_dep:
- operation_dependencies.add((original_dep[0], original_dep[1]))
+ operation_dependencies.add(
+ (original_dep[0], original_dep[1])
+ )
elif dep[0] in self.migrations:
- operation_dependencies.add((dep[0], self.migrations[dep[0]][-1].name))
+ operation_dependencies.add(
+ (dep[0], self.migrations[dep[0]][-1].name)
+ )
else:
# If we can't find the other app, we add a first/last dependency,
# but only if we've already been through once and checked everything
@@ -301,9 +317,13 @@ class MigrationAutodetector:
# as we don't know which migration contains the target field.
# If it's not yet migrated or has no migrations, we use __first__
if graph and graph.leaf_nodes(dep[0]):
- operation_dependencies.add(graph.leaf_nodes(dep[0])[0])
+ operation_dependencies.add(
+ graph.leaf_nodes(dep[0])[0]
+ )
else:
- operation_dependencies.add((dep[0], "__first__"))
+ operation_dependencies.add(
+ (dep[0], "__first__")
+ )
else:
deps_satisfied = False
if deps_satisfied:
@@ -315,21 +335,33 @@ class MigrationAutodetector:
# Make a migration! Well, only if there's stuff to put in it
if dependencies or chopped:
if not self.generated_operations[app_label] or chop_mode:
- subclass = type("Migration", (Migration,), {"operations": [], "dependencies": []})
- instance = subclass("auto_%i" % (len(self.migrations.get(app_label, [])) + 1), app_label)
+ subclass = type(
+ "Migration",
+ (Migration,),
+ {"operations": [], "dependencies": []},
+ )
+ instance = subclass(
+ "auto_%i" % (len(self.migrations.get(app_label, [])) + 1),
+ app_label,
+ )
instance.dependencies = list(dependencies)
instance.operations = chopped
instance.initial = app_label not in self.existing_apps
self.migrations.setdefault(app_label, []).append(instance)
chop_mode = False
else:
- self.generated_operations[app_label] = chopped + self.generated_operations[app_label]
+ self.generated_operations[app_label] = (
+ chopped + self.generated_operations[app_label]
+ )
new_num_ops = sum(len(x) for x in self.generated_operations.values())
if new_num_ops == num_ops:
if not chop_mode:
chop_mode = True
else:
- raise ValueError("Cannot resolve operation dependencies: %r" % self.generated_operations)
+ raise ValueError(
+ "Cannot resolve operation dependencies: %r"
+ % self.generated_operations
+ )
num_ops = new_num_ops
def _sort_migrations(self):
@@ -351,7 +383,9 @@ class MigrationAutodetector:
dependency_graph[op].add(op2)
# we use a stable sort for deterministic tests & general behavior
- self.generated_operations[app_label] = stable_topological_sort(ops, dependency_graph)
+ self.generated_operations[app_label] = stable_topological_sort(
+ ops, dependency_graph
+ )
def _optimize_migrations(self):
# Add in internal dependencies among the migrations
@@ -367,7 +401,9 @@ class MigrationAutodetector:
# Optimize migrations
for app_label, migrations in self.migrations.items():
for migration in migrations:
- migration.operations = MigrationOptimizer().optimize(migration.operations, app_label)
+ migration.operations = MigrationOptimizer().optimize(
+ migration.operations, app_label
+ )
def check_dependency(self, operation, dependency):
"""
@@ -377,56 +413,56 @@ class MigrationAutodetector:
# Created model
if dependency[2] is None and dependency[3] is True:
return (
- isinstance(operation, operations.CreateModel) and
- operation.name_lower == dependency[1].lower()
+ isinstance(operation, operations.CreateModel)
+ and operation.name_lower == dependency[1].lower()
)
# Created field
elif dependency[2] is not None and dependency[3] is True:
return (
- (
- isinstance(operation, operations.CreateModel) and
- operation.name_lower == dependency[1].lower() and
- any(dependency[2] == x for x, y in operation.fields)
- ) or
- (
- isinstance(operation, operations.AddField) and
- operation.model_name_lower == dependency[1].lower() and
- operation.name_lower == dependency[2].lower()
- )
+ isinstance(operation, operations.CreateModel)
+ and operation.name_lower == dependency[1].lower()
+ and any(dependency[2] == x for x, y in operation.fields)
+ ) or (
+ isinstance(operation, operations.AddField)
+ and operation.model_name_lower == dependency[1].lower()
+ and operation.name_lower == dependency[2].lower()
)
# Removed field
elif dependency[2] is not None and dependency[3] is False:
return (
- isinstance(operation, operations.RemoveField) and
- operation.model_name_lower == dependency[1].lower() and
- operation.name_lower == dependency[2].lower()
+ isinstance(operation, operations.RemoveField)
+ and operation.model_name_lower == dependency[1].lower()
+ and operation.name_lower == dependency[2].lower()
)
# Removed model
elif dependency[2] is None and dependency[3] is False:
return (
- isinstance(operation, operations.DeleteModel) and
- operation.name_lower == dependency[1].lower()
+ isinstance(operation, operations.DeleteModel)
+ and operation.name_lower == dependency[1].lower()
)
# Field being altered
elif dependency[2] is not None and dependency[3] == "alter":
return (
- isinstance(operation, operations.AlterField) and
- operation.model_name_lower == dependency[1].lower() and
- operation.name_lower == dependency[2].lower()
+ isinstance(operation, operations.AlterField)
+ and operation.model_name_lower == dependency[1].lower()
+ and operation.name_lower == dependency[2].lower()
)
# order_with_respect_to being unset for a field
elif dependency[2] is not None and dependency[3] == "order_wrt_unset":
return (
- isinstance(operation, operations.AlterOrderWithRespectTo) and
- operation.name_lower == dependency[1].lower() and
- (operation.order_with_respect_to or "").lower() != dependency[2].lower()
+ isinstance(operation, operations.AlterOrderWithRespectTo)
+ and operation.name_lower == dependency[1].lower()
+ and (operation.order_with_respect_to or "").lower()
+ != dependency[2].lower()
)
# Field is removed and part of an index/unique_together
elif dependency[2] is not None and dependency[3] == "foo_together_change":
return (
- isinstance(operation, (operations.AlterUniqueTogether,
- operations.AlterIndexTogether)) and
- operation.name_lower == dependency[1].lower()
+ isinstance(
+ operation,
+ (operations.AlterUniqueTogether, operations.AlterIndexTogether),
+ )
+ and operation.name_lower == dependency[1].lower()
)
# Unknown dependency. Raise an error.
else:
@@ -453,10 +489,10 @@ class MigrationAutodetector:
}
string_version = "%s.%s" % (item[0], item[1])
if (
- model_state.options.get('swappable') or
- "AbstractUser" in base_names or
- "AbstractBaseUser" in base_names or
- settings.AUTH_USER_MODEL.lower() == string_version.lower()
+ model_state.options.get("swappable")
+ or "AbstractUser" in base_names
+ or "AbstractBaseUser" in base_names
+ or settings.AUTH_USER_MODEL.lower() == string_version.lower()
):
return ("___" + item[0], "___" + item[1])
except LookupError:
@@ -479,21 +515,32 @@ class MigrationAutodetector:
removed_models = self.old_model_keys - self.new_model_keys
for rem_app_label, rem_model_name in removed_models:
if rem_app_label == app_label:
- rem_model_state = self.from_state.models[rem_app_label, rem_model_name]
- rem_model_fields_def = self.only_relation_agnostic_fields(rem_model_state.fields)
+ rem_model_state = self.from_state.models[
+ rem_app_label, rem_model_name
+ ]
+ rem_model_fields_def = self.only_relation_agnostic_fields(
+ rem_model_state.fields
+ )
if model_fields_def == rem_model_fields_def:
- if self.questioner.ask_rename_model(rem_model_state, model_state):
+ if self.questioner.ask_rename_model(
+ rem_model_state, model_state
+ ):
dependencies = []
fields = list(model_state.fields.values()) + [
field.remote_field
- for relations in self.to_state.relations[app_label, model_name].values()
+ for relations in self.to_state.relations[
+ app_label, model_name
+ ].values()
for field in relations.values()
]
for field in fields:
if field.is_relation:
dependencies.extend(
self._get_dependencies_for_foreign_key(
- app_label, model_name, field, self.to_state,
+ app_label,
+ model_name,
+ field,
+ self.to_state,
)
)
self.add_operation(
@@ -505,11 +552,13 @@ class MigrationAutodetector:
dependencies=dependencies,
)
self.renamed_models[app_label, model_name] = rem_model_name
- renamed_models_rel_key = '%s.%s' % (
+ renamed_models_rel_key = "%s.%s" % (
rem_model_state.app_label,
rem_model_state.name_lower,
)
- self.renamed_models_rel[renamed_models_rel_key] = '%s.%s' % (
+ self.renamed_models_rel[
+ renamed_models_rel_key
+ ] = "%s.%s" % (
model_state.app_label,
model_state.name_lower,
)
@@ -532,7 +581,7 @@ class MigrationAutodetector:
added_unmanaged_models = self.new_unmanaged_keys - old_keys
all_added_models = chain(
sorted(added_models, key=self.swappable_first_key, reverse=True),
- sorted(added_unmanaged_models, key=self.swappable_first_key, reverse=True)
+ sorted(added_unmanaged_models, key=self.swappable_first_key, reverse=True),
)
for app_label, model_name in all_added_models:
model_state = self.to_state.models[app_label, model_name]
@@ -546,15 +595,17 @@ class MigrationAutodetector:
primary_key_rel = field.remote_field.model
elif not field.remote_field.parent_link:
related_fields[field_name] = field
- if getattr(field.remote_field, 'through', None):
+ if getattr(field.remote_field, "through", None):
related_fields[field_name] = field
# Are there indexes/unique|index_together to defer?
- indexes = model_state.options.pop('indexes')
- constraints = model_state.options.pop('constraints')
- unique_together = model_state.options.pop('unique_together', None)
- index_together = model_state.options.pop('index_together', None)
- order_with_respect_to = model_state.options.pop('order_with_respect_to', None)
+ indexes = model_state.options.pop("indexes")
+ constraints = model_state.options.pop("constraints")
+ unique_together = model_state.options.pop("unique_together", None)
+ index_together = model_state.options.pop("index_together", None)
+ order_with_respect_to = model_state.options.pop(
+ "order_with_respect_to", None
+ )
# Depend on the deletion of any possible proxy version of us
dependencies = [
(app_label, model_name, None, False),
@@ -566,27 +617,44 @@ class MigrationAutodetector:
dependencies.append((base_app_label, base_name, None, True))
# Depend on the removal of base fields if the new model has
# a field with the same name.
- old_base_model_state = self.from_state.models.get((base_app_label, base_name))
- new_base_model_state = self.to_state.models.get((base_app_label, base_name))
+ old_base_model_state = self.from_state.models.get(
+ (base_app_label, base_name)
+ )
+ new_base_model_state = self.to_state.models.get(
+ (base_app_label, base_name)
+ )
if old_base_model_state and new_base_model_state:
- removed_base_fields = set(old_base_model_state.fields).difference(
- new_base_model_state.fields,
- ).intersection(model_state.fields)
+ removed_base_fields = (
+ set(old_base_model_state.fields)
+ .difference(
+ new_base_model_state.fields,
+ )
+ .intersection(model_state.fields)
+ )
for removed_base_field in removed_base_fields:
- dependencies.append((base_app_label, base_name, removed_base_field, False))
+ dependencies.append(
+ (base_app_label, base_name, removed_base_field, False)
+ )
# Depend on the other end of the primary key if it's a relation
if primary_key_rel:
dependencies.append(
resolve_relation(
- primary_key_rel, app_label, model_name,
- ) + (None, True)
+ primary_key_rel,
+ app_label,
+ model_name,
+ )
+ + (None, True)
)
# Generate creation operation
self.add_operation(
app_label,
operations.CreateModel(
name=model_state.name,
- fields=[d for d in model_state.fields.items() if d[0] not in related_fields],
+ fields=[
+ d
+ for d in model_state.fields.items()
+ if d[0] not in related_fields
+ ],
options=model_state.options,
bases=model_state.bases,
managers=model_state.managers,
@@ -596,13 +664,16 @@ class MigrationAutodetector:
)
# Don't add operations which modify the database for unmanaged models
- if not model_state.options.get('managed', True):
+ if not model_state.options.get("managed", True):
continue
# Generate operations for each related field
for name, field in sorted(related_fields.items()):
dependencies = self._get_dependencies_for_foreign_key(
- app_label, model_name, field, self.to_state,
+ app_label,
+ model_name,
+ field,
+ self.to_state,
)
# Depend on our own model being created
dependencies.append((app_label, model_name, None, True))
@@ -627,11 +698,10 @@ class MigrationAutodetector:
dependencies=[
(app_label, model_name, order_with_respect_to, True),
(app_label, model_name, None, True),
- ]
+ ],
)
related_dependencies = [
- (app_label, model_name, name, True)
- for name in sorted(related_fields)
+ (app_label, model_name, name, True) for name in sorted(related_fields)
]
related_dependencies.append((app_label, model_name, None, True))
for index in indexes:
@@ -659,7 +729,7 @@ class MigrationAutodetector:
name=model_name,
unique_together=unique_together,
),
- dependencies=related_dependencies
+ dependencies=related_dependencies,
)
if index_together:
self.add_operation(
@@ -668,13 +738,15 @@ class MigrationAutodetector:
name=model_name,
index_together=index_together,
),
- dependencies=related_dependencies
+ dependencies=related_dependencies,
)
# Fix relationships if the model changed from a proxy model to a
# concrete model.
relations = self.to_state.relations
if (app_label, model_name) in self.old_proxy_keys:
- for related_model_key, related_fields in relations[app_label, model_name].items():
+ for related_model_key, related_fields in relations[
+ app_label, model_name
+ ].items():
related_model_state = self.to_state.models[related_model_key]
for related_field_name, related_field in related_fields.items():
self.add_operation(
@@ -733,7 +805,9 @@ class MigrationAutodetector:
new_keys = self.new_model_keys | self.new_unmanaged_keys
deleted_models = self.old_model_keys - new_keys
deleted_unmanaged_models = self.old_unmanaged_keys - new_keys
- all_deleted_models = chain(sorted(deleted_models), sorted(deleted_unmanaged_models))
+ all_deleted_models = chain(
+ sorted(deleted_models), sorted(deleted_unmanaged_models)
+ )
for app_label, model_name in all_deleted_models:
model_state = self.from_state.models[app_label, model_name]
# Gather related fields
@@ -742,18 +816,18 @@ class MigrationAutodetector:
if field.remote_field:
if field.remote_field.model:
related_fields[field_name] = field
- if getattr(field.remote_field, 'through', None):
+ if getattr(field.remote_field, "through", None):
related_fields[field_name] = field
# Generate option removal first
- unique_together = model_state.options.pop('unique_together', None)
- index_together = model_state.options.pop('index_together', None)
+ unique_together = model_state.options.pop("unique_together", None)
+ index_together = model_state.options.pop("index_together", None)
if unique_together:
self.add_operation(
app_label,
operations.AlterUniqueTogether(
name=model_name,
unique_together=None,
- )
+ ),
)
if index_together:
self.add_operation(
@@ -761,7 +835,7 @@ class MigrationAutodetector:
operations.AlterIndexTogether(
name=model_name,
index_together=None,
- )
+ ),
)
# Then remove each related field
for name in sorted(related_fields):
@@ -770,7 +844,7 @@ class MigrationAutodetector:
operations.RemoveField(
model_name=model_name,
name=name,
- )
+ ),
)
# Finally, remove the model.
# This depends on both the removal/alteration of all incoming fields
@@ -778,16 +852,22 @@ class MigrationAutodetector:
# a through model the field that references it.
dependencies = []
relations = self.from_state.relations
- for (related_object_app_label, object_name), relation_related_fields in (
- relations[app_label, model_name].items()
- ):
+ for (
+ related_object_app_label,
+ object_name,
+ ), relation_related_fields in relations[app_label, model_name].items():
for field_name, field in relation_related_fields.items():
dependencies.append(
(related_object_app_label, object_name, field_name, False),
)
if not field.many_to_many:
dependencies.append(
- (related_object_app_label, object_name, field_name, 'alter'),
+ (
+ related_object_app_label,
+ object_name,
+ field_name,
+ "alter",
+ ),
)
for name in sorted(related_fields):
@@ -795,7 +875,9 @@ class MigrationAutodetector:
# We're referenced in another field's through=
through_user = self.through_users.get((app_label, model_state.name_lower))
if through_user:
- dependencies.append((through_user[0], through_user[1], through_user[2], False))
+ dependencies.append(
+ (through_user[0], through_user[1], through_user[2], False)
+ )
# Finally, make the operation, deduping any dependencies
self.add_operation(
app_label,
@@ -821,29 +903,43 @@ class MigrationAutodetector:
def generate_renamed_fields(self):
"""Work out renamed fields."""
self.renamed_fields = {}
- for app_label, model_name, field_name in sorted(self.new_field_keys - self.old_field_keys):
- old_model_name = self.renamed_models.get((app_label, model_name), model_name)
+ for app_label, model_name, field_name in sorted(
+ self.new_field_keys - self.old_field_keys
+ ):
+ old_model_name = self.renamed_models.get(
+ (app_label, model_name), model_name
+ )
old_model_state = self.from_state.models[app_label, old_model_name]
new_model_state = self.to_state.models[app_label, model_name]
field = new_model_state.get_field(field_name)
# Scan to see if this is actually a rename!
field_dec = self.deep_deconstruct(field)
- for rem_app_label, rem_model_name, rem_field_name in sorted(self.old_field_keys - self.new_field_keys):
+ for rem_app_label, rem_model_name, rem_field_name in sorted(
+ self.old_field_keys - self.new_field_keys
+ ):
if rem_app_label == app_label and rem_model_name == model_name:
old_field = old_model_state.get_field(rem_field_name)
old_field_dec = self.deep_deconstruct(old_field)
- if field.remote_field and field.remote_field.model and 'to' in old_field_dec[2]:
- old_rel_to = old_field_dec[2]['to']
+ if (
+ field.remote_field
+ and field.remote_field.model
+ and "to" in old_field_dec[2]
+ ):
+ old_rel_to = old_field_dec[2]["to"]
if old_rel_to in self.renamed_models_rel:
- old_field_dec[2]['to'] = self.renamed_models_rel[old_rel_to]
+ old_field_dec[2]["to"] = self.renamed_models_rel[old_rel_to]
old_field.set_attributes_from_name(rem_field_name)
old_db_column = old_field.get_attname_column()[1]
- if (old_field_dec == field_dec or (
- # Was the field renamed and db_column equal to the
- # old field's column added?
- old_field_dec[0:2] == field_dec[0:2] and
- dict(old_field_dec[2], db_column=old_db_column) == field_dec[2])):
- if self.questioner.ask_rename(model_name, rem_field_name, field_name, field):
+ if old_field_dec == field_dec or (
+ # Was the field renamed and db_column equal to the
+ # old field's column added?
+ old_field_dec[0:2] == field_dec[0:2]
+ and dict(old_field_dec[2], db_column=old_db_column)
+ == field_dec[2]
+ ):
+ if self.questioner.ask_rename(
+ model_name, rem_field_name, field_name, field
+ ):
# A db_column mismatch requires a prior noop
# AlterField for the subsequent RenameField to be a
# noop on attempts at preserving the old name.
@@ -864,16 +960,22 @@ class MigrationAutodetector:
model_name=model_name,
old_name=rem_field_name,
new_name=field_name,
- )
+ ),
+ )
+ self.old_field_keys.remove(
+ (rem_app_label, rem_model_name, rem_field_name)
)
- self.old_field_keys.remove((rem_app_label, rem_model_name, rem_field_name))
self.old_field_keys.add((app_label, model_name, field_name))
- self.renamed_fields[app_label, model_name, field_name] = rem_field_name
+ self.renamed_fields[
+ app_label, model_name, field_name
+ ] = rem_field_name
break
def generate_added_fields(self):
"""Make AddField operations."""
- for app_label, model_name, field_name in sorted(self.new_field_keys - self.old_field_keys):
+ for app_label, model_name, field_name in sorted(
+ self.new_field_keys - self.old_field_keys
+ ):
self._generate_added_field(app_label, model_name, field_name)
def _generate_added_field(self, app_label, model_name, field_name):
@@ -881,27 +983,38 @@ class MigrationAutodetector:
# Fields that are foreignkeys/m2ms depend on stuff
dependencies = []
if field.remote_field and field.remote_field.model:
- dependencies.extend(self._get_dependencies_for_foreign_key(
- app_label, model_name, field, self.to_state,
- ))
+ dependencies.extend(
+ self._get_dependencies_for_foreign_key(
+ app_label,
+ model_name,
+ field,
+ self.to_state,
+ )
+ )
# You can't just add NOT NULL fields with no default or fields
# which don't allow empty strings as default.
time_fields = (models.DateField, models.DateTimeField, models.TimeField)
preserve_default = (
- field.null or field.has_default() or field.many_to_many or
- (field.blank and field.empty_strings_allowed) or
- (isinstance(field, time_fields) and field.auto_now)
+ field.null
+ or field.has_default()
+ or field.many_to_many
+ or (field.blank and field.empty_strings_allowed)
+ or (isinstance(field, time_fields) and field.auto_now)
)
if not preserve_default:
field = field.clone()
if isinstance(field, time_fields) and field.auto_now_add:
- field.default = self.questioner.ask_auto_now_add_addition(field_name, model_name)
+ field.default = self.questioner.ask_auto_now_add_addition(
+ field_name, model_name
+ )
else:
- field.default = self.questioner.ask_not_null_addition(field_name, model_name)
+ field.default = self.questioner.ask_not_null_addition(
+ field_name, model_name
+ )
if (
- field.unique and
- field.default is not models.NOT_PROVIDED and
- callable(field.default)
+ field.unique
+ and field.default is not models.NOT_PROVIDED
+ and callable(field.default)
):
self.questioner.ask_unique_callable_default_addition(field_name, model_name)
self.add_operation(
@@ -917,7 +1030,9 @@ class MigrationAutodetector:
def generate_removed_fields(self):
"""Make RemoveField operations."""
- for app_label, model_name, field_name in sorted(self.old_field_keys - self.new_field_keys):
+ for app_label, model_name, field_name in sorted(
+ self.old_field_keys - self.new_field_keys
+ ):
self._generate_removed_field(app_label, model_name, field_name)
def _generate_removed_field(self, app_label, model_name, field_name):
@@ -941,21 +1056,35 @@ class MigrationAutodetector:
Make AlterField operations, or possibly RemovedField/AddField if alter
isn't possible.
"""
- for app_label, model_name, field_name in sorted(self.old_field_keys & self.new_field_keys):
+ for app_label, model_name, field_name in sorted(
+ self.old_field_keys & self.new_field_keys
+ ):
# Did the field change?
- old_model_name = self.renamed_models.get((app_label, model_name), model_name)
- old_field_name = self.renamed_fields.get((app_label, model_name, field_name), field_name)
- old_field = self.from_state.models[app_label, old_model_name].get_field(old_field_name)
- new_field = self.to_state.models[app_label, model_name].get_field(field_name)
+ old_model_name = self.renamed_models.get(
+ (app_label, model_name), model_name
+ )
+ old_field_name = self.renamed_fields.get(
+ (app_label, model_name, field_name), field_name
+ )
+ old_field = self.from_state.models[app_label, old_model_name].get_field(
+ old_field_name
+ )
+ new_field = self.to_state.models[app_label, model_name].get_field(
+ field_name
+ )
dependencies = []
# Implement any model renames on relations; these are handled by RenameModel
# so we need to exclude them from the comparison
- if hasattr(new_field, "remote_field") and getattr(new_field.remote_field, "model", None):
- rename_key = resolve_relation(new_field.remote_field.model, app_label, model_name)
+ if hasattr(new_field, "remote_field") and getattr(
+ new_field.remote_field, "model", None
+ ):
+ rename_key = resolve_relation(
+ new_field.remote_field.model, app_label, model_name
+ )
if rename_key in self.renamed_models:
new_field.remote_field.model = old_field.remote_field.model
# Handle ForeignKey which can only have a single to_field.
- remote_field_name = getattr(new_field.remote_field, 'field_name', None)
+ remote_field_name = getattr(new_field.remote_field, "field_name", None)
if remote_field_name:
to_field_rename_key = rename_key + (remote_field_name,)
if to_field_rename_key in self.renamed_fields:
@@ -963,27 +1092,41 @@ class MigrationAutodetector:
# inclusion in ForeignKey.deconstruct() is based on
# both.
new_field.remote_field.model = old_field.remote_field.model
- new_field.remote_field.field_name = old_field.remote_field.field_name
+ new_field.remote_field.field_name = (
+ old_field.remote_field.field_name
+ )
# Handle ForeignObjects which can have multiple from_fields/to_fields.
- from_fields = getattr(new_field, 'from_fields', None)
+ from_fields = getattr(new_field, "from_fields", None)
if from_fields:
from_rename_key = (app_label, model_name)
- new_field.from_fields = tuple([
- self.renamed_fields.get(from_rename_key + (from_field,), from_field)
- for from_field in from_fields
- ])
- new_field.to_fields = tuple([
- self.renamed_fields.get(rename_key + (to_field,), to_field)
- for to_field in new_field.to_fields
- ])
- dependencies.extend(self._get_dependencies_for_foreign_key(
- app_label, model_name, new_field, self.to_state,
- ))
- if (
- hasattr(new_field, 'remote_field') and
- getattr(new_field.remote_field, 'through', None)
+ new_field.from_fields = tuple(
+ [
+ self.renamed_fields.get(
+ from_rename_key + (from_field,), from_field
+ )
+ for from_field in from_fields
+ ]
+ )
+ new_field.to_fields = tuple(
+ [
+ self.renamed_fields.get(rename_key + (to_field,), to_field)
+ for to_field in new_field.to_fields
+ ]
+ )
+ dependencies.extend(
+ self._get_dependencies_for_foreign_key(
+ app_label,
+ model_name,
+ new_field,
+ self.to_state,
+ )
+ )
+ if hasattr(new_field, "remote_field") and getattr(
+ new_field.remote_field, "through", None
):
- rename_key = resolve_relation(new_field.remote_field.through, app_label, model_name)
+ rename_key = resolve_relation(
+ new_field.remote_field.through, app_label, model_name
+ )
if rename_key in self.renamed_models:
new_field.remote_field.through = old_field.remote_field.through
old_field_dec = self.deep_deconstruct(old_field)
@@ -997,10 +1140,16 @@ class MigrationAutodetector:
if both_m2m or neither_m2m:
# Either both fields are m2m or neither is
preserve_default = True
- if (old_field.null and not new_field.null and not new_field.has_default() and
- not new_field.many_to_many):
+ if (
+ old_field.null
+ and not new_field.null
+ and not new_field.has_default()
+ and not new_field.many_to_many
+ ):
field = new_field.clone()
- new_default = self.questioner.ask_not_null_alteration(field_name, model_name)
+ new_default = self.questioner.ask_not_null_alteration(
+ field_name, model_name
+ )
if new_default is not models.NOT_PROVIDED:
field.default = new_default
preserve_default = False
@@ -1024,7 +1173,9 @@ class MigrationAutodetector:
def create_altered_indexes(self):
option_name = operations.AddIndex.option_name
for app_label, model_name in sorted(self.kept_model_keys):
- old_model_name = self.renamed_models.get((app_label, model_name), model_name)
+ old_model_name = self.renamed_models.get(
+ (app_label, model_name), model_name
+ )
old_model_state = self.from_state.models[app_label, old_model_name]
new_model_state = self.to_state.models[app_label, model_name]
@@ -1033,38 +1184,43 @@ class MigrationAutodetector:
add_idx = [idx for idx in new_indexes if idx not in old_indexes]
rem_idx = [idx for idx in old_indexes if idx not in new_indexes]
- self.altered_indexes.update({
- (app_label, model_name): {
- 'added_indexes': add_idx, 'removed_indexes': rem_idx,
+ self.altered_indexes.update(
+ {
+ (app_label, model_name): {
+ "added_indexes": add_idx,
+ "removed_indexes": rem_idx,
+ }
}
- })
+ )
def generate_added_indexes(self):
for (app_label, model_name), alt_indexes in self.altered_indexes.items():
- for index in alt_indexes['added_indexes']:
+ for index in alt_indexes["added_indexes"]:
self.add_operation(
app_label,
operations.AddIndex(
model_name=model_name,
index=index,
- )
+ ),
)
def generate_removed_indexes(self):
for (app_label, model_name), alt_indexes in self.altered_indexes.items():
- for index in alt_indexes['removed_indexes']:
+ for index in alt_indexes["removed_indexes"]:
self.add_operation(
app_label,
operations.RemoveIndex(
model_name=model_name,
name=index.name,
- )
+ ),
)
def create_altered_constraints(self):
option_name = operations.AddConstraint.option_name
for app_label, model_name in sorted(self.kept_model_keys):
- old_model_name = self.renamed_models.get((app_label, model_name), model_name)
+ old_model_name = self.renamed_models.get(
+ (app_label, model_name), model_name
+ )
old_model_state = self.from_state.models[app_label, old_model_name]
new_model_state = self.to_state.models[app_label, model_name]
@@ -1073,38 +1229,47 @@ class MigrationAutodetector:
add_constraints = [c for c in new_constraints if c not in old_constraints]
rem_constraints = [c for c in old_constraints if c not in new_constraints]
- self.altered_constraints.update({
- (app_label, model_name): {
- 'added_constraints': add_constraints, 'removed_constraints': rem_constraints,
+ self.altered_constraints.update(
+ {
+ (app_label, model_name): {
+ "added_constraints": add_constraints,
+ "removed_constraints": rem_constraints,
+ }
}
- })
+ )
def generate_added_constraints(self):
- for (app_label, model_name), alt_constraints in self.altered_constraints.items():
- for constraint in alt_constraints['added_constraints']:
+ for (
+ app_label,
+ model_name,
+ ), alt_constraints in self.altered_constraints.items():
+ for constraint in alt_constraints["added_constraints"]:
self.add_operation(
app_label,
operations.AddConstraint(
model_name=model_name,
constraint=constraint,
- )
+ ),
)
def generate_removed_constraints(self):
- for (app_label, model_name), alt_constraints in self.altered_constraints.items():
- for constraint in alt_constraints['removed_constraints']:
+ for (
+ app_label,
+ model_name,
+ ), alt_constraints in self.altered_constraints.items():
+ for constraint in alt_constraints["removed_constraints"]:
self.add_operation(
app_label,
operations.RemoveConstraint(
model_name=model_name,
name=constraint.name,
- )
+ ),
)
@staticmethod
def _get_dependencies_for_foreign_key(app_label, model_name, field, project_state):
remote_field_model = None
- if hasattr(field.remote_field, 'model'):
+ if hasattr(field.remote_field, "model"):
remote_field_model = field.remote_field.model
else:
relations = project_state.relations[app_label, model_name]
@@ -1113,40 +1278,50 @@ class MigrationAutodetector:
field == related_field.remote_field
for related_field in fields.values()
):
- remote_field_model = f'{remote_app_label}.{remote_model_name}'
+ remote_field_model = f"{remote_app_label}.{remote_model_name}"
break
# Account for FKs to swappable models
- swappable_setting = getattr(field, 'swappable_setting', None)
+ swappable_setting = getattr(field, "swappable_setting", None)
if swappable_setting is not None:
dep_app_label = "__setting__"
dep_object_name = swappable_setting
else:
dep_app_label, dep_object_name = resolve_relation(
- remote_field_model, app_label, model_name,
+ remote_field_model,
+ app_label,
+ model_name,
)
dependencies = [(dep_app_label, dep_object_name, None, True)]
- if getattr(field.remote_field, 'through', None):
+ if getattr(field.remote_field, "through", None):
through_app_label, through_object_name = resolve_relation(
- remote_field_model, app_label, model_name,
+ remote_field_model,
+ app_label,
+ model_name,
)
dependencies.append((through_app_label, through_object_name, None, True))
return dependencies
def _get_altered_foo_together_operations(self, option_name):
for app_label, model_name in sorted(self.kept_model_keys):
- old_model_name = self.renamed_models.get((app_label, model_name), model_name)
+ old_model_name = self.renamed_models.get(
+ (app_label, model_name), model_name
+ )
old_model_state = self.from_state.models[app_label, old_model_name]
new_model_state = self.to_state.models[app_label, model_name]
# We run the old version through the field renames to account for those
old_value = old_model_state.options.get(option_name)
- old_value = {
- tuple(
- self.renamed_fields.get((app_label, model_name, n), n)
- for n in unique
- )
- for unique in old_value
- } if old_value else set()
+ old_value = (
+ {
+ tuple(
+ self.renamed_fields.get((app_label, model_name, n), n)
+ for n in unique
+ )
+ for unique in old_value
+ }
+ if old_value
+ else set()
+ )
new_value = new_model_state.options.get(option_name)
new_value = set(new_value) if new_value else set()
@@ -1157,9 +1332,14 @@ class MigrationAutodetector:
for field_name in foo_togethers:
field = new_model_state.get_field(field_name)
if field.remote_field and field.remote_field.model:
- dependencies.extend(self._get_dependencies_for_foreign_key(
- app_label, model_name, field, self.to_state,
- ))
+ dependencies.extend(
+ self._get_dependencies_for_foreign_key(
+ app_label,
+ model_name,
+ field,
+ self.to_state,
+ )
+ )
yield (
old_value,
new_value,
@@ -1180,7 +1360,9 @@ class MigrationAutodetector:
if removal_value or old_value:
self.add_operation(
app_label,
- operation(name=model_name, **{operation.option_name: removal_value}),
+ operation(
+ name=model_name, **{operation.option_name: removal_value}
+ ),
dependencies=dependencies,
)
@@ -1213,20 +1395,24 @@ class MigrationAutodetector:
self._generate_altered_foo_together(operations.AlterIndexTogether)
def generate_altered_db_table(self):
- models_to_check = self.kept_model_keys.union(self.kept_proxy_keys, self.kept_unmanaged_keys)
+ models_to_check = self.kept_model_keys.union(
+ self.kept_proxy_keys, self.kept_unmanaged_keys
+ )
for app_label, model_name in sorted(models_to_check):
- old_model_name = self.renamed_models.get((app_label, model_name), model_name)
+ old_model_name = self.renamed_models.get(
+ (app_label, model_name), model_name
+ )
old_model_state = self.from_state.models[app_label, old_model_name]
new_model_state = self.to_state.models[app_label, model_name]
- old_db_table_name = old_model_state.options.get('db_table')
- new_db_table_name = new_model_state.options.get('db_table')
+ old_db_table_name = old_model_state.options.get("db_table")
+ new_db_table_name = new_model_state.options.get("db_table")
if old_db_table_name != new_db_table_name:
self.add_operation(
app_label,
operations.AlterModelTable(
name=model_name,
table=new_db_table_name,
- )
+ ),
)
def generate_altered_options(self):
@@ -1245,15 +1431,19 @@ class MigrationAutodetector:
)
for app_label, model_name in sorted(models_to_check):
- old_model_name = self.renamed_models.get((app_label, model_name), model_name)
+ old_model_name = self.renamed_models.get(
+ (app_label, model_name), model_name
+ )
old_model_state = self.from_state.models[app_label, old_model_name]
new_model_state = self.to_state.models[app_label, model_name]
old_options = {
- key: value for key, value in old_model_state.options.items()
+ key: value
+ for key, value in old_model_state.options.items()
if key in AlterModelOptions.ALTER_OPTION_KEYS
}
new_options = {
- key: value for key, value in new_model_state.options.items()
+ key: value
+ for key, value in new_model_state.options.items()
if key in AlterModelOptions.ALTER_OPTION_KEYS
}
if old_options != new_options:
@@ -1262,39 +1452,48 @@ class MigrationAutodetector:
operations.AlterModelOptions(
name=model_name,
options=new_options,
- )
+ ),
)
def generate_altered_order_with_respect_to(self):
for app_label, model_name in sorted(self.kept_model_keys):
- old_model_name = self.renamed_models.get((app_label, model_name), model_name)
+ old_model_name = self.renamed_models.get(
+ (app_label, model_name), model_name
+ )
old_model_state = self.from_state.models[app_label, old_model_name]
new_model_state = self.to_state.models[app_label, model_name]
- if (old_model_state.options.get("order_with_respect_to") !=
- new_model_state.options.get("order_with_respect_to")):
+ if old_model_state.options.get(
+ "order_with_respect_to"
+ ) != new_model_state.options.get("order_with_respect_to"):
# Make sure it comes second if we're adding
# (removal dependency is part of RemoveField)
dependencies = []
if new_model_state.options.get("order_with_respect_to"):
- dependencies.append((
- app_label,
- model_name,
- new_model_state.options["order_with_respect_to"],
- True,
- ))
+ dependencies.append(
+ (
+ app_label,
+ model_name,
+ new_model_state.options["order_with_respect_to"],
+ True,
+ )
+ )
# Actually generate the operation
self.add_operation(
app_label,
operations.AlterOrderWithRespectTo(
name=model_name,
- order_with_respect_to=new_model_state.options.get('order_with_respect_to'),
+ order_with_respect_to=new_model_state.options.get(
+ "order_with_respect_to"
+ ),
),
dependencies=dependencies,
)
def generate_altered_managers(self):
for app_label, model_name in sorted(self.kept_model_keys):
- old_model_name = self.renamed_models.get((app_label, model_name), model_name)
+ old_model_name = self.renamed_models.get(
+ (app_label, model_name), model_name
+ )
old_model_state = self.from_state.models[app_label, old_model_name]
new_model_state = self.to_state.models[app_label, model_name]
if old_model_state.managers != new_model_state.managers:
@@ -1303,7 +1502,7 @@ class MigrationAutodetector:
operations.AlterModelManagers(
name=model_name,
managers=new_model_state.managers,
- )
+ ),
)
def arrange_for_graph(self, changes, graph, migration_name=None):
@@ -1339,21 +1538,23 @@ class MigrationAutodetector:
for i, migration in enumerate(migrations):
if i == 0 and app_leaf:
migration.dependencies.append(app_leaf)
- new_name_parts = ['%04i' % next_number]
+ new_name_parts = ["%04i" % next_number]
if migration_name:
new_name_parts.append(migration_name)
elif i == 0 and not app_leaf:
- new_name_parts.append('initial')
+ new_name_parts.append("initial")
else:
new_name_parts.append(migration.suggest_name()[:100])
- new_name = '_'.join(new_name_parts)
+ new_name = "_".join(new_name_parts)
name_map[(app_label, migration.name)] = (app_label, new_name)
next_number += 1
migration.name = new_name
# Now fix dependencies
for migrations in changes.values():
for migration in migrations:
- migration.dependencies = [name_map.get(d, d) for d in migration.dependencies]
+ migration.dependencies = [
+ name_map.get(d, d) for d in migration.dependencies
+ ]
return changes
def _trim_to_apps(self, changes, app_labels):
@@ -1374,7 +1575,9 @@ class MigrationAutodetector:
old_required_apps = None
while old_required_apps != required_apps:
old_required_apps = set(required_apps)
- required_apps.update(*[app_dependencies.get(app_label, ()) for app_label in required_apps])
+ required_apps.update(
+ *[app_dependencies.get(app_label, ()) for app_label in required_apps]
+ )
# Remove all migrations that aren't needed
for app_label in list(changes):
if app_label not in required_apps:
@@ -1388,9 +1591,9 @@ class MigrationAutodetector:
it. For a squashed migration such as '0001_squashed_0004…', return the
second number. If no number is found, return None.
"""
- if squashed_match := re.search(r'.*_squashed_(\d+)', name):
+ if squashed_match := re.search(r".*_squashed_(\d+)", name):
return int(squashed_match[1])
- match = re.match(r'^\d+', name)
+ match = re.match(r"^\d+", name)
if match:
return int(match[0])
return None
diff --git a/django/db/migrations/exceptions.py b/django/db/migrations/exceptions.py
index 8def99da5b..dd556dacb5 100644
--- a/django/db/migrations/exceptions.py
+++ b/django/db/migrations/exceptions.py
@@ -3,31 +3,37 @@ from django.db import DatabaseError
class AmbiguityError(Exception):
"""More than one migration matches a name prefix."""
+
pass
class BadMigrationError(Exception):
"""There's a bad migration (unreadable/bad format/etc.)."""
+
pass
class CircularDependencyError(Exception):
"""There's an impossible-to-resolve circular dependency."""
+
pass
class InconsistentMigrationHistory(Exception):
"""An applied migration has some of its dependencies not applied."""
+
pass
class InvalidBasesError(ValueError):
"""A model's base classes can't be resolved."""
+
pass
class IrreversibleError(RuntimeError):
"""An irreversible migration is about to be reversed."""
+
pass
diff --git a/django/db/migrations/executor.py b/django/db/migrations/executor.py
index 89e9344a68..ef7f9060f1 100644
--- a/django/db/migrations/executor.py
+++ b/django/db/migrations/executor.py
@@ -43,8 +43,8 @@ class MigrationExecutor:
# If the target is missing, it's likely a replaced migration.
# Reload the graph without replacements.
if (
- self.loader.replace_migrations and
- target not in self.loader.graph.node_map
+ self.loader.replace_migrations
+ and target not in self.loader.graph.node_map
):
self.loader.replace_migrations = False
self.loader.build_graph()
@@ -54,8 +54,8 @@ class MigrationExecutor:
# be rolled back); instead roll back through target's immediate
# child(ren) in the same app, and no further.
next_in_app = sorted(
- n for n in
- self.loader.graph.node_map[target].children
+ n
+ for n in self.loader.graph.node_map[target].children
if n[0] == target[0]
)
for node in next_in_app:
@@ -78,9 +78,12 @@ class MigrationExecutor:
state = ProjectState(real_apps=self.loader.unmigrated_apps)
if with_applied_migrations:
# Create the forwards plan Django would follow on an empty database
- full_plan = self.migration_plan(self.loader.graph.leaf_nodes(), clean_start=True)
+ full_plan = self.migration_plan(
+ self.loader.graph.leaf_nodes(), clean_start=True
+ )
applied_migrations = {
- self.loader.graph.nodes[key] for key in self.loader.applied_migrations
+ self.loader.graph.nodes[key]
+ for key in self.loader.applied_migrations
if key in self.loader.graph.nodes
}
for migration, _ in full_plan:
@@ -106,7 +109,9 @@ class MigrationExecutor:
if plan is None:
plan = self.migration_plan(targets)
# Create the forwards plan Django would follow on an empty database
- full_plan = self.migration_plan(self.loader.graph.leaf_nodes(), clean_start=True)
+ full_plan = self.migration_plan(
+ self.loader.graph.leaf_nodes(), clean_start=True
+ )
all_forwards = all(not backwards for mig, backwards in plan)
all_backwards = all(backwards for mig, backwards in plan)
@@ -121,13 +126,15 @@ class MigrationExecutor:
"Migration plans with both forwards and backwards migrations "
"are not supported. Please split your migration process into "
"separate plans of only forwards OR backwards migrations.",
- plan
+ plan,
)
elif all_forwards:
if state is None:
# The resulting state should still include applied migrations.
state = self._create_project_state(with_applied_migrations=True)
- state = self._migrate_all_forwards(state, plan, full_plan, fake=fake, fake_initial=fake_initial)
+ state = self._migrate_all_forwards(
+ state, plan, full_plan, fake=fake, fake_initial=fake_initial
+ )
else:
# No need to check for `elif all_backwards` here, as that condition
# would always evaluate to true.
@@ -151,13 +158,15 @@ class MigrationExecutor:
# process.
break
if migration in migrations_to_run:
- if 'apps' not in state.__dict__:
+ if "apps" not in state.__dict__:
if self.progress_callback:
self.progress_callback("render_start")
state.apps # Render all -- performance critical
if self.progress_callback:
self.progress_callback("render_success")
- state = self.apply_migration(state, migration, fake=fake, fake_initial=fake_initial)
+ state = self.apply_migration(
+ state, migration, fake=fake, fake_initial=fake_initial
+ )
migrations_to_run.remove(migration)
return state
@@ -177,7 +186,8 @@ class MigrationExecutor:
states = {}
state = self._create_project_state()
applied_migrations = {
- self.loader.graph.nodes[key] for key in self.loader.applied_migrations
+ self.loader.graph.nodes[key]
+ for key in self.loader.applied_migrations
if key in self.loader.graph.nodes
}
if self.progress_callback:
@@ -190,7 +200,7 @@ class MigrationExecutor:
# process.
break
if migration in migrations_to_run:
- if 'apps' not in state.__dict__:
+ if "apps" not in state.__dict__:
state.apps # Render all -- performance critical
# The state before this migration
states[migration] = state
@@ -236,7 +246,9 @@ class MigrationExecutor:
fake = True
if not fake:
# Alright, do it normally
- with self.connection.schema_editor(atomic=migration.atomic) as schema_editor:
+ with self.connection.schema_editor(
+ atomic=migration.atomic
+ ) as schema_editor:
state = migration.apply(state, schema_editor)
if not schema_editor.deferred_sql:
self.record_migration(migration)
@@ -261,7 +273,9 @@ class MigrationExecutor:
if self.progress_callback:
self.progress_callback("unapply_start", migration, fake)
if not fake:
- with self.connection.schema_editor(atomic=migration.atomic) as schema_editor:
+ with self.connection.schema_editor(
+ atomic=migration.atomic
+ ) as schema_editor:
state = migration.unapply(state, schema_editor)
# For replacement migrations, also record individual statuses.
if migration.replaces:
@@ -296,15 +310,18 @@ class MigrationExecutor:
tables or columns it would create exist. This is intended only for use
on initial migrations (as it only looks for CreateModel and AddField).
"""
+
def should_skip_detecting_model(migration, model):
"""
No need to detect tables for proxy models, unmanaged models, or
models that can't be migrated on the current database.
"""
return (
- model._meta.proxy or not model._meta.managed or not
- router.allow_migrate(
- self.connection.alias, migration.app_label,
+ model._meta.proxy
+ or not model._meta.managed
+ or not router.allow_migrate(
+ self.connection.alias,
+ migration.app_label,
model_name=model._meta.model_name,
)
)
@@ -318,7 +335,9 @@ class MigrationExecutor:
return False, project_state
if project_state is None:
- after_state = self.loader.project_state((migration.app_label, migration.name), at_end=True)
+ after_state = self.loader.project_state(
+ (migration.app_label, migration.name), at_end=True
+ )
else:
after_state = migration.mutate_state(project_state)
apps = after_state.apps
@@ -326,9 +345,13 @@ class MigrationExecutor:
found_add_field_migration = False
fold_identifier_case = self.connection.features.ignores_table_name_case
with self.connection.cursor() as cursor:
- existing_table_names = set(self.connection.introspection.table_names(cursor))
+ existing_table_names = set(
+ self.connection.introspection.table_names(cursor)
+ )
if fold_identifier_case:
- existing_table_names = {name.casefold() for name in existing_table_names}
+ existing_table_names = {
+ name.casefold() for name in existing_table_names
+ }
# Make sure all create model and add field operations are done
for operation in migration.operations:
if isinstance(operation, migrations.CreateModel):
@@ -368,7 +391,9 @@ class MigrationExecutor:
found_add_field_migration = True
continue
with self.connection.cursor() as cursor:
- columns = self.connection.introspection.get_table_description(cursor, table)
+ columns = self.connection.introspection.get_table_description(
+ cursor, table
+ )
for column in columns:
field_column = field.column
column_name = column.name
diff --git a/django/db/migrations/graph.py b/django/db/migrations/graph.py
index 4d66822e17..dd845c13e8 100644
--- a/django/db/migrations/graph.py
+++ b/django/db/migrations/graph.py
@@ -11,6 +11,7 @@ class Node:
A single node in the migration graph. Contains direct links to adjacent
nodes in either direction.
"""
+
def __init__(self, key):
self.key = key
self.children = set()
@@ -32,7 +33,7 @@ class Node:
return str(self.key)
def __repr__(self):
- return '<%s: (%r, %r)>' % (self.__class__.__name__, self.key[0], self.key[1])
+ return "<%s: (%r, %r)>" % (self.__class__.__name__, self.key[0], self.key[1])
def add_child(self, child):
self.children.add(child)
@@ -49,6 +50,7 @@ class DummyNode(Node):
After the migration graph is processed, all dummy nodes should be removed.
If there are any left, a nonexistent dependency error is raised.
"""
+
def __init__(self, key, origin, error_message):
super().__init__(key)
self.origin = origin
@@ -133,7 +135,7 @@ class MigrationGraph:
raise NodeNotFoundError(
"Unable to find replacement node %r. It was either never added"
" to the migration graph, or has been removed." % (replacement,),
- replacement
+ replacement,
) from err
for replaced_key in replaced:
self.nodes.pop(replaced_key, None)
@@ -167,8 +169,9 @@ class MigrationGraph:
except KeyError as err:
raise NodeNotFoundError(
"Unable to remove replacement node %r. It was either never added"
- " to the migration graph, or has been removed already." % (replacement,),
- replacement
+ " to the migration graph, or has been removed already."
+ % (replacement,),
+ replacement,
) from err
replaced_nodes = set()
replaced_nodes_parents = set()
@@ -228,7 +231,10 @@ class MigrationGraph:
visited.append(node.key)
else:
stack.append((node, True))
- stack += [(n, False) for n in sorted(node.parents if forwards else node.children)]
+ stack += [
+ (n, False)
+ for n in sorted(node.parents if forwards else node.children)
+ ]
return visited
def root_nodes(self, app=None):
@@ -238,7 +244,9 @@ class MigrationGraph:
"""
roots = set()
for node in self.nodes:
- if all(key[0] != node[0] for key in self.node_map[node].parents) and (not app or app == node[0]):
+ if all(key[0] != node[0] for key in self.node_map[node].parents) and (
+ not app or app == node[0]
+ ):
roots.add(node)
return sorted(roots)
@@ -252,7 +260,9 @@ class MigrationGraph:
"""
leaves = set()
for node in self.nodes:
- if all(key[0] != node[0] for key in self.node_map[node].children) and (not app or app == node[0]):
+ if all(key[0] != node[0] for key in self.node_map[node].children) and (
+ not app or app == node[0]
+ ):
leaves.add(node)
return sorted(leaves)
@@ -270,8 +280,10 @@ class MigrationGraph:
# hashing.
node = child.key
if node in stack:
- cycle = stack[stack.index(node):]
- raise CircularDependencyError(", ".join("%s.%s" % n for n in cycle))
+ cycle = stack[stack.index(node) :]
+ raise CircularDependencyError(
+ ", ".join("%s.%s" % n for n in cycle)
+ )
if node in todo:
stack.append(node)
todo.remove(node)
@@ -280,14 +292,16 @@ class MigrationGraph:
node = stack.pop()
def __str__(self):
- return 'Graph: %s nodes, %s edges' % self._nodes_and_edges()
+ return "Graph: %s nodes, %s edges" % self._nodes_and_edges()
def __repr__(self):
nodes, edges = self._nodes_and_edges()
- return '<%s: nodes=%s, edges=%s>' % (self.__class__.__name__, nodes, edges)
+ return "<%s: nodes=%s, edges=%s>" % (self.__class__.__name__, nodes, edges)
def _nodes_and_edges(self):
- return len(self.nodes), sum(len(node.parents) for node in self.node_map.values())
+ return len(self.nodes), sum(
+ len(node.parents) for node in self.node_map.values()
+ )
def _generate_plan(self, nodes, at_end):
plan = []
diff --git a/django/db/migrations/loader.py b/django/db/migrations/loader.py
index 93fb2c3bd5..81dcd06e04 100644
--- a/django/db/migrations/loader.py
+++ b/django/db/migrations/loader.py
@@ -8,11 +8,13 @@ from django.db.migrations.graph import MigrationGraph
from django.db.migrations.recorder import MigrationRecorder
from .exceptions import (
- AmbiguityError, BadMigrationError, InconsistentMigrationHistory,
+ AmbiguityError,
+ BadMigrationError,
+ InconsistentMigrationHistory,
NodeNotFoundError,
)
-MIGRATIONS_MODULE_NAME = 'migrations'
+MIGRATIONS_MODULE_NAME = "migrations"
class MigrationLoader:
@@ -41,7 +43,10 @@ class MigrationLoader:
"""
def __init__(
- self, connection, load=True, ignore_no_migrations=False,
+ self,
+ connection,
+ load=True,
+ ignore_no_migrations=False,
replace_migrations=True,
):
self.connection = connection
@@ -63,7 +68,7 @@ class MigrationLoader:
return settings.MIGRATION_MODULES[app_label], True
else:
app_package_name = apps.get_app_config(app_label).name
- return '%s.%s' % (app_package_name, MIGRATIONS_MODULE_NAME), False
+ return "%s.%s" % (app_package_name, MIGRATIONS_MODULE_NAME), False
def load_disk(self):
"""Load the migrations from all INSTALLED_APPS from disk."""
@@ -80,24 +85,22 @@ class MigrationLoader:
try:
module = import_module(module_name)
except ModuleNotFoundError as e:
- if (
- (explicit and self.ignore_no_migrations) or
- (not explicit and MIGRATIONS_MODULE_NAME in e.name.split('.'))
+ if (explicit and self.ignore_no_migrations) or (
+ not explicit and MIGRATIONS_MODULE_NAME in e.name.split(".")
):
self.unmigrated_apps.add(app_config.label)
continue
raise
else:
# Module is not a package (e.g. migrations.py).
- if not hasattr(module, '__path__'):
+ if not hasattr(module, "__path__"):
self.unmigrated_apps.add(app_config.label)
continue
# Empty directories are namespaces. Namespace packages have no
# __file__ and don't use a list for __path__. See
# https://docs.python.org/3/reference/import.html#namespace-packages
- if (
- getattr(module, '__file__', None) is None and
- not isinstance(module.__path__, list)
+ if getattr(module, "__file__", None) is None and not isinstance(
+ module.__path__, list
):
self.unmigrated_apps.add(app_config.label)
continue
@@ -106,16 +109,17 @@ class MigrationLoader:
reload(module)
self.migrated_apps.add(app_config.label)
migration_names = {
- name for _, name, is_pkg in pkgutil.iter_modules(module.__path__)
- if not is_pkg and name[0] not in '_~'
+ name
+ for _, name, is_pkg in pkgutil.iter_modules(module.__path__)
+ if not is_pkg and name[0] not in "_~"
}
# Load migrations
for migration_name in migration_names:
- migration_path = '%s.%s' % (module_name, migration_name)
+ migration_path = "%s.%s" % (module_name, migration_name)
try:
migration_module = import_module(migration_path)
except ImportError as e:
- if 'bad magic number' in str(e):
+ if "bad magic number" in str(e):
raise ImportError(
"Couldn't import %r as it appears to be a stale "
".pyc file." % migration_path
@@ -124,9 +128,12 @@ class MigrationLoader:
raise
if not hasattr(migration_module, "Migration"):
raise BadMigrationError(
- "Migration %s in app %s has no Migration class" % (migration_name, app_config.label)
+ "Migration %s in app %s has no Migration class"
+ % (migration_name, app_config.label)
)
- self.disk_migrations[app_config.label, migration_name] = migration_module.Migration(
+ self.disk_migrations[
+ app_config.label, migration_name
+ ] = migration_module.Migration(
migration_name,
app_config.label,
)
@@ -142,11 +149,14 @@ class MigrationLoader:
# Do the search
results = []
for migration_app_label, migration_name in self.disk_migrations:
- if migration_app_label == app_label and migration_name.startswith(name_prefix):
+ if migration_app_label == app_label and migration_name.startswith(
+ name_prefix
+ ):
results.append((migration_app_label, migration_name))
if len(results) > 1:
raise AmbiguityError(
- "There is more than one migration for '%s' with the prefix '%s'" % (app_label, name_prefix)
+ "There is more than one migration for '%s' with the prefix '%s'"
+ % (app_label, name_prefix)
)
elif not results:
raise KeyError(
@@ -181,7 +191,9 @@ class MigrationLoader:
if self.ignore_no_migrations:
return None
else:
- raise ValueError("Dependency on app with no migrations: %s" % key[0])
+ raise ValueError(
+ "Dependency on app with no migrations: %s" % key[0]
+ )
raise ValueError("Dependency on unknown app: %s" % key[0])
def add_internal_dependencies(self, key, migration):
@@ -191,7 +203,7 @@ class MigrationLoader:
"""
for parent in migration.dependencies:
# Ignore __first__ references to the same app.
- if parent[0] == key[0] and parent[1] != '__first__':
+ if parent[0] == key[0] and parent[1] != "__first__":
self.graph.add_dependency(migration, key, parent, skip_validation=True)
def add_external_dependencies(self, key, migration):
@@ -241,7 +253,9 @@ class MigrationLoader:
for key, migration in self.replacements.items():
# Get applied status of each of this migration's replacement
# targets.
- applied_statuses = [(target in self.applied_migrations) for target in migration.replaces]
+ applied_statuses = [
+ (target in self.applied_migrations) for target in migration.replaces
+ ]
# The replacing migration is only marked as applied if all of
# its replacement targets are.
if all(applied_statuses):
@@ -273,9 +287,11 @@ class MigrationLoader:
# Try to reraise exception with more detail.
if exc.node in reverse_replacements:
candidates = reverse_replacements.get(exc.node, set())
- is_replaced = any(candidate in self.graph.nodes for candidate in candidates)
+ is_replaced = any(
+ candidate in self.graph.nodes for candidate in candidates
+ )
if not is_replaced:
- tries = ', '.join('%s.%s' % c for c in candidates)
+ tries = ", ".join("%s.%s" % c for c in candidates)
raise NodeNotFoundError(
"Migration {0} depends on nonexistent node ('{1}', '{2}'). "
"Django tried to replace migration {1}.{2} with any of [{3}] "
@@ -283,7 +299,7 @@ class MigrationLoader:
"are already applied.".format(
exc.origin, exc.node[0], exc.node[1], tries
),
- exc.node
+ exc.node,
) from exc
raise
self.graph.ensure_not_cyclic()
@@ -304,12 +320,17 @@ class MigrationLoader:
# Skip unapplied squashed migrations that have all of their
# `replaces` applied.
if parent in self.replacements:
- if all(m in applied for m in self.replacements[parent].replaces):
+ if all(
+ m in applied for m in self.replacements[parent].replaces
+ ):
continue
raise InconsistentMigrationHistory(
"Migration {}.{} is applied before its dependency "
"{}.{} on database '{}'.".format(
- migration[0], migration[1], parent[0], parent[1],
+ migration[0],
+ migration[1],
+ parent[0],
+ parent[1],
connection.alias,
)
)
@@ -326,7 +347,9 @@ class MigrationLoader:
if app_label in seen_apps:
conflicting_apps.add(app_label)
seen_apps.setdefault(app_label, set()).add(migration_name)
- return {app_label: sorted(seen_apps[app_label]) for app_label in conflicting_apps}
+ return {
+ app_label: sorted(seen_apps[app_label]) for app_label in conflicting_apps
+ }
def project_state(self, nodes=None, at_end=True):
"""
@@ -335,7 +358,9 @@ class MigrationLoader:
See graph.make_state() for the meaning of "nodes" and "at_end".
"""
- return self.graph.make_state(nodes=nodes, at_end=at_end, real_apps=self.unmigrated_apps)
+ return self.graph.make_state(
+ nodes=nodes, at_end=at_end, real_apps=self.unmigrated_apps
+ )
def collect_sql(self, plan):
"""
@@ -345,9 +370,13 @@ class MigrationLoader:
statements = []
state = None
for migration, backwards in plan:
- with self.connection.schema_editor(collect_sql=True, atomic=migration.atomic) as schema_editor:
+ with self.connection.schema_editor(
+ collect_sql=True, atomic=migration.atomic
+ ) as schema_editor:
if state is None:
- state = self.project_state((migration.app_label, migration.name), at_end=False)
+ state = self.project_state(
+ (migration.app_label, migration.name), at_end=False
+ )
if not backwards:
state = migration.apply(state, schema_editor, collect_sql=True)
else:
diff --git a/django/db/migrations/migration.py b/django/db/migrations/migration.py
index 5ee0ae5191..39278d4cc7 100644
--- a/django/db/migrations/migration.py
+++ b/django/db/migrations/migration.py
@@ -60,9 +60,9 @@ class Migration:
def __eq__(self, other):
return (
- isinstance(other, Migration) and
- self.name == other.name and
- self.app_label == other.app_label
+ isinstance(other, Migration)
+ and self.name == other.name
+ and self.app_label == other.app_label
)
def __repr__(self):
@@ -114,15 +114,21 @@ class Migration:
old_state = project_state.clone()
operation.state_forwards(self.app_label, project_state)
# Run the operation
- atomic_operation = operation.atomic or (self.atomic and operation.atomic is not False)
+ atomic_operation = operation.atomic or (
+ self.atomic and operation.atomic is not False
+ )
if not schema_editor.atomic_migration and atomic_operation:
# Force a transaction on a non-transactional-DDL backend or an
# atomic operation inside a non-atomic migration.
with atomic(schema_editor.connection.alias):
- operation.database_forwards(self.app_label, schema_editor, old_state, project_state)
+ operation.database_forwards(
+ self.app_label, schema_editor, old_state, project_state
+ )
else:
# Normal behaviour
- operation.database_forwards(self.app_label, schema_editor, old_state, project_state)
+ operation.database_forwards(
+ self.app_label, schema_editor, old_state, project_state
+ )
return project_state
def unapply(self, project_state, schema_editor, collect_sql=False):
@@ -145,7 +151,9 @@ class Migration:
for operation in self.operations:
# If it's irreversible, error out
if not operation.reversible:
- raise IrreversibleError("Operation %s in %s is not reversible" % (operation, self))
+ raise IrreversibleError(
+ "Operation %s in %s is not reversible" % (operation, self)
+ )
# Preserve new state from previous run to not tamper the same state
# over all operations
new_state = new_state.clone()
@@ -165,15 +173,21 @@ class Migration:
schema_editor.collected_sql.append("--")
if not operation.reduces_to_sql:
continue
- atomic_operation = operation.atomic or (self.atomic and operation.atomic is not False)
+ atomic_operation = operation.atomic or (
+ self.atomic and operation.atomic is not False
+ )
if not schema_editor.atomic_migration and atomic_operation:
# Force a transaction on a non-transactional-DDL backend or an
# atomic operation inside a non-atomic migration.
with atomic(schema_editor.connection.alias):
- operation.database_backwards(self.app_label, schema_editor, from_state, to_state)
+ operation.database_backwards(
+ self.app_label, schema_editor, from_state, to_state
+ )
else:
# Normal behaviour
- operation.database_backwards(self.app_label, schema_editor, from_state, to_state)
+ operation.database_backwards(
+ self.app_label, schema_editor, from_state, to_state
+ )
return project_state
def suggest_name(self):
@@ -183,19 +197,19 @@ class Migration:
name to avoid VCS conflicts if possible.
"""
if self.initial:
- return 'initial'
+ return "initial"
raw_fragments = [op.migration_name_fragment for op in self.operations]
fragments = [name for name in raw_fragments if name]
if not fragments or len(fragments) != len(self.operations):
- return 'auto_%s' % get_migration_name_timestamp()
+ return "auto_%s" % get_migration_name_timestamp()
name = fragments[0]
for fragment in fragments[1:]:
- new_name = f'{name}_{fragment}'
+ new_name = f"{name}_{fragment}"
if len(new_name) > 52:
- name = f'{name}_and_more'
+ name = f"{name}_and_more"
break
name = new_name
return name
diff --git a/django/db/migrations/operations/__init__.py b/django/db/migrations/operations/__init__.py
index 119c955868..793969ed12 100644
--- a/django/db/migrations/operations/__init__.py
+++ b/django/db/migrations/operations/__init__.py
@@ -1,17 +1,40 @@
from .fields import AddField, AlterField, RemoveField, RenameField
from .models import (
- AddConstraint, AddIndex, AlterIndexTogether, AlterModelManagers,
- AlterModelOptions, AlterModelTable, AlterOrderWithRespectTo,
- AlterUniqueTogether, CreateModel, DeleteModel, RemoveConstraint,
- RemoveIndex, RenameModel,
+ AddConstraint,
+ AddIndex,
+ AlterIndexTogether,
+ AlterModelManagers,
+ AlterModelOptions,
+ AlterModelTable,
+ AlterOrderWithRespectTo,
+ AlterUniqueTogether,
+ CreateModel,
+ DeleteModel,
+ RemoveConstraint,
+ RemoveIndex,
+ RenameModel,
)
from .special import RunPython, RunSQL, SeparateDatabaseAndState
__all__ = [
- 'CreateModel', 'DeleteModel', 'AlterModelTable', 'AlterUniqueTogether',
- 'RenameModel', 'AlterIndexTogether', 'AlterModelOptions', 'AddIndex',
- 'RemoveIndex', 'AddField', 'RemoveField', 'AlterField', 'RenameField',
- 'AddConstraint', 'RemoveConstraint',
- 'SeparateDatabaseAndState', 'RunSQL', 'RunPython',
- 'AlterOrderWithRespectTo', 'AlterModelManagers',
+ "CreateModel",
+ "DeleteModel",
+ "AlterModelTable",
+ "AlterUniqueTogether",
+ "RenameModel",
+ "AlterIndexTogether",
+ "AlterModelOptions",
+ "AddIndex",
+ "RemoveIndex",
+ "AddField",
+ "RemoveField",
+ "AlterField",
+ "RenameField",
+ "AddConstraint",
+ "RemoveConstraint",
+ "SeparateDatabaseAndState",
+ "RunSQL",
+ "RunPython",
+ "AlterOrderWithRespectTo",
+ "AlterModelManagers",
]
diff --git a/django/db/migrations/operations/base.py b/django/db/migrations/operations/base.py
index 18935520f8..7d4dff2597 100644
--- a/django/db/migrations/operations/base.py
+++ b/django/db/migrations/operations/base.py
@@ -56,14 +56,18 @@ class Operation:
Take the state from the previous migration, and mutate it
so that it matches what this migration would perform.
"""
- raise NotImplementedError('subclasses of Operation must provide a state_forwards() method')
+ raise NotImplementedError(
+ "subclasses of Operation must provide a state_forwards() method"
+ )
def database_forwards(self, app_label, schema_editor, from_state, to_state):
"""
Perform the mutation on the database schema in the normal
(forwards) direction.
"""
- raise NotImplementedError('subclasses of Operation must provide a database_forwards() method')
+ raise NotImplementedError(
+ "subclasses of Operation must provide a database_forwards() method"
+ )
def database_backwards(self, app_label, schema_editor, from_state, to_state):
"""
@@ -71,7 +75,9 @@ class Operation:
direction - e.g. if this were CreateModel, it would in fact
drop the model's table.
"""
- raise NotImplementedError('subclasses of Operation must provide a database_backwards() method')
+ raise NotImplementedError(
+ "subclasses of Operation must provide a database_backwards() method"
+ )
def describe(self):
"""
diff --git a/django/db/migrations/operations/fields.py b/django/db/migrations/operations/fields.py
index 094c3e3cda..cd3aab43ad 100644
--- a/django/db/migrations/operations/fields.py
+++ b/django/db/migrations/operations/fields.py
@@ -23,16 +23,23 @@ class FieldOperation(Operation):
return self.model_name_lower == operation.model_name_lower
def is_same_field_operation(self, operation):
- return self.is_same_model_operation(operation) and self.name_lower == operation.name_lower
+ return (
+ self.is_same_model_operation(operation)
+ and self.name_lower == operation.name_lower
+ )
def references_model(self, name, app_label):
name_lower = name.lower()
if name_lower == self.model_name_lower:
return True
if self.field:
- return bool(field_references(
- (app_label, self.model_name_lower), self.field, (app_label, name_lower)
- ))
+ return bool(
+ field_references(
+ (app_label, self.model_name_lower),
+ self.field,
+ (app_label, name_lower),
+ )
+ )
return False
def references_field(self, model_name, name, app_label):
@@ -41,22 +48,27 @@ class FieldOperation(Operation):
if model_name_lower == self.model_name_lower:
if name == self.name:
return True
- elif self.field and hasattr(self.field, 'from_fields') and name in self.field.from_fields:
+ elif (
+ self.field
+ and hasattr(self.field, "from_fields")
+ and name in self.field.from_fields
+ ):
return True
# Check if this operation remotely references the field.
if self.field is None:
return False
- return bool(field_references(
- (app_label, self.model_name_lower),
- self.field,
- (app_label, model_name_lower),
- name,
- ))
+ return bool(
+ field_references(
+ (app_label, self.model_name_lower),
+ self.field,
+ (app_label, model_name_lower),
+ name,
+ )
+ )
def reduce(self, operation, app_label):
- return (
- super().reduce(operation, app_label) or
- not operation.references_field(self.model_name, self.name, app_label)
+ return super().reduce(operation, app_label) or not operation.references_field(
+ self.model_name, self.name, app_label
)
@@ -69,17 +81,13 @@ class AddField(FieldOperation):
def deconstruct(self):
kwargs = {
- 'model_name': self.model_name,
- 'name': self.name,
- 'field': self.field,
+ "model_name": self.model_name,
+ "name": self.name,
+ "field": self.field,
}
if self.preserve_default is not True:
- kwargs['preserve_default'] = self.preserve_default
- return (
- self.__class__.__name__,
- [],
- kwargs
- )
+ kwargs["preserve_default"] = self.preserve_default
+ return (self.__class__.__name__, [], kwargs)
def state_forwards(self, app_label, state):
state.add_field(
@@ -107,17 +115,21 @@ class AddField(FieldOperation):
def database_backwards(self, app_label, schema_editor, from_state, to_state):
from_model = from_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, from_model):
- schema_editor.remove_field(from_model, from_model._meta.get_field(self.name))
+ schema_editor.remove_field(
+ from_model, from_model._meta.get_field(self.name)
+ )
def describe(self):
return "Add field %s to %s" % (self.name, self.model_name)
@property
def migration_name_fragment(self):
- return '%s_%s' % (self.model_name_lower, self.name_lower)
+ return "%s_%s" % (self.model_name_lower, self.name_lower)
def reduce(self, operation, app_label):
- if isinstance(operation, FieldOperation) and self.is_same_field_operation(operation):
+ if isinstance(operation, FieldOperation) and self.is_same_field_operation(
+ operation
+ ):
if isinstance(operation, AlterField):
return [
AddField(
@@ -144,14 +156,10 @@ class RemoveField(FieldOperation):
def deconstruct(self):
kwargs = {
- 'model_name': self.model_name,
- 'name': self.name,
+ "model_name": self.model_name,
+ "name": self.name,
}
- return (
- self.__class__.__name__,
- [],
- kwargs
- )
+ return (self.__class__.__name__, [], kwargs)
def state_forwards(self, app_label, state):
state.remove_field(app_label, self.model_name_lower, self.name)
@@ -159,7 +167,9 @@ class RemoveField(FieldOperation):
def database_forwards(self, app_label, schema_editor, from_state, to_state):
from_model = from_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, from_model):
- schema_editor.remove_field(from_model, from_model._meta.get_field(self.name))
+ schema_editor.remove_field(
+ from_model, from_model._meta.get_field(self.name)
+ )
def database_backwards(self, app_label, schema_editor, from_state, to_state):
to_model = to_state.apps.get_model(app_label, self.model_name)
@@ -172,11 +182,15 @@ class RemoveField(FieldOperation):
@property
def migration_name_fragment(self):
- return 'remove_%s_%s' % (self.model_name_lower, self.name_lower)
+ return "remove_%s_%s" % (self.model_name_lower, self.name_lower)
def reduce(self, operation, app_label):
from .models import DeleteModel
- if isinstance(operation, DeleteModel) and operation.name_lower == self.model_name_lower:
+
+ if (
+ isinstance(operation, DeleteModel)
+ and operation.name_lower == self.model_name_lower
+ ):
return [operation]
return super().reduce(operation, app_label)
@@ -193,17 +207,13 @@ class AlterField(FieldOperation):
def deconstruct(self):
kwargs = {
- 'model_name': self.model_name,
- 'name': self.name,
- 'field': self.field,
+ "model_name": self.model_name,
+ "name": self.name,
+ "field": self.field,
}
if self.preserve_default is not True:
- kwargs['preserve_default'] = self.preserve_default
- return (
- self.__class__.__name__,
- [],
- kwargs
- )
+ kwargs["preserve_default"] = self.preserve_default
+ return (self.__class__.__name__, [], kwargs)
def state_forwards(self, app_label, state):
state.alter_field(
@@ -234,15 +244,17 @@ class AlterField(FieldOperation):
@property
def migration_name_fragment(self):
- return 'alter_%s_%s' % (self.model_name_lower, self.name_lower)
+ return "alter_%s_%s" % (self.model_name_lower, self.name_lower)
def reduce(self, operation, app_label):
- if isinstance(operation, RemoveField) and self.is_same_field_operation(operation):
+ if isinstance(operation, RemoveField) and self.is_same_field_operation(
+ operation
+ ):
return [operation]
elif (
- isinstance(operation, RenameField) and
- self.is_same_field_operation(operation) and
- self.field.db_column is None
+ isinstance(operation, RenameField)
+ and self.is_same_field_operation(operation)
+ and self.field.db_column is None
):
return [
operation,
@@ -273,18 +285,16 @@ class RenameField(FieldOperation):
def deconstruct(self):
kwargs = {
- 'model_name': self.model_name,
- 'old_name': self.old_name,
- 'new_name': self.new_name,
+ "model_name": self.model_name,
+ "old_name": self.old_name,
+ "new_name": self.new_name,
}
- return (
- self.__class__.__name__,
- [],
- kwargs
- )
+ return (self.__class__.__name__, [], kwargs)
def state_forwards(self, app_label, state):
- state.rename_field(app_label, self.model_name_lower, self.old_name, self.new_name)
+ state.rename_field(
+ app_label, self.model_name_lower, self.old_name, self.new_name
+ )
def database_forwards(self, app_label, schema_editor, from_state, to_state):
to_model = to_state.apps.get_model(app_label, self.model_name)
@@ -307,11 +317,15 @@ class RenameField(FieldOperation):
)
def describe(self):
- return "Rename field %s on %s to %s" % (self.old_name, self.model_name, self.new_name)
+ return "Rename field %s on %s to %s" % (
+ self.old_name,
+ self.model_name,
+ self.new_name,
+ )
@property
def migration_name_fragment(self):
- return 'rename_%s_%s_%s' % (
+ return "rename_%s_%s_%s" % (
self.old_name_lower,
self.model_name_lower,
self.new_name_lower,
@@ -319,14 +333,15 @@ class RenameField(FieldOperation):
def references_field(self, model_name, name, app_label):
return self.references_model(model_name, app_label) and (
- name.lower() == self.old_name_lower or
- name.lower() == self.new_name_lower
+ name.lower() == self.old_name_lower or name.lower() == self.new_name_lower
)
def reduce(self, operation, app_label):
- if (isinstance(operation, RenameField) and
- self.is_same_model_operation(operation) and
- self.new_name_lower == operation.old_name_lower):
+ if (
+ isinstance(operation, RenameField)
+ and self.is_same_model_operation(operation)
+ and self.new_name_lower == operation.old_name_lower
+ ):
return [
RenameField(
self.model_name,
@@ -336,10 +351,7 @@ class RenameField(FieldOperation):
]
# Skip `FieldOperation.reduce` as we want to run `references_field`
# against self.old_name and self.new_name.
- return (
- super(FieldOperation, self).reduce(operation, app_label) or
- not (
- operation.references_field(self.model_name, self.old_name, app_label) or
- operation.references_field(self.model_name, self.new_name, app_label)
- )
+ return super(FieldOperation, self).reduce(operation, app_label) or not (
+ operation.references_field(self.model_name, self.old_name, app_label)
+ or operation.references_field(self.model_name, self.new_name, app_label)
)
diff --git a/django/db/migrations/operations/models.py b/django/db/migrations/operations/models.py
index 01c44a9a26..90fc31bee5 100644
--- a/django/db/migrations/operations/models.py
+++ b/django/db/migrations/operations/models.py
@@ -5,9 +5,7 @@ from django.db.migrations.utils import field_references, resolve_relation
from django.db.models.options import normalize_together
from django.utils.functional import cached_property
-from .fields import (
- AddField, AlterField, FieldOperation, RemoveField, RenameField,
-)
+from .fields import AddField, AlterField, FieldOperation, RemoveField, RenameField
def _check_for_duplicates(arg_name, objs):
@@ -32,9 +30,8 @@ class ModelOperation(Operation):
return name.lower() == self.name_lower
def reduce(self, operation, app_label):
- return (
- super().reduce(operation, app_label) or
- self.can_reduce_through(operation, app_label)
+ return super().reduce(operation, app_label) or self.can_reduce_through(
+ operation, app_label
)
def can_reduce_through(self, operation, app_label):
@@ -44,7 +41,7 @@ class ModelOperation(Operation):
class CreateModel(ModelOperation):
"""Create a model's table."""
- serialization_expand_args = ['fields', 'options', 'managers']
+ serialization_expand_args = ["fields", "options", "managers"]
def __init__(self, name, fields, options=None, bases=None, managers=None):
self.fields = fields
@@ -54,40 +51,44 @@ class CreateModel(ModelOperation):
super().__init__(name)
# Sanity-check that there are no duplicated field names, bases, or
# manager names
- _check_for_duplicates('fields', (name for name, _ in self.fields))
- _check_for_duplicates('bases', (
- base._meta.label_lower if hasattr(base, '_meta') else
- base.lower() if isinstance(base, str) else base
- for base in self.bases
- ))
- _check_for_duplicates('managers', (name for name, _ in self.managers))
+ _check_for_duplicates("fields", (name for name, _ in self.fields))
+ _check_for_duplicates(
+ "bases",
+ (
+ base._meta.label_lower
+ if hasattr(base, "_meta")
+ else base.lower()
+ if isinstance(base, str)
+ else base
+ for base in self.bases
+ ),
+ )
+ _check_for_duplicates("managers", (name for name, _ in self.managers))
def deconstruct(self):
kwargs = {
- 'name': self.name,
- 'fields': self.fields,
+ "name": self.name,
+ "fields": self.fields,
}
if self.options:
- kwargs['options'] = self.options
+ kwargs["options"] = self.options
if self.bases and self.bases != (models.Model,):
- kwargs['bases'] = self.bases
- if self.managers and self.managers != [('objects', models.Manager())]:
- kwargs['managers'] = self.managers
- return (
- self.__class__.__qualname__,
- [],
- kwargs
- )
+ kwargs["bases"] = self.bases
+ if self.managers and self.managers != [("objects", models.Manager())]:
+ kwargs["managers"] = self.managers
+ return (self.__class__.__qualname__, [], kwargs)
def state_forwards(self, app_label, state):
- state.add_model(ModelState(
- app_label,
- self.name,
- list(self.fields),
- dict(self.options),
- tuple(self.bases),
- list(self.managers),
- ))
+ state.add_model(
+ ModelState(
+ app_label,
+ self.name,
+ list(self.fields),
+ dict(self.options),
+ tuple(self.bases),
+ list(self.managers),
+ )
+ )
def database_forwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.name)
@@ -100,7 +101,10 @@ class CreateModel(ModelOperation):
schema_editor.delete_model(model)
def describe(self):
- return "Create %smodel %s" % ("proxy " if self.options.get("proxy", False) else "", self.name)
+ return "Create %smodel %s" % (
+ "proxy " if self.options.get("proxy", False) else "",
+ self.name,
+ )
@property
def migration_name_fragment(self):
@@ -114,22 +118,32 @@ class CreateModel(ModelOperation):
# Check we didn't inherit from the model
reference_model_tuple = (app_label, name_lower)
for base in self.bases:
- if (base is not models.Model and isinstance(base, (models.base.ModelBase, str)) and
- resolve_relation(base, app_label) == reference_model_tuple):
+ if (
+ base is not models.Model
+ and isinstance(base, (models.base.ModelBase, str))
+ and resolve_relation(base, app_label) == reference_model_tuple
+ ):
return True
# Check we have no FKs/M2Ms with it
for _name, field in self.fields:
- if field_references((app_label, self.name_lower), field, reference_model_tuple):
+ if field_references(
+ (app_label, self.name_lower), field, reference_model_tuple
+ ):
return True
return False
def reduce(self, operation, app_label):
- if (isinstance(operation, DeleteModel) and
- self.name_lower == operation.name_lower and
- not self.options.get("proxy", False)):
+ if (
+ isinstance(operation, DeleteModel)
+ and self.name_lower == operation.name_lower
+ and not self.options.get("proxy", False)
+ ):
return []
- elif isinstance(operation, RenameModel) and self.name_lower == operation.old_name_lower:
+ elif (
+ isinstance(operation, RenameModel)
+ and self.name_lower == operation.old_name_lower
+ ):
return [
CreateModel(
operation.new_name,
@@ -139,7 +153,10 @@ class CreateModel(ModelOperation):
managers=self.managers,
),
]
- elif isinstance(operation, AlterModelOptions) and self.name_lower == operation.name_lower:
+ elif (
+ isinstance(operation, AlterModelOptions)
+ and self.name_lower == operation.name_lower
+ ):
options = {**self.options, **operation.options}
for key in operation.ALTER_OPTION_KEYS:
if key not in operation.options:
@@ -153,27 +170,42 @@ class CreateModel(ModelOperation):
managers=self.managers,
),
]
- elif isinstance(operation, AlterTogetherOptionOperation) and self.name_lower == operation.name_lower:
+ elif (
+ isinstance(operation, AlterTogetherOptionOperation)
+ and self.name_lower == operation.name_lower
+ ):
return [
CreateModel(
self.name,
fields=self.fields,
- options={**self.options, **{operation.option_name: operation.option_value}},
+ options={
+ **self.options,
+ **{operation.option_name: operation.option_value},
+ },
bases=self.bases,
managers=self.managers,
),
]
- elif isinstance(operation, AlterOrderWithRespectTo) and self.name_lower == operation.name_lower:
+ elif (
+ isinstance(operation, AlterOrderWithRespectTo)
+ and self.name_lower == operation.name_lower
+ ):
return [
CreateModel(
self.name,
fields=self.fields,
- options={**self.options, 'order_with_respect_to': operation.order_with_respect_to},
+ options={
+ **self.options,
+ "order_with_respect_to": operation.order_with_respect_to,
+ },
bases=self.bases,
managers=self.managers,
),
]
- elif isinstance(operation, FieldOperation) and self.name_lower == operation.model_name_lower:
+ elif (
+ isinstance(operation, FieldOperation)
+ and self.name_lower == operation.model_name_lower
+ ):
if isinstance(operation, AddField):
return [
CreateModel(
@@ -199,17 +231,25 @@ class CreateModel(ModelOperation):
]
elif isinstance(operation, RemoveField):
options = self.options.copy()
- for option_name in ('unique_together', 'index_together'):
+ for option_name in ("unique_together", "index_together"):
option = options.pop(option_name, None)
if option:
- option = set(filter(bool, (
- tuple(f for f in fields if f != operation.name_lower) for fields in option
- )))
+ option = set(
+ filter(
+ bool,
+ (
+ tuple(
+ f for f in fields if f != operation.name_lower
+ )
+ for fields in option
+ ),
+ )
+ )
if option:
options[option_name] = option
- order_with_respect_to = options.get('order_with_respect_to')
+ order_with_respect_to = options.get("order_with_respect_to")
if order_with_respect_to == operation.name_lower:
- del options['order_with_respect_to']
+ del options["order_with_respect_to"]
return [
CreateModel(
self.name,
@@ -225,16 +265,19 @@ class CreateModel(ModelOperation):
]
elif isinstance(operation, RenameField):
options = self.options.copy()
- for option_name in ('unique_together', 'index_together'):
+ for option_name in ("unique_together", "index_together"):
option = options.get(option_name)
if option:
options[option_name] = {
- tuple(operation.new_name if f == operation.old_name else f for f in fields)
+ tuple(
+ operation.new_name if f == operation.old_name else f
+ for f in fields
+ )
for fields in option
}
- order_with_respect_to = options.get('order_with_respect_to')
+ order_with_respect_to = options.get("order_with_respect_to")
if order_with_respect_to == operation.old_name:
- options['order_with_respect_to'] = operation.new_name
+ options["order_with_respect_to"] = operation.new_name
return [
CreateModel(
self.name,
@@ -255,13 +298,9 @@ class DeleteModel(ModelOperation):
def deconstruct(self):
kwargs = {
- 'name': self.name,
+ "name": self.name,
}
- return (
- self.__class__.__qualname__,
- [],
- kwargs
- )
+ return (self.__class__.__qualname__, [], kwargs)
def state_forwards(self, app_label, state):
state.remove_model(app_label, self.name_lower)
@@ -286,7 +325,7 @@ class DeleteModel(ModelOperation):
@property
def migration_name_fragment(self):
- return 'delete_%s' % self.name_lower
+ return "delete_%s" % self.name_lower
class RenameModel(ModelOperation):
@@ -307,14 +346,10 @@ class RenameModel(ModelOperation):
def deconstruct(self):
kwargs = {
- 'old_name': self.old_name,
- 'new_name': self.new_name,
+ "old_name": self.old_name,
+ "new_name": self.new_name,
}
- return (
- self.__class__.__qualname__,
- [],
- kwargs
- )
+ return (self.__class__.__qualname__, [], kwargs)
def state_forwards(self, app_label, state):
state.rename_model(app_label, self.old_name, self.new_name)
@@ -341,19 +376,24 @@ class RenameModel(ModelOperation):
related_object.related_model._meta.app_label,
related_object.related_model._meta.model_name,
)
- to_field = to_state.apps.get_model(
- *related_key
- )._meta.get_field(related_object.field.name)
+ to_field = to_state.apps.get_model(*related_key)._meta.get_field(
+ related_object.field.name
+ )
schema_editor.alter_field(
model,
related_object.field,
to_field,
)
# Rename M2M fields whose name is based on this model's name.
- fields = zip(old_model._meta.local_many_to_many, new_model._meta.local_many_to_many)
+ fields = zip(
+ old_model._meta.local_many_to_many, new_model._meta.local_many_to_many
+ )
for (old_field, new_field) in fields:
# Skip self-referential fields as these are renamed above.
- if new_field.model == new_field.related_model or not new_field.remote_field.through._meta.auto_created:
+ if (
+ new_field.model == new_field.related_model
+ or not new_field.remote_field.through._meta.auto_created
+ ):
continue
# Rename the M2M table that's based on this model's name.
old_m2m_model = old_field.remote_field.through
@@ -372,18 +412,23 @@ class RenameModel(ModelOperation):
)
def database_backwards(self, app_label, schema_editor, from_state, to_state):
- self.new_name_lower, self.old_name_lower = self.old_name_lower, self.new_name_lower
+ self.new_name_lower, self.old_name_lower = (
+ self.old_name_lower,
+ self.new_name_lower,
+ )
self.new_name, self.old_name = self.old_name, self.new_name
self.database_forwards(app_label, schema_editor, from_state, to_state)
- self.new_name_lower, self.old_name_lower = self.old_name_lower, self.new_name_lower
+ self.new_name_lower, self.old_name_lower = (
+ self.old_name_lower,
+ self.new_name_lower,
+ )
self.new_name, self.old_name = self.old_name, self.new_name
def references_model(self, name, app_label):
return (
- name.lower() == self.old_name_lower or
- name.lower() == self.new_name_lower
+ name.lower() == self.old_name_lower or name.lower() == self.new_name_lower
)
def describe(self):
@@ -391,11 +436,13 @@ class RenameModel(ModelOperation):
@property
def migration_name_fragment(self):
- return 'rename_%s_%s' % (self.old_name_lower, self.new_name_lower)
+ return "rename_%s_%s" % (self.old_name_lower, self.new_name_lower)
def reduce(self, operation, app_label):
- if (isinstance(operation, RenameModel) and
- self.new_name_lower == operation.old_name_lower):
+ if (
+ isinstance(operation, RenameModel)
+ and self.new_name_lower == operation.old_name_lower
+ ):
return [
RenameModel(
self.old_name,
@@ -404,15 +451,17 @@ class RenameModel(ModelOperation):
]
# Skip `ModelOperation.reduce` as we want to run `references_model`
# against self.new_name.
- return (
- super(ModelOperation, self).reduce(operation, app_label) or
- not operation.references_model(self.new_name, app_label)
- )
+ return super(ModelOperation, self).reduce(
+ operation, app_label
+ ) or not operation.references_model(self.new_name, app_label)
class ModelOptionOperation(ModelOperation):
def reduce(self, operation, app_label):
- if isinstance(operation, (self.__class__, DeleteModel)) and self.name_lower == operation.name_lower:
+ if (
+ isinstance(operation, (self.__class__, DeleteModel))
+ and self.name_lower == operation.name_lower
+ ):
return [operation]
return super().reduce(operation, app_label)
@@ -426,17 +475,13 @@ class AlterModelTable(ModelOptionOperation):
def deconstruct(self):
kwargs = {
- 'name': self.name,
- 'table': self.table,
+ "name": self.name,
+ "table": self.table,
}
- return (
- self.__class__.__qualname__,
- [],
- kwargs
- )
+ return (self.__class__.__qualname__, [], kwargs)
def state_forwards(self, app_label, state):
- state.alter_model_options(app_label, self.name_lower, {'db_table': self.table})
+ state.alter_model_options(app_label, self.name_lower, {"db_table": self.table})
def database_forwards(self, app_label, schema_editor, from_state, to_state):
new_model = to_state.apps.get_model(app_label, self.name)
@@ -448,7 +493,9 @@ class AlterModelTable(ModelOptionOperation):
new_model._meta.db_table,
)
# Rename M2M fields whose name is based on this model's db_table
- for (old_field, new_field) in zip(old_model._meta.local_many_to_many, new_model._meta.local_many_to_many):
+ for (old_field, new_field) in zip(
+ old_model._meta.local_many_to_many, new_model._meta.local_many_to_many
+ ):
if new_field.remote_field.through._meta.auto_created:
schema_editor.alter_db_table(
new_field.remote_field.through,
@@ -462,12 +509,12 @@ class AlterModelTable(ModelOptionOperation):
def describe(self):
return "Rename table for %s to %s" % (
self.name,
- self.table if self.table is not None else "(default)"
+ self.table if self.table is not None else "(default)",
)
@property
def migration_name_fragment(self):
- return 'alter_%s_table' % self.name_lower
+ return "alter_%s_table" % self.name_lower
class AlterTogetherOptionOperation(ModelOptionOperation):
@@ -485,14 +532,10 @@ class AlterTogetherOptionOperation(ModelOptionOperation):
def deconstruct(self):
kwargs = {
- 'name': self.name,
+ "name": self.name,
self.option_name: self.option_value,
}
- return (
- self.__class__.__qualname__,
- [],
- kwargs
- )
+ return (self.__class__.__qualname__, [], kwargs)
def state_forwards(self, app_label, state):
state.alter_model_options(
@@ -505,7 +548,7 @@ class AlterTogetherOptionOperation(ModelOptionOperation):
new_model = to_state.apps.get_model(app_label, self.name)
if self.allow_migrate_model(schema_editor.connection.alias, new_model):
old_model = from_state.apps.get_model(app_label, self.name)
- alter_together = getattr(schema_editor, 'alter_%s' % self.option_name)
+ alter_together = getattr(schema_editor, "alter_%s" % self.option_name)
alter_together(
new_model,
getattr(old_model._meta, self.option_name, set()),
@@ -516,27 +559,26 @@ class AlterTogetherOptionOperation(ModelOptionOperation):
return self.database_forwards(app_label, schema_editor, from_state, to_state)
def references_field(self, model_name, name, app_label):
- return (
- self.references_model(model_name, app_label) and
- (
- not self.option_value or
- any((name in fields) for fields in self.option_value)
- )
+ return self.references_model(model_name, app_label) and (
+ not self.option_value
+ or any((name in fields) for fields in self.option_value)
)
def describe(self):
- return "Alter %s for %s (%s constraint(s))" % (self.option_name, self.name, len(self.option_value or ''))
+ return "Alter %s for %s (%s constraint(s))" % (
+ self.option_name,
+ self.name,
+ len(self.option_value or ""),
+ )
@property
def migration_name_fragment(self):
- return 'alter_%s_%s' % (self.name_lower, self.option_name)
+ return "alter_%s_%s" % (self.name_lower, self.option_name)
def can_reduce_through(self, operation, app_label):
- return (
- super().can_reduce_through(operation, app_label) or (
- isinstance(operation, AlterTogetherOptionOperation) and
- type(operation) is not type(self)
- )
+ return super().can_reduce_through(operation, app_label) or (
+ isinstance(operation, AlterTogetherOptionOperation)
+ and type(operation) is not type(self)
)
@@ -545,7 +587,8 @@ class AlterUniqueTogether(AlterTogetherOptionOperation):
Change the value of unique_together to the target one.
Input value of unique_together must be a set of tuples.
"""
- option_name = 'unique_together'
+
+ option_name = "unique_together"
def __init__(self, name, unique_together):
super().__init__(name, unique_together)
@@ -556,6 +599,7 @@ class AlterIndexTogether(AlterTogetherOptionOperation):
Change the value of index_together to the target one.
Input value of index_together must be a set of tuples.
"""
+
option_name = "index_together"
def __init__(self, name, index_together):
@@ -565,7 +609,7 @@ class AlterIndexTogether(AlterTogetherOptionOperation):
class AlterOrderWithRespectTo(ModelOptionOperation):
"""Represent a change with the order_with_respect_to option."""
- option_name = 'order_with_respect_to'
+ option_name = "order_with_respect_to"
def __init__(self, name, order_with_respect_to):
self.order_with_respect_to = order_with_respect_to
@@ -573,14 +617,10 @@ class AlterOrderWithRespectTo(ModelOptionOperation):
def deconstruct(self):
kwargs = {
- 'name': self.name,
- 'order_with_respect_to': self.order_with_respect_to,
+ "name": self.name,
+ "order_with_respect_to": self.order_with_respect_to,
}
- return (
- self.__class__.__qualname__,
- [],
- kwargs
- )
+ return (self.__class__.__qualname__, [], kwargs)
def state_forwards(self, app_label, state):
state.alter_model_options(
@@ -594,11 +634,19 @@ class AlterOrderWithRespectTo(ModelOptionOperation):
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
from_model = from_state.apps.get_model(app_label, self.name)
# Remove a field if we need to
- if from_model._meta.order_with_respect_to and not to_model._meta.order_with_respect_to:
- schema_editor.remove_field(from_model, from_model._meta.get_field("_order"))
+ if (
+ from_model._meta.order_with_respect_to
+ and not to_model._meta.order_with_respect_to
+ ):
+ schema_editor.remove_field(
+ from_model, from_model._meta.get_field("_order")
+ )
# Add a field if we need to (altering the column is untouched as
# it's likely a rename)
- elif to_model._meta.order_with_respect_to and not from_model._meta.order_with_respect_to:
+ elif (
+ to_model._meta.order_with_respect_to
+ and not from_model._meta.order_with_respect_to
+ ):
field = to_model._meta.get_field("_order")
if not field.has_default():
field.default = 0
@@ -611,20 +659,19 @@ class AlterOrderWithRespectTo(ModelOptionOperation):
self.database_forwards(app_label, schema_editor, from_state, to_state)
def references_field(self, model_name, name, app_label):
- return (
- self.references_model(model_name, app_label) and
- (
- self.order_with_respect_to is None or
- name == self.order_with_respect_to
- )
+ return self.references_model(model_name, app_label) and (
+ self.order_with_respect_to is None or name == self.order_with_respect_to
)
def describe(self):
- return "Set order_with_respect_to on %s to %s" % (self.name, self.order_with_respect_to)
+ return "Set order_with_respect_to on %s to %s" % (
+ self.name,
+ self.order_with_respect_to,
+ )
@property
def migration_name_fragment(self):
- return 'alter_%s_order_with_respect_to' % self.name_lower
+ return "alter_%s_order_with_respect_to" % self.name_lower
class AlterModelOptions(ModelOptionOperation):
@@ -655,14 +702,10 @@ class AlterModelOptions(ModelOptionOperation):
def deconstruct(self):
kwargs = {
- 'name': self.name,
- 'options': self.options,
+ "name": self.name,
+ "options": self.options,
}
- return (
- self.__class__.__qualname__,
- [],
- kwargs
- )
+ return (self.__class__.__qualname__, [], kwargs)
def state_forwards(self, app_label, state):
state.alter_model_options(
@@ -683,24 +726,20 @@ class AlterModelOptions(ModelOptionOperation):
@property
def migration_name_fragment(self):
- return 'alter_%s_options' % self.name_lower
+ return "alter_%s_options" % self.name_lower
class AlterModelManagers(ModelOptionOperation):
"""Alter the model's managers."""
- serialization_expand_args = ['managers']
+ serialization_expand_args = ["managers"]
def __init__(self, name, managers):
self.managers = managers
super().__init__(name)
def deconstruct(self):
- return (
- self.__class__.__qualname__,
- [self.name, self.managers],
- {}
- )
+ return (self.__class__.__qualname__, [self.name, self.managers], {})
def state_forwards(self, app_label, state):
state.alter_model_managers(app_label, self.name_lower, self.managers)
@@ -716,11 +755,11 @@ class AlterModelManagers(ModelOptionOperation):
@property
def migration_name_fragment(self):
- return 'alter_%s_managers' % self.name_lower
+ return "alter_%s_managers" % self.name_lower
class IndexOperation(Operation):
- option_name = 'indexes'
+ option_name = "indexes"
@cached_property
def model_name_lower(self):
@@ -754,8 +793,8 @@ class AddIndex(IndexOperation):
def deconstruct(self):
kwargs = {
- 'model_name': self.model_name,
- 'index': self.index,
+ "model_name": self.model_name,
+ "index": self.index,
}
return (
self.__class__.__qualname__,
@@ -765,20 +804,20 @@ class AddIndex(IndexOperation):
def describe(self):
if self.index.expressions:
- return 'Create index %s on %s on model %s' % (
+ return "Create index %s on %s on model %s" % (
self.index.name,
- ', '.join([str(expression) for expression in self.index.expressions]),
+ ", ".join([str(expression) for expression in self.index.expressions]),
self.model_name,
)
- return 'Create index %s on field(s) %s of model %s' % (
+ return "Create index %s on field(s) %s of model %s" % (
self.index.name,
- ', '.join(self.index.fields),
+ ", ".join(self.index.fields),
self.model_name,
)
@property
def migration_name_fragment(self):
- return '%s_%s' % (self.model_name_lower, self.index.name.lower())
+ return "%s_%s" % (self.model_name_lower, self.index.name.lower())
class RemoveIndex(IndexOperation):
@@ -807,8 +846,8 @@ class RemoveIndex(IndexOperation):
def deconstruct(self):
kwargs = {
- 'model_name': self.model_name,
- 'name': self.name,
+ "model_name": self.model_name,
+ "name": self.name,
}
return (
self.__class__.__qualname__,
@@ -817,15 +856,15 @@ class RemoveIndex(IndexOperation):
)
def describe(self):
- return 'Remove index %s from %s' % (self.name, self.model_name)
+ return "Remove index %s from %s" % (self.name, self.model_name)
@property
def migration_name_fragment(self):
- return 'remove_%s_%s' % (self.model_name_lower, self.name.lower())
+ return "remove_%s_%s" % (self.model_name_lower, self.name.lower())
class AddConstraint(IndexOperation):
- option_name = 'constraints'
+ option_name = "constraints"
def __init__(self, model_name, constraint):
self.model_name = model_name
@@ -845,21 +884,28 @@ class AddConstraint(IndexOperation):
schema_editor.remove_constraint(model, self.constraint)
def deconstruct(self):
- return self.__class__.__name__, [], {
- 'model_name': self.model_name,
- 'constraint': self.constraint,
- }
+ return (
+ self.__class__.__name__,
+ [],
+ {
+ "model_name": self.model_name,
+ "constraint": self.constraint,
+ },
+ )
def describe(self):
- return 'Create constraint %s on model %s' % (self.constraint.name, self.model_name)
+ return "Create constraint %s on model %s" % (
+ self.constraint.name,
+ self.model_name,
+ )
@property
def migration_name_fragment(self):
- return '%s_%s' % (self.model_name_lower, self.constraint.name.lower())
+ return "%s_%s" % (self.model_name_lower, self.constraint.name.lower())
class RemoveConstraint(IndexOperation):
- option_name = 'constraints'
+ option_name = "constraints"
def __init__(self, model_name, name):
self.model_name = model_name
@@ -883,14 +929,18 @@ class RemoveConstraint(IndexOperation):
schema_editor.add_constraint(model, constraint)
def deconstruct(self):
- return self.__class__.__name__, [], {
- 'model_name': self.model_name,
- 'name': self.name,
- }
+ return (
+ self.__class__.__name__,
+ [],
+ {
+ "model_name": self.model_name,
+ "name": self.name,
+ },
+ )
def describe(self):
- return 'Remove constraint %s from model %s' % (self.name, self.model_name)
+ return "Remove constraint %s from model %s" % (self.name, self.model_name)
@property
def migration_name_fragment(self):
- return 'remove_%s_%s' % (self.model_name_lower, self.name.lower())
+ return "remove_%s_%s" % (self.model_name_lower, self.name.lower())
diff --git a/django/db/migrations/operations/special.py b/django/db/migrations/operations/special.py
index 5a8510ec02..94a6ec72de 100644
--- a/django/db/migrations/operations/special.py
+++ b/django/db/migrations/operations/special.py
@@ -11,7 +11,7 @@ class SeparateDatabaseAndState(Operation):
that affect the state or not the database, or so on.
"""
- serialization_expand_args = ['database_operations', 'state_operations']
+ serialization_expand_args = ["database_operations", "state_operations"]
def __init__(self, database_operations=None, state_operations=None):
self.database_operations = database_operations or []
@@ -20,14 +20,10 @@ class SeparateDatabaseAndState(Operation):
def deconstruct(self):
kwargs = {}
if self.database_operations:
- kwargs['database_operations'] = self.database_operations
+ kwargs["database_operations"] = self.database_operations
if self.state_operations:
- kwargs['state_operations'] = self.state_operations
- return (
- self.__class__.__qualname__,
- [],
- kwargs
- )
+ kwargs["state_operations"] = self.state_operations
+ return (self.__class__.__qualname__, [], kwargs)
def state_forwards(self, app_label, state):
for state_operation in self.state_operations:
@@ -38,7 +34,9 @@ class SeparateDatabaseAndState(Operation):
for database_operation in self.database_operations:
to_state = from_state.clone()
database_operation.state_forwards(app_label, to_state)
- database_operation.database_forwards(app_label, schema_editor, from_state, to_state)
+ database_operation.database_forwards(
+ app_label, schema_editor, from_state, to_state
+ )
from_state = to_state
def database_backwards(self, app_label, schema_editor, from_state, to_state):
@@ -54,7 +52,9 @@ class SeparateDatabaseAndState(Operation):
for database_operation in reversed(self.database_operations):
from_state = to_state
to_state = to_states[database_operation]
- database_operation.database_backwards(app_label, schema_editor, from_state, to_state)
+ database_operation.database_backwards(
+ app_label, schema_editor, from_state, to_state
+ )
def describe(self):
return "Custom state/database change combination"
@@ -67,9 +67,12 @@ class RunSQL(Operation):
Also accept a list of operations that represent the state change effected
by this SQL change, in case it's custom column/table creation/deletion.
"""
- noop = ''
- def __init__(self, sql, reverse_sql=None, state_operations=None, hints=None, elidable=False):
+ noop = ""
+
+ def __init__(
+ self, sql, reverse_sql=None, state_operations=None, hints=None, elidable=False
+ ):
self.sql = sql
self.reverse_sql = reverse_sql
self.state_operations = state_operations or []
@@ -78,19 +81,15 @@ class RunSQL(Operation):
def deconstruct(self):
kwargs = {
- 'sql': self.sql,
+ "sql": self.sql,
}
if self.reverse_sql is not None:
- kwargs['reverse_sql'] = self.reverse_sql
+ kwargs["reverse_sql"] = self.reverse_sql
if self.state_operations:
- kwargs['state_operations'] = self.state_operations
+ kwargs["state_operations"] = self.state_operations
if self.hints:
- kwargs['hints'] = self.hints
- return (
- self.__class__.__qualname__,
- [],
- kwargs
- )
+ kwargs["hints"] = self.hints
+ return (self.__class__.__qualname__, [], kwargs)
@property
def reversible(self):
@@ -101,13 +100,17 @@ class RunSQL(Operation):
state_operation.state_forwards(app_label, state)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
- if router.allow_migrate(schema_editor.connection.alias, app_label, **self.hints):
+ if router.allow_migrate(
+ schema_editor.connection.alias, app_label, **self.hints
+ ):
self._run_sql(schema_editor, self.sql)
def database_backwards(self, app_label, schema_editor, from_state, to_state):
if self.reverse_sql is None:
raise NotImplementedError("You cannot reverse this operation")
- if router.allow_migrate(schema_editor.connection.alias, app_label, **self.hints):
+ if router.allow_migrate(
+ schema_editor.connection.alias, app_label, **self.hints
+ ):
self._run_sql(schema_editor, self.reverse_sql)
def describe(self):
@@ -137,7 +140,9 @@ class RunPython(Operation):
reduces_to_sql = False
- def __init__(self, code, reverse_code=None, atomic=None, hints=None, elidable=False):
+ def __init__(
+ self, code, reverse_code=None, atomic=None, hints=None, elidable=False
+ ):
self.atomic = atomic
# Forwards code
if not callable(code):
@@ -155,19 +160,15 @@ class RunPython(Operation):
def deconstruct(self):
kwargs = {
- 'code': self.code,
+ "code": self.code,
}
if self.reverse_code is not None:
- kwargs['reverse_code'] = self.reverse_code
+ kwargs["reverse_code"] = self.reverse_code
if self.atomic is not None:
- kwargs['atomic'] = self.atomic
+ kwargs["atomic"] = self.atomic
if self.hints:
- kwargs['hints'] = self.hints
- return (
- self.__class__.__qualname__,
- [],
- kwargs
- )
+ kwargs["hints"] = self.hints
+ return (self.__class__.__qualname__, [], kwargs)
@property
def reversible(self):
@@ -182,7 +183,9 @@ class RunPython(Operation):
# RunPython has access to all models. Ensure that all models are
# reloaded in case any are delayed.
from_state.clear_delayed_apps_cache()
- if router.allow_migrate(schema_editor.connection.alias, app_label, **self.hints):
+ if router.allow_migrate(
+ schema_editor.connection.alias, app_label, **self.hints
+ ):
# We now execute the Python code in a context that contains a 'models'
# object, representing the versioned models as an app registry.
# We could try to override the global cache, but then people will still
@@ -192,7 +195,9 @@ class RunPython(Operation):
def database_backwards(self, app_label, schema_editor, from_state, to_state):
if self.reverse_code is None:
raise NotImplementedError("You cannot reverse this operation")
- if router.allow_migrate(schema_editor.connection.alias, app_label, **self.hints):
+ if router.allow_migrate(
+ schema_editor.connection.alias, app_label, **self.hints
+ ):
self.reverse_code(from_state.apps, schema_editor)
def describe(self):
diff --git a/django/db/migrations/optimizer.py b/django/db/migrations/optimizer.py
index ee20f62af2..7e5dea2377 100644
--- a/django/db/migrations/optimizer.py
+++ b/django/db/migrations/optimizer.py
@@ -28,7 +28,7 @@ class MigrationOptimizer:
"""
# Internal tracking variable for test assertions about # of loops
if app_label is None:
- raise TypeError('app_label must be a str.')
+ raise TypeError("app_label must be a str.")
self._iterations = 0
while True:
result = self.optimize_inner(operations, app_label)
@@ -43,10 +43,10 @@ class MigrationOptimizer:
for i, operation in enumerate(operations):
right = True # Should we reduce on the right or on the left.
# Compare it to each operation after it
- for j, other in enumerate(operations[i + 1:]):
+ for j, other in enumerate(operations[i + 1 :]):
result = operation.reduce(other, app_label)
if isinstance(result, list):
- in_between = operations[i + 1:i + j + 1]
+ in_between = operations[i + 1 : i + j + 1]
if right:
new_operations.extend(in_between)
new_operations.extend(result)
@@ -59,7 +59,7 @@ class MigrationOptimizer:
# Otherwise keep trying.
new_operations.append(operation)
break
- new_operations.extend(operations[i + j + 2:])
+ new_operations.extend(operations[i + j + 2 :])
return new_operations
elif not result:
# Can't perform a right reduction.
diff --git a/django/db/migrations/questioner.py b/django/db/migrations/questioner.py
index 3460e2b3ab..e1081ab70a 100644
--- a/django/db/migrations/questioner.py
+++ b/django/db/migrations/questioner.py
@@ -35,7 +35,7 @@ class MigrationQuestioner:
# file check will ensure we skip South ones.
try:
app_config = apps.get_app_config(app_label)
- except LookupError: # It's a fake app.
+ except LookupError: # It's a fake app.
return self.defaults.get("ask_initial", False)
migrations_import_path, _ = MigrationLoader.migrations_module(app_config.label)
if migrations_import_path is None:
@@ -88,25 +88,29 @@ class MigrationQuestioner:
class InteractiveMigrationQuestioner(MigrationQuestioner):
- def __init__(self, defaults=None, specified_apps=None, dry_run=None, prompt_output=None):
- super().__init__(defaults=defaults, specified_apps=specified_apps, dry_run=dry_run)
+ def __init__(
+ self, defaults=None, specified_apps=None, dry_run=None, prompt_output=None
+ ):
+ super().__init__(
+ defaults=defaults, specified_apps=specified_apps, dry_run=dry_run
+ )
self.prompt_output = prompt_output or OutputWrapper(sys.stdout)
def _boolean_input(self, question, default=None):
- self.prompt_output.write(f'{question} ', ending='')
+ self.prompt_output.write(f"{question} ", ending="")
result = input()
if not result and default is not None:
return default
while not result or result[0].lower() not in "yn":
- self.prompt_output.write('Please answer yes or no: ', ending='')
+ self.prompt_output.write("Please answer yes or no: ", ending="")
result = input()
return result[0].lower() == "y"
def _choice_input(self, question, choices):
- self.prompt_output.write(f'{question}')
+ self.prompt_output.write(f"{question}")
for i, choice in enumerate(choices):
- self.prompt_output.write(' %s) %s' % (i + 1, choice))
- self.prompt_output.write('Select an option: ', ending='')
+ self.prompt_output.write(" %s) %s" % (i + 1, choice))
+ self.prompt_output.write("Select an option: ", ending="")
result = input()
while True:
try:
@@ -116,10 +120,10 @@ class InteractiveMigrationQuestioner(MigrationQuestioner):
else:
if 0 < value <= len(choices):
return value
- self.prompt_output.write('Please select a valid option: ', ending='')
+ self.prompt_output.write("Please select a valid option: ", ending="")
result = input()
- def _ask_default(self, default=''):
+ def _ask_default(self, default=""):
"""
Prompt for a default value.
@@ -127,15 +131,15 @@ class InteractiveMigrationQuestioner(MigrationQuestioner):
string) which will be shown to the user and used as the return value
if the user doesn't provide any other input.
"""
- self.prompt_output.write('Please enter the default value as valid Python.')
+ self.prompt_output.write("Please enter the default value as valid Python.")
if default:
self.prompt_output.write(
f"Accept the default '{default}' by pressing 'Enter' or "
f"provide another value."
)
self.prompt_output.write(
- 'The datetime and django.utils.timezone modules are available, so '
- 'it is possible to provide e.g. timezone.now as a value.'
+ "The datetime and django.utils.timezone modules are available, so "
+ "it is possible to provide e.g. timezone.now as a value."
)
self.prompt_output.write("Type 'exit' to exit this prompt")
while True:
@@ -143,19 +147,21 @@ class InteractiveMigrationQuestioner(MigrationQuestioner):
prompt = "[default: {}] >>> ".format(default)
else:
prompt = ">>> "
- self.prompt_output.write(prompt, ending='')
+ self.prompt_output.write(prompt, ending="")
code = input()
if not code and default:
code = default
if not code:
- self.prompt_output.write("Please enter some code, or 'exit' (without quotes) to exit.")
+ self.prompt_output.write(
+ "Please enter some code, or 'exit' (without quotes) to exit."
+ )
elif code == "exit":
sys.exit(1)
else:
try:
- return eval(code, {}, {'datetime': datetime, 'timezone': timezone})
+ return eval(code, {}, {"datetime": datetime, "timezone": timezone})
except (SyntaxError, NameError) as e:
- self.prompt_output.write('Invalid input: %s' % e)
+ self.prompt_output.write("Invalid input: %s" % e)
def ask_not_null_addition(self, field_name, model_name):
"""Adding a NOT NULL field to a model."""
@@ -167,10 +173,12 @@ class InteractiveMigrationQuestioner(MigrationQuestioner):
f"rows.\n"
f"Please select a fix:",
[
- ("Provide a one-off default now (will be set on all existing "
- "rows with a null value for this column)"),
- 'Quit and manually define a default value in models.py.',
- ]
+ (
+ "Provide a one-off default now (will be set on all existing "
+ "rows with a null value for this column)"
+ ),
+ "Quit and manually define a default value in models.py.",
+ ],
)
if choice == 2:
sys.exit(3)
@@ -188,13 +196,15 @@ class InteractiveMigrationQuestioner(MigrationQuestioner):
f"populate existing rows.\n"
f"Please select a fix:",
[
- ("Provide a one-off default now (will be set on all existing "
- "rows with a null value for this column)"),
- 'Ignore for now. Existing rows that contain NULL values '
- 'will have to be handled manually, for example with a '
- 'RunPython or RunSQL operation.',
- 'Quit and manually define a default value in models.py.',
- ]
+ (
+ "Provide a one-off default now (will be set on all existing "
+ "rows with a null value for this column)"
+ ),
+ "Ignore for now. Existing rows that contain NULL values "
+ "will have to be handled manually, for example with a "
+ "RunPython or RunSQL operation.",
+ "Quit and manually define a default value in models.py.",
+ ],
)
if choice == 2:
return NOT_PROVIDED
@@ -206,21 +216,33 @@ class InteractiveMigrationQuestioner(MigrationQuestioner):
def ask_rename(self, model_name, old_name, new_name, field_instance):
"""Was this field really renamed?"""
- msg = 'Was %s.%s renamed to %s.%s (a %s)? [y/N]'
- return self._boolean_input(msg % (model_name, old_name, model_name, new_name,
- field_instance.__class__.__name__), False)
+ msg = "Was %s.%s renamed to %s.%s (a %s)? [y/N]"
+ return self._boolean_input(
+ msg
+ % (
+ model_name,
+ old_name,
+ model_name,
+ new_name,
+ field_instance.__class__.__name__,
+ ),
+ False,
+ )
def ask_rename_model(self, old_model_state, new_model_state):
"""Was this model really renamed?"""
- msg = 'Was the model %s.%s renamed to %s? [y/N]'
- return self._boolean_input(msg % (old_model_state.app_label, old_model_state.name,
- new_model_state.name), False)
+ msg = "Was the model %s.%s renamed to %s? [y/N]"
+ return self._boolean_input(
+ msg
+ % (old_model_state.app_label, old_model_state.name, new_model_state.name),
+ False,
+ )
def ask_merge(self, app_label):
return self._boolean_input(
- "\nMerging will only work if the operations printed above do not conflict\n" +
- "with each other (working on different fields or models)\n" +
- 'Should these migration branches be merged? [y/N]',
+ "\nMerging will only work if the operations printed above do not conflict\n"
+ + "with each other (working on different fields or models)\n"
+ + "Should these migration branches be merged? [y/N]",
False,
)
@@ -233,15 +255,15 @@ class InteractiveMigrationQuestioner(MigrationQuestioner):
f"default. This is because the database needs something to "
f"populate existing rows.\n",
[
- 'Provide a one-off default now which will be set on all '
- 'existing rows',
- 'Quit and manually define a default value in models.py.',
- ]
+ "Provide a one-off default now which will be set on all "
+ "existing rows",
+ "Quit and manually define a default value in models.py.",
+ ],
)
if choice == 2:
sys.exit(3)
else:
- return self._ask_default(default='timezone.now')
+ return self._ask_default(default="timezone.now")
return None
def ask_unique_callable_default_addition(self, field_name, model_name):
@@ -249,16 +271,16 @@ class InteractiveMigrationQuestioner(MigrationQuestioner):
if not self.dry_run:
version = get_docs_version()
choice = self._choice_input(
- f'Callable default on unique field {model_name}.{field_name} '
- f'will not generate unique values upon migrating.\n'
- f'Please choose how to proceed:\n',
+ f"Callable default on unique field {model_name}.{field_name} "
+ f"will not generate unique values upon migrating.\n"
+ f"Please choose how to proceed:\n",
[
- f'Continue making this migration as the first step in '
- f'writing a manual migration to generate unique values '
- f'described here: '
- f'https://docs.djangoproject.com/en/{version}/howto/'
- f'writing-migrations/#migrations-that-add-unique-fields.',
- 'Quit and edit field options in models.py.',
+ f"Continue making this migration as the first step in "
+ f"writing a manual migration to generate unique values "
+ f"described here: "
+ f"https://docs.djangoproject.com/en/{version}/howto/"
+ f"writing-migrations/#migrations-that-add-unique-fields.",
+ "Quit and edit field options in models.py.",
],
)
if choice == 2:
@@ -268,13 +290,19 @@ class InteractiveMigrationQuestioner(MigrationQuestioner):
class NonInteractiveMigrationQuestioner(MigrationQuestioner):
def __init__(
- self, defaults=None, specified_apps=None, dry_run=None, verbosity=1,
+ self,
+ defaults=None,
+ specified_apps=None,
+ dry_run=None,
+ verbosity=1,
log=None,
):
self.verbosity = verbosity
self.log = log
super().__init__(
- defaults=defaults, specified_apps=specified_apps, dry_run=dry_run,
+ defaults=defaults,
+ specified_apps=specified_apps,
+ dry_run=dry_run,
)
def log_lack_of_migration(self, field_name, model_name, reason):
@@ -289,8 +317,8 @@ class NonInteractiveMigrationQuestioner(MigrationQuestioner):
self.log_lack_of_migration(
field_name,
model_name,
- 'it is impossible to add a non-nullable field without specifying '
- 'a default',
+ "it is impossible to add a non-nullable field without specifying "
+ "a default",
)
sys.exit(3)
diff --git a/django/db/migrations/recorder.py b/django/db/migrations/recorder.py
index 1a37c6b7d0..50876a9ee3 100644
--- a/django/db/migrations/recorder.py
+++ b/django/db/migrations/recorder.py
@@ -18,6 +18,7 @@ class MigrationRecorder:
If a migration is unapplied its row is removed from the table. Having
a row in the table always means a migration is applied.
"""
+
_migration_class = None
@classproperty
@@ -27,6 +28,7 @@ class MigrationRecorder:
MigrationRecorder.
"""
if cls._migration_class is None:
+
class Migration(models.Model):
app = models.CharField(max_length=255)
name = models.CharField(max_length=255)
@@ -34,11 +36,11 @@ class MigrationRecorder:
class Meta:
apps = Apps()
- app_label = 'migrations'
- db_table = 'django_migrations'
+ app_label = "migrations"
+ db_table = "django_migrations"
def __str__(self):
- return 'Migration %s for %s' % (self.name, self.app)
+ return "Migration %s for %s" % (self.name, self.app)
cls._migration_class = Migration
return cls._migration_class
@@ -67,7 +69,9 @@ class MigrationRecorder:
with self.connection.schema_editor() as editor:
editor.create_model(self.Migration)
except DatabaseError as exc:
- raise MigrationSchemaMissing("Unable to create the django_migrations table (%s)" % exc)
+ raise MigrationSchemaMissing(
+ "Unable to create the django_migrations table (%s)" % exc
+ )
def applied_migrations(self):
"""
@@ -75,7 +79,10 @@ class MigrationRecorder:
for all applied migrations.
"""
if self.has_table():
- return {(migration.app, migration.name): migration for migration in self.migration_qs}
+ return {
+ (migration.app, migration.name): migration
+ for migration in self.migration_qs
+ }
else:
# If the django_migrations table doesn't exist, then no migrations
# are applied.
diff --git a/django/db/migrations/serializer.py b/django/db/migrations/serializer.py
index 9c58f38e28..fb4a1964d9 100644
--- a/django/db/migrations/serializer.py
+++ b/django/db/migrations/serializer.py
@@ -25,12 +25,16 @@ class BaseSerializer:
self.value = value
def serialize(self):
- raise NotImplementedError('Subclasses of BaseSerializer must implement the serialize() method.')
+ raise NotImplementedError(
+ "Subclasses of BaseSerializer must implement the serialize() method."
+ )
class BaseSequenceSerializer(BaseSerializer):
def _format(self):
- raise NotImplementedError('Subclasses of BaseSequenceSerializer must implement the _format() method.')
+ raise NotImplementedError(
+ "Subclasses of BaseSequenceSerializer must implement the _format() method."
+ )
def serialize(self):
imports = set()
@@ -55,19 +59,21 @@ class ChoicesSerializer(BaseSerializer):
class DateTimeSerializer(BaseSerializer):
"""For datetime.*, except datetime.datetime."""
+
def serialize(self):
- return repr(self.value), {'import datetime'}
+ return repr(self.value), {"import datetime"}
class DatetimeDatetimeSerializer(BaseSerializer):
"""For datetime.datetime."""
+
def serialize(self):
if self.value.tzinfo is not None and self.value.tzinfo != utc:
self.value = self.value.astimezone(utc)
imports = ["import datetime"]
if self.value.tzinfo is not None:
imports.append("from django.utils.timezone import utc")
- return repr(self.value).replace('datetime.timezone.utc', 'utc'), set(imports)
+ return repr(self.value).replace("datetime.timezone.utc", "utc"), set(imports)
class DecimalSerializer(BaseSerializer):
@@ -123,8 +129,8 @@ class EnumSerializer(BaseSerializer):
enum_class = self.value.__class__
module = enum_class.__module__
return (
- '%s.%s[%r]' % (module, enum_class.__qualname__, self.value.name),
- {'import %s' % module},
+ "%s.%s[%r]" % (module, enum_class.__qualname__, self.value.name),
+ {"import %s" % module},
)
@@ -142,23 +148,29 @@ class FrozensetSerializer(BaseSequenceSerializer):
class FunctionTypeSerializer(BaseSerializer):
def serialize(self):
- if getattr(self.value, "__self__", None) and isinstance(self.value.__self__, type):
+ if getattr(self.value, "__self__", None) and isinstance(
+ self.value.__self__, type
+ ):
klass = self.value.__self__
module = klass.__module__
- return "%s.%s.%s" % (module, klass.__name__, self.value.__name__), {"import %s" % module}
+ return "%s.%s.%s" % (module, klass.__name__, self.value.__name__), {
+ "import %s" % module
+ }
# Further error checking
- if self.value.__name__ == '<lambda>':
+ if self.value.__name__ == "<lambda>":
raise ValueError("Cannot serialize function: lambda")
if self.value.__module__ is None:
raise ValueError("Cannot serialize function %r: No module" % self.value)
module_name = self.value.__module__
- if '<' not in self.value.__qualname__: # Qualname can include <locals>
- return '%s.%s' % (module_name, self.value.__qualname__), {'import %s' % self.value.__module__}
+ if "<" not in self.value.__qualname__: # Qualname can include <locals>
+ return "%s.%s" % (module_name, self.value.__qualname__), {
+ "import %s" % self.value.__module__
+ }
raise ValueError(
- 'Could not find function %s in %s.\n' % (self.value.__name__, module_name)
+ "Could not find function %s in %s.\n" % (self.value.__name__, module_name)
)
@@ -167,11 +179,14 @@ class FunctoolsPartialSerializer(BaseSerializer):
# Serialize functools.partial() arguments
func_string, func_imports = serializer_factory(self.value.func).serialize()
args_string, args_imports = serializer_factory(self.value.args).serialize()
- keywords_string, keywords_imports = serializer_factory(self.value.keywords).serialize()
+ keywords_string, keywords_imports = serializer_factory(
+ self.value.keywords
+ ).serialize()
# Add any imports needed by arguments
- imports = {'import functools', *func_imports, *args_imports, *keywords_imports}
+ imports = {"import functools", *func_imports, *args_imports, *keywords_imports}
return (
- 'functools.%s(%s, *%s, **%s)' % (
+ "functools.%s(%s, *%s, **%s)"
+ % (
self.value.__class__.__name__,
func_string,
args_string,
@@ -214,9 +229,10 @@ class ModelManagerSerializer(DeconstructableSerializer):
class OperationSerializer(BaseSerializer):
def serialize(self):
from django.db.migrations.writer import OperationWriter
+
string, imports = OperationWriter(self.value, indentation=0).serialize()
# Nested operation, trailing comma is handled in upper OperationWriter._write()
- return string.rstrip(','), imports
+ return string.rstrip(","), imports
class PathLikeSerializer(BaseSerializer):
@@ -228,22 +244,24 @@ class PathSerializer(BaseSerializer):
def serialize(self):
# Convert concrete paths to pure paths to avoid issues with migrations
# generated on one platform being used on a different platform.
- prefix = 'Pure' if isinstance(self.value, pathlib.Path) else ''
- return 'pathlib.%s%r' % (prefix, self.value), {'import pathlib'}
+ prefix = "Pure" if isinstance(self.value, pathlib.Path) else ""
+ return "pathlib.%s%r" % (prefix, self.value), {"import pathlib"}
class RegexSerializer(BaseSerializer):
def serialize(self):
- regex_pattern, pattern_imports = serializer_factory(self.value.pattern).serialize()
+ regex_pattern, pattern_imports = serializer_factory(
+ self.value.pattern
+ ).serialize()
# Turn off default implicit flags (e.g. re.U) because regexes with the
# same implicit and explicit flags aren't equal.
- flags = self.value.flags ^ re.compile('').flags
+ flags = self.value.flags ^ re.compile("").flags
regex_flags, flag_imports = serializer_factory(flags).serialize()
- imports = {'import re', *pattern_imports, *flag_imports}
+ imports = {"import re", *pattern_imports, *flag_imports}
args = [regex_pattern]
if flags:
args.append(regex_flags)
- return "re.compile(%s)" % ', '.join(args), imports
+ return "re.compile(%s)" % ", ".join(args), imports
class SequenceSerializer(BaseSequenceSerializer):
@@ -255,12 +273,14 @@ class SetSerializer(BaseSequenceSerializer):
def _format(self):
# Serialize as a set literal except when value is empty because {}
# is an empty dict.
- return '{%s}' if self.value else 'set(%s)'
+ return "{%s}" if self.value else "set(%s)"
class SettingsReferenceSerializer(BaseSerializer):
def serialize(self):
- return "settings.%s" % self.value.setting_name, {"from django.conf import settings"}
+ return "settings.%s" % self.value.setting_name, {
+ "from django.conf import settings"
+ }
class TupleSerializer(BaseSequenceSerializer):
@@ -273,8 +293,8 @@ class TupleSerializer(BaseSequenceSerializer):
class TypeSerializer(BaseSerializer):
def serialize(self):
special_cases = [
- (models.Model, "models.Model", ['from django.db import models']),
- (type(None), 'type(None)', []),
+ (models.Model, "models.Model", ["from django.db import models"]),
+ (type(None), "type(None)", []),
]
for case, string, imports in special_cases:
if case is self.value:
@@ -284,7 +304,9 @@ class TypeSerializer(BaseSerializer):
if module == builtins.__name__:
return self.value.__name__, set()
else:
- return "%s.%s" % (module, self.value.__qualname__), {"import %s" % module}
+ return "%s.%s" % (module, self.value.__qualname__), {
+ "import %s" % module
+ }
class UUIDSerializer(BaseSerializer):
@@ -309,7 +331,11 @@ class Serializer:
(bool, int, type(None), bytes, str, range): BaseSimpleSerializer,
decimal.Decimal: DecimalSerializer,
(functools.partial, functools.partialmethod): FunctoolsPartialSerializer,
- (types.FunctionType, types.BuiltinFunctionType, types.MethodType): FunctionTypeSerializer,
+ (
+ types.FunctionType,
+ types.BuiltinFunctionType,
+ types.MethodType,
+ ): FunctionTypeSerializer,
collections.abc.Iterable: IterableSerializer,
(COMPILED_REGEX_TYPE, RegexObject): RegexSerializer,
uuid.UUID: UUIDSerializer,
@@ -320,7 +346,9 @@ class Serializer:
@classmethod
def register(cls, type_, serializer):
if not issubclass(serializer, BaseSerializer):
- raise ValueError("'%s' must inherit from 'BaseSerializer'." % serializer.__name__)
+ raise ValueError(
+ "'%s' must inherit from 'BaseSerializer'." % serializer.__name__
+ )
cls._registry[type_] = serializer
@classmethod
@@ -345,7 +373,7 @@ def serializer_factory(value):
if isinstance(value, type):
return TypeSerializer(value)
# Anything that knows how to deconstruct itself.
- if hasattr(value, 'deconstruct'):
+ if hasattr(value, "deconstruct"):
return DeconstructableSerializer(value)
for type_, serializer_cls in Serializer._registry.items():
if isinstance(value, type_):
diff --git a/django/db/migrations/state.py b/django/db/migrations/state.py
index dfb51d579c..0c9416ae45 100644
--- a/django/db/migrations/state.py
+++ b/django/db/migrations/state.py
@@ -4,7 +4,8 @@ from contextlib import contextmanager
from functools import partial
from django.apps import AppConfig
-from django.apps.registry import Apps, apps as global_apps
+from django.apps.registry import Apps
+from django.apps.registry import apps as global_apps
from django.conf import settings
from django.core.exceptions import FieldDoesNotExist
from django.db import models
@@ -21,9 +22,9 @@ from .exceptions import InvalidBasesError
from .utils import resolve_relation
-def _get_app_label_and_model_name(model, app_label=''):
+def _get_app_label_and_model_name(model, app_label=""):
if isinstance(model, str):
- split = model.split('.', 1)
+ split = model.split(".", 1)
return tuple(split) if len(split) == 2 else (app_label, split[0])
else:
return model._meta.app_label, model._meta.model_name
@@ -32,12 +33,17 @@ def _get_app_label_and_model_name(model, app_label=''):
def _get_related_models(m):
"""Return all models that have a direct relationship to the given model."""
related_models = [
- subclass for subclass in m.__subclasses__()
+ subclass
+ for subclass in m.__subclasses__()
if issubclass(subclass, models.Model)
]
related_fields_models = set()
for f in m._meta.get_fields(include_parents=True, include_hidden=True):
- if f.is_relation and f.related_model is not None and not isinstance(f.related_model, str):
+ if (
+ f.is_relation
+ and f.related_model is not None
+ and not isinstance(f.related_model, str)
+ ):
related_fields_models.add(f.model)
related_models.append(f.related_model)
# Reverse accessors of foreign keys to proxy models are attached to their
@@ -73,7 +79,10 @@ def get_related_models_recursive(model):
seen = set()
queue = _get_related_models(model)
for rel_mod in queue:
- rel_app_label, rel_model_name = rel_mod._meta.app_label, rel_mod._meta.model_name
+ rel_app_label, rel_model_name = (
+ rel_mod._meta.app_label,
+ rel_mod._meta.model_name,
+ )
if (rel_app_label, rel_model_name) in seen:
continue
seen.add((rel_app_label, rel_model_name))
@@ -111,7 +120,7 @@ class ProjectState:
self.models[model_key] = model_state
if self._relations is not None:
self.resolve_model_relations(model_key)
- if 'apps' in self.__dict__: # hasattr would cache the property
+ if "apps" in self.__dict__: # hasattr would cache the property
self.reload_model(*model_key)
def remove_model(self, app_label, model_name):
@@ -124,7 +133,7 @@ class ProjectState:
model_relations.pop(model_key, None)
if not model_relations:
del self._relations[related_model_key]
- if 'apps' in self.__dict__: # hasattr would cache the property
+ if "apps" in self.__dict__: # hasattr would cache the property
self.apps.unregister_model(*model_key)
# Need to do this explicitly since unregister_model() doesn't clear
# the cache automatically (#24513)
@@ -139,9 +148,11 @@ class ProjectState:
self.models[app_label, new_name_lower] = renamed_model
# Repoint all fields pointing to the old model to the new one.
old_model_tuple = (app_label, old_name_lower)
- new_remote_model = f'{app_label}.{new_name}'
+ new_remote_model = f"{app_label}.{new_name}"
to_reload = set()
- for model_state, name, field, reference in get_references(self, old_model_tuple):
+ for model_state, name, field, reference in get_references(
+ self, old_model_tuple
+ ):
changed_field = None
if reference.to:
changed_field = field.clone()
@@ -193,16 +204,16 @@ class ProjectState:
self.reload_model(app_label, model_name, delay=True)
def add_index(self, app_label, model_name, index):
- self._append_option(app_label, model_name, 'indexes', index)
+ self._append_option(app_label, model_name, "indexes", index)
def remove_index(self, app_label, model_name, index_name):
- self._remove_option(app_label, model_name, 'indexes', index_name)
+ self._remove_option(app_label, model_name, "indexes", index_name)
def add_constraint(self, app_label, model_name, constraint):
- self._append_option(app_label, model_name, 'constraints', constraint)
+ self._append_option(app_label, model_name, "constraints", constraint)
def remove_constraint(self, app_label, model_name, constraint_name):
- self._remove_option(app_label, model_name, 'constraints', constraint_name)
+ self._remove_option(app_label, model_name, "constraints", constraint_name)
def add_field(self, app_label, model_name, name, field, preserve_default):
# If preserve default is off, don't use the default for future state.
@@ -250,9 +261,8 @@ class ProjectState:
# it's sufficient if the new field is (#27737).
# Delay rendering of relationships if it's not a relational field and
# not referenced by a foreign key.
- delay = (
- not field.is_relation and
- not field_is_referenced(self, model_key, (name, field))
+ delay = not field.is_relation and not field_is_referenced(
+ self, model_key, (name, field)
)
self.reload_model(*model_key, delay=delay)
@@ -270,15 +280,17 @@ class ProjectState:
fields[new_name] = found
for field in fields.values():
# Fix from_fields to refer to the new field.
- from_fields = getattr(field, 'from_fields', None)
+ from_fields = getattr(field, "from_fields", None)
if from_fields:
- field.from_fields = tuple([
- new_name if from_field_name == old_name else from_field_name
- for from_field_name in from_fields
- ])
+ field.from_fields = tuple(
+ [
+ new_name if from_field_name == old_name else from_field_name
+ for from_field_name in from_fields
+ ]
+ )
# Fix index/unique_together to refer to the new field.
options = model_state.options
- for option in ('index_together', 'unique_together'):
+ for option in ("index_together", "unique_together"):
if option in options:
options[option] = [
[new_name if n == old_name else n for n in together]
@@ -291,13 +303,15 @@ class ProjectState:
delay = False
if reference.to:
remote_field, to_fields = reference.to
- if getattr(remote_field, 'field_name', None) == old_name:
+ if getattr(remote_field, "field_name", None) == old_name:
remote_field.field_name = new_name
if to_fields:
- field.to_fields = tuple([
- new_name if to_field_name == old_name else to_field_name
- for to_field_name in to_fields
- ])
+ field.to_fields = tuple(
+ [
+ new_name if to_field_name == old_name else to_field_name
+ for to_field_name in to_fields
+ ]
+ )
if self._relations is not None:
old_name_lower = old_name.lower()
new_name_lower = new_name.lower()
@@ -335,7 +349,9 @@ class ProjectState:
if field.is_relation:
if field.remote_field.model == RECURSIVE_RELATIONSHIP_CONSTANT:
continue
- rel_app_label, rel_model_name = _get_app_label_and_model_name(field.related_model, app_label)
+ rel_app_label, rel_model_name = _get_app_label_and_model_name(
+ field.related_model, app_label
+ )
direct_related_models.add((rel_app_label, rel_model_name.lower()))
# For all direct related models recursively get all related models.
@@ -357,15 +373,17 @@ class ProjectState:
return related_models
def reload_model(self, app_label, model_name, delay=False):
- if 'apps' in self.__dict__: # hasattr would cache the property
+ if "apps" in self.__dict__: # hasattr would cache the property
related_models = self._find_reload_model(app_label, model_name, delay)
self._reload(related_models)
def reload_models(self, models, delay=True):
- if 'apps' in self.__dict__: # hasattr would cache the property
+ if "apps" in self.__dict__: # hasattr would cache the property
related_models = set()
for app_label, model_name in models:
- related_models.update(self._find_reload_model(app_label, model_name, delay))
+ related_models.update(
+ self._find_reload_model(app_label, model_name, delay)
+ )
self._reload(related_models)
def _reload(self, related_models):
@@ -395,7 +413,12 @@ class ProjectState:
self.apps.render_multiple(states_to_be_rendered)
def update_model_field_relation(
- self, model, model_key, field_name, field, concretes,
+ self,
+ model,
+ model_key,
+ field_name,
+ field,
+ concretes,
):
remote_model_key = resolve_relation(model, *model_key)
if remote_model_key[0] not in self.real_apps and remote_model_key in concretes:
@@ -413,7 +436,11 @@ class ProjectState:
del relations_to_remote_model[model_key]
def resolve_model_field_relations(
- self, model_key, field_name, field, concretes=None,
+ self,
+ model_key,
+ field_name,
+ field,
+ concretes=None,
):
remote_field = field.remote_field
if not remote_field:
@@ -422,13 +449,19 @@ class ProjectState:
concretes, _ = self._get_concrete_models_mapping_and_proxy_models()
self.update_model_field_relation(
- remote_field.model, model_key, field_name, field, concretes,
+ remote_field.model,
+ model_key,
+ field_name,
+ field,
+ concretes,
)
- through = getattr(remote_field, 'through', None)
+ through = getattr(remote_field, "through", None)
if not through:
return
- self.update_model_field_relation(through, model_key, field_name, field, concretes)
+ self.update_model_field_relation(
+ through, model_key, field_name, field, concretes
+ )
def resolve_model_relations(self, model_key, concretes=None):
if concretes is None:
@@ -455,7 +488,10 @@ class ProjectState:
self._relations[model_key] = self._relations[concretes[model_key]]
def get_concrete_model_key(self, model):
- concrete_models_mapping, _ = self._get_concrete_models_mapping_and_proxy_models()
+ (
+ concrete_models_mapping,
+ _,
+ ) = self._get_concrete_models_mapping_and_proxy_models()
model_key = make_model_tuple(model)
return concrete_models_mapping[model_key]
@@ -464,11 +500,14 @@ class ProjectState:
proxy_models = {}
# Split models to proxy and concrete models.
for model_key, model_state in self.models.items():
- if model_state.options.get('proxy'):
+ if model_state.options.get("proxy"):
proxy_models[model_key] = model_state
# Find a concrete model for the proxy.
- concrete_models_mapping[model_key] = self._find_concrete_model_from_proxy(
- proxy_models, model_state,
+ concrete_models_mapping[
+ model_key
+ ] = self._find_concrete_model_from_proxy(
+ proxy_models,
+ model_state,
)
else:
concrete_models_mapping[model_key] = model_key
@@ -491,14 +530,14 @@ class ProjectState:
models={k: v.clone() for k, v in self.models.items()},
real_apps=self.real_apps,
)
- if 'apps' in self.__dict__:
+ if "apps" in self.__dict__:
new_state.apps = self.apps.clone()
new_state.is_delayed = self.is_delayed
return new_state
def clear_delayed_apps_cache(self):
- if self.is_delayed and 'apps' in self.__dict__:
- del self.__dict__['apps']
+ if self.is_delayed and "apps" in self.__dict__:
+ del self.__dict__["apps"]
@cached_property
def apps(self):
@@ -519,6 +558,7 @@ class ProjectState:
class AppConfigStub(AppConfig):
"""Stub of an AppConfig. Only provides a label and a dict of models."""
+
def __init__(self, label):
self.apps = None
self.models = {}
@@ -537,6 +577,7 @@ class StateApps(Apps):
Subclass of the global Apps registry class to better handle dynamic model
additions and removals.
"""
+
def __init__(self, real_apps, models, ignore_swappable=False):
# Any apps in self.real_apps should have all their models included
# in the render. We don't use the original model instances as there
@@ -550,7 +591,9 @@ class StateApps(Apps):
self.real_models.append(ModelState.from_model(model, exclude_rels=True))
# Populate the app registry with a stub for each application.
app_labels = {model_state.app_label for model_state in models.values()}
- app_configs = [AppConfigStub(label) for label in sorted([*real_apps, *app_labels])]
+ app_configs = [
+ AppConfigStub(label) for label in sorted([*real_apps, *app_labels])
+ ]
super().__init__(app_configs)
# These locks get in the way of copying as implemented in clone(),
@@ -563,7 +606,10 @@ class StateApps(Apps):
# There shouldn't be any operations pending at this point.
from django.core.checks.model_checks import _check_lazy_references
- ignore = {make_model_tuple(settings.AUTH_USER_MODEL)} if ignore_swappable else set()
+
+ ignore = (
+ {make_model_tuple(settings.AUTH_USER_MODEL)} if ignore_swappable else set()
+ )
errors = _check_lazy_references(self, ignore=ignore)
if errors:
raise ValueError("\n".join(error.msg for error in errors))
@@ -646,34 +692,36 @@ class ModelState:
assign new ones, as these are not detached during a clone.
"""
- def __init__(self, app_label, name, fields, options=None, bases=None, managers=None):
+ def __init__(
+ self, app_label, name, fields, options=None, bases=None, managers=None
+ ):
self.app_label = app_label
self.name = name
self.fields = dict(fields)
self.options = options or {}
- self.options.setdefault('indexes', [])
- self.options.setdefault('constraints', [])
+ self.options.setdefault("indexes", [])
+ self.options.setdefault("constraints", [])
self.bases = bases or (models.Model,)
self.managers = managers or []
for name, field in self.fields.items():
# Sanity-check that fields are NOT already bound to a model.
- if hasattr(field, 'model'):
+ if hasattr(field, "model"):
raise ValueError(
'ModelState.fields cannot be bound to a model - "%s" is.' % name
)
# Sanity-check that relation fields are NOT referring to a model class.
- if field.is_relation and hasattr(field.related_model, '_meta'):
+ if field.is_relation and hasattr(field.related_model, "_meta"):
raise ValueError(
'ModelState.fields cannot refer to a model class - "%s.to" does. '
- 'Use a string reference instead.' % name
+ "Use a string reference instead." % name
)
- if field.many_to_many and hasattr(field.remote_field.through, '_meta'):
+ if field.many_to_many and hasattr(field.remote_field.through, "_meta"):
raise ValueError(
'ModelState.fields cannot refer to a model class - "%s.through" does. '
- 'Use a string reference instead.' % name
+ "Use a string reference instead." % name
)
# Sanity-check that indexes have their name set.
- for index in self.options['indexes']:
+ for index in self.options["indexes"]:
if not index.name:
raise ValueError(
"Indexes passed to ModelState require a name attribute. "
@@ -685,8 +733,8 @@ class ModelState:
return self.name.lower()
def get_field(self, field_name):
- if field_name == '_order':
- field_name = self.options.get('order_with_respect_to', field_name)
+ if field_name == "_order":
+ field_name = self.options.get("order_with_respect_to", field_name)
return self.fields[field_name]
@classmethod
@@ -703,22 +751,28 @@ class ModelState:
try:
fields.append((name, field.clone()))
except TypeError as e:
- raise TypeError("Couldn't reconstruct field %s on %s: %s" % (
- name,
- model._meta.label,
- e,
- ))
+ raise TypeError(
+ "Couldn't reconstruct field %s on %s: %s"
+ % (
+ name,
+ model._meta.label,
+ e,
+ )
+ )
if not exclude_rels:
for field in model._meta.local_many_to_many:
name = field.name
try:
fields.append((name, field.clone()))
except TypeError as e:
- raise TypeError("Couldn't reconstruct m2m field %s on %s: %s" % (
- name,
- model._meta.object_name,
- e,
- ))
+ raise TypeError(
+ "Couldn't reconstruct m2m field %s on %s: %s"
+ % (
+ name,
+ model._meta.object_name,
+ e,
+ )
+ )
# Extract the options
options = {}
for name in DEFAULT_NAMES:
@@ -737,9 +791,11 @@ class ModelState:
for index in indexes:
if not index.name:
index.set_name_with_model(model)
- options['indexes'] = indexes
- elif name == 'constraints':
- options['constraints'] = [con.clone() for con in model._meta.constraints]
+ options["indexes"] = indexes
+ elif name == "constraints":
+ options["constraints"] = [
+ con.clone() for con in model._meta.constraints
+ ]
else:
options[name] = model._meta.original_attrs[name]
# If we're ignoring relationships, remove all field-listing model
@@ -749,8 +805,10 @@ class ModelState:
if key in options:
del options[key]
# Private fields are ignored, so remove options that refer to them.
- elif options.get('order_with_respect_to') in {field.name for field in model._meta.private_fields}:
- del options['order_with_respect_to']
+ elif options.get("order_with_respect_to") in {
+ field.name for field in model._meta.private_fields
+ }:
+ del options["order_with_respect_to"]
def flatten_bases(model):
bases = []
@@ -766,19 +824,19 @@ class ModelState:
# __bases__ we may end up with duplicates and ordering issues, we
# therefore discard any duplicates and reorder the bases according
# to their index in the MRO.
- flattened_bases = sorted(set(flatten_bases(model)), key=lambda x: model.__mro__.index(x))
+ flattened_bases = sorted(
+ set(flatten_bases(model)), key=lambda x: model.__mro__.index(x)
+ )
# Make our record
bases = tuple(
- (
- base._meta.label_lower
- if hasattr(base, "_meta") else
- base
- )
+ (base._meta.label_lower if hasattr(base, "_meta") else base)
for base in flattened_bases
)
# Ensure at least one base inherits from models.Model
- if not any((isinstance(base, str) or issubclass(base, models.Model)) for base in bases):
+ if not any(
+ (isinstance(base, str) or issubclass(base, models.Model)) for base in bases
+ ):
bases = (models.Model,)
managers = []
@@ -805,7 +863,7 @@ class ModelState:
managers.append((manager.name, new_manager))
# Ignore a shimmed default manager called objects if it's the only one.
- if managers == [('objects', default_manager_shim)]:
+ if managers == [("objects", default_manager_shim)]:
managers = []
# Construct the new ModelState
@@ -848,7 +906,7 @@ class ModelState:
def render(self, apps):
"""Create a Model object from our current state into the given apps."""
# First, make a Meta object
- meta_contents = {'app_label': self.app_label, 'apps': apps, **self.options}
+ meta_contents = {"app_label": self.app_label, "apps": apps, **self.options}
meta = type("Meta", (), meta_contents)
# Then, work out our bases
try:
@@ -857,11 +915,13 @@ class ModelState:
for base in self.bases
)
except LookupError:
- raise InvalidBasesError("Cannot resolve one or more bases from %r" % (self.bases,))
+ raise InvalidBasesError(
+ "Cannot resolve one or more bases from %r" % (self.bases,)
+ )
# Clone fields for the body, add other bits.
body = {name: field.clone() for name, field in self.fields.items()}
- body['Meta'] = meta
- body['__module__'] = "__fake__"
+ body["Meta"] = meta
+ body["__module__"] = "__fake__"
# Restore managers
body.update(self.construct_managers())
@@ -869,33 +929,33 @@ class ModelState:
return type(self.name, bases, body)
def get_index_by_name(self, name):
- for index in self.options['indexes']:
+ for index in self.options["indexes"]:
if index.name == name:
return index
raise ValueError("No index named %s on model %s" % (name, self.name))
def get_constraint_by_name(self, name):
- for constraint in self.options['constraints']:
+ for constraint in self.options["constraints"]:
if constraint.name == name:
return constraint
- raise ValueError('No constraint named %s on model %s' % (name, self.name))
+ raise ValueError("No constraint named %s on model %s" % (name, self.name))
def __repr__(self):
return "<%s: '%s.%s'>" % (self.__class__.__name__, self.app_label, self.name)
def __eq__(self, other):
return (
- (self.app_label == other.app_label) and
- (self.name == other.name) and
- (len(self.fields) == len(other.fields)) and
- all(
+ (self.app_label == other.app_label)
+ and (self.name == other.name)
+ and (len(self.fields) == len(other.fields))
+ and all(
k1 == k2 and f1.deconstruct()[1:] == f2.deconstruct()[1:]
for (k1, f1), (k2, f2) in zip(
sorted(self.fields.items()),
sorted(other.fields.items()),
)
- ) and
- (self.options == other.options) and
- (self.bases == other.bases) and
- (self.managers == other.managers)
+ )
+ and (self.options == other.options)
+ and (self.bases == other.bases)
+ and (self.managers == other.managers)
)
diff --git a/django/db/migrations/utils.py b/django/db/migrations/utils.py
index 42a4d90340..2b45a6033b 100644
--- a/django/db/migrations/utils.py
+++ b/django/db/migrations/utils.py
@@ -4,9 +4,9 @@ from collections import namedtuple
from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT
-FieldReference = namedtuple('FieldReference', 'to through')
+FieldReference = namedtuple("FieldReference", "to through")
-COMPILED_REGEX_TYPE = type(re.compile(''))
+COMPILED_REGEX_TYPE = type(re.compile(""))
class RegexObject:
@@ -33,16 +33,16 @@ def resolve_relation(model, app_label=None, model_name=None):
if model == RECURSIVE_RELATIONSHIP_CONSTANT:
if app_label is None or model_name is None:
raise TypeError(
- 'app_label and model_name must be provided to resolve '
- 'recursive relationships.'
+ "app_label and model_name must be provided to resolve "
+ "recursive relationships."
)
return app_label, model_name
- if '.' in model:
- app_label, model_name = model.split('.', 1)
+ if "." in model:
+ app_label, model_name = model.split(".", 1)
return app_label, model_name.lower()
if app_label is None:
raise TypeError(
- 'app_label must be provided to resolve unscoped model relationships.'
+ "app_label must be provided to resolve unscoped model relationships."
)
return app_label, model.lower()
return model._meta.app_label, model._meta.model_name
@@ -70,24 +70,32 @@ def field_references(
references_to = None
references_through = None
if resolve_relation(remote_field.model, *model_tuple) == reference_model_tuple:
- to_fields = getattr(field, 'to_fields', None)
+ to_fields = getattr(field, "to_fields", None)
if (
- reference_field_name is None or
+ reference_field_name is None
+ or
# Unspecified to_field(s).
- to_fields is None or
+ to_fields is None
+ or
# Reference to primary key.
- (None in to_fields and (reference_field is None or reference_field.primary_key)) or
+ (
+ None in to_fields
+ and (reference_field is None or reference_field.primary_key)
+ )
+ or
# Reference to field.
reference_field_name in to_fields
):
references_to = (remote_field, to_fields)
- through = getattr(remote_field, 'through', None)
+ through = getattr(remote_field, "through", None)
if through and resolve_relation(through, *model_tuple) == reference_model_tuple:
through_fields = remote_field.through_fields
if (
- reference_field_name is None or
+ reference_field_name is None
+ or
# Unspecified through_fields.
- through_fields is None or
+ through_fields is None
+ or
# Reference to field.
reference_field_name in through_fields
):
@@ -107,7 +115,9 @@ def get_references(state, model_tuple, field_tuple=()):
"""
for state_model_tuple, model_state in state.models.items():
for name, field in model_state.fields.items():
- reference = field_references(state_model_tuple, field, model_tuple, *field_tuple)
+ reference = field_references(
+ state_model_tuple, field, model_tuple, *field_tuple
+ )
if reference:
yield model_state, name, field, reference
diff --git a/django/db/migrations/writer.py b/django/db/migrations/writer.py
index 4918261fb0..a59f0c8dcb 100644
--- a/django/db/migrations/writer.py
+++ b/django/db/migrations/writer.py
@@ -1,10 +1,10 @@
-
import os
import re
from importlib import import_module
from django import get_version
from django.apps import apps
+
# SettingsReference imported for backwards compatibility in Django 2.2.
from django.conf import SettingsReference # NOQA
from django.db import migrations
@@ -22,30 +22,30 @@ class OperationWriter:
self.indentation = indentation
def serialize(self):
-
def _write(_arg_name, _arg_value):
- if (_arg_name in self.operation.serialization_expand_args and
- isinstance(_arg_value, (list, tuple, dict))):
+ if _arg_name in self.operation.serialization_expand_args and isinstance(
+ _arg_value, (list, tuple, dict)
+ ):
if isinstance(_arg_value, dict):
- self.feed('%s={' % _arg_name)
+ self.feed("%s={" % _arg_name)
self.indent()
for key, value in _arg_value.items():
key_string, key_imports = MigrationWriter.serialize(key)
arg_string, arg_imports = MigrationWriter.serialize(value)
args = arg_string.splitlines()
if len(args) > 1:
- self.feed('%s: %s' % (key_string, args[0]))
+ self.feed("%s: %s" % (key_string, args[0]))
for arg in args[1:-1]:
self.feed(arg)
- self.feed('%s,' % args[-1])
+ self.feed("%s," % args[-1])
else:
- self.feed('%s: %s,' % (key_string, arg_string))
+ self.feed("%s: %s," % (key_string, arg_string))
imports.update(key_imports)
imports.update(arg_imports)
self.unindent()
- self.feed('},')
+ self.feed("},")
else:
- self.feed('%s=[' % _arg_name)
+ self.feed("%s=[" % _arg_name)
self.indent()
for item in _arg_value:
arg_string, arg_imports = MigrationWriter.serialize(item)
@@ -53,22 +53,22 @@ class OperationWriter:
if len(args) > 1:
for arg in args[:-1]:
self.feed(arg)
- self.feed('%s,' % args[-1])
+ self.feed("%s," % args[-1])
else:
- self.feed('%s,' % arg_string)
+ self.feed("%s," % arg_string)
imports.update(arg_imports)
self.unindent()
- self.feed('],')
+ self.feed("],")
else:
arg_string, arg_imports = MigrationWriter.serialize(_arg_value)
args = arg_string.splitlines()
if len(args) > 1:
- self.feed('%s=%s' % (_arg_name, args[0]))
+ self.feed("%s=%s" % (_arg_name, args[0]))
for arg in args[1:-1]:
self.feed(arg)
- self.feed('%s,' % args[-1])
+ self.feed("%s," % args[-1])
else:
- self.feed('%s=%s,' % (_arg_name, arg_string))
+ self.feed("%s=%s," % (_arg_name, arg_string))
imports.update(arg_imports)
imports = set()
@@ -79,10 +79,10 @@ class OperationWriter:
# We can just use the fact we already have that imported,
# otherwise, we need to add an import for the operation class.
if getattr(migrations, name, None) == self.operation.__class__:
- self.feed('migrations.%s(' % name)
+ self.feed("migrations.%s(" % name)
else:
- imports.add('import %s' % (self.operation.__class__.__module__))
- self.feed('%s.%s(' % (self.operation.__class__.__module__, name))
+ imports.add("import %s" % (self.operation.__class__.__module__))
+ self.feed("%s.%s(" % (self.operation.__class__.__module__, name))
self.indent()
@@ -99,7 +99,7 @@ class OperationWriter:
_write(arg_name, arg_value)
self.unindent()
- self.feed('),')
+ self.feed("),")
return self.render(), imports
def indent(self):
@@ -109,10 +109,10 @@ class OperationWriter:
self.indentation -= 1
def feed(self, line):
- self.buff.append(' ' * (self.indentation * 4) + line)
+ self.buff.append(" " * (self.indentation * 4) + line)
def render(self):
- return '\n'.join(self.buff)
+ return "\n".join(self.buff)
class MigrationWriter:
@@ -147,7 +147,10 @@ class MigrationWriter:
dependencies = []
for dependency in self.migration.dependencies:
if dependency[0] == "__setting__":
- dependencies.append(" migrations.swappable_dependency(settings.%s)," % dependency[1])
+ dependencies.append(
+ " migrations.swappable_dependency(settings.%s),"
+ % dependency[1]
+ )
imports.add("from django.conf import settings")
else:
dependencies.append(" %s," % self.serialize(dependency)[0])
@@ -183,24 +186,28 @@ class MigrationWriter:
) % "\n# ".join(sorted(migration_imports))
# If there's a replaces, make a string for it
if self.migration.replaces:
- items['replaces_str'] = "\n replaces = %s\n" % self.serialize(self.migration.replaces)[0]
+ items["replaces_str"] = (
+ "\n replaces = %s\n" % self.serialize(self.migration.replaces)[0]
+ )
# Hinting that goes into comment
if self.include_header:
- items['migration_header'] = MIGRATION_HEADER_TEMPLATE % {
- 'version': get_version(),
- 'timestamp': now().strftime("%Y-%m-%d %H:%M"),
+ items["migration_header"] = MIGRATION_HEADER_TEMPLATE % {
+ "version": get_version(),
+ "timestamp": now().strftime("%Y-%m-%d %H:%M"),
}
else:
- items['migration_header'] = ""
+ items["migration_header"] = ""
if self.migration.initial:
- items['initial_str'] = "\n initial = True\n"
+ items["initial_str"] = "\n initial = True\n"
return MIGRATION_TEMPLATE % items
@property
def basedir(self):
- migrations_package_name, _ = MigrationLoader.migrations_module(self.migration.app_label)
+ migrations_package_name, _ = MigrationLoader.migrations_module(
+ self.migration.app_label
+ )
if migrations_package_name is None:
raise ValueError(
@@ -222,7 +229,11 @@ class MigrationWriter:
# Alright, see if it's a direct submodule of the app
app_config = apps.get_app_config(self.migration.app_label)
- maybe_app_name, _, migrations_package_basename = migrations_package_name.rpartition(".")
+ (
+ maybe_app_name,
+ _,
+ migrations_package_basename,
+ ) = migrations_package_name.rpartition(".")
if app_config.name == maybe_app_name:
return os.path.join(app_config.path, migrations_package_basename)
@@ -246,8 +257,8 @@ class MigrationWriter:
raise ValueError(
"Could not locate an appropriate location to create "
"migrations package %s. Make sure the toplevel "
- "package exists and can be imported." %
- migrations_package_name)
+ "package exists and can be imported." % migrations_package_name
+ )
final_dir = os.path.join(base_dir, *missing_dirs)
os.makedirs(final_dir, exist_ok=True)
diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py
index a583af2aff..ffca81de91 100644
--- a/django/db/models/__init__.py
+++ b/django/db/models/__init__.py
@@ -5,14 +5,34 @@ from django.db.models.aggregates import __all__ as aggregates_all
from django.db.models.constraints import * # NOQA
from django.db.models.constraints import __all__ as constraints_all
from django.db.models.deletion import (
- CASCADE, DO_NOTHING, PROTECT, RESTRICT, SET, SET_DEFAULT, SET_NULL,
- ProtectedError, RestrictedError,
+ CASCADE,
+ DO_NOTHING,
+ PROTECT,
+ RESTRICT,
+ SET,
+ SET_DEFAULT,
+ SET_NULL,
+ ProtectedError,
+ RestrictedError,
)
from django.db.models.enums import * # NOQA
from django.db.models.enums import __all__ as enums_all
from django.db.models.expressions import (
- Case, Exists, Expression, ExpressionList, ExpressionWrapper, F, Func,
- OrderBy, OuterRef, RowRange, Subquery, Value, ValueRange, When, Window,
+ Case,
+ Exists,
+ Expression,
+ ExpressionList,
+ ExpressionWrapper,
+ F,
+ Func,
+ OrderBy,
+ OuterRef,
+ RowRange,
+ Subquery,
+ Value,
+ ValueRange,
+ When,
+ Window,
WindowFrame,
)
from django.db.models.fields import * # NOQA
@@ -30,23 +50,66 @@ from django.db.models.query_utils import FilteredRelation, Q
# Imports that would create circular imports if sorted
from django.db.models.base import DEFERRED, Model # isort:skip
from django.db.models.fields.related import ( # isort:skip
- ForeignKey, ForeignObject, OneToOneField, ManyToManyField,
- ForeignObjectRel, ManyToOneRel, ManyToManyRel, OneToOneRel,
+ ForeignKey,
+ ForeignObject,
+ OneToOneField,
+ ManyToManyField,
+ ForeignObjectRel,
+ ManyToOneRel,
+ ManyToManyRel,
+ OneToOneRel,
)
__all__ = aggregates_all + constraints_all + enums_all + fields_all + indexes_all
__all__ += [
- 'ObjectDoesNotExist', 'signals',
- 'CASCADE', 'DO_NOTHING', 'PROTECT', 'RESTRICT', 'SET', 'SET_DEFAULT',
- 'SET_NULL', 'ProtectedError', 'RestrictedError',
- 'Case', 'Exists', 'Expression', 'ExpressionList', 'ExpressionWrapper', 'F',
- 'Func', 'OrderBy', 'OuterRef', 'RowRange', 'Subquery', 'Value',
- 'ValueRange', 'When',
- 'Window', 'WindowFrame',
- 'FileField', 'ImageField', 'JSONField', 'OrderWrt', 'Lookup', 'Transform',
- 'Manager', 'Prefetch', 'Q', 'QuerySet', 'prefetch_related_objects',
- 'DEFERRED', 'Model', 'FilteredRelation',
- 'ForeignKey', 'ForeignObject', 'OneToOneField', 'ManyToManyField',
- 'ForeignObjectRel', 'ManyToOneRel', 'ManyToManyRel', 'OneToOneRel',
+ "ObjectDoesNotExist",
+ "signals",
+ "CASCADE",
+ "DO_NOTHING",
+ "PROTECT",
+ "RESTRICT",
+ "SET",
+ "SET_DEFAULT",
+ "SET_NULL",
+ "ProtectedError",
+ "RestrictedError",
+ "Case",
+ "Exists",
+ "Expression",
+ "ExpressionList",
+ "ExpressionWrapper",
+ "F",
+ "Func",
+ "OrderBy",
+ "OuterRef",
+ "RowRange",
+ "Subquery",
+ "Value",
+ "ValueRange",
+ "When",
+ "Window",
+ "WindowFrame",
+ "FileField",
+ "ImageField",
+ "JSONField",
+ "OrderWrt",
+ "Lookup",
+ "Transform",
+ "Manager",
+ "Prefetch",
+ "Q",
+ "QuerySet",
+ "prefetch_related_objects",
+ "DEFERRED",
+ "Model",
+ "FilteredRelation",
+ "ForeignKey",
+ "ForeignObject",
+ "OneToOneField",
+ "ManyToManyField",
+ "ForeignObjectRel",
+ "ManyToOneRel",
+ "ManyToManyRel",
+ "OneToOneRel",
]
diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py
index bc31b48d8d..2ffed7cd2c 100644
--- a/django/db/models/aggregates.py
+++ b/django/db/models/aggregates.py
@@ -6,28 +6,38 @@ from django.db.models.expressions import Case, Func, Star, When
from django.db.models.fields import IntegerField
from django.db.models.functions.comparison import Coalesce
from django.db.models.functions.mixins import (
- FixDurationInputMixin, NumericOutputFieldMixin,
+ FixDurationInputMixin,
+ NumericOutputFieldMixin,
)
__all__ = [
- 'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
+ "Aggregate",
+ "Avg",
+ "Count",
+ "Max",
+ "Min",
+ "StdDev",
+ "Sum",
+ "Variance",
]
class Aggregate(Func):
- template = '%(function)s(%(distinct)s%(expressions)s)'
+ template = "%(function)s(%(distinct)s%(expressions)s)"
contains_aggregate = True
name = None
- filter_template = '%s FILTER (WHERE %%(filter)s)'
+ filter_template = "%s FILTER (WHERE %%(filter)s)"
window_compatible = True
allow_distinct = False
empty_result_set_value = None
- def __init__(self, *expressions, distinct=False, filter=None, default=None, **extra):
+ def __init__(
+ self, *expressions, distinct=False, filter=None, default=None, **extra
+ ):
if distinct and not self.allow_distinct:
raise TypeError("%s does not allow distinct." % self.__class__.__name__)
if default is not None and self.empty_result_set_value is not None:
- raise TypeError(f'{self.__class__.__name__} does not allow default.')
+ raise TypeError(f"{self.__class__.__name__} does not allow default.")
self.distinct = distinct
self.filter = filter
self.default = default
@@ -47,10 +57,14 @@ class Aggregate(Func):
self.filter = self.filter and exprs.pop()
return super().set_source_expressions(exprs)
- def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
+ def resolve_expression(
+ self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
+ ):
# Aggregates are not allowed in UPDATE queries, so ignore for_save
c = super().resolve_expression(query, allow_joins, reuse, summarize)
- c.filter = c.filter and c.filter.resolve_expression(query, allow_joins, reuse, summarize)
+ c.filter = c.filter and c.filter.resolve_expression(
+ query, allow_joins, reuse, summarize
+ )
if not summarize:
# Call Aggregate.get_source_expressions() to avoid
# returning self.filter and including that in this loop.
@@ -58,11 +72,18 @@ class Aggregate(Func):
for index, expr in enumerate(expressions):
if expr.contains_aggregate:
before_resolved = self.get_source_expressions()[index]
- name = before_resolved.name if hasattr(before_resolved, 'name') else repr(before_resolved)
- raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (c.name, name, name))
+ name = (
+ before_resolved.name
+ if hasattr(before_resolved, "name")
+ else repr(before_resolved)
+ )
+ raise FieldError(
+ "Cannot compute %s('%s'): '%s' is an aggregate"
+ % (c.name, name, name)
+ )
if (default := c.default) is None:
return c
- if hasattr(default, 'resolve_expression'):
+ if hasattr(default, "resolve_expression"):
default = default.resolve_expression(query, allow_joins, reuse, summarize)
c.default = None # Reset the default argument before wrapping.
coalesce = Coalesce(c, default, output_field=c._output_field_or_none)
@@ -72,22 +93,27 @@ class Aggregate(Func):
@property
def default_alias(self):
expressions = self.get_source_expressions()
- if len(expressions) == 1 and hasattr(expressions[0], 'name'):
- return '%s__%s' % (expressions[0].name, self.name.lower())
+ if len(expressions) == 1 and hasattr(expressions[0], "name"):
+ return "%s__%s" % (expressions[0].name, self.name.lower())
raise TypeError("Complex expressions require an alias")
def get_group_by_cols(self, alias=None):
return []
def as_sql(self, compiler, connection, **extra_context):
- extra_context['distinct'] = 'DISTINCT ' if self.distinct else ''
+ extra_context["distinct"] = "DISTINCT " if self.distinct else ""
if self.filter:
if connection.features.supports_aggregate_filter_clause:
filter_sql, filter_params = self.filter.as_sql(compiler, connection)
- template = self.filter_template % extra_context.get('template', self.template)
+ template = self.filter_template % extra_context.get(
+ "template", self.template
+ )
sql, params = super().as_sql(
- compiler, connection, template=template, filter=filter_sql,
- **extra_context
+ compiler,
+ connection,
+ template=template,
+ filter=filter_sql,
+ **extra_context,
)
return sql, (*params, *filter_params)
else:
@@ -96,72 +122,74 @@ class Aggregate(Func):
source_expressions = copy.get_source_expressions()
condition = When(self.filter, then=source_expressions[0])
copy.set_source_expressions([Case(condition)] + source_expressions[1:])
- return super(Aggregate, copy).as_sql(compiler, connection, **extra_context)
+ return super(Aggregate, copy).as_sql(
+ compiler, connection, **extra_context
+ )
return super().as_sql(compiler, connection, **extra_context)
def _get_repr_options(self):
options = super()._get_repr_options()
if self.distinct:
- options['distinct'] = self.distinct
+ options["distinct"] = self.distinct
if self.filter:
- options['filter'] = self.filter
+ options["filter"] = self.filter
return options
class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate):
- function = 'AVG'
- name = 'Avg'
+ function = "AVG"
+ name = "Avg"
allow_distinct = True
class Count(Aggregate):
- function = 'COUNT'
- name = 'Count'
+ function = "COUNT"
+ name = "Count"
output_field = IntegerField()
allow_distinct = True
empty_result_set_value = 0
def __init__(self, expression, filter=None, **extra):
- if expression == '*':
+ if expression == "*":
expression = Star()
if isinstance(expression, Star) and filter is not None:
- raise ValueError('Star cannot be used with filter. Please specify a field.')
+ raise ValueError("Star cannot be used with filter. Please specify a field.")
super().__init__(expression, filter=filter, **extra)
class Max(Aggregate):
- function = 'MAX'
- name = 'Max'
+ function = "MAX"
+ name = "Max"
class Min(Aggregate):
- function = 'MIN'
- name = 'Min'
+ function = "MIN"
+ name = "Min"
class StdDev(NumericOutputFieldMixin, Aggregate):
- name = 'StdDev'
+ name = "StdDev"
def __init__(self, expression, sample=False, **extra):
- self.function = 'STDDEV_SAMP' if sample else 'STDDEV_POP'
+ self.function = "STDDEV_SAMP" if sample else "STDDEV_POP"
super().__init__(expression, **extra)
def _get_repr_options(self):
- return {**super()._get_repr_options(), 'sample': self.function == 'STDDEV_SAMP'}
+ return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"}
class Sum(FixDurationInputMixin, Aggregate):
- function = 'SUM'
- name = 'Sum'
+ function = "SUM"
+ name = "Sum"
allow_distinct = True
class Variance(NumericOutputFieldMixin, Aggregate):
- name = 'Variance'
+ name = "Variance"
def __init__(self, expression, sample=False, **extra):
- self.function = 'VAR_SAMP' if sample else 'VAR_POP'
+ self.function = "VAR_SAMP" if sample else "VAR_POP"
super().__init__(expression, **extra)
def _get_repr_options(self):
- return {**super()._get_repr_options(), 'sample': self.function == 'VAR_SAMP'}
+ return {**super()._get_repr_options(), "sample": self.function == "VAR_SAMP"}
diff --git a/django/db/models/base.py b/django/db/models/base.py
index 37f6a3dd58..8127a9895a 100644
--- a/django/db/models/base.py
+++ b/django/db/models/base.py
@@ -9,28 +9,42 @@ from django.apps import apps
from django.conf import settings
from django.core import checks
from django.core.exceptions import (
- NON_FIELD_ERRORS, FieldDoesNotExist, FieldError, MultipleObjectsReturned,
- ObjectDoesNotExist, ValidationError,
+ NON_FIELD_ERRORS,
+ FieldDoesNotExist,
+ FieldError,
+ MultipleObjectsReturned,
+ ObjectDoesNotExist,
+ ValidationError,
)
from django.db import (
- DEFAULT_DB_ALIAS, DJANGO_VERSION_PICKLE_KEY, DatabaseError, connection,
- connections, router, transaction,
-)
-from django.db.models import (
- NOT_PROVIDED, ExpressionWrapper, IntegerField, Max, Value,
+ DEFAULT_DB_ALIAS,
+ DJANGO_VERSION_PICKLE_KEY,
+ DatabaseError,
+ connection,
+ connections,
+ router,
+ transaction,
)
+from django.db.models import NOT_PROVIDED, ExpressionWrapper, IntegerField, Max, Value
from django.db.models.constants import LOOKUP_SEP
from django.db.models.constraints import CheckConstraint, UniqueConstraint
from django.db.models.deletion import CASCADE, Collector
from django.db.models.fields.related import (
- ForeignObjectRel, OneToOneField, lazy_related_operation, resolve_relation,
+ ForeignObjectRel,
+ OneToOneField,
+ lazy_related_operation,
+ resolve_relation,
)
from django.db.models.functions import Coalesce
from django.db.models.manager import Manager
from django.db.models.options import Options
from django.db.models.query import F, Q
from django.db.models.signals import (
- class_prepared, post_init, post_save, pre_init, pre_save,
+ class_prepared,
+ post_init,
+ post_save,
+ pre_init,
+ pre_save,
)
from django.db.models.utils import make_model_tuple
from django.utils.encoding import force_str
@@ -41,10 +55,10 @@ from django.utils.translation import gettext_lazy as _
class Deferred:
def __repr__(self):
- return '<Deferred field>'
+ return "<Deferred field>"
def __str__(self):
- return '<Deferred field>'
+ return "<Deferred field>"
DEFERRED = Deferred()
@@ -58,19 +72,24 @@ def subclass_exception(name, bases, module, attached_to):
that the returned exception class will be added as an attribute to the
'attached_to' class.
"""
- return type(name, bases, {
- '__module__': module,
- '__qualname__': '%s.%s' % (attached_to.__qualname__, name),
- })
+ return type(
+ name,
+ bases,
+ {
+ "__module__": module,
+ "__qualname__": "%s.%s" % (attached_to.__qualname__, name),
+ },
+ )
def _has_contribute_to_class(value):
# Only call contribute_to_class() if it's bound.
- return not inspect.isclass(value) and hasattr(value, 'contribute_to_class')
+ return not inspect.isclass(value) and hasattr(value, "contribute_to_class")
class ModelBase(type):
"""Metaclass for all models."""
+
def __new__(cls, name, bases, attrs, **kwargs):
super_new = super().__new__
@@ -81,12 +100,12 @@ class ModelBase(type):
return super_new(cls, name, bases, attrs)
# Create the class.
- module = attrs.pop('__module__')
- new_attrs = {'__module__': module}
- classcell = attrs.pop('__classcell__', None)
+ module = attrs.pop("__module__")
+ new_attrs = {"__module__": module}
+ classcell = attrs.pop("__classcell__", None)
if classcell is not None:
- new_attrs['__classcell__'] = classcell
- attr_meta = attrs.pop('Meta', None)
+ new_attrs["__classcell__"] = classcell
+ attr_meta = attrs.pop("Meta", None)
# Pass all attrs without a (Django-specific) contribute_to_class()
# method to type.__new__() so that they're properly initialized
# (i.e. __set_name__()).
@@ -98,16 +117,16 @@ class ModelBase(type):
new_attrs[obj_name] = obj
new_class = super_new(cls, name, bases, new_attrs, **kwargs)
- abstract = getattr(attr_meta, 'abstract', False)
- meta = attr_meta or getattr(new_class, 'Meta', None)
- base_meta = getattr(new_class, '_meta', None)
+ abstract = getattr(attr_meta, "abstract", False)
+ meta = attr_meta or getattr(new_class, "Meta", None)
+ base_meta = getattr(new_class, "_meta", None)
app_label = None
# Look for an application configuration to attach the model to.
app_config = apps.get_containing_app_config(module)
- if getattr(meta, 'app_label', None) is None:
+ if getattr(meta, "app_label", None) is None:
if app_config is None:
if not abstract:
raise RuntimeError(
@@ -119,33 +138,43 @@ class ModelBase(type):
else:
app_label = app_config.label
- new_class.add_to_class('_meta', Options(meta, app_label))
+ new_class.add_to_class("_meta", Options(meta, app_label))
if not abstract:
new_class.add_to_class(
- 'DoesNotExist',
+ "DoesNotExist",
subclass_exception(
- 'DoesNotExist',
+ "DoesNotExist",
tuple(
- x.DoesNotExist for x in parents if hasattr(x, '_meta') and not x._meta.abstract
- ) or (ObjectDoesNotExist,),
+ x.DoesNotExist
+ for x in parents
+ if hasattr(x, "_meta") and not x._meta.abstract
+ )
+ or (ObjectDoesNotExist,),
module,
- attached_to=new_class))
+ attached_to=new_class,
+ ),
+ )
new_class.add_to_class(
- 'MultipleObjectsReturned',
+ "MultipleObjectsReturned",
subclass_exception(
- 'MultipleObjectsReturned',
+ "MultipleObjectsReturned",
tuple(
- x.MultipleObjectsReturned for x in parents if hasattr(x, '_meta') and not x._meta.abstract
- ) or (MultipleObjectsReturned,),
+ x.MultipleObjectsReturned
+ for x in parents
+ if hasattr(x, "_meta") and not x._meta.abstract
+ )
+ or (MultipleObjectsReturned,),
module,
- attached_to=new_class))
+ attached_to=new_class,
+ ),
+ )
if base_meta and not base_meta.abstract:
# Non-abstract child classes inherit some attributes from their
# non-abstract parent (unless an ABC comes before it in the
# method resolution order).
- if not hasattr(meta, 'ordering'):
+ if not hasattr(meta, "ordering"):
new_class._meta.ordering = base_meta.ordering
- if not hasattr(meta, 'get_latest_by'):
+ if not hasattr(meta, "get_latest_by"):
new_class._meta.get_latest_by = base_meta.get_latest_by
is_proxy = new_class._meta.proxy
@@ -153,7 +182,9 @@ class ModelBase(type):
# If the model is a proxy, ensure that the base class
# hasn't been swapped out.
if is_proxy and base_meta and base_meta.swapped:
- raise TypeError("%s cannot proxy the swapped model '%s'." % (name, base_meta.swapped))
+ raise TypeError(
+ "%s cannot proxy the swapped model '%s'." % (name, base_meta.swapped)
+ )
# Add remaining attributes (those with a contribute_to_class() method)
# to the class.
@@ -164,14 +195,14 @@ class ModelBase(type):
new_fields = chain(
new_class._meta.local_fields,
new_class._meta.local_many_to_many,
- new_class._meta.private_fields
+ new_class._meta.private_fields,
)
field_names = {f.name for f in new_fields}
# Basic setup for proxy models.
if is_proxy:
base = None
- for parent in [kls for kls in parents if hasattr(kls, '_meta')]:
+ for parent in [kls for kls in parents if hasattr(kls, "_meta")]:
if parent._meta.abstract:
if parent._meta.fields:
raise TypeError(
@@ -183,9 +214,14 @@ class ModelBase(type):
if base is None:
base = parent
elif parent._meta.concrete_model is not base._meta.concrete_model:
- raise TypeError("Proxy model '%s' has more than one non-abstract model base class." % name)
+ raise TypeError(
+ "Proxy model '%s' has more than one non-abstract model base class."
+ % name
+ )
if base is None:
- raise TypeError("Proxy model '%s' has no non-abstract model base class." % name)
+ raise TypeError(
+ "Proxy model '%s' has no non-abstract model base class." % name
+ )
new_class._meta.setup_proxy(base)
new_class._meta.concrete_model = base._meta.concrete_model
else:
@@ -195,7 +231,7 @@ class ModelBase(type):
parent_links = {}
for base in reversed([new_class] + parents):
# Conceptually equivalent to `if base is Model`.
- if not hasattr(base, '_meta'):
+ if not hasattr(base, "_meta"):
continue
# Skip concrete parent classes.
if base != new_class and not base._meta.abstract:
@@ -210,7 +246,7 @@ class ModelBase(type):
inherited_attributes = set()
# Do the appropriate setup for any model parents.
for base in new_class.mro():
- if base not in parents or not hasattr(base, '_meta'):
+ if base not in parents or not hasattr(base, "_meta"):
# Things without _meta aren't functional models, so they're
# uninteresting parents.
inherited_attributes.update(base.__dict__)
@@ -223,8 +259,9 @@ class ModelBase(type):
for field in parent_fields:
if field.name in field_names:
raise FieldError(
- 'Local field %r in class %r clashes with field of '
- 'the same name from base class %r.' % (
+ "Local field %r in class %r clashes with field of "
+ "the same name from base class %r."
+ % (
field.name,
name,
base.__name__,
@@ -239,7 +276,7 @@ class ModelBase(type):
if base_key in parent_links:
field = parent_links[base_key]
elif not is_proxy:
- attr_name = '%s_ptr' % base._meta.model_name
+ attr_name = "%s_ptr" % base._meta.model_name
field = OneToOneField(
base,
on_delete=CASCADE,
@@ -252,7 +289,8 @@ class ModelBase(type):
raise FieldError(
"Auto-generated field '%s' in class %r for "
"parent_link to base class %r clashes with "
- "declared field of the same name." % (
+ "declared field of the same name."
+ % (
attr_name,
name,
base.__name__,
@@ -271,9 +309,11 @@ class ModelBase(type):
# Add fields from abstract base class if it wasn't overridden.
for field in parent_fields:
- if (field.name not in field_names and
- field.name not in new_class.__dict__ and
- field.name not in inherited_attributes):
+ if (
+ field.name not in field_names
+ and field.name not in new_class.__dict__
+ and field.name not in inherited_attributes
+ ):
new_field = copy.deepcopy(field)
new_class.add_to_class(field.name, new_field)
# Replace parent links defined on this base by the new
@@ -292,8 +332,9 @@ class ModelBase(type):
if field.name in field_names:
if not base._meta.abstract:
raise FieldError(
- 'Local field %r in class %r clashes with field of '
- 'the same name from base class %r.' % (
+ "Local field %r in class %r clashes with field of "
+ "the same name from base class %r."
+ % (
field.name,
name,
base.__name__,
@@ -307,7 +348,9 @@ class ModelBase(type):
# Copy indexes so that index names are unique when models extend an
# abstract model.
- new_class._meta.indexes = [copy.deepcopy(idx) for idx in new_class._meta.indexes]
+ new_class._meta.indexes = [
+ copy.deepcopy(idx) for idx in new_class._meta.indexes
+ ]
if abstract:
# Abstract base models can't be instantiated and don't appear in
@@ -333,8 +376,12 @@ class ModelBase(type):
opts._prepare(cls)
if opts.order_with_respect_to:
- cls.get_next_in_order = partialmethod(cls._get_next_or_previous_in_order, is_next=True)
- cls.get_previous_in_order = partialmethod(cls._get_next_or_previous_in_order, is_next=False)
+ cls.get_next_in_order = partialmethod(
+ cls._get_next_or_previous_in_order, is_next=True
+ )
+ cls.get_previous_in_order = partialmethod(
+ cls._get_next_or_previous_in_order, is_next=False
+ )
# Defer creating accessors on the foreign class until it has been
# created and registered. If remote_field is None, we're ordering
@@ -348,21 +395,26 @@ class ModelBase(type):
# Give the class a docstring -- its definition.
if cls.__doc__ is None:
- cls.__doc__ = "%s(%s)" % (cls.__name__, ", ".join(f.name for f in opts.fields))
+ cls.__doc__ = "%s(%s)" % (
+ cls.__name__,
+ ", ".join(f.name for f in opts.fields),
+ )
- get_absolute_url_override = settings.ABSOLUTE_URL_OVERRIDES.get(opts.label_lower)
+ get_absolute_url_override = settings.ABSOLUTE_URL_OVERRIDES.get(
+ opts.label_lower
+ )
if get_absolute_url_override:
- setattr(cls, 'get_absolute_url', get_absolute_url_override)
+ setattr(cls, "get_absolute_url", get_absolute_url_override)
if not opts.managers:
- if any(f.name == 'objects' for f in opts.fields):
+ if any(f.name == "objects" for f in opts.fields):
raise ValueError(
"Model %s must specify a custom Manager, because it has a "
"field named 'objects'." % cls.__name__
)
manager = Manager()
manager.auto_created = True
- cls.add_to_class('objects', manager)
+ cls.add_to_class("objects", manager)
# Set the name of _meta.indexes. This can't be done in
# Options.contribute_to_class() because fields haven't been added to
@@ -399,6 +451,7 @@ class ModelStateCacheDescriptor:
class ModelState:
"""Store model instance state."""
+
db = None
# If true, uniqueness validation checks will consider this a new, unsaved
# object. Necessary for correct validation of new instances of objects with
@@ -410,19 +463,18 @@ class ModelState:
def __getstate__(self):
state = self.__dict__.copy()
- if 'fields_cache' in state:
- state['fields_cache'] = self.fields_cache.copy()
+ if "fields_cache" in state:
+ state["fields_cache"] = self.fields_cache.copy()
# Manager instances stored in related_managers_cache won't necessarily
# be deserializable if they were dynamically created via an inner
# scope, e.g. create_forward_many_to_many_manager() and
# create_generic_related_manager().
- if 'related_managers_cache' in state:
- state['related_managers_cache'] = {}
+ if "related_managers_cache" in state:
+ state["related_managers_cache"] = {}
return state
class Model(metaclass=ModelBase):
-
def __init__(self, *args, **kwargs):
# Alias some things as locals to avoid repeat global lookups
cls = self.__class__
@@ -430,7 +482,7 @@ class Model(metaclass=ModelBase):
_setattr = setattr
_DEFERRED = DEFERRED
if opts.abstract:
- raise TypeError('Abstract models cannot be instantiated.')
+ raise TypeError("Abstract models cannot be instantiated.")
pre_init.send(sender=cls, args=args, kwargs=kwargs)
@@ -529,10 +581,10 @@ class Model(metaclass=ModelBase):
if value is not _DEFERRED:
_setattr(self, prop, value)
if unexpected:
- unexpected_names = ', '.join(repr(n) for n in unexpected)
+ unexpected_names = ", ".join(repr(n) for n in unexpected)
raise TypeError(
- f'{cls.__name__}() got unexpected keyword arguments: '
- f'{unexpected_names}'
+ f"{cls.__name__}() got unexpected keyword arguments: "
+ f"{unexpected_names}"
)
super().__init__()
post_init.send(sender=cls, instance=self)
@@ -551,10 +603,10 @@ class Model(metaclass=ModelBase):
return new
def __repr__(self):
- return '<%s: %s>' % (self.__class__.__name__, self)
+ return "<%s: %s>" % (self.__class__.__name__, self)
def __str__(self):
- return '%s object (%s)' % (self.__class__.__name__, self.pk)
+ return "%s object (%s)" % (self.__class__.__name__, self.pk)
def __eq__(self, other):
if not isinstance(other, Model):
@@ -580,7 +632,7 @@ class Model(metaclass=ModelBase):
def __getstate__(self):
"""Hook to allow choosing the attributes to pickle."""
state = self.__dict__.copy()
- state['_state'] = copy.copy(state['_state'])
+ state["_state"] = copy.copy(state["_state"])
# memoryview cannot be pickled, so cast it to bytes and store
# separately.
_memoryview_attrs = []
@@ -588,7 +640,7 @@ class Model(metaclass=ModelBase):
if isinstance(value, memoryview):
_memoryview_attrs.append((attr, bytes(value)))
if _memoryview_attrs:
- state['_memoryview_attrs'] = _memoryview_attrs
+ state["_memoryview_attrs"] = _memoryview_attrs
for attr, value in _memoryview_attrs:
state.pop(attr)
return state
@@ -610,8 +662,8 @@ class Model(metaclass=ModelBase):
RuntimeWarning,
stacklevel=2,
)
- if '_memoryview_attrs' in state:
- for attr, value in state.pop('_memoryview_attrs'):
+ if "_memoryview_attrs" in state:
+ for attr, value in state.pop("_memoryview_attrs"):
state[attr] = memoryview(value)
self.__dict__.update(state)
@@ -632,7 +684,8 @@ class Model(metaclass=ModelBase):
Return a set containing names of deferred fields for this instance.
"""
return {
- f.attname for f in self._meta.concrete_fields
+ f.attname
+ for f in self._meta.concrete_fields
if f.attname not in self.__dict__
}
@@ -654,7 +707,7 @@ class Model(metaclass=ModelBase):
if fields is None:
self._prefetched_objects_cache = {}
else:
- prefetched_objects_cache = getattr(self, '_prefetched_objects_cache', ())
+ prefetched_objects_cache = getattr(self, "_prefetched_objects_cache", ())
for field in fields:
if field in prefetched_objects_cache:
del prefetched_objects_cache[field]
@@ -664,10 +717,13 @@ class Model(metaclass=ModelBase):
if any(LOOKUP_SEP in f for f in fields):
raise ValueError(
'Found "%s" in fields argument. Relations and transforms '
- 'are not allowed in fields.' % LOOKUP_SEP)
+ "are not allowed in fields." % LOOKUP_SEP
+ )
- hints = {'instance': self}
- db_instance_qs = self.__class__._base_manager.db_manager(using, hints=hints).filter(pk=self.pk)
+ hints = {"instance": self}
+ db_instance_qs = self.__class__._base_manager.db_manager(
+ using, hints=hints
+ ).filter(pk=self.pk)
# Use provided fields, if not set then reload all non-deferred fields.
deferred_fields = self.get_deferred_fields()
@@ -675,8 +731,11 @@ class Model(metaclass=ModelBase):
fields = list(fields)
db_instance_qs = db_instance_qs.only(*fields)
elif deferred_fields:
- fields = [f.attname for f in self._meta.concrete_fields
- if f.attname not in deferred_fields]
+ fields = [
+ f.attname
+ for f in self._meta.concrete_fields
+ if f.attname not in deferred_fields
+ ]
db_instance_qs = db_instance_qs.only(*fields)
db_instance = db_instance_qs.get()
@@ -714,8 +773,9 @@ class Model(metaclass=ModelBase):
return getattr(self, field_name)
return getattr(self, field.attname)
- def save(self, force_insert=False, force_update=False, using=None,
- update_fields=None):
+ def save(
+ self, force_insert=False, force_update=False, using=None, update_fields=None
+ ):
"""
Save the current instance. Override this in a subclass if you want to
control the saving process.
@@ -724,7 +784,7 @@ class Model(metaclass=ModelBase):
that the "save" must be an SQL insert or update (or equivalent for
non-SQL backends), respectively. Normally, they should not be set.
"""
- self._prepare_related_fields_for_save(operation_name='save')
+ self._prepare_related_fields_for_save(operation_name="save")
using = using or router.db_for_write(self.__class__, instance=self)
if force_insert and (force_update or update_fields):
@@ -752,9 +812,9 @@ class Model(metaclass=ModelBase):
if non_model_fields:
raise ValueError(
- 'The following fields do not exist in this model, are m2m '
- 'fields, or are non-concrete fields: %s'
- % ', '.join(non_model_fields)
+ "The following fields do not exist in this model, are m2m "
+ "fields, or are non-concrete fields: %s"
+ % ", ".join(non_model_fields)
)
# If saving to the same database, and this model is deferred, then
@@ -762,18 +822,29 @@ class Model(metaclass=ModelBase):
elif not force_insert and deferred_fields and using == self._state.db:
field_names = set()
for field in self._meta.concrete_fields:
- if not field.primary_key and not hasattr(field, 'through'):
+ if not field.primary_key and not hasattr(field, "through"):
field_names.add(field.attname)
loaded_fields = field_names.difference(deferred_fields)
if loaded_fields:
update_fields = frozenset(loaded_fields)
- self.save_base(using=using, force_insert=force_insert,
- force_update=force_update, update_fields=update_fields)
+ self.save_base(
+ using=using,
+ force_insert=force_insert,
+ force_update=force_update,
+ update_fields=update_fields,
+ )
+
save.alters_data = True
- def save_base(self, raw=False, force_insert=False,
- force_update=False, using=None, update_fields=None):
+ def save_base(
+ self,
+ raw=False,
+ force_insert=False,
+ force_update=False,
+ using=None,
+ update_fields=None,
+ ):
"""
Handle the parts of saving which should be done only once per save,
yet need to be done in raw saves, too. This includes some sanity
@@ -793,7 +864,10 @@ class Model(metaclass=ModelBase):
meta = cls._meta
if not meta.auto_created:
pre_save.send(
- sender=origin, instance=self, raw=raw, using=using,
+ sender=origin,
+ instance=self,
+ raw=raw,
+ using=using,
update_fields=update_fields,
)
# A transaction isn't needed if one query is issued.
@@ -806,8 +880,12 @@ class Model(metaclass=ModelBase):
if not raw:
parent_inserted = self._save_parents(cls, using, update_fields)
updated = self._save_table(
- raw, cls, force_insert or parent_inserted,
- force_update, using, update_fields,
+ raw,
+ cls,
+ force_insert or parent_inserted,
+ force_update,
+ using,
+ update_fields,
)
# Store the database on which the object was saved
self._state.db = using
@@ -817,8 +895,12 @@ class Model(metaclass=ModelBase):
# Signal that the save is complete
if not meta.auto_created:
post_save.send(
- sender=origin, instance=self, created=(not updated),
- update_fields=update_fields, raw=raw, using=using,
+ sender=origin,
+ instance=self,
+ created=(not updated),
+ update_fields=update_fields,
+ raw=raw,
+ using=using,
)
save_base.alters_data = True
@@ -829,12 +911,19 @@ class Model(metaclass=ModelBase):
inserted = False
for parent, field in meta.parents.items():
# Make sure the link fields are synced between parent and self.
- if (field and getattr(self, parent._meta.pk.attname) is None and
- getattr(self, field.attname) is not None):
+ if (
+ field
+ and getattr(self, parent._meta.pk.attname) is None
+ and getattr(self, field.attname) is not None
+ ):
setattr(self, parent._meta.pk.attname, getattr(self, field.attname))
- parent_inserted = self._save_parents(cls=parent, using=using, update_fields=update_fields)
+ parent_inserted = self._save_parents(
+ cls=parent, using=using, update_fields=update_fields
+ )
updated = self._save_table(
- cls=parent, using=using, update_fields=update_fields,
+ cls=parent,
+ using=using,
+ update_fields=update_fields,
force_insert=parent_inserted,
)
if not updated:
@@ -851,8 +940,15 @@ class Model(metaclass=ModelBase):
field.delete_cached_value(self)
return inserted
- def _save_table(self, raw=False, cls=None, force_insert=False,
- force_update=False, using=None, update_fields=None):
+ def _save_table(
+ self,
+ raw=False,
+ cls=None,
+ force_insert=False,
+ force_update=False,
+ using=None,
+ update_fields=None,
+ ):
"""
Do the heavy-lifting involved in saving. Update or insert the data
for a single table.
@@ -861,8 +957,11 @@ class Model(metaclass=ModelBase):
non_pks = [f for f in meta.local_concrete_fields if not f.primary_key]
if update_fields:
- non_pks = [f for f in non_pks
- if f.name in update_fields or f.attname in update_fields]
+ non_pks = [
+ f
+ for f in non_pks
+ if f.name in update_fields or f.attname in update_fields
+ ]
pk_val = self._get_pk_val(meta)
if pk_val is None:
@@ -874,21 +973,28 @@ class Model(metaclass=ModelBase):
updated = False
# Skip an UPDATE when adding an instance and primary key has a default.
if (
- not raw and
- not force_insert and
- self._state.adding and
- meta.pk.default and
- meta.pk.default is not NOT_PROVIDED
+ not raw
+ and not force_insert
+ and self._state.adding
+ and meta.pk.default
+ and meta.pk.default is not NOT_PROVIDED
):
force_insert = True
# If possible, try an UPDATE. If that doesn't update anything, do an INSERT.
if pk_set and not force_insert:
base_qs = cls._base_manager.using(using)
- values = [(f, None, (getattr(self, f.attname) if raw else f.pre_save(self, False)))
- for f in non_pks]
+ values = [
+ (
+ f,
+ None,
+ (getattr(self, f.attname) if raw else f.pre_save(self, False)),
+ )
+ for f in non_pks
+ ]
forced_update = update_fields or force_update
- updated = self._do_update(base_qs, using, pk_val, values, update_fields,
- forced_update)
+ updated = self._do_update(
+ base_qs, using, pk_val, values, update_fields, forced_update
+ )
if force_update and not updated:
raise DatabaseError("Forced update did not affect any rows.")
if update_fields and not updated:
@@ -899,18 +1005,26 @@ class Model(metaclass=ModelBase):
# autopopulate the _order field
field = meta.order_with_respect_to
filter_args = field.get_filter_kwargs_for_object(self)
- self._order = cls._base_manager.using(using).filter(**filter_args).aggregate(
- _order__max=Coalesce(
- ExpressionWrapper(Max('_order') + Value(1), output_field=IntegerField()),
- Value(0),
- ),
- )['_order__max']
+ self._order = (
+ cls._base_manager.using(using)
+ .filter(**filter_args)
+ .aggregate(
+ _order__max=Coalesce(
+ ExpressionWrapper(
+ Max("_order") + Value(1), output_field=IntegerField()
+ ),
+ Value(0),
+ ),
+ )["_order__max"]
+ )
fields = meta.local_concrete_fields
if not pk_set:
fields = [f for f in fields if f is not meta.auto_field]
returning_fields = meta.db_returning_fields
- results = self._do_insert(cls._base_manager, using, fields, returning_fields, raw)
+ results = self._do_insert(
+ cls._base_manager, using, fields, returning_fields, raw
+ )
if results:
for value, field in zip(results[0], returning_fields):
setattr(self, field.attname, value)
@@ -931,7 +1045,8 @@ class Model(metaclass=ModelBase):
return update_fields is not None or filtered.exists()
if self._meta.select_on_save and not forced_update:
return (
- filtered.exists() and
+ filtered.exists()
+ and
# It may happen that the object is deleted from the DB right after
# this check, causing the subsequent UPDATE to return zero matching
# rows. The same result can occur in some rare cases when the
@@ -949,8 +1064,11 @@ class Model(metaclass=ModelBase):
return the newly created data for the model.
"""
return manager._insert(
- [self], fields=fields, returning_fields=returning_fields,
- using=using, raw=raw,
+ [self],
+ fields=fields,
+ returning_fields=returning_fields,
+ using=using,
+ raw=raw,
)
def _prepare_related_fields_for_save(self, operation_name, fields=None):
@@ -986,7 +1104,9 @@ class Model(metaclass=ModelBase):
setattr(self, field.attname, obj.pk)
# If the relationship's pk/to_field was changed, clear the
# cached relationship.
- if getattr(obj, field.target_field.attname) != getattr(self, field.attname):
+ if getattr(obj, field.target_field.attname) != getattr(
+ self, field.attname
+ ):
field.delete_cached_value(self)
def delete(self, using=None, keep_parents=False):
@@ -1006,42 +1126,59 @@ class Model(metaclass=ModelBase):
value = getattr(self, field.attname)
choices_dict = dict(make_hashable(field.flatchoices))
# force_str() to coerce lazy strings.
- return force_str(choices_dict.get(make_hashable(value), value), strings_only=True)
+ return force_str(
+ choices_dict.get(make_hashable(value), value), strings_only=True
+ )
def _get_next_or_previous_by_FIELD(self, field, is_next, **kwargs):
if not self.pk:
raise ValueError("get_next/get_previous cannot be used on unsaved objects.")
- op = 'gt' if is_next else 'lt'
- order = '' if is_next else '-'
+ op = "gt" if is_next else "lt"
+ order = "" if is_next else "-"
param = getattr(self, field.attname)
- q = Q((field.name, param), (f'pk__{op}', self.pk), _connector=Q.AND)
- q = Q(q, (f'{field.name}__{op}', param), _connector=Q.OR)
- qs = self.__class__._default_manager.using(self._state.db).filter(**kwargs).filter(q).order_by(
- '%s%s' % (order, field.name), '%spk' % order
+ q = Q((field.name, param), (f"pk__{op}", self.pk), _connector=Q.AND)
+ q = Q(q, (f"{field.name}__{op}", param), _connector=Q.OR)
+ qs = (
+ self.__class__._default_manager.using(self._state.db)
+ .filter(**kwargs)
+ .filter(q)
+ .order_by("%s%s" % (order, field.name), "%spk" % order)
)
try:
return qs[0]
except IndexError:
- raise self.DoesNotExist("%s matching query does not exist." % self.__class__._meta.object_name)
+ raise self.DoesNotExist(
+ "%s matching query does not exist." % self.__class__._meta.object_name
+ )
def _get_next_or_previous_in_order(self, is_next):
cachename = "__%s_order_cache" % is_next
if not hasattr(self, cachename):
- op = 'gt' if is_next else 'lt'
- order = '_order' if is_next else '-_order'
+ op = "gt" if is_next else "lt"
+ order = "_order" if is_next else "-_order"
order_field = self._meta.order_with_respect_to
filter_args = order_field.get_filter_kwargs_for_object(self)
- obj = self.__class__._default_manager.filter(**filter_args).filter(**{
- '_order__%s' % op: self.__class__._default_manager.values('_order').filter(**{
- self._meta.pk.name: self.pk
- })
- }).order_by(order)[:1].get()
+ obj = (
+ self.__class__._default_manager.filter(**filter_args)
+ .filter(
+ **{
+ "_order__%s"
+ % op: self.__class__._default_manager.values("_order").filter(
+ **{self._meta.pk.name: self.pk}
+ )
+ }
+ )
+ .order_by(order)[:1]
+ .get()
+ )
setattr(self, cachename, obj)
return getattr(self, cachename)
def prepare_database_save(self, field):
if self.pk is None:
- raise ValueError("Unsaved model instance %r cannot be used in an ORM query." % self)
+ raise ValueError(
+ "Unsaved model instance %r cannot be used in an ORM query." % self
+ )
return getattr(self, field.remote_field.get_related_field().attname)
def clean(self):
@@ -1085,7 +1222,9 @@ class Model(metaclass=ModelBase):
constraints = [(self.__class__, self._meta.total_unique_constraints)]
for parent_class in self._meta.get_parent_list():
if parent_class._meta.unique_together:
- unique_togethers.append((parent_class, parent_class._meta.unique_together))
+ unique_togethers.append(
+ (parent_class, parent_class._meta.unique_together)
+ )
if parent_class._meta.total_unique_constraints:
constraints.append(
(parent_class, parent_class._meta.total_unique_constraints)
@@ -1120,11 +1259,11 @@ class Model(metaclass=ModelBase):
if f.unique:
unique_checks.append((model_class, (name,)))
if f.unique_for_date and f.unique_for_date not in exclude:
- date_checks.append((model_class, 'date', name, f.unique_for_date))
+ date_checks.append((model_class, "date", name, f.unique_for_date))
if f.unique_for_year and f.unique_for_year not in exclude:
- date_checks.append((model_class, 'year', name, f.unique_for_year))
+ date_checks.append((model_class, "year", name, f.unique_for_year))
if f.unique_for_month and f.unique_for_month not in exclude:
- date_checks.append((model_class, 'month', name, f.unique_for_month))
+ date_checks.append((model_class, "month", name, f.unique_for_month))
return unique_checks, date_checks
def _perform_unique_checks(self, unique_checks):
@@ -1139,8 +1278,10 @@ class Model(metaclass=ModelBase):
f = self._meta.get_field(field_name)
lookup_value = getattr(self, f.attname)
# TODO: Handle multiple backends with different feature flags.
- if (lookup_value is None or
- (lookup_value == '' and connection.features.interprets_empty_strings_as_nulls)):
+ if lookup_value is None or (
+ lookup_value == ""
+ and connection.features.interprets_empty_strings_as_nulls
+ ):
# no value, skip the lookup
continue
if f.primary_key and not self._state.adding:
@@ -1168,7 +1309,9 @@ class Model(metaclass=ModelBase):
key = unique_check[0]
else:
key = NON_FIELD_ERRORS
- errors.setdefault(key, []).append(self.unique_error_message(model_class, unique_check))
+ errors.setdefault(key, []).append(
+ self.unique_error_message(model_class, unique_check)
+ )
return errors
@@ -1181,12 +1324,14 @@ class Model(metaclass=ModelBase):
date = getattr(self, unique_for)
if date is None:
continue
- if lookup_type == 'date':
- lookup_kwargs['%s__day' % unique_for] = date.day
- lookup_kwargs['%s__month' % unique_for] = date.month
- lookup_kwargs['%s__year' % unique_for] = date.year
+ if lookup_type == "date":
+ lookup_kwargs["%s__day" % unique_for] = date.day
+ lookup_kwargs["%s__month" % unique_for] = date.month
+ lookup_kwargs["%s__year" % unique_for] = date.year
else:
- lookup_kwargs['%s__%s' % (unique_for, lookup_type)] = getattr(date, lookup_type)
+ lookup_kwargs["%s__%s" % (unique_for, lookup_type)] = getattr(
+ date, lookup_type
+ )
lookup_kwargs[field] = getattr(self, field)
qs = model_class._default_manager.filter(**lookup_kwargs)
@@ -1205,46 +1350,48 @@ class Model(metaclass=ModelBase):
opts = self._meta
field = opts.get_field(field_name)
return ValidationError(
- message=field.error_messages['unique_for_date'],
- code='unique_for_date',
+ message=field.error_messages["unique_for_date"],
+ code="unique_for_date",
params={
- 'model': self,
- 'model_name': capfirst(opts.verbose_name),
- 'lookup_type': lookup_type,
- 'field': field_name,
- 'field_label': capfirst(field.verbose_name),
- 'date_field': unique_for,
- 'date_field_label': capfirst(opts.get_field(unique_for).verbose_name),
- }
+ "model": self,
+ "model_name": capfirst(opts.verbose_name),
+ "lookup_type": lookup_type,
+ "field": field_name,
+ "field_label": capfirst(field.verbose_name),
+ "date_field": unique_for,
+ "date_field_label": capfirst(opts.get_field(unique_for).verbose_name),
+ },
)
def unique_error_message(self, model_class, unique_check):
opts = model_class._meta
params = {
- 'model': self,
- 'model_class': model_class,
- 'model_name': capfirst(opts.verbose_name),
- 'unique_check': unique_check,
+ "model": self,
+ "model_class": model_class,
+ "model_name": capfirst(opts.verbose_name),
+ "unique_check": unique_check,
}
# A unique field
if len(unique_check) == 1:
field = opts.get_field(unique_check[0])
- params['field_label'] = capfirst(field.verbose_name)
+ params["field_label"] = capfirst(field.verbose_name)
return ValidationError(
- message=field.error_messages['unique'],
- code='unique',
+ message=field.error_messages["unique"],
+ code="unique",
params=params,
)
# unique_together
else:
- field_labels = [capfirst(opts.get_field(f).verbose_name) for f in unique_check]
- params['field_labels'] = get_text_list(field_labels, _('and'))
+ field_labels = [
+ capfirst(opts.get_field(f).verbose_name) for f in unique_check
+ ]
+ params["field_labels"] = get_text_list(field_labels, _("and"))
return ValidationError(
message=_("%(model_name)s with this %(field_labels)s already exists."),
- code='unique_together',
+ code="unique_together",
params=params,
)
@@ -1311,9 +1458,13 @@ class Model(metaclass=ModelBase):
@classmethod
def check(cls, **kwargs):
- errors = [*cls._check_swappable(), *cls._check_model(), *cls._check_managers(**kwargs)]
+ errors = [
+ *cls._check_swappable(),
+ *cls._check_model(),
+ *cls._check_managers(**kwargs),
+ ]
if not cls._meta.swapped:
- databases = kwargs.get('databases') or []
+ databases = kwargs.get("databases") or []
errors += [
*cls._check_fields(**kwargs),
*cls._check_m2m_through_same_relationship(),
@@ -1345,16 +1496,17 @@ class Model(metaclass=ModelBase):
@classmethod
def _check_default_pk(cls):
if (
- not cls._meta.abstract and
- cls._meta.pk.auto_created and
+ not cls._meta.abstract
+ and cls._meta.pk.auto_created
+ and
# Inherited PKs are checked in parents models.
not (
- isinstance(cls._meta.pk, OneToOneField) and
- cls._meta.pk.remote_field.parent_link
- ) and
- not settings.is_overridden('DEFAULT_AUTO_FIELD') and
- cls._meta.app_config and
- not cls._meta.app_config._is_default_auto_field_overridden
+ isinstance(cls._meta.pk, OneToOneField)
+ and cls._meta.pk.remote_field.parent_link
+ )
+ and not settings.is_overridden("DEFAULT_AUTO_FIELD")
+ and cls._meta.app_config
+ and not cls._meta.app_config._is_default_auto_field_overridden
):
return [
checks.Warning(
@@ -1368,7 +1520,7 @@ class Model(metaclass=ModelBase):
f"of AutoField, e.g. 'django.db.models.BigAutoField'."
),
obj=cls,
- id='models.W042',
+ id="models.W042",
),
]
return []
@@ -1383,19 +1535,19 @@ class Model(metaclass=ModelBase):
except ValueError:
errors.append(
checks.Error(
- "'%s' is not of the form 'app_label.app_name'." % cls._meta.swappable,
- id='models.E001',
+ "'%s' is not of the form 'app_label.app_name'."
+ % cls._meta.swappable,
+ id="models.E001",
)
)
except LookupError:
- app_label, model_name = cls._meta.swapped.split('.')
+ app_label, model_name = cls._meta.swapped.split(".")
errors.append(
checks.Error(
"'%s' references '%s.%s', which has not been "
- "installed, or is abstract." % (
- cls._meta.swappable, app_label, model_name
- ),
- id='models.E002',
+ "installed, or is abstract."
+ % (cls._meta.swappable, app_label, model_name),
+ id="models.E002",
)
)
return errors
@@ -1408,7 +1560,7 @@ class Model(metaclass=ModelBase):
errors.append(
checks.Error(
"Proxy model '%s' contains model fields." % cls.__name__,
- id='models.E017',
+ id="models.E017",
)
)
return errors
@@ -1433,8 +1585,7 @@ class Model(metaclass=ModelBase):
@classmethod
def _check_m2m_through_same_relationship(cls):
- """ Check if no relationship model is used by more than one m2m field.
- """
+ """Check if no relationship model is used by more than one m2m field."""
errors = []
seen_intermediary_signatures = []
@@ -1448,15 +1599,20 @@ class Model(metaclass=ModelBase):
fields = (f for f in fields if isinstance(f.remote_field.through, ModelBase))
for f in fields:
- signature = (f.remote_field.model, cls, f.remote_field.through, f.remote_field.through_fields)
+ signature = (
+ f.remote_field.model,
+ cls,
+ f.remote_field.through,
+ f.remote_field.through_fields,
+ )
if signature in seen_intermediary_signatures:
errors.append(
checks.Error(
"The model has two identical many-to-many relations "
- "through the intermediate model '%s'." %
- f.remote_field.through._meta.label,
+ "through the intermediate model '%s'."
+ % f.remote_field.through._meta.label,
obj=cls,
- id='models.E003',
+ id="models.E003",
)
)
else:
@@ -1466,15 +1622,17 @@ class Model(metaclass=ModelBase):
@classmethod
def _check_id_field(cls):
"""Check if `id` field is a primary key."""
- fields = [f for f in cls._meta.local_fields if f.name == 'id' and f != cls._meta.pk]
+ fields = [
+ f for f in cls._meta.local_fields if f.name == "id" and f != cls._meta.pk
+ ]
# fields is empty or consists of the invalid "id" field
- if fields and not fields[0].primary_key and cls._meta.pk.name == 'id':
+ if fields and not fields[0].primary_key and cls._meta.pk.name == "id":
return [
checks.Error(
"'id' can only be used as a field name if the field also "
"sets 'primary_key=True'.",
obj=cls,
- id='models.E004',
+ id="models.E004",
)
]
else:
@@ -1495,12 +1653,10 @@ class Model(metaclass=ModelBase):
checks.Error(
"The field '%s' from parent model "
"'%s' clashes with the field '%s' "
- "from parent model '%s'." % (
- clash.name, clash.model._meta,
- f.name, f.model._meta
- ),
+ "from parent model '%s'."
+ % (clash.name, clash.model._meta, f.name, f.model._meta),
obj=cls,
- id='models.E005',
+ id="models.E005",
)
)
used_fields[f.name] = f
@@ -1520,16 +1676,16 @@ class Model(metaclass=ModelBase):
# field "id" and automatically added unique field "id", both
# defined at the same model. This special case is considered in
# _check_id_field and here we ignore it.
- id_conflict = f.name == "id" and clash and clash.name == "id" and clash.model == cls
+ id_conflict = (
+ f.name == "id" and clash and clash.name == "id" and clash.model == cls
+ )
if clash and not id_conflict:
errors.append(
checks.Error(
"The field '%s' clashes with the field '%s' "
- "from model '%s'." % (
- f.name, clash.name, clash.model._meta
- ),
+ "from model '%s'." % (f.name, clash.name, clash.model._meta),
obj=f,
- id='models.E006',
+ id="models.E006",
)
)
used_fields[f.name] = f
@@ -1554,7 +1710,7 @@ class Model(metaclass=ModelBase):
"another field." % (f.name, column_name),
hint="Specify a 'db_column' for the field.",
obj=cls,
- id='models.E007'
+ id="models.E007",
)
)
else:
@@ -1566,13 +1722,13 @@ class Model(metaclass=ModelBase):
def _check_model_name_db_lookup_clashes(cls):
errors = []
model_name = cls.__name__
- if model_name.startswith('_') or model_name.endswith('_'):
+ if model_name.startswith("_") or model_name.endswith("_"):
errors.append(
checks.Error(
"The model name '%s' cannot start or end with an underscore "
"as it collides with the query lookup syntax." % model_name,
obj=cls,
- id='models.E023'
+ id="models.E023",
)
)
elif LOOKUP_SEP in model_name:
@@ -1581,7 +1737,7 @@ class Model(metaclass=ModelBase):
"The model name '%s' cannot contain double underscores as "
"it collides with the query lookup syntax." % model_name,
obj=cls,
- id='models.E024'
+ id="models.E024",
)
)
return errors
@@ -1591,7 +1747,8 @@ class Model(metaclass=ModelBase):
errors = []
property_names = cls._meta._property_names
related_field_accessors = (
- f.get_attname() for f in cls._meta._get_fields(reverse=False)
+ f.get_attname()
+ for f in cls._meta._get_fields(reverse=False)
if f.is_relation and f.related_model is not None
)
for accessor in related_field_accessors:
@@ -1601,7 +1758,7 @@ class Model(metaclass=ModelBase):
"The property '%s' clashes with a related field "
"accessor." % accessor,
obj=cls,
- id='models.E025',
+ id="models.E025",
)
)
return errors
@@ -1615,7 +1772,7 @@ class Model(metaclass=ModelBase):
"The model cannot have more than one field with "
"'primary_key=True'.",
obj=cls,
- id='models.E026',
+ id="models.E026",
)
)
return errors
@@ -1628,16 +1785,18 @@ class Model(metaclass=ModelBase):
checks.Error(
"'index_together' must be a list or tuple.",
obj=cls,
- id='models.E008',
+ id="models.E008",
)
]
- elif any(not isinstance(fields, (tuple, list)) for fields in cls._meta.index_together):
+ elif any(
+ not isinstance(fields, (tuple, list)) for fields in cls._meta.index_together
+ ):
return [
checks.Error(
"All 'index_together' elements must be lists or tuples.",
obj=cls,
- id='models.E009',
+ id="models.E009",
)
]
@@ -1655,16 +1814,19 @@ class Model(metaclass=ModelBase):
checks.Error(
"'unique_together' must be a list or tuple.",
obj=cls,
- id='models.E010',
+ id="models.E010",
)
]
- elif any(not isinstance(fields, (tuple, list)) for fields in cls._meta.unique_together):
+ elif any(
+ not isinstance(fields, (tuple, list))
+ for fields in cls._meta.unique_together
+ ):
return [
checks.Error(
"All 'unique_together' elements must be lists or tuples.",
obj=cls,
- id='models.E011',
+ id="models.E011",
)
]
@@ -1682,13 +1844,13 @@ class Model(metaclass=ModelBase):
for index in cls._meta.indexes:
# Index name can't start with an underscore or a number, restricted
# for cross-database compatibility with Oracle.
- if index.name[0] == '_' or index.name[0].isdigit():
+ if index.name[0] == "_" or index.name[0].isdigit():
errors.append(
checks.Error(
"The index name '%s' cannot start with an underscore "
"or a number." % index.name,
obj=cls,
- id='models.E033',
+ id="models.E033",
),
)
if len(index.name) > index.max_name_length:
@@ -1697,7 +1859,7 @@ class Model(metaclass=ModelBase):
"The index name '%s' cannot be longer than %d "
"characters." % (index.name, index.max_name_length),
obj=cls,
- id='models.E034',
+ id="models.E034",
),
)
if index.contains_expressions:
@@ -1710,57 +1872,59 @@ class Model(metaclass=ModelBase):
continue
connection = connections[db]
if not (
- connection.features.supports_partial_indexes or
- 'supports_partial_indexes' in cls._meta.required_db_features
+ connection.features.supports_partial_indexes
+ or "supports_partial_indexes" in cls._meta.required_db_features
) and any(index.condition is not None for index in cls._meta.indexes):
errors.append(
checks.Warning(
- '%s does not support indexes with conditions.'
+ "%s does not support indexes with conditions."
% connection.display_name,
hint=(
"Conditions will be ignored. Silence this warning "
"if you don't care about it."
),
obj=cls,
- id='models.W037',
+ id="models.W037",
)
)
if not (
- connection.features.supports_covering_indexes or
- 'supports_covering_indexes' in cls._meta.required_db_features
+ connection.features.supports_covering_indexes
+ or "supports_covering_indexes" in cls._meta.required_db_features
) and any(index.include for index in cls._meta.indexes):
errors.append(
checks.Warning(
- '%s does not support indexes with non-key columns.'
+ "%s does not support indexes with non-key columns."
% connection.display_name,
hint=(
"Non-key columns will be ignored. Silence this "
"warning if you don't care about it."
),
obj=cls,
- id='models.W040',
+ id="models.W040",
)
)
if not (
- connection.features.supports_expression_indexes or
- 'supports_expression_indexes' in cls._meta.required_db_features
+ connection.features.supports_expression_indexes
+ or "supports_expression_indexes" in cls._meta.required_db_features
) and any(index.contains_expressions for index in cls._meta.indexes):
errors.append(
checks.Warning(
- '%s does not support indexes on expressions.'
+ "%s does not support indexes on expressions."
% connection.display_name,
hint=(
"An index won't be created. Silence this warning "
"if you don't care about it."
),
obj=cls,
- id='models.W043',
+ id="models.W043",
)
)
- fields = [field for index in cls._meta.indexes for field, _ in index.fields_orders]
+ fields = [
+ field for index in cls._meta.indexes for field, _ in index.fields_orders
+ ]
fields += [include for index in cls._meta.indexes for include in index.include]
fields += references
- errors.extend(cls._check_local_fields(fields, 'indexes'))
+ errors.extend(cls._check_local_fields(fields, "indexes"))
return errors
@classmethod
@@ -1772,7 +1936,7 @@ class Model(metaclass=ModelBase):
forward_fields_map = {}
for field in cls._meta._get_fields(reverse=False):
forward_fields_map[field.name] = field
- if hasattr(field, 'attname'):
+ if hasattr(field, "attname"):
forward_fields_map[field.attname] = field
errors = []
@@ -1782,11 +1946,13 @@ class Model(metaclass=ModelBase):
except KeyError:
errors.append(
checks.Error(
- "'%s' refers to the nonexistent field '%s'." % (
- option, field_name,
+ "'%s' refers to the nonexistent field '%s'."
+ % (
+ option,
+ field_name,
),
obj=cls,
- id='models.E012',
+ id="models.E012",
)
)
else:
@@ -1794,11 +1960,14 @@ class Model(metaclass=ModelBase):
errors.append(
checks.Error(
"'%s' refers to a ManyToManyField '%s', but "
- "ManyToManyFields are not permitted in '%s'." % (
- option, field_name, option,
+ "ManyToManyFields are not permitted in '%s'."
+ % (
+ option,
+ field_name,
+ option,
),
obj=cls,
- id='models.E013',
+ id="models.E013",
)
)
elif field not in cls._meta.local_fields:
@@ -1808,7 +1977,7 @@ class Model(metaclass=ModelBase):
% (option, field_name, cls._meta.object_name),
hint="This issue may be caused by multi-table inheritance.",
obj=cls,
- id='models.E016',
+ id="models.E016",
)
)
return errors
@@ -1824,7 +1993,7 @@ class Model(metaclass=ModelBase):
checks.Error(
"'ordering' and 'order_with_respect_to' cannot be used together.",
obj=cls,
- id='models.E021',
+ id="models.E021",
),
]
@@ -1836,7 +2005,7 @@ class Model(metaclass=ModelBase):
checks.Error(
"'ordering' must be a tuple or list (even if you want to order by only one field).",
obj=cls,
- id='models.E014',
+ id="models.E014",
)
]
@@ -1844,10 +2013,10 @@ class Model(metaclass=ModelBase):
fields = cls._meta.ordering
# Skip expressions and '?' fields.
- fields = (f for f in fields if isinstance(f, str) and f != '?')
+ fields = (f for f in fields if isinstance(f, str) and f != "?")
# Convert "-field" to "field".
- fields = ((f[1:] if f.startswith('-') else f) for f in fields)
+ fields = ((f[1:] if f.startswith("-") else f) for f in fields)
# Separate related fields and non-related fields.
_fields = []
@@ -1866,7 +2035,7 @@ class Model(metaclass=ModelBase):
for part in field.split(LOOKUP_SEP):
try:
# pk is an alias that won't be found by opts.get_field.
- if part == 'pk':
+ if part == "pk":
fld = _cls._meta.pk
else:
fld = _cls._meta.get_field(part)
@@ -1883,13 +2052,13 @@ class Model(metaclass=ModelBase):
"'ordering' refers to the nonexistent field, "
"related field, or lookup '%s'." % field,
obj=cls,
- id='models.E015',
+ id="models.E015",
)
)
# Skip ordering on pk. This is always a valid order_by field
# but is an alias and therefore won't be found by opts.get_field.
- fields = {f for f in fields if f != 'pk'}
+ fields = {f for f in fields if f != "pk"}
# Check for invalid or nonexistent fields in ordering.
invalid_fields = []
@@ -1897,10 +2066,14 @@ class Model(metaclass=ModelBase):
# Any field name that is not present in field_names does not exist.
# Also, ordering by m2m fields is not allowed.
opts = cls._meta
- valid_fields = set(chain.from_iterable(
- (f.name, f.attname) if not (f.auto_created and not f.concrete) else (f.field.related_query_name(),)
- for f in chain(opts.fields, opts.related_objects)
- ))
+ valid_fields = set(
+ chain.from_iterable(
+ (f.name, f.attname)
+ if not (f.auto_created and not f.concrete)
+ else (f.field.related_query_name(),)
+ for f in chain(opts.fields, opts.related_objects)
+ )
+ )
invalid_fields.extend(fields - valid_fields)
@@ -1910,7 +2083,7 @@ class Model(metaclass=ModelBase):
"'ordering' refers to the nonexistent field, related "
"field, or lookup '%s'." % invalid_field,
obj=cls,
- id='models.E015',
+ id="models.E015",
)
)
return errors
@@ -1952,7 +2125,11 @@ class Model(metaclass=ModelBase):
# Check if auto-generated name for the field is too long
# for the database.
- if f.db_column is None and column_name is not None and len(column_name) > allowed_len:
+ if (
+ f.db_column is None
+ and column_name is not None
+ and len(column_name) > allowed_len
+ ):
errors.append(
checks.Error(
'Autogenerated column name too long for field "%s". '
@@ -1960,7 +2137,7 @@ class Model(metaclass=ModelBase):
% (column_name, allowed_len, db_alias),
hint="Set the column name manually using 'db_column'.",
obj=cls,
- id='models.E018',
+ id="models.E018",
)
)
@@ -1973,10 +2150,14 @@ class Model(metaclass=ModelBase):
# for the database.
for m2m in f.remote_field.through._meta.local_fields:
_, rel_name = m2m.get_attname_column()
- if m2m.db_column is None and rel_name is not None and len(rel_name) > allowed_len:
+ if (
+ m2m.db_column is None
+ and rel_name is not None
+ and len(rel_name) > allowed_len
+ ):
errors.append(
checks.Error(
- 'Autogenerated column name too long for M2M field '
+ "Autogenerated column name too long for M2M field "
'"%s". Maximum length is "%s" for database "%s".'
% (rel_name, allowed_len, db_alias),
hint=(
@@ -1984,7 +2165,7 @@ class Model(metaclass=ModelBase):
"M2M and then set column_name using 'db_column'."
),
obj=cls,
- id='models.E019',
+ id="models.E019",
)
)
@@ -2002,7 +2183,7 @@ class Model(metaclass=ModelBase):
yield from cls._get_expr_references(child)
elif isinstance(expr, F):
yield tuple(expr.name.split(LOOKUP_SEP))
- elif hasattr(expr, 'get_source_expressions'):
+ elif hasattr(expr, "get_source_expressions"):
for src_expr in expr.get_source_expressions():
yield from cls._get_expr_references(src_expr)
@@ -2014,132 +2195,145 @@ class Model(metaclass=ModelBase):
continue
connection = connections[db]
if not (
- connection.features.supports_table_check_constraints or
- 'supports_table_check_constraints' in cls._meta.required_db_features
+ connection.features.supports_table_check_constraints
+ or "supports_table_check_constraints" in cls._meta.required_db_features
) and any(
isinstance(constraint, CheckConstraint)
for constraint in cls._meta.constraints
):
errors.append(
checks.Warning(
- '%s does not support check constraints.' % connection.display_name,
+ "%s does not support check constraints."
+ % connection.display_name,
hint=(
"A constraint won't be created. Silence this "
"warning if you don't care about it."
),
obj=cls,
- id='models.W027',
+ id="models.W027",
)
)
if not (
- connection.features.supports_partial_indexes or
- 'supports_partial_indexes' in cls._meta.required_db_features
+ connection.features.supports_partial_indexes
+ or "supports_partial_indexes" in cls._meta.required_db_features
) and any(
- isinstance(constraint, UniqueConstraint) and constraint.condition is not None
+ isinstance(constraint, UniqueConstraint)
+ and constraint.condition is not None
for constraint in cls._meta.constraints
):
errors.append(
checks.Warning(
- '%s does not support unique constraints with '
- 'conditions.' % connection.display_name,
+ "%s does not support unique constraints with "
+ "conditions." % connection.display_name,
hint=(
"A constraint won't be created. Silence this "
"warning if you don't care about it."
),
obj=cls,
- id='models.W036',
+ id="models.W036",
)
)
if not (
- connection.features.supports_deferrable_unique_constraints or
- 'supports_deferrable_unique_constraints' in cls._meta.required_db_features
+ connection.features.supports_deferrable_unique_constraints
+ or "supports_deferrable_unique_constraints"
+ in cls._meta.required_db_features
) and any(
- isinstance(constraint, UniqueConstraint) and constraint.deferrable is not None
+ isinstance(constraint, UniqueConstraint)
+ and constraint.deferrable is not None
for constraint in cls._meta.constraints
):
errors.append(
checks.Warning(
- '%s does not support deferrable unique constraints.'
+ "%s does not support deferrable unique constraints."
% connection.display_name,
hint=(
"A constraint won't be created. Silence this "
"warning if you don't care about it."
),
obj=cls,
- id='models.W038',
+ id="models.W038",
)
)
if not (
- connection.features.supports_covering_indexes or
- 'supports_covering_indexes' in cls._meta.required_db_features
+ connection.features.supports_covering_indexes
+ or "supports_covering_indexes" in cls._meta.required_db_features
) and any(
isinstance(constraint, UniqueConstraint) and constraint.include
for constraint in cls._meta.constraints
):
errors.append(
checks.Warning(
- '%s does not support unique constraints with non-key '
- 'columns.' % connection.display_name,
+ "%s does not support unique constraints with non-key "
+ "columns." % connection.display_name,
hint=(
"A constraint won't be created. Silence this "
"warning if you don't care about it."
),
obj=cls,
- id='models.W039',
+ id="models.W039",
)
)
if not (
- connection.features.supports_expression_indexes or
- 'supports_expression_indexes' in cls._meta.required_db_features
+ connection.features.supports_expression_indexes
+ or "supports_expression_indexes" in cls._meta.required_db_features
) and any(
- isinstance(constraint, UniqueConstraint) and constraint.contains_expressions
+ isinstance(constraint, UniqueConstraint)
+ and constraint.contains_expressions
for constraint in cls._meta.constraints
):
errors.append(
checks.Warning(
- '%s does not support unique constraints on '
- 'expressions.' % connection.display_name,
+ "%s does not support unique constraints on "
+ "expressions." % connection.display_name,
hint=(
"A constraint won't be created. Silence this "
"warning if you don't care about it."
),
obj=cls,
- id='models.W044',
+ id="models.W044",
)
)
- fields = set(chain.from_iterable(
- (*constraint.fields, *constraint.include)
- for constraint in cls._meta.constraints if isinstance(constraint, UniqueConstraint)
- ))
+ fields = set(
+ chain.from_iterable(
+ (*constraint.fields, *constraint.include)
+ for constraint in cls._meta.constraints
+ if isinstance(constraint, UniqueConstraint)
+ )
+ )
references = set()
for constraint in cls._meta.constraints:
if isinstance(constraint, UniqueConstraint):
if (
- connection.features.supports_partial_indexes or
- 'supports_partial_indexes' not in cls._meta.required_db_features
+ connection.features.supports_partial_indexes
+ or "supports_partial_indexes"
+ not in cls._meta.required_db_features
) and isinstance(constraint.condition, Q):
- references.update(cls._get_expr_references(constraint.condition))
+ references.update(
+ cls._get_expr_references(constraint.condition)
+ )
if (
- connection.features.supports_expression_indexes or
- 'supports_expression_indexes' not in cls._meta.required_db_features
+ connection.features.supports_expression_indexes
+ or "supports_expression_indexes"
+ not in cls._meta.required_db_features
) and constraint.contains_expressions:
for expression in constraint.expressions:
references.update(cls._get_expr_references(expression))
elif isinstance(constraint, CheckConstraint):
if (
- connection.features.supports_table_check_constraints or
- 'supports_table_check_constraints' not in cls._meta.required_db_features
+ connection.features.supports_table_check_constraints
+ or "supports_table_check_constraints"
+ not in cls._meta.required_db_features
) and isinstance(constraint.check, Q):
references.update(cls._get_expr_references(constraint.check))
for field_name, *lookups in references:
# pk is an alias that won't be found by opts.get_field.
- if field_name != 'pk':
+ if field_name != "pk":
fields.add(field_name)
if not lookups:
# If it has no lookups it cannot result in a JOIN.
continue
try:
- if field_name == 'pk':
+ if field_name == "pk":
field = cls._meta.pk
else:
field = cls._meta.get_field(field_name)
@@ -2150,20 +2344,20 @@ class Model(metaclass=ModelBase):
# JOIN must happen at the first lookup.
first_lookup = lookups[0]
if (
- hasattr(field, 'get_transform') and
- hasattr(field, 'get_lookup') and
- field.get_transform(first_lookup) is None and
- field.get_lookup(first_lookup) is None
+ hasattr(field, "get_transform")
+ and hasattr(field, "get_lookup")
+ and field.get_transform(first_lookup) is None
+ and field.get_lookup(first_lookup) is None
):
errors.append(
checks.Error(
"'constraints' refers to the joined field '%s'."
% LOOKUP_SEP.join([field_name] + lookups),
obj=cls,
- id='models.E041',
+ id="models.E041",
)
)
- errors.extend(cls._check_local_fields(fields, 'constraints'))
+ errors.extend(cls._check_local_fields(fields, "constraints"))
return errors
@@ -2173,14 +2367,16 @@ class Model(metaclass=ModelBase):
# ORDERING METHODS #########################
+
def method_set_order(self, ordered_obj, id_list, using=None):
if using is None:
using = DEFAULT_DB_ALIAS
order_wrt = ordered_obj._meta.order_with_respect_to
filter_args = order_wrt.get_forward_related_filter(self)
- ordered_obj.objects.db_manager(using).filter(**filter_args).bulk_update([
- ordered_obj(pk=pk, _order=order) for order, pk in enumerate(id_list)
- ], ['_order'])
+ ordered_obj.objects.db_manager(using).filter(**filter_args).bulk_update(
+ [ordered_obj(pk=pk, _order=order) for order, pk in enumerate(id_list)],
+ ["_order"],
+ )
def method_get_order(self, ordered_obj):
@@ -2193,15 +2389,16 @@ def method_get_order(self, ordered_obj):
def make_foreign_order_accessors(model, related_model):
setattr(
related_model,
- 'get_%s_order' % model.__name__.lower(),
- partialmethod(method_get_order, model)
+ "get_%s_order" % model.__name__.lower(),
+ partialmethod(method_get_order, model),
)
setattr(
related_model,
- 'set_%s_order' % model.__name__.lower(),
- partialmethod(method_set_order, model)
+ "set_%s_order" % model.__name__.lower(),
+ partialmethod(method_set_order, model),
)
+
########
# MISC #
########
diff --git a/django/db/models/constants.py b/django/db/models/constants.py
index 95addd2ab0..a0c99c95fc 100644
--- a/django/db/models/constants.py
+++ b/django/db/models/constants.py
@@ -4,9 +4,9 @@ Constants used across the ORM in general.
from enum import Enum
# Separator used to split filter strings apart.
-LOOKUP_SEP = '__'
+LOOKUP_SEP = "__"
class OnConflict(Enum):
- IGNORE = 'ignore'
- UPDATE = 'update'
+ IGNORE = "ignore"
+ UPDATE = "update"
diff --git a/django/db/models/constraints.py b/django/db/models/constraints.py
index 5abedaf3d1..721e43ae58 100644
--- a/django/db/models/constraints.py
+++ b/django/db/models/constraints.py
@@ -5,7 +5,7 @@ from django.db.models.indexes import IndexExpression
from django.db.models.query_utils import Q
from django.db.models.sql.query import Query
-__all__ = ['CheckConstraint', 'Deferrable', 'UniqueConstraint']
+__all__ = ["CheckConstraint", "Deferrable", "UniqueConstraint"]
class BaseConstraint:
@@ -17,18 +17,18 @@ class BaseConstraint:
return False
def constraint_sql(self, model, schema_editor):
- raise NotImplementedError('This method must be implemented by a subclass.')
+ raise NotImplementedError("This method must be implemented by a subclass.")
def create_sql(self, model, schema_editor):
- raise NotImplementedError('This method must be implemented by a subclass.')
+ raise NotImplementedError("This method must be implemented by a subclass.")
def remove_sql(self, model, schema_editor):
- raise NotImplementedError('This method must be implemented by a subclass.')
+ raise NotImplementedError("This method must be implemented by a subclass.")
def deconstruct(self):
- path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__)
- path = path.replace('django.db.models.constraints', 'django.db.models')
- return (path, (), {'name': self.name})
+ path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
+ path = path.replace("django.db.models.constraints", "django.db.models")
+ return (path, (), {"name": self.name})
def clone(self):
_, args, kwargs = self.deconstruct()
@@ -38,9 +38,9 @@ class BaseConstraint:
class CheckConstraint(BaseConstraint):
def __init__(self, *, check, name):
self.check = check
- if not getattr(check, 'conditional', False):
+ if not getattr(check, "conditional", False):
raise TypeError(
- 'CheckConstraint.check must be a Q instance or boolean expression.'
+ "CheckConstraint.check must be a Q instance or boolean expression."
)
super().__init__(name)
@@ -63,7 +63,7 @@ class CheckConstraint(BaseConstraint):
return schema_editor._delete_check_sql(model, self.name)
def __repr__(self):
- return '<%s: check=%s name=%s>' % (
+ return "<%s: check=%s name=%s>" % (
self.__class__.__qualname__,
self.check,
repr(self.name),
@@ -76,17 +76,17 @@ class CheckConstraint(BaseConstraint):
def deconstruct(self):
path, args, kwargs = super().deconstruct()
- kwargs['check'] = self.check
+ kwargs["check"] = self.check
return path, args, kwargs
class Deferrable(Enum):
- DEFERRED = 'deferred'
- IMMEDIATE = 'immediate'
+ DEFERRED = "deferred"
+ IMMEDIATE = "immediate"
# A similar format was proposed for Python 3.10.
def __repr__(self):
- return f'{self.__class__.__qualname__}.{self._name_}'
+ return f"{self.__class__.__qualname__}.{self._name_}"
class UniqueConstraint(BaseConstraint):
@@ -101,51 +101,43 @@ class UniqueConstraint(BaseConstraint):
opclasses=(),
):
if not name:
- raise ValueError('A unique constraint must be named.')
+ raise ValueError("A unique constraint must be named.")
if not expressions and not fields:
raise ValueError(
- 'At least one field or expression is required to define a '
- 'unique constraint.'
+ "At least one field or expression is required to define a "
+ "unique constraint."
)
if expressions and fields:
raise ValueError(
- 'UniqueConstraint.fields and expressions are mutually exclusive.'
+ "UniqueConstraint.fields and expressions are mutually exclusive."
)
if not isinstance(condition, (type(None), Q)):
- raise ValueError('UniqueConstraint.condition must be a Q instance.')
+ raise ValueError("UniqueConstraint.condition must be a Q instance.")
if condition and deferrable:
- raise ValueError(
- 'UniqueConstraint with conditions cannot be deferred.'
- )
+ raise ValueError("UniqueConstraint with conditions cannot be deferred.")
if include and deferrable:
- raise ValueError(
- 'UniqueConstraint with include fields cannot be deferred.'
- )
+ raise ValueError("UniqueConstraint with include fields cannot be deferred.")
if opclasses and deferrable:
- raise ValueError(
- 'UniqueConstraint with opclasses cannot be deferred.'
- )
+ raise ValueError("UniqueConstraint with opclasses cannot be deferred.")
if expressions and deferrable:
- raise ValueError(
- 'UniqueConstraint with expressions cannot be deferred.'
- )
+ raise ValueError("UniqueConstraint with expressions cannot be deferred.")
if expressions and opclasses:
raise ValueError(
- 'UniqueConstraint.opclasses cannot be used with expressions. '
- 'Use django.contrib.postgres.indexes.OpClass() instead.'
+ "UniqueConstraint.opclasses cannot be used with expressions. "
+ "Use django.contrib.postgres.indexes.OpClass() instead."
)
if not isinstance(deferrable, (type(None), Deferrable)):
raise ValueError(
- 'UniqueConstraint.deferrable must be a Deferrable instance.'
+ "UniqueConstraint.deferrable must be a Deferrable instance."
)
if not isinstance(include, (type(None), list, tuple)):
- raise ValueError('UniqueConstraint.include must be a list or tuple.')
+ raise ValueError("UniqueConstraint.include must be a list or tuple.")
if not isinstance(opclasses, (list, tuple)):
- raise ValueError('UniqueConstraint.opclasses must be a list or tuple.')
+ raise ValueError("UniqueConstraint.opclasses must be a list or tuple.")
if opclasses and len(fields) != len(opclasses):
raise ValueError(
- 'UniqueConstraint.fields and UniqueConstraint.opclasses must '
- 'have the same number of elements.'
+ "UniqueConstraint.fields and UniqueConstraint.opclasses must "
+ "have the same number of elements."
)
self.fields = tuple(fields)
self.condition = condition
@@ -185,70 +177,91 @@ class UniqueConstraint(BaseConstraint):
def constraint_sql(self, model, schema_editor):
fields = [model._meta.get_field(field_name) for field_name in self.fields]
- include = [model._meta.get_field(field_name).column for field_name in self.include]
+ include = [
+ model._meta.get_field(field_name).column for field_name in self.include
+ ]
condition = self._get_condition_sql(model, schema_editor)
expressions = self._get_index_expressions(model, schema_editor)
return schema_editor._unique_sql(
- model, fields, self.name, condition=condition,
- deferrable=self.deferrable, include=include,
- opclasses=self.opclasses, expressions=expressions,
+ model,
+ fields,
+ self.name,
+ condition=condition,
+ deferrable=self.deferrable,
+ include=include,
+ opclasses=self.opclasses,
+ expressions=expressions,
)
def create_sql(self, model, schema_editor):
fields = [model._meta.get_field(field_name) for field_name in self.fields]
- include = [model._meta.get_field(field_name).column for field_name in self.include]
+ include = [
+ model._meta.get_field(field_name).column for field_name in self.include
+ ]
condition = self._get_condition_sql(model, schema_editor)
expressions = self._get_index_expressions(model, schema_editor)
return schema_editor._create_unique_sql(
- model, fields, self.name, condition=condition,
- deferrable=self.deferrable, include=include,
- opclasses=self.opclasses, expressions=expressions,
+ model,
+ fields,
+ self.name,
+ condition=condition,
+ deferrable=self.deferrable,
+ include=include,
+ opclasses=self.opclasses,
+ expressions=expressions,
)
def remove_sql(self, model, schema_editor):
condition = self._get_condition_sql(model, schema_editor)
- include = [model._meta.get_field(field_name).column for field_name in self.include]
+ include = [
+ model._meta.get_field(field_name).column for field_name in self.include
+ ]
expressions = self._get_index_expressions(model, schema_editor)
return schema_editor._delete_unique_sql(
- model, self.name, condition=condition, deferrable=self.deferrable,
- include=include, opclasses=self.opclasses, expressions=expressions,
+ model,
+ self.name,
+ condition=condition,
+ deferrable=self.deferrable,
+ include=include,
+ opclasses=self.opclasses,
+ expressions=expressions,
)
def __repr__(self):
- return '<%s:%s%s%s%s%s%s%s>' % (
+ return "<%s:%s%s%s%s%s%s%s>" % (
self.__class__.__qualname__,
- '' if not self.fields else ' fields=%s' % repr(self.fields),
- '' if not self.expressions else ' expressions=%s' % repr(self.expressions),
- ' name=%s' % repr(self.name),
- '' if self.condition is None else ' condition=%s' % self.condition,
- '' if self.deferrable is None else ' deferrable=%r' % self.deferrable,
- '' if not self.include else ' include=%s' % repr(self.include),
- '' if not self.opclasses else ' opclasses=%s' % repr(self.opclasses),
+ "" if not self.fields else " fields=%s" % repr(self.fields),
+ "" if not self.expressions else " expressions=%s" % repr(self.expressions),
+ " name=%s" % repr(self.name),
+ "" if self.condition is None else " condition=%s" % self.condition,
+ "" if self.deferrable is None else " deferrable=%r" % self.deferrable,
+ "" if not self.include else " include=%s" % repr(self.include),
+ "" if not self.opclasses else " opclasses=%s" % repr(self.opclasses),
)
def __eq__(self, other):
if isinstance(other, UniqueConstraint):
return (
- self.name == other.name and
- self.fields == other.fields and
- self.condition == other.condition and
- self.deferrable == other.deferrable and
- self.include == other.include and
- self.opclasses == other.opclasses and
- self.expressions == other.expressions
+ self.name == other.name
+ and self.fields == other.fields
+ and self.condition == other.condition
+ and self.deferrable == other.deferrable
+ and self.include == other.include
+ and self.opclasses == other.opclasses
+ and self.expressions == other.expressions
)
return super().__eq__(other)
def deconstruct(self):
path, args, kwargs = super().deconstruct()
if self.fields:
- kwargs['fields'] = self.fields
+ kwargs["fields"] = self.fields
if self.condition:
- kwargs['condition'] = self.condition
+ kwargs["condition"] = self.condition
if self.deferrable:
- kwargs['deferrable'] = self.deferrable
+ kwargs["deferrable"] = self.deferrable
if self.include:
- kwargs['include'] = self.include
+ kwargs["include"] = self.include
if self.opclasses:
- kwargs['opclasses'] = self.opclasses
+ kwargs["opclasses"] = self.opclasses
return path, self.expressions, kwargs
diff --git a/django/db/models/deletion.py b/django/db/models/deletion.py
index b99337a309..6912b49498 100644
--- a/django/db/models/deletion.py
+++ b/django/db/models/deletion.py
@@ -21,8 +21,11 @@ class RestrictedError(IntegrityError):
def CASCADE(collector, field, sub_objs, using):
collector.collect(
- sub_objs, source=field.remote_field.model, source_attr=field.name,
- nullable=field.null, fail_on_restricted=False,
+ sub_objs,
+ source=field.remote_field.model,
+ source_attr=field.name,
+ nullable=field.null,
+ fail_on_restricted=False,
)
if field.null and not connections[using].features.can_defer_constraint_checks:
collector.add_field_update(field, None, sub_objs)
@@ -31,10 +34,13 @@ def CASCADE(collector, field, sub_objs, using):
def PROTECT(collector, field, sub_objs, using):
raise ProtectedError(
"Cannot delete some instances of model '%s' because they are "
- "referenced through a protected foreign key: '%s.%s'" % (
- field.remote_field.model.__name__, sub_objs[0].__class__.__name__, field.name
+ "referenced through a protected foreign key: '%s.%s'"
+ % (
+ field.remote_field.model.__name__,
+ sub_objs[0].__class__.__name__,
+ field.name,
),
- sub_objs
+ sub_objs,
)
@@ -45,12 +51,16 @@ def RESTRICT(collector, field, sub_objs, using):
def SET(value):
if callable(value):
+
def set_on_delete(collector, field, sub_objs, using):
collector.add_field_update(field, value(), sub_objs)
+
else:
+
def set_on_delete(collector, field, sub_objs, using):
collector.add_field_update(field, value, sub_objs)
- set_on_delete.deconstruct = lambda: ('django.db.models.SET', (value,), {})
+
+ set_on_delete.deconstruct = lambda: ("django.db.models.SET", (value,), {})
return set_on_delete
@@ -70,7 +80,8 @@ def get_candidate_relations_to_delete(opts):
# The candidate relations are the ones that come from N-1 and 1-1 relations.
# N-N (i.e., many-to-many) relations aren't candidates for deletion.
return (
- f for f in opts.get_fields(include_hidden=True)
+ f
+ for f in opts.get_fields(include_hidden=True)
if f.auto_created and not f.concrete and (f.one_to_one or f.one_to_many)
)
@@ -124,7 +135,9 @@ class Collector:
def add_dependency(self, model, dependency, reverse_dependency=False):
if reverse_dependency:
model, dependency = dependency, model
- self.dependencies[model._meta.concrete_model].add(dependency._meta.concrete_model)
+ self.dependencies[model._meta.concrete_model].add(
+ dependency._meta.concrete_model
+ )
self.data.setdefault(dependency, self.data.default_factory())
def add_field_update(self, field, value, objs):
@@ -151,17 +164,21 @@ class Collector:
def clear_restricted_objects_from_queryset(self, model, qs):
if model in self.restricted_objects:
- objs = set(qs.filter(pk__in=[
- obj.pk
- for objs in self.restricted_objects[model].values() for obj in objs
- ]))
+ objs = set(
+ qs.filter(
+ pk__in=[
+ obj.pk
+ for objs in self.restricted_objects[model].values()
+ for obj in objs
+ ]
+ )
+ )
self.clear_restricted_objects_from_set(model, objs)
def _has_signal_listeners(self, model):
- return (
- signals.pre_delete.has_listeners(model) or
- signals.post_delete.has_listeners(model)
- )
+ return signals.pre_delete.has_listeners(
+ model
+ ) or signals.post_delete.has_listeners(model)
def can_fast_delete(self, objs, from_field=None):
"""
@@ -176,9 +193,9 @@ class Collector:
"""
if from_field and from_field.remote_field.on_delete is not CASCADE:
return False
- if hasattr(objs, '_meta'):
+ if hasattr(objs, "_meta"):
model = objs._meta.model
- elif hasattr(objs, 'model') and hasattr(objs, '_raw_delete'):
+ elif hasattr(objs, "model") and hasattr(objs, "_raw_delete"):
model = objs.model
else:
return False
@@ -188,14 +205,22 @@ class Collector:
# parent when parent delete is cascading to child.
opts = model._meta
return (
- all(link == from_field for link in opts.concrete_model._meta.parents.values()) and
+ all(
+ link == from_field
+ for link in opts.concrete_model._meta.parents.values()
+ )
+ and
# Foreign keys pointing to this model.
all(
related.field.remote_field.on_delete is DO_NOTHING
for related in get_candidate_relations_to_delete(opts)
- ) and (
+ )
+ and (
# Something like generic foreign key.
- not any(hasattr(field, 'bulk_related_objects') for field in opts.private_fields)
+ not any(
+ hasattr(field, "bulk_related_objects")
+ for field in opts.private_fields
+ )
)
)
@@ -205,16 +230,27 @@ class Collector:
"""
field_names = [field.name for field in fields]
conn_batch_size = max(
- connections[self.using].ops.bulk_batch_size(field_names, objs), 1)
+ connections[self.using].ops.bulk_batch_size(field_names, objs), 1
+ )
if len(objs) > conn_batch_size:
- return [objs[i:i + conn_batch_size]
- for i in range(0, len(objs), conn_batch_size)]
+ return [
+ objs[i : i + conn_batch_size]
+ for i in range(0, len(objs), conn_batch_size)
+ ]
else:
return [objs]
- def collect(self, objs, source=None, nullable=False, collect_related=True,
- source_attr=None, reverse_dependency=False, keep_parents=False,
- fail_on_restricted=True):
+ def collect(
+ self,
+ objs,
+ source=None,
+ nullable=False,
+ collect_related=True,
+ source_attr=None,
+ reverse_dependency=False,
+ keep_parents=False,
+ fail_on_restricted=True,
+ ):
"""
Add 'objs' to the collection of objects to be deleted as well as all
parent instances. 'objs' must be a homogeneous iterable collection of
@@ -241,8 +277,9 @@ class Collector:
if self.can_fast_delete(objs):
self.fast_deletes.append(objs)
return
- new_objs = self.add(objs, source, nullable,
- reverse_dependency=reverse_dependency)
+ new_objs = self.add(
+ objs, source, nullable, reverse_dependency=reverse_dependency
+ )
if not new_objs:
return
@@ -255,11 +292,14 @@ class Collector:
for ptr in concrete_model._meta.parents.values():
if ptr:
parent_objs = [getattr(obj, ptr.name) for obj in new_objs]
- self.collect(parent_objs, source=model,
- source_attr=ptr.remote_field.related_name,
- collect_related=False,
- reverse_dependency=True,
- fail_on_restricted=False)
+ self.collect(
+ parent_objs,
+ source=model,
+ source_attr=ptr.remote_field.related_name,
+ collect_related=False,
+ reverse_dependency=True,
+ fail_on_restricted=False,
+ )
if not collect_related:
return
@@ -287,11 +327,18 @@ class Collector:
# relationships are select_related as interactions between both
# features are hard to get right. This should only happen in
# the rare cases where .related_objects is overridden anyway.
- if not (sub_objs.query.select_related or self._has_signal_listeners(related_model)):
- referenced_fields = set(chain.from_iterable(
- (rf.attname for rf in rel.field.foreign_related_fields)
- for rel in get_candidate_relations_to_delete(related_model._meta)
- ))
+ if not (
+ sub_objs.query.select_related
+ or self._has_signal_listeners(related_model)
+ ):
+ referenced_fields = set(
+ chain.from_iterable(
+ (rf.attname for rf in rel.field.foreign_related_fields)
+ for rel in get_candidate_relations_to_delete(
+ related_model._meta
+ )
+ )
+ )
sub_objs = sub_objs.only(*tuple(referenced_fields))
if sub_objs:
try:
@@ -301,10 +348,11 @@ class Collector:
protected_objects[key] += error.protected_objects
if protected_objects:
raise ProtectedError(
- 'Cannot delete some instances of model %r because they are '
- 'referenced through protected foreign keys: %s.' % (
+ "Cannot delete some instances of model %r because they are "
+ "referenced through protected foreign keys: %s."
+ % (
model.__name__,
- ', '.join(protected_objects),
+ ", ".join(protected_objects),
),
set(chain.from_iterable(protected_objects.values())),
)
@@ -314,10 +362,12 @@ class Collector:
sub_objs = self.related_objects(related_model, related_fields, batch)
self.fast_deletes.append(sub_objs)
for field in model._meta.private_fields:
- if hasattr(field, 'bulk_related_objects'):
+ if hasattr(field, "bulk_related_objects"):
# It's something like generic foreign key.
sub_objs = field.bulk_related_objects(new_objs, self.using)
- self.collect(sub_objs, source=model, nullable=True, fail_on_restricted=False)
+ self.collect(
+ sub_objs, source=model, nullable=True, fail_on_restricted=False
+ )
if fail_on_restricted:
# Raise an error if collected restricted objects (RESTRICT) aren't
@@ -335,11 +385,12 @@ class Collector:
restricted_objects[key] += objs
if restricted_objects:
raise RestrictedError(
- 'Cannot delete some instances of model %r because '
- 'they are referenced through restricted foreign keys: '
- '%s.' % (
+ "Cannot delete some instances of model %r because "
+ "they are referenced through restricted foreign keys: "
+ "%s."
+ % (
model.__name__,
- ', '.join(restricted_objects),
+ ", ".join(restricted_objects),
),
set(chain.from_iterable(restricted_objects.values())),
)
@@ -349,10 +400,7 @@ class Collector:
Get a QuerySet of the related model to objs via related fields.
"""
predicate = query_utils.Q(
- *(
- (f'{related_field.name}__in', objs)
- for related_field in related_fields
- ),
+ *((f"{related_field.name}__in", objs) for related_field in related_fields),
_connector=query_utils.Q.OR,
)
return related_model._base_manager.using(self.using).filter(predicate)
@@ -397,7 +445,9 @@ class Collector:
instance = list(instances)[0]
if self.can_fast_delete(instance):
with transaction.mark_for_rollback_on_error(self.using):
- count = sql.DeleteQuery(model).delete_batch([instance.pk], self.using)
+ count = sql.DeleteQuery(model).delete_batch(
+ [instance.pk], self.using
+ )
setattr(instance, model._meta.pk.attname, None)
return count, {model._meta.label: count}
@@ -406,7 +456,9 @@ class Collector:
for model, obj in self.instances_with_model():
if not model._meta.auto_created:
signals.pre_delete.send(
- sender=model, instance=obj, using=self.using,
+ sender=model,
+ instance=obj,
+ using=self.using,
origin=self.origin,
)
@@ -420,8 +472,9 @@ class Collector:
for model, instances_for_fieldvalues in self.field_updates.items():
for (field, value), instances in instances_for_fieldvalues.items():
query = sql.UpdateQuery(model)
- query.update_batch([obj.pk for obj in instances],
- {field.name: value}, self.using)
+ query.update_batch(
+ [obj.pk for obj in instances], {field.name: value}, self.using
+ )
# reverse instance collections
for instances in self.data.values():
@@ -438,7 +491,9 @@ class Collector:
if not model._meta.auto_created:
for obj in instances:
signals.post_delete.send(
- sender=model, instance=obj, using=self.using,
+ sender=model,
+ instance=obj,
+ using=self.using,
origin=self.origin,
)
diff --git a/django/db/models/enums.py b/django/db/models/enums.py
index 8474c87c94..9a7a2bb70f 100644
--- a/django/db/models/enums.py
+++ b/django/db/models/enums.py
@@ -3,7 +3,7 @@ from types import DynamicClassAttribute
from django.utils.functional import Promise
-__all__ = ['Choices', 'IntegerChoices', 'TextChoices']
+__all__ = ["Choices", "IntegerChoices", "TextChoices"]
class ChoicesMeta(enum.EnumMeta):
@@ -14,14 +14,14 @@ class ChoicesMeta(enum.EnumMeta):
for key in classdict._member_names:
value = classdict[key]
if (
- isinstance(value, (list, tuple)) and
- len(value) > 1 and
- isinstance(value[-1], (Promise, str))
+ isinstance(value, (list, tuple))
+ and len(value) > 1
+ and isinstance(value[-1], (Promise, str))
):
*value, label = value
value = tuple(value)
else:
- label = key.replace('_', ' ').title()
+ label = key.replace("_", " ").title()
labels.append(label)
# Use dict.__setitem__() to suppress defenses against double
# assignment in enum's classdict.
@@ -39,12 +39,12 @@ class ChoicesMeta(enum.EnumMeta):
@property
def names(cls):
- empty = ['__empty__'] if hasattr(cls, '__empty__') else []
+ empty = ["__empty__"] if hasattr(cls, "__empty__") else []
return empty + [member.name for member in cls]
@property
def choices(cls):
- empty = [(None, cls.__empty__)] if hasattr(cls, '__empty__') else []
+ empty = [(None, cls.__empty__)] if hasattr(cls, "__empty__") else []
return empty + [(member.value, member.label) for member in cls]
@property
@@ -76,11 +76,12 @@ class Choices(enum.Enum, metaclass=ChoicesMeta):
# A similar format was proposed for Python 3.10.
def __repr__(self):
- return f'{self.__class__.__qualname__}.{self._name_}'
+ return f"{self.__class__.__qualname__}.{self._name_}"
class IntegerChoices(int, Choices):
"""Class for creating enumerated integer choices."""
+
pass
diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py
index f31ff4d3df..a2da1f6e38 100644
--- a/django/db/models/expressions.py
+++ b/django/db/models/expressions.py
@@ -20,11 +20,12 @@ class SQLiteNumericMixin:
Some expressions with output_field=DecimalField() must be cast to
numeric to be properly filtered.
"""
+
def as_sqlite(self, compiler, connection, **extra_context):
sql, params = self.as_sql(compiler, connection, **extra_context)
try:
- if self.output_field.get_internal_type() == 'DecimalField':
- sql = 'CAST(%s AS NUMERIC)' % sql
+ if self.output_field.get_internal_type() == "DecimalField":
+ sql = "CAST(%s AS NUMERIC)" % sql
except FieldError:
pass
return sql, params
@@ -37,26 +38,26 @@ class Combinable:
"""
# Arithmetic connectors
- ADD = '+'
- SUB = '-'
- MUL = '*'
- DIV = '/'
- POW = '^'
+ ADD = "+"
+ SUB = "-"
+ MUL = "*"
+ DIV = "/"
+ POW = "^"
# The following is a quoted % operator - it is quoted because it can be
# used in strings that also have parameter substitution.
- MOD = '%%'
+ MOD = "%%"
# Bitwise operators - note that these are generated by .bitand()
# and .bitor(), the '&' and '|' are reserved for boolean operator
# usage.
- BITAND = '&'
- BITOR = '|'
- BITLEFTSHIFT = '<<'
- BITRIGHTSHIFT = '>>'
- BITXOR = '#'
+ BITAND = "&"
+ BITOR = "|"
+ BITLEFTSHIFT = "<<"
+ BITRIGHTSHIFT = ">>"
+ BITXOR = "#"
def _combine(self, other, connector, reversed):
- if not hasattr(other, 'resolve_expression'):
+ if not hasattr(other, "resolve_expression"):
# everything must be resolvable to an expression
other = Value(other)
@@ -90,7 +91,7 @@ class Combinable:
return self._combine(other, self.POW, False)
def __and__(self, other):
- if getattr(self, 'conditional', False) and getattr(other, 'conditional', False):
+ if getattr(self, "conditional", False) and getattr(other, "conditional", False):
return Q(self) & Q(other)
raise NotImplementedError(
"Use .bitand() and .bitor() for bitwise logical operations."
@@ -109,7 +110,7 @@ class Combinable:
return self._combine(other, self.BITXOR, False)
def __or__(self, other):
- if getattr(self, 'conditional', False) and getattr(other, 'conditional', False):
+ if getattr(self, "conditional", False) and getattr(other, "conditional", False):
return Q(self) | Q(other)
raise NotImplementedError(
"Use .bitand() and .bitor() for bitwise logical operations."
@@ -165,14 +166,14 @@ class BaseExpression:
def __getstate__(self):
state = self.__dict__.copy()
- state.pop('convert_value', None)
+ state.pop("convert_value", None)
return state
def get_db_converters(self, connection):
return (
[]
- if self.convert_value is self._convert_value_noop else
- [self.convert_value]
+ if self.convert_value is self._convert_value_noop
+ else [self.convert_value]
) + self.output_field.get_db_converters(connection)
def get_source_expressions(self):
@@ -183,9 +184,10 @@ class BaseExpression:
def _parse_expressions(self, *expressions):
return [
- arg if hasattr(arg, 'resolve_expression') else (
- F(arg) if isinstance(arg, str) else Value(arg)
- ) for arg in expressions
+ arg
+ if hasattr(arg, "resolve_expression")
+ else (F(arg) if isinstance(arg, str) else Value(arg))
+ for arg in expressions
]
def as_sql(self, compiler, connection):
@@ -218,17 +220,26 @@ class BaseExpression:
@cached_property
def contains_aggregate(self):
- return any(expr and expr.contains_aggregate for expr in self.get_source_expressions())
+ return any(
+ expr and expr.contains_aggregate for expr in self.get_source_expressions()
+ )
@cached_property
def contains_over_clause(self):
- return any(expr and expr.contains_over_clause for expr in self.get_source_expressions())
+ return any(
+ expr and expr.contains_over_clause for expr in self.get_source_expressions()
+ )
@cached_property
def contains_column_references(self):
- return any(expr and expr.contains_column_references for expr in self.get_source_expressions())
+ return any(
+ expr and expr.contains_column_references
+ for expr in self.get_source_expressions()
+ )
- def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
+ def resolve_expression(
+ self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
+ ):
"""
Provide the chance to do any preprocessing or validation before being
added to the query.
@@ -245,11 +256,14 @@ class BaseExpression:
"""
c = self.copy()
c.is_summary = summarize
- c.set_source_expressions([
- expr.resolve_expression(query, allow_joins, reuse, summarize)
- if expr else None
- for expr in c.get_source_expressions()
- ])
+ c.set_source_expressions(
+ [
+ expr.resolve_expression(query, allow_joins, reuse, summarize)
+ if expr
+ else None
+ for expr in c.get_source_expressions()
+ ]
+ )
return c
@property
@@ -266,7 +280,7 @@ class BaseExpression:
output_field = self._resolve_output_field()
if output_field is None:
self._output_field_resolved_to_none = True
- raise FieldError('Cannot resolve expression type, unknown output_field')
+ raise FieldError("Cannot resolve expression type, unknown output_field")
return output_field
@cached_property
@@ -295,13 +309,16 @@ class BaseExpression:
If all sources are None, then an error is raised higher up the stack in
the output_field property.
"""
- sources_iter = (source for source in self.get_source_fields() if source is not None)
+ sources_iter = (
+ source for source in self.get_source_fields() if source is not None
+ )
for output_field in sources_iter:
for source in sources_iter:
if not isinstance(output_field, source.__class__):
raise FieldError(
- 'Expression contains mixed types: %s, %s. You must '
- 'set output_field.' % (
+ "Expression contains mixed types: %s, %s. You must "
+ "set output_field."
+ % (
output_field.__class__.__name__,
source.__class__.__name__,
)
@@ -321,12 +338,24 @@ class BaseExpression:
"""
field = self.output_field
internal_type = field.get_internal_type()
- if internal_type == 'FloatField':
- return lambda value, expression, connection: None if value is None else float(value)
- elif internal_type.endswith('IntegerField'):
- return lambda value, expression, connection: None if value is None else int(value)
- elif internal_type == 'DecimalField':
- return lambda value, expression, connection: None if value is None else Decimal(value)
+ if internal_type == "FloatField":
+ return (
+ lambda value, expression, connection: None
+ if value is None
+ else float(value)
+ )
+ elif internal_type.endswith("IntegerField"):
+ return (
+ lambda value, expression, connection: None
+ if value is None
+ else int(value)
+ )
+ elif internal_type == "DecimalField":
+ return (
+ lambda value, expression, connection: None
+ if value is None
+ else Decimal(value)
+ )
return self._convert_value_noop
def get_lookup(self, lookup):
@@ -337,10 +366,12 @@ class BaseExpression:
def relabeled_clone(self, change_map):
clone = self.copy()
- clone.set_source_expressions([
- e.relabeled_clone(change_map) if e is not None else None
- for e in self.get_source_expressions()
- ])
+ clone.set_source_expressions(
+ [
+ e.relabeled_clone(change_map) if e is not None else None
+ for e in self.get_source_expressions()
+ ]
+ )
return clone
def copy(self):
@@ -375,7 +406,7 @@ class BaseExpression:
yield self
for expr in self.get_source_expressions():
if expr:
- if hasattr(expr, 'flatten'):
+ if hasattr(expr, "flatten"):
yield from expr.flatten()
else:
yield expr
@@ -385,7 +416,7 @@ class BaseExpression:
Custom format for select clauses. For example, EXISTS expressions need
to be wrapped in CASE WHEN on Oracle.
"""
- if hasattr(self.output_field, 'select_format'):
+ if hasattr(self.output_field, "select_format"):
return self.output_field.select_format(compiler, sql, params)
return sql, params
@@ -438,12 +469,13 @@ _connector_combinators = {
def _resolve_combined_type(connector, lhs_type, rhs_type):
combinators = _connector_combinators.get(connector, ())
for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
- if issubclass(lhs_type, combinator_lhs_type) and issubclass(rhs_type, combinator_rhs_type):
+ if issubclass(lhs_type, combinator_lhs_type) and issubclass(
+ rhs_type, combinator_rhs_type
+ ):
return combined_type
class CombinedExpression(SQLiteNumericMixin, Expression):
-
def __init__(self, lhs, connector, rhs, output_field=None):
super().__init__(output_field=output_field)
self.connector = connector
@@ -485,13 +517,19 @@ class CombinedExpression(SQLiteNumericMixin, Expression):
expressions.append(sql)
expression_params.extend(params)
# order of precedence
- expression_wrapper = '(%s)'
+ expression_wrapper = "(%s)"
sql = connection.ops.combine_expression(self.connector, expressions)
return expression_wrapper % sql, expression_params
- def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
- lhs = self.lhs.resolve_expression(query, allow_joins, reuse, summarize, for_save)
- rhs = self.rhs.resolve_expression(query, allow_joins, reuse, summarize, for_save)
+ def resolve_expression(
+ self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
+ ):
+ lhs = self.lhs.resolve_expression(
+ query, allow_joins, reuse, summarize, for_save
+ )
+ rhs = self.rhs.resolve_expression(
+ query, allow_joins, reuse, summarize, for_save
+ )
if not isinstance(self, (DurationExpression, TemporalSubtraction)):
try:
lhs_type = lhs.output_field.get_internal_type()
@@ -501,14 +539,28 @@ class CombinedExpression(SQLiteNumericMixin, Expression):
rhs_type = rhs.output_field.get_internal_type()
except (AttributeError, FieldError):
rhs_type = None
- if 'DurationField' in {lhs_type, rhs_type} and lhs_type != rhs_type:
- return DurationExpression(self.lhs, self.connector, self.rhs).resolve_expression(
- query, allow_joins, reuse, summarize, for_save,
+ if "DurationField" in {lhs_type, rhs_type} and lhs_type != rhs_type:
+ return DurationExpression(
+ self.lhs, self.connector, self.rhs
+ ).resolve_expression(
+ query,
+ allow_joins,
+ reuse,
+ summarize,
+ for_save,
)
- datetime_fields = {'DateField', 'DateTimeField', 'TimeField'}
- if self.connector == self.SUB and lhs_type in datetime_fields and lhs_type == rhs_type:
+ datetime_fields = {"DateField", "DateTimeField", "TimeField"}
+ if (
+ self.connector == self.SUB
+ and lhs_type in datetime_fields
+ and lhs_type == rhs_type
+ ):
return TemporalSubtraction(self.lhs, self.rhs).resolve_expression(
- query, allow_joins, reuse, summarize, for_save,
+ query,
+ allow_joins,
+ reuse,
+ summarize,
+ for_save,
)
c = self.copy()
c.is_summary = summarize
@@ -524,7 +576,7 @@ class DurationExpression(CombinedExpression):
except FieldError:
pass
else:
- if output.get_internal_type() == 'DurationField':
+ if output.get_internal_type() == "DurationField":
sql, params = compiler.compile(side)
return connection.ops.format_for_duration_arithmetic(sql), params
return compiler.compile(side)
@@ -542,7 +594,7 @@ class DurationExpression(CombinedExpression):
expressions.append(sql)
expression_params.extend(params)
# order of precedence
- expression_wrapper = '(%s)'
+ expression_wrapper = "(%s)"
sql = connection.ops.combine_duration_expression(self.connector, expressions)
return expression_wrapper % sql, expression_params
@@ -556,11 +608,14 @@ class DurationExpression(CombinedExpression):
pass
else:
allowed_fields = {
- 'DecimalField', 'DurationField', 'FloatField', 'IntegerField',
+ "DecimalField",
+ "DurationField",
+ "FloatField",
+ "IntegerField",
}
if lhs_type not in allowed_fields or rhs_type not in allowed_fields:
raise DatabaseError(
- f'Invalid arguments for operator {self.connector}.'
+ f"Invalid arguments for operator {self.connector}."
)
return sql, params
@@ -575,10 +630,12 @@ class TemporalSubtraction(CombinedExpression):
connection.ops.check_expression_support(self)
lhs = compiler.compile(self.lhs)
rhs = compiler.compile(self.rhs)
- return connection.ops.subtract_temporals(self.lhs.output_field.get_internal_type(), lhs, rhs)
+ return connection.ops.subtract_temporals(
+ self.lhs.output_field.get_internal_type(), lhs, rhs
+ )
-@deconstructible(path='django.db.models.F')
+@deconstructible(path="django.db.models.F")
class F(Combinable):
"""An object capable of resolving references to existing query objects."""
@@ -592,8 +649,9 @@ class F(Combinable):
def __repr__(self):
return "{}({})".format(self.__class__.__name__, self.name)
- def resolve_expression(self, query=None, allow_joins=True, reuse=None,
- summarize=False, for_save=False):
+ def resolve_expression(
+ self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
+ ):
return query.resolve_ref(self.name, allow_joins, reuse, summarize)
def asc(self, **kwargs):
@@ -616,12 +674,13 @@ class ResolvedOuterRef(F):
In this case, the reference to the outer query has been resolved because
the inner query has been used as a subquery.
"""
+
contains_aggregate = False
def as_sql(self, *args, **kwargs):
raise ValueError(
- 'This queryset contains a reference to an outer query and may '
- 'only be used in a subquery.'
+ "This queryset contains a reference to an outer query and may "
+ "only be used in a subquery."
)
def resolve_expression(self, *args, **kwargs):
@@ -651,18 +710,20 @@ class OuterRef(F):
return self
-@deconstructible(path='django.db.models.Func')
+@deconstructible(path="django.db.models.Func")
class Func(SQLiteNumericMixin, Expression):
"""An SQL function call."""
+
function = None
- template = '%(function)s(%(expressions)s)'
- arg_joiner = ', '
+ template = "%(function)s(%(expressions)s)"
+ arg_joiner = ", "
arity = None # The number of arguments the function accepts.
def __init__(self, *expressions, output_field=None, **extra):
if self.arity is not None and len(expressions) != self.arity:
raise TypeError(
- "'%s' takes exactly %s %s (%s given)" % (
+ "'%s' takes exactly %s %s (%s given)"
+ % (
self.__class__.__name__,
self.arity,
"argument" if self.arity == 1 else "arguments",
@@ -677,7 +738,9 @@ class Func(SQLiteNumericMixin, Expression):
args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
extra = {**self.extra, **self._get_repr_options()}
if extra:
- extra = ', '.join(str(key) + '=' + str(val) for key, val in sorted(extra.items()))
+ extra = ", ".join(
+ str(key) + "=" + str(val) for key, val in sorted(extra.items())
+ )
return "{}({}, {})".format(self.__class__.__name__, args, extra)
return "{}({})".format(self.__class__.__name__, args)
@@ -691,14 +754,26 @@ class Func(SQLiteNumericMixin, Expression):
def set_source_expressions(self, exprs):
self.source_expressions = exprs
- def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
+ def resolve_expression(
+ self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
+ ):
c = self.copy()
c.is_summary = summarize
for pos, arg in enumerate(c.source_expressions):
- c.source_expressions[pos] = arg.resolve_expression(query, allow_joins, reuse, summarize, for_save)
+ c.source_expressions[pos] = arg.resolve_expression(
+ query, allow_joins, reuse, summarize, for_save
+ )
return c
- def as_sql(self, compiler, connection, function=None, template=None, arg_joiner=None, **extra_context):
+ def as_sql(
+ self,
+ compiler,
+ connection,
+ function=None,
+ template=None,
+ arg_joiner=None,
+ **extra_context,
+ ):
connection.ops.check_expression_support(self)
sql_parts = []
params = []
@@ -706,7 +781,9 @@ class Func(SQLiteNumericMixin, Expression):
try:
arg_sql, arg_params = compiler.compile(arg)
except EmptyResultSet:
- empty_result_set_value = getattr(arg, 'empty_result_set_value', NotImplemented)
+ empty_result_set_value = getattr(
+ arg, "empty_result_set_value", NotImplemented
+ )
if empty_result_set_value is NotImplemented:
raise
arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
@@ -717,12 +794,12 @@ class Func(SQLiteNumericMixin, Expression):
# method, a value supplied in __init__()'s **extra (the value in
# `data`), or the value defined on the class.
if function is not None:
- data['function'] = function
+ data["function"] = function
else:
- data.setdefault('function', self.function)
- template = template or data.get('template', self.template)
- arg_joiner = arg_joiner or data.get('arg_joiner', self.arg_joiner)
- data['expressions'] = data['field'] = arg_joiner.join(sql_parts)
+ data.setdefault("function", self.function)
+ template = template or data.get("template", self.template)
+ arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner)
+ data["expressions"] = data["field"] = arg_joiner.join(sql_parts)
return template % data, params
def copy(self):
@@ -732,9 +809,10 @@ class Func(SQLiteNumericMixin, Expression):
return copy
-@deconstructible(path='django.db.models.Value')
+@deconstructible(path="django.db.models.Value")
class Value(SQLiteNumericMixin, Expression):
"""Represent a wrapped value as a node within an expression."""
+
# Provide a default value for `for_save` in order to allow unresolved
# instances to be compiled until a decision is taken in #25425.
for_save = False
@@ -752,7 +830,7 @@ class Value(SQLiteNumericMixin, Expression):
self.value = value
def __repr__(self):
- return f'{self.__class__.__name__}({self.value!r})'
+ return f"{self.__class__.__name__}({self.value!r})"
def as_sql(self, compiler, connection):
connection.ops.check_expression_support(self)
@@ -763,16 +841,18 @@ class Value(SQLiteNumericMixin, Expression):
val = output_field.get_db_prep_save(val, connection=connection)
else:
val = output_field.get_db_prep_value(val, connection=connection)
- if hasattr(output_field, 'get_placeholder'):
+ if hasattr(output_field, "get_placeholder"):
return output_field.get_placeholder(val, compiler, connection), [val]
if val is None:
# cx_Oracle does not always convert None to the appropriate
# NULL type (like in case expressions using numbers), so we
# use a literal SQL NULL
- return 'NULL', []
- return '%s', [val]
+ return "NULL", []
+ return "%s", [val]
- def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
+ def resolve_expression(
+ self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
+ ):
c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
c.for_save = for_save
return c
@@ -820,12 +900,14 @@ class RawSQL(Expression):
return "{}({}, {})".format(self.__class__.__name__, self.sql, self.params)
def as_sql(self, compiler, connection):
- return '(%s)' % self.sql, self.params
+ return "(%s)" % self.sql, self.params
def get_group_by_cols(self, alias=None):
return [self]
- def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
+ def resolve_expression(
+ self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
+ ):
# Resolve parents fields used in raw SQL.
for parent in query.model._meta.get_parent_list():
for parent_field in parent._meta.local_fields:
@@ -833,7 +915,9 @@ class RawSQL(Expression):
if column_name.lower() in self.sql.lower():
query.resolve_ref(parent_field.name, allow_joins, reuse, summarize)
break
- return super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
+ return super().resolve_expression(
+ query, allow_joins, reuse, summarize, for_save
+ )
class Star(Expression):
@@ -841,7 +925,7 @@ class Star(Expression):
return "'*'"
def as_sql(self, compiler, connection):
- return '*', []
+ return "*", []
class Col(Expression):
@@ -858,18 +942,20 @@ class Col(Expression):
def __repr__(self):
alias, target = self.alias, self.target
identifiers = (alias, str(target)) if alias else (str(target),)
- return '{}({})'.format(self.__class__.__name__, ', '.join(identifiers))
+ return "{}({})".format(self.__class__.__name__, ", ".join(identifiers))
def as_sql(self, compiler, connection):
alias, column = self.alias, self.target.column
identifiers = (alias, column) if alias else (column,)
- sql = '.'.join(map(compiler.quote_name_unless_alias, identifiers))
+ sql = ".".join(map(compiler.quote_name_unless_alias, identifiers))
return sql, []
def relabeled_clone(self, relabels):
if self.alias is None:
return self
- return self.__class__(relabels.get(self.alias, self.alias), self.target, self.output_field)
+ return self.__class__(
+ relabels.get(self.alias, self.alias), self.target, self.output_field
+ )
def get_group_by_cols(self, alias=None):
return [self]
@@ -877,8 +963,9 @@ class Col(Expression):
def get_db_converters(self, connection):
if self.target == self.output_field:
return self.output_field.get_db_converters(connection)
- return (self.output_field.get_db_converters(connection) +
- self.target.get_db_converters(connection))
+ return self.output_field.get_db_converters(
+ connection
+ ) + self.target.get_db_converters(connection)
class Ref(Expression):
@@ -886,6 +973,7 @@ class Ref(Expression):
Reference to column alias of the query. For example, Ref('sum_cost') in
qs.annotate(sum_cost=Sum('cost')) query.
"""
+
def __init__(self, refs, source):
super().__init__()
self.refs, self.source = refs, source
@@ -897,9 +985,11 @@ class Ref(Expression):
return [self.source]
def set_source_expressions(self, exprs):
- self.source, = exprs
+ (self.source,) = exprs
- def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
+ def resolve_expression(
+ self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
+ ):
# The sub-expression `source` has already been resolved, as this is
# just a reference to the name of `source`.
return self
@@ -920,11 +1010,14 @@ class ExpressionList(Func):
list of expressions as an argument to another expression, like a partition
clause.
"""
- template = '%(expressions)s'
+
+ template = "%(expressions)s"
def __init__(self, *expressions, **extra):
if not expressions:
- raise ValueError('%s requires at least one expression.' % self.__class__.__name__)
+ raise ValueError(
+ "%s requires at least one expression." % self.__class__.__name__
+ )
super().__init__(*expressions, **extra)
def __str__(self):
@@ -936,13 +1029,13 @@ class ExpressionList(Func):
class OrderByList(Func):
- template = 'ORDER BY %(expressions)s'
+ template = "ORDER BY %(expressions)s"
def __init__(self, *expressions, **extra):
expressions = (
(
OrderBy(F(expr[1:]), descending=True)
- if isinstance(expr, str) and expr[0] == '-'
+ if isinstance(expr, str) and expr[0] == "-"
else expr
)
for expr in expressions
@@ -951,11 +1044,11 @@ class OrderByList(Func):
def as_sql(self, *args, **kwargs):
if not self.source_expressions:
- return '', ()
+ return "", ()
return super().as_sql(*args, **kwargs)
-@deconstructible(path='django.db.models.ExpressionWrapper')
+@deconstructible(path="django.db.models.ExpressionWrapper")
class ExpressionWrapper(SQLiteNumericMixin, Expression):
"""
An expression that can wrap another expression so that it can provide
@@ -988,9 +1081,9 @@ class ExpressionWrapper(SQLiteNumericMixin, Expression):
return "{}({})".format(self.__class__.__name__, self.expression)
-@deconstructible(path='django.db.models.When')
+@deconstructible(path="django.db.models.When")
class When(Expression):
- template = 'WHEN %(condition)s THEN %(result)s'
+ template = "WHEN %(condition)s THEN %(result)s"
# This isn't a complete conditional expression, must be used in Case().
conditional = False
@@ -998,12 +1091,12 @@ class When(Expression):
if lookups:
if condition is None:
condition, lookups = Q(**lookups), None
- elif getattr(condition, 'conditional', False):
+ elif getattr(condition, "conditional", False):
condition, lookups = Q(condition, **lookups), None
- if condition is None or not getattr(condition, 'conditional', False) or lookups:
+ if condition is None or not getattr(condition, "conditional", False) or lookups:
raise TypeError(
- 'When() supports a Q object, a boolean expression, or lookups '
- 'as a condition.'
+ "When() supports a Q object, a boolean expression, or lookups "
+ "as a condition."
)
if isinstance(condition, Q) and not condition:
raise ValueError("An empty Q() can't be used as a When() condition.")
@@ -1027,12 +1120,18 @@ class When(Expression):
# We're only interested in the fields of the result expressions.
return [self.result._output_field_or_none]
- def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
+ def resolve_expression(
+ self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
+ ):
c = self.copy()
c.is_summary = summarize
- if hasattr(c.condition, 'resolve_expression'):
- c.condition = c.condition.resolve_expression(query, allow_joins, reuse, summarize, False)
- c.result = c.result.resolve_expression(query, allow_joins, reuse, summarize, for_save)
+ if hasattr(c.condition, "resolve_expression"):
+ c.condition = c.condition.resolve_expression(
+ query, allow_joins, reuse, summarize, False
+ )
+ c.result = c.result.resolve_expression(
+ query, allow_joins, reuse, summarize, for_save
+ )
return c
def as_sql(self, compiler, connection, template=None, **extra_context):
@@ -1040,10 +1139,10 @@ class When(Expression):
template_params = extra_context
sql_params = []
condition_sql, condition_params = compiler.compile(self.condition)
- template_params['condition'] = condition_sql
+ template_params["condition"] = condition_sql
sql_params.extend(condition_params)
result_sql, result_params = compiler.compile(self.result)
- template_params['result'] = result_sql
+ template_params["result"] = result_sql
sql_params.extend(result_params)
template = template or self.template
return template % template_params, sql_params
@@ -1056,7 +1155,7 @@ class When(Expression):
return cols
-@deconstructible(path='django.db.models.Case')
+@deconstructible(path="django.db.models.Case")
class Case(SQLiteNumericMixin, Expression):
"""
An SQL searched CASE expression:
@@ -1069,8 +1168,9 @@ class Case(SQLiteNumericMixin, Expression):
ELSE 'zero'
END
"""
- template = 'CASE %(cases)s ELSE %(default)s END'
- case_joiner = ' '
+
+ template = "CASE %(cases)s ELSE %(default)s END"
+ case_joiner = " "
def __init__(self, *cases, default=None, output_field=None, **extra):
if not all(isinstance(case, When) for case in cases):
@@ -1081,7 +1181,10 @@ class Case(SQLiteNumericMixin, Expression):
self.extra = extra
def __str__(self):
- return "CASE %s, ELSE %r" % (', '.join(str(c) for c in self.cases), self.default)
+ return "CASE %s, ELSE %r" % (
+ ", ".join(str(c) for c in self.cases),
+ self.default,
+ )
def __repr__(self):
return "<%s: %s>" % (self.__class__.__name__, self)
@@ -1092,12 +1195,18 @@ class Case(SQLiteNumericMixin, Expression):
def set_source_expressions(self, exprs):
*self.cases, self.default = exprs
- def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
+ def resolve_expression(
+ self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
+ ):
c = self.copy()
c.is_summary = summarize
for pos, case in enumerate(c.cases):
- c.cases[pos] = case.resolve_expression(query, allow_joins, reuse, summarize, for_save)
- c.default = c.default.resolve_expression(query, allow_joins, reuse, summarize, for_save)
+ c.cases[pos] = case.resolve_expression(
+ query, allow_joins, reuse, summarize, for_save
+ )
+ c.default = c.default.resolve_expression(
+ query, allow_joins, reuse, summarize, for_save
+ )
return c
def copy(self):
@@ -1105,7 +1214,9 @@ class Case(SQLiteNumericMixin, Expression):
c.cases = c.cases[:]
return c
- def as_sql(self, compiler, connection, template=None, case_joiner=None, **extra_context):
+ def as_sql(
+ self, compiler, connection, template=None, case_joiner=None, **extra_context
+ ):
connection.ops.check_expression_support(self)
if not self.cases:
return compiler.compile(self.default)
@@ -1123,10 +1234,10 @@ class Case(SQLiteNumericMixin, Expression):
if not case_parts:
return default_sql, default_params
case_joiner = case_joiner or self.case_joiner
- template_params['cases'] = case_joiner.join(case_parts)
- template_params['default'] = default_sql
+ template_params["cases"] = case_joiner.join(case_parts)
+ template_params["default"] = default_sql
sql_params.extend(default_params)
- template = template or template_params.get('template', self.template)
+ template = template or template_params.get("template", self.template)
sql = template % template_params
if self._output_field_or_none is not None:
sql = connection.ops.unification_cast_sql(self.output_field) % sql
@@ -1143,13 +1254,14 @@ class Subquery(BaseExpression, Combinable):
An explicit subquery. It may contain OuterRef() references to the outer
query which will be resolved when it is applied to that query.
"""
- template = '(%(subquery)s)'
+
+ template = "(%(subquery)s)"
contains_aggregate = False
empty_result_set_value = None
def __init__(self, queryset, output_field=None, **extra):
# Allow the usage of both QuerySet and sql.Query objects.
- self.query = getattr(queryset, 'query', queryset).clone()
+ self.query = getattr(queryset, "query", queryset).clone()
self.query.subquery = True
self.extra = extra
super().__init__(output_field)
@@ -1180,9 +1292,9 @@ class Subquery(BaseExpression, Combinable):
template_params = {**self.extra, **extra_context}
query = query or self.query
subquery_sql, sql_params = query.as_sql(compiler, connection)
- template_params['subquery'] = subquery_sql[1:-1]
+ template_params["subquery"] = subquery_sql[1:-1]
- template = template or template_params.get('template', self.template)
+ template = template or template_params.get("template", self.template)
sql = template % template_params
return sql, sql_params
@@ -1197,7 +1309,7 @@ class Subquery(BaseExpression, Combinable):
class Exists(Subquery):
- template = 'EXISTS(%(subquery)s)'
+ template = "EXISTS(%(subquery)s)"
output_field = fields.BooleanField()
def __init__(self, queryset, negated=False, **kwargs):
@@ -1227,7 +1339,7 @@ class Exists(Subquery):
return compiler.compile(Value(True))
raise
if self.negated:
- sql = 'NOT {}'.format(sql)
+ sql = "NOT {}".format(sql)
return sql, params
def select_format(self, compiler, sql, params):
@@ -1235,28 +1347,31 @@ class Exists(Subquery):
# (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
# BY list.
if not compiler.connection.features.supports_boolean_expr_in_select_clause:
- sql = 'CASE WHEN {} THEN 1 ELSE 0 END'.format(sql)
+ sql = "CASE WHEN {} THEN 1 ELSE 0 END".format(sql)
return sql, params
-@deconstructible(path='django.db.models.OrderBy')
+@deconstructible(path="django.db.models.OrderBy")
class OrderBy(Expression):
- template = '%(expression)s %(ordering)s'
+ template = "%(expression)s %(ordering)s"
conditional = False
- def __init__(self, expression, descending=False, nulls_first=False, nulls_last=False):
+ def __init__(
+ self, expression, descending=False, nulls_first=False, nulls_last=False
+ ):
if nulls_first and nulls_last:
- raise ValueError('nulls_first and nulls_last are mutually exclusive')
+ raise ValueError("nulls_first and nulls_last are mutually exclusive")
self.nulls_first = nulls_first
self.nulls_last = nulls_last
self.descending = descending
- if not hasattr(expression, 'resolve_expression'):
- raise ValueError('expression must be an expression type')
+ if not hasattr(expression, "resolve_expression"):
+ raise ValueError("expression must be an expression type")
self.expression = expression
def __repr__(self):
return "{}({}, descending={})".format(
- self.__class__.__name__, self.expression, self.descending)
+ self.__class__.__name__, self.expression, self.descending
+ )
def set_source_expressions(self, exprs):
self.expression = exprs[0]
@@ -1268,32 +1383,34 @@ class OrderBy(Expression):
template = template or self.template
if connection.features.supports_order_by_nulls_modifier:
if self.nulls_last:
- template = '%s NULLS LAST' % template
+ template = "%s NULLS LAST" % template
elif self.nulls_first:
- template = '%s NULLS FIRST' % template
+ template = "%s NULLS FIRST" % template
else:
if self.nulls_last and not (
self.descending and connection.features.order_by_nulls_first
):
- template = '%%(expression)s IS NULL, %s' % template
+ template = "%%(expression)s IS NULL, %s" % template
elif self.nulls_first and not (
not self.descending and connection.features.order_by_nulls_first
):
- template = '%%(expression)s IS NOT NULL, %s' % template
+ template = "%%(expression)s IS NOT NULL, %s" % template
connection.ops.check_expression_support(self)
expression_sql, params = compiler.compile(self.expression)
placeholders = {
- 'expression': expression_sql,
- 'ordering': 'DESC' if self.descending else 'ASC',
+ "expression": expression_sql,
+ "ordering": "DESC" if self.descending else "ASC",
**extra_context,
}
- params *= template.count('%(expression)s')
+ params *= template.count("%(expression)s")
return (template % placeholders).rstrip(), params
def as_oracle(self, compiler, connection):
# Oracle doesn't allow ORDER BY EXISTS() or filters unless it's wrapped
# in a CASE WHEN.
- if connection.ops.conditional_expression_supported_in_where_clause(self.expression):
+ if connection.ops.conditional_expression_supported_in_where_clause(
+ self.expression
+ ):
copy = self.copy()
copy.expression = Case(
When(self.expression, then=True),
@@ -1323,7 +1440,7 @@ class OrderBy(Expression):
class Window(SQLiteNumericMixin, Expression):
- template = '%(expression)s OVER (%(window)s)'
+ template = "%(expression)s OVER (%(window)s)"
# Although the main expression may either be an aggregate or an
# expression with an aggregate function, the GROUP BY that will
# be introduced in the query as a result is not desired.
@@ -1331,15 +1448,22 @@ class Window(SQLiteNumericMixin, Expression):
contains_over_clause = True
filterable = False
- def __init__(self, expression, partition_by=None, order_by=None, frame=None, output_field=None):
+ def __init__(
+ self,
+ expression,
+ partition_by=None,
+ order_by=None,
+ frame=None,
+ output_field=None,
+ ):
self.partition_by = partition_by
self.order_by = order_by
self.frame = frame
- if not getattr(expression, 'window_compatible', False):
+ if not getattr(expression, "window_compatible", False):
raise ValueError(
- "Expression '%s' isn't compatible with OVER clauses." %
- expression.__class__.__name__
+ "Expression '%s' isn't compatible with OVER clauses."
+ % expression.__class__.__name__
)
if self.partition_by is not None:
@@ -1354,8 +1478,8 @@ class Window(SQLiteNumericMixin, Expression):
self.order_by = OrderByList(self.order_by)
else:
raise ValueError(
- 'Window.order_by must be either a string reference to a '
- 'field, an expression, or a list or tuple of them.'
+ "Window.order_by must be either a string reference to a "
+ "field, an expression, or a list or tuple of them."
)
super().__init__(output_field=output_field)
self.source_expression = self._parse_expressions(expression)[0]
@@ -1372,14 +1496,15 @@ class Window(SQLiteNumericMixin, Expression):
def as_sql(self, compiler, connection, template=None):
connection.ops.check_expression_support(self)
if not connection.features.supports_over_clause:
- raise NotSupportedError('This backend does not support window expressions.')
+ raise NotSupportedError("This backend does not support window expressions.")
expr_sql, params = compiler.compile(self.source_expression)
window_sql, window_params = [], []
if self.partition_by is not None:
sql_expr, sql_params = self.partition_by.as_sql(
- compiler=compiler, connection=connection,
- template='PARTITION BY %(expressions)s',
+ compiler=compiler,
+ connection=connection,
+ template="PARTITION BY %(expressions)s",
)
window_sql.append(sql_expr)
window_params.extend(sql_params)
@@ -1397,10 +1522,10 @@ class Window(SQLiteNumericMixin, Expression):
params.extend(window_params)
template = template or self.template
- return template % {
- 'expression': expr_sql,
- 'window': ' '.join(window_sql).strip()
- }, params
+ return (
+ template % {"expression": expr_sql, "window": " ".join(window_sql).strip()},
+ params,
+ )
def as_sqlite(self, compiler, connection):
if isinstance(self.output_field, fields.DecimalField):
@@ -1413,15 +1538,15 @@ class Window(SQLiteNumericMixin, Expression):
return self.as_sql(compiler, connection)
def __str__(self):
- return '{} OVER ({}{}{})'.format(
+ return "{} OVER ({}{}{})".format(
str(self.source_expression),
- 'PARTITION BY ' + str(self.partition_by) if self.partition_by else '',
- str(self.order_by or ''),
- str(self.frame or ''),
+ "PARTITION BY " + str(self.partition_by) if self.partition_by else "",
+ str(self.order_by or ""),
+ str(self.frame or ""),
)
def __repr__(self):
- return '<%s: %s>' % (self.__class__.__name__, self)
+ return "<%s: %s>" % (self.__class__.__name__, self)
def get_group_by_cols(self, alias=None):
return []
@@ -1435,7 +1560,8 @@ class WindowFrame(Expression):
frame is optional (the default is UNBOUNDED FOLLOWING, which is the last
row in the frame).
"""
- template = '%(frame_type)s BETWEEN %(start)s AND %(end)s'
+
+ template = "%(frame_type)s BETWEEN %(start)s AND %(end)s"
def __init__(self, start=None, end=None):
self.start = Value(start)
@@ -1449,52 +1575,58 @@ class WindowFrame(Expression):
def as_sql(self, compiler, connection):
connection.ops.check_expression_support(self)
- start, end = self.window_frame_start_end(connection, self.start.value, self.end.value)
- return self.template % {
- 'frame_type': self.frame_type,
- 'start': start,
- 'end': end,
- }, []
+ start, end = self.window_frame_start_end(
+ connection, self.start.value, self.end.value
+ )
+ return (
+ self.template
+ % {
+ "frame_type": self.frame_type,
+ "start": start,
+ "end": end,
+ },
+ [],
+ )
def __repr__(self):
- return '<%s: %s>' % (self.__class__.__name__, self)
+ return "<%s: %s>" % (self.__class__.__name__, self)
def get_group_by_cols(self, alias=None):
return []
def __str__(self):
if self.start.value is not None and self.start.value < 0:
- start = '%d %s' % (abs(self.start.value), connection.ops.PRECEDING)
+ start = "%d %s" % (abs(self.start.value), connection.ops.PRECEDING)
elif self.start.value is not None and self.start.value == 0:
start = connection.ops.CURRENT_ROW
else:
start = connection.ops.UNBOUNDED_PRECEDING
if self.end.value is not None and self.end.value > 0:
- end = '%d %s' % (self.end.value, connection.ops.FOLLOWING)
+ end = "%d %s" % (self.end.value, connection.ops.FOLLOWING)
elif self.end.value is not None and self.end.value == 0:
end = connection.ops.CURRENT_ROW
else:
end = connection.ops.UNBOUNDED_FOLLOWING
return self.template % {
- 'frame_type': self.frame_type,
- 'start': start,
- 'end': end,
+ "frame_type": self.frame_type,
+ "start": start,
+ "end": end,
}
def window_frame_start_end(self, connection, start, end):
- raise NotImplementedError('Subclasses must implement window_frame_start_end().')
+ raise NotImplementedError("Subclasses must implement window_frame_start_end().")
class RowRange(WindowFrame):
- frame_type = 'ROWS'
+ frame_type = "ROWS"
def window_frame_start_end(self, connection, start, end):
return connection.ops.window_frame_rows_start_end(start, end)
class ValueRange(WindowFrame):
- frame_type = 'RANGE'
+ frame_type = "RANGE"
def window_frame_start_end(self, connection, start, end):
return connection.ops.window_frame_range_start_end(start, end)
diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py
index 6d6d10a483..313e31b5f5 100644
--- a/django/db/models/fields/__init__.py
+++ b/django/db/models/fields/__init__.py
@@ -19,7 +19,10 @@ from django.db.models.query_utils import DeferredAttribute, RegisterLookupMixin
from django.utils import timezone
from django.utils.datastructures import DictWrapper
from django.utils.dateparse import (
- parse_date, parse_datetime, parse_duration, parse_time,
+ parse_date,
+ parse_datetime,
+ parse_duration,
+ parse_time,
)
from django.utils.duration import duration_microseconds, duration_string
from django.utils.functional import Promise, cached_property
@@ -29,14 +32,38 @@ from django.utils.text import capfirst
from django.utils.translation import gettext_lazy as _
__all__ = [
- 'AutoField', 'BLANK_CHOICE_DASH', 'BigAutoField', 'BigIntegerField',
- 'BinaryField', 'BooleanField', 'CharField', 'CommaSeparatedIntegerField',
- 'DateField', 'DateTimeField', 'DecimalField', 'DurationField',
- 'EmailField', 'Empty', 'Field', 'FilePathField', 'FloatField',
- 'GenericIPAddressField', 'IPAddressField', 'IntegerField', 'NOT_PROVIDED',
- 'NullBooleanField', 'PositiveBigIntegerField', 'PositiveIntegerField',
- 'PositiveSmallIntegerField', 'SlugField', 'SmallAutoField',
- 'SmallIntegerField', 'TextField', 'TimeField', 'URLField', 'UUIDField',
+ "AutoField",
+ "BLANK_CHOICE_DASH",
+ "BigAutoField",
+ "BigIntegerField",
+ "BinaryField",
+ "BooleanField",
+ "CharField",
+ "CommaSeparatedIntegerField",
+ "DateField",
+ "DateTimeField",
+ "DecimalField",
+ "DurationField",
+ "EmailField",
+ "Empty",
+ "Field",
+ "FilePathField",
+ "FloatField",
+ "GenericIPAddressField",
+ "IPAddressField",
+ "IntegerField",
+ "NOT_PROVIDED",
+ "NullBooleanField",
+ "PositiveBigIntegerField",
+ "PositiveIntegerField",
+ "PositiveSmallIntegerField",
+ "SlugField",
+ "SmallAutoField",
+ "SmallIntegerField",
+ "TextField",
+ "TimeField",
+ "URLField",
+ "UUIDField",
]
@@ -72,6 +99,7 @@ def _load_field(app_label, model_name, field_name):
#
# getattr(obj, opts.pk.attname)
+
def _empty(of_cls):
new = Empty()
new.__class__ = of_cls
@@ -98,14 +126,16 @@ class Field(RegisterLookupMixin):
auto_creation_counter = -1
default_validators = [] # Default set of validators
default_error_messages = {
- 'invalid_choice': _('Value %(value)r is not a valid choice.'),
- 'null': _('This field cannot be null.'),
- 'blank': _('This field cannot be blank.'),
- 'unique': _('%(model_name)s with this %(field_label)s already exists.'),
+ "invalid_choice": _("Value %(value)r is not a valid choice."),
+ "null": _("This field cannot be null."),
+ "blank": _("This field cannot be blank."),
+ "unique": _("%(model_name)s with this %(field_label)s already exists."),
# Translators: The 'lookup_type' is one of 'date', 'year' or 'month'.
# Eg: "Title must be unique for pub_date year"
- 'unique_for_date': _("%(field_label)s must be unique for "
- "%(date_field_label)s %(lookup_type)s."),
+ "unique_for_date": _(
+ "%(field_label)s must be unique for "
+ "%(date_field_label)s %(lookup_type)s."
+ ),
}
system_check_deprecated_details = None
system_check_removed_details = None
@@ -123,18 +153,37 @@ class Field(RegisterLookupMixin):
# Generic field type description, usually overridden by subclasses
def _description(self):
- return _('Field of type: %(field_type)s') % {
- 'field_type': self.__class__.__name__
+ return _("Field of type: %(field_type)s") % {
+ "field_type": self.__class__.__name__
}
+
description = property(_description)
- def __init__(self, verbose_name=None, name=None, primary_key=False,
- max_length=None, unique=False, blank=False, null=False,
- db_index=False, rel=None, default=NOT_PROVIDED, editable=True,
- serialize=True, unique_for_date=None, unique_for_month=None,
- unique_for_year=None, choices=None, help_text='', db_column=None,
- db_tablespace=None, auto_created=False, validators=(),
- error_messages=None):
+ def __init__(
+ self,
+ verbose_name=None,
+ name=None,
+ primary_key=False,
+ max_length=None,
+ unique=False,
+ blank=False,
+ null=False,
+ db_index=False,
+ rel=None,
+ default=NOT_PROVIDED,
+ editable=True,
+ serialize=True,
+ unique_for_date=None,
+ unique_for_month=None,
+ unique_for_year=None,
+ choices=None,
+ help_text="",
+ db_column=None,
+ db_tablespace=None,
+ auto_created=False,
+ validators=(),
+ error_messages=None,
+ ):
self.name = name
self.verbose_name = verbose_name # May be set by set_attributes_from_name
self._verbose_name = verbose_name # Store original for deconstruction
@@ -170,7 +219,7 @@ class Field(RegisterLookupMixin):
messages = {}
for c in reversed(self.__class__.__mro__):
- messages.update(getattr(c, 'default_error_messages', {}))
+ messages.update(getattr(c, "default_error_messages", {}))
messages.update(error_messages or {})
self._error_messages = error_messages # Store for deconstruction later
self.error_messages = messages
@@ -180,18 +229,18 @@ class Field(RegisterLookupMixin):
Return "app_label.model_label.field_name" for fields attached to
models.
"""
- if not hasattr(self, 'model'):
+ if not hasattr(self, "model"):
return super().__str__()
model = self.model
- return '%s.%s' % (model._meta.label, self.name)
+ return "%s.%s" % (model._meta.label, self.name)
def __repr__(self):
"""Display the module, class, and name of the field."""
- path = '%s.%s' % (self.__class__.__module__, self.__class__.__qualname__)
- name = getattr(self, 'name', None)
+ path = "%s.%s" % (self.__class__.__module__, self.__class__.__qualname__)
+ name = getattr(self, "name", None)
if name is not None:
- return '<%s: %s>' % (path, name)
- return '<%s>' % path
+ return "<%s: %s>" % (path, name)
+ return "<%s>" % path
def check(self, **kwargs):
return [
@@ -209,12 +258,12 @@ class Field(RegisterLookupMixin):
Check if field name is valid, i.e. 1) does not end with an
underscore, 2) does not contain "__" and 3) is not "pk".
"""
- if self.name.endswith('_'):
+ if self.name.endswith("_"):
return [
checks.Error(
- 'Field names must not end with an underscore.',
+ "Field names must not end with an underscore.",
obj=self,
- id='fields.E001',
+ id="fields.E001",
)
]
elif LOOKUP_SEP in self.name:
@@ -222,15 +271,15 @@ class Field(RegisterLookupMixin):
checks.Error(
'Field names must not contain "%s".' % LOOKUP_SEP,
obj=self,
- id='fields.E002',
+ id="fields.E002",
)
]
- elif self.name == 'pk':
+ elif self.name == "pk":
return [
checks.Error(
"'pk' is a reserved word that cannot be used as a field name.",
obj=self,
- id='fields.E003',
+ id="fields.E003",
)
]
else:
@@ -249,7 +298,7 @@ class Field(RegisterLookupMixin):
checks.Error(
"'choices' must be an iterable (e.g., a list or tuple).",
obj=self,
- id='fields.E004',
+ id="fields.E004",
)
]
@@ -268,14 +317,22 @@ class Field(RegisterLookupMixin):
):
break
if self.max_length is not None and group_choices:
- choice_max_length = max([
- choice_max_length,
- *(len(value) for value, _ in group_choices if isinstance(value, str)),
- ])
+ choice_max_length = max(
+ [
+ choice_max_length,
+ *(
+ len(value)
+ for value, _ in group_choices
+ if isinstance(value, str)
+ ),
+ ]
+ )
except (TypeError, ValueError):
# No groups, choices in the form [value, display]
value, human_name = group_name, group_choices
- if not self._choices_is_value(value) or not self._choices_is_value(human_name):
+ if not self._choices_is_value(value) or not self._choices_is_value(
+ human_name
+ ):
break
if self.max_length is not None and isinstance(value, str):
choice_max_length = max(choice_max_length, len(value))
@@ -290,7 +347,7 @@ class Field(RegisterLookupMixin):
"'max_length' is too small to fit the longest value "
"in 'choices' (%d characters)." % choice_max_length,
obj=self,
- id='fields.E009',
+ id="fields.E009",
),
]
return []
@@ -300,7 +357,7 @@ class Field(RegisterLookupMixin):
"'choices' must be an iterable containing "
"(actual value, human readable name) tuples.",
obj=self,
- id='fields.E005',
+ id="fields.E005",
)
]
@@ -310,25 +367,30 @@ class Field(RegisterLookupMixin):
checks.Error(
"'db_index' must be None, True or False.",
obj=self,
- id='fields.E006',
+ id="fields.E006",
)
]
else:
return []
def _check_null_allowed_for_primary_keys(self):
- if (self.primary_key and self.null and
- not connection.features.interprets_empty_strings_as_nulls):
+ if (
+ self.primary_key
+ and self.null
+ and not connection.features.interprets_empty_strings_as_nulls
+ ):
# We cannot reliably check this for backends like Oracle which
# consider NULL and '' to be equal (and thus set up
# character-based fields a little differently).
return [
checks.Error(
- 'Primary keys must not have null=True.',
- hint=('Set null=False on the field, or '
- 'remove primary_key=True argument.'),
+ "Primary keys must not have null=True.",
+ hint=(
+ "Set null=False on the field, or "
+ "remove primary_key=True argument."
+ ),
obj=self,
- id='fields.E007',
+ id="fields.E007",
)
]
else:
@@ -340,7 +402,9 @@ class Field(RegisterLookupMixin):
app_label = self.model._meta.app_label
errors = []
for alias in databases:
- if router.allow_migrate(alias, app_label, model_name=self.model._meta.model_name):
+ if router.allow_migrate(
+ alias, app_label, model_name=self.model._meta.model_name
+ ):
errors.extend(connections[alias].validation.check_field(self, **kwargs))
return errors
@@ -354,11 +418,12 @@ class Field(RegisterLookupMixin):
hint=(
"validators[{i}] ({repr}) isn't a function or "
"instance of a validator class.".format(
- i=i, repr=repr(validator),
+ i=i,
+ repr=repr(validator),
)
),
obj=self,
- id='fields.E008',
+ id="fields.E008",
)
)
return errors
@@ -368,41 +433,41 @@ class Field(RegisterLookupMixin):
return [
checks.Error(
self.system_check_removed_details.get(
- 'msg',
- '%s has been removed except for support in historical '
- 'migrations.' % self.__class__.__name__
+ "msg",
+ "%s has been removed except for support in historical "
+ "migrations." % self.__class__.__name__,
),
- hint=self.system_check_removed_details.get('hint'),
+ hint=self.system_check_removed_details.get("hint"),
obj=self,
- id=self.system_check_removed_details.get('id', 'fields.EXXX'),
+ id=self.system_check_removed_details.get("id", "fields.EXXX"),
)
]
elif self.system_check_deprecated_details is not None:
return [
checks.Warning(
self.system_check_deprecated_details.get(
- 'msg',
- '%s has been deprecated.' % self.__class__.__name__
+ "msg", "%s has been deprecated." % self.__class__.__name__
),
- hint=self.system_check_deprecated_details.get('hint'),
+ hint=self.system_check_deprecated_details.get("hint"),
obj=self,
- id=self.system_check_deprecated_details.get('id', 'fields.WXXX'),
+ id=self.system_check_deprecated_details.get("id", "fields.WXXX"),
)
]
return []
def get_col(self, alias, output_field=None):
- if (
- alias == self.model._meta.db_table and
- (output_field is None or output_field == self)
+ if alias == self.model._meta.db_table and (
+ output_field is None or output_field == self
):
return self.cached_col
from django.db.models.expressions import Col
+
return Col(alias, self, output_field)
@cached_property
def cached_col(self):
from django.db.models.expressions import Col
+
return Col(self.model._meta.db_table, self)
def select_format(self, compiler, sql, params):
@@ -462,7 +527,7 @@ class Field(RegisterLookupMixin):
"unique_for_month": None,
"unique_for_year": None,
"choices": None,
- "help_text": '',
+ "help_text": "",
"db_column": None,
"db_tablespace": None,
"auto_created": False,
@@ -495,8 +560,8 @@ class Field(RegisterLookupMixin):
path = path.replace("django.db.models.fields.related", "django.db.models")
elif path.startswith("django.db.models.fields.files"):
path = path.replace("django.db.models.fields.files", "django.db.models")
- elif path.startswith('django.db.models.fields.json'):
- path = path.replace('django.db.models.fields.json', 'django.db.models')
+ elif path.startswith("django.db.models.fields.json"):
+ path = path.replace("django.db.models.fields.json", "django.db.models")
elif path.startswith("django.db.models.fields.proxy"):
path = path.replace("django.db.models.fields.proxy", "django.db.models")
elif path.startswith("django.db.models.fields"):
@@ -515,10 +580,9 @@ class Field(RegisterLookupMixin):
def __eq__(self, other):
# Needed for @total_ordering
if isinstance(other, Field):
- return (
- self.creation_counter == other.creation_counter and
- getattr(self, 'model', None) == getattr(other, 'model', None)
- )
+ return self.creation_counter == other.creation_counter and getattr(
+ self, "model", None
+ ) == getattr(other, "model", None)
return NotImplemented
def __lt__(self, other):
@@ -526,17 +590,18 @@ class Field(RegisterLookupMixin):
# Order by creation_counter first for backward compatibility.
if isinstance(other, Field):
if (
- self.creation_counter != other.creation_counter or
- not hasattr(self, 'model') and not hasattr(other, 'model')
+ self.creation_counter != other.creation_counter
+ or not hasattr(self, "model")
+ and not hasattr(other, "model")
):
return self.creation_counter < other.creation_counter
- elif hasattr(self, 'model') != hasattr(other, 'model'):
- return not hasattr(self, 'model') # Order no-model fields first
+ elif hasattr(self, "model") != hasattr(other, "model"):
+ return not hasattr(self, "model") # Order no-model fields first
else:
# creation_counter's are equal, compare only models.
- return (
- (self.model._meta.app_label, self.model._meta.model_name) <
- (other.model._meta.app_label, other.model._meta.model_name)
+ return (self.model._meta.app_label, self.model._meta.model_name) < (
+ other.model._meta.app_label,
+ other.model._meta.model_name,
)
return NotImplemented
@@ -549,7 +614,7 @@ class Field(RegisterLookupMixin):
obj = copy.copy(self)
if self.remote_field:
obj.remote_field = copy.copy(self.remote_field)
- if hasattr(self.remote_field, 'field') and self.remote_field.field is self:
+ if hasattr(self.remote_field, "field") and self.remote_field.field is self:
obj.remote_field.field = obj
memodict[id(self)] = obj
return obj
@@ -568,7 +633,7 @@ class Field(RegisterLookupMixin):
not a new copy of that field. So, use the app registry to load the
model and then the field back.
"""
- if not hasattr(self, 'model'):
+ if not hasattr(self, "model"):
# Fields are sometimes used without attaching them to models (for
# example in aggregation). In this case give back a plain field
# instance. The code below will create a new empty instance of
@@ -577,10 +642,13 @@ class Field(RegisterLookupMixin):
state = self.__dict__.copy()
# The _get_default cached_property can't be pickled due to lambda
# usage.
- state.pop('_get_default', None)
+ state.pop("_get_default", None)
return _empty, (self.__class__,), state
- return _load_field, (self.model._meta.app_label, self.model._meta.object_name,
- self.name)
+ return _load_field, (
+ self.model._meta.app_label,
+ self.model._meta.object_name,
+ self.name,
+ )
def get_pk_value_on_save(self, instance):
"""
@@ -618,7 +686,7 @@ class Field(RegisterLookupMixin):
try:
v(value)
except exceptions.ValidationError as e:
- if hasattr(e, 'code') and e.code in self.error_messages:
+ if hasattr(e, "code") and e.code in self.error_messages:
e.message = self.error_messages[e.code]
errors.extend(e.error_list)
@@ -645,16 +713,16 @@ class Field(RegisterLookupMixin):
elif value == option_key:
return
raise exceptions.ValidationError(
- self.error_messages['invalid_choice'],
- code='invalid_choice',
- params={'value': value},
+ self.error_messages["invalid_choice"],
+ code="invalid_choice",
+ params={"value": value},
)
if value is None and not self.null:
- raise exceptions.ValidationError(self.error_messages['null'], code='null')
+ raise exceptions.ValidationError(self.error_messages["null"], code="null")
if not self.blank and value in self.empty_values:
- raise exceptions.ValidationError(self.error_messages['blank'], code='blank')
+ raise exceptions.ValidationError(self.error_messages["blank"], code="blank")
def clean(self, value, model_instance):
"""
@@ -668,7 +736,7 @@ class Field(RegisterLookupMixin):
return value
def db_type_parameters(self, connection):
- return DictWrapper(self.__dict__, connection.ops.quote_name, 'qn_')
+ return DictWrapper(self.__dict__, connection.ops.quote_name, "qn_")
def db_check(self, connection):
"""
@@ -678,7 +746,9 @@ class Field(RegisterLookupMixin):
"""
data = self.db_type_parameters(connection)
try:
- return connection.data_type_check_constraints[self.get_internal_type()] % data
+ return (
+ connection.data_type_check_constraints[self.get_internal_type()] % data
+ )
except KeyError:
return None
@@ -740,7 +810,7 @@ class Field(RegisterLookupMixin):
return connection.data_types_suffix.get(self.get_internal_type())
def get_db_converters(self, connection):
- if hasattr(self, 'from_db_value'):
+ if hasattr(self, "from_db_value"):
return [self.from_db_value]
return []
@@ -765,7 +835,7 @@ class Field(RegisterLookupMixin):
self.attname, self.column = self.get_attname_column()
self.concrete = self.column is not None
if self.verbose_name is None and self.name:
- self.verbose_name = self.name.replace('_', ' ')
+ self.verbose_name = self.name.replace("_", " ")
def contribute_to_class(self, cls, name, private_only=False):
"""
@@ -784,10 +854,10 @@ class Field(RegisterLookupMixin):
# this class, but don't check methods derived from inheritance, to
# allow overriding inherited choices. For more complex inheritance
# structures users should override contribute_to_class().
- if 'get_%s_display' % self.name not in cls.__dict__:
+ if "get_%s_display" % self.name not in cls.__dict__:
setattr(
cls,
- 'get_%s_display' % self.name,
+ "get_%s_display" % self.name,
partialmethod(cls._get_FIELD_display, field=self),
)
@@ -848,11 +918,21 @@ class Field(RegisterLookupMixin):
return self.default
return lambda: self.default
- if not self.empty_strings_allowed or self.null and not connection.features.interprets_empty_strings_as_nulls:
+ if (
+ not self.empty_strings_allowed
+ or self.null
+ and not connection.features.interprets_empty_strings_as_nulls
+ ):
return return_None
return str # return empty string
- def get_choices(self, include_blank=True, blank_choice=BLANK_CHOICE_DASH, limit_choices_to=None, ordering=()):
+ def get_choices(
+ self,
+ include_blank=True,
+ blank_choice=BLANK_CHOICE_DASH,
+ limit_choices_to=None,
+ ordering=(),
+ ):
"""
Return choices with a default blank choices included, for use
as <select> choices for this field.
@@ -860,7 +940,9 @@ class Field(RegisterLookupMixin):
if self.choices is not None:
choices = list(self.choices)
if include_blank:
- blank_defined = any(choice in ('', None) for choice, _ in self.flatchoices)
+ blank_defined = any(
+ choice in ("", None) for choice, _ in self.flatchoices
+ )
if not blank_defined:
choices = blank_choice + choices
return choices
@@ -868,8 +950,8 @@ class Field(RegisterLookupMixin):
limit_choices_to = limit_choices_to or self.get_limit_choices_to()
choice_func = operator.attrgetter(
self.remote_field.get_related_field().attname
- if hasattr(self.remote_field, 'get_related_field')
- else 'pk'
+ if hasattr(self.remote_field, "get_related_field")
+ else "pk"
)
qs = rel_model._default_manager.complex_filter(limit_choices_to)
if ordering:
@@ -896,6 +978,7 @@ class Field(RegisterLookupMixin):
else:
flat.append((choice, value))
return flat
+
flatchoices = property(_get_flatchoices)
def save_form_data(self, instance, data):
@@ -904,24 +987,25 @@ class Field(RegisterLookupMixin):
def formfield(self, form_class=None, choices_form_class=None, **kwargs):
"""Return a django.forms.Field instance for this field."""
defaults = {
- 'required': not self.blank,
- 'label': capfirst(self.verbose_name),
- 'help_text': self.help_text,
+ "required": not self.blank,
+ "label": capfirst(self.verbose_name),
+ "help_text": self.help_text,
}
if self.has_default():
if callable(self.default):
- defaults['initial'] = self.default
- defaults['show_hidden_initial'] = True
+ defaults["initial"] = self.default
+ defaults["show_hidden_initial"] = True
else:
- defaults['initial'] = self.get_default()
+ defaults["initial"] = self.get_default()
if self.choices is not None:
# Fields with choices get special treatment.
- include_blank = (self.blank or
- not (self.has_default() or 'initial' in kwargs))
- defaults['choices'] = self.get_choices(include_blank=include_blank)
- defaults['coerce'] = self.to_python
+ include_blank = self.blank or not (
+ self.has_default() or "initial" in kwargs
+ )
+ defaults["choices"] = self.get_choices(include_blank=include_blank)
+ defaults["coerce"] = self.to_python
if self.null:
- defaults['empty_value'] = None
+ defaults["empty_value"] = None
if choices_form_class is not None:
form_class = choices_form_class
else:
@@ -930,9 +1014,19 @@ class Field(RegisterLookupMixin):
# max_value) don't apply for choice fields, so be sure to only pass
# the values that TypedChoiceField will understand.
for k in list(kwargs):
- if k not in ('coerce', 'empty_value', 'choices', 'required',
- 'widget', 'label', 'initial', 'help_text',
- 'error_messages', 'show_hidden_initial', 'disabled'):
+ if k not in (
+ "coerce",
+ "empty_value",
+ "choices",
+ "required",
+ "widget",
+ "label",
+ "initial",
+ "help_text",
+ "error_messages",
+ "show_hidden_initial",
+ "disabled",
+ ):
del kwargs[k]
defaults.update(kwargs)
if form_class is None:
@@ -947,8 +1041,8 @@ class Field(RegisterLookupMixin):
class BooleanField(Field):
empty_strings_allowed = False
default_error_messages = {
- 'invalid': _('“%(value)s” value must be either True or False.'),
- 'invalid_nullable': _('“%(value)s” value must be either True, False, or None.'),
+ "invalid": _("“%(value)s” value must be either True or False."),
+ "invalid_nullable": _("“%(value)s” value must be either True, False, or None."),
}
description = _("Boolean (Either True or False)")
@@ -961,14 +1055,14 @@ class BooleanField(Field):
if value in (True, False):
# 1/0 are equal to True/False. bool() converts former to latter.
return bool(value)
- if value in ('t', 'True', '1'):
+ if value in ("t", "True", "1"):
return True
- if value in ('f', 'False', '0'):
+ if value in ("f", "False", "0"):
return False
raise exceptions.ValidationError(
- self.error_messages['invalid_nullable' if self.null else 'invalid'],
- code='invalid',
- params={'value': value},
+ self.error_messages["invalid_nullable" if self.null else "invalid"],
+ code="invalid",
+ params={"value": value},
)
def get_prep_value(self, value):
@@ -979,14 +1073,14 @@ class BooleanField(Field):
def formfield(self, **kwargs):
if self.choices is not None:
- include_blank = not (self.has_default() or 'initial' in kwargs)
- defaults = {'choices': self.get_choices(include_blank=include_blank)}
+ include_blank = not (self.has_default() or "initial" in kwargs)
+ defaults = {"choices": self.get_choices(include_blank=include_blank)}
else:
form_class = forms.NullBooleanField if self.null else forms.BooleanField
# In HTML checkboxes, 'required' means "must be checked" which is
# different from the choices case ("must select some value").
# required=False allows unchecked checkboxes.
- defaults = {'form_class': form_class, 'required': False}
+ defaults = {"form_class": form_class, "required": False}
return super().formfield(**{**defaults, **kwargs})
def select_format(self, compiler, sql, params):
@@ -994,8 +1088,8 @@ class BooleanField(Field):
# Filters that match everything are handled as empty strings in the
# WHERE clause, but in SELECT or GROUP BY list they must use a
# predicate that's always True.
- if sql == '':
- sql = '1'
+ if sql == "":
+ sql = "1"
return sql, params
@@ -1009,7 +1103,7 @@ class CharField(Field):
self.validators.append(validators.MaxLengthValidator(self.max_length))
def check(self, **kwargs):
- databases = kwargs.get('databases') or []
+ databases = kwargs.get("databases") or []
return [
*super().check(**kwargs),
*self._check_db_collation(databases),
@@ -1022,16 +1116,19 @@ class CharField(Field):
checks.Error(
"CharFields must define a 'max_length' attribute.",
obj=self,
- id='fields.E120',
+ id="fields.E120",
)
]
- elif (not isinstance(self.max_length, int) or isinstance(self.max_length, bool) or
- self.max_length <= 0):
+ elif (
+ not isinstance(self.max_length, int)
+ or isinstance(self.max_length, bool)
+ or self.max_length <= 0
+ ):
return [
checks.Error(
"'max_length' must be a positive integer.",
obj=self,
- id='fields.E121',
+ id="fields.E121",
)
]
else:
@@ -1044,16 +1141,17 @@ class CharField(Field):
continue
connection = connections[db]
if not (
- self.db_collation is None or
- 'supports_collation_on_charfield' in self.model._meta.required_db_features or
- connection.features.supports_collation_on_charfield
+ self.db_collation is None
+ or "supports_collation_on_charfield"
+ in self.model._meta.required_db_features
+ or connection.features.supports_collation_on_charfield
):
errors.append(
checks.Error(
- '%s does not support a database collation on '
- 'CharFields.' % connection.display_name,
+ "%s does not support a database collation on "
+ "CharFields." % connection.display_name,
obj=self,
- id='fields.E190',
+ id="fields.E190",
),
)
return errors
@@ -1079,17 +1177,17 @@ class CharField(Field):
# Passing max_length to forms.CharField means that the value's length
# will be validated twice. This is considered acceptable since we want
# the value in the form field (to pass into widget for example).
- defaults = {'max_length': self.max_length}
+ defaults = {"max_length": self.max_length}
# TODO: Handle multiple backends with different feature flags.
if self.null and not connection.features.interprets_empty_strings_as_nulls:
- defaults['empty_value'] = None
+ defaults["empty_value"] = None
defaults.update(kwargs)
return super().formfield(**defaults)
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if self.db_collation:
- kwargs['db_collation'] = self.db_collation
+ kwargs["db_collation"] = self.db_collation
return name, path, args, kwargs
@@ -1097,15 +1195,15 @@ class CommaSeparatedIntegerField(CharField):
default_validators = [validators.validate_comma_separated_integer_list]
description = _("Comma-separated integers")
system_check_removed_details = {
- 'msg': (
- 'CommaSeparatedIntegerField is removed except for support in '
- 'historical migrations.'
+ "msg": (
+ "CommaSeparatedIntegerField is removed except for support in "
+ "historical migrations."
),
- 'hint': (
- 'Use CharField(validators=[validate_comma_separated_integer_list]) '
- 'instead.'
+ "hint": (
+ "Use CharField(validators=[validate_comma_separated_integer_list]) "
+ "instead."
),
- 'id': 'fields.E901',
+ "id": "fields.E901",
}
@@ -1120,7 +1218,6 @@ def _get_naive_now():
class DateTimeCheckMixin:
-
def check(self, **kwargs):
return [
*super().check(**kwargs),
@@ -1132,8 +1229,14 @@ class DateTimeCheckMixin:
# auto_now, auto_now_add, and default are mutually exclusive
# options. The use of more than one of these options together
# will trigger an Error
- mutually_exclusive_options = [self.auto_now_add, self.auto_now, self.has_default()]
- enabled_options = [option not in (None, False) for option in mutually_exclusive_options].count(True)
+ mutually_exclusive_options = [
+ self.auto_now_add,
+ self.auto_now,
+ self.has_default(),
+ ]
+ enabled_options = [
+ option not in (None, False) for option in mutually_exclusive_options
+ ].count(True)
if enabled_options > 1:
return [
checks.Error(
@@ -1141,7 +1244,7 @@ class DateTimeCheckMixin:
"are mutually exclusive. Only one of these options "
"may be present.",
obj=self,
- id='fields.E160',
+ id="fields.E160",
)
]
else:
@@ -1173,15 +1276,15 @@ class DateTimeCheckMixin:
if lower <= value <= upper:
return [
checks.Warning(
- 'Fixed default value provided.',
+ "Fixed default value provided.",
hint=(
- 'It seems you set a fixed date / time / datetime '
- 'value as default for this field. This may not be '
- 'what you want. If you want to have the current date '
- 'as default, use `django.utils.timezone.now`'
+ "It seems you set a fixed date / time / datetime "
+ "value as default for this field. This may not be "
+ "what you want. If you want to have the current date "
+ "as default, use `django.utils.timezone.now`"
),
obj=self,
- id='fields.W161',
+ id="fields.W161",
)
]
return []
@@ -1190,19 +1293,24 @@ class DateTimeCheckMixin:
class DateField(DateTimeCheckMixin, Field):
empty_strings_allowed = False
default_error_messages = {
- 'invalid': _('“%(value)s” value has an invalid date format. It must be '
- 'in YYYY-MM-DD format.'),
- 'invalid_date': _('“%(value)s” value has the correct format (YYYY-MM-DD) '
- 'but it is an invalid date.'),
+ "invalid": _(
+ "“%(value)s” value has an invalid date format. It must be "
+ "in YYYY-MM-DD format."
+ ),
+ "invalid_date": _(
+ "“%(value)s” value has the correct format (YYYY-MM-DD) "
+ "but it is an invalid date."
+ ),
}
description = _("Date (without time)")
- def __init__(self, verbose_name=None, name=None, auto_now=False,
- auto_now_add=False, **kwargs):
+ def __init__(
+ self, verbose_name=None, name=None, auto_now=False, auto_now_add=False, **kwargs
+ ):
self.auto_now, self.auto_now_add = auto_now, auto_now_add
if auto_now or auto_now_add:
- kwargs['editable'] = False
- kwargs['blank'] = True
+ kwargs["editable"] = False
+ kwargs["blank"] = True
super().__init__(verbose_name, name, **kwargs)
def _check_fix_default_value(self):
@@ -1227,12 +1335,12 @@ class DateField(DateTimeCheckMixin, Field):
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if self.auto_now:
- kwargs['auto_now'] = True
+ kwargs["auto_now"] = True
if self.auto_now_add:
- kwargs['auto_now_add'] = True
+ kwargs["auto_now_add"] = True
if self.auto_now or self.auto_now_add:
- del kwargs['editable']
- del kwargs['blank']
+ del kwargs["editable"]
+ del kwargs["blank"]
return name, path, args, kwargs
def get_internal_type(self):
@@ -1257,15 +1365,15 @@ class DateField(DateTimeCheckMixin, Field):
return parsed
except ValueError:
raise exceptions.ValidationError(
- self.error_messages['invalid_date'],
- code='invalid_date',
- params={'value': value},
+ self.error_messages["invalid_date"],
+ code="invalid_date",
+ params={"value": value},
)
raise exceptions.ValidationError(
- self.error_messages['invalid'],
- code='invalid',
- params={'value': value},
+ self.error_messages["invalid"],
+ code="invalid",
+ params={"value": value},
)
def pre_save(self, model_instance, add):
@@ -1280,12 +1388,18 @@ class DateField(DateTimeCheckMixin, Field):
super().contribute_to_class(cls, name, **kwargs)
if not self.null:
setattr(
- cls, 'get_next_by_%s' % self.name,
- partialmethod(cls._get_next_or_previous_by_FIELD, field=self, is_next=True)
+ cls,
+ "get_next_by_%s" % self.name,
+ partialmethod(
+ cls._get_next_or_previous_by_FIELD, field=self, is_next=True
+ ),
)
setattr(
- cls, 'get_previous_by_%s' % self.name,
- partialmethod(cls._get_next_or_previous_by_FIELD, field=self, is_next=False)
+ cls,
+ "get_previous_by_%s" % self.name,
+ partialmethod(
+ cls._get_next_or_previous_by_FIELD, field=self, is_next=False
+ ),
)
def get_prep_value(self, value):
@@ -1300,25 +1414,33 @@ class DateField(DateTimeCheckMixin, Field):
def value_to_string(self, obj):
val = self.value_from_object(obj)
- return '' if val is None else val.isoformat()
+ return "" if val is None else val.isoformat()
def formfield(self, **kwargs):
- return super().formfield(**{
- 'form_class': forms.DateField,
- **kwargs,
- })
+ return super().formfield(
+ **{
+ "form_class": forms.DateField,
+ **kwargs,
+ }
+ )
class DateTimeField(DateField):
empty_strings_allowed = False
default_error_messages = {
- 'invalid': _('“%(value)s” value has an invalid format. It must be in '
- 'YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ] format.'),
- 'invalid_date': _("“%(value)s” value has the correct format "
- "(YYYY-MM-DD) but it is an invalid date."),
- 'invalid_datetime': _('“%(value)s” value has the correct format '
- '(YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ]) '
- 'but it is an invalid date/time.'),
+ "invalid": _(
+ "“%(value)s” value has an invalid format. It must be in "
+ "YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ] format."
+ ),
+ "invalid_date": _(
+ "“%(value)s” value has the correct format "
+ "(YYYY-MM-DD) but it is an invalid date."
+ ),
+ "invalid_datetime": _(
+ "“%(value)s” value has the correct format "
+ "(YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ]) "
+ "but it is an invalid date/time."
+ ),
}
description = _("Date (with time)")
@@ -1353,10 +1475,12 @@ class DateTimeField(DateField):
# local time. This won't work during DST change, but we can't
# do much about it, so we let the exceptions percolate up the
# call stack.
- warnings.warn("DateTimeField %s.%s received a naive datetime "
- "(%s) while time zone support is active." %
- (self.model.__name__, self.name, value),
- RuntimeWarning)
+ warnings.warn(
+ "DateTimeField %s.%s received a naive datetime "
+ "(%s) while time zone support is active."
+ % (self.model.__name__, self.name, value),
+ RuntimeWarning,
+ )
default_timezone = timezone.get_default_timezone()
value = timezone.make_aware(value, default_timezone)
return value
@@ -1367,9 +1491,9 @@ class DateTimeField(DateField):
return parsed
except ValueError:
raise exceptions.ValidationError(
- self.error_messages['invalid_datetime'],
- code='invalid_datetime',
- params={'value': value},
+ self.error_messages["invalid_datetime"],
+ code="invalid_datetime",
+ params={"value": value},
)
try:
@@ -1378,15 +1502,15 @@ class DateTimeField(DateField):
return datetime.datetime(parsed.year, parsed.month, parsed.day)
except ValueError:
raise exceptions.ValidationError(
- self.error_messages['invalid_date'],
- code='invalid_date',
- params={'value': value},
+ self.error_messages["invalid_date"],
+ code="invalid_date",
+ params={"value": value},
)
raise exceptions.ValidationError(
- self.error_messages['invalid'],
- code='invalid',
- params={'value': value},
+ self.error_messages["invalid"],
+ code="invalid",
+ params={"value": value},
)
def pre_save(self, model_instance, add):
@@ -1408,13 +1532,14 @@ class DateTimeField(DateField):
# time. This won't work during DST change, but we can't do much
# about it, so we let the exceptions percolate up the call stack.
try:
- name = '%s.%s' % (self.model.__name__, self.name)
+ name = "%s.%s" % (self.model.__name__, self.name)
except AttributeError:
- name = '(unbound)'
- warnings.warn("DateTimeField %s received a naive datetime (%s)"
- " while time zone support is active." %
- (name, value),
- RuntimeWarning)
+ name = "(unbound)"
+ warnings.warn(
+ "DateTimeField %s received a naive datetime (%s)"
+ " while time zone support is active." % (name, value),
+ RuntimeWarning,
+ )
default_timezone = timezone.get_default_timezone()
value = timezone.make_aware(value, default_timezone)
return value
@@ -1427,24 +1552,32 @@ class DateTimeField(DateField):
def value_to_string(self, obj):
val = self.value_from_object(obj)
- return '' if val is None else val.isoformat()
+ return "" if val is None else val.isoformat()
def formfield(self, **kwargs):
- return super().formfield(**{
- 'form_class': forms.DateTimeField,
- **kwargs,
- })
+ return super().formfield(
+ **{
+ "form_class": forms.DateTimeField,
+ **kwargs,
+ }
+ )
class DecimalField(Field):
empty_strings_allowed = False
default_error_messages = {
- 'invalid': _('“%(value)s” value must be a decimal number.'),
+ "invalid": _("“%(value)s” value must be a decimal number."),
}
description = _("Decimal number")
- def __init__(self, verbose_name=None, name=None, max_digits=None,
- decimal_places=None, **kwargs):
+ def __init__(
+ self,
+ verbose_name=None,
+ name=None,
+ max_digits=None,
+ decimal_places=None,
+ **kwargs,
+ ):
self.max_digits, self.decimal_places = max_digits, decimal_places
super().__init__(verbose_name, name, **kwargs)
@@ -1471,7 +1604,7 @@ class DecimalField(Field):
checks.Error(
"DecimalFields must define a 'decimal_places' attribute.",
obj=self,
- id='fields.E130',
+ id="fields.E130",
)
]
except ValueError:
@@ -1479,7 +1612,7 @@ class DecimalField(Field):
checks.Error(
"'decimal_places' must be a non-negative integer.",
obj=self,
- id='fields.E131',
+ id="fields.E131",
)
]
else:
@@ -1495,7 +1628,7 @@ class DecimalField(Field):
checks.Error(
"DecimalFields must define a 'max_digits' attribute.",
obj=self,
- id='fields.E132',
+ id="fields.E132",
)
]
except ValueError:
@@ -1503,7 +1636,7 @@ class DecimalField(Field):
checks.Error(
"'max_digits' must be a positive integer.",
obj=self,
- id='fields.E133',
+ id="fields.E133",
)
]
else:
@@ -1515,7 +1648,7 @@ class DecimalField(Field):
checks.Error(
"'max_digits' must be greater or equal to 'decimal_places'.",
obj=self,
- id='fields.E134',
+ id="fields.E134",
)
]
return []
@@ -1533,9 +1666,9 @@ class DecimalField(Field):
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if self.max_digits is not None:
- kwargs['max_digits'] = self.max_digits
+ kwargs["max_digits"] = self.max_digits
if self.decimal_places is not None:
- kwargs['decimal_places'] = self.decimal_places
+ kwargs["decimal_places"] = self.decimal_places
return name, path, args, kwargs
def get_internal_type(self):
@@ -1547,34 +1680,38 @@ class DecimalField(Field):
if isinstance(value, float):
if math.isnan(value):
raise exceptions.ValidationError(
- self.error_messages['invalid'],
- code='invalid',
- params={'value': value},
+ self.error_messages["invalid"],
+ code="invalid",
+ params={"value": value},
)
return self.context.create_decimal_from_float(value)
try:
return decimal.Decimal(value)
except (decimal.InvalidOperation, TypeError, ValueError):
raise exceptions.ValidationError(
- self.error_messages['invalid'],
- code='invalid',
- params={'value': value},
+ self.error_messages["invalid"],
+ code="invalid",
+ params={"value": value},
)
def get_db_prep_save(self, value, connection):
- return connection.ops.adapt_decimalfield_value(self.to_python(value), self.max_digits, self.decimal_places)
+ return connection.ops.adapt_decimalfield_value(
+ self.to_python(value), self.max_digits, self.decimal_places
+ )
def get_prep_value(self, value):
value = super().get_prep_value(value)
return self.to_python(value)
def formfield(self, **kwargs):
- return super().formfield(**{
- 'max_digits': self.max_digits,
- 'decimal_places': self.decimal_places,
- 'form_class': forms.DecimalField,
- **kwargs,
- })
+ return super().formfield(
+ **{
+ "max_digits": self.max_digits,
+ "decimal_places": self.decimal_places,
+ "form_class": forms.DecimalField,
+ **kwargs,
+ }
+ )
class DurationField(Field):
@@ -1584,10 +1721,13 @@ class DurationField(Field):
Use interval on PostgreSQL, INTERVAL DAY TO SECOND on Oracle, and bigint
of microseconds on other databases.
"""
+
empty_strings_allowed = False
default_error_messages = {
- 'invalid': _('“%(value)s” value has an invalid format. It must be in '
- '[DD] [[HH:]MM:]ss[.uuuuuu] format.')
+ "invalid": _(
+ "“%(value)s” value has an invalid format. It must be in "
+ "[DD] [[HH:]MM:]ss[.uuuuuu] format."
+ )
}
description = _("Duration")
@@ -1608,9 +1748,9 @@ class DurationField(Field):
return parsed
raise exceptions.ValidationError(
- self.error_messages['invalid'],
- code='invalid',
- params={'value': value},
+ self.error_messages["invalid"],
+ code="invalid",
+ params={"value": value},
)
def get_db_prep_value(self, value, connection, prepared=False):
@@ -1628,13 +1768,15 @@ class DurationField(Field):
def value_to_string(self, obj):
val = self.value_from_object(obj)
- return '' if val is None else duration_string(val)
+ return "" if val is None else duration_string(val)
def formfield(self, **kwargs):
- return super().formfield(**{
- 'form_class': forms.DurationField,
- **kwargs,
- })
+ return super().formfield(
+ **{
+ "form_class": forms.DurationField,
+ **kwargs,
+ }
+ )
class EmailField(CharField):
@@ -1643,7 +1785,7 @@ class EmailField(CharField):
def __init__(self, *args, **kwargs):
# max_length=254 to be compliant with RFCs 3696 and 5321
- kwargs.setdefault('max_length', 254)
+ kwargs.setdefault("max_length", 254)
super().__init__(*args, **kwargs)
def deconstruct(self):
@@ -1655,20 +1797,31 @@ class EmailField(CharField):
def formfield(self, **kwargs):
# As with CharField, this will cause email validation to be performed
# twice.
- return super().formfield(**{
- 'form_class': forms.EmailField,
- **kwargs,
- })
+ return super().formfield(
+ **{
+ "form_class": forms.EmailField,
+ **kwargs,
+ }
+ )
class FilePathField(Field):
description = _("File path")
- def __init__(self, verbose_name=None, name=None, path='', match=None,
- recursive=False, allow_files=True, allow_folders=False, **kwargs):
+ def __init__(
+ self,
+ verbose_name=None,
+ name=None,
+ path="",
+ match=None,
+ recursive=False,
+ allow_files=True,
+ allow_folders=False,
+ **kwargs,
+ ):
self.path, self.match, self.recursive = path, match, recursive
self.allow_files, self.allow_folders = allow_files, allow_folders
- kwargs.setdefault('max_length', 100)
+ kwargs.setdefault("max_length", 100)
super().__init__(verbose_name, name, **kwargs)
def check(self, **kwargs):
@@ -1683,23 +1836,23 @@ class FilePathField(Field):
checks.Error(
"FilePathFields must have either 'allow_files' or 'allow_folders' set to True.",
obj=self,
- id='fields.E140',
+ id="fields.E140",
)
]
return []
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
- if self.path != '':
- kwargs['path'] = self.path
+ if self.path != "":
+ kwargs["path"] = self.path
if self.match is not None:
- kwargs['match'] = self.match
+ kwargs["match"] = self.match
if self.recursive is not False:
- kwargs['recursive'] = self.recursive
+ kwargs["recursive"] = self.recursive
if self.allow_files is not True:
- kwargs['allow_files'] = self.allow_files
+ kwargs["allow_files"] = self.allow_files
if self.allow_folders is not False:
- kwargs['allow_folders'] = self.allow_folders
+ kwargs["allow_folders"] = self.allow_folders
if kwargs.get("max_length") == 100:
del kwargs["max_length"]
return name, path, args, kwargs
@@ -1711,15 +1864,17 @@ class FilePathField(Field):
return str(value)
def formfield(self, **kwargs):
- return super().formfield(**{
- 'path': self.path() if callable(self.path) else self.path,
- 'match': self.match,
- 'recursive': self.recursive,
- 'form_class': forms.FilePathField,
- 'allow_files': self.allow_files,
- 'allow_folders': self.allow_folders,
- **kwargs,
- })
+ return super().formfield(
+ **{
+ "path": self.path() if callable(self.path) else self.path,
+ "match": self.match,
+ "recursive": self.recursive,
+ "form_class": forms.FilePathField,
+ "allow_files": self.allow_files,
+ "allow_folders": self.allow_folders,
+ **kwargs,
+ }
+ )
def get_internal_type(self):
return "FilePathField"
@@ -1728,7 +1883,7 @@ class FilePathField(Field):
class FloatField(Field):
empty_strings_allowed = False
default_error_messages = {
- 'invalid': _('“%(value)s” value must be a float.'),
+ "invalid": _("“%(value)s” value must be a float."),
}
description = _("Floating point number")
@@ -1753,22 +1908,24 @@ class FloatField(Field):
return float(value)
except (TypeError, ValueError):
raise exceptions.ValidationError(
- self.error_messages['invalid'],
- code='invalid',
- params={'value': value},
+ self.error_messages["invalid"],
+ code="invalid",
+ params={"value": value},
)
def formfield(self, **kwargs):
- return super().formfield(**{
- 'form_class': forms.FloatField,
- **kwargs,
- })
+ return super().formfield(
+ **{
+ "form_class": forms.FloatField,
+ **kwargs,
+ }
+ )
class IntegerField(Field):
empty_strings_allowed = False
default_error_messages = {
- 'invalid': _('“%(value)s” value must be an integer.'),
+ "invalid": _("“%(value)s” value must be an integer."),
}
description = _("Integer")
@@ -1782,10 +1939,11 @@ class IntegerField(Field):
if self.max_length is not None:
return [
checks.Warning(
- "'max_length' is ignored when used with %s." % self.__class__.__name__,
+ "'max_length' is ignored when used with %s."
+ % self.__class__.__name__,
hint="Remove 'max_length' from field",
obj=self,
- id='fields.W122',
+ id="fields.W122",
)
]
return []
@@ -1799,22 +1957,28 @@ class IntegerField(Field):
min_value, max_value = connection.ops.integer_field_range(internal_type)
if min_value is not None and not any(
(
- isinstance(validator, validators.MinValueValidator) and (
+ isinstance(validator, validators.MinValueValidator)
+ and (
validator.limit_value()
if callable(validator.limit_value)
else validator.limit_value
- ) >= min_value
- ) for validator in validators_
+ )
+ >= min_value
+ )
+ for validator in validators_
):
validators_.append(validators.MinValueValidator(min_value))
if max_value is not None and not any(
(
- isinstance(validator, validators.MaxValueValidator) and (
+ isinstance(validator, validators.MaxValueValidator)
+ and (
validator.limit_value()
if callable(validator.limit_value)
else validator.limit_value
- ) <= max_value
- ) for validator in validators_
+ )
+ <= max_value
+ )
+ for validator in validators_
):
validators_.append(validators.MaxValueValidator(max_value))
return validators_
@@ -1840,16 +2004,18 @@ class IntegerField(Field):
return int(value)
except (TypeError, ValueError):
raise exceptions.ValidationError(
- self.error_messages['invalid'],
- code='invalid',
- params={'value': value},
+ self.error_messages["invalid"],
+ code="invalid",
+ params={"value": value},
)
def formfield(self, **kwargs):
- return super().formfield(**{
- 'form_class': forms.IntegerField,
- **kwargs,
- })
+ return super().formfield(
+ **{
+ "form_class": forms.IntegerField,
+ **kwargs,
+ }
+ )
class BigIntegerField(IntegerField):
@@ -1860,39 +2026,41 @@ class BigIntegerField(IntegerField):
return "BigIntegerField"
def formfield(self, **kwargs):
- return super().formfield(**{
- 'min_value': -BigIntegerField.MAX_BIGINT - 1,
- 'max_value': BigIntegerField.MAX_BIGINT,
- **kwargs,
- })
+ return super().formfield(
+ **{
+ "min_value": -BigIntegerField.MAX_BIGINT - 1,
+ "max_value": BigIntegerField.MAX_BIGINT,
+ **kwargs,
+ }
+ )
class SmallIntegerField(IntegerField):
- description = _('Small integer')
+ description = _("Small integer")
def get_internal_type(self):
- return 'SmallIntegerField'
+ return "SmallIntegerField"
class IPAddressField(Field):
empty_strings_allowed = False
description = _("IPv4 address")
system_check_removed_details = {
- 'msg': (
- 'IPAddressField has been removed except for support in '
- 'historical migrations.'
+ "msg": (
+ "IPAddressField has been removed except for support in "
+ "historical migrations."
),
- 'hint': 'Use GenericIPAddressField instead.',
- 'id': 'fields.E900',
+ "hint": "Use GenericIPAddressField instead.",
+ "id": "fields.E900",
}
def __init__(self, *args, **kwargs):
- kwargs['max_length'] = 15
+ kwargs["max_length"] = 15
super().__init__(*args, **kwargs)
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
- del kwargs['max_length']
+ del kwargs["max_length"]
return name, path, args, kwargs
def get_prep_value(self, value):
@@ -1910,14 +2078,23 @@ class GenericIPAddressField(Field):
description = _("IP address")
default_error_messages = {}
- def __init__(self, verbose_name=None, name=None, protocol='both',
- unpack_ipv4=False, *args, **kwargs):
+ def __init__(
+ self,
+ verbose_name=None,
+ name=None,
+ protocol="both",
+ unpack_ipv4=False,
+ *args,
+ **kwargs,
+ ):
self.unpack_ipv4 = unpack_ipv4
self.protocol = protocol
- self.default_validators, invalid_error_message = \
- validators.ip_address_validators(protocol, unpack_ipv4)
- self.default_error_messages['invalid'] = invalid_error_message
- kwargs['max_length'] = 39
+ (
+ self.default_validators,
+ invalid_error_message,
+ ) = validators.ip_address_validators(protocol, unpack_ipv4)
+ self.default_error_messages["invalid"] = invalid_error_message
+ kwargs["max_length"] = 39
super().__init__(verbose_name, name, *args, **kwargs)
def check(self, **kwargs):
@@ -1927,13 +2104,13 @@ class GenericIPAddressField(Field):
]
def _check_blank_and_null_values(self, **kwargs):
- if not getattr(self, 'null', False) and getattr(self, 'blank', False):
+ if not getattr(self, "null", False) and getattr(self, "blank", False):
return [
checks.Error(
- 'GenericIPAddressFields cannot have blank=True if null=False, '
- 'as blank values are stored as nulls.',
+ "GenericIPAddressFields cannot have blank=True if null=False, "
+ "as blank values are stored as nulls.",
obj=self,
- id='fields.E150',
+ id="fields.E150",
)
]
return []
@@ -1941,11 +2118,11 @@ class GenericIPAddressField(Field):
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if self.unpack_ipv4 is not False:
- kwargs['unpack_ipv4'] = self.unpack_ipv4
+ kwargs["unpack_ipv4"] = self.unpack_ipv4
if self.protocol != "both":
- kwargs['protocol'] = self.protocol
+ kwargs["protocol"] = self.protocol
if kwargs.get("max_length") == 39:
- del kwargs['max_length']
+ del kwargs["max_length"]
return name, path, args, kwargs
def get_internal_type(self):
@@ -1957,8 +2134,10 @@ class GenericIPAddressField(Field):
if not isinstance(value, str):
value = str(value)
value = value.strip()
- if ':' in value:
- return clean_ipv6_address(value, self.unpack_ipv4, self.error_messages['invalid'])
+ if ":" in value:
+ return clean_ipv6_address(
+ value, self.unpack_ipv4, self.error_messages["invalid"]
+ )
return value
def get_db_prep_value(self, value, connection, prepared=False):
@@ -1970,7 +2149,7 @@ class GenericIPAddressField(Field):
value = super().get_prep_value(value)
if value is None:
return None
- if value and ':' in value:
+ if value and ":" in value:
try:
return clean_ipv6_address(value, self.unpack_ipv4)
except exceptions.ValidationError:
@@ -1978,44 +2157,46 @@ class GenericIPAddressField(Field):
return str(value)
def formfield(self, **kwargs):
- return super().formfield(**{
- 'protocol': self.protocol,
- 'form_class': forms.GenericIPAddressField,
- **kwargs,
- })
+ return super().formfield(
+ **{
+ "protocol": self.protocol,
+ "form_class": forms.GenericIPAddressField,
+ **kwargs,
+ }
+ )
class NullBooleanField(BooleanField):
default_error_messages = {
- 'invalid': _('“%(value)s” value must be either None, True or False.'),
- 'invalid_nullable': _('“%(value)s” value must be either None, True or False.'),
+ "invalid": _("“%(value)s” value must be either None, True or False."),
+ "invalid_nullable": _("“%(value)s” value must be either None, True or False."),
}
description = _("Boolean (Either True, False or None)")
system_check_removed_details = {
- 'msg': (
- 'NullBooleanField is removed except for support in historical '
- 'migrations.'
+ "msg": (
+ "NullBooleanField is removed except for support in historical "
+ "migrations."
),
- 'hint': 'Use BooleanField(null=True) instead.',
- 'id': 'fields.E903',
+ "hint": "Use BooleanField(null=True) instead.",
+ "id": "fields.E903",
}
def __init__(self, *args, **kwargs):
- kwargs['null'] = True
- kwargs['blank'] = True
+ kwargs["null"] = True
+ kwargs["blank"] = True
super().__init__(*args, **kwargs)
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
- del kwargs['null']
- del kwargs['blank']
+ del kwargs["null"]
+ del kwargs["blank"]
return name, path, args, kwargs
class PositiveIntegerRelDbTypeMixin:
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
- if not hasattr(cls, 'integer_field_class'):
+ if not hasattr(cls, "integer_field_class"):
cls.integer_field_class = next(
(
parent
@@ -2041,16 +2222,18 @@ class PositiveIntegerRelDbTypeMixin:
class PositiveBigIntegerField(PositiveIntegerRelDbTypeMixin, BigIntegerField):
- description = _('Positive big integer')
+ description = _("Positive big integer")
def get_internal_type(self):
- return 'PositiveBigIntegerField'
+ return "PositiveBigIntegerField"
def formfield(self, **kwargs):
- return super().formfield(**{
- 'min_value': 0,
- **kwargs,
- })
+ return super().formfield(
+ **{
+ "min_value": 0,
+ **kwargs,
+ }
+ )
class PositiveIntegerField(PositiveIntegerRelDbTypeMixin, IntegerField):
@@ -2060,10 +2243,12 @@ class PositiveIntegerField(PositiveIntegerRelDbTypeMixin, IntegerField):
return "PositiveIntegerField"
def formfield(self, **kwargs):
- return super().formfield(**{
- 'min_value': 0,
- **kwargs,
- })
+ return super().formfield(
+ **{
+ "min_value": 0,
+ **kwargs,
+ }
+ )
class PositiveSmallIntegerField(PositiveIntegerRelDbTypeMixin, SmallIntegerField):
@@ -2073,17 +2258,21 @@ class PositiveSmallIntegerField(PositiveIntegerRelDbTypeMixin, SmallIntegerField
return "PositiveSmallIntegerField"
def formfield(self, **kwargs):
- return super().formfield(**{
- 'min_value': 0,
- **kwargs,
- })
+ return super().formfield(
+ **{
+ "min_value": 0,
+ **kwargs,
+ }
+ )
class SlugField(CharField):
default_validators = [validators.validate_slug]
description = _("Slug (up to %(max_length)s)")
- def __init__(self, *args, max_length=50, db_index=True, allow_unicode=False, **kwargs):
+ def __init__(
+ self, *args, max_length=50, db_index=True, allow_unicode=False, **kwargs
+ ):
self.allow_unicode = allow_unicode
if self.allow_unicode:
self.default_validators = [validators.validate_unicode_slug]
@@ -2092,24 +2281,26 @@ class SlugField(CharField):
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if kwargs.get("max_length") == 50:
- del kwargs['max_length']
+ del kwargs["max_length"]
if self.db_index is False:
- kwargs['db_index'] = False
+ kwargs["db_index"] = False
else:
- del kwargs['db_index']
+ del kwargs["db_index"]
if self.allow_unicode is not False:
- kwargs['allow_unicode'] = self.allow_unicode
+ kwargs["allow_unicode"] = self.allow_unicode
return name, path, args, kwargs
def get_internal_type(self):
return "SlugField"
def formfield(self, **kwargs):
- return super().formfield(**{
- 'form_class': forms.SlugField,
- 'allow_unicode': self.allow_unicode,
- **kwargs,
- })
+ return super().formfield(
+ **{
+ "form_class": forms.SlugField,
+ "allow_unicode": self.allow_unicode,
+ **kwargs,
+ }
+ )
class TextField(Field):
@@ -2120,7 +2311,7 @@ class TextField(Field):
self.db_collation = db_collation
def check(self, **kwargs):
- databases = kwargs.get('databases') or []
+ databases = kwargs.get("databases") or []
return [
*super().check(**kwargs),
*self._check_db_collation(databases),
@@ -2133,16 +2324,17 @@ class TextField(Field):
continue
connection = connections[db]
if not (
- self.db_collation is None or
- 'supports_collation_on_textfield' in self.model._meta.required_db_features or
- connection.features.supports_collation_on_textfield
+ self.db_collation is None
+ or "supports_collation_on_textfield"
+ in self.model._meta.required_db_features
+ or connection.features.supports_collation_on_textfield
):
errors.append(
checks.Error(
- '%s does not support a database collation on '
- 'TextFields.' % connection.display_name,
+ "%s does not support a database collation on "
+ "TextFields." % connection.display_name,
obj=self,
- id='fields.E190',
+ id="fields.E190",
),
)
return errors
@@ -2163,35 +2355,42 @@ class TextField(Field):
# Passing max_length to forms.CharField means that the value's length
# will be validated twice. This is considered acceptable since we want
# the value in the form field (to pass into widget for example).
- return super().formfield(**{
- 'max_length': self.max_length,
- **({} if self.choices is not None else {'widget': forms.Textarea}),
- **kwargs,
- })
+ return super().formfield(
+ **{
+ "max_length": self.max_length,
+ **({} if self.choices is not None else {"widget": forms.Textarea}),
+ **kwargs,
+ }
+ )
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if self.db_collation:
- kwargs['db_collation'] = self.db_collation
+ kwargs["db_collation"] = self.db_collation
return name, path, args, kwargs
class TimeField(DateTimeCheckMixin, Field):
empty_strings_allowed = False
default_error_messages = {
- 'invalid': _('“%(value)s” value has an invalid format. It must be in '
- 'HH:MM[:ss[.uuuuuu]] format.'),
- 'invalid_time': _('“%(value)s” value has the correct format '
- '(HH:MM[:ss[.uuuuuu]]) but it is an invalid time.'),
+ "invalid": _(
+ "“%(value)s” value has an invalid format. It must be in "
+ "HH:MM[:ss[.uuuuuu]] format."
+ ),
+ "invalid_time": _(
+ "“%(value)s” value has the correct format "
+ "(HH:MM[:ss[.uuuuuu]]) but it is an invalid time."
+ ),
}
description = _("Time")
- def __init__(self, verbose_name=None, name=None, auto_now=False,
- auto_now_add=False, **kwargs):
+ def __init__(
+ self, verbose_name=None, name=None, auto_now=False, auto_now_add=False, **kwargs
+ ):
self.auto_now, self.auto_now_add = auto_now, auto_now_add
if auto_now or auto_now_add:
- kwargs['editable'] = False
- kwargs['blank'] = True
+ kwargs["editable"] = False
+ kwargs["blank"] = True
super().__init__(verbose_name, name, **kwargs)
def _check_fix_default_value(self):
@@ -2223,8 +2422,8 @@ class TimeField(DateTimeCheckMixin, Field):
if self.auto_now_add is not False:
kwargs["auto_now_add"] = self.auto_now_add
if self.auto_now or self.auto_now_add:
- del kwargs['blank']
- del kwargs['editable']
+ del kwargs["blank"]
+ del kwargs["editable"]
return name, path, args, kwargs
def get_internal_type(self):
@@ -2247,15 +2446,15 @@ class TimeField(DateTimeCheckMixin, Field):
return parsed
except ValueError:
raise exceptions.ValidationError(
- self.error_messages['invalid_time'],
- code='invalid_time',
- params={'value': value},
+ self.error_messages["invalid_time"],
+ code="invalid_time",
+ params={"value": value},
)
raise exceptions.ValidationError(
- self.error_messages['invalid'],
- code='invalid',
- params={'value': value},
+ self.error_messages["invalid"],
+ code="invalid",
+ params={"value": value},
)
def pre_save(self, model_instance, add):
@@ -2278,13 +2477,15 @@ class TimeField(DateTimeCheckMixin, Field):
def value_to_string(self, obj):
val = self.value_from_object(obj)
- return '' if val is None else val.isoformat()
+ return "" if val is None else val.isoformat()
def formfield(self, **kwargs):
- return super().formfield(**{
- 'form_class': forms.TimeField,
- **kwargs,
- })
+ return super().formfield(
+ **{
+ "form_class": forms.TimeField,
+ **kwargs,
+ }
+ )
class URLField(CharField):
@@ -2292,30 +2493,32 @@ class URLField(CharField):
description = _("URL")
def __init__(self, verbose_name=None, name=None, **kwargs):
- kwargs.setdefault('max_length', 200)
+ kwargs.setdefault("max_length", 200)
super().__init__(verbose_name, name, **kwargs)
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if kwargs.get("max_length") == 200:
- del kwargs['max_length']
+ del kwargs["max_length"]
return name, path, args, kwargs
def formfield(self, **kwargs):
# As with CharField, this will cause URL validation to be performed
# twice.
- return super().formfield(**{
- 'form_class': forms.URLField,
- **kwargs,
- })
+ return super().formfield(
+ **{
+ "form_class": forms.URLField,
+ **kwargs,
+ }
+ )
class BinaryField(Field):
description = _("Raw binary data")
- empty_values = [None, b'']
+ empty_values = [None, b""]
def __init__(self, *args, **kwargs):
- kwargs.setdefault('editable', False)
+ kwargs.setdefault("editable", False)
super().__init__(*args, **kwargs)
if self.max_length is not None:
self.validators.append(validators.MaxLengthValidator(self.max_length))
@@ -2330,7 +2533,7 @@ class BinaryField(Field):
"BinaryField's default cannot be a string. Use bytes "
"content instead.",
obj=self,
- id='fields.E170',
+ id="fields.E170",
)
]
return []
@@ -2338,9 +2541,9 @@ class BinaryField(Field):
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if self.editable:
- kwargs['editable'] = True
+ kwargs["editable"] = True
else:
- del kwargs['editable']
+ del kwargs["editable"]
return name, path, args, kwargs
def get_internal_type(self):
@@ -2353,8 +2556,8 @@ class BinaryField(Field):
if self.has_default() and not callable(self.default):
return self.default
default = super().get_default()
- if default == '':
- return b''
+ if default == "":
+ return b""
return default
def get_db_prep_value(self, value, connection, prepared=False):
@@ -2365,29 +2568,29 @@ class BinaryField(Field):
def value_to_string(self, obj):
"""Binary data is serialized as base64"""
- return b64encode(self.value_from_object(obj)).decode('ascii')
+ return b64encode(self.value_from_object(obj)).decode("ascii")
def to_python(self, value):
# If it's a string, it should be base64-encoded data
if isinstance(value, str):
- return memoryview(b64decode(value.encode('ascii')))
+ return memoryview(b64decode(value.encode("ascii")))
return value
class UUIDField(Field):
default_error_messages = {
- 'invalid': _('“%(value)s” is not a valid UUID.'),
+ "invalid": _("“%(value)s” is not a valid UUID."),
}
- description = _('Universally unique identifier')
+ description = _("Universally unique identifier")
empty_strings_allowed = False
def __init__(self, verbose_name=None, **kwargs):
- kwargs['max_length'] = 32
+ kwargs["max_length"] = 32
super().__init__(verbose_name, **kwargs)
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
- del kwargs['max_length']
+ del kwargs["max_length"]
return name, path, args, kwargs
def get_internal_type(self):
@@ -2409,29 +2612,31 @@ class UUIDField(Field):
def to_python(self, value):
if value is not None and not isinstance(value, uuid.UUID):
- input_form = 'int' if isinstance(value, int) else 'hex'
+ input_form = "int" if isinstance(value, int) else "hex"
try:
return uuid.UUID(**{input_form: value})
except (AttributeError, ValueError):
raise exceptions.ValidationError(
- self.error_messages['invalid'],
- code='invalid',
- params={'value': value},
+ self.error_messages["invalid"],
+ code="invalid",
+ params={"value": value},
)
return value
def formfield(self, **kwargs):
- return super().formfield(**{
- 'form_class': forms.UUIDField,
- **kwargs,
- })
+ return super().formfield(
+ **{
+ "form_class": forms.UUIDField,
+ **kwargs,
+ }
+ )
class AutoFieldMixin:
db_returning = True
def __init__(self, *args, **kwargs):
- kwargs['blank'] = True
+ kwargs["blank"] = True
super().__init__(*args, **kwargs)
def check(self, **kwargs):
@@ -2444,9 +2649,9 @@ class AutoFieldMixin:
if not self.primary_key:
return [
checks.Error(
- 'AutoFields must set primary_key=True.',
+ "AutoFields must set primary_key=True.",
obj=self,
- id='fields.E100',
+ id="fields.E100",
),
]
else:
@@ -2454,8 +2659,8 @@ class AutoFieldMixin:
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
- del kwargs['blank']
- kwargs['primary_key'] = True
+ del kwargs["blank"]
+ kwargs["primary_key"] = True
return name, path, args, kwargs
def validate(self, value, model_instance):
@@ -2502,34 +2707,35 @@ class AutoFieldMeta(type):
return (BigAutoField, SmallAutoField)
def __instancecheck__(self, instance):
- return isinstance(instance, self._subclasses) or super().__instancecheck__(instance)
+ return isinstance(instance, self._subclasses) or super().__instancecheck__(
+ instance
+ )
def __subclasscheck__(self, subclass):
- return issubclass(subclass, self._subclasses) or super().__subclasscheck__(subclass)
+ return issubclass(subclass, self._subclasses) or super().__subclasscheck__(
+ subclass
+ )
class AutoField(AutoFieldMixin, IntegerField, metaclass=AutoFieldMeta):
-
def get_internal_type(self):
- return 'AutoField'
+ return "AutoField"
def rel_db_type(self, connection):
return IntegerField().db_type(connection=connection)
class BigAutoField(AutoFieldMixin, BigIntegerField):
-
def get_internal_type(self):
- return 'BigAutoField'
+ return "BigAutoField"
def rel_db_type(self, connection):
return BigIntegerField().db_type(connection=connection)
class SmallAutoField(AutoFieldMixin, SmallIntegerField):
-
def get_internal_type(self):
- return 'SmallAutoField'
+ return "SmallAutoField"
def rel_db_type(self, connection):
return SmallIntegerField().db_type(connection=connection)
diff --git a/django/db/models/fields/files.py b/django/db/models/fields/files.py
index 18900f7b85..33a1176ed6 100644
--- a/django/db/models/fields/files.py
+++ b/django/db/models/fields/files.py
@@ -24,7 +24,7 @@ class FieldFile(File):
def __eq__(self, other):
# Older code may be expecting FileField values to be simple strings.
# By overriding the == operator, it can remain backwards compatibility.
- if hasattr(other, 'name'):
+ if hasattr(other, "name"):
return self.name == other.name
return self.name == other
@@ -37,12 +37,14 @@ class FieldFile(File):
def _require_file(self):
if not self:
- raise ValueError("The '%s' attribute has no file associated with it." % self.field.name)
+ raise ValueError(
+ "The '%s' attribute has no file associated with it." % self.field.name
+ )
def _get_file(self):
self._require_file()
- if getattr(self, '_file', None) is None:
- self._file = self.storage.open(self.name, 'rb')
+ if getattr(self, "_file", None) is None:
+ self._file = self.storage.open(self.name, "rb")
return self._file
def _set_file(self, file):
@@ -70,13 +72,14 @@ class FieldFile(File):
return self.file.size
return self.storage.size(self.name)
- def open(self, mode='rb'):
+ def open(self, mode="rb"):
self._require_file()
- if getattr(self, '_file', None) is None:
+ if getattr(self, "_file", None) is None:
self.file = self.storage.open(self.name, mode)
else:
self.file.open(mode)
return self
+
# open() doesn't alter the file's contents, but it does reset the pointer
open.alters_data = True
@@ -93,6 +96,7 @@ class FieldFile(File):
# Save the object because it has changed, unless save is False
if save:
self.instance.save()
+
save.alters_data = True
def delete(self, save=True):
@@ -100,7 +104,7 @@ class FieldFile(File):
return
# Only close the file if it's already open, which we know by the
# presence of self._file
- if hasattr(self, '_file'):
+ if hasattr(self, "_file"):
self.close()
del self.file
@@ -112,15 +116,16 @@ class FieldFile(File):
if save:
self.instance.save()
+
delete.alters_data = True
@property
def closed(self):
- file = getattr(self, '_file', None)
+ file = getattr(self, "_file", None)
return file is None or file.closed
def close(self):
- file = getattr(self, '_file', None)
+ file = getattr(self, "_file", None)
if file is not None:
file.close()
@@ -129,12 +134,12 @@ class FieldFile(File):
# the file's name. Everything else will be restored later, by
# FileDescriptor below.
return {
- 'name': self.name,
- 'closed': False,
- '_committed': True,
- '_file': None,
- 'instance': self.instance,
- 'field': self.field,
+ "name": self.name,
+ "closed": False,
+ "_committed": True,
+ "_file": None,
+ "instance": self.instance,
+ "field": self.field,
}
def __setstate__(self, state):
@@ -156,6 +161,7 @@ class FileDescriptor(DeferredAttribute):
>>> with open('/path/to/hello.world') as f:
... instance.file = File(f)
"""
+
def __get__(self, instance, cls=None):
if instance is None:
return self
@@ -198,7 +204,7 @@ class FileDescriptor(DeferredAttribute):
# Finally, because of the (some would say boneheaded) way pickle works,
# the underlying FieldFile might not actually itself have an associated
# file. So we need to reset the details of the FieldFile in those cases.
- elif isinstance(file, FieldFile) and not hasattr(file, 'field'):
+ elif isinstance(file, FieldFile) and not hasattr(file, "field"):
file.instance = instance
file.field = self.field
file.storage = self.field.storage
@@ -225,8 +231,10 @@ class FileField(Field):
description = _("File")
- def __init__(self, verbose_name=None, name=None, upload_to='', storage=None, **kwargs):
- self._primary_key_set_explicitly = 'primary_key' in kwargs
+ def __init__(
+ self, verbose_name=None, name=None, upload_to="", storage=None, **kwargs
+ ):
+ self._primary_key_set_explicitly = "primary_key" in kwargs
self.storage = storage or default_storage
if callable(self.storage):
@@ -236,11 +244,15 @@ class FileField(Field):
if not isinstance(self.storage, Storage):
raise TypeError(
"%s.storage must be a subclass/instance of %s.%s"
- % (self.__class__.__qualname__, Storage.__module__, Storage.__qualname__)
+ % (
+ self.__class__.__qualname__,
+ Storage.__module__,
+ Storage.__qualname__,
+ )
)
self.upload_to = upload_to
- kwargs.setdefault('max_length', 100)
+ kwargs.setdefault("max_length", 100)
super().__init__(verbose_name, name, **kwargs)
def check(self, **kwargs):
@@ -254,23 +266,24 @@ class FileField(Field):
if self._primary_key_set_explicitly:
return [
checks.Error(
- "'primary_key' is not a valid argument for a %s." % self.__class__.__name__,
+ "'primary_key' is not a valid argument for a %s."
+ % self.__class__.__name__,
obj=self,
- id='fields.E201',
+ id="fields.E201",
)
]
else:
return []
def _check_upload_to(self):
- if isinstance(self.upload_to, str) and self.upload_to.startswith('/'):
+ if isinstance(self.upload_to, str) and self.upload_to.startswith("/"):
return [
checks.Error(
"%s's 'upload_to' argument must be a relative path, not an "
"absolute path." % self.__class__.__name__,
obj=self,
- id='fields.E202',
- hint='Remove the leading slash.',
+ id="fields.E202",
+ hint="Remove the leading slash.",
)
]
else:
@@ -280,9 +293,9 @@ class FileField(Field):
name, path, args, kwargs = super().deconstruct()
if kwargs.get("max_length") == 100:
del kwargs["max_length"]
- kwargs['upload_to'] = self.upload_to
+ kwargs["upload_to"] = self.upload_to
if self.storage is not default_storage:
- kwargs['storage'] = getattr(self, '_storage_callable', self.storage)
+ kwargs["storage"] = getattr(self, "_storage_callable", self.storage)
return name, path, args, kwargs
def get_internal_type(self):
@@ -329,14 +342,16 @@ class FileField(Field):
if data is not None:
# This value will be converted to str and stored in the
# database, so leaving False as-is is not acceptable.
- setattr(instance, self.name, data or '')
+ setattr(instance, self.name, data or "")
def formfield(self, **kwargs):
- return super().formfield(**{
- 'form_class': forms.FileField,
- 'max_length': self.max_length,
- **kwargs,
- })
+ return super().formfield(
+ **{
+ "form_class": forms.FileField,
+ "max_length": self.max_length,
+ **kwargs,
+ }
+ )
class ImageFileDescriptor(FileDescriptor):
@@ -344,6 +359,7 @@ class ImageFileDescriptor(FileDescriptor):
Just like the FileDescriptor, but for ImageFields. The only difference is
assigning the width/height to the width_field/height_field, if appropriate.
"""
+
def __set__(self, instance, value):
previous_file = instance.__dict__.get(self.field.attname)
super().__set__(instance, value)
@@ -364,7 +380,7 @@ class ImageFileDescriptor(FileDescriptor):
class ImageFieldFile(ImageFile, FieldFile):
def delete(self, save=True):
# Clear the image dimensions cache
- if hasattr(self, '_dimensions_cache'):
+ if hasattr(self, "_dimensions_cache"):
del self._dimensions_cache
super().delete(save)
@@ -374,7 +390,14 @@ class ImageField(FileField):
descriptor_class = ImageFileDescriptor
description = _("Image")
- def __init__(self, verbose_name=None, name=None, width_field=None, height_field=None, **kwargs):
+ def __init__(
+ self,
+ verbose_name=None,
+ name=None,
+ width_field=None,
+ height_field=None,
+ **kwargs,
+ ):
self.width_field, self.height_field = width_field, height_field
super().__init__(verbose_name, name, **kwargs)
@@ -390,11 +413,13 @@ class ImageField(FileField):
except ImportError:
return [
checks.Error(
- 'Cannot use ImageField because Pillow is not installed.',
- hint=('Get Pillow at https://pypi.org/project/Pillow/ '
- 'or run command "python -m pip install Pillow".'),
+ "Cannot use ImageField because Pillow is not installed.",
+ hint=(
+ "Get Pillow at https://pypi.org/project/Pillow/ "
+ 'or run command "python -m pip install Pillow".'
+ ),
obj=self,
- id='fields.E210',
+ id="fields.E210",
)
]
else:
@@ -403,9 +428,9 @@ class ImageField(FileField):
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if self.width_field:
- kwargs['width_field'] = self.width_field
+ kwargs["width_field"] = self.width_field
if self.height_field:
- kwargs['height_field'] = self.height_field
+ kwargs["height_field"] = self.height_field
return name, path, args, kwargs
def contribute_to_class(self, cls, name, **kwargs):
@@ -445,9 +470,9 @@ class ImageField(FileField):
if not file and not force:
return
- dimension_fields_filled = not(
- (self.width_field and not getattr(instance, self.width_field)) or
- (self.height_field and not getattr(instance, self.height_field))
+ dimension_fields_filled = not (
+ (self.width_field and not getattr(instance, self.width_field))
+ or (self.height_field and not getattr(instance, self.height_field))
)
# When both dimension fields have values, we are most likely loading
# data from the database or updating an image field that already had
@@ -475,7 +500,9 @@ class ImageField(FileField):
setattr(instance, self.height_field, height)
def formfield(self, **kwargs):
- return super().formfield(**{
- 'form_class': forms.ImageField,
- **kwargs,
- })
+ return super().formfield(
+ **{
+ "form_class": forms.ImageField,
+ **kwargs,
+ }
+ )
diff --git a/django/db/models/fields/json.py b/django/db/models/fields/json.py
index efb4e2f6ed..fdca700c9d 100644
--- a/django/db/models/fields/json.py
+++ b/django/db/models/fields/json.py
@@ -10,32 +10,36 @@ from django.utils.translation import gettext_lazy as _
from . import Field
from .mixins import CheckFieldDefaultMixin
-__all__ = ['JSONField']
+__all__ = ["JSONField"]
class JSONField(CheckFieldDefaultMixin, Field):
empty_strings_allowed = False
- description = _('A JSON object')
+ description = _("A JSON object")
default_error_messages = {
- 'invalid': _('Value must be valid JSON.'),
+ "invalid": _("Value must be valid JSON."),
}
- _default_hint = ('dict', '{}')
+ _default_hint = ("dict", "{}")
def __init__(
- self, verbose_name=None, name=None, encoder=None, decoder=None,
+ self,
+ verbose_name=None,
+ name=None,
+ encoder=None,
+ decoder=None,
**kwargs,
):
if encoder and not callable(encoder):
- raise ValueError('The encoder parameter must be a callable object.')
+ raise ValueError("The encoder parameter must be a callable object.")
if decoder and not callable(decoder):
- raise ValueError('The decoder parameter must be a callable object.')
+ raise ValueError("The decoder parameter must be a callable object.")
self.encoder = encoder
self.decoder = decoder
super().__init__(verbose_name, name, **kwargs)
def check(self, **kwargs):
errors = super().check(**kwargs)
- databases = kwargs.get('databases') or []
+ databases = kwargs.get("databases") or []
errors.extend(self._check_supported(databases))
return errors
@@ -46,20 +50,19 @@ class JSONField(CheckFieldDefaultMixin, Field):
continue
connection = connections[db]
if (
- self.model._meta.required_db_vendor and
- self.model._meta.required_db_vendor != connection.vendor
+ self.model._meta.required_db_vendor
+ and self.model._meta.required_db_vendor != connection.vendor
):
continue
if not (
- 'supports_json_field' in self.model._meta.required_db_features or
- connection.features.supports_json_field
+ "supports_json_field" in self.model._meta.required_db_features
+ or connection.features.supports_json_field
):
errors.append(
checks.Error(
- '%s does not support JSONFields.'
- % connection.display_name,
+ "%s does not support JSONFields." % connection.display_name,
obj=self.model,
- id='fields.E180',
+ id="fields.E180",
)
)
return errors
@@ -67,9 +70,9 @@ class JSONField(CheckFieldDefaultMixin, Field):
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if self.encoder is not None:
- kwargs['encoder'] = self.encoder
+ kwargs["encoder"] = self.encoder
if self.decoder is not None:
- kwargs['decoder'] = self.decoder
+ kwargs["decoder"] = self.decoder
return name, path, args, kwargs
def from_db_value(self, value, expression, connection):
@@ -85,7 +88,7 @@ class JSONField(CheckFieldDefaultMixin, Field):
return value
def get_internal_type(self):
- return 'JSONField'
+ return "JSONField"
def get_prep_value(self, value):
if value is None:
@@ -104,64 +107,66 @@ class JSONField(CheckFieldDefaultMixin, Field):
json.dumps(value, cls=self.encoder)
except TypeError:
raise exceptions.ValidationError(
- self.error_messages['invalid'],
- code='invalid',
- params={'value': value},
+ self.error_messages["invalid"],
+ code="invalid",
+ params={"value": value},
)
def value_to_string(self, obj):
return self.value_from_object(obj)
def formfield(self, **kwargs):
- return super().formfield(**{
- 'form_class': forms.JSONField,
- 'encoder': self.encoder,
- 'decoder': self.decoder,
- **kwargs,
- })
+ return super().formfield(
+ **{
+ "form_class": forms.JSONField,
+ "encoder": self.encoder,
+ "decoder": self.decoder,
+ **kwargs,
+ }
+ )
def compile_json_path(key_transforms, include_root=True):
- path = ['$'] if include_root else []
+ path = ["$"] if include_root else []
for key_transform in key_transforms:
try:
num = int(key_transform)
except ValueError: # non-integer
- path.append('.')
+ path.append(".")
path.append(json.dumps(key_transform))
else:
- path.append('[%s]' % num)
- return ''.join(path)
+ path.append("[%s]" % num)
+ return "".join(path)
class DataContains(PostgresOperatorLookup):
- lookup_name = 'contains'
- postgres_operator = '@>'
+ lookup_name = "contains"
+ postgres_operator = "@>"
def as_sql(self, compiler, connection):
if not connection.features.supports_json_field_contains:
raise NotSupportedError(
- 'contains lookup is not supported on this database backend.'
+ "contains lookup is not supported on this database backend."
)
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = tuple(lhs_params) + tuple(rhs_params)
- return 'JSON_CONTAINS(%s, %s)' % (lhs, rhs), params
+ return "JSON_CONTAINS(%s, %s)" % (lhs, rhs), params
class ContainedBy(PostgresOperatorLookup):
- lookup_name = 'contained_by'
- postgres_operator = '<@'
+ lookup_name = "contained_by"
+ postgres_operator = "<@"
def as_sql(self, compiler, connection):
if not connection.features.supports_json_field_contains:
raise NotSupportedError(
- 'contained_by lookup is not supported on this database backend.'
+ "contained_by lookup is not supported on this database backend."
)
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = tuple(rhs_params) + tuple(lhs_params)
- return 'JSON_CONTAINS(%s, %s)' % (rhs, lhs), params
+ return "JSON_CONTAINS(%s, %s)" % (rhs, lhs), params
class HasKeyLookup(PostgresOperatorLookup):
@@ -170,11 +175,13 @@ class HasKeyLookup(PostgresOperatorLookup):
def as_sql(self, compiler, connection, template=None):
# Process JSON path from the left-hand side.
if isinstance(self.lhs, KeyTransform):
- lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(compiler, connection)
+ lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(
+ compiler, connection
+ )
lhs_json_path = compile_json_path(lhs_key_transforms)
else:
lhs, lhs_params = self.process_lhs(compiler, connection)
- lhs_json_path = '$'
+ lhs_json_path = "$"
sql = template % lhs
# Process JSON path from the right-hand side.
rhs = self.rhs
@@ -186,20 +193,27 @@ class HasKeyLookup(PostgresOperatorLookup):
*_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
else:
rhs_key_transforms = [key]
- rhs_params.append('%s%s' % (
- lhs_json_path,
- compile_json_path(rhs_key_transforms, include_root=False),
- ))
+ rhs_params.append(
+ "%s%s"
+ % (
+ lhs_json_path,
+ compile_json_path(rhs_key_transforms, include_root=False),
+ )
+ )
# Add condition for each key.
if self.logical_operator:
- sql = '(%s)' % self.logical_operator.join([sql] * len(rhs_params))
+ sql = "(%s)" % self.logical_operator.join([sql] * len(rhs_params))
return sql, tuple(lhs_params) + tuple(rhs_params)
def as_mysql(self, compiler, connection):
- return self.as_sql(compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)")
+ return self.as_sql(
+ compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)"
+ )
def as_oracle(self, compiler, connection):
- sql, params = self.as_sql(compiler, connection, template="JSON_EXISTS(%s, '%%s')")
+ sql, params = self.as_sql(
+ compiler, connection, template="JSON_EXISTS(%s, '%%s')"
+ )
# Add paths directly into SQL because path expressions cannot be passed
# as bind variables on Oracle.
return sql % tuple(params), []
@@ -213,28 +227,30 @@ class HasKeyLookup(PostgresOperatorLookup):
return super().as_postgresql(compiler, connection)
def as_sqlite(self, compiler, connection):
- return self.as_sql(compiler, connection, template='JSON_TYPE(%s, %%s) IS NOT NULL')
+ return self.as_sql(
+ compiler, connection, template="JSON_TYPE(%s, %%s) IS NOT NULL"
+ )
class HasKey(HasKeyLookup):
- lookup_name = 'has_key'
- postgres_operator = '?'
+ lookup_name = "has_key"
+ postgres_operator = "?"
prepare_rhs = False
class HasKeys(HasKeyLookup):
- lookup_name = 'has_keys'
- postgres_operator = '?&'
- logical_operator = ' AND '
+ lookup_name = "has_keys"
+ postgres_operator = "?&"
+ logical_operator = " AND "
def get_prep_lookup(self):
return [str(item) for item in self.rhs]
class HasAnyKeys(HasKeys):
- lookup_name = 'has_any_keys'
- postgres_operator = '?|'
- logical_operator = ' OR '
+ lookup_name = "has_any_keys"
+ postgres_operator = "?|"
+ logical_operator = " OR "
class CaseInsensitiveMixin:
@@ -244,16 +260,17 @@ class CaseInsensitiveMixin:
Because utf8mb4_bin is a binary collation, comparison of JSON values is
case-sensitive.
"""
+
def process_lhs(self, compiler, connection):
lhs, lhs_params = super().process_lhs(compiler, connection)
- if connection.vendor == 'mysql':
- return 'LOWER(%s)' % lhs, lhs_params
+ if connection.vendor == "mysql":
+ return "LOWER(%s)" % lhs, lhs_params
return lhs, lhs_params
def process_rhs(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
- if connection.vendor == 'mysql':
- return 'LOWER(%s)' % rhs, rhs_params
+ if connection.vendor == "mysql":
+ return "LOWER(%s)" % rhs, rhs_params
return rhs, rhs_params
@@ -263,9 +280,9 @@ class JSONExact(lookups.Exact):
def process_rhs(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
# Treat None lookup values as null.
- if rhs == '%s' and rhs_params == [None]:
- rhs_params = ['null']
- if connection.vendor == 'mysql':
+ if rhs == "%s" and rhs_params == [None]:
+ rhs_params = ["null"]
+ if connection.vendor == "mysql":
func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
rhs = rhs % tuple(func)
return rhs, rhs_params
@@ -285,8 +302,8 @@ JSONField.register_lookup(JSONIContains)
class KeyTransform(Transform):
- postgres_operator = '->'
- postgres_nested_operator = '#>'
+ postgres_operator = "->"
+ postgres_nested_operator = "#>"
def __init__(self, key_name, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -299,41 +316,41 @@ class KeyTransform(Transform):
key_transforms.insert(0, previous.key_name)
previous = previous.lhs
lhs, params = compiler.compile(previous)
- if connection.vendor == 'oracle':
+ if connection.vendor == "oracle":
# Escape string-formatting.
- key_transforms = [key.replace('%', '%%') for key in key_transforms]
+ key_transforms = [key.replace("%", "%%") for key in key_transforms]
return lhs, params, key_transforms
def as_mysql(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
json_path = compile_json_path(key_transforms)
- return 'JSON_EXTRACT(%s, %%s)' % lhs, tuple(params) + (json_path,)
+ return "JSON_EXTRACT(%s, %%s)" % lhs, tuple(params) + (json_path,)
def as_oracle(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
json_path = compile_json_path(key_transforms)
return (
- "COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))" %
- ((lhs, json_path) * 2)
+ "COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))"
+ % ((lhs, json_path) * 2)
), tuple(params) * 2
def as_postgresql(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
if len(key_transforms) > 1:
- sql = '(%s %s %%s)' % (lhs, self.postgres_nested_operator)
+ sql = "(%s %s %%s)" % (lhs, self.postgres_nested_operator)
return sql, tuple(params) + (key_transforms,)
try:
lookup = int(self.key_name)
except ValueError:
lookup = self.key_name
- return '(%s %s %%s)' % (lhs, self.postgres_operator), tuple(params) + (lookup,)
+ return "(%s %s %%s)" % (lhs, self.postgres_operator), tuple(params) + (lookup,)
def as_sqlite(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
json_path = compile_json_path(key_transforms)
- datatype_values = ','.join([
- repr(datatype) for datatype in connection.ops.jsonfield_datatype_values
- ])
+ datatype_values = ",".join(
+ [repr(datatype) for datatype in connection.ops.jsonfield_datatype_values]
+ )
return (
"(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) "
"THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)"
@@ -341,8 +358,8 @@ class KeyTransform(Transform):
class KeyTextTransform(KeyTransform):
- postgres_operator = '->>'
- postgres_nested_operator = '#>>'
+ postgres_operator = "->>"
+ postgres_nested_operator = "#>>"
class KeyTransformTextLookupMixin:
@@ -352,14 +369,16 @@ class KeyTransformTextLookupMixin:
key values to text and performing the lookup on the resulting
representation.
"""
+
def __init__(self, key_transform, *args, **kwargs):
if not isinstance(key_transform, KeyTransform):
raise TypeError(
- 'Transform should be an instance of KeyTransform in order to '
- 'use this lookup.'
+ "Transform should be an instance of KeyTransform in order to "
+ "use this lookup."
)
key_text_transform = KeyTextTransform(
- key_transform.key_name, *key_transform.source_expressions,
+ key_transform.key_name,
+ *key_transform.source_expressions,
**key_transform.extra,
)
super().__init__(key_text_transform, *args, **kwargs)
@@ -376,12 +395,12 @@ class KeyTransformIsNull(lookups.IsNull):
return sql, params
# Column doesn't have a key or IS NULL.
lhs, lhs_params, _ = self.lhs.preprocess_lhs(compiler, connection)
- return '(NOT %s OR %s IS NULL)' % (sql, lhs), tuple(params) + tuple(lhs_params)
+ return "(NOT %s OR %s IS NULL)" % (sql, lhs), tuple(params) + tuple(lhs_params)
def as_sqlite(self, compiler, connection):
- template = 'JSON_TYPE(%s, %%s) IS NULL'
+ template = "JSON_TYPE(%s, %%s) IS NULL"
if not self.rhs:
- template = 'JSON_TYPE(%s, %%s) IS NOT NULL'
+ template = "JSON_TYPE(%s, %%s) IS NOT NULL"
return HasKey(self.lhs.lhs, self.lhs.key_name).as_sql(
compiler,
connection,
@@ -392,26 +411,29 @@ class KeyTransformIsNull(lookups.IsNull):
class KeyTransformIn(lookups.In):
def resolve_expression_parameter(self, compiler, connection, sql, param):
sql, params = super().resolve_expression_parameter(
- compiler, connection, sql, param,
+ compiler,
+ connection,
+ sql,
+ param,
)
if (
- not hasattr(param, 'as_sql') and
- not connection.features.has_native_json_field
+ not hasattr(param, "as_sql")
+ and not connection.features.has_native_json_field
):
- if connection.vendor == 'oracle':
+ if connection.vendor == "oracle":
value = json.loads(param)
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
if isinstance(value, (list, dict)):
- sql = sql % 'JSON_QUERY'
+ sql = sql % "JSON_QUERY"
else:
- sql = sql % 'JSON_VALUE'
- elif connection.vendor == 'mysql' or (
- connection.vendor == 'sqlite' and
- params[0] not in connection.ops.jsonfield_datatype_values
+ sql = sql % "JSON_VALUE"
+ elif connection.vendor == "mysql" or (
+ connection.vendor == "sqlite"
+ and params[0] not in connection.ops.jsonfield_datatype_values
):
sql = "JSON_EXTRACT(%s, '$')"
- if connection.vendor == 'mysql' and connection.mysql_is_mariadb:
- sql = 'JSON_UNQUOTE(%s)' % sql
+ if connection.vendor == "mysql" and connection.mysql_is_mariadb:
+ sql = "JSON_UNQUOTE(%s)" % sql
return sql, params
@@ -420,21 +442,21 @@ class KeyTransformExact(JSONExact):
if isinstance(self.rhs, KeyTransform):
return super(lookups.Exact, self).process_rhs(compiler, connection)
rhs, rhs_params = super().process_rhs(compiler, connection)
- if connection.vendor == 'oracle':
+ if connection.vendor == "oracle":
func = []
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
for value in rhs_params:
value = json.loads(value)
if isinstance(value, (list, dict)):
- func.append(sql % 'JSON_QUERY')
+ func.append(sql % "JSON_QUERY")
else:
- func.append(sql % 'JSON_VALUE')
+ func.append(sql % "JSON_VALUE")
rhs = rhs % tuple(func)
- elif connection.vendor == 'sqlite':
+ elif connection.vendor == "sqlite":
func = []
for value in rhs_params:
if value in connection.ops.jsonfield_datatype_values:
- func.append('%s')
+ func.append("%s")
else:
func.append("JSON_EXTRACT(%s, '$')")
rhs = rhs % tuple(func)
@@ -442,24 +464,28 @@ class KeyTransformExact(JSONExact):
def as_oracle(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
- if rhs_params == ['null']:
+ if rhs_params == ["null"]:
# Field has key and it's NULL.
has_key_expr = HasKey(self.lhs.lhs, self.lhs.key_name)
has_key_sql, has_key_params = has_key_expr.as_oracle(compiler, connection)
- is_null_expr = self.lhs.get_lookup('isnull')(self.lhs, True)
+ is_null_expr = self.lhs.get_lookup("isnull")(self.lhs, True)
is_null_sql, is_null_params = is_null_expr.as_sql(compiler, connection)
return (
- '%s AND %s' % (has_key_sql, is_null_sql),
+ "%s AND %s" % (has_key_sql, is_null_sql),
tuple(has_key_params) + tuple(is_null_params),
)
return super().as_sql(compiler, connection)
-class KeyTransformIExact(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact):
+class KeyTransformIExact(
+ CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact
+):
pass
-class KeyTransformIContains(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains):
+class KeyTransformIContains(
+ CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains
+):
pass
@@ -467,7 +493,9 @@ class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
pass
-class KeyTransformIStartsWith(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith):
+class KeyTransformIStartsWith(
+ CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith
+):
pass
@@ -475,7 +503,9 @@ class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
pass
-class KeyTransformIEndsWith(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith):
+class KeyTransformIEndsWith(
+ CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith
+):
pass
@@ -483,7 +513,9 @@ class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
pass
-class KeyTransformIRegex(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex):
+class KeyTransformIRegex(
+ CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex
+):
pass
@@ -530,7 +562,6 @@ KeyTransform.register_lookup(KeyTransformGte)
class KeyTransformFactory:
-
def __init__(self, key_name):
self.key_name = key_name
diff --git a/django/db/models/fields/mixins.py b/django/db/models/fields/mixins.py
index 3afa8d9304..e7f282210e 100644
--- a/django/db/models/fields/mixins.py
+++ b/django/db/models/fields/mixins.py
@@ -29,22 +29,25 @@ class FieldCacheMixin:
class CheckFieldDefaultMixin:
- _default_hint = ('<valid default>', '<invalid default>')
+ _default_hint = ("<valid default>", "<invalid default>")
def _check_default(self):
- if self.has_default() and self.default is not None and not callable(self.default):
+ if (
+ self.has_default()
+ and self.default is not None
+ and not callable(self.default)
+ ):
return [
checks.Warning(
"%s default should be a callable instead of an instance "
- "so that it's not shared between all field instances." % (
- self.__class__.__name__,
- ),
+ "so that it's not shared between all field instances."
+ % (self.__class__.__name__,),
hint=(
- 'Use a callable instead, e.g., use `%s` instead of '
- '`%s`.' % self._default_hint
+ "Use a callable instead, e.g., use `%s` instead of "
+ "`%s`." % self._default_hint
),
obj=self,
- id='fields.E010',
+ id="fields.E010",
)
]
else:
diff --git a/django/db/models/fields/proxy.py b/django/db/models/fields/proxy.py
index 0ecf04a333..ac02e47a25 100644
--- a/django/db/models/fields/proxy.py
+++ b/django/db/models/fields/proxy.py
@@ -13,6 +13,6 @@ class OrderWrt(fields.IntegerField):
"""
def __init__(self, *args, **kwargs):
- kwargs['name'] = '_order'
- kwargs['editable'] = False
+ kwargs["name"] = "_order"
+ kwargs["editable"] = False
super().__init__(*args, **kwargs)
diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py
index 11407ac902..1cf447c6d4 100644
--- a/django/db/models/fields/related.py
+++ b/django/db/models/fields/related.py
@@ -19,19 +19,25 @@ from django.utils.translation import gettext_lazy as _
from . import Field
from .mixins import FieldCacheMixin
from .related_descriptors import (
- ForeignKeyDeferredAttribute, ForwardManyToOneDescriptor,
- ForwardOneToOneDescriptor, ManyToManyDescriptor,
- ReverseManyToOneDescriptor, ReverseOneToOneDescriptor,
+ ForeignKeyDeferredAttribute,
+ ForwardManyToOneDescriptor,
+ ForwardOneToOneDescriptor,
+ ManyToManyDescriptor,
+ ReverseManyToOneDescriptor,
+ ReverseOneToOneDescriptor,
)
from .related_lookups import (
- RelatedExact, RelatedGreaterThan, RelatedGreaterThanOrEqual, RelatedIn,
- RelatedIsNull, RelatedLessThan, RelatedLessThanOrEqual,
-)
-from .reverse_related import (
- ForeignObjectRel, ManyToManyRel, ManyToOneRel, OneToOneRel,
+ RelatedExact,
+ RelatedGreaterThan,
+ RelatedGreaterThanOrEqual,
+ RelatedIn,
+ RelatedIsNull,
+ RelatedLessThan,
+ RelatedLessThanOrEqual,
)
+from .reverse_related import ForeignObjectRel, ManyToManyRel, ManyToOneRel, OneToOneRel
-RECURSIVE_RELATIONSHIP_CONSTANT = 'self'
+RECURSIVE_RELATIONSHIP_CONSTANT = "self"
def resolve_relation(scope_model, relation):
@@ -119,19 +125,25 @@ class RelatedField(FieldCacheMixin, Field):
def _check_related_name_is_valid(self):
import keyword
+
related_name = self.remote_field.related_name
if related_name is None:
return []
- is_valid_id = not keyword.iskeyword(related_name) and related_name.isidentifier()
- if not (is_valid_id or related_name.endswith('+')):
+ is_valid_id = (
+ not keyword.iskeyword(related_name) and related_name.isidentifier()
+ )
+ if not (is_valid_id or related_name.endswith("+")):
return [
checks.Error(
- "The name '%s' is invalid related_name for field %s.%s" %
- (self.remote_field.related_name, self.model._meta.object_name,
- self.name),
+ "The name '%s' is invalid related_name for field %s.%s"
+ % (
+ self.remote_field.related_name,
+ self.model._meta.object_name,
+ self.name,
+ ),
hint="Related name must be a valid Python identifier or end with a '+'",
obj=self,
- id='fields.E306',
+ id="fields.E306",
)
]
return []
@@ -141,15 +153,17 @@ class RelatedField(FieldCacheMixin, Field):
return []
rel_query_name = self.related_query_name()
errors = []
- if rel_query_name.endswith('_'):
+ if rel_query_name.endswith("_"):
errors.append(
checks.Error(
"Reverse query name '%s' must not end with an underscore."
% rel_query_name,
- hint=("Add or change a related_name or related_query_name "
- "argument for this field."),
+ hint=(
+ "Add or change a related_name or related_query_name "
+ "argument for this field."
+ ),
obj=self,
- id='fields.E308',
+ id="fields.E308",
)
)
if LOOKUP_SEP in rel_query_name:
@@ -157,10 +171,12 @@ class RelatedField(FieldCacheMixin, Field):
checks.Error(
"Reverse query name '%s' must not contain '%s'."
% (rel_query_name, LOOKUP_SEP),
- hint=("Add or change a related_name or related_query_name "
- "argument for this field."),
+ hint=(
+ "Add or change a related_name or related_query_name "
+ "argument for this field."
+ ),
obj=self,
- id='fields.E309',
+ id="fields.E309",
)
)
return errors
@@ -168,29 +184,38 @@ class RelatedField(FieldCacheMixin, Field):
def _check_relation_model_exists(self):
rel_is_missing = self.remote_field.model not in self.opts.apps.get_models()
rel_is_string = isinstance(self.remote_field.model, str)
- model_name = self.remote_field.model if rel_is_string else self.remote_field.model._meta.object_name
- if rel_is_missing and (rel_is_string or not self.remote_field.model._meta.swapped):
+ model_name = (
+ self.remote_field.model
+ if rel_is_string
+ else self.remote_field.model._meta.object_name
+ )
+ if rel_is_missing and (
+ rel_is_string or not self.remote_field.model._meta.swapped
+ ):
return [
checks.Error(
"Field defines a relation with model '%s', which is either "
"not installed, or is abstract." % model_name,
obj=self,
- id='fields.E300',
+ id="fields.E300",
)
]
return []
def _check_referencing_to_swapped_model(self):
- if (self.remote_field.model not in self.opts.apps.get_models() and
- not isinstance(self.remote_field.model, str) and
- self.remote_field.model._meta.swapped):
+ if (
+ self.remote_field.model not in self.opts.apps.get_models()
+ and not isinstance(self.remote_field.model, str)
+ and self.remote_field.model._meta.swapped
+ ):
return [
checks.Error(
"Field defines a relation with the model '%s', which has "
"been swapped out." % self.remote_field.model._meta.label,
- hint="Update the relation to point at 'settings.%s'." % self.remote_field.model._meta.swappable,
+ hint="Update the relation to point at 'settings.%s'."
+ % self.remote_field.model._meta.swappable,
obj=self,
- id='fields.E301',
+ id="fields.E301",
)
]
return []
@@ -227,7 +252,7 @@ class RelatedField(FieldCacheMixin, Field):
rel_name = self.remote_field.get_accessor_name() # i. e. "model_set"
rel_query_name = self.related_query_name() # i. e. "model"
# i.e. "app_label.Model.field".
- field_name = '%s.%s' % (opts.label, self.name)
+ field_name = "%s.%s" % (opts.label, self.name)
# Check clashes between accessor or reverse query name of `field`
# and any other field name -- i.e. accessor for Model.foreign is
@@ -235,28 +260,35 @@ class RelatedField(FieldCacheMixin, Field):
potential_clashes = rel_opts.fields + rel_opts.many_to_many
for clash_field in potential_clashes:
# i.e. "app_label.Target.model_set".
- clash_name = '%s.%s' % (rel_opts.label, clash_field.name)
+ clash_name = "%s.%s" % (rel_opts.label, clash_field.name)
if not rel_is_hidden and clash_field.name == rel_name:
errors.append(
checks.Error(
f"Reverse accessor '{rel_opts.object_name}.{rel_name}' "
f"for '{field_name}' clashes with field name "
f"'{clash_name}'.",
- hint=("Rename field '%s', or add/change a related_name "
- "argument to the definition for field '%s'.") % (clash_name, field_name),
+ hint=(
+ "Rename field '%s', or add/change a related_name "
+ "argument to the definition for field '%s'."
+ )
+ % (clash_name, field_name),
obj=self,
- id='fields.E302',
+ id="fields.E302",
)
)
if clash_field.name == rel_query_name:
errors.append(
checks.Error(
- "Reverse query name for '%s' clashes with field name '%s'." % (field_name, clash_name),
- hint=("Rename field '%s', or add/change a related_name "
- "argument to the definition for field '%s'.") % (clash_name, field_name),
+ "Reverse query name for '%s' clashes with field name '%s'."
+ % (field_name, clash_name),
+ hint=(
+ "Rename field '%s', or add/change a related_name "
+ "argument to the definition for field '%s'."
+ )
+ % (clash_name, field_name),
obj=self,
- id='fields.E303',
+ id="fields.E303",
)
)
@@ -266,7 +298,7 @@ class RelatedField(FieldCacheMixin, Field):
potential_clashes = (r for r in rel_opts.related_objects if r.field is not self)
for clash_field in potential_clashes:
# i.e. "app_label.Model.m2m".
- clash_name = '%s.%s' % (
+ clash_name = "%s.%s" % (
clash_field.related_model._meta.label,
clash_field.field.name,
)
@@ -276,10 +308,13 @@ class RelatedField(FieldCacheMixin, Field):
f"Reverse accessor '{rel_opts.object_name}.{rel_name}' "
f"for '{field_name}' clashes with reverse accessor for "
f"'{clash_name}'.",
- hint=("Add or change a related_name argument "
- "to the definition for '%s' or '%s'.") % (field_name, clash_name),
+ hint=(
+ "Add or change a related_name argument "
+ "to the definition for '%s' or '%s'."
+ )
+ % (field_name, clash_name),
obj=self,
- id='fields.E304',
+ id="fields.E304",
)
)
@@ -288,10 +323,13 @@ class RelatedField(FieldCacheMixin, Field):
checks.Error(
"Reverse query name for '%s' clashes with reverse query name for '%s'."
% (field_name, clash_name),
- hint=("Add or change a related_name argument "
- "to the definition for '%s' or '%s'.") % (field_name, clash_name),
+ hint=(
+ "Add or change a related_name argument "
+ "to the definition for '%s' or '%s'."
+ )
+ % (field_name, clash_name),
obj=self,
- id='fields.E305',
+ id="fields.E305",
)
)
@@ -315,32 +353,35 @@ class RelatedField(FieldCacheMixin, Field):
related_name = self.opts.default_related_name
if related_name:
related_name = related_name % {
- 'class': cls.__name__.lower(),
- 'model_name': cls._meta.model_name.lower(),
- 'app_label': cls._meta.app_label.lower()
+ "class": cls.__name__.lower(),
+ "model_name": cls._meta.model_name.lower(),
+ "app_label": cls._meta.app_label.lower(),
}
self.remote_field.related_name = related_name
if self.remote_field.related_query_name:
related_query_name = self.remote_field.related_query_name % {
- 'class': cls.__name__.lower(),
- 'app_label': cls._meta.app_label.lower(),
+ "class": cls.__name__.lower(),
+ "app_label": cls._meta.app_label.lower(),
}
self.remote_field.related_query_name = related_query_name
def resolve_related_class(model, related, field):
field.remote_field.model = related
field.do_related_class(related, model)
- lazy_related_operation(resolve_related_class, cls, self.remote_field.model, field=self)
+
+ lazy_related_operation(
+ resolve_related_class, cls, self.remote_field.model, field=self
+ )
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if self._limit_choices_to:
- kwargs['limit_choices_to'] = self._limit_choices_to
+ kwargs["limit_choices_to"] = self._limit_choices_to
if self._related_name is not None:
- kwargs['related_name'] = self._related_name
+ kwargs["related_name"] = self._related_name
if self._related_query_name is not None:
- kwargs['related_query_name'] = self._related_query_name
+ kwargs["related_query_name"] = self._related_query_name
return name, path, args, kwargs
def get_forward_related_filter(self, obj):
@@ -352,7 +393,7 @@ class RelatedField(FieldCacheMixin, Field):
self.related_field.model.
"""
return {
- '%s__%s' % (self.name, rh_field.name): getattr(obj, rh_field.attname)
+ "%s__%s" % (self.name, rh_field.name): getattr(obj, rh_field.attname)
for _, rh_field in self.related_fields
}
@@ -391,9 +432,10 @@ class RelatedField(FieldCacheMixin, Field):
return None
def set_attributes_from_rel(self):
- self.name = (
- self.name or
- (self.remote_field.model._meta.model_name + '_' + self.remote_field.model._meta.pk.name)
+ self.name = self.name or (
+ self.remote_field.model._meta.model_name
+ + "_"
+ + self.remote_field.model._meta.pk.name
)
if self.verbose_name is None:
self.verbose_name = self.remote_field.model._meta.verbose_name
@@ -423,14 +465,16 @@ class RelatedField(FieldCacheMixin, Field):
being constructed.
"""
defaults = {}
- if hasattr(self.remote_field, 'get_related_field'):
+ if hasattr(self.remote_field, "get_related_field"):
# If this is a callable, do not invoke it here. Just pass
# it in the defaults for when the form class will later be
# instantiated.
limit_choices_to = self.remote_field.limit_choices_to
- defaults.update({
- 'limit_choices_to': limit_choices_to,
- })
+ defaults.update(
+ {
+ "limit_choices_to": limit_choices_to,
+ }
+ )
defaults.update(kwargs)
return super().formfield(**defaults)
@@ -439,7 +483,11 @@ class RelatedField(FieldCacheMixin, Field):
Define the name that can be used to identify this related object in a
table-spanning query.
"""
- return self.remote_field.related_query_name or self.remote_field.related_name or self.opts.model_name
+ return (
+ self.remote_field.related_query_name
+ or self.remote_field.related_name
+ or self.opts.model_name
+ )
@property
def target_field(self):
@@ -450,7 +498,8 @@ class RelatedField(FieldCacheMixin, Field):
target_fields = self.path_infos[-1].target_fields
if len(target_fields) > 1:
raise exceptions.FieldError(
- "The relation has multiple target fields, but only single target field was asked for")
+ "The relation has multiple target fields, but only single target field was asked for"
+ )
return target_fields[0]
def get_cache_name(self):
@@ -473,13 +522,25 @@ class ForeignObject(RelatedField):
forward_related_accessor_class = ForwardManyToOneDescriptor
rel_class = ForeignObjectRel
- def __init__(self, to, on_delete, from_fields, to_fields, rel=None, related_name=None,
- related_query_name=None, limit_choices_to=None, parent_link=False,
- swappable=True, **kwargs):
+ def __init__(
+ self,
+ to,
+ on_delete,
+ from_fields,
+ to_fields,
+ rel=None,
+ related_name=None,
+ related_query_name=None,
+ limit_choices_to=None,
+ parent_link=False,
+ swappable=True,
+ **kwargs,
+ ):
if rel is None:
rel = self.rel_class(
- self, to,
+ self,
+ to,
related_name=related_name,
related_query_name=related_query_name,
limit_choices_to=limit_choices_to,
@@ -502,8 +563,8 @@ class ForeignObject(RelatedField):
def __copy__(self):
obj = super().__copy__()
# Remove any cached PathInfo values.
- obj.__dict__.pop('path_infos', None)
- obj.__dict__.pop('reverse_path_infos', None)
+ obj.__dict__.pop("path_infos", None)
+ obj.__dict__.pop("reverse_path_infos", None)
return obj
def check(self, **kwargs):
@@ -530,7 +591,7 @@ class ForeignObject(RelatedField):
"model '%s'."
% (to_field, self.remote_field.model._meta.label),
obj=self,
- id='fields.E312',
+ id="fields.E312",
)
)
return errors
@@ -551,21 +612,22 @@ class ForeignObject(RelatedField):
unique_foreign_fields = {
frozenset([f.name])
for f in self.remote_field.model._meta.get_fields()
- if getattr(f, 'unique', False)
+ if getattr(f, "unique", False)
}
- unique_foreign_fields.update({
- frozenset(ut)
- for ut in self.remote_field.model._meta.unique_together
- })
- unique_foreign_fields.update({
- frozenset(uc.fields)
- for uc in self.remote_field.model._meta.total_unique_constraints
- })
+ unique_foreign_fields.update(
+ {frozenset(ut) for ut in self.remote_field.model._meta.unique_together}
+ )
+ unique_foreign_fields.update(
+ {
+ frozenset(uc.fields)
+ for uc in self.remote_field.model._meta.total_unique_constraints
+ }
+ )
foreign_fields = {f.name for f in self.foreign_related_fields}
has_unique_constraint = any(u <= foreign_fields for u in unique_foreign_fields)
if not has_unique_constraint and len(self.foreign_related_fields) > 1:
- field_combination = ', '.join(
+ field_combination = ", ".join(
"'%s'" % rel_field.name for rel_field in self.foreign_related_fields
)
model_name = self.remote_field.model.__name__
@@ -574,13 +636,13 @@ class ForeignObject(RelatedField):
"No subset of the fields %s on model '%s' is unique."
% (field_combination, model_name),
hint=(
- 'Mark a single field as unique=True or add a set of '
- 'fields to a unique constraint (via unique_together '
- 'or a UniqueConstraint (without condition) in the '
- 'model Meta.constraints).'
+ "Mark a single field as unique=True or add a set of "
+ "fields to a unique constraint (via unique_together "
+ "or a UniqueConstraint (without condition) in the "
+ "model Meta.constraints)."
),
obj=self,
- id='fields.E310',
+ id="fields.E310",
)
]
elif not has_unique_constraint:
@@ -591,12 +653,12 @@ class ForeignObject(RelatedField):
"'%s.%s' must be unique because it is referenced by "
"a foreign key." % (model_name, field_name),
hint=(
- 'Add unique=True to this field or add a '
- 'UniqueConstraint (without condition) in the model '
- 'Meta.constraints.'
+ "Add unique=True to this field or add a "
+ "UniqueConstraint (without condition) in the model "
+ "Meta.constraints."
),
obj=self,
- id='fields.E311',
+ id="fields.E311",
)
]
else:
@@ -604,44 +666,48 @@ class ForeignObject(RelatedField):
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
- kwargs['on_delete'] = self.remote_field.on_delete
- kwargs['from_fields'] = self.from_fields
- kwargs['to_fields'] = self.to_fields
+ kwargs["on_delete"] = self.remote_field.on_delete
+ kwargs["from_fields"] = self.from_fields
+ kwargs["to_fields"] = self.to_fields
if self.remote_field.parent_link:
- kwargs['parent_link'] = self.remote_field.parent_link
+ kwargs["parent_link"] = self.remote_field.parent_link
if isinstance(self.remote_field.model, str):
- if '.' in self.remote_field.model:
- app_label, model_name = self.remote_field.model.split('.')
- kwargs['to'] = '%s.%s' % (app_label, model_name.lower())
+ if "." in self.remote_field.model:
+ app_label, model_name = self.remote_field.model.split(".")
+ kwargs["to"] = "%s.%s" % (app_label, model_name.lower())
else:
- kwargs['to'] = self.remote_field.model.lower()
+ kwargs["to"] = self.remote_field.model.lower()
else:
- kwargs['to'] = self.remote_field.model._meta.label_lower
+ kwargs["to"] = self.remote_field.model._meta.label_lower
# If swappable is True, then see if we're actually pointing to the target
# of a swap.
swappable_setting = self.swappable_setting
if swappable_setting is not None:
# If it's already a settings reference, error
- if hasattr(kwargs['to'], "setting_name"):
- if kwargs['to'].setting_name != swappable_setting:
+ if hasattr(kwargs["to"], "setting_name"):
+ if kwargs["to"].setting_name != swappable_setting:
raise ValueError(
"Cannot deconstruct a ForeignKey pointing to a model "
"that is swapped in place of more than one model (%s and %s)"
- % (kwargs['to'].setting_name, swappable_setting)
+ % (kwargs["to"].setting_name, swappable_setting)
)
# Set it
- kwargs['to'] = SettingsReference(
- kwargs['to'],
+ kwargs["to"] = SettingsReference(
+ kwargs["to"],
swappable_setting,
)
return name, path, args, kwargs
def resolve_related_fields(self):
if not self.from_fields or len(self.from_fields) != len(self.to_fields):
- raise ValueError('Foreign Object from and to fields must be the same non-zero length')
+ raise ValueError(
+ "Foreign Object from and to fields must be the same non-zero length"
+ )
if isinstance(self.remote_field.model, str):
- raise ValueError('Related model %r cannot be resolved' % self.remote_field.model)
+ raise ValueError(
+ "Related model %r cannot be resolved" % self.remote_field.model
+ )
related_fields = []
for index in range(len(self.from_fields)):
from_field_name = self.from_fields[index]
@@ -651,8 +717,11 @@ class ForeignObject(RelatedField):
if from_field_name == RECURSIVE_RELATIONSHIP_CONSTANT
else self.opts.get_field(from_field_name)
)
- to_field = (self.remote_field.model._meta.pk if to_field_name is None
- else self.remote_field.model._meta.get_field(to_field_name))
+ to_field = (
+ self.remote_field.model._meta.pk
+ if to_field_name is None
+ else self.remote_field.model._meta.get_field(to_field_name)
+ )
related_fields.append((from_field, to_field))
return related_fields
@@ -670,7 +739,9 @@ class ForeignObject(RelatedField):
@cached_property
def foreign_related_fields(self):
- return tuple(rhs_field for lhs_field, rhs_field in self.related_fields if rhs_field)
+ return tuple(
+ rhs_field for lhs_field, rhs_field in self.related_fields if rhs_field
+ )
def get_local_related_value(self, instance):
return self.get_instance_value_for_fields(instance, self.local_related_fields)
@@ -688,9 +759,11 @@ class ForeignObject(RelatedField):
# instance.pk (that is, parent_ptr_id) when asked for instance.id.
if field.primary_key:
possible_parent_link = opts.get_ancestor_link(field.model)
- if (not possible_parent_link or
- possible_parent_link.primary_key or
- possible_parent_link.model._meta.abstract):
+ if (
+ not possible_parent_link
+ or possible_parent_link.primary_key
+ or possible_parent_link.model._meta.abstract
+ ):
ret.append(instance.pk)
continue
ret.append(getattr(instance, field.attname))
@@ -702,7 +775,9 @@ class ForeignObject(RelatedField):
def get_joining_columns(self, reverse_join=False):
source = self.reverse_related_fields if reverse_join else self.related_fields
- return tuple((lhs_field.column, rhs_field.column) for lhs_field, rhs_field in source)
+ return tuple(
+ (lhs_field.column, rhs_field.column) for lhs_field, rhs_field in source
+ )
def get_reverse_joining_columns(self):
return self.get_joining_columns(reverse_join=True)
@@ -740,15 +815,17 @@ class ForeignObject(RelatedField):
"""Get path from this field to the related model."""
opts = self.remote_field.model._meta
from_opts = self.model._meta
- return [PathInfo(
- from_opts=from_opts,
- to_opts=opts,
- target_fields=self.foreign_related_fields,
- join_field=self,
- m2m=False,
- direct=True,
- filtered_relation=filtered_relation,
- )]
+ return [
+ PathInfo(
+ from_opts=from_opts,
+ to_opts=opts,
+ target_fields=self.foreign_related_fields,
+ join_field=self,
+ m2m=False,
+ direct=True,
+ filtered_relation=filtered_relation,
+ )
+ ]
@cached_property
def path_infos(self):
@@ -758,15 +835,17 @@ class ForeignObject(RelatedField):
"""Get path from the related model to this field's model."""
opts = self.model._meta
from_opts = self.remote_field.model._meta
- return [PathInfo(
- from_opts=from_opts,
- to_opts=opts,
- target_fields=(opts.pk,),
- join_field=self.remote_field,
- m2m=not self.unique,
- direct=False,
- filtered_relation=filtered_relation,
- )]
+ return [
+ PathInfo(
+ from_opts=from_opts,
+ to_opts=opts,
+ target_fields=(opts.pk,),
+ join_field=self.remote_field,
+ m2m=not self.unique,
+ direct=False,
+ filtered_relation=filtered_relation,
+ )
+ ]
@cached_property
def reverse_path_infos(self):
@@ -776,8 +855,8 @@ class ForeignObject(RelatedField):
@functools.lru_cache(maxsize=None)
def get_lookups(cls):
bases = inspect.getmro(cls)
- bases = bases[:bases.index(ForeignObject) + 1]
- class_lookups = [parent.__dict__.get('class_lookups', {}) for parent in bases]
+ bases = bases[: bases.index(ForeignObject) + 1]
+ class_lookups = [parent.__dict__.get("class_lookups", {}) for parent in bases]
return cls.merge_dicts(class_lookups)
def contribute_to_class(self, cls, name, private_only=False, **kwargs):
@@ -787,13 +866,22 @@ class ForeignObject(RelatedField):
def contribute_to_related_class(self, cls, related):
# Internal FK's - i.e., those with a related name ending with '+' -
# and swapped models don't get a related descriptor.
- if not self.remote_field.is_hidden() and not related.related_model._meta.swapped:
- setattr(cls._meta.concrete_model, related.get_accessor_name(), self.related_accessor_class(related))
+ if (
+ not self.remote_field.is_hidden()
+ and not related.related_model._meta.swapped
+ ):
+ setattr(
+ cls._meta.concrete_model,
+ related.get_accessor_name(),
+ self.related_accessor_class(related),
+ )
# While 'limit_choices_to' might be a callable, simply pass
# it along for later - this is too early because it's still
# model load time.
if self.remote_field.limit_choices_to:
- cls._meta.related_fkey_lookups.append(self.remote_field.limit_choices_to)
+ cls._meta.related_fkey_lookups.append(
+ self.remote_field.limit_choices_to
+ )
ForeignObject.register_lookup(RelatedIn)
@@ -813,6 +901,7 @@ class ForeignKey(ForeignObject):
By default ForeignKey will target the pk of the remote model but this
behavior can be changed by using the ``to_field`` argument.
"""
+
descriptor_class = ForeignKeyDeferredAttribute
# Field flags
many_to_many = False
@@ -824,21 +913,33 @@ class ForeignKey(ForeignObject):
empty_strings_allowed = False
default_error_messages = {
- 'invalid': _('%(model)s instance with %(field)s %(value)r does not exist.')
+ "invalid": _("%(model)s instance with %(field)s %(value)r does not exist.")
}
description = _("Foreign Key (type determined by related field)")
- def __init__(self, to, on_delete, related_name=None, related_query_name=None,
- limit_choices_to=None, parent_link=False, to_field=None,
- db_constraint=True, **kwargs):
+ def __init__(
+ self,
+ to,
+ on_delete,
+ related_name=None,
+ related_query_name=None,
+ limit_choices_to=None,
+ parent_link=False,
+ to_field=None,
+ db_constraint=True,
+ **kwargs,
+ ):
try:
to._meta.model_name
except AttributeError:
if not isinstance(to, str):
raise TypeError(
- '%s(%r) is invalid. First parameter to ForeignKey must be '
- 'either a model, a model name, or the string %r' % (
- self.__class__.__name__, to, RECURSIVE_RELATIONSHIP_CONSTANT,
+ "%s(%r) is invalid. First parameter to ForeignKey must be "
+ "either a model, a model name, or the string %r"
+ % (
+ self.__class__.__name__,
+ to,
+ RECURSIVE_RELATIONSHIP_CONSTANT,
)
)
else:
@@ -847,17 +948,19 @@ class ForeignKey(ForeignObject):
# be correct until contribute_to_class is called. Refs #12190.
to_field = to_field or (to._meta.pk and to._meta.pk.name)
if not callable(on_delete):
- raise TypeError('on_delete must be callable.')
+ raise TypeError("on_delete must be callable.")
- kwargs['rel'] = self.rel_class(
- self, to, to_field,
+ kwargs["rel"] = self.rel_class(
+ self,
+ to,
+ to_field,
related_name=related_name,
related_query_name=related_query_name,
limit_choices_to=limit_choices_to,
parent_link=parent_link,
on_delete=on_delete,
)
- kwargs.setdefault('db_index', True)
+ kwargs.setdefault("db_index", True)
super().__init__(
to,
@@ -879,54 +982,60 @@ class ForeignKey(ForeignObject):
]
def _check_on_delete(self):
- on_delete = getattr(self.remote_field, 'on_delete', None)
+ on_delete = getattr(self.remote_field, "on_delete", None)
if on_delete == SET_NULL and not self.null:
return [
checks.Error(
- 'Field specifies on_delete=SET_NULL, but cannot be null.',
- hint='Set null=True argument on the field, or change the on_delete rule.',
+ "Field specifies on_delete=SET_NULL, but cannot be null.",
+ hint="Set null=True argument on the field, or change the on_delete rule.",
obj=self,
- id='fields.E320',
+ id="fields.E320",
)
]
elif on_delete == SET_DEFAULT and not self.has_default():
return [
checks.Error(
- 'Field specifies on_delete=SET_DEFAULT, but has no default value.',
- hint='Set a default value, or change the on_delete rule.',
+ "Field specifies on_delete=SET_DEFAULT, but has no default value.",
+ hint="Set a default value, or change the on_delete rule.",
obj=self,
- id='fields.E321',
+ id="fields.E321",
)
]
else:
return []
def _check_unique(self, **kwargs):
- return [
- checks.Warning(
- 'Setting unique=True on a ForeignKey has the same effect as using a OneToOneField.',
- hint='ForeignKey(unique=True) is usually better served by a OneToOneField.',
- obj=self,
- id='fields.W342',
- )
- ] if self.unique else []
+ return (
+ [
+ checks.Warning(
+ "Setting unique=True on a ForeignKey has the same effect as using a OneToOneField.",
+ hint="ForeignKey(unique=True) is usually better served by a OneToOneField.",
+ obj=self,
+ id="fields.W342",
+ )
+ ]
+ if self.unique
+ else []
+ )
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
- del kwargs['to_fields']
- del kwargs['from_fields']
+ del kwargs["to_fields"]
+ del kwargs["from_fields"]
# Handle the simpler arguments
if self.db_index:
- del kwargs['db_index']
+ del kwargs["db_index"]
else:
- kwargs['db_index'] = False
+ kwargs["db_index"] = False
if self.db_constraint is not True:
- kwargs['db_constraint'] = self.db_constraint
+ kwargs["db_constraint"] = self.db_constraint
# Rel needs more work.
to_meta = getattr(self.remote_field.model, "_meta", None)
if self.remote_field.field_name and (
- not to_meta or (to_meta.pk and self.remote_field.field_name != to_meta.pk.name)):
- kwargs['to_field'] = self.remote_field.field_name
+ not to_meta
+ or (to_meta.pk and self.remote_field.field_name != to_meta.pk.name)
+ ):
+ kwargs["to_field"] = self.remote_field.field_name
return name, path, args, kwargs
def to_python(self, value):
@@ -940,15 +1049,17 @@ class ForeignKey(ForeignObject):
"""Get path from the related model to this field's model."""
opts = self.model._meta
from_opts = self.remote_field.model._meta
- return [PathInfo(
- from_opts=from_opts,
- to_opts=opts,
- target_fields=(opts.pk,),
- join_field=self.remote_field,
- m2m=not self.unique,
- direct=False,
- filtered_relation=filtered_relation,
- )]
+ return [
+ PathInfo(
+ from_opts=from_opts,
+ to_opts=opts,
+ target_fields=(opts.pk,),
+ join_field=self.remote_field,
+ m2m=not self.unique,
+ direct=False,
+ filtered_relation=filtered_relation,
+ )
+ ]
def validate(self, value, model_instance):
if self.remote_field.parent_link:
@@ -964,21 +1075,27 @@ class ForeignKey(ForeignObject):
qs = qs.complex_filter(self.get_limit_choices_to())
if not qs.exists():
raise exceptions.ValidationError(
- self.error_messages['invalid'],
- code='invalid',
+ self.error_messages["invalid"],
+ code="invalid",
params={
- 'model': self.remote_field.model._meta.verbose_name, 'pk': value,
- 'field': self.remote_field.field_name, 'value': value,
+ "model": self.remote_field.model._meta.verbose_name,
+ "pk": value,
+ "field": self.remote_field.field_name,
+ "value": value,
}, # 'pk' is included for backwards compatibility
)
def resolve_related_fields(self):
related_fields = super().resolve_related_fields()
for from_field, to_field in related_fields:
- if to_field and to_field.model != self.remote_field.model._meta.concrete_model:
+ if (
+ to_field
+ and to_field.model != self.remote_field.model._meta.concrete_model
+ ):
raise exceptions.FieldError(
"'%s.%s' refers to field '%s' which is not local to model "
- "'%s'." % (
+ "'%s'."
+ % (
self.model._meta.label,
self.name,
to_field.name,
@@ -988,7 +1105,7 @@ class ForeignKey(ForeignObject):
return related_fields
def get_attname(self):
- return '%s_id' % self.name
+ return "%s_id" % self.name
def get_attname_column(self):
attname = self.get_attname()
@@ -1003,9 +1120,13 @@ class ForeignKey(ForeignObject):
return field_default
def get_db_prep_save(self, value, connection):
- if value is None or (value == '' and
- (not self.target_field.empty_strings_allowed or
- connection.features.interprets_empty_strings_as_nulls)):
+ if value is None or (
+ value == ""
+ and (
+ not self.target_field.empty_strings_allowed
+ or connection.features.interprets_empty_strings_as_nulls
+ )
+ ):
return None
else:
return self.target_field.get_db_prep_save(value, connection=connection)
@@ -1023,16 +1144,20 @@ class ForeignKey(ForeignObject):
def formfield(self, *, using=None, **kwargs):
if isinstance(self.remote_field.model, str):
- raise ValueError("Cannot create form field for %r yet, because "
- "its related model %r has not been loaded yet" %
- (self.name, self.remote_field.model))
- return super().formfield(**{
- 'form_class': forms.ModelChoiceField,
- 'queryset': self.remote_field.model._default_manager.using(using),
- 'to_field_name': self.remote_field.field_name,
- **kwargs,
- 'blank': self.blank,
- })
+ raise ValueError(
+ "Cannot create form field for %r yet, because "
+ "its related model %r has not been loaded yet"
+ % (self.name, self.remote_field.model)
+ )
+ return super().formfield(
+ **{
+ "form_class": forms.ModelChoiceField,
+ "queryset": self.remote_field.model._default_manager.using(using),
+ "to_field_name": self.remote_field.field_name,
+ **kwargs,
+ "blank": self.blank,
+ }
+ )
def db_check(self, connection):
return None
@@ -1060,7 +1185,7 @@ class ForeignKey(ForeignObject):
while isinstance(output_field, ForeignKey):
output_field = output_field.target_field
if output_field is self:
- raise ValueError('Cannot resolve output_field.')
+ raise ValueError("Cannot resolve output_field.")
return super().get_col(alias, output_field)
@@ -1085,13 +1210,13 @@ class OneToOneField(ForeignKey):
description = _("One-to-one relationship")
def __init__(self, to, on_delete, to_field=None, **kwargs):
- kwargs['unique'] = True
+ kwargs["unique"] = True
super().__init__(to, on_delete, to_field=to_field, **kwargs)
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if "unique" in kwargs:
- del kwargs['unique']
+ del kwargs["unique"]
return name, path, args, kwargs
def formfield(self, **kwargs):
@@ -1121,44 +1246,54 @@ def create_many_to_many_intermediary_model(field, klass):
through._meta.managed = model._meta.managed or related._meta.managed
to_model = resolve_relation(klass, field.remote_field.model)
- name = '%s_%s' % (klass._meta.object_name, field.name)
+ name = "%s_%s" % (klass._meta.object_name, field.name)
lazy_related_operation(set_managed, klass, to_model, name)
to = make_model_tuple(to_model)[1]
from_ = klass._meta.model_name
if to == from_:
- to = 'to_%s' % to
- from_ = 'from_%s' % from_
+ to = "to_%s" % to
+ from_ = "from_%s" % from_
- meta = type('Meta', (), {
- 'db_table': field._get_m2m_db_table(klass._meta),
- 'auto_created': klass,
- 'app_label': klass._meta.app_label,
- 'db_tablespace': klass._meta.db_tablespace,
- 'unique_together': (from_, to),
- 'verbose_name': _('%(from)s-%(to)s relationship') % {'from': from_, 'to': to},
- 'verbose_name_plural': _('%(from)s-%(to)s relationships') % {'from': from_, 'to': to},
- 'apps': field.model._meta.apps,
- })
+ meta = type(
+ "Meta",
+ (),
+ {
+ "db_table": field._get_m2m_db_table(klass._meta),
+ "auto_created": klass,
+ "app_label": klass._meta.app_label,
+ "db_tablespace": klass._meta.db_tablespace,
+ "unique_together": (from_, to),
+ "verbose_name": _("%(from)s-%(to)s relationship")
+ % {"from": from_, "to": to},
+ "verbose_name_plural": _("%(from)s-%(to)s relationships")
+ % {"from": from_, "to": to},
+ "apps": field.model._meta.apps,
+ },
+ )
# Construct and return the new class.
- return type(name, (models.Model,), {
- 'Meta': meta,
- '__module__': klass.__module__,
- from_: models.ForeignKey(
- klass,
- related_name='%s+' % name,
- db_tablespace=field.db_tablespace,
- db_constraint=field.remote_field.db_constraint,
- on_delete=CASCADE,
- ),
- to: models.ForeignKey(
- to_model,
- related_name='%s+' % name,
- db_tablespace=field.db_tablespace,
- db_constraint=field.remote_field.db_constraint,
- on_delete=CASCADE,
- )
- })
+ return type(
+ name,
+ (models.Model,),
+ {
+ "Meta": meta,
+ "__module__": klass.__module__,
+ from_: models.ForeignKey(
+ klass,
+ related_name="%s+" % name,
+ db_tablespace=field.db_tablespace,
+ db_constraint=field.remote_field.db_constraint,
+ on_delete=CASCADE,
+ ),
+ to: models.ForeignKey(
+ to_model,
+ related_name="%s+" % name,
+ db_tablespace=field.db_tablespace,
+ db_constraint=field.remote_field.db_constraint,
+ on_delete=CASCADE,
+ ),
+ },
+ )
class ManyToManyField(RelatedField):
@@ -1181,31 +1316,45 @@ class ManyToManyField(RelatedField):
description = _("Many-to-many relationship")
- def __init__(self, to, related_name=None, related_query_name=None,
- limit_choices_to=None, symmetrical=None, through=None,
- through_fields=None, db_constraint=True, db_table=None,
- swappable=True, **kwargs):
+ def __init__(
+ self,
+ to,
+ related_name=None,
+ related_query_name=None,
+ limit_choices_to=None,
+ symmetrical=None,
+ through=None,
+ through_fields=None,
+ db_constraint=True,
+ db_table=None,
+ swappable=True,
+ **kwargs,
+ ):
try:
to._meta
except AttributeError:
if not isinstance(to, str):
raise TypeError(
- '%s(%r) is invalid. First parameter to ManyToManyField '
- 'must be either a model, a model name, or the string %r' % (
- self.__class__.__name__, to, RECURSIVE_RELATIONSHIP_CONSTANT,
+ "%s(%r) is invalid. First parameter to ManyToManyField "
+ "must be either a model, a model name, or the string %r"
+ % (
+ self.__class__.__name__,
+ to,
+ RECURSIVE_RELATIONSHIP_CONSTANT,
)
)
if symmetrical is None:
- symmetrical = (to == RECURSIVE_RELATIONSHIP_CONSTANT)
+ symmetrical = to == RECURSIVE_RELATIONSHIP_CONSTANT
if through is not None and db_table is not None:
raise ValueError(
- 'Cannot specify a db_table if an intermediary model is used.'
+ "Cannot specify a db_table if an intermediary model is used."
)
- kwargs['rel'] = self.rel_class(
- self, to,
+ kwargs["rel"] = self.rel_class(
+ self,
+ to,
related_name=related_name,
related_query_name=related_query_name,
limit_choices_to=limit_choices_to,
@@ -1214,7 +1363,7 @@ class ManyToManyField(RelatedField):
through_fields=through_fields,
db_constraint=db_constraint,
)
- self.has_null_arg = 'null' in kwargs
+ self.has_null_arg = "null" in kwargs
super().__init__(
related_name=related_name,
@@ -1239,9 +1388,9 @@ class ManyToManyField(RelatedField):
if self.unique:
return [
checks.Error(
- 'ManyToManyFields cannot be unique.',
+ "ManyToManyFields cannot be unique.",
obj=self,
- id='fields.E330',
+ id="fields.E330",
)
]
return []
@@ -1252,49 +1401,53 @@ class ManyToManyField(RelatedField):
if self.has_null_arg:
warnings.append(
checks.Warning(
- 'null has no effect on ManyToManyField.',
+ "null has no effect on ManyToManyField.",
obj=self,
- id='fields.W340',
+ id="fields.W340",
)
)
if self._validators:
warnings.append(
checks.Warning(
- 'ManyToManyField does not support validators.',
+ "ManyToManyField does not support validators.",
obj=self,
- id='fields.W341',
+ id="fields.W341",
)
)
if self.remote_field.symmetrical and self._related_name:
warnings.append(
checks.Warning(
- 'related_name has no effect on ManyToManyField '
+ "related_name has no effect on ManyToManyField "
'with a symmetrical relationship, e.g. to "self".',
obj=self,
- id='fields.W345',
+ id="fields.W345",
)
)
return warnings
def _check_relationship_model(self, from_model=None, **kwargs):
- if hasattr(self.remote_field.through, '_meta'):
+ if hasattr(self.remote_field.through, "_meta"):
qualified_model_name = "%s.%s" % (
- self.remote_field.through._meta.app_label, self.remote_field.through.__name__)
+ self.remote_field.through._meta.app_label,
+ self.remote_field.through.__name__,
+ )
else:
qualified_model_name = self.remote_field.through
errors = []
- if self.remote_field.through not in self.opts.apps.get_models(include_auto_created=True):
+ if self.remote_field.through not in self.opts.apps.get_models(
+ include_auto_created=True
+ ):
# The relationship model is not installed.
errors.append(
checks.Error(
"Field specifies a many-to-many relation through model "
"'%s', which has not been installed." % qualified_model_name,
obj=self,
- id='fields.E331',
+ id="fields.E331",
)
)
@@ -1316,7 +1469,7 @@ class ManyToManyField(RelatedField):
# Count foreign keys in intermediate model
if self_referential:
seen_self = sum(
- from_model == getattr(field.remote_field, 'model', None)
+ from_model == getattr(field.remote_field, "model", None)
for field in self.remote_field.through._meta.fields
)
@@ -1327,41 +1480,46 @@ class ManyToManyField(RelatedField):
"'%s', but it has more than two foreign keys "
"to '%s', which is ambiguous. You must specify "
"which two foreign keys Django should use via the "
- "through_fields keyword argument." % (self, from_model_name),
+ "through_fields keyword argument."
+ % (self, from_model_name),
hint="Use through_fields to specify which two foreign keys Django should use.",
obj=self.remote_field.through,
- id='fields.E333',
+ id="fields.E333",
)
)
else:
# Count foreign keys in relationship model
seen_from = sum(
- from_model == getattr(field.remote_field, 'model', None)
+ from_model == getattr(field.remote_field, "model", None)
for field in self.remote_field.through._meta.fields
)
seen_to = sum(
- to_model == getattr(field.remote_field, 'model', None)
+ to_model == getattr(field.remote_field, "model", None)
for field in self.remote_field.through._meta.fields
)
if seen_from > 1 and not self.remote_field.through_fields:
errors.append(
checks.Error(
- ("The model is used as an intermediate model by "
- "'%s', but it has more than one foreign key "
- "from '%s', which is ambiguous. You must specify "
- "which foreign key Django should use via the "
- "through_fields keyword argument.") % (self, from_model_name),
+ (
+ "The model is used as an intermediate model by "
+ "'%s', but it has more than one foreign key "
+ "from '%s', which is ambiguous. You must specify "
+ "which foreign key Django should use via the "
+ "through_fields keyword argument."
+ )
+ % (self, from_model_name),
hint=(
- 'If you want to create a recursive relationship, '
+ "If you want to create a recursive relationship, "
'use ManyToManyField("%s", through="%s").'
- ) % (
+ )
+ % (
RECURSIVE_RELATIONSHIP_CONSTANT,
relationship_model_name,
),
obj=self,
- id='fields.E334',
+ id="fields.E334",
)
)
@@ -1374,14 +1532,15 @@ class ManyToManyField(RelatedField):
"which foreign key Django should use via the "
"through_fields keyword argument." % (self, to_model_name),
hint=(
- 'If you want to create a recursive relationship, '
+ "If you want to create a recursive relationship, "
'use ManyToManyField("%s", through="%s").'
- ) % (
+ )
+ % (
RECURSIVE_RELATIONSHIP_CONSTANT,
relationship_model_name,
),
obj=self,
- id='fields.E335',
+ id="fields.E335",
)
)
@@ -1389,11 +1548,10 @@ class ManyToManyField(RelatedField):
errors.append(
checks.Error(
"The model is used as an intermediate model by "
- "'%s', but it does not have a foreign key to '%s' or '%s'." % (
- self, from_model_name, to_model_name
- ),
+ "'%s', but it does not have a foreign key to '%s' or '%s'."
+ % (self, from_model_name, to_model_name),
obj=self.remote_field.through,
- id='fields.E336',
+ id="fields.E336",
)
)
@@ -1401,8 +1559,11 @@ class ManyToManyField(RelatedField):
if self.remote_field.through_fields is not None:
# Validate that we're given an iterable of at least two items
# and that none of them is "falsy".
- if not (len(self.remote_field.through_fields) >= 2 and
- self.remote_field.through_fields[0] and self.remote_field.through_fields[1]):
+ if not (
+ len(self.remote_field.through_fields) >= 2
+ and self.remote_field.through_fields[0]
+ and self.remote_field.through_fields[1]
+ ):
errors.append(
checks.Error(
"Field specifies 'through_fields' but does not provide "
@@ -1410,7 +1571,7 @@ class ManyToManyField(RelatedField):
"for the relation through model '%s'." % qualified_model_name,
hint="Make sure you specify 'through_fields' as through_fields=('field1', 'field2')",
obj=self,
- id='fields.E337',
+ id="fields.E337",
)
)
@@ -1424,20 +1585,34 @@ class ManyToManyField(RelatedField):
"where the field is attached to."
)
- source, through, target = from_model, self.remote_field.through, self.remote_field.model
- source_field_name, target_field_name = self.remote_field.through_fields[:2]
+ source, through, target = (
+ from_model,
+ self.remote_field.through,
+ self.remote_field.model,
+ )
+ source_field_name, target_field_name = self.remote_field.through_fields[
+ :2
+ ]
- for field_name, related_model in ((source_field_name, source),
- (target_field_name, target)):
+ for field_name, related_model in (
+ (source_field_name, source),
+ (target_field_name, target),
+ ):
possible_field_names = []
for f in through._meta.fields:
- if hasattr(f, 'remote_field') and getattr(f.remote_field, 'model', None) == related_model:
+ if (
+ hasattr(f, "remote_field")
+ and getattr(f.remote_field, "model", None) == related_model
+ ):
possible_field_names.append(f.name)
if possible_field_names:
- hint = "Did you mean one of the following foreign keys to '%s': %s?" % (
- related_model._meta.object_name,
- ', '.join(possible_field_names),
+ hint = (
+ "Did you mean one of the following foreign keys to '%s': %s?"
+ % (
+ related_model._meta.object_name,
+ ", ".join(possible_field_names),
+ )
)
else:
hint = None
@@ -1451,28 +1626,36 @@ class ManyToManyField(RelatedField):
% (qualified_model_name, field_name),
hint=hint,
obj=self,
- id='fields.E338',
+ id="fields.E338",
)
)
else:
- if not (hasattr(field, 'remote_field') and
- getattr(field.remote_field, 'model', None) == related_model):
+ if not (
+ hasattr(field, "remote_field")
+ and getattr(field.remote_field, "model", None)
+ == related_model
+ ):
errors.append(
checks.Error(
- "'%s.%s' is not a foreign key to '%s'." % (
- through._meta.object_name, field_name,
+ "'%s.%s' is not a foreign key to '%s'."
+ % (
+ through._meta.object_name,
+ field_name,
related_model._meta.object_name,
),
hint=hint,
obj=self,
- id='fields.E339',
+ id="fields.E339",
)
)
return errors
def _check_table_uniqueness(self, **kwargs):
- if isinstance(self.remote_field.through, str) or not self.remote_field.through._meta.managed:
+ if (
+ isinstance(self.remote_field.through, str)
+ or not self.remote_field.through._meta.managed
+ ):
return []
registered_tables = {
model._meta.db_table: model
@@ -1483,25 +1666,31 @@ class ManyToManyField(RelatedField):
model = registered_tables.get(m2m_db_table)
# The second condition allows multiple m2m relations on a model if
# some point to a through model that proxies another through model.
- if model and model._meta.concrete_model != self.remote_field.through._meta.concrete_model:
+ if (
+ model
+ and model._meta.concrete_model
+ != self.remote_field.through._meta.concrete_model
+ ):
if model._meta.auto_created:
+
def _get_field_name(model):
for field in model._meta.auto_created._meta.many_to_many:
if field.remote_field.through is model:
return field.name
+
opts = model._meta.auto_created._meta
- clashing_obj = '%s.%s' % (opts.label, _get_field_name(model))
+ clashing_obj = "%s.%s" % (opts.label, _get_field_name(model))
else:
clashing_obj = model._meta.label
if settings.DATABASE_ROUTERS:
- error_class, error_id = checks.Warning, 'fields.W344'
+ error_class, error_id = checks.Warning, "fields.W344"
error_hint = (
- 'You have configured settings.DATABASE_ROUTERS. Verify '
- 'that the table of %r is correctly routed to a separate '
- 'database.' % clashing_obj
+ "You have configured settings.DATABASE_ROUTERS. Verify "
+ "that the table of %r is correctly routed to a separate "
+ "database." % clashing_obj
)
else:
- error_class, error_id = checks.Error, 'fields.E340'
+ error_class, error_id = checks.Error, "fields.E340"
error_hint = None
return [
error_class(
@@ -1518,34 +1707,34 @@ class ManyToManyField(RelatedField):
name, path, args, kwargs = super().deconstruct()
# Handle the simpler arguments.
if self.db_table is not None:
- kwargs['db_table'] = self.db_table
+ kwargs["db_table"] = self.db_table
if self.remote_field.db_constraint is not True:
- kwargs['db_constraint'] = self.remote_field.db_constraint
+ kwargs["db_constraint"] = self.remote_field.db_constraint
# Rel needs more work.
if isinstance(self.remote_field.model, str):
- kwargs['to'] = self.remote_field.model
+ kwargs["to"] = self.remote_field.model
else:
- kwargs['to'] = self.remote_field.model._meta.label
- if getattr(self.remote_field, 'through', None) is not None:
+ kwargs["to"] = self.remote_field.model._meta.label
+ if getattr(self.remote_field, "through", None) is not None:
if isinstance(self.remote_field.through, str):
- kwargs['through'] = self.remote_field.through
+ kwargs["through"] = self.remote_field.through
elif not self.remote_field.through._meta.auto_created:
- kwargs['through'] = self.remote_field.through._meta.label
+ kwargs["through"] = self.remote_field.through._meta.label
# If swappable is True, then see if we're actually pointing to the target
# of a swap.
swappable_setting = self.swappable_setting
if swappable_setting is not None:
# If it's already a settings reference, error.
- if hasattr(kwargs['to'], "setting_name"):
- if kwargs['to'].setting_name != swappable_setting:
+ if hasattr(kwargs["to"], "setting_name"):
+ if kwargs["to"].setting_name != swappable_setting:
raise ValueError(
"Cannot deconstruct a ManyToManyField pointing to a "
"model that is swapped in place of more than one model "
- "(%s and %s)" % (kwargs['to'].setting_name, swappable_setting)
+ "(%s and %s)" % (kwargs["to"].setting_name, swappable_setting)
)
- kwargs['to'] = SettingsReference(
- kwargs['to'],
+ kwargs["to"] = SettingsReference(
+ kwargs["to"],
swappable_setting,
)
return name, path, args, kwargs
@@ -1605,7 +1794,7 @@ class ManyToManyField(RelatedField):
elif self.db_table:
return self.db_table
else:
- m2m_table_name = '%s_%s' % (utils.strip_quotes(opts.db_table), self.name)
+ m2m_table_name = "%s_%s" % (utils.strip_quotes(opts.db_table), self.name)
return utils.truncate_name(m2m_table_name, connection.ops.max_name_length())
def _get_m2m_attr(self, related, attr):
@@ -1613,7 +1802,7 @@ class ManyToManyField(RelatedField):
Function that can be curried to provide the source accessor or DB
column name for the m2m table.
"""
- cache_attr = '_m2m_%s_cache' % attr
+ cache_attr = "_m2m_%s_cache" % attr
if hasattr(self, cache_attr):
return getattr(self, cache_attr)
if self.remote_field.through_fields is not None:
@@ -1621,8 +1810,11 @@ class ManyToManyField(RelatedField):
else:
link_field_name = None
for f in self.remote_field.through._meta.fields:
- if (f.is_relation and f.remote_field.model == related.related_model and
- (link_field_name is None or link_field_name == f.name)):
+ if (
+ f.is_relation
+ and f.remote_field.model == related.related_model
+ and (link_field_name is None or link_field_name == f.name)
+ ):
setattr(self, cache_attr, getattr(f, attr))
return getattr(self, cache_attr)
@@ -1631,7 +1823,7 @@ class ManyToManyField(RelatedField):
Function that can be curried to provide the related accessor or DB
column name for the m2m table.
"""
- cache_attr = '_m2m_reverse_%s_cache' % attr
+ cache_attr = "_m2m_reverse_%s_cache" % attr
if hasattr(self, cache_attr):
return getattr(self, cache_attr)
found = False
@@ -1664,8 +1856,8 @@ class ManyToManyField(RelatedField):
# automatically. The funky name reduces the chance of an accidental
# clash.
if self.remote_field.symmetrical and (
- self.remote_field.model == RECURSIVE_RELATIONSHIP_CONSTANT or
- self.remote_field.model == cls._meta.object_name
+ self.remote_field.model == RECURSIVE_RELATIONSHIP_CONSTANT
+ or self.remote_field.model == cls._meta.object_name
):
self.remote_field.related_name = "%s_rel_+" % name
elif self.remote_field.is_hidden():
@@ -1673,7 +1865,7 @@ class ManyToManyField(RelatedField):
# related_name with one generated from the m2m field name. Django
# still uses backwards relations internally and we need to avoid
# clashes between multiple m2m fields with related_name == '+'.
- self.remote_field.related_name = '_%s_%s_%s_+' % (
+ self.remote_field.related_name = "_%s_%s_%s_+" % (
cls._meta.app_label,
cls.__name__.lower(),
name,
@@ -1687,11 +1879,17 @@ class ManyToManyField(RelatedField):
# 3) The class owning the m2m field has been swapped out.
if not cls._meta.abstract:
if self.remote_field.through:
+
def resolve_through_model(_, model, field):
field.remote_field.through = model
- lazy_related_operation(resolve_through_model, cls, self.remote_field.through, field=self)
+
+ lazy_related_operation(
+ resolve_through_model, cls, self.remote_field.through, field=self
+ )
elif not cls._meta.swapped:
- self.remote_field.through = create_many_to_many_intermediary_model(self, cls)
+ self.remote_field.through = create_many_to_many_intermediary_model(
+ self, cls
+ )
# Add the descriptor for the m2m relation.
setattr(cls, self.name, ManyToManyDescriptor(self.remote_field, reverse=False))
@@ -1702,19 +1900,30 @@ class ManyToManyField(RelatedField):
def contribute_to_related_class(self, cls, related):
# Internal M2Ms (i.e., those with a related name ending with '+')
# and swapped models don't get a related descriptor.
- if not self.remote_field.is_hidden() and not related.related_model._meta.swapped:
- setattr(cls, related.get_accessor_name(), ManyToManyDescriptor(self.remote_field, reverse=True))
+ if (
+ not self.remote_field.is_hidden()
+ and not related.related_model._meta.swapped
+ ):
+ setattr(
+ cls,
+ related.get_accessor_name(),
+ ManyToManyDescriptor(self.remote_field, reverse=True),
+ )
# Set up the accessors for the column names on the m2m table.
- self.m2m_column_name = partial(self._get_m2m_attr, related, 'column')
- self.m2m_reverse_name = partial(self._get_m2m_reverse_attr, related, 'column')
+ self.m2m_column_name = partial(self._get_m2m_attr, related, "column")
+ self.m2m_reverse_name = partial(self._get_m2m_reverse_attr, related, "column")
- self.m2m_field_name = partial(self._get_m2m_attr, related, 'name')
- self.m2m_reverse_field_name = partial(self._get_m2m_reverse_attr, related, 'name')
+ self.m2m_field_name = partial(self._get_m2m_attr, related, "name")
+ self.m2m_reverse_field_name = partial(
+ self._get_m2m_reverse_attr, related, "name"
+ )
- get_m2m_rel = partial(self._get_m2m_attr, related, 'remote_field')
+ get_m2m_rel = partial(self._get_m2m_attr, related, "remote_field")
self.m2m_target_field_name = lambda: get_m2m_rel().field_name
- get_m2m_reverse_rel = partial(self._get_m2m_reverse_attr, related, 'remote_field')
+ get_m2m_reverse_rel = partial(
+ self._get_m2m_reverse_attr, related, "remote_field"
+ )
self.m2m_reverse_target_field_name = lambda: get_m2m_reverse_rel().field_name
def set_attributes_from_rel(self):
@@ -1728,17 +1937,17 @@ class ManyToManyField(RelatedField):
def formfield(self, *, using=None, **kwargs):
defaults = {
- 'form_class': forms.ModelMultipleChoiceField,
- 'queryset': self.remote_field.model._default_manager.using(using),
+ "form_class": forms.ModelMultipleChoiceField,
+ "queryset": self.remote_field.model._default_manager.using(using),
**kwargs,
}
# If initial is passed in, it's a list of related objects, but the
# MultipleChoiceField takes a list of IDs.
- if defaults.get('initial') is not None:
- initial = defaults['initial']
+ if defaults.get("initial") is not None:
+ initial = defaults["initial"]
if callable(initial):
initial = initial()
- defaults['initial'] = [i.pk for i in initial]
+ defaults["initial"] = [i.pk for i in initial]
return super().formfield(**defaults)
def db_check(self, connection):
diff --git a/django/db/models/fields/related_descriptors.py b/django/db/models/fields/related_descriptors.py
index 9c50ef16ce..3f67ed8166 100644
--- a/django/db/models/fields/related_descriptors.py
+++ b/django/db/models/fields/related_descriptors.py
@@ -74,7 +74,9 @@ from django.utils.functional import cached_property
class ForeignKeyDeferredAttribute(DeferredAttribute):
def __set__(self, instance, value):
- if instance.__dict__.get(self.field.attname) != value and self.field.is_cached(instance):
+ if instance.__dict__.get(self.field.attname) != value and self.field.is_cached(
+ instance
+ ):
self.field.delete_cached_value(instance)
instance.__dict__[self.field.attname] = value
@@ -101,14 +103,16 @@ class ForwardManyToOneDescriptor:
# related model might not be resolved yet; `self.field.model` might
# still be a string model reference.
return type(
- 'RelatedObjectDoesNotExist',
- (self.field.remote_field.model.DoesNotExist, AttributeError), {
- '__module__': self.field.model.__module__,
- '__qualname__': '%s.%s.RelatedObjectDoesNotExist' % (
+ "RelatedObjectDoesNotExist",
+ (self.field.remote_field.model.DoesNotExist, AttributeError),
+ {
+ "__module__": self.field.model.__module__,
+ "__qualname__": "%s.%s.RelatedObjectDoesNotExist"
+ % (
self.field.model.__qualname__,
self.field.name,
),
- }
+ },
)
def is_cached(self, instance):
@@ -135,9 +139,12 @@ class ForwardManyToOneDescriptor:
# The check for len(...) == 1 is a special case that allows the query
# to be join-less and smaller. Refs #21760.
if remote_field.is_hidden() or len(self.field.foreign_related_fields) == 1:
- query = {'%s__in' % related_field.name: {instance_attr(inst)[0] for inst in instances}}
+ query = {
+ "%s__in"
+ % related_field.name: {instance_attr(inst)[0] for inst in instances}
+ }
else:
- query = {'%s__in' % self.field.related_query_name(): instances}
+ query = {"%s__in" % self.field.related_query_name(): instances}
queryset = queryset.filter(**query)
# Since we're going to assign directly in the cache,
@@ -146,7 +153,14 @@ class ForwardManyToOneDescriptor:
for rel_obj in queryset:
instance = instances_dict[rel_obj_attr(rel_obj)]
remote_field.set_cached_value(rel_obj, instance)
- return queryset, rel_obj_attr, instance_attr, True, self.field.get_cache_name(), False
+ return (
+ queryset,
+ rel_obj_attr,
+ instance_attr,
+ True,
+ self.field.get_cache_name(),
+ False,
+ )
def get_object(self, instance):
qs = self.get_queryset(instance=instance)
@@ -173,7 +187,11 @@ class ForwardManyToOneDescriptor:
rel_obj = self.field.get_cached_value(instance)
except KeyError:
has_value = None not in self.field.get_local_related_value(instance)
- ancestor_link = instance._meta.get_ancestor_link(self.field.model) if has_value else None
+ ancestor_link = (
+ instance._meta.get_ancestor_link(self.field.model)
+ if has_value
+ else None
+ )
if ancestor_link and ancestor_link.is_cached(instance):
# An ancestor link will exist if this field is defined on a
# multi-table inheritance parent of the instance's class.
@@ -211,9 +229,12 @@ class ForwardManyToOneDescriptor:
- ``value`` is the ``parent`` instance on the right of the equal sign
"""
# An object must be an instance of the related class.
- if value is not None and not isinstance(value, self.field.remote_field.model._meta.concrete_model):
+ if value is not None and not isinstance(
+ value, self.field.remote_field.model._meta.concrete_model
+ ):
raise ValueError(
- 'Cannot assign "%r": "%s.%s" must be a "%s" instance.' % (
+ 'Cannot assign "%r": "%s.%s" must be a "%s" instance.'
+ % (
value,
instance._meta.object_name,
self.field.name,
@@ -222,11 +243,18 @@ class ForwardManyToOneDescriptor:
)
elif value is not None:
if instance._state.db is None:
- instance._state.db = router.db_for_write(instance.__class__, instance=value)
+ instance._state.db = router.db_for_write(
+ instance.__class__, instance=value
+ )
if value._state.db is None:
- value._state.db = router.db_for_write(value.__class__, instance=instance)
+ value._state.db = router.db_for_write(
+ value.__class__, instance=instance
+ )
if not router.allow_relation(value, instance):
- raise ValueError('Cannot assign "%r": the current database router prevents this relation.' % value)
+ raise ValueError(
+ 'Cannot assign "%r": the current database router prevents this relation.'
+ % value
+ )
remote_field = self.field.remote_field
# If we're setting the value of a OneToOneField to None, we need to clear
@@ -314,12 +342,15 @@ class ForwardOneToOneDescriptor(ForwardManyToOneDescriptor):
opts = instance._meta
# Inherited primary key fields from this object's base classes.
inherited_pk_fields = [
- field for field in opts.concrete_fields
+ field
+ for field in opts.concrete_fields
if field.primary_key and field.remote_field
]
for field in inherited_pk_fields:
rel_model_pk_name = field.remote_field.model._meta.pk.attname
- raw_value = getattr(value, rel_model_pk_name) if value is not None else None
+ raw_value = (
+ getattr(value, rel_model_pk_name) if value is not None else None
+ )
setattr(instance, rel_model_pk_name, raw_value)
@@ -346,13 +377,15 @@ class ReverseOneToOneDescriptor:
# The exception isn't created at initialization time for the sake of
# consistency with `ForwardManyToOneDescriptor`.
return type(
- 'RelatedObjectDoesNotExist',
- (self.related.related_model.DoesNotExist, AttributeError), {
- '__module__': self.related.model.__module__,
- '__qualname__': '%s.%s.RelatedObjectDoesNotExist' % (
+ "RelatedObjectDoesNotExist",
+ (self.related.related_model.DoesNotExist, AttributeError),
+ {
+ "__module__": self.related.model.__module__,
+ "__qualname__": "%s.%s.RelatedObjectDoesNotExist"
+ % (
self.related.model.__qualname__,
self.related.name,
- )
+ ),
},
)
@@ -370,7 +403,7 @@ class ReverseOneToOneDescriptor:
rel_obj_attr = self.related.field.get_local_related_value
instance_attr = self.related.field.get_foreign_related_value
instances_dict = {instance_attr(inst): inst for inst in instances}
- query = {'%s__in' % self.related.field.name: instances}
+ query = {"%s__in" % self.related.field.name: instances}
queryset = queryset.filter(**query)
# Since we're going to assign directly in the cache,
@@ -378,7 +411,14 @@ class ReverseOneToOneDescriptor:
for rel_obj in queryset:
instance = instances_dict[rel_obj_attr(rel_obj)]
self.related.field.set_cached_value(rel_obj, instance)
- return queryset, rel_obj_attr, instance_attr, True, self.related.get_cache_name(), False
+ return (
+ queryset,
+ rel_obj_attr,
+ instance_attr,
+ True,
+ self.related.get_cache_name(),
+ False,
+ )
def __get__(self, instance, cls=None):
"""
@@ -419,10 +459,8 @@ class ReverseOneToOneDescriptor:
if rel_obj is None:
raise self.RelatedObjectDoesNotExist(
- "%s has no %s." % (
- instance.__class__.__name__,
- self.related.get_accessor_name()
- )
+ "%s has no %s."
+ % (instance.__class__.__name__, self.related.get_accessor_name())
)
else:
return rel_obj
@@ -458,7 +496,8 @@ class ReverseOneToOneDescriptor:
elif not isinstance(value, self.related.related_model):
# An object must be an instance of the related class.
raise ValueError(
- 'Cannot assign "%r": "%s.%s" must be a "%s" instance.' % (
+ 'Cannot assign "%r": "%s.%s" must be a "%s" instance.'
+ % (
value,
instance._meta.object_name,
self.related.get_accessor_name(),
@@ -467,13 +506,23 @@ class ReverseOneToOneDescriptor:
)
else:
if instance._state.db is None:
- instance._state.db = router.db_for_write(instance.__class__, instance=value)
+ instance._state.db = router.db_for_write(
+ instance.__class__, instance=value
+ )
if value._state.db is None:
- value._state.db = router.db_for_write(value.__class__, instance=instance)
+ value._state.db = router.db_for_write(
+ value.__class__, instance=instance
+ )
if not router.allow_relation(value, instance):
- raise ValueError('Cannot assign "%r": the current database router prevents this relation.' % value)
+ raise ValueError(
+ 'Cannot assign "%r": the current database router prevents this relation.'
+ % value
+ )
- related_pk = tuple(getattr(instance, field.attname) for field in self.related.field.foreign_related_fields)
+ related_pk = tuple(
+ getattr(instance, field.attname)
+ for field in self.related.field.foreign_related_fields
+ )
# Set the value of the related field to the value of the related object's related field
for index, field in enumerate(self.related.field.local_related_fields):
setattr(value, field.attname, related_pk[index])
@@ -548,13 +597,13 @@ class ReverseManyToOneDescriptor:
def _get_set_deprecation_msg_params(self):
return (
- 'reverse side of a related set',
+ "reverse side of a related set",
self.rel.get_accessor_name(),
)
def __set__(self, instance, value):
raise TypeError(
- 'Direct assignment to the %s is prohibited. Use %s.set() instead.'
+ "Direct assignment to the %s is prohibited. Use %s.set() instead."
% self._get_set_deprecation_msg_params(),
)
@@ -581,6 +630,7 @@ def create_reverse_many_to_one_manager(superclass, rel):
manager = getattr(self.model, manager)
manager_class = create_reverse_many_to_one_manager(manager.__class__, rel)
return manager_class(self.instance)
+
do_not_call_in_templates = True
def _apply_rel_filters(self, queryset):
@@ -588,7 +638,9 @@ def create_reverse_many_to_one_manager(superclass, rel):
Filter the queryset for the instance this manager is bound to.
"""
db = self._db or router.db_for_read(self.model, instance=self.instance)
- empty_strings_as_null = connections[db].features.interprets_empty_strings_as_nulls
+ empty_strings_as_null = connections[
+ db
+ ].features.interprets_empty_strings_as_nulls
queryset._add_hints(instance=self.instance)
if self._db:
queryset = queryset.using(self._db)
@@ -596,7 +648,7 @@ def create_reverse_many_to_one_manager(superclass, rel):
queryset = queryset.filter(**self.core_filters)
for field in self.field.foreign_related_fields:
val = getattr(self.instance, field.attname)
- if val is None or (val == '' and empty_strings_as_null):
+ if val is None or (val == "" and empty_strings_as_null):
return queryset.none()
if self.field.many_to_one:
# Guard against field-like objects such as GenericRelation
@@ -608,24 +660,32 @@ def create_reverse_many_to_one_manager(superclass, rel):
except FieldError:
# The relationship has multiple target fields. Use a tuple
# for related object id.
- rel_obj_id = tuple([
- getattr(self.instance, target_field.attname)
- for target_field in self.field.path_infos[-1].target_fields
- ])
+ rel_obj_id = tuple(
+ [
+ getattr(self.instance, target_field.attname)
+ for target_field in self.field.path_infos[-1].target_fields
+ ]
+ )
else:
rel_obj_id = getattr(self.instance, target_field.attname)
- queryset._known_related_objects = {self.field: {rel_obj_id: self.instance}}
+ queryset._known_related_objects = {
+ self.field: {rel_obj_id: self.instance}
+ }
return queryset
def _remove_prefetched_objects(self):
try:
- self.instance._prefetched_objects_cache.pop(self.field.remote_field.get_cache_name())
+ self.instance._prefetched_objects_cache.pop(
+ self.field.remote_field.get_cache_name()
+ )
except (AttributeError, KeyError):
pass # nothing to clear from cache
def get_queryset(self):
try:
- return self.instance._prefetched_objects_cache[self.field.remote_field.get_cache_name()]
+ return self.instance._prefetched_objects_cache[
+ self.field.remote_field.get_cache_name()
+ ]
except (AttributeError, KeyError):
queryset = super().get_queryset()
return self._apply_rel_filters(queryset)
@@ -640,7 +700,7 @@ def create_reverse_many_to_one_manager(superclass, rel):
rel_obj_attr = self.field.get_local_related_value
instance_attr = self.field.get_foreign_related_value
instances_dict = {instance_attr(inst): inst for inst in instances}
- query = {'%s__in' % self.field.name: instances}
+ query = {"%s__in" % self.field.name: instances}
queryset = queryset.filter(**query)
# Since we just bypassed this class' get_queryset(), we must manage
@@ -658,9 +718,13 @@ def create_reverse_many_to_one_manager(superclass, rel):
def check_and_update_obj(obj):
if not isinstance(obj, self.model):
- raise TypeError("'%s' instance expected, got %r" % (
- self.model._meta.object_name, obj,
- ))
+ raise TypeError(
+ "'%s' instance expected, got %r"
+ % (
+ self.model._meta.object_name,
+ obj,
+ )
+ )
setattr(obj, self.field.name, self.instance)
if bulk:
@@ -673,36 +737,43 @@ def create_reverse_many_to_one_manager(superclass, rel):
"the object first." % obj
)
pks.append(obj.pk)
- self.model._base_manager.using(db).filter(pk__in=pks).update(**{
- self.field.name: self.instance,
- })
+ self.model._base_manager.using(db).filter(pk__in=pks).update(
+ **{
+ self.field.name: self.instance,
+ }
+ )
else:
with transaction.atomic(using=db, savepoint=False):
for obj in objs:
check_and_update_obj(obj)
obj.save()
+
add.alters_data = True
def create(self, **kwargs):
kwargs[self.field.name] = self.instance
db = router.db_for_write(self.model, instance=self.instance)
return super(RelatedManager, self.db_manager(db)).create(**kwargs)
+
create.alters_data = True
def get_or_create(self, **kwargs):
kwargs[self.field.name] = self.instance
db = router.db_for_write(self.model, instance=self.instance)
return super(RelatedManager, self.db_manager(db)).get_or_create(**kwargs)
+
get_or_create.alters_data = True
def update_or_create(self, **kwargs):
kwargs[self.field.name] = self.instance
db = router.db_for_write(self.model, instance=self.instance)
return super(RelatedManager, self.db_manager(db)).update_or_create(**kwargs)
+
update_or_create.alters_data = True
# remove() and clear() are only provided if the ForeignKey can have a value of null.
if rel.field.null:
+
def remove(self, *objs, bulk=True):
if not objs:
return
@@ -710,9 +781,13 @@ def create_reverse_many_to_one_manager(superclass, rel):
old_ids = set()
for obj in objs:
if not isinstance(obj, self.model):
- raise TypeError("'%s' instance expected, got %r" % (
- self.model._meta.object_name, obj,
- ))
+ raise TypeError(
+ "'%s' instance expected, got %r"
+ % (
+ self.model._meta.object_name,
+ obj,
+ )
+ )
# Is obj actually part of this descriptor set?
if self.field.get_local_related_value(obj) == val:
old_ids.add(obj.pk)
@@ -721,10 +796,12 @@ def create_reverse_many_to_one_manager(superclass, rel):
"%r is not related to %r." % (obj, self.instance)
)
self._clear(self.filter(pk__in=old_ids), bulk)
+
remove.alters_data = True
def clear(self, *, bulk=True):
self._clear(self, bulk)
+
clear.alters_data = True
def _clear(self, queryset, bulk):
@@ -739,6 +816,7 @@ def create_reverse_many_to_one_manager(superclass, rel):
for obj in queryset:
setattr(obj, self.field.name, None)
obj.save(update_fields=[self.field.name])
+
_clear.alters_data = True
def set(self, objs, *, bulk=True, clear=False):
@@ -765,6 +843,7 @@ def create_reverse_many_to_one_manager(superclass, rel):
self.add(*new_objs, bulk=bulk)
else:
self.add(*objs, bulk=bulk)
+
set.alters_data = True
return RelatedManager
@@ -822,7 +901,8 @@ class ManyToManyDescriptor(ReverseManyToOneDescriptor):
def _get_set_deprecation_msg_params(self):
return (
- '%s side of a many-to-many set' % ('reverse' if self.reverse else 'forward'),
+ "%s side of a many-to-many set"
+ % ("reverse" if self.reverse else "forward"),
self.rel.get_accessor_name() if self.reverse else self.field.name,
)
@@ -865,41 +945,51 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
self.core_filters = {}
self.pk_field_names = {}
for lh_field, rh_field in self.source_field.related_fields:
- core_filter_key = '%s__%s' % (self.query_field_name, rh_field.name)
+ core_filter_key = "%s__%s" % (self.query_field_name, rh_field.name)
self.core_filters[core_filter_key] = getattr(instance, rh_field.attname)
self.pk_field_names[lh_field.name] = rh_field.name
self.related_val = self.source_field.get_foreign_related_value(instance)
if None in self.related_val:
- raise ValueError('"%r" needs to have a value for field "%s" before '
- 'this many-to-many relationship can be used.' %
- (instance, self.pk_field_names[self.source_field_name]))
+ raise ValueError(
+ '"%r" needs to have a value for field "%s" before '
+ "this many-to-many relationship can be used."
+ % (instance, self.pk_field_names[self.source_field_name])
+ )
# Even if this relation is not to pk, we require still pk value.
# The wish is that the instance has been already saved to DB,
# although having a pk value isn't a guarantee of that.
if instance.pk is None:
- raise ValueError("%r instance needs to have a primary key value before "
- "a many-to-many relationship can be used." %
- instance.__class__.__name__)
+ raise ValueError(
+ "%r instance needs to have a primary key value before "
+ "a many-to-many relationship can be used."
+ % instance.__class__.__name__
+ )
def __call__(self, *, manager):
manager = getattr(self.model, manager)
- manager_class = create_forward_many_to_many_manager(manager.__class__, rel, reverse)
+ manager_class = create_forward_many_to_many_manager(
+ manager.__class__, rel, reverse
+ )
return manager_class(instance=self.instance)
+
do_not_call_in_templates = True
def _build_remove_filters(self, removed_vals):
filters = Q((self.source_field_name, self.related_val))
# No need to add a subquery condition if removed_vals is a QuerySet without
# filters.
- removed_vals_filters = (not isinstance(removed_vals, QuerySet) or
- removed_vals._has_filters())
+ removed_vals_filters = (
+ not isinstance(removed_vals, QuerySet) or removed_vals._has_filters()
+ )
if removed_vals_filters:
- filters &= Q((f'{self.target_field_name}__in', removed_vals))
+ filters &= Q((f"{self.target_field_name}__in", removed_vals))
if self.symmetrical:
symmetrical_filters = Q((self.target_field_name, self.related_val))
if removed_vals_filters:
- symmetrical_filters &= Q((f'{self.source_field_name}__in', removed_vals))
+ symmetrical_filters &= Q(
+ (f"{self.source_field_name}__in", removed_vals)
+ )
filters |= symmetrical_filters
return filters
@@ -933,7 +1023,7 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
queryset._add_hints(instance=instances[0])
queryset = queryset.using(queryset._db or self._db)
- query = {'%s__in' % self.query_field_name: instances}
+ query = {"%s__in" % self.query_field_name: instances}
queryset = queryset._next_is_sticky().filter(**query)
# M2M: need to annotate the query in order to get the primary model
@@ -947,13 +1037,18 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
join_table = fk.model._meta.db_table
connection = connections[queryset.db]
qn = connection.ops.quote_name
- queryset = queryset.extra(select={
- '_prefetch_related_val_%s' % f.attname:
- '%s.%s' % (qn(join_table), qn(f.column)) for f in fk.local_related_fields})
+ queryset = queryset.extra(
+ select={
+ "_prefetch_related_val_%s"
+ % f.attname: "%s.%s"
+ % (qn(join_table), qn(f.column))
+ for f in fk.local_related_fields
+ }
+ )
return (
queryset,
lambda result: tuple(
- getattr(result, '_prefetch_related_val_%s' % f.attname)
+ getattr(result, "_prefetch_related_val_%s" % f.attname)
for f in fk.local_related_fields
),
lambda inst: tuple(
@@ -970,7 +1065,9 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
db = router.db_for_write(self.through, instance=self.instance)
with transaction.atomic(using=db, savepoint=False):
self._add_items(
- self.source_field_name, self.target_field_name, *objs,
+ self.source_field_name,
+ self.target_field_name,
+ *objs,
through_defaults=through_defaults,
)
# If this is a symmetrical m2m relation to self, add the mirror
@@ -982,30 +1079,41 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
*objs,
through_defaults=through_defaults,
)
+
add.alters_data = True
def remove(self, *objs):
self._remove_prefetched_objects()
self._remove_items(self.source_field_name, self.target_field_name, *objs)
+
remove.alters_data = True
def clear(self):
db = router.db_for_write(self.through, instance=self.instance)
with transaction.atomic(using=db, savepoint=False):
signals.m2m_changed.send(
- sender=self.through, action="pre_clear",
- instance=self.instance, reverse=self.reverse,
- model=self.model, pk_set=None, using=db,
+ sender=self.through,
+ action="pre_clear",
+ instance=self.instance,
+ reverse=self.reverse,
+ model=self.model,
+ pk_set=None,
+ using=db,
)
self._remove_prefetched_objects()
filters = self._build_remove_filters(super().get_queryset().using(db))
self.through._default_manager.using(db).filter(filters).delete()
signals.m2m_changed.send(
- sender=self.through, action="post_clear",
- instance=self.instance, reverse=self.reverse,
- model=self.model, pk_set=None, using=db,
+ sender=self.through,
+ action="post_clear",
+ instance=self.instance,
+ reverse=self.reverse,
+ model=self.model,
+ pk_set=None,
+ using=db,
)
+
clear.alters_data = True
def set(self, objs, *, clear=False, through_defaults=None):
@@ -1019,7 +1127,11 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
self.clear()
self.add(*objs, through_defaults=through_defaults)
else:
- old_ids = set(self.using(db).values_list(self.target_field.target_field.attname, flat=True))
+ old_ids = set(
+ self.using(db).values_list(
+ self.target_field.target_field.attname, flat=True
+ )
+ )
new_objs = []
for obj in objs:
@@ -1035,6 +1147,7 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
self.remove(*old_ids)
self.add(*new_objs, through_defaults=through_defaults)
+
set.alters_data = True
def create(self, *, through_defaults=None, **kwargs):
@@ -1042,26 +1155,33 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
new_obj = super(ManyRelatedManager, self.db_manager(db)).create(**kwargs)
self.add(new_obj, through_defaults=through_defaults)
return new_obj
+
create.alters_data = True
def get_or_create(self, *, through_defaults=None, **kwargs):
db = router.db_for_write(self.instance.__class__, instance=self.instance)
- obj, created = super(ManyRelatedManager, self.db_manager(db)).get_or_create(**kwargs)
+ obj, created = super(ManyRelatedManager, self.db_manager(db)).get_or_create(
+ **kwargs
+ )
# We only need to add() if created because if we got an object back
# from get() then the relationship already exists.
if created:
self.add(obj, through_defaults=through_defaults)
return obj, created
+
get_or_create.alters_data = True
def update_or_create(self, *, through_defaults=None, **kwargs):
db = router.db_for_write(self.instance.__class__, instance=self.instance)
- obj, created = super(ManyRelatedManager, self.db_manager(db)).update_or_create(**kwargs)
+ obj, created = super(
+ ManyRelatedManager, self.db_manager(db)
+ ).update_or_create(**kwargs)
# We only need to add() if created because if we got an object back
# from get() then the relationship already exists.
if created:
self.add(obj, through_defaults=through_defaults)
return obj, created
+
update_or_create.alters_data = True
def _get_target_ids(self, target_field_name, objs):
@@ -1069,6 +1189,7 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
Return the set of ids of `objs` that the target field references.
"""
from django.db.models import Model
+
target_ids = set()
target_field = self.through._meta.get_field(target_field_name)
for obj in objs:
@@ -1076,36 +1197,42 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
if not router.allow_relation(obj, self.instance):
raise ValueError(
'Cannot add "%r": instance is on database "%s", '
- 'value is on database "%s"' %
- (obj, self.instance._state.db, obj._state.db)
+ 'value is on database "%s"'
+ % (obj, self.instance._state.db, obj._state.db)
)
target_id = target_field.get_foreign_related_value(obj)[0]
if target_id is None:
raise ValueError(
- 'Cannot add "%r": the value for field "%s" is None' %
- (obj, target_field_name)
+ 'Cannot add "%r": the value for field "%s" is None'
+ % (obj, target_field_name)
)
target_ids.add(target_id)
elif isinstance(obj, Model):
raise TypeError(
- "'%s' instance expected, got %r" %
- (self.model._meta.object_name, obj)
+ "'%s' instance expected, got %r"
+ % (self.model._meta.object_name, obj)
)
else:
target_ids.add(target_field.get_prep_value(obj))
return target_ids
- def _get_missing_target_ids(self, source_field_name, target_field_name, db, target_ids):
+ def _get_missing_target_ids(
+ self, source_field_name, target_field_name, db, target_ids
+ ):
"""
Return the subset of ids of `objs` that aren't already assigned to
this relationship.
"""
- vals = self.through._default_manager.using(db).values_list(
- target_field_name, flat=True
- ).filter(**{
- source_field_name: self.related_val[0],
- '%s__in' % target_field_name: target_ids,
- })
+ vals = (
+ self.through._default_manager.using(db)
+ .values_list(target_field_name, flat=True)
+ .filter(
+ **{
+ source_field_name: self.related_val[0],
+ "%s__in" % target_field_name: target_ids,
+ }
+ )
+ )
return target_ids.difference(vals)
def _get_add_plan(self, db, source_field_name):
@@ -1123,21 +1250,27 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
# user-defined intermediary models as they could have other fields
# causing conflicts which must be surfaced.
can_ignore_conflicts = (
- self.through._meta.auto_created is not False and
- connections[db].features.supports_ignore_conflicts
+ self.through._meta.auto_created is not False
+ and connections[db].features.supports_ignore_conflicts
)
# Don't send the signal when inserting duplicate data row
# for symmetrical reverse entries.
- must_send_signals = (self.reverse or source_field_name == self.source_field_name) and (
- signals.m2m_changed.has_listeners(self.through)
- )
+ must_send_signals = (
+ self.reverse or source_field_name == self.source_field_name
+ ) and (signals.m2m_changed.has_listeners(self.through))
# Fast addition through bulk insertion can only be performed
# if no m2m_changed listeners are connected for self.through
# as they require the added set of ids to be provided via
# pk_set.
- return can_ignore_conflicts, must_send_signals, (can_ignore_conflicts and not must_send_signals)
+ return (
+ can_ignore_conflicts,
+ must_send_signals,
+ (can_ignore_conflicts and not must_send_signals),
+ )
- def _add_items(self, source_field_name, target_field_name, *objs, through_defaults=None):
+ def _add_items(
+ self, source_field_name, target_field_name, *objs, through_defaults=None
+ ):
# source_field_name: the PK fieldname in join table for the source object
# target_field_name: the PK fieldname in join table for the target object
# *objs - objects to add. Either object instances, or primary keys of object instances.
@@ -1147,15 +1280,22 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
through_defaults = dict(resolve_callables(through_defaults or {}))
target_ids = self._get_target_ids(target_field_name, objs)
db = router.db_for_write(self.through, instance=self.instance)
- can_ignore_conflicts, must_send_signals, can_fast_add = self._get_add_plan(db, source_field_name)
+ can_ignore_conflicts, must_send_signals, can_fast_add = self._get_add_plan(
+ db, source_field_name
+ )
if can_fast_add:
- self.through._default_manager.using(db).bulk_create([
- self.through(**{
- '%s_id' % source_field_name: self.related_val[0],
- '%s_id' % target_field_name: target_id,
- })
- for target_id in target_ids
- ], ignore_conflicts=True)
+ self.through._default_manager.using(db).bulk_create(
+ [
+ self.through(
+ **{
+ "%s_id" % source_field_name: self.related_val[0],
+ "%s_id" % target_field_name: target_id,
+ }
+ )
+ for target_id in target_ids
+ ],
+ ignore_conflicts=True,
+ )
return
missing_target_ids = self._get_missing_target_ids(
@@ -1164,24 +1304,38 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
with transaction.atomic(using=db, savepoint=False):
if must_send_signals:
signals.m2m_changed.send(
- sender=self.through, action='pre_add',
- instance=self.instance, reverse=self.reverse,
- model=self.model, pk_set=missing_target_ids, using=db,
+ sender=self.through,
+ action="pre_add",
+ instance=self.instance,
+ reverse=self.reverse,
+ model=self.model,
+ pk_set=missing_target_ids,
+ using=db,
)
# Add the ones that aren't there already.
- self.through._default_manager.using(db).bulk_create([
- self.through(**through_defaults, **{
- '%s_id' % source_field_name: self.related_val[0],
- '%s_id' % target_field_name: target_id,
- })
- for target_id in missing_target_ids
- ], ignore_conflicts=can_ignore_conflicts)
+ self.through._default_manager.using(db).bulk_create(
+ [
+ self.through(
+ **through_defaults,
+ **{
+ "%s_id" % source_field_name: self.related_val[0],
+ "%s_id" % target_field_name: target_id,
+ },
+ )
+ for target_id in missing_target_ids
+ ],
+ ignore_conflicts=can_ignore_conflicts,
+ )
if must_send_signals:
signals.m2m_changed.send(
- sender=self.through, action='post_add',
- instance=self.instance, reverse=self.reverse,
- model=self.model, pk_set=missing_target_ids, using=db,
+ sender=self.through,
+ action="post_add",
+ instance=self.instance,
+ reverse=self.reverse,
+ model=self.model,
+ pk_set=missing_target_ids,
+ using=db,
)
def _remove_items(self, source_field_name, target_field_name, *objs):
@@ -1205,23 +1359,32 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
with transaction.atomic(using=db, savepoint=False):
# Send a signal to the other end if need be.
signals.m2m_changed.send(
- sender=self.through, action="pre_remove",
- instance=self.instance, reverse=self.reverse,
- model=self.model, pk_set=old_ids, using=db,
+ sender=self.through,
+ action="pre_remove",
+ instance=self.instance,
+ reverse=self.reverse,
+ model=self.model,
+ pk_set=old_ids,
+ using=db,
)
target_model_qs = super().get_queryset()
if target_model_qs._has_filters():
- old_vals = target_model_qs.using(db).filter(**{
- '%s__in' % self.target_field.target_field.attname: old_ids})
+ old_vals = target_model_qs.using(db).filter(
+ **{"%s__in" % self.target_field.target_field.attname: old_ids}
+ )
else:
old_vals = old_ids
filters = self._build_remove_filters(old_vals)
self.through._default_manager.using(db).filter(filters).delete()
signals.m2m_changed.send(
- sender=self.through, action="post_remove",
- instance=self.instance, reverse=self.reverse,
- model=self.model, pk_set=old_ids, using=db,
+ sender=self.through,
+ action="post_remove",
+ instance=self.instance,
+ reverse=self.reverse,
+ model=self.model,
+ pk_set=old_ids,
+ using=db,
)
return ManyRelatedManager
diff --git a/django/db/models/fields/related_lookups.py b/django/db/models/fields/related_lookups.py
index fd97757b14..1bad1cf416 100644
--- a/django/db/models/fields/related_lookups.py
+++ b/django/db/models/fields/related_lookups.py
@@ -1,5 +1,10 @@
from django.db.models.lookups import (
- Exact, GreaterThan, GreaterThanOrEqual, In, IsNull, LessThan,
+ Exact,
+ GreaterThan,
+ GreaterThanOrEqual,
+ In,
+ IsNull,
+ LessThan,
LessThanOrEqual,
)
@@ -8,16 +13,21 @@ class MultiColSource:
contains_aggregate = False
def __init__(self, alias, targets, sources, field):
- self.targets, self.sources, self.field, self.alias = targets, sources, field, alias
+ self.targets, self.sources, self.field, self.alias = (
+ targets,
+ sources,
+ field,
+ alias,
+ )
self.output_field = self.field
def __repr__(self):
- return "{}({}, {})".format(
- self.__class__.__name__, self.alias, self.field)
+ return "{}({}, {})".format(self.__class__.__name__, self.alias, self.field)
def relabeled_clone(self, relabels):
- return self.__class__(relabels.get(self.alias, self.alias),
- self.targets, self.sources, self.field)
+ return self.__class__(
+ relabels.get(self.alias, self.alias), self.targets, self.sources, self.field
+ )
def get_lookup(self, lookup):
return self.output_field.get_lookup(lookup)
@@ -28,12 +38,15 @@ class MultiColSource:
def get_normalized_value(value, lhs):
from django.db.models import Model
+
if isinstance(value, Model):
value_list = []
sources = lhs.output_field.path_infos[-1].target_fields
for source in sources:
while not isinstance(value, source.model) and source.remote_field:
- source = source.remote_field.model._meta.get_field(source.remote_field.field_name)
+ source = source.remote_field.model._meta.get_field(
+ source.remote_field.field_name
+ )
try:
value_list.append(getattr(value, source.attname))
except AttributeError:
@@ -56,20 +69,21 @@ class RelatedIn(In):
# case ForeignKey to IntegerField given value 'abc'. The
# ForeignKey itself doesn't have validation for non-integers,
# so we must run validation using the target field.
- if hasattr(self.lhs.output_field, 'path_infos'):
+ if hasattr(self.lhs.output_field, "path_infos"):
# Run the target field's get_prep_value. We can safely
# assume there is only one as we don't get to the direct
# value branch otherwise.
- target_field = self.lhs.output_field.path_infos[-1].target_fields[-1]
+ target_field = self.lhs.output_field.path_infos[-1].target_fields[
+ -1
+ ]
self.rhs = [target_field.get_prep_value(v) for v in self.rhs]
- elif (
- not getattr(self.rhs, 'has_select_fields', True) and
- not getattr(self.lhs.field.target_field, 'primary_key', False)
+ elif not getattr(self.rhs, "has_select_fields", True) and not getattr(
+ self.lhs.field.target_field, "primary_key", False
):
self.rhs.clear_select_clause()
if (
- getattr(self.lhs.output_field, 'primary_key', False) and
- self.lhs.output_field.model == self.rhs.model
+ getattr(self.lhs.output_field, "primary_key", False)
+ and self.lhs.output_field.model == self.rhs.model
):
# A case like
# Restaurant.objects.filter(place__in=restaurant_qs), where
@@ -87,7 +101,10 @@ class RelatedIn(In):
# This clause is either a SubqueryConstraint (for values that need to be compiled to
# SQL) or an OR-combined list of (col1 = val1 AND col2 = val2 AND ...) clauses.
from django.db.models.sql.where import (
- AND, OR, SubqueryConstraint, WhereNode,
+ AND,
+ OR,
+ SubqueryConstraint,
+ WhereNode,
)
root_constraint = WhereNode(connector=OR)
@@ -95,31 +112,41 @@ class RelatedIn(In):
values = [get_normalized_value(value, self.lhs) for value in self.rhs]
for value in values:
value_constraint = WhereNode()
- for source, target, val in zip(self.lhs.sources, self.lhs.targets, value):
- lookup_class = target.get_lookup('exact')
- lookup = lookup_class(target.get_col(self.lhs.alias, source), val)
+ for source, target, val in zip(
+ self.lhs.sources, self.lhs.targets, value
+ ):
+ lookup_class = target.get_lookup("exact")
+ lookup = lookup_class(
+ target.get_col(self.lhs.alias, source), val
+ )
value_constraint.add(lookup, AND)
root_constraint.add(value_constraint, OR)
else:
root_constraint.add(
SubqueryConstraint(
- self.lhs.alias, [target.column for target in self.lhs.targets],
- [source.name for source in self.lhs.sources], self.rhs),
- AND)
+ self.lhs.alias,
+ [target.column for target in self.lhs.targets],
+ [source.name for source in self.lhs.sources],
+ self.rhs,
+ ),
+ AND,
+ )
return root_constraint.as_sql(compiler, connection)
return super().as_sql(compiler, connection)
class RelatedLookupMixin:
def get_prep_lookup(self):
- if not isinstance(self.lhs, MultiColSource) and not hasattr(self.rhs, 'resolve_expression'):
+ if not isinstance(self.lhs, MultiColSource) and not hasattr(
+ self.rhs, "resolve_expression"
+ ):
# If we get here, we are dealing with single-column relations.
self.rhs = get_normalized_value(self.rhs, self.lhs)[0]
# We need to run the related field's get_prep_value(). Consider case
# ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
# doesn't have validation for non-integers, so we must run validation
# using the target field.
- if self.prepare_rhs and hasattr(self.lhs.output_field, 'path_infos'):
+ if self.prepare_rhs and hasattr(self.lhs.output_field, "path_infos"):
# Get the target field. We can safely assume there is only one
# as we don't get to the direct value branch otherwise.
target_field = self.lhs.output_field.path_infos[-1].target_fields[-1]
@@ -132,11 +159,15 @@ class RelatedLookupMixin:
assert self.rhs_is_direct_value()
self.rhs = get_normalized_value(self.rhs, self.lhs)
from django.db.models.sql.where import AND, WhereNode
+
root_constraint = WhereNode()
- for target, source, val in zip(self.lhs.targets, self.lhs.sources, self.rhs):
+ for target, source, val in zip(
+ self.lhs.targets, self.lhs.sources, self.rhs
+ ):
lookup_class = target.get_lookup(self.lookup_name)
root_constraint.add(
- lookup_class(target.get_col(self.lhs.alias, source), val), AND)
+ lookup_class(target.get_col(self.lhs.alias, source), val), AND
+ )
return root_constraint.as_sql(compiler, connection)
return super().as_sql(compiler, connection)
diff --git a/django/db/models/fields/reverse_related.py b/django/db/models/fields/reverse_related.py
index 6f0c788bbd..2ff66f34d0 100644
--- a/django/db/models/fields/reverse_related.py
+++ b/django/db/models/fields/reverse_related.py
@@ -36,8 +36,16 @@ class ForeignObjectRel(FieldCacheMixin):
null = True
empty_strings_allowed = False
- def __init__(self, field, to, related_name=None, related_query_name=None,
- limit_choices_to=None, parent_link=False, on_delete=None):
+ def __init__(
+ self,
+ field,
+ to,
+ related_name=None,
+ related_query_name=None,
+ limit_choices_to=None,
+ parent_link=False,
+ on_delete=None,
+ ):
self.field = field
self.model = to
self.related_name = related_name
@@ -73,14 +81,17 @@ class ForeignObjectRel(FieldCacheMixin):
"""
target_fields = self.path_infos[-1].target_fields
if len(target_fields) > 1:
- raise exceptions.FieldError("Can't use target_field for multicolumn relations.")
+ raise exceptions.FieldError(
+ "Can't use target_field for multicolumn relations."
+ )
return target_fields[0]
@cached_property
def related_model(self):
if not self.field.model:
raise AttributeError(
- "This property can't be accessed before self.field.contribute_to_class has been called.")
+ "This property can't be accessed before self.field.contribute_to_class has been called."
+ )
return self.field.model
@cached_property
@@ -110,7 +121,7 @@ class ForeignObjectRel(FieldCacheMixin):
return self.field.db_type
def __repr__(self):
- return '<%s: %s.%s>' % (
+ return "<%s: %s.%s>" % (
type(self).__name__,
self.related_model._meta.app_label,
self.related_model._meta.model_name,
@@ -147,12 +158,15 @@ class ForeignObjectRel(FieldCacheMixin):
# created and doesn't exist in the .models module.
# This is a reverse relation, so there is no reverse_path_infos to
# delete.
- state.pop('path_infos', None)
+ state.pop("path_infos", None)
return state
def get_choices(
- self, include_blank=True, blank_choice=BLANK_CHOICE_DASH,
- limit_choices_to=None, ordering=(),
+ self,
+ include_blank=True,
+ blank_choice=BLANK_CHOICE_DASH,
+ limit_choices_to=None,
+ ordering=(),
):
"""
Return choices with a default blank choices included, for use
@@ -165,13 +179,11 @@ class ForeignObjectRel(FieldCacheMixin):
qs = self.related_model._default_manager.complex_filter(limit_choices_to)
if ordering:
qs = qs.order_by(*ordering)
- return (blank_choice if include_blank else []) + [
- (x.pk, str(x)) for x in qs
- ]
+ return (blank_choice if include_blank else []) + [(x.pk, str(x)) for x in qs]
def is_hidden(self):
"""Should the related object be hidden?"""
- return bool(self.related_name) and self.related_name[-1] == '+'
+ return bool(self.related_name) and self.related_name[-1] == "+"
def get_joining_columns(self):
return self.field.get_reverse_joining_columns()
@@ -204,7 +216,7 @@ class ForeignObjectRel(FieldCacheMixin):
return None
if self.related_name:
return self.related_name
- return opts.model_name + ('_set' if self.multiple else '')
+ return opts.model_name + ("_set" if self.multiple else "")
def get_path_info(self, filtered_relation=None):
if filtered_relation:
@@ -239,10 +251,20 @@ class ManyToOneRel(ForeignObjectRel):
reverse relations into actual fields.
"""
- def __init__(self, field, to, field_name, related_name=None, related_query_name=None,
- limit_choices_to=None, parent_link=False, on_delete=None):
+ def __init__(
+ self,
+ field,
+ to,
+ field_name,
+ related_name=None,
+ related_query_name=None,
+ limit_choices_to=None,
+ parent_link=False,
+ on_delete=None,
+ ):
super().__init__(
- field, to,
+ field,
+ to,
related_name=related_name,
related_query_name=related_query_name,
limit_choices_to=limit_choices_to,
@@ -254,7 +276,7 @@ class ManyToOneRel(ForeignObjectRel):
def __getstate__(self):
state = super().__getstate__()
- state.pop('related_model', None)
+ state.pop("related_model", None)
return state
@property
@@ -267,7 +289,9 @@ class ManyToOneRel(ForeignObjectRel):
"""
field = self.model._meta.get_field(self.field_name)
if not field.concrete:
- raise exceptions.FieldDoesNotExist("No related field named '%s'" % self.field_name)
+ raise exceptions.FieldDoesNotExist(
+ "No related field named '%s'" % self.field_name
+ )
return field
def set_field_name(self):
@@ -282,10 +306,21 @@ class OneToOneRel(ManyToOneRel):
flags for the reverse relation.
"""
- def __init__(self, field, to, field_name, related_name=None, related_query_name=None,
- limit_choices_to=None, parent_link=False, on_delete=None):
+ def __init__(
+ self,
+ field,
+ to,
+ field_name,
+ related_name=None,
+ related_query_name=None,
+ limit_choices_to=None,
+ parent_link=False,
+ on_delete=None,
+ ):
super().__init__(
- field, to, field_name,
+ field,
+ to,
+ field_name,
related_name=related_name,
related_query_name=related_query_name,
limit_choices_to=limit_choices_to,
@@ -304,11 +339,21 @@ class ManyToManyRel(ForeignObjectRel):
flags for the reverse relation.
"""
- def __init__(self, field, to, related_name=None, related_query_name=None,
- limit_choices_to=None, symmetrical=True, through=None,
- through_fields=None, db_constraint=True):
+ def __init__(
+ self,
+ field,
+ to,
+ related_name=None,
+ related_query_name=None,
+ limit_choices_to=None,
+ symmetrical=True,
+ through=None,
+ through_fields=None,
+ db_constraint=True,
+ ):
super().__init__(
- field, to,
+ field,
+ to,
related_name=related_name,
related_query_name=related_query_name,
limit_choices_to=limit_choices_to,
@@ -343,7 +388,7 @@ class ManyToManyRel(ForeignObjectRel):
field = opts.get_field(self.through_fields[0])
else:
for field in opts.fields:
- rel = getattr(field, 'remote_field', None)
+ rel = getattr(field, "remote_field", None)
if rel and rel.model == self.model:
break
return field.foreign_related_fields[0]
diff --git a/django/db/models/functions/__init__.py b/django/db/models/functions/__init__.py
index d687af135d..cd7c801894 100644
--- a/django/db/models/functions/__init__.py
+++ b/django/db/models/functions/__init__.py
@@ -1,46 +1,190 @@
-from .comparison import (
- Cast, Coalesce, Collate, Greatest, JSONObject, Least, NullIf,
-)
+from .comparison import Cast, Coalesce, Collate, Greatest, JSONObject, Least, NullIf
from .datetime import (
- Extract, ExtractDay, ExtractHour, ExtractIsoWeekDay, ExtractIsoYear,
- ExtractMinute, ExtractMonth, ExtractQuarter, ExtractSecond, ExtractWeek,
- ExtractWeekDay, ExtractYear, Now, Trunc, TruncDate, TruncDay, TruncHour,
- TruncMinute, TruncMonth, TruncQuarter, TruncSecond, TruncTime, TruncWeek,
+ Extract,
+ ExtractDay,
+ ExtractHour,
+ ExtractIsoWeekDay,
+ ExtractIsoYear,
+ ExtractMinute,
+ ExtractMonth,
+ ExtractQuarter,
+ ExtractSecond,
+ ExtractWeek,
+ ExtractWeekDay,
+ ExtractYear,
+ Now,
+ Trunc,
+ TruncDate,
+ TruncDay,
+ TruncHour,
+ TruncMinute,
+ TruncMonth,
+ TruncQuarter,
+ TruncSecond,
+ TruncTime,
+ TruncWeek,
TruncYear,
)
from .math import (
- Abs, ACos, ASin, ATan, ATan2, Ceil, Cos, Cot, Degrees, Exp, Floor, Ln, Log,
- Mod, Pi, Power, Radians, Random, Round, Sign, Sin, Sqrt, Tan,
+ Abs,
+ ACos,
+ ASin,
+ ATan,
+ ATan2,
+ Ceil,
+ Cos,
+ Cot,
+ Degrees,
+ Exp,
+ Floor,
+ Ln,
+ Log,
+ Mod,
+ Pi,
+ Power,
+ Radians,
+ Random,
+ Round,
+ Sign,
+ Sin,
+ Sqrt,
+ Tan,
)
from .text import (
- MD5, SHA1, SHA224, SHA256, SHA384, SHA512, Chr, Concat, ConcatPair, Left,
- Length, Lower, LPad, LTrim, Ord, Repeat, Replace, Reverse, Right, RPad,
- RTrim, StrIndex, Substr, Trim, Upper,
+ MD5,
+ SHA1,
+ SHA224,
+ SHA256,
+ SHA384,
+ SHA512,
+ Chr,
+ Concat,
+ ConcatPair,
+ Left,
+ Length,
+ Lower,
+ LPad,
+ LTrim,
+ Ord,
+ Repeat,
+ Replace,
+ Reverse,
+ Right,
+ RPad,
+ RTrim,
+ StrIndex,
+ Substr,
+ Trim,
+ Upper,
)
from .window import (
- CumeDist, DenseRank, FirstValue, Lag, LastValue, Lead, NthValue, Ntile,
- PercentRank, Rank, RowNumber,
+ CumeDist,
+ DenseRank,
+ FirstValue,
+ Lag,
+ LastValue,
+ Lead,
+ NthValue,
+ Ntile,
+ PercentRank,
+ Rank,
+ RowNumber,
)
__all__ = [
# comparison and conversion
- 'Cast', 'Coalesce', 'Collate', 'Greatest', 'JSONObject', 'Least', 'NullIf',
+ "Cast",
+ "Coalesce",
+ "Collate",
+ "Greatest",
+ "JSONObject",
+ "Least",
+ "NullIf",
# datetime
- 'Extract', 'ExtractDay', 'ExtractHour', 'ExtractMinute', 'ExtractMonth',
- 'ExtractQuarter', 'ExtractSecond', 'ExtractWeek', 'ExtractIsoWeekDay',
- 'ExtractWeekDay', 'ExtractIsoYear', 'ExtractYear', 'Now', 'Trunc',
- 'TruncDate', 'TruncDay', 'TruncHour', 'TruncMinute', 'TruncMonth',
- 'TruncQuarter', 'TruncSecond', 'TruncTime', 'TruncWeek', 'TruncYear',
+ "Extract",
+ "ExtractDay",
+ "ExtractHour",
+ "ExtractMinute",
+ "ExtractMonth",
+ "ExtractQuarter",
+ "ExtractSecond",
+ "ExtractWeek",
+ "ExtractIsoWeekDay",
+ "ExtractWeekDay",
+ "ExtractIsoYear",
+ "ExtractYear",
+ "Now",
+ "Trunc",
+ "TruncDate",
+ "TruncDay",
+ "TruncHour",
+ "TruncMinute",
+ "TruncMonth",
+ "TruncQuarter",
+ "TruncSecond",
+ "TruncTime",
+ "TruncWeek",
+ "TruncYear",
# math
- 'Abs', 'ACos', 'ASin', 'ATan', 'ATan2', 'Ceil', 'Cos', 'Cot', 'Degrees',
- 'Exp', 'Floor', 'Ln', 'Log', 'Mod', 'Pi', 'Power', 'Radians', 'Random',
- 'Round', 'Sign', 'Sin', 'Sqrt', 'Tan',
+ "Abs",
+ "ACos",
+ "ASin",
+ "ATan",
+ "ATan2",
+ "Ceil",
+ "Cos",
+ "Cot",
+ "Degrees",
+ "Exp",
+ "Floor",
+ "Ln",
+ "Log",
+ "Mod",
+ "Pi",
+ "Power",
+ "Radians",
+ "Random",
+ "Round",
+ "Sign",
+ "Sin",
+ "Sqrt",
+ "Tan",
# text
- 'MD5', 'SHA1', 'SHA224', 'SHA256', 'SHA384', 'SHA512', 'Chr', 'Concat',
- 'ConcatPair', 'Left', 'Length', 'Lower', 'LPad', 'LTrim', 'Ord', 'Repeat',
- 'Replace', 'Reverse', 'Right', 'RPad', 'RTrim', 'StrIndex', 'Substr',
- 'Trim', 'Upper',
+ "MD5",
+ "SHA1",
+ "SHA224",
+ "SHA256",
+ "SHA384",
+ "SHA512",
+ "Chr",
+ "Concat",
+ "ConcatPair",
+ "Left",
+ "Length",
+ "Lower",
+ "LPad",
+ "LTrim",
+ "Ord",
+ "Repeat",
+ "Replace",
+ "Reverse",
+ "Right",
+ "RPad",
+ "RTrim",
+ "StrIndex",
+ "Substr",
+ "Trim",
+ "Upper",
# window
- 'CumeDist', 'DenseRank', 'FirstValue', 'Lag', 'LastValue', 'Lead',
- 'NthValue', 'Ntile', 'PercentRank', 'Rank', 'RowNumber',
+ "CumeDist",
+ "DenseRank",
+ "FirstValue",
+ "Lag",
+ "LastValue",
+ "Lead",
+ "NthValue",
+ "Ntile",
+ "PercentRank",
+ "Rank",
+ "RowNumber",
]
diff --git a/django/db/models/functions/comparison.py b/django/db/models/functions/comparison.py
index e5882de9c2..cc78834f20 100644
--- a/django/db/models/functions/comparison.py
+++ b/django/db/models/functions/comparison.py
@@ -7,38 +7,43 @@ from django.utils.regex_helper import _lazy_re_compile
class Cast(Func):
"""Coerce an expression to a new field type."""
- function = 'CAST'
- template = '%(function)s(%(expressions)s AS %(db_type)s)'
+
+ function = "CAST"
+ template = "%(function)s(%(expressions)s AS %(db_type)s)"
def __init__(self, expression, output_field):
super().__init__(expression, output_field=output_field)
def as_sql(self, compiler, connection, **extra_context):
- extra_context['db_type'] = self.output_field.cast_db_type(connection)
+ extra_context["db_type"] = self.output_field.cast_db_type(connection)
return super().as_sql(compiler, connection, **extra_context)
def as_sqlite(self, compiler, connection, **extra_context):
db_type = self.output_field.db_type(connection)
- if db_type in {'datetime', 'time'}:
+ if db_type in {"datetime", "time"}:
# Use strftime as datetime/time don't keep fractional seconds.
- template = 'strftime(%%s, %(expressions)s)'
- sql, params = super().as_sql(compiler, connection, template=template, **extra_context)
- format_string = '%H:%M:%f' if db_type == 'time' else '%Y-%m-%d %H:%M:%f'
+ template = "strftime(%%s, %(expressions)s)"
+ sql, params = super().as_sql(
+ compiler, connection, template=template, **extra_context
+ )
+ format_string = "%H:%M:%f" if db_type == "time" else "%Y-%m-%d %H:%M:%f"
params.insert(0, format_string)
return sql, params
- elif db_type == 'date':
- template = 'date(%(expressions)s)'
- return super().as_sql(compiler, connection, template=template, **extra_context)
+ elif db_type == "date":
+ template = "date(%(expressions)s)"
+ return super().as_sql(
+ compiler, connection, template=template, **extra_context
+ )
return self.as_sql(compiler, connection, **extra_context)
def as_mysql(self, compiler, connection, **extra_context):
template = None
output_type = self.output_field.get_internal_type()
# MySQL doesn't support explicit cast to float.
- if output_type == 'FloatField':
- template = '(%(expressions)s + 0.0)'
+ if output_type == "FloatField":
+ template = "(%(expressions)s + 0.0)"
# MariaDB doesn't support explicit cast to JSON.
- elif output_type == 'JSONField' and connection.mysql_is_mariadb:
+ elif output_type == "JSONField" and connection.mysql_is_mariadb:
template = "JSON_EXTRACT(%(expressions)s, '$')"
return self.as_sql(compiler, connection, template=template, **extra_context)
@@ -46,23 +51,31 @@ class Cast(Func):
# CAST would be valid too, but the :: shortcut syntax is more readable.
# 'expressions' is wrapped in parentheses in case it's a complex
# expression.
- return self.as_sql(compiler, connection, template='(%(expressions)s)::%(db_type)s', **extra_context)
+ return self.as_sql(
+ compiler,
+ connection,
+ template="(%(expressions)s)::%(db_type)s",
+ **extra_context,
+ )
def as_oracle(self, compiler, connection, **extra_context):
- if self.output_field.get_internal_type() == 'JSONField':
+ if self.output_field.get_internal_type() == "JSONField":
# Oracle doesn't support explicit cast to JSON.
template = "JSON_QUERY(%(expressions)s, '$')"
- return super().as_sql(compiler, connection, template=template, **extra_context)
+ return super().as_sql(
+ compiler, connection, template=template, **extra_context
+ )
return self.as_sql(compiler, connection, **extra_context)
class Coalesce(Func):
"""Return, from left to right, the first non-null expression."""
- function = 'COALESCE'
+
+ function = "COALESCE"
def __init__(self, *expressions, **extra):
if len(expressions) < 2:
- raise ValueError('Coalesce must take at least two expressions')
+ raise ValueError("Coalesce must take at least two expressions")
super().__init__(*expressions, **extra)
@property
@@ -76,29 +89,32 @@ class Coalesce(Func):
def as_oracle(self, compiler, connection, **extra_context):
# Oracle prohibits mixing TextField (NCLOB) and CharField (NVARCHAR2),
# so convert all fields to NCLOB when that type is expected.
- if self.output_field.get_internal_type() == 'TextField':
+ if self.output_field.get_internal_type() == "TextField":
clone = self.copy()
- clone.set_source_expressions([
- Func(expression, function='TO_NCLOB') for expression in self.get_source_expressions()
- ])
+ clone.set_source_expressions(
+ [
+ Func(expression, function="TO_NCLOB")
+ for expression in self.get_source_expressions()
+ ]
+ )
return super(Coalesce, clone).as_sql(compiler, connection, **extra_context)
return self.as_sql(compiler, connection, **extra_context)
class Collate(Func):
- function = 'COLLATE'
- template = '%(expressions)s %(function)s %(collation)s'
+ function = "COLLATE"
+ template = "%(expressions)s %(function)s %(collation)s"
# Inspired from https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
- collation_re = _lazy_re_compile(r'^[\w\-]+$')
+ collation_re = _lazy_re_compile(r"^[\w\-]+$")
def __init__(self, expression, collation):
if not (collation and self.collation_re.match(collation)):
- raise ValueError('Invalid collation name: %r.' % collation)
+ raise ValueError("Invalid collation name: %r." % collation)
self.collation = collation
super().__init__(expression)
def as_sql(self, compiler, connection, **extra_context):
- extra_context.setdefault('collation', connection.ops.quote_name(self.collation))
+ extra_context.setdefault("collation", connection.ops.quote_name(self.collation))
return super().as_sql(compiler, connection, **extra_context)
@@ -110,20 +126,21 @@ class Greatest(Func):
On PostgreSQL, the maximum not-null expression is returned.
On MySQL, Oracle, and SQLite, if any expression is null, null is returned.
"""
- function = 'GREATEST'
+
+ function = "GREATEST"
def __init__(self, *expressions, **extra):
if len(expressions) < 2:
- raise ValueError('Greatest must take at least two expressions')
+ raise ValueError("Greatest must take at least two expressions")
super().__init__(*expressions, **extra)
def as_sqlite(self, compiler, connection, **extra_context):
"""Use the MAX function on SQLite."""
- return super().as_sqlite(compiler, connection, function='MAX', **extra_context)
+ return super().as_sqlite(compiler, connection, function="MAX", **extra_context)
class JSONObject(Func):
- function = 'JSON_OBJECT'
+ function = "JSON_OBJECT"
output_field = JSONField()
def __init__(self, **fields):
@@ -135,7 +152,7 @@ class JSONObject(Func):
def as_sql(self, compiler, connection, **extra_context):
if not connection.features.has_json_object_function:
raise NotSupportedError(
- 'JSONObject() is not supported on this database backend.'
+ "JSONObject() is not supported on this database backend."
)
return super().as_sql(compiler, connection, **extra_context)
@@ -143,21 +160,21 @@ class JSONObject(Func):
return self.as_sql(
compiler,
connection,
- function='JSONB_BUILD_OBJECT',
+ function="JSONB_BUILD_OBJECT",
**extra_context,
)
def as_oracle(self, compiler, connection, **extra_context):
class ArgJoiner:
def join(self, args):
- args = [' VALUE '.join(arg) for arg in zip(args[::2], args[1::2])]
- return ', '.join(args)
+ args = [" VALUE ".join(arg) for arg in zip(args[::2], args[1::2])]
+ return ", ".join(args)
return self.as_sql(
compiler,
connection,
arg_joiner=ArgJoiner(),
- template='%(function)s(%(expressions)s RETURNING CLOB)',
+ template="%(function)s(%(expressions)s RETURNING CLOB)",
**extra_context,
)
@@ -170,24 +187,25 @@ class Least(Func):
On PostgreSQL, return the minimum not-null expression.
On MySQL, Oracle, and SQLite, if any expression is null, return null.
"""
- function = 'LEAST'
+
+ function = "LEAST"
def __init__(self, *expressions, **extra):
if len(expressions) < 2:
- raise ValueError('Least must take at least two expressions')
+ raise ValueError("Least must take at least two expressions")
super().__init__(*expressions, **extra)
def as_sqlite(self, compiler, connection, **extra_context):
"""Use the MIN function on SQLite."""
- return super().as_sqlite(compiler, connection, function='MIN', **extra_context)
+ return super().as_sqlite(compiler, connection, function="MIN", **extra_context)
class NullIf(Func):
- function = 'NULLIF'
+ function = "NULLIF"
arity = 2
def as_oracle(self, compiler, connection, **extra_context):
expression1 = self.get_source_expressions()[0]
if isinstance(expression1, Value) and expression1.value is None:
- raise ValueError('Oracle does not allow Value(None) for expression1.')
+ raise ValueError("Oracle does not allow Value(None) for expression1.")
return super().as_sql(compiler, connection, **extra_context)
diff --git a/django/db/models/functions/datetime.py b/django/db/models/functions/datetime.py
index 07f884f78d..2d6ec7089e 100644
--- a/django/db/models/functions/datetime.py
+++ b/django/db/models/functions/datetime.py
@@ -3,10 +3,20 @@ from datetime import datetime
from django.conf import settings
from django.db.models.expressions import Func
from django.db.models.fields import (
- DateField, DateTimeField, DurationField, Field, IntegerField, TimeField,
+ DateField,
+ DateTimeField,
+ DurationField,
+ Field,
+ IntegerField,
+ TimeField,
)
from django.db.models.lookups import (
- Transform, YearExact, YearGt, YearGte, YearLt, YearLte,
+ Transform,
+ YearExact,
+ YearGt,
+ YearGte,
+ YearLt,
+ YearLte,
)
from django.utils import timezone
@@ -36,7 +46,7 @@ class Extract(TimezoneMixin, Transform):
if self.lookup_name is None:
self.lookup_name = lookup_name
if self.lookup_name is None:
- raise ValueError('lookup_name must be provided')
+ raise ValueError("lookup_name must be provided")
self.tzinfo = tzinfo
super().__init__(expression, **extra)
@@ -47,14 +57,16 @@ class Extract(TimezoneMixin, Transform):
tzname = self.get_tzname()
sql = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname)
elif self.tzinfo is not None:
- raise ValueError('tzinfo can only be used with DateTimeField.')
+ raise ValueError("tzinfo can only be used with DateTimeField.")
elif isinstance(lhs_output_field, DateField):
sql = connection.ops.date_extract_sql(self.lookup_name, sql)
elif isinstance(lhs_output_field, TimeField):
sql = connection.ops.time_extract_sql(self.lookup_name, sql)
elif isinstance(lhs_output_field, DurationField):
if not connection.features.has_native_duration_field:
- raise ValueError('Extract requires native DurationField database support.')
+ raise ValueError(
+ "Extract requires native DurationField database support."
+ )
sql = connection.ops.time_extract_sql(self.lookup_name, sql)
else:
# resolve_expression has already validated the output_field so this
@@ -62,24 +74,38 @@ class Extract(TimezoneMixin, Transform):
assert False, "Tried to Extract from an invalid type."
return sql, params
- def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
- copy = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
- field = getattr(copy.lhs, 'output_field', None)
+ def resolve_expression(
+ self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
+ ):
+ copy = super().resolve_expression(
+ query, allow_joins, reuse, summarize, for_save
+ )
+ field = getattr(copy.lhs, "output_field", None)
if field is None:
return copy
if not isinstance(field, (DateField, DateTimeField, TimeField, DurationField)):
raise ValueError(
- 'Extract input expression must be DateField, DateTimeField, '
- 'TimeField, or DurationField.'
+ "Extract input expression must be DateField, DateTimeField, "
+ "TimeField, or DurationField."
)
# Passing dates to functions expecting datetimes is most likely a mistake.
- if type(field) == DateField and copy.lookup_name in ('hour', 'minute', 'second'):
+ if type(field) == DateField and copy.lookup_name in (
+ "hour",
+ "minute",
+ "second",
+ ):
raise ValueError(
- "Cannot extract time component '%s' from DateField '%s'." % (copy.lookup_name, field.name)
+ "Cannot extract time component '%s' from DateField '%s'."
+ % (copy.lookup_name, field.name)
)
- if (
- isinstance(field, DurationField) and
- copy.lookup_name in ('year', 'iso_year', 'month', 'week', 'week_day', 'iso_week_day', 'quarter')
+ if isinstance(field, DurationField) and copy.lookup_name in (
+ "year",
+ "iso_year",
+ "month",
+ "week",
+ "week_day",
+ "iso_week_day",
+ "quarter",
):
raise ValueError(
"Cannot extract component '%s' from DurationField '%s'."
@@ -89,20 +115,21 @@ class Extract(TimezoneMixin, Transform):
class ExtractYear(Extract):
- lookup_name = 'year'
+ lookup_name = "year"
class ExtractIsoYear(Extract):
"""Return the ISO-8601 week-numbering year."""
- lookup_name = 'iso_year'
+
+ lookup_name = "iso_year"
class ExtractMonth(Extract):
- lookup_name = 'month'
+ lookup_name = "month"
class ExtractDay(Extract):
- lookup_name = 'day'
+ lookup_name = "day"
class ExtractWeek(Extract):
@@ -110,7 +137,8 @@ class ExtractWeek(Extract):
Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the
week.
"""
- lookup_name = 'week'
+
+ lookup_name = "week"
class ExtractWeekDay(Extract):
@@ -119,28 +147,30 @@ class ExtractWeekDay(Extract):
To replicate this in Python: (mydatetime.isoweekday() % 7) + 1
"""
- lookup_name = 'week_day'
+
+ lookup_name = "week_day"
class ExtractIsoWeekDay(Extract):
"""Return Monday=1 through Sunday=7, based on ISO-8601."""
- lookup_name = 'iso_week_day'
+
+ lookup_name = "iso_week_day"
class ExtractQuarter(Extract):
- lookup_name = 'quarter'
+ lookup_name = "quarter"
class ExtractHour(Extract):
- lookup_name = 'hour'
+ lookup_name = "hour"
class ExtractMinute(Extract):
- lookup_name = 'minute'
+ lookup_name = "minute"
class ExtractSecond(Extract):
- lookup_name = 'second'
+ lookup_name = "second"
DateField.register_lookup(ExtractYear)
@@ -174,14 +204,16 @@ ExtractIsoYear.register_lookup(YearLte)
class Now(Func):
- template = 'CURRENT_TIMESTAMP'
+ template = "CURRENT_TIMESTAMP"
output_field = DateTimeField()
def as_postgresql(self, compiler, connection, **extra_context):
# PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
# transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
# other databases.
- return self.as_sql(compiler, connection, template='STATEMENT_TIMESTAMP()', **extra_context)
+ return self.as_sql(
+ compiler, connection, template="STATEMENT_TIMESTAMP()", **extra_context
+ )
class TruncBase(TimezoneMixin, Transform):
@@ -190,7 +222,14 @@ class TruncBase(TimezoneMixin, Transform):
# RemovedInDjango50Warning: when the deprecation ends, remove is_dst
# argument.
- def __init__(self, expression, output_field=None, tzinfo=None, is_dst=timezone.NOT_PASSED, **extra):
+ def __init__(
+ self,
+ expression,
+ output_field=None,
+ tzinfo=None,
+ is_dst=timezone.NOT_PASSED,
+ **extra,
+ ):
self.tzinfo = tzinfo
self.is_dst = is_dst
super().__init__(expression, output_field=output_field, **extra)
@@ -201,7 +240,7 @@ class TruncBase(TimezoneMixin, Transform):
if isinstance(self.lhs.output_field, DateTimeField):
tzname = self.get_tzname()
elif self.tzinfo is not None:
- raise ValueError('tzinfo can only be used with DateTimeField.')
+ raise ValueError("tzinfo can only be used with DateTimeField.")
if isinstance(self.output_field, DateTimeField):
sql = connection.ops.datetime_trunc_sql(self.kind, inner_sql, tzname)
elif isinstance(self.output_field, DateField):
@@ -209,11 +248,17 @@ class TruncBase(TimezoneMixin, Transform):
elif isinstance(self.output_field, TimeField):
sql = connection.ops.time_trunc_sql(self.kind, inner_sql, tzname)
else:
- raise ValueError('Trunc only valid on DateField, TimeField, or DateTimeField.')
+ raise ValueError(
+ "Trunc only valid on DateField, TimeField, or DateTimeField."
+ )
return sql, inner_params
- def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
- copy = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
+ def resolve_expression(
+ self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
+ ):
+ copy = super().resolve_expression(
+ query, allow_joins, reuse, summarize, for_save
+ )
field = copy.lhs.output_field
# DateTimeField is a subclass of DateField so this works for both.
if not isinstance(field, (DateField, TimeField)):
@@ -223,23 +268,46 @@ class TruncBase(TimezoneMixin, Transform):
# If self.output_field was None, then accessing the field will trigger
# the resolver to assign it to self.lhs.output_field.
if not isinstance(copy.output_field, (DateField, DateTimeField, TimeField)):
- raise ValueError('output_field must be either DateField, TimeField, or DateTimeField')
+ raise ValueError(
+ "output_field must be either DateField, TimeField, or DateTimeField"
+ )
# Passing dates or times to functions expecting datetimes is most
# likely a mistake.
- class_output_field = self.__class__.output_field if isinstance(self.__class__.output_field, Field) else None
+ class_output_field = (
+ self.__class__.output_field
+ if isinstance(self.__class__.output_field, Field)
+ else None
+ )
output_field = class_output_field or copy.output_field
- has_explicit_output_field = class_output_field or field.__class__ is not copy.output_field.__class__
+ has_explicit_output_field = (
+ class_output_field or field.__class__ is not copy.output_field.__class__
+ )
if type(field) == DateField and (
- isinstance(output_field, DateTimeField) or copy.kind in ('hour', 'minute', 'second', 'time')):
- raise ValueError("Cannot truncate DateField '%s' to %s." % (
- field.name, output_field.__class__.__name__ if has_explicit_output_field else 'DateTimeField'
- ))
+ isinstance(output_field, DateTimeField)
+ or copy.kind in ("hour", "minute", "second", "time")
+ ):
+ raise ValueError(
+ "Cannot truncate DateField '%s' to %s."
+ % (
+ field.name,
+ output_field.__class__.__name__
+ if has_explicit_output_field
+ else "DateTimeField",
+ )
+ )
elif isinstance(field, TimeField) and (
- isinstance(output_field, DateTimeField) or
- copy.kind in ('year', 'quarter', 'month', 'week', 'day', 'date')):
- raise ValueError("Cannot truncate TimeField '%s' to %s." % (
- field.name, output_field.__class__.__name__ if has_explicit_output_field else 'DateTimeField'
- ))
+ isinstance(output_field, DateTimeField)
+ or copy.kind in ("year", "quarter", "month", "week", "day", "date")
+ ):
+ raise ValueError(
+ "Cannot truncate TimeField '%s' to %s."
+ % (
+ field.name,
+ output_field.__class__.__name__
+ if has_explicit_output_field
+ else "DateTimeField",
+ )
+ )
return copy
def convert_value(self, value, expression, connection):
@@ -251,8 +319,8 @@ class TruncBase(TimezoneMixin, Transform):
value = timezone.make_aware(value, self.tzinfo, is_dst=self.is_dst)
elif not connection.features.has_zoneinfo_database:
raise ValueError(
- 'Database returned an invalid datetime value. Are time '
- 'zone definitions for your database installed?'
+ "Database returned an invalid datetime value. Are time "
+ "zone definitions for your database installed?"
)
elif isinstance(value, datetime):
if value is None:
@@ -268,38 +336,46 @@ class Trunc(TruncBase):
# RemovedInDjango50Warning: when the deprecation ends, remove is_dst
# argument.
- def __init__(self, expression, kind, output_field=None, tzinfo=None, is_dst=timezone.NOT_PASSED, **extra):
+ def __init__(
+ self,
+ expression,
+ kind,
+ output_field=None,
+ tzinfo=None,
+ is_dst=timezone.NOT_PASSED,
+ **extra,
+ ):
self.kind = kind
super().__init__(
- expression, output_field=output_field, tzinfo=tzinfo,
- is_dst=is_dst, **extra
+ expression, output_field=output_field, tzinfo=tzinfo, is_dst=is_dst, **extra
)
class TruncYear(TruncBase):
- kind = 'year'
+ kind = "year"
class TruncQuarter(TruncBase):
- kind = 'quarter'
+ kind = "quarter"
class TruncMonth(TruncBase):
- kind = 'month'
+ kind = "month"
class TruncWeek(TruncBase):
"""Truncate to midnight on the Monday of the week."""
- kind = 'week'
+
+ kind = "week"
class TruncDay(TruncBase):
- kind = 'day'
+ kind = "day"
class TruncDate(TruncBase):
- kind = 'date'
- lookup_name = 'date'
+ kind = "date"
+ lookup_name = "date"
output_field = DateField()
def as_sql(self, compiler, connection):
@@ -311,8 +387,8 @@ class TruncDate(TruncBase):
class TruncTime(TruncBase):
- kind = 'time'
- lookup_name = 'time'
+ kind = "time"
+ lookup_name = "time"
output_field = TimeField()
def as_sql(self, compiler, connection):
@@ -324,15 +400,15 @@ class TruncTime(TruncBase):
class TruncHour(TruncBase):
- kind = 'hour'
+ kind = "hour"
class TruncMinute(TruncBase):
- kind = 'minute'
+ kind = "minute"
class TruncSecond(TruncBase):
- kind = 'second'
+ kind = "second"
DateTimeField.register_lookup(TruncDate)
diff --git a/django/db/models/functions/math.py b/django/db/models/functions/math.py
index f939885263..8b5fd79c3a 100644
--- a/django/db/models/functions/math.py
+++ b/django/db/models/functions/math.py
@@ -4,37 +4,40 @@ from django.db.models.expressions import Func, Value
from django.db.models.fields import FloatField, IntegerField
from django.db.models.functions import Cast
from django.db.models.functions.mixins import (
- FixDecimalInputMixin, NumericOutputFieldMixin,
+ FixDecimalInputMixin,
+ NumericOutputFieldMixin,
)
from django.db.models.lookups import Transform
class Abs(Transform):
- function = 'ABS'
- lookup_name = 'abs'
+ function = "ABS"
+ lookup_name = "abs"
class ACos(NumericOutputFieldMixin, Transform):
- function = 'ACOS'
- lookup_name = 'acos'
+ function = "ACOS"
+ lookup_name = "acos"
class ASin(NumericOutputFieldMixin, Transform):
- function = 'ASIN'
- lookup_name = 'asin'
+ function = "ASIN"
+ lookup_name = "asin"
class ATan(NumericOutputFieldMixin, Transform):
- function = 'ATAN'
- lookup_name = 'atan'
+ function = "ATAN"
+ lookup_name = "atan"
class ATan2(NumericOutputFieldMixin, Func):
- function = 'ATAN2'
+ function = "ATAN2"
arity = 2
def as_sqlite(self, compiler, connection, **extra_context):
- if not getattr(connection.ops, 'spatialite', False) or connection.ops.spatial_version >= (5, 0, 0):
+ if not getattr(
+ connection.ops, "spatialite", False
+ ) or connection.ops.spatial_version >= (5, 0, 0):
return self.as_sql(compiler, connection)
# This function is usually ATan2(y, x), returning the inverse tangent
# of y / x, but it's ATan2(x, y) on SpatiaLite < 5.0.0.
@@ -42,67 +45,74 @@ class ATan2(NumericOutputFieldMixin, Func):
# arguments are mixed between integer and float or decimal.
# https://www.gaia-gis.it/fossil/libspatialite/tktview?name=0f72cca3a2
clone = self.copy()
- clone.set_source_expressions([
- Cast(expression, FloatField()) if isinstance(expression.output_field, IntegerField)
- else expression for expression in self.get_source_expressions()[::-1]
- ])
+ clone.set_source_expressions(
+ [
+ Cast(expression, FloatField())
+ if isinstance(expression.output_field, IntegerField)
+ else expression
+ for expression in self.get_source_expressions()[::-1]
+ ]
+ )
return clone.as_sql(compiler, connection, **extra_context)
class Ceil(Transform):
- function = 'CEILING'
- lookup_name = 'ceil'
+ function = "CEILING"
+ lookup_name = "ceil"
def as_oracle(self, compiler, connection, **extra_context):
- return super().as_sql(compiler, connection, function='CEIL', **extra_context)
+ return super().as_sql(compiler, connection, function="CEIL", **extra_context)
class Cos(NumericOutputFieldMixin, Transform):
- function = 'COS'
- lookup_name = 'cos'
+ function = "COS"
+ lookup_name = "cos"
class Cot(NumericOutputFieldMixin, Transform):
- function = 'COT'
- lookup_name = 'cot'
+ function = "COT"
+ lookup_name = "cot"
def as_oracle(self, compiler, connection, **extra_context):
- return super().as_sql(compiler, connection, template='(1 / TAN(%(expressions)s))', **extra_context)
+ return super().as_sql(
+ compiler, connection, template="(1 / TAN(%(expressions)s))", **extra_context
+ )
class Degrees(NumericOutputFieldMixin, Transform):
- function = 'DEGREES'
- lookup_name = 'degrees'
+ function = "DEGREES"
+ lookup_name = "degrees"
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(
- compiler, connection,
- template='((%%(expressions)s) * 180 / %s)' % math.pi,
- **extra_context
+ compiler,
+ connection,
+ template="((%%(expressions)s) * 180 / %s)" % math.pi,
+ **extra_context,
)
class Exp(NumericOutputFieldMixin, Transform):
- function = 'EXP'
- lookup_name = 'exp'
+ function = "EXP"
+ lookup_name = "exp"
class Floor(Transform):
- function = 'FLOOR'
- lookup_name = 'floor'
+ function = "FLOOR"
+ lookup_name = "floor"
class Ln(NumericOutputFieldMixin, Transform):
- function = 'LN'
- lookup_name = 'ln'
+ function = "LN"
+ lookup_name = "ln"
class Log(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
- function = 'LOG'
+ function = "LOG"
arity = 2
def as_sqlite(self, compiler, connection, **extra_context):
- if not getattr(connection.ops, 'spatialite', False):
+ if not getattr(connection.ops, "spatialite", False):
return self.as_sql(compiler, connection)
# This function is usually Log(b, x) returning the logarithm of x to
# the base b, but on SpatiaLite it's Log(x, b).
@@ -112,55 +122,60 @@ class Log(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
class Mod(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
- function = 'MOD'
+ function = "MOD"
arity = 2
class Pi(NumericOutputFieldMixin, Func):
- function = 'PI'
+ function = "PI"
arity = 0
def as_oracle(self, compiler, connection, **extra_context):
- return super().as_sql(compiler, connection, template=str(math.pi), **extra_context)
+ return super().as_sql(
+ compiler, connection, template=str(math.pi), **extra_context
+ )
class Power(NumericOutputFieldMixin, Func):
- function = 'POWER'
+ function = "POWER"
arity = 2
class Radians(NumericOutputFieldMixin, Transform):
- function = 'RADIANS'
- lookup_name = 'radians'
+ function = "RADIANS"
+ lookup_name = "radians"
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(
- compiler, connection,
- template='((%%(expressions)s) * %s / 180)' % math.pi,
- **extra_context
+ compiler,
+ connection,
+ template="((%%(expressions)s) * %s / 180)" % math.pi,
+ **extra_context,
)
class Random(NumericOutputFieldMixin, Func):
- function = 'RANDOM'
+ function = "RANDOM"
arity = 0
def as_mysql(self, compiler, connection, **extra_context):
- return super().as_sql(compiler, connection, function='RAND', **extra_context)
+ return super().as_sql(compiler, connection, function="RAND", **extra_context)
def as_oracle(self, compiler, connection, **extra_context):
- return super().as_sql(compiler, connection, function='DBMS_RANDOM.VALUE', **extra_context)
+ return super().as_sql(
+ compiler, connection, function="DBMS_RANDOM.VALUE", **extra_context
+ )
def as_sqlite(self, compiler, connection, **extra_context):
- return super().as_sql(compiler, connection, function='RAND', **extra_context)
+ return super().as_sql(compiler, connection, function="RAND", **extra_context)
def get_group_by_cols(self, alias=None):
return []
class Round(FixDecimalInputMixin, Transform):
- function = 'ROUND'
- lookup_name = 'round'
+ function = "ROUND"
+ lookup_name = "round"
arity = None # Override Transform's arity=1 to enable passing precision.
def __init__(self, expression, precision=0, **extra):
@@ -169,7 +184,7 @@ class Round(FixDecimalInputMixin, Transform):
def as_sqlite(self, compiler, connection, **extra_context):
precision = self.get_source_expressions()[1]
if isinstance(precision, Value) and precision.value < 0:
- raise ValueError('SQLite does not support negative precision.')
+ raise ValueError("SQLite does not support negative precision.")
return super().as_sqlite(compiler, connection, **extra_context)
def _resolve_output_field(self):
@@ -178,20 +193,20 @@ class Round(FixDecimalInputMixin, Transform):
class Sign(Transform):
- function = 'SIGN'
- lookup_name = 'sign'
+ function = "SIGN"
+ lookup_name = "sign"
class Sin(NumericOutputFieldMixin, Transform):
- function = 'SIN'
- lookup_name = 'sin'
+ function = "SIN"
+ lookup_name = "sin"
class Sqrt(NumericOutputFieldMixin, Transform):
- function = 'SQRT'
- lookup_name = 'sqrt'
+ function = "SQRT"
+ lookup_name = "sqrt"
class Tan(NumericOutputFieldMixin, Transform):
- function = 'TAN'
- lookup_name = 'tan'
+ function = "TAN"
+ lookup_name = "tan"
diff --git a/django/db/models/functions/mixins.py b/django/db/models/functions/mixins.py
index 00cfd1bc01..caf20e131d 100644
--- a/django/db/models/functions/mixins.py
+++ b/django/db/models/functions/mixins.py
@@ -5,7 +5,6 @@ from django.db.models.functions import Cast
class FixDecimalInputMixin:
-
def as_postgresql(self, compiler, connection, **extra_context):
# Cast FloatField to DecimalField as PostgreSQL doesn't support the
# following function signatures:
@@ -13,36 +12,42 @@ class FixDecimalInputMixin:
# - MOD(double, double)
output_field = DecimalField(decimal_places=sys.float_info.dig, max_digits=1000)
clone = self.copy()
- clone.set_source_expressions([
- Cast(expression, output_field) if isinstance(expression.output_field, FloatField)
- else expression for expression in self.get_source_expressions()
- ])
+ clone.set_source_expressions(
+ [
+ Cast(expression, output_field)
+ if isinstance(expression.output_field, FloatField)
+ else expression
+ for expression in self.get_source_expressions()
+ ]
+ )
return clone.as_sql(compiler, connection, **extra_context)
class FixDurationInputMixin:
-
def as_mysql(self, compiler, connection, **extra_context):
sql, params = super().as_sql(compiler, connection, **extra_context)
- if self.output_field.get_internal_type() == 'DurationField':
- sql = 'CAST(%s AS SIGNED)' % sql
+ if self.output_field.get_internal_type() == "DurationField":
+ sql = "CAST(%s AS SIGNED)" % sql
return sql, params
def as_oracle(self, compiler, connection, **extra_context):
- if self.output_field.get_internal_type() == 'DurationField':
+ if self.output_field.get_internal_type() == "DurationField":
expression = self.get_source_expressions()[0]
options = self._get_repr_options()
from django.db.backends.oracle.functions import (
- IntervalToSeconds, SecondsToInterval,
+ IntervalToSeconds,
+ SecondsToInterval,
)
+
return compiler.compile(
- SecondsToInterval(self.__class__(IntervalToSeconds(expression), **options))
+ SecondsToInterval(
+ self.__class__(IntervalToSeconds(expression), **options)
+ )
)
return super().as_sql(compiler, connection, **extra_context)
class NumericOutputFieldMixin:
-
def _resolve_output_field(self):
source_fields = self.get_source_fields()
if any(isinstance(s, DecimalField) for s in source_fields):
diff --git a/django/db/models/functions/text.py b/django/db/models/functions/text.py
index 4c52222ba1..a54ce8f19b 100644
--- a/django/db/models/functions/text.py
+++ b/django/db/models/functions/text.py
@@ -10,7 +10,7 @@ class MySQLSHA2Mixin:
return super().as_sql(
compiler,
connection,
- template='SHA2(%%(expressions)s, %s)' % self.function[3:],
+ template="SHA2(%%(expressions)s, %s)" % self.function[3:],
**extra_content,
)
@@ -40,25 +40,28 @@ class PostgreSQLSHAMixin:
class Chr(Transform):
- function = 'CHR'
- lookup_name = 'chr'
+ function = "CHR"
+ lookup_name = "chr"
def as_mysql(self, compiler, connection, **extra_context):
return super().as_sql(
- compiler, connection, function='CHAR',
- template='%(function)s(%(expressions)s USING utf16)',
- **extra_context
+ compiler,
+ connection,
+ function="CHAR",
+ template="%(function)s(%(expressions)s USING utf16)",
+ **extra_context,
)
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(
- compiler, connection,
- template='%(function)s(%(expressions)s USING NCHAR_CS)',
- **extra_context
+ compiler,
+ connection,
+ template="%(function)s(%(expressions)s USING NCHAR_CS)",
+ **extra_context,
)
def as_sqlite(self, compiler, connection, **extra_context):
- return super().as_sql(compiler, connection, function='CHAR', **extra_context)
+ return super().as_sql(compiler, connection, function="CHAR", **extra_context)
class ConcatPair(Func):
@@ -66,29 +69,38 @@ class ConcatPair(Func):
Concatenate two arguments together. This is used by `Concat` because not
all backend databases support more than two arguments.
"""
- function = 'CONCAT'
+
+ function = "CONCAT"
def as_sqlite(self, compiler, connection, **extra_context):
coalesced = self.coalesce()
return super(ConcatPair, coalesced).as_sql(
- compiler, connection, template='%(expressions)s', arg_joiner=' || ',
- **extra_context
+ compiler,
+ connection,
+ template="%(expressions)s",
+ arg_joiner=" || ",
+ **extra_context,
)
def as_mysql(self, compiler, connection, **extra_context):
# Use CONCAT_WS with an empty separator so that NULLs are ignored.
return super().as_sql(
- compiler, connection, function='CONCAT_WS',
+ compiler,
+ connection,
+ function="CONCAT_WS",
template="%(function)s('', %(expressions)s)",
- **extra_context
+ **extra_context,
)
def coalesce(self):
# null on either side results in null for expression, wrap with coalesce
c = self.copy()
- c.set_source_expressions([
- Coalesce(expression, Value('')) for expression in c.get_source_expressions()
- ])
+ c.set_source_expressions(
+ [
+ Coalesce(expression, Value(""))
+ for expression in c.get_source_expressions()
+ ]
+ )
return c
@@ -98,12 +110,13 @@ class Concat(Func):
null expression when any arguments are null will wrap each argument in
coalesce functions to ensure a non-null result.
"""
+
function = None
template = "%(expressions)s"
def __init__(self, *expressions, **extra):
if len(expressions) < 2:
- raise ValueError('Concat must take at least two expressions')
+ raise ValueError("Concat must take at least two expressions")
paired = self._paired(expressions)
super().__init__(paired, **extra)
@@ -117,7 +130,7 @@ class Concat(Func):
class Left(Func):
- function = 'LEFT'
+ function = "LEFT"
arity = 2
output_field = CharField()
@@ -126,7 +139,7 @@ class Left(Func):
expression: the name of a field, or an expression returning a string
length: the number of characters to return from the start of the string
"""
- if not hasattr(length, 'resolve_expression'):
+ if not hasattr(length, "resolve_expression"):
if length < 1:
raise ValueError("'length' must be greater than 0.")
super().__init__(expression, length, **extra)
@@ -143,57 +156,68 @@ class Left(Func):
class Length(Transform):
"""Return the number of characters in the expression."""
- function = 'LENGTH'
- lookup_name = 'length'
+
+ function = "LENGTH"
+ lookup_name = "length"
output_field = IntegerField()
def as_mysql(self, compiler, connection, **extra_context):
- return super().as_sql(compiler, connection, function='CHAR_LENGTH', **extra_context)
+ return super().as_sql(
+ compiler, connection, function="CHAR_LENGTH", **extra_context
+ )
class Lower(Transform):
- function = 'LOWER'
- lookup_name = 'lower'
+ function = "LOWER"
+ lookup_name = "lower"
class LPad(Func):
- function = 'LPAD'
+ function = "LPAD"
output_field = CharField()
- def __init__(self, expression, length, fill_text=Value(' '), **extra):
- if not hasattr(length, 'resolve_expression') and length is not None and length < 0:
+ def __init__(self, expression, length, fill_text=Value(" "), **extra):
+ if (
+ not hasattr(length, "resolve_expression")
+ and length is not None
+ and length < 0
+ ):
raise ValueError("'length' must be greater or equal to 0.")
super().__init__(expression, length, fill_text, **extra)
class LTrim(Transform):
- function = 'LTRIM'
- lookup_name = 'ltrim'
+ function = "LTRIM"
+ lookup_name = "ltrim"
class MD5(OracleHashMixin, Transform):
- function = 'MD5'
- lookup_name = 'md5'
+ function = "MD5"
+ lookup_name = "md5"
class Ord(Transform):
- function = 'ASCII'
- lookup_name = 'ord'
+ function = "ASCII"
+ lookup_name = "ord"
output_field = IntegerField()
def as_mysql(self, compiler, connection, **extra_context):
- return super().as_sql(compiler, connection, function='ORD', **extra_context)
+ return super().as_sql(compiler, connection, function="ORD", **extra_context)
def as_sqlite(self, compiler, connection, **extra_context):
- return super().as_sql(compiler, connection, function='UNICODE', **extra_context)
+ return super().as_sql(compiler, connection, function="UNICODE", **extra_context)
class Repeat(Func):
- function = 'REPEAT'
+ function = "REPEAT"
output_field = CharField()
def __init__(self, expression, number, **extra):
- if not hasattr(number, 'resolve_expression') and number is not None and number < 0:
+ if (
+ not hasattr(number, "resolve_expression")
+ and number is not None
+ and number < 0
+ ):
raise ValueError("'number' must be greater or equal to 0.")
super().__init__(expression, number, **extra)
@@ -205,73 +229,76 @@ class Repeat(Func):
class Replace(Func):
- function = 'REPLACE'
+ function = "REPLACE"
- def __init__(self, expression, text, replacement=Value(''), **extra):
+ def __init__(self, expression, text, replacement=Value(""), **extra):
super().__init__(expression, text, replacement, **extra)
class Reverse(Transform):
- function = 'REVERSE'
- lookup_name = 'reverse'
+ function = "REVERSE"
+ lookup_name = "reverse"
def as_oracle(self, compiler, connection, **extra_context):
# REVERSE in Oracle is undocumented and doesn't support multi-byte
# strings. Use a special subquery instead.
return super().as_sql(
- compiler, connection,
+ compiler,
+ connection,
template=(
- '(SELECT LISTAGG(s) WITHIN GROUP (ORDER BY n DESC) FROM '
- '(SELECT LEVEL n, SUBSTR(%(expressions)s, LEVEL, 1) s '
- 'FROM DUAL CONNECT BY LEVEL <= LENGTH(%(expressions)s)) '
- 'GROUP BY %(expressions)s)'
+ "(SELECT LISTAGG(s) WITHIN GROUP (ORDER BY n DESC) FROM "
+ "(SELECT LEVEL n, SUBSTR(%(expressions)s, LEVEL, 1) s "
+ "FROM DUAL CONNECT BY LEVEL <= LENGTH(%(expressions)s)) "
+ "GROUP BY %(expressions)s)"
),
- **extra_context
+ **extra_context,
)
class Right(Left):
- function = 'RIGHT'
+ function = "RIGHT"
def get_substr(self):
- return Substr(self.source_expressions[0], self.source_expressions[1] * Value(-1))
+ return Substr(
+ self.source_expressions[0], self.source_expressions[1] * Value(-1)
+ )
class RPad(LPad):
- function = 'RPAD'
+ function = "RPAD"
class RTrim(Transform):
- function = 'RTRIM'
- lookup_name = 'rtrim'
+ function = "RTRIM"
+ lookup_name = "rtrim"
class SHA1(OracleHashMixin, PostgreSQLSHAMixin, Transform):
- function = 'SHA1'
- lookup_name = 'sha1'
+ function = "SHA1"
+ lookup_name = "sha1"
class SHA224(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
- function = 'SHA224'
- lookup_name = 'sha224'
+ function = "SHA224"
+ lookup_name = "sha224"
def as_oracle(self, compiler, connection, **extra_context):
- raise NotSupportedError('SHA224 is not supported on Oracle.')
+ raise NotSupportedError("SHA224 is not supported on Oracle.")
class SHA256(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
- function = 'SHA256'
- lookup_name = 'sha256'
+ function = "SHA256"
+ lookup_name = "sha256"
class SHA384(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
- function = 'SHA384'
- lookup_name = 'sha384'
+ function = "SHA384"
+ lookup_name = "sha384"
class SHA512(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
- function = 'SHA512'
- lookup_name = 'sha512'
+ function = "SHA512"
+ lookup_name = "sha512"
class StrIndex(Func):
@@ -280,16 +307,17 @@ class StrIndex(Func):
first occurrence of a substring inside another string, or 0 if the
substring is not found.
"""
- function = 'INSTR'
+
+ function = "INSTR"
arity = 2
output_field = IntegerField()
def as_postgresql(self, compiler, connection, **extra_context):
- return super().as_sql(compiler, connection, function='STRPOS', **extra_context)
+ return super().as_sql(compiler, connection, function="STRPOS", **extra_context)
class Substr(Func):
- function = 'SUBSTRING'
+ function = "SUBSTRING"
output_field = CharField()
def __init__(self, expression, pos, length=None, **extra):
@@ -298,7 +326,7 @@ class Substr(Func):
pos: an integer > 0, or an expression returning an integer
length: an optional number of characters to return
"""
- if not hasattr(pos, 'resolve_expression'):
+ if not hasattr(pos, "resolve_expression"):
if pos < 1:
raise ValueError("'pos' must be greater than 0")
expressions = [expression, pos]
@@ -307,17 +335,17 @@ class Substr(Func):
super().__init__(*expressions, **extra)
def as_sqlite(self, compiler, connection, **extra_context):
- return super().as_sql(compiler, connection, function='SUBSTR', **extra_context)
+ return super().as_sql(compiler, connection, function="SUBSTR", **extra_context)
def as_oracle(self, compiler, connection, **extra_context):
- return super().as_sql(compiler, connection, function='SUBSTR', **extra_context)
+ return super().as_sql(compiler, connection, function="SUBSTR", **extra_context)
class Trim(Transform):
- function = 'TRIM'
- lookup_name = 'trim'
+ function = "TRIM"
+ lookup_name = "trim"
class Upper(Transform):
- function = 'UPPER'
- lookup_name = 'upper'
+ function = "UPPER"
+ lookup_name = "upper"
diff --git a/django/db/models/functions/window.py b/django/db/models/functions/window.py
index 84b2b24ffa..671017aba7 100644
--- a/django/db/models/functions/window.py
+++ b/django/db/models/functions/window.py
@@ -2,26 +2,35 @@ from django.db.models.expressions import Func
from django.db.models.fields import FloatField, IntegerField
__all__ = [
- 'CumeDist', 'DenseRank', 'FirstValue', 'Lag', 'LastValue', 'Lead',
- 'NthValue', 'Ntile', 'PercentRank', 'Rank', 'RowNumber',
+ "CumeDist",
+ "DenseRank",
+ "FirstValue",
+ "Lag",
+ "LastValue",
+ "Lead",
+ "NthValue",
+ "Ntile",
+ "PercentRank",
+ "Rank",
+ "RowNumber",
]
class CumeDist(Func):
- function = 'CUME_DIST'
+ function = "CUME_DIST"
output_field = FloatField()
window_compatible = True
class DenseRank(Func):
- function = 'DENSE_RANK'
+ function = "DENSE_RANK"
output_field = IntegerField()
window_compatible = True
class FirstValue(Func):
arity = 1
- function = 'FIRST_VALUE'
+ function = "FIRST_VALUE"
window_compatible = True
@@ -31,13 +40,12 @@ class LagLeadFunction(Func):
def __init__(self, expression, offset=1, default=None, **extra):
if expression is None:
raise ValueError(
- '%s requires a non-null source expression.' %
- self.__class__.__name__
+ "%s requires a non-null source expression." % self.__class__.__name__
)
if offset is None or offset <= 0:
raise ValueError(
- '%s requires a positive integer for the offset.' %
- self.__class__.__name__
+ "%s requires a positive integer for the offset."
+ % self.__class__.__name__
)
args = (expression, offset)
if default is not None:
@@ -50,28 +58,32 @@ class LagLeadFunction(Func):
class Lag(LagLeadFunction):
- function = 'LAG'
+ function = "LAG"
class LastValue(Func):
arity = 1
- function = 'LAST_VALUE'
+ function = "LAST_VALUE"
window_compatible = True
class Lead(LagLeadFunction):
- function = 'LEAD'
+ function = "LEAD"
class NthValue(Func):
- function = 'NTH_VALUE'
+ function = "NTH_VALUE"
window_compatible = True
def __init__(self, expression, nth=1, **extra):
if expression is None:
- raise ValueError('%s requires a non-null source expression.' % self.__class__.__name__)
+ raise ValueError(
+ "%s requires a non-null source expression." % self.__class__.__name__
+ )
if nth is None or nth <= 0:
- raise ValueError('%s requires a positive integer as for nth.' % self.__class__.__name__)
+ raise ValueError(
+ "%s requires a positive integer as for nth." % self.__class__.__name__
+ )
super().__init__(expression, nth, **extra)
def _resolve_output_field(self):
@@ -80,29 +92,29 @@ class NthValue(Func):
class Ntile(Func):
- function = 'NTILE'
+ function = "NTILE"
output_field = IntegerField()
window_compatible = True
def __init__(self, num_buckets=1, **extra):
if num_buckets <= 0:
- raise ValueError('num_buckets must be greater than 0.')
+ raise ValueError("num_buckets must be greater than 0.")
super().__init__(num_buckets, **extra)
class PercentRank(Func):
- function = 'PERCENT_RANK'
+ function = "PERCENT_RANK"
output_field = FloatField()
window_compatible = True
class Rank(Func):
- function = 'RANK'
+ function = "RANK"
output_field = IntegerField()
window_compatible = True
class RowNumber(Func):
- function = 'ROW_NUMBER'
+ function = "ROW_NUMBER"
output_field = IntegerField()
window_compatible = True
diff --git a/django/db/models/indexes.py b/django/db/models/indexes.py
index e843f9a8cb..95b71ae5bf 100644
--- a/django/db/models/indexes.py
+++ b/django/db/models/indexes.py
@@ -5,11 +5,11 @@ from django.db.models.query_utils import Q
from django.db.models.sql import Query
from django.utils.functional import partition
-__all__ = ['Index']
+__all__ = ["Index"]
class Index:
- suffix = 'idx'
+ suffix = "idx"
# The max length of the name of the index (restricted to 30 for
# cross-database compatibility with Oracle)
max_name_length = 30
@@ -25,45 +25,47 @@ class Index:
include=None,
):
if opclasses and not name:
- raise ValueError('An index must be named to use opclasses.')
+ raise ValueError("An index must be named to use opclasses.")
if not isinstance(condition, (type(None), Q)):
- raise ValueError('Index.condition must be a Q instance.')
+ raise ValueError("Index.condition must be a Q instance.")
if condition and not name:
- raise ValueError('An index must be named to use condition.')
+ raise ValueError("An index must be named to use condition.")
if not isinstance(fields, (list, tuple)):
- raise ValueError('Index.fields must be a list or tuple.')
+ raise ValueError("Index.fields must be a list or tuple.")
if not isinstance(opclasses, (list, tuple)):
- raise ValueError('Index.opclasses must be a list or tuple.')
+ raise ValueError("Index.opclasses must be a list or tuple.")
if not expressions and not fields:
raise ValueError(
- 'At least one field or expression is required to define an index.'
+ "At least one field or expression is required to define an index."
)
if expressions and fields:
raise ValueError(
- 'Index.fields and expressions are mutually exclusive.',
+ "Index.fields and expressions are mutually exclusive.",
)
if expressions and not name:
- raise ValueError('An index must be named to use expressions.')
+ raise ValueError("An index must be named to use expressions.")
if expressions and opclasses:
raise ValueError(
- 'Index.opclasses cannot be used with expressions. Use '
- 'django.contrib.postgres.indexes.OpClass() instead.'
+ "Index.opclasses cannot be used with expressions. Use "
+ "django.contrib.postgres.indexes.OpClass() instead."
)
if opclasses and len(fields) != len(opclasses):
- raise ValueError('Index.fields and Index.opclasses must have the same number of elements.')
+ raise ValueError(
+ "Index.fields and Index.opclasses must have the same number of elements."
+ )
if fields and not all(isinstance(field, str) for field in fields):
- raise ValueError('Index.fields must contain only strings with field names.')
+ raise ValueError("Index.fields must contain only strings with field names.")
if include and not name:
- raise ValueError('A covering index must be named.')
+ raise ValueError("A covering index must be named.")
if not isinstance(include, (type(None), list, tuple)):
- raise ValueError('Index.include must be a list or tuple.')
+ raise ValueError("Index.include must be a list or tuple.")
self.fields = list(fields)
# A list of 2-tuple with the field name and ordering ('' or 'DESC').
self.fields_orders = [
- (field_name[1:], 'DESC') if field_name.startswith('-') else (field_name, '')
+ (field_name[1:], "DESC") if field_name.startswith("-") else (field_name, "")
for field_name in self.fields
]
- self.name = name or ''
+ self.name = name or ""
self.db_tablespace = db_tablespace
self.opclasses = opclasses
self.condition = condition
@@ -86,8 +88,10 @@ class Index:
sql, params = where.as_sql(compiler, schema_editor.connection)
return sql % tuple(schema_editor.quote_value(p) for p in params)
- def create_sql(self, model, schema_editor, using='', **kwargs):
- include = [model._meta.get_field(field_name).column for field_name in self.include]
+ def create_sql(self, model, schema_editor, using="", **kwargs):
+ include = [
+ model._meta.get_field(field_name).column for field_name in self.include
+ ]
condition = self._get_condition_sql(model, schema_editor)
if self.expressions:
index_expressions = []
@@ -108,29 +112,36 @@ class Index:
col_suffixes = [order[1] for order in self.fields_orders]
expressions = None
return schema_editor._create_index_sql(
- model, fields=fields, name=self.name, using=using,
- db_tablespace=self.db_tablespace, col_suffixes=col_suffixes,
- opclasses=self.opclasses, condition=condition, include=include,
- expressions=expressions, **kwargs,
+ model,
+ fields=fields,
+ name=self.name,
+ using=using,
+ db_tablespace=self.db_tablespace,
+ col_suffixes=col_suffixes,
+ opclasses=self.opclasses,
+ condition=condition,
+ include=include,
+ expressions=expressions,
+ **kwargs,
)
def remove_sql(self, model, schema_editor, **kwargs):
return schema_editor._delete_index_sql(model, self.name, **kwargs)
def deconstruct(self):
- path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__)
- path = path.replace('django.db.models.indexes', 'django.db.models')
- kwargs = {'name': self.name}
+ path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
+ path = path.replace("django.db.models.indexes", "django.db.models")
+ kwargs = {"name": self.name}
if self.fields:
- kwargs['fields'] = self.fields
+ kwargs["fields"] = self.fields
if self.db_tablespace is not None:
- kwargs['db_tablespace'] = self.db_tablespace
+ kwargs["db_tablespace"] = self.db_tablespace
if self.opclasses:
- kwargs['opclasses'] = self.opclasses
+ kwargs["opclasses"] = self.opclasses
if self.condition:
- kwargs['condition'] = self.condition
+ kwargs["condition"] = self.condition
if self.include:
- kwargs['include'] = self.include
+ kwargs["include"] = self.include
return (path, self.expressions, kwargs)
def clone(self):
@@ -147,39 +158,44 @@ class Index:
fit its size by truncating the excess length.
"""
_, table_name = split_identifier(model._meta.db_table)
- column_names = [model._meta.get_field(field_name).column for field_name, order in self.fields_orders]
+ column_names = [
+ model._meta.get_field(field_name).column
+ for field_name, order in self.fields_orders
+ ]
column_names_with_order = [
- (('-%s' if order else '%s') % column_name)
- for column_name, (field_name, order) in zip(column_names, self.fields_orders)
+ (("-%s" if order else "%s") % column_name)
+ for column_name, (field_name, order) in zip(
+ column_names, self.fields_orders
+ )
]
# The length of the parts of the name is based on the default max
# length of 30 characters.
hash_data = [table_name] + column_names_with_order + [self.suffix]
- self.name = '%s_%s_%s' % (
+ self.name = "%s_%s_%s" % (
table_name[:11],
column_names[0][:7],
- '%s_%s' % (names_digest(*hash_data, length=6), self.suffix),
+ "%s_%s" % (names_digest(*hash_data, length=6), self.suffix),
)
if len(self.name) > self.max_name_length:
raise ValueError(
- 'Index too long for multiple database support. Is self.suffix '
- 'longer than 3 characters?'
+ "Index too long for multiple database support. Is self.suffix "
+ "longer than 3 characters?"
)
- if self.name[0] == '_' or self.name[0].isdigit():
- self.name = 'D%s' % self.name[1:]
+ if self.name[0] == "_" or self.name[0].isdigit():
+ self.name = "D%s" % self.name[1:]
def __repr__(self):
- return '<%s:%s%s%s%s%s%s%s>' % (
+ return "<%s:%s%s%s%s%s%s%s>" % (
self.__class__.__qualname__,
- '' if not self.fields else ' fields=%s' % repr(self.fields),
- '' if not self.expressions else ' expressions=%s' % repr(self.expressions),
- '' if not self.name else ' name=%s' % repr(self.name),
- ''
+ "" if not self.fields else " fields=%s" % repr(self.fields),
+ "" if not self.expressions else " expressions=%s" % repr(self.expressions),
+ "" if not self.name else " name=%s" % repr(self.name),
+ ""
if self.db_tablespace is None
- else ' db_tablespace=%s' % repr(self.db_tablespace),
- '' if self.condition is None else ' condition=%s' % self.condition,
- '' if not self.include else ' include=%s' % repr(self.include),
- '' if not self.opclasses else ' opclasses=%s' % repr(self.opclasses),
+ else " db_tablespace=%s" % repr(self.db_tablespace),
+ "" if self.condition is None else " condition=%s" % self.condition,
+ "" if not self.include else " include=%s" % repr(self.include),
+ "" if not self.opclasses else " opclasses=%s" % repr(self.opclasses),
)
def __eq__(self, other):
@@ -190,17 +206,20 @@ class Index:
class IndexExpression(Func):
"""Order and wrap expressions for CREATE INDEX statements."""
- template = '%(expressions)s'
+
+ template = "%(expressions)s"
wrapper_classes = (OrderBy, Collate)
def set_wrapper_classes(self, connection=None):
# Some databases (e.g. MySQL) treats COLLATE as an indexed expression.
if connection and connection.features.collate_as_index_expression:
- self.wrapper_classes = tuple([
- wrapper_cls
- for wrapper_cls in self.wrapper_classes
- if wrapper_cls is not Collate
- ])
+ self.wrapper_classes = tuple(
+ [
+ wrapper_cls
+ for wrapper_cls in self.wrapper_classes
+ if wrapper_cls is not Collate
+ ]
+ )
@classmethod
def register_wrappers(cls, *wrapper_classes):
@@ -224,16 +243,17 @@ class IndexExpression(Func):
if len(wrapper_types) != len(set(wrapper_types)):
raise ValueError(
"Multiple references to %s can't be used in an indexed "
- "expression." % ', '.join([
- wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes
- ])
+ "expression."
+ % ", ".join(
+ [wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes]
+ )
)
- if expressions[1:len(wrappers) + 1] != wrappers:
+ if expressions[1 : len(wrappers) + 1] != wrappers:
raise ValueError(
- '%s must be topmost expressions in an indexed expression.'
- % ', '.join([
- wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes
- ])
+ "%s must be topmost expressions in an indexed expression."
+ % ", ".join(
+ [wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes]
+ )
)
# Wrap expressions in parentheses if they are not column references.
root_expression = index_expressions[1]
@@ -245,7 +265,7 @@ class IndexExpression(Func):
for_save,
)
if not isinstance(resolve_root_expression, Col):
- root_expression = Func(root_expression, template='(%(expressions)s)')
+ root_expression = Func(root_expression, template="(%(expressions)s)")
if wrappers:
# Order wrappers and set their expressions.
@@ -262,7 +282,9 @@ class IndexExpression(Func):
else:
# Use the root expression, if there are no wrappers.
self.set_source_expressions([root_expression])
- return super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
+ return super().resolve_expression(
+ query, allow_joins, reuse, summarize, for_save
+ )
def as_sqlite(self, compiler, connection, **extra_context):
# Casting to numeric is unnecessary.
diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py
index 24bfb11c06..5db549e6bf 100644
--- a/django/db/models/lookups.py
+++ b/django/db/models/lookups.py
@@ -4,7 +4,12 @@ import math
from django.core.exceptions import EmptyResultSet
from django.db.models.expressions import Case, Expression, Func, Value, When
from django.db.models.fields import (
- BooleanField, CharField, DateTimeField, Field, IntegerField, UUIDField,
+ BooleanField,
+ CharField,
+ DateTimeField,
+ Field,
+ IntegerField,
+ UUIDField,
)
from django.db.models.query_utils import RegisterLookupMixin
from django.utils.datastructures import OrderedSet
@@ -21,18 +26,19 @@ class Lookup(Expression):
self.lhs, self.rhs = lhs, rhs
self.rhs = self.get_prep_lookup()
self.lhs = self.get_prep_lhs()
- if hasattr(self.lhs, 'get_bilateral_transforms'):
+ if hasattr(self.lhs, "get_bilateral_transforms"):
bilateral_transforms = self.lhs.get_bilateral_transforms()
else:
bilateral_transforms = []
if bilateral_transforms:
# Warn the user as soon as possible if they are trying to apply
# a bilateral transformation on a nested QuerySet: that won't work.
- from django.db.models.sql.query import ( # avoid circular import
- Query,
- )
+ from django.db.models.sql.query import Query # avoid circular import
+
if isinstance(rhs, Query):
- raise NotImplementedError("Bilateral transformations on nested querysets are not implemented.")
+ raise NotImplementedError(
+ "Bilateral transformations on nested querysets are not implemented."
+ )
self.bilateral_transforms = bilateral_transforms
def apply_bilateral_transforms(self, value):
@@ -41,7 +47,7 @@ class Lookup(Expression):
return value
def __repr__(self):
- return f'{self.__class__.__name__}({self.lhs!r}, {self.rhs!r})'
+ return f"{self.__class__.__name__}({self.lhs!r}, {self.rhs!r})"
def batch_process_rhs(self, compiler, connection, rhs=None):
if rhs is None:
@@ -57,7 +63,7 @@ class Lookup(Expression):
sqls_params.extend(sql_params)
else:
_, params = self.get_db_prep_lookup(rhs, connection)
- sqls, sqls_params = ['%s'] * len(params), params
+ sqls, sqls_params = ["%s"] * len(params), params
return sqls, sqls_params
def get_source_expressions(self):
@@ -72,31 +78,31 @@ class Lookup(Expression):
self.lhs, self.rhs = new_exprs
def get_prep_lookup(self):
- if not self.prepare_rhs or hasattr(self.rhs, 'resolve_expression'):
+ if not self.prepare_rhs or hasattr(self.rhs, "resolve_expression"):
return self.rhs
- if hasattr(self.lhs, 'output_field'):
- if hasattr(self.lhs.output_field, 'get_prep_value'):
+ if hasattr(self.lhs, "output_field"):
+ if hasattr(self.lhs.output_field, "get_prep_value"):
return self.lhs.output_field.get_prep_value(self.rhs)
elif self.rhs_is_direct_value():
return Value(self.rhs)
return self.rhs
def get_prep_lhs(self):
- if hasattr(self.lhs, 'resolve_expression'):
+ if hasattr(self.lhs, "resolve_expression"):
return self.lhs
return Value(self.lhs)
def get_db_prep_lookup(self, value, connection):
- return ('%s', [value])
+ return ("%s", [value])
def process_lhs(self, compiler, connection, lhs=None):
lhs = lhs or self.lhs
- if hasattr(lhs, 'resolve_expression'):
+ if hasattr(lhs, "resolve_expression"):
lhs = lhs.resolve_expression(compiler.query)
sql, params = compiler.compile(lhs)
if isinstance(lhs, Lookup):
# Wrapped in parentheses to respect operator precedence.
- sql = f'({sql})'
+ sql = f"({sql})"
return sql, params
def process_rhs(self, compiler, connection):
@@ -108,19 +114,19 @@ class Lookup(Expression):
value = Value(value, output_field=self.lhs.output_field)
value = self.apply_bilateral_transforms(value)
value = value.resolve_expression(compiler.query)
- if hasattr(value, 'as_sql'):
+ if hasattr(value, "as_sql"):
sql, params = compiler.compile(value)
# Ensure expression is wrapped in parentheses to respect operator
# precedence but avoid double wrapping as it can be misinterpreted
# on some backends (e.g. subqueries on SQLite).
- if sql and sql[0] != '(':
- sql = '(%s)' % sql
+ if sql and sql[0] != "(":
+ sql = "(%s)" % sql
return sql, params
else:
return self.get_db_prep_lookup(value, connection)
def rhs_is_direct_value(self):
- return not hasattr(self.rhs, 'as_sql')
+ return not hasattr(self.rhs, "as_sql")
def get_group_by_cols(self, alias=None):
cols = []
@@ -157,11 +163,17 @@ class Lookup(Expression):
def __hash__(self):
return hash(make_hashable(self.identity))
- def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
+ def resolve_expression(
+ self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
+ ):
c = self.copy()
c.is_summary = summarize
- c.lhs = self.lhs.resolve_expression(query, allow_joins, reuse, summarize, for_save)
- c.rhs = self.rhs.resolve_expression(query, allow_joins, reuse, summarize, for_save)
+ c.lhs = self.lhs.resolve_expression(
+ query, allow_joins, reuse, summarize, for_save
+ )
+ c.rhs = self.rhs.resolve_expression(
+ query, allow_joins, reuse, summarize, for_save
+ )
return c
def select_format(self, compiler, sql, params):
@@ -169,7 +181,7 @@ class Lookup(Expression):
# (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
# BY list.
if not compiler.connection.features.supports_boolean_expr_in_select_clause:
- sql = f'CASE WHEN {sql} THEN 1 ELSE 0 END'
+ sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
return sql, params
@@ -178,6 +190,7 @@ class Transform(RegisterLookupMixin, Func):
RegisterLookupMixin() is first so that get_lookup() and get_transform()
first examine self and then check output_field.
"""
+
bilateral = False
arity = 1
@@ -186,7 +199,7 @@ class Transform(RegisterLookupMixin, Func):
return self.get_source_expressions()[0]
def get_bilateral_transforms(self):
- if hasattr(self.lhs, 'get_bilateral_transforms'):
+ if hasattr(self.lhs, "get_bilateral_transforms"):
bilateral_transforms = self.lhs.get_bilateral_transforms()
else:
bilateral_transforms = []
@@ -200,9 +213,10 @@ class BuiltinLookup(Lookup):
lhs_sql, params = super().process_lhs(compiler, connection, lhs)
field_internal_type = self.lhs.output_field.get_internal_type()
db_type = self.lhs.output_field.db_type(connection=connection)
- lhs_sql = connection.ops.field_cast_sql(
- db_type, field_internal_type) % lhs_sql
- lhs_sql = connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql
+ lhs_sql = connection.ops.field_cast_sql(db_type, field_internal_type) % lhs_sql
+ lhs_sql = (
+ connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql
+ )
return lhs_sql, list(params)
def as_sql(self, compiler, connection):
@@ -210,7 +224,7 @@ class BuiltinLookup(Lookup):
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
params.extend(rhs_params)
rhs_sql = self.get_rhs_op(connection, rhs_sql)
- return '%s %s' % (lhs_sql, rhs_sql), params
+ return "%s %s" % (lhs_sql, rhs_sql), params
def get_rhs_op(self, connection, rhs):
return connection.operators[self.lookup_name] % rhs
@@ -221,18 +235,22 @@ class FieldGetDbPrepValueMixin:
Some lookups require Field.get_db_prep_value() to be called on their
inputs.
"""
+
get_db_prep_lookup_value_is_iterable = False
def get_db_prep_lookup(self, value, connection):
# For relational fields, use the 'target_field' attribute of the
# output_field.
- field = getattr(self.lhs.output_field, 'target_field', None)
- get_db_prep_value = getattr(field, 'get_db_prep_value', None) or self.lhs.output_field.get_db_prep_value
+ field = getattr(self.lhs.output_field, "target_field", None)
+ get_db_prep_value = (
+ getattr(field, "get_db_prep_value", None)
+ or self.lhs.output_field.get_db_prep_value
+ )
return (
- '%s',
+ "%s",
[get_db_prep_value(v, connection, prepared=True) for v in value]
- if self.get_db_prep_lookup_value_is_iterable else
- [get_db_prep_value(value, connection, prepared=True)]
+ if self.get_db_prep_lookup_value_is_iterable
+ else [get_db_prep_value(value, connection, prepared=True)],
)
@@ -241,18 +259,19 @@ class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
Some lookups require Field.get_db_prep_value() to be called on each value
in an iterable.
"""
+
get_db_prep_lookup_value_is_iterable = True
def get_prep_lookup(self):
- if hasattr(self.rhs, 'resolve_expression'):
+ if hasattr(self.rhs, "resolve_expression"):
return self.rhs
prepared_values = []
for rhs_value in self.rhs:
- if hasattr(rhs_value, 'resolve_expression'):
+ if hasattr(rhs_value, "resolve_expression"):
# An expression will be handled by the database but can coexist
# alongside real values.
pass
- elif self.prepare_rhs and hasattr(self.lhs.output_field, 'get_prep_value'):
+ elif self.prepare_rhs and hasattr(self.lhs.output_field, "get_prep_value"):
rhs_value = self.lhs.output_field.get_prep_value(rhs_value)
prepared_values.append(rhs_value)
return prepared_values
@@ -267,9 +286,9 @@ class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
def resolve_expression_parameter(self, compiler, connection, sql, param):
params = [param]
- if hasattr(param, 'resolve_expression'):
+ if hasattr(param, "resolve_expression"):
param = param.resolve_expression(compiler.query)
- if hasattr(param, 'as_sql'):
+ if hasattr(param, "as_sql"):
sql, params = compiler.compile(param)
return sql, params
@@ -279,40 +298,44 @@ class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
# sql/param pair. Zip them to get sql and param pairs that refer to the
# same argument and attempt to replace them with the result of
# compiling the param step.
- sql, params = zip(*(
- self.resolve_expression_parameter(compiler, connection, sql, param)
- for sql, param in zip(*pre_processed)
- ))
+ sql, params = zip(
+ *(
+ self.resolve_expression_parameter(compiler, connection, sql, param)
+ for sql, param in zip(*pre_processed)
+ )
+ )
params = itertools.chain.from_iterable(params)
return sql, tuple(params)
class PostgresOperatorLookup(FieldGetDbPrepValueMixin, Lookup):
"""Lookup defined by operators on PostgreSQL."""
+
postgres_operator = None
def as_postgresql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = tuple(lhs_params) + tuple(rhs_params)
- return '%s %s %s' % (lhs, self.postgres_operator, rhs), params
+ return "%s %s %s" % (lhs, self.postgres_operator, rhs), params
@Field.register_lookup
class Exact(FieldGetDbPrepValueMixin, BuiltinLookup):
- lookup_name = 'exact'
+ lookup_name = "exact"
def get_prep_lookup(self):
from django.db.models.sql.query import Query # avoid circular import
+
if isinstance(self.rhs, Query):
if self.rhs.has_limit_one():
if not self.rhs.has_select_fields:
self.rhs.clear_select_clause()
- self.rhs.add_fields(['pk'])
+ self.rhs.add_fields(["pk"])
else:
raise ValueError(
- 'The QuerySet value for an exact lookup must be limited to '
- 'one result using slicing.'
+ "The QuerySet value for an exact lookup must be limited to "
+ "one result using slicing."
)
return super().get_prep_lookup()
@@ -321,19 +344,21 @@ class Exact(FieldGetDbPrepValueMixin, BuiltinLookup):
# turns "boolfield__exact=True" into "WHERE boolean_field" instead of
# "WHERE boolean_field = True" when allowed.
if (
- isinstance(self.rhs, bool) and
- getattr(self.lhs, 'conditional', False) and
- connection.ops.conditional_expression_supported_in_where_clause(self.lhs)
+ isinstance(self.rhs, bool)
+ and getattr(self.lhs, "conditional", False)
+ and connection.ops.conditional_expression_supported_in_where_clause(
+ self.lhs
+ )
):
lhs_sql, params = self.process_lhs(compiler, connection)
- template = '%s' if self.rhs else 'NOT %s'
+ template = "%s" if self.rhs else "NOT %s"
return template % lhs_sql, params
return super().as_sql(compiler, connection)
@Field.register_lookup
class IExact(BuiltinLookup):
- lookup_name = 'iexact'
+ lookup_name = "iexact"
prepare_rhs = False
def process_rhs(self, qn, connection):
@@ -345,22 +370,22 @@ class IExact(BuiltinLookup):
@Field.register_lookup
class GreaterThan(FieldGetDbPrepValueMixin, BuiltinLookup):
- lookup_name = 'gt'
+ lookup_name = "gt"
@Field.register_lookup
class GreaterThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
- lookup_name = 'gte'
+ lookup_name = "gte"
@Field.register_lookup
class LessThan(FieldGetDbPrepValueMixin, BuiltinLookup):
- lookup_name = 'lt'
+ lookup_name = "lt"
@Field.register_lookup
class LessThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
- lookup_name = 'lte'
+ lookup_name = "lte"
class IntegerFieldFloatRounding:
@@ -368,6 +393,7 @@ class IntegerFieldFloatRounding:
Allow floats to work as query values for IntegerField. Without this, the
decimal portion of the float would always be discarded.
"""
+
def get_prep_lookup(self):
if isinstance(self.rhs, float):
self.rhs = math.ceil(self.rhs)
@@ -386,19 +412,20 @@ class IntegerLessThan(IntegerFieldFloatRounding, LessThan):
@Field.register_lookup
class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
- lookup_name = 'in'
+ lookup_name = "in"
def get_prep_lookup(self):
from django.db.models.sql.query import Query # avoid circular import
+
if isinstance(self.rhs, Query):
self.rhs.clear_ordering(clear_default=True)
if not self.rhs.has_select_fields:
self.rhs.clear_select_clause()
- self.rhs.add_fields(['pk'])
+ self.rhs.add_fields(["pk"])
return super().get_prep_lookup()
def process_rhs(self, compiler, connection):
- db_rhs = getattr(self.rhs, '_db', None)
+ db_rhs = getattr(self.rhs, "_db", None)
if db_rhs is not None and db_rhs != connection.alias:
raise ValueError(
"Subqueries aren't allowed across different databases. Force "
@@ -419,16 +446,20 @@ class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
# rhs should be an iterable; use batch_process_rhs() to
# prepare/transform those values.
sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs)
- placeholder = '(' + ', '.join(sqls) + ')'
+ placeholder = "(" + ", ".join(sqls) + ")"
return (placeholder, sqls_params)
return super().process_rhs(compiler, connection)
def get_rhs_op(self, connection, rhs):
- return 'IN %s' % rhs
+ return "IN %s" % rhs
def as_sql(self, compiler, connection):
max_in_list_size = connection.ops.max_in_list_size()
- if self.rhs_is_direct_value() and max_in_list_size and len(self.rhs) > max_in_list_size:
+ if (
+ self.rhs_is_direct_value()
+ and max_in_list_size
+ and len(self.rhs) > max_in_list_size
+ ):
return self.split_parameter_list_as_sql(compiler, connection)
return super().as_sql(compiler, connection)
@@ -438,25 +469,25 @@ class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
max_in_list_size = connection.ops.max_in_list_size()
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.batch_process_rhs(compiler, connection)
- in_clause_elements = ['(']
+ in_clause_elements = ["("]
params = []
for offset in range(0, len(rhs_params), max_in_list_size):
if offset > 0:
- in_clause_elements.append(' OR ')
- in_clause_elements.append('%s IN (' % lhs)
+ in_clause_elements.append(" OR ")
+ in_clause_elements.append("%s IN (" % lhs)
params.extend(lhs_params)
- sqls = rhs[offset: offset + max_in_list_size]
- sqls_params = rhs_params[offset: offset + max_in_list_size]
- param_group = ', '.join(sqls)
+ sqls = rhs[offset : offset + max_in_list_size]
+ sqls_params = rhs_params[offset : offset + max_in_list_size]
+ param_group = ", ".join(sqls)
in_clause_elements.append(param_group)
- in_clause_elements.append(')')
+ in_clause_elements.append(")")
params.extend(sqls_params)
- in_clause_elements.append(')')
- return ''.join(in_clause_elements), params
+ in_clause_elements.append(")")
+ return "".join(in_clause_elements), params
class PatternLookup(BuiltinLookup):
- param_pattern = '%%%s%%'
+ param_pattern = "%%%s%%"
prepare_rhs = False
def get_rhs_op(self, connection, rhs):
@@ -469,8 +500,10 @@ class PatternLookup(BuiltinLookup):
# So, for Python values we don't need any special pattern, but for
# SQL reference values or SQL transformations we need the correct
# pattern added.
- if hasattr(self.rhs, 'as_sql') or self.bilateral_transforms:
- pattern = connection.pattern_ops[self.lookup_name].format(connection.pattern_esc)
+ if hasattr(self.rhs, "as_sql") or self.bilateral_transforms:
+ pattern = connection.pattern_ops[self.lookup_name].format(
+ connection.pattern_esc
+ )
return pattern.format(rhs)
else:
return super().get_rhs_op(connection, rhs)
@@ -478,45 +511,47 @@ class PatternLookup(BuiltinLookup):
def process_rhs(self, qn, connection):
rhs, params = super().process_rhs(qn, connection)
if self.rhs_is_direct_value() and params and not self.bilateral_transforms:
- params[0] = self.param_pattern % connection.ops.prep_for_like_query(params[0])
+ params[0] = self.param_pattern % connection.ops.prep_for_like_query(
+ params[0]
+ )
return rhs, params
@Field.register_lookup
class Contains(PatternLookup):
- lookup_name = 'contains'
+ lookup_name = "contains"
@Field.register_lookup
class IContains(Contains):
- lookup_name = 'icontains'
+ lookup_name = "icontains"
@Field.register_lookup
class StartsWith(PatternLookup):
- lookup_name = 'startswith'
- param_pattern = '%s%%'
+ lookup_name = "startswith"
+ param_pattern = "%s%%"
@Field.register_lookup
class IStartsWith(StartsWith):
- lookup_name = 'istartswith'
+ lookup_name = "istartswith"
@Field.register_lookup
class EndsWith(PatternLookup):
- lookup_name = 'endswith'
- param_pattern = '%%%s'
+ lookup_name = "endswith"
+ param_pattern = "%%%s"
@Field.register_lookup
class IEndsWith(EndsWith):
- lookup_name = 'iendswith'
+ lookup_name = "iendswith"
@Field.register_lookup
class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
- lookup_name = 'range'
+ lookup_name = "range"
def get_rhs_op(self, connection, rhs):
return "BETWEEN %s AND %s" % (rhs[0], rhs[1])
@@ -524,13 +559,13 @@ class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
@Field.register_lookup
class IsNull(BuiltinLookup):
- lookup_name = 'isnull'
+ lookup_name = "isnull"
prepare_rhs = False
def as_sql(self, compiler, connection):
if not isinstance(self.rhs, bool):
raise ValueError(
- 'The QuerySet value for an isnull lookup must be True or False.'
+ "The QuerySet value for an isnull lookup must be True or False."
)
sql, params = compiler.compile(self.lhs)
if self.rhs:
@@ -541,7 +576,7 @@ class IsNull(BuiltinLookup):
@Field.register_lookup
class Regex(BuiltinLookup):
- lookup_name = 'regex'
+ lookup_name = "regex"
prepare_rhs = False
def as_sql(self, compiler, connection):
@@ -556,21 +591,24 @@ class Regex(BuiltinLookup):
@Field.register_lookup
class IRegex(Regex):
- lookup_name = 'iregex'
+ lookup_name = "iregex"
class YearLookup(Lookup):
def year_lookup_bounds(self, connection, year):
from django.db.models.functions import ExtractIsoYear
+
iso_year = isinstance(self.lhs, ExtractIsoYear)
output_field = self.lhs.lhs.output_field
if isinstance(output_field, DateTimeField):
bounds = connection.ops.year_lookup_bounds_for_datetime_field(
- year, iso_year=iso_year,
+ year,
+ iso_year=iso_year,
)
else:
bounds = connection.ops.year_lookup_bounds_for_date_field(
- year, iso_year=iso_year,
+ year,
+ iso_year=iso_year,
)
return bounds
@@ -585,7 +623,7 @@ class YearLookup(Lookup):
rhs_sql = self.get_direct_rhs_sql(connection, rhs_sql)
start, finish = self.year_lookup_bounds(connection, self.rhs)
params.extend(self.get_bound_params(start, finish))
- return '%s %s' % (lhs_sql, rhs_sql), params
+ return "%s %s" % (lhs_sql, rhs_sql), params
return super().as_sql(compiler, connection)
def get_direct_rhs_sql(self, connection, rhs):
@@ -593,13 +631,13 @@ class YearLookup(Lookup):
def get_bound_params(self, start, finish):
raise NotImplementedError(
- 'subclasses of YearLookup must provide a get_bound_params() method'
+ "subclasses of YearLookup must provide a get_bound_params() method"
)
class YearExact(YearLookup, Exact):
def get_direct_rhs_sql(self, connection, rhs):
- return 'BETWEEN %s AND %s'
+ return "BETWEEN %s AND %s"
def get_bound_params(self, start, finish):
return (start, finish)
@@ -630,12 +668,16 @@ class UUIDTextMixin:
Strip hyphens from a value when filtering a UUIDField on backends without
a native datatype for UUID.
"""
+
def process_rhs(self, qn, connection):
if not connection.features.has_native_uuid_field:
from django.db.models.functions import Replace
+
if self.rhs_is_direct_value():
self.rhs = Value(self.rhs)
- self.rhs = Replace(self.rhs, Value('-'), Value(''), output_field=CharField())
+ self.rhs = Replace(
+ self.rhs, Value("-"), Value(""), output_field=CharField()
+ )
rhs, params = super().process_rhs(qn, connection)
return rhs, params
diff --git a/django/db/models/manager.py b/django/db/models/manager.py
index 655dfcf8e7..0c0688828e 100644
--- a/django/db/models/manager.py
+++ b/django/db/models/manager.py
@@ -33,7 +33,7 @@ class BaseManager:
def __str__(self):
"""Return "app_label.model_label.manager_name"."""
- return '%s.%s' % (self.model._meta.label, self.name)
+ return "%s.%s" % (self.model._meta.label, self.name)
def __class_getitem__(cls, *args, **kwargs):
return cls
@@ -46,12 +46,12 @@ class BaseManager:
Raise a ValueError if the manager is dynamically generated.
"""
qs_class = self._queryset_class
- if getattr(self, '_built_with_as_manager', False):
+ if getattr(self, "_built_with_as_manager", False):
# using MyQuerySet.as_manager()
return (
True, # as_manager
None, # manager_class
- '%s.%s' % (qs_class.__module__, qs_class.__name__), # qs_class
+ "%s.%s" % (qs_class.__module__, qs_class.__name__), # qs_class
None, # args
None, # kwargs
)
@@ -69,7 +69,7 @@ class BaseManager:
)
return (
False, # as_manager
- '%s.%s' % (module_name, name), # manager_class
+ "%s.%s" % (module_name, name), # manager_class
None, # qs_class
self._constructor_args[0], # args
self._constructor_args[1], # kwargs
@@ -83,18 +83,21 @@ class BaseManager:
def create_method(name, method):
def manager_method(self, *args, **kwargs):
return getattr(self.get_queryset(), name)(*args, **kwargs)
+
manager_method.__name__ = method.__name__
manager_method.__doc__ = method.__doc__
return manager_method
new_methods = {}
- for name, method in inspect.getmembers(queryset_class, predicate=inspect.isfunction):
+ for name, method in inspect.getmembers(
+ queryset_class, predicate=inspect.isfunction
+ ):
# Only copy missing methods.
if hasattr(cls, name):
continue
# Only copy public methods or methods with the attribute `queryset_only=False`.
- queryset_only = getattr(method, 'queryset_only', None)
- if queryset_only or (queryset_only is None and name.startswith('_')):
+ queryset_only = getattr(method, "queryset_only", None)
+ if queryset_only or (queryset_only is None and name.startswith("_")):
continue
# Copy the method onto the manager.
new_methods[name] = create_method(name, method)
@@ -103,11 +106,15 @@ class BaseManager:
@classmethod
def from_queryset(cls, queryset_class, class_name=None):
if class_name is None:
- class_name = '%sFrom%s' % (cls.__name__, queryset_class.__name__)
- return type(class_name, (cls,), {
- '_queryset_class': queryset_class,
- **cls._get_queryset_methods(queryset_class),
- })
+ class_name = "%sFrom%s" % (cls.__name__, queryset_class.__name__)
+ return type(
+ class_name,
+ (cls,),
+ {
+ "_queryset_class": queryset_class,
+ **cls._get_queryset_methods(queryset_class),
+ },
+ )
def contribute_to_class(self, cls, name):
self.name = self.name or name
@@ -157,8 +164,8 @@ class BaseManager:
def __eq__(self, other):
return (
- isinstance(other, self.__class__) and
- self._constructor_args == other._constructor_args
+ isinstance(other, self.__class__)
+ and self._constructor_args == other._constructor_args
)
def __hash__(self):
@@ -170,22 +177,24 @@ class Manager(BaseManager.from_queryset(QuerySet)):
class ManagerDescriptor:
-
def __init__(self, manager):
self.manager = manager
def __get__(self, instance, cls=None):
if instance is not None:
- raise AttributeError("Manager isn't accessible via %s instances" % cls.__name__)
+ raise AttributeError(
+ "Manager isn't accessible via %s instances" % cls.__name__
+ )
if cls._meta.abstract:
- raise AttributeError("Manager isn't available; %s is abstract" % (
- cls._meta.object_name,
- ))
+ raise AttributeError(
+ "Manager isn't available; %s is abstract" % (cls._meta.object_name,)
+ )
if cls._meta.swapped:
raise AttributeError(
- "Manager isn't available; '%s' has been swapped for '%s'" % (
+ "Manager isn't available; '%s' has been swapped for '%s'"
+ % (
cls._meta.label,
cls._meta.swapped,
)
diff --git a/django/db/models/options.py b/django/db/models/options.py
index 6022099e3e..b95f9871b1 100644
--- a/django/db/models/options.py
+++ b/django/db/models/options.py
@@ -25,13 +25,32 @@ IMMUTABLE_WARNING = (
)
DEFAULT_NAMES = (
- 'verbose_name', 'verbose_name_plural', 'db_table', 'ordering',
- 'unique_together', 'permissions', 'get_latest_by', 'order_with_respect_to',
- 'app_label', 'db_tablespace', 'abstract', 'managed', 'proxy', 'swappable',
- 'auto_created', 'index_together', 'apps', 'default_permissions',
- 'select_on_save', 'default_related_name', 'required_db_features',
- 'required_db_vendor', 'base_manager_name', 'default_manager_name',
- 'indexes', 'constraints',
+ "verbose_name",
+ "verbose_name_plural",
+ "db_table",
+ "ordering",
+ "unique_together",
+ "permissions",
+ "get_latest_by",
+ "order_with_respect_to",
+ "app_label",
+ "db_tablespace",
+ "abstract",
+ "managed",
+ "proxy",
+ "swappable",
+ "auto_created",
+ "index_together",
+ "apps",
+ "default_permissions",
+ "select_on_save",
+ "default_related_name",
+ "required_db_features",
+ "required_db_vendor",
+ "base_manager_name",
+ "default_manager_name",
+ "indexes",
+ "constraints",
)
@@ -63,11 +82,17 @@ def make_immutable_fields_list(name, data):
class Options:
FORWARD_PROPERTIES = {
- 'fields', 'many_to_many', 'concrete_fields', 'local_concrete_fields',
- '_forward_fields_map', 'managers', 'managers_map', 'base_manager',
- 'default_manager',
+ "fields",
+ "many_to_many",
+ "concrete_fields",
+ "local_concrete_fields",
+ "_forward_fields_map",
+ "managers",
+ "managers_map",
+ "base_manager",
+ "default_manager",
}
- REVERSE_PROPERTIES = {'related_objects', 'fields_map', '_relation_tree'}
+ REVERSE_PROPERTIES = {"related_objects", "fields_map", "_relation_tree"}
default_apps = apps
@@ -82,7 +107,7 @@ class Options:
self.model_name = None
self.verbose_name = None
self.verbose_name_plural = None
- self.db_table = ''
+ self.db_table = ""
self.ordering = []
self._ordering_clash = False
self.indexes = []
@@ -90,7 +115,7 @@ class Options:
self.unique_together = []
self.index_together = []
self.select_on_save = False
- self.default_permissions = ('add', 'change', 'delete', 'view')
+ self.default_permissions = ("add", "change", "delete", "view")
self.permissions = []
self.object_name = None
self.app_label = app_label
@@ -130,11 +155,11 @@ class Options:
@property
def label(self):
- return '%s.%s' % (self.app_label, self.object_name)
+ return "%s.%s" % (self.app_label, self.object_name)
@property
def label_lower(self):
- return '%s.%s' % (self.app_label, self.model_name)
+ return "%s.%s" % (self.app_label, self.model_name)
@property
def app_config(self):
@@ -163,7 +188,7 @@ class Options:
# Ignore any private attributes that Django doesn't care about.
# NOTE: We can't modify a dictionary's contents while looping
# over it, so we loop over the *original* dictionary instead.
- if name.startswith('_'):
+ if name.startswith("_"):
del meta_attrs[name]
for attr_name in DEFAULT_NAMES:
if attr_name in meta_attrs:
@@ -177,30 +202,34 @@ class Options:
self.index_together = normalize_together(self.index_together)
# App label/class name interpolation for names of constraints and
# indexes.
- if not getattr(cls._meta, 'abstract', False):
- for attr_name in {'constraints', 'indexes'}:
+ if not getattr(cls._meta, "abstract", False):
+ for attr_name in {"constraints", "indexes"}:
objs = getattr(self, attr_name, [])
setattr(self, attr_name, self._format_names_with_class(cls, objs))
# verbose_name_plural is a special case because it uses a 's'
# by default.
if self.verbose_name_plural is None:
- self.verbose_name_plural = format_lazy('{}s', self.verbose_name)
+ self.verbose_name_plural = format_lazy("{}s", self.verbose_name)
# order_with_respect_and ordering are mutually exclusive.
self._ordering_clash = bool(self.ordering and self.order_with_respect_to)
# Any leftover attributes must be invalid.
if meta_attrs != {}:
- raise TypeError("'class Meta' got invalid attribute(s): %s" % ','.join(meta_attrs))
+ raise TypeError(
+ "'class Meta' got invalid attribute(s): %s" % ",".join(meta_attrs)
+ )
else:
- self.verbose_name_plural = format_lazy('{}s', self.verbose_name)
+ self.verbose_name_plural = format_lazy("{}s", self.verbose_name)
del self.meta
# If the db_table wasn't provided, use the app_label + model_name.
if not self.db_table:
self.db_table = "%s_%s" % (self.app_label, self.model_name)
- self.db_table = truncate_name(self.db_table, connection.ops.max_name_length())
+ self.db_table = truncate_name(
+ self.db_table, connection.ops.max_name_length()
+ )
def _format_names_with_class(self, cls, objs):
"""App label/class name interpolation for object names."""
@@ -208,8 +237,8 @@ class Options:
for obj in objs:
obj = obj.clone()
obj.name = obj.name % {
- 'app_label': cls._meta.app_label.lower(),
- 'class': cls.__name__.lower(),
+ "app_label": cls._meta.app_label.lower(),
+ "class": cls.__name__.lower(),
}
new_objs.append(obj)
return new_objs
@@ -217,19 +246,19 @@ class Options:
def _get_default_pk_class(self):
pk_class_path = getattr(
self.app_config,
- 'default_auto_field',
+ "default_auto_field",
settings.DEFAULT_AUTO_FIELD,
)
if self.app_config and self.app_config._is_default_auto_field_overridden:
app_config_class = type(self.app_config)
source = (
- f'{app_config_class.__module__}.'
- f'{app_config_class.__qualname__}.default_auto_field'
+ f"{app_config_class.__module__}."
+ f"{app_config_class.__qualname__}.default_auto_field"
)
else:
- source = 'DEFAULT_AUTO_FIELD'
+ source = "DEFAULT_AUTO_FIELD"
if not pk_class_path:
- raise ImproperlyConfigured(f'{source} must not be empty.')
+ raise ImproperlyConfigured(f"{source} must not be empty.")
try:
pk_class = import_string(pk_class_path)
except ImportError as e:
@@ -252,15 +281,20 @@ class Options:
query = self.order_with_respect_to
try:
self.order_with_respect_to = next(
- f for f in self._get_fields(reverse=False)
+ f
+ for f in self._get_fields(reverse=False)
if f.name == query or f.attname == query
)
except StopIteration:
- raise FieldDoesNotExist("%s has no field named '%s'" % (self.object_name, query))
+ raise FieldDoesNotExist(
+ "%s has no field named '%s'" % (self.object_name, query)
+ )
- self.ordering = ('_order',)
- if not any(isinstance(field, OrderWrt) for field in model._meta.local_fields):
- model.add_to_class('_order', OrderWrt())
+ self.ordering = ("_order",)
+ if not any(
+ isinstance(field, OrderWrt) for field in model._meta.local_fields
+ ):
+ model.add_to_class("_order", OrderWrt())
else:
self.order_with_respect_to = None
@@ -272,15 +306,17 @@ class Options:
# Look for a local field with the same name as the
# first parent link. If a local field has already been
# created, use it instead of promoting the parent
- already_created = [fld for fld in self.local_fields if fld.name == field.name]
+ already_created = [
+ fld for fld in self.local_fields if fld.name == field.name
+ ]
if already_created:
field = already_created[0]
field.primary_key = True
self.setup_pk(field)
else:
pk_class = self._get_default_pk_class()
- auto = pk_class(verbose_name='ID', primary_key=True, auto_created=True)
- model.add_to_class('id', auto)
+ auto = pk_class(verbose_name="ID", primary_key=True, auto_created=True)
+ model.add_to_class("id", auto)
def add_manager(self, manager):
self.local_managers.append(manager)
@@ -307,7 +343,11 @@ class Options:
# ideally, we'd just ask for field.related_model. However, related_model
# is a cached property, and all the models haven't been loaded yet, so
# we need to make sure we don't cache a string reference.
- if field.is_relation and hasattr(field.remote_field, 'model') and field.remote_field.model:
+ if (
+ field.is_relation
+ and hasattr(field.remote_field, "model")
+ and field.remote_field.model
+ ):
try:
field.remote_field.model._meta._expire_cache(forward=False)
except AttributeError:
@@ -331,7 +371,7 @@ class Options:
self.db_table = target._meta.db_table
def __repr__(self):
- return '<Options for %s>' % self.object_name
+ return "<Options for %s>" % self.object_name
def __str__(self):
return self.label_lower
@@ -348,8 +388,10 @@ class Options:
if self.required_db_vendor:
return self.required_db_vendor == connection.vendor
if self.required_db_features:
- return all(getattr(connection.features, feat, False)
- for feat in self.required_db_features)
+ return all(
+ getattr(connection.features, feat, False)
+ for feat in self.required_db_features
+ )
return True
@property
@@ -371,7 +413,7 @@ class Options:
swapped_for = getattr(settings, self.swappable, None)
if swapped_for:
try:
- swapped_label, swapped_object = swapped_for.split('.')
+ swapped_label, swapped_object = swapped_for.split(".")
except ValueError:
# setting not in the format app_label.model_name
# raising ImproperlyConfigured here causes problems with
@@ -379,7 +421,10 @@ class Options:
# or as part of validation.
return swapped_for
- if '%s.%s' % (swapped_label, swapped_object.lower()) != self.label_lower:
+ if (
+ "%s.%s" % (swapped_label, swapped_object.lower())
+ != self.label_lower
+ ):
return swapped_for
return None
@@ -387,7 +432,7 @@ class Options:
def managers(self):
managers = []
seen_managers = set()
- bases = (b for b in self.model.mro() if hasattr(b, '_meta'))
+ bases = (b for b in self.model.mro() if hasattr(b, "_meta"))
for depth, base in enumerate(bases):
for manager in base._meta.local_managers:
if manager.name in seen_managers:
@@ -413,8 +458,8 @@ class Options:
if not base_manager_name:
# Get the first parent's base_manager_name if there's one.
for parent in self.model.mro()[1:]:
- if hasattr(parent, '_meta'):
- if parent._base_manager.name != '_base_manager':
+ if hasattr(parent, "_meta"):
+ if parent._base_manager.name != "_base_manager":
base_manager_name = parent._base_manager.name
break
@@ -423,14 +468,15 @@ class Options:
return self.managers_map[base_manager_name]
except KeyError:
raise ValueError(
- "%s has no manager named %r" % (
+ "%s has no manager named %r"
+ % (
self.object_name,
base_manager_name,
)
)
manager = Manager()
- manager.name = '_base_manager'
+ manager.name = "_base_manager"
manager.model = self.model
manager.auto_created = True
return manager
@@ -441,7 +487,7 @@ class Options:
if not default_manager_name and not self.local_managers:
# Get the first parent's default_manager_name if there's one.
for parent in self.model.mro()[1:]:
- if hasattr(parent, '_meta'):
+ if hasattr(parent, "_meta"):
default_manager_name = parent._meta.default_manager_name
break
@@ -450,7 +496,8 @@ class Options:
return self.managers_map[default_manager_name]
except KeyError:
raise ValueError(
- "%s has no manager named %r" % (
+ "%s has no manager named %r"
+ % (
self.object_name,
default_manager_name,
)
@@ -484,13 +531,20 @@ class Options:
def is_not_a_generic_foreign_key(f):
return not (
- f.is_relation and f.many_to_one and not (hasattr(f.remote_field, 'model') and f.remote_field.model)
+ f.is_relation
+ and f.many_to_one
+ and not (hasattr(f.remote_field, "model") and f.remote_field.model)
)
return make_immutable_fields_list(
"fields",
- (f for f in self._get_fields(reverse=False)
- if is_not_an_m2m_field(f) and is_not_a_generic_relation(f) and is_not_a_generic_foreign_key(f))
+ (
+ f
+ for f in self._get_fields(reverse=False)
+ if is_not_an_m2m_field(f)
+ and is_not_a_generic_relation(f)
+ and is_not_a_generic_foreign_key(f)
+ ),
)
@cached_property
@@ -530,7 +584,11 @@ class Options:
"""
return make_immutable_fields_list(
"many_to_many",
- (f for f in self._get_fields(reverse=False) if f.is_relation and f.many_to_many)
+ (
+ f
+ for f in self._get_fields(reverse=False)
+ if f.is_relation and f.many_to_many
+ ),
)
@cached_property
@@ -544,10 +602,16 @@ class Options:
combined with filtering of field properties is the public API for
obtaining this field list.
"""
- all_related_fields = self._get_fields(forward=False, reverse=True, include_hidden=True)
+ all_related_fields = self._get_fields(
+ forward=False, reverse=True, include_hidden=True
+ )
return make_immutable_fields_list(
"related_objects",
- (obj for obj in all_related_fields if not obj.hidden or obj.field.many_to_many)
+ (
+ obj
+ for obj in all_related_fields
+ if not obj.hidden or obj.field.many_to_many
+ ),
)
@cached_property
@@ -603,7 +667,9 @@ class Options:
# field map.
return self.fields_map[field_name]
except KeyError:
- raise FieldDoesNotExist("%s has no field named '%s'" % (self.object_name, field_name))
+ raise FieldDoesNotExist(
+ "%s has no field named '%s'" % (self.object_name, field_name)
+ )
def get_base_chain(self, model):
"""
@@ -672,15 +738,17 @@ class Options:
final_field = opts.parents[int_model]
targets = (final_field.remote_field.get_related_field(),)
opts = int_model._meta
- path.append(PathInfo(
- from_opts=final_field.model._meta,
- to_opts=opts,
- target_fields=targets,
- join_field=final_field,
- m2m=False,
- direct=True,
- filtered_relation=None,
- ))
+ path.append(
+ PathInfo(
+ from_opts=final_field.model._meta,
+ to_opts=opts,
+ target_fields=targets,
+ join_field=final_field,
+ m2m=False,
+ direct=True,
+ filtered_relation=None,
+ )
+ )
return path
def get_path_from_parent(self, parent):
@@ -722,7 +790,8 @@ class Options:
if opts.abstract:
continue
fields_with_relations = (
- f for f in opts._get_fields(reverse=False, include_parents=False)
+ f
+ for f in opts._get_fields(reverse=False, include_parents=False)
if f.is_relation and f.related_model is not None
)
for f in fields_with_relations:
@@ -736,11 +805,13 @@ class Options:
# __dict__ takes precedence over a data descriptor (such as
# @cached_property). This means that the _meta._relation_tree is
# only called if related_objects is not in __dict__.
- related_objects = related_objects_graph[model._meta.concrete_model._meta.label]
- model._meta.__dict__['_relation_tree'] = related_objects
+ related_objects = related_objects_graph[
+ model._meta.concrete_model._meta.label
+ ]
+ model._meta.__dict__["_relation_tree"] = related_objects
# It seems it is possible that self is not in all_models, so guard
# against that with default for get().
- return self.__dict__.get('_relation_tree', EMPTY_RELATION_TREE)
+ return self.__dict__.get("_relation_tree", EMPTY_RELATION_TREE)
@cached_property
def _relation_tree(self):
@@ -771,10 +842,18 @@ class Options:
"""
if include_parents is False:
include_parents = PROXY_PARENTS
- return self._get_fields(include_parents=include_parents, include_hidden=include_hidden)
+ return self._get_fields(
+ include_parents=include_parents, include_hidden=include_hidden
+ )
- def _get_fields(self, forward=True, reverse=True, include_parents=True, include_hidden=False,
- seen_models=None):
+ def _get_fields(
+ self,
+ forward=True,
+ reverse=True,
+ include_parents=True,
+ include_hidden=False,
+ seen_models=None,
+ ):
"""
Internal helper function to return fields of the model.
* If forward=True, then fields defined on this model are returned.
@@ -787,7 +866,9 @@ class Options:
parent chain to the model's concrete model.
"""
if include_parents not in (True, False, PROXY_PARENTS):
- raise TypeError("Invalid argument for include_parents: %s" % (include_parents,))
+ raise TypeError(
+ "Invalid argument for include_parents: %s" % (include_parents,)
+ )
# This helper function is used to allow recursion in ``get_fields()``
# implementation and to provide a fast way for Django's internals to
# access specific subsets of fields.
@@ -819,13 +900,22 @@ class Options:
# fields from the same parent again.
if parent in seen_models:
continue
- if (parent._meta.concrete_model != self.concrete_model and
- include_parents == PROXY_PARENTS):
+ if (
+ parent._meta.concrete_model != self.concrete_model
+ and include_parents == PROXY_PARENTS
+ ):
continue
for obj in parent._meta._get_fields(
- forward=forward, reverse=reverse, include_parents=include_parents,
- include_hidden=include_hidden, seen_models=seen_models):
- if not getattr(obj, 'parent_link', False) or obj.model == self.concrete_model:
+ forward=forward,
+ reverse=reverse,
+ include_parents=include_parents,
+ include_hidden=include_hidden,
+ seen_models=seen_models,
+ ):
+ if (
+ not getattr(obj, "parent_link", False)
+ or obj.model == self.concrete_model
+ ):
fields.append(obj)
if reverse and not self.proxy:
# Tree is computed once and cached until the app cache is expired.
@@ -867,9 +957,9 @@ class Options:
constraint
for constraint in self.constraints
if (
- isinstance(constraint, UniqueConstraint) and
- constraint.condition is None and
- not constraint.contains_expressions
+ isinstance(constraint, UniqueConstraint)
+ and constraint.condition is None
+ and not constraint.contains_expressions
)
]
@@ -890,6 +980,9 @@ class Options:
Fields to be returned after a database insert.
"""
return [
- field for field in self._get_fields(forward=True, reverse=False, include_parents=PROXY_PARENTS)
- if getattr(field, 'db_returning', False)
+ field
+ for field in self._get_fields(
+ forward=True, reverse=False, include_parents=PROXY_PARENTS
+ )
+ if getattr(field, "db_returning", False)
]
diff --git a/django/db/models/query.py b/django/db/models/query.py
index 0bc6aec2f3..687fd8b4cd 100644
--- a/django/db/models/query.py
+++ b/django/db/models/query.py
@@ -11,8 +11,12 @@ import django
from django.conf import settings
from django.core import exceptions
from django.db import (
- DJANGO_VERSION_PICKLE_KEY, IntegrityError, NotSupportedError, connections,
- router, transaction,
+ DJANGO_VERSION_PICKLE_KEY,
+ IntegrityError,
+ NotSupportedError,
+ connections,
+ router,
+ transaction,
)
from django.db.models import AutoField, DateField, DateTimeField, sql
from django.db.models.constants import LOOKUP_SEP, OnConflict
@@ -34,7 +38,9 @@ REPR_OUTPUT_SIZE = 20
class BaseIterable:
- def __init__(self, queryset, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE):
+ def __init__(
+ self, queryset, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE
+ ):
self.queryset = queryset
self.chunked_fetch = chunked_fetch
self.chunk_size = chunk_size
@@ -49,25 +55,40 @@ class ModelIterable(BaseIterable):
compiler = queryset.query.get_compiler(using=db)
# Execute the query. This will also fill compiler.select, klass_info,
# and annotations.
- results = compiler.execute_sql(chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size)
- select, klass_info, annotation_col_map = (compiler.select, compiler.klass_info,
- compiler.annotation_col_map)
- model_cls = klass_info['model']
- select_fields = klass_info['select_fields']
+ results = compiler.execute_sql(
+ chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
+ )
+ select, klass_info, annotation_col_map = (
+ compiler.select,
+ compiler.klass_info,
+ compiler.annotation_col_map,
+ )
+ model_cls = klass_info["model"]
+ select_fields = klass_info["select_fields"]
model_fields_start, model_fields_end = select_fields[0], select_fields[-1] + 1
- init_list = [f[0].target.attname
- for f in select[model_fields_start:model_fields_end]]
+ init_list = [
+ f[0].target.attname for f in select[model_fields_start:model_fields_end]
+ ]
related_populators = get_related_populators(klass_info, select, db)
known_related_objects = [
- (field, related_objs, operator.attrgetter(*[
- field.attname
- if from_field == 'self' else
- queryset.model._meta.get_field(from_field).attname
- for from_field in field.from_fields
- ])) for field, related_objs in queryset._known_related_objects.items()
+ (
+ field,
+ related_objs,
+ operator.attrgetter(
+ *[
+ field.attname
+ if from_field == "self"
+ else queryset.model._meta.get_field(from_field).attname
+ for from_field in field.from_fields
+ ]
+ ),
+ )
+ for field, related_objs in queryset._known_related_objects.items()
]
for row in compiler.results_iter(results):
- obj = model_cls.from_db(db, init_list, row[model_fields_start:model_fields_end])
+ obj = model_cls.from_db(
+ db, init_list, row[model_fields_start:model_fields_end]
+ )
for rel_populator in related_populators:
rel_populator.populate(row, obj)
if annotation_col_map:
@@ -107,7 +128,9 @@ class ValuesIterable(BaseIterable):
*query.annotation_select,
]
indexes = range(len(names))
- for row in compiler.results_iter(chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size):
+ for row in compiler.results_iter(
+ chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
+ ):
yield {names[i]: row[i] for i in indexes}
@@ -129,16 +152,25 @@ class ValuesListIterable(BaseIterable):
*query.values_select,
*query.annotation_select,
]
- fields = [*queryset._fields, *(f for f in query.annotation_select if f not in queryset._fields)]
+ fields = [
+ *queryset._fields,
+ *(f for f in query.annotation_select if f not in queryset._fields),
+ ]
if fields != names:
# Reorder according to fields.
index_map = {name: idx for idx, name in enumerate(names)}
rowfactory = operator.itemgetter(*[index_map[f] for f in fields])
return map(
rowfactory,
- compiler.results_iter(chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size)
+ compiler.results_iter(
+ chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
+ ),
)
- return compiler.results_iter(tuple_expected=True, chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size)
+ return compiler.results_iter(
+ tuple_expected=True,
+ chunked_fetch=self.chunked_fetch,
+ chunk_size=self.chunk_size,
+ )
class NamedValuesListIterable(ValuesListIterable):
@@ -153,7 +185,11 @@ class NamedValuesListIterable(ValuesListIterable):
names = queryset._fields
else:
query = queryset.query
- names = [*query.extra_select, *query.values_select, *query.annotation_select]
+ names = [
+ *query.extra_select,
+ *query.values_select,
+ *query.annotation_select,
+ ]
tuple_class = create_namedtuple_class(*names)
new = tuple.__new__
for row in super().__iter__():
@@ -169,7 +205,9 @@ class FlatValuesListIterable(BaseIterable):
def __iter__(self):
queryset = self.queryset
compiler = queryset.query.get_compiler(queryset.db)
- for row in compiler.results_iter(chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size):
+ for row in compiler.results_iter(
+ chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
+ ):
yield row[0]
@@ -209,9 +247,11 @@ class QuerySet:
def as_manager(cls):
# Address the circular dependency between `Queryset` and `Manager`.
from django.db.models.manager import Manager
+
manager = Manager.from_queryset(cls)()
manager._built_with_as_manager = True
return manager
+
as_manager.queryset_only = True
as_manager = classmethod(as_manager)
@@ -223,7 +263,7 @@ class QuerySet:
"""Don't populate the QuerySet's cache."""
obj = self.__class__()
for k, v in self.__dict__.items():
- if k == '_result_cache':
+ if k == "_result_cache":
obj.__dict__[k] = None
else:
obj.__dict__[k] = copy.deepcopy(v, memo)
@@ -254,10 +294,10 @@ class QuerySet:
self.__dict__.update(state)
def __repr__(self):
- data = list(self[:REPR_OUTPUT_SIZE + 1])
+ data = list(self[: REPR_OUTPUT_SIZE + 1])
if len(data) > REPR_OUTPUT_SIZE:
data[-1] = "...(remaining elements truncated)..."
- return '<%s %r>' % (self.__class__.__name__, data)
+ return "<%s %r>" % (self.__class__.__name__, data)
def __len__(self):
self._fetch_all()
@@ -289,17 +329,17 @@ class QuerySet:
"""Retrieve an item or slice from the set of results."""
if not isinstance(k, (int, slice)):
raise TypeError(
- 'QuerySet indices must be integers or slices, not %s.'
+ "QuerySet indices must be integers or slices, not %s."
% type(k).__name__
)
- if (
- (isinstance(k, int) and k < 0) or
- (isinstance(k, slice) and (
- (k.start is not None and k.start < 0) or
- (k.stop is not None and k.stop < 0)
- ))
+ if (isinstance(k, int) and k < 0) or (
+ isinstance(k, slice)
+ and (
+ (k.start is not None and k.start < 0)
+ or (k.stop is not None and k.stop < 0)
+ )
):
- raise ValueError('Negative indexing is not supported.')
+ raise ValueError("Negative indexing is not supported.")
if self._result_cache is not None:
return self._result_cache[k]
@@ -315,7 +355,7 @@ class QuerySet:
else:
stop = None
qs.query.set_limits(start, stop)
- return list(qs)[::k.step] if k.step else qs
+ return list(qs)[:: k.step] if k.step else qs
qs = self._chain()
qs.query.set_limits(k, k + 1)
@@ -326,7 +366,7 @@ class QuerySet:
return cls
def __and__(self, other):
- self._check_operator_queryset(other, '&')
+ self._check_operator_queryset(other, "&")
self._merge_sanity_check(other)
if isinstance(other, EmptyQuerySet):
return other
@@ -338,17 +378,21 @@ class QuerySet:
return combined
def __or__(self, other):
- self._check_operator_queryset(other, '|')
+ self._check_operator_queryset(other, "|")
self._merge_sanity_check(other)
if isinstance(self, EmptyQuerySet):
return other
if isinstance(other, EmptyQuerySet):
return self
- query = self if self.query.can_filter() else self.model._base_manager.filter(pk__in=self.values('pk'))
+ query = (
+ self
+ if self.query.can_filter()
+ else self.model._base_manager.filter(pk__in=self.values("pk"))
+ )
combined = query._chain()
combined._merge_known_related_objects(other)
if not other.query.can_filter():
- other = other.model._base_manager.filter(pk__in=other.values('pk'))
+ other = other.model._base_manager.filter(pk__in=other.values("pk"))
combined.query.combine(other.query, sql.OR)
return combined
@@ -385,14 +429,16 @@ class QuerySet:
# 'QuerySet.iterator() after prefetch_related().'
# )
warnings.warn(
- 'Using QuerySet.iterator() after prefetch_related() '
- 'without specifying chunk_size is deprecated.',
+ "Using QuerySet.iterator() after prefetch_related() "
+ "without specifying chunk_size is deprecated.",
category=RemovedInDjango50Warning,
stacklevel=2,
)
elif chunk_size <= 0:
- raise ValueError('Chunk size must be strictly positive.')
- use_chunked_fetch = not connections[self.db].settings_dict.get('DISABLE_SERVER_SIDE_CURSORS')
+ raise ValueError("Chunk size must be strictly positive.")
+ use_chunked_fetch = not connections[self.db].settings_dict.get(
+ "DISABLE_SERVER_SIDE_CURSORS"
+ )
return self._iterator(use_chunked_fetch, chunk_size)
def aggregate(self, *args, **kwargs):
@@ -405,7 +451,9 @@ class QuerySet:
"""
if self.query.distinct_fields:
raise NotImplementedError("aggregate() + distinct(fields) not implemented.")
- self._validate_values_are_expressions((*args, *kwargs.values()), method_name='aggregate')
+ self._validate_values_are_expressions(
+ (*args, *kwargs.values()), method_name="aggregate"
+ )
for arg in args:
# The default_alias property raises TypeError if default_alias
# can't be set automatically or AttributeError if it isn't an
@@ -423,7 +471,11 @@ class QuerySet:
if not annotation.contains_aggregate:
raise TypeError("%s is not an aggregate expression" % alias)
for expr in annotation.get_source_expressions():
- if expr.contains_aggregate and isinstance(expr, Ref) and expr.refs in kwargs:
+ if (
+ expr.contains_aggregate
+ and isinstance(expr, Ref)
+ and expr.refs in kwargs
+ ):
name = expr.refs
raise exceptions.FieldError(
"Cannot compute %s('%s'): '%s' is an aggregate"
@@ -451,14 +503,17 @@ class QuerySet:
"""
if self.query.combinator and (args or kwargs):
raise NotSupportedError(
- 'Calling QuerySet.get(...) with filters after %s() is not '
- 'supported.' % self.query.combinator
+ "Calling QuerySet.get(...) with filters after %s() is not "
+ "supported." % self.query.combinator
)
clone = self._chain() if self.query.combinator else self.filter(*args, **kwargs)
if self.query.can_filter() and not self.query.distinct_fields:
clone = clone.order_by()
limit = None
- if not clone.query.select_for_update or connections[clone.db].features.supports_select_for_update_with_limit:
+ if (
+ not clone.query.select_for_update
+ or connections[clone.db].features.supports_select_for_update_with_limit
+ ):
limit = MAX_GET_RESULTS
clone.query.set_limits(high=limit)
num = len(clone)
@@ -466,13 +521,13 @@ class QuerySet:
return clone._result_cache[0]
if not num:
raise self.model.DoesNotExist(
- "%s matching query does not exist." %
- self.model._meta.object_name
+ "%s matching query does not exist." % self.model._meta.object_name
)
raise self.model.MultipleObjectsReturned(
- 'get() returned more than one %s -- it returned %s!' % (
+ "get() returned more than one %s -- it returned %s!"
+ % (
self.model._meta.object_name,
- num if not limit or num < limit else 'more than %s' % (limit - 1),
+ num if not limit or num < limit else "more than %s" % (limit - 1),
)
)
@@ -491,69 +546,77 @@ class QuerySet:
if obj.pk is None:
# Populate new PK values.
obj.pk = obj._meta.pk.get_pk_value_on_save(obj)
- obj._prepare_related_fields_for_save(operation_name='bulk_create')
+ obj._prepare_related_fields_for_save(operation_name="bulk_create")
- def _check_bulk_create_options(self, ignore_conflicts, update_conflicts, update_fields, unique_fields):
+ def _check_bulk_create_options(
+ self, ignore_conflicts, update_conflicts, update_fields, unique_fields
+ ):
if ignore_conflicts and update_conflicts:
raise ValueError(
- 'ignore_conflicts and update_conflicts are mutually exclusive.'
+ "ignore_conflicts and update_conflicts are mutually exclusive."
)
db_features = connections[self.db].features
if ignore_conflicts:
if not db_features.supports_ignore_conflicts:
raise NotSupportedError(
- 'This database backend does not support ignoring conflicts.'
+ "This database backend does not support ignoring conflicts."
)
return OnConflict.IGNORE
elif update_conflicts:
if not db_features.supports_update_conflicts:
raise NotSupportedError(
- 'This database backend does not support updating conflicts.'
+ "This database backend does not support updating conflicts."
)
if not update_fields:
raise ValueError(
- 'Fields that will be updated when a row insertion fails '
- 'on conflicts must be provided.'
+ "Fields that will be updated when a row insertion fails "
+ "on conflicts must be provided."
)
if unique_fields and not db_features.supports_update_conflicts_with_target:
raise NotSupportedError(
- 'This database backend does not support updating '
- 'conflicts with specifying unique fields that can trigger '
- 'the upsert.'
+ "This database backend does not support updating "
+ "conflicts with specifying unique fields that can trigger "
+ "the upsert."
)
if not unique_fields and db_features.supports_update_conflicts_with_target:
raise ValueError(
- 'Unique fields that can trigger the upsert must be provided.'
+ "Unique fields that can trigger the upsert must be provided."
)
# Updating primary keys and non-concrete fields is forbidden.
update_fields = [self.model._meta.get_field(name) for name in update_fields]
if any(not f.concrete or f.many_to_many for f in update_fields):
raise ValueError(
- 'bulk_create() can only be used with concrete fields in '
- 'update_fields.'
+ "bulk_create() can only be used with concrete fields in "
+ "update_fields."
)
if any(f.primary_key for f in update_fields):
raise ValueError(
- 'bulk_create() cannot be used with primary keys in '
- 'update_fields.'
+ "bulk_create() cannot be used with primary keys in "
+ "update_fields."
)
if unique_fields:
# Primary key is allowed in unique_fields.
unique_fields = [
self.model._meta.get_field(name)
- for name in unique_fields if name != 'pk'
+ for name in unique_fields
+ if name != "pk"
]
if any(not f.concrete or f.many_to_many for f in unique_fields):
raise ValueError(
- 'bulk_create() can only be used with concrete fields '
- 'in unique_fields.'
+ "bulk_create() can only be used with concrete fields "
+ "in unique_fields."
)
return OnConflict.UPDATE
return None
def bulk_create(
- self, objs, batch_size=None, ignore_conflicts=False,
- update_conflicts=False, update_fields=None, unique_fields=None,
+ self,
+ objs,
+ batch_size=None,
+ ignore_conflicts=False,
+ update_conflicts=False,
+ update_fields=None,
+ unique_fields=None,
):
"""
Insert each of the instances into the database. Do *not* call
@@ -575,7 +638,7 @@ class QuerySet:
# Oracle as well, but the semantics for extracting the primary keys is
# trickier so it's not done yet.
if batch_size is not None and batch_size <= 0:
- raise ValueError('Batch size must be a positive integer.')
+ raise ValueError("Batch size must be a positive integer.")
# Check that the parents share the same concrete model with the our
# model to detect the inheritance pattern ConcreteGrandParent ->
# MultiTableParent -> ProxyChild. Simply checking self.model._meta.proxy
@@ -625,7 +688,10 @@ class QuerySet:
unique_fields=unique_fields,
)
connection = connections[self.db]
- if connection.features.can_return_rows_from_bulk_insert and on_conflict is None:
+ if (
+ connection.features.can_return_rows_from_bulk_insert
+ and on_conflict is None
+ ):
assert len(returned_columns) == len(objs_without_pk)
for obj_without_pk, results in zip(objs_without_pk, returned_columns):
for result, field in zip(results, opts.db_returning_fields):
@@ -640,28 +706,30 @@ class QuerySet:
Update the given fields in each of the given objects in the database.
"""
if batch_size is not None and batch_size < 0:
- raise ValueError('Batch size must be a positive integer.')
+ raise ValueError("Batch size must be a positive integer.")
if not fields:
- raise ValueError('Field names must be given to bulk_update().')
+ raise ValueError("Field names must be given to bulk_update().")
objs = tuple(objs)
if any(obj.pk is None for obj in objs):
- raise ValueError('All bulk_update() objects must have a primary key set.')
+ raise ValueError("All bulk_update() objects must have a primary key set.")
fields = [self.model._meta.get_field(name) for name in fields]
if any(not f.concrete or f.many_to_many for f in fields):
- raise ValueError('bulk_update() can only be used with concrete fields.')
+ raise ValueError("bulk_update() can only be used with concrete fields.")
if any(f.primary_key for f in fields):
- raise ValueError('bulk_update() cannot be used with primary key fields.')
+ raise ValueError("bulk_update() cannot be used with primary key fields.")
if not objs:
return 0
for obj in objs:
- obj._prepare_related_fields_for_save(operation_name='bulk_update', fields=fields)
+ obj._prepare_related_fields_for_save(
+ operation_name="bulk_update", fields=fields
+ )
# PK is used twice in the resulting update query, once in the filter
# and once in the WHEN. Each field will also have one CAST.
connection = connections[self.db]
- max_batch_size = connection.ops.bulk_batch_size(['pk', 'pk'] + fields, objs)
+ max_batch_size = connection.ops.bulk_batch_size(["pk", "pk"] + fields, objs)
batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
requires_casting = connection.features.requires_casted_case_in_updates
- batches = (objs[i:i + batch_size] for i in range(0, len(objs), batch_size))
+ batches = (objs[i : i + batch_size] for i in range(0, len(objs), batch_size))
updates = []
for batch_objs in batches:
update_kwargs = {}
@@ -669,7 +737,7 @@ class QuerySet:
when_statements = []
for obj in batch_objs:
attr = getattr(obj, field.attname)
- if not hasattr(attr, 'resolve_expression'):
+ if not hasattr(attr, "resolve_expression"):
attr = Value(attr, output_field=field)
when_statements.append(When(pk=obj.pk, then=attr))
case_statement = Case(*when_statements, output_field=field)
@@ -682,6 +750,7 @@ class QuerySet:
for pks, update_kwargs in updates:
rows_updated += self.filter(pk__in=pks).update(**update_kwargs)
return rows_updated
+
bulk_update.alters_data = True
def get_or_create(self, defaults=None, **kwargs):
@@ -748,10 +817,12 @@ class QuerySet:
invalid_params.append(param)
if invalid_params:
raise exceptions.FieldError(
- "Invalid field name(s) for model %s: '%s'." % (
+ "Invalid field name(s) for model %s: '%s'."
+ % (
self.model._meta.object_name,
"', '".join(sorted(invalid_params)),
- ))
+ )
+ )
return params
def _earliest(self, *fields):
@@ -762,7 +833,7 @@ class QuerySet:
if fields:
order_by = fields
else:
- order_by = getattr(self.model._meta, 'get_latest_by')
+ order_by = getattr(self.model._meta, "get_latest_by")
if order_by and not isinstance(order_by, (tuple, list)):
order_by = (order_by,)
if order_by is None:
@@ -778,25 +849,25 @@ class QuerySet:
def earliest(self, *fields):
if self.query.is_sliced:
- raise TypeError('Cannot change a query once a slice has been taken.')
+ raise TypeError("Cannot change a query once a slice has been taken.")
return self._earliest(*fields)
def latest(self, *fields):
if self.query.is_sliced:
- raise TypeError('Cannot change a query once a slice has been taken.')
+ raise TypeError("Cannot change a query once a slice has been taken.")
return self.reverse()._earliest(*fields)
def first(self):
"""Return the first object of a query or None if no match is found."""
- for obj in (self if self.ordered else self.order_by('pk'))[:1]:
+ for obj in (self if self.ordered else self.order_by("pk"))[:1]:
return obj
def last(self):
"""Return the last object of a query or None if no match is found."""
- for obj in (self.reverse() if self.ordered else self.order_by('-pk'))[:1]:
+ for obj in (self.reverse() if self.ordered else self.order_by("-pk"))[:1]:
return obj
- def in_bulk(self, id_list=None, *, field_name='pk'):
+ def in_bulk(self, id_list=None, *, field_name="pk"):
"""
Return a dictionary mapping each of the given IDs to the object with
that ID. If `id_list` isn't provided, evaluate the entire QuerySet.
@@ -810,16 +881,19 @@ class QuerySet:
if len(constraint.fields) == 1
]
if (
- field_name != 'pk' and
- not opts.get_field(field_name).unique and
- field_name not in unique_fields and
- self.query.distinct_fields != (field_name,)
+ field_name != "pk"
+ and not opts.get_field(field_name).unique
+ and field_name not in unique_fields
+ and self.query.distinct_fields != (field_name,)
):
- raise ValueError("in_bulk()'s field_name must be a unique field but %r isn't." % field_name)
+ raise ValueError(
+ "in_bulk()'s field_name must be a unique field but %r isn't."
+ % field_name
+ )
if id_list is not None:
if not id_list:
return {}
- filter_key = '{}__in'.format(field_name)
+ filter_key = "{}__in".format(field_name)
batch_size = connections[self.db].features.max_query_params
id_list = tuple(id_list)
# If the database has a limit on the number of query parameters
@@ -827,7 +901,7 @@ class QuerySet:
if batch_size and batch_size < len(id_list):
qs = ()
for offset in range(0, len(id_list), batch_size):
- batch = id_list[offset:offset + batch_size]
+ batch = id_list[offset : offset + batch_size]
qs += tuple(self.filter(**{filter_key: batch}).order_by())
else:
qs = self.filter(**{filter_key: id_list}).order_by()
@@ -837,11 +911,11 @@ class QuerySet:
def delete(self):
"""Delete the records in the current QuerySet."""
- self._not_support_combined_queries('delete')
+ self._not_support_combined_queries("delete")
if self.query.is_sliced:
raise TypeError("Cannot use 'limit' or 'offset' with delete().")
if self.query.distinct or self.query.distinct_fields:
- raise TypeError('Cannot call delete() after .distinct().')
+ raise TypeError("Cannot call delete() after .distinct().")
if self._fields is not None:
raise TypeError("Cannot call delete() after .values() or .values_list()")
@@ -880,6 +954,7 @@ class QuerySet:
with cursor:
return cursor.rowcount
return 0
+
_raw_delete.alters_data = True
def update(self, **kwargs):
@@ -887,9 +962,9 @@ class QuerySet:
Update all elements in the current QuerySet, setting all the given
fields to the appropriate values.
"""
- self._not_support_combined_queries('update')
+ self._not_support_combined_queries("update")
if self.query.is_sliced:
- raise TypeError('Cannot update a query once a slice has been taken.')
+ raise TypeError("Cannot update a query once a slice has been taken.")
self._for_write = True
query = self.query.chain(sql.UpdateQuery)
query.add_update_values(kwargs)
@@ -899,6 +974,7 @@ class QuerySet:
rows = query.get_compiler(self.db).execute_sql(CURSOR)
self._result_cache = None
return rows
+
update.alters_data = True
def _update(self, values):
@@ -909,13 +985,14 @@ class QuerySet:
useful at that level).
"""
if self.query.is_sliced:
- raise TypeError('Cannot update a query once a slice has been taken.')
+ raise TypeError("Cannot update a query once a slice has been taken.")
query = self.query.chain(sql.UpdateQuery)
query.add_update_fields(values)
# Clear any annotations so that they won't be present in subqueries.
query.annotations = {}
self._result_cache = None
return query.get_compiler(self.db).execute_sql(CURSOR)
+
_update.alters_data = True
_update.queryset_only = False
@@ -926,10 +1003,10 @@ class QuerySet:
def contains(self, obj):
"""Return True if the queryset contains an object."""
- self._not_support_combined_queries('contains')
+ self._not_support_combined_queries("contains")
if self._fields is not None:
raise TypeError(
- 'Cannot call QuerySet.contains() after .values() or .values_list().'
+ "Cannot call QuerySet.contains() after .values() or .values_list()."
)
try:
if obj._meta.concrete_model != self.model._meta.concrete_model:
@@ -937,9 +1014,7 @@ class QuerySet:
except AttributeError:
raise TypeError("'obj' must be a model instance.")
if obj.pk is None:
- raise ValueError(
- 'QuerySet.contains() cannot be used on unsaved objects.'
- )
+ raise ValueError("QuerySet.contains() cannot be used on unsaved objects.")
if self._result_cache is not None:
return obj in self._result_cache
return self.filter(pk=obj.pk).exists()
@@ -959,7 +1034,13 @@ class QuerySet:
def raw(self, raw_query, params=(), translations=None, using=None):
if using is None:
using = self.db
- qs = RawQuerySet(raw_query, model=self.model, params=params, translations=translations, using=using)
+ qs = RawQuerySet(
+ raw_query,
+ model=self.model,
+ params=params,
+ translations=translations,
+ using=using,
+ )
qs._prefetch_related_lookups = self._prefetch_related_lookups[:]
return qs
@@ -981,15 +1062,19 @@ class QuerySet:
if flat and named:
raise TypeError("'flat' and 'named' can't be used together.")
if flat and len(fields) > 1:
- raise TypeError("'flat' is not valid when values_list is called with more than one field.")
+ raise TypeError(
+ "'flat' is not valid when values_list is called with more than one field."
+ )
- field_names = {f for f in fields if not hasattr(f, 'resolve_expression')}
+ field_names = {f for f in fields if not hasattr(f, "resolve_expression")}
_fields = []
expressions = {}
counter = 1
for field in fields:
- if hasattr(field, 'resolve_expression'):
- field_id_prefix = getattr(field, 'default_alias', field.__class__.__name__.lower())
+ if hasattr(field, "resolve_expression"):
+ field_id_prefix = getattr(
+ field, "default_alias", field.__class__.__name__.lower()
+ )
while True:
field_id = field_id_prefix + str(counter)
counter += 1
@@ -1002,59 +1087,71 @@ class QuerySet:
clone = self._values(*_fields, **expressions)
clone._iterable_class = (
- NamedValuesListIterable if named
- else FlatValuesListIterable if flat
+ NamedValuesListIterable
+ if named
+ else FlatValuesListIterable
+ if flat
else ValuesListIterable
)
return clone
- def dates(self, field_name, kind, order='ASC'):
+ def dates(self, field_name, kind, order="ASC"):
"""
Return a list of date objects representing all available dates for
the given field_name, scoped to 'kind'.
"""
- if kind not in ('year', 'month', 'week', 'day'):
+ if kind not in ("year", "month", "week", "day"):
raise ValueError("'kind' must be one of 'year', 'month', 'week', or 'day'.")
- if order not in ('ASC', 'DESC'):
+ if order not in ("ASC", "DESC"):
raise ValueError("'order' must be either 'ASC' or 'DESC'.")
- return self.annotate(
- datefield=Trunc(field_name, kind, output_field=DateField()),
- plain_field=F(field_name)
- ).values_list(
- 'datefield', flat=True
- ).distinct().filter(plain_field__isnull=False).order_by(('-' if order == 'DESC' else '') + 'datefield')
+ return (
+ self.annotate(
+ datefield=Trunc(field_name, kind, output_field=DateField()),
+ plain_field=F(field_name),
+ )
+ .values_list("datefield", flat=True)
+ .distinct()
+ .filter(plain_field__isnull=False)
+ .order_by(("-" if order == "DESC" else "") + "datefield")
+ )
# RemovedInDjango50Warning: when the deprecation ends, remove is_dst
# argument.
- def datetimes(self, field_name, kind, order='ASC', tzinfo=None, is_dst=timezone.NOT_PASSED):
+ def datetimes(
+ self, field_name, kind, order="ASC", tzinfo=None, is_dst=timezone.NOT_PASSED
+ ):
"""
Return a list of datetime objects representing all available
datetimes for the given field_name, scoped to 'kind'.
"""
- if kind not in ('year', 'month', 'week', 'day', 'hour', 'minute', 'second'):
+ if kind not in ("year", "month", "week", "day", "hour", "minute", "second"):
raise ValueError(
"'kind' must be one of 'year', 'month', 'week', 'day', "
"'hour', 'minute', or 'second'."
)
- if order not in ('ASC', 'DESC'):
+ if order not in ("ASC", "DESC"):
raise ValueError("'order' must be either 'ASC' or 'DESC'.")
if settings.USE_TZ:
if tzinfo is None:
tzinfo = timezone.get_current_timezone()
else:
tzinfo = None
- return self.annotate(
- datetimefield=Trunc(
- field_name,
- kind,
- output_field=DateTimeField(),
- tzinfo=tzinfo,
- is_dst=is_dst,
- ),
- plain_field=F(field_name)
- ).values_list(
- 'datetimefield', flat=True
- ).distinct().filter(plain_field__isnull=False).order_by(('-' if order == 'DESC' else '') + 'datetimefield')
+ return (
+ self.annotate(
+ datetimefield=Trunc(
+ field_name,
+ kind,
+ output_field=DateTimeField(),
+ tzinfo=tzinfo,
+ is_dst=is_dst,
+ ),
+ plain_field=F(field_name),
+ )
+ .values_list("datetimefield", flat=True)
+ .distinct()
+ .filter(plain_field__isnull=False)
+ .order_by(("-" if order == "DESC" else "") + "datetimefield")
+ )
def none(self):
"""Return an empty QuerySet."""
@@ -1078,7 +1175,7 @@ class QuerySet:
Return a new QuerySet instance with the args ANDed to the existing
set.
"""
- self._not_support_combined_queries('filter')
+ self._not_support_combined_queries("filter")
return self._filter_or_exclude(False, args, kwargs)
def exclude(self, *args, **kwargs):
@@ -1086,12 +1183,12 @@ class QuerySet:
Return a new QuerySet instance with NOT (args) ANDed to the existing
set.
"""
- self._not_support_combined_queries('exclude')
+ self._not_support_combined_queries("exclude")
return self._filter_or_exclude(True, args, kwargs)
def _filter_or_exclude(self, negate, args, kwargs):
if (args or kwargs) and self.query.is_sliced:
- raise TypeError('Cannot filter a query once a slice has been taken.')
+ raise TypeError("Cannot filter a query once a slice has been taken.")
clone = self._chain()
if self._defer_next_filter:
self._defer_next_filter = False
@@ -1129,7 +1226,9 @@ class QuerySet:
# Clear limits and ordering so they can be reapplied
clone.query.clear_ordering(force=True)
clone.query.clear_limits()
- clone.query.combined_queries = (self.query,) + tuple(qs.query for qs in other_qs)
+ clone.query.combined_queries = (self.query,) + tuple(
+ qs.query for qs in other_qs
+ )
clone.query.combinator = combinator
clone.query.combinator_all = all
return clone
@@ -1142,8 +1241,8 @@ class QuerySet:
return self
if len(qs) == 1:
return qs[0]
- return qs[0]._combinator_query('union', *qs[1:], all=all)
- return self._combinator_query('union', *other_qs, all=all)
+ return qs[0]._combinator_query("union", *qs[1:], all=all)
+ return self._combinator_query("union", *other_qs, all=all)
def intersection(self, *other_qs):
# If any query is an EmptyQuerySet, return it.
@@ -1152,13 +1251,13 @@ class QuerySet:
for other in other_qs:
if isinstance(other, EmptyQuerySet):
return other
- return self._combinator_query('intersection', *other_qs)
+ return self._combinator_query("intersection", *other_qs)
def difference(self, *other_qs):
# If the query is an EmptyQuerySet, return it.
if isinstance(self, EmptyQuerySet):
return self
- return self._combinator_query('difference', *other_qs)
+ return self._combinator_query("difference", *other_qs)
def select_for_update(self, nowait=False, skip_locked=False, of=(), no_key=False):
"""
@@ -1166,7 +1265,7 @@ class QuerySet:
FOR UPDATE lock.
"""
if nowait and skip_locked:
- raise ValueError('The nowait option cannot be used with skip_locked.')
+ raise ValueError("The nowait option cannot be used with skip_locked.")
obj = self._chain()
obj._for_write = True
obj.query.select_for_update = True
@@ -1185,9 +1284,11 @@ class QuerySet:
If select_related(None) is called, clear the list.
"""
- self._not_support_combined_queries('select_related')
+ self._not_support_combined_queries("select_related")
if self._fields is not None:
- raise TypeError("Cannot call select_related() after .values() or .values_list()")
+ raise TypeError(
+ "Cannot call select_related() after .values() or .values_list()"
+ )
obj = self._chain()
if fields == (None,):
@@ -1207,7 +1308,7 @@ class QuerySet:
When prefetch_related() is called more than once, append to the list of
prefetch lookups. If prefetch_related(None) is called, clear the list.
"""
- self._not_support_combined_queries('prefetch_related')
+ self._not_support_combined_queries("prefetch_related")
clone = self._chain()
if lookups == (None,):
clone._prefetch_related_lookups = ()
@@ -1217,7 +1318,9 @@ class QuerySet:
lookup = lookup.prefetch_to
lookup = lookup.split(LOOKUP_SEP, 1)[0]
if lookup in self.query._filtered_relations:
- raise ValueError('prefetch_related() is not supported with FilteredRelation.')
+ raise ValueError(
+ "prefetch_related() is not supported with FilteredRelation."
+ )
clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups
return clone
@@ -1226,26 +1329,29 @@ class QuerySet:
Return a query set in which the returned objects have been annotated
with extra data or aggregations.
"""
- self._not_support_combined_queries('annotate')
+ self._not_support_combined_queries("annotate")
return self._annotate(args, kwargs, select=True)
def alias(self, *args, **kwargs):
"""
Return a query set with added aliases for extra data or aggregations.
"""
- self._not_support_combined_queries('alias')
+ self._not_support_combined_queries("alias")
return self._annotate(args, kwargs, select=False)
def _annotate(self, args, kwargs, select=True):
- self._validate_values_are_expressions(args + tuple(kwargs.values()), method_name='annotate')
+ self._validate_values_are_expressions(
+ args + tuple(kwargs.values()), method_name="annotate"
+ )
annotations = {}
for arg in args:
# The default_alias property may raise a TypeError.
try:
if arg.default_alias in kwargs:
- raise ValueError("The named annotation '%s' conflicts with the "
- "default name for another annotation."
- % arg.default_alias)
+ raise ValueError(
+ "The named annotation '%s' conflicts with the "
+ "default name for another annotation." % arg.default_alias
+ )
except TypeError:
raise TypeError("Complex annotations require an alias")
annotations[arg.default_alias] = arg
@@ -1254,20 +1360,29 @@ class QuerySet:
clone = self._chain()
names = self._fields
if names is None:
- names = set(chain.from_iterable(
- (field.name, field.attname) if hasattr(field, 'attname') else (field.name,)
- for field in self.model._meta.get_fields()
- ))
+ names = set(
+ chain.from_iterable(
+ (field.name, field.attname)
+ if hasattr(field, "attname")
+ else (field.name,)
+ for field in self.model._meta.get_fields()
+ )
+ )
for alias, annotation in annotations.items():
if alias in names:
- raise ValueError("The annotation '%s' conflicts with a field on "
- "the model." % alias)
+ raise ValueError(
+ "The annotation '%s' conflicts with a field on "
+ "the model." % alias
+ )
if isinstance(annotation, FilteredRelation):
clone.query.add_filtered_relation(annotation, alias)
else:
clone.query.add_annotation(
- annotation, alias, is_summary=False, select=select,
+ annotation,
+ alias,
+ is_summary=False,
+ select=select,
)
for alias, annotation in clone.query.annotations.items():
if alias in annotations and annotation.contains_aggregate:
@@ -1282,7 +1397,7 @@ class QuerySet:
def order_by(self, *field_names):
"""Return a new QuerySet instance with the ordering changed."""
if self.query.is_sliced:
- raise TypeError('Cannot reorder a query once a slice has been taken.')
+ raise TypeError("Cannot reorder a query once a slice has been taken.")
obj = self._chain()
obj.query.clear_ordering(force=True, clear_default=False)
obj.query.add_ordering(*field_names)
@@ -1292,19 +1407,28 @@ class QuerySet:
"""
Return a new QuerySet instance that will select only distinct results.
"""
- self._not_support_combined_queries('distinct')
+ self._not_support_combined_queries("distinct")
if self.query.is_sliced:
- raise TypeError('Cannot create distinct fields once a slice has been taken.')
+ raise TypeError(
+ "Cannot create distinct fields once a slice has been taken."
+ )
obj = self._chain()
obj.query.add_distinct_fields(*field_names)
return obj
- def extra(self, select=None, where=None, params=None, tables=None,
- order_by=None, select_params=None):
+ def extra(
+ self,
+ select=None,
+ where=None,
+ params=None,
+ tables=None,
+ order_by=None,
+ select_params=None,
+ ):
"""Add extra SQL fragments to the query."""
- self._not_support_combined_queries('extra')
+ self._not_support_combined_queries("extra")
if self.query.is_sliced:
- raise TypeError('Cannot change a query once a slice has been taken.')
+ raise TypeError("Cannot change a query once a slice has been taken.")
clone = self._chain()
clone.query.add_extra(select, select_params, where, params, tables, order_by)
return clone
@@ -1312,7 +1436,7 @@ class QuerySet:
def reverse(self):
"""Reverse the ordering of the QuerySet."""
if self.query.is_sliced:
- raise TypeError('Cannot reverse a query once a slice has been taken.')
+ raise TypeError("Cannot reverse a query once a slice has been taken.")
clone = self._chain()
clone.query.standard_ordering = not clone.query.standard_ordering
return clone
@@ -1324,7 +1448,7 @@ class QuerySet:
The only exception to this is if None is passed in as the only
parameter, in which case removal all deferrals.
"""
- self._not_support_combined_queries('defer')
+ self._not_support_combined_queries("defer")
if self._fields is not None:
raise TypeError("Cannot call defer() after .values() or .values_list()")
clone = self._chain()
@@ -1340,7 +1464,7 @@ class QuerySet:
method and that are not already specified as deferred are loaded
immediately when the queryset is evaluated.
"""
- self._not_support_combined_queries('only')
+ self._not_support_combined_queries("only")
if self._fields is not None:
raise TypeError("Cannot call only() after .values() or .values_list()")
if fields == (None,):
@@ -1350,7 +1474,7 @@ class QuerySet:
for field in fields:
field = field.split(LOOKUP_SEP, 1)[0]
if field in self.query._filtered_relations:
- raise ValueError('only() is not supported with FilteredRelation.')
+ raise ValueError("only() is not supported with FilteredRelation.")
clone = self._chain()
clone.query.add_immediate_loading(fields)
return clone
@@ -1376,8 +1500,9 @@ class QuerySet:
if self.query.extra_order_by or self.query.order_by:
return True
elif (
- self.query.default_ordering and
- self.query.get_meta().ordering and
+ self.query.default_ordering
+ and self.query.get_meta().ordering
+ and
# A default ordering doesn't affect GROUP BY queries.
not self.query.group_by
):
@@ -1397,8 +1522,15 @@ class QuerySet:
###################
def _insert(
- self, objs, fields, returning_fields=None, raw=False, using=None,
- on_conflict=None, update_fields=None, unique_fields=None,
+ self,
+ objs,
+ fields,
+ returning_fields=None,
+ raw=False,
+ using=None,
+ on_conflict=None,
+ update_fields=None,
+ unique_fields=None,
):
"""
Insert a new record for the given model. This provides an interface to
@@ -1415,11 +1547,17 @@ class QuerySet:
)
query.insert_values(fields, objs, raw=raw)
return query.get_compiler(using=using).execute_sql(returning_fields)
+
_insert.alters_data = True
_insert.queryset_only = False
def _batched_insert(
- self, objs, fields, batch_size, on_conflict=None, update_fields=None,
+ self,
+ objs,
+ fields,
+ batch_size,
+ on_conflict=None,
+ update_fields=None,
unique_fields=None,
):
"""
@@ -1431,12 +1569,16 @@ class QuerySet:
batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
inserted_rows = []
bulk_return = connection.features.can_return_rows_from_bulk_insert
- for item in [objs[i:i + batch_size] for i in range(0, len(objs), batch_size)]:
+ for item in [objs[i : i + batch_size] for i in range(0, len(objs), batch_size)]:
if bulk_return and on_conflict is None:
- inserted_rows.extend(self._insert(
- item, fields=fields, using=self.db,
- returning_fields=self.model._meta.db_returning_fields,
- ))
+ inserted_rows.extend(
+ self._insert(
+ item,
+ fields=fields,
+ using=self.db,
+ returning_fields=self.model._meta.db_returning_fields,
+ )
+ )
else:
self._insert(
item,
@@ -1464,7 +1606,12 @@ class QuerySet:
Return a copy of the current QuerySet. A lightweight alternative
to deepcopy().
"""
- c = self.__class__(model=self.model, query=self.query.chain(), using=self._db, hints=self._hints)
+ c = self.__class__(
+ model=self.model,
+ query=self.query.chain(),
+ using=self._db,
+ hints=self._hints,
+ )
c._sticky_filter = self._sticky_filter
c._for_write = self._for_write
c._prefetch_related_lookups = self._prefetch_related_lookups[:]
@@ -1496,9 +1643,10 @@ class QuerySet:
def _merge_sanity_check(self, other):
"""Check that two QuerySet classes may be merged."""
if self._fields is not None and (
- set(self.query.values_select) != set(other.query.values_select) or
- set(self.query.extra_select) != set(other.query.extra_select) or
- set(self.query.annotation_select) != set(other.query.annotation_select)):
+ set(self.query.values_select) != set(other.query.values_select)
+ or set(self.query.extra_select) != set(other.query.extra_select)
+ or set(self.query.annotation_select) != set(other.query.annotation_select)
+ ):
raise TypeError(
"Merging '%s' classes must involve the same values in each case."
% self.__class__.__name__
@@ -1515,10 +1663,11 @@ class QuerySet:
if self._fields and len(self._fields) > 1:
# values() queryset can only be used as nested queries
# if they are set up to select only a single field.
- raise TypeError('Cannot use multi-field values as a filter value.')
+ raise TypeError("Cannot use multi-field values as a filter value.")
query = self.query.resolve_expression(*args, **kwargs)
query._db = self._db
return query
+
resolve_expression.queryset_only = True
def _add_hints(self, **hints):
@@ -1538,25 +1687,28 @@ class QuerySet:
@staticmethod
def _validate_values_are_expressions(values, method_name):
- invalid_args = sorted(str(arg) for arg in values if not hasattr(arg, 'resolve_expression'))
+ invalid_args = sorted(
+ str(arg) for arg in values if not hasattr(arg, "resolve_expression")
+ )
if invalid_args:
raise TypeError(
- 'QuerySet.%s() received non-expression(s): %s.' % (
+ "QuerySet.%s() received non-expression(s): %s."
+ % (
method_name,
- ', '.join(invalid_args),
+ ", ".join(invalid_args),
)
)
def _not_support_combined_queries(self, operation_name):
if self.query.combinator:
raise NotSupportedError(
- 'Calling QuerySet.%s() after %s() is not supported.'
+ "Calling QuerySet.%s() after %s() is not supported."
% (operation_name, self.query.combinator)
)
def _check_operator_queryset(self, other, operator_):
if self.query.combinator or other.query.combinator:
- raise TypeError(f'Cannot use {operator_} operator with combined queryset.')
+ raise TypeError(f"Cannot use {operator_} operator with combined queryset.")
class InstanceCheckMeta(type):
@@ -1579,8 +1731,17 @@ class RawQuerySet:
Provide an iterator which converts the results of raw SQL queries into
annotated model instances.
"""
- def __init__(self, raw_query, model=None, query=None, params=(),
- translations=None, using=None, hints=None):
+
+ def __init__(
+ self,
+ raw_query,
+ model=None,
+ query=None,
+ params=(),
+ translations=None,
+ using=None,
+ hints=None,
+ ):
self.raw_query = raw_query
self.model = model
self._db = using
@@ -1595,10 +1756,17 @@ class RawQuerySet:
def resolve_model_init_order(self):
"""Resolve the init field names and value positions."""
converter = connections[self.db].introspection.identifier_converter
- model_init_fields = [f for f in self.model._meta.fields if converter(f.column) in self.columns]
- annotation_fields = [(column, pos) for pos, column in enumerate(self.columns)
- if column not in self.model_fields]
- model_init_order = [self.columns.index(converter(f.column)) for f in model_init_fields]
+ model_init_fields = [
+ f for f in self.model._meta.fields if converter(f.column) in self.columns
+ ]
+ annotation_fields = [
+ (column, pos)
+ for pos, column in enumerate(self.columns)
+ if column not in self.model_fields
+ ]
+ model_init_order = [
+ self.columns.index(converter(f.column)) for f in model_init_fields
+ ]
model_init_names = [f.attname for f in model_init_fields]
return model_init_names, model_init_order, annotation_fields
@@ -1618,8 +1786,13 @@ class RawQuerySet:
def _clone(self):
"""Same as QuerySet._clone()"""
c = self.__class__(
- self.raw_query, model=self.model, query=self.query, params=self.params,
- translations=self.translations, using=self._db, hints=self._hints
+ self.raw_query,
+ model=self.model,
+ query=self.query,
+ params=self.params,
+ translations=self.translations,
+ using=self._db,
+ hints=self._hints,
)
c._prefetch_related_lookups = self._prefetch_related_lookups[:]
return c
@@ -1646,20 +1819,24 @@ class RawQuerySet:
# Cache some things for performance reasons outside the loop.
db = self.db
connection = connections[db]
- compiler = connection.ops.compiler('SQLCompiler')(self.query, connection, db)
+ compiler = connection.ops.compiler("SQLCompiler")(self.query, connection, db)
query = iter(self.query)
try:
- model_init_names, model_init_pos, annotation_fields = self.resolve_model_init_order()
+ (
+ model_init_names,
+ model_init_pos,
+ annotation_fields,
+ ) = self.resolve_model_init_order()
if self.model._meta.pk.attname not in model_init_names:
raise exceptions.FieldDoesNotExist(
- 'Raw query must include the primary key'
+ "Raw query must include the primary key"
)
model_cls = self.model
fields = [self.model_fields.get(c) for c in self.columns]
- converters = compiler.get_converters([
- f.get_col(f.model._meta.db_table) if f else None for f in fields
- ])
+ converters = compiler.get_converters(
+ [f.get_col(f.model._meta.db_table) if f else None for f in fields]
+ )
if converters:
query = compiler.apply_converters(query, converters)
for values in query:
@@ -1672,7 +1849,7 @@ class RawQuerySet:
yield instance
finally:
# Done iterating the Query. If it has its own cursor, close it.
- if hasattr(self.query, 'cursor') and self.query.cursor:
+ if hasattr(self.query, "cursor") and self.query.cursor:
self.query.cursor.close()
def __repr__(self):
@@ -1689,9 +1866,11 @@ class RawQuerySet:
def using(self, alias):
"""Select the database this RawQuerySet should execute against."""
return RawQuerySet(
- self.raw_query, model=self.model,
+ self.raw_query,
+ model=self.model,
query=self.query.chain(using=alias),
- params=self.params, translations=self.translations,
+ params=self.params,
+ translations=self.translations,
using=alias,
)
@@ -1731,16 +1910,19 @@ class Prefetch:
# `prefetch_to` is the path to the attribute that stores the result.
self.prefetch_to = lookup
if queryset is not None and (
- isinstance(queryset, RawQuerySet) or (
- hasattr(queryset, '_iterable_class') and
- not issubclass(queryset._iterable_class, ModelIterable)
+ isinstance(queryset, RawQuerySet)
+ or (
+ hasattr(queryset, "_iterable_class")
+ and not issubclass(queryset._iterable_class, ModelIterable)
)
):
raise ValueError(
- 'Prefetch querysets cannot use raw(), values(), and values_list().'
+ "Prefetch querysets cannot use raw(), values(), and values_list()."
)
if to_attr:
- self.prefetch_to = LOOKUP_SEP.join(lookup.split(LOOKUP_SEP)[:-1] + [to_attr])
+ self.prefetch_to = LOOKUP_SEP.join(
+ lookup.split(LOOKUP_SEP)[:-1] + [to_attr]
+ )
self.queryset = queryset
self.to_attr = to_attr
@@ -1752,7 +1934,7 @@ class Prefetch:
# Prevent the QuerySet from being evaluated
queryset._result_cache = []
queryset._prefetch_done = True
- obj_dict['queryset'] = queryset
+ obj_dict["queryset"] = queryset
return obj_dict
def add_prefix(self, prefix):
@@ -1760,7 +1942,7 @@ class Prefetch:
self.prefetch_to = prefix + LOOKUP_SEP + self.prefetch_to
def get_current_prefetch_to(self, level):
- return LOOKUP_SEP.join(self.prefetch_to.split(LOOKUP_SEP)[:level + 1])
+ return LOOKUP_SEP.join(self.prefetch_to.split(LOOKUP_SEP)[: level + 1])
def get_current_to_attr(self, level):
parts = self.prefetch_to.split(LOOKUP_SEP)
@@ -1805,7 +1987,7 @@ def prefetch_related_objects(model_instances, *related_lookups):
# We need to be able to dynamically add to the list of prefetch_related
# lookups that we look up (see below). So we need some book keeping to
# ensure we don't do duplicate work.
- done_queries = {} # dictionary of things like 'foo__bar': [results]
+ done_queries = {} # dictionary of things like 'foo__bar': [results]
auto_lookups = set() # we add to this as we go through.
followed_descriptors = set() # recursion protection
@@ -1815,8 +1997,11 @@ def prefetch_related_objects(model_instances, *related_lookups):
lookup = all_lookups.pop()
if lookup.prefetch_to in done_queries:
if lookup.queryset is not None:
- raise ValueError("'%s' lookup was already seen with a different queryset. "
- "You may need to adjust the ordering of your lookups." % lookup.prefetch_to)
+ raise ValueError(
+ "'%s' lookup was already seen with a different queryset. "
+ "You may need to adjust the ordering of your lookups."
+ % lookup.prefetch_to
+ )
continue
@@ -1842,7 +2027,7 @@ def prefetch_related_objects(model_instances, *related_lookups):
# Since prefetching can re-use instances, it is possible to have
# the same instance multiple times in obj_list, so obj might
# already be prepared.
- if not hasattr(obj, '_prefetched_objects_cache'):
+ if not hasattr(obj, "_prefetched_objects_cache"):
try:
obj._prefetched_objects_cache = {}
except (AttributeError, TypeError):
@@ -1862,20 +2047,30 @@ def prefetch_related_objects(model_instances, *related_lookups):
# of prefetch_related), so what applies to first object applies to all.
first_obj = obj_list[0]
to_attr = lookup.get_current_to_attr(level)[0]
- prefetcher, descriptor, attr_found, is_fetched = get_prefetcher(first_obj, through_attr, to_attr)
+ prefetcher, descriptor, attr_found, is_fetched = get_prefetcher(
+ first_obj, through_attr, to_attr
+ )
if not attr_found:
- raise AttributeError("Cannot find '%s' on %s object, '%s' is an invalid "
- "parameter to prefetch_related()" %
- (through_attr, first_obj.__class__.__name__, lookup.prefetch_through))
+ raise AttributeError(
+ "Cannot find '%s' on %s object, '%s' is an invalid "
+ "parameter to prefetch_related()"
+ % (
+ through_attr,
+ first_obj.__class__.__name__,
+ lookup.prefetch_through,
+ )
+ )
if level == len(through_attrs) - 1 and prefetcher is None:
# Last one, this *must* resolve to something that supports
# prefetching, otherwise there is no point adding it and the
# developer asking for it has made a mistake.
- raise ValueError("'%s' does not resolve to an item that supports "
- "prefetching - this is an invalid parameter to "
- "prefetch_related()." % lookup.prefetch_through)
+ raise ValueError(
+ "'%s' does not resolve to an item that supports "
+ "prefetching - this is an invalid parameter to "
+ "prefetch_related()." % lookup.prefetch_through
+ )
obj_to_fetch = None
if prefetcher is not None:
@@ -1892,9 +2087,15 @@ def prefetch_related_objects(model_instances, *related_lookups):
# same relationships to stop infinite recursion. So, if we
# are already on an automatically added lookup, don't add
# the new lookups from relationships we've seen already.
- if not (prefetch_to in done_queries and lookup in auto_lookups and descriptor in followed_descriptors):
+ if not (
+ prefetch_to in done_queries
+ and lookup in auto_lookups
+ and descriptor in followed_descriptors
+ ):
done_queries[prefetch_to] = obj_list
- new_lookups = normalize_prefetch_lookups(reversed(additional_lookups), prefetch_to)
+ new_lookups = normalize_prefetch_lookups(
+ reversed(additional_lookups), prefetch_to
+ )
auto_lookups.update(new_lookups)
all_lookups.extend(new_lookups)
followed_descriptors.add(descriptor)
@@ -1908,7 +2109,7 @@ def prefetch_related_objects(model_instances, *related_lookups):
# that we can continue with nullable or reverse relations.
new_obj_list = []
for obj in obj_list:
- if through_attr in getattr(obj, '_prefetched_objects_cache', ()):
+ if through_attr in getattr(obj, "_prefetched_objects_cache", ()):
# If related objects have been prefetched, use the
# cache rather than the object's through_attr.
new_obj = list(obj._prefetched_objects_cache.get(through_attr))
@@ -1940,6 +2141,7 @@ def get_prefetcher(instance, through_attr, to_attr):
a function that takes an instance and returns a boolean that is True if
the attribute has already been fetched for that instance)
"""
+
def has_to_attr_attribute(instance):
return hasattr(instance, to_attr)
@@ -1957,7 +2159,7 @@ def get_prefetcher(instance, through_attr, to_attr):
if rel_obj_descriptor:
# singly related object, descriptor object has the
# get_prefetch_queryset() method.
- if hasattr(rel_obj_descriptor, 'get_prefetch_queryset'):
+ if hasattr(rel_obj_descriptor, "get_prefetch_queryset"):
prefetcher = rel_obj_descriptor
is_fetched = rel_obj_descriptor.is_cached
else:
@@ -1965,17 +2167,21 @@ def get_prefetcher(instance, through_attr, to_attr):
# the attribute on the instance rather than the class to
# support many related managers
rel_obj = getattr(instance, through_attr)
- if hasattr(rel_obj, 'get_prefetch_queryset'):
+ if hasattr(rel_obj, "get_prefetch_queryset"):
prefetcher = rel_obj
if through_attr != to_attr:
# Special case cached_property instances because hasattr
# triggers attribute computation and assignment.
- if isinstance(getattr(instance.__class__, to_attr, None), cached_property):
+ if isinstance(
+ getattr(instance.__class__, to_attr, None), cached_property
+ ):
+
def has_cached_property(instance):
return to_attr in instance.__dict__
is_fetched = has_cached_property
else:
+
def in_prefetched_cache(instance):
return through_attr in instance._prefetched_objects_cache
@@ -2006,8 +2212,14 @@ def prefetch_one_level(instances, prefetcher, lookup, level):
# The 'values to be matched' must be hashable as they will be used
# in a dictionary.
- rel_qs, rel_obj_attr, instance_attr, single, cache_name, is_descriptor = (
- prefetcher.get_prefetch_queryset(instances, lookup.get_current_queryset(level)))
+ (
+ rel_qs,
+ rel_obj_attr,
+ instance_attr,
+ single,
+ cache_name,
+ is_descriptor,
+ ) = prefetcher.get_prefetch_queryset(instances, lookup.get_current_queryset(level))
# We have to handle the possibility that the QuerySet we just got back
# contains some prefetch_related lookups. We don't want to trigger the
# prefetch_related functionality by evaluating the query. Rather, we need
@@ -2015,8 +2227,8 @@ def prefetch_one_level(instances, prefetcher, lookup, level):
# Copy the lookups in case it is a Prefetch object which could be reused
# later (happens in nested prefetch_related).
additional_lookups = [
- copy.copy(additional_lookup) for additional_lookup
- in getattr(rel_qs, '_prefetch_related_lookups', ())
+ copy.copy(additional_lookup)
+ for additional_lookup in getattr(rel_qs, "_prefetch_related_lookups", ())
]
if additional_lookups:
# Don't need to clone because the manager should have given us a fresh
@@ -2042,7 +2254,7 @@ def prefetch_one_level(instances, prefetcher, lookup, level):
except exceptions.FieldDoesNotExist:
pass
else:
- msg = 'to_attr={} conflicts with a field on the {} model.'
+ msg = "to_attr={} conflicts with a field on the {} model."
raise ValueError(msg.format(to_attr, model.__name__))
# Whether or not we're prefetching the last part of the lookup.
@@ -2098,6 +2310,7 @@ class RelatedPopulator:
method gets row and from_obj as input and populates the select_related()
model instance.
"""
+
def __init__(self, klass_info, select, db):
self.db = db
# Pre-compute needed attributes. The attributes are:
@@ -2123,32 +2336,40 @@ class RelatedPopulator:
# - local_setter, remote_setter: Methods to set cached values on
# the object being populated and on the remote object. Usually
# these are Field.set_cached_value() methods.
- select_fields = klass_info['select_fields']
- from_parent = klass_info['from_parent']
+ select_fields = klass_info["select_fields"]
+ from_parent = klass_info["from_parent"]
if not from_parent:
self.cols_start = select_fields[0]
self.cols_end = select_fields[-1] + 1
self.init_list = [
- f[0].target.attname for f in select[self.cols_start:self.cols_end]
+ f[0].target.attname for f in select[self.cols_start : self.cols_end]
]
self.reorder_for_init = None
else:
- attname_indexes = {select[idx][0].target.attname: idx for idx in select_fields}
- model_init_attnames = (f.attname for f in klass_info['model']._meta.concrete_fields)
- self.init_list = [attname for attname in model_init_attnames if attname in attname_indexes]
- self.reorder_for_init = operator.itemgetter(*[attname_indexes[attname] for attname in self.init_list])
+ attname_indexes = {
+ select[idx][0].target.attname: idx for idx in select_fields
+ }
+ model_init_attnames = (
+ f.attname for f in klass_info["model"]._meta.concrete_fields
+ )
+ self.init_list = [
+ attname for attname in model_init_attnames if attname in attname_indexes
+ ]
+ self.reorder_for_init = operator.itemgetter(
+ *[attname_indexes[attname] for attname in self.init_list]
+ )
- self.model_cls = klass_info['model']
+ self.model_cls = klass_info["model"]
self.pk_idx = self.init_list.index(self.model_cls._meta.pk.attname)
self.related_populators = get_related_populators(klass_info, select, self.db)
- self.local_setter = klass_info['local_setter']
- self.remote_setter = klass_info['remote_setter']
+ self.local_setter = klass_info["local_setter"]
+ self.remote_setter = klass_info["remote_setter"]
def populate(self, row, from_obj):
if self.reorder_for_init:
obj_data = self.reorder_for_init(row)
else:
- obj_data = row[self.cols_start:self.cols_end]
+ obj_data = row[self.cols_start : self.cols_end]
if obj_data[self.pk_idx] is None:
obj = None
else:
@@ -2162,7 +2383,7 @@ class RelatedPopulator:
def get_related_populators(klass_info, select, db):
iterators = []
- related_klass_infos = klass_info.get('related_klass_infos', [])
+ related_klass_infos = klass_info.get("related_klass_infos", [])
for rel_klass_info in related_klass_infos:
rel_cls = RelatedPopulator(rel_klass_info, select, db)
iterators.append(rel_cls)
diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py
index 188b640850..6ea82b6520 100644
--- a/django/db/models/query_utils.py
+++ b/django/db/models/query_utils.py
@@ -17,7 +17,10 @@ from django.utils import tree
# PathInfo is used when converting lookups (fk__somecol). The contents
# describe the relation in Model terms (model Options and Fields for both
# sides of the relation. The join_field is the field backing the relation.
-PathInfo = namedtuple('PathInfo', 'from_opts to_opts target_fields join_field m2m direct filtered_relation')
+PathInfo = namedtuple(
+ "PathInfo",
+ "from_opts to_opts target_fields join_field m2m direct filtered_relation",
+)
def subclasses(cls):
@@ -31,21 +34,26 @@ class Q(tree.Node):
Encapsulate filters as objects that can then be combined logically (using
`&` and `|`).
"""
+
# Connection types
- AND = 'AND'
- OR = 'OR'
+ AND = "AND"
+ OR = "OR"
default = AND
conditional = True
def __init__(self, *args, _connector=None, _negated=False, **kwargs):
- super().__init__(children=[*args, *sorted(kwargs.items())], connector=_connector, negated=_negated)
+ super().__init__(
+ children=[*args, *sorted(kwargs.items())],
+ connector=_connector,
+ negated=_negated,
+ )
def _combine(self, other, conn):
- if not(isinstance(other, Q) or getattr(other, 'conditional', False) is True):
+ if not (isinstance(other, Q) or getattr(other, "conditional", False) is True):
raise TypeError(other)
if not self:
- return other.copy() if hasattr(other, 'copy') else copy.copy(other)
+ return other.copy() if hasattr(other, "copy") else copy.copy(other)
elif isinstance(other, Q) and not other:
_, args, kwargs = self.deconstruct()
return type(self)(*args, **kwargs)
@@ -68,26 +76,31 @@ class Q(tree.Node):
obj.negate()
return obj
- def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
+ def resolve_expression(
+ self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
+ ):
# We must promote any new joins to left outer joins so that when Q is
# used as an expression, rows aren't filtered due to joins.
clause, joins = query._add_q(
- self, reuse, allow_joins=allow_joins, split_subq=False,
+ self,
+ reuse,
+ allow_joins=allow_joins,
+ split_subq=False,
check_filterable=False,
)
query.promote_joins(joins)
return clause
def deconstruct(self):
- path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__)
- if path.startswith('django.db.models.query_utils'):
- path = path.replace('django.db.models.query_utils', 'django.db.models')
+ path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
+ if path.startswith("django.db.models.query_utils"):
+ path = path.replace("django.db.models.query_utils", "django.db.models")
args = tuple(self.children)
kwargs = {}
if self.connector != self.default:
- kwargs['_connector'] = self.connector
+ kwargs["_connector"] = self.connector
if self.negated:
- kwargs['_negated'] = True
+ kwargs["_negated"] = True
return path, args, kwargs
@@ -96,6 +109,7 @@ class DeferredAttribute:
A wrapper for a deferred-loading field. When the value is read from this
object the first time, the query is executed.
"""
+
def __init__(self, field):
self.field = field
@@ -132,7 +146,6 @@ class DeferredAttribute:
class RegisterLookupMixin:
-
@classmethod
def _get_lookup(cls, lookup_name):
return cls.get_lookups().get(lookup_name, None)
@@ -140,13 +153,16 @@ class RegisterLookupMixin:
@classmethod
@functools.lru_cache(maxsize=None)
def get_lookups(cls):
- class_lookups = [parent.__dict__.get('class_lookups', {}) for parent in inspect.getmro(cls)]
+ class_lookups = [
+ parent.__dict__.get("class_lookups", {}) for parent in inspect.getmro(cls)
+ ]
return cls.merge_dicts(class_lookups)
def get_lookup(self, lookup_name):
from django.db.models.lookups import Lookup
+
found = self._get_lookup(lookup_name)
- if found is None and hasattr(self, 'output_field'):
+ if found is None and hasattr(self, "output_field"):
return self.output_field.get_lookup(lookup_name)
if found is not None and not issubclass(found, Lookup):
return None
@@ -154,8 +170,9 @@ class RegisterLookupMixin:
def get_transform(self, lookup_name):
from django.db.models.lookups import Transform
+
found = self._get_lookup(lookup_name)
- if found is None and hasattr(self, 'output_field'):
+ if found is None and hasattr(self, "output_field"):
return self.output_field.get_transform(lookup_name)
if found is not None and not issubclass(found, Transform):
return None
@@ -181,7 +198,7 @@ class RegisterLookupMixin:
def register_lookup(cls, lookup, lookup_name=None):
if lookup_name is None:
lookup_name = lookup.lookup_name
- if 'class_lookups' not in cls.__dict__:
+ if "class_lookups" not in cls.__dict__:
cls.class_lookups = {}
cls.class_lookups[lookup_name] = lookup
cls._clear_cached_lookups()
@@ -228,8 +245,8 @@ def select_related_descend(field, restricted, requested, load_fields, reverse=Fa
if field.attname not in load_fields:
if restricted and field.name in requested:
msg = (
- 'Field %s.%s cannot be both deferred and traversed using '
- 'select_related at the same time.'
+ "Field %s.%s cannot be both deferred and traversed using "
+ "select_related at the same time."
) % (field.model._meta.object_name, field.name)
raise FieldError(msg)
return True
@@ -255,12 +272,14 @@ def check_rel_lookup_compatibility(model, target_opts, field):
1) model and opts match (where proxy inheritance is removed)
2) model is parent of opts' model or the other way around
"""
+
def check(opts):
return (
- model._meta.concrete_model == opts.concrete_model or
- opts.concrete_model in model._meta.get_parent_list() or
- model in opts.get_parent_list()
+ model._meta.concrete_model == opts.concrete_model
+ or opts.concrete_model in model._meta.get_parent_list()
+ or model in opts.get_parent_list()
)
+
# If the field is a primary key, then doing a query against the field's
# model is ok, too. Consider the case:
# class Restaurant(models.Model):
@@ -270,9 +289,8 @@ def check_rel_lookup_compatibility(model, target_opts, field):
# give Place's opts as the target opts, but Restaurant isn't compatible
# with that. This logic applies only to primary keys, as when doing __in=qs,
# we are going to turn this into __in=qs.values('pk') later on.
- return (
- check(target_opts) or
- (getattr(field, 'primary_key', False) and check(field.model._meta))
+ return check(target_opts) or (
+ getattr(field, "primary_key", False) and check(field.model._meta)
)
@@ -281,11 +299,11 @@ class FilteredRelation:
def __init__(self, relation_name, *, condition=Q()):
if not relation_name:
- raise ValueError('relation_name cannot be empty.')
+ raise ValueError("relation_name cannot be empty.")
self.relation_name = relation_name
self.alias = None
if not isinstance(condition, Q):
- raise ValueError('condition argument must be a Q() instance.')
+ raise ValueError("condition argument must be a Q() instance.")
self.condition = condition
self.path = []
@@ -293,9 +311,9 @@ class FilteredRelation:
if not isinstance(other, self.__class__):
return NotImplemented
return (
- self.relation_name == other.relation_name and
- self.alias == other.alias and
- self.condition == other.condition
+ self.relation_name == other.relation_name
+ and self.alias == other.alias
+ and self.condition == other.condition
)
def clone(self):
@@ -309,7 +327,7 @@ class FilteredRelation:
QuerySet.annotate() only accepts expression-like arguments
(with a resolve_expression() method).
"""
- raise NotImplementedError('FilteredRelation.resolve_expression() is unused.')
+ raise NotImplementedError("FilteredRelation.resolve_expression() is unused.")
def as_sql(self, compiler, connection):
# Resolve the condition in Join.filtered_relation.
diff --git a/django/db/models/signals.py b/django/db/models/signals.py
index d14eaaf91d..a0720937af 100644
--- a/django/db/models/signals.py
+++ b/django/db/models/signals.py
@@ -11,6 +11,7 @@ class ModelSignal(Signal):
Signal subclass that allows the sender to be lazily specified as a string
of the `app_label.ModelName` form.
"""
+
def _lazy_method(self, method, apps, receiver, sender, **kwargs):
from django.db.models.options import Options
@@ -24,8 +25,12 @@ class ModelSignal(Signal):
def connect(self, receiver, sender=None, weak=True, dispatch_uid=None, apps=None):
self._lazy_method(
- super().connect, apps, receiver, sender,
- weak=weak, dispatch_uid=dispatch_uid,
+ super().connect,
+ apps,
+ receiver,
+ sender,
+ weak=weak,
+ dispatch_uid=dispatch_uid,
)
def disconnect(self, receiver=None, sender=None, dispatch_uid=None, apps=None):
diff --git a/django/db/models/sql/__init__.py b/django/db/models/sql/__init__.py
index 5fa52f6a1f..2956e047b1 100644
--- a/django/db/models/sql/__init__.py
+++ b/django/db/models/sql/__init__.py
@@ -3,4 +3,4 @@ from django.db.models.sql.query import Query
from django.db.models.sql.subqueries import * # NOQA
from django.db.models.sql.where import AND, OR
-__all__ = ['Query', 'AND', 'OR']
+__all__ = ["Query", "AND", "OR"]
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index d405a203ee..13a7ec7263 100644
--- a/django/db/models/sql/compiler.py
+++ b/django/db/models/sql/compiler.py
@@ -11,7 +11,12 @@ from django.db.models.expressions import F, OrderBy, RawSQL, Ref, Value
from django.db.models.functions import Cast, Random
from django.db.models.query_utils import select_related_descend
from django.db.models.sql.constants import (
- CURSOR, GET_ITERATOR_CHUNK_SIZE, MULTI, NO_RESULTS, ORDER_DIR, SINGLE,
+ CURSOR,
+ GET_ITERATOR_CHUNK_SIZE,
+ MULTI,
+ NO_RESULTS,
+ ORDER_DIR,
+ SINGLE,
)
from django.db.models.sql.query import Query, get_order_dir
from django.db.transaction import TransactionManagementError
@@ -23,7 +28,7 @@ from django.utils.regex_helper import _lazy_re_compile
class SQLCompiler:
# Multiline ordering SQL clause may appear from RawSQL.
ordering_parts = _lazy_re_compile(
- r'^(.*)\s(?:ASC|DESC).*',
+ r"^(.*)\s(?:ASC|DESC).*",
re.MULTILINE | re.DOTALL,
)
@@ -34,7 +39,7 @@ class SQLCompiler:
# Some queries, e.g. coalesced aggregation, need to be executed even if
# they would return an empty result set.
self.elide_empty = elide_empty
- self.quote_cache = {'*': '*'}
+ self.quote_cache = {"*": "*"}
# The select, klass_info, and annotations are needed by QuerySet.iterator()
# these are set as a side-effect of executing the query. Note that we calculate
# separately a list of extra select columns needed for grammatical correctness
@@ -46,9 +51,9 @@ class SQLCompiler:
def __repr__(self):
return (
- f'<{self.__class__.__qualname__} '
- f'model={self.query.model.__qualname__} '
- f'connection={self.connection!r} using={self.using!r}>'
+ f"<{self.__class__.__qualname__} "
+ f"model={self.query.model.__qualname__} "
+ f"connection={self.connection!r} using={self.using!r}>"
)
def setup_query(self):
@@ -118,16 +123,14 @@ class SQLCompiler:
# when we have public API way of forcing the GROUP BY clause.
# Converts string references to expressions.
for expr in self.query.group_by:
- if not hasattr(expr, 'as_sql'):
+ if not hasattr(expr, "as_sql"):
expressions.append(self.query.resolve_ref(expr))
else:
expressions.append(expr)
# Note that even if the group_by is set, it is only the minimal
# set to group by. So, we need to add cols in select, order_by, and
# having into the select in any case.
- ref_sources = {
- expr.source for expr in expressions if isinstance(expr, Ref)
- }
+ ref_sources = {expr.source for expr in expressions if isinstance(expr, Ref)}
for expr, _, _ in select:
# Skip members of the select clause that are already included
# by reference.
@@ -169,8 +172,10 @@ class SQLCompiler:
for expr in expressions:
# Is this a reference to query's base table primary key? If the
# expression isn't a Col-like, then skip the expression.
- if (getattr(expr, 'target', None) == self.query.model._meta.pk and
- getattr(expr, 'alias', None) == self.query.base_table):
+ if (
+ getattr(expr, "target", None) == self.query.model._meta.pk
+ and getattr(expr, "alias", None) == self.query.base_table
+ ):
pk = expr
break
# If the main model's primary key is in the query, group by that
@@ -178,13 +183,17 @@ class SQLCompiler:
# that don't have a primary key included in the grouped columns.
if pk:
pk_aliases = {
- expr.alias for expr in expressions
- if hasattr(expr, 'target') and expr.target.primary_key
+ expr.alias
+ for expr in expressions
+ if hasattr(expr, "target") and expr.target.primary_key
}
expressions = [pk] + [
- expr for expr in expressions
- if expr in having or (
- getattr(expr, 'alias', None) is not None and expr.alias not in pk_aliases
+ expr
+ for expr in expressions
+ if expr in having
+ or (
+ getattr(expr, "alias", None) is not None
+ and expr.alias not in pk_aliases
)
]
elif self.connection.features.allows_group_by_selected_pks:
@@ -195,16 +204,21 @@ class SQLCompiler:
# Unmanaged models are excluded because they could be representing
# database views on which the optimization might not be allowed.
pks = {
- expr for expr in expressions
+ expr
+ for expr in expressions
if (
- hasattr(expr, 'target') and
- expr.target.primary_key and
- self.connection.features.allows_group_by_selected_pks_on_model(expr.target.model)
+ hasattr(expr, "target")
+ and expr.target.primary_key
+ and self.connection.features.allows_group_by_selected_pks_on_model(
+ expr.target.model
+ )
)
}
aliases = {expr.alias for expr in pks}
expressions = [
- expr for expr in expressions if expr in pks or getattr(expr, 'alias', None) not in aliases
+ expr
+ for expr in expressions
+ if expr in pks or getattr(expr, "alias", None) not in aliases
]
return expressions
@@ -248,8 +262,8 @@ class SQLCompiler:
select.append((col, None))
select_idx += 1
klass_info = {
- 'model': self.query.model,
- 'select_fields': select_list,
+ "model": self.query.model,
+ "select_fields": select_list,
}
for alias, annotation in self.query.annotation_select.items():
annotations[alias] = select_idx
@@ -258,14 +272,16 @@ class SQLCompiler:
if self.query.select_related:
related_klass_infos = self.get_related_selections(select)
- klass_info['related_klass_infos'] = related_klass_infos
+ klass_info["related_klass_infos"] = related_klass_infos
def get_select_from_parent(klass_info):
- for ki in klass_info['related_klass_infos']:
- if ki['from_parent']:
- ki['select_fields'] = (klass_info['select_fields'] +
- ki['select_fields'])
+ for ki in klass_info["related_klass_infos"]:
+ if ki["from_parent"]:
+ ki["select_fields"] = (
+ klass_info["select_fields"] + ki["select_fields"]
+ )
get_select_from_parent(ki)
+
get_select_from_parent(klass_info)
ret = []
@@ -273,10 +289,12 @@ class SQLCompiler:
try:
sql, params = self.compile(col)
except EmptyResultSet:
- empty_result_set_value = getattr(col, 'empty_result_set_value', NotImplemented)
+ empty_result_set_value = getattr(
+ col, "empty_result_set_value", NotImplemented
+ )
if empty_result_set_value is NotImplemented:
# Select a predicate that's always False.
- sql, params = '0', ()
+ sql, params = "0", ()
else:
sql, params = self.compile(Value(empty_result_set_value))
else:
@@ -297,12 +315,12 @@ class SQLCompiler:
else:
ordering = []
if self.query.standard_ordering:
- default_order, _ = ORDER_DIR['ASC']
+ default_order, _ = ORDER_DIR["ASC"]
else:
- default_order, _ = ORDER_DIR['DESC']
+ default_order, _ = ORDER_DIR["DESC"]
for field in ordering:
- if hasattr(field, 'resolve_expression'):
+ if hasattr(field, "resolve_expression"):
if isinstance(field, Value):
# output_field must be resolved for constants.
field = Cast(field, field.output_field)
@@ -313,12 +331,12 @@ class SQLCompiler:
field.reverse_ordering()
yield field, False
continue
- if field == '?': # random
+ if field == "?": # random
yield OrderBy(Random()), False
continue
col, order = get_order_dir(field, default_order)
- descending = order == 'DESC'
+ descending = order == "DESC"
if col in self.query.annotation_select:
# Reference to expression in SELECT clause
@@ -345,13 +363,15 @@ class SQLCompiler:
yield OrderBy(expr, descending=descending), False
continue
- if '.' in field:
+ if "." in field:
# This came in through an extra(order_by=...) addition. Pass it
# on verbatim.
- table, col = col.split('.', 1)
+ table, col = col.split(".", 1)
yield (
OrderBy(
- RawSQL('%s.%s' % (self.quote_name_unless_alias(table), col), []),
+ RawSQL(
+ "%s.%s" % (self.quote_name_unless_alias(table), col), []
+ ),
descending=descending,
),
False,
@@ -361,7 +381,10 @@ class SQLCompiler:
if self.query.extra and col in self.query.extra:
if col in self.query.extra_select:
yield (
- OrderBy(Ref(col, RawSQL(*self.query.extra[col])), descending=descending),
+ OrderBy(
+ Ref(col, RawSQL(*self.query.extra[col])),
+ descending=descending,
+ ),
True,
)
else:
@@ -378,7 +401,9 @@ class SQLCompiler:
# 'col' is of the form 'field' or 'field1__field2' or
# '-field1__field2__field', etc.
yield from self.find_ordering_name(
- field, self.query.get_meta(), default_order=default_order,
+ field,
+ self.query.get_meta(),
+ default_order=default_order,
)
def get_order_by(self):
@@ -409,19 +434,21 @@ class SQLCompiler:
):
continue
if src == sel_expr:
- resolved.set_source_expressions([RawSQL('%d' % (idx + 1), ())])
+ resolved.set_source_expressions([RawSQL("%d" % (idx + 1), ())])
break
else:
if col_alias:
- raise DatabaseError('ORDER BY term does not match any column in the result set.')
+ raise DatabaseError(
+ "ORDER BY term does not match any column in the result set."
+ )
# Add column used in ORDER BY clause to the selected
# columns and to each combined query.
order_by_idx = len(self.query.select) + 1
- col_name = f'__orderbycol{order_by_idx}'
+ col_name = f"__orderbycol{order_by_idx}"
for q in self.query.combined_queries:
q.add_annotation(expr_src, col_name)
self.query.add_select_col(resolved, col_name)
- resolved.set_source_expressions([RawSQL(f'{order_by_idx}', ())])
+ resolved.set_source_expressions([RawSQL(f"{order_by_idx}", ())])
sql, params = self.compile(resolved)
# Don't add the same column twice, but the order direction is
# not taken into account so we strip it. When this entire method
@@ -453,9 +480,14 @@ class SQLCompiler:
"""
if name in self.quote_cache:
return self.quote_cache[name]
- if ((name in self.query.alias_map and name not in self.query.table_map) or
- name in self.query.extra_select or (
- self.query.external_aliases.get(name) and name not in self.query.table_map)):
+ if (
+ (name in self.query.alias_map and name not in self.query.table_map)
+ or name in self.query.extra_select
+ or (
+ self.query.external_aliases.get(name)
+ and name not in self.query.table_map
+ )
+ ):
self.quote_cache[name] = name
return name
r = self.connection.ops.quote_name(name)
@@ -463,7 +495,7 @@ class SQLCompiler:
return r
def compile(self, node):
- vendor_impl = getattr(node, 'as_' + self.connection.vendor, None)
+ vendor_impl = getattr(node, "as_" + self.connection.vendor, None)
if vendor_impl:
sql, params = vendor_impl(self, self.connection)
else:
@@ -474,14 +506,19 @@ class SQLCompiler:
features = self.connection.features
compilers = [
query.get_compiler(self.using, self.connection, self.elide_empty)
- for query in self.query.combined_queries if not query.is_empty()
+ for query in self.query.combined_queries
+ if not query.is_empty()
]
if not features.supports_slicing_ordering_in_compound:
for query, compiler in zip(self.query.combined_queries, compilers):
if query.low_mark or query.high_mark:
- raise DatabaseError('LIMIT/OFFSET not allowed in subqueries of compound statements.')
+ raise DatabaseError(
+ "LIMIT/OFFSET not allowed in subqueries of compound statements."
+ )
if compiler.get_order_by():
- raise DatabaseError('ORDER BY not allowed in subqueries of compound statements.')
+ raise DatabaseError(
+ "ORDER BY not allowed in subqueries of compound statements."
+ )
parts = ()
for compiler in compilers:
try:
@@ -490,41 +527,45 @@ class SQLCompiler:
# the query on all combined queries, if not already set.
if not compiler.query.values_select and self.query.values_select:
compiler.query = compiler.query.clone()
- compiler.query.set_values((
- *self.query.extra_select,
- *self.query.values_select,
- *self.query.annotation_select,
- ))
+ compiler.query.set_values(
+ (
+ *self.query.extra_select,
+ *self.query.values_select,
+ *self.query.annotation_select,
+ )
+ )
part_sql, part_args = compiler.as_sql()
if compiler.query.combinator:
# Wrap in a subquery if wrapping in parentheses isn't
# supported.
if not features.supports_parentheses_in_compound:
- part_sql = 'SELECT * FROM ({})'.format(part_sql)
+ part_sql = "SELECT * FROM ({})".format(part_sql)
# Add parentheses when combining with compound query if not
# already added for all compound queries.
elif (
- self.query.subquery or
- not features.supports_slicing_ordering_in_compound
+ self.query.subquery
+ or not features.supports_slicing_ordering_in_compound
):
- part_sql = '({})'.format(part_sql)
+ part_sql = "({})".format(part_sql)
parts += ((part_sql, part_args),)
except EmptyResultSet:
# Omit the empty queryset with UNION and with DIFFERENCE if the
# first queryset is nonempty.
- if combinator == 'union' or (combinator == 'difference' and parts):
+ if combinator == "union" or (combinator == "difference" and parts):
continue
raise
if not parts:
raise EmptyResultSet
combinator_sql = self.connection.ops.set_operators[combinator]
- if all and combinator == 'union':
- combinator_sql += ' ALL'
- braces = '{}'
+ if all and combinator == "union":
+ combinator_sql += " ALL"
+ braces = "{}"
if not self.query.subquery and features.supports_slicing_ordering_in_compound:
- braces = '({})'
- sql_parts, args_parts = zip(*((braces.format(sql), args) for sql, args in parts))
- result = [' {} '.format(combinator_sql).join(sql_parts)]
+ braces = "({})"
+ sql_parts, args_parts = zip(
+ *((braces.format(sql), args) for sql, args in parts)
+ )
+ result = [" {} ".format(combinator_sql).join(sql_parts)]
params = []
for part in args_parts:
params.extend(part)
@@ -543,27 +584,39 @@ class SQLCompiler:
extra_select, order_by, group_by = self.pre_sql_setup()
for_update_part = None
# Is a LIMIT/OFFSET clause needed?
- with_limit_offset = with_limits and (self.query.high_mark is not None or self.query.low_mark)
+ with_limit_offset = with_limits and (
+ self.query.high_mark is not None or self.query.low_mark
+ )
combinator = self.query.combinator
features = self.connection.features
if combinator:
- if not getattr(features, 'supports_select_{}'.format(combinator)):
- raise NotSupportedError('{} is not supported on this database backend.'.format(combinator))
- result, params = self.get_combinator_sql(combinator, self.query.combinator_all)
+ if not getattr(features, "supports_select_{}".format(combinator)):
+ raise NotSupportedError(
+ "{} is not supported on this database backend.".format(
+ combinator
+ )
+ )
+ result, params = self.get_combinator_sql(
+ combinator, self.query.combinator_all
+ )
else:
distinct_fields, distinct_params = self.get_distinct()
# This must come after 'select', 'ordering', and 'distinct'
# (see docstring of get_from_clause() for details).
from_, f_params = self.get_from_clause()
try:
- where, w_params = self.compile(self.where) if self.where is not None else ('', [])
+ where, w_params = (
+ self.compile(self.where) if self.where is not None else ("", [])
+ )
except EmptyResultSet:
if self.elide_empty:
raise
# Use a predicate that's always False.
- where, w_params = '0 = 1', []
- having, h_params = self.compile(self.having) if self.having is not None else ("", [])
- result = ['SELECT']
+ where, w_params = "0 = 1", []
+ having, h_params = (
+ self.compile(self.having) if self.having is not None else ("", [])
+ )
+ result = ["SELECT"]
params = []
if self.query.distinct:
@@ -578,27 +631,38 @@ class SQLCompiler:
col_idx = 1
for _, (s_sql, s_params), alias in self.select + extra_select:
if alias:
- s_sql = '%s AS %s' % (s_sql, self.connection.ops.quote_name(alias))
+ s_sql = "%s AS %s" % (
+ s_sql,
+ self.connection.ops.quote_name(alias),
+ )
elif with_col_aliases:
- s_sql = '%s AS %s' % (
+ s_sql = "%s AS %s" % (
s_sql,
- self.connection.ops.quote_name('col%d' % col_idx),
+ self.connection.ops.quote_name("col%d" % col_idx),
)
col_idx += 1
params.extend(s_params)
out_cols.append(s_sql)
- result += [', '.join(out_cols), 'FROM', *from_]
+ result += [", ".join(out_cols), "FROM", *from_]
params.extend(f_params)
- if self.query.select_for_update and self.connection.features.has_select_for_update:
+ if (
+ self.query.select_for_update
+ and self.connection.features.has_select_for_update
+ ):
if self.connection.get_autocommit():
- raise TransactionManagementError('select_for_update cannot be used outside of a transaction.')
+ raise TransactionManagementError(
+ "select_for_update cannot be used outside of a transaction."
+ )
- if with_limit_offset and not self.connection.features.supports_select_for_update_with_limit:
+ if (
+ with_limit_offset
+ and not self.connection.features.supports_select_for_update_with_limit
+ ):
raise NotSupportedError(
- 'LIMIT/OFFSET is not supported with '
- 'select_for_update on this database backend.'
+ "LIMIT/OFFSET is not supported with "
+ "select_for_update on this database backend."
)
nowait = self.query.select_for_update_nowait
skip_locked = self.query.select_for_update_skip_locked
@@ -607,16 +671,31 @@ class SQLCompiler:
# If it's a NOWAIT/SKIP LOCKED/OF/NO KEY query but the
# backend doesn't support it, raise NotSupportedError to
# prevent a possible deadlock.
- if nowait and not self.connection.features.has_select_for_update_nowait:
- raise NotSupportedError('NOWAIT is not supported on this database backend.')
- elif skip_locked and not self.connection.features.has_select_for_update_skip_locked:
- raise NotSupportedError('SKIP LOCKED is not supported on this database backend.')
+ if (
+ nowait
+ and not self.connection.features.has_select_for_update_nowait
+ ):
+ raise NotSupportedError(
+ "NOWAIT is not supported on this database backend."
+ )
+ elif (
+ skip_locked
+ and not self.connection.features.has_select_for_update_skip_locked
+ ):
+ raise NotSupportedError(
+ "SKIP LOCKED is not supported on this database backend."
+ )
elif of and not self.connection.features.has_select_for_update_of:
- raise NotSupportedError('FOR UPDATE OF is not supported on this database backend.')
- elif no_key and not self.connection.features.has_select_for_no_key_update:
raise NotSupportedError(
- 'FOR NO KEY UPDATE is not supported on this '
- 'database backend.'
+ "FOR UPDATE OF is not supported on this database backend."
+ )
+ elif (
+ no_key
+ and not self.connection.features.has_select_for_no_key_update
+ ):
+ raise NotSupportedError(
+ "FOR NO KEY UPDATE is not supported on this "
+ "database backend."
)
for_update_part = self.connection.ops.for_update_sql(
nowait=nowait,
@@ -629,7 +708,7 @@ class SQLCompiler:
result.append(for_update_part)
if where:
- result.append('WHERE %s' % where)
+ result.append("WHERE %s" % where)
params.extend(w_params)
grouping = []
@@ -638,30 +717,39 @@ class SQLCompiler:
params.extend(g_params)
if grouping:
if distinct_fields:
- raise NotImplementedError('annotate() + distinct(fields) is not implemented.')
+ raise NotImplementedError(
+ "annotate() + distinct(fields) is not implemented."
+ )
order_by = order_by or self.connection.ops.force_no_ordering()
- result.append('GROUP BY %s' % ', '.join(grouping))
+ result.append("GROUP BY %s" % ", ".join(grouping))
if self._meta_ordering:
order_by = None
if having:
- result.append('HAVING %s' % having)
+ result.append("HAVING %s" % having)
params.extend(h_params)
if self.query.explain_info:
- result.insert(0, self.connection.ops.explain_query_prefix(
- self.query.explain_info.format,
- **self.query.explain_info.options
- ))
+ result.insert(
+ 0,
+ self.connection.ops.explain_query_prefix(
+ self.query.explain_info.format,
+ **self.query.explain_info.options,
+ ),
+ )
if order_by:
ordering = []
for _, (o_sql, o_params, _) in order_by:
ordering.append(o_sql)
params.extend(o_params)
- result.append('ORDER BY %s' % ', '.join(ordering))
+ result.append("ORDER BY %s" % ", ".join(ordering))
if with_limit_offset:
- result.append(self.connection.ops.limit_offset_sql(self.query.low_mark, self.query.high_mark))
+ result.append(
+ self.connection.ops.limit_offset_sql(
+ self.query.low_mark, self.query.high_mark
+ )
+ )
if for_update_part and not self.connection.features.for_update_after_from:
result.append(for_update_part)
@@ -677,23 +765,30 @@ class SQLCompiler:
sub_params = []
for index, (select, _, alias) in enumerate(self.select, start=1):
if not alias and with_col_aliases:
- alias = 'col%d' % index
+ alias = "col%d" % index
if alias:
- sub_selects.append("%s.%s" % (
- self.connection.ops.quote_name('subquery'),
- self.connection.ops.quote_name(alias),
- ))
+ sub_selects.append(
+ "%s.%s"
+ % (
+ self.connection.ops.quote_name("subquery"),
+ self.connection.ops.quote_name(alias),
+ )
+ )
else:
- select_clone = select.relabeled_clone({select.alias: 'subquery'})
- subselect, subparams = select_clone.as_sql(self, self.connection)
+ select_clone = select.relabeled_clone(
+ {select.alias: "subquery"}
+ )
+ subselect, subparams = select_clone.as_sql(
+ self, self.connection
+ )
sub_selects.append(subselect)
sub_params.extend(subparams)
- return 'SELECT %s FROM (%s) subquery' % (
- ', '.join(sub_selects),
- ' '.join(result),
+ return "SELECT %s FROM (%s) subquery" % (
+ ", ".join(sub_selects),
+ " ".join(result),
), tuple(sub_params + params)
- return ' '.join(result), tuple(params)
+ return " ".join(result), tuple(params)
finally:
# Finally do cleanup - get rid of the joins we created above.
self.query.reset_refcounts(refcounts_before)
@@ -726,8 +821,13 @@ class SQLCompiler:
# will assign None if the field belongs to this model.
if model == opts.model:
model = None
- if from_parent and model is not None and issubclass(
- from_parent._meta.concrete_model, model._meta.concrete_model):
+ if (
+ from_parent
+ and model is not None
+ and issubclass(
+ from_parent._meta.concrete_model, model._meta.concrete_model
+ )
+ ):
# Avoid loading data for already loaded parents.
# We end up here in the case select_related() resolution
# proceeds from parent model to child model. In that case the
@@ -736,8 +836,7 @@ class SQLCompiler:
continue
if field.model in only_load and field.attname not in only_load[field.model]:
continue
- alias = self.query.join_parent_model(opts, model, start_alias,
- seen_models)
+ alias = self.query.join_parent_model(opts, model, start_alias, seen_models)
column = field.get_col(alias)
result.append(column)
return result
@@ -755,7 +854,9 @@ class SQLCompiler:
for name in self.query.distinct_fields:
parts = name.split(LOOKUP_SEP)
- _, targets, alias, joins, path, _, transform_function = self._setup_joins(parts, opts, None)
+ _, targets, alias, joins, path, _, transform_function = self._setup_joins(
+ parts, opts, None
+ )
targets, alias, _ = self.query.trim_joins(targets, joins, path)
for target in targets:
if name in self.query.annotation_select:
@@ -766,46 +867,63 @@ class SQLCompiler:
params.append(p)
return result, params
- def find_ordering_name(self, name, opts, alias=None, default_order='ASC',
- already_seen=None):
+ def find_ordering_name(
+ self, name, opts, alias=None, default_order="ASC", already_seen=None
+ ):
"""
Return the table alias (the name might be ambiguous, the alias will
not be) and column name for ordering by the given 'name' parameter.
The 'name' is of the form 'field1__field2__...__fieldN'.
"""
name, order = get_order_dir(name, default_order)
- descending = order == 'DESC'
+ descending = order == "DESC"
pieces = name.split(LOOKUP_SEP)
- field, targets, alias, joins, path, opts, transform_function = self._setup_joins(pieces, opts, alias)
+ (
+ field,
+ targets,
+ alias,
+ joins,
+ path,
+ opts,
+ transform_function,
+ ) = self._setup_joins(pieces, opts, alias)
# If we get to this point and the field is a relation to another model,
# append the default ordering for that model unless it is the pk
# shortcut or the attribute name of the field that is specified.
if (
- field.is_relation and
- opts.ordering and
- getattr(field, 'attname', None) != pieces[-1] and
- name != 'pk'
+ field.is_relation
+ and opts.ordering
+ and getattr(field, "attname", None) != pieces[-1]
+ and name != "pk"
):
# Firstly, avoid infinite loops.
already_seen = already_seen or set()
- join_tuple = tuple(getattr(self.query.alias_map[j], 'join_cols', None) for j in joins)
+ join_tuple = tuple(
+ getattr(self.query.alias_map[j], "join_cols", None) for j in joins
+ )
if join_tuple in already_seen:
- raise FieldError('Infinite loop caused by ordering.')
+ raise FieldError("Infinite loop caused by ordering.")
already_seen.add(join_tuple)
results = []
for item in opts.ordering:
- if hasattr(item, 'resolve_expression') and not isinstance(item, OrderBy):
+ if hasattr(item, "resolve_expression") and not isinstance(
+ item, OrderBy
+ ):
item = item.desc() if descending else item.asc()
if isinstance(item, OrderBy):
results.append((item, False))
continue
- results.extend(self.find_ordering_name(item, opts, alias,
- order, already_seen))
+ results.extend(
+ self.find_ordering_name(item, opts, alias, order, already_seen)
+ )
return results
targets, alias, _ = self.query.trim_joins(targets, joins, path)
- return [(OrderBy(transform_function(t, alias), descending=descending), False) for t in targets]
+ return [
+ (OrderBy(transform_function(t, alias), descending=descending), False)
+ for t in targets
+ ]
def _setup_joins(self, pieces, opts, alias):
"""
@@ -816,7 +934,9 @@ class SQLCompiler:
match. Executing SQL where this is not true is an error.
"""
alias = alias or self.query.get_initial_alias()
- field, targets, opts, joins, path, transform_function = self.query.setup_joins(pieces, opts, alias)
+ field, targets, opts, joins, path, transform_function = self.query.setup_joins(
+ pieces, opts, alias
+ )
alias = joins[-1]
return field, targets, alias, joins, path, opts, transform_function
@@ -850,25 +970,39 @@ class SQLCompiler:
# Only add the alias if it's not already present (the table_alias()
# call increments the refcount, so an alias refcount of one means
# this is the only reference).
- if alias not in self.query.alias_map or self.query.alias_refcount[alias] == 1:
- result.append(', %s' % self.quote_name_unless_alias(alias))
+ if (
+ alias not in self.query.alias_map
+ or self.query.alias_refcount[alias] == 1
+ ):
+ result.append(", %s" % self.quote_name_unless_alias(alias))
return result, params
- def get_related_selections(self, select, opts=None, root_alias=None, cur_depth=1,
- requested=None, restricted=None):
+ def get_related_selections(
+ self,
+ select,
+ opts=None,
+ root_alias=None,
+ cur_depth=1,
+ requested=None,
+ restricted=None,
+ ):
"""
Fill in the information needed for a select_related query. The current
depth is measured as the number of connections away from the root model
(for example, cur_depth=1 means we are looking at models with direct
connections to the root model).
"""
+
def _get_field_choices():
direct_choices = (f.name for f in opts.fields if f.is_relation)
reverse_choices = (
f.field.related_query_name()
- for f in opts.related_objects if f.field.unique
+ for f in opts.related_objects
+ if f.field.unique
+ )
+ return chain(
+ direct_choices, reverse_choices, self.query._filtered_relations
)
- return chain(direct_choices, reverse_choices, self.query._filtered_relations)
related_klass_infos = []
if not restricted and cur_depth > self.query.max_depth:
@@ -889,7 +1023,7 @@ class SQLCompiler:
requested = self.query.select_related
def get_related_klass_infos(klass_info, related_klass_infos):
- klass_info['related_klass_infos'] = related_klass_infos
+ klass_info["related_klass_infos"] = related_klass_infos
for f in opts.fields:
field_model = f.model._meta.concrete_model
@@ -903,37 +1037,48 @@ class SQLCompiler:
if next or f.name in requested:
raise FieldError(
"Non-relational field given in select_related: '%s'. "
- "Choices are: %s" % (
+ "Choices are: %s"
+ % (
f.name,
- ", ".join(_get_field_choices()) or '(none)',
+ ", ".join(_get_field_choices()) or "(none)",
)
)
else:
next = False
- if not select_related_descend(f, restricted, requested,
- only_load.get(field_model)):
+ if not select_related_descend(
+ f, restricted, requested, only_load.get(field_model)
+ ):
continue
klass_info = {
- 'model': f.remote_field.model,
- 'field': f,
- 'reverse': False,
- 'local_setter': f.set_cached_value,
- 'remote_setter': f.remote_field.set_cached_value if f.unique else lambda x, y: None,
- 'from_parent': False,
+ "model": f.remote_field.model,
+ "field": f,
+ "reverse": False,
+ "local_setter": f.set_cached_value,
+ "remote_setter": f.remote_field.set_cached_value
+ if f.unique
+ else lambda x, y: None,
+ "from_parent": False,
}
related_klass_infos.append(klass_info)
select_fields = []
- _, _, _, joins, _, _ = self.query.setup_joins(
- [f.name], opts, root_alias)
+ _, _, _, joins, _, _ = self.query.setup_joins([f.name], opts, root_alias)
alias = joins[-1]
- columns = self.get_default_columns(start_alias=alias, opts=f.remote_field.model._meta)
+ columns = self.get_default_columns(
+ start_alias=alias, opts=f.remote_field.model._meta
+ )
for col in columns:
select_fields.append(len(select))
select.append((col, None))
- klass_info['select_fields'] = select_fields
+ klass_info["select_fields"] = select_fields
next_klass_infos = self.get_related_selections(
- select, f.remote_field.model._meta, alias, cur_depth + 1, next, restricted)
+ select,
+ f.remote_field.model._meta,
+ alias,
+ cur_depth + 1,
+ next,
+ restricted,
+ )
get_related_klass_infos(klass_info, next_klass_infos)
if restricted:
@@ -943,36 +1088,40 @@ class SQLCompiler:
if o.field.unique and not o.many_to_many
]
for f, model in related_fields:
- if not select_related_descend(f, restricted, requested,
- only_load.get(model), reverse=True):
+ if not select_related_descend(
+ f, restricted, requested, only_load.get(model), reverse=True
+ ):
continue
related_field_name = f.related_query_name()
fields_found.add(related_field_name)
- join_info = self.query.setup_joins([related_field_name], opts, root_alias)
+ join_info = self.query.setup_joins(
+ [related_field_name], opts, root_alias
+ )
alias = join_info.joins[-1]
from_parent = issubclass(model, opts.model) and model is not opts.model
klass_info = {
- 'model': model,
- 'field': f,
- 'reverse': True,
- 'local_setter': f.remote_field.set_cached_value,
- 'remote_setter': f.set_cached_value,
- 'from_parent': from_parent,
+ "model": model,
+ "field": f,
+ "reverse": True,
+ "local_setter": f.remote_field.set_cached_value,
+ "remote_setter": f.set_cached_value,
+ "from_parent": from_parent,
}
related_klass_infos.append(klass_info)
select_fields = []
columns = self.get_default_columns(
- start_alias=alias, opts=model._meta, from_parent=opts.model)
+ start_alias=alias, opts=model._meta, from_parent=opts.model
+ )
for col in columns:
select_fields.append(len(select))
select.append((col, None))
- klass_info['select_fields'] = select_fields
+ klass_info["select_fields"] = select_fields
next = requested.get(f.related_query_name(), {})
next_klass_infos = self.get_related_selections(
- select, model._meta, alias, cur_depth + 1,
- next, restricted)
+ select, model._meta, alias, cur_depth + 1, next, restricted
+ )
get_related_klass_infos(klass_info, next_klass_infos)
def local_setter(obj, from_obj):
@@ -989,32 +1138,40 @@ class SQLCompiler:
break
if name in self.query._filtered_relations:
fields_found.add(name)
- f, _, join_opts, joins, _, _ = self.query.setup_joins([name], opts, root_alias)
+ f, _, join_opts, joins, _, _ = self.query.setup_joins(
+ [name], opts, root_alias
+ )
model = join_opts.model
alias = joins[-1]
- from_parent = issubclass(model, opts.model) and model is not opts.model
+ from_parent = (
+ issubclass(model, opts.model) and model is not opts.model
+ )
klass_info = {
- 'model': model,
- 'field': f,
- 'reverse': True,
- 'local_setter': local_setter,
- 'remote_setter': partial(remote_setter, name),
- 'from_parent': from_parent,
+ "model": model,
+ "field": f,
+ "reverse": True,
+ "local_setter": local_setter,
+ "remote_setter": partial(remote_setter, name),
+ "from_parent": from_parent,
}
related_klass_infos.append(klass_info)
select_fields = []
columns = self.get_default_columns(
- start_alias=alias, opts=model._meta,
+ start_alias=alias,
+ opts=model._meta,
from_parent=opts.model,
)
for col in columns:
select_fields.append(len(select))
select.append((col, None))
- klass_info['select_fields'] = select_fields
+ klass_info["select_fields"] = select_fields
next_requested = requested.get(name, {})
next_klass_infos = self.get_related_selections(
- select, opts=model._meta, root_alias=alias,
- cur_depth=cur_depth + 1, requested=next_requested,
+ select,
+ opts=model._meta,
+ root_alias=alias,
+ cur_depth=cur_depth + 1,
+ requested=next_requested,
restricted=restricted,
)
get_related_klass_infos(klass_info, next_klass_infos)
@@ -1022,10 +1179,11 @@ class SQLCompiler:
if fields_not_found:
invalid_fields = ("'%s'" % s for s in fields_not_found)
raise FieldError(
- 'Invalid field name(s) given in select_related: %s. '
- 'Choices are: %s' % (
- ', '.join(invalid_fields),
- ', '.join(_get_field_choices()) or '(none)',
+ "Invalid field name(s) given in select_related: %s. "
+ "Choices are: %s"
+ % (
+ ", ".join(invalid_fields),
+ ", ".join(_get_field_choices()) or "(none)",
)
)
return related_klass_infos
@@ -1035,21 +1193,22 @@ class SQLCompiler:
Return a quoted list of arguments for the SELECT FOR UPDATE OF part of
the query.
"""
+
def _get_parent_klass_info(klass_info):
- concrete_model = klass_info['model']._meta.concrete_model
+ concrete_model = klass_info["model"]._meta.concrete_model
for parent_model, parent_link in concrete_model._meta.parents.items():
parent_list = parent_model._meta.get_parent_list()
yield {
- 'model': parent_model,
- 'field': parent_link,
- 'reverse': False,
- 'select_fields': [
+ "model": parent_model,
+ "field": parent_link,
+ "reverse": False,
+ "select_fields": [
select_index
- for select_index in klass_info['select_fields']
+ for select_index in klass_info["select_fields"]
# Selected columns from a model or its parents.
if (
- self.select[select_index][0].target.model == parent_model or
- self.select[select_index][0].target.model in parent_list
+ self.select[select_index][0].target.model == parent_model
+ or self.select[select_index][0].target.model in parent_list
)
],
}
@@ -1062,8 +1221,8 @@ class SQLCompiler:
select_fields is filled recursively, so it also contains fields
from the parent models.
"""
- concrete_model = klass_info['model']._meta.concrete_model
- for select_index in klass_info['select_fields']:
+ concrete_model = klass_info["model"]._meta.concrete_model
+ for select_index in klass_info["select_fields"]:
if self.select[select_index][0].target.model == concrete_model:
return self.select[select_index][0]
@@ -1074,10 +1233,10 @@ class SQLCompiler:
parent_path, klass_info = queue.popleft()
if parent_path is None:
path = []
- yield 'self'
+ yield "self"
else:
- field = klass_info['field']
- if klass_info['reverse']:
+ field = klass_info["field"]
+ if klass_info["reverse"]:
field = field.remote_field
path = parent_path + [field.name]
yield LOOKUP_SEP.join(path)
@@ -1087,25 +1246,26 @@ class SQLCompiler:
)
queue.extend(
(path, klass_info)
- for klass_info in klass_info.get('related_klass_infos', [])
+ for klass_info in klass_info.get("related_klass_infos", [])
)
+
if not self.klass_info:
return []
result = []
invalid_names = []
for name in self.query.select_for_update_of:
klass_info = self.klass_info
- if name == 'self':
+ if name == "self":
col = _get_first_selected_col_from_model(klass_info)
else:
for part in name.split(LOOKUP_SEP):
klass_infos = (
- *klass_info.get('related_klass_infos', []),
+ *klass_info.get("related_klass_infos", []),
*_get_parent_klass_info(klass_info),
)
for related_klass_info in klass_infos:
- field = related_klass_info['field']
- if related_klass_info['reverse']:
+ field = related_klass_info["field"]
+ if related_klass_info["reverse"]:
field = field.remote_field
if field.name == part:
klass_info = related_klass_info
@@ -1124,11 +1284,12 @@ class SQLCompiler:
result.append(self.quote_name_unless_alias(col.alias))
if invalid_names:
raise FieldError(
- 'Invalid field name(s) given in select_for_update(of=(...)): %s. '
- 'Only relational fields followed in the query are allowed. '
- 'Choices are: %s.' % (
- ', '.join(invalid_names),
- ', '.join(_get_field_choices()),
+ "Invalid field name(s) given in select_for_update(of=(...)): %s. "
+ "Only relational fields followed in the query are allowed. "
+ "Choices are: %s."
+ % (
+ ", ".join(invalid_names),
+ ", ".join(_get_field_choices()),
)
)
return result
@@ -1164,12 +1325,19 @@ class SQLCompiler:
row[pos] = value
yield row
- def results_iter(self, results=None, tuple_expected=False, chunked_fetch=False,
- chunk_size=GET_ITERATOR_CHUNK_SIZE):
+ def results_iter(
+ self,
+ results=None,
+ tuple_expected=False,
+ chunked_fetch=False,
+ chunk_size=GET_ITERATOR_CHUNK_SIZE,
+ ):
"""Return an iterator over the results from executing this query."""
if results is None:
- results = self.execute_sql(MULTI, chunked_fetch=chunked_fetch, chunk_size=chunk_size)
- fields = [s[0] for s in self.select[0:self.col_count]]
+ results = self.execute_sql(
+ MULTI, chunked_fetch=chunked_fetch, chunk_size=chunk_size
+ )
+ fields = [s[0] for s in self.select[0 : self.col_count]]
converters = self.get_converters(fields)
rows = chain.from_iterable(results)
if converters:
@@ -1185,7 +1353,9 @@ class SQLCompiler:
"""
return bool(self.execute_sql(SINGLE))
- def execute_sql(self, result_type=MULTI, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE):
+ def execute_sql(
+ self, result_type=MULTI, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE
+ ):
"""
Run the query against the database and return the result(s). The
return value is a single data item if result_type is SINGLE, or an
@@ -1226,7 +1396,7 @@ class SQLCompiler:
try:
val = cursor.fetchone()
if val:
- return val[0:self.col_count]
+ return val[0 : self.col_count]
return val
finally:
# done with the cursor
@@ -1236,7 +1406,8 @@ class SQLCompiler:
return
result = cursor_iter(
- cursor, self.connection.features.empty_fetchmany_value,
+ cursor,
+ self.connection.features.empty_fetchmany_value,
self.col_count if self.has_extra_select else None,
chunk_size,
)
@@ -1254,21 +1425,22 @@ class SQLCompiler:
for index, select_col in enumerate(self.query.select):
lhs_sql, lhs_params = self.compile(select_col)
- rhs = '%s.%s' % (qn(alias), qn2(columns[index]))
- self.query.where.add(
- RawSQL('%s = %s' % (lhs_sql, rhs), lhs_params), 'AND')
+ rhs = "%s.%s" % (qn(alias), qn2(columns[index]))
+ self.query.where.add(RawSQL("%s = %s" % (lhs_sql, rhs), lhs_params), "AND")
sql, params = self.as_sql()
- return 'EXISTS (%s)' % sql, params
+ return "EXISTS (%s)" % sql, params
def explain_query(self):
result = list(self.execute_sql())
# Some backends return 1 item tuples with strings, and others return
# tuples with integers and strings. Flatten them out into strings.
- output_formatter = json.dumps if self.query.explain_info.format == 'json' else str
+ output_formatter = (
+ json.dumps if self.query.explain_info.format == "json" else str
+ )
for row in result[0]:
if not isinstance(row, str):
- yield ' '.join(output_formatter(c) for c in row)
+ yield " ".join(output_formatter(c) for c in row)
else:
yield row
@@ -1289,16 +1461,16 @@ class SQLInsertCompiler(SQLCompiler):
if field is None:
# A field value of None means the value is raw.
sql, params = val, []
- elif hasattr(val, 'as_sql'):
+ elif hasattr(val, "as_sql"):
# This is an expression, let's compile it.
sql, params = self.compile(val)
- elif hasattr(field, 'get_placeholder'):
+ elif hasattr(field, "get_placeholder"):
# Some fields (e.g. geo fields) need special munging before
# they can be inserted.
sql, params = field.get_placeholder(val, self, self.connection), [val]
else:
# Return the common case for the placeholder
- sql, params = '%s', [val]
+ sql, params = "%s", [val]
# The following hook is only used by Oracle Spatial, which sometimes
# needs to yield 'NULL' and [] as its placeholder and params instead
@@ -1314,24 +1486,26 @@ class SQLInsertCompiler(SQLCompiler):
Prepare a value to be used in a query by resolving it if it is an
expression and otherwise calling the field's get_db_prep_save().
"""
- if hasattr(value, 'resolve_expression'):
- value = value.resolve_expression(self.query, allow_joins=False, for_save=True)
+ if hasattr(value, "resolve_expression"):
+ value = value.resolve_expression(
+ self.query, allow_joins=False, for_save=True
+ )
# Don't allow values containing Col expressions. They refer to
# existing columns on a row, but in the case of insert the row
# doesn't exist yet.
if value.contains_column_references:
raise ValueError(
'Failed to insert expression "%s" on %s. F() expressions '
- 'can only be used to update, not to insert.' % (value, field)
+ "can only be used to update, not to insert." % (value, field)
)
if value.contains_aggregate:
raise FieldError(
- 'Aggregate functions are not allowed in this query '
- '(%s=%r).' % (field.name, value)
+ "Aggregate functions are not allowed in this query "
+ "(%s=%r)." % (field.name, value)
)
if value.contains_over_clause:
raise FieldError(
- 'Window expressions are not allowed in this query (%s=%r).'
+ "Window expressions are not allowed in this query (%s=%r)."
% (field.name, value)
)
else:
@@ -1390,25 +1564,32 @@ class SQLInsertCompiler(SQLCompiler):
insert_statement = self.connection.ops.insert_statement(
on_conflict=self.query.on_conflict,
)
- result = ['%s %s' % (insert_statement, qn(opts.db_table))]
+ result = ["%s %s" % (insert_statement, qn(opts.db_table))]
fields = self.query.fields or [opts.pk]
- result.append('(%s)' % ', '.join(qn(f.column) for f in fields))
+ result.append("(%s)" % ", ".join(qn(f.column) for f in fields))
if self.query.fields:
value_rows = [
- [self.prepare_value(field, self.pre_save_val(field, obj)) for field in fields]
+ [
+ self.prepare_value(field, self.pre_save_val(field, obj))
+ for field in fields
+ ]
for obj in self.query.objs
]
else:
# An empty object.
- value_rows = [[self.connection.ops.pk_default_value()] for _ in self.query.objs]
+ value_rows = [
+ [self.connection.ops.pk_default_value()] for _ in self.query.objs
+ ]
fields = [None]
# Currently the backends just accept values when generating bulk
# queries and generate their own placeholders. Doing that isn't
# necessary and it should be possible to use placeholders and
# expressions in bulk inserts too.
- can_bulk = (not self.returning_fields and self.connection.features.has_bulk_insert)
+ can_bulk = (
+ not self.returning_fields and self.connection.features.has_bulk_insert
+ )
placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows)
@@ -1418,9 +1599,14 @@ class SQLInsertCompiler(SQLCompiler):
self.query.update_fields,
self.query.unique_fields,
)
- if self.returning_fields and self.connection.features.can_return_columns_from_insert:
+ if (
+ self.returning_fields
+ and self.connection.features.can_return_columns_from_insert
+ ):
if self.connection.features.can_return_rows_from_bulk_insert:
- result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
+ result.append(
+ self.connection.ops.bulk_insert_sql(fields, placeholder_rows)
+ )
params = param_rows
else:
result.append("VALUES (%s)" % ", ".join(placeholder_rows[0]))
@@ -1429,7 +1615,9 @@ class SQLInsertCompiler(SQLCompiler):
result.append(on_conflict_suffix_sql)
# Skip empty r_sql to allow subclasses to customize behavior for
# 3rd party backends. Refs #19096.
- r_sql, self.returning_params = self.connection.ops.return_insert_columns(self.returning_fields)
+ r_sql, self.returning_params = self.connection.ops.return_insert_columns(
+ self.returning_fields
+ )
if r_sql:
result.append(r_sql)
params += [self.returning_params]
@@ -1450,8 +1638,9 @@ class SQLInsertCompiler(SQLCompiler):
def execute_sql(self, returning_fields=None):
assert not (
- returning_fields and len(self.query.objs) != 1 and
- not self.connection.features.can_return_rows_from_bulk_insert
+ returning_fields
+ and len(self.query.objs) != 1
+ and not self.connection.features.can_return_rows_from_bulk_insert
)
opts = self.query.get_meta()
self.returning_fields = returning_fields
@@ -1460,17 +1649,29 @@ class SQLInsertCompiler(SQLCompiler):
cursor.execute(sql, params)
if not self.returning_fields:
return []
- if self.connection.features.can_return_rows_from_bulk_insert and len(self.query.objs) > 1:
+ if (
+ self.connection.features.can_return_rows_from_bulk_insert
+ and len(self.query.objs) > 1
+ ):
rows = self.connection.ops.fetch_returned_insert_rows(cursor)
elif self.connection.features.can_return_columns_from_insert:
assert len(self.query.objs) == 1
- rows = [self.connection.ops.fetch_returned_insert_columns(
- cursor, self.returning_params,
- )]
+ rows = [
+ self.connection.ops.fetch_returned_insert_columns(
+ cursor,
+ self.returning_params,
+ )
+ ]
else:
- rows = [(self.connection.ops.last_insert_id(
- cursor, opts.db_table, opts.pk.column,
- ),)]
+ rows = [
+ (
+ self.connection.ops.last_insert_id(
+ cursor,
+ opts.db_table,
+ opts.pk.column,
+ ),
+ )
+ ]
cols = [field.get_col(opts.db_table) for field in self.returning_fields]
converters = self.get_converters(cols)
if converters:
@@ -1489,7 +1690,7 @@ class SQLDeleteCompiler(SQLCompiler):
def _expr_refs_base_model(cls, expr, base_model):
if isinstance(expr, Query):
return expr.model == base_model
- if not hasattr(expr, 'get_source_expressions'):
+ if not hasattr(expr, "get_source_expressions"):
return False
return any(
cls._expr_refs_base_model(source_expr, base_model)
@@ -1500,17 +1701,17 @@ class SQLDeleteCompiler(SQLCompiler):
def contains_self_reference_subquery(self):
return any(
self._expr_refs_base_model(expr, self.query.model)
- for expr in chain(self.query.annotations.values(), self.query.where.children)
+ for expr in chain(
+ self.query.annotations.values(), self.query.where.children
+ )
)
def _as_sql(self, query):
- result = [
- 'DELETE FROM %s' % self.quote_name_unless_alias(query.base_table)
- ]
+ result = ["DELETE FROM %s" % self.quote_name_unless_alias(query.base_table)]
where, params = self.compile(query.where)
if where:
- result.append('WHERE %s' % where)
- return ' '.join(result), tuple(params)
+ result.append("WHERE %s" % where)
+ return " ".join(result), tuple(params)
def as_sql(self):
"""
@@ -1523,16 +1724,14 @@ class SQLDeleteCompiler(SQLCompiler):
innerq.__class__ = Query
innerq.clear_select_clause()
pk = self.query.model._meta.pk
- innerq.select = [
- pk.get_col(self.query.get_initial_alias())
- ]
+ innerq.select = [pk.get_col(self.query.get_initial_alias())]
outerq = Query(self.query.model)
if not self.connection.features.update_can_self_select:
# Force the materialization of the inner query to allow reference
# to the target table on MySQL.
sql, params = innerq.get_compiler(connection=self.connection).as_sql()
- innerq = RawSQL('SELECT * FROM (%s) subquery' % sql, params)
- outerq.add_filter('pk__in', innerq)
+ innerq = RawSQL("SELECT * FROM (%s) subquery" % sql, params)
+ outerq.add_filter("pk__in", innerq)
return self._as_sql(outerq)
@@ -1544,23 +1743,25 @@ class SQLUpdateCompiler(SQLCompiler):
"""
self.pre_sql_setup()
if not self.query.values:
- return '', ()
+ return "", ()
qn = self.quote_name_unless_alias
values, update_params = [], []
for field, model, val in self.query.values:
- if hasattr(val, 'resolve_expression'):
- val = val.resolve_expression(self.query, allow_joins=False, for_save=True)
+ if hasattr(val, "resolve_expression"):
+ val = val.resolve_expression(
+ self.query, allow_joins=False, for_save=True
+ )
if val.contains_aggregate:
raise FieldError(
- 'Aggregate functions are not allowed in this query '
- '(%s=%r).' % (field.name, val)
+ "Aggregate functions are not allowed in this query "
+ "(%s=%r)." % (field.name, val)
)
if val.contains_over_clause:
raise FieldError(
- 'Window expressions are not allowed in this query '
- '(%s=%r).' % (field.name, val)
+ "Window expressions are not allowed in this query "
+ "(%s=%r)." % (field.name, val)
)
- elif hasattr(val, 'prepare_database_save'):
+ elif hasattr(val, "prepare_database_save"):
if field.remote_field:
val = field.get_db_prep_save(
val.prepare_database_save(field),
@@ -1576,29 +1777,29 @@ class SQLUpdateCompiler(SQLCompiler):
val = field.get_db_prep_save(val, connection=self.connection)
# Getting the placeholder for the field.
- if hasattr(field, 'get_placeholder'):
+ if hasattr(field, "get_placeholder"):
placeholder = field.get_placeholder(val, self, self.connection)
else:
- placeholder = '%s'
+ placeholder = "%s"
name = field.column
- if hasattr(val, 'as_sql'):
+ if hasattr(val, "as_sql"):
sql, params = self.compile(val)
- values.append('%s = %s' % (qn(name), placeholder % sql))
+ values.append("%s = %s" % (qn(name), placeholder % sql))
update_params.extend(params)
elif val is not None:
- values.append('%s = %s' % (qn(name), placeholder))
+ values.append("%s = %s" % (qn(name), placeholder))
update_params.append(val)
else:
- values.append('%s = NULL' % qn(name))
+ values.append("%s = NULL" % qn(name))
table = self.query.base_table
result = [
- 'UPDATE %s SET' % qn(table),
- ', '.join(values),
+ "UPDATE %s SET" % qn(table),
+ ", ".join(values),
]
where, params = self.compile(self.query.where)
if where:
- result.append('WHERE %s' % where)
- return ' '.join(result), tuple(update_params + params)
+ result.append("WHERE %s" % where)
+ return " ".join(result), tuple(update_params + params)
def execute_sql(self, result_type):
"""
@@ -1644,7 +1845,9 @@ class SQLUpdateCompiler(SQLCompiler):
query.add_fields([query.get_meta().pk.name])
super().pre_sql_setup()
- must_pre_select = count > 1 and not self.connection.features.update_can_self_select
+ must_pre_select = (
+ count > 1 and not self.connection.features.update_can_self_select
+ )
# Now we adjust the current query: reset the where clause and get rid
# of all the tables we don't need (since they're in the sub-select).
@@ -1656,11 +1859,11 @@ class SQLUpdateCompiler(SQLCompiler):
idents = []
for rows in query.get_compiler(self.using).execute_sql(MULTI):
idents.extend(r[0] for r in rows)
- self.query.add_filter('pk__in', idents)
+ self.query.add_filter("pk__in", idents)
self.query.related_ids = idents
else:
# The fast path. Filters and updates in one query.
- self.query.add_filter('pk__in', query)
+ self.query.add_filter("pk__in", query)
self.query.reset_refcounts(refcounts_before)
@@ -1677,13 +1880,14 @@ class SQLAggregateCompiler(SQLCompiler):
sql.append(ann_sql)
params.extend(ann_params)
self.col_count = len(self.query.annotation_select)
- sql = ', '.join(sql)
+ sql = ", ".join(sql)
params = tuple(params)
inner_query_sql, inner_query_params = self.query.inner_query.get_compiler(
- self.using, elide_empty=self.elide_empty,
+ self.using,
+ elide_empty=self.elide_empty,
).as_sql(with_col_aliases=True)
- sql = 'SELECT %s FROM (%s) subquery' % (sql, inner_query_sql)
+ sql = "SELECT %s FROM (%s) subquery" % (sql, inner_query_sql)
params = params + inner_query_params
return sql, params
diff --git a/django/db/models/sql/constants.py b/django/db/models/sql/constants.py
index a1db61b9ff..fdfb2ea891 100644
--- a/django/db/models/sql/constants.py
+++ b/django/db/models/sql/constants.py
@@ -9,16 +9,16 @@ GET_ITERATOR_CHUNK_SIZE = 100
# Namedtuples for sql.* internal use.
# How many results to expect from a cursor.execute call
-MULTI = 'multi'
-SINGLE = 'single'
-CURSOR = 'cursor'
-NO_RESULTS = 'no results'
+MULTI = "multi"
+SINGLE = "single"
+CURSOR = "cursor"
+NO_RESULTS = "no results"
ORDER_DIR = {
- 'ASC': ('ASC', 'DESC'),
- 'DESC': ('DESC', 'ASC'),
+ "ASC": ("ASC", "DESC"),
+ "DESC": ("DESC", "ASC"),
}
# SQL join types.
-INNER = 'INNER JOIN'
-LOUTER = 'LEFT OUTER JOIN'
+INNER = "INNER JOIN"
+LOUTER = "LEFT OUTER JOIN"
diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py
index e08b570350..f398074bf7 100644
--- a/django/db/models/sql/datastructures.py
+++ b/django/db/models/sql/datastructures.py
@@ -11,6 +11,7 @@ class MultiJoin(Exception):
multi-valued join was attempted (if the caller wants to treat that
exceptionally).
"""
+
def __init__(self, names_pos, path_with_names):
self.level = names_pos
# The path travelled, this includes the path to the multijoin.
@@ -38,8 +39,17 @@ class Join:
- as_sql()
- relabeled_clone()
"""
- def __init__(self, table_name, parent_alias, table_alias, join_type,
- join_field, nullable, filtered_relation=None):
+
+ def __init__(
+ self,
+ table_name,
+ parent_alias,
+ table_alias,
+ join_type,
+ join_field,
+ nullable,
+ filtered_relation=None,
+ ):
# Join table
self.table_name = table_name
self.parent_alias = parent_alias
@@ -69,35 +79,47 @@ class Join:
# Add a join condition for each pair of joining columns.
for lhs_col, rhs_col in self.join_cols:
- join_conditions.append('%s.%s = %s.%s' % (
- qn(self.parent_alias),
- qn2(lhs_col),
- qn(self.table_alias),
- qn2(rhs_col),
- ))
+ join_conditions.append(
+ "%s.%s = %s.%s"
+ % (
+ qn(self.parent_alias),
+ qn2(lhs_col),
+ qn(self.table_alias),
+ qn2(rhs_col),
+ )
+ )
# Add a single condition inside parentheses for whatever
# get_extra_restriction() returns.
- extra_cond = self.join_field.get_extra_restriction(self.table_alias, self.parent_alias)
+ extra_cond = self.join_field.get_extra_restriction(
+ self.table_alias, self.parent_alias
+ )
if extra_cond:
extra_sql, extra_params = compiler.compile(extra_cond)
- join_conditions.append('(%s)' % extra_sql)
+ join_conditions.append("(%s)" % extra_sql)
params.extend(extra_params)
if self.filtered_relation:
extra_sql, extra_params = compiler.compile(self.filtered_relation)
if extra_sql:
- join_conditions.append('(%s)' % extra_sql)
+ join_conditions.append("(%s)" % extra_sql)
params.extend(extra_params)
if not join_conditions:
# This might be a rel on the other end of an actual declared field.
- declared_field = getattr(self.join_field, 'field', self.join_field)
+ declared_field = getattr(self.join_field, "field", self.join_field)
raise ValueError(
"Join generated an empty ON clause. %s did not yield either "
"joining columns or extra restrictions." % declared_field.__class__
)
- on_clause_sql = ' AND '.join(join_conditions)
- alias_str = '' if self.table_alias == self.table_name else (' %s' % self.table_alias)
- sql = '%s %s%s ON (%s)' % (self.join_type, qn(self.table_name), alias_str, on_clause_sql)
+ on_clause_sql = " AND ".join(join_conditions)
+ alias_str = (
+ "" if self.table_alias == self.table_name else (" %s" % self.table_alias)
+ )
+ sql = "%s %s%s ON (%s)" % (
+ self.join_type,
+ qn(self.table_name),
+ alias_str,
+ on_clause_sql,
+ )
return sql, params
def relabeled_clone(self, change_map):
@@ -105,12 +127,19 @@ class Join:
new_table_alias = change_map.get(self.table_alias, self.table_alias)
if self.filtered_relation is not None:
filtered_relation = self.filtered_relation.clone()
- filtered_relation.path = [change_map.get(p, p) for p in self.filtered_relation.path]
+ filtered_relation.path = [
+ change_map.get(p, p) for p in self.filtered_relation.path
+ ]
else:
filtered_relation = None
return self.__class__(
- self.table_name, new_parent_alias, new_table_alias, self.join_type,
- self.join_field, self.nullable, filtered_relation=filtered_relation,
+ self.table_name,
+ new_parent_alias,
+ new_table_alias,
+ self.join_type,
+ self.join_field,
+ self.nullable,
+ filtered_relation=filtered_relation,
)
@property
@@ -153,6 +182,7 @@ class BaseTable:
SELECT * FROM "foo" WHERE somecond
could be generated by this class.
"""
+
join_type = None
parent_alias = None
filtered_relation = None
@@ -162,12 +192,16 @@ class BaseTable:
self.table_alias = alias
def as_sql(self, compiler, connection):
- alias_str = '' if self.table_alias == self.table_name else (' %s' % self.table_alias)
+ alias_str = (
+ "" if self.table_alias == self.table_name else (" %s" % self.table_alias)
+ )
base_sql = compiler.quote_name_unless_alias(self.table_name)
return base_sql + alias_str, []
def relabeled_clone(self, change_map):
- return self.__class__(self.table_name, change_map.get(self.table_alias, self.table_alias))
+ return self.__class__(
+ self.table_name, change_map.get(self.table_alias, self.table_alias)
+ )
@property
def identity(self):
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
index 1dc770ae3a..242b2a1f3f 100644
--- a/django/db/models/sql/query.py
+++ b/django/db/models/sql/query.py
@@ -20,32 +20,37 @@ from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections
from django.db.models.aggregates import Count
from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import (
- BaseExpression, Col, Exists, F, OuterRef, Ref, ResolvedOuterRef,
+ BaseExpression,
+ Col,
+ Exists,
+ F,
+ OuterRef,
+ Ref,
+ ResolvedOuterRef,
)
from django.db.models.fields import Field
from django.db.models.fields.related_lookups import MultiColSource
from django.db.models.lookups import Lookup
from django.db.models.query_utils import (
- Q, check_rel_lookup_compatibility, refs_expression,
+ Q,
+ check_rel_lookup_compatibility,
+ refs_expression,
)
from django.db.models.sql.constants import INNER, LOUTER, ORDER_DIR, SINGLE
-from django.db.models.sql.datastructures import (
- BaseTable, Empty, Join, MultiJoin,
-)
-from django.db.models.sql.where import (
- AND, OR, ExtraWhere, NothingNode, WhereNode,
-)
+from django.db.models.sql.datastructures import BaseTable, Empty, Join, MultiJoin
+from django.db.models.sql.where import AND, OR, ExtraWhere, NothingNode, WhereNode
from django.utils.functional import cached_property
from django.utils.tree import Node
-__all__ = ['Query', 'RawQuery']
+__all__ = ["Query", "RawQuery"]
def get_field_names_from_opts(opts):
- return set(chain.from_iterable(
- (f.name, f.attname) if f.concrete else (f.name,)
- for f in opts.get_fields()
- ))
+ return set(
+ chain.from_iterable(
+ (f.name, f.attname) if f.concrete else (f.name,) for f in opts.get_fields()
+ )
+ )
def get_children_from_q(q):
@@ -57,8 +62,8 @@ def get_children_from_q(q):
JoinInfo = namedtuple(
- 'JoinInfo',
- ('final_field', 'targets', 'opts', 'joins', 'path', 'transform_function')
+ "JoinInfo",
+ ("final_field", "targets", "opts", "joins", "path", "transform_function"),
)
@@ -87,8 +92,7 @@ class RawQuery:
if self.cursor is None:
self._execute_query()
converter = connections[self.using].introspection.identifier_converter
- return [converter(column_meta[0])
- for column_meta in self.cursor.description]
+ return [converter(column_meta[0]) for column_meta in self.cursor.description]
def __iter__(self):
# Always execute a new query for a new iterator.
@@ -136,17 +140,17 @@ class RawQuery:
self.cursor.execute(self.sql, params)
-ExplainInfo = namedtuple('ExplainInfo', ('format', 'options'))
+ExplainInfo = namedtuple("ExplainInfo", ("format", "options"))
class Query(BaseExpression):
"""A single SQL query."""
- alias_prefix = 'T'
+ alias_prefix = "T"
empty_result_set_value = None
subq_aliases = frozenset([alias_prefix])
- compiler = 'SQLCompiler'
+ compiler = "SQLCompiler"
base_table_class = BaseTable
join_class = Join
@@ -167,7 +171,7 @@ class Query(BaseExpression):
# aliases too.
# Map external tables to whether they are aliased.
self.external_aliases = {}
- self.table_map = {} # Maps table names to list of aliases.
+ self.table_map = {} # Maps table names to list of aliases.
self.default_cols = True
self.default_ordering = True
self.standard_ordering = True
@@ -240,13 +244,15 @@ class Query(BaseExpression):
def output_field(self):
if len(self.select) == 1:
select = self.select[0]
- return getattr(select, 'target', None) or select.field
+ return getattr(select, "target", None) or select.field
elif len(self.annotation_select) == 1:
return next(iter(self.annotation_select.values())).output_field
@property
def has_select_fields(self):
- return bool(self.select or self.annotation_select_mask or self.extra_select_mask)
+ return bool(
+ self.select or self.annotation_select_mask or self.extra_select_mask
+ )
@cached_property
def base_table(self):
@@ -282,7 +288,9 @@ class Query(BaseExpression):
raise ValueError("Need either using or connection")
if using:
connection = connections[using]
- return connection.ops.compiler(self.compiler)(self, connection, using, elide_empty)
+ return connection.ops.compiler(self.compiler)(
+ self, connection, using, elide_empty
+ )
def get_meta(self):
"""
@@ -311,9 +319,9 @@ class Query(BaseExpression):
if self.annotation_select_mask is not None:
obj.annotation_select_mask = self.annotation_select_mask.copy()
if self.combined_queries:
- obj.combined_queries = tuple([
- query.clone() for query in self.combined_queries
- ])
+ obj.combined_queries = tuple(
+ [query.clone() for query in self.combined_queries]
+ )
# _annotation_select_cache cannot be copied, as doing so breaks the
# (necessary) state in which both annotations and
# _annotation_select_cache point to the same underlying objects.
@@ -329,7 +337,7 @@ class Query(BaseExpression):
# Use deepcopy because select_related stores fields in nested
# dicts.
obj.select_related = copy.deepcopy(obj.select_related)
- if 'subq_aliases' in self.__dict__:
+ if "subq_aliases" in self.__dict__:
obj.subq_aliases = self.subq_aliases.copy()
obj.used_aliases = self.used_aliases.copy()
obj._filtered_relations = self._filtered_relations.copy()
@@ -351,7 +359,7 @@ class Query(BaseExpression):
if not obj.filter_is_sticky:
obj.used_aliases = set()
obj.filter_is_sticky = False
- if hasattr(obj, '_setup_query'):
+ if hasattr(obj, "_setup_query"):
obj._setup_query()
return obj
@@ -401,11 +409,13 @@ class Query(BaseExpression):
break
else:
# An expression that is not selected the subquery.
- if isinstance(expr, Col) or (expr.contains_aggregate and not expr.is_summary):
+ if isinstance(expr, Col) or (
+ expr.contains_aggregate and not expr.is_summary
+ ):
# Reference column or another aggregate. Select it
# under a non-conflicting alias.
col_cnt += 1
- col_alias = '__col%d' % col_cnt
+ col_alias = "__col%d" % col_cnt
self.annotations[col_alias] = expr
self.append_annotation_mask([col_alias])
new_expr = Ref(col_alias, expr)
@@ -424,8 +434,8 @@ class Query(BaseExpression):
if not self.annotation_select:
return {}
existing_annotations = [
- annotation for alias, annotation
- in self.annotations.items()
+ annotation
+ for alias, annotation in self.annotations.items()
if alias not in added_aggregate_names
]
# Decide if we need to use a subquery.
@@ -439,9 +449,15 @@ class Query(BaseExpression):
# those operations must be done in a subquery so that the query
# aggregates on the limit and/or distinct results instead of applying
# the distinct and limit after the aggregation.
- if (isinstance(self.group_by, tuple) or self.is_sliced or existing_annotations or
- self.distinct or self.combinator):
+ if (
+ isinstance(self.group_by, tuple)
+ or self.is_sliced
+ or existing_annotations
+ or self.distinct
+ or self.combinator
+ ):
from django.db.models.sql.subqueries import AggregateQuery
+
inner_query = self.clone()
inner_query.subquery = True
outer_query = AggregateQuery(self.model, inner_query)
@@ -459,15 +475,18 @@ class Query(BaseExpression):
# clearing the select clause can alter results if distinct is
# used.
has_existing_aggregate_annotations = any(
- annotation for annotation in existing_annotations
- if getattr(annotation, 'contains_aggregate', True)
+ annotation
+ for annotation in existing_annotations
+ if getattr(annotation, "contains_aggregate", True)
)
if inner_query.default_cols and has_existing_aggregate_annotations:
- inner_query.group_by = (self.model._meta.pk.get_col(inner_query.get_initial_alias()),)
+ inner_query.group_by = (
+ self.model._meta.pk.get_col(inner_query.get_initial_alias()),
+ )
inner_query.default_cols = False
- relabels = {t: 'subquery' for t in inner_query.alias_map}
- relabels[None] = 'subquery'
+ relabels = {t: "subquery" for t in inner_query.alias_map}
+ relabels[None] = "subquery"
# Remove any aggregates marked for reduction from the subquery
# and move them to the outer AggregateQuery.
col_cnt = 0
@@ -475,16 +494,24 @@ class Query(BaseExpression):
annotation_select_mask = inner_query.annotation_select_mask
if expression.is_summary:
expression, col_cnt = inner_query.rewrite_cols(expression, col_cnt)
- outer_query.annotations[alias] = expression.relabeled_clone(relabels)
+ outer_query.annotations[alias] = expression.relabeled_clone(
+ relabels
+ )
del inner_query.annotations[alias]
annotation_select_mask.remove(alias)
# Make sure the annotation_select wont use cached results.
inner_query.set_annotation_mask(inner_query.annotation_select_mask)
- if inner_query.select == () and not inner_query.default_cols and not inner_query.annotation_select_mask:
+ if (
+ inner_query.select == ()
+ and not inner_query.default_cols
+ and not inner_query.annotation_select_mask
+ ):
# In case of Model.objects[0:3].count(), there would be no
# field selected in the inner query, yet we must use a subquery.
# So, make sure at least one field is selected.
- inner_query.select = (self.model._meta.pk.get_col(inner_query.get_initial_alias()),)
+ inner_query.select = (
+ self.model._meta.pk.get_col(inner_query.get_initial_alias()),
+ )
else:
outer_query = self
self.select = ()
@@ -515,8 +542,8 @@ class Query(BaseExpression):
Perform a COUNT() query using the current filter constraints.
"""
obj = self.clone()
- obj.add_annotation(Count('*'), alias='__count', is_summary=True)
- return obj.get_aggregation(using, ['__count'])['__count']
+ obj.add_annotation(Count("*"), alias="__count", is_summary=True)
+ return obj.get_aggregation(using, ["__count"])["__count"]
def has_filters(self):
return self.where
@@ -525,13 +552,17 @@ class Query(BaseExpression):
q = self.clone()
if not q.distinct:
if q.group_by is True:
- q.add_fields((f.attname for f in self.model._meta.concrete_fields), False)
+ q.add_fields(
+ (f.attname for f in self.model._meta.concrete_fields), False
+ )
# Disable GROUP BY aliases to avoid orphaning references to the
# SELECT clause which is about to be cleared.
q.set_group_by(allow_aliases=False)
q.clear_select_clause()
- if q.combined_queries and q.combinator == 'union':
- limit_combined = connections[using].features.supports_slicing_ordering_in_compound
+ if q.combined_queries and q.combinator == "union":
+ limit_combined = connections[
+ using
+ ].features.supports_slicing_ordering_in_compound
q.combined_queries = tuple(
combined_query.exists(using, limit=limit_combined)
for combined_query in q.combined_queries
@@ -539,8 +570,8 @@ class Query(BaseExpression):
q.clear_ordering(force=True)
if limit:
q.set_limits(high=1)
- q.add_extra({'a': 1}, None, None, None, None, None)
- q.set_extra_mask(['a'])
+ q.add_extra({"a": 1}, None, None, None, None, None)
+ q.set_extra_mask(["a"])
return q
def has_results(self, using):
@@ -552,7 +583,7 @@ class Query(BaseExpression):
q = self.clone()
q.explain_info = ExplainInfo(format, options)
compiler = q.get_compiler(using=using)
- return '\n'.join(compiler.explain_query())
+ return "\n".join(compiler.explain_query())
def combine(self, rhs, connector):
"""
@@ -564,13 +595,13 @@ class Query(BaseExpression):
'rhs' query.
"""
if self.model != rhs.model:
- raise TypeError('Cannot combine queries on two different base models.')
+ raise TypeError("Cannot combine queries on two different base models.")
if self.is_sliced:
- raise TypeError('Cannot combine queries once a slice has been taken.')
+ raise TypeError("Cannot combine queries once a slice has been taken.")
if self.distinct != rhs.distinct:
- raise TypeError('Cannot combine a unique query with a non-unique query.')
+ raise TypeError("Cannot combine a unique query with a non-unique query.")
if self.distinct_fields != rhs.distinct_fields:
- raise TypeError('Cannot combine queries with different distinct fields.')
+ raise TypeError("Cannot combine queries with different distinct fields.")
# If lhs and rhs shares the same alias prefix, it is possible to have
# conflicting alias changes like T4 -> T5, T5 -> T6, which might end up
@@ -583,7 +614,7 @@ class Query(BaseExpression):
# Work out how to relabel the rhs aliases, if necessary.
change_map = {}
- conjunction = (connector == AND)
+ conjunction = connector == AND
# Determine which existing joins can be reused. When combining the
# query with AND we must recreate all joins for m2m filters. When
@@ -600,7 +631,8 @@ class Query(BaseExpression):
reuse = set() if conjunction else set(self.alias_map)
joinpromoter = JoinPromoter(connector, 2, False)
joinpromoter.add_votes(
- j for j in self.alias_map if self.alias_map[j].join_type == INNER)
+ j for j in self.alias_map if self.alias_map[j].join_type == INNER
+ )
rhs_votes = set()
# Now, add the joins from rhs query into the new query (skipping base
# table).
@@ -649,7 +681,9 @@ class Query(BaseExpression):
# really make sense (or return consistent value sets). Not worth
# the extra complexity when you can write a real query instead.
if self.extra and rhs.extra:
- raise ValueError("When merging querysets using 'or', you cannot have extra(select=...) on both sides.")
+ raise ValueError(
+ "When merging querysets using 'or', you cannot have extra(select=...) on both sides."
+ )
self.extra.update(rhs.extra)
extra_select_mask = set()
if self.extra_select_mask is not None:
@@ -767,11 +801,13 @@ class Query(BaseExpression):
# Create a new alias for this table.
if alias_list:
- alias = '%s%d' % (self.alias_prefix, len(self.alias_map) + 1)
+ alias = "%s%d" % (self.alias_prefix, len(self.alias_map) + 1)
alias_list.append(alias)
else:
# The first occurrence of a table uses the table name directly.
- alias = filtered_relation.alias if filtered_relation is not None else table_name
+ alias = (
+ filtered_relation.alias if filtered_relation is not None else table_name
+ )
self.table_map[table_name] = [alias]
self.alias_refcount[alias] = 1
return alias, True
@@ -806,16 +842,19 @@ class Query(BaseExpression):
# Only the first alias (skipped above) should have None join_type
assert self.alias_map[alias].join_type is not None
parent_alias = self.alias_map[alias].parent_alias
- parent_louter = parent_alias and self.alias_map[parent_alias].join_type == LOUTER
+ parent_louter = (
+ parent_alias and self.alias_map[parent_alias].join_type == LOUTER
+ )
already_louter = self.alias_map[alias].join_type == LOUTER
- if ((self.alias_map[alias].nullable or parent_louter) and
- not already_louter):
+ if (self.alias_map[alias].nullable or parent_louter) and not already_louter:
self.alias_map[alias] = self.alias_map[alias].promote()
# Join type of 'alias' changed, so re-examine all aliases that
# refer to this one.
aliases.extend(
- join for join in self.alias_map
- if self.alias_map[join].parent_alias == alias and join not in aliases
+ join
+ for join in self.alias_map
+ if self.alias_map[join].parent_alias == alias
+ and join not in aliases
)
def demote_joins(self, aliases):
@@ -861,10 +900,13 @@ class Query(BaseExpression):
# "group by" and "where".
self.where.relabel_aliases(change_map)
if isinstance(self.group_by, tuple):
- self.group_by = tuple([col.relabeled_clone(change_map) for col in self.group_by])
+ self.group_by = tuple(
+ [col.relabeled_clone(change_map) for col in self.group_by]
+ )
self.select = tuple([col.relabeled_clone(change_map) for col in self.select])
self.annotations = self.annotations and {
- key: col.relabeled_clone(change_map) for key, col in self.annotations.items()
+ key: col.relabeled_clone(change_map)
+ for key, col in self.annotations.items()
}
# 2. Rename the alias in the internal table/alias datastructures.
@@ -895,6 +937,7 @@ class Query(BaseExpression):
conflict. Even tables that previously had no alias will get an alias
after this call. To prevent changing aliases use the exclude parameter.
"""
+
def prefix_gen():
"""
Generate a sequence of characters in alphabetical order:
@@ -908,9 +951,9 @@ class Query(BaseExpression):
prefix = chr(ord(self.alias_prefix) + 1)
yield prefix
for n in count(1):
- seq = alphabet[alphabet.index(prefix):] if prefix else alphabet
+ seq = alphabet[alphabet.index(prefix) :] if prefix else alphabet
for s in product(seq, repeat=n):
- yield ''.join(s)
+ yield "".join(s)
prefix = None
if self.alias_prefix != other_query.alias_prefix:
@@ -928,17 +971,19 @@ class Query(BaseExpression):
break
if pos > local_recursion_limit:
raise RecursionError(
- 'Maximum recursion depth exceeded: too many subqueries.'
+ "Maximum recursion depth exceeded: too many subqueries."
)
self.subq_aliases = self.subq_aliases.union([self.alias_prefix])
other_query.subq_aliases = other_query.subq_aliases.union(self.subq_aliases)
if exclude is None:
exclude = {}
- self.change_aliases({
- alias: '%s%d' % (self.alias_prefix, pos)
- for pos, alias in enumerate(self.alias_map)
- if alias not in exclude
- })
+ self.change_aliases(
+ {
+ alias: "%s%d" % (self.alias_prefix, pos)
+ for pos, alias in enumerate(self.alias_map)
+ if alias not in exclude
+ }
+ )
def get_initial_alias(self):
"""
@@ -974,7 +1019,8 @@ class Query(BaseExpression):
joins are created as LOUTER if the join is nullable.
"""
reuse_aliases = [
- a for a, j in self.alias_map.items()
+ a
+ for a, j in self.alias_map.items()
if (reuse is None or a in reuse) and j.equals(join)
]
if reuse_aliases:
@@ -988,7 +1034,9 @@ class Query(BaseExpression):
return reuse_alias
# No reuse is possible, so we need a new alias.
- alias, _ = self.table_alias(join.table_name, create=True, filtered_relation=join.filtered_relation)
+ alias, _ = self.table_alias(
+ join.table_name, create=True, filtered_relation=join.filtered_relation
+ )
if join.join_type:
if self.alias_map[join.parent_alias].join_type == LOUTER or join.nullable:
join_type = LOUTER
@@ -1034,8 +1082,9 @@ class Query(BaseExpression):
def add_annotation(self, annotation, alias, is_summary=False, select=True):
"""Add a single annotation expression to the Query."""
- annotation = annotation.resolve_expression(self, allow_joins=True, reuse=None,
- summarize=is_summary)
+ annotation = annotation.resolve_expression(
+ self, allow_joins=True, reuse=None, summarize=is_summary
+ )
if select:
self.append_annotation_mask([alias])
else:
@@ -1050,27 +1099,32 @@ class Query(BaseExpression):
clone.where.resolve_expression(query, *args, **kwargs)
# Resolve combined queries.
if clone.combinator:
- clone.combined_queries = tuple([
- combined_query.resolve_expression(query, *args, **kwargs)
- for combined_query in clone.combined_queries
- ])
+ clone.combined_queries = tuple(
+ [
+ combined_query.resolve_expression(query, *args, **kwargs)
+ for combined_query in clone.combined_queries
+ ]
+ )
for key, value in clone.annotations.items():
resolved = value.resolve_expression(query, *args, **kwargs)
- if hasattr(resolved, 'external_aliases'):
+ if hasattr(resolved, "external_aliases"):
resolved.external_aliases.update(clone.external_aliases)
clone.annotations[key] = resolved
# Outer query's aliases are considered external.
for alias, table in query.alias_map.items():
clone.external_aliases[alias] = (
- (isinstance(table, Join) and table.join_field.related_model._meta.db_table != alias) or
- (isinstance(table, BaseTable) and table.table_name != table.table_alias)
+ isinstance(table, Join)
+ and table.join_field.related_model._meta.db_table != alias
+ ) or (
+ isinstance(table, BaseTable) and table.table_name != table.table_alias
)
return clone
def get_external_cols(self):
exprs = chain(self.annotations.values(), self.where.children)
return [
- col for col in self._gen_cols(exprs, include_external=True)
+ col
+ for col in self._gen_cols(exprs, include_external=True)
if col.alias in self.external_aliases
]
@@ -1086,19 +1140,21 @@ class Query(BaseExpression):
# Some backends (e.g. Oracle) raise an error when a subquery contains
# unnecessary ORDER BY clause.
if (
- self.subquery and
- not connection.features.ignores_unnecessary_order_by_in_subqueries
+ self.subquery
+ and not connection.features.ignores_unnecessary_order_by_in_subqueries
):
self.clear_ordering(force=False)
sql, params = self.get_compiler(connection=connection).as_sql()
if self.subquery:
- sql = '(%s)' % sql
+ sql = "(%s)" % sql
return sql, params
def resolve_lookup_value(self, value, can_reuse, allow_joins):
- if hasattr(value, 'resolve_expression'):
+ if hasattr(value, "resolve_expression"):
value = value.resolve_expression(
- self, reuse=can_reuse, allow_joins=allow_joins,
+ self,
+ reuse=can_reuse,
+ allow_joins=allow_joins,
)
elif isinstance(value, (list, tuple)):
# The items of the iterable may be expressions and therefore need
@@ -1108,7 +1164,7 @@ class Query(BaseExpression):
for sub_value in value
)
type_ = type(value)
- if hasattr(type_, '_make'): # namedtuple
+ if hasattr(type_, "_make"): # namedtuple
return type_(*values)
return type_(values)
return value
@@ -1119,15 +1175,17 @@ class Query(BaseExpression):
"""
lookup_splitted = lookup.split(LOOKUP_SEP)
if self.annotations:
- expression, expression_lookups = refs_expression(lookup_splitted, self.annotations)
+ expression, expression_lookups = refs_expression(
+ lookup_splitted, self.annotations
+ )
if expression:
return expression_lookups, (), expression
_, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta())
- field_parts = lookup_splitted[0:len(lookup_splitted) - len(lookup_parts)]
+ field_parts = lookup_splitted[0 : len(lookup_splitted) - len(lookup_parts)]
if len(lookup_parts) > 1 and not field_parts:
raise FieldError(
- 'Invalid lookup "%s" for model %s".' %
- (lookup, self.get_meta().model.__name__)
+ 'Invalid lookup "%s" for model %s".'
+ % (lookup, self.get_meta().model.__name__)
)
return lookup_parts, field_parts, False
@@ -1136,11 +1194,12 @@ class Query(BaseExpression):
Check whether the object passed while querying is of the correct type.
If not, raise a ValueError specifying the wrong object.
"""
- if hasattr(value, '_meta'):
+ if hasattr(value, "_meta"):
if not check_rel_lookup_compatibility(value._meta.model, opts, field):
raise ValueError(
- 'Cannot query "%s": Must be "%s" instance.' %
- (value, opts.object_name))
+ 'Cannot query "%s": Must be "%s" instance.'
+ % (value, opts.object_name)
+ )
def check_related_objects(self, field, value, opts):
"""Check the type of object passed to query relations."""
@@ -1150,29 +1209,31 @@ class Query(BaseExpression):
# opts would be Author's (from the author field) and value.model
# would be Author.objects.all() queryset's .model (Author also).
# The field is the related field on the lhs side.
- if (isinstance(value, Query) and not value.has_select_fields and
- not check_rel_lookup_compatibility(value.model, opts, field)):
+ if (
+ isinstance(value, Query)
+ and not value.has_select_fields
+ and not check_rel_lookup_compatibility(value.model, opts, field)
+ ):
raise ValueError(
- 'Cannot use QuerySet for "%s": Use a QuerySet for "%s".' %
- (value.model._meta.object_name, opts.object_name)
+ 'Cannot use QuerySet for "%s": Use a QuerySet for "%s".'
+ % (value.model._meta.object_name, opts.object_name)
)
- elif hasattr(value, '_meta'):
+ elif hasattr(value, "_meta"):
self.check_query_object_type(value, opts, field)
- elif hasattr(value, '__iter__'):
+ elif hasattr(value, "__iter__"):
for v in value:
self.check_query_object_type(v, opts, field)
def check_filterable(self, expression):
"""Raise an error if expression cannot be used in a WHERE clause."""
- if (
- hasattr(expression, 'resolve_expression') and
- not getattr(expression, 'filterable', True)
+ if hasattr(expression, "resolve_expression") and not getattr(
+ expression, "filterable", True
):
raise NotSupportedError(
- expression.__class__.__name__ + ' is disallowed in the filter '
- 'clause.'
+ expression.__class__.__name__ + " is disallowed in the filter "
+ "clause."
)
- if hasattr(expression, 'get_source_expressions'):
+ if hasattr(expression, "get_source_expressions"):
for expr in expression.get_source_expressions():
self.check_filterable(expr)
@@ -1186,7 +1247,7 @@ class Query(BaseExpression):
and get_transform().
"""
# __exact is the default lookup if one isn't given.
- *transforms, lookup_name = lookups or ['exact']
+ *transforms, lookup_name = lookups or ["exact"]
for name in transforms:
lhs = self.try_transform(lhs, name)
# First try get_lookup() so that the lookup takes precedence if the lhs
@@ -1194,11 +1255,13 @@ class Query(BaseExpression):
lookup_class = lhs.get_lookup(lookup_name)
if not lookup_class:
if lhs.field.is_relation:
- raise FieldError('Related Field got invalid lookup: {}'.format(lookup_name))
+ raise FieldError(
+ "Related Field got invalid lookup: {}".format(lookup_name)
+ )
# A lookup wasn't found. Try to interpret the name as a transform
# and do an Exact lookup against it.
lhs = self.try_transform(lhs, lookup_name)
- lookup_name = 'exact'
+ lookup_name = "exact"
lookup_class = lhs.get_lookup(lookup_name)
if not lookup_class:
return
@@ -1207,20 +1270,20 @@ class Query(BaseExpression):
# Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all
# uses of None as a query value unless the lookup supports it.
if lookup.rhs is None and not lookup.can_use_none_as_rhs:
- if lookup_name not in ('exact', 'iexact'):
+ if lookup_name not in ("exact", "iexact"):
raise ValueError("Cannot use None as a query value")
- return lhs.get_lookup('isnull')(lhs, True)
+ return lhs.get_lookup("isnull")(lhs, True)
# For Oracle '' is equivalent to null. The check must be done at this
# stage because join promotion can't be done in the compiler. Using
# DEFAULT_DB_ALIAS isn't nice but it's the best that can be done here.
# A similar thing is done in is_nullable(), too.
if (
- lookup_name == 'exact' and
- lookup.rhs == '' and
- connections[DEFAULT_DB_ALIAS].features.interprets_empty_strings_as_nulls
+ lookup_name == "exact"
+ and lookup.rhs == ""
+ and connections[DEFAULT_DB_ALIAS].features.interprets_empty_strings_as_nulls
):
- return lhs.get_lookup('isnull')(lhs, True)
+ return lhs.get_lookup("isnull")(lhs, True)
return lookup
@@ -1234,19 +1297,28 @@ class Query(BaseExpression):
return transform_class(lhs)
else:
output_field = lhs.output_field.__class__
- suggested_lookups = difflib.get_close_matches(name, output_field.get_lookups())
+ suggested_lookups = difflib.get_close_matches(
+ name, output_field.get_lookups()
+ )
if suggested_lookups:
- suggestion = ', perhaps you meant %s?' % ' or '.join(suggested_lookups)
+ suggestion = ", perhaps you meant %s?" % " or ".join(suggested_lookups)
else:
- suggestion = '.'
+ suggestion = "."
raise FieldError(
"Unsupported lookup '%s' for %s or join on the field not "
"permitted%s" % (name, output_field.__name__, suggestion)
)
- def build_filter(self, filter_expr, branch_negated=False, current_negated=False,
- can_reuse=None, allow_joins=True, split_subq=True,
- check_filterable=True):
+ def build_filter(
+ self,
+ filter_expr,
+ branch_negated=False,
+ current_negated=False,
+ can_reuse=None,
+ allow_joins=True,
+ split_subq=True,
+ check_filterable=True,
+ ):
"""
Build a WhereNode for a single filter clause but don't add it
to this Query. Query.add_q() will then add this filter to the where
@@ -1284,12 +1356,12 @@ class Query(BaseExpression):
split_subq=split_subq,
check_filterable=check_filterable,
)
- if hasattr(filter_expr, 'resolve_expression'):
- if not getattr(filter_expr, 'conditional', False):
- raise TypeError('Cannot filter against a non-conditional expression.')
+ if hasattr(filter_expr, "resolve_expression"):
+ if not getattr(filter_expr, "conditional", False):
+ raise TypeError("Cannot filter against a non-conditional expression.")
condition = filter_expr.resolve_expression(self, allow_joins=allow_joins)
if not isinstance(condition, Lookup):
- condition = self.build_lookup(['exact'], condition, True)
+ condition = self.build_lookup(["exact"], condition, True)
return WhereNode([condition], connector=AND), []
arg, value = filter_expr
if not arg:
@@ -1304,7 +1376,9 @@ class Query(BaseExpression):
pre_joins = self.alias_refcount.copy()
value = self.resolve_lookup_value(value, can_reuse, allow_joins)
- used_joins = {k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)}
+ used_joins = {
+ k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)
+ }
if check_filterable:
self.check_filterable(value)
@@ -1319,7 +1393,11 @@ class Query(BaseExpression):
try:
join_info = self.setup_joins(
- parts, opts, alias, can_reuse=can_reuse, allow_many=allow_many,
+ parts,
+ opts,
+ alias,
+ can_reuse=can_reuse,
+ allow_many=allow_many,
)
# Prevent iterator from being consumed by check_related_objects()
@@ -1336,7 +1414,9 @@ class Query(BaseExpression):
# Update used_joins before trimming since they are reused to determine
# which joins could be later promoted to INNER.
used_joins.update(join_info.joins)
- targets, alias, join_list = self.trim_joins(join_info.targets, join_info.joins, join_info.path)
+ targets, alias, join_list = self.trim_joins(
+ join_info.targets, join_info.joins, join_info.path
+ )
if can_reuse is not None:
can_reuse.update(join_list)
@@ -1344,11 +1424,15 @@ class Query(BaseExpression):
# No support for transforms for relational fields
num_lookups = len(lookups)
if num_lookups > 1:
- raise FieldError('Related Field got invalid lookup: {}'.format(lookups[0]))
+ raise FieldError(
+ "Related Field got invalid lookup: {}".format(lookups[0])
+ )
if len(targets) == 1:
col = self._get_col(targets[0], join_info.final_field, alias)
else:
- col = MultiColSource(alias, targets, join_info.targets, join_info.final_field)
+ col = MultiColSource(
+ alias, targets, join_info.targets, join_info.final_field
+ )
else:
col = self._get_col(targets[0], join_info.final_field, alias)
@@ -1356,10 +1440,16 @@ class Query(BaseExpression):
lookup_type = condition.lookup_name
clause = WhereNode([condition], connector=AND)
- require_outer = lookup_type == 'isnull' and condition.rhs is True and not current_negated
- if current_negated and (lookup_type != 'isnull' or condition.rhs is False) and condition.rhs is not None:
+ require_outer = (
+ lookup_type == "isnull" and condition.rhs is True and not current_negated
+ )
+ if (
+ current_negated
+ and (lookup_type != "isnull" or condition.rhs is False)
+ and condition.rhs is not None
+ ):
require_outer = True
- if lookup_type != 'isnull':
+ if lookup_type != "isnull":
# The condition added here will be SQL like this:
# NOT (col IS NOT NULL), where the first NOT is added in
# upper layers of code. The reason for addition is that if col
@@ -1370,16 +1460,16 @@ class Query(BaseExpression):
# <=>
# NOT (col IS NOT NULL AND col = someval).
if (
- self.is_nullable(targets[0]) or
- self.alias_map[join_list[-1]].join_type == LOUTER
+ self.is_nullable(targets[0])
+ or self.alias_map[join_list[-1]].join_type == LOUTER
):
- lookup_class = targets[0].get_lookup('isnull')
+ lookup_class = targets[0].get_lookup("isnull")
col = self._get_col(targets[0], join_info.targets[0], alias)
clause.add(lookup_class(col, False), AND)
# If someval is a nullable column, someval IS NOT NULL is
# added.
if isinstance(value, Col) and self.is_nullable(value.target):
- lookup_class = value.target.get_lookup('isnull')
+ lookup_class = value.target.get_lookup("isnull")
clause.add(lookup_class(value, False), AND)
return clause, used_joins if not require_outer else ()
@@ -1397,7 +1487,9 @@ class Query(BaseExpression):
# (Consider case where rel_a is LOUTER and rel_a__col=1 is added - if
# rel_a doesn't produce any rows, then the whole condition must fail.
# So, demotion is OK.
- existing_inner = {a for a in self.alias_map if self.alias_map[a].join_type == INNER}
+ existing_inner = {
+ a for a in self.alias_map if self.alias_map[a].join_type == INNER
+ }
clause, _ = self._add_q(q_object, self.used_aliases)
if clause:
self.where.add(clause, AND)
@@ -1409,20 +1501,33 @@ class Query(BaseExpression):
def clear_where(self):
self.where = WhereNode()
- def _add_q(self, q_object, used_aliases, branch_negated=False,
- current_negated=False, allow_joins=True, split_subq=True,
- check_filterable=True):
+ def _add_q(
+ self,
+ q_object,
+ used_aliases,
+ branch_negated=False,
+ current_negated=False,
+ allow_joins=True,
+ split_subq=True,
+ check_filterable=True,
+ ):
"""Add a Q-object to the current filter."""
connector = q_object.connector
current_negated = current_negated ^ q_object.negated
branch_negated = branch_negated or q_object.negated
target_clause = WhereNode(connector=connector, negated=q_object.negated)
- joinpromoter = JoinPromoter(q_object.connector, len(q_object.children), current_negated)
+ joinpromoter = JoinPromoter(
+ q_object.connector, len(q_object.children), current_negated
+ )
for child in q_object.children:
child_clause, needed_inner = self.build_filter(
- child, can_reuse=used_aliases, branch_negated=branch_negated,
- current_negated=current_negated, allow_joins=allow_joins,
- split_subq=split_subq, check_filterable=check_filterable,
+ child,
+ can_reuse=used_aliases,
+ branch_negated=branch_negated,
+ current_negated=current_negated,
+ allow_joins=allow_joins,
+ split_subq=split_subq,
+ check_filterable=check_filterable,
)
joinpromoter.add_votes(needed_inner)
if child_clause:
@@ -1430,7 +1535,9 @@ class Query(BaseExpression):
needed_inner = joinpromoter.update_join_types(self)
return target_clause, needed_inner
- def build_filtered_relation_q(self, q_object, reuse, branch_negated=False, current_negated=False):
+ def build_filtered_relation_q(
+ self, q_object, reuse, branch_negated=False, current_negated=False
+ ):
"""Add a FilteredRelation object to the current filter."""
connector = q_object.connector
current_negated ^= q_object.negated
@@ -1439,14 +1546,19 @@ class Query(BaseExpression):
for child in q_object.children:
if isinstance(child, Node):
child_clause = self.build_filtered_relation_q(
- child, reuse=reuse, branch_negated=branch_negated,
+ child,
+ reuse=reuse,
+ branch_negated=branch_negated,
current_negated=current_negated,
)
else:
child_clause, _ = self.build_filter(
- child, can_reuse=reuse, branch_negated=branch_negated,
+ child,
+ can_reuse=reuse,
+ branch_negated=branch_negated,
current_negated=current_negated,
- allow_joins=True, split_subq=False,
+ allow_joins=True,
+ split_subq=False,
)
target_clause.add(child_clause, connector)
return target_clause
@@ -1454,7 +1566,9 @@ class Query(BaseExpression):
def add_filtered_relation(self, filtered_relation, alias):
filtered_relation.alias = alias
lookups = dict(get_children_from_q(filtered_relation.condition))
- relation_lookup_parts, relation_field_parts, _ = self.solve_lookup_type(filtered_relation.relation_name)
+ relation_lookup_parts, relation_field_parts, _ = self.solve_lookup_type(
+ filtered_relation.relation_name
+ )
if relation_lookup_parts:
raise ValueError(
"FilteredRelation's relation_name cannot contain lookups "
@@ -1498,7 +1612,7 @@ class Query(BaseExpression):
path, names_with_path = [], []
for pos, name in enumerate(names):
cur_names_with_path = (name, [])
- if name == 'pk':
+ if name == "pk":
name = opts.pk.name
field = None
@@ -1513,7 +1627,10 @@ class Query(BaseExpression):
if LOOKUP_SEP in filtered_relation.relation_name:
parts = filtered_relation.relation_name.split(LOOKUP_SEP)
filtered_relation_path, field, _, _ = self.names_to_path(
- parts, opts, allow_many, fail_on_missing,
+ parts,
+ opts,
+ allow_many,
+ fail_on_missing,
)
path.extend(filtered_relation_path[:-1])
else:
@@ -1540,13 +1657,17 @@ class Query(BaseExpression):
# one step.
pos -= 1
if pos == -1 or fail_on_missing:
- available = sorted([
- *get_field_names_from_opts(opts),
- *self.annotation_select,
- *self._filtered_relations,
- ])
- raise FieldError("Cannot resolve keyword '%s' into field. "
- "Choices are: %s" % (name, ", ".join(available)))
+ available = sorted(
+ [
+ *get_field_names_from_opts(opts),
+ *self.annotation_select,
+ *self._filtered_relations,
+ ]
+ )
+ raise FieldError(
+ "Cannot resolve keyword '%s' into field. "
+ "Choices are: %s" % (name, ", ".join(available))
+ )
break
# Check if we need any joins for concrete inheritance cases (the
# field lives in parent, but we are currently in one of its
@@ -1557,7 +1678,7 @@ class Query(BaseExpression):
path.extend(path_to_parent)
cur_names_with_path[1].extend(path_to_parent)
opts = path_to_parent[-1].to_opts
- if hasattr(field, 'path_infos'):
+ if hasattr(field, "path_infos"):
if filtered_relation:
pathinfos = field.get_path_info(filtered_relation)
else:
@@ -1565,7 +1686,7 @@ class Query(BaseExpression):
if not allow_many:
for inner_pos, p in enumerate(pathinfos):
if p.m2m:
- cur_names_with_path[1].extend(pathinfos[0:inner_pos + 1])
+ cur_names_with_path[1].extend(pathinfos[0 : inner_pos + 1])
names_with_path.append(cur_names_with_path)
raise MultiJoin(pos + 1, names_with_path)
last = pathinfos[-1]
@@ -1582,9 +1703,10 @@ class Query(BaseExpression):
if fail_on_missing and pos + 1 != len(names):
raise FieldError(
"Cannot resolve keyword %r into field. Join on '%s'"
- " not permitted." % (names[pos + 1], name))
+ " not permitted." % (names[pos + 1], name)
+ )
break
- return path, final_field, targets, names[pos + 1:]
+ return path, final_field, targets, names[pos + 1 :]
def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True):
"""
@@ -1631,7 +1753,10 @@ class Query(BaseExpression):
for pivot in range(len(names), 0, -1):
try:
path, final_field, targets, rest = self.names_to_path(
- names[:pivot], opts, allow_many, fail_on_missing=True,
+ names[:pivot],
+ opts,
+ allow_many,
+ fail_on_missing=True,
)
except FieldError as exc:
if pivot == 1:
@@ -1646,6 +1771,7 @@ class Query(BaseExpression):
transforms = names[pivot:]
break
for name in transforms:
+
def transform(field, alias, *, name, previous):
try:
wrapped = previous(field, alias)
@@ -1656,7 +1782,10 @@ class Query(BaseExpression):
raise last_field_exception
else:
raise
- final_transformer = functools.partial(transform, name=name, previous=final_transformer)
+
+ final_transformer = functools.partial(
+ transform, name=name, previous=final_transformer
+ )
# Then, add the path to the query's joins. Note that we can't trim
# joins at this stage - we will need the information about join type
# of the trimmed joins.
@@ -1673,8 +1802,13 @@ class Query(BaseExpression):
else:
nullable = True
connection = self.join_class(
- opts.db_table, alias, table_alias, INNER, join.join_field,
- nullable, filtered_relation=filtered_relation,
+ opts.db_table,
+ alias,
+ table_alias,
+ INNER,
+ join.join_field,
+ nullable,
+ filtered_relation=filtered_relation,
)
reuse = can_reuse if join.m2m else None
alias = self.join(connection, reuse=reuse)
@@ -1706,7 +1840,11 @@ class Query(BaseExpression):
cur_targets = {t.column for t in targets}
if not cur_targets.issubset(join_targets):
break
- targets_dict = {r[1].column: r[0] for r in info.join_field.related_fields if r[1].column in cur_targets}
+ targets_dict = {
+ r[1].column: r[0]
+ for r in info.join_field.related_fields
+ if r[1].column in cur_targets
+ }
targets = tuple(targets_dict[t.column] for t in targets)
self.unref_alias(joins.pop())
return targets, joins[-1], joins
@@ -1716,9 +1854,11 @@ class Query(BaseExpression):
for expr in exprs:
if isinstance(expr, Col):
yield expr
- elif include_external and callable(getattr(expr, 'get_external_cols', None)):
+ elif include_external and callable(
+ getattr(expr, "get_external_cols", None)
+ ):
yield from expr.get_external_cols()
- elif hasattr(expr, 'get_source_expressions'):
+ elif hasattr(expr, "get_source_expressions"):
yield from cls._gen_cols(
expr.get_source_expressions(),
include_external=include_external,
@@ -1735,7 +1875,7 @@ class Query(BaseExpression):
for alias in self._gen_col_aliases([annotation]):
if isinstance(self.alias_map[alias], Join):
raise FieldError(
- 'Joined field references are not permitted in this query'
+ "Joined field references are not permitted in this query"
)
if summarize:
# Summarize currently means we are doing an aggregate() query
@@ -1757,10 +1897,16 @@ class Query(BaseExpression):
for transform in field_list[1:]:
annotation = self.try_transform(annotation, transform)
return annotation
- join_info = self.setup_joins(field_list, self.get_meta(), self.get_initial_alias(), can_reuse=reuse)
- targets, final_alias, join_list = self.trim_joins(join_info.targets, join_info.joins, join_info.path)
+ join_info = self.setup_joins(
+ field_list, self.get_meta(), self.get_initial_alias(), can_reuse=reuse
+ )
+ targets, final_alias, join_list = self.trim_joins(
+ join_info.targets, join_info.joins, join_info.path
+ )
if not allow_joins and len(join_list) > 1:
- raise FieldError('Joined field references are not permitted in this query')
+ raise FieldError(
+ "Joined field references are not permitted in this query"
+ )
if len(targets) > 1:
raise FieldError(
"Referencing multicolumn fields with F() objects isn't supported"
@@ -1813,23 +1959,25 @@ class Query(BaseExpression):
# Need to add a restriction so that outer query's filters are in effect for
# the subquery, too.
query.bump_prefix(self)
- lookup_class = select_field.get_lookup('exact')
+ lookup_class = select_field.get_lookup("exact")
# Note that the query.select[0].alias is different from alias
# due to bump_prefix above.
- lookup = lookup_class(pk.get_col(query.select[0].alias),
- pk.get_col(alias))
+ lookup = lookup_class(pk.get_col(query.select[0].alias), pk.get_col(alias))
query.where.add(lookup, AND)
query.external_aliases[alias] = True
- lookup_class = select_field.get_lookup('exact')
+ lookup_class = select_field.get_lookup("exact")
lookup = lookup_class(col, ResolvedOuterRef(trimmed_prefix))
query.where.add(lookup, AND)
condition, needed_inner = self.build_filter(Exists(query))
if contains_louter:
or_null_condition, _ = self.build_filter(
- ('%s__isnull' % trimmed_prefix, True),
- current_negated=True, branch_negated=True, can_reuse=can_reuse)
+ ("%s__isnull" % trimmed_prefix, True),
+ current_negated=True,
+ branch_negated=True,
+ can_reuse=can_reuse,
+ )
condition.add(or_null_condition, OR)
# Note that the end result will be:
# (outercol NOT IN innerq AND outercol IS NOT NULL) OR outercol IS NULL.
@@ -1907,8 +2055,8 @@ class Query(BaseExpression):
self.values_select = ()
def add_select_col(self, col, name):
- self.select += col,
- self.values_select += name,
+ self.select += (col,)
+ self.values_select += (name,)
def set_select(self, cols):
self.default_cols = False
@@ -1934,7 +2082,9 @@ class Query(BaseExpression):
for name in field_names:
# Join promotion note - we must not remove any rows here, so
# if there is no existing joins, use outer join.
- join_info = self.setup_joins(name.split(LOOKUP_SEP), opts, alias, allow_many=allow_m2m)
+ join_info = self.setup_joins(
+ name.split(LOOKUP_SEP), opts, alias, allow_many=allow_m2m
+ )
targets, final_alias, joins = self.trim_joins(
join_info.targets,
join_info.joins,
@@ -1957,12 +2107,18 @@ class Query(BaseExpression):
"it." % name
)
else:
- names = sorted([
- *get_field_names_from_opts(opts), *self.extra,
- *self.annotation_select, *self._filtered_relations
- ])
- raise FieldError("Cannot resolve keyword %r into field. "
- "Choices are: %s" % (name, ", ".join(names)))
+ names = sorted(
+ [
+ *get_field_names_from_opts(opts),
+ *self.extra,
+ *self.annotation_select,
+ *self._filtered_relations,
+ ]
+ )
+ raise FieldError(
+ "Cannot resolve keyword %r into field. "
+ "Choices are: %s" % (name, ", ".join(names))
+ )
def add_ordering(self, *ordering):
"""
@@ -1976,9 +2132,9 @@ class Query(BaseExpression):
errors = []
for item in ordering:
if isinstance(item, str):
- if item == '?':
+ if item == "?":
continue
- if item.startswith('-'):
+ if item.startswith("-"):
item = item[1:]
if item in self.annotations:
continue
@@ -1987,15 +2143,15 @@ class Query(BaseExpression):
# names_to_path() validates the lookup. A descriptive
# FieldError will be raise if it's not.
self.names_to_path(item.split(LOOKUP_SEP), self.model._meta)
- elif not hasattr(item, 'resolve_expression'):
+ elif not hasattr(item, "resolve_expression"):
errors.append(item)
- if getattr(item, 'contains_aggregate', False):
+ if getattr(item, "contains_aggregate", False):
raise FieldError(
- 'Using an aggregate in order_by() without also including '
- 'it in annotate() is not allowed: %s' % item
+ "Using an aggregate in order_by() without also including "
+ "it in annotate() is not allowed: %s" % item
)
if errors:
- raise FieldError('Invalid order_by arguments: %s' % errors)
+ raise FieldError("Invalid order_by arguments: %s" % errors)
if ordering:
self.order_by += ordering
else:
@@ -2008,7 +2164,9 @@ class Query(BaseExpression):
If 'clear_default' is True, there will be no ordering in the resulting
query (not even the model's default).
"""
- if not force and (self.is_sliced or self.distinct_fields or self.select_for_update):
+ if not force and (
+ self.is_sliced or self.distinct_fields or self.select_for_update
+ ):
return
self.order_by = ()
self.extra_order_by = ()
@@ -2031,10 +2189,9 @@ class Query(BaseExpression):
for join in list(self.alias_map.values())[1:]: # Skip base table.
model = join.join_field.related_model
if model not in seen_models:
- column_names.update({
- field.column
- for field in model._meta.local_concrete_fields
- })
+ column_names.update(
+ {field.column for field in model._meta.local_concrete_fields}
+ )
seen_models.add(model)
group_by = list(self.select)
@@ -2082,7 +2239,7 @@ class Query(BaseExpression):
entry_params = []
pos = entry.find("%s")
while pos != -1:
- if pos == 0 or entry[pos - 1] != '%':
+ if pos == 0 or entry[pos - 1] != "%":
entry_params.append(next(param_iter))
pos = entry.find("%s", pos + 2)
select_pairs[name] = (entry, entry_params)
@@ -2135,8 +2292,8 @@ class Query(BaseExpression):
"""
existing, defer = self.deferred_loading
field_names = set(field_names)
- if 'pk' in field_names:
- field_names.remove('pk')
+ if "pk" in field_names:
+ field_names.remove("pk")
field_names.add(self.get_meta().pk.name)
if defer:
@@ -2224,7 +2381,9 @@ class Query(BaseExpression):
# Selected annotations must be known before setting the GROUP BY
# clause.
if self.group_by is True:
- self.add_fields((f.attname for f in self.model._meta.concrete_fields), False)
+ self.add_fields(
+ (f.attname for f in self.model._meta.concrete_fields), False
+ )
# Disable GROUP BY aliases to avoid orphaning references to the
# SELECT clause which is about to be cleared.
self.set_group_by(allow_aliases=False)
@@ -2254,7 +2413,8 @@ class Query(BaseExpression):
return {}
elif self.annotation_select_mask is not None:
self._annotation_select_cache = {
- k: v for k, v in self.annotations.items()
+ k: v
+ for k, v in self.annotations.items()
if k in self.annotation_select_mask
}
return self._annotation_select_cache
@@ -2269,8 +2429,7 @@ class Query(BaseExpression):
return {}
elif self.extra_select_mask is not None:
self._extra_select_cache = {
- k: v for k, v in self.extra.items()
- if k in self.extra_select_mask
+ k: v for k, v in self.extra.items() if k in self.extra_select_mask
}
return self._extra_select_cache
else:
@@ -2297,8 +2456,7 @@ class Query(BaseExpression):
# the lookup part of the query. That is, avoid trimming
# joins generated for F() expressions.
lookup_tables = [
- t for t in self.alias_map
- if t in self._lookup_joins or t == self.base_table
+ t for t in self.alias_map if t in self._lookup_joins or t == self.base_table
]
for trimmed_paths, path in enumerate(all_paths):
if path.m2m:
@@ -2317,8 +2475,7 @@ class Query(BaseExpression):
break
trimmed_prefix.append(name)
paths_in_prefix -= len(path)
- trimmed_prefix.append(
- join_field.foreign_related_fields[0].name)
+ trimmed_prefix.append(join_field.foreign_related_fields[0].name)
trimmed_prefix = LOOKUP_SEP.join(trimmed_prefix)
# Lets still see if we can trim the first join from the inner query
# (that is, self). We can't do this for:
@@ -2331,7 +2488,9 @@ class Query(BaseExpression):
select_fields = [r[0] for r in join_field.related_fields]
select_alias = lookup_tables[trimmed_paths + 1]
self.unref_alias(lookup_tables[trimmed_paths])
- extra_restriction = join_field.get_extra_restriction(None, lookup_tables[trimmed_paths + 1])
+ extra_restriction = join_field.get_extra_restriction(
+ None, lookup_tables[trimmed_paths + 1]
+ )
if extra_restriction:
self.where.add(extra_restriction, AND)
else:
@@ -2367,12 +2526,12 @@ class Query(BaseExpression):
# is_nullable() is needed to the compiler stage, but that is not easy
# to do currently.
return field.null or (
- field.empty_strings_allowed and
- connections[DEFAULT_DB_ALIAS].features.interprets_empty_strings_as_nulls
+ field.empty_strings_allowed
+ and connections[DEFAULT_DB_ALIAS].features.interprets_empty_strings_as_nulls
)
-def get_order_dir(field, default='ASC'):
+def get_order_dir(field, default="ASC"):
"""
Return the field name and direction for an order specification. For
example, '-foo' is returned as ('foo', 'DESC').
@@ -2381,7 +2540,7 @@ def get_order_dir(field, default='ASC'):
prefix) should sort. The '-' prefix always sorts the opposite way.
"""
dirn = ORDER_DIR[default]
- if field[0] == '-':
+ if field[0] == "-":
return field[1:], dirn[1]
return field, dirn[0]
@@ -2428,8 +2587,8 @@ class JoinPromoter:
def __repr__(self):
return (
- f'{self.__class__.__qualname__}(connector={self.connector!r}, '
- f'num_children={self.num_children!r}, negated={self.negated!r})'
+ f"{self.__class__.__qualname__}(connector={self.connector!r}, "
+ f"num_children={self.num_children!r}, negated={self.negated!r})"
)
def add_votes(self, votes):
@@ -2461,7 +2620,7 @@ class JoinPromoter:
# to rel_a would remove a valid match from the query. So, we need
# to promote any existing INNER to LOUTER (it is possible this
# promotion in turn will be demoted later on).
- if self.effective_connector == 'OR' and votes < self.num_children:
+ if self.effective_connector == "OR" and votes < self.num_children:
to_promote.add(table)
# If connector is AND and there is a filter that can match only
# when there is a joinable row, then use INNER. For example, in
@@ -2473,8 +2632,9 @@ class JoinPromoter:
# (rel_a__col__icontains=Alex | rel_a__col__icontains=Russell)
# then if rel_a doesn't produce any rows, the whole condition
# can't match. Hence we can safely use INNER join.
- if self.effective_connector == 'AND' or (
- self.effective_connector == 'OR' and votes == self.num_children):
+ if self.effective_connector == "AND" or (
+ self.effective_connector == "OR" and votes == self.num_children
+ ):
to_demote.add(table)
# Finally, what happens in cases where we have:
# (rel_a__col=1|rel_b__col=2) & rel_a__col__gte=0
diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py
index f6a371a925..04063f73bc 100644
--- a/django/db/models/sql/subqueries.py
+++ b/django/db/models/sql/subqueries.py
@@ -3,18 +3,16 @@ Query subclasses which provide extra functionality beyond simple data retrieval.
"""
from django.core.exceptions import FieldError
-from django.db.models.sql.constants import (
- CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS,
-)
+from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS
from django.db.models.sql.query import Query
-__all__ = ['DeleteQuery', 'UpdateQuery', 'InsertQuery', 'AggregateQuery']
+__all__ = ["DeleteQuery", "UpdateQuery", "InsertQuery", "AggregateQuery"]
class DeleteQuery(Query):
"""A DELETE SQL query."""
- compiler = 'SQLDeleteCompiler'
+ compiler = "SQLDeleteCompiler"
def do_query(self, table, where, using):
self.alias_map = {table: self.alias_map[table]}
@@ -38,17 +36,19 @@ class DeleteQuery(Query):
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
self.clear_where()
self.add_filter(
- f'{field.attname}__in',
- pk_list[offset:offset + GET_ITERATOR_CHUNK_SIZE],
+ f"{field.attname}__in",
+ pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE],
+ )
+ num_deleted += self.do_query(
+ self.get_meta().db_table, self.where, using=using
)
- num_deleted += self.do_query(self.get_meta().db_table, self.where, using=using)
return num_deleted
class UpdateQuery(Query):
"""An UPDATE SQL query."""
- compiler = 'SQLUpdateCompiler'
+ compiler = "SQLUpdateCompiler"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -72,7 +72,9 @@ class UpdateQuery(Query):
self.add_update_values(values)
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
self.clear_where()
- self.add_filter('pk__in', pk_list[offset: offset + GET_ITERATOR_CHUNK_SIZE])
+ self.add_filter(
+ "pk__in", pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]
+ )
self.get_compiler(using).execute_sql(NO_RESULTS)
def add_update_values(self, values):
@@ -84,12 +86,14 @@ class UpdateQuery(Query):
values_seq = []
for name, val in values.items():
field = self.get_meta().get_field(name)
- direct = not (field.auto_created and not field.concrete) or not field.concrete
+ direct = (
+ not (field.auto_created and not field.concrete) or not field.concrete
+ )
model = field.model._meta.concrete_model
if not direct or (field.is_relation and field.many_to_many):
raise FieldError(
- 'Cannot update model field %r (only non-relations and '
- 'foreign keys permitted).' % field
+ "Cannot update model field %r (only non-relations and "
+ "foreign keys permitted)." % field
)
if model is not self.get_meta().concrete_model:
self.add_related_update(model, field, val)
@@ -104,7 +108,7 @@ class UpdateQuery(Query):
called add_update_targets() to hint at the extra information here.
"""
for field, model, val in values_seq:
- if hasattr(val, 'resolve_expression'):
+ if hasattr(val, "resolve_expression"):
# Resolve expressions here so that annotations are no longer needed
val = val.resolve_expression(self, allow_joins=False, for_save=True)
self.values.append((field, model, val))
@@ -130,15 +134,17 @@ class UpdateQuery(Query):
query = UpdateQuery(model)
query.values = values
if self.related_ids is not None:
- query.add_filter('pk__in', self.related_ids)
+ query.add_filter("pk__in", self.related_ids)
result.append(query)
return result
class InsertQuery(Query):
- compiler = 'SQLInsertCompiler'
+ compiler = "SQLInsertCompiler"
- def __init__(self, *args, on_conflict=None, update_fields=None, unique_fields=None, **kwargs):
+ def __init__(
+ self, *args, on_conflict=None, update_fields=None, unique_fields=None, **kwargs
+ ):
super().__init__(*args, **kwargs)
self.fields = []
self.objs = []
@@ -158,7 +164,7 @@ class AggregateQuery(Query):
elements in the provided list.
"""
- compiler = 'SQLAggregateCompiler'
+ compiler = "SQLAggregateCompiler"
def __init__(self, model, inner_query):
self.inner_query = inner_query
diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py
index 50ff13be75..532780fd98 100644
--- a/django/db/models/sql/where.py
+++ b/django/db/models/sql/where.py
@@ -7,8 +7,8 @@ from django.utils import tree
from django.utils.functional import cached_property
# Connection types
-AND = 'AND'
-OR = 'OR'
+AND = "AND"
+OR = "OR"
class WhereNode(tree.Node):
@@ -25,6 +25,7 @@ class WhereNode(tree.Node):
relabeled_clone() method or relabel_aliases() and clone() methods and
contains_aggregate attribute.
"""
+
default = AND
resolved = False
conditional = True
@@ -40,15 +41,15 @@ class WhereNode(tree.Node):
in_negated = negated ^ self.negated
# If the effective connector is OR and this node contains an aggregate,
# then we need to push the whole branch to HAVING clause.
- may_need_split = (
- (in_negated and self.connector == AND) or
- (not in_negated and self.connector == OR))
+ may_need_split = (in_negated and self.connector == AND) or (
+ not in_negated and self.connector == OR
+ )
if may_need_split and self.contains_aggregate:
return None, self
where_parts = []
having_parts = []
for c in self.children:
- if hasattr(c, 'split_having'):
+ if hasattr(c, "split_having"):
where_part, having_part = c.split_having(in_negated)
if where_part is not None:
where_parts.append(where_part)
@@ -58,8 +59,16 @@ class WhereNode(tree.Node):
having_parts.append(c)
else:
where_parts.append(c)
- having_node = self.__class__(having_parts, self.connector, self.negated) if having_parts else None
- where_node = self.__class__(where_parts, self.connector, self.negated) if where_parts else None
+ having_node = (
+ self.__class__(having_parts, self.connector, self.negated)
+ if having_parts
+ else None
+ )
+ where_node = (
+ self.__class__(where_parts, self.connector, self.negated)
+ if where_parts
+ else None
+ )
return where_node, having_node
def as_sql(self, compiler, connection):
@@ -94,24 +103,24 @@ class WhereNode(tree.Node):
# counts.
if empty_needed == 0:
if self.negated:
- return '', []
+ return "", []
else:
raise EmptyResultSet
if full_needed == 0:
if self.negated:
raise EmptyResultSet
else:
- return '', []
- conn = ' %s ' % self.connector
+ return "", []
+ conn = " %s " % self.connector
sql_string = conn.join(result)
if sql_string:
if self.negated:
# Some backends (Oracle at least) need parentheses
# around the inner SQL in the negated case, even if the
# inner SQL contains just a single expression.
- sql_string = 'NOT (%s)' % sql_string
+ sql_string = "NOT (%s)" % sql_string
elif len(result) > 1 or self.resolved:
- sql_string = '(%s)' % sql_string
+ sql_string = "(%s)" % sql_string
return sql_string, result_params
def get_group_by_cols(self, alias=None):
@@ -133,10 +142,10 @@ class WhereNode(tree.Node):
mapping old (current) alias values to the new values.
"""
for pos, child in enumerate(self.children):
- if hasattr(child, 'relabel_aliases'):
+ if hasattr(child, "relabel_aliases"):
# For example another WhereNode
child.relabel_aliases(change_map)
- elif hasattr(child, 'relabeled_clone'):
+ elif hasattr(child, "relabeled_clone"):
self.children[pos] = child.relabeled_clone(change_map)
def clone(self):
@@ -146,10 +155,12 @@ class WhereNode(tree.Node):
value) tuples, or objects supporting .clone().
"""
clone = self.__class__._new_instance(
- children=None, connector=self.connector, negated=self.negated,
+ children=None,
+ connector=self.connector,
+ negated=self.negated,
)
for child in self.children:
- if hasattr(child, 'clone'):
+ if hasattr(child, "clone"):
clone.children.append(child.clone())
else:
clone.children.append(child)
@@ -185,18 +196,18 @@ class WhereNode(tree.Node):
@staticmethod
def _resolve_leaf(expr, query, *args, **kwargs):
- if hasattr(expr, 'resolve_expression'):
+ if hasattr(expr, "resolve_expression"):
expr = expr.resolve_expression(query, *args, **kwargs)
return expr
@classmethod
def _resolve_node(cls, node, query, *args, **kwargs):
- if hasattr(node, 'children'):
+ if hasattr(node, "children"):
for child in node.children:
cls._resolve_node(child, query, *args, **kwargs)
- if hasattr(node, 'lhs'):
+ if hasattr(node, "lhs"):
node.lhs = cls._resolve_leaf(node.lhs, query, *args, **kwargs)
- if hasattr(node, 'rhs'):
+ if hasattr(node, "rhs"):
node.rhs = cls._resolve_leaf(node.rhs, query, *args, **kwargs)
def resolve_expression(self, *args, **kwargs):
@@ -208,6 +219,7 @@ class WhereNode(tree.Node):
@cached_property
def output_field(self):
from django.db.models import BooleanField
+
return BooleanField()
def select_format(self, compiler, sql, params):
@@ -215,7 +227,7 @@ class WhereNode(tree.Node):
# (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
# BY list.
if not compiler.connection.features.supports_boolean_expr_in_select_clause:
- sql = f'CASE WHEN {sql} THEN 1 ELSE 0 END'
+ sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
return sql, params
def get_db_converters(self, connection):
@@ -227,6 +239,7 @@ class WhereNode(tree.Node):
class NothingNode:
"""A node that matches nothing."""
+
contains_aggregate = False
def as_sql(self, compiler=None, connection=None):
diff --git a/django/db/models/utils.py b/django/db/models/utils.py
index 949c528469..5521f3cca5 100644
--- a/django/db/models/utils.py
+++ b/django/db/models/utils.py
@@ -46,7 +46,7 @@ def create_namedtuple_class(*names):
return unpickle_named_row, (names, tuple(self))
return type(
- 'Row',
- (namedtuple('Row', names),),
- {'__reduce__': __reduce__, '__slots__': ()},
+ "Row",
+ (namedtuple("Row", names),),
+ {"__reduce__": __reduce__, "__slots__": ()},
)
diff --git a/django/db/transaction.py b/django/db/transaction.py
index b61785754f..b3c7b4bbaa 100644
--- a/django/db/transaction.py
+++ b/django/db/transaction.py
@@ -1,12 +1,17 @@
from contextlib import ContextDecorator, contextmanager
from django.db import (
- DEFAULT_DB_ALIAS, DatabaseError, Error, ProgrammingError, connections,
+ DEFAULT_DB_ALIAS,
+ DatabaseError,
+ Error,
+ ProgrammingError,
+ connections,
)
class TransactionManagementError(ProgrammingError):
"""Transaction management is used improperly."""
+
pass
@@ -132,6 +137,7 @@ def on_commit(func, using=None):
# Decorators / context managers #
#################################
+
class Atomic(ContextDecorator):
"""
Guarantee the atomic execution of a given block.
@@ -176,13 +182,13 @@ class Atomic(ContextDecorator):
connection = get_connection(self.using)
if (
- self.durable and
- connection.atomic_blocks and
- not connection.atomic_blocks[-1]._from_testcase
+ self.durable
+ and connection.atomic_blocks
+ and not connection.atomic_blocks[-1]._from_testcase
):
raise RuntimeError(
- 'A durable atomic block cannot be nested within another '
- 'atomic block.'
+ "A durable atomic block cannot be nested within another "
+ "atomic block."
)
if not connection.in_atomic_block:
# Reset state when entering an outermost atomic block.
@@ -206,7 +212,9 @@ class Atomic(ContextDecorator):
else:
connection.savepoint_ids.append(None)
else:
- connection.set_autocommit(False, force_begin_transaction_with_broken_autocommit=True)
+ connection.set_autocommit(
+ False, force_begin_transaction_with_broken_autocommit=True
+ )
connection.in_atomic_block = True
if connection.in_atomic_block:
diff --git a/django/db/utils.py b/django/db/utils.py
index 82498d1df6..7ef62ae5a2 100644
--- a/django/db/utils.py
+++ b/django/db/utils.py
@@ -3,14 +3,15 @@ from importlib import import_module
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
+
# For backwards compatibility with Django < 3.2
from django.utils.connection import ConnectionDoesNotExist # NOQA: F401
from django.utils.connection import BaseConnectionHandler
from django.utils.functional import cached_property
from django.utils.module_loading import import_string
-DEFAULT_DB_ALIAS = 'default'
-DJANGO_VERSION_PICKLE_KEY = '_django_version'
+DEFAULT_DB_ALIAS = "default"
+DJANGO_VERSION_PICKLE_KEY = "_django_version"
class Error(Exception):
@@ -70,15 +71,15 @@ class DatabaseErrorWrapper:
if exc_type is None:
return
for dj_exc_type in (
- DataError,
- OperationalError,
- IntegrityError,
- InternalError,
- ProgrammingError,
- NotSupportedError,
- DatabaseError,
- InterfaceError,
- Error,
+ DataError,
+ OperationalError,
+ IntegrityError,
+ InternalError,
+ ProgrammingError,
+ NotSupportedError,
+ DatabaseError,
+ InterfaceError,
+ Error,
):
db_exc_type = getattr(self.wrapper.Database, dj_exc_type.__name__)
if issubclass(exc_type, db_exc_type):
@@ -95,6 +96,7 @@ class DatabaseErrorWrapper:
def inner(*args, **kwargs):
with self:
return func(*args, **kwargs)
+
return inner
@@ -104,20 +106,22 @@ def load_backend(backend_name):
backend name, or raise an error if it doesn't exist.
"""
# This backend was renamed in Django 1.9.
- if backend_name == 'django.db.backends.postgresql_psycopg2':
- backend_name = 'django.db.backends.postgresql'
+ if backend_name == "django.db.backends.postgresql_psycopg2":
+ backend_name = "django.db.backends.postgresql"
try:
- return import_module('%s.base' % backend_name)
+ return import_module("%s.base" % backend_name)
except ImportError as e_user:
# The database backend wasn't found. Display a helpful error message
# listing all built-in database backends.
import django.db.backends
+
builtin_backends = [
- name for _, name, ispkg in pkgutil.iter_modules(django.db.backends.__path__)
- if ispkg and name not in {'base', 'dummy'}
+ name
+ for _, name, ispkg in pkgutil.iter_modules(django.db.backends.__path__)
+ if ispkg and name not in {"base", "dummy"}
]
- if backend_name not in ['django.db.backends.%s' % b for b in builtin_backends]:
+ if backend_name not in ["django.db.backends.%s" % b for b in builtin_backends]:
backend_reprs = map(repr, sorted(builtin_backends))
raise ImproperlyConfigured(
"%r isn't an available database backend or couldn't be "
@@ -132,7 +136,7 @@ def load_backend(backend_name):
class ConnectionHandler(BaseConnectionHandler):
- settings_name = 'DATABASES'
+ settings_name = "DATABASES"
# Connections needs to still be an actual thread local, as it's truly
# thread-critical. Database backends should use @async_unsafe to protect
# their code from async contexts, but this will give those contexts
@@ -143,13 +147,13 @@ class ConnectionHandler(BaseConnectionHandler):
def configure_settings(self, databases):
databases = super().configure_settings(databases)
if databases == {}:
- databases[DEFAULT_DB_ALIAS] = {'ENGINE': 'django.db.backends.dummy'}
+ databases[DEFAULT_DB_ALIAS] = {"ENGINE": "django.db.backends.dummy"}
elif DEFAULT_DB_ALIAS not in databases:
raise ImproperlyConfigured(
f"You must define a '{DEFAULT_DB_ALIAS}' database."
)
elif databases[DEFAULT_DB_ALIAS] == {}:
- databases[DEFAULT_DB_ALIAS]['ENGINE'] = 'django.db.backends.dummy'
+ databases[DEFAULT_DB_ALIAS]["ENGINE"] = "django.db.backends.dummy"
return databases
@property
@@ -166,17 +170,17 @@ class ConnectionHandler(BaseConnectionHandler):
except KeyError:
raise self.exception_class(f"The connection '{alias}' doesn't exist.")
- conn.setdefault('ATOMIC_REQUESTS', False)
- conn.setdefault('AUTOCOMMIT', True)
- conn.setdefault('ENGINE', 'django.db.backends.dummy')
- if conn['ENGINE'] == 'django.db.backends.' or not conn['ENGINE']:
- conn['ENGINE'] = 'django.db.backends.dummy'
- conn.setdefault('CONN_MAX_AGE', 0)
- conn.setdefault('CONN_HEALTH_CHECKS', False)
- conn.setdefault('OPTIONS', {})
- conn.setdefault('TIME_ZONE', None)
- for setting in ['NAME', 'USER', 'PASSWORD', 'HOST', 'PORT']:
- conn.setdefault(setting, '')
+ conn.setdefault("ATOMIC_REQUESTS", False)
+ conn.setdefault("AUTOCOMMIT", True)
+ conn.setdefault("ENGINE", "django.db.backends.dummy")
+ if conn["ENGINE"] == "django.db.backends." or not conn["ENGINE"]:
+ conn["ENGINE"] = "django.db.backends.dummy"
+ conn.setdefault("CONN_MAX_AGE", 0)
+ conn.setdefault("CONN_HEALTH_CHECKS", False)
+ conn.setdefault("OPTIONS", {})
+ conn.setdefault("TIME_ZONE", None)
+ for setting in ["NAME", "USER", "PASSWORD", "HOST", "PORT"]:
+ conn.setdefault(setting, "")
def prepare_test_settings(self, alias):
"""
@@ -187,13 +191,13 @@ class ConnectionHandler(BaseConnectionHandler):
except KeyError:
raise self.exception_class(f"The connection '{alias}' doesn't exist.")
- test_settings = conn.setdefault('TEST', {})
+ test_settings = conn.setdefault("TEST", {})
default_test_settings = [
- ('CHARSET', None),
- ('COLLATION', None),
- ('MIGRATE', True),
- ('MIRROR', None),
- ('NAME', None),
+ ("CHARSET", None),
+ ("COLLATION", None),
+ ("MIGRATE", True),
+ ("MIRROR", None),
+ ("NAME", None),
]
for key, value in default_test_settings:
test_settings.setdefault(key, value)
@@ -202,7 +206,7 @@ class ConnectionHandler(BaseConnectionHandler):
self.ensure_defaults(alias)
self.prepare_test_settings(alias)
db = self.databases[alias]
- backend = load_backend(db['ENGINE'])
+ backend = load_backend(db["ENGINE"])
return backend.DatabaseWrapper(db, alias)
def close_all(self):
@@ -247,14 +251,15 @@ class ConnectionRouter:
chosen_db = method(model, **hints)
if chosen_db:
return chosen_db
- instance = hints.get('instance')
+ instance = hints.get("instance")
if instance is not None and instance._state.db:
return instance._state.db
return DEFAULT_DB_ALIAS
+
return _route_db
- db_for_read = _router_func('db_for_read')
- db_for_write = _router_func('db_for_write')
+ db_for_read = _router_func("db_for_read")
+ db_for_write = _router_func("db_for_write")
def allow_relation(self, obj1, obj2, **hints):
for router in self.routers: