summaryrefslogtreecommitdiff
path: root/django/db/backends/postgresql
diff options
context:
space:
mode:
authorChristopher Long <indirecthit@gmail.com>2007-06-17 22:18:54 +0000
committerChristopher Long <indirecthit@gmail.com>2007-06-17 22:18:54 +0000
commitae22b6d403dcf25098c77f0dfcf59ae58b186461 (patch)
treec37fc631e99a7e4d909d6b6d236f495003731ea7 /django/db/backends/postgresql
parent0cf7bc439129c66df8d64601e885f83b256b4f25 (diff)
per-object-permissions: Merged to trunk [5486] NOTE: Not fully tested, will be working on this over the next few weeks.
git-svn-id: http://code.djangoproject.com/svn/django/branches/per-object-permissions@5488 bcc190cf-cafb-0310-a4f2-bffc1f526a37
Diffstat (limited to 'django/db/backends/postgresql')
-rw-r--r--django/db/backends/postgresql/base.py139
-rw-r--r--django/db/backends/postgresql/creation.py4
-rw-r--r--django/db/backends/postgresql/introspection.py3
3 files changed, 140 insertions, 6 deletions
diff --git a/django/db/backends/postgresql/base.py b/django/db/backends/postgresql/base.py
index e44bc0b560..fedbb6b7f1 100644
--- a/django/db/backends/postgresql/base.py
+++ b/django/db/backends/postgresql/base.py
@@ -12,6 +12,7 @@ except ImportError, e:
raise ImproperlyConfigured, "Error loading psycopg module: %s" % e
DatabaseError = Database.DatabaseError
+IntegrityError = Database.IntegrityError
try:
# Only exists in Python 2.4+
@@ -20,6 +21,40 @@ except ImportError:
# Import copy of _thread_local.py from Python 2.4
from django.utils._threading_local import local
+def smart_basestring(s, charset):
+ if isinstance(s, unicode):
+ return s.encode(charset)
+ return s
+
+class UnicodeCursorWrapper(object):
+ """
+ A thin wrapper around psycopg cursors that allows them to accept Unicode
+ strings as params.
+
+ This is necessary because psycopg doesn't apply any DB quoting to
+ parameters that are Unicode strings. If a param is Unicode, this will
+ convert it to a bytestring using DEFAULT_CHARSET before passing it to
+ psycopg.
+ """
+ def __init__(self, cursor, charset):
+ self.cursor = cursor
+ self.charset = charset
+
+ def execute(self, sql, params=()):
+ return self.cursor.execute(sql, [smart_basestring(p, self.charset) for p in params])
+
+ def executemany(self, sql, param_list):
+ new_param_list = [tuple([smart_basestring(p, self.charset) for p in params]) for params in param_list]
+ return self.cursor.executemany(sql, new_param_list)
+
+ def __getattr__(self, attr):
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ else:
+ return getattr(self.cursor, attr)
+
+postgres_version = None
+
class DatabaseWrapper(local):
def __init__(self, **kwargs):
self.connection = None
@@ -28,7 +63,9 @@ class DatabaseWrapper(local):
def cursor(self):
from django.conf import settings
+ set_tz = False
if self.connection is None:
+ set_tz = True
if settings.DATABASE_NAME == '':
from django.core.exceptions import ImproperlyConfigured
raise ImproperlyConfigured, "You need to specify DATABASE_NAME in your Django settings file."
@@ -44,16 +81,23 @@ class DatabaseWrapper(local):
self.connection = Database.connect(conn_string, **self.options)
self.connection.set_isolation_level(1) # make transactions transparent to all cursors
cursor = self.connection.cursor()
- cursor.execute("SET TIME ZONE %s", [settings.TIME_ZONE])
+ if set_tz:
+ cursor.execute("SET TIME ZONE %s", [settings.TIME_ZONE])
+ cursor = UnicodeCursorWrapper(cursor, settings.DEFAULT_CHARSET)
+ global postgres_version
+ if not postgres_version:
+ cursor.execute("SELECT version()")
+ postgres_version = [int(val) for val in cursor.fetchone()[0].split()[1].split('.')]
if settings.DEBUG:
return util.CursorDebugWrapper(cursor, self)
return cursor
def _commit(self):
- return self.connection.commit()
+ if self.connection is not None:
+ return self.connection.commit()
def _rollback(self):
- if self.connection:
+ if self.connection is not None:
return self.connection.rollback()
def close(self):
@@ -103,6 +147,9 @@ def get_limit_offset_sql(limit, offset=None):
def get_random_function_sql():
return "RANDOM()"
+def get_deferrable_sql():
+ return " DEFERRABLE INITIALLY DEFERRED"
+
def get_fulltext_search_sql(field_name):
raise NotImplementedError
@@ -112,6 +159,91 @@ def get_drop_foreignkey_sql():
def get_pk_default_value():
return "DEFAULT"
+def get_sql_flush(style, tables, sequences):
+ """Return a list of SQL statements required to remove all data from
+ all tables in the database (without actually removing the tables
+ themselves) and put the database in an empty 'initial' state
+
+ """
+ if tables:
+ if postgres_version[0] >= 8 and postgres_version[1] >= 1:
+ # Postgres 8.1+ can do 'TRUNCATE x, y, z...;'. In fact, it *has to* in order to be able to
+ # truncate tables referenced by a foreign key in any other table. The result is a
+ # single SQL TRUNCATE statement.
+ sql = ['%s %s;' % \
+ (style.SQL_KEYWORD('TRUNCATE'),
+ style.SQL_FIELD(', '.join([quote_name(table) for table in tables]))
+ )]
+ else:
+ # Older versions of Postgres can't do TRUNCATE in a single call, so they must use
+ # a simple delete.
+ sql = ['%s %s %s;' % \
+ (style.SQL_KEYWORD('DELETE'),
+ style.SQL_KEYWORD('FROM'),
+ style.SQL_FIELD(quote_name(table))
+ ) for table in tables]
+
+ # 'ALTER SEQUENCE sequence_name RESTART WITH 1;'... style SQL statements
+ # to reset sequence indices
+ for sequence_info in sequences:
+ table_name = sequence_info['table']
+ column_name = sequence_info['column']
+ if column_name and len(column_name)>0:
+ # sequence name in this case will be <table>_<column>_seq
+ sql.append("%s %s %s %s %s %s;" % \
+ (style.SQL_KEYWORD('ALTER'),
+ style.SQL_KEYWORD('SEQUENCE'),
+ style.SQL_FIELD(quote_name('%s_%s_seq' % (table_name, column_name))),
+ style.SQL_KEYWORD('RESTART'),
+ style.SQL_KEYWORD('WITH'),
+ style.SQL_FIELD('1')
+ )
+ )
+ else:
+ # sequence name in this case will be <table>_id_seq
+ sql.append("%s %s %s %s %s %s;" % \
+ (style.SQL_KEYWORD('ALTER'),
+ style.SQL_KEYWORD('SEQUENCE'),
+ style.SQL_FIELD(quote_name('%s_id_seq' % table_name)),
+ style.SQL_KEYWORD('RESTART'),
+ style.SQL_KEYWORD('WITH'),
+ style.SQL_FIELD('1')
+ )
+ )
+ return sql
+ else:
+ return []
+
+def get_sql_sequence_reset(style, model_list):
+ "Returns a list of the SQL statements to reset sequences for the given models."
+ from django.db import models
+ output = []
+ for model in model_list:
+ # Use `coalesce` to set the sequence for each model to the max pk value if there are records,
+ # or 1 if there are none. Set the `is_called` property (the third argument to `setval`) to true
+ # if there are records (as the max pk value is already in use), otherwise set it to false.
+ for f in model._meta.fields:
+ if isinstance(f, models.AutoField):
+ output.append("%s setval('%s', coalesce(max(%s), 1), max(%s) %s null) %s %s;" % \
+ (style.SQL_KEYWORD('SELECT'),
+ style.SQL_FIELD(quote_name('%s_%s_seq' % (model._meta.db_table, f.column))),
+ style.SQL_FIELD(quote_name(f.column)),
+ style.SQL_FIELD(quote_name(f.column)),
+ style.SQL_KEYWORD('IS NOT'),
+ style.SQL_KEYWORD('FROM'),
+ style.SQL_TABLE(quote_name(model._meta.db_table))))
+ break # Only one AutoField is allowed per model, so don't bother continuing.
+ for f in model._meta.many_to_many:
+ output.append("%s setval('%s', coalesce(max(%s), 1), max(%s) %s null) %s %s;" % \
+ (style.SQL_KEYWORD('SELECT'),
+ style.SQL_FIELD(quote_name('%s_id_seq' % f.m2m_db_table())),
+ style.SQL_FIELD(quote_name('id')),
+ style.SQL_FIELD(quote_name('id')),
+ style.SQL_KEYWORD('IS NOT'),
+ style.SQL_KEYWORD('FROM'),
+ style.SQL_TABLE(f.m2m_db_table())))
+ return output
+
# Register these custom typecasts, because Django expects dates/times to be
# in Python's native (standard-library) datetime/time format, whereas psycopg
# use mx.DateTime by default.
@@ -122,6 +254,7 @@ except AttributeError:
Database.register_type(Database.new_type((1083,1266), "TIME", util.typecast_time))
Database.register_type(Database.new_type((1114,1184), "TIMESTAMP", util.typecast_timestamp))
Database.register_type(Database.new_type((16,), "BOOLEAN", util.typecast_boolean))
+Database.register_type(Database.new_type((1700,), "NUMERIC", util.typecast_decimal))
OPERATOR_MAPPING = {
'exact': '= %s',
diff --git a/django/db/backends/postgresql/creation.py b/django/db/backends/postgresql/creation.py
index 65a804ec40..4646b68ab8 100644
--- a/django/db/backends/postgresql/creation.py
+++ b/django/db/backends/postgresql/creation.py
@@ -9,9 +9,10 @@ DATA_TYPES = {
'CommaSeparatedIntegerField': 'varchar(%(maxlength)s)',
'DateField': 'date',
'DateTimeField': 'timestamp with time zone',
+ 'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)',
'FileField': 'varchar(100)',
'FilePathField': 'varchar(100)',
- 'FloatField': 'numeric(%(max_digits)s, %(decimal_places)s)',
+ 'FloatField': 'double precision',
'ImageField': 'varchar(100)',
'IntegerField': 'integer',
'IPAddressField': 'inet',
@@ -25,6 +26,5 @@ DATA_TYPES = {
'SmallIntegerField': 'smallint',
'TextField': 'text',
'TimeField': 'time',
- 'URLField': 'varchar(200)',
'USStateField': 'varchar(2)',
}
diff --git a/django/db/backends/postgresql/introspection.py b/django/db/backends/postgresql/introspection.py
index 6e1d60c4ff..2605490afd 100644
--- a/django/db/backends/postgresql/introspection.py
+++ b/django/db/backends/postgresql/introspection.py
@@ -72,6 +72,7 @@ DATA_TYPES_REVERSE = {
21: 'SmallIntegerField',
23: 'IntegerField',
25: 'TextField',
+ 701: 'FloatField',
869: 'IPAddressField',
1043: 'CharField',
1082: 'DateField',
@@ -79,5 +80,5 @@ DATA_TYPES_REVERSE = {
1114: 'DateTimeField',
1184: 'DateTimeField',
1266: 'TimeField',
- 1700: 'FloatField',
+ 1700: 'DecimalField',
}