summaryrefslogtreecommitdiff
path: root/django/db/backends/sqlite3/operations.py
diff options
context:
space:
mode:
Diffstat (limited to 'django/db/backends/sqlite3/operations.py')
-rw-r--r--django/db/backends/sqlite3/operations.py183
1 files changed, 106 insertions, 77 deletions
diff --git a/django/db/backends/sqlite3/operations.py b/django/db/backends/sqlite3/operations.py
index c1a6da4e5d..ef8b91c0f0 100644
--- a/django/db/backends/sqlite3/operations.py
+++ b/django/db/backends/sqlite3/operations.py
@@ -16,15 +16,15 @@ from django.utils.functional import cached_property
class DatabaseOperations(BaseDatabaseOperations):
- cast_char_field_without_max_length = 'text'
+ cast_char_field_without_max_length = "text"
cast_data_types = {
- 'DateField': 'TEXT',
- 'DateTimeField': 'TEXT',
+ "DateField": "TEXT",
+ "DateTimeField": "TEXT",
}
- explain_prefix = 'EXPLAIN QUERY PLAN'
+ explain_prefix = "EXPLAIN QUERY PLAN"
# List of datatypes to that cannot be extracted with JSON_EXTRACT() on
# SQLite. Use JSON_TYPE() instead.
- jsonfield_datatype_values = frozenset(['null', 'false', 'true'])
+ jsonfield_datatype_values = frozenset(["null", "false", "true"])
def bulk_batch_size(self, fields, objs):
"""
@@ -55,14 +55,14 @@ class DatabaseOperations(BaseDatabaseOperations):
else:
if isinstance(output_field, bad_fields):
raise NotSupportedError(
- 'You cannot use Sum, Avg, StdDev, and Variance '
- 'aggregations on date/time fields in sqlite3 '
- 'since date/time is saved as text.'
+ "You cannot use Sum, Avg, StdDev, and Variance "
+ "aggregations on date/time fields in sqlite3 "
+ "since date/time is saved as text."
)
if (
- isinstance(expression, models.Aggregate) and
- expression.distinct and
- len(expression.source_expressions) > 1
+ isinstance(expression, models.Aggregate)
+ and expression.distinct
+ and len(expression.source_expressions) > 1
):
raise NotSupportedError(
"SQLite doesn't support DISTINCT on aggregate functions "
@@ -105,26 +105,32 @@ class DatabaseOperations(BaseDatabaseOperations):
def _convert_tznames_to_sql(self, tzname):
if tzname and settings.USE_TZ:
return "'%s'" % tzname, "'%s'" % self.connection.timezone_name
- return 'NULL', 'NULL'
+ return "NULL", "NULL"
def datetime_cast_date_sql(self, field_name, tzname):
- return 'django_datetime_cast_date(%s, %s, %s)' % (
- field_name, *self._convert_tznames_to_sql(tzname),
+ return "django_datetime_cast_date(%s, %s, %s)" % (
+ field_name,
+ *self._convert_tznames_to_sql(tzname),
)
def datetime_cast_time_sql(self, field_name, tzname):
- return 'django_datetime_cast_time(%s, %s, %s)' % (
- field_name, *self._convert_tznames_to_sql(tzname),
+ return "django_datetime_cast_time(%s, %s, %s)" % (
+ field_name,
+ *self._convert_tznames_to_sql(tzname),
)
def datetime_extract_sql(self, lookup_type, field_name, tzname):
return "django_datetime_extract('%s', %s, %s, %s)" % (
- lookup_type.lower(), field_name, *self._convert_tznames_to_sql(tzname),
+ lookup_type.lower(),
+ field_name,
+ *self._convert_tznames_to_sql(tzname),
)
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
return "django_datetime_trunc('%s', %s, %s, %s)" % (
- lookup_type.lower(), field_name, *self._convert_tznames_to_sql(tzname),
+ lookup_type.lower(),
+ field_name,
+ *self._convert_tznames_to_sql(tzname),
)
def time_extract_sql(self, lookup_type, field_name):
@@ -146,11 +152,11 @@ class DatabaseOperations(BaseDatabaseOperations):
if len(params) > BATCH_SIZE:
results = ()
for index in range(0, len(params), BATCH_SIZE):
- chunk = params[index:index + BATCH_SIZE]
+ chunk = params[index : index + BATCH_SIZE]
results += self._quote_params_for_last_executed_query(chunk)
return results
- sql = 'SELECT ' + ', '.join(['QUOTE(?)'] * len(params))
+ sql = "SELECT " + ", ".join(["QUOTE(?)"] * len(params))
# Bypass Django's wrappers and use the underlying sqlite3 connection
# to avoid logging this query - it would trigger infinite recursion.
cursor = self.connection.connection.cursor()
@@ -215,14 +221,20 @@ class DatabaseOperations(BaseDatabaseOperations):
if tables and allow_cascade:
# Simulate TRUNCATE CASCADE by recursively collecting the tables
# referencing the tables to be flushed.
- tables = set(chain.from_iterable(self._references_graph(table) for table in tables))
- sql = ['%s %s %s;' % (
- style.SQL_KEYWORD('DELETE'),
- style.SQL_KEYWORD('FROM'),
- style.SQL_FIELD(self.quote_name(table))
- ) for table in tables]
+ tables = set(
+ chain.from_iterable(self._references_graph(table) for table in tables)
+ )
+ sql = [
+ "%s %s %s;"
+ % (
+ style.SQL_KEYWORD("DELETE"),
+ style.SQL_KEYWORD("FROM"),
+ style.SQL_FIELD(self.quote_name(table)),
+ )
+ for table in tables
+ ]
if reset_sequences:
- sequences = [{'table': table} for table in tables]
+ sequences = [{"table": table} for table in tables]
sql.extend(self.sequence_reset_by_name_sql(style, sequences))
return sql
@@ -230,17 +242,18 @@ class DatabaseOperations(BaseDatabaseOperations):
if not sequences:
return []
return [
- '%s %s %s %s = 0 %s %s %s (%s);' % (
- style.SQL_KEYWORD('UPDATE'),
- style.SQL_TABLE(self.quote_name('sqlite_sequence')),
- style.SQL_KEYWORD('SET'),
- style.SQL_FIELD(self.quote_name('seq')),
- style.SQL_KEYWORD('WHERE'),
- style.SQL_FIELD(self.quote_name('name')),
- style.SQL_KEYWORD('IN'),
- ', '.join([
- "'%s'" % sequence_info['table'] for sequence_info in sequences
- ]),
+ "%s %s %s %s = 0 %s %s %s (%s);"
+ % (
+ style.SQL_KEYWORD("UPDATE"),
+ style.SQL_TABLE(self.quote_name("sqlite_sequence")),
+ style.SQL_KEYWORD("SET"),
+ style.SQL_FIELD(self.quote_name("seq")),
+ style.SQL_KEYWORD("WHERE"),
+ style.SQL_FIELD(self.quote_name("name")),
+ style.SQL_KEYWORD("IN"),
+ ", ".join(
+ ["'%s'" % sequence_info["table"] for sequence_info in sequences]
+ ),
),
]
@@ -249,7 +262,7 @@ class DatabaseOperations(BaseDatabaseOperations):
return None
# Expression values are adapted by the database.
- if hasattr(value, 'resolve_expression'):
+ if hasattr(value, "resolve_expression"):
return value
# SQLite doesn't support tz-aware datetimes
@@ -257,7 +270,9 @@ class DatabaseOperations(BaseDatabaseOperations):
if settings.USE_TZ:
value = timezone.make_naive(value, self.connection.timezone)
else:
- raise ValueError("SQLite backend does not support timezone-aware datetimes when USE_TZ is False.")
+ raise ValueError(
+ "SQLite backend does not support timezone-aware datetimes when USE_TZ is False."
+ )
return str(value)
@@ -266,7 +281,7 @@ class DatabaseOperations(BaseDatabaseOperations):
return None
# Expression values are adapted by the database.
- if hasattr(value, 'resolve_expression'):
+ if hasattr(value, "resolve_expression"):
return value
# SQLite doesn't support tz-aware datetimes
@@ -278,17 +293,17 @@ class DatabaseOperations(BaseDatabaseOperations):
def get_db_converters(self, expression):
converters = super().get_db_converters(expression)
internal_type = expression.output_field.get_internal_type()
- if internal_type == 'DateTimeField':
+ if internal_type == "DateTimeField":
converters.append(self.convert_datetimefield_value)
- elif internal_type == 'DateField':
+ elif internal_type == "DateField":
converters.append(self.convert_datefield_value)
- elif internal_type == 'TimeField':
+ elif internal_type == "TimeField":
converters.append(self.convert_timefield_value)
- elif internal_type == 'DecimalField':
+ elif internal_type == "DecimalField":
converters.append(self.get_decimalfield_converter(expression))
- elif internal_type == 'UUIDField':
+ elif internal_type == "UUIDField":
converters.append(self.convert_uuidfield_value)
- elif internal_type == 'BooleanField':
+ elif internal_type == "BooleanField":
converters.append(self.convert_booleanfield_value)
return converters
@@ -317,15 +332,22 @@ class DatabaseOperations(BaseDatabaseOperations):
# float inaccuracy must be removed.
create_decimal = decimal.Context(prec=15).create_decimal_from_float
if isinstance(expression, Col):
- quantize_value = decimal.Decimal(1).scaleb(-expression.output_field.decimal_places)
+ quantize_value = decimal.Decimal(1).scaleb(
+ -expression.output_field.decimal_places
+ )
def converter(value, expression, connection):
if value is not None:
- return create_decimal(value).quantize(quantize_value, context=expression.output_field.context)
+ return create_decimal(value).quantize(
+ quantize_value, context=expression.output_field.context
+ )
+
else:
+
def converter(value, expression, connection):
if value is not None:
return create_decimal(value)
+
return converter
def convert_uuidfield_value(self, value, expression, connection):
@@ -337,26 +359,26 @@ class DatabaseOperations(BaseDatabaseOperations):
return bool(value) if value in (1, 0) else value
def bulk_insert_sql(self, fields, placeholder_rows):
- placeholder_rows_sql = (', '.join(row) for row in placeholder_rows)
- values_sql = ', '.join(f'({sql})' for sql in placeholder_rows_sql)
- return f'VALUES {values_sql}'
+ placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
+ values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
+ return f"VALUES {values_sql}"
def combine_expression(self, connector, sub_expressions):
# SQLite doesn't have a ^ operator, so use the user-defined POWER
# function that's registered in connect().
- if connector == '^':
- return 'POWER(%s)' % ','.join(sub_expressions)
- elif connector == '#':
- return 'BITXOR(%s)' % ','.join(sub_expressions)
+ if connector == "^":
+ return "POWER(%s)" % ",".join(sub_expressions)
+ elif connector == "#":
+ return "BITXOR(%s)" % ",".join(sub_expressions)
return super().combine_expression(connector, sub_expressions)
def combine_duration_expression(self, connector, sub_expressions):
- if connector not in ['+', '-', '*', '/']:
- raise DatabaseError('Invalid connector for timedelta: %s.' % connector)
+ if connector not in ["+", "-", "*", "/"]:
+ raise DatabaseError("Invalid connector for timedelta: %s." % connector)
fn_params = ["'%s'" % connector] + sub_expressions
if len(fn_params) > 3:
- raise ValueError('Too many params for timedelta operations.')
- return "django_format_dtdelta(%s)" % ', '.join(fn_params)
+ raise ValueError("Too many params for timedelta operations.")
+ return "django_format_dtdelta(%s)" % ", ".join(fn_params)
def integer_field_range(self, internal_type):
# SQLite doesn't enforce any integer constraints
@@ -366,39 +388,46 @@ class DatabaseOperations(BaseDatabaseOperations):
lhs_sql, lhs_params = lhs
rhs_sql, rhs_params = rhs
params = (*lhs_params, *rhs_params)
- if internal_type == 'TimeField':
- return 'django_time_diff(%s, %s)' % (lhs_sql, rhs_sql), params
- return 'django_timestamp_diff(%s, %s)' % (lhs_sql, rhs_sql), params
+ if internal_type == "TimeField":
+ return "django_time_diff(%s, %s)" % (lhs_sql, rhs_sql), params
+ return "django_timestamp_diff(%s, %s)" % (lhs_sql, rhs_sql), params
def insert_statement(self, on_conflict=None):
if on_conflict == OnConflict.IGNORE:
- return 'INSERT OR IGNORE INTO'
+ return "INSERT OR IGNORE INTO"
return super().insert_statement(on_conflict=on_conflict)
def return_insert_columns(self, fields):
# SQLite < 3.35 doesn't support an INSERT...RETURNING statement.
if not fields:
- return '', ()
+ return "", ()
columns = [
- '%s.%s' % (
+ "%s.%s"
+ % (
self.quote_name(field.model._meta.db_table),
self.quote_name(field.column),
- ) for field in fields
+ )
+ for field in fields
]
- return 'RETURNING %s' % ', '.join(columns), ()
+ return "RETURNING %s" % ", ".join(columns), ()
def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
if (
- on_conflict == OnConflict.UPDATE and
- self.connection.features.supports_update_conflicts_with_target
+ on_conflict == OnConflict.UPDATE
+ and self.connection.features.supports_update_conflicts_with_target
):
- return 'ON CONFLICT(%s) DO UPDATE SET %s' % (
- ', '.join(map(self.quote_name, unique_fields)),
- ', '.join([
- f'{field} = EXCLUDED.{field}'
- for field in map(self.quote_name, update_fields)
- ]),
+ return "ON CONFLICT(%s) DO UPDATE SET %s" % (
+ ", ".join(map(self.quote_name, unique_fields)),
+ ", ".join(
+ [
+ f"{field} = EXCLUDED.{field}"
+ for field in map(self.quote_name, update_fields)
+ ]
+ ),
)
return super().on_conflict_suffix_sql(
- fields, on_conflict, update_fields, unique_fields,
+ fields,
+ on_conflict,
+ update_fields,
+ unique_fields,
)