diff options
| author | Russell Keith-Magee <russell@keith-magee.com> | 2009-12-22 15:18:51 +0000 |
|---|---|---|
| committer | Russell Keith-Magee <russell@keith-magee.com> | 2009-12-22 15:18:51 +0000 |
| commit | ff60c5f9de3e8690d1e86f3e9e3f7248a15397c8 (patch) | |
| tree | a4cb0ebdd55fcaf8c8855231b6ad3e1a7bf45bee /django/db/models/fields | |
| parent | 7ef212af149540aa2da577a960d0d87029fd1514 (diff) | |
Fixed #1142 -- Added multiple database support.
This monster of a patch is the result of Alex Gaynor's 2009 Google Summer of Code project.
Congratulations to Alex for a job well done.
Big thanks also go to:
* Justin Bronn for keeping GIS in line with the changes,
* Karen Tracey and Jani Tiainen for their help testing Oracle support
* Brett Hoerner, Jon Loyens, and Craig Kimmerer for their feedback.
* Malcolm Treddinick for his guidance during the GSoC submission process.
* Simon Willison for driving the original design process
* Cal Henderson for complaining about ponies he wanted.
... and everyone else too numerous to mention that helped to bring this feature into fruition.
git-svn-id: http://code.djangoproject.com/svn/django/trunk@11952 bcc190cf-cafb-0310-a4f2-bffc1f526a37
Diffstat (limited to 'django/db/models/fields')
| -rw-r--r-- | django/db/models/fields/__init__.py | 134 | ||||
| -rw-r--r-- | django/db/models/fields/files.py | 6 | ||||
| -rw-r--r-- | django/db/models/fields/related.py | 87 | ||||
| -rw-r--r-- | django/db/models/fields/subclassing.py | 72 |
4 files changed, 223 insertions, 76 deletions
diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index 1be0bc353c..b70f320df3 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -8,6 +8,7 @@ import django.utils.copycompat as copy from django.db import connection from django.db.models import signals +from django.db.models.fields.subclassing import LegacyConnection from django.db.models.query_utils import QueryWrapper from django.dispatch import dispatcher from django.conf import settings @@ -47,6 +48,9 @@ class FieldDoesNotExist(Exception): # getattr(obj, opts.pk.attname) class Field(object): + """Base class for all field types""" + __metaclass__ = LegacyConnection + # Designates whether empty strings fundamentally are allowed at the # database level. empty_strings_allowed = True @@ -123,10 +127,10 @@ class Field(object): """ return value - def db_type(self): + def db_type(self, connection): """ - Returns the database column data type for this field, taking into - account the DATABASE_ENGINE setting. + Returns the database column data type for this field, for the provided + connection. """ # The default implementation of this method looks at the # backend-specific DATA_TYPES dictionary, looking up the field by its @@ -183,21 +187,56 @@ class Field(object): "Returns field's value just before saving." return getattr(model_instance, self.attname) - def get_db_prep_value(self, value): + def get_prep_value(self, value): + "Perform preliminary non-db specific value checks and conversions." + return value + + def get_db_prep_value(self, value, connection, prepared=False): """Returns field's value prepared for interacting with the database backend. Used by the default implementations of ``get_db_prep_save``and `get_db_prep_lookup``` """ + if not prepared: + value = self.get_prep_value(value) return value - def get_db_prep_save(self, value): + def get_db_prep_save(self, value, connection): "Returns field's value prepared for saving into a database." - return self.get_db_prep_value(value) + return self.get_db_prep_value(value, connection=connection, prepared=False) + + def get_prep_lookup(self, lookup_type, value): + "Perform preliminary non-db specific lookup checks and conversions" + if hasattr(value, 'prepare'): + return value.prepare() + if hasattr(value, '_prepare'): + return value._prepare() + + if lookup_type in ( + 'regex', 'iregex', 'month', 'day', 'week_day', 'search', + 'contains', 'icontains', 'iexact', 'startswith', 'istartswith', + 'endswith', 'iendswith', 'isnull' + ): + return value + elif lookup_type in ('exact', 'gt', 'gte', 'lt', 'lte'): + return self.get_prep_value(value) + elif lookup_type in ('range', 'in'): + return [self.get_prep_value(v) for v in value] + elif lookup_type == 'year': + try: + return int(value) + except ValueError: + raise ValueError("The __year lookup type requires an integer argument") - def get_db_prep_lookup(self, lookup_type, value): + raise TypeError("Field has invalid lookup: %s" % lookup_type) + + def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False): "Returns field's value prepared for database lookup." + if not prepared: + value = self.get_prep_lookup(lookup_type, value) + if hasattr(value, 'get_compiler'): + value = value.get_compiler(connection=connection) if hasattr(value, 'as_sql') or hasattr(value, '_as_sql'): # If the value has a relabel_aliases method, it will need to # be invoked before the final SQL is evaluated @@ -206,15 +245,15 @@ class Field(object): if hasattr(value, 'as_sql'): sql, params = value.as_sql() else: - sql, params = value._as_sql() + sql, params = value._as_sql(connection=connection) return QueryWrapper(('(%s)' % sql), params) if lookup_type in ('regex', 'iregex', 'month', 'day', 'week_day', 'search'): return [value] elif lookup_type in ('exact', 'gt', 'gte', 'lt', 'lte'): - return [self.get_db_prep_value(value)] + return [self.get_db_prep_value(value, connection=connection, prepared=prepared)] elif lookup_type in ('range', 'in'): - return [self.get_db_prep_value(v) for v in value] + return [self.get_db_prep_value(v, connection=connection, prepared=prepared) for v in value] elif lookup_type in ('contains', 'icontains'): return ["%%%s%%" % connection.ops.prep_for_like_query(value)] elif lookup_type == 'iexact': @@ -226,18 +265,11 @@ class Field(object): elif lookup_type == 'isnull': return [] elif lookup_type == 'year': - try: - value = int(value) - except ValueError: - raise ValueError("The __year lookup type requires an integer argument") - if self.get_internal_type() == 'DateField': return connection.ops.year_lookup_bounds_for_date_field(value) else: return connection.ops.year_lookup_bounds(value) - raise TypeError("Field has invalid lookup: %s" % lookup_type) - def has_default(self): "Returns a boolean of whether this field has a default value." return self.default is not NOT_PROVIDED @@ -346,6 +378,7 @@ class Field(object): class AutoField(Field): description = ugettext_lazy("Integer") + empty_strings_allowed = False def __init__(self, *args, **kwargs): assert kwargs.get('primary_key', False) is True, "%ss must have primary_key=True." % self.__class__.__name__ @@ -361,7 +394,7 @@ class AutoField(Field): raise exceptions.ValidationError( _("This value must be an integer.")) - def get_db_prep_value(self, value): + def get_prep_value(self, value): if value is None: return None return int(value) @@ -394,16 +427,16 @@ class BooleanField(Field): raise exceptions.ValidationError( _("This value must be either True or False.")) - def get_db_prep_lookup(self, lookup_type, value): + def get_prep_lookup(self, lookup_type, value): # Special-case handling for filters coming from a web request (e.g. the # admin interface). Only works for scalar values (not lists). If you're # passing in a list, you might as well make things the right type when # constructing the list. if value in ('1', '0'): value = bool(int(value)) - return super(BooleanField, self).get_db_prep_lookup(lookup_type, value) + return super(BooleanField, self).get_prep_lookup(lookup_type, value) - def get_db_prep_value(self, value): + def get_prep_value(self, value): if value is None: return None return bool(value) @@ -421,6 +454,7 @@ class BooleanField(Field): class CharField(Field): description = ugettext_lazy("String (up to %(max_length)s)") + def get_internal_type(self): return "CharField" @@ -443,6 +477,7 @@ class CharField(Field): # TODO: Maybe move this into contrib, because it's specialized. class CommaSeparatedIntegerField(CharField): description = ugettext_lazy("Comma-separated integers") + def formfield(self, **kwargs): defaults = { 'form_class': forms.RegexField, @@ -459,6 +494,7 @@ ansi_date_re = re.compile(r'^\d{4}-\d{1,2}-\d{1,2}$') class DateField(Field): description = ugettext_lazy("Date (without time)") + empty_strings_allowed = False def __init__(self, verbose_name=None, name=None, auto_now=False, auto_now_add=False, **kwargs): self.auto_now, self.auto_now_add = auto_now, auto_now_add @@ -509,16 +545,21 @@ class DateField(Field): setattr(cls, 'get_previous_by_%s' % self.name, curry(cls._get_next_or_previous_by_FIELD, field=self, is_next=False)) - def get_db_prep_lookup(self, lookup_type, value): + def get_prep_lookup(self, lookup_type, value): # For "__month", "__day", and "__week_day" lookups, convert the value # to an int so the database backend always sees a consistent type. if lookup_type in ('month', 'day', 'week_day'): - return [int(value)] - return super(DateField, self).get_db_prep_lookup(lookup_type, value) + return int(value) + return super(DateField, self).get_prep_lookup(lookup_type, value) - def get_db_prep_value(self, value): + def get_prep_value(self, value): + return self.to_python(value) + + def get_db_prep_value(self, value, connection, prepared=False): # Casts dates into the format expected by the backend - return connection.ops.value_to_db_date(self.to_python(value)) + if not prepared: + value = self.get_prep_value(value) + return connection.ops.value_to_db_date(value) def value_to_string(self, obj): val = self._get_val_from_obj(obj) @@ -535,6 +576,7 @@ class DateField(Field): class DateTimeField(DateField): description = ugettext_lazy("Date (with time)") + def get_internal_type(self): return "DateTimeField" @@ -575,9 +617,14 @@ class DateTimeField(DateField): raise exceptions.ValidationError( _('Enter a valid date/time in YYYY-MM-DD HH:MM[:ss[.uuuuuu]] format.')) - def get_db_prep_value(self, value): + def get_prep_value(self, value): + return self.to_python(value) + + def get_db_prep_value(self, value, connection, prepared=False): # Casts dates into the format expected by the backend - return connection.ops.value_to_db_datetime(self.to_python(value)) + if not prepared: + value = self.get_prep_value(value) + return connection.ops.value_to_db_datetime(value) def value_to_string(self, obj): val = self._get_val_from_obj(obj) @@ -632,11 +679,11 @@ class DecimalField(Field): from django.db.backends import util return util.format_number(value, self.max_digits, self.decimal_places) - def get_db_prep_save(self, value): + def get_db_prep_save(self, value, connection): return connection.ops.value_to_db_decimal(self.to_python(value), self.max_digits, self.decimal_places) - def get_db_prep_value(self, value): + def get_prep_value(self, value): return self.to_python(value) def formfield(self, **kwargs): @@ -661,6 +708,7 @@ class EmailField(CharField): class FilePathField(Field): description = ugettext_lazy("File path") + def __init__(self, verbose_name=None, name=None, path='', match=None, recursive=False, **kwargs): self.path, self.match, self.recursive = path, match, recursive kwargs['max_length'] = kwargs.get('max_length', 100) @@ -683,7 +731,7 @@ class FloatField(Field): empty_strings_allowed = False description = ugettext_lazy("Floating point number") - def get_db_prep_value(self, value): + def get_prep_value(self, value): if value is None: return None return float(value) @@ -708,7 +756,8 @@ class FloatField(Field): class IntegerField(Field): empty_strings_allowed = False description = ugettext_lazy("Integer") - def get_db_prep_value(self, value): + + def get_prep_value(self, value): if value is None: return None return int(value) @@ -776,16 +825,16 @@ class NullBooleanField(Field): raise exceptions.ValidationError( _("This value must be either None, True or False.")) - def get_db_prep_lookup(self, lookup_type, value): + def get_prep_lookup(self, lookup_type, value): # Special-case handling for filters coming from a web request (e.g. the # admin interface). Only works for scalar values (not lists). If you're # passing in a list, you might as well make things the right type when # constructing the list. if value in ('1', '0'): value = bool(int(value)) - return super(NullBooleanField, self).get_db_prep_lookup(lookup_type, value) + return super(NullBooleanField, self).get_prep_lookup(lookup_type, value) - def get_db_prep_value(self, value): + def get_prep_value(self, value): if value is None: return None return bool(value) @@ -801,6 +850,7 @@ class NullBooleanField(Field): class PositiveIntegerField(IntegerField): description = ugettext_lazy("Integer") + def get_internal_type(self): return "PositiveIntegerField" @@ -838,11 +888,13 @@ class SlugField(CharField): class SmallIntegerField(IntegerField): description = ugettext_lazy("Integer") + def get_internal_type(self): return "SmallIntegerField" class TextField(Field): description = ugettext_lazy("Text") + def get_internal_type(self): return "TextField" @@ -853,6 +905,7 @@ class TextField(Field): class TimeField(Field): description = ugettext_lazy("Time") + empty_strings_allowed = False def __init__(self, verbose_name=None, name=None, auto_now=False, auto_now_add=False, **kwargs): self.auto_now, self.auto_now_add = auto_now, auto_now_add @@ -907,9 +960,14 @@ class TimeField(Field): else: return super(TimeField, self).pre_save(model_instance, add) - def get_db_prep_value(self, value): + def get_prep_value(self, value): + return self.to_python(value) + + def get_db_prep_value(self, value, connection, prepared=False): # Casts times into the format expected by the backend - return connection.ops.value_to_db_time(self.to_python(value)) + if not prepared: + value = self.get_prep_value(value) + return connection.ops.value_to_db_time(value) def value_to_string(self, obj): val = self._get_val_from_obj(obj) @@ -926,6 +984,7 @@ class TimeField(Field): class URLField(CharField): description = ugettext_lazy("URL") + def __init__(self, verbose_name=None, name=None, verify_exists=True, **kwargs): kwargs['max_length'] = kwargs.get('max_length', 200) self.verify_exists = verify_exists @@ -938,6 +997,7 @@ class URLField(CharField): class XMLField(TextField): description = ugettext_lazy("XML text") + def __init__(self, verbose_name=None, name=None, schema_path=None, **kwargs): self.schema_path = schema_path Field.__init__(self, verbose_name, name, **kwargs) diff --git a/django/db/models/fields/files.py b/django/db/models/fields/files.py index 97cb4dc082..6dfeddbc41 100644 --- a/django/db/models/fields/files.py +++ b/django/db/models/fields/files.py @@ -235,12 +235,12 @@ class FileField(Field): def get_internal_type(self): return "FileField" - def get_db_prep_lookup(self, lookup_type, value): + def get_prep_lookup(self, lookup_type, value): if hasattr(value, 'name'): value = value.name - return super(FileField, self).get_db_prep_lookup(lookup_type, value) + return super(FileField, self).get_prep_lookup(lookup_type, value) - def get_db_prep_value(self, value): + def get_prep_value(self, value): "Returns field's value prepared for saving into a database." # Need to convert File objects provided via a form to unicode for database insertion if value is None: diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index e412e6b6b8..7cc9a03907 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -1,7 +1,8 @@ -from django.db import connection, transaction +from django.db import connection, transaction, DEFAULT_DB_ALIAS from django.db.backends import util from django.db.models import signals, get_model -from django.db.models.fields import AutoField, Field, IntegerField, PositiveIntegerField, PositiveSmallIntegerField, FieldDoesNotExist +from django.db.models.fields import (AutoField, Field, IntegerField, + PositiveIntegerField, PositiveSmallIntegerField, FieldDoesNotExist) from django.db.models.related import RelatedObject from django.db.models.query import QuerySet from django.db.models.query_utils import QueryWrapper @@ -11,10 +12,6 @@ from django.utils.functional import curry from django.core import exceptions from django import forms -try: - set -except NameError: - from sets import Set as set # Python 2.3 fallback RECURSIVE_RELATIONSHIP_CONSTANT = 'self' @@ -120,7 +117,7 @@ class RelatedField(object): if not cls._meta.abstract: self.contribute_to_related_class(other, self.related) - def get_db_prep_lookup(self, lookup_type, value): + def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False): # If we are doing a lookup on a Related Field, we must be # comparing object instances. The value should be the PK of value, # not value itself. @@ -140,11 +137,16 @@ class RelatedField(object): if field: if lookup_type in ('range', 'in'): v = [v] - v = field.get_db_prep_lookup(lookup_type, v) + v = field.get_db_prep_lookup(lookup_type, v, + connection=connection, prepared=prepared) if isinstance(v, list): v = v[0] return v + if not prepared: + value = self.get_prep_lookup(lookup_type, value) + if hasattr(value, 'get_compiler'): + value = value.get_compiler(connection=connection) if hasattr(value, 'as_sql') or hasattr(value, '_as_sql'): # If the value has a relabel_aliases method, it will need to # be invoked before the final SQL is evaluated @@ -153,7 +155,7 @@ class RelatedField(object): if hasattr(value, 'as_sql'): sql, params = value.as_sql() else: - sql, params = value._as_sql() + sql, params = value._as_sql(connection=connection) return QueryWrapper(('(%s)' % sql), params) # FIXME: lt and gt are explicitally allowed to make @@ -192,7 +194,7 @@ class SingleRelatedObjectDescriptor(object): return getattr(instance, self.cache_name) except AttributeError: params = {'%s__pk' % self.related.field.name: instance._get_pk_val()} - rel_obj = self.related.model._base_manager.get(**params) + rel_obj = self.related.model._base_manager.using(instance._state.db).get(**params) setattr(instance, self.cache_name, rel_obj) return rel_obj @@ -255,10 +257,11 @@ class ReverseSingleRelatedObjectDescriptor(object): # If the related manager indicates that it should be used for # related fields, respect that. rel_mgr = self.field.rel.to._default_manager + using = instance._state.db or DEFAULT_DB_ALIAS if getattr(rel_mgr, 'use_for_related_fields', False): - rel_obj = rel_mgr.get(**params) + rel_obj = rel_mgr.using(using).get(**params) else: - rel_obj = QuerySet(self.field.rel.to).get(**params) + rel_obj = QuerySet(self.field.rel.to).using(using).get(**params) setattr(instance, cache_name, rel_obj) return rel_obj @@ -275,6 +278,14 @@ class ReverseSingleRelatedObjectDescriptor(object): raise ValueError('Cannot assign "%r": "%s.%s" must be a "%s" instance.' % (value, instance._meta.object_name, self.field.name, self.field.rel.to._meta.object_name)) + elif value is not None and value._state.db != instance._state.db: + if instance._state.db is None: + instance._state.db = value._state.db + else:#elif value._state.db is None: + value._state.db = instance._state.db +# elif value._state.db is not None and instance._state.db is not None: +# raise ValueError('Cannot assign "%r": instance is on database "%s", value is is on database "%s"' % +# (value, instance._state.db, value._state.db)) # If we're setting the value of a OneToOneField to None, we need to clear # out the cache on any old related object. Otherwise, deleting the @@ -356,14 +367,15 @@ class ForeignRelatedObjectsDescriptor(object): class RelatedManager(superclass): def get_query_set(self): - return superclass.get_query_set(self).filter(**(self.core_filters)) + using = instance._state.db or DEFAULT_DB_ALIAS + return superclass.get_query_set(self).using(using).filter(**(self.core_filters)) def add(self, *objs): for obj in objs: if not isinstance(obj, self.model): raise TypeError, "'%s' instance expected" % self.model._meta.object_name setattr(obj, rel_field.name, instance) - obj.save() + obj.save(using=instance._state.db) add.alters_data = True def create(self, **kwargs): @@ -375,7 +387,8 @@ class ForeignRelatedObjectsDescriptor(object): # Update kwargs with the related object that this # ForeignRelatedObjectsDescriptor knows about. kwargs.update({rel_field.name: instance}) - return super(RelatedManager, self).get_or_create(**kwargs) + using = instance._state.db or DEFAULT_DB_ALIAS + return super(RelatedManager, self).using(using).get_or_create(**kwargs) get_or_create.alters_data = True # remove() and clear() are only provided if the ForeignKey can have a value of null. @@ -386,7 +399,7 @@ class ForeignRelatedObjectsDescriptor(object): # Is obj actually part of this descriptor set? if getattr(obj, rel_field.attname) == val: setattr(obj, rel_field.name, None) - obj.save() + obj.save(using=instance._state.db) else: raise rel_field.rel.to.DoesNotExist, "%r is not related to %r." % (obj, instance) remove.alters_data = True @@ -394,7 +407,7 @@ class ForeignRelatedObjectsDescriptor(object): def clear(self): for obj in self.all(): setattr(obj, rel_field.name, None) - obj.save() + obj.save(using=instance._state.db) clear.alters_data = True manager = RelatedManager() @@ -425,7 +438,7 @@ def create_many_related_manager(superclass, rel=False): raise ValueError("%r instance needs to have a primary key value before a many-to-many relationship can be used." % instance.__class__.__name__) def get_query_set(self): - return superclass.get_query_set(self)._next_is_sticky().filter(**(self.core_filters)) + return superclass.get_query_set(self).using(self.instance._state.db)._next_is_sticky().filter(**(self.core_filters)) # If the ManyToMany relation has an intermediary model, # the add and remove methods do not exist. @@ -460,14 +473,14 @@ def create_many_related_manager(superclass, rel=False): if not rel.through._meta.auto_created: opts = through._meta raise AttributeError, "Cannot use create() on a ManyToManyField which specifies an intermediary model. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name) - new_obj = super(ManyRelatedManager, self).create(**kwargs) + new_obj = super(ManyRelatedManager, self).using(self.instance._state.db).create(**kwargs) self.add(new_obj) return new_obj create.alters_data = True def get_or_create(self, **kwargs): obj, created = \ - super(ManyRelatedManager, self).get_or_create(**kwargs) + super(ManyRelatedManager, self).using(self.instance._state.db).get_or_create(**kwargs) # We only need to add() if created because if we got an object back # from get() then the relationship already exists. if created: @@ -487,12 +500,15 @@ def create_many_related_manager(superclass, rel=False): new_ids = set() for obj in objs: if isinstance(obj, self.model): +# if obj._state.db != self.instance._state.db: +# raise ValueError('Cannot add "%r": instance is on database "%s", value is is on database "%s"' % +# (obj, self.instance._state.db, obj._state.db)) new_ids.add(obj.pk) elif isinstance(obj, Model): raise TypeError, "'%s' instance expected" % self.model._meta.object_name else: new_ids.add(obj) - vals = self.through._default_manager.values_list(target_field_name, flat=True) + vals = self.through._default_manager.using(self.instance._state.db).values_list(target_field_name, flat=True) vals = vals.filter(**{ source_field_name: self._pk_val, '%s__in' % target_field_name: new_ids, @@ -501,7 +517,7 @@ def create_many_related_manager(superclass, rel=False): # Add the ones that aren't there already for obj_id in (new_ids - vals): - self.through._default_manager.create(**{ + self.through._default_manager.using(self.instance._state.db).create(**{ '%s_id' % source_field_name: self._pk_val, '%s_id' % target_field_name: obj_id, }) @@ -521,14 +537,14 @@ def create_many_related_manager(superclass, rel=False): else: old_ids.add(obj) # Remove the specified objects from the join table - self.through._default_manager.filter(**{ + self.through._default_manager.using(self.instance._state.db).filter(**{ source_field_name: self._pk_val, '%s__in' % target_field_name: old_ids }).delete() def _clear_items(self, source_field_name): # source_col_name: the PK colname in join_table for the source object - self.through._default_manager.filter(**{ + self.through._default_manager.using(self.instance._state.db).filter(**{ source_field_name: self._pk_val }).delete() @@ -728,11 +744,12 @@ class ForeignKey(RelatedField, Field): return getattr(field_default, self.rel.get_related_field().attname) return field_default - def get_db_prep_save(self, value): + def get_db_prep_save(self, value, connection): if value == '' or value == None: return None else: - return self.rel.get_related_field().get_db_prep_save(value) + return self.rel.get_related_field().get_db_prep_save(value, + connection=connection) def value_to_string(self, obj): if not obj: @@ -764,16 +781,16 @@ class ForeignKey(RelatedField, Field): self.rel.field_name = cls._meta.pk.name def formfield(self, **kwargs): + db = kwargs.pop('using', None) defaults = { 'form_class': forms.ModelChoiceField, - 'queryset': self.rel.to._default_manager.complex_filter( - self.rel.limit_choices_to), + 'queryset': self.rel.to._default_manager.using(db).complex_filter(self.rel.limit_choices_to), 'to_field_name': self.rel.field_name, } defaults.update(kwargs) return super(ForeignKey, self).formfield(**defaults) - def db_type(self): + def db_type(self, connection): # The database column type of a ForeignKey is the column type # of the field to which it points. An exception is if the ForeignKey # points to an AutoField/PositiveIntegerField/PositiveSmallIntegerField, @@ -785,8 +802,8 @@ class ForeignKey(RelatedField, Field): (not connection.features.related_fields_match_type and isinstance(rel_field, (PositiveIntegerField, PositiveSmallIntegerField)))): - return IntegerField().db_type() - return rel_field.db_type() + return IntegerField().db_type(connection=connection) + return rel_field.db_type(connection=connection) class OneToOneField(ForeignKey): """ @@ -1012,7 +1029,11 @@ class ManyToManyField(RelatedField, Field): setattr(instance, self.attname, data) def formfield(self, **kwargs): - defaults = {'form_class': forms.ModelMultipleChoiceField, 'queryset': self.rel.to._default_manager.complex_filter(self.rel.limit_choices_to)} + db = kwargs.pop('using', None) + defaults = { + 'form_class': forms.ModelMultipleChoiceField, + 'queryset': self.rel.to._default_manager.using(db).complex_filter(self.rel.limit_choices_to) + } defaults.update(kwargs) # If initial is passed in, it's a list of related objects, but the # MultipleChoiceField takes a list of IDs. @@ -1023,7 +1044,7 @@ class ManyToManyField(RelatedField, Field): defaults['initial'] = [i._get_pk_val() for i in initial] return super(ManyToManyField, self).formfield(**defaults) - def db_type(self): + def db_type(self, connection): # A ManyToManyField is not represented by a single column, # so return None. return None diff --git a/django/db/models/fields/subclassing.py b/django/db/models/fields/subclassing.py index 10add10739..bd11675ad3 100644 --- a/django/db/models/fields/subclassing.py +++ b/django/db/models/fields/subclassing.py @@ -1,11 +1,77 @@ """ -Convenience routines for creating non-trivial Field subclasses. +Convenience routines for creating non-trivial Field subclasses, as well as +backwards compatibility utilities. Add SubfieldBase as the __metaclass__ for your Field subclass, implement to_python() and the other necessary methods and everything will work seamlessly. """ -class SubfieldBase(type): +from inspect import getargspec +from warnings import warn + +def call_with_connection(func): + arg_names, varargs, varkwargs, defaults = getargspec(func) + updated = ('connection' in arg_names or varkwargs) + if not updated: + warn("A Field class whose %s method hasn't been updated to take a " + "`connection` argument." % func.__name__, + PendingDeprecationWarning, stacklevel=2) + + def inner(*args, **kwargs): + if 'connection' not in kwargs: + from django.db import connection + kwargs['connection'] = connection + warn("%s has been called without providing a connection argument. " % + func.__name__, PendingDeprecationWarning, + stacklevel=1) + if updated: + return func(*args, **kwargs) + if 'connection' in kwargs: + del kwargs['connection'] + return func(*args, **kwargs) + return inner + +def call_with_connection_and_prepared(func): + arg_names, varargs, varkwargs, defaults = getargspec(func) + updated = ( + ('connection' in arg_names or varkwargs) and + ('prepared' in arg_names or varkwargs) + ) + if not updated: + warn("A Field class whose %s method hasn't been updated to take " + "`connection` and `prepared` arguments." % func.__name__, + PendingDeprecationWarning, stacklevel=2) + + def inner(*args, **kwargs): + if 'connection' not in kwargs: + from django.db import connection + kwargs['connection'] = connection + warn("%s has been called without providing a connection argument. " % + func.__name__, PendingDeprecationWarning, + stacklevel=1) + if updated: + return func(*args, **kwargs) + if 'connection' in kwargs: + del kwargs['connection'] + if 'prepared' in kwargs: + del kwargs['prepared'] + return func(*args, **kwargs) + return inner + +class LegacyConnection(type): + """ + A metaclass to normalize arguments give to the get_db_prep_* and db_type + methods on fields. + """ + def __new__(cls, names, bases, attrs): + new_cls = super(LegacyConnection, cls).__new__(cls, names, bases, attrs) + for attr in ('db_type', 'get_db_prep_save'): + setattr(new_cls, attr, call_with_connection(getattr(new_cls, attr))) + for attr in ('get_db_prep_lookup', 'get_db_prep_value'): + setattr(new_cls, attr, call_with_connection_and_prepared(getattr(new_cls, attr))) + return new_cls + +class SubfieldBase(LegacyConnection): """ A metaclass for custom Field subclasses. This ensures the model's attribute has the descriptor protocol attached to it. @@ -26,7 +92,7 @@ class Creator(object): def __get__(self, obj, type=None): if obj is None: raise AttributeError('Can only be accessed via an instance.') - return obj.__dict__[self.field.name] + return obj.__dict__[self.field.name] def __set__(self, obj, value): obj.__dict__[self.field.name] = self.field.to_python(value) |
