diff options
| author | Christopher Long <indirecthit@gmail.com> | 2007-06-17 22:18:54 +0000 |
|---|---|---|
| committer | Christopher Long <indirecthit@gmail.com> | 2007-06-17 22:18:54 +0000 |
| commit | ae22b6d403dcf25098c77f0dfcf59ae58b186461 (patch) | |
| tree | c37fc631e99a7e4d909d6b6d236f495003731ea7 /django/db | |
| parent | 0cf7bc439129c66df8d64601e885f83b256b4f25 (diff) | |
per-object-permissions: Merged to trunk [5486] NOTE: Not fully tested, will be working on this over the next few weeks.
git-svn-id: http://code.djangoproject.com/svn/django/branches/per-object-permissions@5488 bcc190cf-cafb-0310-a4f2-bffc1f526a37
Diffstat (limited to 'django/db')
36 files changed, 1351 insertions, 514 deletions
diff --git a/django/db/__init__.py b/django/db/__init__.py index 4176b5aa79..33223d200a 100644 --- a/django/db/__init__.py +++ b/django/db/__init__.py @@ -2,7 +2,7 @@ from django.conf import settings from django.core import signals from django.dispatch import dispatcher -__all__ = ('backend', 'connection', 'DatabaseError') +__all__ = ('backend', 'connection', 'DatabaseError', 'IntegrityError') if not settings.DATABASE_ENGINE: settings.DATABASE_ENGINE = 'dummy' @@ -29,6 +29,7 @@ runshell = lambda: __import__('django.db.backends.%s.client' % settings.DATABASE connection = backend.DatabaseWrapper(**settings.DATABASE_OPTIONS) DatabaseError = backend.DatabaseError +IntegrityError = backend.IntegrityError # Register an event that closes the database connection # when a Django request is finished. diff --git a/django/db/backends/ado_mssql/base.py b/django/db/backends/ado_mssql/base.py index 72d2fe083e..52363ed705 100644 --- a/django/db/backends/ado_mssql/base.py +++ b/django/db/backends/ado_mssql/base.py @@ -17,6 +17,7 @@ except ImportError: mx = None DatabaseError = Database.DatabaseError +IntegrityError = Database.IntegrityError # We need to use a special Cursor class because adodbapi expects question-mark # param style, but Django expects "%s". This cursor converts question marks to @@ -76,10 +77,11 @@ class DatabaseWrapper(local): return cursor def _commit(self): - return self.connection.commit() + if self.connection is not None: + return self.connection.commit() def _rollback(self): - if self.connection: + if self.connection is not None: return self.connection.rollback() def close(self): @@ -125,6 +127,9 @@ def get_limit_offset_sql(limit, offset=None): def get_random_function_sql(): return "RAND()" +def get_deferrable_sql(): + return " DEFERRABLE INITIALLY DEFERRED" + def get_fulltext_search_sql(field_name): raise NotImplementedError @@ -134,6 +139,24 @@ def get_drop_foreignkey_sql(): def get_pk_default_value(): return "DEFAULT" +def get_sql_flush(style, tables, sequences): + """Return a list of SQL statements required to remove all data from + all tables in the database (without actually removing the tables + themselves) and put the database in an empty 'initial' state + """ + # Return a list of 'TRUNCATE x;', 'TRUNCATE y;', 'TRUNCATE z;'... style SQL statements + # TODO - SQL not actually tested against ADO MSSQL yet! + # TODO - autoincrement indices reset required? See other get_sql_flush() implementations + sql_list = ['%s %s;' % \ + (style.SQL_KEYWORD('TRUNCATE'), + style.SQL_FIELD(quote_name(table)) + ) for table in tables] + +def get_sql_sequence_reset(style, model_list): + "Returns a list of the SQL statements to reset sequences for the given models." + # No sequence reset required + return [] + OPERATOR_MAPPING = { 'exact': '= %s', 'iexact': 'LIKE %s', diff --git a/django/db/backends/ado_mssql/creation.py b/django/db/backends/ado_mssql/creation.py index 4d85d27ea5..a1098ea43e 100644 --- a/django/db/backends/ado_mssql/creation.py +++ b/django/db/backends/ado_mssql/creation.py @@ -5,9 +5,10 @@ DATA_TYPES = { 'CommaSeparatedIntegerField': 'varchar(%(maxlength)s)', 'DateField': 'smalldatetime', 'DateTimeField': 'smalldatetime', + 'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)', 'FileField': 'varchar(100)', 'FilePathField': 'varchar(100)', - 'FloatField': 'numeric(%(max_digits)s, %(decimal_places)s)', + 'FloatField': 'double precision', 'ImageField': 'varchar(100)', 'IntegerField': 'int', 'IPAddressField': 'char(15)', @@ -21,6 +22,5 @@ DATA_TYPES = { 'SmallIntegerField': 'smallint', 'TextField': 'text', 'TimeField': 'time', - 'URLField': 'varchar(200)', 'USStateField': 'varchar(2)', } diff --git a/django/db/backends/dummy/base.py b/django/db/backends/dummy/base.py index f98afc48bb..d0ec897407 100644 --- a/django/db/backends/dummy/base.py +++ b/django/db/backends/dummy/base.py @@ -12,13 +12,19 @@ from django.core.exceptions import ImproperlyConfigured def complain(*args, **kwargs): raise ImproperlyConfigured, "You haven't set the DATABASE_ENGINE setting yet." +def ignore(*args, **kwargs): + pass + class DatabaseError(Exception): pass +class IntegrityError(DatabaseError): + pass + class DatabaseWrapper: cursor = complain _commit = complain - _rollback = complain + _rollback = ignore def __init__(self, **kwargs): pass @@ -36,6 +42,10 @@ get_date_extract_sql = complain get_date_trunc_sql = complain get_limit_offset_sql = complain get_random_function_sql = complain +get_deferrable_sql = complain get_fulltext_search_sql = complain get_drop_foreignkey_sql = complain +get_sql_flush = complain +get_sql_sequence_reset = complain + OPERATOR_MAPPING = {} diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py index e7e060e6c2..d4cb1fa964 100644 --- a/django/db/backends/mysql/base.py +++ b/django/db/backends/mysql/base.py @@ -10,19 +10,34 @@ try: except ImportError, e: from django.core.exceptions import ImproperlyConfigured raise ImproperlyConfigured, "Error loading MySQLdb module: %s" % e + +# We want version (1, 2, 1, 'final', 2) or later. We can't just use +# lexicographic ordering in this check because then (1, 2, 1, 'gamma') +# inadvertently passes the version test. +version = Database.version_info +if (version < (1,2,1) or (version[:3] == (1, 2, 1) and + (len(version) < 5 or version[3] != 'final' or version[4] < 2))): + raise ImportError, "MySQLdb-1.2.1p2 or newer is required; you have %s" % Database.__version__ + from MySQLdb.converters import conversions from MySQLdb.constants import FIELD_TYPE import types import re DatabaseError = Database.DatabaseError +IntegrityError = Database.IntegrityError +# MySQLdb-1.2.1 supports the Python boolean type, and only uses datetime +# module for time-related columns; older versions could have used mx.DateTime +# or strings if there were no datetime module. However, MySQLdb still returns +# TIME columns as timedelta -- they are more like timedelta in terms of actual +# behavior as they are signed and include days -- and Django expects time, so +# we still need to override that. django_conversions = conversions.copy() django_conversions.update({ - types.BooleanType: util.rev_typecast_boolean, - FIELD_TYPE.DATETIME: util.typecast_timestamp, - FIELD_TYPE.DATE: util.typecast_date, FIELD_TYPE.TIME: util.typecast_time, + FIELD_TYPE.DECIMAL: util.typecast_decimal, + FIELD_TYPE.NEWDECIMAL: util.typecast_decimal, }) # This should match the numerical portion of the version numbers (we can treat @@ -31,31 +46,12 @@ django_conversions.update({ # http://dev.mysql.com/doc/refman/5.0/en/news.html . server_version_re = re.compile(r'(\d{1,2})\.(\d{1,2})\.(\d{1,2})') -# This is an extra debug layer over MySQL queries, to display warnings. -# It's only used when DEBUG=True. -class MysqlDebugWrapper: - def __init__(self, cursor): - self.cursor = cursor - - def execute(self, sql, params=()): - try: - return self.cursor.execute(sql, params) - except Database.Warning, w: - self.cursor.execute("SHOW WARNINGS") - raise Database.Warning, "%s: %s" % (w, self.cursor.fetchall()) - - def executemany(self, sql, param_list): - try: - return self.cursor.executemany(sql, param_list) - except Database.Warning, w: - self.cursor.execute("SHOW WARNINGS") - raise Database.Warning, "%s: %s" % (w, self.cursor.fetchall()) - - def __getattr__(self, attr): - if self.__dict__.has_key(attr): - return self.__dict__[attr] - else: - return getattr(self.cursor, attr) +# MySQLdb-1.2.1 and newer automatically makes use of SHOW WARNINGS on +# MySQL-4.1 and newer, so the MysqlDebugWrapper is unnecessary. Since the +# point is to raise Warnings as exceptions, this can be done with the Python +# warning module, and this is setup when the connection is created, and the +# standard util.CursorDebugWrapper can be used. Also, using sql_mode +# TRADITIONAL will automatically cause most warnings to be treated as errors. try: # Only exists in Python 2.4+ @@ -83,33 +79,41 @@ class DatabaseWrapper(local): def cursor(self): from django.conf import settings + from warnings import filterwarnings if not self._valid_connection(): kwargs = { - 'user': settings.DATABASE_USER, - 'db': settings.DATABASE_NAME, - 'passwd': settings.DATABASE_PASSWORD, 'conv': django_conversions, + 'charset': 'utf8', + 'use_unicode': False, } + if settings.DATABASE_USER: + kwargs['user'] = settings.DATABASE_USER + if settings.DATABASE_NAME: + kwargs['db'] = settings.DATABASE_NAME + if settings.DATABASE_PASSWORD: + kwargs['passwd'] = settings.DATABASE_PASSWORD if settings.DATABASE_HOST.startswith('/'): kwargs['unix_socket'] = settings.DATABASE_HOST - else: + elif settings.DATABASE_HOST: kwargs['host'] = settings.DATABASE_HOST if settings.DATABASE_PORT: kwargs['port'] = int(settings.DATABASE_PORT) kwargs.update(self.options) self.connection = Database.connect(**kwargs) - cursor = self.connection.cursor() - if self.connection.get_server_info() >= '4.1': - cursor.execute("SET NAMES 'utf8'") + cursor = self.connection.cursor() + else: + cursor = self.connection.cursor() if settings.DEBUG: - return util.CursorDebugWrapper(MysqlDebugWrapper(cursor), self) + filterwarnings("error", category=Database.Warning) + return util.CursorDebugWrapper(cursor, self) return cursor def _commit(self): - self.connection.commit() + if self.connection is not None: + self.connection.commit() def _rollback(self): - if self.connection: + if self.connection is not None: try: self.connection.rollback() except Database.NotSupportedError: @@ -172,6 +176,9 @@ def get_limit_offset_sql(limit, offset=None): def get_random_function_sql(): return "RAND()" +def get_deferrable_sql(): + return "" + def get_fulltext_search_sql(field_name): return 'MATCH (%s) AGAINST (%%s IN BOOLEAN MODE)' % field_name @@ -181,6 +188,41 @@ def get_drop_foreignkey_sql(): def get_pk_default_value(): return "DEFAULT" +def get_sql_flush(style, tables, sequences): + """Return a list of SQL statements required to remove all data from + all tables in the database (without actually removing the tables + themselves) and put the database in an empty 'initial' state + + """ + # NB: The generated SQL below is specific to MySQL + # 'TRUNCATE x;', 'TRUNCATE y;', 'TRUNCATE z;'... style SQL statements + # to clear all tables of all data + if tables: + sql = ['SET FOREIGN_KEY_CHECKS = 0;'] + \ + ['%s %s;' % \ + (style.SQL_KEYWORD('TRUNCATE'), + style.SQL_FIELD(quote_name(table)) + ) for table in tables] + \ + ['SET FOREIGN_KEY_CHECKS = 1;'] + + # 'ALTER TABLE table AUTO_INCREMENT = 1;'... style SQL statements + # to reset sequence indices + sql.extend(["%s %s %s %s %s;" % \ + (style.SQL_KEYWORD('ALTER'), + style.SQL_KEYWORD('TABLE'), + style.SQL_TABLE(quote_name(sequence['table'])), + style.SQL_KEYWORD('AUTO_INCREMENT'), + style.SQL_FIELD('= 1'), + ) for sequence in sequences]) + return sql + else: + return [] + +def get_sql_sequence_reset(style, model_list): + "Returns a list of the SQL statements to reset sequences for the given models." + # No sequence reset required + return [] + OPERATOR_MAPPING = { 'exact': '= %s', 'iexact': 'LIKE %s', diff --git a/django/db/backends/mysql/client.py b/django/db/backends/mysql/client.py index f9d6297b8e..116074a9ce 100644 --- a/django/db/backends/mysql/client.py +++ b/django/db/backends/mysql/client.py @@ -3,12 +3,25 @@ import os def runshell(): args = [''] - args += ["--user=%s" % settings.DATABASE_USER] - if settings.DATABASE_PASSWORD: - args += ["--password=%s" % settings.DATABASE_PASSWORD] - if settings.DATABASE_HOST: - args += ["--host=%s" % settings.DATABASE_HOST] - if settings.DATABASE_PORT: - args += ["--port=%s" % settings.DATABASE_PORT] - args += [settings.DATABASE_NAME] + db = settings.DATABASE_OPTIONS.get('db', settings.DATABASE_NAME) + user = settings.DATABASE_OPTIONS.get('user', settings.DATABASE_USER) + passwd = settings.DATABASE_OPTIONS.get('passwd', settings.DATABASE_PASSWORD) + host = settings.DATABASE_OPTIONS.get('host', settings.DATABASE_HOST) + port = settings.DATABASE_OPTIONS.get('port', settings.DATABASE_PORT) + defaults_file = settings.DATABASE_OPTIONS.get('read_default_file') + # Seems to be no good way to set sql_mode with CLI + + if defaults_file: + args += ["--defaults-file=%s" % defaults_file] + if user: + args += ["--user=%s" % user] + if passwd: + args += ["--password=%s" % passwd] + if host: + args += ["--host=%s" % host] + if port: + args += ["--port=%s" % port] + if db: + args += [db] + os.execvp('mysql', args) diff --git a/django/db/backends/mysql/creation.py b/django/db/backends/mysql/creation.py index 116b490124..1b23fbff6e 100644 --- a/django/db/backends/mysql/creation.py +++ b/django/db/backends/mysql/creation.py @@ -9,9 +9,10 @@ DATA_TYPES = { 'CommaSeparatedIntegerField': 'varchar(%(maxlength)s)', 'DateField': 'date', 'DateTimeField': 'datetime', + 'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)', 'FileField': 'varchar(100)', 'FilePathField': 'varchar(100)', - 'FloatField': 'numeric(%(max_digits)s, %(decimal_places)s)', + 'FloatField': 'double precision', 'ImageField': 'varchar(100)', 'IntegerField': 'integer', 'IPAddressField': 'char(15)', @@ -25,6 +26,5 @@ DATA_TYPES = { 'SmallIntegerField': 'smallint', 'TextField': 'longtext', 'TimeField': 'time', - 'URLField': 'varchar(200)', 'USStateField': 'varchar(2)', } diff --git a/django/db/backends/mysql/introspection.py b/django/db/backends/mysql/introspection.py index 7829457fa9..39733311c5 100644 --- a/django/db/backends/mysql/introspection.py +++ b/django/db/backends/mysql/introspection.py @@ -76,7 +76,7 @@ def get_indexes(cursor, table_name): DATA_TYPES_REVERSE = { FIELD_TYPE.BLOB: 'TextField', FIELD_TYPE.CHAR: 'CharField', - FIELD_TYPE.DECIMAL: 'FloatField', + FIELD_TYPE.DECIMAL: 'DecimalField', FIELD_TYPE.DATE: 'DateField', FIELD_TYPE.DATETIME: 'DateTimeField', FIELD_TYPE.DOUBLE: 'FloatField', @@ -85,7 +85,7 @@ DATA_TYPES_REVERSE = { FIELD_TYPE.LONG: 'IntegerField', FIELD_TYPE.LONGLONG: 'IntegerField', FIELD_TYPE.SHORT: 'IntegerField', - FIELD_TYPE.STRING: 'TextField', + FIELD_TYPE.STRING: 'CharField', FIELD_TYPE.TIMESTAMP: 'DateTimeField', FIELD_TYPE.TINY: 'IntegerField', FIELD_TYPE.TINY_BLOB: 'TextField', diff --git a/django/db/backends/mysql_old/__init__.py b/django/db/backends/mysql_old/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/django/db/backends/mysql_old/__init__.py diff --git a/django/db/backends/mysql_old/base.py b/django/db/backends/mysql_old/base.py new file mode 100644 index 0000000000..ac3b75efde --- /dev/null +++ b/django/db/backends/mysql_old/base.py @@ -0,0 +1,240 @@ +""" +MySQL database backend for Django. + +Requires MySQLdb: http://sourceforge.net/projects/mysql-python +""" + +from django.db.backends import util +try: + import MySQLdb as Database +except ImportError, e: + from django.core.exceptions import ImproperlyConfigured + raise ImproperlyConfigured, "Error loading MySQLdb module: %s" % e +from MySQLdb.converters import conversions +from MySQLdb.constants import FIELD_TYPE +import types +import re + +DatabaseError = Database.DatabaseError +IntegrityError = Database.IntegrityError + +django_conversions = conversions.copy() +django_conversions.update({ + types.BooleanType: util.rev_typecast_boolean, + FIELD_TYPE.DATETIME: util.typecast_timestamp, + FIELD_TYPE.DATE: util.typecast_date, + FIELD_TYPE.TIME: util.typecast_time, + FIELD_TYPE.DECIMAL: util.typecast_decimal, +}) + +# This should match the numerical portion of the version numbers (we can treat +# versions like 5.0.24 and 5.0.24a as the same). Based on the list of version +# at http://dev.mysql.com/doc/refman/4.1/en/news.html and +# http://dev.mysql.com/doc/refman/5.0/en/news.html . +server_version_re = re.compile(r'(\d{1,2})\.(\d{1,2})\.(\d{1,2})') + +# This is an extra debug layer over MySQL queries, to display warnings. +# It's only used when DEBUG=True. +class MysqlDebugWrapper: + def __init__(self, cursor): + self.cursor = cursor + + def execute(self, sql, params=()): + try: + return self.cursor.execute(sql, params) + except Database.Warning, w: + self.cursor.execute("SHOW WARNINGS") + raise Database.Warning, "%s: %s" % (w, self.cursor.fetchall()) + + def executemany(self, sql, param_list): + try: + return self.cursor.executemany(sql, param_list) + except Database.Warning, w: + self.cursor.execute("SHOW WARNINGS") + raise Database.Warning, "%s: %s" % (w, self.cursor.fetchall()) + + def __getattr__(self, attr): + if attr in self.__dict__: + return self.__dict__[attr] + else: + return getattr(self.cursor, attr) + +try: + # Only exists in Python 2.4+ + from threading import local +except ImportError: + # Import copy of _thread_local.py from Python 2.4 + from django.utils._threading_local import local + +class DatabaseWrapper(local): + def __init__(self, **kwargs): + self.connection = None + self.queries = [] + self.server_version = None + self.options = kwargs + + def _valid_connection(self): + if self.connection is not None: + try: + self.connection.ping() + return True + except DatabaseError: + self.connection.close() + self.connection = None + return False + + def cursor(self): + from django.conf import settings + if not self._valid_connection(): + kwargs = { + 'user': settings.DATABASE_USER, + 'db': settings.DATABASE_NAME, + 'passwd': settings.DATABASE_PASSWORD, + 'conv': django_conversions, + } + if settings.DATABASE_HOST.startswith('/'): + kwargs['unix_socket'] = settings.DATABASE_HOST + else: + kwargs['host'] = settings.DATABASE_HOST + if settings.DATABASE_PORT: + kwargs['port'] = int(settings.DATABASE_PORT) + kwargs.update(self.options) + self.connection = Database.connect(**kwargs) + cursor = self.connection.cursor() + if self.connection.get_server_info() >= '4.1': + cursor.execute("SET NAMES 'utf8'") + else: + cursor = self.connection.cursor() + if settings.DEBUG: + return util.CursorDebugWrapper(MysqlDebugWrapper(cursor), self) + return cursor + + def _commit(self): + if self.connection is not None: + self.connection.commit() + + def _rollback(self): + if self.connection is not None: + try: + self.connection.rollback() + except Database.NotSupportedError: + pass + + def close(self): + if self.connection is not None: + self.connection.close() + self.connection = None + + def get_server_version(self): + if not self.server_version: + if not self._valid_connection(): + self.cursor() + m = server_version_re.match(self.connection.get_server_info()) + if not m: + raise Exception('Unable to determine MySQL version from version string %r' % self.connection.get_server_info()) + self.server_version = tuple([int(x) for x in m.groups()]) + return self.server_version + +supports_constraints = True + +def quote_name(name): + if name.startswith("`") and name.endswith("`"): + return name # Quoting once is enough. + return "`%s`" % name + +dictfetchone = util.dictfetchone +dictfetchmany = util.dictfetchmany +dictfetchall = util.dictfetchall + +def get_last_insert_id(cursor, table_name, pk_name): + return cursor.lastrowid + +def get_date_extract_sql(lookup_type, table_name): + # lookup_type is 'year', 'month', 'day' + # http://dev.mysql.com/doc/mysql/en/date-and-time-functions.html + return "EXTRACT(%s FROM %s)" % (lookup_type.upper(), table_name) + +def get_date_trunc_sql(lookup_type, field_name): + # lookup_type is 'year', 'month', 'day' + fields = ['year', 'month', 'day', 'hour', 'minute', 'second'] + format = ('%%Y-', '%%m', '-%%d', ' %%H:', '%%i', ':%%s') # Use double percents to escape. + format_def = ('0000-', '01', '-01', ' 00:', '00', ':00') + try: + i = fields.index(lookup_type) + 1 + except ValueError: + sql = field_name + else: + format_str = ''.join([f for f in format[:i]] + [f for f in format_def[i:]]) + sql = "CAST(DATE_FORMAT(%s, '%s') AS DATETIME)" % (field_name, format_str) + return sql + +def get_limit_offset_sql(limit, offset=None): + sql = "LIMIT " + if offset and offset != 0: + sql += "%s," % offset + return sql + str(limit) + +def get_random_function_sql(): + return "RAND()" + +def get_deferrable_sql(): + return "" + +def get_fulltext_search_sql(field_name): + return 'MATCH (%s) AGAINST (%%s IN BOOLEAN MODE)' % field_name + +def get_drop_foreignkey_sql(): + return "DROP FOREIGN KEY" + +def get_pk_default_value(): + return "DEFAULT" + +def get_sql_flush(style, tables, sequences): + """Return a list of SQL statements required to remove all data from + all tables in the database (without actually removing the tables + themselves) and put the database in an empty 'initial' state + + """ + # NB: The generated SQL below is specific to MySQL + # 'TRUNCATE x;', 'TRUNCATE y;', 'TRUNCATE z;'... style SQL statements + # to clear all tables of all data + if tables: + sql = ['SET FOREIGN_KEY_CHECKS = 0;'] + \ + ['%s %s;' % \ + (style.SQL_KEYWORD('TRUNCATE'), + style.SQL_FIELD(quote_name(table)) + ) for table in tables] + \ + ['SET FOREIGN_KEY_CHECKS = 1;'] + + # 'ALTER TABLE table AUTO_INCREMENT = 1;'... style SQL statements + # to reset sequence indices + sql.extend(["%s %s %s %s %s;" % \ + (style.SQL_KEYWORD('ALTER'), + style.SQL_KEYWORD('TABLE'), + style.SQL_TABLE(quote_name(sequence['table'])), + style.SQL_KEYWORD('AUTO_INCREMENT'), + style.SQL_FIELD('= 1'), + ) for sequence in sequences]) + return sql + else: + return [] + +def get_sql_sequence_reset(style, model_list): + "Returns a list of the SQL statements to reset sequences for the given models." + # No sequence reset required + return [] + +OPERATOR_MAPPING = { + 'exact': '= %s', + 'iexact': 'LIKE %s', + 'contains': 'LIKE BINARY %s', + 'icontains': 'LIKE %s', + 'gt': '> %s', + 'gte': '>= %s', + 'lt': '< %s', + 'lte': '<= %s', + 'startswith': 'LIKE BINARY %s', + 'endswith': 'LIKE BINARY %s', + 'istartswith': 'LIKE %s', + 'iendswith': 'LIKE %s', +} diff --git a/django/db/backends/mysql_old/client.py b/django/db/backends/mysql_old/client.py new file mode 100644 index 0000000000..f9d6297b8e --- /dev/null +++ b/django/db/backends/mysql_old/client.py @@ -0,0 +1,14 @@ +from django.conf import settings +import os + +def runshell(): + args = [''] + args += ["--user=%s" % settings.DATABASE_USER] + if settings.DATABASE_PASSWORD: + args += ["--password=%s" % settings.DATABASE_PASSWORD] + if settings.DATABASE_HOST: + args += ["--host=%s" % settings.DATABASE_HOST] + if settings.DATABASE_PORT: + args += ["--port=%s" % settings.DATABASE_PORT] + args += [settings.DATABASE_NAME] + os.execvp('mysql', args) diff --git a/django/db/backends/mysql_old/creation.py b/django/db/backends/mysql_old/creation.py new file mode 100644 index 0000000000..1b23fbff6e --- /dev/null +++ b/django/db/backends/mysql_old/creation.py @@ -0,0 +1,30 @@ +# This dictionary maps Field objects to their associated MySQL column +# types, as strings. Column-type strings can contain format strings; they'll +# be interpolated against the values of Field.__dict__ before being output. +# If a column type is set to None, it won't be included in the output. +DATA_TYPES = { + 'AutoField': 'integer AUTO_INCREMENT', + 'BooleanField': 'bool', + 'CharField': 'varchar(%(maxlength)s)', + 'CommaSeparatedIntegerField': 'varchar(%(maxlength)s)', + 'DateField': 'date', + 'DateTimeField': 'datetime', + 'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)', + 'FileField': 'varchar(100)', + 'FilePathField': 'varchar(100)', + 'FloatField': 'double precision', + 'ImageField': 'varchar(100)', + 'IntegerField': 'integer', + 'IPAddressField': 'char(15)', + 'ManyToManyField': None, + 'NullBooleanField': 'bool', + 'OneToOneField': 'integer', + 'PhoneNumberField': 'varchar(20)', + 'PositiveIntegerField': 'integer UNSIGNED', + 'PositiveSmallIntegerField': 'smallint UNSIGNED', + 'SlugField': 'varchar(%(maxlength)s)', + 'SmallIntegerField': 'smallint', + 'TextField': 'longtext', + 'TimeField': 'time', + 'USStateField': 'varchar(2)', +} diff --git a/django/db/backends/mysql_old/introspection.py b/django/db/backends/mysql_old/introspection.py new file mode 100644 index 0000000000..cb5b8320d9 --- /dev/null +++ b/django/db/backends/mysql_old/introspection.py @@ -0,0 +1,95 @@ +from django.db.backends.mysql_old.base import quote_name +from MySQLdb import ProgrammingError, OperationalError +from MySQLdb.constants import FIELD_TYPE +import re + +foreign_key_re = re.compile(r"\sCONSTRAINT `[^`]*` FOREIGN KEY \(`([^`]*)`\) REFERENCES `([^`]*)` \(`([^`]*)`\)") + +def get_table_list(cursor): + "Returns a list of table names in the current database." + cursor.execute("SHOW TABLES") + return [row[0] for row in cursor.fetchall()] + +def get_table_description(cursor, table_name): + "Returns a description of the table, with the DB-API cursor.description interface." + cursor.execute("SELECT * FROM %s LIMIT 1" % quote_name(table_name)) + return cursor.description + +def _name_to_index(cursor, table_name): + """ + Returns a dictionary of {field_name: field_index} for the given table. + Indexes are 0-based. + """ + return dict([(d[0], i) for i, d in enumerate(get_table_description(cursor, table_name))]) + +def get_relations(cursor, table_name): + """ + Returns a dictionary of {field_index: (field_index_other_table, other_table)} + representing all relationships to the given table. Indexes are 0-based. + """ + my_field_dict = _name_to_index(cursor, table_name) + constraints = [] + relations = {} + try: + # This should work for MySQL 5.0. + cursor.execute(""" + SELECT column_name, referenced_table_name, referenced_column_name + FROM information_schema.key_column_usage + WHERE table_name = %s + AND table_schema = DATABASE() + AND referenced_table_name IS NOT NULL + AND referenced_column_name IS NOT NULL""", [table_name]) + constraints.extend(cursor.fetchall()) + except (ProgrammingError, OperationalError): + # Fall back to "SHOW CREATE TABLE", for previous MySQL versions. + # Go through all constraints and save the equal matches. + cursor.execute("SHOW CREATE TABLE %s" % quote_name(table_name)) + for row in cursor.fetchall(): + pos = 0 + while True: + match = foreign_key_re.search(row[1], pos) + if match == None: + break + pos = match.end() + constraints.append(match.groups()) + + for my_fieldname, other_table, other_field in constraints: + other_field_index = _name_to_index(cursor, other_table)[other_field] + my_field_index = my_field_dict[my_fieldname] + relations[my_field_index] = (other_field_index, other_table) + + return relations + +def get_indexes(cursor, table_name): + """ + Returns a dictionary of fieldname -> infodict for the given table, + where each infodict is in the format: + {'primary_key': boolean representing whether it's the primary key, + 'unique': boolean representing whether it's a unique index} + """ + cursor.execute("SHOW INDEX FROM %s" % quote_name(table_name)) + indexes = {} + for row in cursor.fetchall(): + indexes[row[4]] = {'primary_key': (row[2] == 'PRIMARY'), 'unique': not bool(row[1])} + return indexes + +DATA_TYPES_REVERSE = { + FIELD_TYPE.BLOB: 'TextField', + FIELD_TYPE.CHAR: 'CharField', + FIELD_TYPE.DECIMAL: 'DecimalField', + FIELD_TYPE.DATE: 'DateField', + FIELD_TYPE.DATETIME: 'DateTimeField', + FIELD_TYPE.DOUBLE: 'FloatField', + FIELD_TYPE.FLOAT: 'FloatField', + FIELD_TYPE.INT24: 'IntegerField', + FIELD_TYPE.LONG: 'IntegerField', + FIELD_TYPE.LONGLONG: 'IntegerField', + FIELD_TYPE.SHORT: 'IntegerField', + FIELD_TYPE.STRING: 'TextField', + FIELD_TYPE.TIMESTAMP: 'DateTimeField', + FIELD_TYPE.TINY: 'IntegerField', + FIELD_TYPE.TINY_BLOB: 'TextField', + FIELD_TYPE.MEDIUM_BLOB: 'TextField', + FIELD_TYPE.LONG_BLOB: 'TextField', + FIELD_TYPE.VAR_STRING: 'CharField', +} diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py index 3a13f39546..2bc88bb7b9 100644 --- a/django/db/backends/oracle/base.py +++ b/django/db/backends/oracle/base.py @@ -12,6 +12,7 @@ except ImportError, e: raise ImproperlyConfigured, "Error loading cx_Oracle module: %s" % e DatabaseError = Database.Error +IntegrityError = Database.IntegrityError try: # Only exists in Python 2.4+ @@ -43,10 +44,11 @@ class DatabaseWrapper(local): return FormatStylePlaceholderCursor(self.connection) def _commit(self): - self.connection.commit() + if self.connection is not None: + self.connection.commit() def _rollback(self): - if self.connection: + if self.connection is not None: try: self.connection.rollback() except Database.NotSupportedError: @@ -108,6 +110,9 @@ def get_limit_offset_sql(limit, offset=None): def get_random_function_sql(): return "DBMS_RANDOM.RANDOM" +def get_deferrable_sql(): + return " DEFERRABLE INITIALLY DEFERRED" + def get_fulltext_search_sql(field_name): raise NotImplementedError @@ -117,6 +122,24 @@ def get_drop_foreignkey_sql(): def get_pk_default_value(): return "DEFAULT" +def get_sql_flush(style, tables, sequences): + """Return a list of SQL statements required to remove all data from + all tables in the database (without actually removing the tables + themselves) and put the database in an empty 'initial' state + """ + # Return a list of 'TRUNCATE x;', 'TRUNCATE y;', 'TRUNCATE z;'... style SQL statements + # TODO - SQL not actually tested against Oracle yet! + # TODO - autoincrement indices reset required? See other get_sql_flush() implementations + sql = ['%s %s;' % \ + (style.SQL_KEYWORD('TRUNCATE'), + style.SQL_FIELD(quote_name(table)) + ) for table in tables] + +def get_sql_sequence_reset(style, model_list): + "Returns a list of the SQL statements to reset sequences for the given models." + # No sequence reset required + return [] + OPERATOR_MAPPING = { 'exact': '= %s', 'iexact': 'LIKE %s', diff --git a/django/db/backends/oracle/creation.py b/django/db/backends/oracle/creation.py index d45ceb64f5..14a864ac28 100644 --- a/django/db/backends/oracle/creation.py +++ b/django/db/backends/oracle/creation.py @@ -5,9 +5,10 @@ DATA_TYPES = { 'CommaSeparatedIntegerField': 'varchar2(%(maxlength)s)', 'DateField': 'date', 'DateTimeField': 'date', + 'DecimalField': 'number(%(max_digits)s, %(decimal_places)s)', 'FileField': 'varchar2(100)', 'FilePathField': 'varchar2(100)', - 'FloatField': 'number(%(max_digits)s, %(decimal_places)s)', + 'FloatField': 'double precision', 'ImageField': 'varchar2(100)', 'IntegerField': 'integer', 'IPAddressField': 'char(15)', @@ -21,6 +22,5 @@ DATA_TYPES = { 'SmallIntegerField': 'smallint', 'TextField': 'long', 'TimeField': 'timestamp', - 'URLField': 'varchar(200)', 'USStateField': 'varchar(2)', } diff --git a/django/db/backends/oracle/introspection.py b/django/db/backends/oracle/introspection.py index ecc8f372a8..7634206178 100644 --- a/django/db/backends/oracle/introspection.py +++ b/django/db/backends/oracle/introspection.py @@ -46,5 +46,5 @@ DATA_TYPES_REVERSE = { 1114: 'DateTimeField', 1184: 'DateTimeField', 1266: 'TimeField', - 1700: 'FloatField', + 1700: 'DecimalField', } diff --git a/django/db/backends/postgresql/base.py b/django/db/backends/postgresql/base.py index e44bc0b560..fedbb6b7f1 100644 --- a/django/db/backends/postgresql/base.py +++ b/django/db/backends/postgresql/base.py @@ -12,6 +12,7 @@ except ImportError, e: raise ImproperlyConfigured, "Error loading psycopg module: %s" % e DatabaseError = Database.DatabaseError +IntegrityError = Database.IntegrityError try: # Only exists in Python 2.4+ @@ -20,6 +21,40 @@ except ImportError: # Import copy of _thread_local.py from Python 2.4 from django.utils._threading_local import local +def smart_basestring(s, charset): + if isinstance(s, unicode): + return s.encode(charset) + return s + +class UnicodeCursorWrapper(object): + """ + A thin wrapper around psycopg cursors that allows them to accept Unicode + strings as params. + + This is necessary because psycopg doesn't apply any DB quoting to + parameters that are Unicode strings. If a param is Unicode, this will + convert it to a bytestring using DEFAULT_CHARSET before passing it to + psycopg. + """ + def __init__(self, cursor, charset): + self.cursor = cursor + self.charset = charset + + def execute(self, sql, params=()): + return self.cursor.execute(sql, [smart_basestring(p, self.charset) for p in params]) + + def executemany(self, sql, param_list): + new_param_list = [tuple([smart_basestring(p, self.charset) for p in params]) for params in param_list] + return self.cursor.executemany(sql, new_param_list) + + def __getattr__(self, attr): + if attr in self.__dict__: + return self.__dict__[attr] + else: + return getattr(self.cursor, attr) + +postgres_version = None + class DatabaseWrapper(local): def __init__(self, **kwargs): self.connection = None @@ -28,7 +63,9 @@ class DatabaseWrapper(local): def cursor(self): from django.conf import settings + set_tz = False if self.connection is None: + set_tz = True if settings.DATABASE_NAME == '': from django.core.exceptions import ImproperlyConfigured raise ImproperlyConfigured, "You need to specify DATABASE_NAME in your Django settings file." @@ -44,16 +81,23 @@ class DatabaseWrapper(local): self.connection = Database.connect(conn_string, **self.options) self.connection.set_isolation_level(1) # make transactions transparent to all cursors cursor = self.connection.cursor() - cursor.execute("SET TIME ZONE %s", [settings.TIME_ZONE]) + if set_tz: + cursor.execute("SET TIME ZONE %s", [settings.TIME_ZONE]) + cursor = UnicodeCursorWrapper(cursor, settings.DEFAULT_CHARSET) + global postgres_version + if not postgres_version: + cursor.execute("SELECT version()") + postgres_version = [int(val) for val in cursor.fetchone()[0].split()[1].split('.')] if settings.DEBUG: return util.CursorDebugWrapper(cursor, self) return cursor def _commit(self): - return self.connection.commit() + if self.connection is not None: + return self.connection.commit() def _rollback(self): - if self.connection: + if self.connection is not None: return self.connection.rollback() def close(self): @@ -103,6 +147,9 @@ def get_limit_offset_sql(limit, offset=None): def get_random_function_sql(): return "RANDOM()" +def get_deferrable_sql(): + return " DEFERRABLE INITIALLY DEFERRED" + def get_fulltext_search_sql(field_name): raise NotImplementedError @@ -112,6 +159,91 @@ def get_drop_foreignkey_sql(): def get_pk_default_value(): return "DEFAULT" +def get_sql_flush(style, tables, sequences): + """Return a list of SQL statements required to remove all data from + all tables in the database (without actually removing the tables + themselves) and put the database in an empty 'initial' state + + """ + if tables: + if postgres_version[0] >= 8 and postgres_version[1] >= 1: + # Postgres 8.1+ can do 'TRUNCATE x, y, z...;'. In fact, it *has to* in order to be able to + # truncate tables referenced by a foreign key in any other table. The result is a + # single SQL TRUNCATE statement. + sql = ['%s %s;' % \ + (style.SQL_KEYWORD('TRUNCATE'), + style.SQL_FIELD(', '.join([quote_name(table) for table in tables])) + )] + else: + # Older versions of Postgres can't do TRUNCATE in a single call, so they must use + # a simple delete. + sql = ['%s %s %s;' % \ + (style.SQL_KEYWORD('DELETE'), + style.SQL_KEYWORD('FROM'), + style.SQL_FIELD(quote_name(table)) + ) for table in tables] + + # 'ALTER SEQUENCE sequence_name RESTART WITH 1;'... style SQL statements + # to reset sequence indices + for sequence_info in sequences: + table_name = sequence_info['table'] + column_name = sequence_info['column'] + if column_name and len(column_name)>0: + # sequence name in this case will be <table>_<column>_seq + sql.append("%s %s %s %s %s %s;" % \ + (style.SQL_KEYWORD('ALTER'), + style.SQL_KEYWORD('SEQUENCE'), + style.SQL_FIELD(quote_name('%s_%s_seq' % (table_name, column_name))), + style.SQL_KEYWORD('RESTART'), + style.SQL_KEYWORD('WITH'), + style.SQL_FIELD('1') + ) + ) + else: + # sequence name in this case will be <table>_id_seq + sql.append("%s %s %s %s %s %s;" % \ + (style.SQL_KEYWORD('ALTER'), + style.SQL_KEYWORD('SEQUENCE'), + style.SQL_FIELD(quote_name('%s_id_seq' % table_name)), + style.SQL_KEYWORD('RESTART'), + style.SQL_KEYWORD('WITH'), + style.SQL_FIELD('1') + ) + ) + return sql + else: + return [] + +def get_sql_sequence_reset(style, model_list): + "Returns a list of the SQL statements to reset sequences for the given models." + from django.db import models + output = [] + for model in model_list: + # Use `coalesce` to set the sequence for each model to the max pk value if there are records, + # or 1 if there are none. Set the `is_called` property (the third argument to `setval`) to true + # if there are records (as the max pk value is already in use), otherwise set it to false. + for f in model._meta.fields: + if isinstance(f, models.AutoField): + output.append("%s setval('%s', coalesce(max(%s), 1), max(%s) %s null) %s %s;" % \ + (style.SQL_KEYWORD('SELECT'), + style.SQL_FIELD(quote_name('%s_%s_seq' % (model._meta.db_table, f.column))), + style.SQL_FIELD(quote_name(f.column)), + style.SQL_FIELD(quote_name(f.column)), + style.SQL_KEYWORD('IS NOT'), + style.SQL_KEYWORD('FROM'), + style.SQL_TABLE(quote_name(model._meta.db_table)))) + break # Only one AutoField is allowed per model, so don't bother continuing. + for f in model._meta.many_to_many: + output.append("%s setval('%s', coalesce(max(%s), 1), max(%s) %s null) %s %s;" % \ + (style.SQL_KEYWORD('SELECT'), + style.SQL_FIELD(quote_name('%s_id_seq' % f.m2m_db_table())), + style.SQL_FIELD(quote_name('id')), + style.SQL_FIELD(quote_name('id')), + style.SQL_KEYWORD('IS NOT'), + style.SQL_KEYWORD('FROM'), + style.SQL_TABLE(f.m2m_db_table()))) + return output + # Register these custom typecasts, because Django expects dates/times to be # in Python's native (standard-library) datetime/time format, whereas psycopg # use mx.DateTime by default. @@ -122,6 +254,7 @@ except AttributeError: Database.register_type(Database.new_type((1083,1266), "TIME", util.typecast_time)) Database.register_type(Database.new_type((1114,1184), "TIMESTAMP", util.typecast_timestamp)) Database.register_type(Database.new_type((16,), "BOOLEAN", util.typecast_boolean)) +Database.register_type(Database.new_type((1700,), "NUMERIC", util.typecast_decimal)) OPERATOR_MAPPING = { 'exact': '= %s', diff --git a/django/db/backends/postgresql/creation.py b/django/db/backends/postgresql/creation.py index 65a804ec40..4646b68ab8 100644 --- a/django/db/backends/postgresql/creation.py +++ b/django/db/backends/postgresql/creation.py @@ -9,9 +9,10 @@ DATA_TYPES = { 'CommaSeparatedIntegerField': 'varchar(%(maxlength)s)', 'DateField': 'date', 'DateTimeField': 'timestamp with time zone', + 'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)', 'FileField': 'varchar(100)', 'FilePathField': 'varchar(100)', - 'FloatField': 'numeric(%(max_digits)s, %(decimal_places)s)', + 'FloatField': 'double precision', 'ImageField': 'varchar(100)', 'IntegerField': 'integer', 'IPAddressField': 'inet', @@ -25,6 +26,5 @@ DATA_TYPES = { 'SmallIntegerField': 'smallint', 'TextField': 'text', 'TimeField': 'time', - 'URLField': 'varchar(200)', 'USStateField': 'varchar(2)', } diff --git a/django/db/backends/postgresql/introspection.py b/django/db/backends/postgresql/introspection.py index 6e1d60c4ff..2605490afd 100644 --- a/django/db/backends/postgresql/introspection.py +++ b/django/db/backends/postgresql/introspection.py @@ -72,6 +72,7 @@ DATA_TYPES_REVERSE = { 21: 'SmallIntegerField', 23: 'IntegerField', 25: 'TextField', + 701: 'FloatField', 869: 'IPAddressField', 1043: 'CharField', 1082: 'DateField', @@ -79,5 +80,5 @@ DATA_TYPES_REVERSE = { 1114: 'DateTimeField', 1184: 'DateTimeField', 1266: 'TimeField', - 1700: 'FloatField', + 1700: 'DecimalField', } diff --git a/django/db/backends/postgresql_psycopg2/base.py b/django/db/backends/postgresql_psycopg2/base.py index 04322332dc..d9ad363ac1 100644 --- a/django/db/backends/postgresql_psycopg2/base.py +++ b/django/db/backends/postgresql_psycopg2/base.py @@ -12,6 +12,7 @@ except ImportError, e: raise ImproperlyConfigured, "Error loading psycopg2 module: %s" % e DatabaseError = Database.DatabaseError +IntegrityError = Database.IntegrityError try: # Only exists in Python 2.4+ @@ -20,6 +21,8 @@ except ImportError: # Import copy of _thread_local.py from Python 2.4 from django.utils._threading_local import local +postgres_version = None + class DatabaseWrapper(local): def __init__(self, **kwargs): self.connection = None @@ -28,7 +31,9 @@ class DatabaseWrapper(local): def cursor(self): from django.conf import settings + set_tz = False if self.connection is None: + set_tz = True if settings.DATABASE_NAME == '': from django.core.exceptions import ImproperlyConfigured raise ImproperlyConfigured, "You need to specify DATABASE_NAME in your Django settings file." @@ -45,16 +50,22 @@ class DatabaseWrapper(local): self.connection.set_isolation_level(1) # make transactions transparent to all cursors cursor = self.connection.cursor() cursor.tzinfo_factory = None - cursor.execute("SET TIME ZONE %s", [settings.TIME_ZONE]) + if set_tz: + cursor.execute("SET TIME ZONE %s", [settings.TIME_ZONE]) + global postgres_version + if not postgres_version: + cursor.execute("SELECT version()") + postgres_version = [int(val) for val in cursor.fetchone()[0].split()[1].split('.')] if settings.DEBUG: return util.CursorDebugWrapper(cursor, self) return cursor def _commit(self): - return self.connection.commit() + if self.connection is not None: + return self.connection.commit() def _rollback(self): - if self.connection: + if self.connection is not None: return self.connection.rollback() def close(self): @@ -96,6 +107,9 @@ def get_limit_offset_sql(limit, offset=None): def get_random_function_sql(): return "RANDOM()" +def get_deferrable_sql(): + return " DEFERRABLE INITIALLY DEFERRED" + def get_fulltext_search_sql(field_name): raise NotImplementedError @@ -105,6 +119,88 @@ def get_drop_foreignkey_sql(): def get_pk_default_value(): return "DEFAULT" +def get_sql_flush(style, tables, sequences): + """Return a list of SQL statements required to remove all data from + all tables in the database (without actually removing the tables + themselves) and put the database in an empty 'initial' state + """ + if tables: + if postgres_version[0] >= 8 and postgres_version[1] >= 1: + # Postgres 8.1+ can do 'TRUNCATE x, y, z...;'. In fact, it *has to* in order to be able to + # truncate tables referenced by a foreign key in any other table. The result is a + # single SQL TRUNCATE statement + sql = ['%s %s;' % \ + (style.SQL_KEYWORD('TRUNCATE'), + style.SQL_FIELD(', '.join([quote_name(table) for table in tables])) + )] + else: + sql = ['%s %s %s;' % \ + (style.SQL_KEYWORD('DELETE'), + style.SQL_KEYWORD('FROM'), + style.SQL_FIELD(quote_name(table)) + ) for table in tables] + + # 'ALTER SEQUENCE sequence_name RESTART WITH 1;'... style SQL statements + # to reset sequence indices + for sequence in sequences: + table_name = sequence['table'] + column_name = sequence['column'] + if column_name and len(column_name) > 0: + # sequence name in this case will be <table>_<column>_seq + sql.append("%s %s %s %s %s %s;" % \ + (style.SQL_KEYWORD('ALTER'), + style.SQL_KEYWORD('SEQUENCE'), + style.SQL_FIELD(quote_name('%s_%s_seq' % (table_name, column_name))), + style.SQL_KEYWORD('RESTART'), + style.SQL_KEYWORD('WITH'), + style.SQL_FIELD('1') + ) + ) + else: + # sequence name in this case will be <table>_id_seq + sql.append("%s %s %s %s %s %s;" % \ + (style.SQL_KEYWORD('ALTER'), + style.SQL_KEYWORD('SEQUENCE'), + style.SQL_FIELD(quote_name('%s_id_seq' % table_name)), + style.SQL_KEYWORD('RESTART'), + style.SQL_KEYWORD('WITH'), + style.SQL_FIELD('1') + ) + ) + return sql + else: + return [] + +def get_sql_sequence_reset(style, model_list): + "Returns a list of the SQL statements to reset sequences for the given models." + from django.db import models + output = [] + for model in model_list: + # Use `coalesce` to set the sequence for each model to the max pk value if there are records, + # or 1 if there are none. Set the `is_called` property (the third argument to `setval`) to true + # if there are records (as the max pk value is already in use), otherwise set it to false. + for f in model._meta.fields: + if isinstance(f, models.AutoField): + output.append("%s setval('%s', coalesce(max(%s), 1), max(%s) %s null) %s %s;" % \ + (style.SQL_KEYWORD('SELECT'), + style.SQL_FIELD(quote_name('%s_%s_seq' % (model._meta.db_table, f.column))), + style.SQL_FIELD(quote_name(f.column)), + style.SQL_FIELD(quote_name(f.column)), + style.SQL_KEYWORD('IS NOT'), + style.SQL_KEYWORD('FROM'), + style.SQL_TABLE(quote_name(model._meta.db_table)))) + break # Only one AutoField is allowed per model, so don't bother continuing. + for f in model._meta.many_to_many: + output.append("%s setval('%s', coalesce(max(%s), 1), max(%s) %s null) %s %s;" % \ + (style.SQL_KEYWORD('SELECT'), + style.SQL_FIELD(quote_name('%s_id_seq' % f.m2m_db_table())), + style.SQL_FIELD(quote_name('id')), + style.SQL_FIELD(quote_name('id')), + style.SQL_KEYWORD('IS NOT'), + style.SQL_KEYWORD('FROM'), + style.SQL_TABLE(f.m2m_db_table()))) + return output + OPERATOR_MAPPING = { 'exact': '= %s', 'iexact': 'ILIKE %s', diff --git a/django/db/backends/postgresql_psycopg2/introspection.py b/django/db/backends/postgresql_psycopg2/introspection.py index a546da8c45..aa45fe7db7 100644 --- a/django/db/backends/postgresql_psycopg2/introspection.py +++ b/django/db/backends/postgresql_psycopg2/introspection.py @@ -72,6 +72,7 @@ DATA_TYPES_REVERSE = { 21: 'SmallIntegerField', 23: 'IntegerField', 25: 'TextField', + 701: 'FloatField', 869: 'IPAddressField', 1043: 'CharField', 1082: 'DateField', @@ -79,5 +80,5 @@ DATA_TYPES_REVERSE = { 1114: 'DateTimeField', 1184: 'DateTimeField', 1266: 'TimeField', - 1700: 'FloatField', + 1700: 'DecimalField', } diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index 891320160f..5cd67a32f5 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -17,7 +17,13 @@ except ImportError, e: module = 'sqlite3' raise ImproperlyConfigured, "Error loading %s module: %s" % (module, e) +try: + import decimal +except ImportError: + from django.utils import _decimal as decimal # for Python 2.3 + DatabaseError = Database.DatabaseError +IntegrityError = Database.IntegrityError Database.register_converter("bool", lambda s: str(s) == '1') Database.register_converter("time", util.typecast_time) @@ -25,6 +31,8 @@ Database.register_converter("date", util.typecast_date) Database.register_converter("datetime", util.typecast_timestamp) Database.register_converter("timestamp", util.typecast_timestamp) Database.register_converter("TIMESTAMP", util.typecast_timestamp) +Database.register_converter("decimal", util.typecast_decimal) +Database.register_adapter(decimal.Decimal, util.rev_typecast_decimal) def utf8rowFactory(cursor, row): def utf8(s): @@ -67,10 +75,11 @@ class DatabaseWrapper(local): return cursor def _commit(self): - self.connection.commit() + if self.connection is not None: + self.connection.commit() def _rollback(self): - if self.connection: + if self.connection is not None: self.connection.rollback() def close(self): @@ -139,6 +148,9 @@ def get_limit_offset_sql(limit, offset=None): def get_random_function_sql(): return "RANDOM()" +def get_deferrable_sql(): + return "" + def get_fulltext_search_sql(field_name): raise NotImplementedError @@ -148,6 +160,29 @@ def get_drop_foreignkey_sql(): def get_pk_default_value(): return "NULL" +def get_sql_flush(style, tables, sequences): + """Return a list of SQL statements required to remove all data from + all tables in the database (without actually removing the tables + themselves) and put the database in an empty 'initial' state + + """ + # NB: The generated SQL below is specific to SQLite + # Note: The DELETE FROM... SQL generated below works for SQLite databases + # because constraints don't exist + sql = ['%s %s %s;' % \ + (style.SQL_KEYWORD('DELETE'), + style.SQL_KEYWORD('FROM'), + style.SQL_FIELD(quote_name(table)) + ) for table in tables] + # Note: No requirement for reset of auto-incremented indices (cf. other + # get_sql_flush() implementations). Just return SQL at this point + return sql + +def get_sql_sequence_reset(style, model_list): + "Returns a list of the SQL statements to reset sequences for the given models." + # No sequence reset required + return [] + def _sqlite_date_trunc(lookup_type, dt): try: dt = util.typecast_timestamp(dt) diff --git a/django/db/backends/sqlite3/creation.py b/django/db/backends/sqlite3/creation.py index e845179e64..e63046ab7d 100644 --- a/django/db/backends/sqlite3/creation.py +++ b/django/db/backends/sqlite3/creation.py @@ -8,9 +8,10 @@ DATA_TYPES = { 'CommaSeparatedIntegerField': 'varchar(%(maxlength)s)', 'DateField': 'date', 'DateTimeField': 'datetime', + 'DecimalField': 'decimal', 'FileField': 'varchar(100)', 'FilePathField': 'varchar(100)', - 'FloatField': 'numeric(%(max_digits)s, %(decimal_places)s)', + 'FloatField': 'real', 'ImageField': 'varchar(100)', 'IntegerField': 'integer', 'IPAddressField': 'char(15)', @@ -24,6 +25,5 @@ DATA_TYPES = { 'SmallIntegerField': 'smallint', 'TextField': 'text', 'TimeField': 'time', - 'URLField': 'varchar(200)', 'USStateField': 'varchar(2)', } diff --git a/django/db/backends/util.py b/django/db/backends/util.py index d8f86fef4f..81c752e664 100644 --- a/django/db/backends/util.py +++ b/django/db/backends/util.py @@ -1,6 +1,11 @@ import datetime from time import time +try: + import decimal +except ImportError: + from django.utils import _decimal as decimal # for Python 2.3 + class CursorDebugWrapper(object): def __init__(self, cursor, db): self.cursor = cursor @@ -33,7 +38,7 @@ class CursorDebugWrapper(object): }) def __getattr__(self, attr): - if self.__dict__.has_key(attr): + if attr in self.__dict__: return self.__dict__[attr] else: return getattr(self.cursor, attr) @@ -85,6 +90,11 @@ def typecast_boolean(s): if not s: return False return str(s)[0].lower() == 't' +def typecast_decimal(s): + if s is None or s == '': + return None + return decimal.Decimal(s) + ############################################### # Converters from Python to database (string) # ############################################### @@ -92,6 +102,11 @@ def typecast_boolean(s): def rev_typecast_boolean(obj, d): return obj and '1' or '0' +def rev_typecast_decimal(d): + if d is None: + return None + return str(d) + ################################################################################## # Helper functions for dictfetch* for databases that don't natively support them # ################################################################################## diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py index 0308dd047a..6c3abb6b59 100644 --- a/django/db/models/__init__.py +++ b/django/db/models/__init__.py @@ -8,7 +8,6 @@ from django.db.models.manager import Manager from django.db.models.base import Model, AdminOptions from django.db.models.fields import * from django.db.models.fields.related import ForeignKey, OneToOneField, ManyToManyField, ManyToOneRel, ManyToManyRel, OneToOneRel, TABULAR, STACKED -from django.db.models.fields.generic import GenericRelation, GenericRel, GenericForeignKey from django.db.models import signals from django.utils.functional import curry from django.utils.text import capfirst @@ -27,27 +26,3 @@ def permalink(func): viewname = bits[0] return reverse(bits[0], None, *bits[1:3]) return inner - -class LazyDate(object): - """ - Use in limit_choices_to to compare the field to dates calculated at run time - instead of when the model is loaded. For example:: - - ... limit_choices_to = {'date__gt' : models.LazyDate(days=-3)} ... - - which will limit the choices to dates greater than three days ago. - """ - def __init__(self, **kwargs): - self.delta = datetime.timedelta(**kwargs) - - def __str__(self): - return str(self.__get_value__()) - - def __repr__(self): - return "<LazyDate: %s>" % self.delta - - def __get_value__(self): - return (datetime.datetime.now() + self.delta).date() - - def __getattr__(self, attr): - return getattr(self.__get_value__(), attr) diff --git a/django/db/models/base.py b/django/db/models/base.py index 70569a2561..a8e6303e1c 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -13,6 +13,7 @@ from django.dispatch import dispatcher from django.utils.datastructures import SortedDict from django.utils.functional import curry from django.conf import settings +from itertools import izip import types import sys import os @@ -21,8 +22,13 @@ class ModelBase(type): "Metaclass for all models" def __new__(cls, name, bases, attrs): # If this isn't a subclass of Model, don't do anything special. - if not bases or bases == (object,): - return type.__new__(cls, name, bases, attrs) + try: + if not filter(lambda b: issubclass(b, Model), bases): + return super(ModelBase, cls).__new__(cls, name, bases, attrs) + except NameError: + # 'Model' isn't defined yet, meaning we're looking at Django's own + # Model class, defined below. + return super(ModelBase, cls).__new__(cls, name, bases, attrs) # Create the class. new_class = type.__new__(cls, name, bases, {'__module__': attrs.pop('__module__')}) @@ -36,11 +42,11 @@ class ModelBase(type): new_class._meta.parents.append(base) new_class._meta.parents.extend(base._meta.parents) - model_module = sys.modules[new_class.__module__] if getattr(new_class._meta, 'app_label', None) is None: # Figure out the app_label by looking one level up. # For 'django.contrib.sites.models', this would be 'sites'. + model_module = sys.modules[new_class.__module__] new_class._meta.app_label = model_module.__name__.split('.')[-2] # Bail out early if we have already created this class. @@ -63,7 +69,7 @@ class ModelBase(type): if getattr(new_class._meta, 'row_level_permissions', False): from django.contrib.auth.models import RowLevelPermission - gen_rel = django.db.models.GenericRelation(RowLevelPermission, object_id_field="model_id", content_type_field="model_ct") + gen_rel = django.contrib.contenttypes.generic.GenericRelation(RowLevelPermission, object_id_field="model_id", content_type_field="model_ct") new_class.add_to_class("row_level_permissions", gen_rel) new_class._prepare() @@ -95,41 +101,74 @@ class Model(object): def __init__(self, *args, **kwargs): dispatcher.send(signal=signals.pre_init, sender=self.__class__, args=args, kwargs=kwargs) - for f in self._meta.fields: - if isinstance(f.rel, ManyToOneRel): - try: - # Assume object instance was passed in. - rel_obj = kwargs.pop(f.name) - except KeyError: + + # There is a rather weird disparity here; if kwargs, it's set, then args + # overrides it. It should be one or the other; don't duplicate the work + # The reason for the kwargs check is that standard iterator passes in by + # args, and nstantiation for iteration is 33% faster. + args_len = len(args) + if args_len > len(self._meta.fields): + # Daft, but matches old exception sans the err msg. + raise IndexError("Number of args exceeds number of fields") + + fields_iter = iter(self._meta.fields) + if not kwargs: + # The ordering of the izip calls matter - izip throws StopIteration + # when an iter throws it. So if the first iter throws it, the second + # is *not* consumed. We rely on this, so don't change the order + # without changing the logic. + for val, field in izip(args, fields_iter): + setattr(self, field.attname, val) + else: + # Slower, kwargs-ready version. + for val, field in izip(args, fields_iter): + setattr(self, field.attname, val) + kwargs.pop(field.name, None) + # Maintain compatibility with existing calls. + if isinstance(field.rel, ManyToOneRel): + kwargs.pop(field.attname, None) + + # Now we're left with the unprocessed fields that *must* come from + # keywords, or default. + + for field in fields_iter: + if kwargs: + if isinstance(field.rel, ManyToOneRel): try: - # Object instance wasn't passed in -- must be an ID. - val = kwargs.pop(f.attname) + # Assume object instance was passed in. + rel_obj = kwargs.pop(field.name) except KeyError: - val = f.get_default() - else: - # Object instance was passed in. - # Special case: You can pass in "None" for related objects if it's allowed. - if rel_obj is None and f.null: - val = None - else: try: - val = getattr(rel_obj, f.rel.get_related_field().attname) - except AttributeError: - raise TypeError, "Invalid value: %r should be a %s instance, not a %s" % (f.name, f.rel.to, type(rel_obj)) - setattr(self, f.attname, val) + # Object instance wasn't passed in -- must be an ID. + val = kwargs.pop(field.attname) + except KeyError: + val = field.get_default() + else: + # Object instance was passed in. Special case: You can + # pass in "None" for related objects if it's allowed. + if rel_obj is None and field.null: + val = None + else: + try: + val = getattr(rel_obj, field.rel.get_related_field().attname) + except AttributeError: + raise TypeError("Invalid value: %r should be a %s instance, not a %s" % + (field.name, field.rel.to, type(rel_obj))) + else: + val = kwargs.pop(field.attname, field.get_default()) else: - val = kwargs.pop(f.attname, f.get_default()) - setattr(self, f.attname, val) - for prop in kwargs.keys(): - try: - if isinstance(getattr(self.__class__, prop), property): - setattr(self, prop, kwargs.pop(prop)) - except AttributeError: - pass + val = field.get_default() + setattr(self, field.attname, val) + if kwargs: - raise TypeError, "'%s' is an invalid keyword argument for this function" % kwargs.keys()[0] - for i, arg in enumerate(args): - setattr(self, self._meta.fields[i].attname, arg) + for prop in kwargs.keys(): + try: + if isinstance(getattr(self.__class__, prop), property): + setattr(self, prop, kwargs.pop(prop)) + except AttributeError: + pass + if kwargs: + raise TypeError, "'%s' is an invalid keyword argument for this function" % kwargs.keys()[0] dispatcher.send(signal=signals.post_init, sender=self.__class__, instance=self) def add_to_class(cls, name, value): @@ -327,7 +366,7 @@ class Model(object): def _get_FIELD_size(self, field): return os.path.getsize(self._get_FIELD_filename(field)) - def _save_FIELD_file(self, field, filename, raw_contents): + def _save_FIELD_file(self, field, filename, raw_contents, save=True): directory = field.get_directory_name() try: # Create the date-based directory if it doesn't exist. os.makedirs(os.path.join(settings.MEDIA_ROOT, directory)) @@ -362,8 +401,9 @@ class Model(object): if field.height_field: setattr(self, field.height_field, height) - # Save the object, because it has changed. - self.save() + # Save the object because it has changed unless save is False + if save: + self.save() _save_FIELD_file.alters_data = True diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index fe317ac24f..136ce31b8b 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -10,6 +10,10 @@ from django.utils.itercompat import tee from django.utils.text import capfirst from django.utils.translation import gettext, gettext_lazy import datetime, os, time +try: + import decimal +except ImportError: + from django.utils import _decimal as decimal # for Python 2.3 class NOT_PROVIDED: pass @@ -67,7 +71,7 @@ class Field(object): def __init__(self, verbose_name=None, name=None, primary_key=False, maxlength=None, unique=False, blank=False, null=False, db_index=False, - core=False, rel=None, default=NOT_PROVIDED, editable=True, + core=False, rel=None, default=NOT_PROVIDED, editable=True, serialize=True, prepopulate_from=None, unique_for_date=None, unique_for_month=None, unique_for_year=None, validator_list=None, choices=None, radio_admin=None, help_text='', db_column=None): @@ -78,6 +82,7 @@ class Field(object): self.blank, self.null = blank, null self.core, self.rel, self.default = core, rel, default self.editable = editable + self.serialize = serialize self.validator_list = validator_list or [] self.prepopulate_from = prepopulate_from self.unique_for_date, self.unique_for_month = unique_for_date, unique_for_month @@ -164,7 +169,7 @@ class Field(object): def get_db_prep_lookup(self, lookup_type, value): "Returns field's value prepared for database lookup." - if lookup_type in ('exact', 'gt', 'gte', 'lt', 'lte', 'year', 'month', 'day', 'search'): + if lookup_type in ('exact', 'gt', 'gte', 'lt', 'lte', 'month', 'day', 'search'): return [value] elif lookup_type in ('range', 'in'): return value @@ -178,7 +183,13 @@ class Field(object): return ["%%%s" % prep_for_like_query(value)] elif lookup_type == 'isnull': return [] - raise TypeError, "Field has invalid lookup: %s" % lookup_type + elif lookup_type == 'year': + try: + value = int(value) + except ValueError: + raise ValueError("The __year lookup type requires an integer argument") + return ['%s-01-01 00:00:00' % value, '%s-12-31 23:59:59.999999' % value] + raise TypeError("Field has invalid lookup: %s" % lookup_type) def has_default(self): "Returns a boolean of whether this field has a default value." @@ -334,10 +345,17 @@ class Field(object): return self._choices choices = property(_get_choices) - def formfield(self): + def formfield(self, form_class=forms.CharField, **kwargs): "Returns a django.newforms.Field instance for this database Field." - # TODO: This is just a temporary default during development. - return forms.CharField(required=not self.blank, label=capfirst(self.verbose_name)) + defaults = {'required': not self.blank, 'label': capfirst(self.verbose_name), 'help_text': self.help_text} + if self.choices: + defaults['widget'] = forms.Select(choices=self.get_choices()) + defaults.update(kwargs) + return form_class(**defaults) + + def value_from_object(self, obj): + "Returns the value of this field in the given model instance." + return getattr(obj, self.attname) class AutoField(Field): empty_strings_allowed = False @@ -375,7 +393,7 @@ class AutoField(Field): super(AutoField, self).contribute_to_class(cls, name) cls._meta.has_auto_field = True - def formfield(self): + def formfield(self, **kwargs): return None class BooleanField(Field): @@ -392,8 +410,10 @@ class BooleanField(Field): def get_manipulator_field_objs(self): return [oldforms.CheckboxField] - def formfield(self): - return forms.BooleanField(required=not self.blank, label=capfirst(self.verbose_name)) + def formfield(self, **kwargs): + defaults = {'form_class': forms.BooleanField} + defaults.update(kwargs) + return super(BooleanField, self).formfield(**defaults) class CharField(Field): def get_manipulator_field_objs(self): @@ -409,8 +429,10 @@ class CharField(Field): raise validators.ValidationError, gettext_lazy("This field cannot be null.") return str(value) - def formfield(self): - return forms.CharField(max_length=self.maxlength, required=not self.blank, label=capfirst(self.verbose_name)) + def formfield(self, **kwargs): + defaults = {'max_length': self.maxlength} + defaults.update(kwargs) + return super(CharField, self).formfield(**defaults) # TODO: Maybe move this into contrib, because it's specialized. class CommaSeparatedIntegerField(CharField): @@ -428,6 +450,8 @@ class DateField(Field): Field.__init__(self, verbose_name, name, **kwargs) def to_python(self, value): + if value is None: + return value if isinstance(value, datetime.datetime): return value.date() if isinstance(value, datetime.date): @@ -479,15 +503,19 @@ class DateField(Field): def get_manipulator_field_objs(self): return [oldforms.DateField] - def flatten_data(self, follow, obj = None): + def flatten_data(self, follow, obj=None): val = self._get_val_from_obj(obj) return {self.attname: (val is not None and val.strftime("%Y-%m-%d") or '')} - def formfield(self): - return forms.DateField(required=not self.blank, label=capfirst(self.verbose_name)) + def formfield(self, **kwargs): + defaults = {'form_class': forms.DateField} + defaults.update(kwargs) + return super(DateField, self).formfield(**defaults) class DateTimeField(DateField): def to_python(self, value): + if value is None: + return value if isinstance(value, datetime.datetime): return value if isinstance(value, datetime.date): @@ -544,8 +572,69 @@ class DateTimeField(DateField): return {date_field: (val is not None and val.strftime("%Y-%m-%d") or ''), time_field: (val is not None and val.strftime("%H:%M:%S") or '')} - def formfield(self): - return forms.DateTimeField(required=not self.blank, label=capfirst(self.verbose_name)) + def formfield(self, **kwargs): + defaults = {'form_class': forms.DateTimeField} + defaults.update(kwargs) + return super(DateTimeField, self).formfield(**defaults) + +class DecimalField(Field): + empty_strings_allowed = False + def __init__(self, verbose_name=None, name=None, max_digits=None, decimal_places=None, **kwargs): + self.max_digits, self.decimal_places = max_digits, decimal_places + Field.__init__(self, verbose_name, name, **kwargs) + + def to_python(self, value): + if value is None: + return value + try: + return decimal.Decimal(value) + except decimal.InvalidOperation: + raise validators.ValidationError, gettext("This value must be a decimal number.") + + def _format(self, value): + if isinstance(value, basestring): + return value + else: + return self.format_number(value) + + def format_number(self, value): + """ + Formats a number into a string with the requisite number of digits and + decimal places. + """ + num_chars = self.max_digits + # Allow for a decimal point + if self.decimal_places > 0: + num_chars += 1 + # Allow for a minus sign + if value < 0: + num_chars += 1 + + return "%.*f" % (self.decimal_places, value) + + def get_db_prep_save(self, value): + if value is not None: + value = self._format(value) + return super(DecimalField, self).get_db_prep_save(value) + + def get_db_prep_lookup(self, lookup_type, value): + if lookup_type == 'range': + value = [self._format(v) for v in value] + else: + value = self._format(value) + return super(DecimalField, self).get_db_prep_lookup(lookup_type, value) + + def get_manipulator_field_objs(self): + return [curry(oldforms.DecimalField, max_digits=self.max_digits, decimal_places=self.decimal_places)] + + def formfield(self, **kwargs): + defaults = { + 'max_digits': self.max_digits, + 'decimal_places': self.decimal_places, + 'form_class': forms.DecimalField, + } + defaults.update(kwargs) + return super(DecimalField, self).formfield(**defaults) class EmailField(CharField): def __init__(self, *args, **kwargs): @@ -561,8 +650,10 @@ class EmailField(CharField): def validate(self, field_data, all_data): validators.isValidEmail(field_data, all_data) - def formfield(self): - return forms.EmailField(required=not self.blank, label=capfirst(self.verbose_name)) + def formfield(self, **kwargs): + defaults = {'form_class': forms.EmailField} + defaults.update(kwargs) + return super(EmailField, self).formfield(**defaults) class FileField(Field): def __init__(self, verbose_name=None, name=None, upload_to='', **kwargs): @@ -610,7 +701,7 @@ class FileField(Field): setattr(cls, 'get_%s_filename' % self.name, curry(cls._get_FIELD_filename, field=self)) setattr(cls, 'get_%s_url' % self.name, curry(cls._get_FIELD_url, field=self)) setattr(cls, 'get_%s_size' % self.name, curry(cls._get_FIELD_size, field=self)) - setattr(cls, 'save_%s_file' % self.name, lambda instance, filename, raw_contents: instance._save_FIELD_file(self, filename, raw_contents)) + setattr(cls, 'save_%s_file' % self.name, lambda instance, filename, raw_contents, save=True: instance._save_FIELD_file(self, filename, raw_contents, save)) dispatcher.connect(self.delete_file, signal=signals.post_delete, sender=cls) def delete_file(self, instance): @@ -628,14 +719,14 @@ class FileField(Field): def get_manipulator_field_names(self, name_prefix): return [name_prefix + self.name + '_file', name_prefix + self.name] - def save_file(self, new_data, new_object, original_object, change, rel): + def save_file(self, new_data, new_object, original_object, change, rel, save=True): upload_field_name = self.get_manipulator_field_names('')[0] if new_data.get(upload_field_name, False): func = getattr(new_object, 'save_%s_file' % self.name) if rel: - func(new_data[upload_field_name][0]["filename"], new_data[upload_field_name][0]["content"]) + func(new_data[upload_field_name][0]["filename"], new_data[upload_field_name][0]["content"], save) else: - func(new_data[upload_field_name]["filename"], new_data[upload_field_name]["content"]) + func(new_data[upload_field_name]["filename"], new_data[upload_field_name]["content"], save) def get_directory_name(self): return os.path.normpath(datetime.datetime.now().strftime(self.upload_to)) @@ -655,12 +746,14 @@ class FilePathField(Field): class FloatField(Field): empty_strings_allowed = False - def __init__(self, verbose_name=None, name=None, max_digits=None, decimal_places=None, **kwargs): - self.max_digits, self.decimal_places = max_digits, decimal_places - Field.__init__(self, verbose_name, name, **kwargs) def get_manipulator_field_objs(self): - return [curry(oldforms.FloatField, max_digits=self.max_digits, decimal_places=self.decimal_places)] + return [oldforms.FloatField] + + def formfield(self, **kwargs): + defaults = {'form_class': forms.FloatField} + defaults.update(kwargs) + return super(FloatField, self).formfield(**defaults) class ImageField(FileField): def __init__(self, verbose_name=None, name=None, width_field=None, height_field=None, **kwargs): @@ -679,12 +772,12 @@ class ImageField(FileField): if not self.height_field: setattr(cls, 'get_%s_height' % self.name, curry(cls._get_FIELD_height, field=self)) - def save_file(self, new_data, new_object, original_object, change, rel): - FileField.save_file(self, new_data, new_object, original_object, change, rel) + def save_file(self, new_data, new_object, original_object, change, rel, save=True): + FileField.save_file(self, new_data, new_object, original_object, change, rel, save) # If the image has height and/or width field(s) and they haven't # changed, set the width and/or height field(s) back to their original # values. - if change and (self.width_field or self.height_field): + if change and (self.width_field or self.height_field) and save: if self.width_field: setattr(new_object, self.width_field, getattr(original_object, self.width_field)) if self.height_field: @@ -696,8 +789,10 @@ class IntegerField(Field): def get_manipulator_field_objs(self): return [oldforms.IntegerField] - def formfield(self): - return forms.IntegerField(required=not self.blank, label=capfirst(self.verbose_name)) + def formfield(self, **kwargs): + defaults = {'form_class': forms.IntegerField} + defaults.update(kwargs) + return super(IntegerField, self).formfield(**defaults) class IPAddressField(Field): def __init__(self, *args, **kwargs): @@ -715,6 +810,13 @@ class NullBooleanField(Field): kwargs['null'] = True Field.__init__(self, *args, **kwargs) + def to_python(self, value): + if value in (None, True, False): return value + if value in ('None'): return None + if value in ('t', 'True', '1'): return True + if value in ('f', 'False', '0'): return False + raise validators.ValidationError, gettext("This value must be either None, True or False.") + def get_manipulator_field_objs(self): return [oldforms.NullBooleanField] @@ -725,6 +827,12 @@ class PhoneNumberField(IntegerField): def validate(self, field_data, all_data): validators.isValidPhone(field_data, all_data) + def formfield(self, **kwargs): + from django.contrib.localflavor.us.forms import USPhoneNumberField + defaults = {'form_class': USPhoneNumberField} + defaults.update(kwargs) + return super(PhoneNumberField, self).formfield(**defaults) + class PositiveIntegerField(IntegerField): def get_manipulator_field_objs(self): return [oldforms.PositiveIntegerField] @@ -738,7 +846,7 @@ class SlugField(Field): kwargs['maxlength'] = kwargs.get('maxlength', 50) kwargs.setdefault('validator_list', []).append(validators.isSlug) # Set db_index=True unless it's been set manually. - if not kwargs.has_key('db_index'): + if 'db_index' not in kwargs: kwargs['db_index'] = True Field.__init__(self, *args, **kwargs) @@ -753,6 +861,11 @@ class TextField(Field): def get_manipulator_field_objs(self): return [oldforms.LargeTextField] + def formfield(self, **kwargs): + defaults = {'widget': forms.Textarea} + defaults.update(kwargs) + return super(TextField, self).formfield(**defaults) + class TimeField(Field): empty_strings_allowed = False def __init__(self, verbose_name=None, name=None, auto_now=False, auto_now_add=False, **kwargs): @@ -781,7 +894,7 @@ class TimeField(Field): if value is not None: # MySQL will throw a warning if microseconds are given, because it # doesn't support microseconds. - if settings.DATABASE_ENGINE == 'mysql': + if settings.DATABASE_ENGINE == 'mysql' and hasattr(value, 'microsecond'): value = value.replace(microsecond=0) value = str(value) return Field.get_db_prep_save(self, value) @@ -793,26 +906,40 @@ class TimeField(Field): val = self._get_val_from_obj(obj) return {self.attname: (val is not None and val.strftime("%H:%M:%S") or '')} - def formfield(self): - return forms.TimeField(required=not self.blank, label=capfirst(self.verbose_name)) + def formfield(self, **kwargs): + defaults = {'form_class': forms.TimeField} + defaults.update(kwargs) + return super(TimeField, self).formfield(**defaults) -class URLField(Field): +class URLField(CharField): def __init__(self, verbose_name=None, name=None, verify_exists=True, **kwargs): + kwargs['maxlength'] = kwargs.get('maxlength', 200) if verify_exists: kwargs.setdefault('validator_list', []).append(validators.isExistingURL) self.verify_exists = verify_exists - Field.__init__(self, verbose_name, name, **kwargs) + CharField.__init__(self, verbose_name, name, **kwargs) def get_manipulator_field_objs(self): return [oldforms.URLField] - def formfield(self): - return forms.URLField(required=not self.blank, verify_exists=self.verify_exists, label=capfirst(self.verbose_name)) + def get_internal_type(self): + return "CharField" + + def formfield(self, **kwargs): + defaults = {'form_class': forms.URLField, 'verify_exists': self.verify_exists} + defaults.update(kwargs) + return super(URLField, self).formfield(**defaults) class USStateField(Field): def get_manipulator_field_objs(self): return [oldforms.USStateField] + def formfield(self, **kwargs): + from django.contrib.localflavor.us.forms import USStateSelect + defaults = {'widget': USStateSelect} + defaults.update(kwargs) + return super(USStateField, self).formfield(**defaults) + class XMLField(TextField): def __init__(self, verbose_name=None, name=None, schema_path=None, **kwargs): self.schema_path = schema_path diff --git a/django/db/models/fields/generic.py b/django/db/models/fields/generic.py deleted file mode 100644 index 1ad8346e42..0000000000 --- a/django/db/models/fields/generic.py +++ /dev/null @@ -1,259 +0,0 @@ -""" -Classes allowing "generic" relations through ContentType and object-id fields. -""" - -from django import oldforms -from django.core.exceptions import ObjectDoesNotExist -from django.db import backend -from django.db.models import signals -from django.db.models.fields.related import RelatedField, Field, ManyToManyRel -from django.db.models.loading import get_model -from django.dispatch import dispatcher -from django.utils.functional import curry - -class GenericForeignKey(object): - """ - Provides a generic relation to any object through content-type/object-id - fields. - """ - - def __init__(self, ct_field="content_type", fk_field="object_id"): - self.ct_field = ct_field - self.fk_field = fk_field - - def contribute_to_class(self, cls, name): - # Make sure the fields exist (these raise FieldDoesNotExist, - # which is a fine error to raise here) - self.name = name - self.model = cls - self.cache_attr = "_%s_cache" % name - - # For some reason I don't totally understand, using weakrefs here doesn't work. - dispatcher.connect(self.instance_pre_init, signal=signals.pre_init, sender=cls, weak=False) - - # Connect myself as the descriptor for this field - setattr(cls, name, self) - - def instance_pre_init(self, signal, sender, args, kwargs): - # Handle initalizing an object with the generic FK instaed of - # content-type/object-id fields. - if kwargs.has_key(self.name): - value = kwargs.pop(self.name) - kwargs[self.ct_field] = self.get_content_type(value) - kwargs[self.fk_field] = value._get_pk_val() - - def get_content_type(self, obj): - # Convenience function using get_model avoids a circular import when using this model - ContentType = get_model("contenttypes", "contenttype") - return ContentType.objects.get_for_model(obj) - - def __get__(self, instance, instance_type=None): - if instance is None: - raise AttributeError, "%s must be accessed via instance" % self.name - - try: - return getattr(instance, self.cache_attr) - except AttributeError: - rel_obj = None - ct = getattr(instance, self.ct_field) - if ct: - try: - rel_obj = ct.get_object_for_this_type(pk=getattr(instance, self.fk_field)) - except ObjectDoesNotExist: - pass - setattr(instance, self.cache_attr, rel_obj) - return rel_obj - - def __set__(self, instance, value): - if instance is None: - raise AttributeError, "%s must be accessed via instance" % self.related.opts.object_name - - ct = None - fk = None - if value is not None: - ct = self.get_content_type(value) - fk = value._get_pk_val() - - setattr(instance, self.ct_field, ct) - setattr(instance, self.fk_field, fk) - setattr(instance, self.cache_attr, value) - -class GenericRelation(RelatedField, Field): - """Provides an accessor to generic related objects (i.e. comments)""" - - def __init__(self, to, **kwargs): - kwargs['verbose_name'] = kwargs.get('verbose_name', None) - kwargs['rel'] = GenericRel(to, - related_name=kwargs.pop('related_name', None), - limit_choices_to=kwargs.pop('limit_choices_to', None), - symmetrical=kwargs.pop('symmetrical', True)) - - # Override content-type/object-id field names on the related class - self.object_id_field_name = kwargs.pop("object_id_field", "object_id") - self.content_type_field_name = kwargs.pop("content_type_field", "content_type") - - kwargs['blank'] = True - kwargs['editable'] = False - Field.__init__(self, **kwargs) - - def get_manipulator_field_objs(self): - choices = self.get_choices_default() - return [curry(oldforms.SelectMultipleField, size=min(max(len(choices), 5), 15), choices=choices)] - - def get_choices_default(self): - return Field.get_choices(self, include_blank=False) - - def flatten_data(self, follow, obj = None): - new_data = {} - if obj: - instance_ids = [instance._get_pk_val() for instance in getattr(obj, self.name).all()] - new_data[self.name] = instance_ids - return new_data - - def m2m_db_table(self): - return self.rel.to._meta.db_table - - def m2m_column_name(self): - return self.object_id_field_name - - def m2m_reverse_name(self): - return self.object_id_field_name - - def contribute_to_class(self, cls, name): - super(GenericRelation, self).contribute_to_class(cls, name) - - # Save a reference to which model this class is on for future use - self.model = cls - - # Add the descriptor for the m2m relation - setattr(cls, self.name, ReverseGenericRelatedObjectsDescriptor(self)) - - def contribute_to_related_class(self, cls, related): - pass - - def set_attributes_from_rel(self): - pass - - def get_internal_type(self): - return "ManyToManyField" - -class ReverseGenericRelatedObjectsDescriptor(object): - """ - This class provides the functionality that makes the related-object - managers available as attributes on a model class, for fields that have - multiple "remote" values and have a GenericRelation defined in their model - (rather than having another model pointed *at* them). In the example - "article.publications", the publications attribute is a - ReverseGenericRelatedObjectsDescriptor instance. - """ - def __init__(self, field): - self.field = field - - def __get__(self, instance, instance_type=None): - if instance is None: - raise AttributeError, "Manager must be accessed via instance" - - # This import is done here to avoid circular import importing this module - from django.contrib.contenttypes.models import ContentType - - # Dynamically create a class that subclasses the related model's - # default manager. - rel_model = self.field.rel.to - superclass = rel_model._default_manager.__class__ - RelatedManager = create_generic_related_manager(superclass) - - manager = RelatedManager( - model = rel_model, - instance = instance, - symmetrical = (self.field.rel.symmetrical and instance.__class__ == rel_model), - join_table = backend.quote_name(self.field.m2m_db_table()), - source_col_name = backend.quote_name(self.field.m2m_column_name()), - target_col_name = backend.quote_name(self.field.m2m_reverse_name()), - content_type = ContentType.objects.get_for_model(self.field.model), - content_type_field_name = self.field.content_type_field_name, - object_id_field_name = self.field.object_id_field_name - ) - - return manager - - def __set__(self, instance, value): - if instance is None: - raise AttributeError, "Manager must be accessed via instance" - - manager = self.__get__(instance) - manager.clear() - for obj in value: - manager.add(obj) - -def create_generic_related_manager(superclass): - """ - Factory function for a manager that subclasses 'superclass' (which is a - Manager) and adds behavior for generic related objects. - """ - - class GenericRelatedObjectManager(superclass): - def __init__(self, model=None, core_filters=None, instance=None, symmetrical=None, - join_table=None, source_col_name=None, target_col_name=None, content_type=None, - content_type_field_name=None, object_id_field_name=None): - - super(GenericRelatedObjectManager, self).__init__() - self.core_filters = core_filters or {} - self.model = model - self.content_type = content_type - self.symmetrical = symmetrical - self.instance = instance - self.join_table = join_table - self.join_table = model._meta.db_table - self.source_col_name = source_col_name - self.target_col_name = target_col_name - self.content_type_field_name = content_type_field_name - self.object_id_field_name = object_id_field_name - self.pk_val = self.instance._get_pk_val() - - def get_query_set(self): - query = { - '%s__pk' % self.content_type_field_name : self.content_type.id, - '%s__exact' % self.object_id_field_name : self.pk_val, - } - return superclass.get_query_set(self).filter(**query) - - def add(self, *objs): - for obj in objs: - setattr(obj, self.content_type_field_name, self.content_type) - setattr(obj, self.object_id_field_name, self.pk_val) - obj.save() - add.alters_data = True - - def remove(self, *objs): - for obj in objs: - obj.delete() - remove.alters_data = True - - def clear(self): - for obj in self.all(): - obj.delete() - clear.alters_data = True - - def create(self, **kwargs): - kwargs[self.content_type_field_name] = self.content_type - kwargs[self.object_id_field_name] = self.pk_val - obj = self.model(**kwargs) - obj.save() - return obj - create.alters_data = True - - return GenericRelatedObjectManager - -class GenericRel(ManyToManyRel): - def __init__(self, to, related_name=None, limit_choices_to=None, symmetrical=True): - self.to = to - self.num_in_admin = 0 - self.related_name = related_name - self.filter_interface = None - self.limit_choices_to = limit_choices_to or {} - self.edit_inline = False - self.raw_id_admin = False - self.symmetrical = symmetrical - self.multiple = True - assert not (self.raw_id_admin and self.filter_interface), \ - "Generic relations may not use both raw_id_admin and filter_interface" diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 797ef05be1..0739d0461a 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -2,10 +2,12 @@ from django.db import backend, transaction from django.db.models import signals, get_model from django.db.models.fields import AutoField, Field, IntegerField, get_ul_class from django.db.models.related import RelatedObject +from django.utils.text import capfirst from django.utils.translation import gettext_lazy, string_concat, ngettext from django.utils.functional import curry from django.core import validators from django import oldforms +from django import newforms as forms from django.dispatch import dispatcher # For Python 2.3 @@ -314,18 +316,20 @@ def create_many_related_manager(superclass): # join_table: name of the m2m link table # source_col_name: the PK colname in join_table for the source object # target_col_name: the PK colname in join_table for the target object - # *objs - objects to add + # *objs - objects to add. Either object instances, or primary keys of object instances. from django.db import connection # If there aren't any objects, there is nothing to do. if objs: # Check that all the objects are of the right type + new_ids = set() for obj in objs: - if not isinstance(obj, self.model): - raise ValueError, "objects to add() must be %s instances" % self.model._meta.object_name + if isinstance(obj, self.model): + new_ids.add(obj._get_pk_val()) + else: + new_ids.add(obj) # Add the newly created or already existing objects to the join table. # First find out which items are already added, to avoid adding them twice - new_ids = set([obj._get_pk_val() for obj in objs]) cursor = connection.cursor() cursor.execute("SELECT %s FROM %s WHERE %s = %%s AND %s IN (%s)" % \ (target_col_name, self.join_table, source_col_name, @@ -352,14 +356,16 @@ def create_many_related_manager(superclass): # If there aren't any objects, there is nothing to do. if objs: # Check that all the objects are of the right type + old_ids = set() for obj in objs: - if not isinstance(obj, self.model): - raise ValueError, "objects to remove() must be %s instances" % self.model._meta.object_name + if isinstance(obj, self.model): + old_ids.add(obj._get_pk_val()) + else: + old_ids.add(obj) # Remove the specified objects from the join table - old_ids = set([obj._get_pk_val() for obj in objs]) cursor = connection.cursor() cursor.execute("DELETE FROM %s WHERE %s = %%s AND %s IN (%s)" % \ - (self.join_table, source_col_name, + (self.join_table, source_col_name, target_col_name, ",".join(['%s'] * len(old_ids))), [self._pk_val] + list(old_ids)) transaction.commit_unless_managed() @@ -468,7 +474,7 @@ class ForeignKey(RelatedField, Field): to_field = to_field or to._meta.pk.name kwargs['verbose_name'] = kwargs.get('verbose_name', '') - if kwargs.has_key('edit_inline_type'): + if 'edit_inline_type' in kwargs: import warnings warnings.warn("edit_inline_type is deprecated. Use edit_inline instead.") kwargs['edit_inline'] = kwargs.pop('edit_inline_type') @@ -546,6 +552,11 @@ class ForeignKey(RelatedField, Field): def contribute_to_related_class(self, cls, related): setattr(cls, related.get_accessor_name(), ForeignRelatedObjectsDescriptor(related)) + def formfield(self, **kwargs): + defaults = {'form_class': forms.ModelChoiceField, 'queryset': self.rel.to._default_manager.all()} + defaults.update(kwargs) + return super(ForeignKey, self).formfield(**defaults) + class OneToOneField(RelatedField, IntegerField): def __init__(self, to, to_field=None, **kwargs): try: @@ -556,7 +567,7 @@ class OneToOneField(RelatedField, IntegerField): to_field = to_field or to._meta.pk.name kwargs['verbose_name'] = kwargs.get('verbose_name', '') - if kwargs.has_key('edit_inline_type'): + if 'edit_inline_type' in kwargs: import warnings warnings.warn("edit_inline_type is deprecated. Use edit_inline instead.") kwargs['edit_inline'] = kwargs.pop('edit_inline_type') @@ -607,6 +618,11 @@ class OneToOneField(RelatedField, IntegerField): if not cls._meta.one_to_one_field: cls._meta.one_to_one_field = self + def formfield(self, **kwargs): + defaults = {'form_class': forms.ModelChoiceField, 'queryset': self.rel.to._default_manager.all()} + defaults.update(kwargs) + return super(OneToOneField, self).formfield(**defaults) + class ManyToManyField(RelatedField, Field): def __init__(self, to, **kwargs): kwargs['verbose_name'] = kwargs.get('verbose_name', None) @@ -617,6 +633,7 @@ class ManyToManyField(RelatedField, Field): limit_choices_to=kwargs.pop('limit_choices_to', None), raw_id_admin=kwargs.pop('raw_id_admin', False), symmetrical=kwargs.pop('symmetrical', True)) + self.db_table = kwargs.pop('db_table', None) if kwargs["rel"].raw_id_admin: kwargs.setdefault("validator_list", []).append(self.isValidIDList) Field.__init__(self, **kwargs) @@ -639,7 +656,10 @@ class ManyToManyField(RelatedField, Field): def _get_m2m_db_table(self, opts): "Function that can be curried to provide the m2m table name for this relation" - return '%s_%s' % (opts.db_table, self.name) + if self.db_table: + return self.db_table + else: + return '%s_%s' % (opts.db_table, self.name) def _get_m2m_column_name(self, related): "Function that can be curried to provide the source column name for the m2m table" @@ -713,6 +733,19 @@ class ManyToManyField(RelatedField, Field): def set_attributes_from_rel(self): pass + def value_from_object(self, obj): + "Returns the value of this field in the given model instance." + return getattr(obj, self.attname).all() + + def formfield(self, **kwargs): + defaults = {'form_class': forms.ModelMultipleChoiceField, 'queryset': self.rel.to._default_manager.all()} + defaults.update(kwargs) + # If initial is passed in, it's a list of related objects, but the + # MultipleChoiceField takes a list of IDs. + if defaults.get('initial') is not None: + defaults['initial'] = [i._get_pk_val() for i in defaults['initial']] + return super(ManyToManyField, self).formfield(**defaults) + class ManyToOneRel(object): def __init__(self, to, field_name, num_in_admin=3, min_num_in_admin=None, max_num_in_admin=None, num_extra_on_change=1, edit_inline=False, diff --git a/django/db/models/loading.py b/django/db/models/loading.py index f4aff2438b..224f5e8451 100644 --- a/django/db/models/loading.py +++ b/django/db/models/loading.py @@ -103,7 +103,7 @@ def register_models(app_label, *models): # in the _app_models dictionary model_name = model._meta.object_name.lower() model_dict = _app_models.setdefault(app_label, {}) - if model_dict.has_key(model_name): + if model_name in model_dict: # The same model may be imported via different paths (e.g. # appname.models and project.appname.models). We use the source # filename as a means to detect identity. diff --git a/django/db/models/manager.py b/django/db/models/manager.py index 6005874516..b60eed262a 100644 --- a/django/db/models/manager.py +++ b/django/db/models/manager.py @@ -1,4 +1,4 @@ -from django.db.models.query import QuerySet +from django.db.models.query import QuerySet, EmptyQuerySet from django.dispatch import dispatcher from django.db.models import signals from django.db.models.fields import FieldDoesNotExist @@ -41,12 +41,18 @@ class Manager(object): ####################### # PROXIES TO QUERYSET # ####################### + + def get_empty_query_set(self): + return EmptyQuerySet(self.model) def get_query_set(self): """Returns a new QuerySet object. Subclasses can override this method to easily customise the behaviour of the Manager. """ return QuerySet(self.model) + + def none(self): + return self.get_empty_query_set() def all(self): return self.get_query_set() diff --git a/django/db/models/manipulators.py b/django/db/models/manipulators.py index e9dfa7037c..d5fc5f725e 100644 --- a/django/db/models/manipulators.py +++ b/django/db/models/manipulators.py @@ -96,14 +96,16 @@ class AutomaticManipulator(oldforms.Manipulator): if self.change: params[self.opts.pk.attname] = self.obj_key - # First, save the basic object itself. + # First, create the basic object itself. new_object = self.model(**params) - new_object.save() - # Now that the object's been saved, save any uploaded files. + # Now that the object's been created, save any uploaded files. for f in self.opts.fields: if isinstance(f, FileField): - f.save_file(new_data, new_object, self.change and self.original_object or None, self.change, rel=False) + f.save_file(new_data, new_object, self.change and self.original_object or None, self.change, rel=False, save=False) + + # Now save the object + new_object.save() # Calculate which primary fields have changed. if self.change: diff --git a/django/db/models/options.py b/django/db/models/options.py index ee253ff451..556168e7d0 100644 --- a/django/db/models/options.py +++ b/django/db/models/options.py @@ -85,6 +85,7 @@ class Options(object): self.fields.insert(bisect(self.fields, field), field) if not self.pk and field.primary_key: self.pk = field + field.serialize = False def __repr__(self): return '<Options for %s>' % self.object_name @@ -140,7 +141,7 @@ class Options(object): def get_follow(self, override=None): follow = {} for f in self.fields + self.many_to_many + self.get_all_related_objects(): - if override and override.has_key(f.name): + if override and f.name in override: child_override = override[f.name] else: child_override = None @@ -182,7 +183,7 @@ class Options(object): # TODO: follow if not hasattr(self, '_field_types'): self._field_types = {} - if not self._field_types.has_key(field_type): + if field_type not in self._field_types: try: # First check self.fields. for f in self.fields: diff --git a/django/db/models/query.py b/django/db/models/query.py index 53ed63ae5b..a6e702be18 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -1,8 +1,9 @@ from django.db import backend, connection, transaction from django.db.models.fields import DateField, FieldDoesNotExist -from django.db.models import signals +from django.db.models import signals, loading from django.dispatch import dispatcher from django.utils.datastructures import SortedDict +from django.contrib.contenttypes import generic import operator import re @@ -25,6 +26,9 @@ QUERY_TERMS = ( # Larger values are slightly faster at the expense of more storage space. GET_ITERATOR_CHUNK_SIZE = 100 +class EmptyResultSet(Exception): + pass + #################### # HELPER FUNCTIONS # #################### @@ -80,6 +84,7 @@ class QuerySet(object): self._filters = Q() self._order_by = None # Ordering, e.g. ('date', '-name'). If None, use model's ordering. self._select_related = False # Whether to fill cache for related objects. + self._max_related_depth = 0 # Maximum "depth" for select_related self._distinct = False # Whether the query should use SELECT DISTINCT. self._select = {} # Dictionary of attname -> SQL. self._where = [] # List of extra WHERE clauses to use. @@ -104,6 +109,8 @@ class QuerySet(object): def __getitem__(self, k): "Retrieve an item or slice from the set of results." + if not isinstance(k, (slice, int)): + raise TypeError assert (not isinstance(k, slice) and (k >= 0)) \ or (isinstance(k, slice) and (k.start is None or k.start >= 0) and (k.stop is None or k.stop >= 0)), \ "Negative indexing is not supported." @@ -163,12 +170,16 @@ class QuerySet(object): def iterator(self): "Performs the SELECT database lookup of this QuerySet." + try: + select, sql, params = self._get_sql_clause() + except EmptyResultSet: + raise StopIteration + # self._select is a dictionary, and dictionaries' key order is # undefined, so we convert it to a list of tuples. extra_select = self._select.items() cursor = connection.cursor() - select, sql, params = self._get_sql_clause() cursor.execute("SELECT " + (self._distinct and "DISTINCT " or "") + ",".join(select) + sql, params) fill_cache = self._select_related index_end = len(self.model._meta.fields) @@ -178,7 +189,8 @@ class QuerySet(object): raise StopIteration for row in rows: if fill_cache: - obj, index_end = get_cached_row(self.model, row, 0) + obj, index_end = get_cached_row(klass=self.model, row=row, + index_start=0, max_depth=self._max_related_depth) else: obj = self.model(*row[:index_end]) for i, k in enumerate(extra_select): @@ -186,13 +198,31 @@ class QuerySet(object): yield obj def count(self): - "Performs a SELECT COUNT() and returns the number of records as an integer." + """ + Performs a SELECT COUNT() and returns the number of records as an + integer. + + If the queryset is already cached (i.e. self._result_cache is set) this + simply returns the length of the cached results set to avoid multiple + SELECT COUNT(*) calls. + """ + if self._result_cache is not None: + return len(self._result_cache) + counter = self._clone() counter._order_by = () + counter._select_related = False + + offset = counter._offset + limit = counter._limit counter._offset = None counter._limit = None - counter._select_related = False - select, sql, params = counter._get_sql_clause() + + try: + select, sql, params = counter._get_sql_clause() + except EmptyResultSet: + return 0 + cursor = connection.cursor() if self._distinct: id_col = "%s.%s" % (backend.quote_name(self.model._meta.db_table), @@ -200,7 +230,16 @@ class QuerySet(object): cursor.execute("SELECT COUNT(DISTINCT(%s))" % id_col + sql, params) else: cursor.execute("SELECT COUNT(*)" + sql, params) - return cursor.fetchone()[0] + count = cursor.fetchone()[0] + + # Apply any offset and limit constraints manually, since using LIMIT or + # OFFSET in SQL doesn't change the output of COUNT. + if offset: + count = max(0, count - offset) + if limit: + count = min(limit, count) + + return count def get(self, *args, **kwargs): "Performs the SELECT and returns a single object matching the given keyword arguments." @@ -359,9 +398,9 @@ class QuerySet(object): else: return self._filter_or_exclude(None, **filter_obj) - def select_related(self, true_or_false=True): + def select_related(self, true_or_false=True, depth=0): "Returns a new QuerySet instance with '_select_related' modified." - return self._clone(_select_related=true_or_false) + return self._clone(_select_related=true_or_false, _max_related_depth=depth) def order_by(self, *field_names): "Returns a new QuerySet instance with the ordering changed." @@ -395,6 +434,7 @@ class QuerySet(object): c._filters = self._filters c._order_by = self._order_by c._select_related = self._select_related + c._max_related_depth = self._max_related_depth c._distinct = self._distinct c._select = self._select.copy() c._where = self._where[:] @@ -448,7 +488,10 @@ class QuerySet(object): # Add additional tables and WHERE clauses based on select_related. if self._select_related: - fill_table_cache(opts, select, tables, where, opts.db_table, [opts.db_table]) + fill_table_cache(opts, select, tables, where, + old_prefix=opts.db_table, + cache_tables_seen=[opts.db_table], + max_depth=self._max_related_depth) # Add any additional SELECTs. if self._select: @@ -509,22 +552,42 @@ class QuerySet(object): return select, " ".join(sql), params class ValuesQuerySet(QuerySet): - def iterator(self): - # select_related and select aren't supported in values(). + def __init__(self, *args, **kwargs): + super(ValuesQuerySet, self).__init__(*args, **kwargs) + # select_related isn't supported in values(). self._select_related = False - self._select = {} + + def iterator(self): + try: + select, sql, params = self._get_sql_clause() + except EmptyResultSet: + raise StopIteration # self._fields is a list of field names to fetch. if self._fields: - columns = [self.model._meta.get_field(f, many_to_many=False).column for f in self._fields] + #columns = [self.model._meta.get_field(f, many_to_many=False).column for f in self._fields] + if not self._select: + columns = [self.model._meta.get_field(f, many_to_many=False).column for f in self._fields] + else: + columns = [] + for f in self._fields: + if f in [field.name for field in self.model._meta.fields]: + columns.append( self.model._meta.get_field(f, many_to_many=False).column ) + elif not self._select.has_key( f ): + raise FieldDoesNotExist, '%s has no field named %r' % ( self.model._meta.object_name, f ) + field_names = self._fields else: # Default to all fields. columns = [f.column for f in self.model._meta.fields] field_names = [f.attname for f in self.model._meta.fields] - cursor = connection.cursor() - select, sql, params = self._get_sql_clause() select = ['%s.%s' % (backend.quote_name(self.model._meta.db_table), backend.quote_name(c)) for c in columns] + + # Add any additional SELECTs. + if self._select: + select.extend(['(%s) AS %s' % (quote_only_if_word(s[1]), backend.quote_name(s[0])) for s in self._select.items()]) + + cursor = connection.cursor() cursor.execute("SELECT " + (self._distinct and "DISTINCT " or "") + ",".join(select) + sql, params) while 1: rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE) @@ -545,7 +608,12 @@ class DateQuerySet(QuerySet): if self._field.null: self._where.append('%s.%s IS NOT NULL' % \ (backend.quote_name(self.model._meta.db_table), backend.quote_name(self._field.column))) - select, sql, params = self._get_sql_clause() + + try: + select, sql, params = self._get_sql_clause() + except EmptyResultSet: + raise StopIteration + sql = 'SELECT %s %s GROUP BY 1 ORDER BY 1 %s' % \ (backend.get_date_trunc_sql(self._kind, '%s.%s' % (backend.quote_name(self.model._meta.db_table), backend.quote_name(self._field.column))), sql, self._order) @@ -563,6 +631,25 @@ class DateQuerySet(QuerySet): c._order = self._order return c +class EmptyQuerySet(QuerySet): + def __init__(self, model=None): + super(EmptyQuerySet, self).__init__(model) + self._result_cache = [] + + def count(self): + return 0 + + def delete(self): + pass + + def _clone(self, klass=None, **kwargs): + c = super(EmptyQuerySet, self)._clone(klass, **kwargs) + c._result_cache = [] + return c + + def _get_sql_clause(self): + raise EmptyResultSet + class QOperator(object): "Base class for QAnd and QOr" def __init__(self, *args): @@ -571,10 +658,14 @@ class QOperator(object): def get_sql(self, opts): joins, where, params = SortedDict(), [], [] for val in self.args: - joins2, where2, params2 = val.get_sql(opts) - joins.update(joins2) - where.extend(where2) - params.extend(params2) + try: + joins2, where2, params2 = val.get_sql(opts) + joins.update(joins2) + where.extend(where2) + params.extend(params2) + except EmptyResultSet: + if not isinstance(self, QOr): + raise EmptyResultSet if where: return joins, ['(%s)' % self.operator.join(where)], params return joins, [], params @@ -628,8 +719,11 @@ class QNot(Q): self.q = q def get_sql(self, opts): - joins, where, params = self.q.get_sql(opts) - where2 = ['(NOT (%s))' % " AND ".join(where)] + try: + joins, where, params = self.q.get_sql(opts) + where2 = ['(NOT (%s))' % " AND ".join(where)] + except EmptyResultSet: + return SortedDict(), [], [] return joins, where2, params def get_where_clause(lookup_type, table_prefix, field_name, value): @@ -641,10 +735,14 @@ def get_where_clause(lookup_type, table_prefix, field_name, value): except KeyError: pass if lookup_type == 'in': - return '%s%s IN (%s)' % (table_prefix, field_name, ','.join(['%s' for v in value])) - elif lookup_type == 'range': + in_string = ','.join(['%s' for id in value]) + if in_string: + return '%s%s IN (%s)' % (table_prefix, field_name, in_string) + else: + raise EmptyResultSet + elif lookup_type in ('range', 'year'): return '%s%s BETWEEN %%s AND %%s' % (table_prefix, field_name) - elif lookup_type in ('year', 'month', 'day'): + elif lookup_type in ('month', 'day'): return "%s = %%s" % backend.get_date_extract_sql(lookup_type, table_prefix + field_name) elif lookup_type == 'isnull': return "%s%s IS %sNULL" % (table_prefix, field_name, (not value and 'NOT ' or '')) @@ -652,21 +750,33 @@ def get_where_clause(lookup_type, table_prefix, field_name, value): return backend.get_fulltext_search_sql(table_prefix + field_name) raise TypeError, "Got invalid lookup_type: %s" % repr(lookup_type) -def get_cached_row(klass, row, index_start): - "Helper function that recursively returns an object with cache filled" +def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0): + """Helper function that recursively returns an object with cache filled""" + + # If we've got a max_depth set and we've exceeded that depth, bail now. + if max_depth and cur_depth > max_depth: + return None + index_end = index_start + len(klass._meta.fields) obj = klass(*row[index_start:index_end]) for f in klass._meta.fields: if f.rel and not f.null: - rel_obj, index_end = get_cached_row(f.rel.to, row, index_end) - setattr(obj, f.get_cache_name(), rel_obj) + cached_row = get_cached_row(f.rel.to, row, index_end, max_depth, cur_depth+1) + if cached_row: + rel_obj, index_end = cached_row + setattr(obj, f.get_cache_name(), rel_obj) return obj, index_end -def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen): +def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen, max_depth=0, cur_depth=0): """ Helper function that recursively populates the select, tables and where (in place) for select_related queries. """ + + # If we've got a max_depth set and we've exceeded that depth, bail now. + if max_depth and cur_depth > max_depth: + return None + qn = backend.quote_name for f in opts.fields: if f.rel and not f.null: @@ -681,12 +791,12 @@ def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen) where.append('%s.%s = %s.%s' % \ (qn(old_prefix), qn(f.column), qn(db_table), qn(f.rel.get_related_field().column))) select.extend(['%s.%s' % (qn(db_table), qn(f2.column)) for f2 in f.rel.to._meta.fields]) - fill_table_cache(f.rel.to._meta, select, tables, where, db_table, cache_tables_seen) + fill_table_cache(f.rel.to._meta, select, tables, where, db_table, cache_tables_seen, max_depth, cur_depth+1) def parse_lookup(kwarg_items, opts): # Helper function that handles converting API kwargs # (e.g. "name__exact": "tom") to SQL. - # Returns a tuple of (tables, joins, where, params). + # Returns a tuple of (joins, where, params). # 'joins' is a sorted dictionary describing the tables that must be joined # to complete the query. The dictionary is sorted because creation order @@ -725,12 +835,14 @@ def parse_lookup(kwarg_items, opts): if len(path) < 1: raise TypeError, "Cannot parse keyword query %r" % kwarg - + if value is None: # Interpret '__exact=None' as the sql '= NULL'; otherwise, reject # all uses of None as a query value. if lookup_type != 'exact': raise ValueError, "Cannot use None as a query value" + elif callable(value): + value = value() joins2, where2, params2 = lookup_inner(path, lookup_type, value, opts, opts.db_table, None) joins.update(joins2) @@ -755,6 +867,13 @@ def find_field(name, field_list, related_query): return None return matches[0] +def field_choices(field_list, related_query): + if related_query: + choices = [f.field.related_query_name() for f in field_list] + else: + choices = [f.name for f in field_list] + return choices + def lookup_inner(path, lookup_type, value, opts, table, column): qn = backend.quote_name joins, where, params = SortedDict(), [], [] @@ -827,13 +946,23 @@ def lookup_inner(path, lookup_type, value, opts, table, column): new_opts = field.rel.to._meta new_column = new_opts.pk.column join_column = field.column - - raise FieldFound + raise FieldFound + elif path: + # For regular fields, if there are still items on the path, + # an error has been made. We munge "name" so that the error + # properly identifies the cause of the problem. + name += LOOKUP_SEPARATOR + path[0] + else: + raise FieldFound except FieldFound: # Match found, loop has been shortcut. pass else: # No match found. - raise TypeError, "Cannot resolve keyword '%s' into field" % name + choices = field_choices(current_opts.many_to_many, False) + \ + field_choices(current_opts.get_all_related_many_to_many_objects(), True) + \ + field_choices(current_opts.get_all_related_objects(), True) + \ + field_choices(current_opts.fields, False) + raise TypeError, "Cannot resolve keyword '%s' into field. Choices are: %s" % (name, ", ".join(choices)) # Check whether an intermediate join is required between current_table # and new_table. @@ -926,18 +1055,26 @@ def delete_objects(seen_objs): pk_list = [pk for pk,instance in seen_objs[cls]] for related in cls._meta.get_all_related_many_to_many_objects(): - for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): - cursor.execute("DELETE FROM %s WHERE %s IN (%s)" % \ - (qn(related.field.m2m_db_table()), - qn(related.field.m2m_reverse_name()), - ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]])), - pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]) + if not isinstance(related.field, generic.GenericRelation): + for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): + cursor.execute("DELETE FROM %s WHERE %s IN (%s)" % \ + (qn(related.field.m2m_db_table()), + qn(related.field.m2m_reverse_name()), + ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]])), + pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]) for f in cls._meta.many_to_many: + if isinstance(f, generic.GenericRelation): + from django.contrib.contenttypes.models import ContentType + query_extra = 'AND %s=%%s' % f.rel.to._meta.get_field(f.content_type_field_name).column + args_extra = [ContentType.objects.get_for_model(cls).id] + else: + query_extra = '' + args_extra = [] for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): - cursor.execute("DELETE FROM %s WHERE %s IN (%s)" % \ + cursor.execute(("DELETE FROM %s WHERE %s IN (%s)" % \ (qn(f.m2m_db_table()), qn(f.m2m_column_name()), - ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]])), - pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]) + ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]]))) + query_extra, + pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE] + args_extra) for field in cls._meta.fields: if field.rel and field.null and field.rel.to in seen_objs: for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): diff --git a/django/db/models/related.py b/django/db/models/related.py index ac1ec50ca2..2c1dc5c516 100644 --- a/django/db/models/related.py +++ b/django/db/models/related.py @@ -1,7 +1,7 @@ class BoundRelatedObject(object): def __init__(self, related_object, field_mapping, original): self.relation = related_object - self.field_mappings = field_mapping[related_object.opts.module_name] + self.field_mappings = field_mapping[related_object.name] def template_name(self): raise NotImplementedError @@ -16,7 +16,7 @@ class RelatedObject(object): self.opts = model._meta self.field = field self.edit_inline = field.rel.edit_inline - self.name = self.opts.module_name + self.name = '%s:%s' % (self.opts.app_label, self.opts.module_name) self.var_name = self.opts.object_name.lower() def flatten_data(self, follow, obj=None): @@ -68,7 +68,10 @@ class RelatedObject(object): # object return [attr] else: - return [None] * self.field.rel.num_in_admin + if self.field.rel.min_num_in_admin: + return [None] * max(self.field.rel.num_in_admin, self.field.rel.min_num_in_admin) + else: + return [None] * self.field.rel.num_in_admin def get_db_prep_lookup(self, lookup_type, value): # Defer to the actual field definition for db prep @@ -101,12 +104,12 @@ class RelatedObject(object): attr = getattr(manipulator.original_object, self.get_accessor_name()) count = attr.count() count += self.field.rel.num_extra_on_change - if self.field.rel.min_num_in_admin: - count = max(count, self.field.rel.min_num_in_admin) - if self.field.rel.max_num_in_admin: - count = min(count, self.field.rel.max_num_in_admin) else: count = self.field.rel.num_in_admin + if self.field.rel.min_num_in_admin: + count = max(count, self.field.rel.min_num_in_admin) + if self.field.rel.max_num_in_admin: + count = min(count, self.field.rel.max_num_in_admin) else: count = 1 diff --git a/django/db/transaction.py b/django/db/transaction.py index 4a0658e1c3..bb90713525 100644 --- a/django/db/transaction.py +++ b/django/db/transaction.py @@ -46,12 +46,12 @@ def enter_transaction_management(): when no current block is running). """ thread_ident = thread.get_ident() - if state.has_key(thread_ident) and state[thread_ident]: + if thread_ident in state and state[thread_ident]: state[thread_ident].append(state[thread_ident][-1]) else: state[thread_ident] = [] state[thread_ident].append(settings.TRANSACTIONS_MANAGED) - if not dirty.has_key(thread_ident): + if thread_ident not in dirty: dirty[thread_ident] = False def leave_transaction_management(): @@ -61,7 +61,7 @@ def leave_transaction_management(): those from outside. (Commits are on connection level.) """ thread_ident = thread.get_ident() - if state.has_key(thread_ident) and state[thread_ident]: + if thread_ident in state and state[thread_ident]: del state[thread_ident][-1] else: raise TransactionManagementError("This code isn't under transaction management") @@ -84,7 +84,7 @@ def set_dirty(): changes waiting for commit. """ thread_ident = thread.get_ident() - if dirty.has_key(thread_ident): + if thread_ident in dirty: dirty[thread_ident] = True else: raise TransactionManagementError("This code isn't under transaction management") @@ -96,7 +96,7 @@ def set_clean(): should happen. """ thread_ident = thread.get_ident() - if dirty.has_key(thread_ident): + if thread_ident in dirty: dirty[thread_ident] = False else: raise TransactionManagementError("This code isn't under transaction management") @@ -106,7 +106,7 @@ def is_managed(): Checks whether the transaction manager is in manual or in auto state. """ thread_ident = thread.get_ident() - if state.has_key(thread_ident): + if thread_ident in state: if state[thread_ident]: return state[thread_ident][-1] return settings.TRANSACTIONS_MANAGED |
