diff options
Diffstat (limited to 'django/db/models/query.py')
| -rw-r--r-- | django/db/models/query.py | 229 |
1 files changed, 183 insertions, 46 deletions
diff --git a/django/db/models/query.py b/django/db/models/query.py index 53ed63ae5b..a6e702be18 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -1,8 +1,9 @@ from django.db import backend, connection, transaction from django.db.models.fields import DateField, FieldDoesNotExist -from django.db.models import signals +from django.db.models import signals, loading from django.dispatch import dispatcher from django.utils.datastructures import SortedDict +from django.contrib.contenttypes import generic import operator import re @@ -25,6 +26,9 @@ QUERY_TERMS = ( # Larger values are slightly faster at the expense of more storage space. GET_ITERATOR_CHUNK_SIZE = 100 +class EmptyResultSet(Exception): + pass + #################### # HELPER FUNCTIONS # #################### @@ -80,6 +84,7 @@ class QuerySet(object): self._filters = Q() self._order_by = None # Ordering, e.g. ('date', '-name'). If None, use model's ordering. self._select_related = False # Whether to fill cache for related objects. + self._max_related_depth = 0 # Maximum "depth" for select_related self._distinct = False # Whether the query should use SELECT DISTINCT. self._select = {} # Dictionary of attname -> SQL. self._where = [] # List of extra WHERE clauses to use. @@ -104,6 +109,8 @@ class QuerySet(object): def __getitem__(self, k): "Retrieve an item or slice from the set of results." + if not isinstance(k, (slice, int)): + raise TypeError assert (not isinstance(k, slice) and (k >= 0)) \ or (isinstance(k, slice) and (k.start is None or k.start >= 0) and (k.stop is None or k.stop >= 0)), \ "Negative indexing is not supported." @@ -163,12 +170,16 @@ class QuerySet(object): def iterator(self): "Performs the SELECT database lookup of this QuerySet." + try: + select, sql, params = self._get_sql_clause() + except EmptyResultSet: + raise StopIteration + # self._select is a dictionary, and dictionaries' key order is # undefined, so we convert it to a list of tuples. extra_select = self._select.items() cursor = connection.cursor() - select, sql, params = self._get_sql_clause() cursor.execute("SELECT " + (self._distinct and "DISTINCT " or "") + ",".join(select) + sql, params) fill_cache = self._select_related index_end = len(self.model._meta.fields) @@ -178,7 +189,8 @@ class QuerySet(object): raise StopIteration for row in rows: if fill_cache: - obj, index_end = get_cached_row(self.model, row, 0) + obj, index_end = get_cached_row(klass=self.model, row=row, + index_start=0, max_depth=self._max_related_depth) else: obj = self.model(*row[:index_end]) for i, k in enumerate(extra_select): @@ -186,13 +198,31 @@ class QuerySet(object): yield obj def count(self): - "Performs a SELECT COUNT() and returns the number of records as an integer." + """ + Performs a SELECT COUNT() and returns the number of records as an + integer. + + If the queryset is already cached (i.e. self._result_cache is set) this + simply returns the length of the cached results set to avoid multiple + SELECT COUNT(*) calls. + """ + if self._result_cache is not None: + return len(self._result_cache) + counter = self._clone() counter._order_by = () + counter._select_related = False + + offset = counter._offset + limit = counter._limit counter._offset = None counter._limit = None - counter._select_related = False - select, sql, params = counter._get_sql_clause() + + try: + select, sql, params = counter._get_sql_clause() + except EmptyResultSet: + return 0 + cursor = connection.cursor() if self._distinct: id_col = "%s.%s" % (backend.quote_name(self.model._meta.db_table), @@ -200,7 +230,16 @@ class QuerySet(object): cursor.execute("SELECT COUNT(DISTINCT(%s))" % id_col + sql, params) else: cursor.execute("SELECT COUNT(*)" + sql, params) - return cursor.fetchone()[0] + count = cursor.fetchone()[0] + + # Apply any offset and limit constraints manually, since using LIMIT or + # OFFSET in SQL doesn't change the output of COUNT. + if offset: + count = max(0, count - offset) + if limit: + count = min(limit, count) + + return count def get(self, *args, **kwargs): "Performs the SELECT and returns a single object matching the given keyword arguments." @@ -359,9 +398,9 @@ class QuerySet(object): else: return self._filter_or_exclude(None, **filter_obj) - def select_related(self, true_or_false=True): + def select_related(self, true_or_false=True, depth=0): "Returns a new QuerySet instance with '_select_related' modified." - return self._clone(_select_related=true_or_false) + return self._clone(_select_related=true_or_false, _max_related_depth=depth) def order_by(self, *field_names): "Returns a new QuerySet instance with the ordering changed." @@ -395,6 +434,7 @@ class QuerySet(object): c._filters = self._filters c._order_by = self._order_by c._select_related = self._select_related + c._max_related_depth = self._max_related_depth c._distinct = self._distinct c._select = self._select.copy() c._where = self._where[:] @@ -448,7 +488,10 @@ class QuerySet(object): # Add additional tables and WHERE clauses based on select_related. if self._select_related: - fill_table_cache(opts, select, tables, where, opts.db_table, [opts.db_table]) + fill_table_cache(opts, select, tables, where, + old_prefix=opts.db_table, + cache_tables_seen=[opts.db_table], + max_depth=self._max_related_depth) # Add any additional SELECTs. if self._select: @@ -509,22 +552,42 @@ class QuerySet(object): return select, " ".join(sql), params class ValuesQuerySet(QuerySet): - def iterator(self): - # select_related and select aren't supported in values(). + def __init__(self, *args, **kwargs): + super(ValuesQuerySet, self).__init__(*args, **kwargs) + # select_related isn't supported in values(). self._select_related = False - self._select = {} + + def iterator(self): + try: + select, sql, params = self._get_sql_clause() + except EmptyResultSet: + raise StopIteration # self._fields is a list of field names to fetch. if self._fields: - columns = [self.model._meta.get_field(f, many_to_many=False).column for f in self._fields] + #columns = [self.model._meta.get_field(f, many_to_many=False).column for f in self._fields] + if not self._select: + columns = [self.model._meta.get_field(f, many_to_many=False).column for f in self._fields] + else: + columns = [] + for f in self._fields: + if f in [field.name for field in self.model._meta.fields]: + columns.append( self.model._meta.get_field(f, many_to_many=False).column ) + elif not self._select.has_key( f ): + raise FieldDoesNotExist, '%s has no field named %r' % ( self.model._meta.object_name, f ) + field_names = self._fields else: # Default to all fields. columns = [f.column for f in self.model._meta.fields] field_names = [f.attname for f in self.model._meta.fields] - cursor = connection.cursor() - select, sql, params = self._get_sql_clause() select = ['%s.%s' % (backend.quote_name(self.model._meta.db_table), backend.quote_name(c)) for c in columns] + + # Add any additional SELECTs. + if self._select: + select.extend(['(%s) AS %s' % (quote_only_if_word(s[1]), backend.quote_name(s[0])) for s in self._select.items()]) + + cursor = connection.cursor() cursor.execute("SELECT " + (self._distinct and "DISTINCT " or "") + ",".join(select) + sql, params) while 1: rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE) @@ -545,7 +608,12 @@ class DateQuerySet(QuerySet): if self._field.null: self._where.append('%s.%s IS NOT NULL' % \ (backend.quote_name(self.model._meta.db_table), backend.quote_name(self._field.column))) - select, sql, params = self._get_sql_clause() + + try: + select, sql, params = self._get_sql_clause() + except EmptyResultSet: + raise StopIteration + sql = 'SELECT %s %s GROUP BY 1 ORDER BY 1 %s' % \ (backend.get_date_trunc_sql(self._kind, '%s.%s' % (backend.quote_name(self.model._meta.db_table), backend.quote_name(self._field.column))), sql, self._order) @@ -563,6 +631,25 @@ class DateQuerySet(QuerySet): c._order = self._order return c +class EmptyQuerySet(QuerySet): + def __init__(self, model=None): + super(EmptyQuerySet, self).__init__(model) + self._result_cache = [] + + def count(self): + return 0 + + def delete(self): + pass + + def _clone(self, klass=None, **kwargs): + c = super(EmptyQuerySet, self)._clone(klass, **kwargs) + c._result_cache = [] + return c + + def _get_sql_clause(self): + raise EmptyResultSet + class QOperator(object): "Base class for QAnd and QOr" def __init__(self, *args): @@ -571,10 +658,14 @@ class QOperator(object): def get_sql(self, opts): joins, where, params = SortedDict(), [], [] for val in self.args: - joins2, where2, params2 = val.get_sql(opts) - joins.update(joins2) - where.extend(where2) - params.extend(params2) + try: + joins2, where2, params2 = val.get_sql(opts) + joins.update(joins2) + where.extend(where2) + params.extend(params2) + except EmptyResultSet: + if not isinstance(self, QOr): + raise EmptyResultSet if where: return joins, ['(%s)' % self.operator.join(where)], params return joins, [], params @@ -628,8 +719,11 @@ class QNot(Q): self.q = q def get_sql(self, opts): - joins, where, params = self.q.get_sql(opts) - where2 = ['(NOT (%s))' % " AND ".join(where)] + try: + joins, where, params = self.q.get_sql(opts) + where2 = ['(NOT (%s))' % " AND ".join(where)] + except EmptyResultSet: + return SortedDict(), [], [] return joins, where2, params def get_where_clause(lookup_type, table_prefix, field_name, value): @@ -641,10 +735,14 @@ def get_where_clause(lookup_type, table_prefix, field_name, value): except KeyError: pass if lookup_type == 'in': - return '%s%s IN (%s)' % (table_prefix, field_name, ','.join(['%s' for v in value])) - elif lookup_type == 'range': + in_string = ','.join(['%s' for id in value]) + if in_string: + return '%s%s IN (%s)' % (table_prefix, field_name, in_string) + else: + raise EmptyResultSet + elif lookup_type in ('range', 'year'): return '%s%s BETWEEN %%s AND %%s' % (table_prefix, field_name) - elif lookup_type in ('year', 'month', 'day'): + elif lookup_type in ('month', 'day'): return "%s = %%s" % backend.get_date_extract_sql(lookup_type, table_prefix + field_name) elif lookup_type == 'isnull': return "%s%s IS %sNULL" % (table_prefix, field_name, (not value and 'NOT ' or '')) @@ -652,21 +750,33 @@ def get_where_clause(lookup_type, table_prefix, field_name, value): return backend.get_fulltext_search_sql(table_prefix + field_name) raise TypeError, "Got invalid lookup_type: %s" % repr(lookup_type) -def get_cached_row(klass, row, index_start): - "Helper function that recursively returns an object with cache filled" +def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0): + """Helper function that recursively returns an object with cache filled""" + + # If we've got a max_depth set and we've exceeded that depth, bail now. + if max_depth and cur_depth > max_depth: + return None + index_end = index_start + len(klass._meta.fields) obj = klass(*row[index_start:index_end]) for f in klass._meta.fields: if f.rel and not f.null: - rel_obj, index_end = get_cached_row(f.rel.to, row, index_end) - setattr(obj, f.get_cache_name(), rel_obj) + cached_row = get_cached_row(f.rel.to, row, index_end, max_depth, cur_depth+1) + if cached_row: + rel_obj, index_end = cached_row + setattr(obj, f.get_cache_name(), rel_obj) return obj, index_end -def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen): +def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen, max_depth=0, cur_depth=0): """ Helper function that recursively populates the select, tables and where (in place) for select_related queries. """ + + # If we've got a max_depth set and we've exceeded that depth, bail now. + if max_depth and cur_depth > max_depth: + return None + qn = backend.quote_name for f in opts.fields: if f.rel and not f.null: @@ -681,12 +791,12 @@ def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen) where.append('%s.%s = %s.%s' % \ (qn(old_prefix), qn(f.column), qn(db_table), qn(f.rel.get_related_field().column))) select.extend(['%s.%s' % (qn(db_table), qn(f2.column)) for f2 in f.rel.to._meta.fields]) - fill_table_cache(f.rel.to._meta, select, tables, where, db_table, cache_tables_seen) + fill_table_cache(f.rel.to._meta, select, tables, where, db_table, cache_tables_seen, max_depth, cur_depth+1) def parse_lookup(kwarg_items, opts): # Helper function that handles converting API kwargs # (e.g. "name__exact": "tom") to SQL. - # Returns a tuple of (tables, joins, where, params). + # Returns a tuple of (joins, where, params). # 'joins' is a sorted dictionary describing the tables that must be joined # to complete the query. The dictionary is sorted because creation order @@ -725,12 +835,14 @@ def parse_lookup(kwarg_items, opts): if len(path) < 1: raise TypeError, "Cannot parse keyword query %r" % kwarg - + if value is None: # Interpret '__exact=None' as the sql '= NULL'; otherwise, reject # all uses of None as a query value. if lookup_type != 'exact': raise ValueError, "Cannot use None as a query value" + elif callable(value): + value = value() joins2, where2, params2 = lookup_inner(path, lookup_type, value, opts, opts.db_table, None) joins.update(joins2) @@ -755,6 +867,13 @@ def find_field(name, field_list, related_query): return None return matches[0] +def field_choices(field_list, related_query): + if related_query: + choices = [f.field.related_query_name() for f in field_list] + else: + choices = [f.name for f in field_list] + return choices + def lookup_inner(path, lookup_type, value, opts, table, column): qn = backend.quote_name joins, where, params = SortedDict(), [], [] @@ -827,13 +946,23 @@ def lookup_inner(path, lookup_type, value, opts, table, column): new_opts = field.rel.to._meta new_column = new_opts.pk.column join_column = field.column - - raise FieldFound + raise FieldFound + elif path: + # For regular fields, if there are still items on the path, + # an error has been made. We munge "name" so that the error + # properly identifies the cause of the problem. + name += LOOKUP_SEPARATOR + path[0] + else: + raise FieldFound except FieldFound: # Match found, loop has been shortcut. pass else: # No match found. - raise TypeError, "Cannot resolve keyword '%s' into field" % name + choices = field_choices(current_opts.many_to_many, False) + \ + field_choices(current_opts.get_all_related_many_to_many_objects(), True) + \ + field_choices(current_opts.get_all_related_objects(), True) + \ + field_choices(current_opts.fields, False) + raise TypeError, "Cannot resolve keyword '%s' into field. Choices are: %s" % (name, ", ".join(choices)) # Check whether an intermediate join is required between current_table # and new_table. @@ -926,18 +1055,26 @@ def delete_objects(seen_objs): pk_list = [pk for pk,instance in seen_objs[cls]] for related in cls._meta.get_all_related_many_to_many_objects(): - for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): - cursor.execute("DELETE FROM %s WHERE %s IN (%s)" % \ - (qn(related.field.m2m_db_table()), - qn(related.field.m2m_reverse_name()), - ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]])), - pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]) + if not isinstance(related.field, generic.GenericRelation): + for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): + cursor.execute("DELETE FROM %s WHERE %s IN (%s)" % \ + (qn(related.field.m2m_db_table()), + qn(related.field.m2m_reverse_name()), + ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]])), + pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]) for f in cls._meta.many_to_many: + if isinstance(f, generic.GenericRelation): + from django.contrib.contenttypes.models import ContentType + query_extra = 'AND %s=%%s' % f.rel.to._meta.get_field(f.content_type_field_name).column + args_extra = [ContentType.objects.get_for_model(cls).id] + else: + query_extra = '' + args_extra = [] for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): - cursor.execute("DELETE FROM %s WHERE %s IN (%s)" % \ + cursor.execute(("DELETE FROM %s WHERE %s IN (%s)" % \ (qn(f.m2m_db_table()), qn(f.m2m_column_name()), - ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]])), - pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]) + ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]]))) + query_extra, + pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE] + args_extra) for field in cls._meta.fields: if field.rel and field.null and field.rel.to in seen_objs: for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): |
