diff options
Diffstat (limited to 'django/db/backends/sqlite3/introspection.py')
| -rw-r--r-- | django/db/backends/sqlite3/introspection.py | 241 |
1 files changed, 146 insertions, 95 deletions
diff --git a/django/db/backends/sqlite3/introspection.py b/django/db/backends/sqlite3/introspection.py index 81884a7951..f5a5e81e9d 100644 --- a/django/db/backends/sqlite3/introspection.py +++ b/django/db/backends/sqlite3/introspection.py @@ -3,19 +3,21 @@ from collections import namedtuple import sqlparse from django.db import DatabaseError -from django.db.backends.base.introspection import ( - BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo, -) +from django.db.backends.base.introspection import BaseDatabaseIntrospection +from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo +from django.db.backends.base.introspection import TableInfo from django.db.models import Index from django.utils.regex_helper import _lazy_re_compile -FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('pk', 'has_json_constraint')) +FieldInfo = namedtuple( + "FieldInfo", BaseFieldInfo._fields + ("pk", "has_json_constraint") +) -field_size_re = _lazy_re_compile(r'^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$') +field_size_re = _lazy_re_compile(r"^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$") def get_field_size(name): - """ Extract the size number from a "varchar(11)" type name """ + """Extract the size number from a "varchar(11)" type name""" m = field_size_re.search(name) return int(m[1]) if m else None @@ -28,29 +30,29 @@ class FlexibleFieldLookupDict: # entries here because SQLite allows for anything and doesn't normalize the # field type; it uses whatever was given. base_data_types_reverse = { - 'bool': 'BooleanField', - 'boolean': 'BooleanField', - 'smallint': 'SmallIntegerField', - 'smallint unsigned': 'PositiveSmallIntegerField', - 'smallinteger': 'SmallIntegerField', - 'int': 'IntegerField', - 'integer': 'IntegerField', - 'bigint': 'BigIntegerField', - 'integer unsigned': 'PositiveIntegerField', - 'bigint unsigned': 'PositiveBigIntegerField', - 'decimal': 'DecimalField', - 'real': 'FloatField', - 'text': 'TextField', - 'char': 'CharField', - 'varchar': 'CharField', - 'blob': 'BinaryField', - 'date': 'DateField', - 'datetime': 'DateTimeField', - 'time': 'TimeField', + "bool": "BooleanField", + "boolean": "BooleanField", + "smallint": "SmallIntegerField", + "smallint unsigned": "PositiveSmallIntegerField", + "smallinteger": "SmallIntegerField", + "int": "IntegerField", + "integer": "IntegerField", + "bigint": "BigIntegerField", + "integer unsigned": "PositiveIntegerField", + "bigint unsigned": "PositiveBigIntegerField", + "decimal": "DecimalField", + "real": "FloatField", + "text": "TextField", + "char": "CharField", + "varchar": "CharField", + "blob": "BinaryField", + "date": "DateField", + "datetime": "DateTimeField", + "time": "TimeField", } def __getitem__(self, key): - key = key.lower().split('(', 1)[0].strip() + key = key.lower().split("(", 1)[0].strip() return self.base_data_types_reverse[key] @@ -59,22 +61,28 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): def get_field_type(self, data_type, description): field_type = super().get_field_type(data_type, description) - if description.pk and field_type in {'BigIntegerField', 'IntegerField', 'SmallIntegerField'}: + if description.pk and field_type in { + "BigIntegerField", + "IntegerField", + "SmallIntegerField", + }: # No support for BigAutoField or SmallAutoField as SQLite treats # all integer primary keys as signed 64-bit integers. - return 'AutoField' + return "AutoField" if description.has_json_constraint: - return 'JSONField' + return "JSONField" return field_type def get_table_list(self, cursor): """Return a list of table and view names in the current database.""" # Skip the sqlite_sequence system table used for autoincrement key # generation. - cursor.execute(""" + cursor.execute( + """ SELECT name, type FROM sqlite_master WHERE type in ('table', 'view') AND NOT name='sqlite_sequence' - ORDER BY name""") + ORDER BY name""" + ) return [TableInfo(row[0], row[1][0]) for row in cursor.fetchall()] def get_table_description(self, cursor, table_name): @@ -82,37 +90,51 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): Return a description of the table with the DB-API cursor.description interface. """ - cursor.execute('PRAGMA table_info(%s)' % self.connection.ops.quote_name(table_name)) + cursor.execute( + "PRAGMA table_info(%s)" % self.connection.ops.quote_name(table_name) + ) table_info = cursor.fetchall() if not table_info: - raise DatabaseError(f'Table {table_name} does not exist (empty pragma).') + raise DatabaseError(f"Table {table_name} does not exist (empty pragma).") collations = self._get_column_collations(cursor, table_name) json_columns = set() if self.connection.features.can_introspect_json_field: for line in table_info: column = line[1] json_constraint_sql = '%%json_valid("%s")%%' % column - has_json_constraint = cursor.execute(""" + has_json_constraint = cursor.execute( + """ SELECT sql FROM sqlite_master WHERE type = 'table' AND name = %s AND sql LIKE %s - """, [table_name, json_constraint_sql]).fetchone() + """, + [table_name, json_constraint_sql], + ).fetchone() if has_json_constraint: json_columns.add(column) return [ FieldInfo( - name, data_type, None, get_field_size(data_type), None, None, - not notnull, default, collations.get(name), pk == 1, name in json_columns + name, + data_type, + None, + get_field_size(data_type), + None, + None, + not notnull, + default, + collations.get(name), + pk == 1, + name in json_columns, ) for cid, name, data_type, notnull, default, pk in table_info ] def get_sequences(self, cursor, table_name, table_fields=()): pk_col = self.get_primary_key_column(cursor, table_name) - return [{'table': table_name, 'column': pk_col}] + return [{"table": table_name, "column": pk_col}] def get_relations(self, cursor, table_name): """ @@ -120,7 +142,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): representing all foreign keys in the given table. """ cursor.execute( - 'PRAGMA foreign_key_list(%s)' % self.connection.ops.quote_name(table_name) + "PRAGMA foreign_key_list(%s)" % self.connection.ops.quote_name(table_name) ) return { column_name: (ref_column_name, ref_table_name) @@ -130,7 +152,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): def get_primary_key_column(self, cursor, table_name): """Return the column name of the primary key for the given table.""" cursor.execute( - 'PRAGMA table_info(%s)' % self.connection.ops.quote_name(table_name) + "PRAGMA table_info(%s)" % self.connection.ops.quote_name(table_name) ) for _, name, *_, pk in cursor.fetchall(): if pk: @@ -148,19 +170,21 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): check_columns = [] braces_deep = 0 for token in tokens: - if token.match(sqlparse.tokens.Punctuation, '('): + if token.match(sqlparse.tokens.Punctuation, "("): braces_deep += 1 - elif token.match(sqlparse.tokens.Punctuation, ')'): + elif token.match(sqlparse.tokens.Punctuation, ")"): braces_deep -= 1 if braces_deep < 0: # End of columns and constraints for table definition. break - elif braces_deep == 0 and token.match(sqlparse.tokens.Punctuation, ','): + elif braces_deep == 0 and token.match(sqlparse.tokens.Punctuation, ","): # End of current column or constraint definition. break # Detect column or constraint definition by first token. if is_constraint_definition is None: - is_constraint_definition = token.match(sqlparse.tokens.Keyword, 'CONSTRAINT') + is_constraint_definition = token.match( + sqlparse.tokens.Keyword, "CONSTRAINT" + ) if is_constraint_definition: continue if is_constraint_definition: @@ -171,7 +195,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): elif token.ttype == sqlparse.tokens.Literal.String.Symbol: constraint_name = token.value[1:-1] # Start constraint columns parsing after UNIQUE keyword. - if token.match(sqlparse.tokens.Keyword, 'UNIQUE'): + if token.match(sqlparse.tokens.Keyword, "UNIQUE"): unique = True unique_braces_deep = braces_deep elif unique: @@ -191,10 +215,10 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): field_name = token.value elif token.ttype == sqlparse.tokens.Literal.String.Symbol: field_name = token.value[1:-1] - if token.match(sqlparse.tokens.Keyword, 'UNIQUE'): + if token.match(sqlparse.tokens.Keyword, "UNIQUE"): unique_columns = [field_name] # Start constraint columns parsing after CHECK keyword. - if token.match(sqlparse.tokens.Keyword, 'CHECK'): + if token.match(sqlparse.tokens.Keyword, "CHECK"): check = True check_braces_deep = braces_deep elif check: @@ -209,22 +233,30 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): elif token.ttype == sqlparse.tokens.Literal.String.Symbol: if token.value[1:-1] in columns: check_columns.append(token.value[1:-1]) - unique_constraint = { - 'unique': True, - 'columns': unique_columns, - 'primary_key': False, - 'foreign_key': None, - 'check': False, - 'index': False, - } if unique_columns else None - check_constraint = { - 'check': True, - 'columns': check_columns, - 'primary_key': False, - 'unique': False, - 'foreign_key': None, - 'index': False, - } if check_columns else None + unique_constraint = ( + { + "unique": True, + "columns": unique_columns, + "primary_key": False, + "foreign_key": None, + "check": False, + "index": False, + } + if unique_columns + else None + ) + check_constraint = ( + { + "check": True, + "columns": check_columns, + "primary_key": False, + "unique": False, + "foreign_key": None, + "index": False, + } + if check_columns + else None + ) return constraint_name, unique_constraint, check_constraint, token def _parse_table_constraints(self, sql, columns): @@ -236,24 +268,33 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): tokens = (token for token in statement.flatten() if not token.is_whitespace) # Go to columns and constraint definition for token in tokens: - if token.match(sqlparse.tokens.Punctuation, '('): + if token.match(sqlparse.tokens.Punctuation, "("): break # Parse columns and constraint definition while True: - constraint_name, unique, check, end_token = self._parse_column_or_constraint_definition(tokens, columns) + ( + constraint_name, + unique, + check, + end_token, + ) = self._parse_column_or_constraint_definition(tokens, columns) if unique: if constraint_name: constraints[constraint_name] = unique else: unnamed_constrains_index += 1 - constraints['__unnamed_constraint_%s__' % unnamed_constrains_index] = unique + constraints[ + "__unnamed_constraint_%s__" % unnamed_constrains_index + ] = unique if check: if constraint_name: constraints[constraint_name] = check else: unnamed_constrains_index += 1 - constraints['__unnamed_constraint_%s__' % unnamed_constrains_index] = check - if end_token.match(sqlparse.tokens.Punctuation, ')'): + constraints[ + "__unnamed_constraint_%s__" % unnamed_constrains_index + ] = check + if end_token.match(sqlparse.tokens.Punctuation, ")"): break return constraints @@ -266,19 +307,22 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): # Find inline check constraints. try: table_schema = cursor.execute( - "SELECT sql FROM sqlite_master WHERE type='table' and name=%s" % ( - self.connection.ops.quote_name(table_name), - ) + "SELECT sql FROM sqlite_master WHERE type='table' and name=%s" + % (self.connection.ops.quote_name(table_name),) ).fetchone()[0] except TypeError: # table_name is a view. pass else: - columns = {info.name for info in self.get_table_description(cursor, table_name)} + columns = { + info.name for info in self.get_table_description(cursor, table_name) + } constraints.update(self._parse_table_constraints(table_schema, columns)) # Get the index info - cursor.execute("PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name)) + cursor.execute( + "PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name) + ) for row in cursor.fetchall(): # SQLite 3.8.9+ has 5 columns, however older versions only give 3 # columns. Discard last 2 columns if there. @@ -288,7 +332,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): "WHERE type='index' AND name=%s" % self.connection.ops.quote_name(index) ) # There's at most one row. - sql, = cursor.fetchone() or (None,) + (sql,) = cursor.fetchone() or (None,) # Inline constraints are already detected in # _parse_table_constraints(). The reasons to avoid fetching inline # constraints from `PRAGMA index_list` are: @@ -299,7 +343,9 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): # An inline constraint continue # Get the index info for that index - cursor.execute('PRAGMA index_info(%s)' % self.connection.ops.quote_name(index)) + cursor.execute( + "PRAGMA index_info(%s)" % self.connection.ops.quote_name(index) + ) for index_rank, column_rank, column in cursor.fetchall(): if index not in constraints: constraints[index] = { @@ -310,14 +356,14 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): "check": False, "index": True, } - constraints[index]['columns'].append(column) + constraints[index]["columns"].append(column) # Add type and column orders for indexes - if constraints[index]['index']: + if constraints[index]["index"]: # SQLite doesn't support any index type other than b-tree - constraints[index]['type'] = Index.suffix + constraints[index]["type"] = Index.suffix orders = self._get_index_columns_orders(sql) if orders is not None: - constraints[index]['orders'] = orders + constraints[index]["orders"] = orders # Get the PK pk_column = self.get_primary_key_column(cursor, table_name) if pk_column: @@ -334,44 +380,49 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): "index": False, } relations = enumerate(self.get_relations(cursor, table_name).items()) - constraints.update({ - f'fk_{index}': { - 'columns': [column_name], - 'primary_key': False, - 'unique': False, - 'foreign_key': (ref_table_name, ref_column_name), - 'check': False, - 'index': False, + constraints.update( + { + f"fk_{index}": { + "columns": [column_name], + "primary_key": False, + "unique": False, + "foreign_key": (ref_table_name, ref_column_name), + "check": False, + "index": False, + } + for index, (column_name, (ref_column_name, ref_table_name)) in relations } - for index, (column_name, (ref_column_name, ref_table_name)) in relations - }) + ) return constraints def _get_index_columns_orders(self, sql): tokens = sqlparse.parse(sql)[0] for token in tokens: if isinstance(token, sqlparse.sql.Parenthesis): - columns = str(token).strip('()').split(', ') - return ['DESC' if info.endswith('DESC') else 'ASC' for info in columns] + columns = str(token).strip("()").split(", ") + return ["DESC" if info.endswith("DESC") else "ASC" for info in columns] return None def _get_column_collations(self, cursor, table_name): - row = cursor.execute(""" + row = cursor.execute( + """ SELECT sql FROM sqlite_master WHERE type = 'table' AND name = %s - """, [table_name]).fetchone() + """, + [table_name], + ).fetchone() if not row: return {} sql = row[0] - columns = str(sqlparse.parse(sql)[0][-1]).strip('()').split(', ') + columns = str(sqlparse.parse(sql)[0][-1]).strip("()").split(", ") collations = {} for column in columns: tokens = column[1:].split() column_name = tokens[0].strip('"') for index, token in enumerate(tokens): - if token == 'COLLATE': + if token == "COLLATE": collation = tokens[index + 1] break else: |
