summaryrefslogtreecommitdiff
path: root/django/db/models/query.py
diff options
context:
space:
mode:
authorChristopher Long <indirecthit@gmail.com>2007-06-17 22:18:54 +0000
committerChristopher Long <indirecthit@gmail.com>2007-06-17 22:18:54 +0000
commitae22b6d403dcf25098c77f0dfcf59ae58b186461 (patch)
treec37fc631e99a7e4d909d6b6d236f495003731ea7 /django/db/models/query.py
parent0cf7bc439129c66df8d64601e885f83b256b4f25 (diff)
per-object-permissions: Merged to trunk [5486] NOTE: Not fully tested, will be working on this over the next few weeks.
git-svn-id: http://code.djangoproject.com/svn/django/branches/per-object-permissions@5488 bcc190cf-cafb-0310-a4f2-bffc1f526a37
Diffstat (limited to 'django/db/models/query.py')
-rw-r--r--django/db/models/query.py229
1 files changed, 183 insertions, 46 deletions
diff --git a/django/db/models/query.py b/django/db/models/query.py
index 53ed63ae5b..a6e702be18 100644
--- a/django/db/models/query.py
+++ b/django/db/models/query.py
@@ -1,8 +1,9 @@
from django.db import backend, connection, transaction
from django.db.models.fields import DateField, FieldDoesNotExist
-from django.db.models import signals
+from django.db.models import signals, loading
from django.dispatch import dispatcher
from django.utils.datastructures import SortedDict
+from django.contrib.contenttypes import generic
import operator
import re
@@ -25,6 +26,9 @@ QUERY_TERMS = (
# Larger values are slightly faster at the expense of more storage space.
GET_ITERATOR_CHUNK_SIZE = 100
+class EmptyResultSet(Exception):
+ pass
+
####################
# HELPER FUNCTIONS #
####################
@@ -80,6 +84,7 @@ class QuerySet(object):
self._filters = Q()
self._order_by = None # Ordering, e.g. ('date', '-name'). If None, use model's ordering.
self._select_related = False # Whether to fill cache for related objects.
+ self._max_related_depth = 0 # Maximum "depth" for select_related
self._distinct = False # Whether the query should use SELECT DISTINCT.
self._select = {} # Dictionary of attname -> SQL.
self._where = [] # List of extra WHERE clauses to use.
@@ -104,6 +109,8 @@ class QuerySet(object):
def __getitem__(self, k):
"Retrieve an item or slice from the set of results."
+ if not isinstance(k, (slice, int)):
+ raise TypeError
assert (not isinstance(k, slice) and (k >= 0)) \
or (isinstance(k, slice) and (k.start is None or k.start >= 0) and (k.stop is None or k.stop >= 0)), \
"Negative indexing is not supported."
@@ -163,12 +170,16 @@ class QuerySet(object):
def iterator(self):
"Performs the SELECT database lookup of this QuerySet."
+ try:
+ select, sql, params = self._get_sql_clause()
+ except EmptyResultSet:
+ raise StopIteration
+
# self._select is a dictionary, and dictionaries' key order is
# undefined, so we convert it to a list of tuples.
extra_select = self._select.items()
cursor = connection.cursor()
- select, sql, params = self._get_sql_clause()
cursor.execute("SELECT " + (self._distinct and "DISTINCT " or "") + ",".join(select) + sql, params)
fill_cache = self._select_related
index_end = len(self.model._meta.fields)
@@ -178,7 +189,8 @@ class QuerySet(object):
raise StopIteration
for row in rows:
if fill_cache:
- obj, index_end = get_cached_row(self.model, row, 0)
+ obj, index_end = get_cached_row(klass=self.model, row=row,
+ index_start=0, max_depth=self._max_related_depth)
else:
obj = self.model(*row[:index_end])
for i, k in enumerate(extra_select):
@@ -186,13 +198,31 @@ class QuerySet(object):
yield obj
def count(self):
- "Performs a SELECT COUNT() and returns the number of records as an integer."
+ """
+ Performs a SELECT COUNT() and returns the number of records as an
+ integer.
+
+ If the queryset is already cached (i.e. self._result_cache is set) this
+ simply returns the length of the cached results set to avoid multiple
+ SELECT COUNT(*) calls.
+ """
+ if self._result_cache is not None:
+ return len(self._result_cache)
+
counter = self._clone()
counter._order_by = ()
+ counter._select_related = False
+
+ offset = counter._offset
+ limit = counter._limit
counter._offset = None
counter._limit = None
- counter._select_related = False
- select, sql, params = counter._get_sql_clause()
+
+ try:
+ select, sql, params = counter._get_sql_clause()
+ except EmptyResultSet:
+ return 0
+
cursor = connection.cursor()
if self._distinct:
id_col = "%s.%s" % (backend.quote_name(self.model._meta.db_table),
@@ -200,7 +230,16 @@ class QuerySet(object):
cursor.execute("SELECT COUNT(DISTINCT(%s))" % id_col + sql, params)
else:
cursor.execute("SELECT COUNT(*)" + sql, params)
- return cursor.fetchone()[0]
+ count = cursor.fetchone()[0]
+
+ # Apply any offset and limit constraints manually, since using LIMIT or
+ # OFFSET in SQL doesn't change the output of COUNT.
+ if offset:
+ count = max(0, count - offset)
+ if limit:
+ count = min(limit, count)
+
+ return count
def get(self, *args, **kwargs):
"Performs the SELECT and returns a single object matching the given keyword arguments."
@@ -359,9 +398,9 @@ class QuerySet(object):
else:
return self._filter_or_exclude(None, **filter_obj)
- def select_related(self, true_or_false=True):
+ def select_related(self, true_or_false=True, depth=0):
"Returns a new QuerySet instance with '_select_related' modified."
- return self._clone(_select_related=true_or_false)
+ return self._clone(_select_related=true_or_false, _max_related_depth=depth)
def order_by(self, *field_names):
"Returns a new QuerySet instance with the ordering changed."
@@ -395,6 +434,7 @@ class QuerySet(object):
c._filters = self._filters
c._order_by = self._order_by
c._select_related = self._select_related
+ c._max_related_depth = self._max_related_depth
c._distinct = self._distinct
c._select = self._select.copy()
c._where = self._where[:]
@@ -448,7 +488,10 @@ class QuerySet(object):
# Add additional tables and WHERE clauses based on select_related.
if self._select_related:
- fill_table_cache(opts, select, tables, where, opts.db_table, [opts.db_table])
+ fill_table_cache(opts, select, tables, where,
+ old_prefix=opts.db_table,
+ cache_tables_seen=[opts.db_table],
+ max_depth=self._max_related_depth)
# Add any additional SELECTs.
if self._select:
@@ -509,22 +552,42 @@ class QuerySet(object):
return select, " ".join(sql), params
class ValuesQuerySet(QuerySet):
- def iterator(self):
- # select_related and select aren't supported in values().
+ def __init__(self, *args, **kwargs):
+ super(ValuesQuerySet, self).__init__(*args, **kwargs)
+ # select_related isn't supported in values().
self._select_related = False
- self._select = {}
+
+ def iterator(self):
+ try:
+ select, sql, params = self._get_sql_clause()
+ except EmptyResultSet:
+ raise StopIteration
# self._fields is a list of field names to fetch.
if self._fields:
- columns = [self.model._meta.get_field(f, many_to_many=False).column for f in self._fields]
+ #columns = [self.model._meta.get_field(f, many_to_many=False).column for f in self._fields]
+ if not self._select:
+ columns = [self.model._meta.get_field(f, many_to_many=False).column for f in self._fields]
+ else:
+ columns = []
+ for f in self._fields:
+ if f in [field.name for field in self.model._meta.fields]:
+ columns.append( self.model._meta.get_field(f, many_to_many=False).column )
+ elif not self._select.has_key( f ):
+ raise FieldDoesNotExist, '%s has no field named %r' % ( self.model._meta.object_name, f )
+
field_names = self._fields
else: # Default to all fields.
columns = [f.column for f in self.model._meta.fields]
field_names = [f.attname for f in self.model._meta.fields]
- cursor = connection.cursor()
- select, sql, params = self._get_sql_clause()
select = ['%s.%s' % (backend.quote_name(self.model._meta.db_table), backend.quote_name(c)) for c in columns]
+
+ # Add any additional SELECTs.
+ if self._select:
+ select.extend(['(%s) AS %s' % (quote_only_if_word(s[1]), backend.quote_name(s[0])) for s in self._select.items()])
+
+ cursor = connection.cursor()
cursor.execute("SELECT " + (self._distinct and "DISTINCT " or "") + ",".join(select) + sql, params)
while 1:
rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)
@@ -545,7 +608,12 @@ class DateQuerySet(QuerySet):
if self._field.null:
self._where.append('%s.%s IS NOT NULL' % \
(backend.quote_name(self.model._meta.db_table), backend.quote_name(self._field.column)))
- select, sql, params = self._get_sql_clause()
+
+ try:
+ select, sql, params = self._get_sql_clause()
+ except EmptyResultSet:
+ raise StopIteration
+
sql = 'SELECT %s %s GROUP BY 1 ORDER BY 1 %s' % \
(backend.get_date_trunc_sql(self._kind, '%s.%s' % (backend.quote_name(self.model._meta.db_table),
backend.quote_name(self._field.column))), sql, self._order)
@@ -563,6 +631,25 @@ class DateQuerySet(QuerySet):
c._order = self._order
return c
+class EmptyQuerySet(QuerySet):
+ def __init__(self, model=None):
+ super(EmptyQuerySet, self).__init__(model)
+ self._result_cache = []
+
+ def count(self):
+ return 0
+
+ def delete(self):
+ pass
+
+ def _clone(self, klass=None, **kwargs):
+ c = super(EmptyQuerySet, self)._clone(klass, **kwargs)
+ c._result_cache = []
+ return c
+
+ def _get_sql_clause(self):
+ raise EmptyResultSet
+
class QOperator(object):
"Base class for QAnd and QOr"
def __init__(self, *args):
@@ -571,10 +658,14 @@ class QOperator(object):
def get_sql(self, opts):
joins, where, params = SortedDict(), [], []
for val in self.args:
- joins2, where2, params2 = val.get_sql(opts)
- joins.update(joins2)
- where.extend(where2)
- params.extend(params2)
+ try:
+ joins2, where2, params2 = val.get_sql(opts)
+ joins.update(joins2)
+ where.extend(where2)
+ params.extend(params2)
+ except EmptyResultSet:
+ if not isinstance(self, QOr):
+ raise EmptyResultSet
if where:
return joins, ['(%s)' % self.operator.join(where)], params
return joins, [], params
@@ -628,8 +719,11 @@ class QNot(Q):
self.q = q
def get_sql(self, opts):
- joins, where, params = self.q.get_sql(opts)
- where2 = ['(NOT (%s))' % " AND ".join(where)]
+ try:
+ joins, where, params = self.q.get_sql(opts)
+ where2 = ['(NOT (%s))' % " AND ".join(where)]
+ except EmptyResultSet:
+ return SortedDict(), [], []
return joins, where2, params
def get_where_clause(lookup_type, table_prefix, field_name, value):
@@ -641,10 +735,14 @@ def get_where_clause(lookup_type, table_prefix, field_name, value):
except KeyError:
pass
if lookup_type == 'in':
- return '%s%s IN (%s)' % (table_prefix, field_name, ','.join(['%s' for v in value]))
- elif lookup_type == 'range':
+ in_string = ','.join(['%s' for id in value])
+ if in_string:
+ return '%s%s IN (%s)' % (table_prefix, field_name, in_string)
+ else:
+ raise EmptyResultSet
+ elif lookup_type in ('range', 'year'):
return '%s%s BETWEEN %%s AND %%s' % (table_prefix, field_name)
- elif lookup_type in ('year', 'month', 'day'):
+ elif lookup_type in ('month', 'day'):
return "%s = %%s" % backend.get_date_extract_sql(lookup_type, table_prefix + field_name)
elif lookup_type == 'isnull':
return "%s%s IS %sNULL" % (table_prefix, field_name, (not value and 'NOT ' or ''))
@@ -652,21 +750,33 @@ def get_where_clause(lookup_type, table_prefix, field_name, value):
return backend.get_fulltext_search_sql(table_prefix + field_name)
raise TypeError, "Got invalid lookup_type: %s" % repr(lookup_type)
-def get_cached_row(klass, row, index_start):
- "Helper function that recursively returns an object with cache filled"
+def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0):
+ """Helper function that recursively returns an object with cache filled"""
+
+ # If we've got a max_depth set and we've exceeded that depth, bail now.
+ if max_depth and cur_depth > max_depth:
+ return None
+
index_end = index_start + len(klass._meta.fields)
obj = klass(*row[index_start:index_end])
for f in klass._meta.fields:
if f.rel and not f.null:
- rel_obj, index_end = get_cached_row(f.rel.to, row, index_end)
- setattr(obj, f.get_cache_name(), rel_obj)
+ cached_row = get_cached_row(f.rel.to, row, index_end, max_depth, cur_depth+1)
+ if cached_row:
+ rel_obj, index_end = cached_row
+ setattr(obj, f.get_cache_name(), rel_obj)
return obj, index_end
-def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen):
+def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen, max_depth=0, cur_depth=0):
"""
Helper function that recursively populates the select, tables and where (in
place) for select_related queries.
"""
+
+ # If we've got a max_depth set and we've exceeded that depth, bail now.
+ if max_depth and cur_depth > max_depth:
+ return None
+
qn = backend.quote_name
for f in opts.fields:
if f.rel and not f.null:
@@ -681,12 +791,12 @@ def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen)
where.append('%s.%s = %s.%s' % \
(qn(old_prefix), qn(f.column), qn(db_table), qn(f.rel.get_related_field().column)))
select.extend(['%s.%s' % (qn(db_table), qn(f2.column)) for f2 in f.rel.to._meta.fields])
- fill_table_cache(f.rel.to._meta, select, tables, where, db_table, cache_tables_seen)
+ fill_table_cache(f.rel.to._meta, select, tables, where, db_table, cache_tables_seen, max_depth, cur_depth+1)
def parse_lookup(kwarg_items, opts):
# Helper function that handles converting API kwargs
# (e.g. "name__exact": "tom") to SQL.
- # Returns a tuple of (tables, joins, where, params).
+ # Returns a tuple of (joins, where, params).
# 'joins' is a sorted dictionary describing the tables that must be joined
# to complete the query. The dictionary is sorted because creation order
@@ -725,12 +835,14 @@ def parse_lookup(kwarg_items, opts):
if len(path) < 1:
raise TypeError, "Cannot parse keyword query %r" % kwarg
-
+
if value is None:
# Interpret '__exact=None' as the sql '= NULL'; otherwise, reject
# all uses of None as a query value.
if lookup_type != 'exact':
raise ValueError, "Cannot use None as a query value"
+ elif callable(value):
+ value = value()
joins2, where2, params2 = lookup_inner(path, lookup_type, value, opts, opts.db_table, None)
joins.update(joins2)
@@ -755,6 +867,13 @@ def find_field(name, field_list, related_query):
return None
return matches[0]
+def field_choices(field_list, related_query):
+ if related_query:
+ choices = [f.field.related_query_name() for f in field_list]
+ else:
+ choices = [f.name for f in field_list]
+ return choices
+
def lookup_inner(path, lookup_type, value, opts, table, column):
qn = backend.quote_name
joins, where, params = SortedDict(), [], []
@@ -827,13 +946,23 @@ def lookup_inner(path, lookup_type, value, opts, table, column):
new_opts = field.rel.to._meta
new_column = new_opts.pk.column
join_column = field.column
-
- raise FieldFound
+ raise FieldFound
+ elif path:
+ # For regular fields, if there are still items on the path,
+ # an error has been made. We munge "name" so that the error
+ # properly identifies the cause of the problem.
+ name += LOOKUP_SEPARATOR + path[0]
+ else:
+ raise FieldFound
except FieldFound: # Match found, loop has been shortcut.
pass
else: # No match found.
- raise TypeError, "Cannot resolve keyword '%s' into field" % name
+ choices = field_choices(current_opts.many_to_many, False) + \
+ field_choices(current_opts.get_all_related_many_to_many_objects(), True) + \
+ field_choices(current_opts.get_all_related_objects(), True) + \
+ field_choices(current_opts.fields, False)
+ raise TypeError, "Cannot resolve keyword '%s' into field. Choices are: %s" % (name, ", ".join(choices))
# Check whether an intermediate join is required between current_table
# and new_table.
@@ -926,18 +1055,26 @@ def delete_objects(seen_objs):
pk_list = [pk for pk,instance in seen_objs[cls]]
for related in cls._meta.get_all_related_many_to_many_objects():
- for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
- cursor.execute("DELETE FROM %s WHERE %s IN (%s)" % \
- (qn(related.field.m2m_db_table()),
- qn(related.field.m2m_reverse_name()),
- ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]])),
- pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE])
+ if not isinstance(related.field, generic.GenericRelation):
+ for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
+ cursor.execute("DELETE FROM %s WHERE %s IN (%s)" % \
+ (qn(related.field.m2m_db_table()),
+ qn(related.field.m2m_reverse_name()),
+ ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]])),
+ pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE])
for f in cls._meta.many_to_many:
+ if isinstance(f, generic.GenericRelation):
+ from django.contrib.contenttypes.models import ContentType
+ query_extra = 'AND %s=%%s' % f.rel.to._meta.get_field(f.content_type_field_name).column
+ args_extra = [ContentType.objects.get_for_model(cls).id]
+ else:
+ query_extra = ''
+ args_extra = []
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
- cursor.execute("DELETE FROM %s WHERE %s IN (%s)" % \
+ cursor.execute(("DELETE FROM %s WHERE %s IN (%s)" % \
(qn(f.m2m_db_table()), qn(f.m2m_column_name()),
- ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]])),
- pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE])
+ ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]]))) + query_extra,
+ pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE] + args_extra)
for field in cls._meta.fields:
if field.rel and field.null and field.rel.to in seen_objs:
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):