diff options
Diffstat (limited to 'django/db/models/sql/query.py')
| -rw-r--r-- | django/db/models/sql/query.py | 179 |
1 files changed, 111 insertions, 68 deletions
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index c3c8e55793..db4e6744bf 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -19,6 +19,7 @@ from django.db.models.constants import LOOKUP_SEP from django.db.models.aggregates import refs_aggregate from django.db.models.expressions import ExpressionNode from django.db.models.fields import FieldDoesNotExist +from django.db.models.lookups import Transform from django.db.models.query_utils import Q from django.db.models.related import PathInfo from django.db.models.sql import aggregates as base_aggregates_module @@ -1028,13 +1029,16 @@ class Query(object): # Add the aggregate to the query aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary) - def prepare_lookup_value(self, value, lookup_type, can_reuse): + def prepare_lookup_value(self, value, lookups, can_reuse): + # Default lookup if none given is exact. + if len(lookups) == 0: + lookups = ['exact'] # Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all # uses of None as a query value. if value is None: - if lookup_type not in ('exact', 'iexact'): + if lookups[-1] not in ('exact', 'iexact'): raise ValueError("Cannot use None as a query value") - lookup_type = 'isnull' + lookups[-1] = 'isnull' value = True elif callable(value): warnings.warn( @@ -1055,40 +1059,54 @@ class Query(object): # stage. Using DEFAULT_DB_ALIAS isn't nice, but it is the best we # can do here. Similar thing is done in is_nullable(), too. if (connections[DEFAULT_DB_ALIAS].features.interprets_empty_strings_as_nulls and - lookup_type == 'exact' and value == ''): + lookups[-1] == 'exact' and value == ''): value = True - lookup_type = 'isnull' - return value, lookup_type + lookups[-1] = ['isnull'] + return value, lookups def solve_lookup_type(self, lookup): """ Solve the lookup type from the lookup (eg: 'foobar__id__icontains') """ - lookup_type = 'exact' # Default lookup type - lookup_parts = lookup.split(LOOKUP_SEP) - num_parts = len(lookup_parts) - if (len(lookup_parts) > 1 and lookup_parts[-1] in self.query_terms - and (not self._aggregates or lookup not in self._aggregates)): - # Traverse the lookup query to distinguish related fields from - # lookup types. - lookup_model = self.model - for counter, field_name in enumerate(lookup_parts): - try: - lookup_field = lookup_model._meta.get_field(field_name) - except FieldDoesNotExist: - # Not a field. Bail out. - lookup_type = lookup_parts.pop() - break - # Unless we're at the end of the list of lookups, let's attempt - # to continue traversing relations. - if (counter + 1) < num_parts: - try: - lookup_model = lookup_field.rel.to - except AttributeError: - # Not a related field. Bail out. - lookup_type = lookup_parts.pop() - break - return lookup_type, lookup_parts + lookup_splitted = lookup.split(LOOKUP_SEP) + if self._aggregates: + aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.aggregates) + if aggregate: + return aggregate_lookups, (), aggregate + _, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta()) + field_parts = lookup_splitted[0:len(lookup_splitted) - len(lookup_parts)] + if len(lookup_parts) == 0: + lookup_parts = ['exact'] + elif len(lookup_parts) > 1: + if not field_parts: + raise FieldError( + 'Invalid lookup "%s" for model %s".' % + (lookup, self.get_meta().model.__name__)) + return lookup_parts, field_parts, False + + def build_lookup(self, lookups, lhs, rhs): + lookups = lookups[:] + while lookups: + lookup = lookups[0] + next = lhs.get_lookup(lookup) + if next: + if len(lookups) == 1: + # This was the last lookup, so return value lookup. + if issubclass(next, Transform): + lookups.append('exact') + lhs = next(lhs, lookups) + else: + return next(lhs, rhs) + else: + lhs = next(lhs, lookups) + # A field's get_lookup() can return None to opt for backwards + # compatibility path. + elif len(lookups) > 2: + raise FieldError( + "Unsupported lookup for field '%s'" % lhs.output_type.name) + else: + return None + lookups = lookups[1:] def build_filter(self, filter_expr, branch_negated=False, current_negated=False, can_reuse=None, connector=AND): @@ -1118,21 +1136,24 @@ class Query(object): is responsible for unreffing the joins used. """ arg, value = filter_expr - lookup_type, parts = self.solve_lookup_type(arg) - if not parts: + if not arg: raise FieldError("Cannot parse keyword query %r" % arg) + lookups, parts, reffed_aggregate = self.solve_lookup_type(arg) # Work out the lookup type and remove it from the end of 'parts', # if necessary. - value, lookup_type = self.prepare_lookup_value(value, lookup_type, can_reuse) + value, lookups = self.prepare_lookup_value(value, lookups, can_reuse) used_joins = getattr(value, '_used_joins', []) clause = self.where_class() - if self._aggregates: - for alias, aggregate in self.aggregates.items(): - if alias in (parts[0], LOOKUP_SEP.join(parts)): - clause.add((aggregate, lookup_type, value), AND) - return clause, [] + if reffed_aggregate: + condition = self.build_lookup(lookups, reffed_aggregate, value) + if not condition: + # Backwards compat for custom lookups + assert len(lookups) == 1 + condition = (reffed_aggregate, lookups[0], value) + clause.add(condition, AND) + return clause, [] opts = self.get_meta() alias = self.get_initial_alias() @@ -1154,11 +1175,31 @@ class Query(object): targets, alias, join_list = self.trim_joins(sources, join_list, path) if hasattr(field, 'get_lookup_constraint'): - constraint = field.get_lookup_constraint(self.where_class, alias, targets, sources, - lookup_type, value) + # For now foreign keys get special treatment. This should be + # refactored when composite fields lands. + condition = field.get_lookup_constraint(self.where_class, alias, targets, sources, + lookups, value) + lookup_type = lookups[-1] else: - constraint = (Constraint(alias, targets[0].column, field), lookup_type, value) - clause.add(constraint, AND) + assert(len(targets) == 1) + col = Col(alias, targets[0], field) + condition = self.build_lookup(lookups, col, value) + if not condition: + # Backwards compat for custom lookups + if lookups[0] not in self.query_terms: + raise FieldError( + "Join on field '%s' not permitted. Did you " + "misspell '%s' for the lookup type?" % + (col.output_type.name, lookups[0])) + if len(lookups) > 1: + raise FieldError("Nested lookup '%s' not supported." % + LOOKUP_SEP.join(lookups)) + condition = (Constraint(alias, targets[0].column, field), lookups[0], value) + lookup_type = lookups[-1] + else: + lookup_type = condition.lookup_name + + clause.add(condition, AND) require_outer = lookup_type == 'isnull' and value is True and not current_negated if current_negated and (lookup_type != 'isnull' or value is False): @@ -1175,7 +1216,8 @@ class Query(object): # (col IS NULL OR col != someval) # <=> # NOT (col IS NOT NULL AND col = someval). - clause.add((Constraint(alias, targets[0].column, None), 'isnull', False), AND) + lookup_class = targets[0].get_lookup('isnull') + clause.add(lookup_class(Col(alias, targets[0], sources[0]), False), AND) return clause, used_joins if not require_outer else () def add_filter(self, filter_clause): @@ -1189,7 +1231,7 @@ class Query(object): if not self._aggregates: return False if not isinstance(obj, Node): - return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.aggregates) + return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.aggregates)[0] or (hasattr(obj[1], 'contains_aggregate') and obj[1].contains_aggregate(self.aggregates))) return any(self.need_having(c) for c in obj.children) @@ -1277,7 +1319,7 @@ class Query(object): needed_inner = joinpromoter.update_join_types(self) return target_clause, needed_inner - def names_to_path(self, names, opts, allow_many): + def names_to_path(self, names, opts, allow_many=True, fail_on_missing=False): """ Walks the names path and turns them PathInfo tuples. Note that a single name in 'names' can generate multiple PathInfos (m2m for @@ -1297,9 +1339,10 @@ class Query(object): try: field, model, direct, m2m = opts.get_field_by_name(name) except FieldDoesNotExist: - available = opts.get_all_field_names() + list(self.aggregate_select) - raise FieldError("Cannot resolve keyword %r into field. " - "Choices are: %s" % (name, ", ".join(available))) + # We didn't found the current field, so move position back + # one step. + pos -= 1 + break # Check if we need any joins for concrete inheritance cases (the # field lives in parent, but we are currently in one of its # children) @@ -1334,15 +1377,14 @@ class Query(object): final_field = field targets = (field,) break + if pos == -1 or (fail_on_missing and pos + 1 != len(names)): + self.raise_field_error(opts, name) + return path, final_field, targets, names[pos + 1:] - if pos != len(names) - 1: - if pos == len(names) - 2: - raise FieldError( - "Join on field %r not permitted. Did you misspell %r for " - "the lookup type?" % (name, names[pos + 1])) - else: - raise FieldError("Join on field %r not permitted." % name) - return path, final_field, targets + def raise_field_error(self, opts, name): + available = opts.get_all_field_names() + list(self.aggregate_select) + raise FieldError("Cannot resolve keyword %r into field. " + "Choices are: %s" % (name, ", ".join(available))) def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True): """ @@ -1371,8 +1413,9 @@ class Query(object): """ joins = [alias] # First, generate the path for the names - path, final_field, targets = self.names_to_path( - names, opts, allow_many) + path, final_field, targets, rest = self.names_to_path( + names, opts, allow_many, fail_on_missing=True) + # Then, add the path to the query's joins. Note that we can't trim # joins at this stage - we will need the information about join type # of the trimmed joins. @@ -1387,8 +1430,6 @@ class Query(object): alias = self.join( connection, reuse=reuse, nullable=nullable, join_field=join.join_field) joins.append(alias) - if hasattr(final_field, 'field'): - final_field = final_field.field return final_field, targets, opts, joins, path def trim_joins(self, targets, joins, path): @@ -1451,17 +1492,19 @@ class Query(object): # nothing alias, col = query.select[0].col if self.is_nullable(query.select[0].field): - query.where.add((Constraint(alias, col, query.select[0].field), 'isnull', False), AND) + lookup_class = query.select[0].field.get_lookup('isnull') + lookup = lookup_class(Col(alias, query.select[0].field, query.select[0].field), False) + query.where.add(lookup, AND) if alias in can_reuse: - pk = query.select[0].field.model._meta.pk + select_field = query.select[0].field + pk = select_field.model._meta.pk # Need to add a restriction so that outer query's filters are in effect for # the subquery, too. query.bump_prefix(self) - query.where.add( - (Constraint(query.select[0].col[0], pk.column, pk), - 'exact', Col(alias, pk.column)), - AND - ) + lookup_class = select_field.get_lookup('exact') + lookup = lookup_class(Col(query.select[0].col[0], pk, pk), + Col(alias, pk, pk)) + query.where.add(lookup, AND) condition, needed_inner = self.build_filter( ('%s__in' % trimmed_prefix, query), |
