summaryrefslogtreecommitdiff
path: root/django/db/backends/postgresql_psycopg2/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'django/db/backends/postgresql_psycopg2/base.py')
-rw-r--r--django/db/backends/postgresql_psycopg2/base.py102
1 files changed, 99 insertions, 3 deletions
diff --git a/django/db/backends/postgresql_psycopg2/base.py b/django/db/backends/postgresql_psycopg2/base.py
index 04322332dc..d9ad363ac1 100644
--- a/django/db/backends/postgresql_psycopg2/base.py
+++ b/django/db/backends/postgresql_psycopg2/base.py
@@ -12,6 +12,7 @@ except ImportError, e:
raise ImproperlyConfigured, "Error loading psycopg2 module: %s" % e
DatabaseError = Database.DatabaseError
+IntegrityError = Database.IntegrityError
try:
# Only exists in Python 2.4+
@@ -20,6 +21,8 @@ except ImportError:
# Import copy of _thread_local.py from Python 2.4
from django.utils._threading_local import local
+postgres_version = None
+
class DatabaseWrapper(local):
def __init__(self, **kwargs):
self.connection = None
@@ -28,7 +31,9 @@ class DatabaseWrapper(local):
def cursor(self):
from django.conf import settings
+ set_tz = False
if self.connection is None:
+ set_tz = True
if settings.DATABASE_NAME == '':
from django.core.exceptions import ImproperlyConfigured
raise ImproperlyConfigured, "You need to specify DATABASE_NAME in your Django settings file."
@@ -45,16 +50,22 @@ class DatabaseWrapper(local):
self.connection.set_isolation_level(1) # make transactions transparent to all cursors
cursor = self.connection.cursor()
cursor.tzinfo_factory = None
- cursor.execute("SET TIME ZONE %s", [settings.TIME_ZONE])
+ if set_tz:
+ cursor.execute("SET TIME ZONE %s", [settings.TIME_ZONE])
+ global postgres_version
+ if not postgres_version:
+ cursor.execute("SELECT version()")
+ postgres_version = [int(val) for val in cursor.fetchone()[0].split()[1].split('.')]
if settings.DEBUG:
return util.CursorDebugWrapper(cursor, self)
return cursor
def _commit(self):
- return self.connection.commit()
+ if self.connection is not None:
+ return self.connection.commit()
def _rollback(self):
- if self.connection:
+ if self.connection is not None:
return self.connection.rollback()
def close(self):
@@ -96,6 +107,9 @@ def get_limit_offset_sql(limit, offset=None):
def get_random_function_sql():
return "RANDOM()"
+def get_deferrable_sql():
+ return " DEFERRABLE INITIALLY DEFERRED"
+
def get_fulltext_search_sql(field_name):
raise NotImplementedError
@@ -105,6 +119,88 @@ def get_drop_foreignkey_sql():
def get_pk_default_value():
return "DEFAULT"
+def get_sql_flush(style, tables, sequences):
+ """Return a list of SQL statements required to remove all data from
+ all tables in the database (without actually removing the tables
+ themselves) and put the database in an empty 'initial' state
+ """
+ if tables:
+ if postgres_version[0] >= 8 and postgres_version[1] >= 1:
+ # Postgres 8.1+ can do 'TRUNCATE x, y, z...;'. In fact, it *has to* in order to be able to
+ # truncate tables referenced by a foreign key in any other table. The result is a
+ # single SQL TRUNCATE statement
+ sql = ['%s %s;' % \
+ (style.SQL_KEYWORD('TRUNCATE'),
+ style.SQL_FIELD(', '.join([quote_name(table) for table in tables]))
+ )]
+ else:
+ sql = ['%s %s %s;' % \
+ (style.SQL_KEYWORD('DELETE'),
+ style.SQL_KEYWORD('FROM'),
+ style.SQL_FIELD(quote_name(table))
+ ) for table in tables]
+
+ # 'ALTER SEQUENCE sequence_name RESTART WITH 1;'... style SQL statements
+ # to reset sequence indices
+ for sequence in sequences:
+ table_name = sequence['table']
+ column_name = sequence['column']
+ if column_name and len(column_name) > 0:
+ # sequence name in this case will be <table>_<column>_seq
+ sql.append("%s %s %s %s %s %s;" % \
+ (style.SQL_KEYWORD('ALTER'),
+ style.SQL_KEYWORD('SEQUENCE'),
+ style.SQL_FIELD(quote_name('%s_%s_seq' % (table_name, column_name))),
+ style.SQL_KEYWORD('RESTART'),
+ style.SQL_KEYWORD('WITH'),
+ style.SQL_FIELD('1')
+ )
+ )
+ else:
+ # sequence name in this case will be <table>_id_seq
+ sql.append("%s %s %s %s %s %s;" % \
+ (style.SQL_KEYWORD('ALTER'),
+ style.SQL_KEYWORD('SEQUENCE'),
+ style.SQL_FIELD(quote_name('%s_id_seq' % table_name)),
+ style.SQL_KEYWORD('RESTART'),
+ style.SQL_KEYWORD('WITH'),
+ style.SQL_FIELD('1')
+ )
+ )
+ return sql
+ else:
+ return []
+
+def get_sql_sequence_reset(style, model_list):
+ "Returns a list of the SQL statements to reset sequences for the given models."
+ from django.db import models
+ output = []
+ for model in model_list:
+ # Use `coalesce` to set the sequence for each model to the max pk value if there are records,
+ # or 1 if there are none. Set the `is_called` property (the third argument to `setval`) to true
+ # if there are records (as the max pk value is already in use), otherwise set it to false.
+ for f in model._meta.fields:
+ if isinstance(f, models.AutoField):
+ output.append("%s setval('%s', coalesce(max(%s), 1), max(%s) %s null) %s %s;" % \
+ (style.SQL_KEYWORD('SELECT'),
+ style.SQL_FIELD(quote_name('%s_%s_seq' % (model._meta.db_table, f.column))),
+ style.SQL_FIELD(quote_name(f.column)),
+ style.SQL_FIELD(quote_name(f.column)),
+ style.SQL_KEYWORD('IS NOT'),
+ style.SQL_KEYWORD('FROM'),
+ style.SQL_TABLE(quote_name(model._meta.db_table))))
+ break # Only one AutoField is allowed per model, so don't bother continuing.
+ for f in model._meta.many_to_many:
+ output.append("%s setval('%s', coalesce(max(%s), 1), max(%s) %s null) %s %s;" % \
+ (style.SQL_KEYWORD('SELECT'),
+ style.SQL_FIELD(quote_name('%s_id_seq' % f.m2m_db_table())),
+ style.SQL_FIELD(quote_name('id')),
+ style.SQL_FIELD(quote_name('id')),
+ style.SQL_KEYWORD('IS NOT'),
+ style.SQL_KEYWORD('FROM'),
+ style.SQL_TABLE(f.m2m_db_table())))
+ return output
+
OPERATOR_MAPPING = {
'exact': '= %s',
'iexact': 'ILIKE %s',