summaryrefslogtreecommitdiff
path: root/django/db
diff options
context:
space:
mode:
authorRussell Keith-Magee <russell@keith-magee.com>2010-03-20 15:02:59 +0000
committerRussell Keith-Magee <russell@keith-magee.com>2010-03-20 15:02:59 +0000
commitbfa080f402ddfc383abdbac96fa198e7f2f8ec59 (patch)
tree4bfb3649f863c0bbe6c3f010bf450faecc16e1a4 /django/db
parent4528f3988689272d511b1395efc578c0b0d9e671 (diff)
Fixed #12937 -- Corrected the operation of select_related() when following an reverse relation on an inherited model. Thanks to subsume for the report.
git-svn-id: http://code.djangoproject.com/svn/django/trunk@12814 bcc190cf-cafb-0310-a4f2-bffc1f526a37
Diffstat (limited to 'django/db')
-rw-r--r--django/db/models/query.py33
-rw-r--r--django/db/models/sql/compiler.py6
2 files changed, 31 insertions, 8 deletions
diff --git a/django/db/models/query.py b/django/db/models/query.py
index f6b4419d27..8adf0d555c 100644
--- a/django/db/models/query.py
+++ b/django/db/models/query.py
@@ -1113,7 +1113,7 @@ class EmptyQuerySet(QuerySet):
def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
- requested=None, offset=0, only_load=None):
+ requested=None, offset=0, only_load=None, local_only=False):
"""
Helper function that recursively returns an object with the specified
related attributes already populated.
@@ -1141,6 +1141,8 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
* only_load - if the query has had only() or defer() applied,
this is the list of field names that will be returned. If None,
the full field list for `klass` can be assumed.
+ * local_only - Only populate local fields. This is used when building
+ following reverse select-related relations
"""
if max_depth and requested is None and cur_depth > max_depth:
# We've recursed deeply enough; stop now.
@@ -1153,9 +1155,11 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
skip = set()
init_list = []
# Build the list of fields that *haven't* been requested
- for field in klass._meta.fields:
+ for field, model in klass._meta.get_fields_with_model():
if field.name not in load_fields:
skip.add(field.name)
+ elif local_only and model is not None:
+ continue
else:
init_list.append(field.attname)
# Retrieve all the requested fields
@@ -1174,7 +1178,11 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
else:
# Load all fields on klass
- field_count = len(klass._meta.fields)
+ if local_only:
+ field_names = [f.attname for f in klass._meta.local_fields]
+ else:
+ field_names = [f.attname for f in klass._meta.fields]
+ field_count = len(field_names)
fields = row[index_start : index_start + field_count]
# If all the select_related columns are None, then the related
# object must be non-existent - set the relation to None.
@@ -1182,7 +1190,7 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
if fields == (None,) * field_count:
obj = None
else:
- obj = klass(*fields)
+ obj = klass(**dict(zip(field_names, fields)))
# If an object was retrieved, set the database state.
if obj:
@@ -1229,7 +1237,7 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
next = requested[f.related_query_name()]
# Recursively retrieve the data for the related object
cached_row = get_cached_row(model, row, index_end, using,
- max_depth, cur_depth+1, next)
+ max_depth, cur_depth+1, next, local_only=True)
# If the recursive descent found an object, populate the
# descriptor caches relevant to the object
if cached_row:
@@ -1242,7 +1250,20 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
# If the related object exists, populate
# the descriptor cache.
setattr(rel_obj, f.get_cache_name(), obj)
-
+ # Now populate all the non-local field values
+ # on the related object
+ for rel_field,rel_model in rel_obj._meta.get_fields_with_model():
+ if rel_model is not None:
+ setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname))
+ # populate the field cache for any related object
+ # that has already been retrieved
+ if rel_field.rel:
+ try:
+ cached_obj = getattr(obj, rel_field.get_cache_name())
+ setattr(rel_obj, rel_field.get_cache_name(), cached_obj)
+ except AttributeError:
+ # Related object hasn't been cached yet
+ pass
return obj, index_end
def delete_objects(seen_objs, using):
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index b7d63d381e..2fe03302a9 100644
--- a/django/db/models/sql/compiler.py
+++ b/django/db/models/sql/compiler.py
@@ -215,7 +215,7 @@ class SQLCompiler(object):
return result
def get_default_columns(self, with_aliases=False, col_aliases=None,
- start_alias=None, opts=None, as_pairs=False):
+ start_alias=None, opts=None, as_pairs=False, local_only=False):
"""
Computes the default columns for selecting every field in the base
model. Will sometimes be called to pull in related models (e.g. via
@@ -240,6 +240,8 @@ class SQLCompiler(object):
if start_alias:
seen = {None: start_alias}
for field, model in opts.get_fields_with_model():
+ if local_only and model is not None:
+ continue
if start_alias:
try:
alias = seen[model]
@@ -643,7 +645,7 @@ class SQLCompiler(object):
)
used.add(alias)
columns, aliases = self.get_default_columns(start_alias=alias,
- opts=model._meta, as_pairs=True)
+ opts=model._meta, as_pairs=True, local_only=True)
self.query.related_select_cols.extend(columns)
self.query.related_select_fields.extend(model._meta.fields)