summaryrefslogtreecommitdiff
path: root/django/db/models/fields
diff options
context:
space:
mode:
authorRussell Keith-Magee <russell@keith-magee.com>2009-12-22 15:18:51 +0000
committerRussell Keith-Magee <russell@keith-magee.com>2009-12-22 15:18:51 +0000
commitff60c5f9de3e8690d1e86f3e9e3f7248a15397c8 (patch)
treea4cb0ebdd55fcaf8c8855231b6ad3e1a7bf45bee /django/db/models/fields
parent7ef212af149540aa2da577a960d0d87029fd1514 (diff)
Fixed #1142 -- Added multiple database support.
This monster of a patch is the result of Alex Gaynor's 2009 Google Summer of Code project. Congratulations to Alex for a job well done. Big thanks also go to: * Justin Bronn for keeping GIS in line with the changes, * Karen Tracey and Jani Tiainen for their help testing Oracle support * Brett Hoerner, Jon Loyens, and Craig Kimmerer for their feedback. * Malcolm Treddinick for his guidance during the GSoC submission process. * Simon Willison for driving the original design process * Cal Henderson for complaining about ponies he wanted. ... and everyone else too numerous to mention that helped to bring this feature into fruition. git-svn-id: http://code.djangoproject.com/svn/django/trunk@11952 bcc190cf-cafb-0310-a4f2-bffc1f526a37
Diffstat (limited to 'django/db/models/fields')
-rw-r--r--django/db/models/fields/__init__.py134
-rw-r--r--django/db/models/fields/files.py6
-rw-r--r--django/db/models/fields/related.py87
-rw-r--r--django/db/models/fields/subclassing.py72
4 files changed, 223 insertions, 76 deletions
diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py
index 1be0bc353c..b70f320df3 100644
--- a/django/db/models/fields/__init__.py
+++ b/django/db/models/fields/__init__.py
@@ -8,6 +8,7 @@ import django.utils.copycompat as copy
from django.db import connection
from django.db.models import signals
+from django.db.models.fields.subclassing import LegacyConnection
from django.db.models.query_utils import QueryWrapper
from django.dispatch import dispatcher
from django.conf import settings
@@ -47,6 +48,9 @@ class FieldDoesNotExist(Exception):
# getattr(obj, opts.pk.attname)
class Field(object):
+ """Base class for all field types"""
+ __metaclass__ = LegacyConnection
+
# Designates whether empty strings fundamentally are allowed at the
# database level.
empty_strings_allowed = True
@@ -123,10 +127,10 @@ class Field(object):
"""
return value
- def db_type(self):
+ def db_type(self, connection):
"""
- Returns the database column data type for this field, taking into
- account the DATABASE_ENGINE setting.
+ Returns the database column data type for this field, for the provided
+ connection.
"""
# The default implementation of this method looks at the
# backend-specific DATA_TYPES dictionary, looking up the field by its
@@ -183,21 +187,56 @@ class Field(object):
"Returns field's value just before saving."
return getattr(model_instance, self.attname)
- def get_db_prep_value(self, value):
+ def get_prep_value(self, value):
+ "Perform preliminary non-db specific value checks and conversions."
+ return value
+
+ def get_db_prep_value(self, value, connection, prepared=False):
"""Returns field's value prepared for interacting with the database
backend.
Used by the default implementations of ``get_db_prep_save``and
`get_db_prep_lookup```
"""
+ if not prepared:
+ value = self.get_prep_value(value)
return value
- def get_db_prep_save(self, value):
+ def get_db_prep_save(self, value, connection):
"Returns field's value prepared for saving into a database."
- return self.get_db_prep_value(value)
+ return self.get_db_prep_value(value, connection=connection, prepared=False)
+
+ def get_prep_lookup(self, lookup_type, value):
+ "Perform preliminary non-db specific lookup checks and conversions"
+ if hasattr(value, 'prepare'):
+ return value.prepare()
+ if hasattr(value, '_prepare'):
+ return value._prepare()
+
+ if lookup_type in (
+ 'regex', 'iregex', 'month', 'day', 'week_day', 'search',
+ 'contains', 'icontains', 'iexact', 'startswith', 'istartswith',
+ 'endswith', 'iendswith', 'isnull'
+ ):
+ return value
+ elif lookup_type in ('exact', 'gt', 'gte', 'lt', 'lte'):
+ return self.get_prep_value(value)
+ elif lookup_type in ('range', 'in'):
+ return [self.get_prep_value(v) for v in value]
+ elif lookup_type == 'year':
+ try:
+ return int(value)
+ except ValueError:
+ raise ValueError("The __year lookup type requires an integer argument")
- def get_db_prep_lookup(self, lookup_type, value):
+ raise TypeError("Field has invalid lookup: %s" % lookup_type)
+
+ def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False):
"Returns field's value prepared for database lookup."
+ if not prepared:
+ value = self.get_prep_lookup(lookup_type, value)
+ if hasattr(value, 'get_compiler'):
+ value = value.get_compiler(connection=connection)
if hasattr(value, 'as_sql') or hasattr(value, '_as_sql'):
# If the value has a relabel_aliases method, it will need to
# be invoked before the final SQL is evaluated
@@ -206,15 +245,15 @@ class Field(object):
if hasattr(value, 'as_sql'):
sql, params = value.as_sql()
else:
- sql, params = value._as_sql()
+ sql, params = value._as_sql(connection=connection)
return QueryWrapper(('(%s)' % sql), params)
if lookup_type in ('regex', 'iregex', 'month', 'day', 'week_day', 'search'):
return [value]
elif lookup_type in ('exact', 'gt', 'gte', 'lt', 'lte'):
- return [self.get_db_prep_value(value)]
+ return [self.get_db_prep_value(value, connection=connection, prepared=prepared)]
elif lookup_type in ('range', 'in'):
- return [self.get_db_prep_value(v) for v in value]
+ return [self.get_db_prep_value(v, connection=connection, prepared=prepared) for v in value]
elif lookup_type in ('contains', 'icontains'):
return ["%%%s%%" % connection.ops.prep_for_like_query(value)]
elif lookup_type == 'iexact':
@@ -226,18 +265,11 @@ class Field(object):
elif lookup_type == 'isnull':
return []
elif lookup_type == 'year':
- try:
- value = int(value)
- except ValueError:
- raise ValueError("The __year lookup type requires an integer argument")
-
if self.get_internal_type() == 'DateField':
return connection.ops.year_lookup_bounds_for_date_field(value)
else:
return connection.ops.year_lookup_bounds(value)
- raise TypeError("Field has invalid lookup: %s" % lookup_type)
-
def has_default(self):
"Returns a boolean of whether this field has a default value."
return self.default is not NOT_PROVIDED
@@ -346,6 +378,7 @@ class Field(object):
class AutoField(Field):
description = ugettext_lazy("Integer")
+
empty_strings_allowed = False
def __init__(self, *args, **kwargs):
assert kwargs.get('primary_key', False) is True, "%ss must have primary_key=True." % self.__class__.__name__
@@ -361,7 +394,7 @@ class AutoField(Field):
raise exceptions.ValidationError(
_("This value must be an integer."))
- def get_db_prep_value(self, value):
+ def get_prep_value(self, value):
if value is None:
return None
return int(value)
@@ -394,16 +427,16 @@ class BooleanField(Field):
raise exceptions.ValidationError(
_("This value must be either True or False."))
- def get_db_prep_lookup(self, lookup_type, value):
+ def get_prep_lookup(self, lookup_type, value):
# Special-case handling for filters coming from a web request (e.g. the
# admin interface). Only works for scalar values (not lists). If you're
# passing in a list, you might as well make things the right type when
# constructing the list.
if value in ('1', '0'):
value = bool(int(value))
- return super(BooleanField, self).get_db_prep_lookup(lookup_type, value)
+ return super(BooleanField, self).get_prep_lookup(lookup_type, value)
- def get_db_prep_value(self, value):
+ def get_prep_value(self, value):
if value is None:
return None
return bool(value)
@@ -421,6 +454,7 @@ class BooleanField(Field):
class CharField(Field):
description = ugettext_lazy("String (up to %(max_length)s)")
+
def get_internal_type(self):
return "CharField"
@@ -443,6 +477,7 @@ class CharField(Field):
# TODO: Maybe move this into contrib, because it's specialized.
class CommaSeparatedIntegerField(CharField):
description = ugettext_lazy("Comma-separated integers")
+
def formfield(self, **kwargs):
defaults = {
'form_class': forms.RegexField,
@@ -459,6 +494,7 @@ ansi_date_re = re.compile(r'^\d{4}-\d{1,2}-\d{1,2}$')
class DateField(Field):
description = ugettext_lazy("Date (without time)")
+
empty_strings_allowed = False
def __init__(self, verbose_name=None, name=None, auto_now=False, auto_now_add=False, **kwargs):
self.auto_now, self.auto_now_add = auto_now, auto_now_add
@@ -509,16 +545,21 @@ class DateField(Field):
setattr(cls, 'get_previous_by_%s' % self.name,
curry(cls._get_next_or_previous_by_FIELD, field=self, is_next=False))
- def get_db_prep_lookup(self, lookup_type, value):
+ def get_prep_lookup(self, lookup_type, value):
# For "__month", "__day", and "__week_day" lookups, convert the value
# to an int so the database backend always sees a consistent type.
if lookup_type in ('month', 'day', 'week_day'):
- return [int(value)]
- return super(DateField, self).get_db_prep_lookup(lookup_type, value)
+ return int(value)
+ return super(DateField, self).get_prep_lookup(lookup_type, value)
- def get_db_prep_value(self, value):
+ def get_prep_value(self, value):
+ return self.to_python(value)
+
+ def get_db_prep_value(self, value, connection, prepared=False):
# Casts dates into the format expected by the backend
- return connection.ops.value_to_db_date(self.to_python(value))
+ if not prepared:
+ value = self.get_prep_value(value)
+ return connection.ops.value_to_db_date(value)
def value_to_string(self, obj):
val = self._get_val_from_obj(obj)
@@ -535,6 +576,7 @@ class DateField(Field):
class DateTimeField(DateField):
description = ugettext_lazy("Date (with time)")
+
def get_internal_type(self):
return "DateTimeField"
@@ -575,9 +617,14 @@ class DateTimeField(DateField):
raise exceptions.ValidationError(
_('Enter a valid date/time in YYYY-MM-DD HH:MM[:ss[.uuuuuu]] format.'))
- def get_db_prep_value(self, value):
+ def get_prep_value(self, value):
+ return self.to_python(value)
+
+ def get_db_prep_value(self, value, connection, prepared=False):
# Casts dates into the format expected by the backend
- return connection.ops.value_to_db_datetime(self.to_python(value))
+ if not prepared:
+ value = self.get_prep_value(value)
+ return connection.ops.value_to_db_datetime(value)
def value_to_string(self, obj):
val = self._get_val_from_obj(obj)
@@ -632,11 +679,11 @@ class DecimalField(Field):
from django.db.backends import util
return util.format_number(value, self.max_digits, self.decimal_places)
- def get_db_prep_save(self, value):
+ def get_db_prep_save(self, value, connection):
return connection.ops.value_to_db_decimal(self.to_python(value),
self.max_digits, self.decimal_places)
- def get_db_prep_value(self, value):
+ def get_prep_value(self, value):
return self.to_python(value)
def formfield(self, **kwargs):
@@ -661,6 +708,7 @@ class EmailField(CharField):
class FilePathField(Field):
description = ugettext_lazy("File path")
+
def __init__(self, verbose_name=None, name=None, path='', match=None, recursive=False, **kwargs):
self.path, self.match, self.recursive = path, match, recursive
kwargs['max_length'] = kwargs.get('max_length', 100)
@@ -683,7 +731,7 @@ class FloatField(Field):
empty_strings_allowed = False
description = ugettext_lazy("Floating point number")
- def get_db_prep_value(self, value):
+ def get_prep_value(self, value):
if value is None:
return None
return float(value)
@@ -708,7 +756,8 @@ class FloatField(Field):
class IntegerField(Field):
empty_strings_allowed = False
description = ugettext_lazy("Integer")
- def get_db_prep_value(self, value):
+
+ def get_prep_value(self, value):
if value is None:
return None
return int(value)
@@ -776,16 +825,16 @@ class NullBooleanField(Field):
raise exceptions.ValidationError(
_("This value must be either None, True or False."))
- def get_db_prep_lookup(self, lookup_type, value):
+ def get_prep_lookup(self, lookup_type, value):
# Special-case handling for filters coming from a web request (e.g. the
# admin interface). Only works for scalar values (not lists). If you're
# passing in a list, you might as well make things the right type when
# constructing the list.
if value in ('1', '0'):
value = bool(int(value))
- return super(NullBooleanField, self).get_db_prep_lookup(lookup_type, value)
+ return super(NullBooleanField, self).get_prep_lookup(lookup_type, value)
- def get_db_prep_value(self, value):
+ def get_prep_value(self, value):
if value is None:
return None
return bool(value)
@@ -801,6 +850,7 @@ class NullBooleanField(Field):
class PositiveIntegerField(IntegerField):
description = ugettext_lazy("Integer")
+
def get_internal_type(self):
return "PositiveIntegerField"
@@ -838,11 +888,13 @@ class SlugField(CharField):
class SmallIntegerField(IntegerField):
description = ugettext_lazy("Integer")
+
def get_internal_type(self):
return "SmallIntegerField"
class TextField(Field):
description = ugettext_lazy("Text")
+
def get_internal_type(self):
return "TextField"
@@ -853,6 +905,7 @@ class TextField(Field):
class TimeField(Field):
description = ugettext_lazy("Time")
+
empty_strings_allowed = False
def __init__(self, verbose_name=None, name=None, auto_now=False, auto_now_add=False, **kwargs):
self.auto_now, self.auto_now_add = auto_now, auto_now_add
@@ -907,9 +960,14 @@ class TimeField(Field):
else:
return super(TimeField, self).pre_save(model_instance, add)
- def get_db_prep_value(self, value):
+ def get_prep_value(self, value):
+ return self.to_python(value)
+
+ def get_db_prep_value(self, value, connection, prepared=False):
# Casts times into the format expected by the backend
- return connection.ops.value_to_db_time(self.to_python(value))
+ if not prepared:
+ value = self.get_prep_value(value)
+ return connection.ops.value_to_db_time(value)
def value_to_string(self, obj):
val = self._get_val_from_obj(obj)
@@ -926,6 +984,7 @@ class TimeField(Field):
class URLField(CharField):
description = ugettext_lazy("URL")
+
def __init__(self, verbose_name=None, name=None, verify_exists=True, **kwargs):
kwargs['max_length'] = kwargs.get('max_length', 200)
self.verify_exists = verify_exists
@@ -938,6 +997,7 @@ class URLField(CharField):
class XMLField(TextField):
description = ugettext_lazy("XML text")
+
def __init__(self, verbose_name=None, name=None, schema_path=None, **kwargs):
self.schema_path = schema_path
Field.__init__(self, verbose_name, name, **kwargs)
diff --git a/django/db/models/fields/files.py b/django/db/models/fields/files.py
index 97cb4dc082..6dfeddbc41 100644
--- a/django/db/models/fields/files.py
+++ b/django/db/models/fields/files.py
@@ -235,12 +235,12 @@ class FileField(Field):
def get_internal_type(self):
return "FileField"
- def get_db_prep_lookup(self, lookup_type, value):
+ def get_prep_lookup(self, lookup_type, value):
if hasattr(value, 'name'):
value = value.name
- return super(FileField, self).get_db_prep_lookup(lookup_type, value)
+ return super(FileField, self).get_prep_lookup(lookup_type, value)
- def get_db_prep_value(self, value):
+ def get_prep_value(self, value):
"Returns field's value prepared for saving into a database."
# Need to convert File objects provided via a form to unicode for database insertion
if value is None:
diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py
index e412e6b6b8..7cc9a03907 100644
--- a/django/db/models/fields/related.py
+++ b/django/db/models/fields/related.py
@@ -1,7 +1,8 @@
-from django.db import connection, transaction
+from django.db import connection, transaction, DEFAULT_DB_ALIAS
from django.db.backends import util
from django.db.models import signals, get_model
-from django.db.models.fields import AutoField, Field, IntegerField, PositiveIntegerField, PositiveSmallIntegerField, FieldDoesNotExist
+from django.db.models.fields import (AutoField, Field, IntegerField,
+ PositiveIntegerField, PositiveSmallIntegerField, FieldDoesNotExist)
from django.db.models.related import RelatedObject
from django.db.models.query import QuerySet
from django.db.models.query_utils import QueryWrapper
@@ -11,10 +12,6 @@ from django.utils.functional import curry
from django.core import exceptions
from django import forms
-try:
- set
-except NameError:
- from sets import Set as set # Python 2.3 fallback
RECURSIVE_RELATIONSHIP_CONSTANT = 'self'
@@ -120,7 +117,7 @@ class RelatedField(object):
if not cls._meta.abstract:
self.contribute_to_related_class(other, self.related)
- def get_db_prep_lookup(self, lookup_type, value):
+ def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False):
# If we are doing a lookup on a Related Field, we must be
# comparing object instances. The value should be the PK of value,
# not value itself.
@@ -140,11 +137,16 @@ class RelatedField(object):
if field:
if lookup_type in ('range', 'in'):
v = [v]
- v = field.get_db_prep_lookup(lookup_type, v)
+ v = field.get_db_prep_lookup(lookup_type, v,
+ connection=connection, prepared=prepared)
if isinstance(v, list):
v = v[0]
return v
+ if not prepared:
+ value = self.get_prep_lookup(lookup_type, value)
+ if hasattr(value, 'get_compiler'):
+ value = value.get_compiler(connection=connection)
if hasattr(value, 'as_sql') or hasattr(value, '_as_sql'):
# If the value has a relabel_aliases method, it will need to
# be invoked before the final SQL is evaluated
@@ -153,7 +155,7 @@ class RelatedField(object):
if hasattr(value, 'as_sql'):
sql, params = value.as_sql()
else:
- sql, params = value._as_sql()
+ sql, params = value._as_sql(connection=connection)
return QueryWrapper(('(%s)' % sql), params)
# FIXME: lt and gt are explicitally allowed to make
@@ -192,7 +194,7 @@ class SingleRelatedObjectDescriptor(object):
return getattr(instance, self.cache_name)
except AttributeError:
params = {'%s__pk' % self.related.field.name: instance._get_pk_val()}
- rel_obj = self.related.model._base_manager.get(**params)
+ rel_obj = self.related.model._base_manager.using(instance._state.db).get(**params)
setattr(instance, self.cache_name, rel_obj)
return rel_obj
@@ -255,10 +257,11 @@ class ReverseSingleRelatedObjectDescriptor(object):
# If the related manager indicates that it should be used for
# related fields, respect that.
rel_mgr = self.field.rel.to._default_manager
+ using = instance._state.db or DEFAULT_DB_ALIAS
if getattr(rel_mgr, 'use_for_related_fields', False):
- rel_obj = rel_mgr.get(**params)
+ rel_obj = rel_mgr.using(using).get(**params)
else:
- rel_obj = QuerySet(self.field.rel.to).get(**params)
+ rel_obj = QuerySet(self.field.rel.to).using(using).get(**params)
setattr(instance, cache_name, rel_obj)
return rel_obj
@@ -275,6 +278,14 @@ class ReverseSingleRelatedObjectDescriptor(object):
raise ValueError('Cannot assign "%r": "%s.%s" must be a "%s" instance.' %
(value, instance._meta.object_name,
self.field.name, self.field.rel.to._meta.object_name))
+ elif value is not None and value._state.db != instance._state.db:
+ if instance._state.db is None:
+ instance._state.db = value._state.db
+ else:#elif value._state.db is None:
+ value._state.db = instance._state.db
+# elif value._state.db is not None and instance._state.db is not None:
+# raise ValueError('Cannot assign "%r": instance is on database "%s", value is is on database "%s"' %
+# (value, instance._state.db, value._state.db))
# If we're setting the value of a OneToOneField to None, we need to clear
# out the cache on any old related object. Otherwise, deleting the
@@ -356,14 +367,15 @@ class ForeignRelatedObjectsDescriptor(object):
class RelatedManager(superclass):
def get_query_set(self):
- return superclass.get_query_set(self).filter(**(self.core_filters))
+ using = instance._state.db or DEFAULT_DB_ALIAS
+ return superclass.get_query_set(self).using(using).filter(**(self.core_filters))
def add(self, *objs):
for obj in objs:
if not isinstance(obj, self.model):
raise TypeError, "'%s' instance expected" % self.model._meta.object_name
setattr(obj, rel_field.name, instance)
- obj.save()
+ obj.save(using=instance._state.db)
add.alters_data = True
def create(self, **kwargs):
@@ -375,7 +387,8 @@ class ForeignRelatedObjectsDescriptor(object):
# Update kwargs with the related object that this
# ForeignRelatedObjectsDescriptor knows about.
kwargs.update({rel_field.name: instance})
- return super(RelatedManager, self).get_or_create(**kwargs)
+ using = instance._state.db or DEFAULT_DB_ALIAS
+ return super(RelatedManager, self).using(using).get_or_create(**kwargs)
get_or_create.alters_data = True
# remove() and clear() are only provided if the ForeignKey can have a value of null.
@@ -386,7 +399,7 @@ class ForeignRelatedObjectsDescriptor(object):
# Is obj actually part of this descriptor set?
if getattr(obj, rel_field.attname) == val:
setattr(obj, rel_field.name, None)
- obj.save()
+ obj.save(using=instance._state.db)
else:
raise rel_field.rel.to.DoesNotExist, "%r is not related to %r." % (obj, instance)
remove.alters_data = True
@@ -394,7 +407,7 @@ class ForeignRelatedObjectsDescriptor(object):
def clear(self):
for obj in self.all():
setattr(obj, rel_field.name, None)
- obj.save()
+ obj.save(using=instance._state.db)
clear.alters_data = True
manager = RelatedManager()
@@ -425,7 +438,7 @@ def create_many_related_manager(superclass, rel=False):
raise ValueError("%r instance needs to have a primary key value before a many-to-many relationship can be used." % instance.__class__.__name__)
def get_query_set(self):
- return superclass.get_query_set(self)._next_is_sticky().filter(**(self.core_filters))
+ return superclass.get_query_set(self).using(self.instance._state.db)._next_is_sticky().filter(**(self.core_filters))
# If the ManyToMany relation has an intermediary model,
# the add and remove methods do not exist.
@@ -460,14 +473,14 @@ def create_many_related_manager(superclass, rel=False):
if not rel.through._meta.auto_created:
opts = through._meta
raise AttributeError, "Cannot use create() on a ManyToManyField which specifies an intermediary model. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name)
- new_obj = super(ManyRelatedManager, self).create(**kwargs)
+ new_obj = super(ManyRelatedManager, self).using(self.instance._state.db).create(**kwargs)
self.add(new_obj)
return new_obj
create.alters_data = True
def get_or_create(self, **kwargs):
obj, created = \
- super(ManyRelatedManager, self).get_or_create(**kwargs)
+ super(ManyRelatedManager, self).using(self.instance._state.db).get_or_create(**kwargs)
# We only need to add() if created because if we got an object back
# from get() then the relationship already exists.
if created:
@@ -487,12 +500,15 @@ def create_many_related_manager(superclass, rel=False):
new_ids = set()
for obj in objs:
if isinstance(obj, self.model):
+# if obj._state.db != self.instance._state.db:
+# raise ValueError('Cannot add "%r": instance is on database "%s", value is is on database "%s"' %
+# (obj, self.instance._state.db, obj._state.db))
new_ids.add(obj.pk)
elif isinstance(obj, Model):
raise TypeError, "'%s' instance expected" % self.model._meta.object_name
else:
new_ids.add(obj)
- vals = self.through._default_manager.values_list(target_field_name, flat=True)
+ vals = self.through._default_manager.using(self.instance._state.db).values_list(target_field_name, flat=True)
vals = vals.filter(**{
source_field_name: self._pk_val,
'%s__in' % target_field_name: new_ids,
@@ -501,7 +517,7 @@ def create_many_related_manager(superclass, rel=False):
# Add the ones that aren't there already
for obj_id in (new_ids - vals):
- self.through._default_manager.create(**{
+ self.through._default_manager.using(self.instance._state.db).create(**{
'%s_id' % source_field_name: self._pk_val,
'%s_id' % target_field_name: obj_id,
})
@@ -521,14 +537,14 @@ def create_many_related_manager(superclass, rel=False):
else:
old_ids.add(obj)
# Remove the specified objects from the join table
- self.through._default_manager.filter(**{
+ self.through._default_manager.using(self.instance._state.db).filter(**{
source_field_name: self._pk_val,
'%s__in' % target_field_name: old_ids
}).delete()
def _clear_items(self, source_field_name):
# source_col_name: the PK colname in join_table for the source object
- self.through._default_manager.filter(**{
+ self.through._default_manager.using(self.instance._state.db).filter(**{
source_field_name: self._pk_val
}).delete()
@@ -728,11 +744,12 @@ class ForeignKey(RelatedField, Field):
return getattr(field_default, self.rel.get_related_field().attname)
return field_default
- def get_db_prep_save(self, value):
+ def get_db_prep_save(self, value, connection):
if value == '' or value == None:
return None
else:
- return self.rel.get_related_field().get_db_prep_save(value)
+ return self.rel.get_related_field().get_db_prep_save(value,
+ connection=connection)
def value_to_string(self, obj):
if not obj:
@@ -764,16 +781,16 @@ class ForeignKey(RelatedField, Field):
self.rel.field_name = cls._meta.pk.name
def formfield(self, **kwargs):
+ db = kwargs.pop('using', None)
defaults = {
'form_class': forms.ModelChoiceField,
- 'queryset': self.rel.to._default_manager.complex_filter(
- self.rel.limit_choices_to),
+ 'queryset': self.rel.to._default_manager.using(db).complex_filter(self.rel.limit_choices_to),
'to_field_name': self.rel.field_name,
}
defaults.update(kwargs)
return super(ForeignKey, self).formfield(**defaults)
- def db_type(self):
+ def db_type(self, connection):
# The database column type of a ForeignKey is the column type
# of the field to which it points. An exception is if the ForeignKey
# points to an AutoField/PositiveIntegerField/PositiveSmallIntegerField,
@@ -785,8 +802,8 @@ class ForeignKey(RelatedField, Field):
(not connection.features.related_fields_match_type and
isinstance(rel_field, (PositiveIntegerField,
PositiveSmallIntegerField)))):
- return IntegerField().db_type()
- return rel_field.db_type()
+ return IntegerField().db_type(connection=connection)
+ return rel_field.db_type(connection=connection)
class OneToOneField(ForeignKey):
"""
@@ -1012,7 +1029,11 @@ class ManyToManyField(RelatedField, Field):
setattr(instance, self.attname, data)
def formfield(self, **kwargs):
- defaults = {'form_class': forms.ModelMultipleChoiceField, 'queryset': self.rel.to._default_manager.complex_filter(self.rel.limit_choices_to)}
+ db = kwargs.pop('using', None)
+ defaults = {
+ 'form_class': forms.ModelMultipleChoiceField,
+ 'queryset': self.rel.to._default_manager.using(db).complex_filter(self.rel.limit_choices_to)
+ }
defaults.update(kwargs)
# If initial is passed in, it's a list of related objects, but the
# MultipleChoiceField takes a list of IDs.
@@ -1023,7 +1044,7 @@ class ManyToManyField(RelatedField, Field):
defaults['initial'] = [i._get_pk_val() for i in initial]
return super(ManyToManyField, self).formfield(**defaults)
- def db_type(self):
+ def db_type(self, connection):
# A ManyToManyField is not represented by a single column,
# so return None.
return None
diff --git a/django/db/models/fields/subclassing.py b/django/db/models/fields/subclassing.py
index 10add10739..bd11675ad3 100644
--- a/django/db/models/fields/subclassing.py
+++ b/django/db/models/fields/subclassing.py
@@ -1,11 +1,77 @@
"""
-Convenience routines for creating non-trivial Field subclasses.
+Convenience routines for creating non-trivial Field subclasses, as well as
+backwards compatibility utilities.
Add SubfieldBase as the __metaclass__ for your Field subclass, implement
to_python() and the other necessary methods and everything will work seamlessly.
"""
-class SubfieldBase(type):
+from inspect import getargspec
+from warnings import warn
+
+def call_with_connection(func):
+ arg_names, varargs, varkwargs, defaults = getargspec(func)
+ updated = ('connection' in arg_names or varkwargs)
+ if not updated:
+ warn("A Field class whose %s method hasn't been updated to take a "
+ "`connection` argument." % func.__name__,
+ PendingDeprecationWarning, stacklevel=2)
+
+ def inner(*args, **kwargs):
+ if 'connection' not in kwargs:
+ from django.db import connection
+ kwargs['connection'] = connection
+ warn("%s has been called without providing a connection argument. " %
+ func.__name__, PendingDeprecationWarning,
+ stacklevel=1)
+ if updated:
+ return func(*args, **kwargs)
+ if 'connection' in kwargs:
+ del kwargs['connection']
+ return func(*args, **kwargs)
+ return inner
+
+def call_with_connection_and_prepared(func):
+ arg_names, varargs, varkwargs, defaults = getargspec(func)
+ updated = (
+ ('connection' in arg_names or varkwargs) and
+ ('prepared' in arg_names or varkwargs)
+ )
+ if not updated:
+ warn("A Field class whose %s method hasn't been updated to take "
+ "`connection` and `prepared` arguments." % func.__name__,
+ PendingDeprecationWarning, stacklevel=2)
+
+ def inner(*args, **kwargs):
+ if 'connection' not in kwargs:
+ from django.db import connection
+ kwargs['connection'] = connection
+ warn("%s has been called without providing a connection argument. " %
+ func.__name__, PendingDeprecationWarning,
+ stacklevel=1)
+ if updated:
+ return func(*args, **kwargs)
+ if 'connection' in kwargs:
+ del kwargs['connection']
+ if 'prepared' in kwargs:
+ del kwargs['prepared']
+ return func(*args, **kwargs)
+ return inner
+
+class LegacyConnection(type):
+ """
+ A metaclass to normalize arguments give to the get_db_prep_* and db_type
+ methods on fields.
+ """
+ def __new__(cls, names, bases, attrs):
+ new_cls = super(LegacyConnection, cls).__new__(cls, names, bases, attrs)
+ for attr in ('db_type', 'get_db_prep_save'):
+ setattr(new_cls, attr, call_with_connection(getattr(new_cls, attr)))
+ for attr in ('get_db_prep_lookup', 'get_db_prep_value'):
+ setattr(new_cls, attr, call_with_connection_and_prepared(getattr(new_cls, attr)))
+ return new_cls
+
+class SubfieldBase(LegacyConnection):
"""
A metaclass for custom Field subclasses. This ensures the model's attribute
has the descriptor protocol attached to it.
@@ -26,7 +92,7 @@ class Creator(object):
def __get__(self, obj, type=None):
if obj is None:
raise AttributeError('Can only be accessed via an instance.')
- return obj.__dict__[self.field.name]
+ return obj.__dict__[self.field.name]
def __set__(self, obj, value):
obj.__dict__[self.field.name] = self.field.to_python(value)