summaryrefslogtreecommitdiff
path: root/django
diff options
context:
space:
mode:
authorMalcolm Tredinnick <malcolm.tredinnick@gmail.com>2008-04-27 02:50:16 +0000
committerMalcolm Tredinnick <malcolm.tredinnick@gmail.com>2008-04-27 02:50:16 +0000
commit9c52d56f6f8a9cdafb231adf9f4110473099c9b5 (patch)
treeeeded174bec983e4415f5f52f187b3d5d9a1882d /django
parentc91a30f00fd182faf8ca5c03cd7dbcf8b735b458 (diff)
Merged the queryset-refactor branch into trunk.
This is a big internal change, but mostly backwards compatible with existing code. Also adds a couple of new features. Fixed #245, #1050, #1656, #1801, #2076, #2091, #2150, #2253, #2306, #2400, #2430, #2482, #2496, #2676, #2737, #2874, #2902, #2939, #3037, #3141, #3288, #3440, #3592, #3739, #4088, #4260, #4289, #4306, #4358, #4464, #4510, #4858, #5012, #5020, #5261, #5295, #5321, #5324, #5325, #5555, #5707, #5796, #5817, #5987, #6018, #6074, #6088, #6154, #6177, #6180, #6203, #6658 git-svn-id: http://code.djangoproject.com/svn/django/trunk@7477 bcc190cf-cafb-0310-a4f2-bffc1f526a37
Diffstat (limited to 'django')
-rw-r--r--django/contrib/admin/views/main.py18
-rw-r--r--django/contrib/contenttypes/generic.py5
-rw-r--r--django/core/exceptions.py5
-rw-r--r--django/core/management/sql.py24
-rw-r--r--django/core/management/validation.py30
-rw-r--r--django/core/serializers/base.py2
-rw-r--r--django/db/__init__.py14
-rw-r--r--django/db/backends/__init__.py39
-rw-r--r--django/db/backends/mysql/base.py5
-rw-r--r--django/db/backends/mysql_old/base.py5
-rw-r--r--django/db/backends/oracle/base.py274
-rw-r--r--django/db/backends/oracle/query.py151
-rw-r--r--django/db/backends/postgresql/operations.py3
-rw-r--r--django/db/backends/sqlite3/base.py3
-rw-r--r--django/db/models/base.py375
-rw-r--r--django/db/models/fields/__init__.py43
-rw-r--r--django/db/models/fields/proxy.py16
-rw-r--r--django/db/models/fields/related.py255
-rw-r--r--django/db/models/manager.py40
-rw-r--r--django/db/models/options.py318
-rw-r--r--django/db/models/query.py1401
-rw-r--r--django/db/models/query_utils.py50
-rw-r--r--django/db/models/sql/__init__.py7
-rw-r--r--django/db/models/sql/constants.py36
-rw-r--r--django/db/models/sql/datastructures.py103
-rw-r--r--django/db/models/sql/query.py1504
-rw-r--r--django/db/models/sql/subqueries.py367
-rw-r--r--django/db/models/sql/where.py171
-rw-r--r--django/utils/tree.py134
29 files changed, 3828 insertions, 1570 deletions
diff --git a/django/contrib/admin/views/main.py b/django/contrib/admin/views/main.py
index 55f8ae4c32..2b7a690a83 100644
--- a/django/contrib/admin/views/main.py
+++ b/django/contrib/admin/views/main.py
@@ -8,7 +8,7 @@ from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist, Per
from django.core.paginator import QuerySetPaginator, InvalidPage
from django.shortcuts import get_object_or_404, render_to_response
from django.db import models
-from django.db.models.query import handle_legacy_orderlist, QuerySet
+from django.db.models.query import QuerySet
from django.http import Http404, HttpResponse, HttpResponseRedirect
from django.utils.html import escape
from django.utils.text import capfirst, get_text_list
@@ -627,7 +627,7 @@ class ChangeList(object):
# Perform a slight optimization: Check to see whether any filters were
# given. If not, use paginator.hits to calculate the number of objects,
# because we've already done paginator.hits and the value is cached.
- if isinstance(self.query_set._filters, models.Q) and not self.query_set._filters.kwargs:
+ if not self.query_set.query.where:
full_result_count = result_count
else:
full_result_count = self.manager.count()
@@ -653,15 +653,12 @@ class ChangeList(object):
def get_ordering(self):
lookup_opts, params = self.lookup_opts, self.params
- # For ordering, first check the "ordering" parameter in the admin options,
- # then check the object's default ordering. If neither of those exist,
- # order descending by ID by default. Finally, look for manually-specified
- # ordering from the query string.
+ # For ordering, first check the "ordering" parameter in the admin
+ # options, then check the object's default ordering. If neither of
+ # those exist, order descending by ID by default. Finally, look for
+ # manually-specified ordering from the query string.
ordering = lookup_opts.admin.ordering or lookup_opts.ordering or ['-' + lookup_opts.pk.name]
- # Normalize it to new-style ordering.
- ordering = handle_legacy_orderlist(ordering)
-
if ordering[0].startswith('-'):
order_field, order_type = ordering[0][1:], 'desc'
else:
@@ -753,8 +750,7 @@ class ChangeList(object):
for bit in self.query.split():
or_queries = [models.Q(**{construct_search(field_name): bit}) for field_name in self.lookup_opts.admin.search_fields]
other_qs = QuerySet(self.model)
- if qs._select_related:
- other_qs = other_qs.select_related()
+ other_qs.dup_select_related(qs)
other_qs = other_qs.filter(reduce(operator.or_, or_queries))
qs = qs & other_qs
diff --git a/django/contrib/contenttypes/generic.py b/django/contrib/contenttypes/generic.py
index 1282cc3df8..e91be70d1b 100644
--- a/django/contrib/contenttypes/generic.py
+++ b/django/contrib/contenttypes/generic.py
@@ -154,6 +154,11 @@ class GenericRelation(RelatedField, Field):
def get_internal_type(self):
return "ManyToManyField"
+ def db_type(self):
+ # Since we're simulating a ManyToManyField, in effect, best return the
+ # same db_type as well.
+ return None
+
class ReverseGenericRelatedObjectsDescriptor(object):
"""
This class provides the functionality that makes the related-object
diff --git a/django/core/exceptions.py b/django/core/exceptions.py
index d9fc326cf2..e5df8caca8 100644
--- a/django/core/exceptions.py
+++ b/django/core/exceptions.py
@@ -27,3 +27,8 @@ class MiddlewareNotUsed(Exception):
class ImproperlyConfigured(Exception):
"Django is somehow improperly configured"
pass
+
+class FieldError(Exception):
+ """Some kind of problem with a model field."""
+ pass
+
diff --git a/django/core/management/sql.py b/django/core/management/sql.py
index a3d02696c9..1ccb100361 100644
--- a/django/core/management/sql.py
+++ b/django/core/management/sql.py
@@ -26,7 +26,7 @@ def django_table_list(only_existing=False):
for app in models.get_apps():
for model in models.get_models(app):
tables.append(model._meta.db_table)
- tables.extend([f.m2m_db_table() for f in model._meta.many_to_many])
+ tables.extend([f.m2m_db_table() for f in model._meta.local_many_to_many])
if only_existing:
existing = table_list()
tables = [t for t in tables if t in existing]
@@ -54,12 +54,12 @@ def sequence_list():
for app in apps:
for model in models.get_models(app):
- for f in model._meta.fields:
+ for f in model._meta.local_fields:
if isinstance(f, models.AutoField):
sequence_list.append({'table': model._meta.db_table, 'column': f.column})
break # Only one AutoField is allowed per model, so don't bother continuing.
- for f in model._meta.many_to_many:
+ for f in model._meta.local_many_to_many:
sequence_list.append({'table': f.m2m_db_table(), 'column': None})
return sequence_list
@@ -149,7 +149,7 @@ def sql_delete(app, style):
if cursor and table_name_converter(model._meta.db_table) in table_names:
# The table exists, so it needs to be dropped
opts = model._meta
- for f in opts.fields:
+ for f in opts.local_fields:
if f.rel and f.rel.to not in to_delete:
references_to_delete.setdefault(f.rel.to, []).append( (model, f) )
@@ -181,7 +181,7 @@ def sql_delete(app, style):
# Output DROP TABLE statements for many-to-many tables.
for model in app_models:
opts = model._meta
- for f in opts.many_to_many:
+ for f in opts.local_many_to_many:
if isinstance(f.rel, generic.GenericRel):
continue
if cursor and table_name_converter(f.m2m_db_table()) in table_names:
@@ -258,7 +258,7 @@ def sql_model_create(model, style, known_models=set()):
pending_references = {}
qn = connection.ops.quote_name
inline_references = connection.features.inline_fk_references
- for f in opts.fields:
+ for f in opts.local_fields:
col_type = f.db_type()
tablespace = f.db_tablespace or opts.db_tablespace
if col_type is None:
@@ -294,14 +294,8 @@ def sql_model_create(model, style, known_models=set()):
style.SQL_COLTYPE(models.IntegerField().db_type()) + ' ' + \
style.SQL_KEYWORD('NULL'))
for field_constraints in opts.unique_together:
- constraint_output = [style.SQL_KEYWORD('UNIQUE')]
- constraint_output.append('(%s)' % \
+ table_output.append(style.SQL_KEYWORD('UNIQUE') + ' (%s)' % \
", ".join([style.SQL_FIELD(qn(opts.get_field(f).column)) for f in field_constraints]))
- if opts.db_tablespace and connection.features.supports_tablespaces \
- and connection.features.autoindexes_primary_keys:
- constraint_output.append(connection.ops.tablespace_sql(
- opts.db_tablespace, inline=True))
- table_output.append(' '.join(constraint_output))
full_statement = [style.SQL_KEYWORD('CREATE TABLE') + ' ' + style.SQL_TABLE(qn(opts.db_table)) + ' (']
for i, line in enumerate(table_output): # Combine and add commas.
@@ -359,7 +353,7 @@ def many_to_many_sql_for_model(model, style):
final_output = []
qn = connection.ops.quote_name
inline_references = connection.features.inline_fk_references
- for f in opts.many_to_many:
+ for f in opts.local_many_to_many:
if not isinstance(f.rel, generic.GenericRel):
tablespace = f.db_tablespace or opts.db_tablespace
if tablespace and connection.features.supports_tablespaces and connection.features.autoindexes_primary_keys:
@@ -466,7 +460,7 @@ def sql_indexes_for_model(model, style):
output = []
qn = connection.ops.quote_name
- for f in model._meta.fields:
+ for f in model._meta.local_fields:
if f.db_index and not ((f.primary_key or f.unique) and connection.features.autoindexes_primary_keys):
unique = f.unique and 'UNIQUE ' or ''
tablespace = f.db_tablespace or model._meta.db_tablespace
diff --git a/django/core/management/validation.py b/django/core/management/validation.py
index bc9faae056..cd1f84f34b 100644
--- a/django/core/management/validation.py
+++ b/django/core/management/validation.py
@@ -32,7 +32,7 @@ def get_validation_errors(outfile, app=None):
opts = cls._meta
# Do field-specific validation.
- for f in opts.fields:
+ for f in opts.local_fields:
if f.name == 'id' and not f.primary_key and opts.pk.name == 'id':
e.add(opts, '"%s": You can\'t use "id" as a field name, because each model automatically gets an "id" field if none of the fields have primary_key=True. You need to either remove/rename your "id" field or add primary_key=True to a field.' % f.name)
if f.name.endswith('_'):
@@ -69,8 +69,8 @@ def get_validation_errors(outfile, app=None):
if db_version < (5, 0, 3) and isinstance(f, (models.CharField, models.CommaSeparatedIntegerField, models.SlugField)) and f.max_length > 255:
e.add(opts, '"%s": %s cannot have a "max_length" greater than 255 when you are using a version of MySQL prior to 5.0.3 (you are using %s).' % (f.name, f.__class__.__name__, '.'.join([str(n) for n in db_version[:3]])))
- # Check to see if the related field will clash with any
- # existing fields, m2m fields, m2m related objects or related objects
+ # Check to see if the related field will clash with any existing
+ # fields, m2m fields, m2m related objects or related objects
if f.rel:
if f.rel.to not in models.get_models():
e.add(opts, "'%s' has relation with model %s, which has not been installed" % (f.name, f.rel.to))
@@ -87,7 +87,7 @@ def get_validation_errors(outfile, app=None):
e.add(opts, "Accessor for field '%s' clashes with field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.name, f.name))
if r.name == rel_query_name:
e.add(opts, "Reverse query name for field '%s' clashes with field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.name, f.name))
- for r in rel_opts.many_to_many:
+ for r in rel_opts.local_many_to_many:
if r.name == rel_name:
e.add(opts, "Accessor for field '%s' clashes with m2m field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.name, f.name))
if r.name == rel_query_name:
@@ -104,9 +104,10 @@ def get_validation_errors(outfile, app=None):
if r.get_accessor_name() == rel_query_name:
e.add(opts, "Reverse query name for field '%s' clashes with related field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.get_accessor_name(), f.name))
- for i, f in enumerate(opts.many_to_many):
+ for i, f in enumerate(opts.local_many_to_many):
# Check to see if the related m2m field will clash with any
- # existing fields, m2m fields, m2m related objects or related objects
+ # existing fields, m2m fields, m2m related objects or related
+ # objects
if f.rel.to not in models.get_models():
e.add(opts, "'%s' has m2m relation with model %s, which has not been installed" % (f.name, f.rel.to))
# it is a string and we could not find the model it refers to
@@ -117,17 +118,17 @@ def get_validation_errors(outfile, app=None):
rel_opts = f.rel.to._meta
rel_name = RelatedObject(f.rel.to, cls, f).get_accessor_name()
rel_query_name = f.related_query_name()
- # If rel_name is none, there is no reverse accessor.
- # (This only occurs for symmetrical m2m relations to self).
- # If this is the case, there are no clashes to check for this field, as
- # there are no reverse descriptors for this field.
+ # If rel_name is none, there is no reverse accessor (this only
+ # occurs for symmetrical m2m relations to self). If this is the
+ # case, there are no clashes to check for this field, as there are
+ # no reverse descriptors for this field.
if rel_name is not None:
for r in rel_opts.fields:
if r.name == rel_name:
e.add(opts, "Accessor for m2m field '%s' clashes with field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.name, f.name))
if r.name == rel_query_name:
e.add(opts, "Reverse query name for m2m field '%s' clashes with field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.name, f.name))
- for r in rel_opts.many_to_many:
+ for r in rel_opts.local_many_to_many:
if r.name == rel_name:
e.add(opts, "Accessor for m2m field '%s' clashes with m2m field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.name, f.name))
if r.name == rel_query_name:
@@ -200,7 +201,10 @@ def get_validation_errors(outfile, app=None):
field_name = field_name[1:]
if opts.order_with_respect_to and field_name == '_order':
continue
- if '.' in field_name: continue # Skip ordering in the format 'table.field'.
+ # Skip ordering in the format field1__field2 (FIXME: checking
+ # this format would be nice, but it's a little fiddly).
+ if '_' in field_name:
+ continue
try:
opts.get_field(field_name, many_to_many=False)
except models.FieldDoesNotExist:
@@ -228,5 +232,7 @@ def get_validation_errors(outfile, app=None):
else:
if isinstance(f.rel, models.ManyToManyRel):
e.add(opts, '"unique_together" refers to %s. ManyToManyFields are not supported in unique_together.' % f.name)
+ if f not in opts.local_fields:
+ e.add(opts, '"unique_together" refers to %s. This is not in the same model as the unique_together statement.' % f.name)
return len(e.errors)
diff --git a/django/core/serializers/base.py b/django/core/serializers/base.py
index 2c92e2afad..a79497ecec 100644
--- a/django/core/serializers/base.py
+++ b/django/core/serializers/base.py
@@ -165,7 +165,7 @@ class DeserializedObject(object):
# This ensures that the data that is deserialized is literally
# what came from the file, not post-processed by pre_save/save
# methods.
- models.Model.save(self.object, raw=True)
+ models.Model.save_base(self.object, raw=True)
if self.m2m_data and save_m2m:
for accessor_name, object_list in self.m2m_data.items():
setattr(self.object, accessor_name, object_list)
diff --git a/django/db/__init__.py b/django/db/__init__.py
index 8f75e0d7b8..95dd36822e 100644
--- a/django/db/__init__.py
+++ b/django/db/__init__.py
@@ -11,16 +11,18 @@ if not settings.DATABASE_ENGINE:
settings.DATABASE_ENGINE = 'dummy'
try:
- # Most of the time, the database backend will be one of the official
+ # Most of the time, the database backend will be one of the official
# backends that ships with Django, so look there first.
_import_path = 'django.db.backends.'
backend = __import__('%s%s.base' % (_import_path, settings.DATABASE_ENGINE), {}, {}, [''])
+ creation = __import__('%s%s.creation' % (_import_path, settings.DATABASE_ENGINE), {}, {}, [''])
except ImportError, e:
- # If the import failed, we might be looking for a database backend
+ # If the import failed, we might be looking for a database backend
# distributed external to Django. So we'll try that next.
try:
_import_path = ''
backend = __import__('%s.base' % settings.DATABASE_ENGINE, {}, {}, [''])
+ creation = __import__('%s.creation' % settings.DATABASE_ENGINE, {}, {}, [''])
except ImportError, e_user:
# The database backend wasn't found. Display a helpful error message
# listing all possible (built-in) database backends.
@@ -37,10 +39,12 @@ def _import_database_module(import_path='', module_name=''):
"""Lazily import a database module when requested."""
return __import__('%s%s.%s' % (import_path, settings.DATABASE_ENGINE, module_name), {}, {}, [''])
-# We don't want to import the introspect/creation modules unless
-# someone asks for 'em, so lazily load them on demmand.
+# We don't want to import the introspect module unless someone asks for it, so
+# lazily load it on demmand.
get_introspection_module = curry(_import_database_module, _import_path, 'introspection')
-get_creation_module = curry(_import_database_module, _import_path, 'creation')
+
+def get_creation_module():
+ return creation
# We want runshell() to work the same way, but we have to treat it a
# little differently (since it just runs instead of returning a module like
diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py
index be1776e65f..8917fc3b23 100644
--- a/django/db/backends/__init__.py
+++ b/django/db/backends/__init__.py
@@ -49,7 +49,8 @@ class BaseDatabaseFeatures(object):
supports_constraints = True
supports_tablespaces = False
uses_case_insensitive_names = False
- uses_custom_queryset = False
+ uses_custom_query_class = False
+ empty_fetchmany_value = []
class BaseDatabaseOperations(object):
"""
@@ -86,10 +87,9 @@ class BaseDatabaseOperations(object):
Returns the SQL necessary to cast a datetime value so that it will be
retrieved as a Python datetime object instead of a string.
- This SQL should include a '%s' in place of the field's name. This
- method should return None if no casting is necessary.
+ This SQL should include a '%s' in place of the field's name.
"""
- return None
+ return "%s"
def deferrable_sql(self):
"""
@@ -169,6 +169,14 @@ class BaseDatabaseOperations(object):
sql += " OFFSET %s" % offset
return sql
+ def lookup_cast(self, lookup_type):
+ """
+ Returns the string to use in a query when performing lookups
+ ("contains", "like", etc). The resulting string should contain a '%s'
+ placeholder for the column being searched against.
+ """
+ return "%s"
+
def max_name_length(self):
"""
Returns the maximum length of table and column names, or None if there
@@ -176,6 +184,14 @@ class BaseDatabaseOperations(object):
"""
return None
+ def no_limit_value(self):
+ """
+ Returns the value to use for the LIMIT when we are wanting "LIMIT
+ infinity". Returns None if the limit clause can be omitted in this case.
+ """
+ # FIXME: API may need to change once Oracle backend is repaired.
+ raise NotImplementedError()
+
def pk_default_value(self):
"""
Returns the value to use during an INSERT statement to specify that
@@ -183,11 +199,11 @@ class BaseDatabaseOperations(object):
"""
return 'DEFAULT'
- def query_set_class(self, DefaultQuerySet):
+ def query_class(self, DefaultQueryClass):
"""
Given the default QuerySet class, returns a custom QuerySet class
to use for this backend. Returns None if a custom QuerySet isn't used.
- See also BaseDatabaseFeatures.uses_custom_queryset, which regulates
+ See also BaseDatabaseFeatures.uses_custom_query_class, which regulates
whether this method is called at all.
"""
return None
@@ -205,6 +221,17 @@ class BaseDatabaseOperations(object):
"""
return 'RANDOM()'
+ def regex_lookup(self, lookup_type):
+ """
+ Returns the string to use in a query when performing regular expression
+ lookups (using "regex" or "iregex"). The resulting string should
+ contain a '%s' placeholder for the column being searched against.
+
+ If the feature is not supported (or part of it is not supported), a
+ NotImplementedError exception can be raised.
+ """
+ raise NotImplementedError
+
def sql_flush(self, style, tables, sequences):
"""
Returns a list of SQL statements required to remove all data from
diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py
index 7782387f41..17aa6f13bf 100644
--- a/django/db/backends/mysql/base.py
+++ b/django/db/backends/mysql/base.py
@@ -62,6 +62,7 @@ server_version_re = re.compile(r'(\d{1,2})\.(\d{1,2})\.(\d{1,2})')
class DatabaseFeatures(BaseDatabaseFeatures):
autoindexes_primary_keys = False
inline_fk_references = False
+ empty_fetchmany_value = ()
class DatabaseOperations(BaseDatabaseOperations):
def date_extract_sql(self, lookup_type, field_name):
@@ -94,6 +95,10 @@ class DatabaseOperations(BaseDatabaseOperations):
sql += "%s," % offset
return sql + str(limit)
+ def no_limit_value(self):
+ # 2**64 - 1, as recommended by the MySQL documentation
+ return 18446744073709551615L
+
def quote_name(self, name):
if name.startswith("`") and name.endswith("`"):
return name # Quoting once is enough.
diff --git a/django/db/backends/mysql_old/base.py b/django/db/backends/mysql_old/base.py
index c22094b968..efbfeeafc5 100644
--- a/django/db/backends/mysql_old/base.py
+++ b/django/db/backends/mysql_old/base.py
@@ -66,6 +66,7 @@ class MysqlDebugWrapper:
class DatabaseFeatures(BaseDatabaseFeatures):
autoindexes_primary_keys = False
inline_fk_references = False
+ empty_fetchmany_value = ()
class DatabaseOperations(BaseDatabaseOperations):
def date_extract_sql(self, lookup_type, field_name):
@@ -98,6 +99,10 @@ class DatabaseOperations(BaseDatabaseOperations):
sql += "%s," % offset
return sql + str(limit)
+ def no_limit_value(self):
+ # 2**64 - 1, as recommended by the MySQL documentation
+ return 18446744073709551615L
+
def quote_name(self, name):
if name.startswith("`") and name.endswith("`"):
return name # Quoting once is enough.
diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py
index 152adf7056..3635acdf2a 100644
--- a/django/db/backends/oracle/base.py
+++ b/django/db/backends/oracle/base.py
@@ -4,11 +4,12 @@ Oracle database backend for Django.
Requires cx_Oracle: http://www.python.net/crew/atuining/cx_Oracle/
"""
+import os
+
from django.db.backends import BaseDatabaseWrapper, BaseDatabaseFeatures, BaseDatabaseOperations, util
+from django.db.backends.oracle import query
from django.utils.datastructures import SortedDict
from django.utils.encoding import smart_str, force_unicode
-import datetime
-import os
# Oracle takes client-side character set encoding from the environment.
os.environ['NLS_LANG'] = '.UTF8'
@@ -24,11 +25,12 @@ IntegrityError = Database.IntegrityError
class DatabaseFeatures(BaseDatabaseFeatures):
allows_group_by_ordinal = False
allows_unique_and_pk = False # Suppress UNIQUE/PK for Oracle (ORA-02259)
+ empty_fetchmany_value = ()
needs_datetime_string_cast = False
needs_upper_for_iops = True
supports_tablespaces = True
uses_case_insensitive_names = True
- uses_custom_queryset = True
+ uses_custom_query_class = True
class DatabaseOperations(BaseDatabaseOperations):
def autoinc_sql(self, table, column):
@@ -89,243 +91,16 @@ class DatabaseOperations(BaseDatabaseOperations):
# Instead, they are handled in django/db/backends/oracle/query.py.
return ""
+ def lookup_cast(self, lookup_type):
+ if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'):
+ return "UPPER(%s)"
+ return "%s"
+
def max_name_length(self):
return 30
- def query_set_class(self, DefaultQuerySet):
- from django.db import connection
- from django.db.models.query import EmptyResultSet, GET_ITERATOR_CHUNK_SIZE, quote_only_if_word
-
- class OracleQuerySet(DefaultQuerySet):
-
- def iterator(self):
- "Performs the SELECT database lookup of this QuerySet."
-
- from django.db.models.query import get_cached_row
-
- # self._select is a dictionary, and dictionaries' key order is
- # undefined, so we convert it to a list of tuples.
- extra_select = self._select.items()
-
- full_query = None
-
- try:
- try:
- select, sql, params, full_query = self._get_sql_clause(get_full_query=True)
- except TypeError:
- select, sql, params = self._get_sql_clause()
- except EmptyResultSet:
- raise StopIteration
- if not full_query:
- full_query = "SELECT %s%s\n%s" % ((self._distinct and "DISTINCT " or ""), ', '.join(select), sql)
-
- cursor = connection.cursor()
- cursor.execute(full_query, params)
-
- fill_cache = self._select_related
- fields = self.model._meta.fields
- index_end = len(fields)
-
- # so here's the logic;
- # 1. retrieve each row in turn
- # 2. convert NCLOBs
-
- while 1:
- rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)
- if not rows:
- raise StopIteration
- for row in rows:
- row = self.resolve_columns(row, fields)
- if fill_cache:
- obj, index_end = get_cached_row(klass=self.model, row=row,
- index_start=0, max_depth=self._max_related_depth)
- else:
- obj = self.model(*row[:index_end])
- for i, k in enumerate(extra_select):
- setattr(obj, k[0], row[index_end+i])
- yield obj
-
-
- def _get_sql_clause(self, get_full_query=False):
- from django.db.models.query import fill_table_cache, \
- handle_legacy_orderlist, orderfield2column
-
- opts = self.model._meta
- qn = connection.ops.quote_name
-
- # Construct the fundamental parts of the query: SELECT X FROM Y WHERE Z.
- select = ["%s.%s" % (qn(opts.db_table), qn(f.column)) for f in opts.fields]
- tables = [quote_only_if_word(t) for t in self._tables]
- joins = SortedDict()
- where = self._where[:]
- params = self._params[:]
-
- # Convert self._filters into SQL.
- joins2, where2, params2 = self._filters.get_sql(opts)
- joins.update(joins2)
- where.extend(where2)
- params.extend(params2)
-
- # Add additional tables and WHERE clauses based on select_related.
- if self._select_related:
- fill_table_cache(opts, select, tables, where, opts.db_table, [opts.db_table])
-
- # Add any additional SELECTs.
- if self._select:
- select.extend(['(%s) AS %s' % (quote_only_if_word(s[1]), qn(s[0])) for s in self._select.items()])
-
- # Start composing the body of the SQL statement.
- sql = [" FROM", qn(opts.db_table)]
-
- # Compose the join dictionary into SQL describing the joins.
- if joins:
- sql.append(" ".join(["%s %s %s ON %s" % (join_type, table, alias, condition)
- for (alias, (table, join_type, condition)) in joins.items()]))
-
- # Compose the tables clause into SQL.
- if tables:
- sql.append(", " + ", ".join(tables))
-
- # Compose the where clause into SQL.
- if where:
- sql.append(where and "WHERE " + " AND ".join(where))
-
- # ORDER BY clause
- order_by = []
- if self._order_by is not None:
- ordering_to_use = self._order_by
- else:
- ordering_to_use = opts.ordering
- for f in handle_legacy_orderlist(ordering_to_use):
- if f == '?': # Special case.
- order_by.append(DatabaseOperations().random_function_sql())
- else:
- if f.startswith('-'):
- col_name = f[1:]
- order = "DESC"
- else:
- col_name = f
- order = "ASC"
- if "." in col_name:
- table_prefix, col_name = col_name.split('.', 1)
- table_prefix = qn(table_prefix) + '.'
- else:
- # Use the database table as a column prefix if it wasn't given,
- # and if the requested column isn't a custom SELECT.
- if "." not in col_name and col_name not in (self._select or ()):
- table_prefix = qn(opts.db_table) + '.'
- else:
- table_prefix = ''
- order_by.append('%s%s %s' % (table_prefix, qn(orderfield2column(col_name, opts)), order))
- if order_by:
- sql.append("ORDER BY " + ", ".join(order_by))
-
- # Look for column name collisions in the select elements
- # and fix them with an AS alias. This allows us to do a
- # SELECT * later in the paging query.
- cols = [clause.split('.')[-1] for clause in select]
- for index, col in enumerate(cols):
- if cols.count(col) > 1:
- col = '%s%d' % (col.replace('"', ''), index)
- cols[index] = col
- select[index] = '%s AS %s' % (select[index], col)
-
- # LIMIT and OFFSET clauses
- # To support limits and offsets, Oracle requires some funky rewriting of an otherwise normal looking query.
- select_clause = ",".join(select)
- distinct = (self._distinct and "DISTINCT " or "")
-
- if order_by:
- order_by_clause = " OVER (ORDER BY %s )" % (", ".join(order_by))
- else:
- #Oracle's row_number() function always requires an order-by clause.
- #So we need to define a default order-by, since none was provided.
- order_by_clause = " OVER (ORDER BY %s.%s)" % \
- (qn(opts.db_table), qn(opts.fields[0].db_column or opts.fields[0].column))
- # limit_and_offset_clause
- if self._limit is None:
- assert self._offset is None, "'offset' is not allowed without 'limit'"
-
- if self._offset is not None:
- offset = int(self._offset)
- else:
- offset = 0
- if self._limit is not None:
- limit = int(self._limit)
- else:
- limit = None
-
- limit_and_offset_clause = ''
- if limit is not None:
- limit_and_offset_clause = "WHERE rn > %s AND rn <= %s" % (offset, limit+offset)
- elif offset:
- limit_and_offset_clause = "WHERE rn > %s" % (offset)
-
- if len(limit_and_offset_clause) > 0:
- fmt = \
- """SELECT * FROM
- (SELECT %s%s,
- ROW_NUMBER()%s AS rn
- %s)
- %s"""
- full_query = fmt % (distinct, select_clause,
- order_by_clause, ' '.join(sql).strip(),
- limit_and_offset_clause)
- else:
- full_query = None
-
- if get_full_query:
- return select, " ".join(sql), params, full_query
- else:
- return select, " ".join(sql), params
-
- def resolve_columns(self, row, fields=()):
- from django.db.models.fields import DateField, DateTimeField, \
- TimeField, BooleanField, NullBooleanField, DecimalField, Field
- values = []
- for value, field in map(None, row, fields):
- if isinstance(value, Database.LOB):
- value = value.read()
- # Oracle stores empty strings as null. We need to undo this in
- # order to adhere to the Django convention of using the empty
- # string instead of null, but only if the field accepts the
- # empty string.
- if value is None and isinstance(field, Field) and field.empty_strings_allowed:
- value = u''
- # Convert 1 or 0 to True or False
- elif value in (1, 0) and isinstance(field, (BooleanField, NullBooleanField)):
- value = bool(value)
- # Convert floats to decimals
- elif value is not None and isinstance(field, DecimalField):
- value = util.typecast_decimal(field.format_number(value))
- # cx_Oracle always returns datetime.datetime objects for
- # DATE and TIMESTAMP columns, but Django wants to see a
- # python datetime.date, .time, or .datetime. We use the type
- # of the Field to determine which to cast to, but it's not
- # always available.
- # As a workaround, we cast to date if all the time-related
- # values are 0, or to time if the date is 1/1/1900.
- # This could be cleaned a bit by adding a method to the Field
- # classes to normalize values from the database (the to_python
- # method is used for validation and isn't what we want here).
- elif isinstance(value, Database.Timestamp):
- # In Python 2.3, the cx_Oracle driver returns its own
- # Timestamp object that we must convert to a datetime class.
- if not isinstance(value, datetime.datetime):
- value = datetime.datetime(value.year, value.month, value.day, value.hour,
- value.minute, value.second, value.fsecond)
- if isinstance(field, DateTimeField):
- pass # DateTimeField subclasses DateField so must be checked first.
- elif isinstance(field, DateField):
- value = value.date()
- elif isinstance(field, TimeField) or (value.year == 1900 and value.month == value.day == 1):
- value = value.time()
- elif value.hour == value.minute == value.second == value.microsecond == 0:
- value = value.date()
- values.append(value)
- return values
-
- return OracleQuerySet
+ def query_class(self, DefaultQueryClass):
+ return query.query_class(DefaultQueryClass, Database)
def quote_name(self, name):
# SQL92 requires delimited (quoted) names to be case-sensitive. When
@@ -339,6 +114,23 @@ class DatabaseOperations(BaseDatabaseOperations):
def random_function_sql(self):
return "DBMS_RANDOM.RANDOM"
+ def regex_lookup_9(self, lookup_type):
+ raise NotImplementedError("Regexes are not supported in Oracle before version 10g.")
+
+ def regex_lookup_10(self, lookup_type):
+ if lookup_type == 'regex':
+ match_option = "'c'"
+ else:
+ match_option = "'i'"
+ return 'REGEXP_LIKE(%%s, %%s, %s)' % match_option
+
+ def regex_lookup(self, lookup_type):
+ # If regex_lookup is called before it's been initialized, then create
+ # a cursor to initialize it and recur.
+ from django.db import connection
+ connection.cursor()
+ return connection.ops.regex_lookup(lookup_type)
+
def sql_flush(self, style, tables, sequences):
# Return a list of 'TRUNCATE x;', 'TRUNCATE y;',
# 'TRUNCATE z;'... style SQL statements
@@ -430,6 +222,14 @@ class DatabaseWrapper(BaseDatabaseWrapper):
"NLS_TIMESTAMP_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF'")
try:
self.oracle_version = int(self.connection.version.split('.')[0])
+ # There's no way for the DatabaseOperations class to know the
+ # currently active Oracle version, so we do some setups here.
+ # TODO: Multi-db support will need a better solution (a way to
+ # communicate the current version).
+ if self.oracle_version <= 9:
+ self.ops.regex_lookup = self.ops.regex_lookup_9
+ else:
+ self.ops.regex_lookup = self.ops.regex_lookup_10
except ValueError:
pass
try:
diff --git a/django/db/backends/oracle/query.py b/django/db/backends/oracle/query.py
new file mode 100644
index 0000000000..033ffe8533
--- /dev/null
+++ b/django/db/backends/oracle/query.py
@@ -0,0 +1,151 @@
+"""
+Custom Query class for this backend (a derivative of
+django.db.models.sql.query.Query).
+"""
+
+import datetime
+
+from django.db.backends import util
+
+# Cache. Maps default query class to new Oracle query class.
+_classes = {}
+
+def query_class(QueryClass, Database):
+ """
+ Returns a custom djang.db.models.sql.query.Query subclass that is
+ appropraite for Oracle.
+
+ The 'Database' module (cx_Oracle) is passed in here so that all the setup
+ required to import it only needs to be done by the calling module.
+ """
+ global _classes
+ try:
+ return _classes[QueryClass]
+ except KeyError:
+ pass
+
+ class OracleQuery(QueryClass):
+ def resolve_columns(self, row, fields=()):
+ index_start = len(self.extra_select.keys())
+ values = [self.convert_values(v, None) for v in row[:index_start]]
+ for value, field in map(None, row[index_start:], fields):
+ values.append(self.convert_values(value, field))
+ return values
+
+ def convert_values(self, value, field):
+ from django.db.models.fields import DateField, DateTimeField, \
+ TimeField, BooleanField, NullBooleanField, DecimalField, Field
+ if isinstance(value, Database.LOB):
+ value = value.read()
+ # Oracle stores empty strings as null. We need to undo this in
+ # order to adhere to the Django convention of using the empty
+ # string instead of null, but only if the field accepts the
+ # empty string.
+ if value is None and isinstance(field, Field) and field.empty_strings_allowed:
+ value = u''
+ # Convert 1 or 0 to True or False
+ elif value in (1, 0) and isinstance(field, (BooleanField, NullBooleanField)):
+ value = bool(value)
+ # Convert floats to decimals
+ elif value is not None and isinstance(field, DecimalField):
+ value = util.typecast_decimal(field.format_number(value))
+ # cx_Oracle always returns datetime.datetime objects for
+ # DATE and TIMESTAMP columns, but Django wants to see a
+ # python datetime.date, .time, or .datetime. We use the type
+ # of the Field to determine which to cast to, but it's not
+ # always available.
+ # As a workaround, we cast to date if all the time-related
+ # values are 0, or to time if the date is 1/1/1900.
+ # This could be cleaned a bit by adding a method to the Field
+ # classes to normalize values from the database (the to_python
+ # method is used for validation and isn't what we want here).
+ elif isinstance(value, Database.Timestamp):
+ # In Python 2.3, the cx_Oracle driver returns its own
+ # Timestamp object that we must convert to a datetime class.
+ if not isinstance(value, datetime.datetime):
+ value = datetime.datetime(value.year, value.month,
+ value.day, value.hour, value.minute, value.second,
+ value.fsecond)
+ if isinstance(field, DateTimeField):
+ # DateTimeField subclasses DateField so must be checked
+ # first.
+ pass
+ elif isinstance(field, DateField):
+ value = value.date()
+ elif isinstance(field, TimeField) or (value.year == 1900 and value.month == value.day == 1):
+ value = value.time()
+ elif value.hour == value.minute == value.second == value.microsecond == 0:
+ value = value.date()
+ return value
+
+ def as_sql(self, with_limits=True, with_col_aliases=False):
+ """
+ Creates the SQL for this query. Returns the SQL string and list
+ of parameters. This is overriden from the original Query class
+ to accommodate Oracle's limit/offset SQL.
+
+ If 'with_limits' is False, any limit/offset information is not
+ included in the query.
+ """
+ # The `do_offset` flag indicates whether we need to construct
+ # the SQL needed to use limit/offset w/Oracle.
+ do_offset = with_limits and (self.high_mark or self.low_mark)
+
+ # If no offsets, just return the result of the base class
+ # `as_sql`.
+ if not do_offset:
+ return super(OracleQuery, self).as_sql(with_limits=False,
+ with_col_aliases=with_col_aliases)
+
+ # `get_columns` needs to be called before `get_ordering` to
+ # populate `_select_alias`.
+ self.pre_sql_setup()
+ out_cols = self.get_columns()
+ ordering = self.get_ordering()
+
+ # Getting the "ORDER BY" SQL for the ROW_NUMBER() result.
+ if ordering:
+ rn_orderby = ', '.join(ordering)
+ else:
+ # Oracle's ROW_NUMBER() function always requires an
+ # order-by clause. So we need to define a default
+ # order-by, since none was provided.
+ qn = self.quote_name_unless_alias
+ opts = self.model._meta
+ rn_orderby = '%s.%s' % (qn(opts.db_table), qn(opts.fields[0].db_column or opts.fields[0].column))
+
+ # Getting the selection SQL and the params, which has the `rn`
+ # extra selection SQL.
+ self.extra_select['rn'] = 'ROW_NUMBER() OVER (ORDER BY %s )' % rn_orderby
+ sql, params= super(OracleQuery, self).as_sql(with_limits=False,
+ with_col_aliases=True)
+
+ # Constructing the result SQL, using the initial select SQL
+ # obtained above.
+ result = ['SELECT * FROM (%s)' % sql]
+
+ # Place WHERE condition on `rn` for the desired range.
+ result.append('WHERE rn > %d' % self.low_mark)
+ if self.high_mark:
+ result.append('AND rn <= %d' % self.high_mark)
+
+ # Returning the SQL w/params.
+ return ' '.join(result), params
+
+ def set_limits(self, low=None, high=None):
+ super(OracleQuery, self).set_limits(low, high)
+
+ # We need to select the row number for the LIMIT/OFFSET sql.
+ # A placeholder is added to extra_select now, because as_sql is
+ # too late to be modifying extra_select. However, the actual sql
+ # depends on the ordering, so that is generated in as_sql.
+ self.extra_select['rn'] = '1'
+
+ def clear_limits(self):
+ super(OracleQuery, self).clear_limits()
+ if 'rn' in self.extra_select:
+ del self.extra_select['rn']
+
+ _classes[QueryClass] = OracleQuery
+ return OracleQuery
+
diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py
index cd46413453..7e27b23f72 100644
--- a/django/db/backends/postgresql/operations.py
+++ b/django/db/backends/postgresql/operations.py
@@ -44,6 +44,9 @@ class DatabaseOperations(BaseDatabaseOperations):
cursor.execute("SELECT CURRVAL('\"%s_%s_seq\"')" % (table_name, pk_name))
return cursor.fetchone()[0]
+ def no_limit_value(self):
+ return None
+
def quote_name(self, name):
if name.startswith('"') and name.endswith('"'):
return name # Quoting once is enough.
diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py
index b4b445cd16..b8bf5c8f0b 100644
--- a/django/db/backends/sqlite3/base.py
+++ b/django/db/backends/sqlite3/base.py
@@ -63,6 +63,9 @@ class DatabaseOperations(BaseDatabaseOperations):
return name # Quoting once is enough.
return '"%s"' % name
+ def no_limit_value(self):
+ return -1
+
def sql_flush(self, style, tables, sequences):
# NB: The generated SQL below is specific to SQLite
# Note: The DELETE FROM... SQL generated below works for SQLite databases
diff --git a/django/db/models/base.py b/django/db/models/base.py
index 4f034bcf8b..bb02d7a00c 100644
--- a/django/db/models/base.py
+++ b/django/db/models/base.py
@@ -1,10 +1,16 @@
-import django.db.models.manipulators
-import django.db.models.manager
+import copy
+import types
+import sys
+import os
+from itertools import izip
+
+import django.db.models.manipulators # Imported to register signal handler.
+import django.db.models.manager # Ditto.
from django.core import validators
-from django.core.exceptions import ObjectDoesNotExist, MultipleObjectsReturned
+from django.core.exceptions import ObjectDoesNotExist, MultipleObjectsReturned, FieldError
from django.db.models.fields import AutoField, ImageField, FieldDoesNotExist
-from django.db.models.fields.related import OneToOneRel, ManyToOneRel
-from django.db.models.query import delete_objects
+from django.db.models.fields.related import OneToOneRel, ManyToOneRel, OneToOneField
+from django.db.models.query import delete_objects, Q
from django.db.models.options import Options, AdminOptions
from django.db import connection, transaction
from django.db.models import signals
@@ -14,10 +20,11 @@ from django.utils.datastructures import SortedDict
from django.utils.functional import curry
from django.utils.encoding import smart_str, force_unicode, smart_unicode
from django.conf import settings
-from itertools import izip
-import types
-import sys
-import os
+
+try:
+ set
+except NameError:
+ from sets import Set as set # Python 2.3 fallback
class ModelBase(type):
"Metaclass for all models"
@@ -25,29 +32,45 @@ class ModelBase(type):
# If this isn't a subclass of Model, don't do anything special.
try:
parents = [b for b in bases if issubclass(b, Model)]
- if not parents:
- return super(ModelBase, cls).__new__(cls, name, bases, attrs)
except NameError:
# 'Model' isn't defined yet, meaning we're looking at Django's own
# Model class, defined below.
+ parents = []
+ if not parents:
return super(ModelBase, cls).__new__(cls, name, bases, attrs)
# Create the class.
- new_class = type.__new__(cls, name, bases, {'__module__': attrs.pop('__module__')})
- new_class.add_to_class('_meta', Options(attrs.pop('Meta', None)))
- new_class.add_to_class('DoesNotExist', types.ClassType('DoesNotExist', (ObjectDoesNotExist,), {}))
- new_class.add_to_class('MultipleObjectsReturned',
- types.ClassType('MultipleObjectsReturned', (MultipleObjectsReturned, ), {}))
-
- # Build complete list of parents
- for base in parents:
- # Things without _meta aren't functional models, so they're
- # uninteresting parents.
- if hasattr(base, '_meta'):
- new_class._meta.parents.append(base)
- new_class._meta.parents.extend(base._meta.parents)
+ module = attrs.pop('__module__')
+ new_class = type.__new__(cls, name, bases, {'__module__': module})
+ attr_meta = attrs.pop('Meta', None)
+ abstract = getattr(attr_meta, 'abstract', False)
+ if not attr_meta:
+ meta = getattr(new_class, 'Meta', None)
+ else:
+ meta = attr_meta
+ base_meta = getattr(new_class, '_meta', None)
+ new_class.add_to_class('_meta', Options(meta))
+ if not abstract:
+ new_class.add_to_class('DoesNotExist',
+ subclass_exception('DoesNotExist', ObjectDoesNotExist, module))
+ new_class.add_to_class('MultipleObjectsReturned',
+ subclass_exception('MultipleObjectsReturned', MultipleObjectsReturned, module))
+ if base_meta and not base_meta.abstract:
+ # Non-abstract child classes inherit some attributes from their
+ # non-abstract parent (unless an ABC comes before it in the
+ # method resolution order).
+ if not hasattr(meta, 'ordering'):
+ new_class._meta.ordering = base_meta.ordering
+ if not hasattr(meta, 'get_latest_by'):
+ new_class._meta.get_latest_by = base_meta.get_latest_by
+ old_default_mgr = None
+ if getattr(new_class, '_default_manager', None):
+ # We have a parent who set the default manager.
+ if new_class._default_manager.model._meta.abstract:
+ old_default_mgr = new_class._default_manager
+ new_class._default_manager = None
if getattr(new_class._meta, 'app_label', None) is None:
# Figure out the app_label by looking one level up.
# For 'django.contrib.sites.models', this would be 'sites'.
@@ -63,21 +86,50 @@ class ModelBase(type):
for obj_name, obj in attrs.items():
new_class.add_to_class(obj_name, obj)
- # Add Fields inherited from parents
- for parent in new_class._meta.parents:
- for field in parent._meta.fields:
- # Only add parent fields if they aren't defined for this class.
- try:
- new_class._meta.get_field(field.name)
- except FieldDoesNotExist:
- field.contribute_to_class(new_class, field.name)
+ # Do the appropriate setup for any model parents.
+ o2o_map = dict([(f.rel.to, f) for f in new_class._meta.local_fields
+ if isinstance(f, OneToOneField)])
+ for base in parents:
+ if not hasattr(base, '_meta'):
+ # Things without _meta aren't functional models, so they're
+ # uninteresting parents.
+ continue
+ if not base._meta.abstract:
+ if base in o2o_map:
+ field = o2o_map[base]
+ field.primary_key = True
+ new_class._meta.setup_pk(field)
+ else:
+ attr_name = '%s_ptr' % base._meta.module_name
+ field = OneToOneField(base, name=attr_name,
+ auto_created=True, parent_link=True)
+ new_class.add_to_class(attr_name, field)
+ new_class._meta.parents[base] = field
+ else:
+ # The abstract base class case.
+ names = set([f.name for f in new_class._meta.local_fields + new_class._meta.many_to_many])
+ for field in base._meta.local_fields + base._meta.local_many_to_many:
+ if field.name in names:
+ raise FieldError('Local field %r in class %r clashes with field of similar name from abstract base class %r'
+ % (field.name, name, base.__name__))
+ new_class.add_to_class(field.name, copy.deepcopy(field))
- new_class._prepare()
+ if abstract:
+ # Abstract base models can't be instantiated and don't appear in
+ # the list of models for an app. We do the final setup for them a
+ # little differently from normal models.
+ attr_meta.abstract = False
+ new_class.Meta = attr_meta
+ return new_class
+ if old_default_mgr and not new_class._default_manager:
+ new_class._default_manager = old_default_mgr._copy_to_model(new_class)
+ new_class._prepare()
register_models(new_class._meta.app_label, new_class)
+
# Because of the way imports happen (recursively), we may or may not be
- # the first class for this model to register with the framework. There
- # should only be one class for each model, so we must always return the
+ # the first time this model tries to register with the framework. There
+ # should only be one class for each model, so we always return the
# registered version.
return get_model(new_class._meta.app_label, name, False)
@@ -113,31 +165,6 @@ class ModelBase(type):
class Model(object):
__metaclass__ = ModelBase
- def _get_pk_val(self):
- return getattr(self, self._meta.pk.attname)
-
- def _set_pk_val(self, value):
- return setattr(self, self._meta.pk.attname, value)
-
- pk = property(_get_pk_val, _set_pk_val)
-
- def __repr__(self):
- return smart_str(u'<%s: %s>' % (self.__class__.__name__, unicode(self)))
-
- def __str__(self):
- if hasattr(self, '__unicode__'):
- return force_unicode(self).encode('utf-8')
- return '%s object' % self.__class__.__name__
-
- def __eq__(self, other):
- return isinstance(other, self.__class__) and self._get_pk_val() == other._get_pk_val()
-
- def __ne__(self, other):
- return not self.__eq__(other)
-
- def __hash__(self):
- return hash(self._get_pk_val())
-
def __init__(self, *args, **kwargs):
dispatcher.send(signal=signals.pre_init, sender=self.__class__, args=args, kwargs=kwargs)
@@ -210,72 +237,133 @@ class Model(object):
raise TypeError, "'%s' is an invalid keyword argument for this function" % kwargs.keys()[0]
dispatcher.send(signal=signals.post_init, sender=self.__class__, instance=self)
- def save(self, raw=False):
- dispatcher.send(signal=signals.pre_save, sender=self.__class__,
- instance=self, raw=raw)
+ def from_sequence(cls, values):
+ """
+ An alternate class constructor, primarily for internal use.
- non_pks = [f for f in self._meta.fields if not f.primary_key]
- cursor = connection.cursor()
+ Creates a model instance from a sequence of values (which corresponds
+ to all the non-many-to-many fields in creation order. If there are more
+ fields than values, the remaining (final) fields are given their
+ default values.
+
+ ForeignKey fields can only be initialised using id values, not
+ instances, in this method.
+ """
+ dispatcher.send(signal=signals.pre_init, sender=cls, args=values,
+ kwargs={})
+ obj = Empty()
+ obj.__class__ = cls
+ field_iter = iter(obj._meta.fields)
+ for val, field in izip(values, field_iter):
+ setattr(obj, field.attname, val)
+ for field in field_iter:
+ setattr(obj, field.attname, field.get_default())
+ dispatcher.send(signal=signals.post_init, sender=cls, instance=obj)
+ return obj
- qn = connection.ops.quote_name
+ from_sequence = classmethod(from_sequence)
+
+ def __repr__(self):
+ return smart_str(u'<%s: %s>' % (self.__class__.__name__, unicode(self)))
+
+ def __str__(self):
+ if hasattr(self, '__unicode__'):
+ return force_unicode(self).encode('utf-8')
+ return '%s object' % self.__class__.__name__
+
+ def __eq__(self, other):
+ return isinstance(other, self.__class__) and self._get_pk_val() == other._get_pk_val()
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __hash__(self):
+ return hash(self._get_pk_val())
+
+ def _get_pk_val(self, meta=None):
+ if not meta:
+ meta = self._meta
+ return getattr(self, meta.pk.attname)
+
+ def _set_pk_val(self, value):
+ return setattr(self, self._meta.pk.attname, value)
+
+ pk = property(_get_pk_val, _set_pk_val)
+
+ def save(self):
+ """
+ Save the current instance. Override this in a subclass if you want to
+ control the saving process.
+ """
+ self.save_base()
+
+ save.alters_data = True
+
+ def save_base(self, raw=False, cls=None):
+ """
+ Does the heavy-lifting involved in saving. Subclasses shouldn't need to
+ override this method. It's separate from save() in order to hide the
+ need for overrides of save() to pass around internal-only parameters
+ ('raw' and 'cls').
+ """
+ if not cls:
+ cls = self.__class__
+ meta = self._meta
+ signal = True
+ dispatcher.send(signal=signals.pre_save, sender=self.__class__,
+ instance=self, raw=raw)
+ else:
+ meta = cls._meta
+ signal = False
+
+ for parent, field in meta.parents.items():
+ self.save_base(raw, parent)
+ setattr(self, field.attname, self._get_pk_val(parent._meta))
+
+ non_pks = [f for f in meta.local_fields if not f.primary_key]
# First, try an UPDATE. If that doesn't update anything, do an INSERT.
- pk_val = self._get_pk_val()
+ pk_val = self._get_pk_val(meta)
# Note: the comparison with '' is required for compatibility with
# oldforms-style model creation.
pk_set = pk_val is not None and smart_unicode(pk_val) != u''
record_exists = True
+ manager = cls._default_manager
if pk_set:
# Determine whether a record with the primary key already exists.
- cursor.execute("SELECT 1 FROM %s WHERE %s=%%s" % \
- (qn(self._meta.db_table), qn(self._meta.pk.column)),
- self._meta.pk.get_db_prep_lookup('exact', pk_val))
- # If it does already exist, do an UPDATE.
- if cursor.fetchone():
- db_values = [f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, False)) for f in non_pks]
- if db_values:
- cursor.execute("UPDATE %s SET %s WHERE %s=%%s" % \
- (qn(self._meta.db_table),
- ','.join(['%s=%%s' % qn(f.column) for f in non_pks]),
- qn(self._meta.pk.column)),
- db_values + self._meta.pk.get_db_prep_lookup('exact', pk_val))
+ if manager.filter(pk=pk_val).extra(select={'a': 1}).values('a').order_by():
+ # It does already exist, so do an UPDATE.
+ if non_pks:
+ values = [(f, None, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, False))) for f in non_pks]
+ manager.filter(pk=pk_val)._update(values)
else:
record_exists = False
if not pk_set or not record_exists:
- field_names = [qn(f.column) for f in self._meta.fields if not isinstance(f, AutoField)]
- db_values = [f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True)) for f in self._meta.fields if not isinstance(f, AutoField)]
- # If the PK has been manually set, respect that.
- if pk_set:
- field_names += [f.column for f in self._meta.fields if isinstance(f, AutoField)]
- db_values += [f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True)) for f in self._meta.fields if isinstance(f, AutoField)]
- placeholders = ['%s'] * len(field_names)
- if self._meta.order_with_respect_to:
- field_names.append(qn('_order'))
- placeholders.append('%s')
- subsel = 'SELECT COUNT(*) FROM %s WHERE %s = %%s' % (
- qn(self._meta.db_table),
- qn(self._meta.order_with_respect_to.column))
- cursor.execute(subsel, (getattr(self, self._meta.order_with_respect_to.attname),))
- db_values.append(cursor.fetchone()[0])
+ if not pk_set:
+ values = [(f, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True))) for f in meta.local_fields if not isinstance(f, AutoField)]
+ else:
+ values = [(f, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True))) for f in meta.local_fields]
+
+ if meta.order_with_respect_to:
+ field = meta.order_with_respect_to
+ values.append((meta.get_field_by_name('_order')[0], manager.filter(**{field.name: getattr(self, field.attname)}).count()))
record_exists = False
- if db_values:
- cursor.execute("INSERT INTO %s (%s) VALUES (%s)" % \
- (qn(self._meta.db_table), ','.join(field_names),
- ','.join(placeholders)), db_values)
+
+ update_pk = bool(meta.has_auto_field and not pk_set)
+ if values:
+ # Create a new record.
+ result = manager._insert(values, return_id=update_pk)
else:
# Create a new record with defaults for everything.
- cursor.execute("INSERT INTO %s (%s) VALUES (%s)" %
- (qn(self._meta.db_table), qn(self._meta.pk.column),
- connection.ops.pk_default_value()))
- if self._meta.has_auto_field and not pk_set:
- setattr(self, self._meta.pk.attname, connection.ops.last_insert_id(cursor, self._meta.db_table, self._meta.pk.column))
- transaction.commit_unless_managed()
+ result = manager._insert([(meta.pk, connection.ops.pk_default_value())], return_id=update_pk, raw_values=True)
- # Run any post-save hooks.
- dispatcher.send(signal=signals.post_save, sender=self.__class__,
- instance=self, created=(not record_exists), raw=raw)
+ if update_pk:
+ setattr(self, meta.pk.attname, result)
+ transaction.commit_unless_managed()
- save.alters_data = True
+ if signal:
+ dispatcher.send(signal=signals.post_save, sender=self.__class__,
+ instance=self, created=(not record_exists), raw=raw)
def validate(self):
"""
@@ -341,32 +429,31 @@ class Model(object):
return force_unicode(dict(field.choices).get(value, value), strings_only=True)
def _get_next_or_previous_by_FIELD(self, field, is_next, **kwargs):
- qn = connection.ops.quote_name
- op = is_next and '>' or '<'
- where = '(%s %s %%s OR (%s = %%s AND %s.%s %s %%s))' % \
- (qn(field.column), op, qn(field.column),
- qn(self._meta.db_table), qn(self._meta.pk.column), op)
+ op = is_next and 'gt' or 'lt'
+ order = not is_next and '-' or ''
param = smart_str(getattr(self, field.attname))
- q = self.__class__._default_manager.filter(**kwargs).order_by((not is_next and '-' or '') + field.name, (not is_next and '-' or '') + self._meta.pk.name)
- q._where.append(where)
- q._params.extend([param, param, getattr(self, self._meta.pk.attname)])
+ q = Q(**{'%s__%s' % (field.name, op): param})
+ q = q|Q(**{field.name: param, 'pk__%s' % op: self.pk})
+ qs = self.__class__._default_manager.filter(**kwargs).filter(q).order_by('%s%s' % (order, field.name), '%spk' % order)
try:
- return q[0]
+ return qs[0]
except IndexError:
raise self.DoesNotExist, "%s matching query does not exist." % self.__class__._meta.object_name
def _get_next_or_previous_in_order(self, is_next):
- qn = connection.ops.quote_name
cachename = "__%s_order_cache" % is_next
if not hasattr(self, cachename):
+ qn = connection.ops.quote_name
op = is_next and '>' or '<'
+ order = not is_next and '-_order' or '_order'
order_field = self._meta.order_with_respect_to
+ # FIXME: When querysets support nested queries, this can be turned
+ # into a pure queryset operation.
where = ['%s %s (SELECT %s FROM %s WHERE %s=%%s)' % \
(qn('_order'), op, qn('_order'),
- qn(self._meta.db_table), qn(self._meta.pk.column)),
- '%s=%%s' % qn(order_field.column)]
- params = [self._get_pk_val(), getattr(self, order_field.attname)]
- obj = self._default_manager.order_by('_order').extra(where=where, params=params)[:1].get()
+ qn(self._meta.db_table), qn(self._meta.pk.column))]
+ params = [self.pk]
+ obj = self._default_manager.filter(**{order_field.name: getattr(self, order_field.attname)}).extra(where=where, params=params).order_by(order)[:1].get()
setattr(self, cachename, obj)
return getattr(self, cachename)
@@ -446,29 +533,20 @@ class Model(object):
# ORDERING METHODS #########################
def method_set_order(ordered_obj, self, id_list):
- qn = connection.ops.quote_name
- cursor = connection.cursor()
- # Example: "UPDATE poll_choices SET _order = %s WHERE poll_id = %s AND id = %s"
- sql = "UPDATE %s SET %s = %%s WHERE %s = %%s AND %s = %%s" % \
- (qn(ordered_obj._meta.db_table), qn('_order'),
- qn(ordered_obj._meta.order_with_respect_to.column),
- qn(ordered_obj._meta.pk.column))
rel_val = getattr(self, ordered_obj._meta.order_with_respect_to.rel.field_name)
- cursor.executemany(sql, [(i, rel_val, j) for i, j in enumerate(id_list)])
+ order_name = ordered_obj._meta.order_with_respect_to.name
+ # FIXME: It would be nice if there was an "update many" version of update
+ # for situations like this.
+ for i, j in enumerate(id_list):
+ ordered_obj.objects.filter(**{'pk': j, order_name: rel_val}).update(_order=i)
transaction.commit_unless_managed()
def method_get_order(ordered_obj, self):
- qn = connection.ops.quote_name
- cursor = connection.cursor()
- # Example: "SELECT id FROM poll_choices WHERE poll_id = %s ORDER BY _order"
- sql = "SELECT %s FROM %s WHERE %s = %%s ORDER BY %s" % \
- (qn(ordered_obj._meta.pk.column),
- qn(ordered_obj._meta.db_table),
- qn(ordered_obj._meta.order_with_respect_to.column),
- qn('_order'))
rel_val = getattr(self, ordered_obj._meta.order_with_respect_to.rel.field_name)
- cursor.execute(sql, [rel_val])
- return [r[0] for r in cursor.fetchall()]
+ order_name = ordered_obj._meta.order_with_respect_to.name
+ pk_name = ordered_obj._meta.pk.name
+ return [r[pk_name] for r in
+ ordered_obj.objects.filter(**{order_name: rel_val}).values(pk_name)]
##############################################
# HELPER FUNCTIONS (CURRIED MODEL FUNCTIONS) #
@@ -476,3 +554,20 @@ def method_get_order(ordered_obj, self):
def get_absolute_url(opts, func, self, *args, **kwargs):
return settings.ABSOLUTE_URL_OVERRIDES.get('%s.%s' % (opts.app_label, opts.module_name), func)(self, *args, **kwargs)
+
+########
+# MISC #
+########
+
+class Empty(object):
+ pass
+
+if sys.version_info < (2, 5):
+ # Prior to Python 2.5, Exception was an old-style class
+ def subclass_exception(name, parent, unused):
+ return types.ClassType(name, (parent,), {})
+
+else:
+ def subclass_exception(name, parent, module):
+ return type(name, (parent,), {'__module__': module})
+
diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py
index 13a84ece5f..7778117fb3 100644
--- a/django/db/models/fields/__init__.py
+++ b/django/db/models/fields/__init__.py
@@ -1,3 +1,4 @@
+import copy
import datetime
import os
import time
@@ -75,15 +76,19 @@ class Field(object):
# database level.
empty_strings_allowed = True
- # Tracks each time a Field instance is created. Used to retain order.
+ # These track each time a Field instance is created. Used to retain order.
+ # The auto_creation_counter is used for fields that Django implicitly
+ # creates, creation_counter is used for all user-specified fields.
creation_counter = 0
+ auto_creation_counter = -1
def __init__(self, verbose_name=None, name=None, primary_key=False,
- max_length=None, unique=False, blank=False, null=False, db_index=False,
- core=False, rel=None, default=NOT_PROVIDED, editable=True, serialize=True,
- prepopulate_from=None, unique_for_date=None, unique_for_month=None,
- unique_for_year=None, validator_list=None, choices=None, radio_admin=None,
- help_text='', db_column=None, db_tablespace=None):
+ max_length=None, unique=False, blank=False, null=False,
+ db_index=False, core=False, rel=None, default=NOT_PROVIDED,
+ editable=True, serialize=True, prepopulate_from=None,
+ unique_for_date=None, unique_for_month=None, unique_for_year=None,
+ validator_list=None, choices=None, radio_admin=None, help_text='',
+ db_column=None, db_tablespace=None, auto_created=False):
self.name = name
self.verbose_name = verbose_name
self.primary_key = primary_key
@@ -109,14 +114,27 @@ class Field(object):
# Set db_index to True if the field has a relationship and doesn't explicitly set db_index.
self.db_index = db_index
- # Increase the creation counter, and save our local copy.
- self.creation_counter = Field.creation_counter
- Field.creation_counter += 1
+ # Adjust the appropriate creation counter, and save our local copy.
+ if auto_created:
+ self.creation_counter = Field.auto_creation_counter
+ Field.auto_creation_counter -= 1
+ else:
+ self.creation_counter = Field.creation_counter
+ Field.creation_counter += 1
def __cmp__(self, other):
# This is needed because bisect does not take a comparison function.
return cmp(self.creation_counter, other.creation_counter)
+ def __deepcopy__(self, memodict):
+ # We don't have to deepcopy very much here, since most things are not
+ # intended to be altered after initial creation.
+ obj = copy.copy(self)
+ if self.rel:
+ obj.rel = copy.copy(self.rel)
+ memodict[id(self)] = obj
+ return obj
+
def to_python(self, value):
"""
Converts the input value into the expected Python data type, raising
@@ -145,11 +163,10 @@ class Field(object):
# mapped to one of the built-in Django field types. In this case, you
# can implement db_type() instead of get_internal_type() to specify
# exactly which wacky database column type you want to use.
- data_types = get_creation_module().DATA_TYPES
- internal_type = self.get_internal_type()
- if internal_type not in data_types:
+ try:
+ return get_creation_module().DATA_TYPES[self.get_internal_type()] % self.__dict__
+ except KeyError:
return None
- return data_types[internal_type] % self.__dict__
def validate_full(self, field_data, all_data):
"""
diff --git a/django/db/models/fields/proxy.py b/django/db/models/fields/proxy.py
new file mode 100644
index 0000000000..31a31e3c3c
--- /dev/null
+++ b/django/db/models/fields/proxy.py
@@ -0,0 +1,16 @@
+"""
+Field-like classes that aren't really fields. It's easier to use objects that
+have the same attributes as fields sometimes (avoids a lot of special casing).
+"""
+
+from django.db.models import fields
+
+class OrderWrt(fields.IntegerField):
+ """
+ A proxy for the _order database field that is used when
+ Meta.order_with_respect_to is specified.
+ """
+ name = '_order'
+ attname = '_order'
+ column = '_order'
+
diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py
index eceb49378f..f9b913ae50 100644
--- a/django/db/models/fields/related.py
+++ b/django/db/models/fields/related.py
@@ -2,6 +2,7 @@ from django.db import connection, transaction
from django.db.models import signals, get_model
from django.db.models.fields import AutoField, Field, IntegerField, PositiveIntegerField, PositiveSmallIntegerField, get_ul_class
from django.db.models.related import RelatedObject
+from django.db.models.query_utils import QueryWrapper
from django.utils.text import capfirst
from django.utils.translation import ugettext_lazy, string_concat, ungettext, ugettext as _
from django.utils.functional import curry
@@ -27,21 +28,21 @@ def add_lazy_relation(cls, field, relation):
"""
Adds a lookup on ``cls`` when a related field is defined using a string,
i.e.::
-
+
class MyModel(Model):
fk = ForeignKey("AnotherModel")
-
+
This string can be:
-
+
* RECURSIVE_RELATIONSHIP_CONSTANT (i.e. "self") to indicate a recursive
relation.
-
+
* The name of a model (i.e "AnotherModel") to indicate another model in
the same app.
-
+
* An app-label and model name (i.e. "someapp.AnotherModel") to indicate
another model in a different app.
-
+
If the other model hasn't yet been loaded -- almost a given if you're using
lazy relationships -- then the relation won't be set up until the
class_prepared signal fires at the end of model initialization.
@@ -50,7 +51,7 @@ def add_lazy_relation(cls, field, relation):
if relation == RECURSIVE_RELATIONSHIP_CONSTANT:
app_label = cls._meta.app_label
model_name = cls.__name__
-
+
else:
# Look for an "app.Model" relation
try:
@@ -59,10 +60,10 @@ def add_lazy_relation(cls, field, relation):
# If we can't split, assume a model in current app
app_label = cls._meta.app_label
model_name = relation
-
+
# Try to look up the related model, and if it's already loaded resolve the
# string right away. If get_model returns None, it means that the related
- # model isn't loaded yet, so we need to pend the relation until the class
+ # model isn't loaded yet, so we need to pend the relation until the class
# is prepared.
model = get_model(app_label, model_name, False)
if model:
@@ -72,7 +73,7 @@ def add_lazy_relation(cls, field, relation):
key = (app_label, model_name)
value = (cls, field)
pending_lookups.setdefault(key, []).append(value)
-
+
def do_pending_lookups(sender):
"""
Handle any pending relations to the sending model. Sent from class_prepared.
@@ -107,6 +108,8 @@ class RelatedField(object):
add_lazy_relation(cls, self, other)
else:
self.do_related_class(other, cls)
+ if not cls._meta.abstract and self.rel.related_name:
+ self.rel.related_name = self.rel.related_name % {'class': cls.__name__.lower()}
def set_attributes_from_rel(self):
self.name = self.name or (self.rel.to._meta.object_name.lower() + '_' + self.rel.to._meta.pk.name)
@@ -136,6 +139,9 @@ class RelatedField(object):
pass
return v
+ if hasattr(value, 'as_sql'):
+ sql, params = value.as_sql()
+ return QueryWrapper(('(%s)' % sql), params)
if lookup_type == 'exact':
return [pk_trace(value)]
if lookup_type == 'in':
@@ -145,9 +151,10 @@ class RelatedField(object):
raise TypeError, "Related Field has invalid lookup: %s" % lookup_type
def _get_related_query_name(self, opts):
- # This method defines the name that can be used to identify this related object
- # in a table-spanning query. It uses the lower-cased object_name by default,
- # but this can be overridden with the "related_name" option.
+ # This method defines the name that can be used to identify this
+ # related object in a table-spanning query. It uses the lower-cased
+ # object_name by default, but this can be overridden with the
+ # "related_name" option.
return self.rel.related_name or opts.object_name.lower()
class SingleRelatedObjectDescriptor(object):
@@ -158,14 +165,19 @@ class SingleRelatedObjectDescriptor(object):
# SingleRelatedObjectDescriptor instance.
def __init__(self, related):
self.related = related
+ self.cache_name = '_%s_cache' % related.field.name
def __get__(self, instance, instance_type=None):
if instance is None:
raise AttributeError, "%s must be accessed via instance" % self.related.opts.object_name
- params = {'%s__pk' % self.related.field.name: instance._get_pk_val()}
- rel_obj = self.related.model._default_manager.get(**params)
- return rel_obj
+ try:
+ return getattr(instance, self.cache_name)
+ except AttributeError:
+ params = {'%s__pk' % self.related.field.name: instance._get_pk_val()}
+ rel_obj = self.related.model._default_manager.get(**params)
+ setattr(instance, self.cache_name, rel_obj)
+ return rel_obj
def __set__(self, instance, value):
if instance is None:
@@ -495,13 +507,77 @@ class ReverseManyRelatedObjectsDescriptor(object):
manager.clear()
manager.add(*value)
+class ManyToOneRel(object):
+ def __init__(self, to, field_name, num_in_admin=3, min_num_in_admin=None,
+ max_num_in_admin=None, num_extra_on_change=1, edit_inline=False,
+ related_name=None, limit_choices_to=None, lookup_overrides=None,
+ raw_id_admin=False, parent_link=False):
+ try:
+ to._meta
+ except AttributeError: # to._meta doesn't exist, so it must be RECURSIVE_RELATIONSHIP_CONSTANT
+ assert isinstance(to, basestring), "'to' must be either a model, a model name or the string %r" % RECURSIVE_RELATIONSHIP_CONSTANT
+ self.to, self.field_name = to, field_name
+ self.num_in_admin, self.edit_inline = num_in_admin, edit_inline
+ self.min_num_in_admin, self.max_num_in_admin = min_num_in_admin, max_num_in_admin
+ self.num_extra_on_change, self.related_name = num_extra_on_change, related_name
+ if limit_choices_to is None:
+ limit_choices_to = {}
+ self.limit_choices_to = limit_choices_to
+ self.lookup_overrides = lookup_overrides or {}
+ self.raw_id_admin = raw_id_admin
+ self.multiple = True
+ self.parent_link = parent_link
+
+ def get_related_field(self):
+ """
+ Returns the Field in the 'to' object to which this relationship is
+ tied.
+ """
+ data = self.to._meta.get_field_by_name(self.field_name)
+ if not data[2]:
+ raise FieldDoesNotExist("No related field named '%s'" %
+ self.field_name)
+ return data[0]
+
+class OneToOneRel(ManyToOneRel):
+ def __init__(self, to, field_name, num_in_admin=0, min_num_in_admin=None,
+ max_num_in_admin=None, num_extra_on_change=None, edit_inline=False,
+ related_name=None, limit_choices_to=None, lookup_overrides=None,
+ raw_id_admin=False, parent_link=False):
+ # NOTE: *_num_in_admin and num_extra_on_change are intentionally
+ # ignored here. We accept them as parameters only to match the calling
+ # signature of ManyToOneRel.__init__().
+ super(OneToOneRel, self).__init__(to, field_name, num_in_admin,
+ edit_inline=edit_inline, related_name=related_name,
+ limit_choices_to=limit_choices_to,
+ lookup_overrides=lookup_overrides, raw_id_admin=raw_id_admin,
+ parent_link=parent_link)
+ self.multiple = False
+
+class ManyToManyRel(object):
+ def __init__(self, to, num_in_admin=0, related_name=None,
+ filter_interface=None, limit_choices_to=None, raw_id_admin=False, symmetrical=True):
+ self.to = to
+ self.num_in_admin = num_in_admin
+ self.related_name = related_name
+ self.filter_interface = filter_interface
+ if limit_choices_to is None:
+ limit_choices_to = {}
+ self.limit_choices_to = limit_choices_to
+ self.edit_inline = False
+ self.raw_id_admin = raw_id_admin
+ self.symmetrical = symmetrical
+ self.multiple = True
+
+ assert not (self.raw_id_admin and self.filter_interface), "ManyToManyRels may not use both raw_id_admin and filter_interface"
+
class ForeignKey(RelatedField, Field):
empty_strings_allowed = False
- def __init__(self, to, to_field=None, **kwargs):
+ def __init__(self, to, to_field=None, rel_class=ManyToOneRel, **kwargs):
try:
to_name = to._meta.object_name.lower()
except AttributeError: # to._meta doesn't exist, so it must be RECURSIVE_RELATIONSHIP_CONSTANT
- assert isinstance(to, basestring), "ForeignKey(%r) is invalid. First parameter to ForeignKey must be either a model, a model name, or the string %r" % (to, RECURSIVE_RELATIONSHIP_CONSTANT)
+ assert isinstance(to, basestring), "%s(%r) is invalid. First parameter to ForeignKey must be either a model, a model name, or the string %r" % (self.__class__.__name__, to, RECURSIVE_RELATIONSHIP_CONSTANT)
else:
to_field = to_field or to._meta.pk.name
kwargs['verbose_name'] = kwargs.get('verbose_name', '')
@@ -511,7 +587,7 @@ class ForeignKey(RelatedField, Field):
warnings.warn("edit_inline_type is deprecated. Use edit_inline instead.", DeprecationWarning)
kwargs['edit_inline'] = kwargs.pop('edit_inline_type')
- kwargs['rel'] = ManyToOneRel(to, to_field,
+ kwargs['rel'] = rel_class(to, to_field,
num_in_admin=kwargs.pop('num_in_admin', 3),
min_num_in_admin=kwargs.pop('min_num_in_admin', None),
max_num_in_admin=kwargs.pop('max_num_in_admin', None),
@@ -520,7 +596,8 @@ class ForeignKey(RelatedField, Field):
related_name=kwargs.pop('related_name', None),
limit_choices_to=kwargs.pop('limit_choices_to', None),
lookup_overrides=kwargs.pop('lookup_overrides', None),
- raw_id_admin=kwargs.pop('raw_id_admin', False))
+ raw_id_admin=kwargs.pop('raw_id_admin', False),
+ parent_link=kwargs.pop('parent_link', False))
Field.__init__(self, **kwargs)
self.db_index = True
@@ -606,82 +683,25 @@ class ForeignKey(RelatedField, Field):
return IntegerField().db_type()
return rel_field.db_type()
-class OneToOneField(RelatedField, IntegerField):
+class OneToOneField(ForeignKey):
+ """
+ A OneToOneField is essentially the same as a ForeignKey, with the exception
+ that always carries a "unique" constraint with it and the reverse relation
+ always returns the object pointed to (since there will only ever be one),
+ rather than returning a list.
+ """
def __init__(self, to, to_field=None, **kwargs):
- try:
- to_name = to._meta.object_name.lower()
- except AttributeError: # to._meta doesn't exist, so it must be RECURSIVE_RELATIONSHIP_CONSTANT
- assert isinstance(to, basestring), "OneToOneField(%r) is invalid. First parameter to OneToOneField must be either a model, a model name, or the string %r" % (to, RECURSIVE_RELATIONSHIP_CONSTANT)
- else:
- to_field = to_field or to._meta.pk.name
- kwargs['verbose_name'] = kwargs.get('verbose_name', '')
-
- if 'edit_inline_type' in kwargs:
- import warnings
- warnings.warn("edit_inline_type is deprecated. Use edit_inline instead.", DeprecationWarning)
- kwargs['edit_inline'] = kwargs.pop('edit_inline_type')
-
- kwargs['rel'] = OneToOneRel(to, to_field,
- num_in_admin=kwargs.pop('num_in_admin', 0),
- edit_inline=kwargs.pop('edit_inline', False),
- related_name=kwargs.pop('related_name', None),
- limit_choices_to=kwargs.pop('limit_choices_to', None),
- lookup_overrides=kwargs.pop('lookup_overrides', None),
- raw_id_admin=kwargs.pop('raw_id_admin', False))
- kwargs['primary_key'] = True
- IntegerField.__init__(self, **kwargs)
-
- self.db_index = True
-
- def get_attname(self):
- return '%s_id' % self.name
-
- def get_validator_unique_lookup_type(self):
- return '%s__%s__exact' % (self.name, self.rel.get_related_field().name)
-
- # TODO: Copied from ForeignKey... putting this in RelatedField adversely affects
- # ManyToManyField. This works for now.
- def prepare_field_objs_and_params(self, manipulator, name_prefix):
- params = {'validator_list': self.validator_list[:], 'member_name': name_prefix + self.attname}
- if self.rel.raw_id_admin:
- field_objs = self.get_manipulator_field_objs()
- params['validator_list'].append(curry(manipulator_valid_rel_key, self, manipulator))
- else:
- if self.radio_admin:
- field_objs = [oldforms.RadioSelectField]
- params['ul_class'] = get_ul_class(self.radio_admin)
- else:
- if self.null:
- field_objs = [oldforms.NullSelectField]
- else:
- field_objs = [oldforms.SelectField]
- params['choices'] = self.get_choices_default()
- return field_objs, params
-
- def contribute_to_class(self, cls, name):
- super(OneToOneField, self).contribute_to_class(cls, name)
- setattr(cls, self.name, ReverseSingleRelatedObjectDescriptor(self))
+ kwargs['unique'] = True
+ if 'num_in_admin' not in kwargs:
+ kwargs['num_in_admin'] = 0
+ super(OneToOneField, self).__init__(to, to_field, OneToOneRel, **kwargs)
def contribute_to_related_class(self, cls, related):
- setattr(cls, related.get_accessor_name(), SingleRelatedObjectDescriptor(related))
+ setattr(cls, related.get_accessor_name(),
+ SingleRelatedObjectDescriptor(related))
if not cls._meta.one_to_one_field:
cls._meta.one_to_one_field = self
- def formfield(self, **kwargs):
- defaults = {'form_class': forms.ModelChoiceField, 'queryset': self.rel.to._default_manager.all()}
- defaults.update(kwargs)
- return super(OneToOneField, self).formfield(**defaults)
-
- def db_type(self):
- # The database column type of a OneToOneField is the column type
- # of the field to which it points. An exception is if the OneToOneField
- # points to an AutoField/PositiveIntegerField/PositiveSmallIntegerField,
- # in which case the column type is simply that of an IntegerField.
- rel_field = self.rel.get_related_field()
- if isinstance(rel_field, (AutoField, PositiveIntegerField, PositiveSmallIntegerField)):
- return IntegerField().db_type()
- return rel_field.db_type()
-
class ManyToManyField(RelatedField, Field):
def __init__(self, to, **kwargs):
kwargs['verbose_name'] = kwargs.get('verbose_name', None)
@@ -798,7 +818,7 @@ class ManyToManyField(RelatedField, Field):
def save_form_data(self, instance, data):
setattr(instance, self.attname, data)
-
+
def formfield(self, **kwargs):
defaults = {'form_class': forms.ModelMultipleChoiceField, 'queryset': self.rel.to._default_manager.all()}
defaults.update(kwargs)
@@ -813,56 +833,3 @@ class ManyToManyField(RelatedField, Field):
# so return None.
return None
-class ManyToOneRel(object):
- def __init__(self, to, field_name, num_in_admin=3, min_num_in_admin=None,
- max_num_in_admin=None, num_extra_on_change=1, edit_inline=False,
- related_name=None, limit_choices_to=None, lookup_overrides=None, raw_id_admin=False):
- try:
- to._meta
- except AttributeError: # to._meta doesn't exist, so it must be RECURSIVE_RELATIONSHIP_CONSTANT
- assert isinstance(to, basestring), "'to' must be either a model, a model name or the string %r" % RECURSIVE_RELATIONSHIP_CONSTANT
- self.to, self.field_name = to, field_name
- self.num_in_admin, self.edit_inline = num_in_admin, edit_inline
- self.min_num_in_admin, self.max_num_in_admin = min_num_in_admin, max_num_in_admin
- self.num_extra_on_change, self.related_name = num_extra_on_change, related_name
- if limit_choices_to is None:
- limit_choices_to = {}
- self.limit_choices_to = limit_choices_to
- self.lookup_overrides = lookup_overrides or {}
- self.raw_id_admin = raw_id_admin
- self.multiple = True
-
- def get_related_field(self):
- "Returns the Field in the 'to' object to which this relationship is tied."
- return self.to._meta.get_field(self.field_name)
-
-class OneToOneRel(ManyToOneRel):
- def __init__(self, to, field_name, num_in_admin=0, edit_inline=False,
- related_name=None, limit_choices_to=None, lookup_overrides=None,
- raw_id_admin=False):
- self.to, self.field_name = to, field_name
- self.num_in_admin, self.edit_inline = num_in_admin, edit_inline
- self.related_name = related_name
- if limit_choices_to is None:
- limit_choices_to = {}
- self.limit_choices_to = limit_choices_to
- self.lookup_overrides = lookup_overrides or {}
- self.raw_id_admin = raw_id_admin
- self.multiple = False
-
-class ManyToManyRel(object):
- def __init__(self, to, num_in_admin=0, related_name=None,
- filter_interface=None, limit_choices_to=None, raw_id_admin=False, symmetrical=True):
- self.to = to
- self.num_in_admin = num_in_admin
- self.related_name = related_name
- self.filter_interface = filter_interface
- if limit_choices_to is None:
- limit_choices_to = {}
- self.limit_choices_to = limit_choices_to
- self.edit_inline = False
- self.raw_id_admin = raw_id_admin
- self.symmetrical = symmetrical
- self.multiple = True
-
- assert not (self.raw_id_admin and self.filter_interface), "ManyToManyRels may not use both raw_id_admin and filter_interface"
diff --git a/django/db/models/manager.py b/django/db/models/manager.py
index 7b2e916738..3a9da34a49 100644
--- a/django/db/models/manager.py
+++ b/django/db/models/manager.py
@@ -1,11 +1,13 @@
-from django.db.models.query import QuerySet, EmptyQuerySet
+import copy
+
+from django.db.models.query import QuerySet, EmptyQuerySet, insert_query
from django.dispatch import dispatcher
from django.db.models import signals
from django.db.models.fields import FieldDoesNotExist
def ensure_default_manager(sender):
cls = sender
- if not hasattr(cls, '_default_manager'):
+ if not getattr(cls, '_default_manager', None) and not cls._meta.abstract:
# Create the default manager, if needed.
try:
cls._meta.get_field('objects')
@@ -31,13 +33,24 @@ class Manager(object):
# TODO: Use weakref because of possible memory leak / circular reference.
self.model = model
setattr(model, name, ManagerDescriptor(self))
- if not hasattr(model, '_default_manager') or self.creation_counter < model._default_manager.creation_counter:
+ if not getattr(model, '_default_manager', None) or self.creation_counter < model._default_manager.creation_counter:
model._default_manager = self
+ def _copy_to_model(self, model):
+ """
+ Makes a copy of the manager and assigns it to 'model', which should be
+ a child of the existing model (used when inheriting a manager from an
+ abstract base class).
+ """
+ assert issubclass(model, self.model)
+ mgr = copy.copy(self)
+ mgr.model = model
+ return mgr
+
#######################
# PROXIES TO QUERYSET #
#######################
-
+
def get_empty_query_set(self):
return EmptyQuerySet(self.model)
@@ -46,7 +59,7 @@ class Manager(object):
to easily customize the behavior of the Manager.
"""
return QuerySet(self.model)
-
+
def none(self):
return self.get_empty_query_set()
@@ -70,7 +83,7 @@ class Manager(object):
def get_or_create(self, **kwargs):
return self.get_query_set().get_or_create(**kwargs)
-
+
def create(self, **kwargs):
return self.get_query_set().create(**kwargs)
@@ -101,6 +114,21 @@ class Manager(object):
def values(self, *args, **kwargs):
return self.get_query_set().values(*args, **kwargs)
+ def values_list(self, *args, **kwargs):
+ return self.get_query_set().values_list(*args, **kwargs)
+
+ def update(self, *args, **kwargs):
+ return self.get_query_set().update(*args, **kwargs)
+
+ def reverse(self, *args, **kwargs):
+ return self.get_query_set().reverse(*args, **kwargs)
+
+ def _insert(self, values, **kwargs):
+ return insert_query(self.model, values, **kwargs)
+
+ def _update(self, values, **kwargs):
+ return self.get_query_set()._update(values, **kwargs)
+
class ManagerDescriptor(object):
# This class ensures managers aren't accessible via model instances.
# For example, Poll.objects works, but poll_obj.objects raises AttributeError.
diff --git a/django/db/models/options.py b/django/db/models/options.py
index 37ace0a7c1..8fcaed485e 100644
--- a/django/db/models/options.py
+++ b/django/db/models/options.py
@@ -1,25 +1,32 @@
+import re
+from bisect import bisect
+try:
+ set
+except NameError:
+ from sets import Set as set # Python 2.3 fallback
+
from django.conf import settings
from django.db.models.related import RelatedObject
from django.db.models.fields.related import ManyToManyRel
from django.db.models.fields import AutoField, FieldDoesNotExist
+from django.db.models.fields.proxy import OrderWrt
from django.db.models.loading import get_models, app_cache_ready
-from django.db.models.query import orderlist2sql
from django.db.models import Manager
from django.utils.translation import activate, deactivate_all, get_language, string_concat
from django.utils.encoding import force_unicode, smart_str
-from bisect import bisect
-import re
+from django.utils.datastructures import SortedDict
# Calculate the verbose_name by converting from InitialCaps to "lowercase with spaces".
get_verbose_name = lambda class_name: re.sub('(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))', ' \\1', class_name).lower().strip()
DEFAULT_NAMES = ('verbose_name', 'db_table', 'ordering',
'unique_together', 'permissions', 'get_latest_by',
- 'order_with_respect_to', 'app_label', 'db_tablespace')
+ 'order_with_respect_to', 'app_label', 'db_tablespace',
+ 'abstract')
class Options(object):
def __init__(self, meta):
- self.fields, self.many_to_many = [], []
+ self.local_fields, self.local_many_to_many = [], []
self.module_name, self.verbose_name = None, None
self.verbose_name_plural = None
self.db_table = ''
@@ -35,7 +42,8 @@ class Options(object):
self.pk = None
self.has_auto_field, self.auto_field = False, None
self.one_to_one_field = None
- self.parents = []
+ self.abstract = False
+ self.parents = SortedDict()
def contribute_to_class(self, cls, name):
cls._meta = self
@@ -47,11 +55,14 @@ class Options(object):
# Next, apply any overridden values from 'class Meta'.
if self.meta:
- meta_attrs = self.meta.__dict__
+ meta_attrs = self.meta.__dict__.copy()
del meta_attrs['__module__']
del meta_attrs['__doc__']
for attr_name in DEFAULT_NAMES:
- setattr(self, attr_name, meta_attrs.pop(attr_name, getattr(self, attr_name)))
+ if attr_name in meta_attrs:
+ setattr(self, attr_name, meta_attrs.pop(attr_name))
+ elif hasattr(self.meta, attr_name):
+ setattr(self, attr_name, getattr(self.meta, attr_name))
# unique_together can be either a tuple of tuples, or a single
# tuple of two strings. Normalize it to a tuple of tuples, so that
@@ -82,9 +93,16 @@ class Options(object):
self.order_with_respect_to = None
if self.pk is None:
- auto = AutoField(verbose_name='ID', primary_key=True)
- auto.creation_counter = -1
- model.add_to_class('id', auto)
+ if self.parents:
+ # Promote the first parent link in lieu of adding yet another
+ # field.
+ field = self.parents.value_for_index(0)
+ field.primary_key = True
+ self.pk = field
+ else:
+ auto = AutoField(verbose_name='ID', primary_key=True,
+ auto_created=True)
+ model.add_to_class('id', auto)
# If the db_table wasn't provided, use the app_label + module_name.
if not self.db_table:
@@ -94,14 +112,26 @@ class Options(object):
def add_field(self, field):
# Insert the given field in the order in which it was created, using
# the "creation_counter" attribute of the field.
- # Move many-to-many related fields from self.fields into self.many_to_many.
+ # Move many-to-many related fields from self.fields into
+ # self.many_to_many.
if field.rel and isinstance(field.rel, ManyToManyRel):
- self.many_to_many.insert(bisect(self.many_to_many, field), field)
+ self.local_many_to_many.insert(bisect(self.local_many_to_many, field), field)
+ if hasattr(self, '_m2m_cache'):
+ del self._m2m_cache
else:
- self.fields.insert(bisect(self.fields, field), field)
- if not self.pk and field.primary_key:
- self.pk = field
- field.serialize = False
+ self.local_fields.insert(bisect(self.local_fields, field), field)
+ self.setup_pk(field)
+ if hasattr(self, '_field_cache'):
+ del self._field_cache
+ del self._field_name_cache
+
+ if hasattr(self, '_name_map'):
+ del self._name_map
+
+ def setup_pk(self, field):
+ if not self.pk and field.primary_key:
+ self.pk = field
+ field.serialize = False
def __repr__(self):
return '<Options for %s>' % self.object_name
@@ -122,19 +152,137 @@ class Options(object):
return raw
verbose_name_raw = property(verbose_name_raw)
+ def _fields(self):
+ """
+ The getter for self.fields. This returns the list of field objects
+ available to this model (including through parent models).
+
+ Callers are not permitted to modify this list, since it's a reference
+ to this instance (not a copy).
+ """
+ try:
+ self._field_name_cache
+ except AttributeError:
+ self._fill_fields_cache()
+ return self._field_name_cache
+ fields = property(_fields)
+
+ def get_fields_with_model(self):
+ """
+ Returns a sequence of (field, model) pairs for all fields. The "model"
+ element is None for fields on the current model. Mostly of use when
+ constructing queries so that we know which model a field belongs to.
+ """
+ try:
+ self._field_cache
+ except AttributeError:
+ self._fill_fields_cache()
+ return self._field_cache
+
+ def _fill_fields_cache(self):
+ cache = []
+ for parent in self.parents:
+ for field, model in parent._meta.get_fields_with_model():
+ if model:
+ cache.append((field, model))
+ else:
+ cache.append((field, parent))
+ cache.extend([(f, None) for f in self.local_fields])
+ self._field_cache = tuple(cache)
+ self._field_name_cache = [x for x, _ in cache]
+
+ def _many_to_many(self):
+ try:
+ self._m2m_cache
+ except AttributeError:
+ self._fill_m2m_cache()
+ return self._m2m_cache.keys()
+ many_to_many = property(_many_to_many)
+
+ def get_m2m_with_model(self):
+ """
+ The many-to-many version of get_fields_with_model().
+ """
+ try:
+ self._m2m_cache
+ except AttributeError:
+ self._fill_m2m_cache()
+ return self._m2m_cache.items()
+
+ def _fill_m2m_cache(self):
+ cache = SortedDict()
+ for parent in self.parents:
+ for field, model in parent._meta.get_m2m_with_model():
+ if model:
+ cache[field] = model
+ else:
+ cache[field] = parent
+ for field in self.local_many_to_many:
+ cache[field] = None
+ self._m2m_cache = cache
+
def get_field(self, name, many_to_many=True):
- "Returns the requested field by name. Raises FieldDoesNotExist on error."
+ """
+ Returns the requested field by name. Raises FieldDoesNotExist on error.
+ """
to_search = many_to_many and (self.fields + self.many_to_many) or self.fields
for f in to_search:
if f.name == name:
return f
raise FieldDoesNotExist, '%s has no field named %r' % (self.object_name, name)
- def get_order_sql(self, table_prefix=''):
- "Returns the full 'ORDER BY' clause for this object, according to self.ordering."
- if not self.ordering: return ''
- pre = table_prefix and (table_prefix + '.') or ''
- return 'ORDER BY ' + orderlist2sql(self.ordering, self, pre)
+ def get_field_by_name(self, name):
+ """
+ Returns the (field_object, model, direct, m2m), where field_object is
+ the Field instance for the given name, model is the model containing
+ this field (None for local fields), direct is True if the field exists
+ on this model, and m2m is True for many-to-many relations. When
+ 'direct' is False, 'field_object' is the corresponding RelatedObject
+ for this field (since the field doesn't have an instance associated
+ with it).
+
+ Uses a cache internally, so after the first access, this is very fast.
+ """
+ try:
+ try:
+ return self._name_map[name]
+ except AttributeError:
+ cache = self.init_name_map()
+ return self._name_map[name]
+ except KeyError:
+ raise FieldDoesNotExist('%s has no field named %r'
+ % (self.object_name, name))
+
+ def get_all_field_names(self):
+ """
+ Returns a list of all field names that are possible for this model
+ (including reverse relation names).
+ """
+ try:
+ cache = self._name_map
+ except AttributeError:
+ cache = self.init_name_map()
+ names = cache.keys()
+ names.sort()
+ return names
+
+ def init_name_map(self):
+ """
+ Initialises the field name -> field object mapping.
+ """
+ cache = dict([(f.name, (f, m, True, False)) for f, m in
+ self.get_fields_with_model()])
+ for f, model in self.get_m2m_with_model():
+ cache[f.name] = (f, model, True, True)
+ for f, model in self.get_all_related_m2m_objects_with_model():
+ cache[f.field.related_query_name()] = (f, model, False, True)
+ for f, model in self.get_all_related_objects_with_model():
+ cache[f.field.related_query_name()] = (f, model, False, False)
+ if self.order_with_respect_to:
+ cache['_order'] = OrderWrt(), None, True, False
+ if app_cache_ready():
+ self._name_map = cache
+ return cache
def get_add_permission(self):
return 'add_%s' % self.object_name.lower()
@@ -145,17 +293,81 @@ class Options(object):
def get_delete_permission(self):
return 'delete_%s' % self.object_name.lower()
- def get_all_related_objects(self):
- try: # Try the cache first.
- return self._all_related_objects
+ def get_all_related_objects(self, local_only=False):
+ try:
+ self._related_objects_cache
+ except AttributeError:
+ self._fill_related_objects_cache()
+ if local_only:
+ return [k for k, v in self._related_objects_cache.items() if not v]
+ return self._related_objects_cache.keys()
+
+ def get_all_related_objects_with_model(self):
+ """
+ Returns a list of (related-object, model) pairs. Similar to
+ get_fields_with_model().
+ """
+ try:
+ self._related_objects_cache
except AttributeError:
- rel_objs = []
- for klass in get_models():
- for f in klass._meta.fields:
- if f.rel and not isinstance(f.rel.to, str) and self == f.rel.to._meta:
- rel_objs.append(RelatedObject(f.rel.to, klass, f))
- self._all_related_objects = rel_objs
- return rel_objs
+ self._fill_related_objects_cache()
+ return self._related_objects_cache.items()
+
+ def _fill_related_objects_cache(self):
+ cache = SortedDict()
+ parent_list = self.get_parent_list()
+ for parent in self.parents:
+ for obj, model in parent._meta.get_all_related_objects_with_model():
+ if (obj.field.creation_counter < 0 or obj.field.rel.parent_link) and obj.model not in parent_list:
+ continue
+ if not model:
+ cache[obj] = parent
+ else:
+ cache[obj] = model
+ for klass in get_models():
+ for f in klass._meta.local_fields:
+ if f.rel and not isinstance(f.rel.to, str) and self == f.rel.to._meta:
+ cache[RelatedObject(f.rel.to, klass, f)] = None
+ self._related_objects_cache = cache
+
+ def get_all_related_many_to_many_objects(self, local_only=False):
+ try:
+ cache = self._related_many_to_many_cache
+ except AttributeError:
+ cache = self._fill_related_many_to_many_cache()
+ if local_only:
+ return [k for k, v in cache.items() if not v]
+ return cache.keys()
+
+ def get_all_related_m2m_objects_with_model(self):
+ """
+ Returns a list of (related-m2m-object, model) pairs. Similar to
+ get_fields_with_model().
+ """
+ try:
+ cache = self._related_many_to_many_cache
+ except AttributeError:
+ cache = self._fill_related_many_to_many_cache()
+ return cache.items()
+
+ def _fill_related_many_to_many_cache(self):
+ cache = SortedDict()
+ parent_list = self.get_parent_list()
+ for parent in self.parents:
+ for obj, model in parent._meta.get_all_related_m2m_objects_with_model():
+ if obj.field.creation_counter < 0 and obj.model not in parent_list:
+ continue
+ if not model:
+ cache[obj] = parent
+ else:
+ cache[obj] = model
+ for klass in get_models():
+ for f in klass._meta.local_many_to_many:
+ if f.rel and not isinstance(f.rel.to, str) and self == f.rel.to._meta:
+ cache[RelatedObject(f.rel.to, klass, f)] = None
+ if app_cache_ready():
+ self._related_many_to_many_cache = cache
+ return cache
def get_followed_related_objects(self, follow=None):
if follow == None:
@@ -179,18 +391,34 @@ class Options(object):
follow[f.name] = fol
return follow
- def get_all_related_many_to_many_objects(self):
- try: # Try the cache first.
- return self._all_related_many_to_many_objects
- except AttributeError:
- rel_objs = []
- for klass in get_models():
- for f in klass._meta.many_to_many:
- if f.rel and not isinstance(f.rel.to, str) and self == f.rel.to._meta:
- rel_objs.append(RelatedObject(f.rel.to, klass, f))
- if app_cache_ready():
- self._all_related_many_to_many_objects = rel_objs
- return rel_objs
+ def get_base_chain(self, model):
+ """
+ Returns a list of parent classes leading to 'model' (order from closet
+ to most distant ancestor). This has to handle the case were 'model' is
+ a granparent or even more distant relation.
+ """
+ if not self.parents:
+ return
+ if model in self.parents:
+ return [model]
+ for parent in self.parents:
+ res = parent._meta.get_base_chain(model)
+ if res:
+ res.insert(0, parent)
+ return res
+ raise TypeError('%r is not an ancestor of this model'
+ % model._meta.module_name)
+
+ def get_parent_list(self):
+ """
+ Returns a list of all the ancestor of this model as a list. Useful for
+ determining if something is an ancestor, regardless of lineage.
+ """
+ result = set()
+ for parent in self.parents:
+ result.add(parent)
+ result.update(parent._meta.get_parent_list())
+ return result
def get_ordered_objects(self):
"Returns a list of Options objects that are ordered with respect to this object."
diff --git a/django/db/models/query.py b/django/db/models/query.py
index a3d00c2ead..e1e2bb19f2 100644
--- a/django/db/models/query.py
+++ b/django/db/models/query.py
@@ -1,173 +1,134 @@
+import warnings
+
from django.conf import settings
from django.db import connection, transaction, IntegrityError
from django.db.models.fields import DateField, FieldDoesNotExist
-from django.db.models import signals, loading
+from django.db.models.query_utils import Q
+from django.db.models import signals, sql
from django.dispatch import dispatcher
from django.utils.datastructures import SortedDict
-from django.utils.encoding import smart_unicode
-from django.contrib.contenttypes import generic
-import datetime
-import operator
-import re
-
-try:
- set
-except NameError:
- from sets import Set as set # Python 2.3 fallback
-
-# The string constant used to separate query parts
-LOOKUP_SEPARATOR = '__'
-
-# The list of valid query types
-QUERY_TERMS = (
- 'exact', 'iexact', 'contains', 'icontains',
- 'gt', 'gte', 'lt', 'lte', 'in',
- 'startswith', 'istartswith', 'endswith', 'iendswith',
- 'range', 'year', 'month', 'day', 'isnull', 'search',
- 'regex', 'iregex',
-)
-
-# Size of each "chunk" for get_iterator calls.
-# Larger values are slightly faster at the expense of more storage space.
-GET_ITERATOR_CHUNK_SIZE = 100
-
-class EmptyResultSet(Exception):
- pass
-
-####################
-# HELPER FUNCTIONS #
-####################
-
-# Django currently supports two forms of ordering.
-# Form 1 (deprecated) example:
-# order_by=(('pub_date', 'DESC'), ('headline', 'ASC'), (None, 'RANDOM'))
-# Form 2 (new-style) example:
-# order_by=('-pub_date', 'headline', '?')
-# Form 1 is deprecated and will no longer be supported for Django's first
-# official release. The following code converts from Form 1 to Form 2.
-LEGACY_ORDERING_MAPPING = {'ASC': '_', 'DESC': '-_', 'RANDOM': '?'}
+# Used to control how many objects are worked with at once in some cases (e.g.
+# when deleting objects).
+CHUNK_SIZE = 100
+ITER_CHUNK_SIZE = CHUNK_SIZE
-def handle_legacy_orderlist(order_list):
- if not order_list or isinstance(order_list[0], basestring):
- return order_list
- else:
- import warnings
- new_order_list = [LEGACY_ORDERING_MAPPING[j.upper()].replace('_', smart_unicode(i)) for i, j in order_list]
- warnings.warn("%r ordering syntax is deprecated. Use %r instead." % (order_list, new_order_list), DeprecationWarning)
- return new_order_list
+# Pull into this namespace for backwards compatibility
+EmptyResultSet = sql.EmptyResultSet
-def orderfield2column(f, opts):
- try:
- return opts.get_field(f, False).column
- except FieldDoesNotExist:
- return f
-
-def orderlist2sql(order_list, opts, prefix=''):
- qn = connection.ops.quote_name
- if prefix.endswith('.'):
- prefix = qn(prefix[:-1]) + '.'
- output = []
- for f in handle_legacy_orderlist(order_list):
- if f.startswith('-'):
- output.append('%s%s DESC' % (prefix, qn(orderfield2column(f[1:], opts))))
- elif f == '?':
- output.append(connection.ops.random_function_sql())
- else:
- output.append('%s%s ASC' % (prefix, qn(orderfield2column(f, opts))))
- return ', '.join(output)
-
-def quote_only_if_word(word):
- if re.search('\W', word): # Don't quote if there are spaces or non-word chars.
- return word
- else:
- return connection.ops.quote_name(word)
-
-class _QuerySet(object):
+class QuerySet(object):
"Represents a lazy database lookup for a set of objects"
- def __init__(self, model=None):
+ def __init__(self, model=None, query=None):
self.model = model
- self._filters = Q()
- self._order_by = None # Ordering, e.g. ('date', '-name'). If None, use model's ordering.
- self._select_related = False # Whether to fill cache for related objects.
- self._max_related_depth = 0 # Maximum "depth" for select_related
- self._distinct = False # Whether the query should use SELECT DISTINCT.
- self._select = {} # Dictionary of attname -> SQL.
- self._where = [] # List of extra WHERE clauses to use.
- self._params = [] # List of params to use for extra WHERE clauses.
- self._tables = [] # List of extra tables to use.
- self._offset = None # OFFSET clause.
- self._limit = None # LIMIT clause.
+ self.query = query or sql.Query(self.model, connection)
self._result_cache = None
+ self._iter = None
########################
# PYTHON MAGIC METHODS #
########################
def __repr__(self):
- return repr(self._get_data())
+ return repr(list(self))
def __len__(self):
- return len(self._get_data())
+ # Since __len__ is called quite frequently (for example, as part of
+ # list(qs), we make some effort here to be as efficient as possible
+ # whilst not messing up any existing iterators against the queryset.
+ if self._result_cache is None:
+ if self._iter:
+ self._result_cache = list(self._iter())
+ else:
+ self._result_cache = list(self.iterator())
+ elif self._iter:
+ self._result_cache.extend(list(self._iter))
+ return len(self._result_cache)
def __iter__(self):
- return iter(self._get_data())
+ if self._result_cache is None:
+ self._iter = self.iterator()
+ self._result_cache = []
+ if self._iter:
+ return self._result_iter()
+ # Python's list iterator is better than our version when we're just
+ # iterating over the cache.
+ return iter(self._result_cache)
+
+ def _result_iter(self):
+ pos = 0
+ while 1:
+ upper = len(self._result_cache)
+ while pos < upper:
+ yield self._result_cache[pos]
+ pos = pos + 1
+ if not self._iter:
+ raise StopIteration
+ if len(self._result_cache) <= pos:
+ self._fill_cache()
+
+ def __nonzero__(self):
+ if self._result_cache is not None:
+ return bool(self._result_cache)
+ try:
+ iter(self).next()
+ except StopIteration:
+ return False
+ return True
def __getitem__(self, k):
"Retrieve an item or slice from the set of results."
if not isinstance(k, (slice, int, long)):
raise TypeError
- assert (not isinstance(k, slice) and (k >= 0)) \
- or (isinstance(k, slice) and (k.start is None or k.start >= 0) and (k.stop is None or k.stop >= 0)), \
- "Negative indexing is not supported."
- if self._result_cache is None:
- if isinstance(k, slice):
- # Offset:
- if self._offset is None:
- offset = k.start
- elif k.start is None:
- offset = self._offset
- else:
- offset = self._offset + k.start
- # Now adjust offset to the bounds of any existing limit:
- if self._limit is not None and k.start is not None:
- limit = self._limit - k.start
- else:
- limit = self._limit
+ assert ((not isinstance(k, slice) and (k >= 0))
+ or (isinstance(k, slice) and (k.start is None or k.start >= 0)
+ and (k.stop is None or k.stop >= 0))), \
+ "Negative indexing is not supported."
- # Limit:
- if k.stop is not None and k.start is not None:
- if limit is None:
- limit = k.stop - k.start
+ if self._result_cache is not None:
+ if self._iter is not None:
+ # The result cache has only been partially populated, so we may
+ # need to fill it out a bit more.
+ if isinstance(k, slice):
+ if k.stop is not None:
+ # Some people insist on passing in strings here.
+ bound = int(k.stop)
else:
- limit = min((k.stop - k.start), limit)
+ bound = None
else:
- if limit is None:
- limit = k.stop
- else:
- if k.stop is not None:
- limit = min(k.stop, limit)
+ bound = k + 1
+ if len(self._result_cache) < bound:
+ self._fill_cache(bound - len(self._result_cache))
+ return self._result_cache[k]
- if k.step is None:
- return self._clone(_offset=offset, _limit=limit)
- else:
- return list(self._clone(_offset=offset, _limit=limit))[::k.step]
+ if isinstance(k, slice):
+ qs = self._clone()
+ if k.start is not None:
+ start = int(k.start)
else:
- try:
- return list(self._clone(_offset=k, _limit=1))[0]
- except self.model.DoesNotExist, e:
- raise IndexError, e.args
- else:
- return self._result_cache[k]
+ start = None
+ if k.stop is not None:
+ stop = int(k.stop)
+ else:
+ stop = None
+ qs.query.set_limits(start, stop)
+ return k.step and list(qs)[::k.step] or qs
+ try:
+ qs = self._clone()
+ qs.query.set_limits(k, k + 1)
+ return list(qs)[0]
+ except self.model.DoesNotExist, e:
+ raise IndexError, e.args
def __and__(self, other):
- combined = self._combine(other)
- combined._filters = self._filters & other._filters
+ self._merge_sanity_check(other)
+ combined = self._clone()
+ combined.query.combine(other.query, sql.AND)
return combined
def __or__(self, other):
- combined = self._combine(other)
- combined._filters = self._filters | other._filters
+ self._merge_sanity_check(other)
+ combined = self._clone()
+ combined.query.combine(other.query, sql.OR)
return combined
####################################
@@ -175,38 +136,27 @@ class _QuerySet(object):
####################################
def iterator(self):
- "Performs the SELECT database lookup of this QuerySet."
- try:
- select, sql, params = self._get_sql_clause()
- except EmptyResultSet:
- raise StopIteration
-
- # self._select is a dictionary, and dictionaries' key order is
- # undefined, so we convert it to a list of tuples.
- extra_select = self._select.items()
-
- cursor = connection.cursor()
- cursor.execute("SELECT " + (self._distinct and "DISTINCT " or "") + ",".join(select) + sql, params)
-
- fill_cache = self._select_related
- fields = self.model._meta.fields
- index_end = len(fields)
- has_resolve_columns = hasattr(self, 'resolve_columns')
- while 1:
- rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)
- if not rows:
- raise StopIteration
- for row in rows:
- if has_resolve_columns:
- row = self.resolve_columns(row, fields)
- if fill_cache:
- obj, index_end = get_cached_row(klass=self.model, row=row,
- index_start=0, max_depth=self._max_related_depth)
- else:
- obj = self.model(*row[:index_end])
- for i, k in enumerate(extra_select):
- setattr(obj, k[0], row[index_end+i])
- yield obj
+ """
+ An iterator over the results from applying this QuerySet to the
+ database.
+ """
+ fill_cache = self.query.select_related
+ if isinstance(fill_cache, dict):
+ requested = fill_cache
+ else:
+ requested = None
+ max_depth = self.query.max_depth
+ extra_select = self.query.extra_select.keys()
+ index_start = len(extra_select)
+ for row in self.query.results_iter():
+ if fill_cache:
+ obj, _ = get_cached_row(self.model, row, index_start,
+ max_depth, requested=requested)
+ else:
+ obj = self.model.from_sequence(row[index_start:])
+ for i, k in enumerate(extra_select):
+ setattr(obj, k, row[i])
+ yield obj
def count(self):
"""
@@ -220,50 +170,22 @@ class _QuerySet(object):
if self._result_cache is not None:
return len(self._result_cache)
- counter = self._clone()
- counter._order_by = ()
- counter._select_related = False
-
- offset = counter._offset
- limit = counter._limit
- counter._offset = None
- counter._limit = None
-
- try:
- select, sql, params = counter._get_sql_clause()
- except EmptyResultSet:
- return 0
-
- cursor = connection.cursor()
- if self._distinct:
- id_col = "%s.%s" % (connection.ops.quote_name(self.model._meta.db_table),
- connection.ops.quote_name(self.model._meta.pk.column))
- cursor.execute("SELECT COUNT(DISTINCT(%s))" % id_col + sql, params)
- else:
- cursor.execute("SELECT COUNT(*)" + sql, params)
- count = cursor.fetchone()[0]
-
- # Apply any offset and limit constraints manually, since using LIMIT or
- # OFFSET in SQL doesn't change the output of COUNT.
- if offset:
- count = max(0, count - offset)
- if limit:
- count = min(limit, count)
-
- return count
+ return self.query.get_count()
def get(self, *args, **kwargs):
- "Performs the SELECT and returns a single object matching the given keyword arguments."
+ """
+ Performs the query and returns a single object matching the given
+ keyword arguments.
+ """
clone = self.filter(*args, **kwargs)
- # clean up SQL by removing unneeded ORDER BY
- if not clone._order_by:
- clone._order_by = ()
- obj_list = list(clone)
- if len(obj_list) < 1:
- raise self.model.DoesNotExist, "%s matching query does not exist." % self.model._meta.object_name
- elif len(obj_list) > 1:
- raise self.model.MultipleObjectsReturned, "get() returned more than one %s -- it returned %s! Lookup parameters were %s" % (self.model._meta.object_name, len(obj_list), kwargs)
- return obj_list[0]
+ num = len(clone)
+ if num == 1:
+ return clone._result_cache[0]
+ if not num:
+ raise self.model.DoesNotExist("%s matching query does not exist."
+ % self.model._meta.object_name)
+ raise self.model.MultipleObjectsReturned("get() returned more than one %s -- it returned %s! Lookup parameters were %s"
+ % (self.model._meta.object_name, num, kwargs))
def create(self, **kwargs):
"""
@@ -280,7 +202,8 @@ class _QuerySet(object):
Returns a tuple of (object, created), where created is a boolean
specifying whether an object was created.
"""
- assert len(kwargs), 'get_or_create() must be passed at least one keyword argument'
+ assert kwargs, \
+ 'get_or_create() must be passed at least one keyword argument'
defaults = kwargs.pop('defaults', {})
try:
return self.get(**kwargs), False
@@ -301,400 +224,384 @@ class _QuerySet(object):
"""
latest_by = field_name or self.model._meta.get_latest_by
assert bool(latest_by), "latest() requires either a field_name parameter or 'get_latest_by' in the model"
- assert self._limit is None and self._offset is None, \
+ assert self.query.can_filter(), \
"Cannot change a query once a slice has been taken."
- return self._clone(_limit=1, _order_by=('-'+latest_by,)).get()
+ obj = self._clone()
+ obj.query.set_limits(high=1)
+ obj.query.add_ordering('-%s' % latest_by)
+ return obj.get()
def in_bulk(self, id_list):
"""
Returns a dictionary mapping each of the given IDs to the object with
that ID.
"""
- assert self._limit is None and self._offset is None, \
+ assert self.query.can_filter(), \
"Cannot use 'limit' or 'offset' with in_bulk"
- assert isinstance(id_list, (tuple, list)), "in_bulk() must be provided with a list of IDs."
- qn = connection.ops.quote_name
- id_list = list(id_list)
- if id_list == []:
+ assert isinstance(id_list, (tuple, list)), \
+ "in_bulk() must be provided with a list of IDs."
+ if not id_list:
return {}
qs = self._clone()
- qs._where.append("%s.%s IN (%s)" % (qn(self.model._meta.db_table), qn(self.model._meta.pk.column), ",".join(['%s'] * len(id_list))))
- qs._params.extend(id_list)
+ qs.query.add_filter(('pk__in', id_list))
return dict([(obj._get_pk_val(), obj) for obj in qs.iterator()])
def delete(self):
"""
Deletes the records in the current QuerySet.
"""
- assert self._limit is None and self._offset is None, \
- "Cannot use 'limit' or 'offset' with delete."
+ assert self.query.can_filter(), \
+ "Cannot use 'limit' or 'offset' with delete."
del_query = self._clone()
- # disable non-supported fields
- del_query._select_related = False
- del_query._order_by = []
+ # Disable non-supported fields.
+ del_query.query.select_related = False
+ del_query.query.clear_ordering()
- # Delete objects in chunks to prevent an the list of
- # related objects from becoming too long
- more_objects = True
- while more_objects:
- # Collect all the objects to be deleted in this chunk, and all the objects
- # that are related to the objects that are to be deleted
+ # Delete objects in chunks to prevent the list of related objects from
+ # becoming too long.
+ while 1:
+ # Collect all the objects to be deleted in this chunk, and all the
+ # objects that are related to the objects that are to be deleted.
seen_objs = SortedDict()
- more_objects = False
- for object in del_query[0:GET_ITERATOR_CHUNK_SIZE]:
- more_objects = True
+ for object in del_query[:CHUNK_SIZE]:
object._collect_sub_objects(seen_objs)
- # If one or more objects were found, delete them.
- # Otherwise, stop looping.
- if more_objects:
- delete_objects(seen_objs)
+ if not seen_objs:
+ break
+ delete_objects(seen_objs)
# Clear the result cache, in case this QuerySet gets reused.
self._result_cache = None
delete.alters_data = True
+ def update(self, **kwargs):
+ """
+ Updates all elements in the current QuerySet, setting all the given
+ fields to the appropriate values.
+ """
+ query = self.query.clone(sql.UpdateQuery)
+ query.add_update_values(kwargs)
+ query.execute_sql(None)
+ self._result_cache = None
+ update.alters_data = True
+
+ def _update(self, values):
+ """
+ A version of update that accepts field objects instead of field names.
+ Used primarily for model saving and not intended for use by general
+ code (it requires too much poking around at model internals to be
+ useful at that level).
+ """
+ query = self.query.clone(sql.UpdateQuery)
+ query.add_update_fields(values)
+ query.execute_sql(None)
+ self._result_cache = None
+ _update.alters_data = True
+
##################################################
# PUBLIC METHODS THAT RETURN A QUERYSET SUBCLASS #
##################################################
def values(self, *fields):
- return self._clone(klass=ValuesQuerySet, _fields=fields)
+ return self._clone(klass=ValuesQuerySet, setup=True, _fields=fields)
+
+ def values_list(self, *fields, **kwargs):
+ flat = kwargs.pop('flat', False)
+ if kwargs:
+ raise TypeError('Unexpected keyword arguments to values_list: %s'
+ % (kwargs.keys(),))
+ if flat and len(fields) > 1:
+ raise TypeError("'flat' is not valid when values_list is called with more than one field.")
+ return self._clone(klass=ValuesListQuerySet, setup=True, flat=flat,
+ _fields=fields)
def dates(self, field_name, kind, order='ASC'):
"""
Returns a list of datetime objects representing all available dates
for the given field_name, scoped to 'kind'.
"""
- assert kind in ("month", "year", "day"), "'kind' must be one of 'year', 'month' or 'day'."
- assert order in ('ASC', 'DESC'), "'order' must be either 'ASC' or 'DESC'."
+ assert kind in ("month", "year", "day"), \
+ "'kind' must be one of 'year', 'month' or 'day'."
+ assert order in ('ASC', 'DESC'), \
+ "'order' must be either 'ASC' or 'DESC'."
# Let the FieldDoesNotExist exception propagate.
field = self.model._meta.get_field(field_name, many_to_many=False)
- assert isinstance(field, DateField), "%r isn't a DateField." % field_name
- return self._clone(klass=DateQuerySet, _field=field, _kind=kind, _order=order)
+ assert isinstance(field, DateField), "%r isn't a DateField." \
+ % field_name
+ return self._clone(klass=DateQuerySet, setup=True, _field=field,
+ _kind=kind, _order=order)
+
+ def none(self):
+ """
+ Returns an empty queryset.
+ """
+ return self._clone(klass=EmptyQuerySet)
##################################################################
# PUBLIC METHODS THAT ALTER ATTRIBUTES AND RETURN A NEW QUERYSET #
##################################################################
+ def all(self):
+ """
+ Returns a new QuerySet that is a copy of the current one. This allows a
+ QuerySet to proxy for a model manager in some cases.
+ """
+ return self._clone()
+
def filter(self, *args, **kwargs):
- "Returns a new QuerySet instance with the args ANDed to the existing set."
- return self._filter_or_exclude(None, *args, **kwargs)
+ """
+ Returns a new QuerySet instance with the args ANDed to the existing
+ set.
+ """
+ return self._filter_or_exclude(False, *args, **kwargs)
def exclude(self, *args, **kwargs):
- "Returns a new QuerySet instance with NOT (args) ANDed to the existing set."
- return self._filter_or_exclude(QNot, *args, **kwargs)
+ """
+ Returns a new QuerySet instance with NOT (args) ANDed to the existing
+ set.
+ """
+ return self._filter_or_exclude(True, *args, **kwargs)
- def _filter_or_exclude(self, mapper, *args, **kwargs):
- # mapper is a callable used to transform Q objects,
- # or None for identity transform
- if mapper is None:
- mapper = lambda x: x
- if len(args) > 0 or len(kwargs) > 0:
- assert self._limit is None and self._offset is None, \
- "Cannot filter a query once a slice has been taken."
+ def _filter_or_exclude(self, negate, *args, **kwargs):
+ if args or kwargs:
+ assert self.query.can_filter(), \
+ "Cannot filter a query once a slice has been taken."
clone = self._clone()
- if len(kwargs) > 0:
- clone._filters = clone._filters & mapper(Q(**kwargs))
- if len(args) > 0:
- clone._filters = clone._filters & reduce(operator.and_, map(mapper, args))
+ if negate:
+ clone.query.add_q(~Q(*args, **kwargs))
+ else:
+ clone.query.add_q(Q(*args, **kwargs))
return clone
def complex_filter(self, filter_obj):
- """Returns a new QuerySet instance with filter_obj added to the filters.
- filter_obj can be a Q object (has 'get_sql' method) or a dictionary of
- keyword lookup arguments."""
- # This exists to support framework features such as 'limit_choices_to',
- # and usually it will be more natural to use other methods.
- if hasattr(filter_obj, 'get_sql'):
+ """
+ Returns a new QuerySet instance with filter_obj added to the filters.
+ filter_obj can be a Q object (or anything with an add_to_query()
+ method) or a dictionary of keyword lookup arguments.
+
+ This exists to support framework features such as 'limit_choices_to',
+ and usually it will be more natural to use other methods.
+ """
+ if isinstance(filter_obj, Q) or hasattr(filter_obj, 'add_to_query'):
return self._filter_or_exclude(None, filter_obj)
else:
return self._filter_or_exclude(None, **filter_obj)
- def select_related(self, true_or_false=True, depth=0):
- "Returns a new QuerySet instance with '_select_related' modified."
- return self._clone(_select_related=true_or_false, _max_related_depth=depth)
+ def select_related(self, *fields, **kwargs):
+ """
+ Returns a new QuerySet instance that will select related objects. If
+ fields are specified, they must be ForeignKey fields and only those
+ related objects are included in the selection.
+ """
+ depth = kwargs.pop('depth', 0)
+ if kwargs:
+ raise TypeError('Unexpected keyword arguments to select_related: %s'
+ % (kwargs.keys(),))
+ obj = self._clone()
+ if fields:
+ if depth:
+ raise TypeError('Cannot pass both "depth" and fields to select_related()')
+ obj.query.add_select_related(fields)
+ else:
+ obj.query.select_related = True
+ if depth:
+ obj.query.max_depth = depth
+ return obj
+
+ def dup_select_related(self, other):
+ """
+ Copies the related selection status from the queryset 'other' to the
+ current queryset.
+ """
+ self.query.select_related = other.query.select_related
def order_by(self, *field_names):
- "Returns a new QuerySet instance with the ordering changed."
- assert self._limit is None and self._offset is None, \
+ """Returns a new QuerySet instance with the ordering changed."""
+ assert self.query.can_filter(), \
"Cannot reorder a query once a slice has been taken."
- return self._clone(_order_by=field_names)
+ obj = self._clone()
+ obj.query.clear_ordering()
+ obj.query.add_ordering(*field_names)
+ return obj
def distinct(self, true_or_false=True):
- "Returns a new QuerySet instance with '_distinct' modified."
- return self._clone(_distinct=true_or_false)
+ """
+ Returns a new QuerySet instance that will select only distinct results.
+ """
+ obj = self._clone()
+ obj.query.distinct = true_or_false
+ return obj
- def extra(self, select=None, where=None, params=None, tables=None):
- assert self._limit is None and self._offset is None, \
+ def extra(self, select=None, where=None, params=None, tables=None,
+ order_by=None, select_params=None):
+ """
+ Add extra SQL fragments to the query.
+ """
+ assert self.query.can_filter(), \
"Cannot change a query once a slice has been taken"
clone = self._clone()
- if select: clone._select.update(select)
- if where: clone._where.extend(where)
- if params: clone._params.extend(params)
- if tables: clone._tables.extend(tables)
+ clone.query.add_extra(select, select_params, where, params, tables, order_by)
+ return clone
+
+ def reverse(self):
+ """
+ Reverses the ordering of the queryset.
+ """
+ clone = self._clone()
+ clone.query.standard_ordering = not clone.query.standard_ordering
return clone
###################
# PRIVATE METHODS #
###################
- def _clone(self, klass=None, **kwargs):
+ def _clone(self, klass=None, setup=False, **kwargs):
if klass is None:
klass = self.__class__
- c = klass()
- c.model = self.model
- c._filters = self._filters
- c._order_by = self._order_by
- c._select_related = self._select_related
- c._max_related_depth = self._max_related_depth
- c._distinct = self._distinct
- c._select = self._select.copy()
- c._where = self._where[:]
- c._params = self._params[:]
- c._tables = self._tables[:]
- c._offset = self._offset
- c._limit = self._limit
+ c = klass(model=self.model, query=self.query.clone())
c.__dict__.update(kwargs)
+ if setup and hasattr(c, '_setup_query'):
+ c._setup_query()
return c
- def _combine(self, other):
- assert self._limit is None and self._offset is None \
- and other._limit is None and other._offset is None, \
- "Cannot combine queries once a slice has been taken."
- assert self._distinct == other._distinct, \
- "Cannot combine a unique query with a non-unique query"
- # use 'other's order by
- # (so that A.filter(args1) & A.filter(args2) does the same as
- # A.filter(args1).filter(args2)
- combined = other._clone()
- if self._select: combined._select.update(self._select)
- if self._where: combined._where.extend(self._where)
- if self._params: combined._params.extend(self._params)
- if self._tables: combined._tables.extend(self._tables)
- # If 'self' is ordered and 'other' isn't, propagate 'self's ordering
- if (self._order_by is not None and len(self._order_by) > 0) and \
- (combined._order_by is None or len(combined._order_by) == 0):
- combined._order_by = self._order_by
- return combined
-
- def _get_data(self):
- if self._result_cache is None:
- self._result_cache = list(self.iterator())
- return self._result_cache
-
- def _get_sql_clause(self):
- qn = connection.ops.quote_name
- opts = self.model._meta
-
- # Construct the fundamental parts of the query: SELECT X FROM Y WHERE Z.
- select = ["%s.%s" % (qn(opts.db_table), qn(f.column)) for f in opts.fields]
- tables = [quote_only_if_word(t) for t in self._tables]
- joins = SortedDict()
- where = self._where[:]
- params = self._params[:]
-
- # Convert self._filters into SQL.
- joins2, where2, params2 = self._filters.get_sql(opts)
- joins.update(joins2)
- where.extend(where2)
- params.extend(params2)
-
- # Add additional tables and WHERE clauses based on select_related.
- if self._select_related:
- fill_table_cache(opts, select, tables, where,
- old_prefix=opts.db_table,
- cache_tables_seen=[opts.db_table],
- max_depth=self._max_related_depth)
-
- # Add any additional SELECTs.
- if self._select:
- select.extend(['(%s) AS %s' % (quote_only_if_word(s[1]), qn(s[0])) for s in self._select.items()])
-
- # Start composing the body of the SQL statement.
- sql = [" FROM", qn(opts.db_table)]
-
- # Compose the join dictionary into SQL describing the joins.
- if joins:
- sql.append(" ".join(["%s %s AS %s ON %s" % (join_type, table, alias, condition)
- for (alias, (table, join_type, condition)) in joins.items()]))
-
- # Compose the tables clause into SQL.
- if tables:
- sql.append(", " + ", ".join(tables))
-
- # Compose the where clause into SQL.
- if where:
- sql.append(where and "WHERE " + " AND ".join(where))
-
- # ORDER BY clause
- order_by = []
- if self._order_by is not None:
- ordering_to_use = self._order_by
- else:
- ordering_to_use = opts.ordering
- for f in handle_legacy_orderlist(ordering_to_use):
- if f == '?': # Special case.
- order_by.append(connection.ops.random_function_sql())
- else:
- if f.startswith('-'):
- col_name = f[1:]
- order = "DESC"
- else:
- col_name = f
- order = "ASC"
- if "." in col_name:
- table_prefix, col_name = col_name.split('.', 1)
- table_prefix = qn(table_prefix) + '.'
- else:
- # Use the database table as a column prefix if it wasn't given,
- # and if the requested column isn't a custom SELECT.
- if "." not in col_name and col_name not in (self._select or ()):
- table_prefix = qn(opts.db_table) + '.'
- else:
- table_prefix = ''
- order_by.append('%s%s %s' % (table_prefix, qn(orderfield2column(col_name, opts)), order))
- if order_by:
- sql.append("ORDER BY " + ", ".join(order_by))
-
- # LIMIT and OFFSET clauses
- if self._limit is not None:
- sql.append("%s " % connection.ops.limit_offset_sql(self._limit, self._offset))
- else:
- assert self._offset is None, "'offset' is not allowed without 'limit'"
-
- return select, " ".join(sql), params
+ def _fill_cache(self, num=None):
+ """
+ Fills the result cache with 'num' more entries (or until the results
+ iterator is exhausted).
+ """
+ if self._iter:
+ try:
+ for i in range(num or ITER_CHUNK_SIZE):
+ self._result_cache.append(self._iter.next())
+ except StopIteration:
+ self._iter = None
-# Use the backend's QuerySet class if it defines one. Otherwise, use _QuerySet.
-if connection.features.uses_custom_queryset:
- QuerySet = connection.ops.query_set_class(_QuerySet)
-else:
- QuerySet = _QuerySet
+ def _merge_sanity_check(self, other):
+ """
+ Checks that we are merging two comparable queryset classes.
+ """
+ if self.__class__ is not other.__class__:
+ raise TypeError("Cannot merge querysets of different types ('%s' and '%s'."
+ % (self.__class__.__name__, other.__class__.__name__))
class ValuesQuerySet(QuerySet):
def __init__(self, *args, **kwargs):
super(ValuesQuerySet, self).__init__(*args, **kwargs)
- # select_related isn't supported in values().
- self._select_related = False
+ # select_related isn't supported in values(). (FIXME -#3358)
+ self.query.select_related = False
- def iterator(self):
- try:
- select, sql, params = self._get_sql_clause()
- except EmptyResultSet:
- raise StopIteration
+ # QuerySet.clone() will also set up the _fields attribute with the
+ # names of the model fields to select.
- qn = connection.ops.quote_name
+ def __iter__(self):
+ return self.iterator()
- # self._select is a dictionary, and dictionaries' key order is
- # undefined, so we convert it to a list of tuples.
- extra_select = self._select.items()
+ def iterator(self):
+ self.query.trim_extra_select(self.extra_names)
+ names = self.query.extra_select.keys() + self.field_names
+ for row in self.query.results_iter():
+ yield dict(zip(names, row))
+
+ def _setup_query(self):
+ """
+ Constructs the field_names list that the values query will be
+ retrieving.
- # Construct two objects -- fields and field_names.
- # fields is a list of Field objects to fetch.
- # field_names is a list of field names, which will be the keys in the
- # resulting dictionaries.
+ Called by the _clone() method after initialising the rest of the
+ instance.
+ """
+ self.extra_names = []
if self._fields:
- if not extra_select:
- fields = [self.model._meta.get_field(f, many_to_many=False) for f in self._fields]
- field_names = self._fields
+ if not self.query.extra_select:
+ field_names = list(self._fields)
else:
- fields = []
field_names = []
for f in self._fields:
- if f in [field.name for field in self.model._meta.fields]:
- fields.append(self.model._meta.get_field(f, many_to_many=False))
+ if self.query.extra_select.has_key(f):
+ self.extra_names.append(f)
+ else:
field_names.append(f)
- elif not self._select.has_key(f):
- raise FieldDoesNotExist('%s has no field named %r' % (self.model._meta.object_name, f))
- else: # Default to all fields.
- fields = self.model._meta.fields
- field_names = [f.attname for f in fields]
-
- columns = [f.column for f in fields]
- select = ['%s.%s' % (qn(self.model._meta.db_table), qn(c)) for c in columns]
- if extra_select:
- select.extend(['(%s) AS %s' % (quote_only_if_word(s[1]), qn(s[0])) for s in extra_select])
- field_names.extend([f[0] for f in extra_select])
-
- cursor = connection.cursor()
- cursor.execute("SELECT " + (self._distinct and "DISTINCT " or "") + ",".join(select) + sql, params)
+ else:
+ # Default to all fields.
+ field_names = [f.attname for f in self.model._meta.fields]
- has_resolve_columns = hasattr(self, 'resolve_columns')
- while 1:
- rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)
- if not rows:
- raise StopIteration
- for row in rows:
- if has_resolve_columns:
- row = self.resolve_columns(row, fields)
- yield dict(zip(field_names, row))
+ self.query.add_fields(field_names, False)
+ self.query.default_cols = False
+ self.field_names = field_names
- def _clone(self, klass=None, **kwargs):
+ def _clone(self, klass=None, setup=False, **kwargs):
+ """
+ Cloning a ValuesQuerySet preserves the current fields.
+ """
c = super(ValuesQuerySet, self)._clone(klass, **kwargs)
c._fields = self._fields[:]
+ c.field_names = self.field_names
+ c.extra_names = self.extra_names
+ if setup and hasattr(c, '_setup_query'):
+ c._setup_query()
return c
-class DateQuerySet(QuerySet):
- def iterator(self):
- from django.db.backends.util import typecast_timestamp
- from django.db.models.fields import DateTimeField
+ def _merge_sanity_check(self, other):
+ super(ValuesQuerySet, self)._merge_sanity_check(other)
+ if (set(self.extra_names) != set(other.extra_names) or
+ set(self.field_names) != set(other.field_names)):
+ raise TypeError("Merging '%s' classes must involve the same values in each case."
+ % self.__class__.__name__)
- qn = connection.ops.quote_name
- self._order_by = () # Clear this because it'll mess things up otherwise.
- if self._field.null:
- self._where.append('%s.%s IS NOT NULL' % \
- (qn(self.model._meta.db_table), qn(self._field.column)))
- try:
- select, sql, params = self._get_sql_clause()
- except EmptyResultSet:
- raise StopIteration
+class ValuesListQuerySet(ValuesQuerySet):
+ def iterator(self):
+ self.query.trim_extra_select(self.extra_names)
+ if self.flat and len(self._fields) == 1:
+ for row in self.query.results_iter():
+ yield row[0]
+ elif not self.query.extra_select:
+ for row in self.query.results_iter():
+ yield row
+ else:
+ # When extra(select=...) is involved, the extra cols come are
+ # always at the start of the row, so we need to reorder the fields
+ # to match the order in self._fields.
+ names = self.query.extra_select.keys() + self.field_names
+ for row in self.query.results_iter():
+ data = dict(zip(names, row))
+ yield tuple([data[f] for f in self._fields])
- table_name = qn(self.model._meta.db_table)
- field_name = qn(self._field.column)
+ def _clone(self, *args, **kwargs):
+ clone = super(ValuesListQuerySet, self)._clone(*args, **kwargs)
+ clone.flat = self.flat
+ return clone
- if connection.features.allows_group_by_ordinal:
- group_by = '1'
- else:
- group_by = connection.ops.date_trunc_sql(self._kind, '%s.%s' % (table_name, field_name))
+class DateQuerySet(QuerySet):
+ def iterator(self):
+ return self.query.results_iter()
- sql = 'SELECT %s %s GROUP BY %s ORDER BY 1 %s' % \
- (connection.ops.date_trunc_sql(self._kind, '%s.%s' % (qn(self.model._meta.db_table),
- qn(self._field.column))), sql, group_by, self._order)
- cursor = connection.cursor()
- cursor.execute(sql, params)
+ def _setup_query(self):
+ """
+ Sets up any special features of the query attribute.
- has_resolve_columns = hasattr(self, 'resolve_columns')
- needs_datetime_string_cast = connection.features.needs_datetime_string_cast
- dates = []
- # It would be better to use self._field here instead of DateTimeField(),
- # but in Oracle that will result in a list of datetime.date instead of
- # datetime.datetime.
- fields = [DateTimeField()]
- while 1:
- rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)
- if not rows:
- return dates
- for row in rows:
- date = row[0]
- if has_resolve_columns:
- date = self.resolve_columns([date], fields)[0]
- elif needs_datetime_string_cast:
- date = typecast_timestamp(str(date))
- dates.append(date)
+ Called by the _clone() method after initialising the rest of the
+ instance.
+ """
+ self.query = self.query.clone(klass=sql.DateQuery, setup=True)
+ self.query.select = []
+ self.query.add_date_select(self._field.column, self._kind, self._order)
+ if self._field.null:
+ self.query.add_filter(('%s__isnull' % self._field.name, True))
- def _clone(self, klass=None, **kwargs):
- c = super(DateQuerySet, self)._clone(klass, **kwargs)
+ def _clone(self, klass=None, setup=False, **kwargs):
+ c = super(DateQuerySet, self)._clone(klass, False, **kwargs)
c._field = self._field
c._kind = self._kind
- c._order = self._order
+ if setup and hasattr(c, '_setup_query'):
+ c._setup_query()
return c
class EmptyQuerySet(QuerySet):
- def __init__(self, model=None):
- super(EmptyQuerySet, self).__init__(model)
+ def __init__(self, model=None, query=None):
+ super(EmptyQuerySet, self).__init__(model, query)
self._result_cache = []
def count(self):
@@ -703,488 +610,112 @@ class EmptyQuerySet(QuerySet):
def delete(self):
pass
- def _clone(self, klass=None, **kwargs):
+ def _clone(self, klass=None, setup=False, **kwargs):
c = super(EmptyQuerySet, self)._clone(klass, **kwargs)
c._result_cache = []
return c
- def _get_sql_clause(self):
- raise EmptyResultSet
-
-class QOperator(object):
- "Base class for QAnd and QOr"
- def __init__(self, *args):
- self.args = args
-
- def get_sql(self, opts):
- joins, where, params = SortedDict(), [], []
- for val in self.args:
- try:
- joins2, where2, params2 = val.get_sql(opts)
- joins.update(joins2)
- where.extend(where2)
- params.extend(params2)
- except EmptyResultSet:
- if not isinstance(self, QOr):
- raise EmptyResultSet
- if where:
- return joins, ['(%s)' % self.operator.join(where)], params
- return joins, [], params
-
-class QAnd(QOperator):
- "Encapsulates a combined query that uses 'AND'."
- operator = ' AND '
- def __or__(self, other):
- return QOr(self, other)
-
- def __and__(self, other):
- if isinstance(other, QAnd):
- return QAnd(*(self.args+other.args))
- elif isinstance(other, (Q, QOr)):
- return QAnd(*(self.args+(other,)))
- else:
- raise TypeError, other
-
-class QOr(QOperator):
- "Encapsulates a combined query that uses 'OR'."
- operator = ' OR '
- def __and__(self, other):
- return QAnd(self, other)
-
- def __or__(self, other):
- if isinstance(other, QOr):
- return QOr(*(self.args+other.args))
- elif isinstance(other, (Q, QAnd)):
- return QOr(*(self.args+(other,)))
- else:
- raise TypeError, other
-
-class Q(object):
- "Encapsulates queries as objects that can be combined logically."
- def __init__(self, **kwargs):
- self.kwargs = kwargs
-
- def __and__(self, other):
- return QAnd(self, other)
-
- def __or__(self, other):
- return QOr(self, other)
-
- def get_sql(self, opts):
- return parse_lookup(self.kwargs.items(), opts)
-
-class QNot(Q):
- "Encapsulates NOT (...) queries as objects"
- def __init__(self, q):
- "Creates a negation of the q object passed in."
- self.q = q
+ def iterator(self):
+ # This slightly odd construction is because we need an empty generator
+ # (it raises StopIteration immediately).
+ yield iter([]).next()
- def get_sql(self, opts):
- try:
- joins, where, params = self.q.get_sql(opts)
- where2 = ['(NOT (%s))' % " AND ".join(where)]
- except EmptyResultSet:
- return SortedDict(), [], []
- return joins, where2, params
+# QOperator, QNot, QAnd and QOr are temporarily retained for backwards
+# compatibility. All the old functionality is now part of the 'Q' class.
+class QOperator(Q):
+ def __init__(self, *args, **kwargs):
+ warnings.warn('Use Q instead of QOr, QAnd or QOperation.',
+ DeprecationWarning, stacklevel=2)
+ super(QOperator, self).__init__(*args, **kwargs)
-def get_where_clause(lookup_type, table_prefix, field_name, value, db_type):
- if table_prefix.endswith('.'):
- table_prefix = connection.ops.quote_name(table_prefix[:-1])+'.'
- field_name = connection.ops.quote_name(field_name)
- if type(value) == datetime.datetime and connection.ops.datetime_cast_sql():
- cast_sql = connection.ops.datetime_cast_sql()
- else:
- cast_sql = '%s'
- field_sql = connection.ops.field_cast_sql(db_type) % (table_prefix + field_name)
- if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith') and connection.features.needs_upper_for_iops:
- format = 'UPPER(%s) %s'
- else:
- format = '%s %s'
- try:
- return format % (field_sql, connection.operators[lookup_type] % cast_sql)
- except KeyError:
- pass
- if lookup_type == 'in':
- in_string = ','.join(['%s' for id in value])
- if in_string:
- return '%s IN (%s)' % (field_sql, in_string)
- else:
- raise EmptyResultSet
- elif lookup_type in ('range', 'year'):
- return '%s BETWEEN %%s AND %%s' % field_sql
- elif lookup_type in ('month', 'day'):
- return "%s = %%s" % connection.ops.date_extract_sql(lookup_type, field_sql)
- elif lookup_type == 'isnull':
- return "%s IS %sNULL" % (field_sql, (not value and 'NOT ' or ''))
- elif lookup_type == 'search':
- return connection.ops.fulltext_search_sql(field_sql)
- elif lookup_type in ('regex', 'iregex'):
- if settings.DATABASE_ENGINE == 'oracle':
- if connection.oracle_version and connection.oracle_version <= 9:
- msg = "Regexes are not supported in Oracle before version 10g."
- raise NotImplementedError(msg)
- if lookup_type == 'regex':
- match_option = 'c'
- else:
- match_option = 'i'
- return "REGEXP_LIKE(%s, %s, '%s')" % (field_sql, cast_sql, match_option)
- else:
- raise NotImplementedError
- raise TypeError, "Got invalid lookup_type: %s" % repr(lookup_type)
+QOr = QAnd = QOperator
-def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0):
- """Helper function that recursively returns an object with cache filled"""
+def QNot(q):
+ warnings.warn('Use ~q instead of QNot(q)', DeprecationWarning, stacklevel=2)
+ return ~q
- # If we've got a max_depth set and we've exceeded that depth, bail now.
- if max_depth and cur_depth > max_depth:
+def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
+ requested=None):
+ """
+ Helper function that recursively returns an object with the specified
+ related attributes already populated.
+ """
+ if max_depth and requested is None and cur_depth > max_depth:
+ # We've recursed deeply enough; stop now.
return None
+ restricted = requested is not None
index_end = index_start + len(klass._meta.fields)
- obj = klass(*row[index_start:index_end])
+ obj = klass.from_sequence(row[index_start:index_end])
for f in klass._meta.fields:
- if f.rel and not f.null:
- cached_row = get_cached_row(f.rel.to, row, index_end, max_depth, cur_depth+1)
- if cached_row:
- rel_obj, index_end = cached_row
- setattr(obj, f.get_cache_name(), rel_obj)
+ if (not f.rel or (not restricted and f.null) or
+ (restricted and f.name not in requested) or f.rel.parent_link):
+ continue
+ if restricted:
+ next = requested[f.name]
+ else:
+ next = None
+ cached_row = get_cached_row(f.rel.to, row, index_end, max_depth,
+ cur_depth+1, next)
+ if cached_row:
+ rel_obj, index_end = cached_row
+ setattr(obj, f.get_cache_name(), rel_obj)
return obj, index_end
-def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen, max_depth=0, cur_depth=0):
- """
- Helper function that recursively populates the select, tables and where (in
- place) for select_related queries.
- """
-
- # If we've got a max_depth set and we've exceeded that depth, bail now.
- if max_depth and cur_depth > max_depth:
- return None
-
- qn = connection.ops.quote_name
- for f in opts.fields:
- if f.rel and not f.null:
- db_table = f.rel.to._meta.db_table
- if db_table not in cache_tables_seen:
- tables.append(qn(db_table))
- else: # The table was already seen, so give it a table alias.
- new_prefix = '%s%s' % (db_table, len(cache_tables_seen))
- tables.append('%s %s' % (qn(db_table), qn(new_prefix)))
- db_table = new_prefix
- cache_tables_seen.append(db_table)
- where.append('%s.%s = %s.%s' % \
- (qn(old_prefix), qn(f.column), qn(db_table), qn(f.rel.get_related_field().column)))
- select.extend(['%s.%s' % (qn(db_table), qn(f2.column)) for f2 in f.rel.to._meta.fields])
- fill_table_cache(f.rel.to._meta, select, tables, where, db_table, cache_tables_seen, max_depth, cur_depth+1)
-
-def parse_lookup(kwarg_items, opts):
- # Helper function that handles converting API kwargs
- # (e.g. "name__exact": "tom") to SQL.
- # Returns a tuple of (joins, where, params).
-
- # 'joins' is a sorted dictionary describing the tables that must be joined
- # to complete the query. The dictionary is sorted because creation order
- # is significant; it is a dictionary to ensure uniqueness of alias names.
- #
- # Each key-value pair follows the form
- # alias: (table, join_type, condition)
- # where
- # alias is the AS alias for the joined table
- # table is the actual table name to be joined
- # join_type is the type of join (INNER JOIN, LEFT OUTER JOIN, etc)
- # condition is the where-like statement over which narrows the join.
- # alias will be derived from the lookup list name.
- #
- # At present, this method only every returns INNER JOINs; the option is
- # there for others to implement custom Q()s, etc that return other join
- # types.
- joins, where, params = SortedDict(), [], []
-
- for kwarg, value in kwarg_items:
- path = kwarg.split(LOOKUP_SEPARATOR)
- # Extract the last elements of the kwarg.
- # The very-last is the lookup_type (equals, like, etc).
- # The second-last is the table column on which the lookup_type is
- # to be performed. If this name is 'pk', it will be substituted with
- # the name of the primary key.
- # If there is only one part, or the last part is not a query
- # term, assume that the query is an __exact
- lookup_type = path.pop()
- if lookup_type == 'pk':
- lookup_type = 'exact'
- path.append(None)
- elif len(path) == 0 or lookup_type not in QUERY_TERMS:
- path.append(lookup_type)
- lookup_type = 'exact'
-
- if len(path) < 1:
- raise TypeError, "Cannot parse keyword query %r" % kwarg
-
- if value is None:
- # Interpret '__exact=None' as the sql '= NULL'; otherwise, reject
- # all uses of None as a query value.
- if lookup_type != 'exact':
- raise ValueError, "Cannot use None as a query value"
- elif callable(value):
- value = value()
-
- joins2, where2, params2 = lookup_inner(path, lookup_type, value, opts, opts.db_table, None)
- joins.update(joins2)
- where.extend(where2)
- params.extend(params2)
- return joins, where, params
-
-class FieldFound(Exception):
- "Exception used to short circuit field-finding operations."
- pass
-
-def find_field(name, field_list, related_query):
+def delete_objects(seen_objs):
"""
- Finds a field with a specific name in a list of field instances.
- Returns None if there are no matches, or several matches.
+ Iterate through a list of seen classes, and remove any instances that are
+ referred to.
"""
- if related_query:
- matches = [f for f in field_list if f.field.related_query_name() == name]
- else:
- matches = [f for f in field_list if f.name == name]
- if len(matches) != 1:
- return None
- return matches[0]
-
-def field_choices(field_list, related_query):
- if related_query:
- choices = [f.field.related_query_name() for f in field_list]
- else:
- choices = [f.name for f in field_list]
- return choices
-
-def lookup_inner(path, lookup_type, value, opts, table, column):
- qn = connection.ops.quote_name
- joins, where, params = SortedDict(), [], []
- current_opts = opts
- current_table = table
- current_column = column
- intermediate_table = None
- join_required = False
-
- name = path.pop(0)
- # Has the primary key been requested? If so, expand it out
- # to be the name of the current class' primary key
- if name is None or name == 'pk':
- name = current_opts.pk.name
-
- # Try to find the name in the fields associated with the current class
- try:
- # Does the name belong to a defined many-to-many field?
- field = find_field(name, current_opts.many_to_many, False)
- if field:
- new_table = current_table + '__' + name
- new_opts = field.rel.to._meta
- new_column = new_opts.pk.column
-
- # Need to create an intermediate table join over the m2m table
- # This process hijacks current_table/column to point to the
- # intermediate table.
- current_table = "m2m_" + new_table
- intermediate_table = field.m2m_db_table()
- join_column = field.m2m_reverse_name()
- intermediate_column = field.m2m_column_name()
-
- raise FieldFound
-
- # Does the name belong to a reverse defined many-to-many field?
- field = find_field(name, current_opts.get_all_related_many_to_many_objects(), True)
- if field:
- new_table = current_table + '__' + name
- new_opts = field.opts
- new_column = new_opts.pk.column
-
- # Need to create an intermediate table join over the m2m table.
- # This process hijacks current_table/column to point to the
- # intermediate table.
- current_table = "m2m_" + new_table
- intermediate_table = field.field.m2m_db_table()
- join_column = field.field.m2m_column_name()
- intermediate_column = field.field.m2m_reverse_name()
-
- raise FieldFound
-
- # Does the name belong to a one-to-many field?
- field = find_field(name, current_opts.get_all_related_objects(), True)
- if field:
- new_table = table + '__' + name
- new_opts = field.opts
- new_column = field.field.column
- join_column = opts.pk.column
-
- # 1-N fields MUST be joined, regardless of any other conditions.
- join_required = True
-
- raise FieldFound
-
- # Does the name belong to a one-to-one, many-to-one, or regular field?
- field = find_field(name, current_opts.fields, False)
- if field:
- if field.rel: # One-to-One/Many-to-one field
- new_table = current_table + '__' + name
- new_opts = field.rel.to._meta
- new_column = new_opts.pk.column
- join_column = field.column
- raise FieldFound
- elif path:
- # For regular fields, if there are still items on the path,
- # an error has been made. We munge "name" so that the error
- # properly identifies the cause of the problem.
- name += LOOKUP_SEPARATOR + path[0]
- else:
- raise FieldFound
-
- except FieldFound: # Match found, loop has been shortcut.
- pass
- else: # No match found.
- choices = field_choices(current_opts.many_to_many, False) + \
- field_choices(current_opts.get_all_related_many_to_many_objects(), True) + \
- field_choices(current_opts.get_all_related_objects(), True) + \
- field_choices(current_opts.fields, False)
- raise TypeError, "Cannot resolve keyword '%s' into field. Choices are: %s" % (name, ", ".join(choices))
-
- # Check whether an intermediate join is required between current_table
- # and new_table.
- if intermediate_table:
- joins[qn(current_table)] = (
- qn(intermediate_table), "LEFT OUTER JOIN",
- "%s.%s = %s.%s" % (qn(table), qn(current_opts.pk.column), qn(current_table), qn(intermediate_column))
- )
-
- if path:
- # There are elements left in the path. More joins are required.
- if len(path) == 1 and path[0] in (new_opts.pk.name, None) \
- and lookup_type in ('exact', 'isnull') and not join_required:
- # If the next and final name query is for a primary key,
- # and the search is for isnull/exact, then the current
- # (for N-1) or intermediate (for N-N) table can be used
- # for the search. No need to join an extra table just
- # to check the primary key.
- new_table = current_table
- else:
- # There are 1 or more name queries pending, and we have ruled out
- # any shortcuts; therefore, a join is required.
- joins[qn(new_table)] = (
- qn(new_opts.db_table), "INNER JOIN",
- "%s.%s = %s.%s" % (qn(current_table), qn(join_column), qn(new_table), qn(new_column))
- )
- # If we have made the join, we don't need to tell subsequent
- # recursive calls about the column name we joined on.
- join_column = None
-
- # There are name queries remaining. Recurse deeper.
- joins2, where2, params2 = lookup_inner(path, lookup_type, value, new_opts, new_table, join_column)
-
- joins.update(joins2)
- where.extend(where2)
- params.extend(params2)
- else:
- # No elements left in path. Current element is the element on which
- # the search is being performed.
- db_type = None
-
- if join_required:
- # Last query term is a RelatedObject
- if field.field.rel.multiple:
- # RelatedObject is from a 1-N relation.
- # Join is required; query operates on joined table.
- column = new_opts.pk.name
- joins[qn(new_table)] = (
- qn(new_opts.db_table), "INNER JOIN",
- "%s.%s = %s.%s" % (qn(current_table), qn(join_column), qn(new_table), qn(new_column))
- )
- current_table = new_table
- else:
- # RelatedObject is from a 1-1 relation,
- # No need to join; get the pk value from the related object,
- # and compare using that.
- column = current_opts.pk.name
- elif intermediate_table:
- # Last query term is a related object from an N-N relation.
- # Join from intermediate table is sufficient.
- column = join_column
- elif name == current_opts.pk.name and lookup_type in ('exact', 'isnull') and current_column:
- # Last query term is for a primary key. If previous iterations
- # introduced a current/intermediate table that can be used to
- # optimize the query, then use that table and column name.
- column = current_column
- else:
- # Last query term was a normal field.
- column = field.column
- db_type = field.db_type()
-
- where.append(get_where_clause(lookup_type, current_table + '.', column, value, db_type))
- params.extend(field.get_db_prep_lookup(lookup_type, value))
-
- return joins, where, params
-
-def delete_objects(seen_objs):
- "Iterate through a list of seen classes, and remove any instances that are referred to"
- qn = connection.ops.quote_name
ordered_classes = seen_objs.keys()
ordered_classes.reverse()
- cursor = connection.cursor()
-
for cls in ordered_classes:
seen_objs[cls] = seen_objs[cls].items()
seen_objs[cls].sort()
# Pre notify all instances to be deleted
for pk_val, instance in seen_objs[cls]:
- dispatcher.send(signal=signals.pre_delete, sender=cls, instance=instance)
+ dispatcher.send(signal=signals.pre_delete, sender=cls,
+ instance=instance)
pk_list = [pk for pk,instance in seen_objs[cls]]
- for related in cls._meta.get_all_related_many_to_many_objects():
- if not isinstance(related.field, generic.GenericRelation):
- for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
- cursor.execute("DELETE FROM %s WHERE %s IN (%s)" % \
- (qn(related.field.m2m_db_table()),
- qn(related.field.m2m_reverse_name()),
- ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]])),
- pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE])
- for f in cls._meta.many_to_many:
- if isinstance(f, generic.GenericRelation):
- from django.contrib.contenttypes.models import ContentType
- query_extra = 'AND %s=%%s' % f.rel.to._meta.get_field(f.content_type_field_name).column
- args_extra = [ContentType.objects.get_for_model(cls).id]
- else:
- query_extra = ''
- args_extra = []
- for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
- cursor.execute(("DELETE FROM %s WHERE %s IN (%s)" % \
- (qn(f.m2m_db_table()), qn(f.m2m_column_name()),
- ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]]))) + query_extra,
- pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE] + args_extra)
+ del_query = sql.DeleteQuery(cls, connection)
+ del_query.delete_batch_related(pk_list)
+
+ update_query = sql.UpdateQuery(cls, connection)
for field in cls._meta.fields:
if field.rel and field.null and field.rel.to in seen_objs:
- for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
- cursor.execute("UPDATE %s SET %s=NULL WHERE %s IN (%s)" % \
- (qn(cls._meta.db_table), qn(field.column), qn(cls._meta.pk.column),
- ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]])),
- pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE])
+ update_query.clear_related(field, pk_list)
# Now delete the actual data
for cls in ordered_classes:
seen_objs[cls].reverse()
pk_list = [pk for pk,instance in seen_objs[cls]]
- for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
- cursor.execute("DELETE FROM %s WHERE %s IN (%s)" % \
- (qn(cls._meta.db_table), qn(cls._meta.pk.column),
- ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]])),
- pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE])
+ del_query = sql.DeleteQuery(cls, connection)
+ del_query.delete_batch(pk_list)
- # Last cleanup; set NULLs where there once was a reference to the object,
- # NULL the primary key of the found objects, and perform post-notification.
+ # Last cleanup; set NULLs where there once was a reference to the
+ # object, NULL the primary key of the found objects, and perform
+ # post-notification.
for pk_val, instance in seen_objs[cls]:
for field in cls._meta.fields:
if field.rel and field.null and field.rel.to in seen_objs:
setattr(instance, field.attname, None)
- dispatcher.send(signal=signals.post_delete, sender=cls, instance=instance)
+ dispatcher.send(signal=signals.post_delete, sender=cls,
+ instance=instance)
setattr(instance, cls._meta.pk.attname, None)
transaction.commit_unless_managed()
+
+def insert_query(model, values, return_id=False, raw_values=False):
+ """
+ Inserts a new record for the given model. This provides an interface to
+ the InsertQuery class and is how Model.save() is implemented. It is not
+ part of the public API.
+ """
+ query = sql.InsertQuery(model, connection)
+ query.insert_values(values, raw_values)
+ return query.execute_sql(return_id)
+
diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py
new file mode 100644
index 0000000000..0ce7900c74
--- /dev/null
+++ b/django/db/models/query_utils.py
@@ -0,0 +1,50 @@
+"""
+Various data structures used in query construction.
+
+Factored out from django.db.models.query so that they can also be used by other
+modules without getting into circular import difficulties.
+"""
+
+from copy import deepcopy
+
+from django.utils import tree
+
+class QueryWrapper(object):
+ """
+ A type that indicates the contents are an SQL fragment and the associate
+ parameters. Can be used to pass opaque data to a where-clause, for example.
+ """
+ def __init__(self, sql, params):
+ self.data = sql, params
+
+class Q(tree.Node):
+ """
+ Encapsulates filters as objects that can then be combined logically (using
+ & and |).
+ """
+ # Connection types
+ AND = 'AND'
+ OR = 'OR'
+ default = AND
+
+ def __init__(self, *args, **kwargs):
+ super(Q, self).__init__(children=list(args) + kwargs.items())
+
+ def _combine(self, other, conn):
+ if not isinstance(other, Q):
+ raise TypeError(other)
+ obj = deepcopy(self)
+ obj.add(other, conn)
+ return obj
+
+ def __or__(self, other):
+ return self._combine(other, self.OR)
+
+ def __and__(self, other):
+ return self._combine(other, self.AND)
+
+ def __invert__(self):
+ obj = deepcopy(self)
+ obj.negate()
+ return obj
+
diff --git a/django/db/models/sql/__init__.py b/django/db/models/sql/__init__.py
new file mode 100644
index 0000000000..7310982690
--- /dev/null
+++ b/django/db/models/sql/__init__.py
@@ -0,0 +1,7 @@
+from query import *
+from subqueries import *
+from where import AND, OR
+from datastructures import EmptyResultSet
+
+__all__ = ['Query', 'AND', 'OR', 'EmptyResultSet']
+
diff --git a/django/db/models/sql/constants.py b/django/db/models/sql/constants.py
new file mode 100644
index 0000000000..3075817385
--- /dev/null
+++ b/django/db/models/sql/constants.py
@@ -0,0 +1,36 @@
+import re
+
+# Valid query types (a dictionary is used for speedy lookups).
+QUERY_TERMS = dict([(x, None) for x in (
+ 'exact', 'iexact', 'contains', 'icontains', 'gt', 'gte', 'lt', 'lte', 'in',
+ 'startswith', 'istartswith', 'endswith', 'iendswith', 'range', 'year',
+ 'month', 'day', 'isnull', 'search', 'regex', 'iregex',
+ )])
+
+# Size of each "chunk" for get_iterator calls.
+# Larger values are slightly faster at the expense of more storage space.
+GET_ITERATOR_CHUNK_SIZE = 100
+
+# Separator used to split filter strings apart.
+LOOKUP_SEP = '__'
+
+# Constants to make looking up tuple values clearer.
+# Join lists
+TABLE_NAME = 0
+RHS_ALIAS = 1
+JOIN_TYPE = 2
+LHS_ALIAS = 3
+LHS_JOIN_COL = 4
+RHS_JOIN_COL = 5
+NULLABLE = 6
+
+# How many results to expect from a cursor.execute call
+MULTI = 'multi'
+SINGLE = 'single'
+
+ORDER_PATTERN = re.compile(r'\?|[-+]?\w+$')
+ORDER_DIR = {
+ 'ASC': ('ASC', 'DESC'),
+ 'DESC': ('DESC', 'ASC')}
+
+
diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py
new file mode 100644
index 0000000000..cb54a564f8
--- /dev/null
+++ b/django/db/models/sql/datastructures.py
@@ -0,0 +1,103 @@
+"""
+Useful auxilliary data structures for query construction. Not useful outside
+the SQL domain.
+"""
+
+class EmptyResultSet(Exception):
+ pass
+
+class FullResultSet(Exception):
+ pass
+
+class MultiJoin(Exception):
+ """
+ Used by join construction code to indicate the point at which a
+ multi-valued join was attempted (if the caller wants to treat that
+ exceptionally).
+ """
+ def __init__(self, level):
+ self.level = level
+
+class Empty(object):
+ pass
+
+class RawValue(object):
+ def __init__(self, value):
+ self.value = value
+
+class Aggregate(object):
+ """
+ Base class for all aggregate-related classes (min, max, avg, count, sum).
+ """
+ def relabel_aliases(self, change_map):
+ """
+ Relabel the column alias, if necessary. Must be implemented by
+ subclasses.
+ """
+ raise NotImplementedError
+
+ def as_sql(self, quote_func=None):
+ """
+ Returns the SQL string fragment for this object.
+
+ The quote_func function is used to quote the column components. If
+ None, it defaults to doing nothing.
+
+ Must be implemented by subclasses.
+ """
+ raise NotImplementedError
+
+class Count(Aggregate):
+ """
+ Perform a count on the given column.
+ """
+ def __init__(self, col='*', distinct=False):
+ """
+ Set the column to count on (defaults to '*') and set whether the count
+ should be distinct or not.
+ """
+ self.col = col
+ self.distinct = distinct
+
+ def relabel_aliases(self, change_map):
+ c = self.col
+ if isinstance(c, (list, tuple)):
+ self.col = (change_map.get(c[0], c[0]), c[1])
+
+ def as_sql(self, quote_func=None):
+ if not quote_func:
+ quote_func = lambda x: x
+ if isinstance(self.col, (list, tuple)):
+ col = ('%s.%s' % tuple([quote_func(c) for c in self.col]))
+ elif hasattr(self.col, 'as_sql'):
+ col = self.col.as_sql(quote_func)
+ else:
+ col = self.col
+ if self.distinct:
+ return 'COUNT(DISTINCT %s)' % col
+ else:
+ return 'COUNT(%s)' % col
+
+class Date(object):
+ """
+ Add a date selection column.
+ """
+ def __init__(self, col, lookup_type, date_sql_func):
+ self.col = col
+ self.lookup_type = lookup_type
+ self.date_sql_func= date_sql_func
+
+ def relabel_aliases(self, change_map):
+ c = self.col
+ if isinstance(c, (list, tuple)):
+ self.col = (change_map.get(c[0], c[0]), c[1])
+
+ def as_sql(self, quote_func=None):
+ if not quote_func:
+ quote_func = lambda x: x
+ if isinstance(self.col, (list, tuple)):
+ col = '%s.%s' % tuple([quote_func(c) for c in self.col])
+ else:
+ col = self.col
+ return self.date_sql_func(self.lookup_type, col)
+
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
new file mode 100644
index 0000000000..59b2ebdd68
--- /dev/null
+++ b/django/db/models/sql/query.py
@@ -0,0 +1,1504 @@
+"""
+Create SQL statements for QuerySets.
+
+The code in here encapsulates all of the SQL construction so that QuerySets
+themselves do not have to (and could be backed by things other than SQL
+databases). The abstraction barrier only works one way: this module has to know
+all about the internals of models in order to get the information it needs.
+"""
+
+from copy import deepcopy
+
+from django.utils.tree import Node
+from django.utils.datastructures import SortedDict
+from django.dispatch import dispatcher
+from django.db import connection
+from django.db.models import signals
+from django.db.models.sql.where import WhereNode, EverythingNode, AND, OR
+from django.db.models.sql.datastructures import Count
+from django.db.models.fields import FieldDoesNotExist
+from django.core.exceptions import FieldError
+from datastructures import EmptyResultSet, Empty, MultiJoin
+from constants import *
+
+try:
+ set
+except NameError:
+ from sets import Set as set # Python 2.3 fallback
+
+__all__ = ['Query']
+
+class Query(object):
+ """
+ A single SQL query.
+ """
+ # SQL join types. These are part of the class because their string forms
+ # vary from database to database and can be customised by a subclass.
+ INNER = 'INNER JOIN'
+ LOUTER = 'LEFT OUTER JOIN'
+
+ alias_prefix = 'T'
+ query_terms = QUERY_TERMS
+
+ def __init__(self, model, connection, where=WhereNode):
+ self.model = model
+ self.connection = connection
+ self.alias_refcount = {}
+ self.alias_map = {} # Maps alias to join information
+ self.table_map = {} # Maps table names to list of aliases.
+ self.join_map = {}
+ self.rev_join_map = {} # Reverse of join_map.
+ self.quote_cache = {}
+ self.default_cols = True
+ self.default_ordering = True
+ self.standard_ordering = True
+ self.ordering_aliases = []
+ self.start_meta = None
+ self.select_fields = []
+ self.related_select_fields = []
+
+ # SQL-related attributes
+ self.select = []
+ self.tables = [] # Aliases in the order they are created.
+ self.where = where()
+ self.where_class = where
+ self.group_by = []
+ self.having = []
+ self.order_by = []
+ self.low_mark, self.high_mark = 0, None # Used for offset/limit
+ self.distinct = False
+ self.select_related = False
+ self.related_select_cols = []
+
+ # Arbitrary maximum limit for select_related. Prevents infinite
+ # recursion. Can be changed by the depth parameter to select_related().
+ self.max_depth = 5
+
+ # These are for extensions. The contents are more or less appended
+ # verbatim to the appropriate clause.
+ self.extra_select = {} # Maps col_alias -> col_sql.
+ self.extra_select_params = ()
+ self.extra_tables = ()
+ self.extra_where = ()
+ self.extra_params = ()
+ self.extra_order_by = ()
+
+ def __str__(self):
+ """
+ Returns the query as a string of SQL with the parameter values
+ substituted in.
+
+ Parameter values won't necessarily be quoted correctly, since that is
+ done by the database interface at execution time.
+ """
+ sql, params = self.as_sql()
+ return sql % params
+
+ def __deepcopy__(self, memo):
+ result= self.clone()
+ memo[id(self)] = result
+ return result
+
+ def get_meta(self):
+ """
+ Returns the Options instance (the model._meta) from which to start
+ processing. Normally, this is self.model._meta, but it can change.
+ """
+ if self.start_meta:
+ return self.start_meta
+ return self.model._meta
+
+ def quote_name_unless_alias(self, name):
+ """
+ A wrapper around connection.ops.quote_name that doesn't quote aliases
+ for table names. This avoids problems with some SQL dialects that treat
+ quoted strings specially (e.g. PostgreSQL).
+ """
+ if name in self.quote_cache:
+ return self.quote_cache[name]
+ if ((name in self.alias_map and name not in self.table_map) or
+ name in self.extra_select):
+ self.quote_cache[name] = name
+ return name
+ r = self.connection.ops.quote_name(name)
+ self.quote_cache[name] = r
+ return r
+
+ def clone(self, klass=None, **kwargs):
+ """
+ Creates a copy of the current instance. The 'kwargs' parameter can be
+ used by clients to update attributes after copying has taken place.
+ """
+ obj = Empty()
+ obj.__class__ = klass or self.__class__
+ obj.model = self.model
+ obj.connection = self.connection
+ obj.alias_refcount = self.alias_refcount.copy()
+ obj.alias_map = self.alias_map.copy()
+ obj.table_map = self.table_map.copy()
+ obj.join_map = self.join_map.copy()
+ obj.rev_join_map = self.rev_join_map.copy()
+ obj.quote_cache = {}
+ obj.default_cols = self.default_cols
+ obj.default_ordering = self.default_ordering
+ obj.standard_ordering = self.standard_ordering
+ obj.ordering_aliases = []
+ obj.start_meta = self.start_meta
+ obj.select_fields = self.select_fields[:]
+ obj.related_select_fields = self.related_select_fields[:]
+ obj.select = self.select[:]
+ obj.tables = self.tables[:]
+ obj.where = deepcopy(self.where)
+ obj.where_class = self.where_class
+ obj.group_by = self.group_by[:]
+ obj.having = self.having[:]
+ obj.order_by = self.order_by[:]
+ obj.low_mark, obj.high_mark = self.low_mark, self.high_mark
+ obj.distinct = self.distinct
+ obj.select_related = self.select_related
+ obj.related_select_cols = []
+ obj.max_depth = self.max_depth
+ obj.extra_select = self.extra_select.copy()
+ obj.extra_select_params = self.extra_select_params
+ obj.extra_tables = self.extra_tables
+ obj.extra_where = self.extra_where
+ obj.extra_params = self.extra_params
+ obj.extra_order_by = self.extra_order_by
+ obj.__dict__.update(kwargs)
+ if hasattr(obj, '_setup_query'):
+ obj._setup_query()
+ return obj
+
+ def results_iter(self):
+ """
+ Returns an iterator over the results from executing this query.
+ """
+ resolve_columns = hasattr(self, 'resolve_columns')
+ if resolve_columns:
+ if self.select_fields:
+ fields = self.select_fields + self.related_select_fields
+ else:
+ fields = self.model._meta.fields
+ for rows in self.execute_sql(MULTI):
+ for row in rows:
+ if resolve_columns:
+ row = self.resolve_columns(row, fields)
+ yield row
+
+ def get_count(self):
+ """
+ Performs a COUNT() query using the current filter constraints.
+ """
+ from subqueries import CountQuery
+ obj = self.clone()
+ obj.clear_ordering(True)
+ obj.clear_limits()
+ obj.select_related = False
+ obj.related_select_cols = []
+ obj.related_select_fields = []
+ if obj.distinct and len(obj.select) > 1:
+ obj = self.clone(CountQuery, _query=obj, where=self.where_class(),
+ distinct=False)
+ obj.select = []
+ obj.extra_select = {}
+ obj.add_count_column()
+ data = obj.execute_sql(SINGLE)
+ if not data:
+ return 0
+ number = data[0]
+
+ # Apply offset and limit constraints manually, since using LIMIT/OFFSET
+ # in SQL (in variants that provide them) doesn't change the COUNT
+ # output.
+ number = max(0, number - self.low_mark)
+ if self.high_mark:
+ number = min(number, self.high_mark - self.low_mark)
+
+ return number
+
+ def as_sql(self, with_limits=True, with_col_aliases=False):
+ """
+ Creates the SQL for this query. Returns the SQL string and list of
+ parameters.
+
+ If 'with_limits' is False, any limit/offset information is not included
+ in the query.
+ """
+ self.pre_sql_setup()
+ out_cols = self.get_columns(with_col_aliases)
+ ordering = self.get_ordering()
+
+ # This must come after 'select' and 'ordering' -- see docstring of
+ # get_from_clause() for details.
+ from_, f_params = self.get_from_clause()
+
+ where, w_params = self.where.as_sql(qn=self.quote_name_unless_alias)
+ params = list(self.extra_select_params)
+
+ result = ['SELECT']
+ if self.distinct:
+ result.append('DISTINCT')
+ result.append(', '.join(out_cols + self.ordering_aliases))
+
+ result.append('FROM')
+ result.extend(from_)
+ params.extend(f_params)
+
+ if where:
+ result.append('WHERE %s' % where)
+ params.extend(w_params)
+ if self.extra_where:
+ if not where:
+ result.append('WHERE')
+ else:
+ result.append('AND')
+ result.append(' AND '.join(self.extra_where))
+
+ if self.group_by:
+ grouping = self.get_grouping()
+ result.append('GROUP BY %s' % ', '.join(grouping))
+
+ if ordering:
+ result.append('ORDER BY %s' % ', '.join(ordering))
+
+ # FIXME: Pull this out to make life easier for Oracle et al.
+ if with_limits:
+ if self.high_mark:
+ result.append('LIMIT %d' % (self.high_mark - self.low_mark))
+ if self.low_mark:
+ if not self.high_mark:
+ val = self.connection.ops.no_limit_value()
+ if val:
+ result.append('LIMIT %d' % val)
+ result.append('OFFSET %d' % self.low_mark)
+
+ params.extend(self.extra_params)
+ return ' '.join(result), tuple(params)
+
+ def combine(self, rhs, connector):
+ """
+ Merge the 'rhs' query into the current one (with any 'rhs' effects
+ being applied *after* (that is, "to the right of") anything in the
+ current query. 'rhs' is not modified during a call to this function.
+
+ The 'connector' parameter describes how to connect filters from the
+ 'rhs' query.
+ """
+ assert self.model == rhs.model, \
+ "Cannot combine queries on two different base models."
+ assert self.can_filter(), \
+ "Cannot combine queries once a slice has been taken."
+ assert self.distinct == rhs.distinct, \
+ "Cannot combine a unique query with a non-unique query."
+
+ # Work out how to relabel the rhs aliases, if necessary.
+ change_map = {}
+ used = set()
+ conjunction = (connector == AND)
+ first = True
+ for alias in rhs.tables:
+ if not rhs.alias_refcount[alias]:
+ # An unused alias.
+ continue
+ promote = (rhs.alias_map[alias][JOIN_TYPE] == self.LOUTER)
+ new_alias = self.join(rhs.rev_join_map[alias],
+ (conjunction and not first), used, promote, not conjunction)
+ used.add(new_alias)
+ change_map[alias] = new_alias
+ first = False
+
+ # So that we don't exclude valid results in an "or" query combination,
+ # the first join that is exclusive to the lhs (self) must be converted
+ # to an outer join.
+ if not conjunction:
+ for alias in self.tables[1:]:
+ if self.alias_refcount[alias] == 1:
+ self.promote_alias(alias, True)
+ break
+
+ # Now relabel a copy of the rhs where-clause and add it to the current
+ # one.
+ if rhs.where:
+ w = deepcopy(rhs.where)
+ w.relabel_aliases(change_map)
+ if not self.where:
+ # Since 'self' matches everything, add an explicit "include
+ # everything" where-constraint so that connections between the
+ # where clauses won't exclude valid results.
+ self.where.add(EverythingNode(), AND)
+ elif self.where:
+ # rhs has an empty where clause.
+ w = self.where_class()
+ w.add(EverythingNode(), AND)
+ else:
+ w = self.where_class()
+ self.where.add(w, connector)
+
+ # Selection columns and extra extensions are those provided by 'rhs'.
+ self.select = []
+ for col in rhs.select:
+ if isinstance(col, (list, tuple)):
+ self.select.append((change_map.get(col[0], col[0]), col[1]))
+ else:
+ item = deepcopy(col)
+ item.relabel_aliases(change_map)
+ self.select.append(item)
+ self.select_fields = rhs.select_fields[:]
+ self.extra_select = rhs.extra_select.copy()
+ self.extra_tables = rhs.extra_tables
+ self.extra_where = rhs.extra_where
+ self.extra_params = rhs.extra_params
+
+ # Ordering uses the 'rhs' ordering, unless it has none, in which case
+ # the current ordering is used.
+ self.order_by = rhs.order_by and rhs.order_by[:] or self.order_by
+ self.extra_order_by = rhs.extra_order_by or self.extra_order_by
+
+ def pre_sql_setup(self):
+ """
+ Does any necessary class setup immediately prior to producing SQL. This
+ is for things that can't necessarily be done in __init__ because we
+ might not have all the pieces in place at that time.
+ """
+ if not self.tables:
+ self.join((None, self.model._meta.db_table, None, None))
+ if self.select_related and not self.related_select_cols:
+ self.fill_related_selections()
+
+ def get_columns(self, with_aliases=False):
+ """
+ Return the list of columns to use in the select statement. If no
+ columns have been specified, returns all columns relating to fields in
+ the model.
+
+ If 'with_aliases' is true, any column names that are duplicated
+ (without the table names) are given unique aliases. This is needed in
+ some cases to avoid ambiguitity with nested queries.
+ """
+ qn = self.quote_name_unless_alias
+ result = ['(%s) AS %s' % (col, alias) for alias, col in self.extra_select.iteritems()]
+ aliases = set(self.extra_select.keys())
+ if with_aliases:
+ col_aliases = aliases.copy()
+ else:
+ col_aliases = set()
+ if self.select:
+ for col in self.select:
+ if isinstance(col, (list, tuple)):
+ r = '%s.%s' % (qn(col[0]), qn(col[1]))
+ if with_aliases and col[1] in col_aliases:
+ c_alias = 'Col%d' % len(col_aliases)
+ result.append('%s AS %s' % (r, c_alias))
+ aliases.add(c_alias)
+ col_aliases.add(c_alias)
+ else:
+ result.append(r)
+ aliases.add(r)
+ col_aliases.add(col[1])
+ else:
+ result.append(col.as_sql(quote_func=qn))
+ if hasattr(col, 'alias'):
+ aliases.add(col.alias)
+ col_aliases.add(col.alias)
+ elif self.default_cols:
+ cols, new_aliases = self.get_default_columns(with_aliases,
+ col_aliases)
+ result.extend(cols)
+ aliases.update(new_aliases)
+ for table, col in self.related_select_cols:
+ r = '%s.%s' % (qn(table), qn(col))
+ if with_aliases and col in col_aliases:
+ c_alias = 'Col%d' % len(col_aliases)
+ result.append('%s AS %s' % (r, c_alias))
+ aliases.add(c_alias)
+ col_aliases.add(c_alias)
+ else:
+ result.append(r)
+ aliases.add(r)
+ col_aliases.add(col)
+
+ self._select_aliases = aliases
+ return result
+
+ def get_default_columns(self, with_aliases=False, col_aliases=None):
+ """
+ Computes the default columns for selecting every field in the base
+ model.
+
+ Returns a list of strings, quoted appropriately for use in SQL
+ directly, as well as a set of aliases used in the select statement.
+ """
+ result = []
+ table_alias = self.tables[0]
+ root_pk = self.model._meta.pk.column
+ seen = {None: table_alias}
+ qn = self.quote_name_unless_alias
+ qn2 = self.connection.ops.quote_name
+ aliases = set()
+ for field, model in self.model._meta.get_fields_with_model():
+ try:
+ alias = seen[model]
+ except KeyError:
+ alias = self.join((table_alias, model._meta.db_table,
+ root_pk, model._meta.pk.column))
+ seen[model] = alias
+ if with_aliases and field.column in col_aliases:
+ c_alias = 'Col%d' % len(col_aliases)
+ result.append('%s.%s AS %s' % (qn(alias),
+ qn2(field.column), c_alias))
+ col_aliases.add(c_alias)
+ aliases.add(c_alias)
+ else:
+ r = '%s.%s' % (qn(alias), qn2(field.column))
+ result.append(r)
+ aliases.add(r)
+ if with_aliases:
+ col_aliases.add(field.column)
+ return result, aliases
+
+ def get_from_clause(self):
+ """
+ Returns a list of strings that are joined together to go after the
+ "FROM" part of the query, as well as a list any extra parameters that
+ need to be included. Sub-classes, can override this to create a
+ from-clause via a "select", for example (e.g. CountQuery).
+
+ This should only be called after any SQL construction methods that
+ might change the tables we need. This means the select columns and
+ ordering must be done first.
+ """
+ result = []
+ qn = self.quote_name_unless_alias
+ qn2 = self.connection.ops.quote_name
+ first = True
+ for alias in self.tables:
+ if not self.alias_refcount[alias]:
+ continue
+ try:
+ name, alias, join_type, lhs, lhs_col, col, nullable = self.alias_map[alias]
+ except KeyError:
+ # Extra tables can end up in self.tables, but not in the
+ # alias_map if they aren't in a join. That's OK. We skip them.
+ continue
+ alias_str = (alias != name and ' %s' % alias or '')
+ if join_type and not first:
+ result.append('%s %s%s ON (%s.%s = %s.%s)'
+ % (join_type, qn(name), alias_str, qn(lhs),
+ qn2(lhs_col), qn(alias), qn2(col)))
+ else:
+ connector = not first and ', ' or ''
+ result.append('%s%s%s' % (connector, qn(name), alias_str))
+ first = False
+ for t in self.extra_tables:
+ alias, unused = self.table_alias(t)
+ if alias not in self.alias_map:
+ connector = not first and ', ' or ''
+ result.append('%s%s' % (connector, qn(alias)))
+ first = False
+ return result, []
+
+ def get_grouping(self):
+ """
+ Returns a tuple representing the SQL elements in the "group by" clause.
+ """
+ qn = self.quote_name_unless_alias
+ result = []
+ for col in self.group_by:
+ if isinstance(col, (list, tuple)):
+ result.append('%s.%s' % (qn(col[0]), qn(col[1])))
+ elif hasattr(col, 'as_sql'):
+ result.append(col.as_sql(qn))
+ else:
+ result.append(str(col))
+ return result
+
+ def get_ordering(self):
+ """
+ Returns list representing the SQL elements in the "order by" clause.
+ Also sets the ordering_aliases attribute on this instance to a list of
+ extra aliases needed in the select.
+
+ Determining the ordering SQL can change the tables we need to include,
+ so this should be run *before* get_from_clause().
+ """
+ if self.extra_order_by:
+ ordering = self.extra_order_by
+ elif not self.default_ordering:
+ ordering = []
+ else:
+ ordering = self.order_by or self.model._meta.ordering
+ qn = self.quote_name_unless_alias
+ qn2 = self.connection.ops.quote_name
+ distinct = self.distinct
+ select_aliases = self._select_aliases
+ result = []
+ ordering_aliases = []
+ if self.standard_ordering:
+ asc, desc = ORDER_DIR['ASC']
+ else:
+ asc, desc = ORDER_DIR['DESC']
+ for field in ordering:
+ if field == '?':
+ result.append(self.connection.ops.random_function_sql())
+ continue
+ if isinstance(field, int):
+ if field < 0:
+ order = desc
+ field = -field
+ else:
+ order = asc
+ result.append('%s %s' % (field, order))
+ continue
+ if '.' in field:
+ # This came in through an extra(order_by=...) addition. Pass it
+ # on verbatim.
+ col, order = get_order_dir(field, asc)
+ table, col = col.split('.', 1)
+ elt = '%s.%s' % (qn(table), col)
+ if not distinct or elt in select_aliases:
+ result.append('%s %s' % (elt, order))
+ elif get_order_dir(field)[0] not in self.extra_select:
+ # 'col' is of the form 'field' or 'field1__field2' or
+ # '-field1__field2__field', etc.
+ for table, col, order in self.find_ordering_name(field,
+ self.model._meta, default_order=asc):
+ elt = '%s.%s' % (qn(table), qn2(col))
+ if distinct and elt not in select_aliases:
+ ordering_aliases.append(elt)
+ result.append('%s %s' % (elt, order))
+ else:
+ col, order = get_order_dir(field, asc)
+ elt = qn(col)
+ if distinct and elt not in select_aliases:
+ ordering_aliases.append(elt)
+ result.append('%s %s' % (elt, order))
+ self.ordering_aliases = ordering_aliases
+ return result
+
+ def find_ordering_name(self, name, opts, alias=None, default_order='ASC',
+ already_seen=None):
+ """
+ Returns the table alias (the name might be ambiguous, the alias will
+ not be) and column name for ordering by the given 'name' parameter.
+ The 'name' is of the form 'field1__field2__...__fieldN'.
+ """
+ name, order = get_order_dir(name, default_order)
+ pieces = name.split(LOOKUP_SEP)
+ if not alias:
+ alias = self.get_initial_alias()
+ field, target, opts, joins, last = self.setup_joins(pieces, opts,
+ alias, False)
+ alias = joins[-1]
+ col = target.column
+
+ # If we get to this point and the field is a relation to another model,
+ # append the default ordering for that model.
+ if field.rel and len(joins) > 1 and opts.ordering:
+ # Firstly, avoid infinite loops.
+ if not already_seen:
+ already_seen = set()
+ join_tuple = tuple([self.alias_map[j][TABLE_NAME] for j in joins])
+ if join_tuple in already_seen:
+ raise FieldError('Infinite loop caused by ordering.')
+ already_seen.add(join_tuple)
+
+ results = []
+ for item in opts.ordering:
+ results.extend(self.find_ordering_name(item, opts, alias,
+ order, already_seen))
+ return results
+
+ if alias:
+ # We have to do the same "final join" optimisation as in
+ # add_filter, since the final column might not otherwise be part of
+ # the select set (so we can't order on it).
+ join = self.alias_map[alias]
+ if col == join[RHS_JOIN_COL]:
+ self.unref_alias(alias)
+ alias = join[LHS_ALIAS]
+ col = join[LHS_JOIN_COL]
+ return [(alias, col, order)]
+
+ def table_alias(self, table_name, create=False):
+ """
+ Returns a table alias for the given table_name and whether this is a
+ new alias or not.
+
+ If 'create' is true, a new alias is always created. Otherwise, the
+ most recently created alias for the table (if one exists) is reused.
+ """
+ current = self.table_map.get(table_name)
+ if not create and current:
+ alias = current[0]
+ self.alias_refcount[alias] += 1
+ return alias, False
+
+ # Create a new alias for this table.
+ if current:
+ alias = '%s%d' % (self.alias_prefix, len(self.alias_map) + 1)
+ current.append(alias)
+ else:
+ # The first occurence of a table uses the table name directly.
+ alias = table_name
+ self.table_map[alias] = [alias]
+ self.alias_refcount[alias] = 1
+ #self.alias_map[alias] = None
+ self.tables.append(alias)
+ return alias, True
+
+ def ref_alias(self, alias):
+ """ Increases the reference count for this alias. """
+ self.alias_refcount[alias] += 1
+
+ def unref_alias(self, alias):
+ """ Decreases the reference count for this alias. """
+ self.alias_refcount[alias] -= 1
+
+ def promote_alias(self, alias, unconditional=False):
+ """
+ Promotes the join type of an alias to an outer join if it's possible
+ for the join to contain NULL values on the left. If 'unconditional' is
+ False, the join is only promoted if it is nullable, otherwise it is
+ always promoted.
+ """
+ if ((unconditional or self.alias_map[alias][NULLABLE]) and
+ self.alias_map[alias] != self.LOUTER):
+ data = list(self.alias_map[alias])
+ data[JOIN_TYPE] = self.LOUTER
+ self.alias_map[alias] = tuple(data)
+
+ def change_aliases(self, change_map):
+ """
+ Changes the aliases in change_map (which maps old-alias -> new-alias),
+ relabelling any references to them in select columns and the where
+ clause.
+ """
+ assert set(change_map.keys()).intersection(set(change_map.values())) == set()
+
+ # 1. Update references in "select" and "where".
+ self.where.relabel_aliases(change_map)
+ for pos, col in enumerate(self.select):
+ if isinstance(col, (list, tuple)):
+ self.select[pos] = (change_map.get(old_alias, old_alias), col[1])
+ else:
+ col.relabel_aliases(change_map)
+
+ # 2. Rename the alias in the internal table/alias datastructures.
+ for old_alias, new_alias in change_map.iteritems():
+ alias_data = list(self.alias_map[old_alias])
+ alias_data[RHS_ALIAS] = new_alias
+
+ t = self.rev_join_map[old_alias]
+ data = list(self.join_map[t])
+ data[data.index(old_alias)] = new_alias
+ self.join_map[t] = tuple(data)
+ self.rev_join_map[new_alias] = t
+ del self.rev_join_map[old_alias]
+ self.alias_refcount[new_alias] = self.alias_refcount[old_alias]
+ del self.alias_refcount[old_alias]
+ self.alias_map[new_alias] = tuple(alias_data)
+ del self.alias_map[old_alias]
+
+ table_aliases = self.table_map[alias_data[TABLE_NAME]]
+ for pos, alias in enumerate(table_aliases):
+ if alias == old_alias:
+ table_aliases[pos] = new_alias
+ break
+ for pos, alias in enumerate(self.tables):
+ if alias == old_alias:
+ self.tables[pos] = new_alias
+ break
+
+ # 3. Update any joins that refer to the old alias.
+ for alias, data in self.alias_map.iteritems():
+ lhs = data[LHS_ALIAS]
+ if lhs in change_map:
+ data = list(data)
+ data[LHS_ALIAS] = change_map[lhs]
+ self.alias_map[alias] = tuple(data)
+
+ def bump_prefix(self, exceptions=()):
+ """
+ Changes the alias prefix to the next letter in the alphabet and
+ relabels all the aliases. Even tables that previously had no alias will
+ get an alias after this call (it's mostly used for nested queries and
+ the outer query will already be using the non-aliased table name).
+
+ Subclasses who create their own prefix should override this method to
+ produce a similar result (a new prefix and relabelled aliases).
+
+ The 'exceptions' parameter is a container that holds alias names which
+ should not be changed.
+ """
+ assert ord(self.alias_prefix) < ord('Z')
+ self.alias_prefix = chr(ord(self.alias_prefix) + 1)
+ change_map = {}
+ prefix = self.alias_prefix
+ for pos, alias in enumerate(self.tables):
+ if alias in exceptions:
+ continue
+ new_alias = '%s%d' % (prefix, pos)
+ change_map[alias] = new_alias
+ self.tables[pos] = new_alias
+ self.change_aliases(change_map)
+
+ def get_initial_alias(self):
+ """
+ Returns the first alias for this query, after increasing its reference
+ count.
+ """
+ if self.tables:
+ alias = self.tables[0]
+ self.ref_alias(alias)
+ else:
+ alias = self.join((None, self.model._meta.db_table, None, None))
+ return alias
+
+ def count_active_tables(self):
+ """
+ Returns the number of tables in this query with a non-zero reference
+ count.
+ """
+ return len([1 for count in self.alias_refcount.itervalues() if count])
+
+ def join(self, connection, always_create=False, exclusions=(),
+ promote=False, outer_if_first=False, nullable=False, reuse=None):
+ """
+ Returns an alias for the join in 'connection', either reusing an
+ existing alias for that join or creating a new one. 'connection' is a
+ tuple (lhs, table, lhs_col, col) where 'lhs' is either an existing
+ table alias or a table name. The join correspods to the SQL equivalent
+ of::
+
+ lhs.lhs_col = table.col
+
+ If 'always_create' is True and 'reuse' is None, a new alias is always
+ created, regardless of whether one already exists or not. Otherwise
+ 'reuse' must be a set and a new join is created unless one of the
+ aliases in `reuse` can be used.
+
+ If 'exclusions' is specified, it is something satisfying the container
+ protocol ("foo in exclusions" must work) and specifies a list of
+ aliases that should not be returned, even if they satisfy the join.
+
+ If 'promote' is True, the join type for the alias will be LOUTER (if
+ the alias previously existed, the join type will be promoted from INNER
+ to LOUTER, if necessary).
+
+ If 'outer_if_first' is True and a new join is created, it will have the
+ LOUTER join type. This is used when joining certain types of querysets
+ and Q-objects together.
+
+ If 'nullable' is True, the join can potentially involve NULL values and
+ is a candidate for promotion (to "left outer") when combining querysets.
+ """
+ lhs, table, lhs_col, col = connection
+ if lhs in self.alias_map:
+ lhs_table = self.alias_map[lhs][TABLE_NAME]
+ else:
+ lhs_table = lhs
+
+ if reuse and always_create and table in self.table_map:
+ # Convert the 'reuse' to case to be "exclude everything but the
+ # reusable set for this table".
+ exclusions = set(self.table_map[table]).difference(reuse)
+ always_create = False
+ t_ident = (lhs_table, table, lhs_col, col)
+ if not always_create:
+ for alias in self.join_map.get(t_ident, ()):
+ if alias not in exclusions:
+ self.ref_alias(alias)
+ if promote:
+ self.promote_alias(alias)
+ return alias
+
+ # No reuse is possible, so we need a new alias.
+ alias, _ = self.table_alias(table, True)
+ if not lhs:
+ # Not all tables need to be joined to anything. No join type
+ # means the later columns are ignored.
+ join_type = None
+ elif promote or outer_if_first:
+ join_type = self.LOUTER
+ else:
+ join_type = self.INNER
+ join = (table, alias, join_type, lhs, lhs_col, col, nullable)
+ self.alias_map[alias] = join
+ if t_ident in self.join_map:
+ self.join_map[t_ident] += (alias,)
+ else:
+ self.join_map[t_ident] = (alias,)
+ self.rev_join_map[alias] = t_ident
+ return alias
+
+ def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
+ used=None, requested=None, restricted=None):
+ """
+ Fill in the information needed for a select_related query. The current
+ depth is measured as the number of connections away from the root model
+ (for example, cur_depth=1 means we are looking at models with direct
+ connections to the root model).
+ """
+ if not restricted and self.max_depth and cur_depth > self.max_depth:
+ # We've recursed far enough; bail out.
+ return
+ if not opts:
+ opts = self.get_meta()
+ root_alias = self.get_initial_alias()
+ self.related_select_cols = []
+ self.related_select_fields = []
+ if not used:
+ used = set()
+
+ # Setup for the case when only particular related fields should be
+ # included in the related selection.
+ if requested is None and restricted is not False:
+ if isinstance(self.select_related, dict):
+ requested = self.select_related
+ restricted = True
+ else:
+ restricted = False
+
+ for f, model in opts.get_fields_with_model():
+ if (not f.rel or (restricted and f.name not in requested) or
+ (not restricted and f.null) or f.rel.parent_link):
+ continue
+ table = f.rel.to._meta.db_table
+ if model:
+ int_opts = opts
+ alias = root_alias
+ for int_model in opts.get_base_chain(model):
+ lhs_col = int_opts.parents[int_model].column
+ int_opts = int_model._meta
+ alias = self.join((alias, int_opts.db_table, lhs_col,
+ int_opts.pk.column), exclusions=used,
+ promote=f.null)
+ else:
+ alias = root_alias
+ alias = self.join((alias, table, f.column,
+ f.rel.get_related_field().column), exclusions=used,
+ promote=f.null)
+ used.add(alias)
+ self.related_select_cols.extend([(alias, f2.column)
+ for f2 in f.rel.to._meta.fields])
+ self.related_select_fields.extend(f.rel.to._meta.fields)
+ if restricted:
+ next = requested.get(f.name, {})
+ else:
+ next = False
+ self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1,
+ used, next, restricted)
+
+ def add_filter(self, filter_expr, connector=AND, negate=False, trim=False,
+ can_reuse=None):
+ """
+ Add a single filter to the query. The 'filter_expr' is a pair:
+ (filter_string, value). E.g. ('name__contains', 'fred')
+
+ If 'negate' is True, this is an exclude() filter. If 'trim' is True, we
+ automatically trim the final join group (used internally when
+ constructing nested queries).
+
+ If 'can_reuse' is a set, we are processing a component of a
+ multi-component filter (e.g. filter(Q1, Q2)). In this case, 'can_reuse'
+ will be a set of table aliases that can be reused in this filter, even
+ if we would otherwise force the creation of new aliases for a join
+ (needed for nested Q-filters). The set is updated by this method.
+ """
+ arg, value = filter_expr
+ parts = arg.split(LOOKUP_SEP)
+ if not parts:
+ raise FieldError("Cannot parse keyword query %r" % arg)
+
+ # Work out the lookup type and remove it from 'parts', if necessary.
+ if len(parts) == 1 or parts[-1] not in self.query_terms:
+ lookup_type = 'exact'
+ else:
+ lookup_type = parts.pop()
+
+ # Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all
+ # uses of None as a query value.
+ if value is None:
+ if lookup_type != 'exact':
+ raise ValueError("Cannot use None as a query value")
+ lookup_type = 'isnull'
+ value = True
+ elif callable(value):
+ value = value()
+
+ opts = self.get_meta()
+ alias = self.get_initial_alias()
+ allow_many = trim or not negate
+
+ try:
+ field, target, opts, join_list, last = self.setup_joins(parts, opts,
+ alias, True, allow_many, can_reuse=can_reuse)
+ except MultiJoin, e:
+ self.split_exclude(filter_expr, LOOKUP_SEP.join(parts[:e.level]))
+ return
+ final = len(join_list)
+ penultimate = last.pop()
+ if penultimate == final:
+ penultimate = last.pop()
+ if trim and len(join_list) > 1:
+ extra = join_list[penultimate:]
+ join_list = join_list[:penultimate]
+ final = penultimate
+ penultimate = last.pop()
+ col = self.alias_map[extra[0]][LHS_JOIN_COL]
+ for alias in extra:
+ self.unref_alias(alias)
+ else:
+ col = target.column
+ alias = join_list[-1]
+
+ if final > 1:
+ # An optimization: if the final join is against the same column as
+ # we are comparing against, we can go back one step in the join
+ # chain and compare against the lhs of the join instead. The result
+ # (potentially) involves one less table join.
+ join = self.alias_map[alias]
+ if col == join[RHS_JOIN_COL]:
+ self.unref_alias(alias)
+ alias = join[LHS_ALIAS]
+ col = join[LHS_JOIN_COL]
+ join_list = join_list[:-1]
+ final -= 1
+ if final == penultimate:
+ penultimate = last.pop()
+
+ if (lookup_type == 'isnull' and value is True and not negate and
+ final > 1):
+ # If the comparison is against NULL, we need to use a left outer
+ # join when connecting to the previous model. We make that
+ # adjustment here. We don't do this unless needed as it's less
+ # efficient at the database level.
+ self.promote_alias(join_list[penultimate])
+
+ if connector == OR:
+ # Some joins may need to be promoted when adding a new filter to a
+ # disjunction. We walk the list of new joins and where it diverges
+ # from any previous joins (ref count is 1 in the table list), we
+ # make the new additions (and any existing ones not used in the new
+ # join list) an outer join.
+ join_it = iter(join_list)
+ table_it = iter(self.tables)
+ join_it.next(), table_it.next()
+ for join in join_it:
+ table = table_it.next()
+ if join == table and self.alias_refcount[join] > 1:
+ continue
+ self.promote_alias(join)
+ if table != join:
+ self.promote_alias(table)
+ break
+ for join in join_it:
+ self.promote_alias(join)
+ for table in table_it:
+ # Some of these will have been promoted from the join_list, but
+ # that's harmless.
+ self.promote_alias(table)
+
+ self.where.add((alias, col, field, lookup_type, value), connector)
+ if negate:
+ self.where.negate()
+ for alias in join_list:
+ self.promote_alias(alias)
+ if final > 1 and lookup_type != 'isnull':
+ for alias in join_list:
+ if self.alias_map[alias] == self.LOUTER:
+ j_col = self.alias_map[alias][RHS_JOIN_COL]
+ entry = Node([(alias, j_col, None, 'isnull', True)])
+ entry.negate()
+ self.where.add(entry, AND)
+ break
+ if can_reuse is not None:
+ can_reuse.update(join_list)
+
+ def add_q(self, q_object, used_aliases=None):
+ """
+ Adds a Q-object to the current filter.
+
+ Can also be used to add anything that has an 'add_to_query()' method.
+ """
+ if used_aliases is None:
+ used_aliases = set()
+ if hasattr(q_object, 'add_to_query'):
+ # Complex custom objects are responsible for adding themselves.
+ q_object.add_to_query(self, used_aliases)
+ return
+
+ if self.where and q_object.connector != AND and len(q_object) > 1:
+ self.where.start_subtree(AND)
+ subtree = True
+ else:
+ subtree = False
+ connector = AND
+ for child in q_object.children:
+ if isinstance(child, Node):
+ self.where.start_subtree(connector)
+ self.add_q(child, used_aliases)
+ self.where.end_subtree()
+ if q_object.negated:
+ self.where.children[-1].negate()
+ else:
+ self.add_filter(child, connector, q_object.negated,
+ can_reuse=used_aliases)
+ connector = q_object.connector
+ if subtree:
+ self.where.end_subtree()
+
+ def setup_joins(self, names, opts, alias, dupe_multis, allow_many=True,
+ allow_explicit_fk=False, can_reuse=None):
+ """
+ Compute the necessary table joins for the passage through the fields
+ given in 'names'. 'opts' is the Options class for the current model
+ (which gives the table we are joining to), 'alias' is the alias for the
+ table we are joining to. If dupe_multis is True, any many-to-many or
+ many-to-one joins will always create a new alias (necessary for
+ disjunctive filters).
+
+ Returns the final field involved in the join, the target database
+ column (used for any 'where' constraint), the final 'opts' value and the
+ list of tables joined.
+ """
+ joins = [alias]
+ last = [0]
+ for pos, name in enumerate(names):
+ last.append(len(joins))
+ if name == 'pk':
+ name = opts.pk.name
+
+ try:
+ field, model, direct, m2m = opts.get_field_by_name(name)
+ except FieldDoesNotExist:
+ for f in opts.fields:
+ if allow_explicit_fk and name == f.attname:
+ # XXX: A hack to allow foo_id to work in values() for
+ # backwards compatibility purposes. If we dropped that
+ # feature, this could be removed.
+ field, model, direct, m2m = opts.get_field_by_name(f.name)
+ break
+ else:
+ names = opts.get_all_field_names()
+ raise FieldError("Cannot resolve keyword %r into field. "
+ "Choices are: %s" % (name, ", ".join(names)))
+ if not allow_many and (m2m or not direct):
+ for alias in joins:
+ self.unref_alias(alias)
+ raise MultiJoin(pos + 1)
+ if model:
+ # The field lives on a base class of the current model.
+ alias_list = []
+ for int_model in opts.get_base_chain(model):
+ lhs_col = opts.parents[int_model].column
+ opts = int_model._meta
+ alias = self.join((alias, opts.db_table, lhs_col,
+ opts.pk.column), exclusions=joins)
+ joins.append(alias)
+ cached_data = opts._join_cache.get(name)
+ orig_opts = opts
+
+ if direct:
+ if m2m:
+ # Many-to-many field defined on the current model.
+ if cached_data:
+ (table1, from_col1, to_col1, table2, from_col2,
+ to_col2, opts, target) = cached_data
+ else:
+ table1 = field.m2m_db_table()
+ from_col1 = opts.pk.column
+ to_col1 = field.m2m_column_name()
+ opts = field.rel.to._meta
+ table2 = opts.db_table
+ from_col2 = field.m2m_reverse_name()
+ to_col2 = opts.pk.column
+ target = opts.pk
+ orig_opts._join_cache[name] = (table1, from_col1,
+ to_col1, table2, from_col2, to_col2, opts,
+ target)
+
+ int_alias = self.join((alias, table1, from_col1, to_col1),
+ dupe_multis, joins, nullable=True, reuse=can_reuse)
+ alias = self.join((int_alias, table2, from_col2, to_col2),
+ dupe_multis, joins, nullable=True, reuse=can_reuse)
+ joins.extend([int_alias, alias])
+ elif field.rel:
+ # One-to-one or many-to-one field
+ if cached_data:
+ (table, from_col, to_col, opts, target) = cached_data
+ else:
+ opts = field.rel.to._meta
+ target = field.rel.get_related_field()
+ table = opts.db_table
+ from_col = field.column
+ to_col = target.column
+ orig_opts._join_cache[name] = (table, from_col, to_col,
+ opts, target)
+
+ alias = self.join((alias, table, from_col, to_col),
+ exclusions=joins, nullable=field.null)
+ joins.append(alias)
+ else:
+ # Non-relation fields.
+ target = field
+ break
+ else:
+ orig_field = field
+ field = field.field
+ if m2m:
+ # Many-to-many field defined on the target model.
+ if cached_data:
+ (table1, from_col1, to_col1, table2, from_col2,
+ to_col2, opts, target) = cached_data
+ else:
+ table1 = field.m2m_db_table()
+ from_col1 = opts.pk.column
+ to_col1 = field.m2m_reverse_name()
+ opts = orig_field.opts
+ table2 = opts.db_table
+ from_col2 = field.m2m_column_name()
+ to_col2 = opts.pk.column
+ target = opts.pk
+ orig_opts._join_cache[name] = (table1, from_col1,
+ to_col1, table2, from_col2, to_col2, opts,
+ target)
+
+ int_alias = self.join((alias, table1, from_col1, to_col1),
+ dupe_multis, joins, nullable=True, reuse=can_reuse)
+ alias = self.join((int_alias, table2, from_col2, to_col2),
+ dupe_multis, joins, nullable=True, reuse=can_reuse)
+ joins.extend([int_alias, alias])
+ else:
+ # One-to-many field (ForeignKey defined on the target model)
+ if cached_data:
+ (table, from_col, to_col, opts, target) = cached_data
+ else:
+ local_field = opts.get_field_by_name(
+ field.rel.field_name)[0]
+ opts = orig_field.opts
+ table = opts.db_table
+ from_col = local_field.column
+ to_col = field.column
+ target = opts.pk
+ orig_opts._join_cache[name] = (table, from_col, to_col,
+ opts, target)
+
+ alias = self.join((alias, table, from_col, to_col),
+ dupe_multis, joins, nullable=True, reuse=can_reuse)
+ joins.append(alias)
+
+ if pos != len(names) - 1:
+ raise FieldError("Join on field %r not permitted." % name)
+
+ return field, target, opts, joins, last
+
+ def split_exclude(self, filter_expr, prefix):
+ """
+ When doing an exclude against any kind of N-to-many relation, we need
+ to use a subquery. This method constructs the nested query, given the
+ original exclude filter (filter_expr) and the portion up to the first
+ N-to-many relation field.
+ """
+ query = Query(self.model, self.connection)
+ query.add_filter(filter_expr)
+ query.set_start(prefix)
+ query.clear_ordering(True)
+ self.add_filter(('%s__in' % prefix, query), negate=True, trim=True)
+
+ def set_limits(self, low=None, high=None):
+ """
+ Adjusts the limits on the rows retrieved. We use low/high to set these,
+ as it makes it more Pythonic to read and write. When the SQL query is
+ created, they are converted to the appropriate offset and limit values.
+
+ Any limits passed in here are applied relative to the existing
+ constraints. So low is added to the current low value and both will be
+ clamped to any existing high value.
+ """
+ if high:
+ if self.high_mark:
+ self.high_mark = min(self.high_mark, self.low_mark + high)
+ else:
+ self.high_mark = self.low_mark + high
+ if low:
+ if self.high_mark:
+ self.low_mark = min(self.high_mark, self.low_mark + low)
+ else:
+ self.low_mark = self.low_mark + low
+
+ def clear_limits(self):
+ """
+ Clears any existing limits.
+ """
+ self.low_mark, self.high_mark = 0, None
+
+ def can_filter(self):
+ """
+ Returns True if adding filters to this instance is still possible.
+
+ Typically, this means no limits or offsets have been put on the results.
+ """
+ return not (self.low_mark or self.high_mark)
+
+ def add_fields(self, field_names, allow_m2m=True):
+ """
+ Adds the given (model) fields to the select set. The field names are
+ added in the order specified.
+ """
+ alias = self.get_initial_alias()
+ opts = self.get_meta()
+ try:
+ for name in field_names:
+ field, target, u2, joins, u3 = self.setup_joins(
+ name.split(LOOKUP_SEP), opts, alias, False, allow_m2m,
+ True)
+ final_alias = joins[-1]
+ col = target.column
+ if len(joins) > 1:
+ join = self.alias_map[final_alias]
+ if col == join[RHS_JOIN_COL]:
+ self.unref_alias(final_alias)
+ final_alias = join[LHS_ALIAS]
+ col = join[LHS_JOIN_COL]
+ joins = joins[:-1]
+ for join in joins[1:]:
+ # Only nullable aliases are promoted, so we don't end up
+ # doing unnecessary left outer joins here.
+ self.promote_alias(join)
+ self.select.append((final_alias, col))
+ self.select_fields.append(field)
+ except MultiJoin:
+ raise FieldError("Invalid field name: '%s'" % name)
+ except FieldError:
+ names = opts.get_all_field_names() + self.extra_select.keys()
+ names.sort()
+ raise FieldError("Cannot resolve keyword %r into field. "
+ "Choices are: %s" % (name, ", ".join(names)))
+
+ def add_ordering(self, *ordering):
+ """
+ Adds items from the 'ordering' sequence to the query's "order by"
+ clause. These items are either field names (not column names) --
+ possibly with a direction prefix ('-' or '?') -- or ordinals,
+ corresponding to column positions in the 'select' list.
+
+ If 'ordering' is empty, all ordering is cleared from the query.
+ """
+ errors = []
+ for item in ordering:
+ if not ORDER_PATTERN.match(item):
+ errors.append(item)
+ if errors:
+ raise FieldError('Invalid order_by arguments: %s' % errors)
+ if ordering:
+ self.order_by.extend(ordering)
+ else:
+ self.default_ordering = False
+
+ def clear_ordering(self, force_empty=False):
+ """
+ Removes any ordering settings. If 'force_empty' is True, there will be
+ no ordering in the resulting query (not even the model's default).
+ """
+ self.order_by = []
+ self.extra_order_by = ()
+ if force_empty:
+ self.default_ordering = False
+
+ def add_count_column(self):
+ """
+ Converts the query to do count(...) or count(distinct(pk)) in order to
+ get its size.
+ """
+ # TODO: When group_by support is added, this needs to be adjusted so
+ # that it doesn't totally overwrite the select list.
+ if not self.distinct:
+ if not self.select:
+ select = Count()
+ else:
+ assert len(self.select) == 1, \
+ "Cannot add count col with multiple cols in 'select': %r" % self.select
+ select = Count(self.select[0])
+ else:
+ opts = self.model._meta
+ if not self.select:
+ select = Count((self.join((None, opts.db_table, None, None)),
+ opts.pk.column), True)
+ else:
+ # Because of SQL portability issues, multi-column, distinct
+ # counts need a sub-query -- see get_count() for details.
+ assert len(self.select) == 1, \
+ "Cannot add count col with multiple cols in 'select'."
+ select = Count(self.select[0], True)
+
+ # Distinct handling is done in Count(), so don't do it at this
+ # level.
+ self.distinct = False
+ self.select = [select]
+ self.select_fields = [None]
+ self.extra_select = {}
+ self.extra_select_params = ()
+
+ def add_select_related(self, fields):
+ """
+ Sets up the select_related data structure so that we only select
+ certain related models (as opposed to all models, when
+ self.select_related=True).
+ """
+ field_dict = {}
+ for field in fields:
+ d = field_dict
+ for part in field.split(LOOKUP_SEP):
+ d = d.setdefault(part, {})
+ self.select_related = field_dict
+ self.related_select_cols = []
+ self.related_select_fields = []
+
+ def add_extra(self, select, select_params, where, params, tables, order_by):
+ """
+ Adds data to the various extra_* attributes for user-created additions
+ to the query.
+ """
+ if select:
+ # The extra select might be ordered (because it will be accepting
+ # parameters).
+ if (isinstance(select, SortedDict) and
+ not isinstance(self.extra_select, SortedDict)):
+ self.extra_select = SortedDict(self.extra_select)
+ self.extra_select.update(select)
+ if select_params:
+ self.extra_select_params += tuple(select_params)
+ if where:
+ self.extra_where += tuple(where)
+ if params:
+ self.extra_params += tuple(params)
+ if tables:
+ self.extra_tables += tuple(tables)
+ if order_by:
+ self.extra_order_by = order_by
+
+ def trim_extra_select(self, names):
+ """
+ Removes any aliases in the extra_select dictionary that aren't in
+ 'names'.
+
+ This is needed if we are selecting certain values that don't incldue
+ all of the extra_select names.
+ """
+ for key in set(self.extra_select).difference(set(names)):
+ del self.extra_select[key]
+
+ def set_start(self, start):
+ """
+ Sets the table from which to start joining. The start position is
+ specified by the related attribute from the base model. This will
+ automatically set to the select column to be the column linked from the
+ previous table.
+
+ This method is primarily for internal use and the error checking isn't
+ as friendly as add_filter(). Mostly useful for querying directly
+ against the join table of many-to-many relation in a subquery.
+ """
+ opts = self.model._meta
+ alias = self.get_initial_alias()
+ field, col, opts, joins, last = self.setup_joins(
+ start.split(LOOKUP_SEP), opts, alias, False)
+ alias = joins[last[-1]]
+ self.select = [(alias, self.alias_map[alias][RHS_JOIN_COL])]
+ self.select_fields = [field]
+ self.start_meta = opts
+
+ # The call to setup_joins add an extra reference to everything in
+ # joins. So we need to unref everything once, and everything prior to
+ # the final join a second time.
+ for alias in joins:
+ self.unref_alias(alias)
+ for alias in joins[:last[-1]]:
+ self.unref_alias(alias)
+
+ def execute_sql(self, result_type=MULTI):
+ """
+ Run the query against the database and returns the result(s). The
+ return value is a single data item if result_type is SINGLE, or an
+ iterator over the results if the result_type is MULTI.
+
+ result_type is either MULTI (use fetchmany() to retrieve all rows),
+ SINGLE (only retrieve a single row), or None (no results expected, but
+ the cursor is returned, since it's used by subclasses such as
+ InsertQuery).
+ """
+ try:
+ sql, params = self.as_sql()
+ if not sql:
+ raise EmptyResultSet
+ except EmptyResultSet:
+ if result_type == MULTI:
+ return empty_iter()
+ else:
+ return
+
+ cursor = self.connection.cursor()
+ cursor.execute(sql, params)
+
+ if not result_type:
+ return cursor
+ if result_type == SINGLE:
+ if self.ordering_aliases:
+ return cursor.fetchone()[:-len(results.ordering_aliases)]
+ return cursor.fetchone()
+
+ # The MULTI case.
+ if self.ordering_aliases:
+ return order_modified_iter(cursor, len(self.ordering_aliases),
+ self.connection.features.empty_fetchmany_value)
+ return iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)),
+ self.connection.features.empty_fetchmany_value)
+
+# Use the backend's custom Query class if it defines one. Otherwise, use the
+# default.
+if connection.features.uses_custom_query_class:
+ Query = connection.ops.query_class(Query)
+
+def get_order_dir(field, default='ASC'):
+ """
+ Returns the field name and direction for an order specification. For
+ example, '-foo' is returned as ('foo', 'DESC').
+
+ The 'default' param is used to indicate which way no prefix (or a '+'
+ prefix) should sort. The '-' prefix always sorts the opposite way.
+ """
+ dirn = ORDER_DIR[default]
+ if field[0] == '-':
+ return field[1:], dirn[1]
+ return field, dirn[0]
+
+def empty_iter():
+ """
+ Returns an iterator containing no results.
+ """
+ yield iter([]).next()
+
+def order_modified_iter(cursor, trim, sentinel):
+ """
+ Yields blocks of rows from a cursor. We use this iterator in the special
+ case when extra output columns have been added to support ordering
+ requirements. We must trim those extra columns before anything else can use
+ the results, since they're only needed to make the SQL valid.
+ """
+ for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)),
+ sentinel):
+ yield [r[:-trim] for r in rows]
+
+def setup_join_cache(sender):
+ """
+ The information needed to join between model fields is something that is
+ invariant over the life of the model, so we cache it in the model's Options
+ class, rather than recomputing it all the time.
+
+ This method initialises the (empty) cache when the model is created.
+ """
+ sender._meta._join_cache = {}
+
+dispatcher.connect(setup_join_cache, signal=signals.class_prepared)
+
diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py
new file mode 100644
index 0000000000..382e6e94ff
--- /dev/null
+++ b/django/db/models/sql/subqueries.py
@@ -0,0 +1,367 @@
+"""
+Query subclasses which provide extra functionality beyond simple data retrieval.
+"""
+
+from django.contrib.contenttypes import generic
+from django.core.exceptions import FieldError
+from django.db.models.sql.constants import *
+from django.db.models.sql.datastructures import RawValue, Date
+from django.db.models.sql.query import Query
+from django.db.models.sql.where import AND
+
+__all__ = ['DeleteQuery', 'UpdateQuery', 'InsertQuery', 'DateQuery',
+ 'CountQuery']
+
+class DeleteQuery(Query):
+ """
+ Delete queries are done through this class, since they are more constrained
+ than general queries.
+ """
+ def as_sql(self):
+ """
+ Creates the SQL for this query. Returns the SQL string and list of
+ parameters.
+ """
+ assert len(self.tables) == 1, \
+ "Can only delete from one table at a time."
+ result = ['DELETE FROM %s' % self.quote_name_unless_alias(self.tables[0])]
+ where, params = self.where.as_sql()
+ result.append('WHERE %s' % where)
+ return ' '.join(result), tuple(params)
+
+ def do_query(self, table, where):
+ self.tables = [table]
+ self.where = where
+ self.execute_sql(None)
+
+ def delete_batch_related(self, pk_list):
+ """
+ Set up and execute delete queries for all the objects related to the
+ primary key values in pk_list. To delete the objects themselves, use
+ the delete_batch() method.
+
+ More than one physical query may be executed if there are a
+ lot of values in pk_list.
+ """
+ cls = self.model
+ for related in cls._meta.get_all_related_many_to_many_objects():
+ if not isinstance(related.field, generic.GenericRelation):
+ for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
+ where = self.where_class()
+ where.add((None, related.field.m2m_reverse_name(),
+ related.field, 'in',
+ pk_list[offset : offset+GET_ITERATOR_CHUNK_SIZE]),
+ AND)
+ self.do_query(related.field.m2m_db_table(), where)
+
+ for f in cls._meta.many_to_many:
+ w1 = self.where_class()
+ if isinstance(f, generic.GenericRelation):
+ from django.contrib.contenttypes.models import ContentType
+ field = f.rel.to._meta.get_field(f.content_type_field_name)
+ w1.add((None, field.column, field, 'exact',
+ ContentType.objects.get_for_model(cls).id), AND)
+ for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
+ where = self.where_class()
+ where.add((None, f.m2m_column_name(), f, 'in',
+ pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]),
+ AND)
+ if w1:
+ where.add(w1, AND)
+ self.do_query(f.m2m_db_table(), where)
+
+ def delete_batch(self, pk_list):
+ """
+ Set up and execute delete queries for all the objects in pk_list. This
+ should be called after delete_batch_related(), if necessary.
+
+ More than one physical query may be executed if there are a
+ lot of values in pk_list.
+ """
+ for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
+ where = self.where_class()
+ field = self.model._meta.pk
+ where.add((None, field.column, field, 'in',
+ pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND)
+ self.do_query(self.model._meta.db_table, where)
+
+class UpdateQuery(Query):
+ """
+ Represents an "update" SQL query.
+ """
+ def __init__(self, *args, **kwargs):
+ super(UpdateQuery, self).__init__(*args, **kwargs)
+ self._setup_query()
+
+ def _setup_query(self):
+ """
+ Runs on initialisation and after cloning. Any attributes that would
+ normally be set in __init__ should go in here, instead, so that they
+ are also set up after a clone() call.
+ """
+ self.values = []
+ self.related_ids = None
+ if not hasattr(self, 'related_updates'):
+ self.related_updates = {}
+
+ def clone(self, klass=None, **kwargs):
+ return super(UpdateQuery, self).clone(klass,
+ related_updates=self.related_updates.copy, **kwargs)
+
+ def execute_sql(self, result_type=None):
+ super(UpdateQuery, self).execute_sql(result_type)
+ for query in self.get_related_updates():
+ query.execute_sql(result_type)
+
+ def as_sql(self):
+ """
+ Creates the SQL for this query. Returns the SQL string and list of
+ parameters.
+ """
+ self.pre_sql_setup()
+ if not self.values:
+ return '', ()
+ table = self.tables[0]
+ qn = self.quote_name_unless_alias
+ result = ['UPDATE %s' % qn(table)]
+ result.append('SET')
+ values, update_params = [], []
+ for name, val, placeholder in self.values:
+ if val is not None:
+ values.append('%s = %s' % (qn(name), placeholder))
+ update_params.append(val)
+ else:
+ values.append('%s = NULL' % qn(name))
+ result.append(', '.join(values))
+ where, params = self.where.as_sql()
+ if where:
+ result.append('WHERE %s' % where)
+ return ' '.join(result), tuple(update_params + params)
+
+ def pre_sql_setup(self):
+ """
+ If the update depends on results from other tables, we need to do some
+ munging of the "where" conditions to match the format required for
+ (portable) SQL updates. That is done here.
+
+ Further, if we are going to be running multiple updates, we pull out
+ the id values to update at this point so that they don't change as a
+ result of the progressive updates.
+ """
+ self.select_related = False
+ self.clear_ordering(True)
+ super(UpdateQuery, self).pre_sql_setup()
+ count = self.count_active_tables()
+ if not self.related_updates and count == 1:
+ return
+
+ # We need to use a sub-select in the where clause to filter on things
+ # from other tables.
+ query = self.clone(klass=Query)
+ query.bump_prefix()
+ query.select = []
+ query.extra_select = {}
+ query.add_fields([query.model._meta.pk.name])
+
+ # Now we adjust the current query: reset the where clause and get rid
+ # of all the tables we don't need (since they're in the sub-select).
+ self.where = self.where_class()
+ if self.related_updates:
+ idents = []
+ for rows in query.execute_sql(MULTI):
+ idents.extend([r[0] for r in rows])
+ self.add_filter(('pk__in', idents))
+ self.related_ids = idents
+ else:
+ self.add_filter(('pk__in', query))
+ for alias in self.tables[1:]:
+ self.alias_refcount[alias] = 0
+
+ def clear_related(self, related_field, pk_list):
+ """
+ Set up and execute an update query that clears related entries for the
+ keys in pk_list.
+
+ This is used by the QuerySet.delete_objects() method.
+ """
+ for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
+ self.where = self.where_class()
+ f = self.model._meta.pk
+ self.where.add((None, f.column, f, 'in',
+ pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]),
+ AND)
+ self.values = [(related_field.column, None, '%s')]
+ self.execute_sql(None)
+
+ def add_update_values(self, values):
+ """
+ Convert a dictionary of field name to value mappings into an update
+ query. This is the entry point for the public update() method on
+ querysets.
+ """
+ values_seq = []
+ for name, val in values.iteritems():
+ field, model, direct, m2m = self.model._meta.get_field_by_name(name)
+ if not direct or m2m:
+ raise FieldError('Cannot update model field %r (only non-relations and foreign keys permitted).' % field)
+ values_seq.append((field, model, val))
+ return self.add_update_fields(values_seq)
+
+ def add_update_fields(self, values_seq):
+ """
+ Turn a sequence of (field, model, value) triples into an update query.
+ Used by add_update_values() as well as the "fast" update path when
+ saving models.
+ """
+ from django.db.models.base import Model
+ for field, model, val in values_seq:
+ # FIXME: Some sort of db_prep_* is probably more appropriate here.
+ if field.rel and isinstance(val, Model):
+ val = val.pk
+
+ # Getting the placeholder for the field.
+ if hasattr(field, 'get_placeholder'):
+ placeholder = field.get_placeholder(val)
+ else:
+ placeholder = '%s'
+
+ if model:
+ self.add_related_update(model, field.column, val, placeholder)
+ else:
+ self.values.append((field.column, val, placeholder))
+
+ def add_related_update(self, model, column, value, placeholder):
+ """
+ Adds (name, value) to an update query for an ancestor model.
+
+ Updates are coalesced so that we only run one update query per ancestor.
+ """
+ try:
+ self.related_updates[model].append((column, value, placeholder))
+ except KeyError:
+ self.related_updates[model] = [(column, value, placeholder)]
+
+ def get_related_updates(self):
+ """
+ Returns a list of query objects: one for each update required to an
+ ancestor model. Each query will have the same filtering conditions as
+ the current query but will only update a single table.
+ """
+ if not self.related_updates:
+ return []
+ result = []
+ for model, values in self.related_updates.iteritems():
+ query = UpdateQuery(model, self.connection)
+ query.values = values
+ if self.related_ids:
+ query.add_filter(('pk__in', self.related_ids))
+ result.append(query)
+ return result
+
+class InsertQuery(Query):
+ def __init__(self, *args, **kwargs):
+ super(InsertQuery, self).__init__(*args, **kwargs)
+ self.columns = []
+ self.values = []
+ self.params = ()
+
+ def clone(self, klass=None, **kwargs):
+ extras = {'columns': self.columns[:], 'values': self.values[:],
+ 'params': self.params}
+ return super(InsertQuery, self).clone(klass, extras)
+
+ def as_sql(self):
+ # We don't need quote_name_unless_alias() here, since these are all
+ # going to be column names (so we can avoid the extra overhead).
+ qn = self.connection.ops.quote_name
+ result = ['INSERT INTO %s' % qn(self.model._meta.db_table)]
+ result.append('(%s)' % ', '.join([qn(c) for c in self.columns]))
+ result.append('VALUES (%s)' % ', '.join(self.values))
+ return ' '.join(result), self.params
+
+ def execute_sql(self, return_id=False):
+ cursor = super(InsertQuery, self).execute_sql(None)
+ if return_id:
+ return self.connection.ops.last_insert_id(cursor,
+ self.model._meta.db_table, self.model._meta.pk.column)
+
+ def insert_values(self, insert_values, raw_values=False):
+ """
+ Set up the insert query from the 'insert_values' dictionary. The
+ dictionary gives the model field names and their target values.
+
+ If 'raw_values' is True, the values in the 'insert_values' dictionary
+ are inserted directly into the query, rather than passed as SQL
+ parameters. This provides a way to insert NULL and DEFAULT keywords
+ into the query, for example.
+ """
+ placeholders, values = [], []
+ for field, val in insert_values:
+ if hasattr(field, 'get_placeholder'):
+ # Some fields (e.g. geo fields) need special munging before
+ # they can be inserted.
+ placeholders.append(field.get_placeholder(val))
+ else:
+ placeholders.append('%s')
+
+ self.columns.append(field.column)
+ values.append(val)
+ if raw_values:
+ self.values.extend(values)
+ else:
+ self.params += tuple(values)
+ self.values.extend(placeholders)
+
+class DateQuery(Query):
+ """
+ A DateQuery is a normal query, except that it specifically selects a single
+ date field. This requires some special handling when converting the results
+ back to Python objects, so we put it in a separate class.
+ """
+ def results_iter(self):
+ """
+ Returns an iterator over the results from executing this query.
+ """
+ resolve_columns = hasattr(self, 'resolve_columns')
+ if resolve_columns:
+ from django.db.models.fields import DateTimeField
+ fields = [DateTimeField()]
+ else:
+ from django.db.backends.util import typecast_timestamp
+ needs_string_cast = self.connection.features.needs_datetime_string_cast
+
+ offset = len(self.extra_select)
+ for rows in self.execute_sql(MULTI):
+ for row in rows:
+ date = row[offset]
+ if resolve_columns:
+ date = self.resolve_columns([date], fields)[0]
+ elif needs_string_cast:
+ date = typecast_timestamp(str(date))
+ yield date
+
+ def add_date_select(self, column, lookup_type, order='ASC'):
+ """
+ Converts the query into a date extraction query.
+ """
+ alias = self.join((None, self.model._meta.db_table, None, None))
+ select = Date((alias, column), lookup_type,
+ self.connection.ops.date_trunc_sql)
+ self.select = [select]
+ self.select_fields = [None]
+ self.distinct = True
+ self.order_by = order == 'ASC' and [1] or [-1]
+
+class CountQuery(Query):
+ """
+ A CountQuery knows how to take a normal query which would select over
+ multiple distinct columns and turn it into SQL that can be used on a
+ variety of backends (it requires a select in the FROM clause).
+ """
+ def get_from_clause(self):
+ result, params = self._query.as_sql()
+ return ['(%s) A1' % result], params
+
+ def get_ordering(self):
+ return ()
+
diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py
new file mode 100644
index 0000000000..c8857a01fe
--- /dev/null
+++ b/django/db/models/sql/where.py
@@ -0,0 +1,171 @@
+"""
+Code to manage the creation and SQL rendering of 'where' constraints.
+"""
+import datetime
+
+from django.utils import tree
+from django.db import connection
+from django.db.models.fields import Field
+from django.db.models.query_utils import QueryWrapper
+from datastructures import EmptyResultSet, FullResultSet
+
+# Connection types
+AND = 'AND'
+OR = 'OR'
+
+class WhereNode(tree.Node):
+ """
+ Used to represent the SQL where-clause.
+
+ The class is tied to the Query class that created it (in order to create
+ the corret SQL).
+
+ The children in this tree are usually either Q-like objects or lists of
+ [table_alias, field_name, field_class, lookup_type, value]. However, a
+ child could also be any class with as_sql() and relabel_aliases() methods.
+ """
+ default = AND
+
+ def as_sql(self, node=None, qn=None):
+ """
+ Returns the SQL version of the where clause and the value to be
+ substituted in. Returns None, None if this node is empty.
+
+ If 'node' is provided, that is the root of the SQL generation
+ (generally not needed except by the internal implementation for
+ recursion).
+ """
+ if node is None:
+ node = self
+ if not qn:
+ qn = connection.ops.quote_name
+ if not node.children:
+ return None, []
+ result = []
+ result_params = []
+ empty = True
+ for child in node.children:
+ try:
+ if hasattr(child, 'as_sql'):
+ sql, params = child.as_sql(qn=qn)
+ format = '(%s)'
+ elif isinstance(child, tree.Node):
+ sql, params = self.as_sql(child, qn)
+ if len(child.children) == 1:
+ format = '%s'
+ else:
+ format = '(%s)'
+ if child.negated:
+ format = 'NOT %s' % format
+ else:
+ sql, params = self.make_atom(child, qn)
+ format = '%s'
+ except EmptyResultSet:
+ if node.connector == AND and not node.negated:
+ # We can bail out early in this particular case (only).
+ raise
+ elif node.negated:
+ empty = False
+ continue
+ except FullResultSet:
+ if self.connector == OR:
+ if node.negated:
+ empty = True
+ break
+ # We match everything. No need for any constraints.
+ return '', []
+ if node.negated:
+ empty = True
+ continue
+ empty = False
+ if sql:
+ result.append(format % sql)
+ result_params.extend(params)
+ if empty:
+ raise EmptyResultSet
+ conn = ' %s ' % node.connector
+ return conn.join(result), result_params
+
+ def make_atom(self, child, qn):
+ """
+ Turn a tuple (table_alias, field_name, field_class, lookup_type, value)
+ into valid SQL.
+
+ Returns the string for the SQL fragment and the parameters to use for
+ it.
+ """
+ table_alias, name, field, lookup_type, value = child
+ if table_alias:
+ lhs = '%s.%s' % (qn(table_alias), qn(name))
+ else:
+ lhs = qn(name)
+ db_type = field and field.db_type() or None
+ field_sql = connection.ops.field_cast_sql(db_type) % lhs
+
+ if isinstance(value, datetime.datetime):
+ cast_sql = connection.ops.datetime_cast_sql()
+ else:
+ cast_sql = '%s'
+
+ if field:
+ params = field.get_db_prep_lookup(lookup_type, value)
+ else:
+ params = Field().get_db_prep_lookup(lookup_type, value)
+ if isinstance(params, QueryWrapper):
+ extra, params = params.data
+ else:
+ extra = ''
+
+ if lookup_type in connection.operators:
+ format = "%s %%s %s" % (connection.ops.lookup_cast(lookup_type),
+ extra)
+ return (format % (field_sql,
+ connection.operators[lookup_type] % cast_sql), params)
+
+ if lookup_type == 'in':
+ if not value:
+ raise EmptyResultSet
+ if extra:
+ return ('%s IN %s' % (field_sql, extra), params)
+ return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] * len(value))),
+ params)
+ elif lookup_type in ('range', 'year'):
+ return ('%s BETWEEN %%s and %%s' % field_sql, params)
+ elif lookup_type in ('month', 'day'):
+ return ('%s = %%s' % connection.ops.date_extract_sql(lookup_type,
+ field_sql), params)
+ elif lookup_type == 'isnull':
+ return ('%s IS %sNULL' % (field_sql, (not value and 'NOT ' or '')),
+ params)
+ elif lookup_type == 'search':
+ return (connection.ops.fulltext_search_sql(field_sql), params)
+ elif lookup_type in ('regex', 'iregex'):
+ return connection.ops.regex_lookup(lookup_type) % (field_sql, cast_sql), params
+
+ raise TypeError('Invalid lookup_type: %r' % lookup_type)
+
+ def relabel_aliases(self, change_map, node=None):
+ """
+ Relabels the alias values of any children. 'change_map' is a dictionary
+ mapping old (current) alias values to the new values.
+ """
+ if not node:
+ node = self
+ for pos, child in enumerate(node.children):
+ if hasattr(child, 'relabel_aliases'):
+ child.relabel_aliases(change_map)
+ elif isinstance(child, tree.Node):
+ self.relabel_aliases(change_map, child)
+ else:
+ if child[0] in change_map:
+ node.children[pos] = (change_map[child[0]],) + child[1:]
+
+class EverythingNode(object):
+ """
+ A node that matches everything.
+ """
+ def as_sql(self, qn=None):
+ raise FullResultSet
+
+ def relabel_aliases(self, change_map, node=None):
+ return
diff --git a/django/utils/tree.py b/django/utils/tree.py
new file mode 100644
index 0000000000..a62a4ae6c3
--- /dev/null
+++ b/django/utils/tree.py
@@ -0,0 +1,134 @@
+"""
+A class for storing a tree graph. Primarily used for filter constructs in the
+ORM.
+"""
+
+from copy import deepcopy
+
+class Node(object):
+ """
+ A single internal node in the tree graph. A Node should be viewed as a
+ connection (the root) with the children being either leaf nodes or other
+ Node instances.
+ """
+ # Standard connector type. Clients usually won't use this at all and
+ # subclasses will usually override the value.
+ default = 'DEFAULT'
+
+ def __init__(self, children=None, connector=None, negated=False):
+ """
+ Constructs a new Node. If no connector is given, the default will be
+ used.
+
+ Warning: You probably don't want to pass in the 'negated' parameter. It
+ is NOT the same as constructing a node and calling negate() on the
+ result.
+ """
+ self.children = children and children[:] or []
+ self.connector = connector or self.default
+ self.subtree_parents = []
+ self.negated = negated
+
+ def __str__(self):
+ if self.negated:
+ return '(NOT (%s: %s))' % (self.connector, ', '.join([str(c) for c
+ in self.children]))
+ return '(%s: %s)' % (self.connector, ', '.join([str(c) for c in
+ self.children]))
+
+ def __deepcopy__(self, memodict):
+ """
+ Utility method used by copy.deepcopy().
+ """
+ obj = Node(connector=self.connector, negated=self.negated)
+ obj.__class__ = self.__class__
+ obj.children = deepcopy(self.children, memodict)
+ obj.subtree_parents = deepcopy(self.subtree_parents, memodict)
+ return obj
+
+ def __len__(self):
+ """
+ The size of a node if the number of children it has.
+ """
+ return len(self.children)
+
+ def __nonzero__(self):
+ """
+ For truth value testing.
+ """
+ return bool(self.children)
+
+ def __contains__(self, other):
+ """
+ Returns True is 'other' is a direct child of this instance.
+ """
+ return other in self.children
+
+ def add(self, node, conn_type):
+ """
+ Adds a new node to the tree. If the conn_type is the same as the root's
+ current connector type, the node is added to the first level.
+ Otherwise, the whole tree is pushed down one level and a new root
+ connector is created, connecting the existing tree and the new node.
+ """
+ if node in self.children:
+ return
+ if len(self.children) < 2:
+ self.connector = conn_type
+ if self.connector == conn_type:
+ if isinstance(node, Node) and (node.connector == conn_type or
+ len(node) == 1):
+ self.children.extend(node.children)
+ else:
+ self.children.append(node)
+ else:
+ obj = Node(self.children, self.connector, self.negated)
+ self.connector = conn_type
+ self.children = [obj, node]
+
+ def negate(self):
+ """
+ Negate the sense of the root connector. This reorganises the children
+ so that the current node has a single child: a negated node containing
+ all the previous children. This slightly odd construction makes adding
+ new children behave more intuitively.
+
+ Interpreting the meaning of this negate is up to client code. This
+ method is useful for implementing "not" arrangements.
+ """
+ self.children = [Node(self.children, self.connector, not self.negated)]
+ self.connector = self.default
+
+ def start_subtree(self, conn_type):
+ """
+ Sets up internal state so that new nodes are added to a subtree of the
+ current node. The conn_type specifies how the sub-tree is joined to the
+ existing children.
+ """
+ if len(self.children) == 1:
+ self.connector = conn_type
+ elif self.connector != conn_type:
+ self.children = [Node(self.children, self.connector, self.negated)]
+ self.connector = conn_type
+ self.negated = False
+
+ self.subtree_parents.append(Node(self.children, self.connector,
+ self.negated))
+ self.connector = self.default
+ self.negated = False
+ self.children = []
+
+ def end_subtree(self):
+ """
+ Closes off the most recently unmatched start_subtree() call.
+
+ This puts the current state into a node of the parent tree and returns
+ the current instances state to be the parent.
+ """
+ obj = self.subtree_parents.pop()
+ node = Node(self.children, self.connector)
+ self.connector = obj.connector
+ self.negated = obj.negated
+ self.children = obj.children
+ self.children.append(node)
+