summaryrefslogtreecommitdiff
path: root/django/db
diff options
context:
space:
mode:
Diffstat (limited to 'django/db')
-rw-r--r--django/db/__init__.py3
-rw-r--r--django/db/backends/ado_mssql/base.py27
-rw-r--r--django/db/backends/ado_mssql/creation.py4
-rw-r--r--django/db/backends/dummy/base.py12
-rw-r--r--django/db/backends/mysql/base.py118
-rw-r--r--django/db/backends/mysql/client.py29
-rw-r--r--django/db/backends/mysql/creation.py4
-rw-r--r--django/db/backends/mysql/introspection.py4
-rw-r--r--django/db/backends/mysql_old/__init__.py0
-rw-r--r--django/db/backends/mysql_old/base.py240
-rw-r--r--django/db/backends/mysql_old/client.py14
-rw-r--r--django/db/backends/mysql_old/creation.py30
-rw-r--r--django/db/backends/mysql_old/introspection.py95
-rw-r--r--django/db/backends/oracle/base.py27
-rw-r--r--django/db/backends/oracle/creation.py4
-rw-r--r--django/db/backends/oracle/introspection.py2
-rw-r--r--django/db/backends/postgresql/base.py139
-rw-r--r--django/db/backends/postgresql/creation.py4
-rw-r--r--django/db/backends/postgresql/introspection.py3
-rw-r--r--django/db/backends/postgresql_psycopg2/base.py102
-rw-r--r--django/db/backends/postgresql_psycopg2/introspection.py3
-rw-r--r--django/db/backends/sqlite3/base.py39
-rw-r--r--django/db/backends/sqlite3/creation.py4
-rw-r--r--django/db/backends/util.py17
-rw-r--r--django/db/models/__init__.py25
-rw-r--r--django/db/models/base.py114
-rw-r--r--django/db/models/fields/__init__.py205
-rw-r--r--django/db/models/fields/generic.py259
-rw-r--r--django/db/models/fields/related.py55
-rw-r--r--django/db/models/loading.py2
-rw-r--r--django/db/models/manager.py8
-rw-r--r--django/db/models/manipulators.py10
-rw-r--r--django/db/models/options.py5
-rw-r--r--django/db/models/query.py229
-rw-r--r--django/db/models/related.py17
-rw-r--r--django/db/transaction.py12
36 files changed, 1351 insertions, 514 deletions
diff --git a/django/db/__init__.py b/django/db/__init__.py
index 4176b5aa79..33223d200a 100644
--- a/django/db/__init__.py
+++ b/django/db/__init__.py
@@ -2,7 +2,7 @@ from django.conf import settings
from django.core import signals
from django.dispatch import dispatcher
-__all__ = ('backend', 'connection', 'DatabaseError')
+__all__ = ('backend', 'connection', 'DatabaseError', 'IntegrityError')
if not settings.DATABASE_ENGINE:
settings.DATABASE_ENGINE = 'dummy'
@@ -29,6 +29,7 @@ runshell = lambda: __import__('django.db.backends.%s.client' % settings.DATABASE
connection = backend.DatabaseWrapper(**settings.DATABASE_OPTIONS)
DatabaseError = backend.DatabaseError
+IntegrityError = backend.IntegrityError
# Register an event that closes the database connection
# when a Django request is finished.
diff --git a/django/db/backends/ado_mssql/base.py b/django/db/backends/ado_mssql/base.py
index 72d2fe083e..52363ed705 100644
--- a/django/db/backends/ado_mssql/base.py
+++ b/django/db/backends/ado_mssql/base.py
@@ -17,6 +17,7 @@ except ImportError:
mx = None
DatabaseError = Database.DatabaseError
+IntegrityError = Database.IntegrityError
# We need to use a special Cursor class because adodbapi expects question-mark
# param style, but Django expects "%s". This cursor converts question marks to
@@ -76,10 +77,11 @@ class DatabaseWrapper(local):
return cursor
def _commit(self):
- return self.connection.commit()
+ if self.connection is not None:
+ return self.connection.commit()
def _rollback(self):
- if self.connection:
+ if self.connection is not None:
return self.connection.rollback()
def close(self):
@@ -125,6 +127,9 @@ def get_limit_offset_sql(limit, offset=None):
def get_random_function_sql():
return "RAND()"
+def get_deferrable_sql():
+ return " DEFERRABLE INITIALLY DEFERRED"
+
def get_fulltext_search_sql(field_name):
raise NotImplementedError
@@ -134,6 +139,24 @@ def get_drop_foreignkey_sql():
def get_pk_default_value():
return "DEFAULT"
+def get_sql_flush(style, tables, sequences):
+ """Return a list of SQL statements required to remove all data from
+ all tables in the database (without actually removing the tables
+ themselves) and put the database in an empty 'initial' state
+ """
+ # Return a list of 'TRUNCATE x;', 'TRUNCATE y;', 'TRUNCATE z;'... style SQL statements
+ # TODO - SQL not actually tested against ADO MSSQL yet!
+ # TODO - autoincrement indices reset required? See other get_sql_flush() implementations
+ sql_list = ['%s %s;' % \
+ (style.SQL_KEYWORD('TRUNCATE'),
+ style.SQL_FIELD(quote_name(table))
+ ) for table in tables]
+
+def get_sql_sequence_reset(style, model_list):
+ "Returns a list of the SQL statements to reset sequences for the given models."
+ # No sequence reset required
+ return []
+
OPERATOR_MAPPING = {
'exact': '= %s',
'iexact': 'LIKE %s',
diff --git a/django/db/backends/ado_mssql/creation.py b/django/db/backends/ado_mssql/creation.py
index 4d85d27ea5..a1098ea43e 100644
--- a/django/db/backends/ado_mssql/creation.py
+++ b/django/db/backends/ado_mssql/creation.py
@@ -5,9 +5,10 @@ DATA_TYPES = {
'CommaSeparatedIntegerField': 'varchar(%(maxlength)s)',
'DateField': 'smalldatetime',
'DateTimeField': 'smalldatetime',
+ 'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)',
'FileField': 'varchar(100)',
'FilePathField': 'varchar(100)',
- 'FloatField': 'numeric(%(max_digits)s, %(decimal_places)s)',
+ 'FloatField': 'double precision',
'ImageField': 'varchar(100)',
'IntegerField': 'int',
'IPAddressField': 'char(15)',
@@ -21,6 +22,5 @@ DATA_TYPES = {
'SmallIntegerField': 'smallint',
'TextField': 'text',
'TimeField': 'time',
- 'URLField': 'varchar(200)',
'USStateField': 'varchar(2)',
}
diff --git a/django/db/backends/dummy/base.py b/django/db/backends/dummy/base.py
index f98afc48bb..d0ec897407 100644
--- a/django/db/backends/dummy/base.py
+++ b/django/db/backends/dummy/base.py
@@ -12,13 +12,19 @@ from django.core.exceptions import ImproperlyConfigured
def complain(*args, **kwargs):
raise ImproperlyConfigured, "You haven't set the DATABASE_ENGINE setting yet."
+def ignore(*args, **kwargs):
+ pass
+
class DatabaseError(Exception):
pass
+class IntegrityError(DatabaseError):
+ pass
+
class DatabaseWrapper:
cursor = complain
_commit = complain
- _rollback = complain
+ _rollback = ignore
def __init__(self, **kwargs):
pass
@@ -36,6 +42,10 @@ get_date_extract_sql = complain
get_date_trunc_sql = complain
get_limit_offset_sql = complain
get_random_function_sql = complain
+get_deferrable_sql = complain
get_fulltext_search_sql = complain
get_drop_foreignkey_sql = complain
+get_sql_flush = complain
+get_sql_sequence_reset = complain
+
OPERATOR_MAPPING = {}
diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py
index e7e060e6c2..d4cb1fa964 100644
--- a/django/db/backends/mysql/base.py
+++ b/django/db/backends/mysql/base.py
@@ -10,19 +10,34 @@ try:
except ImportError, e:
from django.core.exceptions import ImproperlyConfigured
raise ImproperlyConfigured, "Error loading MySQLdb module: %s" % e
+
+# We want version (1, 2, 1, 'final', 2) or later. We can't just use
+# lexicographic ordering in this check because then (1, 2, 1, 'gamma')
+# inadvertently passes the version test.
+version = Database.version_info
+if (version < (1,2,1) or (version[:3] == (1, 2, 1) and
+ (len(version) < 5 or version[3] != 'final' or version[4] < 2))):
+ raise ImportError, "MySQLdb-1.2.1p2 or newer is required; you have %s" % Database.__version__
+
from MySQLdb.converters import conversions
from MySQLdb.constants import FIELD_TYPE
import types
import re
DatabaseError = Database.DatabaseError
+IntegrityError = Database.IntegrityError
+# MySQLdb-1.2.1 supports the Python boolean type, and only uses datetime
+# module for time-related columns; older versions could have used mx.DateTime
+# or strings if there were no datetime module. However, MySQLdb still returns
+# TIME columns as timedelta -- they are more like timedelta in terms of actual
+# behavior as they are signed and include days -- and Django expects time, so
+# we still need to override that.
django_conversions = conversions.copy()
django_conversions.update({
- types.BooleanType: util.rev_typecast_boolean,
- FIELD_TYPE.DATETIME: util.typecast_timestamp,
- FIELD_TYPE.DATE: util.typecast_date,
FIELD_TYPE.TIME: util.typecast_time,
+ FIELD_TYPE.DECIMAL: util.typecast_decimal,
+ FIELD_TYPE.NEWDECIMAL: util.typecast_decimal,
})
# This should match the numerical portion of the version numbers (we can treat
@@ -31,31 +46,12 @@ django_conversions.update({
# http://dev.mysql.com/doc/refman/5.0/en/news.html .
server_version_re = re.compile(r'(\d{1,2})\.(\d{1,2})\.(\d{1,2})')
-# This is an extra debug layer over MySQL queries, to display warnings.
-# It's only used when DEBUG=True.
-class MysqlDebugWrapper:
- def __init__(self, cursor):
- self.cursor = cursor
-
- def execute(self, sql, params=()):
- try:
- return self.cursor.execute(sql, params)
- except Database.Warning, w:
- self.cursor.execute("SHOW WARNINGS")
- raise Database.Warning, "%s: %s" % (w, self.cursor.fetchall())
-
- def executemany(self, sql, param_list):
- try:
- return self.cursor.executemany(sql, param_list)
- except Database.Warning, w:
- self.cursor.execute("SHOW WARNINGS")
- raise Database.Warning, "%s: %s" % (w, self.cursor.fetchall())
-
- def __getattr__(self, attr):
- if self.__dict__.has_key(attr):
- return self.__dict__[attr]
- else:
- return getattr(self.cursor, attr)
+# MySQLdb-1.2.1 and newer automatically makes use of SHOW WARNINGS on
+# MySQL-4.1 and newer, so the MysqlDebugWrapper is unnecessary. Since the
+# point is to raise Warnings as exceptions, this can be done with the Python
+# warning module, and this is setup when the connection is created, and the
+# standard util.CursorDebugWrapper can be used. Also, using sql_mode
+# TRADITIONAL will automatically cause most warnings to be treated as errors.
try:
# Only exists in Python 2.4+
@@ -83,33 +79,41 @@ class DatabaseWrapper(local):
def cursor(self):
from django.conf import settings
+ from warnings import filterwarnings
if not self._valid_connection():
kwargs = {
- 'user': settings.DATABASE_USER,
- 'db': settings.DATABASE_NAME,
- 'passwd': settings.DATABASE_PASSWORD,
'conv': django_conversions,
+ 'charset': 'utf8',
+ 'use_unicode': False,
}
+ if settings.DATABASE_USER:
+ kwargs['user'] = settings.DATABASE_USER
+ if settings.DATABASE_NAME:
+ kwargs['db'] = settings.DATABASE_NAME
+ if settings.DATABASE_PASSWORD:
+ kwargs['passwd'] = settings.DATABASE_PASSWORD
if settings.DATABASE_HOST.startswith('/'):
kwargs['unix_socket'] = settings.DATABASE_HOST
- else:
+ elif settings.DATABASE_HOST:
kwargs['host'] = settings.DATABASE_HOST
if settings.DATABASE_PORT:
kwargs['port'] = int(settings.DATABASE_PORT)
kwargs.update(self.options)
self.connection = Database.connect(**kwargs)
- cursor = self.connection.cursor()
- if self.connection.get_server_info() >= '4.1':
- cursor.execute("SET NAMES 'utf8'")
+ cursor = self.connection.cursor()
+ else:
+ cursor = self.connection.cursor()
if settings.DEBUG:
- return util.CursorDebugWrapper(MysqlDebugWrapper(cursor), self)
+ filterwarnings("error", category=Database.Warning)
+ return util.CursorDebugWrapper(cursor, self)
return cursor
def _commit(self):
- self.connection.commit()
+ if self.connection is not None:
+ self.connection.commit()
def _rollback(self):
- if self.connection:
+ if self.connection is not None:
try:
self.connection.rollback()
except Database.NotSupportedError:
@@ -172,6 +176,9 @@ def get_limit_offset_sql(limit, offset=None):
def get_random_function_sql():
return "RAND()"
+def get_deferrable_sql():
+ return ""
+
def get_fulltext_search_sql(field_name):
return 'MATCH (%s) AGAINST (%%s IN BOOLEAN MODE)' % field_name
@@ -181,6 +188,41 @@ def get_drop_foreignkey_sql():
def get_pk_default_value():
return "DEFAULT"
+def get_sql_flush(style, tables, sequences):
+ """Return a list of SQL statements required to remove all data from
+ all tables in the database (without actually removing the tables
+ themselves) and put the database in an empty 'initial' state
+
+ """
+ # NB: The generated SQL below is specific to MySQL
+ # 'TRUNCATE x;', 'TRUNCATE y;', 'TRUNCATE z;'... style SQL statements
+ # to clear all tables of all data
+ if tables:
+ sql = ['SET FOREIGN_KEY_CHECKS = 0;'] + \
+ ['%s %s;' % \
+ (style.SQL_KEYWORD('TRUNCATE'),
+ style.SQL_FIELD(quote_name(table))
+ ) for table in tables] + \
+ ['SET FOREIGN_KEY_CHECKS = 1;']
+
+ # 'ALTER TABLE table AUTO_INCREMENT = 1;'... style SQL statements
+ # to reset sequence indices
+ sql.extend(["%s %s %s %s %s;" % \
+ (style.SQL_KEYWORD('ALTER'),
+ style.SQL_KEYWORD('TABLE'),
+ style.SQL_TABLE(quote_name(sequence['table'])),
+ style.SQL_KEYWORD('AUTO_INCREMENT'),
+ style.SQL_FIELD('= 1'),
+ ) for sequence in sequences])
+ return sql
+ else:
+ return []
+
+def get_sql_sequence_reset(style, model_list):
+ "Returns a list of the SQL statements to reset sequences for the given models."
+ # No sequence reset required
+ return []
+
OPERATOR_MAPPING = {
'exact': '= %s',
'iexact': 'LIKE %s',
diff --git a/django/db/backends/mysql/client.py b/django/db/backends/mysql/client.py
index f9d6297b8e..116074a9ce 100644
--- a/django/db/backends/mysql/client.py
+++ b/django/db/backends/mysql/client.py
@@ -3,12 +3,25 @@ import os
def runshell():
args = ['']
- args += ["--user=%s" % settings.DATABASE_USER]
- if settings.DATABASE_PASSWORD:
- args += ["--password=%s" % settings.DATABASE_PASSWORD]
- if settings.DATABASE_HOST:
- args += ["--host=%s" % settings.DATABASE_HOST]
- if settings.DATABASE_PORT:
- args += ["--port=%s" % settings.DATABASE_PORT]
- args += [settings.DATABASE_NAME]
+ db = settings.DATABASE_OPTIONS.get('db', settings.DATABASE_NAME)
+ user = settings.DATABASE_OPTIONS.get('user', settings.DATABASE_USER)
+ passwd = settings.DATABASE_OPTIONS.get('passwd', settings.DATABASE_PASSWORD)
+ host = settings.DATABASE_OPTIONS.get('host', settings.DATABASE_HOST)
+ port = settings.DATABASE_OPTIONS.get('port', settings.DATABASE_PORT)
+ defaults_file = settings.DATABASE_OPTIONS.get('read_default_file')
+ # Seems to be no good way to set sql_mode with CLI
+
+ if defaults_file:
+ args += ["--defaults-file=%s" % defaults_file]
+ if user:
+ args += ["--user=%s" % user]
+ if passwd:
+ args += ["--password=%s" % passwd]
+ if host:
+ args += ["--host=%s" % host]
+ if port:
+ args += ["--port=%s" % port]
+ if db:
+ args += [db]
+
os.execvp('mysql', args)
diff --git a/django/db/backends/mysql/creation.py b/django/db/backends/mysql/creation.py
index 116b490124..1b23fbff6e 100644
--- a/django/db/backends/mysql/creation.py
+++ b/django/db/backends/mysql/creation.py
@@ -9,9 +9,10 @@ DATA_TYPES = {
'CommaSeparatedIntegerField': 'varchar(%(maxlength)s)',
'DateField': 'date',
'DateTimeField': 'datetime',
+ 'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)',
'FileField': 'varchar(100)',
'FilePathField': 'varchar(100)',
- 'FloatField': 'numeric(%(max_digits)s, %(decimal_places)s)',
+ 'FloatField': 'double precision',
'ImageField': 'varchar(100)',
'IntegerField': 'integer',
'IPAddressField': 'char(15)',
@@ -25,6 +26,5 @@ DATA_TYPES = {
'SmallIntegerField': 'smallint',
'TextField': 'longtext',
'TimeField': 'time',
- 'URLField': 'varchar(200)',
'USStateField': 'varchar(2)',
}
diff --git a/django/db/backends/mysql/introspection.py b/django/db/backends/mysql/introspection.py
index 7829457fa9..39733311c5 100644
--- a/django/db/backends/mysql/introspection.py
+++ b/django/db/backends/mysql/introspection.py
@@ -76,7 +76,7 @@ def get_indexes(cursor, table_name):
DATA_TYPES_REVERSE = {
FIELD_TYPE.BLOB: 'TextField',
FIELD_TYPE.CHAR: 'CharField',
- FIELD_TYPE.DECIMAL: 'FloatField',
+ FIELD_TYPE.DECIMAL: 'DecimalField',
FIELD_TYPE.DATE: 'DateField',
FIELD_TYPE.DATETIME: 'DateTimeField',
FIELD_TYPE.DOUBLE: 'FloatField',
@@ -85,7 +85,7 @@ DATA_TYPES_REVERSE = {
FIELD_TYPE.LONG: 'IntegerField',
FIELD_TYPE.LONGLONG: 'IntegerField',
FIELD_TYPE.SHORT: 'IntegerField',
- FIELD_TYPE.STRING: 'TextField',
+ FIELD_TYPE.STRING: 'CharField',
FIELD_TYPE.TIMESTAMP: 'DateTimeField',
FIELD_TYPE.TINY: 'IntegerField',
FIELD_TYPE.TINY_BLOB: 'TextField',
diff --git a/django/db/backends/mysql_old/__init__.py b/django/db/backends/mysql_old/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/django/db/backends/mysql_old/__init__.py
diff --git a/django/db/backends/mysql_old/base.py b/django/db/backends/mysql_old/base.py
new file mode 100644
index 0000000000..ac3b75efde
--- /dev/null
+++ b/django/db/backends/mysql_old/base.py
@@ -0,0 +1,240 @@
+"""
+MySQL database backend for Django.
+
+Requires MySQLdb: http://sourceforge.net/projects/mysql-python
+"""
+
+from django.db.backends import util
+try:
+ import MySQLdb as Database
+except ImportError, e:
+ from django.core.exceptions import ImproperlyConfigured
+ raise ImproperlyConfigured, "Error loading MySQLdb module: %s" % e
+from MySQLdb.converters import conversions
+from MySQLdb.constants import FIELD_TYPE
+import types
+import re
+
+DatabaseError = Database.DatabaseError
+IntegrityError = Database.IntegrityError
+
+django_conversions = conversions.copy()
+django_conversions.update({
+ types.BooleanType: util.rev_typecast_boolean,
+ FIELD_TYPE.DATETIME: util.typecast_timestamp,
+ FIELD_TYPE.DATE: util.typecast_date,
+ FIELD_TYPE.TIME: util.typecast_time,
+ FIELD_TYPE.DECIMAL: util.typecast_decimal,
+})
+
+# This should match the numerical portion of the version numbers (we can treat
+# versions like 5.0.24 and 5.0.24a as the same). Based on the list of version
+# at http://dev.mysql.com/doc/refman/4.1/en/news.html and
+# http://dev.mysql.com/doc/refman/5.0/en/news.html .
+server_version_re = re.compile(r'(\d{1,2})\.(\d{1,2})\.(\d{1,2})')
+
+# This is an extra debug layer over MySQL queries, to display warnings.
+# It's only used when DEBUG=True.
+class MysqlDebugWrapper:
+ def __init__(self, cursor):
+ self.cursor = cursor
+
+ def execute(self, sql, params=()):
+ try:
+ return self.cursor.execute(sql, params)
+ except Database.Warning, w:
+ self.cursor.execute("SHOW WARNINGS")
+ raise Database.Warning, "%s: %s" % (w, self.cursor.fetchall())
+
+ def executemany(self, sql, param_list):
+ try:
+ return self.cursor.executemany(sql, param_list)
+ except Database.Warning, w:
+ self.cursor.execute("SHOW WARNINGS")
+ raise Database.Warning, "%s: %s" % (w, self.cursor.fetchall())
+
+ def __getattr__(self, attr):
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ else:
+ return getattr(self.cursor, attr)
+
+try:
+ # Only exists in Python 2.4+
+ from threading import local
+except ImportError:
+ # Import copy of _thread_local.py from Python 2.4
+ from django.utils._threading_local import local
+
+class DatabaseWrapper(local):
+ def __init__(self, **kwargs):
+ self.connection = None
+ self.queries = []
+ self.server_version = None
+ self.options = kwargs
+
+ def _valid_connection(self):
+ if self.connection is not None:
+ try:
+ self.connection.ping()
+ return True
+ except DatabaseError:
+ self.connection.close()
+ self.connection = None
+ return False
+
+ def cursor(self):
+ from django.conf import settings
+ if not self._valid_connection():
+ kwargs = {
+ 'user': settings.DATABASE_USER,
+ 'db': settings.DATABASE_NAME,
+ 'passwd': settings.DATABASE_PASSWORD,
+ 'conv': django_conversions,
+ }
+ if settings.DATABASE_HOST.startswith('/'):
+ kwargs['unix_socket'] = settings.DATABASE_HOST
+ else:
+ kwargs['host'] = settings.DATABASE_HOST
+ if settings.DATABASE_PORT:
+ kwargs['port'] = int(settings.DATABASE_PORT)
+ kwargs.update(self.options)
+ self.connection = Database.connect(**kwargs)
+ cursor = self.connection.cursor()
+ if self.connection.get_server_info() >= '4.1':
+ cursor.execute("SET NAMES 'utf8'")
+ else:
+ cursor = self.connection.cursor()
+ if settings.DEBUG:
+ return util.CursorDebugWrapper(MysqlDebugWrapper(cursor), self)
+ return cursor
+
+ def _commit(self):
+ if self.connection is not None:
+ self.connection.commit()
+
+ def _rollback(self):
+ if self.connection is not None:
+ try:
+ self.connection.rollback()
+ except Database.NotSupportedError:
+ pass
+
+ def close(self):
+ if self.connection is not None:
+ self.connection.close()
+ self.connection = None
+
+ def get_server_version(self):
+ if not self.server_version:
+ if not self._valid_connection():
+ self.cursor()
+ m = server_version_re.match(self.connection.get_server_info())
+ if not m:
+ raise Exception('Unable to determine MySQL version from version string %r' % self.connection.get_server_info())
+ self.server_version = tuple([int(x) for x in m.groups()])
+ return self.server_version
+
+supports_constraints = True
+
+def quote_name(name):
+ if name.startswith("`") and name.endswith("`"):
+ return name # Quoting once is enough.
+ return "`%s`" % name
+
+dictfetchone = util.dictfetchone
+dictfetchmany = util.dictfetchmany
+dictfetchall = util.dictfetchall
+
+def get_last_insert_id(cursor, table_name, pk_name):
+ return cursor.lastrowid
+
+def get_date_extract_sql(lookup_type, table_name):
+ # lookup_type is 'year', 'month', 'day'
+ # http://dev.mysql.com/doc/mysql/en/date-and-time-functions.html
+ return "EXTRACT(%s FROM %s)" % (lookup_type.upper(), table_name)
+
+def get_date_trunc_sql(lookup_type, field_name):
+ # lookup_type is 'year', 'month', 'day'
+ fields = ['year', 'month', 'day', 'hour', 'minute', 'second']
+ format = ('%%Y-', '%%m', '-%%d', ' %%H:', '%%i', ':%%s') # Use double percents to escape.
+ format_def = ('0000-', '01', '-01', ' 00:', '00', ':00')
+ try:
+ i = fields.index(lookup_type) + 1
+ except ValueError:
+ sql = field_name
+ else:
+ format_str = ''.join([f for f in format[:i]] + [f for f in format_def[i:]])
+ sql = "CAST(DATE_FORMAT(%s, '%s') AS DATETIME)" % (field_name, format_str)
+ return sql
+
+def get_limit_offset_sql(limit, offset=None):
+ sql = "LIMIT "
+ if offset and offset != 0:
+ sql += "%s," % offset
+ return sql + str(limit)
+
+def get_random_function_sql():
+ return "RAND()"
+
+def get_deferrable_sql():
+ return ""
+
+def get_fulltext_search_sql(field_name):
+ return 'MATCH (%s) AGAINST (%%s IN BOOLEAN MODE)' % field_name
+
+def get_drop_foreignkey_sql():
+ return "DROP FOREIGN KEY"
+
+def get_pk_default_value():
+ return "DEFAULT"
+
+def get_sql_flush(style, tables, sequences):
+ """Return a list of SQL statements required to remove all data from
+ all tables in the database (without actually removing the tables
+ themselves) and put the database in an empty 'initial' state
+
+ """
+ # NB: The generated SQL below is specific to MySQL
+ # 'TRUNCATE x;', 'TRUNCATE y;', 'TRUNCATE z;'... style SQL statements
+ # to clear all tables of all data
+ if tables:
+ sql = ['SET FOREIGN_KEY_CHECKS = 0;'] + \
+ ['%s %s;' % \
+ (style.SQL_KEYWORD('TRUNCATE'),
+ style.SQL_FIELD(quote_name(table))
+ ) for table in tables] + \
+ ['SET FOREIGN_KEY_CHECKS = 1;']
+
+ # 'ALTER TABLE table AUTO_INCREMENT = 1;'... style SQL statements
+ # to reset sequence indices
+ sql.extend(["%s %s %s %s %s;" % \
+ (style.SQL_KEYWORD('ALTER'),
+ style.SQL_KEYWORD('TABLE'),
+ style.SQL_TABLE(quote_name(sequence['table'])),
+ style.SQL_KEYWORD('AUTO_INCREMENT'),
+ style.SQL_FIELD('= 1'),
+ ) for sequence in sequences])
+ return sql
+ else:
+ return []
+
+def get_sql_sequence_reset(style, model_list):
+ "Returns a list of the SQL statements to reset sequences for the given models."
+ # No sequence reset required
+ return []
+
+OPERATOR_MAPPING = {
+ 'exact': '= %s',
+ 'iexact': 'LIKE %s',
+ 'contains': 'LIKE BINARY %s',
+ 'icontains': 'LIKE %s',
+ 'gt': '> %s',
+ 'gte': '>= %s',
+ 'lt': '< %s',
+ 'lte': '<= %s',
+ 'startswith': 'LIKE BINARY %s',
+ 'endswith': 'LIKE BINARY %s',
+ 'istartswith': 'LIKE %s',
+ 'iendswith': 'LIKE %s',
+}
diff --git a/django/db/backends/mysql_old/client.py b/django/db/backends/mysql_old/client.py
new file mode 100644
index 0000000000..f9d6297b8e
--- /dev/null
+++ b/django/db/backends/mysql_old/client.py
@@ -0,0 +1,14 @@
+from django.conf import settings
+import os
+
+def runshell():
+ args = ['']
+ args += ["--user=%s" % settings.DATABASE_USER]
+ if settings.DATABASE_PASSWORD:
+ args += ["--password=%s" % settings.DATABASE_PASSWORD]
+ if settings.DATABASE_HOST:
+ args += ["--host=%s" % settings.DATABASE_HOST]
+ if settings.DATABASE_PORT:
+ args += ["--port=%s" % settings.DATABASE_PORT]
+ args += [settings.DATABASE_NAME]
+ os.execvp('mysql', args)
diff --git a/django/db/backends/mysql_old/creation.py b/django/db/backends/mysql_old/creation.py
new file mode 100644
index 0000000000..1b23fbff6e
--- /dev/null
+++ b/django/db/backends/mysql_old/creation.py
@@ -0,0 +1,30 @@
+# This dictionary maps Field objects to their associated MySQL column
+# types, as strings. Column-type strings can contain format strings; they'll
+# be interpolated against the values of Field.__dict__ before being output.
+# If a column type is set to None, it won't be included in the output.
+DATA_TYPES = {
+ 'AutoField': 'integer AUTO_INCREMENT',
+ 'BooleanField': 'bool',
+ 'CharField': 'varchar(%(maxlength)s)',
+ 'CommaSeparatedIntegerField': 'varchar(%(maxlength)s)',
+ 'DateField': 'date',
+ 'DateTimeField': 'datetime',
+ 'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)',
+ 'FileField': 'varchar(100)',
+ 'FilePathField': 'varchar(100)',
+ 'FloatField': 'double precision',
+ 'ImageField': 'varchar(100)',
+ 'IntegerField': 'integer',
+ 'IPAddressField': 'char(15)',
+ 'ManyToManyField': None,
+ 'NullBooleanField': 'bool',
+ 'OneToOneField': 'integer',
+ 'PhoneNumberField': 'varchar(20)',
+ 'PositiveIntegerField': 'integer UNSIGNED',
+ 'PositiveSmallIntegerField': 'smallint UNSIGNED',
+ 'SlugField': 'varchar(%(maxlength)s)',
+ 'SmallIntegerField': 'smallint',
+ 'TextField': 'longtext',
+ 'TimeField': 'time',
+ 'USStateField': 'varchar(2)',
+}
diff --git a/django/db/backends/mysql_old/introspection.py b/django/db/backends/mysql_old/introspection.py
new file mode 100644
index 0000000000..cb5b8320d9
--- /dev/null
+++ b/django/db/backends/mysql_old/introspection.py
@@ -0,0 +1,95 @@
+from django.db.backends.mysql_old.base import quote_name
+from MySQLdb import ProgrammingError, OperationalError
+from MySQLdb.constants import FIELD_TYPE
+import re
+
+foreign_key_re = re.compile(r"\sCONSTRAINT `[^`]*` FOREIGN KEY \(`([^`]*)`\) REFERENCES `([^`]*)` \(`([^`]*)`\)")
+
+def get_table_list(cursor):
+ "Returns a list of table names in the current database."
+ cursor.execute("SHOW TABLES")
+ return [row[0] for row in cursor.fetchall()]
+
+def get_table_description(cursor, table_name):
+ "Returns a description of the table, with the DB-API cursor.description interface."
+ cursor.execute("SELECT * FROM %s LIMIT 1" % quote_name(table_name))
+ return cursor.description
+
+def _name_to_index(cursor, table_name):
+ """
+ Returns a dictionary of {field_name: field_index} for the given table.
+ Indexes are 0-based.
+ """
+ return dict([(d[0], i) for i, d in enumerate(get_table_description(cursor, table_name))])
+
+def get_relations(cursor, table_name):
+ """
+ Returns a dictionary of {field_index: (field_index_other_table, other_table)}
+ representing all relationships to the given table. Indexes are 0-based.
+ """
+ my_field_dict = _name_to_index(cursor, table_name)
+ constraints = []
+ relations = {}
+ try:
+ # This should work for MySQL 5.0.
+ cursor.execute("""
+ SELECT column_name, referenced_table_name, referenced_column_name
+ FROM information_schema.key_column_usage
+ WHERE table_name = %s
+ AND table_schema = DATABASE()
+ AND referenced_table_name IS NOT NULL
+ AND referenced_column_name IS NOT NULL""", [table_name])
+ constraints.extend(cursor.fetchall())
+ except (ProgrammingError, OperationalError):
+ # Fall back to "SHOW CREATE TABLE", for previous MySQL versions.
+ # Go through all constraints and save the equal matches.
+ cursor.execute("SHOW CREATE TABLE %s" % quote_name(table_name))
+ for row in cursor.fetchall():
+ pos = 0
+ while True:
+ match = foreign_key_re.search(row[1], pos)
+ if match == None:
+ break
+ pos = match.end()
+ constraints.append(match.groups())
+
+ for my_fieldname, other_table, other_field in constraints:
+ other_field_index = _name_to_index(cursor, other_table)[other_field]
+ my_field_index = my_field_dict[my_fieldname]
+ relations[my_field_index] = (other_field_index, other_table)
+
+ return relations
+
+def get_indexes(cursor, table_name):
+ """
+ Returns a dictionary of fieldname -> infodict for the given table,
+ where each infodict is in the format:
+ {'primary_key': boolean representing whether it's the primary key,
+ 'unique': boolean representing whether it's a unique index}
+ """
+ cursor.execute("SHOW INDEX FROM %s" % quote_name(table_name))
+ indexes = {}
+ for row in cursor.fetchall():
+ indexes[row[4]] = {'primary_key': (row[2] == 'PRIMARY'), 'unique': not bool(row[1])}
+ return indexes
+
+DATA_TYPES_REVERSE = {
+ FIELD_TYPE.BLOB: 'TextField',
+ FIELD_TYPE.CHAR: 'CharField',
+ FIELD_TYPE.DECIMAL: 'DecimalField',
+ FIELD_TYPE.DATE: 'DateField',
+ FIELD_TYPE.DATETIME: 'DateTimeField',
+ FIELD_TYPE.DOUBLE: 'FloatField',
+ FIELD_TYPE.FLOAT: 'FloatField',
+ FIELD_TYPE.INT24: 'IntegerField',
+ FIELD_TYPE.LONG: 'IntegerField',
+ FIELD_TYPE.LONGLONG: 'IntegerField',
+ FIELD_TYPE.SHORT: 'IntegerField',
+ FIELD_TYPE.STRING: 'TextField',
+ FIELD_TYPE.TIMESTAMP: 'DateTimeField',
+ FIELD_TYPE.TINY: 'IntegerField',
+ FIELD_TYPE.TINY_BLOB: 'TextField',
+ FIELD_TYPE.MEDIUM_BLOB: 'TextField',
+ FIELD_TYPE.LONG_BLOB: 'TextField',
+ FIELD_TYPE.VAR_STRING: 'CharField',
+}
diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py
index 3a13f39546..2bc88bb7b9 100644
--- a/django/db/backends/oracle/base.py
+++ b/django/db/backends/oracle/base.py
@@ -12,6 +12,7 @@ except ImportError, e:
raise ImproperlyConfigured, "Error loading cx_Oracle module: %s" % e
DatabaseError = Database.Error
+IntegrityError = Database.IntegrityError
try:
# Only exists in Python 2.4+
@@ -43,10 +44,11 @@ class DatabaseWrapper(local):
return FormatStylePlaceholderCursor(self.connection)
def _commit(self):
- self.connection.commit()
+ if self.connection is not None:
+ self.connection.commit()
def _rollback(self):
- if self.connection:
+ if self.connection is not None:
try:
self.connection.rollback()
except Database.NotSupportedError:
@@ -108,6 +110,9 @@ def get_limit_offset_sql(limit, offset=None):
def get_random_function_sql():
return "DBMS_RANDOM.RANDOM"
+def get_deferrable_sql():
+ return " DEFERRABLE INITIALLY DEFERRED"
+
def get_fulltext_search_sql(field_name):
raise NotImplementedError
@@ -117,6 +122,24 @@ def get_drop_foreignkey_sql():
def get_pk_default_value():
return "DEFAULT"
+def get_sql_flush(style, tables, sequences):
+ """Return a list of SQL statements required to remove all data from
+ all tables in the database (without actually removing the tables
+ themselves) and put the database in an empty 'initial' state
+ """
+ # Return a list of 'TRUNCATE x;', 'TRUNCATE y;', 'TRUNCATE z;'... style SQL statements
+ # TODO - SQL not actually tested against Oracle yet!
+ # TODO - autoincrement indices reset required? See other get_sql_flush() implementations
+ sql = ['%s %s;' % \
+ (style.SQL_KEYWORD('TRUNCATE'),
+ style.SQL_FIELD(quote_name(table))
+ ) for table in tables]
+
+def get_sql_sequence_reset(style, model_list):
+ "Returns a list of the SQL statements to reset sequences for the given models."
+ # No sequence reset required
+ return []
+
OPERATOR_MAPPING = {
'exact': '= %s',
'iexact': 'LIKE %s',
diff --git a/django/db/backends/oracle/creation.py b/django/db/backends/oracle/creation.py
index d45ceb64f5..14a864ac28 100644
--- a/django/db/backends/oracle/creation.py
+++ b/django/db/backends/oracle/creation.py
@@ -5,9 +5,10 @@ DATA_TYPES = {
'CommaSeparatedIntegerField': 'varchar2(%(maxlength)s)',
'DateField': 'date',
'DateTimeField': 'date',
+ 'DecimalField': 'number(%(max_digits)s, %(decimal_places)s)',
'FileField': 'varchar2(100)',
'FilePathField': 'varchar2(100)',
- 'FloatField': 'number(%(max_digits)s, %(decimal_places)s)',
+ 'FloatField': 'double precision',
'ImageField': 'varchar2(100)',
'IntegerField': 'integer',
'IPAddressField': 'char(15)',
@@ -21,6 +22,5 @@ DATA_TYPES = {
'SmallIntegerField': 'smallint',
'TextField': 'long',
'TimeField': 'timestamp',
- 'URLField': 'varchar(200)',
'USStateField': 'varchar(2)',
}
diff --git a/django/db/backends/oracle/introspection.py b/django/db/backends/oracle/introspection.py
index ecc8f372a8..7634206178 100644
--- a/django/db/backends/oracle/introspection.py
+++ b/django/db/backends/oracle/introspection.py
@@ -46,5 +46,5 @@ DATA_TYPES_REVERSE = {
1114: 'DateTimeField',
1184: 'DateTimeField',
1266: 'TimeField',
- 1700: 'FloatField',
+ 1700: 'DecimalField',
}
diff --git a/django/db/backends/postgresql/base.py b/django/db/backends/postgresql/base.py
index e44bc0b560..fedbb6b7f1 100644
--- a/django/db/backends/postgresql/base.py
+++ b/django/db/backends/postgresql/base.py
@@ -12,6 +12,7 @@ except ImportError, e:
raise ImproperlyConfigured, "Error loading psycopg module: %s" % e
DatabaseError = Database.DatabaseError
+IntegrityError = Database.IntegrityError
try:
# Only exists in Python 2.4+
@@ -20,6 +21,40 @@ except ImportError:
# Import copy of _thread_local.py from Python 2.4
from django.utils._threading_local import local
+def smart_basestring(s, charset):
+ if isinstance(s, unicode):
+ return s.encode(charset)
+ return s
+
+class UnicodeCursorWrapper(object):
+ """
+ A thin wrapper around psycopg cursors that allows them to accept Unicode
+ strings as params.
+
+ This is necessary because psycopg doesn't apply any DB quoting to
+ parameters that are Unicode strings. If a param is Unicode, this will
+ convert it to a bytestring using DEFAULT_CHARSET before passing it to
+ psycopg.
+ """
+ def __init__(self, cursor, charset):
+ self.cursor = cursor
+ self.charset = charset
+
+ def execute(self, sql, params=()):
+ return self.cursor.execute(sql, [smart_basestring(p, self.charset) for p in params])
+
+ def executemany(self, sql, param_list):
+ new_param_list = [tuple([smart_basestring(p, self.charset) for p in params]) for params in param_list]
+ return self.cursor.executemany(sql, new_param_list)
+
+ def __getattr__(self, attr):
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ else:
+ return getattr(self.cursor, attr)
+
+postgres_version = None
+
class DatabaseWrapper(local):
def __init__(self, **kwargs):
self.connection = None
@@ -28,7 +63,9 @@ class DatabaseWrapper(local):
def cursor(self):
from django.conf import settings
+ set_tz = False
if self.connection is None:
+ set_tz = True
if settings.DATABASE_NAME == '':
from django.core.exceptions import ImproperlyConfigured
raise ImproperlyConfigured, "You need to specify DATABASE_NAME in your Django settings file."
@@ -44,16 +81,23 @@ class DatabaseWrapper(local):
self.connection = Database.connect(conn_string, **self.options)
self.connection.set_isolation_level(1) # make transactions transparent to all cursors
cursor = self.connection.cursor()
- cursor.execute("SET TIME ZONE %s", [settings.TIME_ZONE])
+ if set_tz:
+ cursor.execute("SET TIME ZONE %s", [settings.TIME_ZONE])
+ cursor = UnicodeCursorWrapper(cursor, settings.DEFAULT_CHARSET)
+ global postgres_version
+ if not postgres_version:
+ cursor.execute("SELECT version()")
+ postgres_version = [int(val) for val in cursor.fetchone()[0].split()[1].split('.')]
if settings.DEBUG:
return util.CursorDebugWrapper(cursor, self)
return cursor
def _commit(self):
- return self.connection.commit()
+ if self.connection is not None:
+ return self.connection.commit()
def _rollback(self):
- if self.connection:
+ if self.connection is not None:
return self.connection.rollback()
def close(self):
@@ -103,6 +147,9 @@ def get_limit_offset_sql(limit, offset=None):
def get_random_function_sql():
return "RANDOM()"
+def get_deferrable_sql():
+ return " DEFERRABLE INITIALLY DEFERRED"
+
def get_fulltext_search_sql(field_name):
raise NotImplementedError
@@ -112,6 +159,91 @@ def get_drop_foreignkey_sql():
def get_pk_default_value():
return "DEFAULT"
+def get_sql_flush(style, tables, sequences):
+ """Return a list of SQL statements required to remove all data from
+ all tables in the database (without actually removing the tables
+ themselves) and put the database in an empty 'initial' state
+
+ """
+ if tables:
+ if postgres_version[0] >= 8 and postgres_version[1] >= 1:
+ # Postgres 8.1+ can do 'TRUNCATE x, y, z...;'. In fact, it *has to* in order to be able to
+ # truncate tables referenced by a foreign key in any other table. The result is a
+ # single SQL TRUNCATE statement.
+ sql = ['%s %s;' % \
+ (style.SQL_KEYWORD('TRUNCATE'),
+ style.SQL_FIELD(', '.join([quote_name(table) for table in tables]))
+ )]
+ else:
+ # Older versions of Postgres can't do TRUNCATE in a single call, so they must use
+ # a simple delete.
+ sql = ['%s %s %s;' % \
+ (style.SQL_KEYWORD('DELETE'),
+ style.SQL_KEYWORD('FROM'),
+ style.SQL_FIELD(quote_name(table))
+ ) for table in tables]
+
+ # 'ALTER SEQUENCE sequence_name RESTART WITH 1;'... style SQL statements
+ # to reset sequence indices
+ for sequence_info in sequences:
+ table_name = sequence_info['table']
+ column_name = sequence_info['column']
+ if column_name and len(column_name)>0:
+ # sequence name in this case will be <table>_<column>_seq
+ sql.append("%s %s %s %s %s %s;" % \
+ (style.SQL_KEYWORD('ALTER'),
+ style.SQL_KEYWORD('SEQUENCE'),
+ style.SQL_FIELD(quote_name('%s_%s_seq' % (table_name, column_name))),
+ style.SQL_KEYWORD('RESTART'),
+ style.SQL_KEYWORD('WITH'),
+ style.SQL_FIELD('1')
+ )
+ )
+ else:
+ # sequence name in this case will be <table>_id_seq
+ sql.append("%s %s %s %s %s %s;" % \
+ (style.SQL_KEYWORD('ALTER'),
+ style.SQL_KEYWORD('SEQUENCE'),
+ style.SQL_FIELD(quote_name('%s_id_seq' % table_name)),
+ style.SQL_KEYWORD('RESTART'),
+ style.SQL_KEYWORD('WITH'),
+ style.SQL_FIELD('1')
+ )
+ )
+ return sql
+ else:
+ return []
+
+def get_sql_sequence_reset(style, model_list):
+ "Returns a list of the SQL statements to reset sequences for the given models."
+ from django.db import models
+ output = []
+ for model in model_list:
+ # Use `coalesce` to set the sequence for each model to the max pk value if there are records,
+ # or 1 if there are none. Set the `is_called` property (the third argument to `setval`) to true
+ # if there are records (as the max pk value is already in use), otherwise set it to false.
+ for f in model._meta.fields:
+ if isinstance(f, models.AutoField):
+ output.append("%s setval('%s', coalesce(max(%s), 1), max(%s) %s null) %s %s;" % \
+ (style.SQL_KEYWORD('SELECT'),
+ style.SQL_FIELD(quote_name('%s_%s_seq' % (model._meta.db_table, f.column))),
+ style.SQL_FIELD(quote_name(f.column)),
+ style.SQL_FIELD(quote_name(f.column)),
+ style.SQL_KEYWORD('IS NOT'),
+ style.SQL_KEYWORD('FROM'),
+ style.SQL_TABLE(quote_name(model._meta.db_table))))
+ break # Only one AutoField is allowed per model, so don't bother continuing.
+ for f in model._meta.many_to_many:
+ output.append("%s setval('%s', coalesce(max(%s), 1), max(%s) %s null) %s %s;" % \
+ (style.SQL_KEYWORD('SELECT'),
+ style.SQL_FIELD(quote_name('%s_id_seq' % f.m2m_db_table())),
+ style.SQL_FIELD(quote_name('id')),
+ style.SQL_FIELD(quote_name('id')),
+ style.SQL_KEYWORD('IS NOT'),
+ style.SQL_KEYWORD('FROM'),
+ style.SQL_TABLE(f.m2m_db_table())))
+ return output
+
# Register these custom typecasts, because Django expects dates/times to be
# in Python's native (standard-library) datetime/time format, whereas psycopg
# use mx.DateTime by default.
@@ -122,6 +254,7 @@ except AttributeError:
Database.register_type(Database.new_type((1083,1266), "TIME", util.typecast_time))
Database.register_type(Database.new_type((1114,1184), "TIMESTAMP", util.typecast_timestamp))
Database.register_type(Database.new_type((16,), "BOOLEAN", util.typecast_boolean))
+Database.register_type(Database.new_type((1700,), "NUMERIC", util.typecast_decimal))
OPERATOR_MAPPING = {
'exact': '= %s',
diff --git a/django/db/backends/postgresql/creation.py b/django/db/backends/postgresql/creation.py
index 65a804ec40..4646b68ab8 100644
--- a/django/db/backends/postgresql/creation.py
+++ b/django/db/backends/postgresql/creation.py
@@ -9,9 +9,10 @@ DATA_TYPES = {
'CommaSeparatedIntegerField': 'varchar(%(maxlength)s)',
'DateField': 'date',
'DateTimeField': 'timestamp with time zone',
+ 'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)',
'FileField': 'varchar(100)',
'FilePathField': 'varchar(100)',
- 'FloatField': 'numeric(%(max_digits)s, %(decimal_places)s)',
+ 'FloatField': 'double precision',
'ImageField': 'varchar(100)',
'IntegerField': 'integer',
'IPAddressField': 'inet',
@@ -25,6 +26,5 @@ DATA_TYPES = {
'SmallIntegerField': 'smallint',
'TextField': 'text',
'TimeField': 'time',
- 'URLField': 'varchar(200)',
'USStateField': 'varchar(2)',
}
diff --git a/django/db/backends/postgresql/introspection.py b/django/db/backends/postgresql/introspection.py
index 6e1d60c4ff..2605490afd 100644
--- a/django/db/backends/postgresql/introspection.py
+++ b/django/db/backends/postgresql/introspection.py
@@ -72,6 +72,7 @@ DATA_TYPES_REVERSE = {
21: 'SmallIntegerField',
23: 'IntegerField',
25: 'TextField',
+ 701: 'FloatField',
869: 'IPAddressField',
1043: 'CharField',
1082: 'DateField',
@@ -79,5 +80,5 @@ DATA_TYPES_REVERSE = {
1114: 'DateTimeField',
1184: 'DateTimeField',
1266: 'TimeField',
- 1700: 'FloatField',
+ 1700: 'DecimalField',
}
diff --git a/django/db/backends/postgresql_psycopg2/base.py b/django/db/backends/postgresql_psycopg2/base.py
index 04322332dc..d9ad363ac1 100644
--- a/django/db/backends/postgresql_psycopg2/base.py
+++ b/django/db/backends/postgresql_psycopg2/base.py
@@ -12,6 +12,7 @@ except ImportError, e:
raise ImproperlyConfigured, "Error loading psycopg2 module: %s" % e
DatabaseError = Database.DatabaseError
+IntegrityError = Database.IntegrityError
try:
# Only exists in Python 2.4+
@@ -20,6 +21,8 @@ except ImportError:
# Import copy of _thread_local.py from Python 2.4
from django.utils._threading_local import local
+postgres_version = None
+
class DatabaseWrapper(local):
def __init__(self, **kwargs):
self.connection = None
@@ -28,7 +31,9 @@ class DatabaseWrapper(local):
def cursor(self):
from django.conf import settings
+ set_tz = False
if self.connection is None:
+ set_tz = True
if settings.DATABASE_NAME == '':
from django.core.exceptions import ImproperlyConfigured
raise ImproperlyConfigured, "You need to specify DATABASE_NAME in your Django settings file."
@@ -45,16 +50,22 @@ class DatabaseWrapper(local):
self.connection.set_isolation_level(1) # make transactions transparent to all cursors
cursor = self.connection.cursor()
cursor.tzinfo_factory = None
- cursor.execute("SET TIME ZONE %s", [settings.TIME_ZONE])
+ if set_tz:
+ cursor.execute("SET TIME ZONE %s", [settings.TIME_ZONE])
+ global postgres_version
+ if not postgres_version:
+ cursor.execute("SELECT version()")
+ postgres_version = [int(val) for val in cursor.fetchone()[0].split()[1].split('.')]
if settings.DEBUG:
return util.CursorDebugWrapper(cursor, self)
return cursor
def _commit(self):
- return self.connection.commit()
+ if self.connection is not None:
+ return self.connection.commit()
def _rollback(self):
- if self.connection:
+ if self.connection is not None:
return self.connection.rollback()
def close(self):
@@ -96,6 +107,9 @@ def get_limit_offset_sql(limit, offset=None):
def get_random_function_sql():
return "RANDOM()"
+def get_deferrable_sql():
+ return " DEFERRABLE INITIALLY DEFERRED"
+
def get_fulltext_search_sql(field_name):
raise NotImplementedError
@@ -105,6 +119,88 @@ def get_drop_foreignkey_sql():
def get_pk_default_value():
return "DEFAULT"
+def get_sql_flush(style, tables, sequences):
+ """Return a list of SQL statements required to remove all data from
+ all tables in the database (without actually removing the tables
+ themselves) and put the database in an empty 'initial' state
+ """
+ if tables:
+ if postgres_version[0] >= 8 and postgres_version[1] >= 1:
+ # Postgres 8.1+ can do 'TRUNCATE x, y, z...;'. In fact, it *has to* in order to be able to
+ # truncate tables referenced by a foreign key in any other table. The result is a
+ # single SQL TRUNCATE statement
+ sql = ['%s %s;' % \
+ (style.SQL_KEYWORD('TRUNCATE'),
+ style.SQL_FIELD(', '.join([quote_name(table) for table in tables]))
+ )]
+ else:
+ sql = ['%s %s %s;' % \
+ (style.SQL_KEYWORD('DELETE'),
+ style.SQL_KEYWORD('FROM'),
+ style.SQL_FIELD(quote_name(table))
+ ) for table in tables]
+
+ # 'ALTER SEQUENCE sequence_name RESTART WITH 1;'... style SQL statements
+ # to reset sequence indices
+ for sequence in sequences:
+ table_name = sequence['table']
+ column_name = sequence['column']
+ if column_name and len(column_name) > 0:
+ # sequence name in this case will be <table>_<column>_seq
+ sql.append("%s %s %s %s %s %s;" % \
+ (style.SQL_KEYWORD('ALTER'),
+ style.SQL_KEYWORD('SEQUENCE'),
+ style.SQL_FIELD(quote_name('%s_%s_seq' % (table_name, column_name))),
+ style.SQL_KEYWORD('RESTART'),
+ style.SQL_KEYWORD('WITH'),
+ style.SQL_FIELD('1')
+ )
+ )
+ else:
+ # sequence name in this case will be <table>_id_seq
+ sql.append("%s %s %s %s %s %s;" % \
+ (style.SQL_KEYWORD('ALTER'),
+ style.SQL_KEYWORD('SEQUENCE'),
+ style.SQL_FIELD(quote_name('%s_id_seq' % table_name)),
+ style.SQL_KEYWORD('RESTART'),
+ style.SQL_KEYWORD('WITH'),
+ style.SQL_FIELD('1')
+ )
+ )
+ return sql
+ else:
+ return []
+
+def get_sql_sequence_reset(style, model_list):
+ "Returns a list of the SQL statements to reset sequences for the given models."
+ from django.db import models
+ output = []
+ for model in model_list:
+ # Use `coalesce` to set the sequence for each model to the max pk value if there are records,
+ # or 1 if there are none. Set the `is_called` property (the third argument to `setval`) to true
+ # if there are records (as the max pk value is already in use), otherwise set it to false.
+ for f in model._meta.fields:
+ if isinstance(f, models.AutoField):
+ output.append("%s setval('%s', coalesce(max(%s), 1), max(%s) %s null) %s %s;" % \
+ (style.SQL_KEYWORD('SELECT'),
+ style.SQL_FIELD(quote_name('%s_%s_seq' % (model._meta.db_table, f.column))),
+ style.SQL_FIELD(quote_name(f.column)),
+ style.SQL_FIELD(quote_name(f.column)),
+ style.SQL_KEYWORD('IS NOT'),
+ style.SQL_KEYWORD('FROM'),
+ style.SQL_TABLE(quote_name(model._meta.db_table))))
+ break # Only one AutoField is allowed per model, so don't bother continuing.
+ for f in model._meta.many_to_many:
+ output.append("%s setval('%s', coalesce(max(%s), 1), max(%s) %s null) %s %s;" % \
+ (style.SQL_KEYWORD('SELECT'),
+ style.SQL_FIELD(quote_name('%s_id_seq' % f.m2m_db_table())),
+ style.SQL_FIELD(quote_name('id')),
+ style.SQL_FIELD(quote_name('id')),
+ style.SQL_KEYWORD('IS NOT'),
+ style.SQL_KEYWORD('FROM'),
+ style.SQL_TABLE(f.m2m_db_table())))
+ return output
+
OPERATOR_MAPPING = {
'exact': '= %s',
'iexact': 'ILIKE %s',
diff --git a/django/db/backends/postgresql_psycopg2/introspection.py b/django/db/backends/postgresql_psycopg2/introspection.py
index a546da8c45..aa45fe7db7 100644
--- a/django/db/backends/postgresql_psycopg2/introspection.py
+++ b/django/db/backends/postgresql_psycopg2/introspection.py
@@ -72,6 +72,7 @@ DATA_TYPES_REVERSE = {
21: 'SmallIntegerField',
23: 'IntegerField',
25: 'TextField',
+ 701: 'FloatField',
869: 'IPAddressField',
1043: 'CharField',
1082: 'DateField',
@@ -79,5 +80,5 @@ DATA_TYPES_REVERSE = {
1114: 'DateTimeField',
1184: 'DateTimeField',
1266: 'TimeField',
- 1700: 'FloatField',
+ 1700: 'DecimalField',
}
diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py
index 891320160f..5cd67a32f5 100644
--- a/django/db/backends/sqlite3/base.py
+++ b/django/db/backends/sqlite3/base.py
@@ -17,7 +17,13 @@ except ImportError, e:
module = 'sqlite3'
raise ImproperlyConfigured, "Error loading %s module: %s" % (module, e)
+try:
+ import decimal
+except ImportError:
+ from django.utils import _decimal as decimal # for Python 2.3
+
DatabaseError = Database.DatabaseError
+IntegrityError = Database.IntegrityError
Database.register_converter("bool", lambda s: str(s) == '1')
Database.register_converter("time", util.typecast_time)
@@ -25,6 +31,8 @@ Database.register_converter("date", util.typecast_date)
Database.register_converter("datetime", util.typecast_timestamp)
Database.register_converter("timestamp", util.typecast_timestamp)
Database.register_converter("TIMESTAMP", util.typecast_timestamp)
+Database.register_converter("decimal", util.typecast_decimal)
+Database.register_adapter(decimal.Decimal, util.rev_typecast_decimal)
def utf8rowFactory(cursor, row):
def utf8(s):
@@ -67,10 +75,11 @@ class DatabaseWrapper(local):
return cursor
def _commit(self):
- self.connection.commit()
+ if self.connection is not None:
+ self.connection.commit()
def _rollback(self):
- if self.connection:
+ if self.connection is not None:
self.connection.rollback()
def close(self):
@@ -139,6 +148,9 @@ def get_limit_offset_sql(limit, offset=None):
def get_random_function_sql():
return "RANDOM()"
+def get_deferrable_sql():
+ return ""
+
def get_fulltext_search_sql(field_name):
raise NotImplementedError
@@ -148,6 +160,29 @@ def get_drop_foreignkey_sql():
def get_pk_default_value():
return "NULL"
+def get_sql_flush(style, tables, sequences):
+ """Return a list of SQL statements required to remove all data from
+ all tables in the database (without actually removing the tables
+ themselves) and put the database in an empty 'initial' state
+
+ """
+ # NB: The generated SQL below is specific to SQLite
+ # Note: The DELETE FROM... SQL generated below works for SQLite databases
+ # because constraints don't exist
+ sql = ['%s %s %s;' % \
+ (style.SQL_KEYWORD('DELETE'),
+ style.SQL_KEYWORD('FROM'),
+ style.SQL_FIELD(quote_name(table))
+ ) for table in tables]
+ # Note: No requirement for reset of auto-incremented indices (cf. other
+ # get_sql_flush() implementations). Just return SQL at this point
+ return sql
+
+def get_sql_sequence_reset(style, model_list):
+ "Returns a list of the SQL statements to reset sequences for the given models."
+ # No sequence reset required
+ return []
+
def _sqlite_date_trunc(lookup_type, dt):
try:
dt = util.typecast_timestamp(dt)
diff --git a/django/db/backends/sqlite3/creation.py b/django/db/backends/sqlite3/creation.py
index e845179e64..e63046ab7d 100644
--- a/django/db/backends/sqlite3/creation.py
+++ b/django/db/backends/sqlite3/creation.py
@@ -8,9 +8,10 @@ DATA_TYPES = {
'CommaSeparatedIntegerField': 'varchar(%(maxlength)s)',
'DateField': 'date',
'DateTimeField': 'datetime',
+ 'DecimalField': 'decimal',
'FileField': 'varchar(100)',
'FilePathField': 'varchar(100)',
- 'FloatField': 'numeric(%(max_digits)s, %(decimal_places)s)',
+ 'FloatField': 'real',
'ImageField': 'varchar(100)',
'IntegerField': 'integer',
'IPAddressField': 'char(15)',
@@ -24,6 +25,5 @@ DATA_TYPES = {
'SmallIntegerField': 'smallint',
'TextField': 'text',
'TimeField': 'time',
- 'URLField': 'varchar(200)',
'USStateField': 'varchar(2)',
}
diff --git a/django/db/backends/util.py b/django/db/backends/util.py
index d8f86fef4f..81c752e664 100644
--- a/django/db/backends/util.py
+++ b/django/db/backends/util.py
@@ -1,6 +1,11 @@
import datetime
from time import time
+try:
+ import decimal
+except ImportError:
+ from django.utils import _decimal as decimal # for Python 2.3
+
class CursorDebugWrapper(object):
def __init__(self, cursor, db):
self.cursor = cursor
@@ -33,7 +38,7 @@ class CursorDebugWrapper(object):
})
def __getattr__(self, attr):
- if self.__dict__.has_key(attr):
+ if attr in self.__dict__:
return self.__dict__[attr]
else:
return getattr(self.cursor, attr)
@@ -85,6 +90,11 @@ def typecast_boolean(s):
if not s: return False
return str(s)[0].lower() == 't'
+def typecast_decimal(s):
+ if s is None or s == '':
+ return None
+ return decimal.Decimal(s)
+
###############################################
# Converters from Python to database (string) #
###############################################
@@ -92,6 +102,11 @@ def typecast_boolean(s):
def rev_typecast_boolean(obj, d):
return obj and '1' or '0'
+def rev_typecast_decimal(d):
+ if d is None:
+ return None
+ return str(d)
+
##################################################################################
# Helper functions for dictfetch* for databases that don't natively support them #
##################################################################################
diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py
index 0308dd047a..6c3abb6b59 100644
--- a/django/db/models/__init__.py
+++ b/django/db/models/__init__.py
@@ -8,7 +8,6 @@ from django.db.models.manager import Manager
from django.db.models.base import Model, AdminOptions
from django.db.models.fields import *
from django.db.models.fields.related import ForeignKey, OneToOneField, ManyToManyField, ManyToOneRel, ManyToManyRel, OneToOneRel, TABULAR, STACKED
-from django.db.models.fields.generic import GenericRelation, GenericRel, GenericForeignKey
from django.db.models import signals
from django.utils.functional import curry
from django.utils.text import capfirst
@@ -27,27 +26,3 @@ def permalink(func):
viewname = bits[0]
return reverse(bits[0], None, *bits[1:3])
return inner
-
-class LazyDate(object):
- """
- Use in limit_choices_to to compare the field to dates calculated at run time
- instead of when the model is loaded. For example::
-
- ... limit_choices_to = {'date__gt' : models.LazyDate(days=-3)} ...
-
- which will limit the choices to dates greater than three days ago.
- """
- def __init__(self, **kwargs):
- self.delta = datetime.timedelta(**kwargs)
-
- def __str__(self):
- return str(self.__get_value__())
-
- def __repr__(self):
- return "<LazyDate: %s>" % self.delta
-
- def __get_value__(self):
- return (datetime.datetime.now() + self.delta).date()
-
- def __getattr__(self, attr):
- return getattr(self.__get_value__(), attr)
diff --git a/django/db/models/base.py b/django/db/models/base.py
index 70569a2561..a8e6303e1c 100644
--- a/django/db/models/base.py
+++ b/django/db/models/base.py
@@ -13,6 +13,7 @@ from django.dispatch import dispatcher
from django.utils.datastructures import SortedDict
from django.utils.functional import curry
from django.conf import settings
+from itertools import izip
import types
import sys
import os
@@ -21,8 +22,13 @@ class ModelBase(type):
"Metaclass for all models"
def __new__(cls, name, bases, attrs):
# If this isn't a subclass of Model, don't do anything special.
- if not bases or bases == (object,):
- return type.__new__(cls, name, bases, attrs)
+ try:
+ if not filter(lambda b: issubclass(b, Model), bases):
+ return super(ModelBase, cls).__new__(cls, name, bases, attrs)
+ except NameError:
+ # 'Model' isn't defined yet, meaning we're looking at Django's own
+ # Model class, defined below.
+ return super(ModelBase, cls).__new__(cls, name, bases, attrs)
# Create the class.
new_class = type.__new__(cls, name, bases, {'__module__': attrs.pop('__module__')})
@@ -36,11 +42,11 @@ class ModelBase(type):
new_class._meta.parents.append(base)
new_class._meta.parents.extend(base._meta.parents)
- model_module = sys.modules[new_class.__module__]
if getattr(new_class._meta, 'app_label', None) is None:
# Figure out the app_label by looking one level up.
# For 'django.contrib.sites.models', this would be 'sites'.
+ model_module = sys.modules[new_class.__module__]
new_class._meta.app_label = model_module.__name__.split('.')[-2]
# Bail out early if we have already created this class.
@@ -63,7 +69,7 @@ class ModelBase(type):
if getattr(new_class._meta, 'row_level_permissions', False):
from django.contrib.auth.models import RowLevelPermission
- gen_rel = django.db.models.GenericRelation(RowLevelPermission, object_id_field="model_id", content_type_field="model_ct")
+ gen_rel = django.contrib.contenttypes.generic.GenericRelation(RowLevelPermission, object_id_field="model_id", content_type_field="model_ct")
new_class.add_to_class("row_level_permissions", gen_rel)
new_class._prepare()
@@ -95,41 +101,74 @@ class Model(object):
def __init__(self, *args, **kwargs):
dispatcher.send(signal=signals.pre_init, sender=self.__class__, args=args, kwargs=kwargs)
- for f in self._meta.fields:
- if isinstance(f.rel, ManyToOneRel):
- try:
- # Assume object instance was passed in.
- rel_obj = kwargs.pop(f.name)
- except KeyError:
+
+ # There is a rather weird disparity here; if kwargs, it's set, then args
+ # overrides it. It should be one or the other; don't duplicate the work
+ # The reason for the kwargs check is that standard iterator passes in by
+ # args, and nstantiation for iteration is 33% faster.
+ args_len = len(args)
+ if args_len > len(self._meta.fields):
+ # Daft, but matches old exception sans the err msg.
+ raise IndexError("Number of args exceeds number of fields")
+
+ fields_iter = iter(self._meta.fields)
+ if not kwargs:
+ # The ordering of the izip calls matter - izip throws StopIteration
+ # when an iter throws it. So if the first iter throws it, the second
+ # is *not* consumed. We rely on this, so don't change the order
+ # without changing the logic.
+ for val, field in izip(args, fields_iter):
+ setattr(self, field.attname, val)
+ else:
+ # Slower, kwargs-ready version.
+ for val, field in izip(args, fields_iter):
+ setattr(self, field.attname, val)
+ kwargs.pop(field.name, None)
+ # Maintain compatibility with existing calls.
+ if isinstance(field.rel, ManyToOneRel):
+ kwargs.pop(field.attname, None)
+
+ # Now we're left with the unprocessed fields that *must* come from
+ # keywords, or default.
+
+ for field in fields_iter:
+ if kwargs:
+ if isinstance(field.rel, ManyToOneRel):
try:
- # Object instance wasn't passed in -- must be an ID.
- val = kwargs.pop(f.attname)
+ # Assume object instance was passed in.
+ rel_obj = kwargs.pop(field.name)
except KeyError:
- val = f.get_default()
- else:
- # Object instance was passed in.
- # Special case: You can pass in "None" for related objects if it's allowed.
- if rel_obj is None and f.null:
- val = None
- else:
try:
- val = getattr(rel_obj, f.rel.get_related_field().attname)
- except AttributeError:
- raise TypeError, "Invalid value: %r should be a %s instance, not a %s" % (f.name, f.rel.to, type(rel_obj))
- setattr(self, f.attname, val)
+ # Object instance wasn't passed in -- must be an ID.
+ val = kwargs.pop(field.attname)
+ except KeyError:
+ val = field.get_default()
+ else:
+ # Object instance was passed in. Special case: You can
+ # pass in "None" for related objects if it's allowed.
+ if rel_obj is None and field.null:
+ val = None
+ else:
+ try:
+ val = getattr(rel_obj, field.rel.get_related_field().attname)
+ except AttributeError:
+ raise TypeError("Invalid value: %r should be a %s instance, not a %s" %
+ (field.name, field.rel.to, type(rel_obj)))
+ else:
+ val = kwargs.pop(field.attname, field.get_default())
else:
- val = kwargs.pop(f.attname, f.get_default())
- setattr(self, f.attname, val)
- for prop in kwargs.keys():
- try:
- if isinstance(getattr(self.__class__, prop), property):
- setattr(self, prop, kwargs.pop(prop))
- except AttributeError:
- pass
+ val = field.get_default()
+ setattr(self, field.attname, val)
+
if kwargs:
- raise TypeError, "'%s' is an invalid keyword argument for this function" % kwargs.keys()[0]
- for i, arg in enumerate(args):
- setattr(self, self._meta.fields[i].attname, arg)
+ for prop in kwargs.keys():
+ try:
+ if isinstance(getattr(self.__class__, prop), property):
+ setattr(self, prop, kwargs.pop(prop))
+ except AttributeError:
+ pass
+ if kwargs:
+ raise TypeError, "'%s' is an invalid keyword argument for this function" % kwargs.keys()[0]
dispatcher.send(signal=signals.post_init, sender=self.__class__, instance=self)
def add_to_class(cls, name, value):
@@ -327,7 +366,7 @@ class Model(object):
def _get_FIELD_size(self, field):
return os.path.getsize(self._get_FIELD_filename(field))
- def _save_FIELD_file(self, field, filename, raw_contents):
+ def _save_FIELD_file(self, field, filename, raw_contents, save=True):
directory = field.get_directory_name()
try: # Create the date-based directory if it doesn't exist.
os.makedirs(os.path.join(settings.MEDIA_ROOT, directory))
@@ -362,8 +401,9 @@ class Model(object):
if field.height_field:
setattr(self, field.height_field, height)
- # Save the object, because it has changed.
- self.save()
+ # Save the object because it has changed unless save is False
+ if save:
+ self.save()
_save_FIELD_file.alters_data = True
diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py
index fe317ac24f..136ce31b8b 100644
--- a/django/db/models/fields/__init__.py
+++ b/django/db/models/fields/__init__.py
@@ -10,6 +10,10 @@ from django.utils.itercompat import tee
from django.utils.text import capfirst
from django.utils.translation import gettext, gettext_lazy
import datetime, os, time
+try:
+ import decimal
+except ImportError:
+ from django.utils import _decimal as decimal # for Python 2.3
class NOT_PROVIDED:
pass
@@ -67,7 +71,7 @@ class Field(object):
def __init__(self, verbose_name=None, name=None, primary_key=False,
maxlength=None, unique=False, blank=False, null=False, db_index=False,
- core=False, rel=None, default=NOT_PROVIDED, editable=True,
+ core=False, rel=None, default=NOT_PROVIDED, editable=True, serialize=True,
prepopulate_from=None, unique_for_date=None, unique_for_month=None,
unique_for_year=None, validator_list=None, choices=None, radio_admin=None,
help_text='', db_column=None):
@@ -78,6 +82,7 @@ class Field(object):
self.blank, self.null = blank, null
self.core, self.rel, self.default = core, rel, default
self.editable = editable
+ self.serialize = serialize
self.validator_list = validator_list or []
self.prepopulate_from = prepopulate_from
self.unique_for_date, self.unique_for_month = unique_for_date, unique_for_month
@@ -164,7 +169,7 @@ class Field(object):
def get_db_prep_lookup(self, lookup_type, value):
"Returns field's value prepared for database lookup."
- if lookup_type in ('exact', 'gt', 'gte', 'lt', 'lte', 'year', 'month', 'day', 'search'):
+ if lookup_type in ('exact', 'gt', 'gte', 'lt', 'lte', 'month', 'day', 'search'):
return [value]
elif lookup_type in ('range', 'in'):
return value
@@ -178,7 +183,13 @@ class Field(object):
return ["%%%s" % prep_for_like_query(value)]
elif lookup_type == 'isnull':
return []
- raise TypeError, "Field has invalid lookup: %s" % lookup_type
+ elif lookup_type == 'year':
+ try:
+ value = int(value)
+ except ValueError:
+ raise ValueError("The __year lookup type requires an integer argument")
+ return ['%s-01-01 00:00:00' % value, '%s-12-31 23:59:59.999999' % value]
+ raise TypeError("Field has invalid lookup: %s" % lookup_type)
def has_default(self):
"Returns a boolean of whether this field has a default value."
@@ -334,10 +345,17 @@ class Field(object):
return self._choices
choices = property(_get_choices)
- def formfield(self):
+ def formfield(self, form_class=forms.CharField, **kwargs):
"Returns a django.newforms.Field instance for this database Field."
- # TODO: This is just a temporary default during development.
- return forms.CharField(required=not self.blank, label=capfirst(self.verbose_name))
+ defaults = {'required': not self.blank, 'label': capfirst(self.verbose_name), 'help_text': self.help_text}
+ if self.choices:
+ defaults['widget'] = forms.Select(choices=self.get_choices())
+ defaults.update(kwargs)
+ return form_class(**defaults)
+
+ def value_from_object(self, obj):
+ "Returns the value of this field in the given model instance."
+ return getattr(obj, self.attname)
class AutoField(Field):
empty_strings_allowed = False
@@ -375,7 +393,7 @@ class AutoField(Field):
super(AutoField, self).contribute_to_class(cls, name)
cls._meta.has_auto_field = True
- def formfield(self):
+ def formfield(self, **kwargs):
return None
class BooleanField(Field):
@@ -392,8 +410,10 @@ class BooleanField(Field):
def get_manipulator_field_objs(self):
return [oldforms.CheckboxField]
- def formfield(self):
- return forms.BooleanField(required=not self.blank, label=capfirst(self.verbose_name))
+ def formfield(self, **kwargs):
+ defaults = {'form_class': forms.BooleanField}
+ defaults.update(kwargs)
+ return super(BooleanField, self).formfield(**defaults)
class CharField(Field):
def get_manipulator_field_objs(self):
@@ -409,8 +429,10 @@ class CharField(Field):
raise validators.ValidationError, gettext_lazy("This field cannot be null.")
return str(value)
- def formfield(self):
- return forms.CharField(max_length=self.maxlength, required=not self.blank, label=capfirst(self.verbose_name))
+ def formfield(self, **kwargs):
+ defaults = {'max_length': self.maxlength}
+ defaults.update(kwargs)
+ return super(CharField, self).formfield(**defaults)
# TODO: Maybe move this into contrib, because it's specialized.
class CommaSeparatedIntegerField(CharField):
@@ -428,6 +450,8 @@ class DateField(Field):
Field.__init__(self, verbose_name, name, **kwargs)
def to_python(self, value):
+ if value is None:
+ return value
if isinstance(value, datetime.datetime):
return value.date()
if isinstance(value, datetime.date):
@@ -479,15 +503,19 @@ class DateField(Field):
def get_manipulator_field_objs(self):
return [oldforms.DateField]
- def flatten_data(self, follow, obj = None):
+ def flatten_data(self, follow, obj=None):
val = self._get_val_from_obj(obj)
return {self.attname: (val is not None and val.strftime("%Y-%m-%d") or '')}
- def formfield(self):
- return forms.DateField(required=not self.blank, label=capfirst(self.verbose_name))
+ def formfield(self, **kwargs):
+ defaults = {'form_class': forms.DateField}
+ defaults.update(kwargs)
+ return super(DateField, self).formfield(**defaults)
class DateTimeField(DateField):
def to_python(self, value):
+ if value is None:
+ return value
if isinstance(value, datetime.datetime):
return value
if isinstance(value, datetime.date):
@@ -544,8 +572,69 @@ class DateTimeField(DateField):
return {date_field: (val is not None and val.strftime("%Y-%m-%d") or ''),
time_field: (val is not None and val.strftime("%H:%M:%S") or '')}
- def formfield(self):
- return forms.DateTimeField(required=not self.blank, label=capfirst(self.verbose_name))
+ def formfield(self, **kwargs):
+ defaults = {'form_class': forms.DateTimeField}
+ defaults.update(kwargs)
+ return super(DateTimeField, self).formfield(**defaults)
+
+class DecimalField(Field):
+ empty_strings_allowed = False
+ def __init__(self, verbose_name=None, name=None, max_digits=None, decimal_places=None, **kwargs):
+ self.max_digits, self.decimal_places = max_digits, decimal_places
+ Field.__init__(self, verbose_name, name, **kwargs)
+
+ def to_python(self, value):
+ if value is None:
+ return value
+ try:
+ return decimal.Decimal(value)
+ except decimal.InvalidOperation:
+ raise validators.ValidationError, gettext("This value must be a decimal number.")
+
+ def _format(self, value):
+ if isinstance(value, basestring):
+ return value
+ else:
+ return self.format_number(value)
+
+ def format_number(self, value):
+ """
+ Formats a number into a string with the requisite number of digits and
+ decimal places.
+ """
+ num_chars = self.max_digits
+ # Allow for a decimal point
+ if self.decimal_places > 0:
+ num_chars += 1
+ # Allow for a minus sign
+ if value < 0:
+ num_chars += 1
+
+ return "%.*f" % (self.decimal_places, value)
+
+ def get_db_prep_save(self, value):
+ if value is not None:
+ value = self._format(value)
+ return super(DecimalField, self).get_db_prep_save(value)
+
+ def get_db_prep_lookup(self, lookup_type, value):
+ if lookup_type == 'range':
+ value = [self._format(v) for v in value]
+ else:
+ value = self._format(value)
+ return super(DecimalField, self).get_db_prep_lookup(lookup_type, value)
+
+ def get_manipulator_field_objs(self):
+ return [curry(oldforms.DecimalField, max_digits=self.max_digits, decimal_places=self.decimal_places)]
+
+ def formfield(self, **kwargs):
+ defaults = {
+ 'max_digits': self.max_digits,
+ 'decimal_places': self.decimal_places,
+ 'form_class': forms.DecimalField,
+ }
+ defaults.update(kwargs)
+ return super(DecimalField, self).formfield(**defaults)
class EmailField(CharField):
def __init__(self, *args, **kwargs):
@@ -561,8 +650,10 @@ class EmailField(CharField):
def validate(self, field_data, all_data):
validators.isValidEmail(field_data, all_data)
- def formfield(self):
- return forms.EmailField(required=not self.blank, label=capfirst(self.verbose_name))
+ def formfield(self, **kwargs):
+ defaults = {'form_class': forms.EmailField}
+ defaults.update(kwargs)
+ return super(EmailField, self).formfield(**defaults)
class FileField(Field):
def __init__(self, verbose_name=None, name=None, upload_to='', **kwargs):
@@ -610,7 +701,7 @@ class FileField(Field):
setattr(cls, 'get_%s_filename' % self.name, curry(cls._get_FIELD_filename, field=self))
setattr(cls, 'get_%s_url' % self.name, curry(cls._get_FIELD_url, field=self))
setattr(cls, 'get_%s_size' % self.name, curry(cls._get_FIELD_size, field=self))
- setattr(cls, 'save_%s_file' % self.name, lambda instance, filename, raw_contents: instance._save_FIELD_file(self, filename, raw_contents))
+ setattr(cls, 'save_%s_file' % self.name, lambda instance, filename, raw_contents, save=True: instance._save_FIELD_file(self, filename, raw_contents, save))
dispatcher.connect(self.delete_file, signal=signals.post_delete, sender=cls)
def delete_file(self, instance):
@@ -628,14 +719,14 @@ class FileField(Field):
def get_manipulator_field_names(self, name_prefix):
return [name_prefix + self.name + '_file', name_prefix + self.name]
- def save_file(self, new_data, new_object, original_object, change, rel):
+ def save_file(self, new_data, new_object, original_object, change, rel, save=True):
upload_field_name = self.get_manipulator_field_names('')[0]
if new_data.get(upload_field_name, False):
func = getattr(new_object, 'save_%s_file' % self.name)
if rel:
- func(new_data[upload_field_name][0]["filename"], new_data[upload_field_name][0]["content"])
+ func(new_data[upload_field_name][0]["filename"], new_data[upload_field_name][0]["content"], save)
else:
- func(new_data[upload_field_name]["filename"], new_data[upload_field_name]["content"])
+ func(new_data[upload_field_name]["filename"], new_data[upload_field_name]["content"], save)
def get_directory_name(self):
return os.path.normpath(datetime.datetime.now().strftime(self.upload_to))
@@ -655,12 +746,14 @@ class FilePathField(Field):
class FloatField(Field):
empty_strings_allowed = False
- def __init__(self, verbose_name=None, name=None, max_digits=None, decimal_places=None, **kwargs):
- self.max_digits, self.decimal_places = max_digits, decimal_places
- Field.__init__(self, verbose_name, name, **kwargs)
def get_manipulator_field_objs(self):
- return [curry(oldforms.FloatField, max_digits=self.max_digits, decimal_places=self.decimal_places)]
+ return [oldforms.FloatField]
+
+ def formfield(self, **kwargs):
+ defaults = {'form_class': forms.FloatField}
+ defaults.update(kwargs)
+ return super(FloatField, self).formfield(**defaults)
class ImageField(FileField):
def __init__(self, verbose_name=None, name=None, width_field=None, height_field=None, **kwargs):
@@ -679,12 +772,12 @@ class ImageField(FileField):
if not self.height_field:
setattr(cls, 'get_%s_height' % self.name, curry(cls._get_FIELD_height, field=self))
- def save_file(self, new_data, new_object, original_object, change, rel):
- FileField.save_file(self, new_data, new_object, original_object, change, rel)
+ def save_file(self, new_data, new_object, original_object, change, rel, save=True):
+ FileField.save_file(self, new_data, new_object, original_object, change, rel, save)
# If the image has height and/or width field(s) and they haven't
# changed, set the width and/or height field(s) back to their original
# values.
- if change and (self.width_field or self.height_field):
+ if change and (self.width_field or self.height_field) and save:
if self.width_field:
setattr(new_object, self.width_field, getattr(original_object, self.width_field))
if self.height_field:
@@ -696,8 +789,10 @@ class IntegerField(Field):
def get_manipulator_field_objs(self):
return [oldforms.IntegerField]
- def formfield(self):
- return forms.IntegerField(required=not self.blank, label=capfirst(self.verbose_name))
+ def formfield(self, **kwargs):
+ defaults = {'form_class': forms.IntegerField}
+ defaults.update(kwargs)
+ return super(IntegerField, self).formfield(**defaults)
class IPAddressField(Field):
def __init__(self, *args, **kwargs):
@@ -715,6 +810,13 @@ class NullBooleanField(Field):
kwargs['null'] = True
Field.__init__(self, *args, **kwargs)
+ def to_python(self, value):
+ if value in (None, True, False): return value
+ if value in ('None'): return None
+ if value in ('t', 'True', '1'): return True
+ if value in ('f', 'False', '0'): return False
+ raise validators.ValidationError, gettext("This value must be either None, True or False.")
+
def get_manipulator_field_objs(self):
return [oldforms.NullBooleanField]
@@ -725,6 +827,12 @@ class PhoneNumberField(IntegerField):
def validate(self, field_data, all_data):
validators.isValidPhone(field_data, all_data)
+ def formfield(self, **kwargs):
+ from django.contrib.localflavor.us.forms import USPhoneNumberField
+ defaults = {'form_class': USPhoneNumberField}
+ defaults.update(kwargs)
+ return super(PhoneNumberField, self).formfield(**defaults)
+
class PositiveIntegerField(IntegerField):
def get_manipulator_field_objs(self):
return [oldforms.PositiveIntegerField]
@@ -738,7 +846,7 @@ class SlugField(Field):
kwargs['maxlength'] = kwargs.get('maxlength', 50)
kwargs.setdefault('validator_list', []).append(validators.isSlug)
# Set db_index=True unless it's been set manually.
- if not kwargs.has_key('db_index'):
+ if 'db_index' not in kwargs:
kwargs['db_index'] = True
Field.__init__(self, *args, **kwargs)
@@ -753,6 +861,11 @@ class TextField(Field):
def get_manipulator_field_objs(self):
return [oldforms.LargeTextField]
+ def formfield(self, **kwargs):
+ defaults = {'widget': forms.Textarea}
+ defaults.update(kwargs)
+ return super(TextField, self).formfield(**defaults)
+
class TimeField(Field):
empty_strings_allowed = False
def __init__(self, verbose_name=None, name=None, auto_now=False, auto_now_add=False, **kwargs):
@@ -781,7 +894,7 @@ class TimeField(Field):
if value is not None:
# MySQL will throw a warning if microseconds are given, because it
# doesn't support microseconds.
- if settings.DATABASE_ENGINE == 'mysql':
+ if settings.DATABASE_ENGINE == 'mysql' and hasattr(value, 'microsecond'):
value = value.replace(microsecond=0)
value = str(value)
return Field.get_db_prep_save(self, value)
@@ -793,26 +906,40 @@ class TimeField(Field):
val = self._get_val_from_obj(obj)
return {self.attname: (val is not None and val.strftime("%H:%M:%S") or '')}
- def formfield(self):
- return forms.TimeField(required=not self.blank, label=capfirst(self.verbose_name))
+ def formfield(self, **kwargs):
+ defaults = {'form_class': forms.TimeField}
+ defaults.update(kwargs)
+ return super(TimeField, self).formfield(**defaults)
-class URLField(Field):
+class URLField(CharField):
def __init__(self, verbose_name=None, name=None, verify_exists=True, **kwargs):
+ kwargs['maxlength'] = kwargs.get('maxlength', 200)
if verify_exists:
kwargs.setdefault('validator_list', []).append(validators.isExistingURL)
self.verify_exists = verify_exists
- Field.__init__(self, verbose_name, name, **kwargs)
+ CharField.__init__(self, verbose_name, name, **kwargs)
def get_manipulator_field_objs(self):
return [oldforms.URLField]
- def formfield(self):
- return forms.URLField(required=not self.blank, verify_exists=self.verify_exists, label=capfirst(self.verbose_name))
+ def get_internal_type(self):
+ return "CharField"
+
+ def formfield(self, **kwargs):
+ defaults = {'form_class': forms.URLField, 'verify_exists': self.verify_exists}
+ defaults.update(kwargs)
+ return super(URLField, self).formfield(**defaults)
class USStateField(Field):
def get_manipulator_field_objs(self):
return [oldforms.USStateField]
+ def formfield(self, **kwargs):
+ from django.contrib.localflavor.us.forms import USStateSelect
+ defaults = {'widget': USStateSelect}
+ defaults.update(kwargs)
+ return super(USStateField, self).formfield(**defaults)
+
class XMLField(TextField):
def __init__(self, verbose_name=None, name=None, schema_path=None, **kwargs):
self.schema_path = schema_path
diff --git a/django/db/models/fields/generic.py b/django/db/models/fields/generic.py
deleted file mode 100644
index 1ad8346e42..0000000000
--- a/django/db/models/fields/generic.py
+++ /dev/null
@@ -1,259 +0,0 @@
-"""
-Classes allowing "generic" relations through ContentType and object-id fields.
-"""
-
-from django import oldforms
-from django.core.exceptions import ObjectDoesNotExist
-from django.db import backend
-from django.db.models import signals
-from django.db.models.fields.related import RelatedField, Field, ManyToManyRel
-from django.db.models.loading import get_model
-from django.dispatch import dispatcher
-from django.utils.functional import curry
-
-class GenericForeignKey(object):
- """
- Provides a generic relation to any object through content-type/object-id
- fields.
- """
-
- def __init__(self, ct_field="content_type", fk_field="object_id"):
- self.ct_field = ct_field
- self.fk_field = fk_field
-
- def contribute_to_class(self, cls, name):
- # Make sure the fields exist (these raise FieldDoesNotExist,
- # which is a fine error to raise here)
- self.name = name
- self.model = cls
- self.cache_attr = "_%s_cache" % name
-
- # For some reason I don't totally understand, using weakrefs here doesn't work.
- dispatcher.connect(self.instance_pre_init, signal=signals.pre_init, sender=cls, weak=False)
-
- # Connect myself as the descriptor for this field
- setattr(cls, name, self)
-
- def instance_pre_init(self, signal, sender, args, kwargs):
- # Handle initalizing an object with the generic FK instaed of
- # content-type/object-id fields.
- if kwargs.has_key(self.name):
- value = kwargs.pop(self.name)
- kwargs[self.ct_field] = self.get_content_type(value)
- kwargs[self.fk_field] = value._get_pk_val()
-
- def get_content_type(self, obj):
- # Convenience function using get_model avoids a circular import when using this model
- ContentType = get_model("contenttypes", "contenttype")
- return ContentType.objects.get_for_model(obj)
-
- def __get__(self, instance, instance_type=None):
- if instance is None:
- raise AttributeError, "%s must be accessed via instance" % self.name
-
- try:
- return getattr(instance, self.cache_attr)
- except AttributeError:
- rel_obj = None
- ct = getattr(instance, self.ct_field)
- if ct:
- try:
- rel_obj = ct.get_object_for_this_type(pk=getattr(instance, self.fk_field))
- except ObjectDoesNotExist:
- pass
- setattr(instance, self.cache_attr, rel_obj)
- return rel_obj
-
- def __set__(self, instance, value):
- if instance is None:
- raise AttributeError, "%s must be accessed via instance" % self.related.opts.object_name
-
- ct = None
- fk = None
- if value is not None:
- ct = self.get_content_type(value)
- fk = value._get_pk_val()
-
- setattr(instance, self.ct_field, ct)
- setattr(instance, self.fk_field, fk)
- setattr(instance, self.cache_attr, value)
-
-class GenericRelation(RelatedField, Field):
- """Provides an accessor to generic related objects (i.e. comments)"""
-
- def __init__(self, to, **kwargs):
- kwargs['verbose_name'] = kwargs.get('verbose_name', None)
- kwargs['rel'] = GenericRel(to,
- related_name=kwargs.pop('related_name', None),
- limit_choices_to=kwargs.pop('limit_choices_to', None),
- symmetrical=kwargs.pop('symmetrical', True))
-
- # Override content-type/object-id field names on the related class
- self.object_id_field_name = kwargs.pop("object_id_field", "object_id")
- self.content_type_field_name = kwargs.pop("content_type_field", "content_type")
-
- kwargs['blank'] = True
- kwargs['editable'] = False
- Field.__init__(self, **kwargs)
-
- def get_manipulator_field_objs(self):
- choices = self.get_choices_default()
- return [curry(oldforms.SelectMultipleField, size=min(max(len(choices), 5), 15), choices=choices)]
-
- def get_choices_default(self):
- return Field.get_choices(self, include_blank=False)
-
- def flatten_data(self, follow, obj = None):
- new_data = {}
- if obj:
- instance_ids = [instance._get_pk_val() for instance in getattr(obj, self.name).all()]
- new_data[self.name] = instance_ids
- return new_data
-
- def m2m_db_table(self):
- return self.rel.to._meta.db_table
-
- def m2m_column_name(self):
- return self.object_id_field_name
-
- def m2m_reverse_name(self):
- return self.object_id_field_name
-
- def contribute_to_class(self, cls, name):
- super(GenericRelation, self).contribute_to_class(cls, name)
-
- # Save a reference to which model this class is on for future use
- self.model = cls
-
- # Add the descriptor for the m2m relation
- setattr(cls, self.name, ReverseGenericRelatedObjectsDescriptor(self))
-
- def contribute_to_related_class(self, cls, related):
- pass
-
- def set_attributes_from_rel(self):
- pass
-
- def get_internal_type(self):
- return "ManyToManyField"
-
-class ReverseGenericRelatedObjectsDescriptor(object):
- """
- This class provides the functionality that makes the related-object
- managers available as attributes on a model class, for fields that have
- multiple "remote" values and have a GenericRelation defined in their model
- (rather than having another model pointed *at* them). In the example
- "article.publications", the publications attribute is a
- ReverseGenericRelatedObjectsDescriptor instance.
- """
- def __init__(self, field):
- self.field = field
-
- def __get__(self, instance, instance_type=None):
- if instance is None:
- raise AttributeError, "Manager must be accessed via instance"
-
- # This import is done here to avoid circular import importing this module
- from django.contrib.contenttypes.models import ContentType
-
- # Dynamically create a class that subclasses the related model's
- # default manager.
- rel_model = self.field.rel.to
- superclass = rel_model._default_manager.__class__
- RelatedManager = create_generic_related_manager(superclass)
-
- manager = RelatedManager(
- model = rel_model,
- instance = instance,
- symmetrical = (self.field.rel.symmetrical and instance.__class__ == rel_model),
- join_table = backend.quote_name(self.field.m2m_db_table()),
- source_col_name = backend.quote_name(self.field.m2m_column_name()),
- target_col_name = backend.quote_name(self.field.m2m_reverse_name()),
- content_type = ContentType.objects.get_for_model(self.field.model),
- content_type_field_name = self.field.content_type_field_name,
- object_id_field_name = self.field.object_id_field_name
- )
-
- return manager
-
- def __set__(self, instance, value):
- if instance is None:
- raise AttributeError, "Manager must be accessed via instance"
-
- manager = self.__get__(instance)
- manager.clear()
- for obj in value:
- manager.add(obj)
-
-def create_generic_related_manager(superclass):
- """
- Factory function for a manager that subclasses 'superclass' (which is a
- Manager) and adds behavior for generic related objects.
- """
-
- class GenericRelatedObjectManager(superclass):
- def __init__(self, model=None, core_filters=None, instance=None, symmetrical=None,
- join_table=None, source_col_name=None, target_col_name=None, content_type=None,
- content_type_field_name=None, object_id_field_name=None):
-
- super(GenericRelatedObjectManager, self).__init__()
- self.core_filters = core_filters or {}
- self.model = model
- self.content_type = content_type
- self.symmetrical = symmetrical
- self.instance = instance
- self.join_table = join_table
- self.join_table = model._meta.db_table
- self.source_col_name = source_col_name
- self.target_col_name = target_col_name
- self.content_type_field_name = content_type_field_name
- self.object_id_field_name = object_id_field_name
- self.pk_val = self.instance._get_pk_val()
-
- def get_query_set(self):
- query = {
- '%s__pk' % self.content_type_field_name : self.content_type.id,
- '%s__exact' % self.object_id_field_name : self.pk_val,
- }
- return superclass.get_query_set(self).filter(**query)
-
- def add(self, *objs):
- for obj in objs:
- setattr(obj, self.content_type_field_name, self.content_type)
- setattr(obj, self.object_id_field_name, self.pk_val)
- obj.save()
- add.alters_data = True
-
- def remove(self, *objs):
- for obj in objs:
- obj.delete()
- remove.alters_data = True
-
- def clear(self):
- for obj in self.all():
- obj.delete()
- clear.alters_data = True
-
- def create(self, **kwargs):
- kwargs[self.content_type_field_name] = self.content_type
- kwargs[self.object_id_field_name] = self.pk_val
- obj = self.model(**kwargs)
- obj.save()
- return obj
- create.alters_data = True
-
- return GenericRelatedObjectManager
-
-class GenericRel(ManyToManyRel):
- def __init__(self, to, related_name=None, limit_choices_to=None, symmetrical=True):
- self.to = to
- self.num_in_admin = 0
- self.related_name = related_name
- self.filter_interface = None
- self.limit_choices_to = limit_choices_to or {}
- self.edit_inline = False
- self.raw_id_admin = False
- self.symmetrical = symmetrical
- self.multiple = True
- assert not (self.raw_id_admin and self.filter_interface), \
- "Generic relations may not use both raw_id_admin and filter_interface"
diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py
index 797ef05be1..0739d0461a 100644
--- a/django/db/models/fields/related.py
+++ b/django/db/models/fields/related.py
@@ -2,10 +2,12 @@ from django.db import backend, transaction
from django.db.models import signals, get_model
from django.db.models.fields import AutoField, Field, IntegerField, get_ul_class
from django.db.models.related import RelatedObject
+from django.utils.text import capfirst
from django.utils.translation import gettext_lazy, string_concat, ngettext
from django.utils.functional import curry
from django.core import validators
from django import oldforms
+from django import newforms as forms
from django.dispatch import dispatcher
# For Python 2.3
@@ -314,18 +316,20 @@ def create_many_related_manager(superclass):
# join_table: name of the m2m link table
# source_col_name: the PK colname in join_table for the source object
# target_col_name: the PK colname in join_table for the target object
- # *objs - objects to add
+ # *objs - objects to add. Either object instances, or primary keys of object instances.
from django.db import connection
# If there aren't any objects, there is nothing to do.
if objs:
# Check that all the objects are of the right type
+ new_ids = set()
for obj in objs:
- if not isinstance(obj, self.model):
- raise ValueError, "objects to add() must be %s instances" % self.model._meta.object_name
+ if isinstance(obj, self.model):
+ new_ids.add(obj._get_pk_val())
+ else:
+ new_ids.add(obj)
# Add the newly created or already existing objects to the join table.
# First find out which items are already added, to avoid adding them twice
- new_ids = set([obj._get_pk_val() for obj in objs])
cursor = connection.cursor()
cursor.execute("SELECT %s FROM %s WHERE %s = %%s AND %s IN (%s)" % \
(target_col_name, self.join_table, source_col_name,
@@ -352,14 +356,16 @@ def create_many_related_manager(superclass):
# If there aren't any objects, there is nothing to do.
if objs:
# Check that all the objects are of the right type
+ old_ids = set()
for obj in objs:
- if not isinstance(obj, self.model):
- raise ValueError, "objects to remove() must be %s instances" % self.model._meta.object_name
+ if isinstance(obj, self.model):
+ old_ids.add(obj._get_pk_val())
+ else:
+ old_ids.add(obj)
# Remove the specified objects from the join table
- old_ids = set([obj._get_pk_val() for obj in objs])
cursor = connection.cursor()
cursor.execute("DELETE FROM %s WHERE %s = %%s AND %s IN (%s)" % \
- (self.join_table, source_col_name,
+ (self.join_table, source_col_name,
target_col_name, ",".join(['%s'] * len(old_ids))),
[self._pk_val] + list(old_ids))
transaction.commit_unless_managed()
@@ -468,7 +474,7 @@ class ForeignKey(RelatedField, Field):
to_field = to_field or to._meta.pk.name
kwargs['verbose_name'] = kwargs.get('verbose_name', '')
- if kwargs.has_key('edit_inline_type'):
+ if 'edit_inline_type' in kwargs:
import warnings
warnings.warn("edit_inline_type is deprecated. Use edit_inline instead.")
kwargs['edit_inline'] = kwargs.pop('edit_inline_type')
@@ -546,6 +552,11 @@ class ForeignKey(RelatedField, Field):
def contribute_to_related_class(self, cls, related):
setattr(cls, related.get_accessor_name(), ForeignRelatedObjectsDescriptor(related))
+ def formfield(self, **kwargs):
+ defaults = {'form_class': forms.ModelChoiceField, 'queryset': self.rel.to._default_manager.all()}
+ defaults.update(kwargs)
+ return super(ForeignKey, self).formfield(**defaults)
+
class OneToOneField(RelatedField, IntegerField):
def __init__(self, to, to_field=None, **kwargs):
try:
@@ -556,7 +567,7 @@ class OneToOneField(RelatedField, IntegerField):
to_field = to_field or to._meta.pk.name
kwargs['verbose_name'] = kwargs.get('verbose_name', '')
- if kwargs.has_key('edit_inline_type'):
+ if 'edit_inline_type' in kwargs:
import warnings
warnings.warn("edit_inline_type is deprecated. Use edit_inline instead.")
kwargs['edit_inline'] = kwargs.pop('edit_inline_type')
@@ -607,6 +618,11 @@ class OneToOneField(RelatedField, IntegerField):
if not cls._meta.one_to_one_field:
cls._meta.one_to_one_field = self
+ def formfield(self, **kwargs):
+ defaults = {'form_class': forms.ModelChoiceField, 'queryset': self.rel.to._default_manager.all()}
+ defaults.update(kwargs)
+ return super(OneToOneField, self).formfield(**defaults)
+
class ManyToManyField(RelatedField, Field):
def __init__(self, to, **kwargs):
kwargs['verbose_name'] = kwargs.get('verbose_name', None)
@@ -617,6 +633,7 @@ class ManyToManyField(RelatedField, Field):
limit_choices_to=kwargs.pop('limit_choices_to', None),
raw_id_admin=kwargs.pop('raw_id_admin', False),
symmetrical=kwargs.pop('symmetrical', True))
+ self.db_table = kwargs.pop('db_table', None)
if kwargs["rel"].raw_id_admin:
kwargs.setdefault("validator_list", []).append(self.isValidIDList)
Field.__init__(self, **kwargs)
@@ -639,7 +656,10 @@ class ManyToManyField(RelatedField, Field):
def _get_m2m_db_table(self, opts):
"Function that can be curried to provide the m2m table name for this relation"
- return '%s_%s' % (opts.db_table, self.name)
+ if self.db_table:
+ return self.db_table
+ else:
+ return '%s_%s' % (opts.db_table, self.name)
def _get_m2m_column_name(self, related):
"Function that can be curried to provide the source column name for the m2m table"
@@ -713,6 +733,19 @@ class ManyToManyField(RelatedField, Field):
def set_attributes_from_rel(self):
pass
+ def value_from_object(self, obj):
+ "Returns the value of this field in the given model instance."
+ return getattr(obj, self.attname).all()
+
+ def formfield(self, **kwargs):
+ defaults = {'form_class': forms.ModelMultipleChoiceField, 'queryset': self.rel.to._default_manager.all()}
+ defaults.update(kwargs)
+ # If initial is passed in, it's a list of related objects, but the
+ # MultipleChoiceField takes a list of IDs.
+ if defaults.get('initial') is not None:
+ defaults['initial'] = [i._get_pk_val() for i in defaults['initial']]
+ return super(ManyToManyField, self).formfield(**defaults)
+
class ManyToOneRel(object):
def __init__(self, to, field_name, num_in_admin=3, min_num_in_admin=None,
max_num_in_admin=None, num_extra_on_change=1, edit_inline=False,
diff --git a/django/db/models/loading.py b/django/db/models/loading.py
index f4aff2438b..224f5e8451 100644
--- a/django/db/models/loading.py
+++ b/django/db/models/loading.py
@@ -103,7 +103,7 @@ def register_models(app_label, *models):
# in the _app_models dictionary
model_name = model._meta.object_name.lower()
model_dict = _app_models.setdefault(app_label, {})
- if model_dict.has_key(model_name):
+ if model_name in model_dict:
# The same model may be imported via different paths (e.g.
# appname.models and project.appname.models). We use the source
# filename as a means to detect identity.
diff --git a/django/db/models/manager.py b/django/db/models/manager.py
index 6005874516..b60eed262a 100644
--- a/django/db/models/manager.py
+++ b/django/db/models/manager.py
@@ -1,4 +1,4 @@
-from django.db.models.query import QuerySet
+from django.db.models.query import QuerySet, EmptyQuerySet
from django.dispatch import dispatcher
from django.db.models import signals
from django.db.models.fields import FieldDoesNotExist
@@ -41,12 +41,18 @@ class Manager(object):
#######################
# PROXIES TO QUERYSET #
#######################
+
+ def get_empty_query_set(self):
+ return EmptyQuerySet(self.model)
def get_query_set(self):
"""Returns a new QuerySet object. Subclasses can override this method
to easily customise the behaviour of the Manager.
"""
return QuerySet(self.model)
+
+ def none(self):
+ return self.get_empty_query_set()
def all(self):
return self.get_query_set()
diff --git a/django/db/models/manipulators.py b/django/db/models/manipulators.py
index e9dfa7037c..d5fc5f725e 100644
--- a/django/db/models/manipulators.py
+++ b/django/db/models/manipulators.py
@@ -96,14 +96,16 @@ class AutomaticManipulator(oldforms.Manipulator):
if self.change:
params[self.opts.pk.attname] = self.obj_key
- # First, save the basic object itself.
+ # First, create the basic object itself.
new_object = self.model(**params)
- new_object.save()
- # Now that the object's been saved, save any uploaded files.
+ # Now that the object's been created, save any uploaded files.
for f in self.opts.fields:
if isinstance(f, FileField):
- f.save_file(new_data, new_object, self.change and self.original_object or None, self.change, rel=False)
+ f.save_file(new_data, new_object, self.change and self.original_object or None, self.change, rel=False, save=False)
+
+ # Now save the object
+ new_object.save()
# Calculate which primary fields have changed.
if self.change:
diff --git a/django/db/models/options.py b/django/db/models/options.py
index ee253ff451..556168e7d0 100644
--- a/django/db/models/options.py
+++ b/django/db/models/options.py
@@ -85,6 +85,7 @@ class Options(object):
self.fields.insert(bisect(self.fields, field), field)
if not self.pk and field.primary_key:
self.pk = field
+ field.serialize = False
def __repr__(self):
return '<Options for %s>' % self.object_name
@@ -140,7 +141,7 @@ class Options(object):
def get_follow(self, override=None):
follow = {}
for f in self.fields + self.many_to_many + self.get_all_related_objects():
- if override and override.has_key(f.name):
+ if override and f.name in override:
child_override = override[f.name]
else:
child_override = None
@@ -182,7 +183,7 @@ class Options(object):
# TODO: follow
if not hasattr(self, '_field_types'):
self._field_types = {}
- if not self._field_types.has_key(field_type):
+ if field_type not in self._field_types:
try:
# First check self.fields.
for f in self.fields:
diff --git a/django/db/models/query.py b/django/db/models/query.py
index 53ed63ae5b..a6e702be18 100644
--- a/django/db/models/query.py
+++ b/django/db/models/query.py
@@ -1,8 +1,9 @@
from django.db import backend, connection, transaction
from django.db.models.fields import DateField, FieldDoesNotExist
-from django.db.models import signals
+from django.db.models import signals, loading
from django.dispatch import dispatcher
from django.utils.datastructures import SortedDict
+from django.contrib.contenttypes import generic
import operator
import re
@@ -25,6 +26,9 @@ QUERY_TERMS = (
# Larger values are slightly faster at the expense of more storage space.
GET_ITERATOR_CHUNK_SIZE = 100
+class EmptyResultSet(Exception):
+ pass
+
####################
# HELPER FUNCTIONS #
####################
@@ -80,6 +84,7 @@ class QuerySet(object):
self._filters = Q()
self._order_by = None # Ordering, e.g. ('date', '-name'). If None, use model's ordering.
self._select_related = False # Whether to fill cache for related objects.
+ self._max_related_depth = 0 # Maximum "depth" for select_related
self._distinct = False # Whether the query should use SELECT DISTINCT.
self._select = {} # Dictionary of attname -> SQL.
self._where = [] # List of extra WHERE clauses to use.
@@ -104,6 +109,8 @@ class QuerySet(object):
def __getitem__(self, k):
"Retrieve an item or slice from the set of results."
+ if not isinstance(k, (slice, int)):
+ raise TypeError
assert (not isinstance(k, slice) and (k >= 0)) \
or (isinstance(k, slice) and (k.start is None or k.start >= 0) and (k.stop is None or k.stop >= 0)), \
"Negative indexing is not supported."
@@ -163,12 +170,16 @@ class QuerySet(object):
def iterator(self):
"Performs the SELECT database lookup of this QuerySet."
+ try:
+ select, sql, params = self._get_sql_clause()
+ except EmptyResultSet:
+ raise StopIteration
+
# self._select is a dictionary, and dictionaries' key order is
# undefined, so we convert it to a list of tuples.
extra_select = self._select.items()
cursor = connection.cursor()
- select, sql, params = self._get_sql_clause()
cursor.execute("SELECT " + (self._distinct and "DISTINCT " or "") + ",".join(select) + sql, params)
fill_cache = self._select_related
index_end = len(self.model._meta.fields)
@@ -178,7 +189,8 @@ class QuerySet(object):
raise StopIteration
for row in rows:
if fill_cache:
- obj, index_end = get_cached_row(self.model, row, 0)
+ obj, index_end = get_cached_row(klass=self.model, row=row,
+ index_start=0, max_depth=self._max_related_depth)
else:
obj = self.model(*row[:index_end])
for i, k in enumerate(extra_select):
@@ -186,13 +198,31 @@ class QuerySet(object):
yield obj
def count(self):
- "Performs a SELECT COUNT() and returns the number of records as an integer."
+ """
+ Performs a SELECT COUNT() and returns the number of records as an
+ integer.
+
+ If the queryset is already cached (i.e. self._result_cache is set) this
+ simply returns the length of the cached results set to avoid multiple
+ SELECT COUNT(*) calls.
+ """
+ if self._result_cache is not None:
+ return len(self._result_cache)
+
counter = self._clone()
counter._order_by = ()
+ counter._select_related = False
+
+ offset = counter._offset
+ limit = counter._limit
counter._offset = None
counter._limit = None
- counter._select_related = False
- select, sql, params = counter._get_sql_clause()
+
+ try:
+ select, sql, params = counter._get_sql_clause()
+ except EmptyResultSet:
+ return 0
+
cursor = connection.cursor()
if self._distinct:
id_col = "%s.%s" % (backend.quote_name(self.model._meta.db_table),
@@ -200,7 +230,16 @@ class QuerySet(object):
cursor.execute("SELECT COUNT(DISTINCT(%s))" % id_col + sql, params)
else:
cursor.execute("SELECT COUNT(*)" + sql, params)
- return cursor.fetchone()[0]
+ count = cursor.fetchone()[0]
+
+ # Apply any offset and limit constraints manually, since using LIMIT or
+ # OFFSET in SQL doesn't change the output of COUNT.
+ if offset:
+ count = max(0, count - offset)
+ if limit:
+ count = min(limit, count)
+
+ return count
def get(self, *args, **kwargs):
"Performs the SELECT and returns a single object matching the given keyword arguments."
@@ -359,9 +398,9 @@ class QuerySet(object):
else:
return self._filter_or_exclude(None, **filter_obj)
- def select_related(self, true_or_false=True):
+ def select_related(self, true_or_false=True, depth=0):
"Returns a new QuerySet instance with '_select_related' modified."
- return self._clone(_select_related=true_or_false)
+ return self._clone(_select_related=true_or_false, _max_related_depth=depth)
def order_by(self, *field_names):
"Returns a new QuerySet instance with the ordering changed."
@@ -395,6 +434,7 @@ class QuerySet(object):
c._filters = self._filters
c._order_by = self._order_by
c._select_related = self._select_related
+ c._max_related_depth = self._max_related_depth
c._distinct = self._distinct
c._select = self._select.copy()
c._where = self._where[:]
@@ -448,7 +488,10 @@ class QuerySet(object):
# Add additional tables and WHERE clauses based on select_related.
if self._select_related:
- fill_table_cache(opts, select, tables, where, opts.db_table, [opts.db_table])
+ fill_table_cache(opts, select, tables, where,
+ old_prefix=opts.db_table,
+ cache_tables_seen=[opts.db_table],
+ max_depth=self._max_related_depth)
# Add any additional SELECTs.
if self._select:
@@ -509,22 +552,42 @@ class QuerySet(object):
return select, " ".join(sql), params
class ValuesQuerySet(QuerySet):
- def iterator(self):
- # select_related and select aren't supported in values().
+ def __init__(self, *args, **kwargs):
+ super(ValuesQuerySet, self).__init__(*args, **kwargs)
+ # select_related isn't supported in values().
self._select_related = False
- self._select = {}
+
+ def iterator(self):
+ try:
+ select, sql, params = self._get_sql_clause()
+ except EmptyResultSet:
+ raise StopIteration
# self._fields is a list of field names to fetch.
if self._fields:
- columns = [self.model._meta.get_field(f, many_to_many=False).column for f in self._fields]
+ #columns = [self.model._meta.get_field(f, many_to_many=False).column for f in self._fields]
+ if not self._select:
+ columns = [self.model._meta.get_field(f, many_to_many=False).column for f in self._fields]
+ else:
+ columns = []
+ for f in self._fields:
+ if f in [field.name for field in self.model._meta.fields]:
+ columns.append( self.model._meta.get_field(f, many_to_many=False).column )
+ elif not self._select.has_key( f ):
+ raise FieldDoesNotExist, '%s has no field named %r' % ( self.model._meta.object_name, f )
+
field_names = self._fields
else: # Default to all fields.
columns = [f.column for f in self.model._meta.fields]
field_names = [f.attname for f in self.model._meta.fields]
- cursor = connection.cursor()
- select, sql, params = self._get_sql_clause()
select = ['%s.%s' % (backend.quote_name(self.model._meta.db_table), backend.quote_name(c)) for c in columns]
+
+ # Add any additional SELECTs.
+ if self._select:
+ select.extend(['(%s) AS %s' % (quote_only_if_word(s[1]), backend.quote_name(s[0])) for s in self._select.items()])
+
+ cursor = connection.cursor()
cursor.execute("SELECT " + (self._distinct and "DISTINCT " or "") + ",".join(select) + sql, params)
while 1:
rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)
@@ -545,7 +608,12 @@ class DateQuerySet(QuerySet):
if self._field.null:
self._where.append('%s.%s IS NOT NULL' % \
(backend.quote_name(self.model._meta.db_table), backend.quote_name(self._field.column)))
- select, sql, params = self._get_sql_clause()
+
+ try:
+ select, sql, params = self._get_sql_clause()
+ except EmptyResultSet:
+ raise StopIteration
+
sql = 'SELECT %s %s GROUP BY 1 ORDER BY 1 %s' % \
(backend.get_date_trunc_sql(self._kind, '%s.%s' % (backend.quote_name(self.model._meta.db_table),
backend.quote_name(self._field.column))), sql, self._order)
@@ -563,6 +631,25 @@ class DateQuerySet(QuerySet):
c._order = self._order
return c
+class EmptyQuerySet(QuerySet):
+ def __init__(self, model=None):
+ super(EmptyQuerySet, self).__init__(model)
+ self._result_cache = []
+
+ def count(self):
+ return 0
+
+ def delete(self):
+ pass
+
+ def _clone(self, klass=None, **kwargs):
+ c = super(EmptyQuerySet, self)._clone(klass, **kwargs)
+ c._result_cache = []
+ return c
+
+ def _get_sql_clause(self):
+ raise EmptyResultSet
+
class QOperator(object):
"Base class for QAnd and QOr"
def __init__(self, *args):
@@ -571,10 +658,14 @@ class QOperator(object):
def get_sql(self, opts):
joins, where, params = SortedDict(), [], []
for val in self.args:
- joins2, where2, params2 = val.get_sql(opts)
- joins.update(joins2)
- where.extend(where2)
- params.extend(params2)
+ try:
+ joins2, where2, params2 = val.get_sql(opts)
+ joins.update(joins2)
+ where.extend(where2)
+ params.extend(params2)
+ except EmptyResultSet:
+ if not isinstance(self, QOr):
+ raise EmptyResultSet
if where:
return joins, ['(%s)' % self.operator.join(where)], params
return joins, [], params
@@ -628,8 +719,11 @@ class QNot(Q):
self.q = q
def get_sql(self, opts):
- joins, where, params = self.q.get_sql(opts)
- where2 = ['(NOT (%s))' % " AND ".join(where)]
+ try:
+ joins, where, params = self.q.get_sql(opts)
+ where2 = ['(NOT (%s))' % " AND ".join(where)]
+ except EmptyResultSet:
+ return SortedDict(), [], []
return joins, where2, params
def get_where_clause(lookup_type, table_prefix, field_name, value):
@@ -641,10 +735,14 @@ def get_where_clause(lookup_type, table_prefix, field_name, value):
except KeyError:
pass
if lookup_type == 'in':
- return '%s%s IN (%s)' % (table_prefix, field_name, ','.join(['%s' for v in value]))
- elif lookup_type == 'range':
+ in_string = ','.join(['%s' for id in value])
+ if in_string:
+ return '%s%s IN (%s)' % (table_prefix, field_name, in_string)
+ else:
+ raise EmptyResultSet
+ elif lookup_type in ('range', 'year'):
return '%s%s BETWEEN %%s AND %%s' % (table_prefix, field_name)
- elif lookup_type in ('year', 'month', 'day'):
+ elif lookup_type in ('month', 'day'):
return "%s = %%s" % backend.get_date_extract_sql(lookup_type, table_prefix + field_name)
elif lookup_type == 'isnull':
return "%s%s IS %sNULL" % (table_prefix, field_name, (not value and 'NOT ' or ''))
@@ -652,21 +750,33 @@ def get_where_clause(lookup_type, table_prefix, field_name, value):
return backend.get_fulltext_search_sql(table_prefix + field_name)
raise TypeError, "Got invalid lookup_type: %s" % repr(lookup_type)
-def get_cached_row(klass, row, index_start):
- "Helper function that recursively returns an object with cache filled"
+def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0):
+ """Helper function that recursively returns an object with cache filled"""
+
+ # If we've got a max_depth set and we've exceeded that depth, bail now.
+ if max_depth and cur_depth > max_depth:
+ return None
+
index_end = index_start + len(klass._meta.fields)
obj = klass(*row[index_start:index_end])
for f in klass._meta.fields:
if f.rel and not f.null:
- rel_obj, index_end = get_cached_row(f.rel.to, row, index_end)
- setattr(obj, f.get_cache_name(), rel_obj)
+ cached_row = get_cached_row(f.rel.to, row, index_end, max_depth, cur_depth+1)
+ if cached_row:
+ rel_obj, index_end = cached_row
+ setattr(obj, f.get_cache_name(), rel_obj)
return obj, index_end
-def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen):
+def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen, max_depth=0, cur_depth=0):
"""
Helper function that recursively populates the select, tables and where (in
place) for select_related queries.
"""
+
+ # If we've got a max_depth set and we've exceeded that depth, bail now.
+ if max_depth and cur_depth > max_depth:
+ return None
+
qn = backend.quote_name
for f in opts.fields:
if f.rel and not f.null:
@@ -681,12 +791,12 @@ def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen)
where.append('%s.%s = %s.%s' % \
(qn(old_prefix), qn(f.column), qn(db_table), qn(f.rel.get_related_field().column)))
select.extend(['%s.%s' % (qn(db_table), qn(f2.column)) for f2 in f.rel.to._meta.fields])
- fill_table_cache(f.rel.to._meta, select, tables, where, db_table, cache_tables_seen)
+ fill_table_cache(f.rel.to._meta, select, tables, where, db_table, cache_tables_seen, max_depth, cur_depth+1)
def parse_lookup(kwarg_items, opts):
# Helper function that handles converting API kwargs
# (e.g. "name__exact": "tom") to SQL.
- # Returns a tuple of (tables, joins, where, params).
+ # Returns a tuple of (joins, where, params).
# 'joins' is a sorted dictionary describing the tables that must be joined
# to complete the query. The dictionary is sorted because creation order
@@ -725,12 +835,14 @@ def parse_lookup(kwarg_items, opts):
if len(path) < 1:
raise TypeError, "Cannot parse keyword query %r" % kwarg
-
+
if value is None:
# Interpret '__exact=None' as the sql '= NULL'; otherwise, reject
# all uses of None as a query value.
if lookup_type != 'exact':
raise ValueError, "Cannot use None as a query value"
+ elif callable(value):
+ value = value()
joins2, where2, params2 = lookup_inner(path, lookup_type, value, opts, opts.db_table, None)
joins.update(joins2)
@@ -755,6 +867,13 @@ def find_field(name, field_list, related_query):
return None
return matches[0]
+def field_choices(field_list, related_query):
+ if related_query:
+ choices = [f.field.related_query_name() for f in field_list]
+ else:
+ choices = [f.name for f in field_list]
+ return choices
+
def lookup_inner(path, lookup_type, value, opts, table, column):
qn = backend.quote_name
joins, where, params = SortedDict(), [], []
@@ -827,13 +946,23 @@ def lookup_inner(path, lookup_type, value, opts, table, column):
new_opts = field.rel.to._meta
new_column = new_opts.pk.column
join_column = field.column
-
- raise FieldFound
+ raise FieldFound
+ elif path:
+ # For regular fields, if there are still items on the path,
+ # an error has been made. We munge "name" so that the error
+ # properly identifies the cause of the problem.
+ name += LOOKUP_SEPARATOR + path[0]
+ else:
+ raise FieldFound
except FieldFound: # Match found, loop has been shortcut.
pass
else: # No match found.
- raise TypeError, "Cannot resolve keyword '%s' into field" % name
+ choices = field_choices(current_opts.many_to_many, False) + \
+ field_choices(current_opts.get_all_related_many_to_many_objects(), True) + \
+ field_choices(current_opts.get_all_related_objects(), True) + \
+ field_choices(current_opts.fields, False)
+ raise TypeError, "Cannot resolve keyword '%s' into field. Choices are: %s" % (name, ", ".join(choices))
# Check whether an intermediate join is required between current_table
# and new_table.
@@ -926,18 +1055,26 @@ def delete_objects(seen_objs):
pk_list = [pk for pk,instance in seen_objs[cls]]
for related in cls._meta.get_all_related_many_to_many_objects():
- for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
- cursor.execute("DELETE FROM %s WHERE %s IN (%s)" % \
- (qn(related.field.m2m_db_table()),
- qn(related.field.m2m_reverse_name()),
- ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]])),
- pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE])
+ if not isinstance(related.field, generic.GenericRelation):
+ for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
+ cursor.execute("DELETE FROM %s WHERE %s IN (%s)" % \
+ (qn(related.field.m2m_db_table()),
+ qn(related.field.m2m_reverse_name()),
+ ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]])),
+ pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE])
for f in cls._meta.many_to_many:
+ if isinstance(f, generic.GenericRelation):
+ from django.contrib.contenttypes.models import ContentType
+ query_extra = 'AND %s=%%s' % f.rel.to._meta.get_field(f.content_type_field_name).column
+ args_extra = [ContentType.objects.get_for_model(cls).id]
+ else:
+ query_extra = ''
+ args_extra = []
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
- cursor.execute("DELETE FROM %s WHERE %s IN (%s)" % \
+ cursor.execute(("DELETE FROM %s WHERE %s IN (%s)" % \
(qn(f.m2m_db_table()), qn(f.m2m_column_name()),
- ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]])),
- pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE])
+ ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]]))) + query_extra,
+ pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE] + args_extra)
for field in cls._meta.fields:
if field.rel and field.null and field.rel.to in seen_objs:
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
diff --git a/django/db/models/related.py b/django/db/models/related.py
index ac1ec50ca2..2c1dc5c516 100644
--- a/django/db/models/related.py
+++ b/django/db/models/related.py
@@ -1,7 +1,7 @@
class BoundRelatedObject(object):
def __init__(self, related_object, field_mapping, original):
self.relation = related_object
- self.field_mappings = field_mapping[related_object.opts.module_name]
+ self.field_mappings = field_mapping[related_object.name]
def template_name(self):
raise NotImplementedError
@@ -16,7 +16,7 @@ class RelatedObject(object):
self.opts = model._meta
self.field = field
self.edit_inline = field.rel.edit_inline
- self.name = self.opts.module_name
+ self.name = '%s:%s' % (self.opts.app_label, self.opts.module_name)
self.var_name = self.opts.object_name.lower()
def flatten_data(self, follow, obj=None):
@@ -68,7 +68,10 @@ class RelatedObject(object):
# object
return [attr]
else:
- return [None] * self.field.rel.num_in_admin
+ if self.field.rel.min_num_in_admin:
+ return [None] * max(self.field.rel.num_in_admin, self.field.rel.min_num_in_admin)
+ else:
+ return [None] * self.field.rel.num_in_admin
def get_db_prep_lookup(self, lookup_type, value):
# Defer to the actual field definition for db prep
@@ -101,12 +104,12 @@ class RelatedObject(object):
attr = getattr(manipulator.original_object, self.get_accessor_name())
count = attr.count()
count += self.field.rel.num_extra_on_change
- if self.field.rel.min_num_in_admin:
- count = max(count, self.field.rel.min_num_in_admin)
- if self.field.rel.max_num_in_admin:
- count = min(count, self.field.rel.max_num_in_admin)
else:
count = self.field.rel.num_in_admin
+ if self.field.rel.min_num_in_admin:
+ count = max(count, self.field.rel.min_num_in_admin)
+ if self.field.rel.max_num_in_admin:
+ count = min(count, self.field.rel.max_num_in_admin)
else:
count = 1
diff --git a/django/db/transaction.py b/django/db/transaction.py
index 4a0658e1c3..bb90713525 100644
--- a/django/db/transaction.py
+++ b/django/db/transaction.py
@@ -46,12 +46,12 @@ def enter_transaction_management():
when no current block is running).
"""
thread_ident = thread.get_ident()
- if state.has_key(thread_ident) and state[thread_ident]:
+ if thread_ident in state and state[thread_ident]:
state[thread_ident].append(state[thread_ident][-1])
else:
state[thread_ident] = []
state[thread_ident].append(settings.TRANSACTIONS_MANAGED)
- if not dirty.has_key(thread_ident):
+ if thread_ident not in dirty:
dirty[thread_ident] = False
def leave_transaction_management():
@@ -61,7 +61,7 @@ def leave_transaction_management():
those from outside. (Commits are on connection level.)
"""
thread_ident = thread.get_ident()
- if state.has_key(thread_ident) and state[thread_ident]:
+ if thread_ident in state and state[thread_ident]:
del state[thread_ident][-1]
else:
raise TransactionManagementError("This code isn't under transaction management")
@@ -84,7 +84,7 @@ def set_dirty():
changes waiting for commit.
"""
thread_ident = thread.get_ident()
- if dirty.has_key(thread_ident):
+ if thread_ident in dirty:
dirty[thread_ident] = True
else:
raise TransactionManagementError("This code isn't under transaction management")
@@ -96,7 +96,7 @@ def set_clean():
should happen.
"""
thread_ident = thread.get_ident()
- if dirty.has_key(thread_ident):
+ if thread_ident in dirty:
dirty[thread_ident] = False
else:
raise TransactionManagementError("This code isn't under transaction management")
@@ -106,7 +106,7 @@ def is_managed():
Checks whether the transaction manager is in manual or in auto state.
"""
thread_ident = thread.get_ident()
- if state.has_key(thread_ident):
+ if thread_ident in state:
if state[thread_ident]:
return state[thread_ident][-1]
return settings.TRANSACTIONS_MANAGED