diff options
| author | Russell Keith-Magee <russell@keith-magee.com> | 2009-12-22 15:18:51 +0000 |
|---|---|---|
| committer | Russell Keith-Magee <russell@keith-magee.com> | 2009-12-22 15:18:51 +0000 |
| commit | ff60c5f9de3e8690d1e86f3e9e3f7248a15397c8 (patch) | |
| tree | a4cb0ebdd55fcaf8c8855231b6ad3e1a7bf45bee /django/db/models/sql | |
| parent | 7ef212af149540aa2da577a960d0d87029fd1514 (diff) | |
Fixed #1142 -- Added multiple database support.
This monster of a patch is the result of Alex Gaynor's 2009 Google Summer of Code project.
Congratulations to Alex for a job well done.
Big thanks also go to:
* Justin Bronn for keeping GIS in line with the changes,
* Karen Tracey and Jani Tiainen for their help testing Oracle support
* Brett Hoerner, Jon Loyens, and Craig Kimmerer for their feedback.
* Malcolm Treddinick for his guidance during the GSoC submission process.
* Simon Willison for driving the original design process
* Cal Henderson for complaining about ponies he wanted.
... and everyone else too numerous to mention that helped to bring this feature into fruition.
git-svn-id: http://code.djangoproject.com/svn/django/trunk@11952 bcc190cf-cafb-0310-a4f2-bffc1f526a37
Diffstat (limited to 'django/db/models/sql')
| -rw-r--r-- | django/db/models/sql/aggregates.py | 9 | ||||
| -rw-r--r-- | django/db/models/sql/compiler.py | 921 | ||||
| -rw-r--r-- | django/db/models/sql/datastructures.py | 12 | ||||
| -rw-r--r-- | django/db/models/sql/expressions.py | 22 | ||||
| -rw-r--r-- | django/db/models/sql/query.py | 766 | ||||
| -rw-r--r-- | django/db/models/sql/subqueries.py | 265 | ||||
| -rw-r--r-- | django/db/models/sql/where.py | 77 |
7 files changed, 1065 insertions, 1007 deletions
diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py index 6fdaf188c4..8a14bdf2df 100644 --- a/django/db/models/sql/aggregates.py +++ b/django/db/models/sql/aggregates.py @@ -72,15 +72,13 @@ class Aggregate(object): if isinstance(self.col, (list, tuple)): self.col = (change_map.get(self.col[0], self.col[0]), self.col[1]) - def as_sql(self, quote_func=None): + def as_sql(self, qn, connection): "Return the aggregate, rendered as SQL." - if not quote_func: - quote_func = lambda x: x if hasattr(self.col, 'as_sql'): - field_name = self.col.as_sql(quote_func) + field_name = self.col.as_sql(qn, connection) elif isinstance(self.col, (list, tuple)): - field_name = '.'.join([quote_func(c) for c in self.col]) + field_name = '.'.join([qn(c) for c in self.col]) else: field_name = self.col @@ -127,4 +125,3 @@ class Variance(Aggregate): def __init__(self, col, sample=False, **extra): super(Variance, self).__init__(col, **extra) self.sql_function = sample and 'VAR_SAMP' or 'VAR_POP' - diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py new file mode 100644 index 0000000000..6a95d32259 --- /dev/null +++ b/django/db/models/sql/compiler.py @@ -0,0 +1,921 @@ +from django.core.exceptions import FieldError +from django.db import connections +from django.db.backends.util import truncate_name +from django.db.models.sql.constants import * +from django.db.models.sql.datastructures import EmptyResultSet +from django.db.models.sql.expressions import SQLEvaluator +from django.db.models.sql.query import get_proxied_model, get_order_dir, \ + select_related_descend, Query + +class SQLCompiler(object): + def __init__(self, query, connection, using): + self.query = query + self.connection = connection + self.using = using + self.quote_cache = {} + + # Check that the compiler will be able to execute the query + for alias, aggregate in self.query.aggregate_select.items(): + self.connection.ops.check_aggregate_support(aggregate) + + def pre_sql_setup(self): + """ + Does any necessary class setup immediately prior to producing SQL. This + is for things that can't necessarily be done in __init__ because we + might not have all the pieces in place at that time. + """ + if not self.query.tables: + self.query.join((None, self.query.model._meta.db_table, None, None)) + if (not self.query.select and self.query.default_cols and not + self.query.included_inherited_models): + self.query.setup_inherited_models() + if self.query.select_related and not self.query.related_select_cols: + self.fill_related_selections() + + def quote_name_unless_alias(self, name): + """ + A wrapper around connection.ops.quote_name that doesn't quote aliases + for table names. This avoids problems with some SQL dialects that treat + quoted strings specially (e.g. PostgreSQL). + """ + if name in self.quote_cache: + return self.quote_cache[name] + if ((name in self.query.alias_map and name not in self.query.table_map) or + name in self.query.extra_select): + self.quote_cache[name] = name + return name + r = self.connection.ops.quote_name(name) + self.quote_cache[name] = r + return r + + def as_sql(self, with_limits=True, with_col_aliases=False): + """ + Creates the SQL for this query. Returns the SQL string and list of + parameters. + + If 'with_limits' is False, any limit/offset information is not included + in the query. + """ + self.pre_sql_setup() + out_cols = self.get_columns(with_col_aliases) + ordering, ordering_group_by = self.get_ordering() + + # This must come after 'select' and 'ordering' -- see docstring of + # get_from_clause() for details. + from_, f_params = self.get_from_clause() + + qn = self.quote_name_unless_alias + + where, w_params = self.query.where.as_sql(qn=qn, connection=self.connection) + having, h_params = self.query.having.as_sql(qn=qn, connection=self.connection) + params = [] + for val in self.query.extra_select.itervalues(): + params.extend(val[1]) + + result = ['SELECT'] + if self.query.distinct: + result.append('DISTINCT') + result.append(', '.join(out_cols + self.query.ordering_aliases)) + + result.append('FROM') + result.extend(from_) + params.extend(f_params) + + if where: + result.append('WHERE %s' % where) + params.extend(w_params) + if self.query.extra_where: + if not where: + result.append('WHERE') + else: + result.append('AND') + result.append(' AND '.join(self.query.extra_where)) + + grouping, gb_params = self.get_grouping() + if grouping: + if ordering: + # If the backend can't group by PK (i.e., any database + # other than MySQL), then any fields mentioned in the + # ordering clause needs to be in the group by clause. + if not self.connection.features.allows_group_by_pk: + for col, col_params in ordering_group_by: + if col not in grouping: + grouping.append(str(col)) + gb_params.extend(col_params) + else: + ordering = self.connection.ops.force_no_ordering() + result.append('GROUP BY %s' % ', '.join(grouping)) + params.extend(gb_params) + + if having: + result.append('HAVING %s' % having) + params.extend(h_params) + + if ordering: + result.append('ORDER BY %s' % ', '.join(ordering)) + + if with_limits: + if self.query.high_mark is not None: + result.append('LIMIT %d' % (self.query.high_mark - self.query.low_mark)) + if self.query.low_mark: + if self.query.high_mark is None: + val = self.connection.ops.no_limit_value() + if val: + result.append('LIMIT %d' % val) + result.append('OFFSET %d' % self.query.low_mark) + + params.extend(self.query.extra_params) + return ' '.join(result), tuple(params) + + def as_nested_sql(self): + """ + Perform the same functionality as the as_sql() method, returning an + SQL string and parameters. However, the alias prefixes are bumped + beforehand (in a copy -- the current query isn't changed) and any + ordering is removed. + + Used when nesting this query inside another. + """ + obj = self.query.clone() + obj.clear_ordering(True) + obj.bump_prefix() + return obj.get_compiler(connection=self.connection).as_sql() + + def get_columns(self, with_aliases=False): + """ + Returns the list of columns to use in the select statement. If no + columns have been specified, returns all columns relating to fields in + the model. + + If 'with_aliases' is true, any column names that are duplicated + (without the table names) are given unique aliases. This is needed in + some cases to avoid ambiguity with nested queries. + """ + qn = self.quote_name_unless_alias + qn2 = self.connection.ops.quote_name + result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in self.query.extra_select.iteritems()] + aliases = set(self.query.extra_select.keys()) + if with_aliases: + col_aliases = aliases.copy() + else: + col_aliases = set() + if self.query.select: + only_load = self.deferred_to_columns() + for col in self.query.select: + if isinstance(col, (list, tuple)): + alias, column = col + table = self.query.alias_map[alias][TABLE_NAME] + if table in only_load and col not in only_load[table]: + continue + r = '%s.%s' % (qn(alias), qn(column)) + if with_aliases: + if col[1] in col_aliases: + c_alias = 'Col%d' % len(col_aliases) + result.append('%s AS %s' % (r, c_alias)) + aliases.add(c_alias) + col_aliases.add(c_alias) + else: + result.append('%s AS %s' % (r, qn2(col[1]))) + aliases.add(r) + col_aliases.add(col[1]) + else: + result.append(r) + aliases.add(r) + col_aliases.add(col[1]) + else: + result.append(col.as_sql(qn, self.connection)) + + if hasattr(col, 'alias'): + aliases.add(col.alias) + col_aliases.add(col.alias) + + elif self.query.default_cols: + cols, new_aliases = self.get_default_columns(with_aliases, + col_aliases) + result.extend(cols) + aliases.update(new_aliases) + + max_name_length = self.connection.ops.max_name_length() + result.extend([ + '%s%s' % ( + aggregate.as_sql(qn, self.connection), + alias is not None + and ' AS %s' % qn(truncate_name(alias, max_name_length)) + or '' + ) + for alias, aggregate in self.query.aggregate_select.items() + ]) + + for table, col in self.query.related_select_cols: + r = '%s.%s' % (qn(table), qn(col)) + if with_aliases and col in col_aliases: + c_alias = 'Col%d' % len(col_aliases) + result.append('%s AS %s' % (r, c_alias)) + aliases.add(c_alias) + col_aliases.add(c_alias) + else: + result.append(r) + aliases.add(r) + col_aliases.add(col) + + self._select_aliases = aliases + return result + + def get_default_columns(self, with_aliases=False, col_aliases=None, + start_alias=None, opts=None, as_pairs=False): + """ + Computes the default columns for selecting every field in the base + model. Will sometimes be called to pull in related models (e.g. via + select_related), in which case "opts" and "start_alias" will be given + to provide a starting point for the traversal. + + Returns a list of strings, quoted appropriately for use in SQL + directly, as well as a set of aliases used in the select statement (if + 'as_pairs' is True, returns a list of (alias, col_name) pairs instead + of strings as the first component and None as the second component). + """ + result = [] + if opts is None: + opts = self.query.model._meta + qn = self.quote_name_unless_alias + qn2 = self.connection.ops.quote_name + aliases = set() + only_load = self.deferred_to_columns() + # Skip all proxy to the root proxied model + proxied_model = get_proxied_model(opts) + + if start_alias: + seen = {None: start_alias} + for field, model in opts.get_fields_with_model(): + if start_alias: + try: + alias = seen[model] + except KeyError: + if model is proxied_model: + alias = start_alias + else: + link_field = opts.get_ancestor_link(model) + alias = self.query.join((start_alias, model._meta.db_table, + link_field.column, model._meta.pk.column)) + seen[model] = alias + else: + # If we're starting from the base model of the queryset, the + # aliases will have already been set up in pre_sql_setup(), so + # we can save time here. + alias = self.query.included_inherited_models[model] + table = self.query.alias_map[alias][TABLE_NAME] + if table in only_load and field.column not in only_load[table]: + continue + if as_pairs: + result.append((alias, field.column)) + aliases.add(alias) + continue + if with_aliases and field.column in col_aliases: + c_alias = 'Col%d' % len(col_aliases) + result.append('%s.%s AS %s' % (qn(alias), + qn2(field.column), c_alias)) + col_aliases.add(c_alias) + aliases.add(c_alias) + else: + r = '%s.%s' % (qn(alias), qn2(field.column)) + result.append(r) + aliases.add(r) + if with_aliases: + col_aliases.add(field.column) + return result, aliases + + def get_ordering(self): + """ + Returns a tuple containing a list representing the SQL elements in the + "order by" clause, and the list of SQL elements that need to be added + to the GROUP BY clause as a result of the ordering. + + Also sets the ordering_aliases attribute on this instance to a list of + extra aliases needed in the select. + + Determining the ordering SQL can change the tables we need to include, + so this should be run *before* get_from_clause(). + """ + if self.query.extra_order_by: + ordering = self.query.extra_order_by + elif not self.query.default_ordering: + ordering = self.query.order_by + else: + ordering = self.query.order_by or self.query.model._meta.ordering + qn = self.quote_name_unless_alias + qn2 = self.connection.ops.quote_name + distinct = self.query.distinct + select_aliases = self._select_aliases + result = [] + group_by = [] + ordering_aliases = [] + if self.query.standard_ordering: + asc, desc = ORDER_DIR['ASC'] + else: + asc, desc = ORDER_DIR['DESC'] + + # It's possible, due to model inheritance, that normal usage might try + # to include the same field more than once in the ordering. We track + # the table/column pairs we use and discard any after the first use. + processed_pairs = set() + + for field in ordering: + if field == '?': + result.append(self.connection.ops.random_function_sql()) + continue + if isinstance(field, int): + if field < 0: + order = desc + field = -field + else: + order = asc + result.append('%s %s' % (field, order)) + group_by.append((field, [])) + continue + col, order = get_order_dir(field, asc) + if col in self.query.aggregate_select: + result.append('%s %s' % (col, order)) + continue + if '.' in field: + # This came in through an extra(order_by=...) addition. Pass it + # on verbatim. + table, col = col.split('.', 1) + if (table, col) not in processed_pairs: + elt = '%s.%s' % (qn(table), col) + processed_pairs.add((table, col)) + if not distinct or elt in select_aliases: + result.append('%s %s' % (elt, order)) + group_by.append((elt, [])) + elif get_order_dir(field)[0] not in self.query.extra_select: + # 'col' is of the form 'field' or 'field1__field2' or + # '-field1__field2__field', etc. + for table, col, order in self.find_ordering_name(field, + self.query.model._meta, default_order=asc): + if (table, col) not in processed_pairs: + elt = '%s.%s' % (qn(table), qn2(col)) + processed_pairs.add((table, col)) + if distinct and elt not in select_aliases: + ordering_aliases.append(elt) + result.append('%s %s' % (elt, order)) + group_by.append((elt, [])) + else: + elt = qn2(col) + if distinct and col not in select_aliases: + ordering_aliases.append(elt) + result.append('%s %s' % (elt, order)) + group_by.append(self.query.extra_select[col]) + self.query.ordering_aliases = ordering_aliases + return result, group_by + + def find_ordering_name(self, name, opts, alias=None, default_order='ASC', + already_seen=None): + """ + Returns the table alias (the name might be ambiguous, the alias will + not be) and column name for ordering by the given 'name' parameter. + The 'name' is of the form 'field1__field2__...__fieldN'. + """ + name, order = get_order_dir(name, default_order) + pieces = name.split(LOOKUP_SEP) + if not alias: + alias = self.query.get_initial_alias() + field, target, opts, joins, last, extra = self.query.setup_joins(pieces, + opts, alias, False) + alias = joins[-1] + col = target.column + if not field.rel: + # To avoid inadvertent trimming of a necessary alias, use the + # refcount to show that we are referencing a non-relation field on + # the model. + self.query.ref_alias(alias) + + # Must use left outer joins for nullable fields and their relations. + self.query.promote_alias_chain(joins, + self.query.alias_map[joins[0]][JOIN_TYPE] == self.query.LOUTER) + + # If we get to this point and the field is a relation to another model, + # append the default ordering for that model. + if field.rel and len(joins) > 1 and opts.ordering: + # Firstly, avoid infinite loops. + if not already_seen: + already_seen = set() + join_tuple = tuple([self.query.alias_map[j][TABLE_NAME] for j in joins]) + if join_tuple in already_seen: + raise FieldError('Infinite loop caused by ordering.') + already_seen.add(join_tuple) + + results = [] + for item in opts.ordering: + results.extend(self.find_ordering_name(item, opts, alias, + order, already_seen)) + return results + + if alias: + # We have to do the same "final join" optimisation as in + # add_filter, since the final column might not otherwise be part of + # the select set (so we can't order on it). + while 1: + join = self.query.alias_map[alias] + if col != join[RHS_JOIN_COL]: + break + self.query.unref_alias(alias) + alias = join[LHS_ALIAS] + col = join[LHS_JOIN_COL] + return [(alias, col, order)] + + def get_from_clause(self): + """ + Returns a list of strings that are joined together to go after the + "FROM" part of the query, as well as a list any extra parameters that + need to be included. Sub-classes, can override this to create a + from-clause via a "select". + + This should only be called after any SQL construction methods that + might change the tables we need. This means the select columns and + ordering must be done first. + """ + result = [] + qn = self.quote_name_unless_alias + qn2 = self.connection.ops.quote_name + first = True + for alias in self.query.tables: + if not self.query.alias_refcount[alias]: + continue + try: + name, alias, join_type, lhs, lhs_col, col, nullable = self.query.alias_map[alias] + except KeyError: + # Extra tables can end up in self.tables, but not in the + # alias_map if they aren't in a join. That's OK. We skip them. + continue + alias_str = (alias != name and ' %s' % alias or '') + if join_type and not first: + result.append('%s %s%s ON (%s.%s = %s.%s)' + % (join_type, qn(name), alias_str, qn(lhs), + qn2(lhs_col), qn(alias), qn2(col))) + else: + connector = not first and ', ' or '' + result.append('%s%s%s' % (connector, qn(name), alias_str)) + first = False + for t in self.query.extra_tables: + alias, unused = self.query.table_alias(t) + # Only add the alias if it's not already present (the table_alias() + # calls increments the refcount, so an alias refcount of one means + # this is the only reference. + if alias not in self.query.alias_map or self.query.alias_refcount[alias] == 1: + connector = not first and ', ' or '' + result.append('%s%s' % (connector, qn(alias))) + first = False + return result, [] + + def get_grouping(self): + """ + Returns a tuple representing the SQL elements in the "group by" clause. + """ + qn = self.quote_name_unless_alias + result, params = [], [] + if self.query.group_by is not None: + if len(self.query.model._meta.fields) == len(self.query.select) and \ + self.connection.features.allows_group_by_pk: + self.query.group_by = [(self.query.model._meta.db_table, self.query.model._meta.pk.column)] + + group_by = self.query.group_by or [] + + extra_selects = [] + for extra_select, extra_params in self.query.extra_select.itervalues(): + extra_selects.append(extra_select) + params.extend(extra_params) + for col in group_by + self.query.related_select_cols + extra_selects: + if isinstance(col, (list, tuple)): + result.append('%s.%s' % (qn(col[0]), qn(col[1]))) + elif hasattr(col, 'as_sql'): + result.append(col.as_sql(qn)) + else: + result.append(str(col)) + return result, params + + def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1, + used=None, requested=None, restricted=None, nullable=None, + dupe_set=None, avoid_set=None): + """ + Fill in the information needed for a select_related query. The current + depth is measured as the number of connections away from the root model + (for example, cur_depth=1 means we are looking at models with direct + connections to the root model). + """ + if not restricted and self.query.max_depth and cur_depth > self.query.max_depth: + # We've recursed far enough; bail out. + return + + if not opts: + opts = self.query.get_meta() + root_alias = self.query.get_initial_alias() + self.query.related_select_cols = [] + self.query.related_select_fields = [] + if not used: + used = set() + if dupe_set is None: + dupe_set = set() + if avoid_set is None: + avoid_set = set() + orig_dupe_set = dupe_set + + # Setup for the case when only particular related fields should be + # included in the related selection. + if requested is None and restricted is not False: + if isinstance(self.query.select_related, dict): + requested = self.query.select_related + restricted = True + else: + restricted = False + + for f, model in opts.get_fields_with_model(): + if not select_related_descend(f, restricted, requested): + continue + # The "avoid" set is aliases we want to avoid just for this + # particular branch of the recursion. They aren't permanently + # forbidden from reuse in the related selection tables (which is + # what "used" specifies). + avoid = avoid_set.copy() + dupe_set = orig_dupe_set.copy() + table = f.rel.to._meta.db_table + if nullable or f.null: + promote = True + else: + promote = False + if model: + int_opts = opts + alias = root_alias + alias_chain = [] + for int_model in opts.get_base_chain(model): + # Proxy model have elements in base chain + # with no parents, assign the new options + # object and skip to the next base in that + # case + if not int_opts.parents[int_model]: + int_opts = int_model._meta + continue + lhs_col = int_opts.parents[int_model].column + dedupe = lhs_col in opts.duplicate_targets + if dedupe: + avoid.update(self.query.dupe_avoidance.get(id(opts), lhs_col), + ()) + dupe_set.add((opts, lhs_col)) + int_opts = int_model._meta + alias = self.query.join((alias, int_opts.db_table, lhs_col, + int_opts.pk.column), exclusions=used, + promote=promote) + alias_chain.append(alias) + for (dupe_opts, dupe_col) in dupe_set: + self.query.update_dupe_avoidance(dupe_opts, dupe_col, alias) + if self.query.alias_map[root_alias][JOIN_TYPE] == self.query.LOUTER: + self.query.promote_alias_chain(alias_chain, True) + else: + alias = root_alias + + dedupe = f.column in opts.duplicate_targets + if dupe_set or dedupe: + avoid.update(self.query.dupe_avoidance.get((id(opts), f.column), ())) + if dedupe: + dupe_set.add((opts, f.column)) + + alias = self.query.join((alias, table, f.column, + f.rel.get_related_field().column), + exclusions=used.union(avoid), promote=promote) + used.add(alias) + columns, aliases = self.get_default_columns(start_alias=alias, + opts=f.rel.to._meta, as_pairs=True) + self.query.related_select_cols.extend(columns) + if self.query.alias_map[alias][JOIN_TYPE] == self.query.LOUTER: + self.query.promote_alias_chain(aliases, True) + self.query.related_select_fields.extend(f.rel.to._meta.fields) + if restricted: + next = requested.get(f.name, {}) + else: + next = False + if f.null is not None: + new_nullable = f.null + else: + new_nullable = None + for dupe_opts, dupe_col in dupe_set: + self.query.update_dupe_avoidance(dupe_opts, dupe_col, alias) + self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1, + used, next, restricted, new_nullable, dupe_set, avoid) + + def deferred_to_columns(self): + """ + Converts the self.deferred_loading data structure to mapping of table + names to sets of column names which are to be loaded. Returns the + dictionary. + """ + columns = {} + self.query.deferred_to_data(columns, self.query.deferred_to_columns_cb) + return columns + + def results_iter(self): + """ + Returns an iterator over the results from executing this query. + """ + resolve_columns = hasattr(self, 'resolve_columns') + fields = None + for rows in self.execute_sql(MULTI): + for row in rows: + if resolve_columns: + if fields is None: + # We only set this up here because + # related_select_fields isn't populated until + # execute_sql() has been called. + if self.query.select_fields: + fields = self.query.select_fields + self.query.related_select_fields + else: + fields = self.query.model._meta.fields + # If the field was deferred, exclude it from being passed + # into `resolve_columns` because it wasn't selected. + only_load = self.deferred_to_columns() + if only_load: + db_table = self.query.model._meta.db_table + fields = [f for f in fields if db_table in only_load and + f.column in only_load[db_table]] + row = self.resolve_columns(row, fields) + + if self.query.aggregate_select: + aggregate_start = len(self.query.extra_select.keys()) + len(self.query.select) + aggregate_end = aggregate_start + len(self.query.aggregate_select) + row = tuple(row[:aggregate_start]) + tuple([ + self.query.resolve_aggregate(value, aggregate, self.connection) + for (alias, aggregate), value + in zip(self.query.aggregate_select.items(), row[aggregate_start:aggregate_end]) + ]) + tuple(row[aggregate_end:]) + + yield row + + def execute_sql(self, result_type=MULTI): + """ + Run the query against the database and returns the result(s). The + return value is a single data item if result_type is SINGLE, or an + iterator over the results if the result_type is MULTI. + + result_type is either MULTI (use fetchmany() to retrieve all rows), + SINGLE (only retrieve a single row), or None. In this last case, the + cursor is returned if any query is executed, since it's used by + subclasses such as InsertQuery). It's possible, however, that no query + is needed, as the filters describe an empty set. In that case, None is + returned, to avoid any unnecessary database interaction. + """ + try: + sql, params = self.as_sql() + if not sql: + raise EmptyResultSet + except EmptyResultSet: + if result_type == MULTI: + return empty_iter() + else: + return + + cursor = self.connection.cursor() + cursor.execute(sql, params) + + if not result_type: + return cursor + if result_type == SINGLE: + if self.query.ordering_aliases: + return cursor.fetchone()[:-len(self.query.ordering_aliases)] + return cursor.fetchone() + + # The MULTI case. + if self.query.ordering_aliases: + result = order_modified_iter(cursor, len(self.query.ordering_aliases), + self.connection.features.empty_fetchmany_value) + else: + result = iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)), + self.connection.features.empty_fetchmany_value) + if not self.connection.features.can_use_chunked_reads: + # If we are using non-chunked reads, we return the same data + # structure as normally, but ensure it is all read into memory + # before going any further. + return list(result) + return result + + +class SQLInsertCompiler(SQLCompiler): + def placeholder(self, field, val): + if field is None: + # A field value of None means the value is raw. + return val + elif hasattr(field, 'get_placeholder'): + # Some fields (e.g. geo fields) need special munging before + # they can be inserted. + return field.get_placeholder(val, self.connection) + else: + # Return the common case for the placeholder + return '%s' + + def as_sql(self): + # We don't need quote_name_unless_alias() here, since these are all + # going to be column names (so we can avoid the extra overhead). + qn = self.connection.ops.quote_name + opts = self.query.model._meta + result = ['INSERT INTO %s' % qn(opts.db_table)] + result.append('(%s)' % ', '.join([qn(c) for c in self.query.columns])) + values = [self.placeholder(*v) for v in self.query.values] + result.append('VALUES (%s)' % ', '.join(values)) + params = self.query.params + if self.return_id and self.connection.features.can_return_id_from_insert: + col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column)) + r_fmt, r_params = self.connection.ops.return_insert_id() + result.append(r_fmt % col) + params = params + r_params + return ' '.join(result), params + + def execute_sql(self, return_id=False): + self.return_id = return_id + cursor = super(SQLInsertCompiler, self).execute_sql(None) + if not (return_id and cursor): + return + if self.connection.features.can_return_id_from_insert: + return self.connection.ops.fetch_returned_insert_id(cursor) + return self.connection.ops.last_insert_id(cursor, + self.query.model._meta.db_table, self.query.model._meta.pk.column) + + +class SQLDeleteCompiler(SQLCompiler): + def as_sql(self): + """ + Creates the SQL for this query. Returns the SQL string and list of + parameters. + """ + assert len(self.query.tables) == 1, \ + "Can only delete from one table at a time." + qn = self.quote_name_unless_alias + result = ['DELETE FROM %s' % qn(self.query.tables[0])] + where, params = self.query.where.as_sql(qn=qn, connection=self.connection) + result.append('WHERE %s' % where) + return ' '.join(result), tuple(params) + +class SQLUpdateCompiler(SQLCompiler): + def as_sql(self): + """ + Creates the SQL for this query. Returns the SQL string and list of + parameters. + """ + from django.db.models.base import Model + + self.pre_sql_setup() + if not self.query.values: + return '', () + table = self.query.tables[0] + qn = self.quote_name_unless_alias + result = ['UPDATE %s' % qn(table)] + result.append('SET') + values, update_params = [], [] + for field, model, val in self.query.values: + if hasattr(val, 'prepare_database_save'): + val = val.prepare_database_save(field) + else: + val = field.get_db_prep_save(val, connection=self.connection) + + # Getting the placeholder for the field. + if hasattr(field, 'get_placeholder'): + placeholder = field.get_placeholder(val, self.connection) + else: + placeholder = '%s' + + if hasattr(val, 'evaluate'): + val = SQLEvaluator(val, self.query, allow_joins=False) + name = field.column + if hasattr(val, 'as_sql'): + sql, params = val.as_sql(qn, self.connection) + values.append('%s = %s' % (qn(name), sql)) + update_params.extend(params) + elif val is not None: + values.append('%s = %s' % (qn(name), placeholder)) + update_params.append(val) + else: + values.append('%s = NULL' % qn(name)) + if not values: + return '', () + result.append(', '.join(values)) + where, params = self.query.where.as_sql(qn=qn, connection=self.connection) + if where: + result.append('WHERE %s' % where) + return ' '.join(result), tuple(update_params + params) + + def execute_sql(self, result_type): + """ + Execute the specified update. Returns the number of rows affected by + the primary update query. The "primary update query" is the first + non-empty query that is executed. Row counts for any subsequent, + related queries are not available. + """ + cursor = super(SQLUpdateCompiler, self).execute_sql(result_type) + rows = cursor and cursor.rowcount or 0 + is_empty = cursor is None + del cursor + for query in self.query.get_related_updates(): + aux_rows = query.get_compiler(self.using).execute_sql(result_type) + if is_empty: + rows = aux_rows + is_empty = False + return rows + + def pre_sql_setup(self): + """ + If the update depends on results from other tables, we need to do some + munging of the "where" conditions to match the format required for + (portable) SQL updates. That is done here. + + Further, if we are going to be running multiple updates, we pull out + the id values to update at this point so that they don't change as a + result of the progressive updates. + """ + self.query.select_related = False + self.query.clear_ordering(True) + super(SQLUpdateCompiler, self).pre_sql_setup() + count = self.query.count_active_tables() + if not self.query.related_updates and count == 1: + return + + # We need to use a sub-select in the where clause to filter on things + # from other tables. + query = self.query.clone(klass=Query) + query.bump_prefix() + query.extra = {} + query.select = [] + query.add_fields([query.model._meta.pk.name]) + must_pre_select = count > 1 and not self.connection.features.update_can_self_select + + # Now we adjust the current query: reset the where clause and get rid + # of all the tables we don't need (since they're in the sub-select). + self.query.where = self.query.where_class() + if self.query.related_updates or must_pre_select: + # Either we're using the idents in multiple update queries (so + # don't want them to change), or the db backend doesn't support + # selecting from the updating table (e.g. MySQL). + idents = [] + for rows in query.get_compiler(self.using).execute_sql(MULTI): + idents.extend([r[0] for r in rows]) + self.query.add_filter(('pk__in', idents)) + self.query.related_ids = idents + else: + # The fast path. Filters and updates in one query. + self.query.add_filter(('pk__in', query)) + for alias in self.query.tables[1:]: + self.query.alias_refcount[alias] = 0 + +class SQLAggregateCompiler(SQLCompiler): + def as_sql(self, qn=None): + """ + Creates the SQL for this query. Returns the SQL string and list of + parameters. + """ + if qn is None: + qn = self.quote_name_unless_alias + sql = ('SELECT %s FROM (%s) subquery' % ( + ', '.join([ + aggregate.as_sql(qn, self.connection) + for aggregate in self.query.aggregate_select.values() + ]), + self.query.subquery) + ) + params = self.query.sub_params + return (sql, params) + +class SQLDateCompiler(SQLCompiler): + def results_iter(self): + """ + Returns an iterator over the results from executing this query. + """ + resolve_columns = hasattr(self, 'resolve_columns') + if resolve_columns: + from django.db.models.fields import DateTimeField + fields = [DateTimeField()] + else: + from django.db.backends.util import typecast_timestamp + needs_string_cast = self.connection.features.needs_datetime_string_cast + + offset = len(self.query.extra_select) + for rows in self.execute_sql(MULTI): + for row in rows: + date = row[offset] + if resolve_columns: + date = self.resolve_columns(row, fields)[offset] + elif needs_string_cast: + date = typecast_timestamp(str(date)) + yield date + + +def empty_iter(): + """ + Returns an iterator containing no results. + """ + yield iter([]).next() + + +def order_modified_iter(cursor, trim, sentinel): + """ + Yields blocks of rows from a cursor. We use this iterator in the special + case when extra output columns have been added to support ordering + requirements. We must trim those extra columns before anything else can use + the results, since they're only needed to make the SQL valid. + """ + for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)), + sentinel): + yield [r[:-trim] for r in rows] diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py index 4d53999c79..92d64e15dd 100644 --- a/django/db/models/sql/datastructures.py +++ b/django/db/models/sql/datastructures.py @@ -29,22 +29,18 @@ class Date(object): """ Add a date selection column. """ - def __init__(self, col, lookup_type, date_sql_func): + def __init__(self, col, lookup_type): self.col = col self.lookup_type = lookup_type - self.date_sql_func = date_sql_func def relabel_aliases(self, change_map): c = self.col if isinstance(c, (list, tuple)): self.col = (change_map.get(c[0], c[0]), c[1]) - def as_sql(self, quote_func=None): - if not quote_func: - quote_func = lambda x: x + def as_sql(self, qn, connection): if isinstance(self.col, (list, tuple)): - col = '%s.%s' % tuple([quote_func(c) for c in self.col]) + col = '%s.%s' % tuple([qn(c) for c in self.col]) else: col = self.col - return self.date_sql_func(self.lookup_type, col) - + return connection.ops.date_trunc_sql(self.lookup_type, col) diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py index 0914c2b3c1..9bbc16ec8a 100644 --- a/django/db/models/sql/expressions.py +++ b/django/db/models/sql/expressions.py @@ -1,5 +1,4 @@ from django.core.exceptions import FieldError -from django.db import connection from django.db.models.fields import FieldDoesNotExist from django.db.models.sql.constants import LOOKUP_SEP @@ -12,8 +11,11 @@ class SQLEvaluator(object): self.contains_aggregate = False self.expression.prepare(self, query, allow_joins) - def as_sql(self, qn=None): - return self.expression.evaluate(self, qn) + def prepare(self): + return self + + def as_sql(self, qn, connection): + return self.expression.evaluate(self, qn, connection) def relabel_aliases(self, change_map): for node, col in self.cols.items(): @@ -54,15 +56,12 @@ class SQLEvaluator(object): # Vistor methods for final expression evaluation # ################################################## - def evaluate_node(self, node, qn): - if not qn: - qn = connection.ops.quote_name - + def evaluate_node(self, node, qn, connection): expressions = [] expression_params = [] for child in node.children: if hasattr(child, 'evaluate'): - sql, params = child.evaluate(self, qn) + sql, params = child.evaluate(self, qn, connection) else: sql, params = '%s', (child,) @@ -77,12 +76,9 @@ class SQLEvaluator(object): return connection.ops.combine_expression(node.connector, expressions), expression_params - def evaluate_leaf(self, node, qn): - if not qn: - qn = connection.ops.quote_name - + def evaluate_leaf(self, node, qn, connection): col = self.cols[node] if hasattr(col, 'as_sql'): - return col.as_sql(qn), () + return col.as_sql(qn, connection), () else: return '%s.%s' % (qn(col[0]), qn(col[1])), () diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 9ecf273be3..d821c0ee02 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -11,32 +11,34 @@ from django.utils.copycompat import deepcopy from django.utils.tree import Node from django.utils.datastructures import SortedDict from django.utils.encoding import force_unicode -from django.db.backends.util import truncate_name -from django.db import connection +from django.db import connections, DEFAULT_DB_ALIAS from django.db.models import signals from django.db.models.fields import FieldDoesNotExist from django.db.models.query_utils import select_related_descend, InvalidQuery from django.db.models.sql import aggregates as base_aggregates_module +from django.db.models.sql.constants import * +from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin from django.db.models.sql.expressions import SQLEvaluator from django.db.models.sql.where import WhereNode, Constraint, EverythingNode, AND, OR from django.core.exceptions import FieldError -from datastructures import EmptyResultSet, Empty, MultiJoin -from constants import * -__all__ = ['Query', 'BaseQuery', 'RawQuery'] +__all__ = ['Query', 'RawQuery'] class RawQuery(object): """ A single raw SQL query """ - def __init__(self, sql, connection, params=None): + def __init__(self, sql, using, params=None): self.validate_sql(sql) self.params = params or () self.sql = sql - self.connection = connection + self.using = using self.cursor = None + def clone(self, using): + return RawQuery(self.sql, using, params=self.params) + def get_columns(self): if self.cursor is None: self._execute_query() @@ -57,10 +59,11 @@ class RawQuery(object): return "<RawQuery: %r>" % (self.sql % self.params) def _execute_query(self): - self.cursor = self.connection.cursor() + self.cursor = connections[self.using].cursor() self.cursor.execute(self.sql, self.params) -class BaseQuery(object): + +class Query(object): """ A single SQL query. """ @@ -73,9 +76,10 @@ class BaseQuery(object): query_terms = QUERY_TERMS aggregates_module = base_aggregates_module - def __init__(self, model, connection, where=WhereNode): + compiler = 'SQLCompiler' + + def __init__(self, model, where=WhereNode): self.model = model - self.connection = connection self.alias_refcount = {} self.alias_map = {} # Maps alias to join information self.table_map = {} # Maps table names to list of aliases. @@ -139,7 +143,7 @@ class BaseQuery(object): Parameter values won't necessarily be quoted correctly, since that is done by the database interface at execution time. """ - sql, params = self.as_sql() + sql, params = self.get_compiler(DEFAULT_DB_ALIAS).as_sql() return sql % params def __deepcopy__(self, memo): @@ -154,7 +158,6 @@ class BaseQuery(object): obj_dict = self.__dict__.copy() obj_dict['related_select_fields'] = [] obj_dict['related_select_cols'] = [] - del obj_dict['connection'] # Fields can't be pickled, so if a field list has been # specified, we pickle the list of field names instead. @@ -176,10 +179,16 @@ class BaseQuery(object): ] self.__dict__.update(obj_dict) - # XXX: Need a better solution for this when multi-db stuff is - # supported. It's the only class-reference to the module-level - # connection variable. - self.connection = connection + + def prepare(self): + return self + + def get_compiler(self, using=None, connection=None): + if using is None and connection is None: + raise ValueError("Need either using or connection") + if using: + connection = connections[using] + return connection.ops.compiler(self.compiler)(self, connection, using) def get_meta(self): """ @@ -189,22 +198,6 @@ class BaseQuery(object): """ return self.model._meta - def quote_name_unless_alias(self, name): - """ - A wrapper around connection.ops.quote_name that doesn't quote aliases - for table names. This avoids problems with some SQL dialects that treat - quoted strings specially (e.g. PostgreSQL). - """ - if name in self.quote_cache: - return self.quote_cache[name] - if ((name in self.alias_map and name not in self.table_map) or - name in self.extra_select): - self.quote_cache[name] = name - return name - r = self.connection.ops.quote_name(name) - self.quote_cache[name] = r - return r - def clone(self, klass=None, **kwargs): """ Creates a copy of the current instance. The 'kwargs' parameter can be @@ -213,7 +206,6 @@ class BaseQuery(object): obj = Empty() obj.__class__ = klass or self.__class__ obj.model = self.model - obj.connection = self.connection obj.alias_refcount = self.alias_refcount.copy() obj.alias_map = self.alias_map.copy() obj.table_map = self.table_map.copy() @@ -276,16 +268,16 @@ class BaseQuery(object): obj._setup_query() return obj - def convert_values(self, value, field): + def convert_values(self, value, field, connection): """Convert the database-returned value into a type that is consistent across database backends. By default, this defers to the underlying backend operations, but it can be overridden by Query classes for specific backends. """ - return self.connection.ops.convert_values(value, field) + return connection.ops.convert_values(value, field) - def resolve_aggregate(self, value, aggregate): + def resolve_aggregate(self, value, aggregate, connection): """Resolve the value of aggregates returned by the database to consistent (and reasonable) types. @@ -305,39 +297,9 @@ class BaseQuery(object): return float(value) else: # Return value depends on the type of the field being processed. - return self.convert_values(value, aggregate.field) - - def results_iter(self): - """ - Returns an iterator over the results from executing this query. - """ - resolve_columns = hasattr(self, 'resolve_columns') - fields = None - for rows in self.execute_sql(MULTI): - for row in rows: - if resolve_columns: - if fields is None: - # We only set this up here because - # related_select_fields isn't populated until - # execute_sql() has been called. - if self.select_fields: - fields = self.select_fields + self.related_select_fields - else: - fields = self.model._meta.fields - row = self.resolve_columns(row, fields) - - if self.aggregate_select: - aggregate_start = len(self.extra_select.keys()) + len(self.select) - aggregate_end = aggregate_start + len(self.aggregate_select) - row = tuple(row[:aggregate_start]) + tuple([ - self.resolve_aggregate(value, aggregate) - for (alias, aggregate), value - in zip(self.aggregate_select.items(), row[aggregate_start:aggregate_end]) - ]) + tuple(row[aggregate_end:]) - - yield row + return self.convert_values(value, aggregate.field, connection) - def get_aggregation(self): + def get_aggregation(self, using): """ Returns the dictionary with the values of the existing aggregations. """ @@ -349,7 +311,7 @@ class BaseQuery(object): # over the subquery instead. if self.group_by is not None: from subqueries import AggregateQuery - query = AggregateQuery(self.model, self.connection) + query = AggregateQuery(self.model) obj = self.clone() @@ -360,7 +322,7 @@ class BaseQuery(object): query.aggregate_select[alias] = aggregate del obj.aggregate_select[alias] - query.add_subquery(obj) + query.add_subquery(obj, using) else: query = self self.select = [] @@ -374,17 +336,17 @@ class BaseQuery(object): query.related_select_cols = [] query.related_select_fields = [] - result = query.execute_sql(SINGLE) + result = query.get_compiler(using).execute_sql(SINGLE) if result is None: result = [None for q in query.aggregate_select.items()] return dict([ - (alias, self.resolve_aggregate(val, aggregate)) + (alias, self.resolve_aggregate(val, aggregate, connection=connections[using])) for (alias, aggregate), val in zip(query.aggregate_select.items(), result) ]) - def get_count(self): + def get_count(self, using): """ Performs a COUNT() query using the current filter constraints. """ @@ -398,11 +360,11 @@ class BaseQuery(object): subquery.clear_ordering(True) subquery.clear_limits() - obj = AggregateQuery(obj.model, obj.connection) - obj.add_subquery(subquery) + obj = AggregateQuery(obj.model) + obj.add_subquery(subquery, using=using) obj.add_count_column() - number = obj.get_aggregation()[None] + number = obj.get_aggregation(using=using)[None] # Apply offset and limit constraints manually, since using LIMIT/OFFSET # in SQL (in variants that provide them) doesn't change the COUNT @@ -413,7 +375,7 @@ class BaseQuery(object): return number - def has_results(self): + def has_results(self, using): q = self.clone() q.add_extra({'a': 1}, None, None, None, None, None) q.add_fields(()) @@ -421,99 +383,8 @@ class BaseQuery(object): q.set_aggregate_mask(()) q.clear_ordering() q.set_limits(high=1) - return bool(q.execute_sql(SINGLE)) - - def as_sql(self, with_limits=True, with_col_aliases=False): - """ - Creates the SQL for this query. Returns the SQL string and list of - parameters. - - If 'with_limits' is False, any limit/offset information is not included - in the query. - """ - self.pre_sql_setup() - out_cols = self.get_columns(with_col_aliases) - ordering, ordering_group_by = self.get_ordering() - - # This must come after 'select' and 'ordering' -- see docstring of - # get_from_clause() for details. - from_, f_params = self.get_from_clause() - - qn = self.quote_name_unless_alias - where, w_params = self.where.as_sql(qn=qn) - having, h_params = self.having.as_sql(qn=qn) - params = [] - for val in self.extra_select.itervalues(): - params.extend(val[1]) - - result = ['SELECT'] - if self.distinct: - result.append('DISTINCT') - result.append(', '.join(out_cols + self.ordering_aliases)) - - result.append('FROM') - result.extend(from_) - params.extend(f_params) - - if where: - result.append('WHERE %s' % where) - params.extend(w_params) - if self.extra_where: - if not where: - result.append('WHERE') - else: - result.append('AND') - result.append(' AND '.join(self.extra_where)) - - grouping, gb_params = self.get_grouping() - if grouping: - if ordering: - # If the backend can't group by PK (i.e., any database - # other than MySQL), then any fields mentioned in the - # ordering clause needs to be in the group by clause. - if not self.connection.features.allows_group_by_pk: - for col, col_params in ordering_group_by: - if col not in grouping: - grouping.append(str(col)) - gb_params.extend(col_params) - else: - ordering = self.connection.ops.force_no_ordering() - result.append('GROUP BY %s' % ', '.join(grouping)) - params.extend(gb_params) - - if having: - result.append('HAVING %s' % having) - params.extend(h_params) - - if ordering: - result.append('ORDER BY %s' % ', '.join(ordering)) - - if with_limits: - if self.high_mark is not None: - result.append('LIMIT %d' % (self.high_mark - self.low_mark)) - if self.low_mark: - if self.high_mark is None: - val = self.connection.ops.no_limit_value() - if val: - result.append('LIMIT %d' % val) - result.append('OFFSET %d' % self.low_mark) - - params.extend(self.extra_params) - return ' '.join(result), tuple(params) - - def as_nested_sql(self): - """ - Perform the same functionality as the as_sql() method, returning an - SQL string and parameters. However, the alias prefixes are bumped - beforehand (in a copy -- the current query isn't changed) and any - ordering is removed. - - Used when nesting this query inside another. - """ - obj = self.clone() - obj.clear_ordering(True) - obj.bump_prefix() - return obj.as_sql() + compiler = q.get_compiler(using=using) + return bool(compiler.execute_sql(SINGLE)) def combine(self, rhs, connector): """ @@ -613,20 +484,6 @@ class BaseQuery(object): self.order_by = rhs.order_by and rhs.order_by[:] or self.order_by self.extra_order_by = rhs.extra_order_by or self.extra_order_by - def pre_sql_setup(self): - """ - Does any necessary class setup immediately prior to producing SQL. This - is for things that can't necessarily be done in __init__ because we - might not have all the pieces in place at that time. - """ - if not self.tables: - self.join((None, self.model._meta.db_table, None, None)) - if (not self.select and self.default_cols and not - self.included_inherited_models): - self.setup_inherited_models() - if self.select_related and not self.related_select_cols: - self.fill_related_selections() - def deferred_to_data(self, target, callback): """ Converts the self.deferred_loading data structure to an alternate data @@ -705,15 +562,6 @@ class BaseQuery(object): for model, values in seen.iteritems(): callback(target, model, values) - def deferred_to_columns(self): - """ - Converts the self.deferred_loading data structure to mapping of table - names to sets of column names which are to be loaded. Returns the - dictionary. - """ - columns = {} - self.deferred_to_data(columns, self.deferred_to_columns_cb) - return columns def deferred_to_columns_cb(self, target, model, fields): """ @@ -726,349 +574,6 @@ class BaseQuery(object): for field in fields: target[table].add(field.column) - def get_columns(self, with_aliases=False): - """ - Returns the list of columns to use in the select statement. If no - columns have been specified, returns all columns relating to fields in - the model. - - If 'with_aliases' is true, any column names that are duplicated - (without the table names) are given unique aliases. This is needed in - some cases to avoid ambiguity with nested queries. - """ - qn = self.quote_name_unless_alias - qn2 = self.connection.ops.quote_name - result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in self.extra_select.iteritems()] - aliases = set(self.extra_select.keys()) - if with_aliases: - col_aliases = aliases.copy() - else: - col_aliases = set() - if self.select: - only_load = self.deferred_to_columns() - for col in self.select: - if isinstance(col, (list, tuple)): - alias, column = col - table = self.alias_map[alias][TABLE_NAME] - if table in only_load and col not in only_load[table]: - continue - r = '%s.%s' % (qn(alias), qn(column)) - if with_aliases: - if col[1] in col_aliases: - c_alias = 'Col%d' % len(col_aliases) - result.append('%s AS %s' % (r, c_alias)) - aliases.add(c_alias) - col_aliases.add(c_alias) - else: - result.append('%s AS %s' % (r, qn2(col[1]))) - aliases.add(r) - col_aliases.add(col[1]) - else: - result.append(r) - aliases.add(r) - col_aliases.add(col[1]) - else: - result.append(col.as_sql(quote_func=qn)) - - if hasattr(col, 'alias'): - aliases.add(col.alias) - col_aliases.add(col.alias) - - elif self.default_cols: - cols, new_aliases = self.get_default_columns(with_aliases, - col_aliases) - result.extend(cols) - aliases.update(new_aliases) - - result.extend([ - '%s%s' % ( - aggregate.as_sql(quote_func=qn), - alias is not None and ' AS %s' % qn(alias) or '' - ) - for alias, aggregate in self.aggregate_select.items() - ]) - - for table, col in self.related_select_cols: - r = '%s.%s' % (qn(table), qn(col)) - if with_aliases and col in col_aliases: - c_alias = 'Col%d' % len(col_aliases) - result.append('%s AS %s' % (r, c_alias)) - aliases.add(c_alias) - col_aliases.add(c_alias) - else: - result.append(r) - aliases.add(r) - col_aliases.add(col) - - self._select_aliases = aliases - return result - - def get_default_columns(self, with_aliases=False, col_aliases=None, - start_alias=None, opts=None, as_pairs=False): - """ - Computes the default columns for selecting every field in the base - model. Will sometimes be called to pull in related models (e.g. via - select_related), in which case "opts" and "start_alias" will be given - to provide a starting point for the traversal. - - Returns a list of strings, quoted appropriately for use in SQL - directly, as well as a set of aliases used in the select statement (if - 'as_pairs' is True, returns a list of (alias, col_name) pairs instead - of strings as the first component and None as the second component). - """ - result = [] - if opts is None: - opts = self.model._meta - qn = self.quote_name_unless_alias - qn2 = self.connection.ops.quote_name - aliases = set() - only_load = self.deferred_to_columns() - # Skip all proxy to the root proxied model - proxied_model = get_proxied_model(opts) - - if start_alias: - seen = {None: start_alias} - for field, model in opts.get_fields_with_model(): - if start_alias: - try: - alias = seen[model] - except KeyError: - if model is proxied_model: - alias = start_alias - else: - link_field = opts.get_ancestor_link(model) - alias = self.join((start_alias, model._meta.db_table, - link_field.column, model._meta.pk.column)) - seen[model] = alias - else: - # If we're starting from the base model of the queryset, the - # aliases will have already been set up in pre_sql_setup(), so - # we can save time here. - alias = self.included_inherited_models[model] - table = self.alias_map[alias][TABLE_NAME] - if table in only_load and field.column not in only_load[table]: - continue - if as_pairs: - result.append((alias, field.column)) - aliases.add(alias) - continue - if with_aliases and field.column in col_aliases: - c_alias = 'Col%d' % len(col_aliases) - result.append('%s.%s AS %s' % (qn(alias), - qn2(field.column), c_alias)) - col_aliases.add(c_alias) - aliases.add(c_alias) - else: - r = '%s.%s' % (qn(alias), qn2(field.column)) - result.append(r) - aliases.add(r) - if with_aliases: - col_aliases.add(field.column) - return result, aliases - - def get_from_clause(self): - """ - Returns a list of strings that are joined together to go after the - "FROM" part of the query, as well as a list any extra parameters that - need to be included. Sub-classes, can override this to create a - from-clause via a "select". - - This should only be called after any SQL construction methods that - might change the tables we need. This means the select columns and - ordering must be done first. - """ - result = [] - qn = self.quote_name_unless_alias - qn2 = self.connection.ops.quote_name - first = True - for alias in self.tables: - if not self.alias_refcount[alias]: - continue - try: - name, alias, join_type, lhs, lhs_col, col, nullable = self.alias_map[alias] - except KeyError: - # Extra tables can end up in self.tables, but not in the - # alias_map if they aren't in a join. That's OK. We skip them. - continue - alias_str = (alias != name and ' %s' % alias or '') - if join_type and not first: - result.append('%s %s%s ON (%s.%s = %s.%s)' - % (join_type, qn(name), alias_str, qn(lhs), - qn2(lhs_col), qn(alias), qn2(col))) - else: - connector = not first and ', ' or '' - result.append('%s%s%s' % (connector, qn(name), alias_str)) - first = False - for t in self.extra_tables: - alias, unused = self.table_alias(t) - # Only add the alias if it's not already present (the table_alias() - # calls increments the refcount, so an alias refcount of one means - # this is the only reference. - if alias not in self.alias_map or self.alias_refcount[alias] == 1: - connector = not first and ', ' or '' - result.append('%s%s' % (connector, qn(alias))) - first = False - return result, [] - - def get_grouping(self): - """ - Returns a tuple representing the SQL elements in the "group by" clause. - """ - qn = self.quote_name_unless_alias - result, params = [], [] - if self.group_by is not None: - group_by = self.group_by or [] - - extra_selects = [] - for extra_select, extra_params in self.extra_select.itervalues(): - extra_selects.append(extra_select) - params.extend(extra_params) - for col in group_by + self.related_select_cols + extra_selects: - if isinstance(col, (list, tuple)): - result.append('%s.%s' % (qn(col[0]), qn(col[1]))) - elif hasattr(col, 'as_sql'): - result.append(col.as_sql(qn)) - else: - result.append(str(col)) - return result, params - - def get_ordering(self): - """ - Returns a tuple containing a list representing the SQL elements in the - "order by" clause, and the list of SQL elements that need to be added - to the GROUP BY clause as a result of the ordering. - - Also sets the ordering_aliases attribute on this instance to a list of - extra aliases needed in the select. - - Determining the ordering SQL can change the tables we need to include, - so this should be run *before* get_from_clause(). - """ - if self.extra_order_by: - ordering = self.extra_order_by - elif not self.default_ordering: - ordering = self.order_by - else: - ordering = self.order_by or self.model._meta.ordering - qn = self.quote_name_unless_alias - qn2 = self.connection.ops.quote_name - distinct = self.distinct - select_aliases = self._select_aliases - result = [] - group_by = [] - ordering_aliases = [] - if self.standard_ordering: - asc, desc = ORDER_DIR['ASC'] - else: - asc, desc = ORDER_DIR['DESC'] - - # It's possible, due to model inheritance, that normal usage might try - # to include the same field more than once in the ordering. We track - # the table/column pairs we use and discard any after the first use. - processed_pairs = set() - - for field in ordering: - if field == '?': - result.append(self.connection.ops.random_function_sql()) - continue - if isinstance(field, int): - if field < 0: - order = desc - field = -field - else: - order = asc - result.append('%s %s' % (field, order)) - group_by.append((field, [])) - continue - col, order = get_order_dir(field, asc) - if col in self.aggregate_select: - result.append('%s %s' % (col, order)) - continue - if '.' in field: - # This came in through an extra(order_by=...) addition. Pass it - # on verbatim. - table, col = col.split('.', 1) - if (table, col) not in processed_pairs: - elt = '%s.%s' % (qn(table), col) - processed_pairs.add((table, col)) - if not distinct or elt in select_aliases: - result.append('%s %s' % (elt, order)) - group_by.append((elt, [])) - elif get_order_dir(field)[0] not in self.extra_select: - # 'col' is of the form 'field' or 'field1__field2' or - # '-field1__field2__field', etc. - for table, col, order in self.find_ordering_name(field, - self.model._meta, default_order=asc): - if (table, col) not in processed_pairs: - elt = '%s.%s' % (qn(table), qn2(col)) - processed_pairs.add((table, col)) - if distinct and elt not in select_aliases: - ordering_aliases.append(elt) - result.append('%s %s' % (elt, order)) - group_by.append((elt, [])) - else: - elt = qn2(col) - if distinct and col not in select_aliases: - ordering_aliases.append(elt) - result.append('%s %s' % (elt, order)) - group_by.append(self.extra_select[col]) - self.ordering_aliases = ordering_aliases - return result, group_by - - def find_ordering_name(self, name, opts, alias=None, default_order='ASC', - already_seen=None): - """ - Returns the table alias (the name might be ambiguous, the alias will - not be) and column name for ordering by the given 'name' parameter. - The 'name' is of the form 'field1__field2__...__fieldN'. - """ - name, order = get_order_dir(name, default_order) - pieces = name.split(LOOKUP_SEP) - if not alias: - alias = self.get_initial_alias() - field, target, opts, joins, last, extra = self.setup_joins(pieces, - opts, alias, False) - alias = joins[-1] - col = target.column - if not field.rel: - # To avoid inadvertent trimming of a necessary alias, use the - # refcount to show that we are referencing a non-relation field on - # the model. - self.ref_alias(alias) - - # Must use left outer joins for nullable fields and their relations. - self.promote_alias_chain(joins, - self.alias_map[joins[0]][JOIN_TYPE] == self.LOUTER) - - # If we get to this point and the field is a relation to another model, - # append the default ordering for that model. - if field.rel and len(joins) > 1 and opts.ordering: - # Firstly, avoid infinite loops. - if not already_seen: - already_seen = set() - join_tuple = tuple([self.alias_map[j][TABLE_NAME] for j in joins]) - if join_tuple in already_seen: - raise FieldError('Infinite loop caused by ordering.') - already_seen.add(join_tuple) - - results = [] - for item in opts.ordering: - results.extend(self.find_ordering_name(item, opts, alias, - order, already_seen)) - return results - - if alias: - # We have to do the same "final join" optimisation as in - # add_filter, since the final column might not otherwise be part of - # the select set (so we can't order on it). - while 1: - join = self.alias_map[alias] - if col != join[RHS_JOIN_COL]: - break - self.unref_alias(alias) - alias = join[LHS_ALIAS] - col = join[LHS_JOIN_COL] - return [(alias, col, order)] def table_alias(self, table_name, create=False): """ @@ -1372,113 +877,6 @@ class BaseQuery(object): self.unref_alias(alias) self.included_inherited_models = {} - def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1, - used=None, requested=None, restricted=None, nullable=None, - dupe_set=None, avoid_set=None): - """ - Fill in the information needed for a select_related query. The current - depth is measured as the number of connections away from the root model - (for example, cur_depth=1 means we are looking at models with direct - connections to the root model). - """ - if not restricted and self.max_depth and cur_depth > self.max_depth: - # We've recursed far enough; bail out. - return - - if not opts: - opts = self.get_meta() - root_alias = self.get_initial_alias() - self.related_select_cols = [] - self.related_select_fields = [] - if not used: - used = set() - if dupe_set is None: - dupe_set = set() - if avoid_set is None: - avoid_set = set() - orig_dupe_set = dupe_set - - # Setup for the case when only particular related fields should be - # included in the related selection. - if requested is None and restricted is not False: - if isinstance(self.select_related, dict): - requested = self.select_related - restricted = True - else: - restricted = False - - for f, model in opts.get_fields_with_model(): - if not select_related_descend(f, restricted, requested): - continue - # The "avoid" set is aliases we want to avoid just for this - # particular branch of the recursion. They aren't permanently - # forbidden from reuse in the related selection tables (which is - # what "used" specifies). - avoid = avoid_set.copy() - dupe_set = orig_dupe_set.copy() - table = f.rel.to._meta.db_table - if nullable or f.null: - promote = True - else: - promote = False - if model: - int_opts = opts - alias = root_alias - alias_chain = [] - for int_model in opts.get_base_chain(model): - # Proxy model have elements in base chain - # with no parents, assign the new options - # object and skip to the next base in that - # case - if not int_opts.parents[int_model]: - int_opts = int_model._meta - continue - lhs_col = int_opts.parents[int_model].column - dedupe = lhs_col in opts.duplicate_targets - if dedupe: - avoid.update(self.dupe_avoidance.get(id(opts), lhs_col), - ()) - dupe_set.add((opts, lhs_col)) - int_opts = int_model._meta - alias = self.join((alias, int_opts.db_table, lhs_col, - int_opts.pk.column), exclusions=used, - promote=promote) - alias_chain.append(alias) - for (dupe_opts, dupe_col) in dupe_set: - self.update_dupe_avoidance(dupe_opts, dupe_col, alias) - if self.alias_map[root_alias][JOIN_TYPE] == self.LOUTER: - self.promote_alias_chain(alias_chain, True) - else: - alias = root_alias - - dedupe = f.column in opts.duplicate_targets - if dupe_set or dedupe: - avoid.update(self.dupe_avoidance.get((id(opts), f.column), ())) - if dedupe: - dupe_set.add((opts, f.column)) - - alias = self.join((alias, table, f.column, - f.rel.get_related_field().column), - exclusions=used.union(avoid), promote=promote) - used.add(alias) - columns, aliases = self.get_default_columns(start_alias=alias, - opts=f.rel.to._meta, as_pairs=True) - self.related_select_cols.extend(columns) - if self.alias_map[alias][JOIN_TYPE] == self.LOUTER: - self.promote_alias_chain(aliases, True) - self.related_select_fields.extend(f.rel.to._meta.fields) - if restricted: - next = requested.get(f.name, {}) - else: - next = False - if f.null is not None: - new_nullable = f.null - else: - new_nullable = None - for dupe_opts, dupe_col in dupe_set: - self.update_dupe_avoidance(dupe_opts, dupe_col, alias) - self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1, - used, next, restricted, new_nullable, dupe_set, avoid) def add_aggregate(self, aggregate, model, alias, is_summary): """ @@ -1527,7 +925,6 @@ class BaseQuery(object): col = field_name # Add the aggregate to the query - alias = truncate_name(alias, self.connection.ops.max_name_length()) aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary) def add_filter(self, filter_expr, connector=AND, negate=False, trim=False, @@ -1578,10 +975,6 @@ class BaseQuery(object): raise ValueError("Cannot use None as a query value") lookup_type = 'isnull' value = True - elif (value == '' and lookup_type == 'exact' and - connection.features.interprets_empty_strings_as_nulls): - lookup_type = 'isnull' - value = True elif callable(value): value = value() elif hasattr(value, 'evaluate'): @@ -1999,7 +1392,7 @@ class BaseQuery(object): original exclude filter (filter_expr) and the portion up to the first N-to-many relation field. """ - query = Query(self.model, self.connection) + query = Query(self.model) query.add_filter(filter_expr, can_reuse=can_reuse) query.bump_prefix() query.clear_ordering(True) @@ -2138,11 +1531,6 @@ class BaseQuery(object): will be made automatically. """ self.group_by = [] - if self.connection.features.allows_group_by_pk: - if len(self.select) == len(self.model._meta.fields): - self.group_by.append((self.model._meta.db_table, - self.model._meta.pk.column)) - return for sel in self.select: self.group_by.append(sel) @@ -2382,58 +1770,6 @@ class BaseQuery(object): self.select = [(select_alias, select_col)] self.remove_inherited_models() - def execute_sql(self, result_type=MULTI): - """ - Run the query against the database and returns the result(s). The - return value is a single data item if result_type is SINGLE, or an - iterator over the results if the result_type is MULTI. - - result_type is either MULTI (use fetchmany() to retrieve all rows), - SINGLE (only retrieve a single row), or None. In this last case, the - cursor is returned if any query is executed, since it's used by - subclasses such as InsertQuery). It's possible, however, that no query - is needed, as the filters describe an empty set. In that case, None is - returned, to avoid any unnecessary database interaction. - """ - try: - sql, params = self.as_sql() - if not sql: - raise EmptyResultSet - except EmptyResultSet: - if result_type == MULTI: - return empty_iter() - else: - return - cursor = self.connection.cursor() - cursor.execute(sql, params) - - if not result_type: - return cursor - if result_type == SINGLE: - if self.ordering_aliases: - return cursor.fetchone()[:-len(self.ordering_aliases)] - return cursor.fetchone() - - # The MULTI case. - if self.ordering_aliases: - result = order_modified_iter(cursor, len(self.ordering_aliases), - self.connection.features.empty_fetchmany_value) - else: - result = iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)), - self.connection.features.empty_fetchmany_value) - if not self.connection.features.can_use_chunked_reads: - # If we are using non-chunked reads, we return the same data - # structure as normally, but ensure it is all read into memory - # before going any further. - return list(result) - return result - -# Use the backend's custom Query class if it defines one. Otherwise, use the -# default. -if connection.features.uses_custom_query_class: - Query = connection.ops.query_class(BaseQuery) -else: - Query = BaseQuery def get_order_dir(field, default='ASC'): """ @@ -2448,22 +1784,6 @@ def get_order_dir(field, default='ASC'): return field[1:], dirn[1] return field, dirn[0] -def empty_iter(): - """ - Returns an iterator containing no results. - """ - yield iter([]).next() - -def order_modified_iter(cursor, trim, sentinel): - """ - Yields blocks of rows from a cursor. We use this iterator in the special - case when extra output columns have been added to support ordering - requirements. We must trim those extra columns before anything else can use - the results, since they're only needed to make the SQL valid. - """ - for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)), - sentinel): - yield [r[:-trim] for r in rows] def setup_join_cache(sender, **kwargs): """ diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index f00f1bd68a..e80a023699 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -3,6 +3,7 @@ Query subclasses which provide extra functionality beyond simple data retrieval. """ from django.core.exceptions import FieldError +from django.db import connections from django.db.models.sql.constants import * from django.db.models.sql.datastructures import Date from django.db.models.sql.expressions import SQLEvaluator @@ -17,24 +18,15 @@ class DeleteQuery(Query): Delete queries are done through this class, since they are more constrained than general queries. """ - def as_sql(self): - """ - Creates the SQL for this query. Returns the SQL string and list of - parameters. - """ - assert len(self.tables) == 1, \ - "Can only delete from one table at a time." - result = ['DELETE FROM %s' % self.quote_name_unless_alias(self.tables[0])] - where, params = self.where.as_sql() - result.append('WHERE %s' % where) - return ' '.join(result), tuple(params) - def do_query(self, table, where): + compiler = 'SQLDeleteCompiler' + + def do_query(self, table, where, using): self.tables = [table] self.where = where - self.execute_sql(None) + self.get_compiler(using).execute_sql(None) - def delete_batch_related(self, pk_list): + def delete_batch_related(self, pk_list, using): """ Set up and execute delete queries for all the objects related to the primary key values in pk_list. To delete the objects themselves, use @@ -54,7 +46,7 @@ class DeleteQuery(Query): 'in', pk_list[offset : offset+GET_ITERATOR_CHUNK_SIZE]), AND) - self.do_query(related.field.m2m_db_table(), where) + self.do_query(related.field.m2m_db_table(), where, using=using) for f in cls._meta.many_to_many: w1 = self.where_class() @@ -70,9 +62,9 @@ class DeleteQuery(Query): AND) if w1: where.add(w1, AND) - self.do_query(f.m2m_db_table(), where) + self.do_query(f.m2m_db_table(), where, using=using) - def delete_batch(self, pk_list): + def delete_batch(self, pk_list, using): """ Set up and execute delete queries for all the objects in pk_list. This should be called after delete_batch_related(), if necessary. @@ -85,12 +77,15 @@ class DeleteQuery(Query): field = self.model._meta.pk where.add((Constraint(None, field.column, field), 'in', pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND) - self.do_query(self.model._meta.db_table, where) + self.do_query(self.model._meta.db_table, where, using=using) class UpdateQuery(Query): """ Represents an "update" SQL query. """ + + compiler = 'SQLUpdateCompiler' + def __init__(self, *args, **kwargs): super(UpdateQuery, self).__init__(*args, **kwargs) self._setup_query() @@ -110,98 +105,8 @@ class UpdateQuery(Query): return super(UpdateQuery, self).clone(klass, related_updates=self.related_updates.copy(), **kwargs) - def execute_sql(self, result_type=None): - """ - Execute the specified update. Returns the number of rows affected by - the primary update query. The "primary update query" is the first - non-empty query that is executed. Row counts for any subsequent, - related queries are not available. - """ - cursor = super(UpdateQuery, self).execute_sql(result_type) - rows = cursor and cursor.rowcount or 0 - is_empty = cursor is None - del cursor - for query in self.get_related_updates(): - aux_rows = query.execute_sql(result_type) - if is_empty: - rows = aux_rows - is_empty = False - return rows - - def as_sql(self): - """ - Creates the SQL for this query. Returns the SQL string and list of - parameters. - """ - self.pre_sql_setup() - if not self.values: - return '', () - table = self.tables[0] - qn = self.quote_name_unless_alias - result = ['UPDATE %s' % qn(table)] - result.append('SET') - values, update_params = [], [] - for name, val, placeholder in self.values: - if hasattr(val, 'as_sql'): - sql, params = val.as_sql(qn) - values.append('%s = %s' % (qn(name), sql)) - update_params.extend(params) - elif val is not None: - values.append('%s = %s' % (qn(name), placeholder)) - update_params.append(val) - else: - values.append('%s = NULL' % qn(name)) - result.append(', '.join(values)) - where, params = self.where.as_sql() - if where: - result.append('WHERE %s' % where) - return ' '.join(result), tuple(update_params + params) - - def pre_sql_setup(self): - """ - If the update depends on results from other tables, we need to do some - munging of the "where" conditions to match the format required for - (portable) SQL updates. That is done here. - Further, if we are going to be running multiple updates, we pull out - the id values to update at this point so that they don't change as a - result of the progressive updates. - """ - self.select_related = False - self.clear_ordering(True) - super(UpdateQuery, self).pre_sql_setup() - count = self.count_active_tables() - if not self.related_updates and count == 1: - return - - # We need to use a sub-select in the where clause to filter on things - # from other tables. - query = self.clone(klass=Query) - query.bump_prefix() - query.extra = {} - query.select = [] - query.add_fields([query.model._meta.pk.name]) - must_pre_select = count > 1 and not self.connection.features.update_can_self_select - - # Now we adjust the current query: reset the where clause and get rid - # of all the tables we don't need (since they're in the sub-select). - self.where = self.where_class() - if self.related_updates or must_pre_select: - # Either we're using the idents in multiple update queries (so - # don't want them to change), or the db backend doesn't support - # selecting from the updating table (e.g. MySQL). - idents = [] - for rows in query.execute_sql(MULTI): - idents.extend([r[0] for r in rows]) - self.add_filter(('pk__in', idents)) - self.related_ids = idents - else: - # The fast path. Filters and updates in one query. - self.add_filter(('pk__in', query)) - for alias in self.tables[1:]: - self.alias_refcount[alias] = 0 - - def clear_related(self, related_field, pk_list): + def clear_related(self, related_field, pk_list, using): """ Set up and execute an update query that clears related entries for the keys in pk_list. @@ -214,8 +119,8 @@ class UpdateQuery(Query): self.where.add((Constraint(None, f.column, f), 'in', pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND) - self.values = [(related_field.column, None, '%s')] - self.execute_sql(None) + self.values = [(related_field, None, None)] + self.get_compiler(using).execute_sql(None) def add_update_values(self, values): """ @@ -228,6 +133,9 @@ class UpdateQuery(Query): field, model, direct, m2m = self.model._meta.get_field_by_name(name) if not direct or m2m: raise FieldError('Cannot update model field %r (only non-relations and foreign keys permitted).' % field) + if model: + self.add_related_update(model, field, val) + continue values_seq.append((field, model, val)) return self.add_update_fields(values_seq) @@ -237,36 +145,18 @@ class UpdateQuery(Query): Used by add_update_values() as well as the "fast" update path when saving models. """ - from django.db.models.base import Model - for field, model, val in values_seq: - if hasattr(val, 'prepare_database_save'): - val = val.prepare_database_save(field) - else: - val = field.get_db_prep_save(val) + self.values.extend(values_seq) - # Getting the placeholder for the field. - if hasattr(field, 'get_placeholder'): - placeholder = field.get_placeholder(val) - else: - placeholder = '%s' - - if hasattr(val, 'evaluate'): - val = SQLEvaluator(val, self, allow_joins=False) - if model: - self.add_related_update(model, field.column, val, placeholder) - else: - self.values.append((field.column, val, placeholder)) - - def add_related_update(self, model, column, value, placeholder): + def add_related_update(self, model, field, value): """ Adds (name, value) to an update query for an ancestor model. Updates are coalesced so that we only run one update query per ancestor. """ try: - self.related_updates[model].append((column, value, placeholder)) + self.related_updates[model].append((field, None, value)) except KeyError: - self.related_updates[model] = [(column, value, placeholder)] + self.related_updates[model] = [(field, None, value)] def get_related_updates(self): """ @@ -278,7 +168,7 @@ class UpdateQuery(Query): return [] result = [] for model, values in self.related_updates.iteritems(): - query = UpdateQuery(model, self.connection) + query = UpdateQuery(model) query.values = values if self.related_ids: query.add_filter(('pk__in', self.related_ids)) @@ -286,45 +176,23 @@ class UpdateQuery(Query): return result class InsertQuery(Query): + compiler = 'SQLInsertCompiler' + def __init__(self, *args, **kwargs): super(InsertQuery, self).__init__(*args, **kwargs) self.columns = [] self.values = [] self.params = () - self.return_id = False def clone(self, klass=None, **kwargs): - extras = {'columns': self.columns[:], 'values': self.values[:], - 'params': self.params, 'return_id': self.return_id} + extras = { + 'columns': self.columns[:], + 'values': self.values[:], + 'params': self.params + } extras.update(kwargs) return super(InsertQuery, self).clone(klass, **extras) - def as_sql(self): - # We don't need quote_name_unless_alias() here, since these are all - # going to be column names (so we can avoid the extra overhead). - qn = self.connection.ops.quote_name - opts = self.model._meta - result = ['INSERT INTO %s' % qn(opts.db_table)] - result.append('(%s)' % ', '.join([qn(c) for c in self.columns])) - result.append('VALUES (%s)' % ', '.join(self.values)) - params = self.params - if self.return_id and self.connection.features.can_return_id_from_insert: - col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column)) - r_fmt, r_params = self.connection.ops.return_insert_id() - result.append(r_fmt % col) - params = params + r_params - return ' '.join(result), params - - def execute_sql(self, return_id=False): - self.return_id = return_id - cursor = super(InsertQuery, self).execute_sql(None) - if not (return_id and cursor): - return - if self.connection.features.can_return_id_from_insert: - return self.connection.ops.fetch_returned_insert_id(cursor) - return self.connection.ops.last_insert_id(cursor, - self.model._meta.db_table, self.model._meta.pk.column) - def insert_values(self, insert_values, raw_values=False): """ Set up the insert query from the 'insert_values' dictionary. The @@ -337,17 +205,11 @@ class InsertQuery(Query): """ placeholders, values = [], [] for field, val in insert_values: - if hasattr(field, 'get_placeholder'): - # Some fields (e.g. geo fields) need special munging before - # they can be inserted. - placeholders.append(field.get_placeholder(val)) - else: - placeholders.append('%s') - + placeholders.append((field, val)) self.columns.append(field.column) values.append(val) if raw_values: - self.values.extend(values) + self.values.extend([(None, v) for v in values]) else: self.params += tuple(values) self.values.extend(placeholders) @@ -358,44 +220,8 @@ class DateQuery(Query): date field. This requires some special handling when converting the results back to Python objects, so we put it in a separate class. """ - def __getstate__(self): - """ - Special DateQuery-specific pickle handling. - """ - for elt in self.select: - if isinstance(elt, Date): - # Eliminate a method reference that can't be pickled. The - # __setstate__ method restores this. - elt.date_sql_func = None - return super(DateQuery, self).__getstate__() - def __setstate__(self, obj_dict): - super(DateQuery, self).__setstate__(obj_dict) - for elt in self.select: - if isinstance(elt, Date): - self.date_sql_func = self.connection.ops.date_trunc_sql - - def results_iter(self): - """ - Returns an iterator over the results from executing this query. - """ - resolve_columns = hasattr(self, 'resolve_columns') - if resolve_columns: - from django.db.models.fields import DateTimeField - fields = [DateTimeField()] - else: - from django.db.backends.util import typecast_timestamp - needs_string_cast = self.connection.features.needs_datetime_string_cast - - offset = len(self.extra_select) - for rows in self.execute_sql(MULTI): - for row in rows: - date = row[offset] - if resolve_columns: - date = self.resolve_columns(row, fields)[offset] - elif needs_string_cast: - date = typecast_timestamp(str(date)) - yield date + compiler = 'SQLDateCompiler' def add_date_select(self, field, lookup_type, order='ASC'): """ @@ -404,8 +230,7 @@ class DateQuery(Query): result = self.setup_joins([field.name], self.get_meta(), self.get_initial_alias(), False) alias = result[3][-1] - select = Date((alias, field.column), lookup_type, - self.connection.ops.date_trunc_sql) + select = Date((alias, field.column), lookup_type) self.select = [select] self.select_fields = [None] self.select_related = False # See #7097. @@ -418,20 +243,8 @@ class AggregateQuery(Query): An AggregateQuery takes another query as a parameter to the FROM clause and only selects the elements in the provided list. """ - def add_subquery(self, query): - self.subquery, self.sub_params = query.as_sql(with_col_aliases=True) - def as_sql(self, quote_func=None): - """ - Creates the SQL for this query. Returns the SQL string and list of - parameters. - """ - sql = ('SELECT %s FROM (%s) subquery' % ( - ', '.join([ - aggregate.as_sql() - for aggregate in self.aggregate_select.values() - ]), - self.subquery) - ) - params = self.sub_params - return (sql, params) + compiler = 'SQLAggregateCompiler' + + def add_subquery(self, query, using): + self.subquery, self.sub_params = query.get_compiler(using).as_sql(with_col_aliases=True) diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index ec0545ca5b..4aa2351f17 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -4,7 +4,6 @@ Code to manage the creation and SQL rendering of 'where' constraints. import datetime from django.utils import tree -from django.db import connection from django.db.models.fields import Field from django.db.models.query_utils import QueryWrapper from datastructures import EmptyResultSet, FullResultSet @@ -51,18 +50,6 @@ class WhereNode(tree.Node): # Consume any generators immediately, so that we can determine # emptiness and transform any non-empty values correctly. value = list(value) - if hasattr(obj, "process"): - try: - obj, params = obj.process(lookup_type, value) - except (EmptyShortCircuit, EmptyResultSet): - # There are situations where we want to short-circuit any - # comparisons and make sure that nothing is returned. One - # example is when checking for a NULL pk value, or the - # equivalent. - super(WhereNode, self).add(NothingNode(), connector) - return - else: - params = Field().get_db_prep_lookup(lookup_type, value) # The "annotation" parameter is used to pass auxilliary information # about the value(s) to the query construction. Specifically, datetime @@ -75,10 +62,16 @@ class WhereNode(tree.Node): else: annotation = bool(value) - super(WhereNode, self).add((obj, lookup_type, annotation, params), + if hasattr(obj, "prepare"): + value = obj.prepare(lookup_type, value) + super(WhereNode, self).add((obj, lookup_type, annotation, value), + connector) + return + + super(WhereNode, self).add((obj, lookup_type, annotation, value), connector) - def as_sql(self, qn=None): + def as_sql(self, qn, connection): """ Returns the SQL version of the where clause and the value to be substituted in. Returns None, None if this node is empty. @@ -87,8 +80,6 @@ class WhereNode(tree.Node): (generally not needed except by the internal implementation for recursion). """ - if not qn: - qn = connection.ops.quote_name if not self.children: return None, [] result = [] @@ -97,10 +88,10 @@ class WhereNode(tree.Node): for child in self.children: try: if hasattr(child, 'as_sql'): - sql, params = child.as_sql(qn=qn) + sql, params = child.as_sql(qn=qn, connection=connection) else: # A leaf node in the tree. - sql, params = self.make_atom(child, qn) + sql, params = self.make_atom(child, qn, connection) except EmptyResultSet: if self.connector == AND and not self.negated: @@ -136,7 +127,7 @@ class WhereNode(tree.Node): sql_string = '(%s)' % sql_string return sql_string, result_params - def make_atom(self, child, qn): + def make_atom(self, child, qn, connection): """ Turn a tuple (table_alias, column_name, db_type, lookup_type, value_annot, params) into valid SQL. @@ -144,13 +135,21 @@ class WhereNode(tree.Node): Returns the string for the SQL fragment and the parameters to use for it. """ - lvalue, lookup_type, value_annot, params = child + lvalue, lookup_type, value_annot, params_or_value = child + if hasattr(lvalue, 'process'): + try: + lvalue, params = lvalue.process(lookup_type, params_or_value, connection) + except EmptyShortCircuit: + raise EmptyResultSet + else: + params = Field().get_db_prep_lookup(lookup_type, params_or_value, + connection=connection, prepared=True) if isinstance(lvalue, tuple): # A direct database column lookup. - field_sql = self.sql_for_columns(lvalue, qn) + field_sql = self.sql_for_columns(lvalue, qn, connection) else: # A smart object with an as_sql() method. - field_sql = lvalue.as_sql(quote_func=qn) + field_sql = lvalue.as_sql(qn, connection) if value_annot is datetime.datetime: cast_sql = connection.ops.datetime_cast_sql() @@ -158,11 +157,16 @@ class WhereNode(tree.Node): cast_sql = '%s' if hasattr(params, 'as_sql'): - extra, params = params.as_sql(qn) + extra, params = params.as_sql(qn, connection) cast_sql = '' else: extra = '' + if (len(params) == 1 and params[0] == '' and lookup_type == 'exact' + and connection.features.interprets_empty_strings_as_nulls): + lookup_type = 'isnull' + value_annot = True + if lookup_type in connection.operators: format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),) return (format % (field_sql, @@ -191,7 +195,7 @@ class WhereNode(tree.Node): raise TypeError('Invalid lookup_type: %r' % lookup_type) - def sql_for_columns(self, data, qn): + def sql_for_columns(self, data, qn, connection): """ Returns the SQL fragment used for the left-hand side of a column constraint (for example, the "T1.foo" portion in the clause @@ -233,7 +237,8 @@ class EverythingNode(object): """ A node that matches everything. """ - def as_sql(self, qn=None): + + def as_sql(self, qn=None, connection=None): raise FullResultSet def relabel_aliases(self, change_map, node=None): @@ -243,7 +248,7 @@ class NothingNode(object): """ A node that matches nothing. """ - def as_sql(self, qn=None): + def as_sql(self, qn=None, connection=None): raise EmptyResultSet def relabel_aliases(self, change_map, node=None): @@ -257,7 +262,12 @@ class Constraint(object): def __init__(self, alias, col, field): self.alias, self.col, self.field = alias, col, field - def process(self, lookup_type, value): + def prepare(self, lookup_type, value): + if self.field: + return self.field.get_prep_lookup(lookup_type, value) + return value + + def process(self, lookup_type, value, connection): """ Returns a tuple of data suitable for inclusion in a WhereNode instance. @@ -266,16 +276,21 @@ class Constraint(object): from django.db.models.base import ObjectDoesNotExist try: if self.field: - params = self.field.get_db_prep_lookup(lookup_type, value) - db_type = self.field.db_type() + params = self.field.get_db_prep_lookup(lookup_type, value, + connection=connection, prepared=True) + db_type = self.field.db_type(connection=connection) else: # This branch is used at times when we add a comparison to NULL # (we don't really want to waste time looking up the associated # field object at the calling location). - params = Field().get_db_prep_lookup(lookup_type, value) + params = Field().get_db_prep_lookup(lookup_type, value, + connection=connection, prepared=True) db_type = None except ObjectDoesNotExist: raise EmptyShortCircuit return (self.alias, self.col, db_type), params + def relabel_aliases(self, change_map): + if self.alias in change_map: + self.alias = change_map[self.alias] |
