diff options
| author | django-bot <ops@djangoproject.com> | 2022-02-03 20:24:19 +0100 |
|---|---|---|
| committer | Mariusz Felisiak <felisiak.mariusz@gmail.com> | 2022-02-07 20:37:05 +0100 |
| commit | 9c19aff7c7561e3a82978a272ecdaad40dda5c00 (patch) | |
| tree | f0506b668a013d0063e5fba3dbf4863b466713ba /django/db/backends/sqlite3 | |
| parent | f68fa8b45dfac545cfc4111d4e52804c86db68d3 (diff) | |
Refs #33476 -- Reformatted code with Black.
Diffstat (limited to 'django/db/backends/sqlite3')
| -rw-r--r-- | django/db/backends/sqlite3/_functions.py | 234 | ||||
| -rw-r--r-- | django/db/backends/sqlite3/base.py | 224 | ||||
| -rw-r--r-- | django/db/backends/sqlite3/client.py | 4 | ||||
| -rw-r--r-- | django/db/backends/sqlite3/creation.py | 49 | ||||
| -rw-r--r-- | django/db/backends/sqlite3/features.py | 76 | ||||
| -rw-r--r-- | django/db/backends/sqlite3/introspection.py | 241 | ||||
| -rw-r--r-- | django/db/backends/sqlite3/operations.py | 183 | ||||
| -rw-r--r-- | django/db/backends/sqlite3/schema.py | 307 |
8 files changed, 774 insertions, 544 deletions
diff --git a/django/db/backends/sqlite3/_functions.py b/django/db/backends/sqlite3/_functions.py index 3529a99dd6..86684c1907 100644 --- a/django/db/backends/sqlite3/_functions.py +++ b/django/db/backends/sqlite3/_functions.py @@ -7,14 +7,30 @@ import statistics from datetime import timedelta from hashlib import sha1, sha224, sha256, sha384, sha512 from math import ( - acos, asin, atan, atan2, ceil, cos, degrees, exp, floor, fmod, log, pi, - radians, sin, sqrt, tan, + acos, + asin, + atan, + atan2, + ceil, + cos, + degrees, + exp, + floor, + fmod, + log, + pi, + radians, + sin, + sqrt, + tan, ) from re import search as re_search from django.db.backends.base.base import timezone_constructor from django.db.backends.utils import ( - split_tzname_delta, typecast_time, typecast_timestamp, + split_tzname_delta, + typecast_time, + typecast_timestamp, ) from django.utils import timezone from django.utils.crypto import md5 @@ -26,56 +42,62 @@ def register(connection): connection.create_function, deterministic=True, ) - create_deterministic_function('django_date_extract', 2, _sqlite_datetime_extract) - create_deterministic_function('django_date_trunc', 4, _sqlite_date_trunc) - create_deterministic_function('django_datetime_cast_date', 3, _sqlite_datetime_cast_date) - create_deterministic_function('django_datetime_cast_time', 3, _sqlite_datetime_cast_time) - create_deterministic_function('django_datetime_extract', 4, _sqlite_datetime_extract) - create_deterministic_function('django_datetime_trunc', 4, _sqlite_datetime_trunc) - create_deterministic_function('django_time_extract', 2, _sqlite_time_extract) - create_deterministic_function('django_time_trunc', 4, _sqlite_time_trunc) - create_deterministic_function('django_time_diff', 2, _sqlite_time_diff) - create_deterministic_function('django_timestamp_diff', 2, _sqlite_timestamp_diff) - create_deterministic_function('django_format_dtdelta', 3, _sqlite_format_dtdelta) - create_deterministic_function('regexp', 2, _sqlite_regexp) - create_deterministic_function('ACOS', 1, _sqlite_acos) - create_deterministic_function('ASIN', 1, _sqlite_asin) - create_deterministic_function('ATAN', 1, _sqlite_atan) - create_deterministic_function('ATAN2', 2, _sqlite_atan2) - create_deterministic_function('BITXOR', 2, _sqlite_bitxor) - create_deterministic_function('CEILING', 1, _sqlite_ceiling) - create_deterministic_function('COS', 1, _sqlite_cos) - create_deterministic_function('COT', 1, _sqlite_cot) - create_deterministic_function('DEGREES', 1, _sqlite_degrees) - create_deterministic_function('EXP', 1, _sqlite_exp) - create_deterministic_function('FLOOR', 1, _sqlite_floor) - create_deterministic_function('LN', 1, _sqlite_ln) - create_deterministic_function('LOG', 2, _sqlite_log) - create_deterministic_function('LPAD', 3, _sqlite_lpad) - create_deterministic_function('MD5', 1, _sqlite_md5) - create_deterministic_function('MOD', 2, _sqlite_mod) - create_deterministic_function('PI', 0, _sqlite_pi) - create_deterministic_function('POWER', 2, _sqlite_power) - create_deterministic_function('RADIANS', 1, _sqlite_radians) - create_deterministic_function('REPEAT', 2, _sqlite_repeat) - create_deterministic_function('REVERSE', 1, _sqlite_reverse) - create_deterministic_function('RPAD', 3, _sqlite_rpad) - create_deterministic_function('SHA1', 1, _sqlite_sha1) - create_deterministic_function('SHA224', 1, _sqlite_sha224) - create_deterministic_function('SHA256', 1, _sqlite_sha256) - create_deterministic_function('SHA384', 1, _sqlite_sha384) - create_deterministic_function('SHA512', 1, _sqlite_sha512) - create_deterministic_function('SIGN', 1, _sqlite_sign) - create_deterministic_function('SIN', 1, _sqlite_sin) - create_deterministic_function('SQRT', 1, _sqlite_sqrt) - create_deterministic_function('TAN', 1, _sqlite_tan) + create_deterministic_function("django_date_extract", 2, _sqlite_datetime_extract) + create_deterministic_function("django_date_trunc", 4, _sqlite_date_trunc) + create_deterministic_function( + "django_datetime_cast_date", 3, _sqlite_datetime_cast_date + ) + create_deterministic_function( + "django_datetime_cast_time", 3, _sqlite_datetime_cast_time + ) + create_deterministic_function( + "django_datetime_extract", 4, _sqlite_datetime_extract + ) + create_deterministic_function("django_datetime_trunc", 4, _sqlite_datetime_trunc) + create_deterministic_function("django_time_extract", 2, _sqlite_time_extract) + create_deterministic_function("django_time_trunc", 4, _sqlite_time_trunc) + create_deterministic_function("django_time_diff", 2, _sqlite_time_diff) + create_deterministic_function("django_timestamp_diff", 2, _sqlite_timestamp_diff) + create_deterministic_function("django_format_dtdelta", 3, _sqlite_format_dtdelta) + create_deterministic_function("regexp", 2, _sqlite_regexp) + create_deterministic_function("ACOS", 1, _sqlite_acos) + create_deterministic_function("ASIN", 1, _sqlite_asin) + create_deterministic_function("ATAN", 1, _sqlite_atan) + create_deterministic_function("ATAN2", 2, _sqlite_atan2) + create_deterministic_function("BITXOR", 2, _sqlite_bitxor) + create_deterministic_function("CEILING", 1, _sqlite_ceiling) + create_deterministic_function("COS", 1, _sqlite_cos) + create_deterministic_function("COT", 1, _sqlite_cot) + create_deterministic_function("DEGREES", 1, _sqlite_degrees) + create_deterministic_function("EXP", 1, _sqlite_exp) + create_deterministic_function("FLOOR", 1, _sqlite_floor) + create_deterministic_function("LN", 1, _sqlite_ln) + create_deterministic_function("LOG", 2, _sqlite_log) + create_deterministic_function("LPAD", 3, _sqlite_lpad) + create_deterministic_function("MD5", 1, _sqlite_md5) + create_deterministic_function("MOD", 2, _sqlite_mod) + create_deterministic_function("PI", 0, _sqlite_pi) + create_deterministic_function("POWER", 2, _sqlite_power) + create_deterministic_function("RADIANS", 1, _sqlite_radians) + create_deterministic_function("REPEAT", 2, _sqlite_repeat) + create_deterministic_function("REVERSE", 1, _sqlite_reverse) + create_deterministic_function("RPAD", 3, _sqlite_rpad) + create_deterministic_function("SHA1", 1, _sqlite_sha1) + create_deterministic_function("SHA224", 1, _sqlite_sha224) + create_deterministic_function("SHA256", 1, _sqlite_sha256) + create_deterministic_function("SHA384", 1, _sqlite_sha384) + create_deterministic_function("SHA512", 1, _sqlite_sha512) + create_deterministic_function("SIGN", 1, _sqlite_sign) + create_deterministic_function("SIN", 1, _sqlite_sin) + create_deterministic_function("SQRT", 1, _sqlite_sqrt) + create_deterministic_function("TAN", 1, _sqlite_tan) # Don't use the built-in RANDOM() function because it returns a value # in the range [-1 * 2^63, 2^63 - 1] instead of [0, 1). - connection.create_function('RAND', 0, random.random) - connection.create_aggregate('STDDEV_POP', 1, StdDevPop) - connection.create_aggregate('STDDEV_SAMP', 1, StdDevSamp) - connection.create_aggregate('VAR_POP', 1, VarPop) - connection.create_aggregate('VAR_SAMP', 1, VarSamp) + connection.create_function("RAND", 0, random.random) + connection.create_aggregate("STDDEV_POP", 1, StdDevPop) + connection.create_aggregate("STDDEV_SAMP", 1, StdDevSamp) + connection.create_aggregate("VAR_POP", 1, VarPop) + connection.create_aggregate("VAR_SAMP", 1, VarSamp) def _sqlite_datetime_parse(dt, tzname=None, conn_tzname=None): @@ -90,9 +112,9 @@ def _sqlite_datetime_parse(dt, tzname=None, conn_tzname=None): if tzname is not None and tzname != conn_tzname: tzname, sign, offset = split_tzname_delta(tzname) if offset: - hours, minutes = offset.split(':') + hours, minutes = offset.split(":") offset_delta = timedelta(hours=int(hours), minutes=int(minutes)) - dt += offset_delta if sign == '+' else -offset_delta + dt += offset_delta if sign == "+" else -offset_delta dt = timezone.localtime(dt, timezone_constructor(tzname)) return dt @@ -101,19 +123,19 @@ def _sqlite_date_trunc(lookup_type, dt, tzname, conn_tzname): dt = _sqlite_datetime_parse(dt, tzname, conn_tzname) if dt is None: return None - if lookup_type == 'year': - return f'{dt.year:04d}-01-01' - elif lookup_type == 'quarter': + if lookup_type == "year": + return f"{dt.year:04d}-01-01" + elif lookup_type == "quarter": month_in_quarter = dt.month - (dt.month - 1) % 3 - return f'{dt.year:04d}-{month_in_quarter:02d}-01' - elif lookup_type == 'month': - return f'{dt.year:04d}-{dt.month:02d}-01' - elif lookup_type == 'week': + return f"{dt.year:04d}-{month_in_quarter:02d}-01" + elif lookup_type == "month": + return f"{dt.year:04d}-{dt.month:02d}-01" + elif lookup_type == "week": dt = dt - timedelta(days=dt.weekday()) - return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d}' - elif lookup_type == 'day': - return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d}' - raise ValueError(f'Unsupported lookup type: {lookup_type!r}') + return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}" + elif lookup_type == "day": + return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}" + raise ValueError(f"Unsupported lookup type: {lookup_type!r}") def _sqlite_time_trunc(lookup_type, dt, tzname, conn_tzname): @@ -127,13 +149,13 @@ def _sqlite_time_trunc(lookup_type, dt, tzname, conn_tzname): return None else: dt = dt_parsed - if lookup_type == 'hour': - return f'{dt.hour:02d}:00:00' - elif lookup_type == 'minute': - return f'{dt.hour:02d}:{dt.minute:02d}:00' - elif lookup_type == 'second': - return f'{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}' - raise ValueError(f'Unsupported lookup type: {lookup_type!r}') + if lookup_type == "hour": + return f"{dt.hour:02d}:00:00" + elif lookup_type == "minute": + return f"{dt.hour:02d}:{dt.minute:02d}:00" + elif lookup_type == "second": + return f"{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}" + raise ValueError(f"Unsupported lookup type: {lookup_type!r}") def _sqlite_datetime_cast_date(dt, tzname, conn_tzname): @@ -154,15 +176,15 @@ def _sqlite_datetime_extract(lookup_type, dt, tzname=None, conn_tzname=None): dt = _sqlite_datetime_parse(dt, tzname, conn_tzname) if dt is None: return None - if lookup_type == 'week_day': + if lookup_type == "week_day": return (dt.isoweekday() % 7) + 1 - elif lookup_type == 'iso_week_day': + elif lookup_type == "iso_week_day": return dt.isoweekday() - elif lookup_type == 'week': + elif lookup_type == "week": return dt.isocalendar()[1] - elif lookup_type == 'quarter': + elif lookup_type == "quarter": return ceil(dt.month / 3) - elif lookup_type == 'iso_year': + elif lookup_type == "iso_year": return dt.isocalendar()[0] else: return getattr(dt, lookup_type) @@ -172,25 +194,25 @@ def _sqlite_datetime_trunc(lookup_type, dt, tzname, conn_tzname): dt = _sqlite_datetime_parse(dt, tzname, conn_tzname) if dt is None: return None - if lookup_type == 'year': - return f'{dt.year:04d}-01-01 00:00:00' - elif lookup_type == 'quarter': + if lookup_type == "year": + return f"{dt.year:04d}-01-01 00:00:00" + elif lookup_type == "quarter": month_in_quarter = dt.month - (dt.month - 1) % 3 - return f'{dt.year:04d}-{month_in_quarter:02d}-01 00:00:00' - elif lookup_type == 'month': - return f'{dt.year:04d}-{dt.month:02d}-01 00:00:00' - elif lookup_type == 'week': + return f"{dt.year:04d}-{month_in_quarter:02d}-01 00:00:00" + elif lookup_type == "month": + return f"{dt.year:04d}-{dt.month:02d}-01 00:00:00" + elif lookup_type == "week": dt = dt - timedelta(days=dt.weekday()) - return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d} 00:00:00' - elif lookup_type == 'day': - return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d} 00:00:00' - elif lookup_type == 'hour': - return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d} {dt.hour:02d}:00:00' - elif lookup_type == 'minute': - return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d} {dt.hour:02d}:{dt.minute:02d}:00' - elif lookup_type == 'second': - return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d} {dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}' - raise ValueError(f'Unsupported lookup type: {lookup_type!r}') + return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} 00:00:00" + elif lookup_type == "day": + return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} 00:00:00" + elif lookup_type == "hour": + return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} {dt.hour:02d}:00:00" + elif lookup_type == "minute": + return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} {dt.hour:02d}:{dt.minute:02d}:00" + elif lookup_type == "second": + return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} {dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}" + raise ValueError(f"Unsupported lookup type: {lookup_type!r}") def _sqlite_time_extract(lookup_type, dt): @@ -204,7 +226,7 @@ def _sqlite_time_extract(lookup_type, dt): def _sqlite_prepare_dtdelta_param(conn, param): - if conn in ['+', '-']: + if conn in ["+", "-"]: if isinstance(param, int): return timedelta(0, 0, param) else: @@ -227,13 +249,13 @@ def _sqlite_format_dtdelta(connector, lhs, rhs): real_rhs = _sqlite_prepare_dtdelta_param(connector, rhs) except (ValueError, TypeError): return None - if connector == '+': + if connector == "+": # typecast_timestamp() returns a date or a datetime without timezone. # It will be formatted as "%Y-%m-%d" or "%Y-%m-%d %H:%M:%S[.%f]" out = str(real_lhs + real_rhs) - elif connector == '-': + elif connector == "-": out = str(real_lhs - real_rhs) - elif connector == '*': + elif connector == "*": out = real_lhs * real_rhs else: out = real_lhs / real_rhs @@ -246,14 +268,14 @@ def _sqlite_time_diff(lhs, rhs): left = typecast_time(lhs) right = typecast_time(rhs) return ( - (left.hour * 60 * 60 * 1000000) + - (left.minute * 60 * 1000000) + - (left.second * 1000000) + - (left.microsecond) - - (right.hour * 60 * 60 * 1000000) - - (right.minute * 60 * 1000000) - - (right.second * 1000000) - - (right.microsecond) + (left.hour * 60 * 60 * 1000000) + + (left.minute * 60 * 1000000) + + (left.second * 1000000) + + (left.microsecond) + - (right.hour * 60 * 60 * 1000000) + - (right.minute * 60 * 1000000) + - (right.second * 1000000) + - (right.microsecond) ) @@ -380,7 +402,7 @@ def _sqlite_pi(): def _sqlite_power(x, y): if x is None or y is None: return None - return x ** y + return x**y def _sqlite_radians(x): diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index 4343ea180e..5bcd61eb96 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -32,13 +32,13 @@ def decoder(conv_func): def check_sqlite_version(): if Database.sqlite_version_info < (3, 9, 0): raise ImproperlyConfigured( - 'SQLite 3.9.0 or later is required (found %s).' % Database.sqlite_version + "SQLite 3.9.0 or later is required (found %s)." % Database.sqlite_version ) check_sqlite_version() -Database.register_converter("bool", b'1'.__eq__) +Database.register_converter("bool", b"1".__eq__) Database.register_converter("time", decoder(parse_time)) Database.register_converter("datetime", decoder(parse_datetime)) Database.register_converter("timestamp", decoder(parse_datetime)) @@ -47,69 +47,69 @@ Database.register_adapter(decimal.Decimal, str) class DatabaseWrapper(BaseDatabaseWrapper): - vendor = 'sqlite' - display_name = 'SQLite' + vendor = "sqlite" + display_name = "SQLite" # SQLite doesn't actually support most of these types, but it "does the right # thing" given more verbose field definitions, so leave them as is so that # schema inspection is more useful. data_types = { - 'AutoField': 'integer', - 'BigAutoField': 'integer', - 'BinaryField': 'BLOB', - 'BooleanField': 'bool', - 'CharField': 'varchar(%(max_length)s)', - 'DateField': 'date', - 'DateTimeField': 'datetime', - 'DecimalField': 'decimal', - 'DurationField': 'bigint', - 'FileField': 'varchar(%(max_length)s)', - 'FilePathField': 'varchar(%(max_length)s)', - 'FloatField': 'real', - 'IntegerField': 'integer', - 'BigIntegerField': 'bigint', - 'IPAddressField': 'char(15)', - 'GenericIPAddressField': 'char(39)', - 'JSONField': 'text', - 'OneToOneField': 'integer', - 'PositiveBigIntegerField': 'bigint unsigned', - 'PositiveIntegerField': 'integer unsigned', - 'PositiveSmallIntegerField': 'smallint unsigned', - 'SlugField': 'varchar(%(max_length)s)', - 'SmallAutoField': 'integer', - 'SmallIntegerField': 'smallint', - 'TextField': 'text', - 'TimeField': 'time', - 'UUIDField': 'char(32)', + "AutoField": "integer", + "BigAutoField": "integer", + "BinaryField": "BLOB", + "BooleanField": "bool", + "CharField": "varchar(%(max_length)s)", + "DateField": "date", + "DateTimeField": "datetime", + "DecimalField": "decimal", + "DurationField": "bigint", + "FileField": "varchar(%(max_length)s)", + "FilePathField": "varchar(%(max_length)s)", + "FloatField": "real", + "IntegerField": "integer", + "BigIntegerField": "bigint", + "IPAddressField": "char(15)", + "GenericIPAddressField": "char(39)", + "JSONField": "text", + "OneToOneField": "integer", + "PositiveBigIntegerField": "bigint unsigned", + "PositiveIntegerField": "integer unsigned", + "PositiveSmallIntegerField": "smallint unsigned", + "SlugField": "varchar(%(max_length)s)", + "SmallAutoField": "integer", + "SmallIntegerField": "smallint", + "TextField": "text", + "TimeField": "time", + "UUIDField": "char(32)", } data_type_check_constraints = { - 'PositiveBigIntegerField': '"%(column)s" >= 0', - 'JSONField': '(JSON_VALID("%(column)s") OR "%(column)s" IS NULL)', - 'PositiveIntegerField': '"%(column)s" >= 0', - 'PositiveSmallIntegerField': '"%(column)s" >= 0', + "PositiveBigIntegerField": '"%(column)s" >= 0', + "JSONField": '(JSON_VALID("%(column)s") OR "%(column)s" IS NULL)', + "PositiveIntegerField": '"%(column)s" >= 0', + "PositiveSmallIntegerField": '"%(column)s" >= 0', } data_types_suffix = { - 'AutoField': 'AUTOINCREMENT', - 'BigAutoField': 'AUTOINCREMENT', - 'SmallAutoField': 'AUTOINCREMENT', + "AutoField": "AUTOINCREMENT", + "BigAutoField": "AUTOINCREMENT", + "SmallAutoField": "AUTOINCREMENT", } # SQLite requires LIKE statements to include an ESCAPE clause if the value # being escaped has a percent or underscore in it. # See https://www.sqlite.org/lang_expr.html for an explanation. operators = { - 'exact': '= %s', - 'iexact': "LIKE %s ESCAPE '\\'", - 'contains': "LIKE %s ESCAPE '\\'", - 'icontains': "LIKE %s ESCAPE '\\'", - 'regex': 'REGEXP %s', - 'iregex': "REGEXP '(?i)' || %s", - 'gt': '> %s', - 'gte': '>= %s', - 'lt': '< %s', - 'lte': '<= %s', - 'startswith': "LIKE %s ESCAPE '\\'", - 'endswith': "LIKE %s ESCAPE '\\'", - 'istartswith': "LIKE %s ESCAPE '\\'", - 'iendswith': "LIKE %s ESCAPE '\\'", + "exact": "= %s", + "iexact": "LIKE %s ESCAPE '\\'", + "contains": "LIKE %s ESCAPE '\\'", + "icontains": "LIKE %s ESCAPE '\\'", + "regex": "REGEXP %s", + "iregex": "REGEXP '(?i)' || %s", + "gt": "> %s", + "gte": ">= %s", + "lt": "< %s", + "lte": "<= %s", + "startswith": "LIKE %s ESCAPE '\\'", + "endswith": "LIKE %s ESCAPE '\\'", + "istartswith": "LIKE %s ESCAPE '\\'", + "iendswith": "LIKE %s ESCAPE '\\'", } # The patterns below are used to generate SQL pattern lookup clauses when @@ -122,12 +122,12 @@ class DatabaseWrapper(BaseDatabaseWrapper): # the LIKE operator. pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\', '\\'), '%%', '\%%'), '_', '\_')" pattern_ops = { - 'contains': r"LIKE '%%' || {} || '%%' ESCAPE '\'", - 'icontains': r"LIKE '%%' || UPPER({}) || '%%' ESCAPE '\'", - 'startswith': r"LIKE {} || '%%' ESCAPE '\'", - 'istartswith': r"LIKE UPPER({}) || '%%' ESCAPE '\'", - 'endswith': r"LIKE '%%' || {} ESCAPE '\'", - 'iendswith': r"LIKE '%%' || UPPER({}) ESCAPE '\'", + "contains": r"LIKE '%%' || {} || '%%' ESCAPE '\'", + "icontains": r"LIKE '%%' || UPPER({}) || '%%' ESCAPE '\'", + "startswith": r"LIKE {} || '%%' ESCAPE '\'", + "istartswith": r"LIKE UPPER({}) || '%%' ESCAPE '\'", + "endswith": r"LIKE '%%' || {} ESCAPE '\'", + "iendswith": r"LIKE '%%' || UPPER({}) ESCAPE '\'", } Database = Database @@ -141,14 +141,15 @@ class DatabaseWrapper(BaseDatabaseWrapper): def get_connection_params(self): settings_dict = self.settings_dict - if not settings_dict['NAME']: + if not settings_dict["NAME"]: raise ImproperlyConfigured( "settings.DATABASES is improperly configured. " - "Please supply the NAME value.") + "Please supply the NAME value." + ) kwargs = { - 'database': settings_dict['NAME'], - 'detect_types': Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES, - **settings_dict['OPTIONS'], + "database": settings_dict["NAME"], + "detect_types": Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES, + **settings_dict["OPTIONS"], } # Always allow the underlying SQLite connection to be shareable # between multiple threads. The safe-guarding will be handled at a @@ -156,15 +157,15 @@ class DatabaseWrapper(BaseDatabaseWrapper): # property. This is necessary as the shareability is disabled by # default in pysqlite and it cannot be changed once a connection is # opened. - if 'check_same_thread' in kwargs and kwargs['check_same_thread']: + if "check_same_thread" in kwargs and kwargs["check_same_thread"]: warnings.warn( - 'The `check_same_thread` option was provided and set to ' - 'True. It will be overridden with False. Use the ' - '`DatabaseWrapper.allow_thread_sharing` property instead ' - 'for controlling thread shareability.', - RuntimeWarning + "The `check_same_thread` option was provided and set to " + "True. It will be overridden with False. Use the " + "`DatabaseWrapper.allow_thread_sharing` property instead " + "for controlling thread shareability.", + RuntimeWarning, ) - kwargs.update({'check_same_thread': False, 'uri': True}) + kwargs.update({"check_same_thread": False, "uri": True}) return kwargs @async_unsafe @@ -172,10 +173,10 @@ class DatabaseWrapper(BaseDatabaseWrapper): conn = Database.connect(**conn_params) register_functions(conn) - conn.execute('PRAGMA foreign_keys = ON') + conn.execute("PRAGMA foreign_keys = ON") # The macOS bundled SQLite defaults legacy_alter_table ON, which # prevents atomic table renames (feature supports_atomic_references_rename) - conn.execute('PRAGMA legacy_alter_table = OFF') + conn.execute("PRAGMA legacy_alter_table = OFF") return conn def init_connection_state(self): @@ -207,7 +208,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): else: # sqlite3's internal default is ''. It's different from None. # See Modules/_sqlite/connection.c. - level = '' + level = "" # 'isolation_level' is a misleading API. # SQLite always runs at the SERIALIZABLE isolation level. with self.wrap_database_errors: @@ -215,16 +216,16 @@ class DatabaseWrapper(BaseDatabaseWrapper): def disable_constraint_checking(self): with self.cursor() as cursor: - cursor.execute('PRAGMA foreign_keys = OFF') + cursor.execute("PRAGMA foreign_keys = OFF") # Foreign key constraints cannot be turned off while in a multi- # statement transaction. Fetch the current state of the pragma # to determine if constraints are effectively disabled. - enabled = cursor.execute('PRAGMA foreign_keys').fetchone()[0] + enabled = cursor.execute("PRAGMA foreign_keys").fetchone()[0] return not bool(enabled) def enable_constraint_checking(self): with self.cursor() as cursor: - cursor.execute('PRAGMA foreign_keys = ON') + cursor.execute("PRAGMA foreign_keys = ON") def check_constraints(self, table_names=None): """ @@ -237,24 +238,32 @@ class DatabaseWrapper(BaseDatabaseWrapper): if self.features.supports_pragma_foreign_key_check: with self.cursor() as cursor: if table_names is None: - violations = cursor.execute('PRAGMA foreign_key_check').fetchall() + violations = cursor.execute("PRAGMA foreign_key_check").fetchall() else: violations = chain.from_iterable( cursor.execute( - 'PRAGMA foreign_key_check(%s)' + "PRAGMA foreign_key_check(%s)" % self.ops.quote_name(table_name) ).fetchall() for table_name in table_names ) # See https://www.sqlite.org/pragma.html#pragma_foreign_key_check - for table_name, rowid, referenced_table_name, foreign_key_index in violations: + for ( + table_name, + rowid, + referenced_table_name, + foreign_key_index, + ) in violations: foreign_key = cursor.execute( - 'PRAGMA foreign_key_list(%s)' % self.ops.quote_name(table_name) + "PRAGMA foreign_key_list(%s)" % self.ops.quote_name(table_name) ).fetchall()[foreign_key_index] column_name, referenced_column_name = foreign_key[3:5] - primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name) + primary_key_column_name = self.introspection.get_primary_key_column( + cursor, table_name + ) primary_key_value, bad_value = cursor.execute( - 'SELECT %s, %s FROM %s WHERE rowid = %%s' % ( + "SELECT %s, %s FROM %s WHERE rowid = %%s" + % ( self.ops.quote_name(primary_key_column_name), self.ops.quote_name(column_name), self.ops.quote_name(table_name), @@ -264,9 +273,15 @@ class DatabaseWrapper(BaseDatabaseWrapper): raise IntegrityError( "The row in table '%s' with primary key '%s' has an " "invalid foreign key: %s.%s contains a value '%s' that " - "does not have a corresponding value in %s.%s." % ( - table_name, primary_key_value, table_name, column_name, - bad_value, referenced_table_name, referenced_column_name + "does not have a corresponding value in %s.%s." + % ( + table_name, + primary_key_value, + table_name, + column_name, + bad_value, + referenced_table_name, + referenced_column_name, ) ) else: @@ -274,11 +289,16 @@ class DatabaseWrapper(BaseDatabaseWrapper): if table_names is None: table_names = self.introspection.table_names(cursor) for table_name in table_names: - primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name) + primary_key_column_name = self.introspection.get_primary_key_column( + cursor, table_name + ) if not primary_key_column_name: continue relations = self.introspection.get_relations(cursor, table_name) - for column_name, (referenced_column_name, referenced_table_name) in relations: + for column_name, ( + referenced_column_name, + referenced_table_name, + ) in relations: cursor.execute( """ SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING @@ -287,18 +307,29 @@ class DatabaseWrapper(BaseDatabaseWrapper): WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL """ % ( - primary_key_column_name, column_name, table_name, - referenced_table_name, column_name, referenced_column_name, - column_name, referenced_column_name, + primary_key_column_name, + column_name, + table_name, + referenced_table_name, + column_name, + referenced_column_name, + column_name, + referenced_column_name, ) ) for bad_row in cursor.fetchall(): raise IntegrityError( "The row in table '%s' with primary key '%s' has an " "invalid foreign key: %s.%s contains a value '%s' that " - "does not have a corresponding value in %s.%s." % ( - table_name, bad_row[0], table_name, column_name, - bad_row[1], referenced_table_name, referenced_column_name, + "does not have a corresponding value in %s.%s." + % ( + table_name, + bad_row[0], + table_name, + column_name, + bad_row[1], + referenced_table_name, + referenced_column_name, ) ) @@ -315,10 +346,10 @@ class DatabaseWrapper(BaseDatabaseWrapper): self.cursor().execute("BEGIN") def is_in_memory_db(self): - return self.creation.is_in_memory_db(self.settings_dict['NAME']) + return self.creation.is_in_memory_db(self.settings_dict["NAME"]) -FORMAT_QMARK_REGEX = _lazy_re_compile(r'(?<!%)%s') +FORMAT_QMARK_REGEX = _lazy_re_compile(r"(?<!%)%s") class SQLiteCursorWrapper(Database.Cursor): @@ -327,6 +358,7 @@ class SQLiteCursorWrapper(Database.Cursor): This fixes it -- but note that if you want to use a literal "%s" in a query, you'll need to use "%%s". """ + def execute(self, query, params=None): if params is None: return Database.Cursor.execute(self, query) @@ -338,4 +370,4 @@ class SQLiteCursorWrapper(Database.Cursor): return Database.Cursor.executemany(self, query, param_list) def convert_query(self, query): - return FORMAT_QMARK_REGEX.sub('?', query).replace('%%', '%') + return FORMAT_QMARK_REGEX.sub("?", query).replace("%%", "%") diff --git a/django/db/backends/sqlite3/client.py b/django/db/backends/sqlite3/client.py index 69b9568db3..7cee35dc81 100644 --- a/django/db/backends/sqlite3/client.py +++ b/django/db/backends/sqlite3/client.py @@ -2,9 +2,9 @@ from django.db.backends.base.client import BaseDatabaseClient class DatabaseClient(BaseDatabaseClient): - executable_name = 'sqlite3' + executable_name = "sqlite3" @classmethod def settings_to_cmd_args_env(cls, settings_dict, parameters): - args = [cls.executable_name, settings_dict['NAME'], *parameters] + args = [cls.executable_name, settings_dict["NAME"], *parameters] return args, None diff --git a/django/db/backends/sqlite3/creation.py b/django/db/backends/sqlite3/creation.py index 4a4046c670..9d8d4a63ad 100644 --- a/django/db/backends/sqlite3/creation.py +++ b/django/db/backends/sqlite3/creation.py @@ -7,17 +7,16 @@ from django.db.backends.base.creation import BaseDatabaseCreation class DatabaseCreation(BaseDatabaseCreation): - @staticmethod def is_in_memory_db(database_name): return not isinstance(database_name, Path) and ( - database_name == ':memory:' or 'mode=memory' in database_name + database_name == ":memory:" or "mode=memory" in database_name ) def _get_test_db_name(self): - test_database_name = self.connection.settings_dict['TEST']['NAME'] or ':memory:' - if test_database_name == ':memory:': - return 'file:memorydb_%s?mode=memory&cache=shared' % self.connection.alias + test_database_name = self.connection.settings_dict["TEST"]["NAME"] or ":memory:" + if test_database_name == ":memory:": + return "file:memorydb_%s?mode=memory&cache=shared" % self.connection.alias return test_database_name def _create_test_db(self, verbosity, autoclobber, keepdb=False): @@ -28,38 +27,39 @@ class DatabaseCreation(BaseDatabaseCreation): if not self.is_in_memory_db(test_database_name): # Erase the old test database if verbosity >= 1: - self.log('Destroying old test database for alias %s...' % ( - self._get_database_display_str(verbosity, test_database_name), - )) + self.log( + "Destroying old test database for alias %s..." + % (self._get_database_display_str(verbosity, test_database_name),) + ) if os.access(test_database_name, os.F_OK): if not autoclobber: confirm = input( "Type 'yes' if you would like to try deleting the test " "database '%s', or 'no' to cancel: " % test_database_name ) - if autoclobber or confirm == 'yes': + if autoclobber or confirm == "yes": try: os.remove(test_database_name) except Exception as e: - self.log('Got an error deleting the old test database: %s' % e) + self.log("Got an error deleting the old test database: %s" % e) sys.exit(2) else: - self.log('Tests cancelled.') + self.log("Tests cancelled.") sys.exit(1) return test_database_name def get_test_db_clone_settings(self, suffix): orig_settings_dict = self.connection.settings_dict - source_database_name = orig_settings_dict['NAME'] + source_database_name = orig_settings_dict["NAME"] if self.is_in_memory_db(source_database_name): return orig_settings_dict else: - root, ext = os.path.splitext(orig_settings_dict['NAME']) - return {**orig_settings_dict, 'NAME': '{}_{}{}'.format(root, suffix, ext)} + root, ext = os.path.splitext(orig_settings_dict["NAME"]) + return {**orig_settings_dict, "NAME": "{}_{}{}".format(root, suffix, ext)} def _clone_test_db(self, suffix, verbosity, keepdb=False): - source_database_name = self.connection.settings_dict['NAME'] - target_database_name = self.get_test_db_clone_settings(suffix)['NAME'] + source_database_name = self.connection.settings_dict["NAME"] + target_database_name = self.get_test_db_clone_settings(suffix)["NAME"] # Forking automatically makes a copy of an in-memory database. if not self.is_in_memory_db(source_database_name): # Erase the old test database @@ -67,18 +67,23 @@ class DatabaseCreation(BaseDatabaseCreation): if keepdb: return if verbosity >= 1: - self.log('Destroying old test database for alias %s...' % ( - self._get_database_display_str(verbosity, target_database_name), - )) + self.log( + "Destroying old test database for alias %s..." + % ( + self._get_database_display_str( + verbosity, target_database_name + ), + ) + ) try: os.remove(target_database_name) except Exception as e: - self.log('Got an error deleting the old test database: %s' % e) + self.log("Got an error deleting the old test database: %s" % e) sys.exit(2) try: shutil.copy(source_database_name, target_database_name) except Exception as e: - self.log('Got an error cloning the test database: %s' % e) + self.log("Got an error cloning the test database: %s" % e) sys.exit(2) def _destroy_test_db(self, test_database_name, verbosity): @@ -95,7 +100,7 @@ class DatabaseCreation(BaseDatabaseCreation): TEST NAME. See https://www.sqlite.org/inmemorydb.html """ test_database_name = self._get_test_db_name() - sig = [self.connection.settings_dict['NAME']] + sig = [self.connection.settings_dict["NAME"]] if self.is_in_memory_db(test_database_name): sig.append(self.connection.alias) else: diff --git a/django/db/backends/sqlite3/features.py b/django/db/backends/sqlite3/features.py index 153ce8d1d1..c076f0121e 100644 --- a/django/db/backends/sqlite3/features.py +++ b/django/db/backends/sqlite3/features.py @@ -43,49 +43,53 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_update_conflicts = Database.sqlite_version_info >= (3, 24, 0) supports_update_conflicts_with_target = supports_update_conflicts test_collations = { - 'ci': 'nocase', - 'cs': 'binary', - 'non_default': 'nocase', + "ci": "nocase", + "cs": "binary", + "non_default": "nocase", } django_test_expected_failures = { # The django_format_dtdelta() function doesn't properly handle mixed # Date/DateTime fields and timedeltas. - 'expressions.tests.FTimeDeltaTests.test_mixed_comparisons1', + "expressions.tests.FTimeDeltaTests.test_mixed_comparisons1", } @cached_property def django_test_skips(self): skips = { - 'SQLite stores values rounded to 15 significant digits.': { - 'model_fields.test_decimalfield.DecimalFieldTests.test_fetch_from_db_without_float_rounding', + "SQLite stores values rounded to 15 significant digits.": { + "model_fields.test_decimalfield.DecimalFieldTests.test_fetch_from_db_without_float_rounding", }, - 'SQLite naively remakes the table on field alteration.': { - 'schema.tests.SchemaTests.test_unique_no_unnecessary_fk_drops', - 'schema.tests.SchemaTests.test_unique_and_reverse_m2m', - 'schema.tests.SchemaTests.test_alter_field_default_doesnt_perform_queries', - 'schema.tests.SchemaTests.test_rename_column_renames_deferred_sql_references', + "SQLite naively remakes the table on field alteration.": { + "schema.tests.SchemaTests.test_unique_no_unnecessary_fk_drops", + "schema.tests.SchemaTests.test_unique_and_reverse_m2m", + "schema.tests.SchemaTests.test_alter_field_default_doesnt_perform_queries", + "schema.tests.SchemaTests.test_rename_column_renames_deferred_sql_references", }, "SQLite doesn't support negative precision for ROUND().": { - 'db_functions.math.test_round.RoundTests.test_null_with_negative_precision', - 'db_functions.math.test_round.RoundTests.test_decimal_with_negative_precision', - 'db_functions.math.test_round.RoundTests.test_float_with_negative_precision', - 'db_functions.math.test_round.RoundTests.test_integer_with_negative_precision', + "db_functions.math.test_round.RoundTests.test_null_with_negative_precision", + "db_functions.math.test_round.RoundTests.test_decimal_with_negative_precision", + "db_functions.math.test_round.RoundTests.test_float_with_negative_precision", + "db_functions.math.test_round.RoundTests.test_integer_with_negative_precision", }, } if Database.sqlite_version_info < (3, 27): - skips.update({ - 'Nondeterministic failure on SQLite < 3.27.': { - 'expressions_window.tests.WindowFunctionTests.test_subquery_row_range_rank', - }, - }) + skips.update( + { + "Nondeterministic failure on SQLite < 3.27.": { + "expressions_window.tests.WindowFunctionTests.test_subquery_row_range_rank", + }, + } + ) if self.connection.is_in_memory_db(): - skips.update({ - "the sqlite backend's close() method is a no-op when using an " - "in-memory database": { - 'servers.test_liveserverthread.LiveServerThreadTest.test_closes_connections', - 'servers.tests.LiveServerTestCloseConnectionTest.test_closes_connections', - }, - }) + skips.update( + { + "the sqlite backend's close() method is a no-op when using an " + "in-memory database": { + "servers.test_liveserverthread.LiveServerThreadTest.test_closes_connections", + "servers.tests.LiveServerTestCloseConnectionTest.test_closes_connections", + }, + } + ) return skips @cached_property @@ -94,12 +98,12 @@ class DatabaseFeatures(BaseDatabaseFeatures): @cached_property def introspected_field_types(self): - return{ + return { **super().introspected_field_types, - 'BigAutoField': 'AutoField', - 'DurationField': 'BigIntegerField', - 'GenericIPAddressField': 'CharField', - 'SmallAutoField': 'AutoField', + "BigAutoField": "AutoField", + "DurationField": "BigIntegerField", + "GenericIPAddressField": "CharField", + "SmallAutoField": "AutoField", } @cached_property @@ -112,11 +116,13 @@ class DatabaseFeatures(BaseDatabaseFeatures): return False return True - can_introspect_json_field = property(operator.attrgetter('supports_json_field')) - has_json_object_function = property(operator.attrgetter('supports_json_field')) + can_introspect_json_field = property(operator.attrgetter("supports_json_field")) + has_json_object_function = property(operator.attrgetter("supports_json_field")) @cached_property def can_return_columns_from_insert(self): return Database.sqlite_version_info >= (3, 35) - can_return_rows_from_bulk_insert = property(operator.attrgetter('can_return_columns_from_insert')) + can_return_rows_from_bulk_insert = property( + operator.attrgetter("can_return_columns_from_insert") + ) diff --git a/django/db/backends/sqlite3/introspection.py b/django/db/backends/sqlite3/introspection.py index 81884a7951..f5a5e81e9d 100644 --- a/django/db/backends/sqlite3/introspection.py +++ b/django/db/backends/sqlite3/introspection.py @@ -3,19 +3,21 @@ from collections import namedtuple import sqlparse from django.db import DatabaseError -from django.db.backends.base.introspection import ( - BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo, -) +from django.db.backends.base.introspection import BaseDatabaseIntrospection +from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo +from django.db.backends.base.introspection import TableInfo from django.db.models import Index from django.utils.regex_helper import _lazy_re_compile -FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('pk', 'has_json_constraint')) +FieldInfo = namedtuple( + "FieldInfo", BaseFieldInfo._fields + ("pk", "has_json_constraint") +) -field_size_re = _lazy_re_compile(r'^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$') +field_size_re = _lazy_re_compile(r"^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$") def get_field_size(name): - """ Extract the size number from a "varchar(11)" type name """ + """Extract the size number from a "varchar(11)" type name""" m = field_size_re.search(name) return int(m[1]) if m else None @@ -28,29 +30,29 @@ class FlexibleFieldLookupDict: # entries here because SQLite allows for anything and doesn't normalize the # field type; it uses whatever was given. base_data_types_reverse = { - 'bool': 'BooleanField', - 'boolean': 'BooleanField', - 'smallint': 'SmallIntegerField', - 'smallint unsigned': 'PositiveSmallIntegerField', - 'smallinteger': 'SmallIntegerField', - 'int': 'IntegerField', - 'integer': 'IntegerField', - 'bigint': 'BigIntegerField', - 'integer unsigned': 'PositiveIntegerField', - 'bigint unsigned': 'PositiveBigIntegerField', - 'decimal': 'DecimalField', - 'real': 'FloatField', - 'text': 'TextField', - 'char': 'CharField', - 'varchar': 'CharField', - 'blob': 'BinaryField', - 'date': 'DateField', - 'datetime': 'DateTimeField', - 'time': 'TimeField', + "bool": "BooleanField", + "boolean": "BooleanField", + "smallint": "SmallIntegerField", + "smallint unsigned": "PositiveSmallIntegerField", + "smallinteger": "SmallIntegerField", + "int": "IntegerField", + "integer": "IntegerField", + "bigint": "BigIntegerField", + "integer unsigned": "PositiveIntegerField", + "bigint unsigned": "PositiveBigIntegerField", + "decimal": "DecimalField", + "real": "FloatField", + "text": "TextField", + "char": "CharField", + "varchar": "CharField", + "blob": "BinaryField", + "date": "DateField", + "datetime": "DateTimeField", + "time": "TimeField", } def __getitem__(self, key): - key = key.lower().split('(', 1)[0].strip() + key = key.lower().split("(", 1)[0].strip() return self.base_data_types_reverse[key] @@ -59,22 +61,28 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): def get_field_type(self, data_type, description): field_type = super().get_field_type(data_type, description) - if description.pk and field_type in {'BigIntegerField', 'IntegerField', 'SmallIntegerField'}: + if description.pk and field_type in { + "BigIntegerField", + "IntegerField", + "SmallIntegerField", + }: # No support for BigAutoField or SmallAutoField as SQLite treats # all integer primary keys as signed 64-bit integers. - return 'AutoField' + return "AutoField" if description.has_json_constraint: - return 'JSONField' + return "JSONField" return field_type def get_table_list(self, cursor): """Return a list of table and view names in the current database.""" # Skip the sqlite_sequence system table used for autoincrement key # generation. - cursor.execute(""" + cursor.execute( + """ SELECT name, type FROM sqlite_master WHERE type in ('table', 'view') AND NOT name='sqlite_sequence' - ORDER BY name""") + ORDER BY name""" + ) return [TableInfo(row[0], row[1][0]) for row in cursor.fetchall()] def get_table_description(self, cursor, table_name): @@ -82,37 +90,51 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): Return a description of the table with the DB-API cursor.description interface. """ - cursor.execute('PRAGMA table_info(%s)' % self.connection.ops.quote_name(table_name)) + cursor.execute( + "PRAGMA table_info(%s)" % self.connection.ops.quote_name(table_name) + ) table_info = cursor.fetchall() if not table_info: - raise DatabaseError(f'Table {table_name} does not exist (empty pragma).') + raise DatabaseError(f"Table {table_name} does not exist (empty pragma).") collations = self._get_column_collations(cursor, table_name) json_columns = set() if self.connection.features.can_introspect_json_field: for line in table_info: column = line[1] json_constraint_sql = '%%json_valid("%s")%%' % column - has_json_constraint = cursor.execute(""" + has_json_constraint = cursor.execute( + """ SELECT sql FROM sqlite_master WHERE type = 'table' AND name = %s AND sql LIKE %s - """, [table_name, json_constraint_sql]).fetchone() + """, + [table_name, json_constraint_sql], + ).fetchone() if has_json_constraint: json_columns.add(column) return [ FieldInfo( - name, data_type, None, get_field_size(data_type), None, None, - not notnull, default, collations.get(name), pk == 1, name in json_columns + name, + data_type, + None, + get_field_size(data_type), + None, + None, + not notnull, + default, + collations.get(name), + pk == 1, + name in json_columns, ) for cid, name, data_type, notnull, default, pk in table_info ] def get_sequences(self, cursor, table_name, table_fields=()): pk_col = self.get_primary_key_column(cursor, table_name) - return [{'table': table_name, 'column': pk_col}] + return [{"table": table_name, "column": pk_col}] def get_relations(self, cursor, table_name): """ @@ -120,7 +142,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): representing all foreign keys in the given table. """ cursor.execute( - 'PRAGMA foreign_key_list(%s)' % self.connection.ops.quote_name(table_name) + "PRAGMA foreign_key_list(%s)" % self.connection.ops.quote_name(table_name) ) return { column_name: (ref_column_name, ref_table_name) @@ -130,7 +152,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): def get_primary_key_column(self, cursor, table_name): """Return the column name of the primary key for the given table.""" cursor.execute( - 'PRAGMA table_info(%s)' % self.connection.ops.quote_name(table_name) + "PRAGMA table_info(%s)" % self.connection.ops.quote_name(table_name) ) for _, name, *_, pk in cursor.fetchall(): if pk: @@ -148,19 +170,21 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): check_columns = [] braces_deep = 0 for token in tokens: - if token.match(sqlparse.tokens.Punctuation, '('): + if token.match(sqlparse.tokens.Punctuation, "("): braces_deep += 1 - elif token.match(sqlparse.tokens.Punctuation, ')'): + elif token.match(sqlparse.tokens.Punctuation, ")"): braces_deep -= 1 if braces_deep < 0: # End of columns and constraints for table definition. break - elif braces_deep == 0 and token.match(sqlparse.tokens.Punctuation, ','): + elif braces_deep == 0 and token.match(sqlparse.tokens.Punctuation, ","): # End of current column or constraint definition. break # Detect column or constraint definition by first token. if is_constraint_definition is None: - is_constraint_definition = token.match(sqlparse.tokens.Keyword, 'CONSTRAINT') + is_constraint_definition = token.match( + sqlparse.tokens.Keyword, "CONSTRAINT" + ) if is_constraint_definition: continue if is_constraint_definition: @@ -171,7 +195,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): elif token.ttype == sqlparse.tokens.Literal.String.Symbol: constraint_name = token.value[1:-1] # Start constraint columns parsing after UNIQUE keyword. - if token.match(sqlparse.tokens.Keyword, 'UNIQUE'): + if token.match(sqlparse.tokens.Keyword, "UNIQUE"): unique = True unique_braces_deep = braces_deep elif unique: @@ -191,10 +215,10 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): field_name = token.value elif token.ttype == sqlparse.tokens.Literal.String.Symbol: field_name = token.value[1:-1] - if token.match(sqlparse.tokens.Keyword, 'UNIQUE'): + if token.match(sqlparse.tokens.Keyword, "UNIQUE"): unique_columns = [field_name] # Start constraint columns parsing after CHECK keyword. - if token.match(sqlparse.tokens.Keyword, 'CHECK'): + if token.match(sqlparse.tokens.Keyword, "CHECK"): check = True check_braces_deep = braces_deep elif check: @@ -209,22 +233,30 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): elif token.ttype == sqlparse.tokens.Literal.String.Symbol: if token.value[1:-1] in columns: check_columns.append(token.value[1:-1]) - unique_constraint = { - 'unique': True, - 'columns': unique_columns, - 'primary_key': False, - 'foreign_key': None, - 'check': False, - 'index': False, - } if unique_columns else None - check_constraint = { - 'check': True, - 'columns': check_columns, - 'primary_key': False, - 'unique': False, - 'foreign_key': None, - 'index': False, - } if check_columns else None + unique_constraint = ( + { + "unique": True, + "columns": unique_columns, + "primary_key": False, + "foreign_key": None, + "check": False, + "index": False, + } + if unique_columns + else None + ) + check_constraint = ( + { + "check": True, + "columns": check_columns, + "primary_key": False, + "unique": False, + "foreign_key": None, + "index": False, + } + if check_columns + else None + ) return constraint_name, unique_constraint, check_constraint, token def _parse_table_constraints(self, sql, columns): @@ -236,24 +268,33 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): tokens = (token for token in statement.flatten() if not token.is_whitespace) # Go to columns and constraint definition for token in tokens: - if token.match(sqlparse.tokens.Punctuation, '('): + if token.match(sqlparse.tokens.Punctuation, "("): break # Parse columns and constraint definition while True: - constraint_name, unique, check, end_token = self._parse_column_or_constraint_definition(tokens, columns) + ( + constraint_name, + unique, + check, + end_token, + ) = self._parse_column_or_constraint_definition(tokens, columns) if unique: if constraint_name: constraints[constraint_name] = unique else: unnamed_constrains_index += 1 - constraints['__unnamed_constraint_%s__' % unnamed_constrains_index] = unique + constraints[ + "__unnamed_constraint_%s__" % unnamed_constrains_index + ] = unique if check: if constraint_name: constraints[constraint_name] = check else: unnamed_constrains_index += 1 - constraints['__unnamed_constraint_%s__' % unnamed_constrains_index] = check - if end_token.match(sqlparse.tokens.Punctuation, ')'): + constraints[ + "__unnamed_constraint_%s__" % unnamed_constrains_index + ] = check + if end_token.match(sqlparse.tokens.Punctuation, ")"): break return constraints @@ -266,19 +307,22 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): # Find inline check constraints. try: table_schema = cursor.execute( - "SELECT sql FROM sqlite_master WHERE type='table' and name=%s" % ( - self.connection.ops.quote_name(table_name), - ) + "SELECT sql FROM sqlite_master WHERE type='table' and name=%s" + % (self.connection.ops.quote_name(table_name),) ).fetchone()[0] except TypeError: # table_name is a view. pass else: - columns = {info.name for info in self.get_table_description(cursor, table_name)} + columns = { + info.name for info in self.get_table_description(cursor, table_name) + } constraints.update(self._parse_table_constraints(table_schema, columns)) # Get the index info - cursor.execute("PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name)) + cursor.execute( + "PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name) + ) for row in cursor.fetchall(): # SQLite 3.8.9+ has 5 columns, however older versions only give 3 # columns. Discard last 2 columns if there. @@ -288,7 +332,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): "WHERE type='index' AND name=%s" % self.connection.ops.quote_name(index) ) # There's at most one row. - sql, = cursor.fetchone() or (None,) + (sql,) = cursor.fetchone() or (None,) # Inline constraints are already detected in # _parse_table_constraints(). The reasons to avoid fetching inline # constraints from `PRAGMA index_list` are: @@ -299,7 +343,9 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): # An inline constraint continue # Get the index info for that index - cursor.execute('PRAGMA index_info(%s)' % self.connection.ops.quote_name(index)) + cursor.execute( + "PRAGMA index_info(%s)" % self.connection.ops.quote_name(index) + ) for index_rank, column_rank, column in cursor.fetchall(): if index not in constraints: constraints[index] = { @@ -310,14 +356,14 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): "check": False, "index": True, } - constraints[index]['columns'].append(column) + constraints[index]["columns"].append(column) # Add type and column orders for indexes - if constraints[index]['index']: + if constraints[index]["index"]: # SQLite doesn't support any index type other than b-tree - constraints[index]['type'] = Index.suffix + constraints[index]["type"] = Index.suffix orders = self._get_index_columns_orders(sql) if orders is not None: - constraints[index]['orders'] = orders + constraints[index]["orders"] = orders # Get the PK pk_column = self.get_primary_key_column(cursor, table_name) if pk_column: @@ -334,44 +380,49 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): "index": False, } relations = enumerate(self.get_relations(cursor, table_name).items()) - constraints.update({ - f'fk_{index}': { - 'columns': [column_name], - 'primary_key': False, - 'unique': False, - 'foreign_key': (ref_table_name, ref_column_name), - 'check': False, - 'index': False, + constraints.update( + { + f"fk_{index}": { + "columns": [column_name], + "primary_key": False, + "unique": False, + "foreign_key": (ref_table_name, ref_column_name), + "check": False, + "index": False, + } + for index, (column_name, (ref_column_name, ref_table_name)) in relations } - for index, (column_name, (ref_column_name, ref_table_name)) in relations - }) + ) return constraints def _get_index_columns_orders(self, sql): tokens = sqlparse.parse(sql)[0] for token in tokens: if isinstance(token, sqlparse.sql.Parenthesis): - columns = str(token).strip('()').split(', ') - return ['DESC' if info.endswith('DESC') else 'ASC' for info in columns] + columns = str(token).strip("()").split(", ") + return ["DESC" if info.endswith("DESC") else "ASC" for info in columns] return None def _get_column_collations(self, cursor, table_name): - row = cursor.execute(""" + row = cursor.execute( + """ SELECT sql FROM sqlite_master WHERE type = 'table' AND name = %s - """, [table_name]).fetchone() + """, + [table_name], + ).fetchone() if not row: return {} sql = row[0] - columns = str(sqlparse.parse(sql)[0][-1]).strip('()').split(', ') + columns = str(sqlparse.parse(sql)[0][-1]).strip("()").split(", ") collations = {} for column in columns: tokens = column[1:].split() column_name = tokens[0].strip('"') for index, token in enumerate(tokens): - if token == 'COLLATE': + if token == "COLLATE": collation = tokens[index + 1] break else: diff --git a/django/db/backends/sqlite3/operations.py b/django/db/backends/sqlite3/operations.py index c1a6da4e5d..ef8b91c0f0 100644 --- a/django/db/backends/sqlite3/operations.py +++ b/django/db/backends/sqlite3/operations.py @@ -16,15 +16,15 @@ from django.utils.functional import cached_property class DatabaseOperations(BaseDatabaseOperations): - cast_char_field_without_max_length = 'text' + cast_char_field_without_max_length = "text" cast_data_types = { - 'DateField': 'TEXT', - 'DateTimeField': 'TEXT', + "DateField": "TEXT", + "DateTimeField": "TEXT", } - explain_prefix = 'EXPLAIN QUERY PLAN' + explain_prefix = "EXPLAIN QUERY PLAN" # List of datatypes to that cannot be extracted with JSON_EXTRACT() on # SQLite. Use JSON_TYPE() instead. - jsonfield_datatype_values = frozenset(['null', 'false', 'true']) + jsonfield_datatype_values = frozenset(["null", "false", "true"]) def bulk_batch_size(self, fields, objs): """ @@ -55,14 +55,14 @@ class DatabaseOperations(BaseDatabaseOperations): else: if isinstance(output_field, bad_fields): raise NotSupportedError( - 'You cannot use Sum, Avg, StdDev, and Variance ' - 'aggregations on date/time fields in sqlite3 ' - 'since date/time is saved as text.' + "You cannot use Sum, Avg, StdDev, and Variance " + "aggregations on date/time fields in sqlite3 " + "since date/time is saved as text." ) if ( - isinstance(expression, models.Aggregate) and - expression.distinct and - len(expression.source_expressions) > 1 + isinstance(expression, models.Aggregate) + and expression.distinct + and len(expression.source_expressions) > 1 ): raise NotSupportedError( "SQLite doesn't support DISTINCT on aggregate functions " @@ -105,26 +105,32 @@ class DatabaseOperations(BaseDatabaseOperations): def _convert_tznames_to_sql(self, tzname): if tzname and settings.USE_TZ: return "'%s'" % tzname, "'%s'" % self.connection.timezone_name - return 'NULL', 'NULL' + return "NULL", "NULL" def datetime_cast_date_sql(self, field_name, tzname): - return 'django_datetime_cast_date(%s, %s, %s)' % ( - field_name, *self._convert_tznames_to_sql(tzname), + return "django_datetime_cast_date(%s, %s, %s)" % ( + field_name, + *self._convert_tznames_to_sql(tzname), ) def datetime_cast_time_sql(self, field_name, tzname): - return 'django_datetime_cast_time(%s, %s, %s)' % ( - field_name, *self._convert_tznames_to_sql(tzname), + return "django_datetime_cast_time(%s, %s, %s)" % ( + field_name, + *self._convert_tznames_to_sql(tzname), ) def datetime_extract_sql(self, lookup_type, field_name, tzname): return "django_datetime_extract('%s', %s, %s, %s)" % ( - lookup_type.lower(), field_name, *self._convert_tznames_to_sql(tzname), + lookup_type.lower(), + field_name, + *self._convert_tznames_to_sql(tzname), ) def datetime_trunc_sql(self, lookup_type, field_name, tzname): return "django_datetime_trunc('%s', %s, %s, %s)" % ( - lookup_type.lower(), field_name, *self._convert_tznames_to_sql(tzname), + lookup_type.lower(), + field_name, + *self._convert_tznames_to_sql(tzname), ) def time_extract_sql(self, lookup_type, field_name): @@ -146,11 +152,11 @@ class DatabaseOperations(BaseDatabaseOperations): if len(params) > BATCH_SIZE: results = () for index in range(0, len(params), BATCH_SIZE): - chunk = params[index:index + BATCH_SIZE] + chunk = params[index : index + BATCH_SIZE] results += self._quote_params_for_last_executed_query(chunk) return results - sql = 'SELECT ' + ', '.join(['QUOTE(?)'] * len(params)) + sql = "SELECT " + ", ".join(["QUOTE(?)"] * len(params)) # Bypass Django's wrappers and use the underlying sqlite3 connection # to avoid logging this query - it would trigger infinite recursion. cursor = self.connection.connection.cursor() @@ -215,14 +221,20 @@ class DatabaseOperations(BaseDatabaseOperations): if tables and allow_cascade: # Simulate TRUNCATE CASCADE by recursively collecting the tables # referencing the tables to be flushed. - tables = set(chain.from_iterable(self._references_graph(table) for table in tables)) - sql = ['%s %s %s;' % ( - style.SQL_KEYWORD('DELETE'), - style.SQL_KEYWORD('FROM'), - style.SQL_FIELD(self.quote_name(table)) - ) for table in tables] + tables = set( + chain.from_iterable(self._references_graph(table) for table in tables) + ) + sql = [ + "%s %s %s;" + % ( + style.SQL_KEYWORD("DELETE"), + style.SQL_KEYWORD("FROM"), + style.SQL_FIELD(self.quote_name(table)), + ) + for table in tables + ] if reset_sequences: - sequences = [{'table': table} for table in tables] + sequences = [{"table": table} for table in tables] sql.extend(self.sequence_reset_by_name_sql(style, sequences)) return sql @@ -230,17 +242,18 @@ class DatabaseOperations(BaseDatabaseOperations): if not sequences: return [] return [ - '%s %s %s %s = 0 %s %s %s (%s);' % ( - style.SQL_KEYWORD('UPDATE'), - style.SQL_TABLE(self.quote_name('sqlite_sequence')), - style.SQL_KEYWORD('SET'), - style.SQL_FIELD(self.quote_name('seq')), - style.SQL_KEYWORD('WHERE'), - style.SQL_FIELD(self.quote_name('name')), - style.SQL_KEYWORD('IN'), - ', '.join([ - "'%s'" % sequence_info['table'] for sequence_info in sequences - ]), + "%s %s %s %s = 0 %s %s %s (%s);" + % ( + style.SQL_KEYWORD("UPDATE"), + style.SQL_TABLE(self.quote_name("sqlite_sequence")), + style.SQL_KEYWORD("SET"), + style.SQL_FIELD(self.quote_name("seq")), + style.SQL_KEYWORD("WHERE"), + style.SQL_FIELD(self.quote_name("name")), + style.SQL_KEYWORD("IN"), + ", ".join( + ["'%s'" % sequence_info["table"] for sequence_info in sequences] + ), ), ] @@ -249,7 +262,7 @@ class DatabaseOperations(BaseDatabaseOperations): return None # Expression values are adapted by the database. - if hasattr(value, 'resolve_expression'): + if hasattr(value, "resolve_expression"): return value # SQLite doesn't support tz-aware datetimes @@ -257,7 +270,9 @@ class DatabaseOperations(BaseDatabaseOperations): if settings.USE_TZ: value = timezone.make_naive(value, self.connection.timezone) else: - raise ValueError("SQLite backend does not support timezone-aware datetimes when USE_TZ is False.") + raise ValueError( + "SQLite backend does not support timezone-aware datetimes when USE_TZ is False." + ) return str(value) @@ -266,7 +281,7 @@ class DatabaseOperations(BaseDatabaseOperations): return None # Expression values are adapted by the database. - if hasattr(value, 'resolve_expression'): + if hasattr(value, "resolve_expression"): return value # SQLite doesn't support tz-aware datetimes @@ -278,17 +293,17 @@ class DatabaseOperations(BaseDatabaseOperations): def get_db_converters(self, expression): converters = super().get_db_converters(expression) internal_type = expression.output_field.get_internal_type() - if internal_type == 'DateTimeField': + if internal_type == "DateTimeField": converters.append(self.convert_datetimefield_value) - elif internal_type == 'DateField': + elif internal_type == "DateField": converters.append(self.convert_datefield_value) - elif internal_type == 'TimeField': + elif internal_type == "TimeField": converters.append(self.convert_timefield_value) - elif internal_type == 'DecimalField': + elif internal_type == "DecimalField": converters.append(self.get_decimalfield_converter(expression)) - elif internal_type == 'UUIDField': + elif internal_type == "UUIDField": converters.append(self.convert_uuidfield_value) - elif internal_type == 'BooleanField': + elif internal_type == "BooleanField": converters.append(self.convert_booleanfield_value) return converters @@ -317,15 +332,22 @@ class DatabaseOperations(BaseDatabaseOperations): # float inaccuracy must be removed. create_decimal = decimal.Context(prec=15).create_decimal_from_float if isinstance(expression, Col): - quantize_value = decimal.Decimal(1).scaleb(-expression.output_field.decimal_places) + quantize_value = decimal.Decimal(1).scaleb( + -expression.output_field.decimal_places + ) def converter(value, expression, connection): if value is not None: - return create_decimal(value).quantize(quantize_value, context=expression.output_field.context) + return create_decimal(value).quantize( + quantize_value, context=expression.output_field.context + ) + else: + def converter(value, expression, connection): if value is not None: return create_decimal(value) + return converter def convert_uuidfield_value(self, value, expression, connection): @@ -337,26 +359,26 @@ class DatabaseOperations(BaseDatabaseOperations): return bool(value) if value in (1, 0) else value def bulk_insert_sql(self, fields, placeholder_rows): - placeholder_rows_sql = (', '.join(row) for row in placeholder_rows) - values_sql = ', '.join(f'({sql})' for sql in placeholder_rows_sql) - return f'VALUES {values_sql}' + placeholder_rows_sql = (", ".join(row) for row in placeholder_rows) + values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql) + return f"VALUES {values_sql}" def combine_expression(self, connector, sub_expressions): # SQLite doesn't have a ^ operator, so use the user-defined POWER # function that's registered in connect(). - if connector == '^': - return 'POWER(%s)' % ','.join(sub_expressions) - elif connector == '#': - return 'BITXOR(%s)' % ','.join(sub_expressions) + if connector == "^": + return "POWER(%s)" % ",".join(sub_expressions) + elif connector == "#": + return "BITXOR(%s)" % ",".join(sub_expressions) return super().combine_expression(connector, sub_expressions) def combine_duration_expression(self, connector, sub_expressions): - if connector not in ['+', '-', '*', '/']: - raise DatabaseError('Invalid connector for timedelta: %s.' % connector) + if connector not in ["+", "-", "*", "/"]: + raise DatabaseError("Invalid connector for timedelta: %s." % connector) fn_params = ["'%s'" % connector] + sub_expressions if len(fn_params) > 3: - raise ValueError('Too many params for timedelta operations.') - return "django_format_dtdelta(%s)" % ', '.join(fn_params) + raise ValueError("Too many params for timedelta operations.") + return "django_format_dtdelta(%s)" % ", ".join(fn_params) def integer_field_range(self, internal_type): # SQLite doesn't enforce any integer constraints @@ -366,39 +388,46 @@ class DatabaseOperations(BaseDatabaseOperations): lhs_sql, lhs_params = lhs rhs_sql, rhs_params = rhs params = (*lhs_params, *rhs_params) - if internal_type == 'TimeField': - return 'django_time_diff(%s, %s)' % (lhs_sql, rhs_sql), params - return 'django_timestamp_diff(%s, %s)' % (lhs_sql, rhs_sql), params + if internal_type == "TimeField": + return "django_time_diff(%s, %s)" % (lhs_sql, rhs_sql), params + return "django_timestamp_diff(%s, %s)" % (lhs_sql, rhs_sql), params def insert_statement(self, on_conflict=None): if on_conflict == OnConflict.IGNORE: - return 'INSERT OR IGNORE INTO' + return "INSERT OR IGNORE INTO" return super().insert_statement(on_conflict=on_conflict) def return_insert_columns(self, fields): # SQLite < 3.35 doesn't support an INSERT...RETURNING statement. if not fields: - return '', () + return "", () columns = [ - '%s.%s' % ( + "%s.%s" + % ( self.quote_name(field.model._meta.db_table), self.quote_name(field.column), - ) for field in fields + ) + for field in fields ] - return 'RETURNING %s' % ', '.join(columns), () + return "RETURNING %s" % ", ".join(columns), () def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields): if ( - on_conflict == OnConflict.UPDATE and - self.connection.features.supports_update_conflicts_with_target + on_conflict == OnConflict.UPDATE + and self.connection.features.supports_update_conflicts_with_target ): - return 'ON CONFLICT(%s) DO UPDATE SET %s' % ( - ', '.join(map(self.quote_name, unique_fields)), - ', '.join([ - f'{field} = EXCLUDED.{field}' - for field in map(self.quote_name, update_fields) - ]), + return "ON CONFLICT(%s) DO UPDATE SET %s" % ( + ", ".join(map(self.quote_name, unique_fields)), + ", ".join( + [ + f"{field} = EXCLUDED.{field}" + for field in map(self.quote_name, update_fields) + ] + ), ) return super().on_conflict_suffix_sql( - fields, on_conflict, update_fields, unique_fields, + fields, + on_conflict, + update_fields, + unique_fields, ) diff --git a/django/db/backends/sqlite3/schema.py b/django/db/backends/sqlite3/schema.py index 3ff0a3f7db..c9af8088e5 100644 --- a/django/db/backends/sqlite3/schema.py +++ b/django/db/backends/sqlite3/schema.py @@ -14,7 +14,9 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): sql_delete_table = "DROP TABLE %(table)s" sql_create_fk = None - sql_create_inline_fk = "REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED" + sql_create_inline_fk = ( + "REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED" + ) sql_create_column_inline_fk = sql_create_inline_fk sql_create_unique = "CREATE UNIQUE INDEX %(name)s ON %(table)s (%(columns)s)" sql_delete_unique = "DROP INDEX %(name)s" @@ -24,11 +26,11 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): # disabled. Enforce it here for the duration of the schema edition. if not self.connection.disable_constraint_checking(): raise NotSupportedError( - 'SQLite schema editor cannot be used while foreign key ' - 'constraint checks are enabled. Make sure to disable them ' - 'before entering a transaction.atomic() context because ' - 'SQLite does not support disabling them in the middle of ' - 'a multi-statement transaction.' + "SQLite schema editor cannot be used while foreign key " + "constraint checks are enabled. Make sure to disable them " + "before entering a transaction.atomic() context because " + "SQLite does not support disabling them in the middle of " + "a multi-statement transaction." ) return super().__enter__() @@ -43,6 +45,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): # security hardening). try: import sqlite3 + value = sqlite3.adapt(value) except ImportError: pass @@ -54,7 +57,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): elif isinstance(value, (Decimal, float, int)): return str(value) elif isinstance(value, str): - return "'%s'" % value.replace("\'", "\'\'") + return "'%s'" % value.replace("'", "''") elif value is None: return "NULL" elif isinstance(value, (bytes, bytearray, memoryview)): @@ -63,12 +66,16 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): # character. return "X'%s'" % value.hex() else: - raise ValueError("Cannot quote parameter value %r of type %s" % (value, type(value))) + raise ValueError( + "Cannot quote parameter value %r of type %s" % (value, type(value)) + ) def prepare_default(self, value): return self.quote_value(value) - def _is_referenced_by_fk_constraint(self, table_name, column_name=None, ignore_self=False): + def _is_referenced_by_fk_constraint( + self, table_name, column_name=None, ignore_self=False + ): """ Return whether or not the provided table name is referenced by another one. If `column_name` is specified, only references pointing to that @@ -79,22 +86,33 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): for other_table in self.connection.introspection.get_table_list(cursor): if ignore_self and other_table.name == table_name: continue - relations = self.connection.introspection.get_relations(cursor, other_table.name) + relations = self.connection.introspection.get_relations( + cursor, other_table.name + ) for constraint_column, constraint_table in relations.values(): - if (constraint_table == table_name and - (column_name is None or constraint_column == column_name)): + if constraint_table == table_name and ( + column_name is None or constraint_column == column_name + ): return True return False - def alter_db_table(self, model, old_db_table, new_db_table, disable_constraints=True): - if (not self.connection.features.supports_atomic_references_rename and - disable_constraints and self._is_referenced_by_fk_constraint(old_db_table)): + def alter_db_table( + self, model, old_db_table, new_db_table, disable_constraints=True + ): + if ( + not self.connection.features.supports_atomic_references_rename + and disable_constraints + and self._is_referenced_by_fk_constraint(old_db_table) + ): if self.connection.in_atomic_block: - raise NotSupportedError(( - 'Renaming the %r table while in a transaction is not ' - 'supported on SQLite < 3.26 because it would break referential ' - 'integrity. Try adding `atomic = False` to the Migration class.' - ) % old_db_table) + raise NotSupportedError( + ( + "Renaming the %r table while in a transaction is not " + "supported on SQLite < 3.26 because it would break referential " + "integrity. Try adding `atomic = False` to the Migration class." + ) + % old_db_table + ) self.connection.enable_constraint_checking() super().alter_db_table(model, old_db_table, new_db_table) self.connection.disable_constraint_checking() @@ -107,42 +125,56 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): old_field_name = old_field.name table_name = model._meta.db_table _, old_column_name = old_field.get_attname_column() - if (new_field.name != old_field_name and - not self.connection.features.supports_atomic_references_rename and - self._is_referenced_by_fk_constraint(table_name, old_column_name, ignore_self=True)): + if ( + new_field.name != old_field_name + and not self.connection.features.supports_atomic_references_rename + and self._is_referenced_by_fk_constraint( + table_name, old_column_name, ignore_self=True + ) + ): if self.connection.in_atomic_block: - raise NotSupportedError(( - 'Renaming the %r.%r column while in a transaction is not ' - 'supported on SQLite < 3.26 because it would break referential ' - 'integrity. Try adding `atomic = False` to the Migration class.' - ) % (model._meta.db_table, old_field_name)) + raise NotSupportedError( + ( + "Renaming the %r.%r column while in a transaction is not " + "supported on SQLite < 3.26 because it would break referential " + "integrity. Try adding `atomic = False` to the Migration class." + ) + % (model._meta.db_table, old_field_name) + ) with atomic(self.connection.alias): super().alter_field(model, old_field, new_field, strict=strict) # Follow SQLite's documented procedure for performing changes # that don't affect the on-disk content. # https://sqlite.org/lang_altertable.html#otheralter with self.connection.cursor() as cursor: - schema_version = cursor.execute('PRAGMA schema_version').fetchone()[0] - cursor.execute('PRAGMA writable_schema = 1') + schema_version = cursor.execute("PRAGMA schema_version").fetchone()[ + 0 + ] + cursor.execute("PRAGMA writable_schema = 1") references_template = ' REFERENCES "%s" ("%%s") ' % table_name new_column_name = new_field.get_attname_column()[1] search = references_template % old_column_name replacement = references_template % new_column_name - cursor.execute('UPDATE sqlite_master SET sql = replace(sql, %s, %s)', (search, replacement)) - cursor.execute('PRAGMA schema_version = %d' % (schema_version + 1)) - cursor.execute('PRAGMA writable_schema = 0') + cursor.execute( + "UPDATE sqlite_master SET sql = replace(sql, %s, %s)", + (search, replacement), + ) + cursor.execute("PRAGMA schema_version = %d" % (schema_version + 1)) + cursor.execute("PRAGMA writable_schema = 0") # The integrity check will raise an exception and rollback # the transaction if the sqlite_master updates corrupt the # database. - cursor.execute('PRAGMA integrity_check') + cursor.execute("PRAGMA integrity_check") # Perform a VACUUM to refresh the database representation from # the sqlite_master table. with self.connection.cursor() as cursor: - cursor.execute('VACUUM') + cursor.execute("VACUUM") else: super().alter_field(model, old_field, new_field, strict=strict) - def _remake_table(self, model, create_field=None, delete_field=None, alter_field=None): + def _remake_table( + self, model, create_field=None, delete_field=None, alter_field=None + ): """ Shortcut to transform a model from old_model into new_model @@ -163,6 +195,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): # to an altered field. def is_self_referential(f): return f.is_relation and f.remote_field.model is model + # Work out the new fields dict / mapping body = { f.name: f.clone() if is_self_referential(f) else f @@ -170,14 +203,18 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): } # Since mapping might mix column names and default values, # its values must be already quoted. - mapping = {f.column: self.quote_name(f.column) for f in model._meta.local_concrete_fields} + mapping = { + f.column: self.quote_name(f.column) + for f in model._meta.local_concrete_fields + } # This maps field names (not columns) for things like unique_together rename_mapping = {} # If any of the new or altered fields is introducing a new PK, # remove the old one restore_pk_field = None - if getattr(create_field, 'primary_key', False) or ( - alter_field and getattr(alter_field[1], 'primary_key', False)): + if getattr(create_field, "primary_key", False) or ( + alter_field and getattr(alter_field[1], "primary_key", False) + ): for name, field in list(body.items()): if field.primary_key: field.primary_key = False @@ -201,8 +238,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): body[new_field.name] = new_field if old_field.null and not new_field.null: case_sql = "coalesce(%(col)s, %(default)s)" % { - 'col': self.quote_name(old_field.column), - 'default': self.prepare_default(self.effective_default(new_field)), + "col": self.quote_name(old_field.column), + "default": self.prepare_default(self.effective_default(new_field)), } mapping[new_field.column] = case_sql else: @@ -213,7 +250,10 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): del body[delete_field.name] del mapping[delete_field.column] # Remove any implicit M2M tables - if delete_field.many_to_many and delete_field.remote_field.through._meta.auto_created: + if ( + delete_field.many_to_many + and delete_field.remote_field.through._meta.auto_created + ): return self.delete_model(delete_field.remote_field.through) # Work inside a new app registry apps = Apps() @@ -235,8 +275,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): indexes = model._meta.indexes if delete_field: indexes = [ - index for index in indexes - if delete_field.name not in index.fields + index for index in indexes if delete_field.name not in index.fields ] constraints = list(model._meta.constraints) @@ -252,52 +291,57 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): # This wouldn't be required if the schema editor was operating on model # states instead of rendered models. meta_contents = { - 'app_label': model._meta.app_label, - 'db_table': model._meta.db_table, - 'unique_together': unique_together, - 'index_together': index_together, - 'indexes': indexes, - 'constraints': constraints, - 'apps': apps, + "app_label": model._meta.app_label, + "db_table": model._meta.db_table, + "unique_together": unique_together, + "index_together": index_together, + "indexes": indexes, + "constraints": constraints, + "apps": apps, } meta = type("Meta", (), meta_contents) - body_copy['Meta'] = meta - body_copy['__module__'] = model.__module__ + body_copy["Meta"] = meta + body_copy["__module__"] = model.__module__ type(model._meta.object_name, model.__bases__, body_copy) # Construct a model with a renamed table name. body_copy = copy.deepcopy(body) meta_contents = { - 'app_label': model._meta.app_label, - 'db_table': 'new__%s' % strip_quotes(model._meta.db_table), - 'unique_together': unique_together, - 'index_together': index_together, - 'indexes': indexes, - 'constraints': constraints, - 'apps': apps, + "app_label": model._meta.app_label, + "db_table": "new__%s" % strip_quotes(model._meta.db_table), + "unique_together": unique_together, + "index_together": index_together, + "indexes": indexes, + "constraints": constraints, + "apps": apps, } meta = type("Meta", (), meta_contents) - body_copy['Meta'] = meta - body_copy['__module__'] = model.__module__ - new_model = type('New%s' % model._meta.object_name, model.__bases__, body_copy) + body_copy["Meta"] = meta + body_copy["__module__"] = model.__module__ + new_model = type("New%s" % model._meta.object_name, model.__bases__, body_copy) # Create a new table with the updated schema. self.create_model(new_model) # Copy data from the old table into the new table - self.execute("INSERT INTO %s (%s) SELECT %s FROM %s" % ( - self.quote_name(new_model._meta.db_table), - ', '.join(self.quote_name(x) for x in mapping), - ', '.join(mapping.values()), - self.quote_name(model._meta.db_table), - )) + self.execute( + "INSERT INTO %s (%s) SELECT %s FROM %s" + % ( + self.quote_name(new_model._meta.db_table), + ", ".join(self.quote_name(x) for x in mapping), + ", ".join(mapping.values()), + self.quote_name(model._meta.db_table), + ) + ) # Delete the old table to make way for the new self.delete_model(model, handle_autom2m=False) # Rename the new table to take way for the old self.alter_db_table( - new_model, new_model._meta.db_table, model._meta.db_table, + new_model, + new_model._meta.db_table, + model._meta.db_table, disable_constraints=False, ) @@ -314,12 +358,17 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): super().delete_model(model) else: # Delete the table (and only that) - self.execute(self.sql_delete_table % { - "table": self.quote_name(model._meta.db_table), - }) + self.execute( + self.sql_delete_table + % { + "table": self.quote_name(model._meta.db_table), + } + ) # Remove all deferred statements referencing the deleted table. for sql in list(self.deferred_sql): - if isinstance(sql, Statement) and sql.references_table(model._meta.db_table): + if isinstance(sql, Statement) and sql.references_table( + model._meta.db_table + ): self.deferred_sql.remove(sql) def add_field(self, model, field): @@ -327,11 +376,14 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): if ( # Primary keys and unique fields are not supported in ALTER TABLE # ADD COLUMN. - field.primary_key or field.unique or + field.primary_key + or field.unique + or # Fields with default values cannot by handled by ALTER TABLE ADD # COLUMN statement because DROP DEFAULT is not supported in # ALTER TABLE. - not field.null or self.effective_default(field) is not None + not field.null + or self.effective_default(field) is not None ): self._remake_table(model, create_field=field) else: @@ -351,21 +403,40 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): # For everything else, remake. else: # It might not actually have a column behind it - if field.db_parameters(connection=self.connection)['type'] is None: + if field.db_parameters(connection=self.connection)["type"] is None: return self._remake_table(model, delete_field=field) - def _alter_field(self, model, old_field, new_field, old_type, new_type, - old_db_params, new_db_params, strict=False): + def _alter_field( + self, + model, + old_field, + new_field, + old_type, + new_type, + old_db_params, + new_db_params, + strict=False, + ): """Perform a "physical" (non-ManyToMany) field update.""" # Use "ALTER TABLE ... RENAME COLUMN" if only the column name # changed and there aren't any constraints. - if (self.connection.features.can_alter_table_rename_column and - old_field.column != new_field.column and - self.column_sql(model, old_field) == self.column_sql(model, new_field) and - not (old_field.remote_field and old_field.db_constraint or - new_field.remote_field and new_field.db_constraint)): - return self.execute(self._rename_field_sql(model._meta.db_table, old_field, new_field, new_type)) + if ( + self.connection.features.can_alter_table_rename_column + and old_field.column != new_field.column + and self.column_sql(model, old_field) == self.column_sql(model, new_field) + and not ( + old_field.remote_field + and old_field.db_constraint + or new_field.remote_field + and new_field.db_constraint + ) + ): + return self.execute( + self._rename_field_sql( + model._meta.db_table, old_field, new_field, new_type + ) + ) # Alter by remaking table self._remake_table(model, alter_field=(old_field, new_field)) # Rebuild tables with FKs pointing to this field. @@ -393,15 +464,22 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): def _alter_many_to_many(self, model, old_field, new_field, strict): """Alter M2Ms to repoint their to= endpoints.""" - if old_field.remote_field.through._meta.db_table == new_field.remote_field.through._meta.db_table: + if ( + old_field.remote_field.through._meta.db_table + == new_field.remote_field.through._meta.db_table + ): # The field name didn't change, but some options did; we have to propagate this altering. self._remake_table( old_field.remote_field.through, alter_field=( # We need the field that points to the target model, so we can tell alter_field to change it - # this is m2m_reverse_field_name() (as opposed to m2m_field_name, which points to our model) - old_field.remote_field.through._meta.get_field(old_field.m2m_reverse_field_name()), - new_field.remote_field.through._meta.get_field(new_field.m2m_reverse_field_name()), + old_field.remote_field.through._meta.get_field( + old_field.m2m_reverse_field_name() + ), + new_field.remote_field.through._meta.get_field( + new_field.m2m_reverse_field_name() + ), ), ) return @@ -409,29 +487,36 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): # Make a new through table self.create_model(new_field.remote_field.through) # Copy the data across - self.execute("INSERT INTO %s (%s) SELECT %s FROM %s" % ( - self.quote_name(new_field.remote_field.through._meta.db_table), - ', '.join([ - "id", - new_field.m2m_column_name(), - new_field.m2m_reverse_name(), - ]), - ', '.join([ - "id", - old_field.m2m_column_name(), - old_field.m2m_reverse_name(), - ]), - self.quote_name(old_field.remote_field.through._meta.db_table), - )) + self.execute( + "INSERT INTO %s (%s) SELECT %s FROM %s" + % ( + self.quote_name(new_field.remote_field.through._meta.db_table), + ", ".join( + [ + "id", + new_field.m2m_column_name(), + new_field.m2m_reverse_name(), + ] + ), + ", ".join( + [ + "id", + old_field.m2m_column_name(), + old_field.m2m_reverse_name(), + ] + ), + self.quote_name(old_field.remote_field.through._meta.db_table), + ) + ) # Delete the old through table self.delete_model(old_field.remote_field.through) def add_constraint(self, model, constraint): if isinstance(constraint, UniqueConstraint) and ( - constraint.condition or - constraint.contains_expressions or - constraint.include or - constraint.deferrable + constraint.condition + or constraint.contains_expressions + or constraint.include + or constraint.deferrable ): super().add_constraint(model, constraint) else: @@ -439,14 +524,14 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): def remove_constraint(self, model, constraint): if isinstance(constraint, UniqueConstraint) and ( - constraint.condition or - constraint.contains_expressions or - constraint.include or - constraint.deferrable + constraint.condition + or constraint.contains_expressions + or constraint.include + or constraint.deferrable ): super().remove_constraint(model, constraint) else: self._remake_table(model) def _collate_sql(self, collation): - return 'COLLATE ' + collation + return "COLLATE " + collation |
