diff options
| author | Russell Keith-Magee <russell@keith-magee.com> | 2010-01-27 13:30:29 +0000 |
|---|---|---|
| committer | Russell Keith-Magee <russell@keith-magee.com> | 2010-01-27 13:30:29 +0000 |
| commit | 58cd220f51d5e294cb9e67c12a6e9d08523e282f (patch) | |
| tree | c87c968bb69215449924efe57e8b13d70acd4fa3 /django | |
| parent | 8e8d4b5888b73e5c0b2cfc77be4c6d5898546654 (diff) | |
Fixed #7270 -- Added the ability to follow reverse OneToOneFields in select_related(). Thanks to George Vilches, Ben Davis, and Alex Gaynor for their work on various stages of this patch.
git-svn-id: http://code.djangoproject.com/svn/django/trunk@12307 bcc190cf-cafb-0310-a4f2-bffc1f526a37
Diffstat (limited to 'django')
| -rw-r--r-- | django/db/models/fields/related.py | 4 | ||||
| -rw-r--r-- | django/db/models/query.py | 74 | ||||
| -rw-r--r-- | django/db/models/query_utils.py | 18 | ||||
| -rw-r--r-- | django/db/models/related.py | 3 | ||||
| -rw-r--r-- | django/db/models/sql/compiler.py | 68 |
5 files changed, 159 insertions, 8 deletions
diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 4020d5e268..5de6fb1067 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -189,7 +189,7 @@ class SingleRelatedObjectDescriptor(object): # SingleRelatedObjectDescriptor instance. def __init__(self, related): self.related = related - self.cache_name = '_%s_cache' % related.get_accessor_name() + self.cache_name = related.get_cache_name() def __get__(self, instance, instance_type=None): if instance is None: @@ -319,7 +319,7 @@ class ReverseSingleRelatedObjectDescriptor(object): # cache. This cache also might not exist if the related object # hasn't been accessed yet. if related: - cache_name = '_%s_cache' % self.field.related.get_accessor_name() + cache_name = self.field.related.get_cache_name() try: delattr(related, cache_name) except AttributeError: diff --git a/django/db/models/query.py b/django/db/models/query.py index 3b290a6457..8cb3dbecfc 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -1116,6 +1116,29 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0, """ Helper function that recursively returns an object with the specified related attributes already populated. + + This method may be called recursively to populate deep select_related() + clauses. + + Arguments: + * klass - the class to retrieve (and instantiate) + * row - the row of data returned by the database cursor + * index_start - the index of the row at which data for this + object is known to start + * max_depth - the maximum depth to which a select_related() + relationship should be explored. + * cur_depth - the current depth in the select_related() tree. + Used in recursive calls to determin if we should dig deeper. + * requested - A dictionary describing the select_related() tree + that is to be retrieved. keys are field names; values are + dictionaries describing the keys on that related object that + are themselves to be select_related(). + * offset - the number of additional fields that are known to + exist in `row` for `klass`. This usually means the number of + annotated results on `klass`. + * only_load - if the query has had only() or defer() applied, + this is the list of field names that will be returned. If None, + the full field list for `klass` can be assumed. """ if max_depth and requested is None and cur_depth > max_depth: # We've recursed deeply enough; stop now. @@ -1127,14 +1150,18 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0, # Handle deferred fields. skip = set() init_list = [] - pk_val = row[index_start + klass._meta.pk_index()] + # Build the list of fields that *haven't* been requested for field in klass._meta.fields: if field.name not in load_fields: skip.add(field.name) else: init_list.append(field.attname) + # Retrieve all the requested fields field_count = len(init_list) fields = row[index_start : index_start + field_count] + # If all the select_related columns are None, then the related + # object must be non-existent - set the relation to None. + # Otherwise, construct the related object. if fields == (None,) * field_count: obj = None elif skip: @@ -1143,14 +1170,20 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0, else: obj = klass(*fields) else: + # Load all fields on klass field_count = len(klass._meta.fields) fields = row[index_start : index_start + field_count] + # If all the select_related columns are None, then the related + # object must be non-existent - set the relation to None. + # Otherwise, construct the related object. if fields == (None,) * field_count: obj = None else: obj = klass(*fields) index_end = index_start + field_count + offset + # Iterate over each related object, populating any + # select_related() fields for f in klass._meta.fields: if not select_related_descend(f, restricted, requested): continue @@ -1158,12 +1191,51 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0, next = requested[f.name] else: next = None + # Recursively retrieve the data for the related object cached_row = get_cached_row(f.rel.to, row, index_end, max_depth, cur_depth+1, next) + # If the recursive descent found an object, populate the + # descriptor caches relevant to the object if cached_row: rel_obj, index_end = cached_row if obj is not None: + # If the base object exists, populate the + # descriptor cache setattr(obj, f.get_cache_name(), rel_obj) + if f.unique: + # If the field is unique, populate the + # reverse descriptor cache on the related object + setattr(rel_obj, f.related.get_cache_name(), obj) + + # Now do the same, but for reverse related objects. + # Only handle the restricted case - i.e., don't do a depth + # descent into reverse relations unless explicitly requested + if restricted: + related_fields = [ + (o.field, o.model) + for o in klass._meta.get_all_related_objects() + if o.field.unique + ] + for f, model in related_fields: + if not select_related_descend(f, restricted, requested, reverse=True): + continue + next = requested[f.related_query_name()] + # Recursively retrieve the data for the related object + cached_row = get_cached_row(model, row, index_end, max_depth, + cur_depth+1, next) + # If the recursive descent found an object, populate the + # descriptor caches relevant to the object + if cached_row: + rel_obj, index_end = cached_row + if obj is not None: + # If the field is unique, populate the + # reverse descriptor cache + setattr(obj, f.related.get_cache_name(), rel_obj) + if rel_obj is not None: + # If the related object exists, populate + # the descriptor cache. + setattr(rel_obj, f.get_cache_name(), obj) + return obj, index_end def delete_objects(seen_objs, using): diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py index 9f6083ce7e..8e804ec3ef 100644 --- a/django/db/models/query_utils.py +++ b/django/db/models/query_utils.py @@ -197,19 +197,29 @@ class DeferredAttribute(object): """ instance.__dict__[self.field_name] = value -def select_related_descend(field, restricted, requested): +def select_related_descend(field, restricted, requested, reverse=False): """ Returns True if this field should be used to descend deeper for select_related() purposes. Used by both the query construction code (sql.query.fill_related_selections()) and the model instance creation code (query.get_cached_row()). + + Arguments: + * field - the field to be checked + * restricted - a boolean field, indicating if the field list has been + manually restricted using a requested clause) + * requested - The select_related() dictionary. + * reverse - boolean, True if we are checking a reverse select related """ if not field.rel: return False - if field.rel.parent_link: - return False - if restricted and field.name not in requested: + if field.rel.parent_link and not reverse: return False + if restricted: + if reverse and field.related_query_name() not in requested: + return False + if not reverse and field.name not in requested: + return False if not restricted and field.null: return False return True diff --git a/django/db/models/related.py b/django/db/models/related.py index afdf3f7b61..e4afd8a6f8 100644 --- a/django/db/models/related.py +++ b/django/db/models/related.py @@ -45,3 +45,6 @@ class RelatedObject(object): return self.field.rel.related_name or (self.opts.object_name.lower() + '_set') else: return self.field.rel.related_name or (self.opts.object_name.lower()) + + def get_cache_name(self): + return "_%s_cache" % self.get_accessor_name() diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 6a95d32259..1625a0e6c9 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -520,7 +520,7 @@ class SQLCompiler(object): # Setup for the case when only particular related fields should be # included in the related selection. - if requested is None and restricted is not False: + if requested is None: if isinstance(self.query.select_related, dict): requested = self.query.select_related restricted = True @@ -600,6 +600,72 @@ class SQLCompiler(object): self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1, used, next, restricted, new_nullable, dupe_set, avoid) + if restricted: + related_fields = [ + (o.field, o.model) + for o in opts.get_all_related_objects() + if o.field.unique + ] + for f, model in related_fields: + if not select_related_descend(f, restricted, requested, reverse=True): + continue + # The "avoid" set is aliases we want to avoid just for this + # particular branch of the recursion. They aren't permanently + # forbidden from reuse in the related selection tables (which is + # what "used" specifies). + avoid = avoid_set.copy() + dupe_set = orig_dupe_set.copy() + table = model._meta.db_table + + int_opts = opts + alias = root_alias + alias_chain = [] + chain = opts.get_base_chain(f.rel.to) + if chain is not None: + for int_model in chain: + # Proxy model have elements in base chain + # with no parents, assign the new options + # object and skip to the next base in that + # case + if not int_opts.parents[int_model]: + int_opts = int_model._meta + continue + lhs_col = int_opts.parents[int_model].column + dedupe = lhs_col in opts.duplicate_targets + if dedupe: + avoid.update(self.query.dupe_avoidance.get(id(opts), lhs_col), + ()) + dupe_set.add((opts, lhs_col)) + int_opts = int_model._meta + alias = self.query.join( + (alias, int_opts.db_table, lhs_col, int_opts.pk.column), + exclusions=used, promote=True, reuse=used + ) + alias_chain.append(alias) + for dupe_opts, dupe_col in dupe_set: + self.query.update_dupe_avoidance(dupe_opts, dupe_col, alias) + dedupe = f.column in opts.duplicate_targets + if dupe_set or dedupe: + avoid.update(self.query.dupe_avoidance.get((id(opts), f.column), ())) + if dedupe: + dupe_set.add((opts, f.column)) + alias = self.query.join( + (alias, table, f.rel.get_related_field().column, f.column), + exclusions=used.union(avoid), + promote=True + ) + used.add(alias) + columns, aliases = self.get_default_columns(start_alias=alias, + opts=model._meta, as_pairs=True) + self.query.related_select_cols.extend(columns) + self.query.related_select_fields.extend(model._meta.fields) + + next = requested.get(f.related_query_name(), {}) + new_nullable = f.null or None + + self.fill_related_selections(model._meta, table, cur_depth+1, + used, next, restricted, new_nullable) + def deferred_to_columns(self): """ Converts the self.deferred_loading data structure to mapping of table |
