diff options
| author | Russell Keith-Magee <russell@keith-magee.com> | 2010-03-20 15:02:59 +0000 |
|---|---|---|
| committer | Russell Keith-Magee <russell@keith-magee.com> | 2010-03-20 15:02:59 +0000 |
| commit | bfa080f402ddfc383abdbac96fa198e7f2f8ec59 (patch) | |
| tree | 4bfb3649f863c0bbe6c3f010bf450faecc16e1a4 /django/db | |
| parent | 4528f3988689272d511b1395efc578c0b0d9e671 (diff) | |
Fixed #12937 -- Corrected the operation of select_related() when following an reverse relation on an inherited model. Thanks to subsume for the report.
git-svn-id: http://code.djangoproject.com/svn/django/trunk@12814 bcc190cf-cafb-0310-a4f2-bffc1f526a37
Diffstat (limited to 'django/db')
| -rw-r--r-- | django/db/models/query.py | 33 | ||||
| -rw-r--r-- | django/db/models/sql/compiler.py | 6 |
2 files changed, 31 insertions, 8 deletions
diff --git a/django/db/models/query.py b/django/db/models/query.py index f6b4419d27..8adf0d555c 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -1113,7 +1113,7 @@ class EmptyQuerySet(QuerySet): def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, - requested=None, offset=0, only_load=None): + requested=None, offset=0, only_load=None, local_only=False): """ Helper function that recursively returns an object with the specified related attributes already populated. @@ -1141,6 +1141,8 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, * only_load - if the query has had only() or defer() applied, this is the list of field names that will be returned. If None, the full field list for `klass` can be assumed. + * local_only - Only populate local fields. This is used when building + following reverse select-related relations """ if max_depth and requested is None and cur_depth > max_depth: # We've recursed deeply enough; stop now. @@ -1153,9 +1155,11 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, skip = set() init_list = [] # Build the list of fields that *haven't* been requested - for field in klass._meta.fields: + for field, model in klass._meta.get_fields_with_model(): if field.name not in load_fields: skip.add(field.name) + elif local_only and model is not None: + continue else: init_list.append(field.attname) # Retrieve all the requested fields @@ -1174,7 +1178,11 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, else: # Load all fields on klass - field_count = len(klass._meta.fields) + if local_only: + field_names = [f.attname for f in klass._meta.local_fields] + else: + field_names = [f.attname for f in klass._meta.fields] + field_count = len(field_names) fields = row[index_start : index_start + field_count] # If all the select_related columns are None, then the related # object must be non-existent - set the relation to None. @@ -1182,7 +1190,7 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, if fields == (None,) * field_count: obj = None else: - obj = klass(*fields) + obj = klass(**dict(zip(field_names, fields))) # If an object was retrieved, set the database state. if obj: @@ -1229,7 +1237,7 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, next = requested[f.related_query_name()] # Recursively retrieve the data for the related object cached_row = get_cached_row(model, row, index_end, using, - max_depth, cur_depth+1, next) + max_depth, cur_depth+1, next, local_only=True) # If the recursive descent found an object, populate the # descriptor caches relevant to the object if cached_row: @@ -1242,7 +1250,20 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, # If the related object exists, populate # the descriptor cache. setattr(rel_obj, f.get_cache_name(), obj) - + # Now populate all the non-local field values + # on the related object + for rel_field,rel_model in rel_obj._meta.get_fields_with_model(): + if rel_model is not None: + setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname)) + # populate the field cache for any related object + # that has already been retrieved + if rel_field.rel: + try: + cached_obj = getattr(obj, rel_field.get_cache_name()) + setattr(rel_obj, rel_field.get_cache_name(), cached_obj) + except AttributeError: + # Related object hasn't been cached yet + pass return obj, index_end def delete_objects(seen_objs, using): diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index b7d63d381e..2fe03302a9 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -215,7 +215,7 @@ class SQLCompiler(object): return result def get_default_columns(self, with_aliases=False, col_aliases=None, - start_alias=None, opts=None, as_pairs=False): + start_alias=None, opts=None, as_pairs=False, local_only=False): """ Computes the default columns for selecting every field in the base model. Will sometimes be called to pull in related models (e.g. via @@ -240,6 +240,8 @@ class SQLCompiler(object): if start_alias: seen = {None: start_alias} for field, model in opts.get_fields_with_model(): + if local_only and model is not None: + continue if start_alias: try: alias = seen[model] @@ -643,7 +645,7 @@ class SQLCompiler(object): ) used.add(alias) columns, aliases = self.get_default_columns(start_alias=alias, - opts=model._meta, as_pairs=True) + opts=model._meta, as_pairs=True, local_only=True) self.query.related_select_cols.extend(columns) self.query.related_select_fields.extend(model._meta.fields) |
