summaryrefslogtreecommitdiff
path: root/django/db
diff options
context:
space:
mode:
Diffstat (limited to 'django/db')
-rw-r--r--django/db/backends/ado_mssql/creation.py1
-rw-r--r--django/db/backends/mysql/base.py8
-rw-r--r--django/db/backends/mysql/creation.py1
-rw-r--r--django/db/backends/oracle/creation.py1
-rw-r--r--django/db/backends/postgresql/base.py35
-rw-r--r--django/db/backends/postgresql/creation.py1
-rw-r--r--django/db/backends/sqlite3/creation.py1
-rw-r--r--django/db/models/fields/__init__.py128
-rw-r--r--django/db/models/fields/generic.py4
-rw-r--r--django/db/models/fields/related.py132
-rw-r--r--django/db/models/manager.py8
-rw-r--r--django/db/models/manipulators.py8
-rw-r--r--django/db/models/query.py102
13 files changed, 316 insertions, 114 deletions
diff --git a/django/db/backends/ado_mssql/creation.py b/django/db/backends/ado_mssql/creation.py
index 4d85d27ea5..5158ba02f9 100644
--- a/django/db/backends/ado_mssql/creation.py
+++ b/django/db/backends/ado_mssql/creation.py
@@ -21,6 +21,5 @@ DATA_TYPES = {
'SmallIntegerField': 'smallint',
'TextField': 'text',
'TimeField': 'time',
- 'URLField': 'varchar(200)',
'USStateField': 'varchar(2)',
}
diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py
index e7e060e6c2..28c5b1c683 100644
--- a/django/db/backends/mysql/base.py
+++ b/django/db/backends/mysql/base.py
@@ -98,9 +98,11 @@ class DatabaseWrapper(local):
kwargs['port'] = int(settings.DATABASE_PORT)
kwargs.update(self.options)
self.connection = Database.connect(**kwargs)
- cursor = self.connection.cursor()
- if self.connection.get_server_info() >= '4.1':
- cursor.execute("SET NAMES 'utf8'")
+ cursor = self.connection.cursor()
+ if self.connection.get_server_info() >= '4.1':
+ cursor.execute("SET NAMES 'utf8'")
+ else:
+ cursor = self.connection.cursor()
if settings.DEBUG:
return util.CursorDebugWrapper(MysqlDebugWrapper(cursor), self)
return cursor
diff --git a/django/db/backends/mysql/creation.py b/django/db/backends/mysql/creation.py
index 116b490124..22ed901653 100644
--- a/django/db/backends/mysql/creation.py
+++ b/django/db/backends/mysql/creation.py
@@ -25,6 +25,5 @@ DATA_TYPES = {
'SmallIntegerField': 'smallint',
'TextField': 'longtext',
'TimeField': 'time',
- 'URLField': 'varchar(200)',
'USStateField': 'varchar(2)',
}
diff --git a/django/db/backends/oracle/creation.py b/django/db/backends/oracle/creation.py
index d45ceb64f5..da65df172e 100644
--- a/django/db/backends/oracle/creation.py
+++ b/django/db/backends/oracle/creation.py
@@ -21,6 +21,5 @@ DATA_TYPES = {
'SmallIntegerField': 'smallint',
'TextField': 'long',
'TimeField': 'timestamp',
- 'URLField': 'varchar(200)',
'USStateField': 'varchar(2)',
}
diff --git a/django/db/backends/postgresql/base.py b/django/db/backends/postgresql/base.py
index 1e48e9c3fa..054f74a0d3 100644
--- a/django/db/backends/postgresql/base.py
+++ b/django/db/backends/postgresql/base.py
@@ -20,6 +20,38 @@ except ImportError:
# Import copy of _thread_local.py from Python 2.4
from django.utils._threading_local import local
+def smart_basestring(s, charset):
+ if isinstance(s, unicode):
+ return s.encode(charset)
+ return s
+
+class UnicodeCursorWrapper(object):
+ """
+ A thin wrapper around psycopg cursors that allows them to accept Unicode
+ strings as params.
+
+ This is necessary because psycopg doesn't apply any DB quoting to
+ parameters that are Unicode strings. If a param is Unicode, this will
+ convert it to a bytestring using DEFAULT_CHARSET before passing it to
+ psycopg.
+ """
+ def __init__(self, cursor, charset):
+ self.cursor = cursor
+ self.charset = charset
+
+ def execute(self, sql, params=()):
+ return self.cursor.execute(sql, [smart_basestring(p, self.charset) for p in params])
+
+ def executemany(self, sql, param_list):
+ new_param_list = [tuple([smart_basestring(p, self.charset) for p in params]) for params in param_list]
+ return self.cursor.executemany(sql, new_param_list)
+
+ def __getattr__(self, attr):
+ if self.__dict__.has_key(attr):
+ return self.__dict__[attr]
+ else:
+ return getattr(self.cursor, attr)
+
class DatabaseWrapper(local):
def __init__(self, **kwargs):
self.connection = None
@@ -45,6 +77,7 @@ class DatabaseWrapper(local):
self.connection.set_isolation_level(1) # make transactions transparent to all cursors
cursor = self.connection.cursor()
cursor.execute("SET TIME ZONE %s", [settings.TIME_ZONE])
+ cursor = UnicodeCursorWrapper(cursor, settings.DEFAULT_CHARSET)
if settings.DEBUG:
return util.CursorDebugWrapper(cursor, self)
return cursor
@@ -118,7 +151,7 @@ def get_pk_default_value():
try:
Database.register_type(Database.new_type((1082,), "DATE", util.typecast_date))
except AttributeError:
- raise Exception, "You appear to be using psycopg version 2, which isn't supported yet, because it's still in beta. Use psycopg version 1 instead: http://initd.org/projects/psycopg1"
+ raise Exception, "You appear to be using psycopg version 2. Set your DATABASE_ENGINE to 'postgresql_psycopg2' instead of 'postgresql'."
Database.register_type(Database.new_type((1083,1266), "TIME", util.typecast_time))
Database.register_type(Database.new_type((1114,1184), "TIMESTAMP", util.typecast_timestamp))
Database.register_type(Database.new_type((16,), "BOOLEAN", util.typecast_boolean))
diff --git a/django/db/backends/postgresql/creation.py b/django/db/backends/postgresql/creation.py
index 65a804ec40..6c130f368e 100644
--- a/django/db/backends/postgresql/creation.py
+++ b/django/db/backends/postgresql/creation.py
@@ -25,6 +25,5 @@ DATA_TYPES = {
'SmallIntegerField': 'smallint',
'TextField': 'text',
'TimeField': 'time',
- 'URLField': 'varchar(200)',
'USStateField': 'varchar(2)',
}
diff --git a/django/db/backends/sqlite3/creation.py b/django/db/backends/sqlite3/creation.py
index e845179e64..77f570b2e8 100644
--- a/django/db/backends/sqlite3/creation.py
+++ b/django/db/backends/sqlite3/creation.py
@@ -24,6 +24,5 @@ DATA_TYPES = {
'SmallIntegerField': 'smallint',
'TextField': 'text',
'TimeField': 'time',
- 'URLField': 'varchar(200)',
'USStateField': 'varchar(2)',
}
diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py
index 8e8d68aad5..024fa95b8e 100644
--- a/django/db/models/fields/__init__.py
+++ b/django/db/models/fields/__init__.py
@@ -2,7 +2,8 @@ from django.db.models import signals
from django.dispatch import dispatcher
from django.conf import settings
from django.core import validators
-from django import forms
+from django import oldforms
+from django import newforms as forms
from django.core.exceptions import ObjectDoesNotExist
from django.utils.functional import curry
from django.utils.itercompat import tee
@@ -206,10 +207,10 @@ class Field(object):
if self.choices:
if self.radio_admin:
- field_objs = [forms.RadioSelectField]
+ field_objs = [oldforms.RadioSelectField]
params['ul_class'] = get_ul_class(self.radio_admin)
else:
- field_objs = [forms.SelectField]
+ field_objs = [oldforms.SelectField]
params['choices'] = self.get_choices_default()
else:
@@ -218,7 +219,7 @@ class Field(object):
def get_manipulator_fields(self, opts, manipulator, change, name_prefix='', rel=False, follow=True):
"""
- Returns a list of forms.FormField instances for this field. It
+ Returns a list of oldforms.FormField instances for this field. It
calculates the choices at runtime, not at compile time.
name_prefix is a prefix to prepend to the "field_name" argument.
@@ -333,6 +334,16 @@ class Field(object):
return self._choices
choices = property(_get_choices)
+ def formfield(self, **kwargs):
+ "Returns a django.newforms.Field instance for this database Field."
+ defaults = {'required': not self.blank, 'label': capfirst(self.verbose_name), 'help_text': self.help_text}
+ defaults.update(kwargs)
+ return forms.CharField(**defaults)
+
+ def value_from_object(self, obj):
+ "Returns the value of this field in the given model instance."
+ return getattr(obj, self.attname)
+
class AutoField(Field):
empty_strings_allowed = False
def __init__(self, *args, **kwargs):
@@ -354,7 +365,7 @@ class AutoField(Field):
return Field.get_manipulator_fields(self, opts, manipulator, change, name_prefix, rel, follow)
def get_manipulator_field_objs(self):
- return [forms.HiddenField]
+ return [oldforms.HiddenField]
def get_manipulator_new_data(self, new_data, rel=False):
# Never going to be called
@@ -369,6 +380,9 @@ class AutoField(Field):
super(AutoField, self).contribute_to_class(cls, name)
cls._meta.has_auto_field = True
+ def formfield(self, **kwargs):
+ return None
+
class BooleanField(Field):
def __init__(self, *args, **kwargs):
kwargs['blank'] = True
@@ -381,11 +395,16 @@ class BooleanField(Field):
raise validators.ValidationError, gettext("This value must be either True or False.")
def get_manipulator_field_objs(self):
- return [forms.CheckboxField]
+ return [oldforms.CheckboxField]
+
+ def formfield(self, **kwargs):
+ defaults = {'required': not self.blank, 'label': capfirst(self.verbose_name), 'help_text': self.help_text}
+ defaults.update(kwargs)
+ return forms.BooleanField(**defaults)
class CharField(Field):
def get_manipulator_field_objs(self):
- return [forms.TextField]
+ return [oldforms.TextField]
def to_python(self, value):
if isinstance(value, basestring):
@@ -397,10 +416,15 @@ class CharField(Field):
raise validators.ValidationError, gettext_lazy("This field cannot be null.")
return str(value)
+ def formfield(self, **kwargs):
+ defaults = {'max_length': self.maxlength, 'required': not self.blank, 'label': capfirst(self.verbose_name), 'help_text': self.help_text}
+ defaults.update(kwargs)
+ return forms.CharField(**defaults)
+
# TODO: Maybe move this into contrib, because it's specialized.
class CommaSeparatedIntegerField(CharField):
def get_manipulator_field_objs(self):
- return [forms.CommaSeparatedIntegerField]
+ return [oldforms.CommaSeparatedIntegerField]
class DateField(Field):
empty_strings_allowed = False
@@ -462,12 +486,17 @@ class DateField(Field):
return Field.get_db_prep_save(self, value)
def get_manipulator_field_objs(self):
- return [forms.DateField]
+ return [oldforms.DateField]
- def flatten_data(self, follow, obj = None):
+ def flatten_data(self, follow, obj=None):
val = self._get_val_from_obj(obj)
return {self.attname: (val is not None and val.strftime("%Y-%m-%d") or '')}
+ def formfield(self, **kwargs):
+ defaults = {'required': not self.blank, 'label': capfirst(self.verbose_name), 'help_text': self.help_text}
+ defaults.update(kwargs)
+ return forms.DateField(**defaults)
+
class DateTimeField(DateField):
def to_python(self, value):
if isinstance(value, datetime.datetime):
@@ -503,7 +532,7 @@ class DateTimeField(DateField):
return Field.get_db_prep_lookup(self, lookup_type, value)
def get_manipulator_field_objs(self):
- return [forms.DateField, forms.TimeField]
+ return [oldforms.DateField, oldforms.TimeField]
def get_manipulator_field_names(self, name_prefix):
return [name_prefix + self.name + '_date', name_prefix + self.name + '_time']
@@ -526,6 +555,11 @@ class DateTimeField(DateField):
return {date_field: (val is not None and val.strftime("%Y-%m-%d") or ''),
time_field: (val is not None and val.strftime("%H:%M:%S") or '')}
+ def formfield(self, **kwargs):
+ defaults = {'required': not self.blank, 'label': capfirst(self.verbose_name), 'help_text': self.help_text}
+ defaults.update(kwargs)
+ return forms.DateTimeField(**defaults)
+
class EmailField(CharField):
def __init__(self, *args, **kwargs):
kwargs['maxlength'] = 75
@@ -535,11 +569,16 @@ class EmailField(CharField):
return "CharField"
def get_manipulator_field_objs(self):
- return [forms.EmailField]
+ return [oldforms.EmailField]
def validate(self, field_data, all_data):
validators.isValidEmail(field_data, all_data)
+ def formfield(self, **kwargs):
+ defaults = {'required': not self.blank, 'label': capfirst(self.verbose_name), 'help_text': self.help_text}
+ defaults.update(kwargs)
+ return forms.EmailField(**defaults)
+
class FileField(Field):
def __init__(self, verbose_name=None, name=None, upload_to='', **kwargs):
self.upload_to = upload_to
@@ -599,7 +638,7 @@ class FileField(Field):
os.remove(file_name)
def get_manipulator_field_objs(self):
- return [forms.FileUploadField, forms.HiddenField]
+ return [oldforms.FileUploadField, oldforms.HiddenField]
def get_manipulator_field_names(self, name_prefix):
return [name_prefix + self.name + '_file', name_prefix + self.name]
@@ -627,7 +666,7 @@ class FilePathField(Field):
Field.__init__(self, verbose_name, name, **kwargs)
def get_manipulator_field_objs(self):
- return [curry(forms.FilePathField, path=self.path, match=self.match, recursive=self.recursive)]
+ return [curry(oldforms.FilePathField, path=self.path, match=self.match, recursive=self.recursive)]
class FloatField(Field):
empty_strings_allowed = False
@@ -636,7 +675,7 @@ class FloatField(Field):
Field.__init__(self, verbose_name, name, **kwargs)
def get_manipulator_field_objs(self):
- return [curry(forms.FloatField, max_digits=self.max_digits, decimal_places=self.decimal_places)]
+ return [curry(oldforms.FloatField, max_digits=self.max_digits, decimal_places=self.decimal_places)]
class ImageField(FileField):
def __init__(self, verbose_name=None, name=None, width_field=None, height_field=None, **kwargs):
@@ -644,7 +683,7 @@ class ImageField(FileField):
FileField.__init__(self, verbose_name, name, **kwargs)
def get_manipulator_field_objs(self):
- return [forms.ImageUploadField, forms.HiddenField]
+ return [oldforms.ImageUploadField, oldforms.HiddenField]
def contribute_to_class(self, cls, name):
super(ImageField, self).contribute_to_class(cls, name)
@@ -670,7 +709,12 @@ class ImageField(FileField):
class IntegerField(Field):
empty_strings_allowed = False
def get_manipulator_field_objs(self):
- return [forms.IntegerField]
+ return [oldforms.IntegerField]
+
+ def formfield(self, **kwargs):
+ defaults = {'required': not self.blank, 'label': capfirst(self.verbose_name), 'help_text': self.help_text}
+ defaults.update(kwargs)
+ return forms.IntegerField(**defaults)
class IPAddressField(Field):
def __init__(self, *args, **kwargs):
@@ -678,7 +722,7 @@ class IPAddressField(Field):
Field.__init__(self, *args, **kwargs)
def get_manipulator_field_objs(self):
- return [forms.IPAddressField]
+ return [oldforms.IPAddressField]
def validate(self, field_data, all_data):
validators.isValidIPAddress4(field_data, None)
@@ -689,22 +733,22 @@ class NullBooleanField(Field):
Field.__init__(self, *args, **kwargs)
def get_manipulator_field_objs(self):
- return [forms.NullBooleanField]
+ return [oldforms.NullBooleanField]
class PhoneNumberField(IntegerField):
def get_manipulator_field_objs(self):
- return [forms.PhoneNumberField]
+ return [oldforms.PhoneNumberField]
def validate(self, field_data, all_data):
validators.isValidPhone(field_data, all_data)
class PositiveIntegerField(IntegerField):
def get_manipulator_field_objs(self):
- return [forms.PositiveIntegerField]
+ return [oldforms.PositiveIntegerField]
class PositiveSmallIntegerField(IntegerField):
def get_manipulator_field_objs(self):
- return [forms.PositiveSmallIntegerField]
+ return [oldforms.PositiveSmallIntegerField]
class SlugField(Field):
def __init__(self, *args, **kwargs):
@@ -716,15 +760,20 @@ class SlugField(Field):
Field.__init__(self, *args, **kwargs)
def get_manipulator_field_objs(self):
- return [forms.TextField]
+ return [oldforms.TextField]
class SmallIntegerField(IntegerField):
def get_manipulator_field_objs(self):
- return [forms.SmallIntegerField]
+ return [oldforms.SmallIntegerField]
class TextField(Field):
def get_manipulator_field_objs(self):
- return [forms.LargeTextField]
+ return [oldforms.LargeTextField]
+
+ def formfield(self, **kwargs):
+ defaults = {'required': not self.blank, 'widget': forms.Textarea, 'label': capfirst(self.verbose_name), 'help_text': self.help_text}
+ defaults.update(kwargs)
+ return forms.CharField(**defaults)
class TimeField(Field):
empty_strings_allowed = False
@@ -760,24 +809,39 @@ class TimeField(Field):
return Field.get_db_prep_save(self, value)
def get_manipulator_field_objs(self):
- return [forms.TimeField]
+ return [oldforms.TimeField]
def flatten_data(self,follow, obj = None):
val = self._get_val_from_obj(obj)
return {self.attname: (val is not None and val.strftime("%H:%M:%S") or '')}
-class URLField(Field):
+ def formfield(self, **kwargs):
+ defaults = {'required': not self.blank, 'label': capfirst(self.verbose_name), 'help_text': self.help_text}
+ defaults.update(kwargs)
+ return forms.TimeField(**defaults)
+
+class URLField(CharField):
def __init__(self, verbose_name=None, name=None, verify_exists=True, **kwargs):
+ kwargs['maxlength'] = kwargs.get('maxlength', 200)
if verify_exists:
kwargs.setdefault('validator_list', []).append(validators.isExistingURL)
- Field.__init__(self, verbose_name, name, **kwargs)
+ self.verify_exists = verify_exists
+ CharField.__init__(self, verbose_name, name, **kwargs)
def get_manipulator_field_objs(self):
- return [forms.URLField]
+ return [oldforms.URLField]
+
+ def get_internal_type(self):
+ return "CharField"
+
+ def formfield(self, **kwargs):
+ defaults = {'required': not self.blank, 'verify_exists': self.verify_exists, 'label': capfirst(self.verbose_name), 'help_text': self.help_text}
+ defaults.update(kwargs)
+ return forms.URLField(**defaults)
class USStateField(Field):
def get_manipulator_field_objs(self):
- return [forms.USStateField]
+ return [oldforms.USStateField]
class XMLField(TextField):
def __init__(self, verbose_name=None, name=None, schema_path=None, **kwargs):
@@ -788,7 +852,7 @@ class XMLField(TextField):
return "TextField"
def get_manipulator_field_objs(self):
- return [curry(forms.XMLLargeTextField, schema_path=self.schema_path)]
+ return [curry(oldforms.XMLLargeTextField, schema_path=self.schema_path)]
class OrderingField(IntegerField):
empty_strings_allowed=False
@@ -801,4 +865,4 @@ class OrderingField(IntegerField):
return "IntegerField"
def get_manipulator_fields(self, opts, manipulator, change, name_prefix='', rel=False, follow=True):
- return [forms.HiddenField(name_prefix + self.name)]
+ return [oldforms.HiddenField(name_prefix + self.name)]
diff --git a/django/db/models/fields/generic.py b/django/db/models/fields/generic.py
index 7d7651029c..1ad8346e42 100644
--- a/django/db/models/fields/generic.py
+++ b/django/db/models/fields/generic.py
@@ -2,7 +2,7 @@
Classes allowing "generic" relations through ContentType and object-id fields.
"""
-from django import forms
+from django import oldforms
from django.core.exceptions import ObjectDoesNotExist
from django.db import backend
from django.db.models import signals
@@ -98,7 +98,7 @@ class GenericRelation(RelatedField, Field):
def get_manipulator_field_objs(self):
choices = self.get_choices_default()
- return [curry(forms.SelectMultipleField, size=min(max(len(choices), 5), 15), choices=choices)]
+ return [curry(oldforms.SelectMultipleField, size=min(max(len(choices), 5), 15), choices=choices)]
def get_choices_default(self):
return Field.get_choices(self, include_blank=False)
diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py
index bd9262d55a..b517747735 100644
--- a/django/db/models/fields/related.py
+++ b/django/db/models/fields/related.py
@@ -2,10 +2,12 @@ from django.db import backend, transaction
from django.db.models import signals, get_model
from django.db.models.fields import AutoField, Field, IntegerField, get_ul_class
from django.db.models.related import RelatedObject
+from django.utils.text import capfirst
from django.utils.translation import gettext_lazy, string_concat, ngettext
from django.utils.functional import curry
from django.core import validators
-from django import forms
+from django import oldforms
+from django import newforms as forms
from django.dispatch import dispatcher
# For Python 2.3
@@ -256,8 +258,7 @@ class ForeignRelatedObjectsDescriptor(object):
# Otherwise, just move the named objects into the set.
if self.related.field.null:
manager.clear()
- for obj in value:
- manager.add(obj)
+ manager.add(*value)
def create_many_related_manager(superclass):
"""Creates a manager that subclasses 'superclass' (which is a Manager)
@@ -315,28 +316,36 @@ def create_many_related_manager(superclass):
# join_table: name of the m2m link table
# source_col_name: the PK colname in join_table for the source object
# target_col_name: the PK colname in join_table for the target object
- # *objs - objects to add
+ # *objs - objects to add. Either object instances, or primary keys of object instances.
from django.db import connection
- # Add the newly created or already existing objects to the join table.
- # First find out which items are already added, to avoid adding them twice
- new_ids = set([obj._get_pk_val() for obj in objs])
- cursor = connection.cursor()
- cursor.execute("SELECT %s FROM %s WHERE %s = %%s AND %s IN (%s)" % \
- (target_col_name, self.join_table, source_col_name,
- target_col_name, ",".join(['%s'] * len(new_ids))),
- [self._pk_val] + list(new_ids))
- if cursor.rowcount is not None and cursor.rowcount != 0:
- existing_ids = set([row[0] for row in cursor.fetchmany(cursor.rowcount)])
- else:
- existing_ids = set()
+ # If there aren't any objects, there is nothing to do.
+ if objs:
+ # Check that all the objects are of the right type
+ new_ids = set()
+ for obj in objs:
+ if isinstance(obj, self.model):
+ new_ids.add(obj._get_pk_val())
+ else:
+ new_ids.add(obj)
+ # Add the newly created or already existing objects to the join table.
+ # First find out which items are already added, to avoid adding them twice
+ cursor = connection.cursor()
+ cursor.execute("SELECT %s FROM %s WHERE %s = %%s AND %s IN (%s)" % \
+ (target_col_name, self.join_table, source_col_name,
+ target_col_name, ",".join(['%s'] * len(new_ids))),
+ [self._pk_val] + list(new_ids))
+ if cursor.rowcount is not None and cursor.rowcount != 0:
+ existing_ids = set([row[0] for row in cursor.fetchmany(cursor.rowcount)])
+ else:
+ existing_ids = set()
- # Add the ones that aren't there already
- for obj_id in (new_ids - existing_ids):
- cursor.execute("INSERT INTO %s (%s, %s) VALUES (%%s, %%s)" % \
- (self.join_table, source_col_name, target_col_name),
- [self._pk_val, obj_id])
- transaction.commit_unless_managed()
+ # Add the ones that aren't there already
+ for obj_id in (new_ids - existing_ids):
+ cursor.execute("INSERT INTO %s (%s, %s) VALUES (%%s, %%s)" % \
+ (self.join_table, source_col_name, target_col_name),
+ [self._pk_val, obj_id])
+ transaction.commit_unless_managed()
def _remove_items(self, source_col_name, target_col_name, *objs):
# source_col_name: the PK colname in join_table for the source object
@@ -344,16 +353,22 @@ def create_many_related_manager(superclass):
# *objs - objects to remove
from django.db import connection
- for obj in objs:
- if not isinstance(obj, self.model):
- raise ValueError, "objects to remove() must be %s instances" % self.model._meta.object_name
- # Remove the specified objects from the join table
- cursor = connection.cursor()
- for obj in objs:
- cursor.execute("DELETE FROM %s WHERE %s = %%s AND %s = %%s" % \
- (self.join_table, source_col_name, target_col_name),
- [self._pk_val, obj._get_pk_val()])
- transaction.commit_unless_managed()
+ # If there aren't any objects, there is nothing to do.
+ if objs:
+ # Check that all the objects are of the right type
+ old_ids = set()
+ for obj in objs:
+ if isinstance(obj, self.model):
+ old_ids.add(obj._get_pk_val())
+ else:
+ old_ids.add(obj)
+ # Remove the specified objects from the join table
+ cursor = connection.cursor()
+ cursor.execute("DELETE FROM %s WHERE %s = %%s AND %s IN (%s)" % \
+ (self.join_table, source_col_name,
+ target_col_name, ",".join(['%s'] * len(old_ids))),
+ [self._pk_val] + list(old_ids))
+ transaction.commit_unless_managed()
def _clear_items(self, source_col_name):
# source_col_name: the PK colname in join_table for the source object
@@ -405,8 +420,7 @@ class ManyRelatedObjectsDescriptor(object):
manager = self.__get__(instance)
manager.clear()
- for obj in value:
- manager.add(obj)
+ manager.add(*value)
class ReverseManyRelatedObjectsDescriptor(object):
# This class provides the functionality that makes the related-object
@@ -447,8 +461,7 @@ class ReverseManyRelatedObjectsDescriptor(object):
manager = self.__get__(instance)
manager.clear()
- for obj in value:
- manager.add(obj)
+ manager.add(*value)
class ForeignKey(RelatedField, Field):
empty_strings_allowed = False
@@ -493,13 +506,13 @@ class ForeignKey(RelatedField, Field):
params['validator_list'].append(curry(manipulator_valid_rel_key, self, manipulator))
else:
if self.radio_admin:
- field_objs = [forms.RadioSelectField]
+ field_objs = [oldforms.RadioSelectField]
params['ul_class'] = get_ul_class(self.radio_admin)
else:
if self.null:
- field_objs = [forms.NullSelectField]
+ field_objs = [oldforms.NullSelectField]
else:
- field_objs = [forms.SelectField]
+ field_objs = [oldforms.SelectField]
params['choices'] = self.get_choices_default()
return field_objs, params
@@ -508,7 +521,7 @@ class ForeignKey(RelatedField, Field):
if self.rel.raw_id_admin and not isinstance(rel_field, AutoField):
return rel_field.get_manipulator_field_objs()
else:
- return [forms.IntegerField]
+ return [oldforms.IntegerField]
def get_db_prep_save(self, value):
if value == '' or value == None:
@@ -539,6 +552,11 @@ class ForeignKey(RelatedField, Field):
def contribute_to_related_class(self, cls, related):
setattr(cls, related.get_accessor_name(), ForeignRelatedObjectsDescriptor(related))
+ def formfield(self, **kwargs):
+ defaults = {'choices': self.get_choices_default(), 'required': not self.blank, 'label': capfirst(self.verbose_name), 'help_text': self.help_text}
+ defaults.update(kwargs)
+ return forms.ChoiceField(**defaults)
+
class OneToOneField(RelatedField, IntegerField):
def __init__(self, to, to_field=None, **kwargs):
try:
@@ -581,13 +599,13 @@ class OneToOneField(RelatedField, IntegerField):
params['validator_list'].append(curry(manipulator_valid_rel_key, self, manipulator))
else:
if self.radio_admin:
- field_objs = [forms.RadioSelectField]
+ field_objs = [oldforms.RadioSelectField]
params['ul_class'] = get_ul_class(self.radio_admin)
else:
if self.null:
- field_objs = [forms.NullSelectField]
+ field_objs = [oldforms.NullSelectField]
else:
- field_objs = [forms.SelectField]
+ field_objs = [oldforms.SelectField]
params['choices'] = self.get_choices_default()
return field_objs, params
@@ -600,6 +618,11 @@ class OneToOneField(RelatedField, IntegerField):
if not cls._meta.one_to_one_field:
cls._meta.one_to_one_field = self
+ def formfield(self, **kwargs):
+ defaults = {'choices': self.get_choices_default(), 'required': not self.blank, 'label': capfirst(self.verbose_name), 'help_text': self.help_text}
+ defaults.update(kwargs)
+ return forms.ChoiceField(**kwargs)
+
class ManyToManyField(RelatedField, Field):
def __init__(self, to, **kwargs):
kwargs['verbose_name'] = kwargs.get('verbose_name', None)
@@ -610,6 +633,7 @@ class ManyToManyField(RelatedField, Field):
limit_choices_to=kwargs.pop('limit_choices_to', None),
raw_id_admin=kwargs.pop('raw_id_admin', False),
symmetrical=kwargs.pop('symmetrical', True))
+ self.db_table = kwargs.pop('db_table', None)
if kwargs["rel"].raw_id_admin:
kwargs.setdefault("validator_list", []).append(self.isValidIDList)
Field.__init__(self, **kwargs)
@@ -622,17 +646,20 @@ class ManyToManyField(RelatedField, Field):
def get_manipulator_field_objs(self):
if self.rel.raw_id_admin:
- return [forms.RawIdAdminField]
+ return [oldforms.RawIdAdminField]
else:
choices = self.get_choices_default()
- return [curry(forms.SelectMultipleField, size=min(max(len(choices), 5), 15), choices=choices)]
+ return [curry(oldforms.SelectMultipleField, size=min(max(len(choices), 5), 15), choices=choices)]
def get_choices_default(self):
return Field.get_choices(self, include_blank=False)
def _get_m2m_db_table(self, opts):
"Function that can be curried to provide the m2m table name for this relation"
- return '%s_%s' % (opts.db_table, self.name)
+ if self.db_table:
+ return self.db_table
+ else:
+ return '%s_%s' % (opts.db_table, self.name)
def _get_m2m_column_name(self, related):
"Function that can be curried to provide the source column name for the m2m table"
@@ -706,6 +733,19 @@ class ManyToManyField(RelatedField, Field):
def set_attributes_from_rel(self):
pass
+ def value_from_object(self, obj):
+ "Returns the value of this field in the given model instance."
+ return getattr(obj, self.attname).all()
+
+ def formfield(self, **kwargs):
+ # If initial is passed in, it's a list of related objects, but the
+ # MultipleChoiceField takes a list of IDs.
+ if kwargs.get('initial') is not None:
+ kwargs['initial'] = [i._get_pk_val() for i in kwargs['initial']]
+ defaults = {'choices': self.get_choices_default(), 'required': not self.blank, 'label': capfirst(self.verbose_name), 'help_text': self.help_text}
+ defaults.update(kwargs)
+ return forms.MultipleChoiceField(**defaults)
+
class ManyToOneRel(object):
def __init__(self, to, field_name, num_in_admin=3, min_num_in_admin=None,
max_num_in_admin=None, num_extra_on_change=1, edit_inline=False,
diff --git a/django/db/models/manager.py b/django/db/models/manager.py
index 6005874516..b60eed262a 100644
--- a/django/db/models/manager.py
+++ b/django/db/models/manager.py
@@ -1,4 +1,4 @@
-from django.db.models.query import QuerySet
+from django.db.models.query import QuerySet, EmptyQuerySet
from django.dispatch import dispatcher
from django.db.models import signals
from django.db.models.fields import FieldDoesNotExist
@@ -41,12 +41,18 @@ class Manager(object):
#######################
# PROXIES TO QUERYSET #
#######################
+
+ def get_empty_query_set(self):
+ return EmptyQuerySet(self.model)
def get_query_set(self):
"""Returns a new QuerySet object. Subclasses can override this method
to easily customise the behaviour of the Manager.
"""
return QuerySet(self.model)
+
+ def none(self):
+ return self.get_empty_query_set()
def all(self):
return self.get_query_set()
diff --git a/django/db/models/manipulators.py b/django/db/models/manipulators.py
index c61e82f813..aea3aa70a7 100644
--- a/django/db/models/manipulators.py
+++ b/django/db/models/manipulators.py
@@ -1,5 +1,5 @@
from django.core.exceptions import ObjectDoesNotExist
-from django import forms
+from django import oldforms
from django.core import validators
from django.db.models.fields import FileField, AutoField
from django.dispatch import dispatcher
@@ -40,7 +40,7 @@ class ManipulatorDescriptor(object):
self.man._prepare(model)
return self.man
-class AutomaticManipulator(forms.Manipulator):
+class AutomaticManipulator(oldforms.Manipulator):
def _prepare(cls, model):
cls.model = model
cls.manager = model._default_manager
@@ -76,7 +76,7 @@ class AutomaticManipulator(forms.Manipulator):
# Add field for ordering.
if self.change and self.opts.get_ordered_objects():
- self.fields.append(forms.CommaSeparatedIntegerField(field_name="order_"))
+ self.fields.append(oldforms.CommaSeparatedIntegerField(field_name="order_"))
def save(self, new_data):
# TODO: big cleanup when core fields go -> use recursive manipulators.
@@ -308,7 +308,7 @@ def manipulator_validator_unique_together(field_name_list, opts, self, field_dat
def manipulator_validator_unique_for_date(from_field, date_field, opts, lookup_type, self, field_data, all_data):
from django.db.models.fields.related import ManyToOneRel
date_str = all_data.get(date_field.get_manipulator_field_names('')[0], None)
- date_val = forms.DateField.html2python(date_str)
+ date_val = oldforms.DateField.html2python(date_str)
if date_val is None:
return # Date was invalid. This will be caught by another validator.
lookup_kwargs = {'%s__year' % date_field.name: date_val.year}
diff --git a/django/db/models/query.py b/django/db/models/query.py
index 53ed63ae5b..51fa334e63 100644
--- a/django/db/models/query.py
+++ b/django/db/models/query.py
@@ -1,5 +1,6 @@
from django.db import backend, connection, transaction
from django.db.models.fields import DateField, FieldDoesNotExist
+from django.db.models.fields.generic import GenericRelation
from django.db.models import signals
from django.dispatch import dispatcher
from django.utils.datastructures import SortedDict
@@ -25,6 +26,9 @@ QUERY_TERMS = (
# Larger values are slightly faster at the expense of more storage space.
GET_ITERATOR_CHUNK_SIZE = 100
+class EmptyResultSet(Exception):
+ pass
+
####################
# HELPER FUNCTIONS #
####################
@@ -168,7 +172,12 @@ class QuerySet(object):
extra_select = self._select.items()
cursor = connection.cursor()
- select, sql, params = self._get_sql_clause()
+
+ try:
+ select, sql, params = self._get_sql_clause()
+ except EmptyResultSet:
+ raise StopIteration
+
cursor.execute("SELECT " + (self._distinct and "DISTINCT " or "") + ",".join(select) + sql, params)
fill_cache = self._select_related
index_end = len(self.model._meta.fields)
@@ -192,7 +201,12 @@ class QuerySet(object):
counter._offset = None
counter._limit = None
counter._select_related = False
- select, sql, params = counter._get_sql_clause()
+
+ try:
+ select, sql, params = counter._get_sql_clause()
+ except EmptyResultSet:
+ return 0
+
cursor = connection.cursor()
if self._distinct:
id_col = "%s.%s" % (backend.quote_name(self.model._meta.db_table),
@@ -523,7 +537,12 @@ class ValuesQuerySet(QuerySet):
field_names = [f.attname for f in self.model._meta.fields]
cursor = connection.cursor()
- select, sql, params = self._get_sql_clause()
+
+ try:
+ select, sql, params = self._get_sql_clause()
+ except EmptyResultSet:
+ raise StopIteration
+
select = ['%s.%s' % (backend.quote_name(self.model._meta.db_table), backend.quote_name(c)) for c in columns]
cursor.execute("SELECT " + (self._distinct and "DISTINCT " or "") + ",".join(select) + sql, params)
while 1:
@@ -545,7 +564,12 @@ class DateQuerySet(QuerySet):
if self._field.null:
self._where.append('%s.%s IS NOT NULL' % \
(backend.quote_name(self.model._meta.db_table), backend.quote_name(self._field.column)))
- select, sql, params = self._get_sql_clause()
+
+ try:
+ select, sql, params = self._get_sql_clause()
+ except EmptyResultSet:
+ raise StopIteration
+
sql = 'SELECT %s %s GROUP BY 1 ORDER BY 1 %s' % \
(backend.get_date_trunc_sql(self._kind, '%s.%s' % (backend.quote_name(self.model._meta.db_table),
backend.quote_name(self._field.column))), sql, self._order)
@@ -562,6 +586,25 @@ class DateQuerySet(QuerySet):
c._kind = self._kind
c._order = self._order
return c
+
+class EmptyQuerySet(QuerySet):
+ def __init__(self, model=None):
+ super(EmptyQuerySet, self).__init__(model)
+ self._result_cache = []
+
+ def iterator(self):
+ raise StopIteration
+
+ def count(self):
+ return 0
+
+ def delete(self):
+ pass
+
+ def _clone(self, klass=None, **kwargs):
+ c = super(EmptyQuerySet, self)._clone(klass, **kwargs)
+ c._result_cache = []
+ return c
class QOperator(object):
"Base class for QAnd and QOr"
@@ -571,10 +614,14 @@ class QOperator(object):
def get_sql(self, opts):
joins, where, params = SortedDict(), [], []
for val in self.args:
- joins2, where2, params2 = val.get_sql(opts)
- joins.update(joins2)
- where.extend(where2)
- params.extend(params2)
+ try:
+ joins2, where2, params2 = val.get_sql(opts)
+ joins.update(joins2)
+ where.extend(where2)
+ params.extend(params2)
+ except EmptyResultSet:
+ if not isinstance(self, QOr):
+ raise EmptyResultSet
if where:
return joins, ['(%s)' % self.operator.join(where)], params
return joins, [], params
@@ -628,8 +675,11 @@ class QNot(Q):
self.q = q
def get_sql(self, opts):
- joins, where, params = self.q.get_sql(opts)
- where2 = ['(NOT (%s))' % " AND ".join(where)]
+ try:
+ joins, where, params = self.q.get_sql(opts)
+ where2 = ['(NOT (%s))' % " AND ".join(where)]
+ except EmptyResultSet:
+ return SortedDict(), [], []
return joins, where2, params
def get_where_clause(lookup_type, table_prefix, field_name, value):
@@ -641,7 +691,11 @@ def get_where_clause(lookup_type, table_prefix, field_name, value):
except KeyError:
pass
if lookup_type == 'in':
- return '%s%s IN (%s)' % (table_prefix, field_name, ','.join(['%s' for v in value]))
+ in_string = ','.join(['%s' for id in value])
+ if in_string:
+ return '%s%s IN (%s)' % (table_prefix, field_name, in_string)
+ else:
+ raise EmptyResultSet
elif lookup_type == 'range':
return '%s%s BETWEEN %%s AND %%s' % (table_prefix, field_name)
elif lookup_type in ('year', 'month', 'day'):
@@ -926,18 +980,26 @@ def delete_objects(seen_objs):
pk_list = [pk for pk,instance in seen_objs[cls]]
for related in cls._meta.get_all_related_many_to_many_objects():
- for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
- cursor.execute("DELETE FROM %s WHERE %s IN (%s)" % \
- (qn(related.field.m2m_db_table()),
- qn(related.field.m2m_reverse_name()),
- ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]])),
- pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE])
+ if not isinstance(related.field, GenericRelation):
+ for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
+ cursor.execute("DELETE FROM %s WHERE %s IN (%s)" % \
+ (qn(related.field.m2m_db_table()),
+ qn(related.field.m2m_reverse_name()),
+ ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]])),
+ pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE])
for f in cls._meta.many_to_many:
+ if isinstance(f, GenericRelation):
+ from django.contrib.contenttypes.models import ContentType
+ query_extra = 'AND %s=%%s' % f.rel.to._meta.get_field(f.content_type_field_name).column
+ args_extra = [ContentType.objects.get_for_model(cls).id]
+ else:
+ query_extra = ''
+ args_extra = []
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
- cursor.execute("DELETE FROM %s WHERE %s IN (%s)" % \
+ cursor.execute(("DELETE FROM %s WHERE %s IN (%s)" % \
(qn(f.m2m_db_table()), qn(f.m2m_column_name()),
- ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]])),
- pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE])
+ ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]]))) + query_extra,
+ pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE] + args_extra)
for field in cls._meta.fields:
if field.rel and field.null and field.rel.to in seen_objs:
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):