summaryrefslogtreecommitdiff
path: root/django/db/models/query.py
diff options
context:
space:
mode:
Diffstat (limited to 'django/db/models/query.py')
-rw-r--r--django/db/models/query.py159
1 files changed, 110 insertions, 49 deletions
diff --git a/django/db/models/query.py b/django/db/models/query.py
index 8c71155c0e..8799b4a93b 100644
--- a/django/db/models/query.py
+++ b/django/db/models/query.py
@@ -2,7 +2,9 @@
The main QuerySet implementation. This provides the public API for the ORM.
"""
-from django.db import connection, transaction, IntegrityError
+from copy import deepcopy
+
+from django.db import connections, transaction, IntegrityError, DEFAULT_DB_ALIAS
from django.db.models.aggregates import Aggregate
from django.db.models.fields import DateField
from django.db.models.query_utils import Q, select_related_descend, CollectedObjects, CyclicDependency, deferred_class_factory, InvalidQuery
@@ -24,9 +26,11 @@ class QuerySet(object):
"""
Represents a lazy database lookup for a set of objects.
"""
- def __init__(self, model=None, query=None):
+ def __init__(self, model=None, query=None, using=None):
self.model = model
- self.query = query or sql.Query(self.model, connection)
+ # EmptyQuerySet instantiates QuerySet with model as None
+ self._db = using
+ self.query = query or sql.Query(self.model)
self._result_cache = None
self._iter = None
self._sticky_filter = False
@@ -258,7 +262,8 @@ class QuerySet(object):
init_list.append(field.attname)
model_cls = deferred_class_factory(self.model, skip)
- for row in self.query.results_iter():
+ compiler = self.query.get_compiler(using=self.db)
+ for row in compiler.results_iter():
if fill_cache:
obj, _ = get_cached_row(self.model, row,
index_start, max_depth,
@@ -280,6 +285,9 @@ class QuerySet(object):
for i, aggregate in enumerate(aggregate_select):
setattr(obj, aggregate, row[i+aggregate_start])
+ # Store the source database of the object
+ obj._state.db = self.db
+
yield obj
def aggregate(self, *args, **kwargs):
@@ -299,7 +307,7 @@ class QuerySet(object):
query.add_aggregate(aggregate_expr, self.model, alias,
is_summary=True)
- return query.get_aggregation()
+ return query.get_aggregation(using=self.db)
def count(self):
"""
@@ -312,7 +320,7 @@ class QuerySet(object):
if self._result_cache is not None and not self._iter:
return len(self._result_cache)
- return self.query.get_count()
+ return self.query.get_count(using=self.db)
def get(self, *args, **kwargs):
"""
@@ -337,7 +345,7 @@ class QuerySet(object):
and returning the created object.
"""
obj = self.model(**kwargs)
- obj.save(force_insert=True)
+ obj.save(force_insert=True, using=self.db)
return obj
def get_or_create(self, **kwargs):
@@ -356,12 +364,12 @@ class QuerySet(object):
params = dict([(k, v) for k, v in kwargs.items() if '__' not in k])
params.update(defaults)
obj = self.model(**params)
- sid = transaction.savepoint()
- obj.save(force_insert=True)
- transaction.savepoint_commit(sid)
+ sid = transaction.savepoint(using=self.db)
+ obj.save(force_insert=True, using=self.db)
+ transaction.savepoint_commit(sid, using=self.db)
return obj, True
except IntegrityError, e:
- transaction.savepoint_rollback(sid)
+ transaction.savepoint_rollback(sid, using=self.db)
try:
return self.get(**kwargs), False
except self.model.DoesNotExist:
@@ -421,7 +429,7 @@ class QuerySet(object):
if not seen_objs:
break
- delete_objects(seen_objs)
+ delete_objects(seen_objs, del_query.db)
# Clear the result cache, in case this QuerySet gets reused.
self._result_cache = None
@@ -436,20 +444,20 @@ class QuerySet(object):
"Cannot update a query once a slice has been taken."
query = self.query.clone(sql.UpdateQuery)
query.add_update_values(kwargs)
- if not transaction.is_managed():
- transaction.enter_transaction_management()
+ if not transaction.is_managed(using=self.db):
+ transaction.enter_transaction_management(using=self.db)
forced_managed = True
else:
forced_managed = False
try:
- rows = query.execute_sql(None)
+ rows = query.get_compiler(self.db).execute_sql(None)
if forced_managed:
- transaction.commit()
+ transaction.commit(using=self.db)
else:
- transaction.commit_unless_managed()
+ transaction.commit_unless_managed(using=self.db)
finally:
if forced_managed:
- transaction.leave_transaction_management()
+ transaction.leave_transaction_management(using=self.db)
self._result_cache = None
return rows
update.alters_data = True
@@ -466,12 +474,12 @@ class QuerySet(object):
query = self.query.clone(sql.UpdateQuery)
query.add_update_fields(values)
self._result_cache = None
- return query.execute_sql(None)
+ return query.get_compiler(self.db).execute_sql(None)
_update.alters_data = True
def exists(self):
if self._result_cache is None:
- return self.query.has_results()
+ return self.query.has_results(using=self.db)
return bool(self._result_cache)
##################################################
@@ -678,6 +686,14 @@ class QuerySet(object):
clone.query.add_immediate_loading(fields)
return clone
+ def using(self, alias):
+ """
+ Selects which database this QuerySet should excecute it's query against.
+ """
+ clone = self._clone()
+ clone._db = alias
+ return clone
+
###################################
# PUBLIC INTROSPECTION ATTRIBUTES #
###################################
@@ -695,6 +711,11 @@ class QuerySet(object):
return False
ordered = property(ordered)
+ @property
+ def db(self):
+ "Return the database that will be used if this query is executed now"
+ return self._db or DEFAULT_DB_ALIAS
+
###################
# PRIVATE METHODS #
###################
@@ -706,6 +727,7 @@ class QuerySet(object):
if self._sticky_filter:
query.filter_is_sticky = True
c = klass(model=self.model, query=query)
+ c._db = self._db
c.__dict__.update(kwargs)
if setup and hasattr(c, '_setup_query'):
c._setup_query()
@@ -755,12 +777,17 @@ class QuerySet(object):
self.query.add_fields(field_names, False)
self.query.set_group_by()
- def _as_sql(self):
+ def _prepare(self):
+ return self
+
+ def _as_sql(self, connection):
"""
Returns the internal query's SQL and parameters (as a tuple).
"""
obj = self.values("pk")
- return obj.query.as_nested_sql()
+ if connection == connections[obj.db]:
+ return obj.query.get_compiler(connection=connection).as_nested_sql()
+ raise ValueError("Can't do subqueries with queries on different DBs.")
# When used as part of a nested query, a queryset will never be an "always
# empty" result.
@@ -783,7 +810,7 @@ class ValuesQuerySet(QuerySet):
names = extra_names + field_names + aggregate_names
- for row in self.query.results_iter():
+ for row in self.query.get_compiler(self.db).results_iter():
yield dict(zip(names, row))
def _setup_query(self):
@@ -866,7 +893,7 @@ class ValuesQuerySet(QuerySet):
super(ValuesQuerySet, self)._setup_aggregate_query(aggregates)
- def _as_sql(self):
+ def _as_sql(self, connection):
"""
For ValueQuerySet (and subclasses like ValuesListQuerySet), they can
only be used as nested queries if they're already set up to select only
@@ -878,15 +905,30 @@ class ValuesQuerySet(QuerySet):
(not self._fields and len(self.model._meta.fields) > 1)):
raise TypeError('Cannot use a multi-field %s as a filter value.'
% self.__class__.__name__)
- return self._clone().query.as_nested_sql()
+
+ obj = self._clone()
+ if connection == connections[obj.db]:
+ return obj.query.get_compiler(connection=connection).as_nested_sql()
+ raise ValueError("Can't do subqueries with queries on different DBs.")
+
+ def _prepare(self):
+ """
+ Validates that we aren't trying to do a query like
+ value__in=qs.values('value1', 'value2'), which isn't valid.
+ """
+ if ((self._fields and len(self._fields) > 1) or
+ (not self._fields and len(self.model._meta.fields) > 1)):
+ raise TypeError('Cannot use a multi-field %s as a filter value.'
+ % self.__class__.__name__)
+ return self
class ValuesListQuerySet(ValuesQuerySet):
def iterator(self):
if self.flat and len(self._fields) == 1:
- for row in self.query.results_iter():
+ for row in self.query.get_compiler(self.db).results_iter():
yield row[0]
elif not self.query.extra_select and not self.query.aggregate_select:
- for row in self.query.results_iter():
+ for row in self.query.get_compiler(self.db).results_iter():
yield tuple(row)
else:
# When extra(select=...) or an annotation is involved, the extra
@@ -905,7 +947,7 @@ class ValuesListQuerySet(ValuesQuerySet):
else:
fields = names
- for row in self.query.results_iter():
+ for row in self.query.get_compiler(self.db).results_iter():
data = dict(zip(names, row))
yield tuple([data[f] for f in fields])
@@ -917,7 +959,7 @@ class ValuesListQuerySet(ValuesQuerySet):
class DateQuerySet(QuerySet):
def iterator(self):
- return self.query.results_iter()
+ return self.query.get_compiler(self.db).results_iter()
def _setup_query(self):
"""
@@ -1032,13 +1074,14 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
setattr(obj, f.get_cache_name(), rel_obj)
return obj, index_end
-def delete_objects(seen_objs):
+def delete_objects(seen_objs, using):
"""
Iterate through a list of seen classes, and remove any instances that are
referred to.
"""
- if not transaction.is_managed():
- transaction.enter_transaction_management()
+ connection = connections[using]
+ if not transaction.is_managed(using=using):
+ transaction.enter_transaction_management(using=using)
forced_managed = True
else:
forced_managed = False
@@ -1064,19 +1107,18 @@ def delete_objects(seen_objs):
signals.pre_delete.send(sender=cls, instance=instance)
pk_list = [pk for pk,instance in items]
- del_query = sql.DeleteQuery(cls, connection)
- del_query.delete_batch_related(pk_list)
+ del_query = sql.DeleteQuery(cls)
+ del_query.delete_batch_related(pk_list, using=using)
- update_query = sql.UpdateQuery(cls, connection)
+ update_query = sql.UpdateQuery(cls)
for field, model in cls._meta.get_fields_with_model():
if (field.rel and field.null and field.rel.to in seen_objs and
filter(lambda f: f.column == field.rel.get_related_field().column,
field.rel.to._meta.fields)):
if model:
- sql.UpdateQuery(model, connection).clear_related(field,
- pk_list)
+ sql.UpdateQuery(model).clear_related(field, pk_list, using=using)
else:
- update_query.clear_related(field, pk_list)
+ update_query.clear_related(field, pk_list, using=using)
# Now delete the actual data.
for cls in ordered_classes:
@@ -1084,8 +1126,8 @@ def delete_objects(seen_objs):
items.reverse()
pk_list = [pk for pk,instance in items]
- del_query = sql.DeleteQuery(cls, connection)
- del_query.delete_batch(pk_list)
+ del_query = sql.DeleteQuery(cls)
+ del_query.delete_batch(pk_list, using=using)
# Last cleanup; set NULLs where there once was a reference to the
# object, NULL the primary key of the found objects, and perform
@@ -1100,21 +1142,24 @@ def delete_objects(seen_objs):
setattr(instance, cls._meta.pk.attname, None)
if forced_managed:
- transaction.commit()
+ transaction.commit(using=using)
else:
- transaction.commit_unless_managed()
+ transaction.commit_unless_managed(using=using)
finally:
if forced_managed:
- transaction.leave_transaction_management()
+ transaction.leave_transaction_management(using=using)
class RawQuerySet(object):
"""
Provides an iterator which converts the results of raw SQL queries into
annotated model instances.
"""
- def __init__(self, query, model=None, query_obj=None, params=None, translations=None):
+ def __init__(self, raw_query, model=None, query=None, params=None,
+ translations=None, using=None):
+ self.raw_query = raw_query
self.model = model
- self.query = query_obj or sql.RawQuery(sql=query, connection=connection, params=params)
+ self._db = using
+ self.query = query or sql.RawQuery(sql=raw_query, using=self.db, params=params)
self.params = params or ()
self.translations = translations or {}
@@ -1123,7 +1168,21 @@ class RawQuerySet(object):
yield self.transform_results(row)
def __repr__(self):
- return "<RawQuerySet: %r>" % (self.query.sql % self.params)
+ return "<RawQuerySet: %r>" % (self.raw_query % self.params)
+
+ @property
+ def db(self):
+ "Return the database that will be used if this query is executed now"
+ return self._db or DEFAULT_DB_ALIAS
+
+ def using(self, alias):
+ """
+ Selects which database this Raw QuerySet should excecute it's query against.
+ """
+ return RawQuerySet(self.raw_query, model=self.model,
+ query=self.query.clone(using=alias),
+ params=self.params, translations=self.translations,
+ using=alias)
@property
def columns(self):
@@ -1189,14 +1248,16 @@ class RawQuerySet(object):
for field, value in annotations:
setattr(instance, field, value)
+ instance._state.db = self.query.using
+
return instance
-def insert_query(model, values, return_id=False, raw_values=False):
+def insert_query(model, values, return_id=False, raw_values=False, using=None):
"""
Inserts a new record for the given model. This provides an interface to
the InsertQuery class and is how Model.save() is implemented. It is not
part of the public API.
"""
- query = sql.InsertQuery(model, connection)
+ query = sql.InsertQuery(model)
query.insert_values(values, raw_values)
- return query.execute_sql(return_id)
+ return query.get_compiler(using=using).execute_sql(return_id)