diff options
Diffstat (limited to 'django/db/backends/postgresql')
| -rw-r--r-- | django/db/backends/postgresql/base.py | 139 | ||||
| -rw-r--r-- | django/db/backends/postgresql/creation.py | 4 | ||||
| -rw-r--r-- | django/db/backends/postgresql/introspection.py | 3 |
3 files changed, 140 insertions, 6 deletions
diff --git a/django/db/backends/postgresql/base.py b/django/db/backends/postgresql/base.py index e44bc0b560..fedbb6b7f1 100644 --- a/django/db/backends/postgresql/base.py +++ b/django/db/backends/postgresql/base.py @@ -12,6 +12,7 @@ except ImportError, e: raise ImproperlyConfigured, "Error loading psycopg module: %s" % e DatabaseError = Database.DatabaseError +IntegrityError = Database.IntegrityError try: # Only exists in Python 2.4+ @@ -20,6 +21,40 @@ except ImportError: # Import copy of _thread_local.py from Python 2.4 from django.utils._threading_local import local +def smart_basestring(s, charset): + if isinstance(s, unicode): + return s.encode(charset) + return s + +class UnicodeCursorWrapper(object): + """ + A thin wrapper around psycopg cursors that allows them to accept Unicode + strings as params. + + This is necessary because psycopg doesn't apply any DB quoting to + parameters that are Unicode strings. If a param is Unicode, this will + convert it to a bytestring using DEFAULT_CHARSET before passing it to + psycopg. + """ + def __init__(self, cursor, charset): + self.cursor = cursor + self.charset = charset + + def execute(self, sql, params=()): + return self.cursor.execute(sql, [smart_basestring(p, self.charset) for p in params]) + + def executemany(self, sql, param_list): + new_param_list = [tuple([smart_basestring(p, self.charset) for p in params]) for params in param_list] + return self.cursor.executemany(sql, new_param_list) + + def __getattr__(self, attr): + if attr in self.__dict__: + return self.__dict__[attr] + else: + return getattr(self.cursor, attr) + +postgres_version = None + class DatabaseWrapper(local): def __init__(self, **kwargs): self.connection = None @@ -28,7 +63,9 @@ class DatabaseWrapper(local): def cursor(self): from django.conf import settings + set_tz = False if self.connection is None: + set_tz = True if settings.DATABASE_NAME == '': from django.core.exceptions import ImproperlyConfigured raise ImproperlyConfigured, "You need to specify DATABASE_NAME in your Django settings file." @@ -44,16 +81,23 @@ class DatabaseWrapper(local): self.connection = Database.connect(conn_string, **self.options) self.connection.set_isolation_level(1) # make transactions transparent to all cursors cursor = self.connection.cursor() - cursor.execute("SET TIME ZONE %s", [settings.TIME_ZONE]) + if set_tz: + cursor.execute("SET TIME ZONE %s", [settings.TIME_ZONE]) + cursor = UnicodeCursorWrapper(cursor, settings.DEFAULT_CHARSET) + global postgres_version + if not postgres_version: + cursor.execute("SELECT version()") + postgres_version = [int(val) for val in cursor.fetchone()[0].split()[1].split('.')] if settings.DEBUG: return util.CursorDebugWrapper(cursor, self) return cursor def _commit(self): - return self.connection.commit() + if self.connection is not None: + return self.connection.commit() def _rollback(self): - if self.connection: + if self.connection is not None: return self.connection.rollback() def close(self): @@ -103,6 +147,9 @@ def get_limit_offset_sql(limit, offset=None): def get_random_function_sql(): return "RANDOM()" +def get_deferrable_sql(): + return " DEFERRABLE INITIALLY DEFERRED" + def get_fulltext_search_sql(field_name): raise NotImplementedError @@ -112,6 +159,91 @@ def get_drop_foreignkey_sql(): def get_pk_default_value(): return "DEFAULT" +def get_sql_flush(style, tables, sequences): + """Return a list of SQL statements required to remove all data from + all tables in the database (without actually removing the tables + themselves) and put the database in an empty 'initial' state + + """ + if tables: + if postgres_version[0] >= 8 and postgres_version[1] >= 1: + # Postgres 8.1+ can do 'TRUNCATE x, y, z...;'. In fact, it *has to* in order to be able to + # truncate tables referenced by a foreign key in any other table. The result is a + # single SQL TRUNCATE statement. + sql = ['%s %s;' % \ + (style.SQL_KEYWORD('TRUNCATE'), + style.SQL_FIELD(', '.join([quote_name(table) for table in tables])) + )] + else: + # Older versions of Postgres can't do TRUNCATE in a single call, so they must use + # a simple delete. + sql = ['%s %s %s;' % \ + (style.SQL_KEYWORD('DELETE'), + style.SQL_KEYWORD('FROM'), + style.SQL_FIELD(quote_name(table)) + ) for table in tables] + + # 'ALTER SEQUENCE sequence_name RESTART WITH 1;'... style SQL statements + # to reset sequence indices + for sequence_info in sequences: + table_name = sequence_info['table'] + column_name = sequence_info['column'] + if column_name and len(column_name)>0: + # sequence name in this case will be <table>_<column>_seq + sql.append("%s %s %s %s %s %s;" % \ + (style.SQL_KEYWORD('ALTER'), + style.SQL_KEYWORD('SEQUENCE'), + style.SQL_FIELD(quote_name('%s_%s_seq' % (table_name, column_name))), + style.SQL_KEYWORD('RESTART'), + style.SQL_KEYWORD('WITH'), + style.SQL_FIELD('1') + ) + ) + else: + # sequence name in this case will be <table>_id_seq + sql.append("%s %s %s %s %s %s;" % \ + (style.SQL_KEYWORD('ALTER'), + style.SQL_KEYWORD('SEQUENCE'), + style.SQL_FIELD(quote_name('%s_id_seq' % table_name)), + style.SQL_KEYWORD('RESTART'), + style.SQL_KEYWORD('WITH'), + style.SQL_FIELD('1') + ) + ) + return sql + else: + return [] + +def get_sql_sequence_reset(style, model_list): + "Returns a list of the SQL statements to reset sequences for the given models." + from django.db import models + output = [] + for model in model_list: + # Use `coalesce` to set the sequence for each model to the max pk value if there are records, + # or 1 if there are none. Set the `is_called` property (the third argument to `setval`) to true + # if there are records (as the max pk value is already in use), otherwise set it to false. + for f in model._meta.fields: + if isinstance(f, models.AutoField): + output.append("%s setval('%s', coalesce(max(%s), 1), max(%s) %s null) %s %s;" % \ + (style.SQL_KEYWORD('SELECT'), + style.SQL_FIELD(quote_name('%s_%s_seq' % (model._meta.db_table, f.column))), + style.SQL_FIELD(quote_name(f.column)), + style.SQL_FIELD(quote_name(f.column)), + style.SQL_KEYWORD('IS NOT'), + style.SQL_KEYWORD('FROM'), + style.SQL_TABLE(quote_name(model._meta.db_table)))) + break # Only one AutoField is allowed per model, so don't bother continuing. + for f in model._meta.many_to_many: + output.append("%s setval('%s', coalesce(max(%s), 1), max(%s) %s null) %s %s;" % \ + (style.SQL_KEYWORD('SELECT'), + style.SQL_FIELD(quote_name('%s_id_seq' % f.m2m_db_table())), + style.SQL_FIELD(quote_name('id')), + style.SQL_FIELD(quote_name('id')), + style.SQL_KEYWORD('IS NOT'), + style.SQL_KEYWORD('FROM'), + style.SQL_TABLE(f.m2m_db_table()))) + return output + # Register these custom typecasts, because Django expects dates/times to be # in Python's native (standard-library) datetime/time format, whereas psycopg # use mx.DateTime by default. @@ -122,6 +254,7 @@ except AttributeError: Database.register_type(Database.new_type((1083,1266), "TIME", util.typecast_time)) Database.register_type(Database.new_type((1114,1184), "TIMESTAMP", util.typecast_timestamp)) Database.register_type(Database.new_type((16,), "BOOLEAN", util.typecast_boolean)) +Database.register_type(Database.new_type((1700,), "NUMERIC", util.typecast_decimal)) OPERATOR_MAPPING = { 'exact': '= %s', diff --git a/django/db/backends/postgresql/creation.py b/django/db/backends/postgresql/creation.py index 65a804ec40..4646b68ab8 100644 --- a/django/db/backends/postgresql/creation.py +++ b/django/db/backends/postgresql/creation.py @@ -9,9 +9,10 @@ DATA_TYPES = { 'CommaSeparatedIntegerField': 'varchar(%(maxlength)s)', 'DateField': 'date', 'DateTimeField': 'timestamp with time zone', + 'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)', 'FileField': 'varchar(100)', 'FilePathField': 'varchar(100)', - 'FloatField': 'numeric(%(max_digits)s, %(decimal_places)s)', + 'FloatField': 'double precision', 'ImageField': 'varchar(100)', 'IntegerField': 'integer', 'IPAddressField': 'inet', @@ -25,6 +26,5 @@ DATA_TYPES = { 'SmallIntegerField': 'smallint', 'TextField': 'text', 'TimeField': 'time', - 'URLField': 'varchar(200)', 'USStateField': 'varchar(2)', } diff --git a/django/db/backends/postgresql/introspection.py b/django/db/backends/postgresql/introspection.py index 6e1d60c4ff..2605490afd 100644 --- a/django/db/backends/postgresql/introspection.py +++ b/django/db/backends/postgresql/introspection.py @@ -72,6 +72,7 @@ DATA_TYPES_REVERSE = { 21: 'SmallIntegerField', 23: 'IntegerField', 25: 'TextField', + 701: 'FloatField', 869: 'IPAddressField', 1043: 'CharField', 1082: 'DateField', @@ -79,5 +80,5 @@ DATA_TYPES_REVERSE = { 1114: 'DateTimeField', 1184: 'DateTimeField', 1266: 'TimeField', - 1700: 'FloatField', + 1700: 'DecimalField', } |
