diff options
| author | Russell Keith-Magee <russell@keith-magee.com> | 2009-12-22 15:18:51 +0000 |
|---|---|---|
| committer | Russell Keith-Magee <russell@keith-magee.com> | 2009-12-22 15:18:51 +0000 |
| commit | ff60c5f9de3e8690d1e86f3e9e3f7248a15397c8 (patch) | |
| tree | a4cb0ebdd55fcaf8c8855231b6ad3e1a7bf45bee /django/db/models | |
| parent | 7ef212af149540aa2da577a960d0d87029fd1514 (diff) | |
Fixed #1142 -- Added multiple database support.
This monster of a patch is the result of Alex Gaynor's 2009 Google Summer of Code project.
Congratulations to Alex for a job well done.
Big thanks also go to:
* Justin Bronn for keeping GIS in line with the changes,
* Karen Tracey and Jani Tiainen for their help testing Oracle support
* Brett Hoerner, Jon Loyens, and Craig Kimmerer for their feedback.
* Malcolm Treddinick for his guidance during the GSoC submission process.
* Simon Willison for driving the original design process
* Cal Henderson for complaining about ponies he wanted.
... and everyone else too numerous to mention that helped to bring this feature into fruition.
git-svn-id: http://code.djangoproject.com/svn/django/trunk@11952 bcc190cf-cafb-0310-a4f2-bffc1f526a37
Diffstat (limited to 'django/db/models')
| -rw-r--r-- | django/db/models/aggregates.py | 3 | ||||
| -rw-r--r-- | django/db/models/base.py | 77 | ||||
| -rw-r--r-- | django/db/models/expressions.py | 8 | ||||
| -rw-r--r-- | django/db/models/fields/__init__.py | 134 | ||||
| -rw-r--r-- | django/db/models/fields/files.py | 6 | ||||
| -rw-r--r-- | django/db/models/fields/related.py | 87 | ||||
| -rw-r--r-- | django/db/models/fields/subclassing.py | 72 | ||||
| -rw-r--r-- | django/db/models/manager.py | 26 | ||||
| -rw-r--r-- | django/db/models/query.py | 159 | ||||
| -rw-r--r-- | django/db/models/query_utils.py | 4 | ||||
| -rw-r--r-- | django/db/models/related.py | 5 | ||||
| -rw-r--r-- | django/db/models/sql/aggregates.py | 9 | ||||
| -rw-r--r-- | django/db/models/sql/compiler.py | 921 | ||||
| -rw-r--r-- | django/db/models/sql/datastructures.py | 12 | ||||
| -rw-r--r-- | django/db/models/sql/expressions.py | 22 | ||||
| -rw-r--r-- | django/db/models/sql/query.py | 766 | ||||
| -rw-r--r-- | django/db/models/sql/subqueries.py | 265 | ||||
| -rw-r--r-- | django/db/models/sql/where.py | 77 |
18 files changed, 1478 insertions, 1175 deletions
diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index ce8829c593..a2349cf5c6 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -43,9 +43,6 @@ class Aggregate(object): """ klass = getattr(query.aggregates_module, self.name) aggregate = klass(col, source=source, is_summary=is_summary, **self.extra) - # Validate that the backend has a fully supported, correct - # implementation of this aggregate - query.connection.ops.check_aggregate_support(aggregate) query.aggregates[alias] = aggregate class Avg(Aggregate): diff --git a/django/db/models/base.py b/django/db/models/base.py index 5b727a059f..3464ae6712 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -9,7 +9,7 @@ from django.db.models.fields.related import OneToOneRel, ManyToOneRel, OneToOneF from django.db.models.query import delete_objects, Q from django.db.models.query_utils import CollectedObjects, DeferredAttribute from django.db.models.options import Options -from django.db import connection, transaction, DatabaseError +from django.db import connections, transaction, DatabaseError, DEFAULT_DB_ALIAS from django.db.models import signals from django.db.models.loading import register_models, get_model import django.utils.copycompat as copy @@ -230,6 +230,13 @@ class ModelBase(type): signals.class_prepared.send(sender=cls) +class ModelState(object): + """ + A class for storing instance state + """ + def __init__(self, db=None): + self.db = db + class Model(object): __metaclass__ = ModelBase _deferred = False @@ -237,6 +244,9 @@ class Model(object): def __init__(self, *args, **kwargs): signals.pre_init.send(sender=self.__class__, args=args, kwargs=kwargs) + # Set up the storage for instane state + self._state = ModelState() + # There is a rather weird disparity here; if kwargs, it's set, then args # overrides it. It should be one or the other; don't duplicate the work # The reason for the kwargs check is that standard iterator passes in by @@ -404,7 +414,7 @@ class Model(object): return getattr(self, field_name) return getattr(self, field.attname) - def save(self, force_insert=False, force_update=False): + def save(self, force_insert=False, force_update=False, using=None): """ Saves the current instance. Override this in a subclass if you want to control the saving process. @@ -416,18 +426,20 @@ class Model(object): if force_insert and force_update: raise ValueError("Cannot force both insert and updating in " "model saving.") - self.save_base(force_insert=force_insert, force_update=force_update) + self.save_base(using=using, force_insert=force_insert, force_update=force_update) save.alters_data = True - def save_base(self, raw=False, cls=None, origin=None, - force_insert=False, force_update=False): + def save_base(self, raw=False, cls=None, origin=None, force_insert=False, + force_update=False, using=None): """ Does the heavy-lifting involved in saving. Subclasses shouldn't need to override this method. It's separate from save() in order to hide the need for overrides of save() to pass around internal-only parameters ('raw', 'cls', and 'origin'). """ + using = using or self._state.db or DEFAULT_DB_ALIAS + connection = connections[using] assert not (force_insert and force_update) if cls is None: cls = self.__class__ @@ -458,7 +470,7 @@ class Model(object): if field and getattr(self, parent._meta.pk.attname) is None and getattr(self, field.attname) is not None: setattr(self, parent._meta.pk.attname, getattr(self, field.attname)) - self.save_base(cls=parent, origin=org) + self.save_base(cls=parent, origin=org, using=using) if field: setattr(self, field.attname, self._get_pk_val(parent._meta)) @@ -476,11 +488,11 @@ class Model(object): if pk_set: # Determine whether a record with the primary key already exists. if (force_update or (not force_insert and - manager.filter(pk=pk_val).exists())): + manager.using(using).filter(pk=pk_val).exists())): # It does already exist, so do an UPDATE. if force_update or non_pks: values = [(f, None, (raw and getattr(self, f.attname) or f.pre_save(self, False))) for f in non_pks] - rows = manager.filter(pk=pk_val)._update(values) + rows = manager.using(using).filter(pk=pk_val)._update(values) if force_update and not rows: raise DatabaseError("Forced update did not affect any rows.") else: @@ -489,27 +501,33 @@ class Model(object): if not pk_set: if force_update: raise ValueError("Cannot force an update in save() with no primary key.") - values = [(f, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True))) for f in meta.local_fields if not isinstance(f, AutoField)] + values = [(f, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True), connection=connection)) + for f in meta.local_fields if not isinstance(f, AutoField)] else: - values = [(f, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True))) for f in meta.local_fields] + values = [(f, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True), connection=connection)) + for f in meta.local_fields] if meta.order_with_respect_to: field = meta.order_with_respect_to - values.append((meta.get_field_by_name('_order')[0], manager.filter(**{field.name: getattr(self, field.attname)}).count())) + values.append((meta.get_field_by_name('_order')[0], manager.using(using).filter(**{field.name: getattr(self, field.attname)}).count())) record_exists = False update_pk = bool(meta.has_auto_field and not pk_set) if values: # Create a new record. - result = manager._insert(values, return_id=update_pk) + result = manager._insert(values, return_id=update_pk, using=using) else: # Create a new record with defaults for everything. - result = manager._insert([(meta.pk, connection.ops.pk_default_value())], return_id=update_pk, raw_values=True) + result = manager._insert([(meta.pk, connection.ops.pk_default_value())], return_id=update_pk, raw_values=True, using=using) if update_pk: setattr(self, meta.pk.attname, result) - transaction.commit_unless_managed() + transaction.commit_unless_managed(using=using) + + # Store the database on which the object was saved + self._state.db = using + # Signal that the save is complete if origin and not meta.auto_created: signals.post_save.send(sender=origin, instance=self, created=(not record_exists), raw=raw) @@ -572,7 +590,9 @@ class Model(object): # delete it and all its descendents. parent_obj._collect_sub_objects(seen_objs) - def delete(self): + def delete(self, using=None): + using = using or self._state.db or DEFAULT_DB_ALIAS + connection = connections[using] assert self._get_pk_val() is not None, "%s object can't be deleted because its %s attribute is set to None." % (self._meta.object_name, self._meta.pk.attname) # Find all the objects than need to be deleted. @@ -580,7 +600,7 @@ class Model(object): self._collect_sub_objects(seen_objs) # Actually delete the objects. - delete_objects(seen_objs) + delete_objects(seen_objs, using) delete.alters_data = True @@ -594,7 +614,7 @@ class Model(object): param = smart_str(getattr(self, field.attname)) q = Q(**{'%s__%s' % (field.name, op): param}) q = q|Q(**{field.name: param, 'pk__%s' % op: self.pk}) - qs = self.__class__._default_manager.filter(**kwargs).filter(q).order_by('%s%s' % (order, field.name), '%spk' % order) + qs = self.__class__._default_manager.using(self._state.db).filter(**kwargs).filter(q).order_by('%s%s' % (order, field.name), '%spk' % order) try: return qs[0] except IndexError: @@ -603,17 +623,16 @@ class Model(object): def _get_next_or_previous_in_order(self, is_next): cachename = "__%s_order_cache" % is_next if not hasattr(self, cachename): - qn = connection.ops.quote_name - op = is_next and '>' or '<' + op = is_next and 'gt' or 'lt' order = not is_next and '-_order' or '_order' order_field = self._meta.order_with_respect_to - # FIXME: When querysets support nested queries, this can be turned - # into a pure queryset operation. - where = ['%s %s (SELECT %s FROM %s WHERE %s=%%s)' % \ - (qn('_order'), op, qn('_order'), - qn(self._meta.db_table), qn(self._meta.pk.column))] - params = [self.pk] - obj = self._default_manager.filter(**{order_field.name: getattr(self, order_field.attname)}).extra(where=where, params=params).order_by(order)[:1].get() + obj = self._default_manager.filter(**{ + order_field.name: getattr(self, order_field.attname) + }).filter(**{ + '_order__%s' % op: self._default_manager.values('_order').filter(**{ + self._meta.pk.name: self.pk + }) + }).order_by(order)[:1].get() setattr(self, cachename, obj) return getattr(self, cachename) @@ -627,14 +646,16 @@ class Model(object): # ORDERING METHODS ######################### -def method_set_order(ordered_obj, self, id_list): +def method_set_order(ordered_obj, self, id_list, using=None): + if using is None: + using = DEFAULT_DB_ALIAS rel_val = getattr(self, ordered_obj._meta.order_with_respect_to.rel.field_name) order_name = ordered_obj._meta.order_with_respect_to.name # FIXME: It would be nice if there was an "update many" version of update # for situations like this. for i, j in enumerate(id_list): ordered_obj.objects.filter(**{'pk': j, order_name: rel_val}).update(_order=i) - transaction.commit_unless_managed() + transaction.commit_unless_managed(using=using) def method_get_order(ordered_obj, self): diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 68abf9d22b..f760e4c5f3 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -41,8 +41,8 @@ class ExpressionNode(tree.Node): def prepare(self, evaluator, query, allow_joins): return evaluator.prepare_node(self, query, allow_joins) - def evaluate(self, evaluator, qn): - return evaluator.evaluate_node(self, qn) + def evaluate(self, evaluator, qn, connection): + return evaluator.evaluate_node(self, qn, connection) ############# # OPERATORS # @@ -109,5 +109,5 @@ class F(ExpressionNode): def prepare(self, evaluator, query, allow_joins): return evaluator.prepare_leaf(self, query, allow_joins) - def evaluate(self, evaluator, qn): - return evaluator.evaluate_leaf(self, qn) + def evaluate(self, evaluator, qn, connection): + return evaluator.evaluate_leaf(self, qn, connection) diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index 1be0bc353c..b70f320df3 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -8,6 +8,7 @@ import django.utils.copycompat as copy from django.db import connection from django.db.models import signals +from django.db.models.fields.subclassing import LegacyConnection from django.db.models.query_utils import QueryWrapper from django.dispatch import dispatcher from django.conf import settings @@ -47,6 +48,9 @@ class FieldDoesNotExist(Exception): # getattr(obj, opts.pk.attname) class Field(object): + """Base class for all field types""" + __metaclass__ = LegacyConnection + # Designates whether empty strings fundamentally are allowed at the # database level. empty_strings_allowed = True @@ -123,10 +127,10 @@ class Field(object): """ return value - def db_type(self): + def db_type(self, connection): """ - Returns the database column data type for this field, taking into - account the DATABASE_ENGINE setting. + Returns the database column data type for this field, for the provided + connection. """ # The default implementation of this method looks at the # backend-specific DATA_TYPES dictionary, looking up the field by its @@ -183,21 +187,56 @@ class Field(object): "Returns field's value just before saving." return getattr(model_instance, self.attname) - def get_db_prep_value(self, value): + def get_prep_value(self, value): + "Perform preliminary non-db specific value checks and conversions." + return value + + def get_db_prep_value(self, value, connection, prepared=False): """Returns field's value prepared for interacting with the database backend. Used by the default implementations of ``get_db_prep_save``and `get_db_prep_lookup``` """ + if not prepared: + value = self.get_prep_value(value) return value - def get_db_prep_save(self, value): + def get_db_prep_save(self, value, connection): "Returns field's value prepared for saving into a database." - return self.get_db_prep_value(value) + return self.get_db_prep_value(value, connection=connection, prepared=False) + + def get_prep_lookup(self, lookup_type, value): + "Perform preliminary non-db specific lookup checks and conversions" + if hasattr(value, 'prepare'): + return value.prepare() + if hasattr(value, '_prepare'): + return value._prepare() + + if lookup_type in ( + 'regex', 'iregex', 'month', 'day', 'week_day', 'search', + 'contains', 'icontains', 'iexact', 'startswith', 'istartswith', + 'endswith', 'iendswith', 'isnull' + ): + return value + elif lookup_type in ('exact', 'gt', 'gte', 'lt', 'lte'): + return self.get_prep_value(value) + elif lookup_type in ('range', 'in'): + return [self.get_prep_value(v) for v in value] + elif lookup_type == 'year': + try: + return int(value) + except ValueError: + raise ValueError("The __year lookup type requires an integer argument") - def get_db_prep_lookup(self, lookup_type, value): + raise TypeError("Field has invalid lookup: %s" % lookup_type) + + def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False): "Returns field's value prepared for database lookup." + if not prepared: + value = self.get_prep_lookup(lookup_type, value) + if hasattr(value, 'get_compiler'): + value = value.get_compiler(connection=connection) if hasattr(value, 'as_sql') or hasattr(value, '_as_sql'): # If the value has a relabel_aliases method, it will need to # be invoked before the final SQL is evaluated @@ -206,15 +245,15 @@ class Field(object): if hasattr(value, 'as_sql'): sql, params = value.as_sql() else: - sql, params = value._as_sql() + sql, params = value._as_sql(connection=connection) return QueryWrapper(('(%s)' % sql), params) if lookup_type in ('regex', 'iregex', 'month', 'day', 'week_day', 'search'): return [value] elif lookup_type in ('exact', 'gt', 'gte', 'lt', 'lte'): - return [self.get_db_prep_value(value)] + return [self.get_db_prep_value(value, connection=connection, prepared=prepared)] elif lookup_type in ('range', 'in'): - return [self.get_db_prep_value(v) for v in value] + return [self.get_db_prep_value(v, connection=connection, prepared=prepared) for v in value] elif lookup_type in ('contains', 'icontains'): return ["%%%s%%" % connection.ops.prep_for_like_query(value)] elif lookup_type == 'iexact': @@ -226,18 +265,11 @@ class Field(object): elif lookup_type == 'isnull': return [] elif lookup_type == 'year': - try: - value = int(value) - except ValueError: - raise ValueError("The __year lookup type requires an integer argument") - if self.get_internal_type() == 'DateField': return connection.ops.year_lookup_bounds_for_date_field(value) else: return connection.ops.year_lookup_bounds(value) - raise TypeError("Field has invalid lookup: %s" % lookup_type) - def has_default(self): "Returns a boolean of whether this field has a default value." return self.default is not NOT_PROVIDED @@ -346,6 +378,7 @@ class Field(object): class AutoField(Field): description = ugettext_lazy("Integer") + empty_strings_allowed = False def __init__(self, *args, **kwargs): assert kwargs.get('primary_key', False) is True, "%ss must have primary_key=True." % self.__class__.__name__ @@ -361,7 +394,7 @@ class AutoField(Field): raise exceptions.ValidationError( _("This value must be an integer.")) - def get_db_prep_value(self, value): + def get_prep_value(self, value): if value is None: return None return int(value) @@ -394,16 +427,16 @@ class BooleanField(Field): raise exceptions.ValidationError( _("This value must be either True or False.")) - def get_db_prep_lookup(self, lookup_type, value): + def get_prep_lookup(self, lookup_type, value): # Special-case handling for filters coming from a web request (e.g. the # admin interface). Only works for scalar values (not lists). If you're # passing in a list, you might as well make things the right type when # constructing the list. if value in ('1', '0'): value = bool(int(value)) - return super(BooleanField, self).get_db_prep_lookup(lookup_type, value) + return super(BooleanField, self).get_prep_lookup(lookup_type, value) - def get_db_prep_value(self, value): + def get_prep_value(self, value): if value is None: return None return bool(value) @@ -421,6 +454,7 @@ class BooleanField(Field): class CharField(Field): description = ugettext_lazy("String (up to %(max_length)s)") + def get_internal_type(self): return "CharField" @@ -443,6 +477,7 @@ class CharField(Field): # TODO: Maybe move this into contrib, because it's specialized. class CommaSeparatedIntegerField(CharField): description = ugettext_lazy("Comma-separated integers") + def formfield(self, **kwargs): defaults = { 'form_class': forms.RegexField, @@ -459,6 +494,7 @@ ansi_date_re = re.compile(r'^\d{4}-\d{1,2}-\d{1,2}$') class DateField(Field): description = ugettext_lazy("Date (without time)") + empty_strings_allowed = False def __init__(self, verbose_name=None, name=None, auto_now=False, auto_now_add=False, **kwargs): self.auto_now, self.auto_now_add = auto_now, auto_now_add @@ -509,16 +545,21 @@ class DateField(Field): setattr(cls, 'get_previous_by_%s' % self.name, curry(cls._get_next_or_previous_by_FIELD, field=self, is_next=False)) - def get_db_prep_lookup(self, lookup_type, value): + def get_prep_lookup(self, lookup_type, value): # For "__month", "__day", and "__week_day" lookups, convert the value # to an int so the database backend always sees a consistent type. if lookup_type in ('month', 'day', 'week_day'): - return [int(value)] - return super(DateField, self).get_db_prep_lookup(lookup_type, value) + return int(value) + return super(DateField, self).get_prep_lookup(lookup_type, value) - def get_db_prep_value(self, value): + def get_prep_value(self, value): + return self.to_python(value) + + def get_db_prep_value(self, value, connection, prepared=False): # Casts dates into the format expected by the backend - return connection.ops.value_to_db_date(self.to_python(value)) + if not prepared: + value = self.get_prep_value(value) + return connection.ops.value_to_db_date(value) def value_to_string(self, obj): val = self._get_val_from_obj(obj) @@ -535,6 +576,7 @@ class DateField(Field): class DateTimeField(DateField): description = ugettext_lazy("Date (with time)") + def get_internal_type(self): return "DateTimeField" @@ -575,9 +617,14 @@ class DateTimeField(DateField): raise exceptions.ValidationError( _('Enter a valid date/time in YYYY-MM-DD HH:MM[:ss[.uuuuuu]] format.')) - def get_db_prep_value(self, value): + def get_prep_value(self, value): + return self.to_python(value) + + def get_db_prep_value(self, value, connection, prepared=False): # Casts dates into the format expected by the backend - return connection.ops.value_to_db_datetime(self.to_python(value)) + if not prepared: + value = self.get_prep_value(value) + return connection.ops.value_to_db_datetime(value) def value_to_string(self, obj): val = self._get_val_from_obj(obj) @@ -632,11 +679,11 @@ class DecimalField(Field): from django.db.backends import util return util.format_number(value, self.max_digits, self.decimal_places) - def get_db_prep_save(self, value): + def get_db_prep_save(self, value, connection): return connection.ops.value_to_db_decimal(self.to_python(value), self.max_digits, self.decimal_places) - def get_db_prep_value(self, value): + def get_prep_value(self, value): return self.to_python(value) def formfield(self, **kwargs): @@ -661,6 +708,7 @@ class EmailField(CharField): class FilePathField(Field): description = ugettext_lazy("File path") + def __init__(self, verbose_name=None, name=None, path='', match=None, recursive=False, **kwargs): self.path, self.match, self.recursive = path, match, recursive kwargs['max_length'] = kwargs.get('max_length', 100) @@ -683,7 +731,7 @@ class FloatField(Field): empty_strings_allowed = False description = ugettext_lazy("Floating point number") - def get_db_prep_value(self, value): + def get_prep_value(self, value): if value is None: return None return float(value) @@ -708,7 +756,8 @@ class FloatField(Field): class IntegerField(Field): empty_strings_allowed = False description = ugettext_lazy("Integer") - def get_db_prep_value(self, value): + + def get_prep_value(self, value): if value is None: return None return int(value) @@ -776,16 +825,16 @@ class NullBooleanField(Field): raise exceptions.ValidationError( _("This value must be either None, True or False.")) - def get_db_prep_lookup(self, lookup_type, value): + def get_prep_lookup(self, lookup_type, value): # Special-case handling for filters coming from a web request (e.g. the # admin interface). Only works for scalar values (not lists). If you're # passing in a list, you might as well make things the right type when # constructing the list. if value in ('1', '0'): value = bool(int(value)) - return super(NullBooleanField, self).get_db_prep_lookup(lookup_type, value) + return super(NullBooleanField, self).get_prep_lookup(lookup_type, value) - def get_db_prep_value(self, value): + def get_prep_value(self, value): if value is None: return None return bool(value) @@ -801,6 +850,7 @@ class NullBooleanField(Field): class PositiveIntegerField(IntegerField): description = ugettext_lazy("Integer") + def get_internal_type(self): return "PositiveIntegerField" @@ -838,11 +888,13 @@ class SlugField(CharField): class SmallIntegerField(IntegerField): description = ugettext_lazy("Integer") + def get_internal_type(self): return "SmallIntegerField" class TextField(Field): description = ugettext_lazy("Text") + def get_internal_type(self): return "TextField" @@ -853,6 +905,7 @@ class TextField(Field): class TimeField(Field): description = ugettext_lazy("Time") + empty_strings_allowed = False def __init__(self, verbose_name=None, name=None, auto_now=False, auto_now_add=False, **kwargs): self.auto_now, self.auto_now_add = auto_now, auto_now_add @@ -907,9 +960,14 @@ class TimeField(Field): else: return super(TimeField, self).pre_save(model_instance, add) - def get_db_prep_value(self, value): + def get_prep_value(self, value): + return self.to_python(value) + + def get_db_prep_value(self, value, connection, prepared=False): # Casts times into the format expected by the backend - return connection.ops.value_to_db_time(self.to_python(value)) + if not prepared: + value = self.get_prep_value(value) + return connection.ops.value_to_db_time(value) def value_to_string(self, obj): val = self._get_val_from_obj(obj) @@ -926,6 +984,7 @@ class TimeField(Field): class URLField(CharField): description = ugettext_lazy("URL") + def __init__(self, verbose_name=None, name=None, verify_exists=True, **kwargs): kwargs['max_length'] = kwargs.get('max_length', 200) self.verify_exists = verify_exists @@ -938,6 +997,7 @@ class URLField(CharField): class XMLField(TextField): description = ugettext_lazy("XML text") + def __init__(self, verbose_name=None, name=None, schema_path=None, **kwargs): self.schema_path = schema_path Field.__init__(self, verbose_name, name, **kwargs) diff --git a/django/db/models/fields/files.py b/django/db/models/fields/files.py index 97cb4dc082..6dfeddbc41 100644 --- a/django/db/models/fields/files.py +++ b/django/db/models/fields/files.py @@ -235,12 +235,12 @@ class FileField(Field): def get_internal_type(self): return "FileField" - def get_db_prep_lookup(self, lookup_type, value): + def get_prep_lookup(self, lookup_type, value): if hasattr(value, 'name'): value = value.name - return super(FileField, self).get_db_prep_lookup(lookup_type, value) + return super(FileField, self).get_prep_lookup(lookup_type, value) - def get_db_prep_value(self, value): + def get_prep_value(self, value): "Returns field's value prepared for saving into a database." # Need to convert File objects provided via a form to unicode for database insertion if value is None: diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index e412e6b6b8..7cc9a03907 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -1,7 +1,8 @@ -from django.db import connection, transaction +from django.db import connection, transaction, DEFAULT_DB_ALIAS from django.db.backends import util from django.db.models import signals, get_model -from django.db.models.fields import AutoField, Field, IntegerField, PositiveIntegerField, PositiveSmallIntegerField, FieldDoesNotExist +from django.db.models.fields import (AutoField, Field, IntegerField, + PositiveIntegerField, PositiveSmallIntegerField, FieldDoesNotExist) from django.db.models.related import RelatedObject from django.db.models.query import QuerySet from django.db.models.query_utils import QueryWrapper @@ -11,10 +12,6 @@ from django.utils.functional import curry from django.core import exceptions from django import forms -try: - set -except NameError: - from sets import Set as set # Python 2.3 fallback RECURSIVE_RELATIONSHIP_CONSTANT = 'self' @@ -120,7 +117,7 @@ class RelatedField(object): if not cls._meta.abstract: self.contribute_to_related_class(other, self.related) - def get_db_prep_lookup(self, lookup_type, value): + def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False): # If we are doing a lookup on a Related Field, we must be # comparing object instances. The value should be the PK of value, # not value itself. @@ -140,11 +137,16 @@ class RelatedField(object): if field: if lookup_type in ('range', 'in'): v = [v] - v = field.get_db_prep_lookup(lookup_type, v) + v = field.get_db_prep_lookup(lookup_type, v, + connection=connection, prepared=prepared) if isinstance(v, list): v = v[0] return v + if not prepared: + value = self.get_prep_lookup(lookup_type, value) + if hasattr(value, 'get_compiler'): + value = value.get_compiler(connection=connection) if hasattr(value, 'as_sql') or hasattr(value, '_as_sql'): # If the value has a relabel_aliases method, it will need to # be invoked before the final SQL is evaluated @@ -153,7 +155,7 @@ class RelatedField(object): if hasattr(value, 'as_sql'): sql, params = value.as_sql() else: - sql, params = value._as_sql() + sql, params = value._as_sql(connection=connection) return QueryWrapper(('(%s)' % sql), params) # FIXME: lt and gt are explicitally allowed to make @@ -192,7 +194,7 @@ class SingleRelatedObjectDescriptor(object): return getattr(instance, self.cache_name) except AttributeError: params = {'%s__pk' % self.related.field.name: instance._get_pk_val()} - rel_obj = self.related.model._base_manager.get(**params) + rel_obj = self.related.model._base_manager.using(instance._state.db).get(**params) setattr(instance, self.cache_name, rel_obj) return rel_obj @@ -255,10 +257,11 @@ class ReverseSingleRelatedObjectDescriptor(object): # If the related manager indicates that it should be used for # related fields, respect that. rel_mgr = self.field.rel.to._default_manager + using = instance._state.db or DEFAULT_DB_ALIAS if getattr(rel_mgr, 'use_for_related_fields', False): - rel_obj = rel_mgr.get(**params) + rel_obj = rel_mgr.using(using).get(**params) else: - rel_obj = QuerySet(self.field.rel.to).get(**params) + rel_obj = QuerySet(self.field.rel.to).using(using).get(**params) setattr(instance, cache_name, rel_obj) return rel_obj @@ -275,6 +278,14 @@ class ReverseSingleRelatedObjectDescriptor(object): raise ValueError('Cannot assign "%r": "%s.%s" must be a "%s" instance.' % (value, instance._meta.object_name, self.field.name, self.field.rel.to._meta.object_name)) + elif value is not None and value._state.db != instance._state.db: + if instance._state.db is None: + instance._state.db = value._state.db + else:#elif value._state.db is None: + value._state.db = instance._state.db +# elif value._state.db is not None and instance._state.db is not None: +# raise ValueError('Cannot assign "%r": instance is on database "%s", value is is on database "%s"' % +# (value, instance._state.db, value._state.db)) # If we're setting the value of a OneToOneField to None, we need to clear # out the cache on any old related object. Otherwise, deleting the @@ -356,14 +367,15 @@ class ForeignRelatedObjectsDescriptor(object): class RelatedManager(superclass): def get_query_set(self): - return superclass.get_query_set(self).filter(**(self.core_filters)) + using = instance._state.db or DEFAULT_DB_ALIAS + return superclass.get_query_set(self).using(using).filter(**(self.core_filters)) def add(self, *objs): for obj in objs: if not isinstance(obj, self.model): raise TypeError, "'%s' instance expected" % self.model._meta.object_name setattr(obj, rel_field.name, instance) - obj.save() + obj.save(using=instance._state.db) add.alters_data = True def create(self, **kwargs): @@ -375,7 +387,8 @@ class ForeignRelatedObjectsDescriptor(object): # Update kwargs with the related object that this # ForeignRelatedObjectsDescriptor knows about. kwargs.update({rel_field.name: instance}) - return super(RelatedManager, self).get_or_create(**kwargs) + using = instance._state.db or DEFAULT_DB_ALIAS + return super(RelatedManager, self).using(using).get_or_create(**kwargs) get_or_create.alters_data = True # remove() and clear() are only provided if the ForeignKey can have a value of null. @@ -386,7 +399,7 @@ class ForeignRelatedObjectsDescriptor(object): # Is obj actually part of this descriptor set? if getattr(obj, rel_field.attname) == val: setattr(obj, rel_field.name, None) - obj.save() + obj.save(using=instance._state.db) else: raise rel_field.rel.to.DoesNotExist, "%r is not related to %r." % (obj, instance) remove.alters_data = True @@ -394,7 +407,7 @@ class ForeignRelatedObjectsDescriptor(object): def clear(self): for obj in self.all(): setattr(obj, rel_field.name, None) - obj.save() + obj.save(using=instance._state.db) clear.alters_data = True manager = RelatedManager() @@ -425,7 +438,7 @@ def create_many_related_manager(superclass, rel=False): raise ValueError("%r instance needs to have a primary key value before a many-to-many relationship can be used." % instance.__class__.__name__) def get_query_set(self): - return superclass.get_query_set(self)._next_is_sticky().filter(**(self.core_filters)) + return superclass.get_query_set(self).using(self.instance._state.db)._next_is_sticky().filter(**(self.core_filters)) # If the ManyToMany relation has an intermediary model, # the add and remove methods do not exist. @@ -460,14 +473,14 @@ def create_many_related_manager(superclass, rel=False): if not rel.through._meta.auto_created: opts = through._meta raise AttributeError, "Cannot use create() on a ManyToManyField which specifies an intermediary model. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name) - new_obj = super(ManyRelatedManager, self).create(**kwargs) + new_obj = super(ManyRelatedManager, self).using(self.instance._state.db).create(**kwargs) self.add(new_obj) return new_obj create.alters_data = True def get_or_create(self, **kwargs): obj, created = \ - super(ManyRelatedManager, self).get_or_create(**kwargs) + super(ManyRelatedManager, self).using(self.instance._state.db).get_or_create(**kwargs) # We only need to add() if created because if we got an object back # from get() then the relationship already exists. if created: @@ -487,12 +500,15 @@ def create_many_related_manager(superclass, rel=False): new_ids = set() for obj in objs: if isinstance(obj, self.model): +# if obj._state.db != self.instance._state.db: +# raise ValueError('Cannot add "%r": instance is on database "%s", value is is on database "%s"' % +# (obj, self.instance._state.db, obj._state.db)) new_ids.add(obj.pk) elif isinstance(obj, Model): raise TypeError, "'%s' instance expected" % self.model._meta.object_name else: new_ids.add(obj) - vals = self.through._default_manager.values_list(target_field_name, flat=True) + vals = self.through._default_manager.using(self.instance._state.db).values_list(target_field_name, flat=True) vals = vals.filter(**{ source_field_name: self._pk_val, '%s__in' % target_field_name: new_ids, @@ -501,7 +517,7 @@ def create_many_related_manager(superclass, rel=False): # Add the ones that aren't there already for obj_id in (new_ids - vals): - self.through._default_manager.create(**{ + self.through._default_manager.using(self.instance._state.db).create(**{ '%s_id' % source_field_name: self._pk_val, '%s_id' % target_field_name: obj_id, }) @@ -521,14 +537,14 @@ def create_many_related_manager(superclass, rel=False): else: old_ids.add(obj) # Remove the specified objects from the join table - self.through._default_manager.filter(**{ + self.through._default_manager.using(self.instance._state.db).filter(**{ source_field_name: self._pk_val, '%s__in' % target_field_name: old_ids }).delete() def _clear_items(self, source_field_name): # source_col_name: the PK colname in join_table for the source object - self.through._default_manager.filter(**{ + self.through._default_manager.using(self.instance._state.db).filter(**{ source_field_name: self._pk_val }).delete() @@ -728,11 +744,12 @@ class ForeignKey(RelatedField, Field): return getattr(field_default, self.rel.get_related_field().attname) return field_default - def get_db_prep_save(self, value): + def get_db_prep_save(self, value, connection): if value == '' or value == None: return None else: - return self.rel.get_related_field().get_db_prep_save(value) + return self.rel.get_related_field().get_db_prep_save(value, + connection=connection) def value_to_string(self, obj): if not obj: @@ -764,16 +781,16 @@ class ForeignKey(RelatedField, Field): self.rel.field_name = cls._meta.pk.name def formfield(self, **kwargs): + db = kwargs.pop('using', None) defaults = { 'form_class': forms.ModelChoiceField, - 'queryset': self.rel.to._default_manager.complex_filter( - self.rel.limit_choices_to), + 'queryset': self.rel.to._default_manager.using(db).complex_filter(self.rel.limit_choices_to), 'to_field_name': self.rel.field_name, } defaults.update(kwargs) return super(ForeignKey, self).formfield(**defaults) - def db_type(self): + def db_type(self, connection): # The database column type of a ForeignKey is the column type # of the field to which it points. An exception is if the ForeignKey # points to an AutoField/PositiveIntegerField/PositiveSmallIntegerField, @@ -785,8 +802,8 @@ class ForeignKey(RelatedField, Field): (not connection.features.related_fields_match_type and isinstance(rel_field, (PositiveIntegerField, PositiveSmallIntegerField)))): - return IntegerField().db_type() - return rel_field.db_type() + return IntegerField().db_type(connection=connection) + return rel_field.db_type(connection=connection) class OneToOneField(ForeignKey): """ @@ -1012,7 +1029,11 @@ class ManyToManyField(RelatedField, Field): setattr(instance, self.attname, data) def formfield(self, **kwargs): - defaults = {'form_class': forms.ModelMultipleChoiceField, 'queryset': self.rel.to._default_manager.complex_filter(self.rel.limit_choices_to)} + db = kwargs.pop('using', None) + defaults = { + 'form_class': forms.ModelMultipleChoiceField, + 'queryset': self.rel.to._default_manager.using(db).complex_filter(self.rel.limit_choices_to) + } defaults.update(kwargs) # If initial is passed in, it's a list of related objects, but the # MultipleChoiceField takes a list of IDs. @@ -1023,7 +1044,7 @@ class ManyToManyField(RelatedField, Field): defaults['initial'] = [i._get_pk_val() for i in initial] return super(ManyToManyField, self).formfield(**defaults) - def db_type(self): + def db_type(self, connection): # A ManyToManyField is not represented by a single column, # so return None. return None diff --git a/django/db/models/fields/subclassing.py b/django/db/models/fields/subclassing.py index 10add10739..bd11675ad3 100644 --- a/django/db/models/fields/subclassing.py +++ b/django/db/models/fields/subclassing.py @@ -1,11 +1,77 @@ """ -Convenience routines for creating non-trivial Field subclasses. +Convenience routines for creating non-trivial Field subclasses, as well as +backwards compatibility utilities. Add SubfieldBase as the __metaclass__ for your Field subclass, implement to_python() and the other necessary methods and everything will work seamlessly. """ -class SubfieldBase(type): +from inspect import getargspec +from warnings import warn + +def call_with_connection(func): + arg_names, varargs, varkwargs, defaults = getargspec(func) + updated = ('connection' in arg_names or varkwargs) + if not updated: + warn("A Field class whose %s method hasn't been updated to take a " + "`connection` argument." % func.__name__, + PendingDeprecationWarning, stacklevel=2) + + def inner(*args, **kwargs): + if 'connection' not in kwargs: + from django.db import connection + kwargs['connection'] = connection + warn("%s has been called without providing a connection argument. " % + func.__name__, PendingDeprecationWarning, + stacklevel=1) + if updated: + return func(*args, **kwargs) + if 'connection' in kwargs: + del kwargs['connection'] + return func(*args, **kwargs) + return inner + +def call_with_connection_and_prepared(func): + arg_names, varargs, varkwargs, defaults = getargspec(func) + updated = ( + ('connection' in arg_names or varkwargs) and + ('prepared' in arg_names or varkwargs) + ) + if not updated: + warn("A Field class whose %s method hasn't been updated to take " + "`connection` and `prepared` arguments." % func.__name__, + PendingDeprecationWarning, stacklevel=2) + + def inner(*args, **kwargs): + if 'connection' not in kwargs: + from django.db import connection + kwargs['connection'] = connection + warn("%s has been called without providing a connection argument. " % + func.__name__, PendingDeprecationWarning, + stacklevel=1) + if updated: + return func(*args, **kwargs) + if 'connection' in kwargs: + del kwargs['connection'] + if 'prepared' in kwargs: + del kwargs['prepared'] + return func(*args, **kwargs) + return inner + +class LegacyConnection(type): + """ + A metaclass to normalize arguments give to the get_db_prep_* and db_type + methods on fields. + """ + def __new__(cls, names, bases, attrs): + new_cls = super(LegacyConnection, cls).__new__(cls, names, bases, attrs) + for attr in ('db_type', 'get_db_prep_save'): + setattr(new_cls, attr, call_with_connection(getattr(new_cls, attr))) + for attr in ('get_db_prep_lookup', 'get_db_prep_value'): + setattr(new_cls, attr, call_with_connection_and_prepared(getattr(new_cls, attr))) + return new_cls + +class SubfieldBase(LegacyConnection): """ A metaclass for custom Field subclasses. This ensures the model's attribute has the descriptor protocol attached to it. @@ -26,7 +92,7 @@ class Creator(object): def __get__(self, obj, type=None): if obj is None: raise AttributeError('Can only be accessed via an instance.') - return obj.__dict__[self.field.name] + return obj.__dict__[self.field.name] def __set__(self, obj, value): obj.__dict__[self.field.name] = self.field.to_python(value) diff --git a/django/db/models/manager.py b/django/db/models/manager.py index c4d47e0d36..7f96daaa4e 100644 --- a/django/db/models/manager.py +++ b/django/db/models/manager.py @@ -1,4 +1,6 @@ -import django.utils.copycompat as copy +from django.utils import copycompat as copy + +from django.db import DEFAULT_DB_ALIAS from django.db.models.query import QuerySet, EmptyQuerySet, insert_query, RawQuerySet from django.db.models import signals from django.db.models.fields import FieldDoesNotExist @@ -49,6 +51,7 @@ class Manager(object): self._set_creation_counter() self.model = None self._inherited = False + self._db = None def contribute_to_class(self, model, name): # TODO: Use weakref because of possible memory leak / circular reference. @@ -84,6 +87,15 @@ class Manager(object): mgr._inherited = True return mgr + def db_manager(self, alias): + obj = copy.copy(self) + obj._db = alias + return obj + + @property + def db(self): + return self._db or DEFAULT_DB_ALIAS + ####################### # PROXIES TO QUERYSET # ####################### @@ -95,7 +107,10 @@ class Manager(object): """Returns a new QuerySet object. Subclasses can override this method to easily customize the behavior of the Manager. """ - return QuerySet(self.model) + qs = QuerySet(self.model) + if self._db is not None: + qs = qs.using(self._db) + return qs def none(self): return self.get_empty_query_set() @@ -172,6 +187,9 @@ class Manager(object): def only(self, *args, **kwargs): return self.get_query_set().only(*args, **kwargs) + def using(self, *args, **kwargs): + return self.get_query_set().using(*args, **kwargs) + def exists(self, *args, **kwargs): return self.get_query_set().exists(*args, **kwargs) @@ -181,8 +199,8 @@ class Manager(object): def _update(self, values, **kwargs): return self.get_query_set()._update(values, **kwargs) - def raw(self, query, params=None, *args, **kwargs): - return RawQuerySet(model=self.model, query=query, params=params, *args, **kwargs) + def raw(self, raw_query, params=None, *args, **kwargs): + return RawQuerySet(raw_query=raw_query, model=self.model, params=params, using=self.db, *args, **kwargs) class ManagerDescriptor(object): # This class ensures managers aren't accessible via model instances. diff --git a/django/db/models/query.py b/django/db/models/query.py index 8c71155c0e..8799b4a93b 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -2,7 +2,9 @@ The main QuerySet implementation. This provides the public API for the ORM. """ -from django.db import connection, transaction, IntegrityError +from copy import deepcopy + +from django.db import connections, transaction, IntegrityError, DEFAULT_DB_ALIAS from django.db.models.aggregates import Aggregate from django.db.models.fields import DateField from django.db.models.query_utils import Q, select_related_descend, CollectedObjects, CyclicDependency, deferred_class_factory, InvalidQuery @@ -24,9 +26,11 @@ class QuerySet(object): """ Represents a lazy database lookup for a set of objects. """ - def __init__(self, model=None, query=None): + def __init__(self, model=None, query=None, using=None): self.model = model - self.query = query or sql.Query(self.model, connection) + # EmptyQuerySet instantiates QuerySet with model as None + self._db = using + self.query = query or sql.Query(self.model) self._result_cache = None self._iter = None self._sticky_filter = False @@ -258,7 +262,8 @@ class QuerySet(object): init_list.append(field.attname) model_cls = deferred_class_factory(self.model, skip) - for row in self.query.results_iter(): + compiler = self.query.get_compiler(using=self.db) + for row in compiler.results_iter(): if fill_cache: obj, _ = get_cached_row(self.model, row, index_start, max_depth, @@ -280,6 +285,9 @@ class QuerySet(object): for i, aggregate in enumerate(aggregate_select): setattr(obj, aggregate, row[i+aggregate_start]) + # Store the source database of the object + obj._state.db = self.db + yield obj def aggregate(self, *args, **kwargs): @@ -299,7 +307,7 @@ class QuerySet(object): query.add_aggregate(aggregate_expr, self.model, alias, is_summary=True) - return query.get_aggregation() + return query.get_aggregation(using=self.db) def count(self): """ @@ -312,7 +320,7 @@ class QuerySet(object): if self._result_cache is not None and not self._iter: return len(self._result_cache) - return self.query.get_count() + return self.query.get_count(using=self.db) def get(self, *args, **kwargs): """ @@ -337,7 +345,7 @@ class QuerySet(object): and returning the created object. """ obj = self.model(**kwargs) - obj.save(force_insert=True) + obj.save(force_insert=True, using=self.db) return obj def get_or_create(self, **kwargs): @@ -356,12 +364,12 @@ class QuerySet(object): params = dict([(k, v) for k, v in kwargs.items() if '__' not in k]) params.update(defaults) obj = self.model(**params) - sid = transaction.savepoint() - obj.save(force_insert=True) - transaction.savepoint_commit(sid) + sid = transaction.savepoint(using=self.db) + obj.save(force_insert=True, using=self.db) + transaction.savepoint_commit(sid, using=self.db) return obj, True except IntegrityError, e: - transaction.savepoint_rollback(sid) + transaction.savepoint_rollback(sid, using=self.db) try: return self.get(**kwargs), False except self.model.DoesNotExist: @@ -421,7 +429,7 @@ class QuerySet(object): if not seen_objs: break - delete_objects(seen_objs) + delete_objects(seen_objs, del_query.db) # Clear the result cache, in case this QuerySet gets reused. self._result_cache = None @@ -436,20 +444,20 @@ class QuerySet(object): "Cannot update a query once a slice has been taken." query = self.query.clone(sql.UpdateQuery) query.add_update_values(kwargs) - if not transaction.is_managed(): - transaction.enter_transaction_management() + if not transaction.is_managed(using=self.db): + transaction.enter_transaction_management(using=self.db) forced_managed = True else: forced_managed = False try: - rows = query.execute_sql(None) + rows = query.get_compiler(self.db).execute_sql(None) if forced_managed: - transaction.commit() + transaction.commit(using=self.db) else: - transaction.commit_unless_managed() + transaction.commit_unless_managed(using=self.db) finally: if forced_managed: - transaction.leave_transaction_management() + transaction.leave_transaction_management(using=self.db) self._result_cache = None return rows update.alters_data = True @@ -466,12 +474,12 @@ class QuerySet(object): query = self.query.clone(sql.UpdateQuery) query.add_update_fields(values) self._result_cache = None - return query.execute_sql(None) + return query.get_compiler(self.db).execute_sql(None) _update.alters_data = True def exists(self): if self._result_cache is None: - return self.query.has_results() + return self.query.has_results(using=self.db) return bool(self._result_cache) ################################################## @@ -678,6 +686,14 @@ class QuerySet(object): clone.query.add_immediate_loading(fields) return clone + def using(self, alias): + """ + Selects which database this QuerySet should excecute it's query against. + """ + clone = self._clone() + clone._db = alias + return clone + ################################### # PUBLIC INTROSPECTION ATTRIBUTES # ################################### @@ -695,6 +711,11 @@ class QuerySet(object): return False ordered = property(ordered) + @property + def db(self): + "Return the database that will be used if this query is executed now" + return self._db or DEFAULT_DB_ALIAS + ################### # PRIVATE METHODS # ################### @@ -706,6 +727,7 @@ class QuerySet(object): if self._sticky_filter: query.filter_is_sticky = True c = klass(model=self.model, query=query) + c._db = self._db c.__dict__.update(kwargs) if setup and hasattr(c, '_setup_query'): c._setup_query() @@ -755,12 +777,17 @@ class QuerySet(object): self.query.add_fields(field_names, False) self.query.set_group_by() - def _as_sql(self): + def _prepare(self): + return self + + def _as_sql(self, connection): """ Returns the internal query's SQL and parameters (as a tuple). """ obj = self.values("pk") - return obj.query.as_nested_sql() + if connection == connections[obj.db]: + return obj.query.get_compiler(connection=connection).as_nested_sql() + raise ValueError("Can't do subqueries with queries on different DBs.") # When used as part of a nested query, a queryset will never be an "always # empty" result. @@ -783,7 +810,7 @@ class ValuesQuerySet(QuerySet): names = extra_names + field_names + aggregate_names - for row in self.query.results_iter(): + for row in self.query.get_compiler(self.db).results_iter(): yield dict(zip(names, row)) def _setup_query(self): @@ -866,7 +893,7 @@ class ValuesQuerySet(QuerySet): super(ValuesQuerySet, self)._setup_aggregate_query(aggregates) - def _as_sql(self): + def _as_sql(self, connection): """ For ValueQuerySet (and subclasses like ValuesListQuerySet), they can only be used as nested queries if they're already set up to select only @@ -878,15 +905,30 @@ class ValuesQuerySet(QuerySet): (not self._fields and len(self.model._meta.fields) > 1)): raise TypeError('Cannot use a multi-field %s as a filter value.' % self.__class__.__name__) - return self._clone().query.as_nested_sql() + + obj = self._clone() + if connection == connections[obj.db]: + return obj.query.get_compiler(connection=connection).as_nested_sql() + raise ValueError("Can't do subqueries with queries on different DBs.") + + def _prepare(self): + """ + Validates that we aren't trying to do a query like + value__in=qs.values('value1', 'value2'), which isn't valid. + """ + if ((self._fields and len(self._fields) > 1) or + (not self._fields and len(self.model._meta.fields) > 1)): + raise TypeError('Cannot use a multi-field %s as a filter value.' + % self.__class__.__name__) + return self class ValuesListQuerySet(ValuesQuerySet): def iterator(self): if self.flat and len(self._fields) == 1: - for row in self.query.results_iter(): + for row in self.query.get_compiler(self.db).results_iter(): yield row[0] elif not self.query.extra_select and not self.query.aggregate_select: - for row in self.query.results_iter(): + for row in self.query.get_compiler(self.db).results_iter(): yield tuple(row) else: # When extra(select=...) or an annotation is involved, the extra @@ -905,7 +947,7 @@ class ValuesListQuerySet(ValuesQuerySet): else: fields = names - for row in self.query.results_iter(): + for row in self.query.get_compiler(self.db).results_iter(): data = dict(zip(names, row)) yield tuple([data[f] for f in fields]) @@ -917,7 +959,7 @@ class ValuesListQuerySet(ValuesQuerySet): class DateQuerySet(QuerySet): def iterator(self): - return self.query.results_iter() + return self.query.get_compiler(self.db).results_iter() def _setup_query(self): """ @@ -1032,13 +1074,14 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0, setattr(obj, f.get_cache_name(), rel_obj) return obj, index_end -def delete_objects(seen_objs): +def delete_objects(seen_objs, using): """ Iterate through a list of seen classes, and remove any instances that are referred to. """ - if not transaction.is_managed(): - transaction.enter_transaction_management() + connection = connections[using] + if not transaction.is_managed(using=using): + transaction.enter_transaction_management(using=using) forced_managed = True else: forced_managed = False @@ -1064,19 +1107,18 @@ def delete_objects(seen_objs): signals.pre_delete.send(sender=cls, instance=instance) pk_list = [pk for pk,instance in items] - del_query = sql.DeleteQuery(cls, connection) - del_query.delete_batch_related(pk_list) + del_query = sql.DeleteQuery(cls) + del_query.delete_batch_related(pk_list, using=using) - update_query = sql.UpdateQuery(cls, connection) + update_query = sql.UpdateQuery(cls) for field, model in cls._meta.get_fields_with_model(): if (field.rel and field.null and field.rel.to in seen_objs and filter(lambda f: f.column == field.rel.get_related_field().column, field.rel.to._meta.fields)): if model: - sql.UpdateQuery(model, connection).clear_related(field, - pk_list) + sql.UpdateQuery(model).clear_related(field, pk_list, using=using) else: - update_query.clear_related(field, pk_list) + update_query.clear_related(field, pk_list, using=using) # Now delete the actual data. for cls in ordered_classes: @@ -1084,8 +1126,8 @@ def delete_objects(seen_objs): items.reverse() pk_list = [pk for pk,instance in items] - del_query = sql.DeleteQuery(cls, connection) - del_query.delete_batch(pk_list) + del_query = sql.DeleteQuery(cls) + del_query.delete_batch(pk_list, using=using) # Last cleanup; set NULLs where there once was a reference to the # object, NULL the primary key of the found objects, and perform @@ -1100,21 +1142,24 @@ def delete_objects(seen_objs): setattr(instance, cls._meta.pk.attname, None) if forced_managed: - transaction.commit() + transaction.commit(using=using) else: - transaction.commit_unless_managed() + transaction.commit_unless_managed(using=using) finally: if forced_managed: - transaction.leave_transaction_management() + transaction.leave_transaction_management(using=using) class RawQuerySet(object): """ Provides an iterator which converts the results of raw SQL queries into annotated model instances. """ - def __init__(self, query, model=None, query_obj=None, params=None, translations=None): + def __init__(self, raw_query, model=None, query=None, params=None, + translations=None, using=None): + self.raw_query = raw_query self.model = model - self.query = query_obj or sql.RawQuery(sql=query, connection=connection, params=params) + self._db = using + self.query = query or sql.RawQuery(sql=raw_query, using=self.db, params=params) self.params = params or () self.translations = translations or {} @@ -1123,7 +1168,21 @@ class RawQuerySet(object): yield self.transform_results(row) def __repr__(self): - return "<RawQuerySet: %r>" % (self.query.sql % self.params) + return "<RawQuerySet: %r>" % (self.raw_query % self.params) + + @property + def db(self): + "Return the database that will be used if this query is executed now" + return self._db or DEFAULT_DB_ALIAS + + def using(self, alias): + """ + Selects which database this Raw QuerySet should excecute it's query against. + """ + return RawQuerySet(self.raw_query, model=self.model, + query=self.query.clone(using=alias), + params=self.params, translations=self.translations, + using=alias) @property def columns(self): @@ -1189,14 +1248,16 @@ class RawQuerySet(object): for field, value in annotations: setattr(instance, field, value) + instance._state.db = self.query.using + return instance -def insert_query(model, values, return_id=False, raw_values=False): +def insert_query(model, values, return_id=False, raw_values=False, using=None): """ Inserts a new record for the given model. This provides an interface to the InsertQuery class and is how Model.save() is implemented. It is not part of the public API. """ - query = sql.InsertQuery(model, connection) + query = sql.InsertQuery(model) query.insert_values(values, raw_values) - return query.execute_sql(return_id) + return query.get_compiler(using=using).execute_sql(return_id) diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py index 746b04d4fb..9f6083ce7e 100644 --- a/django/db/models/query_utils.py +++ b/django/db/models/query_utils.py @@ -134,7 +134,7 @@ class QueryWrapper(object): def __init__(self, sql, params): self.data = sql, params - def as_sql(self, qn=None): + def as_sql(self, qn=None, connection=None): return self.data class Q(tree.Node): @@ -187,7 +187,7 @@ class DeferredAttribute(object): cls = self.model_ref() data = instance.__dict__ if data.get(self.field_name, self) is self: - data[self.field_name] = cls._base_manager.filter(pk=instance.pk).values_list(self.field_name, flat=True).get() + data[self.field_name] = cls._base_manager.filter(pk=instance.pk).values_list(self.field_name, flat=True).using(instance._state.db).get() return data[self.field_name] def __set__(self, instance, value): diff --git a/django/db/models/related.py b/django/db/models/related.py index ff7c787a93..afdf3f7b61 100644 --- a/django/db/models/related.py +++ b/django/db/models/related.py @@ -18,9 +18,10 @@ class RelatedObject(object): self.name = '%s:%s' % (self.opts.app_label, self.opts.module_name) self.var_name = self.opts.object_name.lower() - def get_db_prep_lookup(self, lookup_type, value): + def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False): # Defer to the actual field definition for db prep - return self.field.get_db_prep_lookup(lookup_type, value) + return self.field.get_db_prep_lookup(lookup_type, value, + connection=connection, prepared=prepared) def editable_fields(self): "Get the fields in this class that should be edited inline." diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py index 6fdaf188c4..8a14bdf2df 100644 --- a/django/db/models/sql/aggregates.py +++ b/django/db/models/sql/aggregates.py @@ -72,15 +72,13 @@ class Aggregate(object): if isinstance(self.col, (list, tuple)): self.col = (change_map.get(self.col[0], self.col[0]), self.col[1]) - def as_sql(self, quote_func=None): + def as_sql(self, qn, connection): "Return the aggregate, rendered as SQL." - if not quote_func: - quote_func = lambda x: x if hasattr(self.col, 'as_sql'): - field_name = self.col.as_sql(quote_func) + field_name = self.col.as_sql(qn, connection) elif isinstance(self.col, (list, tuple)): - field_name = '.'.join([quote_func(c) for c in self.col]) + field_name = '.'.join([qn(c) for c in self.col]) else: field_name = self.col @@ -127,4 +125,3 @@ class Variance(Aggregate): def __init__(self, col, sample=False, **extra): super(Variance, self).__init__(col, **extra) self.sql_function = sample and 'VAR_SAMP' or 'VAR_POP' - diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py new file mode 100644 index 0000000000..6a95d32259 --- /dev/null +++ b/django/db/models/sql/compiler.py @@ -0,0 +1,921 @@ +from django.core.exceptions import FieldError +from django.db import connections +from django.db.backends.util import truncate_name +from django.db.models.sql.constants import * +from django.db.models.sql.datastructures import EmptyResultSet +from django.db.models.sql.expressions import SQLEvaluator +from django.db.models.sql.query import get_proxied_model, get_order_dir, \ + select_related_descend, Query + +class SQLCompiler(object): + def __init__(self, query, connection, using): + self.query = query + self.connection = connection + self.using = using + self.quote_cache = {} + + # Check that the compiler will be able to execute the query + for alias, aggregate in self.query.aggregate_select.items(): + self.connection.ops.check_aggregate_support(aggregate) + + def pre_sql_setup(self): + """ + Does any necessary class setup immediately prior to producing SQL. This + is for things that can't necessarily be done in __init__ because we + might not have all the pieces in place at that time. + """ + if not self.query.tables: + self.query.join((None, self.query.model._meta.db_table, None, None)) + if (not self.query.select and self.query.default_cols and not + self.query.included_inherited_models): + self.query.setup_inherited_models() + if self.query.select_related and not self.query.related_select_cols: + self.fill_related_selections() + + def quote_name_unless_alias(self, name): + """ + A wrapper around connection.ops.quote_name that doesn't quote aliases + for table names. This avoids problems with some SQL dialects that treat + quoted strings specially (e.g. PostgreSQL). + """ + if name in self.quote_cache: + return self.quote_cache[name] + if ((name in self.query.alias_map and name not in self.query.table_map) or + name in self.query.extra_select): + self.quote_cache[name] = name + return name + r = self.connection.ops.quote_name(name) + self.quote_cache[name] = r + return r + + def as_sql(self, with_limits=True, with_col_aliases=False): + """ + Creates the SQL for this query. Returns the SQL string and list of + parameters. + + If 'with_limits' is False, any limit/offset information is not included + in the query. + """ + self.pre_sql_setup() + out_cols = self.get_columns(with_col_aliases) + ordering, ordering_group_by = self.get_ordering() + + # This must come after 'select' and 'ordering' -- see docstring of + # get_from_clause() for details. + from_, f_params = self.get_from_clause() + + qn = self.quote_name_unless_alias + + where, w_params = self.query.where.as_sql(qn=qn, connection=self.connection) + having, h_params = self.query.having.as_sql(qn=qn, connection=self.connection) + params = [] + for val in self.query.extra_select.itervalues(): + params.extend(val[1]) + + result = ['SELECT'] + if self.query.distinct: + result.append('DISTINCT') + result.append(', '.join(out_cols + self.query.ordering_aliases)) + + result.append('FROM') + result.extend(from_) + params.extend(f_params) + + if where: + result.append('WHERE %s' % where) + params.extend(w_params) + if self.query.extra_where: + if not where: + result.append('WHERE') + else: + result.append('AND') + result.append(' AND '.join(self.query.extra_where)) + + grouping, gb_params = self.get_grouping() + if grouping: + if ordering: + # If the backend can't group by PK (i.e., any database + # other than MySQL), then any fields mentioned in the + # ordering clause needs to be in the group by clause. + if not self.connection.features.allows_group_by_pk: + for col, col_params in ordering_group_by: + if col not in grouping: + grouping.append(str(col)) + gb_params.extend(col_params) + else: + ordering = self.connection.ops.force_no_ordering() + result.append('GROUP BY %s' % ', '.join(grouping)) + params.extend(gb_params) + + if having: + result.append('HAVING %s' % having) + params.extend(h_params) + + if ordering: + result.append('ORDER BY %s' % ', '.join(ordering)) + + if with_limits: + if self.query.high_mark is not None: + result.append('LIMIT %d' % (self.query.high_mark - self.query.low_mark)) + if self.query.low_mark: + if self.query.high_mark is None: + val = self.connection.ops.no_limit_value() + if val: + result.append('LIMIT %d' % val) + result.append('OFFSET %d' % self.query.low_mark) + + params.extend(self.query.extra_params) + return ' '.join(result), tuple(params) + + def as_nested_sql(self): + """ + Perform the same functionality as the as_sql() method, returning an + SQL string and parameters. However, the alias prefixes are bumped + beforehand (in a copy -- the current query isn't changed) and any + ordering is removed. + + Used when nesting this query inside another. + """ + obj = self.query.clone() + obj.clear_ordering(True) + obj.bump_prefix() + return obj.get_compiler(connection=self.connection).as_sql() + + def get_columns(self, with_aliases=False): + """ + Returns the list of columns to use in the select statement. If no + columns have been specified, returns all columns relating to fields in + the model. + + If 'with_aliases' is true, any column names that are duplicated + (without the table names) are given unique aliases. This is needed in + some cases to avoid ambiguity with nested queries. + """ + qn = self.quote_name_unless_alias + qn2 = self.connection.ops.quote_name + result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in self.query.extra_select.iteritems()] + aliases = set(self.query.extra_select.keys()) + if with_aliases: + col_aliases = aliases.copy() + else: + col_aliases = set() + if self.query.select: + only_load = self.deferred_to_columns() + for col in self.query.select: + if isinstance(col, (list, tuple)): + alias, column = col + table = self.query.alias_map[alias][TABLE_NAME] + if table in only_load and col not in only_load[table]: + continue + r = '%s.%s' % (qn(alias), qn(column)) + if with_aliases: + if col[1] in col_aliases: + c_alias = 'Col%d' % len(col_aliases) + result.append('%s AS %s' % (r, c_alias)) + aliases.add(c_alias) + col_aliases.add(c_alias) + else: + result.append('%s AS %s' % (r, qn2(col[1]))) + aliases.add(r) + col_aliases.add(col[1]) + else: + result.append(r) + aliases.add(r) + col_aliases.add(col[1]) + else: + result.append(col.as_sql(qn, self.connection)) + + if hasattr(col, 'alias'): + aliases.add(col.alias) + col_aliases.add(col.alias) + + elif self.query.default_cols: + cols, new_aliases = self.get_default_columns(with_aliases, + col_aliases) + result.extend(cols) + aliases.update(new_aliases) + + max_name_length = self.connection.ops.max_name_length() + result.extend([ + '%s%s' % ( + aggregate.as_sql(qn, self.connection), + alias is not None + and ' AS %s' % qn(truncate_name(alias, max_name_length)) + or '' + ) + for alias, aggregate in self.query.aggregate_select.items() + ]) + + for table, col in self.query.related_select_cols: + r = '%s.%s' % (qn(table), qn(col)) + if with_aliases and col in col_aliases: + c_alias = 'Col%d' % len(col_aliases) + result.append('%s AS %s' % (r, c_alias)) + aliases.add(c_alias) + col_aliases.add(c_alias) + else: + result.append(r) + aliases.add(r) + col_aliases.add(col) + + self._select_aliases = aliases + return result + + def get_default_columns(self, with_aliases=False, col_aliases=None, + start_alias=None, opts=None, as_pairs=False): + """ + Computes the default columns for selecting every field in the base + model. Will sometimes be called to pull in related models (e.g. via + select_related), in which case "opts" and "start_alias" will be given + to provide a starting point for the traversal. + + Returns a list of strings, quoted appropriately for use in SQL + directly, as well as a set of aliases used in the select statement (if + 'as_pairs' is True, returns a list of (alias, col_name) pairs instead + of strings as the first component and None as the second component). + """ + result = [] + if opts is None: + opts = self.query.model._meta + qn = self.quote_name_unless_alias + qn2 = self.connection.ops.quote_name + aliases = set() + only_load = self.deferred_to_columns() + # Skip all proxy to the root proxied model + proxied_model = get_proxied_model(opts) + + if start_alias: + seen = {None: start_alias} + for field, model in opts.get_fields_with_model(): + if start_alias: + try: + alias = seen[model] + except KeyError: + if model is proxied_model: + alias = start_alias + else: + link_field = opts.get_ancestor_link(model) + alias = self.query.join((start_alias, model._meta.db_table, + link_field.column, model._meta.pk.column)) + seen[model] = alias + else: + # If we're starting from the base model of the queryset, the + # aliases will have already been set up in pre_sql_setup(), so + # we can save time here. + alias = self.query.included_inherited_models[model] + table = self.query.alias_map[alias][TABLE_NAME] + if table in only_load and field.column not in only_load[table]: + continue + if as_pairs: + result.append((alias, field.column)) + aliases.add(alias) + continue + if with_aliases and field.column in col_aliases: + c_alias = 'Col%d' % len(col_aliases) + result.append('%s.%s AS %s' % (qn(alias), + qn2(field.column), c_alias)) + col_aliases.add(c_alias) + aliases.add(c_alias) + else: + r = '%s.%s' % (qn(alias), qn2(field.column)) + result.append(r) + aliases.add(r) + if with_aliases: + col_aliases.add(field.column) + return result, aliases + + def get_ordering(self): + """ + Returns a tuple containing a list representing the SQL elements in the + "order by" clause, and the list of SQL elements that need to be added + to the GROUP BY clause as a result of the ordering. + + Also sets the ordering_aliases attribute on this instance to a list of + extra aliases needed in the select. + + Determining the ordering SQL can change the tables we need to include, + so this should be run *before* get_from_clause(). + """ + if self.query.extra_order_by: + ordering = self.query.extra_order_by + elif not self.query.default_ordering: + ordering = self.query.order_by + else: + ordering = self.query.order_by or self.query.model._meta.ordering + qn = self.quote_name_unless_alias + qn2 = self.connection.ops.quote_name + distinct = self.query.distinct + select_aliases = self._select_aliases + result = [] + group_by = [] + ordering_aliases = [] + if self.query.standard_ordering: + asc, desc = ORDER_DIR['ASC'] + else: + asc, desc = ORDER_DIR['DESC'] + + # It's possible, due to model inheritance, that normal usage might try + # to include the same field more than once in the ordering. We track + # the table/column pairs we use and discard any after the first use. + processed_pairs = set() + + for field in ordering: + if field == '?': + result.append(self.connection.ops.random_function_sql()) + continue + if isinstance(field, int): + if field < 0: + order = desc + field = -field + else: + order = asc + result.append('%s %s' % (field, order)) + group_by.append((field, [])) + continue + col, order = get_order_dir(field, asc) + if col in self.query.aggregate_select: + result.append('%s %s' % (col, order)) + continue + if '.' in field: + # This came in through an extra(order_by=...) addition. Pass it + # on verbatim. + table, col = col.split('.', 1) + if (table, col) not in processed_pairs: + elt = '%s.%s' % (qn(table), col) + processed_pairs.add((table, col)) + if not distinct or elt in select_aliases: + result.append('%s %s' % (elt, order)) + group_by.append((elt, [])) + elif get_order_dir(field)[0] not in self.query.extra_select: + # 'col' is of the form 'field' or 'field1__field2' or + # '-field1__field2__field', etc. + for table, col, order in self.find_ordering_name(field, + self.query.model._meta, default_order=asc): + if (table, col) not in processed_pairs: + elt = '%s.%s' % (qn(table), qn2(col)) + processed_pairs.add((table, col)) + if distinct and elt not in select_aliases: + ordering_aliases.append(elt) + result.append('%s %s' % (elt, order)) + group_by.append((elt, [])) + else: + elt = qn2(col) + if distinct and col not in select_aliases: + ordering_aliases.append(elt) + result.append('%s %s' % (elt, order)) + group_by.append(self.query.extra_select[col]) + self.query.ordering_aliases = ordering_aliases + return result, group_by + + def find_ordering_name(self, name, opts, alias=None, default_order='ASC', + already_seen=None): + """ + Returns the table alias (the name might be ambiguous, the alias will + not be) and column name for ordering by the given 'name' parameter. + The 'name' is of the form 'field1__field2__...__fieldN'. + """ + name, order = get_order_dir(name, default_order) + pieces = name.split(LOOKUP_SEP) + if not alias: + alias = self.query.get_initial_alias() + field, target, opts, joins, last, extra = self.query.setup_joins(pieces, + opts, alias, False) + alias = joins[-1] + col = target.column + if not field.rel: + # To avoid inadvertent trimming of a necessary alias, use the + # refcount to show that we are referencing a non-relation field on + # the model. + self.query.ref_alias(alias) + + # Must use left outer joins for nullable fields and their relations. + self.query.promote_alias_chain(joins, + self.query.alias_map[joins[0]][JOIN_TYPE] == self.query.LOUTER) + + # If we get to this point and the field is a relation to another model, + # append the default ordering for that model. + if field.rel and len(joins) > 1 and opts.ordering: + # Firstly, avoid infinite loops. + if not already_seen: + already_seen = set() + join_tuple = tuple([self.query.alias_map[j][TABLE_NAME] for j in joins]) + if join_tuple in already_seen: + raise FieldError('Infinite loop caused by ordering.') + already_seen.add(join_tuple) + + results = [] + for item in opts.ordering: + results.extend(self.find_ordering_name(item, opts, alias, + order, already_seen)) + return results + + if alias: + # We have to do the same "final join" optimisation as in + # add_filter, since the final column might not otherwise be part of + # the select set (so we can't order on it). + while 1: + join = self.query.alias_map[alias] + if col != join[RHS_JOIN_COL]: + break + self.query.unref_alias(alias) + alias = join[LHS_ALIAS] + col = join[LHS_JOIN_COL] + return [(alias, col, order)] + + def get_from_clause(self): + """ + Returns a list of strings that are joined together to go after the + "FROM" part of the query, as well as a list any extra parameters that + need to be included. Sub-classes, can override this to create a + from-clause via a "select". + + This should only be called after any SQL construction methods that + might change the tables we need. This means the select columns and + ordering must be done first. + """ + result = [] + qn = self.quote_name_unless_alias + qn2 = self.connection.ops.quote_name + first = True + for alias in self.query.tables: + if not self.query.alias_refcount[alias]: + continue + try: + name, alias, join_type, lhs, lhs_col, col, nullable = self.query.alias_map[alias] + except KeyError: + # Extra tables can end up in self.tables, but not in the + # alias_map if they aren't in a join. That's OK. We skip them. + continue + alias_str = (alias != name and ' %s' % alias or '') + if join_type and not first: + result.append('%s %s%s ON (%s.%s = %s.%s)' + % (join_type, qn(name), alias_str, qn(lhs), + qn2(lhs_col), qn(alias), qn2(col))) + else: + connector = not first and ', ' or '' + result.append('%s%s%s' % (connector, qn(name), alias_str)) + first = False + for t in self.query.extra_tables: + alias, unused = self.query.table_alias(t) + # Only add the alias if it's not already present (the table_alias() + # calls increments the refcount, so an alias refcount of one means + # this is the only reference. + if alias not in self.query.alias_map or self.query.alias_refcount[alias] == 1: + connector = not first and ', ' or '' + result.append('%s%s' % (connector, qn(alias))) + first = False + return result, [] + + def get_grouping(self): + """ + Returns a tuple representing the SQL elements in the "group by" clause. + """ + qn = self.quote_name_unless_alias + result, params = [], [] + if self.query.group_by is not None: + if len(self.query.model._meta.fields) == len(self.query.select) and \ + self.connection.features.allows_group_by_pk: + self.query.group_by = [(self.query.model._meta.db_table, self.query.model._meta.pk.column)] + + group_by = self.query.group_by or [] + + extra_selects = [] + for extra_select, extra_params in self.query.extra_select.itervalues(): + extra_selects.append(extra_select) + params.extend(extra_params) + for col in group_by + self.query.related_select_cols + extra_selects: + if isinstance(col, (list, tuple)): + result.append('%s.%s' % (qn(col[0]), qn(col[1]))) + elif hasattr(col, 'as_sql'): + result.append(col.as_sql(qn)) + else: + result.append(str(col)) + return result, params + + def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1, + used=None, requested=None, restricted=None, nullable=None, + dupe_set=None, avoid_set=None): + """ + Fill in the information needed for a select_related query. The current + depth is measured as the number of connections away from the root model + (for example, cur_depth=1 means we are looking at models with direct + connections to the root model). + """ + if not restricted and self.query.max_depth and cur_depth > self.query.max_depth: + # We've recursed far enough; bail out. + return + + if not opts: + opts = self.query.get_meta() + root_alias = self.query.get_initial_alias() + self.query.related_select_cols = [] + self.query.related_select_fields = [] + if not used: + used = set() + if dupe_set is None: + dupe_set = set() + if avoid_set is None: + avoid_set = set() + orig_dupe_set = dupe_set + + # Setup for the case when only particular related fields should be + # included in the related selection. + if requested is None and restricted is not False: + if isinstance(self.query.select_related, dict): + requested = self.query.select_related + restricted = True + else: + restricted = False + + for f, model in opts.get_fields_with_model(): + if not select_related_descend(f, restricted, requested): + continue + # The "avoid" set is aliases we want to avoid just for this + # particular branch of the recursion. They aren't permanently + # forbidden from reuse in the related selection tables (which is + # what "used" specifies). + avoid = avoid_set.copy() + dupe_set = orig_dupe_set.copy() + table = f.rel.to._meta.db_table + if nullable or f.null: + promote = True + else: + promote = False + if model: + int_opts = opts + alias = root_alias + alias_chain = [] + for int_model in opts.get_base_chain(model): + # Proxy model have elements in base chain + # with no parents, assign the new options + # object and skip to the next base in that + # case + if not int_opts.parents[int_model]: + int_opts = int_model._meta + continue + lhs_col = int_opts.parents[int_model].column + dedupe = lhs_col in opts.duplicate_targets + if dedupe: + avoid.update(self.query.dupe_avoidance.get(id(opts), lhs_col), + ()) + dupe_set.add((opts, lhs_col)) + int_opts = int_model._meta + alias = self.query.join((alias, int_opts.db_table, lhs_col, + int_opts.pk.column), exclusions=used, + promote=promote) + alias_chain.append(alias) + for (dupe_opts, dupe_col) in dupe_set: + self.query.update_dupe_avoidance(dupe_opts, dupe_col, alias) + if self.query.alias_map[root_alias][JOIN_TYPE] == self.query.LOUTER: + self.query.promote_alias_chain(alias_chain, True) + else: + alias = root_alias + + dedupe = f.column in opts.duplicate_targets + if dupe_set or dedupe: + avoid.update(self.query.dupe_avoidance.get((id(opts), f.column), ())) + if dedupe: + dupe_set.add((opts, f.column)) + + alias = self.query.join((alias, table, f.column, + f.rel.get_related_field().column), + exclusions=used.union(avoid), promote=promote) + used.add(alias) + columns, aliases = self.get_default_columns(start_alias=alias, + opts=f.rel.to._meta, as_pairs=True) + self.query.related_select_cols.extend(columns) + if self.query.alias_map[alias][JOIN_TYPE] == self.query.LOUTER: + self.query.promote_alias_chain(aliases, True) + self.query.related_select_fields.extend(f.rel.to._meta.fields) + if restricted: + next = requested.get(f.name, {}) + else: + next = False + if f.null is not None: + new_nullable = f.null + else: + new_nullable = None + for dupe_opts, dupe_col in dupe_set: + self.query.update_dupe_avoidance(dupe_opts, dupe_col, alias) + self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1, + used, next, restricted, new_nullable, dupe_set, avoid) + + def deferred_to_columns(self): + """ + Converts the self.deferred_loading data structure to mapping of table + names to sets of column names which are to be loaded. Returns the + dictionary. + """ + columns = {} + self.query.deferred_to_data(columns, self.query.deferred_to_columns_cb) + return columns + + def results_iter(self): + """ + Returns an iterator over the results from executing this query. + """ + resolve_columns = hasattr(self, 'resolve_columns') + fields = None + for rows in self.execute_sql(MULTI): + for row in rows: + if resolve_columns: + if fields is None: + # We only set this up here because + # related_select_fields isn't populated until + # execute_sql() has been called. + if self.query.select_fields: + fields = self.query.select_fields + self.query.related_select_fields + else: + fields = self.query.model._meta.fields + # If the field was deferred, exclude it from being passed + # into `resolve_columns` because it wasn't selected. + only_load = self.deferred_to_columns() + if only_load: + db_table = self.query.model._meta.db_table + fields = [f for f in fields if db_table in only_load and + f.column in only_load[db_table]] + row = self.resolve_columns(row, fields) + + if self.query.aggregate_select: + aggregate_start = len(self.query.extra_select.keys()) + len(self.query.select) + aggregate_end = aggregate_start + len(self.query.aggregate_select) + row = tuple(row[:aggregate_start]) + tuple([ + self.query.resolve_aggregate(value, aggregate, self.connection) + for (alias, aggregate), value + in zip(self.query.aggregate_select.items(), row[aggregate_start:aggregate_end]) + ]) + tuple(row[aggregate_end:]) + + yield row + + def execute_sql(self, result_type=MULTI): + """ + Run the query against the database and returns the result(s). The + return value is a single data item if result_type is SINGLE, or an + iterator over the results if the result_type is MULTI. + + result_type is either MULTI (use fetchmany() to retrieve all rows), + SINGLE (only retrieve a single row), or None. In this last case, the + cursor is returned if any query is executed, since it's used by + subclasses such as InsertQuery). It's possible, however, that no query + is needed, as the filters describe an empty set. In that case, None is + returned, to avoid any unnecessary database interaction. + """ + try: + sql, params = self.as_sql() + if not sql: + raise EmptyResultSet + except EmptyResultSet: + if result_type == MULTI: + return empty_iter() + else: + return + + cursor = self.connection.cursor() + cursor.execute(sql, params) + + if not result_type: + return cursor + if result_type == SINGLE: + if self.query.ordering_aliases: + return cursor.fetchone()[:-len(self.query.ordering_aliases)] + return cursor.fetchone() + + # The MULTI case. + if self.query.ordering_aliases: + result = order_modified_iter(cursor, len(self.query.ordering_aliases), + self.connection.features.empty_fetchmany_value) + else: + result = iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)), + self.connection.features.empty_fetchmany_value) + if not self.connection.features.can_use_chunked_reads: + # If we are using non-chunked reads, we return the same data + # structure as normally, but ensure it is all read into memory + # before going any further. + return list(result) + return result + + +class SQLInsertCompiler(SQLCompiler): + def placeholder(self, field, val): + if field is None: + # A field value of None means the value is raw. + return val + elif hasattr(field, 'get_placeholder'): + # Some fields (e.g. geo fields) need special munging before + # they can be inserted. + return field.get_placeholder(val, self.connection) + else: + # Return the common case for the placeholder + return '%s' + + def as_sql(self): + # We don't need quote_name_unless_alias() here, since these are all + # going to be column names (so we can avoid the extra overhead). + qn = self.connection.ops.quote_name + opts = self.query.model._meta + result = ['INSERT INTO %s' % qn(opts.db_table)] + result.append('(%s)' % ', '.join([qn(c) for c in self.query.columns])) + values = [self.placeholder(*v) for v in self.query.values] + result.append('VALUES (%s)' % ', '.join(values)) + params = self.query.params + if self.return_id and self.connection.features.can_return_id_from_insert: + col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column)) + r_fmt, r_params = self.connection.ops.return_insert_id() + result.append(r_fmt % col) + params = params + r_params + return ' '.join(result), params + + def execute_sql(self, return_id=False): + self.return_id = return_id + cursor = super(SQLInsertCompiler, self).execute_sql(None) + if not (return_id and cursor): + return + if self.connection.features.can_return_id_from_insert: + return self.connection.ops.fetch_returned_insert_id(cursor) + return self.connection.ops.last_insert_id(cursor, + self.query.model._meta.db_table, self.query.model._meta.pk.column) + + +class SQLDeleteCompiler(SQLCompiler): + def as_sql(self): + """ + Creates the SQL for this query. Returns the SQL string and list of + parameters. + """ + assert len(self.query.tables) == 1, \ + "Can only delete from one table at a time." + qn = self.quote_name_unless_alias + result = ['DELETE FROM %s' % qn(self.query.tables[0])] + where, params = self.query.where.as_sql(qn=qn, connection=self.connection) + result.append('WHERE %s' % where) + return ' '.join(result), tuple(params) + +class SQLUpdateCompiler(SQLCompiler): + def as_sql(self): + """ + Creates the SQL for this query. Returns the SQL string and list of + parameters. + """ + from django.db.models.base import Model + + self.pre_sql_setup() + if not self.query.values: + return '', () + table = self.query.tables[0] + qn = self.quote_name_unless_alias + result = ['UPDATE %s' % qn(table)] + result.append('SET') + values, update_params = [], [] + for field, model, val in self.query.values: + if hasattr(val, 'prepare_database_save'): + val = val.prepare_database_save(field) + else: + val = field.get_db_prep_save(val, connection=self.connection) + + # Getting the placeholder for the field. + if hasattr(field, 'get_placeholder'): + placeholder = field.get_placeholder(val, self.connection) + else: + placeholder = '%s' + + if hasattr(val, 'evaluate'): + val = SQLEvaluator(val, self.query, allow_joins=False) + name = field.column + if hasattr(val, 'as_sql'): + sql, params = val.as_sql(qn, self.connection) + values.append('%s = %s' % (qn(name), sql)) + update_params.extend(params) + elif val is not None: + values.append('%s = %s' % (qn(name), placeholder)) + update_params.append(val) + else: + values.append('%s = NULL' % qn(name)) + if not values: + return '', () + result.append(', '.join(values)) + where, params = self.query.where.as_sql(qn=qn, connection=self.connection) + if where: + result.append('WHERE %s' % where) + return ' '.join(result), tuple(update_params + params) + + def execute_sql(self, result_type): + """ + Execute the specified update. Returns the number of rows affected by + the primary update query. The "primary update query" is the first + non-empty query that is executed. Row counts for any subsequent, + related queries are not available. + """ + cursor = super(SQLUpdateCompiler, self).execute_sql(result_type) + rows = cursor and cursor.rowcount or 0 + is_empty = cursor is None + del cursor + for query in self.query.get_related_updates(): + aux_rows = query.get_compiler(self.using).execute_sql(result_type) + if is_empty: + rows = aux_rows + is_empty = False + return rows + + def pre_sql_setup(self): + """ + If the update depends on results from other tables, we need to do some + munging of the "where" conditions to match the format required for + (portable) SQL updates. That is done here. + + Further, if we are going to be running multiple updates, we pull out + the id values to update at this point so that they don't change as a + result of the progressive updates. + """ + self.query.select_related = False + self.query.clear_ordering(True) + super(SQLUpdateCompiler, self).pre_sql_setup() + count = self.query.count_active_tables() + if not self.query.related_updates and count == 1: + return + + # We need to use a sub-select in the where clause to filter on things + # from other tables. + query = self.query.clone(klass=Query) + query.bump_prefix() + query.extra = {} + query.select = [] + query.add_fields([query.model._meta.pk.name]) + must_pre_select = count > 1 and not self.connection.features.update_can_self_select + + # Now we adjust the current query: reset the where clause and get rid + # of all the tables we don't need (since they're in the sub-select). + self.query.where = self.query.where_class() + if self.query.related_updates or must_pre_select: + # Either we're using the idents in multiple update queries (so + # don't want them to change), or the db backend doesn't support + # selecting from the updating table (e.g. MySQL). + idents = [] + for rows in query.get_compiler(self.using).execute_sql(MULTI): + idents.extend([r[0] for r in rows]) + self.query.add_filter(('pk__in', idents)) + self.query.related_ids = idents + else: + # The fast path. Filters and updates in one query. + self.query.add_filter(('pk__in', query)) + for alias in self.query.tables[1:]: + self.query.alias_refcount[alias] = 0 + +class SQLAggregateCompiler(SQLCompiler): + def as_sql(self, qn=None): + """ + Creates the SQL for this query. Returns the SQL string and list of + parameters. + """ + if qn is None: + qn = self.quote_name_unless_alias + sql = ('SELECT %s FROM (%s) subquery' % ( + ', '.join([ + aggregate.as_sql(qn, self.connection) + for aggregate in self.query.aggregate_select.values() + ]), + self.query.subquery) + ) + params = self.query.sub_params + return (sql, params) + +class SQLDateCompiler(SQLCompiler): + def results_iter(self): + """ + Returns an iterator over the results from executing this query. + """ + resolve_columns = hasattr(self, 'resolve_columns') + if resolve_columns: + from django.db.models.fields import DateTimeField + fields = [DateTimeField()] + else: + from django.db.backends.util import typecast_timestamp + needs_string_cast = self.connection.features.needs_datetime_string_cast + + offset = len(self.query.extra_select) + for rows in self.execute_sql(MULTI): + for row in rows: + date = row[offset] + if resolve_columns: + date = self.resolve_columns(row, fields)[offset] + elif needs_string_cast: + date = typecast_timestamp(str(date)) + yield date + + +def empty_iter(): + """ + Returns an iterator containing no results. + """ + yield iter([]).next() + + +def order_modified_iter(cursor, trim, sentinel): + """ + Yields blocks of rows from a cursor. We use this iterator in the special + case when extra output columns have been added to support ordering + requirements. We must trim those extra columns before anything else can use + the results, since they're only needed to make the SQL valid. + """ + for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)), + sentinel): + yield [r[:-trim] for r in rows] diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py index 4d53999c79..92d64e15dd 100644 --- a/django/db/models/sql/datastructures.py +++ b/django/db/models/sql/datastructures.py @@ -29,22 +29,18 @@ class Date(object): """ Add a date selection column. """ - def __init__(self, col, lookup_type, date_sql_func): + def __init__(self, col, lookup_type): self.col = col self.lookup_type = lookup_type - self.date_sql_func = date_sql_func def relabel_aliases(self, change_map): c = self.col if isinstance(c, (list, tuple)): self.col = (change_map.get(c[0], c[0]), c[1]) - def as_sql(self, quote_func=None): - if not quote_func: - quote_func = lambda x: x + def as_sql(self, qn, connection): if isinstance(self.col, (list, tuple)): - col = '%s.%s' % tuple([quote_func(c) for c in self.col]) + col = '%s.%s' % tuple([qn(c) for c in self.col]) else: col = self.col - return self.date_sql_func(self.lookup_type, col) - + return connection.ops.date_trunc_sql(self.lookup_type, col) diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py index 0914c2b3c1..9bbc16ec8a 100644 --- a/django/db/models/sql/expressions.py +++ b/django/db/models/sql/expressions.py @@ -1,5 +1,4 @@ from django.core.exceptions import FieldError -from django.db import connection from django.db.models.fields import FieldDoesNotExist from django.db.models.sql.constants import LOOKUP_SEP @@ -12,8 +11,11 @@ class SQLEvaluator(object): self.contains_aggregate = False self.expression.prepare(self, query, allow_joins) - def as_sql(self, qn=None): - return self.expression.evaluate(self, qn) + def prepare(self): + return self + + def as_sql(self, qn, connection): + return self.expression.evaluate(self, qn, connection) def relabel_aliases(self, change_map): for node, col in self.cols.items(): @@ -54,15 +56,12 @@ class SQLEvaluator(object): # Vistor methods for final expression evaluation # ################################################## - def evaluate_node(self, node, qn): - if not qn: - qn = connection.ops.quote_name - + def evaluate_node(self, node, qn, connection): expressions = [] expression_params = [] for child in node.children: if hasattr(child, 'evaluate'): - sql, params = child.evaluate(self, qn) + sql, params = child.evaluate(self, qn, connection) else: sql, params = '%s', (child,) @@ -77,12 +76,9 @@ class SQLEvaluator(object): return connection.ops.combine_expression(node.connector, expressions), expression_params - def evaluate_leaf(self, node, qn): - if not qn: - qn = connection.ops.quote_name - + def evaluate_leaf(self, node, qn, connection): col = self.cols[node] if hasattr(col, 'as_sql'): - return col.as_sql(qn), () + return col.as_sql(qn, connection), () else: return '%s.%s' % (qn(col[0]), qn(col[1])), () diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 9ecf273be3..d821c0ee02 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -11,32 +11,34 @@ from django.utils.copycompat import deepcopy from django.utils.tree import Node from django.utils.datastructures import SortedDict from django.utils.encoding import force_unicode -from django.db.backends.util import truncate_name -from django.db import connection +from django.db import connections, DEFAULT_DB_ALIAS from django.db.models import signals from django.db.models.fields import FieldDoesNotExist from django.db.models.query_utils import select_related_descend, InvalidQuery from django.db.models.sql import aggregates as base_aggregates_module +from django.db.models.sql.constants import * +from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin from django.db.models.sql.expressions import SQLEvaluator from django.db.models.sql.where import WhereNode, Constraint, EverythingNode, AND, OR from django.core.exceptions import FieldError -from datastructures import EmptyResultSet, Empty, MultiJoin -from constants import * -__all__ = ['Query', 'BaseQuery', 'RawQuery'] +__all__ = ['Query', 'RawQuery'] class RawQuery(object): """ A single raw SQL query """ - def __init__(self, sql, connection, params=None): + def __init__(self, sql, using, params=None): self.validate_sql(sql) self.params = params or () self.sql = sql - self.connection = connection + self.using = using self.cursor = None + def clone(self, using): + return RawQuery(self.sql, using, params=self.params) + def get_columns(self): if self.cursor is None: self._execute_query() @@ -57,10 +59,11 @@ class RawQuery(object): return "<RawQuery: %r>" % (self.sql % self.params) def _execute_query(self): - self.cursor = self.connection.cursor() + self.cursor = connections[self.using].cursor() self.cursor.execute(self.sql, self.params) -class BaseQuery(object): + +class Query(object): """ A single SQL query. """ @@ -73,9 +76,10 @@ class BaseQuery(object): query_terms = QUERY_TERMS aggregates_module = base_aggregates_module - def __init__(self, model, connection, where=WhereNode): + compiler = 'SQLCompiler' + + def __init__(self, model, where=WhereNode): self.model = model - self.connection = connection self.alias_refcount = {} self.alias_map = {} # Maps alias to join information self.table_map = {} # Maps table names to list of aliases. @@ -139,7 +143,7 @@ class BaseQuery(object): Parameter values won't necessarily be quoted correctly, since that is done by the database interface at execution time. """ - sql, params = self.as_sql() + sql, params = self.get_compiler(DEFAULT_DB_ALIAS).as_sql() return sql % params def __deepcopy__(self, memo): @@ -154,7 +158,6 @@ class BaseQuery(object): obj_dict = self.__dict__.copy() obj_dict['related_select_fields'] = [] obj_dict['related_select_cols'] = [] - del obj_dict['connection'] # Fields can't be pickled, so if a field list has been # specified, we pickle the list of field names instead. @@ -176,10 +179,16 @@ class BaseQuery(object): ] self.__dict__.update(obj_dict) - # XXX: Need a better solution for this when multi-db stuff is - # supported. It's the only class-reference to the module-level - # connection variable. - self.connection = connection + + def prepare(self): + return self + + def get_compiler(self, using=None, connection=None): + if using is None and connection is None: + raise ValueError("Need either using or connection") + if using: + connection = connections[using] + return connection.ops.compiler(self.compiler)(self, connection, using) def get_meta(self): """ @@ -189,22 +198,6 @@ class BaseQuery(object): """ return self.model._meta - def quote_name_unless_alias(self, name): - """ - A wrapper around connection.ops.quote_name that doesn't quote aliases - for table names. This avoids problems with some SQL dialects that treat - quoted strings specially (e.g. PostgreSQL). - """ - if name in self.quote_cache: - return self.quote_cache[name] - if ((name in self.alias_map and name not in self.table_map) or - name in self.extra_select): - self.quote_cache[name] = name - return name - r = self.connection.ops.quote_name(name) - self.quote_cache[name] = r - return r - def clone(self, klass=None, **kwargs): """ Creates a copy of the current instance. The 'kwargs' parameter can be @@ -213,7 +206,6 @@ class BaseQuery(object): obj = Empty() obj.__class__ = klass or self.__class__ obj.model = self.model - obj.connection = self.connection obj.alias_refcount = self.alias_refcount.copy() obj.alias_map = self.alias_map.copy() obj.table_map = self.table_map.copy() @@ -276,16 +268,16 @@ class BaseQuery(object): obj._setup_query() return obj - def convert_values(self, value, field): + def convert_values(self, value, field, connection): """Convert the database-returned value into a type that is consistent across database backends. By default, this defers to the underlying backend operations, but it can be overridden by Query classes for specific backends. """ - return self.connection.ops.convert_values(value, field) + return connection.ops.convert_values(value, field) - def resolve_aggregate(self, value, aggregate): + def resolve_aggregate(self, value, aggregate, connection): """Resolve the value of aggregates returned by the database to consistent (and reasonable) types. @@ -305,39 +297,9 @@ class BaseQuery(object): return float(value) else: # Return value depends on the type of the field being processed. - return self.convert_values(value, aggregate.field) - - def results_iter(self): - """ - Returns an iterator over the results from executing this query. - """ - resolve_columns = hasattr(self, 'resolve_columns') - fields = None - for rows in self.execute_sql(MULTI): - for row in rows: - if resolve_columns: - if fields is None: - # We only set this up here because - # related_select_fields isn't populated until - # execute_sql() has been called. - if self.select_fields: - fields = self.select_fields + self.related_select_fields - else: - fields = self.model._meta.fields - row = self.resolve_columns(row, fields) - - if self.aggregate_select: - aggregate_start = len(self.extra_select.keys()) + len(self.select) - aggregate_end = aggregate_start + len(self.aggregate_select) - row = tuple(row[:aggregate_start]) + tuple([ - self.resolve_aggregate(value, aggregate) - for (alias, aggregate), value - in zip(self.aggregate_select.items(), row[aggregate_start:aggregate_end]) - ]) + tuple(row[aggregate_end:]) - - yield row + return self.convert_values(value, aggregate.field, connection) - def get_aggregation(self): + def get_aggregation(self, using): """ Returns the dictionary with the values of the existing aggregations. """ @@ -349,7 +311,7 @@ class BaseQuery(object): # over the subquery instead. if self.group_by is not None: from subqueries import AggregateQuery - query = AggregateQuery(self.model, self.connection) + query = AggregateQuery(self.model) obj = self.clone() @@ -360,7 +322,7 @@ class BaseQuery(object): query.aggregate_select[alias] = aggregate del obj.aggregate_select[alias] - query.add_subquery(obj) + query.add_subquery(obj, using) else: query = self self.select = [] @@ -374,17 +336,17 @@ class BaseQuery(object): query.related_select_cols = [] query.related_select_fields = [] - result = query.execute_sql(SINGLE) + result = query.get_compiler(using).execute_sql(SINGLE) if result is None: result = [None for q in query.aggregate_select.items()] return dict([ - (alias, self.resolve_aggregate(val, aggregate)) + (alias, self.resolve_aggregate(val, aggregate, connection=connections[using])) for (alias, aggregate), val in zip(query.aggregate_select.items(), result) ]) - def get_count(self): + def get_count(self, using): """ Performs a COUNT() query using the current filter constraints. """ @@ -398,11 +360,11 @@ class BaseQuery(object): subquery.clear_ordering(True) subquery.clear_limits() - obj = AggregateQuery(obj.model, obj.connection) - obj.add_subquery(subquery) + obj = AggregateQuery(obj.model) + obj.add_subquery(subquery, using=using) obj.add_count_column() - number = obj.get_aggregation()[None] + number = obj.get_aggregation(using=using)[None] # Apply offset and limit constraints manually, since using LIMIT/OFFSET # in SQL (in variants that provide them) doesn't change the COUNT @@ -413,7 +375,7 @@ class BaseQuery(object): return number - def has_results(self): + def has_results(self, using): q = self.clone() q.add_extra({'a': 1}, None, None, None, None, None) q.add_fields(()) @@ -421,99 +383,8 @@ class BaseQuery(object): q.set_aggregate_mask(()) q.clear_ordering() q.set_limits(high=1) - return bool(q.execute_sql(SINGLE)) - - def as_sql(self, with_limits=True, with_col_aliases=False): - """ - Creates the SQL for this query. Returns the SQL string and list of - parameters. - - If 'with_limits' is False, any limit/offset information is not included - in the query. - """ - self.pre_sql_setup() - out_cols = self.get_columns(with_col_aliases) - ordering, ordering_group_by = self.get_ordering() - - # This must come after 'select' and 'ordering' -- see docstring of - # get_from_clause() for details. - from_, f_params = self.get_from_clause() - - qn = self.quote_name_unless_alias - where, w_params = self.where.as_sql(qn=qn) - having, h_params = self.having.as_sql(qn=qn) - params = [] - for val in self.extra_select.itervalues(): - params.extend(val[1]) - - result = ['SELECT'] - if self.distinct: - result.append('DISTINCT') - result.append(', '.join(out_cols + self.ordering_aliases)) - - result.append('FROM') - result.extend(from_) - params.extend(f_params) - - if where: - result.append('WHERE %s' % where) - params.extend(w_params) - if self.extra_where: - if not where: - result.append('WHERE') - else: - result.append('AND') - result.append(' AND '.join(self.extra_where)) - - grouping, gb_params = self.get_grouping() - if grouping: - if ordering: - # If the backend can't group by PK (i.e., any database - # other than MySQL), then any fields mentioned in the - # ordering clause needs to be in the group by clause. - if not self.connection.features.allows_group_by_pk: - for col, col_params in ordering_group_by: - if col not in grouping: - grouping.append(str(col)) - gb_params.extend(col_params) - else: - ordering = self.connection.ops.force_no_ordering() - result.append('GROUP BY %s' % ', '.join(grouping)) - params.extend(gb_params) - - if having: - result.append('HAVING %s' % having) - params.extend(h_params) - - if ordering: - result.append('ORDER BY %s' % ', '.join(ordering)) - - if with_limits: - if self.high_mark is not None: - result.append('LIMIT %d' % (self.high_mark - self.low_mark)) - if self.low_mark: - if self.high_mark is None: - val = self.connection.ops.no_limit_value() - if val: - result.append('LIMIT %d' % val) - result.append('OFFSET %d' % self.low_mark) - - params.extend(self.extra_params) - return ' '.join(result), tuple(params) - - def as_nested_sql(self): - """ - Perform the same functionality as the as_sql() method, returning an - SQL string and parameters. However, the alias prefixes are bumped - beforehand (in a copy -- the current query isn't changed) and any - ordering is removed. - - Used when nesting this query inside another. - """ - obj = self.clone() - obj.clear_ordering(True) - obj.bump_prefix() - return obj.as_sql() + compiler = q.get_compiler(using=using) + return bool(compiler.execute_sql(SINGLE)) def combine(self, rhs, connector): """ @@ -613,20 +484,6 @@ class BaseQuery(object): self.order_by = rhs.order_by and rhs.order_by[:] or self.order_by self.extra_order_by = rhs.extra_order_by or self.extra_order_by - def pre_sql_setup(self): - """ - Does any necessary class setup immediately prior to producing SQL. This - is for things that can't necessarily be done in __init__ because we - might not have all the pieces in place at that time. - """ - if not self.tables: - self.join((None, self.model._meta.db_table, None, None)) - if (not self.select and self.default_cols and not - self.included_inherited_models): - self.setup_inherited_models() - if self.select_related and not self.related_select_cols: - self.fill_related_selections() - def deferred_to_data(self, target, callback): """ Converts the self.deferred_loading data structure to an alternate data @@ -705,15 +562,6 @@ class BaseQuery(object): for model, values in seen.iteritems(): callback(target, model, values) - def deferred_to_columns(self): - """ - Converts the self.deferred_loading data structure to mapping of table - names to sets of column names which are to be loaded. Returns the - dictionary. - """ - columns = {} - self.deferred_to_data(columns, self.deferred_to_columns_cb) - return columns def deferred_to_columns_cb(self, target, model, fields): """ @@ -726,349 +574,6 @@ class BaseQuery(object): for field in fields: target[table].add(field.column) - def get_columns(self, with_aliases=False): - """ - Returns the list of columns to use in the select statement. If no - columns have been specified, returns all columns relating to fields in - the model. - - If 'with_aliases' is true, any column names that are duplicated - (without the table names) are given unique aliases. This is needed in - some cases to avoid ambiguity with nested queries. - """ - qn = self.quote_name_unless_alias - qn2 = self.connection.ops.quote_name - result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in self.extra_select.iteritems()] - aliases = set(self.extra_select.keys()) - if with_aliases: - col_aliases = aliases.copy() - else: - col_aliases = set() - if self.select: - only_load = self.deferred_to_columns() - for col in self.select: - if isinstance(col, (list, tuple)): - alias, column = col - table = self.alias_map[alias][TABLE_NAME] - if table in only_load and col not in only_load[table]: - continue - r = '%s.%s' % (qn(alias), qn(column)) - if with_aliases: - if col[1] in col_aliases: - c_alias = 'Col%d' % len(col_aliases) - result.append('%s AS %s' % (r, c_alias)) - aliases.add(c_alias) - col_aliases.add(c_alias) - else: - result.append('%s AS %s' % (r, qn2(col[1]))) - aliases.add(r) - col_aliases.add(col[1]) - else: - result.append(r) - aliases.add(r) - col_aliases.add(col[1]) - else: - result.append(col.as_sql(quote_func=qn)) - - if hasattr(col, 'alias'): - aliases.add(col.alias) - col_aliases.add(col.alias) - - elif self.default_cols: - cols, new_aliases = self.get_default_columns(with_aliases, - col_aliases) - result.extend(cols) - aliases.update(new_aliases) - - result.extend([ - '%s%s' % ( - aggregate.as_sql(quote_func=qn), - alias is not None and ' AS %s' % qn(alias) or '' - ) - for alias, aggregate in self.aggregate_select.items() - ]) - - for table, col in self.related_select_cols: - r = '%s.%s' % (qn(table), qn(col)) - if with_aliases and col in col_aliases: - c_alias = 'Col%d' % len(col_aliases) - result.append('%s AS %s' % (r, c_alias)) - aliases.add(c_alias) - col_aliases.add(c_alias) - else: - result.append(r) - aliases.add(r) - col_aliases.add(col) - - self._select_aliases = aliases - return result - - def get_default_columns(self, with_aliases=False, col_aliases=None, - start_alias=None, opts=None, as_pairs=False): - """ - Computes the default columns for selecting every field in the base - model. Will sometimes be called to pull in related models (e.g. via - select_related), in which case "opts" and "start_alias" will be given - to provide a starting point for the traversal. - - Returns a list of strings, quoted appropriately for use in SQL - directly, as well as a set of aliases used in the select statement (if - 'as_pairs' is True, returns a list of (alias, col_name) pairs instead - of strings as the first component and None as the second component). - """ - result = [] - if opts is None: - opts = self.model._meta - qn = self.quote_name_unless_alias - qn2 = self.connection.ops.quote_name - aliases = set() - only_load = self.deferred_to_columns() - # Skip all proxy to the root proxied model - proxied_model = get_proxied_model(opts) - - if start_alias: - seen = {None: start_alias} - for field, model in opts.get_fields_with_model(): - if start_alias: - try: - alias = seen[model] - except KeyError: - if model is proxied_model: - alias = start_alias - else: - link_field = opts.get_ancestor_link(model) - alias = self.join((start_alias, model._meta.db_table, - link_field.column, model._meta.pk.column)) - seen[model] = alias - else: - # If we're starting from the base model of the queryset, the - # aliases will have already been set up in pre_sql_setup(), so - # we can save time here. - alias = self.included_inherited_models[model] - table = self.alias_map[alias][TABLE_NAME] - if table in only_load and field.column not in only_load[table]: - continue - if as_pairs: - result.append((alias, field.column)) - aliases.add(alias) - continue - if with_aliases and field.column in col_aliases: - c_alias = 'Col%d' % len(col_aliases) - result.append('%s.%s AS %s' % (qn(alias), - qn2(field.column), c_alias)) - col_aliases.add(c_alias) - aliases.add(c_alias) - else: - r = '%s.%s' % (qn(alias), qn2(field.column)) - result.append(r) - aliases.add(r) - if with_aliases: - col_aliases.add(field.column) - return result, aliases - - def get_from_clause(self): - """ - Returns a list of strings that are joined together to go after the - "FROM" part of the query, as well as a list any extra parameters that - need to be included. Sub-classes, can override this to create a - from-clause via a "select". - - This should only be called after any SQL construction methods that - might change the tables we need. This means the select columns and - ordering must be done first. - """ - result = [] - qn = self.quote_name_unless_alias - qn2 = self.connection.ops.quote_name - first = True - for alias in self.tables: - if not self.alias_refcount[alias]: - continue - try: - name, alias, join_type, lhs, lhs_col, col, nullable = self.alias_map[alias] - except KeyError: - # Extra tables can end up in self.tables, but not in the - # alias_map if they aren't in a join. That's OK. We skip them. - continue - alias_str = (alias != name and ' %s' % alias or '') - if join_type and not first: - result.append('%s %s%s ON (%s.%s = %s.%s)' - % (join_type, qn(name), alias_str, qn(lhs), - qn2(lhs_col), qn(alias), qn2(col))) - else: - connector = not first and ', ' or '' - result.append('%s%s%s' % (connector, qn(name), alias_str)) - first = False - for t in self.extra_tables: - alias, unused = self.table_alias(t) - # Only add the alias if it's not already present (the table_alias() - # calls increments the refcount, so an alias refcount of one means - # this is the only reference. - if alias not in self.alias_map or self.alias_refcount[alias] == 1: - connector = not first and ', ' or '' - result.append('%s%s' % (connector, qn(alias))) - first = False - return result, [] - - def get_grouping(self): - """ - Returns a tuple representing the SQL elements in the "group by" clause. - """ - qn = self.quote_name_unless_alias - result, params = [], [] - if self.group_by is not None: - group_by = self.group_by or [] - - extra_selects = [] - for extra_select, extra_params in self.extra_select.itervalues(): - extra_selects.append(extra_select) - params.extend(extra_params) - for col in group_by + self.related_select_cols + extra_selects: - if isinstance(col, (list, tuple)): - result.append('%s.%s' % (qn(col[0]), qn(col[1]))) - elif hasattr(col, 'as_sql'): - result.append(col.as_sql(qn)) - else: - result.append(str(col)) - return result, params - - def get_ordering(self): - """ - Returns a tuple containing a list representing the SQL elements in the - "order by" clause, and the list of SQL elements that need to be added - to the GROUP BY clause as a result of the ordering. - - Also sets the ordering_aliases attribute on this instance to a list of - extra aliases needed in the select. - - Determining the ordering SQL can change the tables we need to include, - so this should be run *before* get_from_clause(). - """ - if self.extra_order_by: - ordering = self.extra_order_by - elif not self.default_ordering: - ordering = self.order_by - else: - ordering = self.order_by or self.model._meta.ordering - qn = self.quote_name_unless_alias - qn2 = self.connection.ops.quote_name - distinct = self.distinct - select_aliases = self._select_aliases - result = [] - group_by = [] - ordering_aliases = [] - if self.standard_ordering: - asc, desc = ORDER_DIR['ASC'] - else: - asc, desc = ORDER_DIR['DESC'] - - # It's possible, due to model inheritance, that normal usage might try - # to include the same field more than once in the ordering. We track - # the table/column pairs we use and discard any after the first use. - processed_pairs = set() - - for field in ordering: - if field == '?': - result.append(self.connection.ops.random_function_sql()) - continue - if isinstance(field, int): - if field < 0: - order = desc - field = -field - else: - order = asc - result.append('%s %s' % (field, order)) - group_by.append((field, [])) - continue - col, order = get_order_dir(field, asc) - if col in self.aggregate_select: - result.append('%s %s' % (col, order)) - continue - if '.' in field: - # This came in through an extra(order_by=...) addition. Pass it - # on verbatim. - table, col = col.split('.', 1) - if (table, col) not in processed_pairs: - elt = '%s.%s' % (qn(table), col) - processed_pairs.add((table, col)) - if not distinct or elt in select_aliases: - result.append('%s %s' % (elt, order)) - group_by.append((elt, [])) - elif get_order_dir(field)[0] not in self.extra_select: - # 'col' is of the form 'field' or 'field1__field2' or - # '-field1__field2__field', etc. - for table, col, order in self.find_ordering_name(field, - self.model._meta, default_order=asc): - if (table, col) not in processed_pairs: - elt = '%s.%s' % (qn(table), qn2(col)) - processed_pairs.add((table, col)) - if distinct and elt not in select_aliases: - ordering_aliases.append(elt) - result.append('%s %s' % (elt, order)) - group_by.append((elt, [])) - else: - elt = qn2(col) - if distinct and col not in select_aliases: - ordering_aliases.append(elt) - result.append('%s %s' % (elt, order)) - group_by.append(self.extra_select[col]) - self.ordering_aliases = ordering_aliases - return result, group_by - - def find_ordering_name(self, name, opts, alias=None, default_order='ASC', - already_seen=None): - """ - Returns the table alias (the name might be ambiguous, the alias will - not be) and column name for ordering by the given 'name' parameter. - The 'name' is of the form 'field1__field2__...__fieldN'. - """ - name, order = get_order_dir(name, default_order) - pieces = name.split(LOOKUP_SEP) - if not alias: - alias = self.get_initial_alias() - field, target, opts, joins, last, extra = self.setup_joins(pieces, - opts, alias, False) - alias = joins[-1] - col = target.column - if not field.rel: - # To avoid inadvertent trimming of a necessary alias, use the - # refcount to show that we are referencing a non-relation field on - # the model. - self.ref_alias(alias) - - # Must use left outer joins for nullable fields and their relations. - self.promote_alias_chain(joins, - self.alias_map[joins[0]][JOIN_TYPE] == self.LOUTER) - - # If we get to this point and the field is a relation to another model, - # append the default ordering for that model. - if field.rel and len(joins) > 1 and opts.ordering: - # Firstly, avoid infinite loops. - if not already_seen: - already_seen = set() - join_tuple = tuple([self.alias_map[j][TABLE_NAME] for j in joins]) - if join_tuple in already_seen: - raise FieldError('Infinite loop caused by ordering.') - already_seen.add(join_tuple) - - results = [] - for item in opts.ordering: - results.extend(self.find_ordering_name(item, opts, alias, - order, already_seen)) - return results - - if alias: - # We have to do the same "final join" optimisation as in - # add_filter, since the final column might not otherwise be part of - # the select set (so we can't order on it). - while 1: - join = self.alias_map[alias] - if col != join[RHS_JOIN_COL]: - break - self.unref_alias(alias) - alias = join[LHS_ALIAS] - col = join[LHS_JOIN_COL] - return [(alias, col, order)] def table_alias(self, table_name, create=False): """ @@ -1372,113 +877,6 @@ class BaseQuery(object): self.unref_alias(alias) self.included_inherited_models = {} - def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1, - used=None, requested=None, restricted=None, nullable=None, - dupe_set=None, avoid_set=None): - """ - Fill in the information needed for a select_related query. The current - depth is measured as the number of connections away from the root model - (for example, cur_depth=1 means we are looking at models with direct - connections to the root model). - """ - if not restricted and self.max_depth and cur_depth > self.max_depth: - # We've recursed far enough; bail out. - return - - if not opts: - opts = self.get_meta() - root_alias = self.get_initial_alias() - self.related_select_cols = [] - self.related_select_fields = [] - if not used: - used = set() - if dupe_set is None: - dupe_set = set() - if avoid_set is None: - avoid_set = set() - orig_dupe_set = dupe_set - - # Setup for the case when only particular related fields should be - # included in the related selection. - if requested is None and restricted is not False: - if isinstance(self.select_related, dict): - requested = self.select_related - restricted = True - else: - restricted = False - - for f, model in opts.get_fields_with_model(): - if not select_related_descend(f, restricted, requested): - continue - # The "avoid" set is aliases we want to avoid just for this - # particular branch of the recursion. They aren't permanently - # forbidden from reuse in the related selection tables (which is - # what "used" specifies). - avoid = avoid_set.copy() - dupe_set = orig_dupe_set.copy() - table = f.rel.to._meta.db_table - if nullable or f.null: - promote = True - else: - promote = False - if model: - int_opts = opts - alias = root_alias - alias_chain = [] - for int_model in opts.get_base_chain(model): - # Proxy model have elements in base chain - # with no parents, assign the new options - # object and skip to the next base in that - # case - if not int_opts.parents[int_model]: - int_opts = int_model._meta - continue - lhs_col = int_opts.parents[int_model].column - dedupe = lhs_col in opts.duplicate_targets - if dedupe: - avoid.update(self.dupe_avoidance.get(id(opts), lhs_col), - ()) - dupe_set.add((opts, lhs_col)) - int_opts = int_model._meta - alias = self.join((alias, int_opts.db_table, lhs_col, - int_opts.pk.column), exclusions=used, - promote=promote) - alias_chain.append(alias) - for (dupe_opts, dupe_col) in dupe_set: - self.update_dupe_avoidance(dupe_opts, dupe_col, alias) - if self.alias_map[root_alias][JOIN_TYPE] == self.LOUTER: - self.promote_alias_chain(alias_chain, True) - else: - alias = root_alias - - dedupe = f.column in opts.duplicate_targets - if dupe_set or dedupe: - avoid.update(self.dupe_avoidance.get((id(opts), f.column), ())) - if dedupe: - dupe_set.add((opts, f.column)) - - alias = self.join((alias, table, f.column, - f.rel.get_related_field().column), - exclusions=used.union(avoid), promote=promote) - used.add(alias) - columns, aliases = self.get_default_columns(start_alias=alias, - opts=f.rel.to._meta, as_pairs=True) - self.related_select_cols.extend(columns) - if self.alias_map[alias][JOIN_TYPE] == self.LOUTER: - self.promote_alias_chain(aliases, True) - self.related_select_fields.extend(f.rel.to._meta.fields) - if restricted: - next = requested.get(f.name, {}) - else: - next = False - if f.null is not None: - new_nullable = f.null - else: - new_nullable = None - for dupe_opts, dupe_col in dupe_set: - self.update_dupe_avoidance(dupe_opts, dupe_col, alias) - self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1, - used, next, restricted, new_nullable, dupe_set, avoid) def add_aggregate(self, aggregate, model, alias, is_summary): """ @@ -1527,7 +925,6 @@ class BaseQuery(object): col = field_name # Add the aggregate to the query - alias = truncate_name(alias, self.connection.ops.max_name_length()) aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary) def add_filter(self, filter_expr, connector=AND, negate=False, trim=False, @@ -1578,10 +975,6 @@ class BaseQuery(object): raise ValueError("Cannot use None as a query value") lookup_type = 'isnull' value = True - elif (value == '' and lookup_type == 'exact' and - connection.features.interprets_empty_strings_as_nulls): - lookup_type = 'isnull' - value = True elif callable(value): value = value() elif hasattr(value, 'evaluate'): @@ -1999,7 +1392,7 @@ class BaseQuery(object): original exclude filter (filter_expr) and the portion up to the first N-to-many relation field. """ - query = Query(self.model, self.connection) + query = Query(self.model) query.add_filter(filter_expr, can_reuse=can_reuse) query.bump_prefix() query.clear_ordering(True) @@ -2138,11 +1531,6 @@ class BaseQuery(object): will be made automatically. """ self.group_by = [] - if self.connection.features.allows_group_by_pk: - if len(self.select) == len(self.model._meta.fields): - self.group_by.append((self.model._meta.db_table, - self.model._meta.pk.column)) - return for sel in self.select: self.group_by.append(sel) @@ -2382,58 +1770,6 @@ class BaseQuery(object): self.select = [(select_alias, select_col)] self.remove_inherited_models() - def execute_sql(self, result_type=MULTI): - """ - Run the query against the database and returns the result(s). The - return value is a single data item if result_type is SINGLE, or an - iterator over the results if the result_type is MULTI. - - result_type is either MULTI (use fetchmany() to retrieve all rows), - SINGLE (only retrieve a single row), or None. In this last case, the - cursor is returned if any query is executed, since it's used by - subclasses such as InsertQuery). It's possible, however, that no query - is needed, as the filters describe an empty set. In that case, None is - returned, to avoid any unnecessary database interaction. - """ - try: - sql, params = self.as_sql() - if not sql: - raise EmptyResultSet - except EmptyResultSet: - if result_type == MULTI: - return empty_iter() - else: - return - cursor = self.connection.cursor() - cursor.execute(sql, params) - - if not result_type: - return cursor - if result_type == SINGLE: - if self.ordering_aliases: - return cursor.fetchone()[:-len(self.ordering_aliases)] - return cursor.fetchone() - - # The MULTI case. - if self.ordering_aliases: - result = order_modified_iter(cursor, len(self.ordering_aliases), - self.connection.features.empty_fetchmany_value) - else: - result = iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)), - self.connection.features.empty_fetchmany_value) - if not self.connection.features.can_use_chunked_reads: - # If we are using non-chunked reads, we return the same data - # structure as normally, but ensure it is all read into memory - # before going any further. - return list(result) - return result - -# Use the backend's custom Query class if it defines one. Otherwise, use the -# default. -if connection.features.uses_custom_query_class: - Query = connection.ops.query_class(BaseQuery) -else: - Query = BaseQuery def get_order_dir(field, default='ASC'): """ @@ -2448,22 +1784,6 @@ def get_order_dir(field, default='ASC'): return field[1:], dirn[1] return field, dirn[0] -def empty_iter(): - """ - Returns an iterator containing no results. - """ - yield iter([]).next() - -def order_modified_iter(cursor, trim, sentinel): - """ - Yields blocks of rows from a cursor. We use this iterator in the special - case when extra output columns have been added to support ordering - requirements. We must trim those extra columns before anything else can use - the results, since they're only needed to make the SQL valid. - """ - for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)), - sentinel): - yield [r[:-trim] for r in rows] def setup_join_cache(sender, **kwargs): """ diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index f00f1bd68a..e80a023699 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -3,6 +3,7 @@ Query subclasses which provide extra functionality beyond simple data retrieval. """ from django.core.exceptions import FieldError +from django.db import connections from django.db.models.sql.constants import * from django.db.models.sql.datastructures import Date from django.db.models.sql.expressions import SQLEvaluator @@ -17,24 +18,15 @@ class DeleteQuery(Query): Delete queries are done through this class, since they are more constrained than general queries. """ - def as_sql(self): - """ - Creates the SQL for this query. Returns the SQL string and list of - parameters. - """ - assert len(self.tables) == 1, \ - "Can only delete from one table at a time." - result = ['DELETE FROM %s' % self.quote_name_unless_alias(self.tables[0])] - where, params = self.where.as_sql() - result.append('WHERE %s' % where) - return ' '.join(result), tuple(params) - def do_query(self, table, where): + compiler = 'SQLDeleteCompiler' + + def do_query(self, table, where, using): self.tables = [table] self.where = where - self.execute_sql(None) + self.get_compiler(using).execute_sql(None) - def delete_batch_related(self, pk_list): + def delete_batch_related(self, pk_list, using): """ Set up and execute delete queries for all the objects related to the primary key values in pk_list. To delete the objects themselves, use @@ -54,7 +46,7 @@ class DeleteQuery(Query): 'in', pk_list[offset : offset+GET_ITERATOR_CHUNK_SIZE]), AND) - self.do_query(related.field.m2m_db_table(), where) + self.do_query(related.field.m2m_db_table(), where, using=using) for f in cls._meta.many_to_many: w1 = self.where_class() @@ -70,9 +62,9 @@ class DeleteQuery(Query): AND) if w1: where.add(w1, AND) - self.do_query(f.m2m_db_table(), where) + self.do_query(f.m2m_db_table(), where, using=using) - def delete_batch(self, pk_list): + def delete_batch(self, pk_list, using): """ Set up and execute delete queries for all the objects in pk_list. This should be called after delete_batch_related(), if necessary. @@ -85,12 +77,15 @@ class DeleteQuery(Query): field = self.model._meta.pk where.add((Constraint(None, field.column, field), 'in', pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND) - self.do_query(self.model._meta.db_table, where) + self.do_query(self.model._meta.db_table, where, using=using) class UpdateQuery(Query): """ Represents an "update" SQL query. """ + + compiler = 'SQLUpdateCompiler' + def __init__(self, *args, **kwargs): super(UpdateQuery, self).__init__(*args, **kwargs) self._setup_query() @@ -110,98 +105,8 @@ class UpdateQuery(Query): return super(UpdateQuery, self).clone(klass, related_updates=self.related_updates.copy(), **kwargs) - def execute_sql(self, result_type=None): - """ - Execute the specified update. Returns the number of rows affected by - the primary update query. The "primary update query" is the first - non-empty query that is executed. Row counts for any subsequent, - related queries are not available. - """ - cursor = super(UpdateQuery, self).execute_sql(result_type) - rows = cursor and cursor.rowcount or 0 - is_empty = cursor is None - del cursor - for query in self.get_related_updates(): - aux_rows = query.execute_sql(result_type) - if is_empty: - rows = aux_rows - is_empty = False - return rows - - def as_sql(self): - """ - Creates the SQL for this query. Returns the SQL string and list of - parameters. - """ - self.pre_sql_setup() - if not self.values: - return '', () - table = self.tables[0] - qn = self.quote_name_unless_alias - result = ['UPDATE %s' % qn(table)] - result.append('SET') - values, update_params = [], [] - for name, val, placeholder in self.values: - if hasattr(val, 'as_sql'): - sql, params = val.as_sql(qn) - values.append('%s = %s' % (qn(name), sql)) - update_params.extend(params) - elif val is not None: - values.append('%s = %s' % (qn(name), placeholder)) - update_params.append(val) - else: - values.append('%s = NULL' % qn(name)) - result.append(', '.join(values)) - where, params = self.where.as_sql() - if where: - result.append('WHERE %s' % where) - return ' '.join(result), tuple(update_params + params) - - def pre_sql_setup(self): - """ - If the update depends on results from other tables, we need to do some - munging of the "where" conditions to match the format required for - (portable) SQL updates. That is done here. - Further, if we are going to be running multiple updates, we pull out - the id values to update at this point so that they don't change as a - result of the progressive updates. - """ - self.select_related = False - self.clear_ordering(True) - super(UpdateQuery, self).pre_sql_setup() - count = self.count_active_tables() - if not self.related_updates and count == 1: - return - - # We need to use a sub-select in the where clause to filter on things - # from other tables. - query = self.clone(klass=Query) - query.bump_prefix() - query.extra = {} - query.select = [] - query.add_fields([query.model._meta.pk.name]) - must_pre_select = count > 1 and not self.connection.features.update_can_self_select - - # Now we adjust the current query: reset the where clause and get rid - # of all the tables we don't need (since they're in the sub-select). - self.where = self.where_class() - if self.related_updates or must_pre_select: - # Either we're using the idents in multiple update queries (so - # don't want them to change), or the db backend doesn't support - # selecting from the updating table (e.g. MySQL). - idents = [] - for rows in query.execute_sql(MULTI): - idents.extend([r[0] for r in rows]) - self.add_filter(('pk__in', idents)) - self.related_ids = idents - else: - # The fast path. Filters and updates in one query. - self.add_filter(('pk__in', query)) - for alias in self.tables[1:]: - self.alias_refcount[alias] = 0 - - def clear_related(self, related_field, pk_list): + def clear_related(self, related_field, pk_list, using): """ Set up and execute an update query that clears related entries for the keys in pk_list. @@ -214,8 +119,8 @@ class UpdateQuery(Query): self.where.add((Constraint(None, f.column, f), 'in', pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND) - self.values = [(related_field.column, None, '%s')] - self.execute_sql(None) + self.values = [(related_field, None, None)] + self.get_compiler(using).execute_sql(None) def add_update_values(self, values): """ @@ -228,6 +133,9 @@ class UpdateQuery(Query): field, model, direct, m2m = self.model._meta.get_field_by_name(name) if not direct or m2m: raise FieldError('Cannot update model field %r (only non-relations and foreign keys permitted).' % field) + if model: + self.add_related_update(model, field, val) + continue values_seq.append((field, model, val)) return self.add_update_fields(values_seq) @@ -237,36 +145,18 @@ class UpdateQuery(Query): Used by add_update_values() as well as the "fast" update path when saving models. """ - from django.db.models.base import Model - for field, model, val in values_seq: - if hasattr(val, 'prepare_database_save'): - val = val.prepare_database_save(field) - else: - val = field.get_db_prep_save(val) + self.values.extend(values_seq) - # Getting the placeholder for the field. - if hasattr(field, 'get_placeholder'): - placeholder = field.get_placeholder(val) - else: - placeholder = '%s' - - if hasattr(val, 'evaluate'): - val = SQLEvaluator(val, self, allow_joins=False) - if model: - self.add_related_update(model, field.column, val, placeholder) - else: - self.values.append((field.column, val, placeholder)) - - def add_related_update(self, model, column, value, placeholder): + def add_related_update(self, model, field, value): """ Adds (name, value) to an update query for an ancestor model. Updates are coalesced so that we only run one update query per ancestor. """ try: - self.related_updates[model].append((column, value, placeholder)) + self.related_updates[model].append((field, None, value)) except KeyError: - self.related_updates[model] = [(column, value, placeholder)] + self.related_updates[model] = [(field, None, value)] def get_related_updates(self): """ @@ -278,7 +168,7 @@ class UpdateQuery(Query): return [] result = [] for model, values in self.related_updates.iteritems(): - query = UpdateQuery(model, self.connection) + query = UpdateQuery(model) query.values = values if self.related_ids: query.add_filter(('pk__in', self.related_ids)) @@ -286,45 +176,23 @@ class UpdateQuery(Query): return result class InsertQuery(Query): + compiler = 'SQLInsertCompiler' + def __init__(self, *args, **kwargs): super(InsertQuery, self).__init__(*args, **kwargs) self.columns = [] self.values = [] self.params = () - self.return_id = False def clone(self, klass=None, **kwargs): - extras = {'columns': self.columns[:], 'values': self.values[:], - 'params': self.params, 'return_id': self.return_id} + extras = { + 'columns': self.columns[:], + 'values': self.values[:], + 'params': self.params + } extras.update(kwargs) return super(InsertQuery, self).clone(klass, **extras) - def as_sql(self): - # We don't need quote_name_unless_alias() here, since these are all - # going to be column names (so we can avoid the extra overhead). - qn = self.connection.ops.quote_name - opts = self.model._meta - result = ['INSERT INTO %s' % qn(opts.db_table)] - result.append('(%s)' % ', '.join([qn(c) for c in self.columns])) - result.append('VALUES (%s)' % ', '.join(self.values)) - params = self.params - if self.return_id and self.connection.features.can_return_id_from_insert: - col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column)) - r_fmt, r_params = self.connection.ops.return_insert_id() - result.append(r_fmt % col) - params = params + r_params - return ' '.join(result), params - - def execute_sql(self, return_id=False): - self.return_id = return_id - cursor = super(InsertQuery, self).execute_sql(None) - if not (return_id and cursor): - return - if self.connection.features.can_return_id_from_insert: - return self.connection.ops.fetch_returned_insert_id(cursor) - return self.connection.ops.last_insert_id(cursor, - self.model._meta.db_table, self.model._meta.pk.column) - def insert_values(self, insert_values, raw_values=False): """ Set up the insert query from the 'insert_values' dictionary. The @@ -337,17 +205,11 @@ class InsertQuery(Query): """ placeholders, values = [], [] for field, val in insert_values: - if hasattr(field, 'get_placeholder'): - # Some fields (e.g. geo fields) need special munging before - # they can be inserted. - placeholders.append(field.get_placeholder(val)) - else: - placeholders.append('%s') - + placeholders.append((field, val)) self.columns.append(field.column) values.append(val) if raw_values: - self.values.extend(values) + self.values.extend([(None, v) for v in values]) else: self.params += tuple(values) self.values.extend(placeholders) @@ -358,44 +220,8 @@ class DateQuery(Query): date field. This requires some special handling when converting the results back to Python objects, so we put it in a separate class. """ - def __getstate__(self): - """ - Special DateQuery-specific pickle handling. - """ - for elt in self.select: - if isinstance(elt, Date): - # Eliminate a method reference that can't be pickled. The - # __setstate__ method restores this. - elt.date_sql_func = None - return super(DateQuery, self).__getstate__() - def __setstate__(self, obj_dict): - super(DateQuery, self).__setstate__(obj_dict) - for elt in self.select: - if isinstance(elt, Date): - self.date_sql_func = self.connection.ops.date_trunc_sql - - def results_iter(self): - """ - Returns an iterator over the results from executing this query. - """ - resolve_columns = hasattr(self, 'resolve_columns') - if resolve_columns: - from django.db.models.fields import DateTimeField - fields = [DateTimeField()] - else: - from django.db.backends.util import typecast_timestamp - needs_string_cast = self.connection.features.needs_datetime_string_cast - - offset = len(self.extra_select) - for rows in self.execute_sql(MULTI): - for row in rows: - date = row[offset] - if resolve_columns: - date = self.resolve_columns(row, fields)[offset] - elif needs_string_cast: - date = typecast_timestamp(str(date)) - yield date + compiler = 'SQLDateCompiler' def add_date_select(self, field, lookup_type, order='ASC'): """ @@ -404,8 +230,7 @@ class DateQuery(Query): result = self.setup_joins([field.name], self.get_meta(), self.get_initial_alias(), False) alias = result[3][-1] - select = Date((alias, field.column), lookup_type, - self.connection.ops.date_trunc_sql) + select = Date((alias, field.column), lookup_type) self.select = [select] self.select_fields = [None] self.select_related = False # See #7097. @@ -418,20 +243,8 @@ class AggregateQuery(Query): An AggregateQuery takes another query as a parameter to the FROM clause and only selects the elements in the provided list. """ - def add_subquery(self, query): - self.subquery, self.sub_params = query.as_sql(with_col_aliases=True) - def as_sql(self, quote_func=None): - """ - Creates the SQL for this query. Returns the SQL string and list of - parameters. - """ - sql = ('SELECT %s FROM (%s) subquery' % ( - ', '.join([ - aggregate.as_sql() - for aggregate in self.aggregate_select.values() - ]), - self.subquery) - ) - params = self.sub_params - return (sql, params) + compiler = 'SQLAggregateCompiler' + + def add_subquery(self, query, using): + self.subquery, self.sub_params = query.get_compiler(using).as_sql(with_col_aliases=True) diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index ec0545ca5b..4aa2351f17 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -4,7 +4,6 @@ Code to manage the creation and SQL rendering of 'where' constraints. import datetime from django.utils import tree -from django.db import connection from django.db.models.fields import Field from django.db.models.query_utils import QueryWrapper from datastructures import EmptyResultSet, FullResultSet @@ -51,18 +50,6 @@ class WhereNode(tree.Node): # Consume any generators immediately, so that we can determine # emptiness and transform any non-empty values correctly. value = list(value) - if hasattr(obj, "process"): - try: - obj, params = obj.process(lookup_type, value) - except (EmptyShortCircuit, EmptyResultSet): - # There are situations where we want to short-circuit any - # comparisons and make sure that nothing is returned. One - # example is when checking for a NULL pk value, or the - # equivalent. - super(WhereNode, self).add(NothingNode(), connector) - return - else: - params = Field().get_db_prep_lookup(lookup_type, value) # The "annotation" parameter is used to pass auxilliary information # about the value(s) to the query construction. Specifically, datetime @@ -75,10 +62,16 @@ class WhereNode(tree.Node): else: annotation = bool(value) - super(WhereNode, self).add((obj, lookup_type, annotation, params), + if hasattr(obj, "prepare"): + value = obj.prepare(lookup_type, value) + super(WhereNode, self).add((obj, lookup_type, annotation, value), + connector) + return + + super(WhereNode, self).add((obj, lookup_type, annotation, value), connector) - def as_sql(self, qn=None): + def as_sql(self, qn, connection): """ Returns the SQL version of the where clause and the value to be substituted in. Returns None, None if this node is empty. @@ -87,8 +80,6 @@ class WhereNode(tree.Node): (generally not needed except by the internal implementation for recursion). """ - if not qn: - qn = connection.ops.quote_name if not self.children: return None, [] result = [] @@ -97,10 +88,10 @@ class WhereNode(tree.Node): for child in self.children: try: if hasattr(child, 'as_sql'): - sql, params = child.as_sql(qn=qn) + sql, params = child.as_sql(qn=qn, connection=connection) else: # A leaf node in the tree. - sql, params = self.make_atom(child, qn) + sql, params = self.make_atom(child, qn, connection) except EmptyResultSet: if self.connector == AND and not self.negated: @@ -136,7 +127,7 @@ class WhereNode(tree.Node): sql_string = '(%s)' % sql_string return sql_string, result_params - def make_atom(self, child, qn): + def make_atom(self, child, qn, connection): """ Turn a tuple (table_alias, column_name, db_type, lookup_type, value_annot, params) into valid SQL. @@ -144,13 +135,21 @@ class WhereNode(tree.Node): Returns the string for the SQL fragment and the parameters to use for it. """ - lvalue, lookup_type, value_annot, params = child + lvalue, lookup_type, value_annot, params_or_value = child + if hasattr(lvalue, 'process'): + try: + lvalue, params = lvalue.process(lookup_type, params_or_value, connection) + except EmptyShortCircuit: + raise EmptyResultSet + else: + params = Field().get_db_prep_lookup(lookup_type, params_or_value, + connection=connection, prepared=True) if isinstance(lvalue, tuple): # A direct database column lookup. - field_sql = self.sql_for_columns(lvalue, qn) + field_sql = self.sql_for_columns(lvalue, qn, connection) else: # A smart object with an as_sql() method. - field_sql = lvalue.as_sql(quote_func=qn) + field_sql = lvalue.as_sql(qn, connection) if value_annot is datetime.datetime: cast_sql = connection.ops.datetime_cast_sql() @@ -158,11 +157,16 @@ class WhereNode(tree.Node): cast_sql = '%s' if hasattr(params, 'as_sql'): - extra, params = params.as_sql(qn) + extra, params = params.as_sql(qn, connection) cast_sql = '' else: extra = '' + if (len(params) == 1 and params[0] == '' and lookup_type == 'exact' + and connection.features.interprets_empty_strings_as_nulls): + lookup_type = 'isnull' + value_annot = True + if lookup_type in connection.operators: format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),) return (format % (field_sql, @@ -191,7 +195,7 @@ class WhereNode(tree.Node): raise TypeError('Invalid lookup_type: %r' % lookup_type) - def sql_for_columns(self, data, qn): + def sql_for_columns(self, data, qn, connection): """ Returns the SQL fragment used for the left-hand side of a column constraint (for example, the "T1.foo" portion in the clause @@ -233,7 +237,8 @@ class EverythingNode(object): """ A node that matches everything. """ - def as_sql(self, qn=None): + + def as_sql(self, qn=None, connection=None): raise FullResultSet def relabel_aliases(self, change_map, node=None): @@ -243,7 +248,7 @@ class NothingNode(object): """ A node that matches nothing. """ - def as_sql(self, qn=None): + def as_sql(self, qn=None, connection=None): raise EmptyResultSet def relabel_aliases(self, change_map, node=None): @@ -257,7 +262,12 @@ class Constraint(object): def __init__(self, alias, col, field): self.alias, self.col, self.field = alias, col, field - def process(self, lookup_type, value): + def prepare(self, lookup_type, value): + if self.field: + return self.field.get_prep_lookup(lookup_type, value) + return value + + def process(self, lookup_type, value, connection): """ Returns a tuple of data suitable for inclusion in a WhereNode instance. @@ -266,16 +276,21 @@ class Constraint(object): from django.db.models.base import ObjectDoesNotExist try: if self.field: - params = self.field.get_db_prep_lookup(lookup_type, value) - db_type = self.field.db_type() + params = self.field.get_db_prep_lookup(lookup_type, value, + connection=connection, prepared=True) + db_type = self.field.db_type(connection=connection) else: # This branch is used at times when we add a comparison to NULL # (we don't really want to waste time looking up the associated # field object at the calling location). - params = Field().get_db_prep_lookup(lookup_type, value) + params = Field().get_db_prep_lookup(lookup_type, value, + connection=connection, prepared=True) db_type = None except ObjectDoesNotExist: raise EmptyShortCircuit return (self.alias, self.col, db_type), params + def relabel_aliases(self, change_map): + if self.alias in change_map: + self.alias = change_map[self.alias] |
