diff options
| author | Loic Bistuer <loic.bistuer@sixmedia.com> | 2013-09-19 00:31:07 +0700 |
|---|---|---|
| committer | Anssi Kääriäinen <akaariai@gmail.com> | 2013-09-25 21:15:59 +0300 |
| commit | 04a2a6b0f9cb6bb98edfe84bf4361216d60a4e38 (patch) | |
| tree | e33452e2b614bbe2f707927ac03f2389f43e512c /django/db | |
| parent | 83554b018ef283827c0e7459ab934d447b3419d5 (diff) | |
Fixed #3871 -- Custom managers when traversing reverse relations.
Diffstat (limited to 'django/db')
| -rw-r--r-- | django/db/models/fields/related.py | 188 |
1 files changed, 108 insertions, 80 deletions
diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index fd9e8fa4d8..23959f6b18 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -365,6 +365,92 @@ class ReverseSingleRelatedObjectDescriptor(six.with_metaclass(RenameRelatedObjec setattr(value, self.field.related.get_cache_name(), instance) +def create_foreign_related_manager(superclass, rel_field, rel_model): + class RelatedManager(superclass): + def __init__(self, instance): + super(RelatedManager, self).__init__() + self.instance = instance + self.core_filters = {'%s__exact' % rel_field.name: instance} + self.model = rel_model + + def __call__(self, **kwargs): + # We use **kwargs rather than a kwarg argument to enforce the + # `manager='manager_name'` syntax. + manager = getattr(self.model, kwargs.pop('manager')) + manager_class = create_foreign_related_manager(manager.__class__, rel_field, rel_model) + return manager_class(self.instance) + + def get_queryset(self): + try: + return self.instance._prefetched_objects_cache[rel_field.related_query_name()] + except (AttributeError, KeyError): + db = self._db or router.db_for_read(self.model, instance=self.instance) + qs = super(RelatedManager, self).get_queryset().using(db).filter(**self.core_filters) + empty_strings_as_null = connections[db].features.interprets_empty_strings_as_nulls + for field in rel_field.foreign_related_fields: + val = getattr(self.instance, field.attname) + if val is None or (val == '' and empty_strings_as_null): + return qs.none() + qs._known_related_objects = {rel_field: {self.instance.pk: self.instance}} + return qs + + def get_prefetch_queryset(self, instances): + rel_obj_attr = rel_field.get_local_related_value + instance_attr = rel_field.get_foreign_related_value + instances_dict = dict((instance_attr(inst), inst) for inst in instances) + db = self._db or router.db_for_read(self.model, instance=instances[0]) + query = {'%s__in' % rel_field.name: instances} + qs = super(RelatedManager, self).get_queryset().using(db).filter(**query) + # Since we just bypassed this class' get_queryset(), we must manage + # the reverse relation manually. + for rel_obj in qs: + instance = instances_dict[rel_obj_attr(rel_obj)] + setattr(rel_obj, rel_field.name, instance) + cache_name = rel_field.related_query_name() + return qs, rel_obj_attr, instance_attr, False, cache_name + + def add(self, *objs): + for obj in objs: + if not isinstance(obj, self.model): + raise TypeError("'%s' instance expected, got %r" % (self.model._meta.object_name, obj)) + setattr(obj, rel_field.name, self.instance) + obj.save() + add.alters_data = True + + def create(self, **kwargs): + kwargs[rel_field.name] = self.instance + db = router.db_for_write(self.model, instance=self.instance) + return super(RelatedManager, self.db_manager(db)).create(**kwargs) + create.alters_data = True + + def get_or_create(self, **kwargs): + # Update kwargs with the related object that this + # ForeignRelatedObjectsDescriptor knows about. + kwargs[rel_field.name] = self.instance + db = router.db_for_write(self.model, instance=self.instance) + return super(RelatedManager, self.db_manager(db)).get_or_create(**kwargs) + get_or_create.alters_data = True + + # remove() and clear() are only provided if the ForeignKey can have a value of null. + if rel_field.null: + def remove(self, *objs): + val = rel_field.get_foreign_related_value(self.instance) + for obj in objs: + # Is obj actually part of this descriptor set? + if rel_field.get_local_related_value(obj) == val: + setattr(obj, rel_field.name, None) + obj.save() + else: + raise rel_field.rel.to.DoesNotExist("%r is not related to %r." % (obj, self.instance)) + remove.alters_data = True + + def clear(self): + self.update(**{rel_field.name: None}) + clear.alters_data = True + + return RelatedManager + + class ForeignRelatedObjectsDescriptor(object): # This class provides the functionality that makes the related-object # managers available as attributes on a model class, for fields that have @@ -392,86 +478,11 @@ class ForeignRelatedObjectsDescriptor(object): def related_manager_cls(self): # Dynamically create a class that subclasses the related model's default # manager. - superclass = self.related.model._default_manager.__class__ - rel_field = self.related.field - rel_model = self.related.model - - class RelatedManager(superclass): - def __init__(self, instance): - super(RelatedManager, self).__init__() - self.instance = instance - self.core_filters = {'%s__exact' % rel_field.name: instance} - self.model = rel_model - - def get_queryset(self): - try: - return self.instance._prefetched_objects_cache[rel_field.related_query_name()] - except (AttributeError, KeyError): - db = self._db or router.db_for_read(self.model, instance=self.instance) - qs = super(RelatedManager, self).get_queryset().using(db).filter(**self.core_filters) - empty_strings_as_null = connections[db].features.interprets_empty_strings_as_nulls - for field in rel_field.foreign_related_fields: - val = getattr(self.instance, field.attname) - if val is None or (val == '' and empty_strings_as_null): - return qs.none() - qs._known_related_objects = {rel_field: {self.instance.pk: self.instance}} - return qs - - def get_prefetch_queryset(self, instances): - rel_obj_attr = rel_field.get_local_related_value - instance_attr = rel_field.get_foreign_related_value - instances_dict = dict((instance_attr(inst), inst) for inst in instances) - db = self._db or router.db_for_read(self.model, instance=instances[0]) - query = {'%s__in' % rel_field.name: instances} - qs = super(RelatedManager, self).get_queryset().using(db).filter(**query) - # Since we just bypassed this class' get_queryset(), we must manage - # the reverse relation manually. - for rel_obj in qs: - instance = instances_dict[rel_obj_attr(rel_obj)] - setattr(rel_obj, rel_field.name, instance) - cache_name = rel_field.related_query_name() - return qs, rel_obj_attr, instance_attr, False, cache_name - - def add(self, *objs): - for obj in objs: - if not isinstance(obj, self.model): - raise TypeError("'%s' instance expected, got %r" % (self.model._meta.object_name, obj)) - setattr(obj, rel_field.name, self.instance) - obj.save() - add.alters_data = True - - def create(self, **kwargs): - kwargs[rel_field.name] = self.instance - db = router.db_for_write(self.model, instance=self.instance) - return super(RelatedManager, self.db_manager(db)).create(**kwargs) - create.alters_data = True - - def get_or_create(self, **kwargs): - # Update kwargs with the related object that this - # ForeignRelatedObjectsDescriptor knows about. - kwargs[rel_field.name] = self.instance - db = router.db_for_write(self.model, instance=self.instance) - return super(RelatedManager, self.db_manager(db)).get_or_create(**kwargs) - get_or_create.alters_data = True - - # remove() and clear() are only provided if the ForeignKey can have a value of null. - if rel_field.null: - def remove(self, *objs): - val = rel_field.get_foreign_related_value(self.instance) - for obj in objs: - # Is obj actually part of this descriptor set? - if rel_field.get_local_related_value(obj) == val: - setattr(obj, rel_field.name, None) - obj.save() - else: - raise rel_field.rel.to.DoesNotExist("%r is not related to %r." % (obj, self.instance)) - remove.alters_data = True - - def clear(self): - self.update(**{rel_field.name: None}) - clear.alters_data = True - - return RelatedManager + return create_foreign_related_manager( + self.related.model._default_manager.__class__, + self.related.field, + self.related.model, + ) def create_many_related_manager(superclass, rel): @@ -513,6 +524,23 @@ def create_many_related_manager(superclass, rel): "a many-to-many relationship can be used." % instance.__class__.__name__) + def __call__(self, **kwargs): + # We use **kwargs rather than a kwarg argument to enforce the + # `manager='manager_name'` syntax. + manager = getattr(self.model, kwargs.pop('manager')) + manager_class = create_many_related_manager(manager.__class__, rel) + return manager_class( + model=self.model, + query_field_name=self.query_field_name, + instance=self.instance, + symmetrical=self.symmetrical, + source_field_name=self.source_field_name, + target_field_name=self.target_field_name, + reverse=self.reverse, + through=self.through, + prefetch_cache_name=self.prefetch_cache_name, + ) + def get_queryset(self): try: return self.instance._prefetched_objects_cache[self.prefetch_cache_name] |
