diff options
Diffstat (limited to 'django/db/models/query.py')
| -rw-r--r-- | django/db/models/query.py | 159 |
1 files changed, 110 insertions, 49 deletions
diff --git a/django/db/models/query.py b/django/db/models/query.py index 8c71155c0e..8799b4a93b 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -2,7 +2,9 @@ The main QuerySet implementation. This provides the public API for the ORM. """ -from django.db import connection, transaction, IntegrityError +from copy import deepcopy + +from django.db import connections, transaction, IntegrityError, DEFAULT_DB_ALIAS from django.db.models.aggregates import Aggregate from django.db.models.fields import DateField from django.db.models.query_utils import Q, select_related_descend, CollectedObjects, CyclicDependency, deferred_class_factory, InvalidQuery @@ -24,9 +26,11 @@ class QuerySet(object): """ Represents a lazy database lookup for a set of objects. """ - def __init__(self, model=None, query=None): + def __init__(self, model=None, query=None, using=None): self.model = model - self.query = query or sql.Query(self.model, connection) + # EmptyQuerySet instantiates QuerySet with model as None + self._db = using + self.query = query or sql.Query(self.model) self._result_cache = None self._iter = None self._sticky_filter = False @@ -258,7 +262,8 @@ class QuerySet(object): init_list.append(field.attname) model_cls = deferred_class_factory(self.model, skip) - for row in self.query.results_iter(): + compiler = self.query.get_compiler(using=self.db) + for row in compiler.results_iter(): if fill_cache: obj, _ = get_cached_row(self.model, row, index_start, max_depth, @@ -280,6 +285,9 @@ class QuerySet(object): for i, aggregate in enumerate(aggregate_select): setattr(obj, aggregate, row[i+aggregate_start]) + # Store the source database of the object + obj._state.db = self.db + yield obj def aggregate(self, *args, **kwargs): @@ -299,7 +307,7 @@ class QuerySet(object): query.add_aggregate(aggregate_expr, self.model, alias, is_summary=True) - return query.get_aggregation() + return query.get_aggregation(using=self.db) def count(self): """ @@ -312,7 +320,7 @@ class QuerySet(object): if self._result_cache is not None and not self._iter: return len(self._result_cache) - return self.query.get_count() + return self.query.get_count(using=self.db) def get(self, *args, **kwargs): """ @@ -337,7 +345,7 @@ class QuerySet(object): and returning the created object. """ obj = self.model(**kwargs) - obj.save(force_insert=True) + obj.save(force_insert=True, using=self.db) return obj def get_or_create(self, **kwargs): @@ -356,12 +364,12 @@ class QuerySet(object): params = dict([(k, v) for k, v in kwargs.items() if '__' not in k]) params.update(defaults) obj = self.model(**params) - sid = transaction.savepoint() - obj.save(force_insert=True) - transaction.savepoint_commit(sid) + sid = transaction.savepoint(using=self.db) + obj.save(force_insert=True, using=self.db) + transaction.savepoint_commit(sid, using=self.db) return obj, True except IntegrityError, e: - transaction.savepoint_rollback(sid) + transaction.savepoint_rollback(sid, using=self.db) try: return self.get(**kwargs), False except self.model.DoesNotExist: @@ -421,7 +429,7 @@ class QuerySet(object): if not seen_objs: break - delete_objects(seen_objs) + delete_objects(seen_objs, del_query.db) # Clear the result cache, in case this QuerySet gets reused. self._result_cache = None @@ -436,20 +444,20 @@ class QuerySet(object): "Cannot update a query once a slice has been taken." query = self.query.clone(sql.UpdateQuery) query.add_update_values(kwargs) - if not transaction.is_managed(): - transaction.enter_transaction_management() + if not transaction.is_managed(using=self.db): + transaction.enter_transaction_management(using=self.db) forced_managed = True else: forced_managed = False try: - rows = query.execute_sql(None) + rows = query.get_compiler(self.db).execute_sql(None) if forced_managed: - transaction.commit() + transaction.commit(using=self.db) else: - transaction.commit_unless_managed() + transaction.commit_unless_managed(using=self.db) finally: if forced_managed: - transaction.leave_transaction_management() + transaction.leave_transaction_management(using=self.db) self._result_cache = None return rows update.alters_data = True @@ -466,12 +474,12 @@ class QuerySet(object): query = self.query.clone(sql.UpdateQuery) query.add_update_fields(values) self._result_cache = None - return query.execute_sql(None) + return query.get_compiler(self.db).execute_sql(None) _update.alters_data = True def exists(self): if self._result_cache is None: - return self.query.has_results() + return self.query.has_results(using=self.db) return bool(self._result_cache) ################################################## @@ -678,6 +686,14 @@ class QuerySet(object): clone.query.add_immediate_loading(fields) return clone + def using(self, alias): + """ + Selects which database this QuerySet should excecute it's query against. + """ + clone = self._clone() + clone._db = alias + return clone + ################################### # PUBLIC INTROSPECTION ATTRIBUTES # ################################### @@ -695,6 +711,11 @@ class QuerySet(object): return False ordered = property(ordered) + @property + def db(self): + "Return the database that will be used if this query is executed now" + return self._db or DEFAULT_DB_ALIAS + ################### # PRIVATE METHODS # ################### @@ -706,6 +727,7 @@ class QuerySet(object): if self._sticky_filter: query.filter_is_sticky = True c = klass(model=self.model, query=query) + c._db = self._db c.__dict__.update(kwargs) if setup and hasattr(c, '_setup_query'): c._setup_query() @@ -755,12 +777,17 @@ class QuerySet(object): self.query.add_fields(field_names, False) self.query.set_group_by() - def _as_sql(self): + def _prepare(self): + return self + + def _as_sql(self, connection): """ Returns the internal query's SQL and parameters (as a tuple). """ obj = self.values("pk") - return obj.query.as_nested_sql() + if connection == connections[obj.db]: + return obj.query.get_compiler(connection=connection).as_nested_sql() + raise ValueError("Can't do subqueries with queries on different DBs.") # When used as part of a nested query, a queryset will never be an "always # empty" result. @@ -783,7 +810,7 @@ class ValuesQuerySet(QuerySet): names = extra_names + field_names + aggregate_names - for row in self.query.results_iter(): + for row in self.query.get_compiler(self.db).results_iter(): yield dict(zip(names, row)) def _setup_query(self): @@ -866,7 +893,7 @@ class ValuesQuerySet(QuerySet): super(ValuesQuerySet, self)._setup_aggregate_query(aggregates) - def _as_sql(self): + def _as_sql(self, connection): """ For ValueQuerySet (and subclasses like ValuesListQuerySet), they can only be used as nested queries if they're already set up to select only @@ -878,15 +905,30 @@ class ValuesQuerySet(QuerySet): (not self._fields and len(self.model._meta.fields) > 1)): raise TypeError('Cannot use a multi-field %s as a filter value.' % self.__class__.__name__) - return self._clone().query.as_nested_sql() + + obj = self._clone() + if connection == connections[obj.db]: + return obj.query.get_compiler(connection=connection).as_nested_sql() + raise ValueError("Can't do subqueries with queries on different DBs.") + + def _prepare(self): + """ + Validates that we aren't trying to do a query like + value__in=qs.values('value1', 'value2'), which isn't valid. + """ + if ((self._fields and len(self._fields) > 1) or + (not self._fields and len(self.model._meta.fields) > 1)): + raise TypeError('Cannot use a multi-field %s as a filter value.' + % self.__class__.__name__) + return self class ValuesListQuerySet(ValuesQuerySet): def iterator(self): if self.flat and len(self._fields) == 1: - for row in self.query.results_iter(): + for row in self.query.get_compiler(self.db).results_iter(): yield row[0] elif not self.query.extra_select and not self.query.aggregate_select: - for row in self.query.results_iter(): + for row in self.query.get_compiler(self.db).results_iter(): yield tuple(row) else: # When extra(select=...) or an annotation is involved, the extra @@ -905,7 +947,7 @@ class ValuesListQuerySet(ValuesQuerySet): else: fields = names - for row in self.query.results_iter(): + for row in self.query.get_compiler(self.db).results_iter(): data = dict(zip(names, row)) yield tuple([data[f] for f in fields]) @@ -917,7 +959,7 @@ class ValuesListQuerySet(ValuesQuerySet): class DateQuerySet(QuerySet): def iterator(self): - return self.query.results_iter() + return self.query.get_compiler(self.db).results_iter() def _setup_query(self): """ @@ -1032,13 +1074,14 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0, setattr(obj, f.get_cache_name(), rel_obj) return obj, index_end -def delete_objects(seen_objs): +def delete_objects(seen_objs, using): """ Iterate through a list of seen classes, and remove any instances that are referred to. """ - if not transaction.is_managed(): - transaction.enter_transaction_management() + connection = connections[using] + if not transaction.is_managed(using=using): + transaction.enter_transaction_management(using=using) forced_managed = True else: forced_managed = False @@ -1064,19 +1107,18 @@ def delete_objects(seen_objs): signals.pre_delete.send(sender=cls, instance=instance) pk_list = [pk for pk,instance in items] - del_query = sql.DeleteQuery(cls, connection) - del_query.delete_batch_related(pk_list) + del_query = sql.DeleteQuery(cls) + del_query.delete_batch_related(pk_list, using=using) - update_query = sql.UpdateQuery(cls, connection) + update_query = sql.UpdateQuery(cls) for field, model in cls._meta.get_fields_with_model(): if (field.rel and field.null and field.rel.to in seen_objs and filter(lambda f: f.column == field.rel.get_related_field().column, field.rel.to._meta.fields)): if model: - sql.UpdateQuery(model, connection).clear_related(field, - pk_list) + sql.UpdateQuery(model).clear_related(field, pk_list, using=using) else: - update_query.clear_related(field, pk_list) + update_query.clear_related(field, pk_list, using=using) # Now delete the actual data. for cls in ordered_classes: @@ -1084,8 +1126,8 @@ def delete_objects(seen_objs): items.reverse() pk_list = [pk for pk,instance in items] - del_query = sql.DeleteQuery(cls, connection) - del_query.delete_batch(pk_list) + del_query = sql.DeleteQuery(cls) + del_query.delete_batch(pk_list, using=using) # Last cleanup; set NULLs where there once was a reference to the # object, NULL the primary key of the found objects, and perform @@ -1100,21 +1142,24 @@ def delete_objects(seen_objs): setattr(instance, cls._meta.pk.attname, None) if forced_managed: - transaction.commit() + transaction.commit(using=using) else: - transaction.commit_unless_managed() + transaction.commit_unless_managed(using=using) finally: if forced_managed: - transaction.leave_transaction_management() + transaction.leave_transaction_management(using=using) class RawQuerySet(object): """ Provides an iterator which converts the results of raw SQL queries into annotated model instances. """ - def __init__(self, query, model=None, query_obj=None, params=None, translations=None): + def __init__(self, raw_query, model=None, query=None, params=None, + translations=None, using=None): + self.raw_query = raw_query self.model = model - self.query = query_obj or sql.RawQuery(sql=query, connection=connection, params=params) + self._db = using + self.query = query or sql.RawQuery(sql=raw_query, using=self.db, params=params) self.params = params or () self.translations = translations or {} @@ -1123,7 +1168,21 @@ class RawQuerySet(object): yield self.transform_results(row) def __repr__(self): - return "<RawQuerySet: %r>" % (self.query.sql % self.params) + return "<RawQuerySet: %r>" % (self.raw_query % self.params) + + @property + def db(self): + "Return the database that will be used if this query is executed now" + return self._db or DEFAULT_DB_ALIAS + + def using(self, alias): + """ + Selects which database this Raw QuerySet should excecute it's query against. + """ + return RawQuerySet(self.raw_query, model=self.model, + query=self.query.clone(using=alias), + params=self.params, translations=self.translations, + using=alias) @property def columns(self): @@ -1189,14 +1248,16 @@ class RawQuerySet(object): for field, value in annotations: setattr(instance, field, value) + instance._state.db = self.query.using + return instance -def insert_query(model, values, return_id=False, raw_values=False): +def insert_query(model, values, return_id=False, raw_values=False, using=None): """ Inserts a new record for the given model. This provides an interface to the InsertQuery class and is how Model.save() is implemented. It is not part of the public API. """ - query = sql.InsertQuery(model, connection) + query = sql.InsertQuery(model) query.insert_values(values, raw_values) - return query.execute_sql(return_id) + return query.get_compiler(using=using).execute_sql(return_id) |
