diff options
Diffstat (limited to 'django/db/models/sql/subqueries.py')
| -rw-r--r-- | django/db/models/sql/subqueries.py | 265 |
1 files changed, 39 insertions, 226 deletions
diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index f00f1bd68a..e80a023699 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -3,6 +3,7 @@ Query subclasses which provide extra functionality beyond simple data retrieval. """ from django.core.exceptions import FieldError +from django.db import connections from django.db.models.sql.constants import * from django.db.models.sql.datastructures import Date from django.db.models.sql.expressions import SQLEvaluator @@ -17,24 +18,15 @@ class DeleteQuery(Query): Delete queries are done through this class, since they are more constrained than general queries. """ - def as_sql(self): - """ - Creates the SQL for this query. Returns the SQL string and list of - parameters. - """ - assert len(self.tables) == 1, \ - "Can only delete from one table at a time." - result = ['DELETE FROM %s' % self.quote_name_unless_alias(self.tables[0])] - where, params = self.where.as_sql() - result.append('WHERE %s' % where) - return ' '.join(result), tuple(params) - def do_query(self, table, where): + compiler = 'SQLDeleteCompiler' + + def do_query(self, table, where, using): self.tables = [table] self.where = where - self.execute_sql(None) + self.get_compiler(using).execute_sql(None) - def delete_batch_related(self, pk_list): + def delete_batch_related(self, pk_list, using): """ Set up and execute delete queries for all the objects related to the primary key values in pk_list. To delete the objects themselves, use @@ -54,7 +46,7 @@ class DeleteQuery(Query): 'in', pk_list[offset : offset+GET_ITERATOR_CHUNK_SIZE]), AND) - self.do_query(related.field.m2m_db_table(), where) + self.do_query(related.field.m2m_db_table(), where, using=using) for f in cls._meta.many_to_many: w1 = self.where_class() @@ -70,9 +62,9 @@ class DeleteQuery(Query): AND) if w1: where.add(w1, AND) - self.do_query(f.m2m_db_table(), where) + self.do_query(f.m2m_db_table(), where, using=using) - def delete_batch(self, pk_list): + def delete_batch(self, pk_list, using): """ Set up and execute delete queries for all the objects in pk_list. This should be called after delete_batch_related(), if necessary. @@ -85,12 +77,15 @@ class DeleteQuery(Query): field = self.model._meta.pk where.add((Constraint(None, field.column, field), 'in', pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND) - self.do_query(self.model._meta.db_table, where) + self.do_query(self.model._meta.db_table, where, using=using) class UpdateQuery(Query): """ Represents an "update" SQL query. """ + + compiler = 'SQLUpdateCompiler' + def __init__(self, *args, **kwargs): super(UpdateQuery, self).__init__(*args, **kwargs) self._setup_query() @@ -110,98 +105,8 @@ class UpdateQuery(Query): return super(UpdateQuery, self).clone(klass, related_updates=self.related_updates.copy(), **kwargs) - def execute_sql(self, result_type=None): - """ - Execute the specified update. Returns the number of rows affected by - the primary update query. The "primary update query" is the first - non-empty query that is executed. Row counts for any subsequent, - related queries are not available. - """ - cursor = super(UpdateQuery, self).execute_sql(result_type) - rows = cursor and cursor.rowcount or 0 - is_empty = cursor is None - del cursor - for query in self.get_related_updates(): - aux_rows = query.execute_sql(result_type) - if is_empty: - rows = aux_rows - is_empty = False - return rows - - def as_sql(self): - """ - Creates the SQL for this query. Returns the SQL string and list of - parameters. - """ - self.pre_sql_setup() - if not self.values: - return '', () - table = self.tables[0] - qn = self.quote_name_unless_alias - result = ['UPDATE %s' % qn(table)] - result.append('SET') - values, update_params = [], [] - for name, val, placeholder in self.values: - if hasattr(val, 'as_sql'): - sql, params = val.as_sql(qn) - values.append('%s = %s' % (qn(name), sql)) - update_params.extend(params) - elif val is not None: - values.append('%s = %s' % (qn(name), placeholder)) - update_params.append(val) - else: - values.append('%s = NULL' % qn(name)) - result.append(', '.join(values)) - where, params = self.where.as_sql() - if where: - result.append('WHERE %s' % where) - return ' '.join(result), tuple(update_params + params) - - def pre_sql_setup(self): - """ - If the update depends on results from other tables, we need to do some - munging of the "where" conditions to match the format required for - (portable) SQL updates. That is done here. - Further, if we are going to be running multiple updates, we pull out - the id values to update at this point so that they don't change as a - result of the progressive updates. - """ - self.select_related = False - self.clear_ordering(True) - super(UpdateQuery, self).pre_sql_setup() - count = self.count_active_tables() - if not self.related_updates and count == 1: - return - - # We need to use a sub-select in the where clause to filter on things - # from other tables. - query = self.clone(klass=Query) - query.bump_prefix() - query.extra = {} - query.select = [] - query.add_fields([query.model._meta.pk.name]) - must_pre_select = count > 1 and not self.connection.features.update_can_self_select - - # Now we adjust the current query: reset the where clause and get rid - # of all the tables we don't need (since they're in the sub-select). - self.where = self.where_class() - if self.related_updates or must_pre_select: - # Either we're using the idents in multiple update queries (so - # don't want them to change), or the db backend doesn't support - # selecting from the updating table (e.g. MySQL). - idents = [] - for rows in query.execute_sql(MULTI): - idents.extend([r[0] for r in rows]) - self.add_filter(('pk__in', idents)) - self.related_ids = idents - else: - # The fast path. Filters and updates in one query. - self.add_filter(('pk__in', query)) - for alias in self.tables[1:]: - self.alias_refcount[alias] = 0 - - def clear_related(self, related_field, pk_list): + def clear_related(self, related_field, pk_list, using): """ Set up and execute an update query that clears related entries for the keys in pk_list. @@ -214,8 +119,8 @@ class UpdateQuery(Query): self.where.add((Constraint(None, f.column, f), 'in', pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND) - self.values = [(related_field.column, None, '%s')] - self.execute_sql(None) + self.values = [(related_field, None, None)] + self.get_compiler(using).execute_sql(None) def add_update_values(self, values): """ @@ -228,6 +133,9 @@ class UpdateQuery(Query): field, model, direct, m2m = self.model._meta.get_field_by_name(name) if not direct or m2m: raise FieldError('Cannot update model field %r (only non-relations and foreign keys permitted).' % field) + if model: + self.add_related_update(model, field, val) + continue values_seq.append((field, model, val)) return self.add_update_fields(values_seq) @@ -237,36 +145,18 @@ class UpdateQuery(Query): Used by add_update_values() as well as the "fast" update path when saving models. """ - from django.db.models.base import Model - for field, model, val in values_seq: - if hasattr(val, 'prepare_database_save'): - val = val.prepare_database_save(field) - else: - val = field.get_db_prep_save(val) + self.values.extend(values_seq) - # Getting the placeholder for the field. - if hasattr(field, 'get_placeholder'): - placeholder = field.get_placeholder(val) - else: - placeholder = '%s' - - if hasattr(val, 'evaluate'): - val = SQLEvaluator(val, self, allow_joins=False) - if model: - self.add_related_update(model, field.column, val, placeholder) - else: - self.values.append((field.column, val, placeholder)) - - def add_related_update(self, model, column, value, placeholder): + def add_related_update(self, model, field, value): """ Adds (name, value) to an update query for an ancestor model. Updates are coalesced so that we only run one update query per ancestor. """ try: - self.related_updates[model].append((column, value, placeholder)) + self.related_updates[model].append((field, None, value)) except KeyError: - self.related_updates[model] = [(column, value, placeholder)] + self.related_updates[model] = [(field, None, value)] def get_related_updates(self): """ @@ -278,7 +168,7 @@ class UpdateQuery(Query): return [] result = [] for model, values in self.related_updates.iteritems(): - query = UpdateQuery(model, self.connection) + query = UpdateQuery(model) query.values = values if self.related_ids: query.add_filter(('pk__in', self.related_ids)) @@ -286,45 +176,23 @@ class UpdateQuery(Query): return result class InsertQuery(Query): + compiler = 'SQLInsertCompiler' + def __init__(self, *args, **kwargs): super(InsertQuery, self).__init__(*args, **kwargs) self.columns = [] self.values = [] self.params = () - self.return_id = False def clone(self, klass=None, **kwargs): - extras = {'columns': self.columns[:], 'values': self.values[:], - 'params': self.params, 'return_id': self.return_id} + extras = { + 'columns': self.columns[:], + 'values': self.values[:], + 'params': self.params + } extras.update(kwargs) return super(InsertQuery, self).clone(klass, **extras) - def as_sql(self): - # We don't need quote_name_unless_alias() here, since these are all - # going to be column names (so we can avoid the extra overhead). - qn = self.connection.ops.quote_name - opts = self.model._meta - result = ['INSERT INTO %s' % qn(opts.db_table)] - result.append('(%s)' % ', '.join([qn(c) for c in self.columns])) - result.append('VALUES (%s)' % ', '.join(self.values)) - params = self.params - if self.return_id and self.connection.features.can_return_id_from_insert: - col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column)) - r_fmt, r_params = self.connection.ops.return_insert_id() - result.append(r_fmt % col) - params = params + r_params - return ' '.join(result), params - - def execute_sql(self, return_id=False): - self.return_id = return_id - cursor = super(InsertQuery, self).execute_sql(None) - if not (return_id and cursor): - return - if self.connection.features.can_return_id_from_insert: - return self.connection.ops.fetch_returned_insert_id(cursor) - return self.connection.ops.last_insert_id(cursor, - self.model._meta.db_table, self.model._meta.pk.column) - def insert_values(self, insert_values, raw_values=False): """ Set up the insert query from the 'insert_values' dictionary. The @@ -337,17 +205,11 @@ class InsertQuery(Query): """ placeholders, values = [], [] for field, val in insert_values: - if hasattr(field, 'get_placeholder'): - # Some fields (e.g. geo fields) need special munging before - # they can be inserted. - placeholders.append(field.get_placeholder(val)) - else: - placeholders.append('%s') - + placeholders.append((field, val)) self.columns.append(field.column) values.append(val) if raw_values: - self.values.extend(values) + self.values.extend([(None, v) for v in values]) else: self.params += tuple(values) self.values.extend(placeholders) @@ -358,44 +220,8 @@ class DateQuery(Query): date field. This requires some special handling when converting the results back to Python objects, so we put it in a separate class. """ - def __getstate__(self): - """ - Special DateQuery-specific pickle handling. - """ - for elt in self.select: - if isinstance(elt, Date): - # Eliminate a method reference that can't be pickled. The - # __setstate__ method restores this. - elt.date_sql_func = None - return super(DateQuery, self).__getstate__() - def __setstate__(self, obj_dict): - super(DateQuery, self).__setstate__(obj_dict) - for elt in self.select: - if isinstance(elt, Date): - self.date_sql_func = self.connection.ops.date_trunc_sql - - def results_iter(self): - """ - Returns an iterator over the results from executing this query. - """ - resolve_columns = hasattr(self, 'resolve_columns') - if resolve_columns: - from django.db.models.fields import DateTimeField - fields = [DateTimeField()] - else: - from django.db.backends.util import typecast_timestamp - needs_string_cast = self.connection.features.needs_datetime_string_cast - - offset = len(self.extra_select) - for rows in self.execute_sql(MULTI): - for row in rows: - date = row[offset] - if resolve_columns: - date = self.resolve_columns(row, fields)[offset] - elif needs_string_cast: - date = typecast_timestamp(str(date)) - yield date + compiler = 'SQLDateCompiler' def add_date_select(self, field, lookup_type, order='ASC'): """ @@ -404,8 +230,7 @@ class DateQuery(Query): result = self.setup_joins([field.name], self.get_meta(), self.get_initial_alias(), False) alias = result[3][-1] - select = Date((alias, field.column), lookup_type, - self.connection.ops.date_trunc_sql) + select = Date((alias, field.column), lookup_type) self.select = [select] self.select_fields = [None] self.select_related = False # See #7097. @@ -418,20 +243,8 @@ class AggregateQuery(Query): An AggregateQuery takes another query as a parameter to the FROM clause and only selects the elements in the provided list. """ - def add_subquery(self, query): - self.subquery, self.sub_params = query.as_sql(with_col_aliases=True) - def as_sql(self, quote_func=None): - """ - Creates the SQL for this query. Returns the SQL string and list of - parameters. - """ - sql = ('SELECT %s FROM (%s) subquery' % ( - ', '.join([ - aggregate.as_sql() - for aggregate in self.aggregate_select.values() - ]), - self.subquery) - ) - params = self.sub_params - return (sql, params) + compiler = 'SQLAggregateCompiler' + + def add_subquery(self, query, using): + self.subquery, self.sub_params = query.get_compiler(using).as_sql(with_col_aliases=True) |
