From 2d991ff661f72195f3a57be66d2bbc761c923f7e Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Fri, 17 Dec 2021 07:16:31 +0000 Subject: Refs #33355 -- Moved SQLite functions to separate module. Co-Authored-By: Nick Pope --- django/db/backends/sqlite3/base.py | 293 +------------------------------------ 1 file changed, 4 insertions(+), 289 deletions(-) (limited to 'django/db/backends/sqlite3/base.py') diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index 7ec6466b1b..4343ea180e 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -1,32 +1,19 @@ """ SQLite backend for the sqlite3 module in the standard library. """ -import datetime import decimal -import functools -import hashlib -import math -import operator -import random -import re -import statistics import warnings from itertools import chain from sqlite3 import dbapi2 as Database from django.core.exceptions import ImproperlyConfigured from django.db import IntegrityError -from django.db.backends import utils as backend_utils -from django.db.backends.base.base import ( - BaseDatabaseWrapper, timezone_constructor, -) -from django.utils import timezone +from django.db.backends.base.base import BaseDatabaseWrapper from django.utils.asyncio import async_unsafe -from django.utils.crypto import md5 from django.utils.dateparse import parse_datetime, parse_time -from django.utils.duration import duration_microseconds from django.utils.regex_helper import _lazy_re_compile +from ._functions import register as register_functions from .client import DatabaseClient from .creation import DatabaseCreation from .features import DatabaseFeatures @@ -42,27 +29,6 @@ def decoder(conv_func): return lambda s: conv_func(s.decode()) -def none_guard(func): - """ - Decorator that returns None if any of the arguments to the decorated - function are None. Many SQL functions return NULL if any of their arguments - are NULL. This decorator simplifies the implementation of this for the - custom functions registered below. - """ - @functools.wraps(func) - def wrapper(*args, **kwargs): - return None if None in args else func(*args, **kwargs) - return wrapper - - -def list_aggregate(function): - """ - Return an aggregate class that accumulates values in a list and applies - the provided function to the data. - """ - return type('ListAggregate', (list,), {'finalize': function, 'step': list.append}) - - def check_sqlite_version(): if Database.sqlite_version_info < (3, 9, 0): raise ImproperlyConfigured( @@ -204,60 +170,8 @@ class DatabaseWrapper(BaseDatabaseWrapper): @async_unsafe def get_new_connection(self, conn_params): conn = Database.connect(**conn_params) - create_deterministic_function = functools.partial( - conn.create_function, - deterministic=True, - ) - create_deterministic_function('django_date_extract', 2, _sqlite_datetime_extract) - create_deterministic_function('django_date_trunc', 4, _sqlite_date_trunc) - create_deterministic_function('django_datetime_cast_date', 3, _sqlite_datetime_cast_date) - create_deterministic_function('django_datetime_cast_time', 3, _sqlite_datetime_cast_time) - create_deterministic_function('django_datetime_extract', 4, _sqlite_datetime_extract) - create_deterministic_function('django_datetime_trunc', 4, _sqlite_datetime_trunc) - create_deterministic_function('django_time_extract', 2, _sqlite_time_extract) - create_deterministic_function('django_time_trunc', 4, _sqlite_time_trunc) - create_deterministic_function('django_time_diff', 2, _sqlite_time_diff) - create_deterministic_function('django_timestamp_diff', 2, _sqlite_timestamp_diff) - create_deterministic_function('django_format_dtdelta', 3, _sqlite_format_dtdelta) - create_deterministic_function('regexp', 2, _sqlite_regexp) - create_deterministic_function('ACOS', 1, none_guard(math.acos)) - create_deterministic_function('ASIN', 1, none_guard(math.asin)) - create_deterministic_function('ATAN', 1, none_guard(math.atan)) - create_deterministic_function('ATAN2', 2, none_guard(math.atan2)) - create_deterministic_function('BITXOR', 2, none_guard(operator.xor)) - create_deterministic_function('CEILING', 1, none_guard(math.ceil)) - create_deterministic_function('COS', 1, none_guard(math.cos)) - create_deterministic_function('COT', 1, none_guard(lambda x: 1 / math.tan(x))) - create_deterministic_function('DEGREES', 1, none_guard(math.degrees)) - create_deterministic_function('EXP', 1, none_guard(math.exp)) - create_deterministic_function('FLOOR', 1, none_guard(math.floor)) - create_deterministic_function('LN', 1, none_guard(math.log)) - create_deterministic_function('LOG', 2, none_guard(lambda x, y: math.log(y, x))) - create_deterministic_function('LPAD', 3, _sqlite_lpad) - create_deterministic_function('MD5', 1, none_guard(lambda x: md5(x.encode()).hexdigest())) - create_deterministic_function('MOD', 2, none_guard(math.fmod)) - create_deterministic_function('PI', 0, lambda: math.pi) - create_deterministic_function('POWER', 2, none_guard(operator.pow)) - create_deterministic_function('RADIANS', 1, none_guard(math.radians)) - create_deterministic_function('REPEAT', 2, none_guard(operator.mul)) - create_deterministic_function('REVERSE', 1, none_guard(lambda x: x[::-1])) - create_deterministic_function('RPAD', 3, _sqlite_rpad) - create_deterministic_function('SHA1', 1, none_guard(lambda x: hashlib.sha1(x.encode()).hexdigest())) - create_deterministic_function('SHA224', 1, none_guard(lambda x: hashlib.sha224(x.encode()).hexdigest())) - create_deterministic_function('SHA256', 1, none_guard(lambda x: hashlib.sha256(x.encode()).hexdigest())) - create_deterministic_function('SHA384', 1, none_guard(lambda x: hashlib.sha384(x.encode()).hexdigest())) - create_deterministic_function('SHA512', 1, none_guard(lambda x: hashlib.sha512(x.encode()).hexdigest())) - create_deterministic_function('SIGN', 1, none_guard(lambda x: (x > 0) - (x < 0))) - create_deterministic_function('SIN', 1, none_guard(math.sin)) - create_deterministic_function('SQRT', 1, none_guard(math.sqrt)) - create_deterministic_function('TAN', 1, none_guard(math.tan)) - # Don't use the built-in RANDOM() function because it returns a value - # in the range [-1 * 2^63, 2^63 - 1] instead of [0, 1). - conn.create_function('RAND', 0, random.random) - conn.create_aggregate('STDDEV_POP', 1, list_aggregate(statistics.pstdev)) - conn.create_aggregate('STDDEV_SAMP', 1, list_aggregate(statistics.stdev)) - conn.create_aggregate('VAR_POP', 1, list_aggregate(statistics.pvariance)) - conn.create_aggregate('VAR_SAMP', 1, list_aggregate(statistics.variance)) + register_functions(conn) + conn.execute('PRAGMA foreign_keys = ON') # The macOS bundled SQLite defaults legacy_alter_table ON, which # prevents atomic table renames (feature supports_atomic_references_rename) @@ -425,202 +339,3 @@ class SQLiteCursorWrapper(Database.Cursor): def convert_query(self, query): return FORMAT_QMARK_REGEX.sub('?', query).replace('%%', '%') - - -def _sqlite_datetime_parse(dt, tzname=None, conn_tzname=None): - if dt is None: - return None - try: - dt = backend_utils.typecast_timestamp(dt) - except (TypeError, ValueError): - return None - if conn_tzname: - dt = dt.replace(tzinfo=timezone_constructor(conn_tzname)) - if tzname is not None and tzname != conn_tzname: - tzname, sign, offset = backend_utils.split_tzname_delta(tzname) - if offset: - hours, minutes = offset.split(':') - offset_delta = datetime.timedelta(hours=int(hours), minutes=int(minutes)) - dt += offset_delta if sign == '+' else -offset_delta - dt = timezone.localtime(dt, timezone_constructor(tzname)) - return dt - - -def _sqlite_date_trunc(lookup_type, dt, tzname, conn_tzname): - dt = _sqlite_datetime_parse(dt, tzname, conn_tzname) - if dt is None: - return None - if lookup_type == 'year': - return f'{dt.year:04d}-01-01' - elif lookup_type == 'quarter': - month_in_quarter = dt.month - (dt.month - 1) % 3 - return f'{dt.year:04d}-{month_in_quarter:02d}-01' - elif lookup_type == 'month': - return f'{dt.year:04d}-{dt.month:02d}-01' - elif lookup_type == 'week': - dt = dt - datetime.timedelta(days=dt.weekday()) - return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d}' - elif lookup_type == 'day': - return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d}' - - -def _sqlite_time_trunc(lookup_type, dt, tzname, conn_tzname): - if dt is None: - return None - dt_parsed = _sqlite_datetime_parse(dt, tzname, conn_tzname) - if dt_parsed is None: - try: - dt = backend_utils.typecast_time(dt) - except (ValueError, TypeError): - return None - else: - dt = dt_parsed - if lookup_type == 'hour': - return f'{dt.hour:02d}:00:00' - elif lookup_type == 'minute': - return f'{dt.hour:02d}:{dt.minute:02d}:00' - elif lookup_type == 'second': - return f'{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}' - - -def _sqlite_datetime_cast_date(dt, tzname, conn_tzname): - dt = _sqlite_datetime_parse(dt, tzname, conn_tzname) - if dt is None: - return None - return dt.date().isoformat() - - -def _sqlite_datetime_cast_time(dt, tzname, conn_tzname): - dt = _sqlite_datetime_parse(dt, tzname, conn_tzname) - if dt is None: - return None - return dt.time().isoformat() - - -def _sqlite_datetime_extract(lookup_type, dt, tzname=None, conn_tzname=None): - dt = _sqlite_datetime_parse(dt, tzname, conn_tzname) - if dt is None: - return None - if lookup_type == 'week_day': - return (dt.isoweekday() % 7) + 1 - elif lookup_type == 'iso_week_day': - return dt.isoweekday() - elif lookup_type == 'week': - return dt.isocalendar()[1] - elif lookup_type == 'quarter': - return math.ceil(dt.month / 3) - elif lookup_type == 'iso_year': - return dt.isocalendar()[0] - else: - return getattr(dt, lookup_type) - - -def _sqlite_datetime_trunc(lookup_type, dt, tzname, conn_tzname): - dt = _sqlite_datetime_parse(dt, tzname, conn_tzname) - if dt is None: - return None - if lookup_type == 'year': - return f'{dt.year:04d}-01-01 00:00:00' - elif lookup_type == 'quarter': - month_in_quarter = dt.month - (dt.month - 1) % 3 - return f'{dt.year:04d}-{month_in_quarter:02d}-01 00:00:00' - elif lookup_type == 'month': - return f'{dt.year:04d}-{dt.month:02d}-01 00:00:00' - elif lookup_type == 'week': - dt = dt - datetime.timedelta(days=dt.weekday()) - return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d} 00:00:00' - elif lookup_type == 'day': - return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d} 00:00:00' - elif lookup_type == 'hour': - return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d} {dt.hour:02d}:00:00' - elif lookup_type == 'minute': - return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d} {dt.hour:02d}:{dt.minute:02d}:00' - elif lookup_type == 'second': - return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d} {dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}' - - -def _sqlite_time_extract(lookup_type, dt): - if dt is None: - return None - try: - dt = backend_utils.typecast_time(dt) - except (ValueError, TypeError): - return None - return getattr(dt, lookup_type) - - -def _sqlite_prepare_dtdelta_param(conn, param): - if conn in ['+', '-']: - if isinstance(param, int): - return datetime.timedelta(0, 0, param) - else: - return backend_utils.typecast_timestamp(param) - return param - - -@none_guard -def _sqlite_format_dtdelta(conn, lhs, rhs): - """ - LHS and RHS can be either: - - An integer number of microseconds - - A string representing a datetime - - A scalar value, e.g. float - """ - conn = conn.strip() - try: - real_lhs = _sqlite_prepare_dtdelta_param(conn, lhs) - real_rhs = _sqlite_prepare_dtdelta_param(conn, rhs) - except (ValueError, TypeError): - return None - if conn == '+': - # typecast_timestamp returns a date or a datetime without timezone. - # It will be formatted as "%Y-%m-%d" or "%Y-%m-%d %H:%M:%S[.%f]" - out = str(real_lhs + real_rhs) - elif conn == '-': - out = str(real_lhs - real_rhs) - elif conn == '*': - out = real_lhs * real_rhs - else: - out = real_lhs / real_rhs - return out - - -@none_guard -def _sqlite_time_diff(lhs, rhs): - left = backend_utils.typecast_time(lhs) - right = backend_utils.typecast_time(rhs) - return ( - (left.hour * 60 * 60 * 1000000) + - (left.minute * 60 * 1000000) + - (left.second * 1000000) + - (left.microsecond) - - (right.hour * 60 * 60 * 1000000) - - (right.minute * 60 * 1000000) - - (right.second * 1000000) - - (right.microsecond) - ) - - -@none_guard -def _sqlite_timestamp_diff(lhs, rhs): - left = backend_utils.typecast_timestamp(lhs) - right = backend_utils.typecast_timestamp(rhs) - return duration_microseconds(left - right) - - -@none_guard -def _sqlite_regexp(re_pattern, re_string): - return bool(re.search(re_pattern, str(re_string))) - - -@none_guard -def _sqlite_lpad(text, length, fill_text): - delta = length - len(text) - if delta <= 0: - return text[:length] - return (fill_text * length)[:delta] + text - - -@none_guard -def _sqlite_rpad(text, length, fill_text): - return (text + fill_text * length)[:length] -- cgit v1.3