diff options
Diffstat (limited to 'django/db/models/sql/compiler.py')
| -rw-r--r-- | django/db/models/sql/compiler.py | 63 |
1 files changed, 38 insertions, 25 deletions
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 41bba93206..123427cf8b 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -45,7 +45,7 @@ class SQLCompiler(object): if self.query.select_related and not self.query.related_select_cols: self.fill_related_selections() - def quote_name_unless_alias(self, name): + def __call__(self, name): """ A wrapper around connection.ops.quote_name that doesn't quote aliases for table names. This avoids problems with some SQL dialects that treat @@ -61,6 +61,22 @@ class SQLCompiler(object): self.quote_cache[name] = r return r + def quote_name_unless_alias(self, name): + """ + A wrapper around connection.ops.quote_name that doesn't quote aliases + for table names. This avoids problems with some SQL dialects that treat + quoted strings specially (e.g. PostgreSQL). + """ + return self(name) + + def compile(self, node): + vendor_impl = getattr( + node, 'as_' + self.connection.vendor, None) + if vendor_impl: + return vendor_impl(self, self.connection) + else: + return node.as_sql(self, self.connection) + def as_sql(self, with_limits=True, with_col_aliases=False): """ Creates the SQL for this query. Returns the SQL string and list of @@ -88,11 +104,9 @@ class SQLCompiler(object): # docstring of get_from_clause() for details. from_, f_params = self.get_from_clause() - qn = self.quote_name_unless_alias - - where, w_params = self.query.where.as_sql(qn=qn, connection=self.connection) - having, h_params = self.query.having.as_sql(qn=qn, connection=self.connection) - having_group_by = self.query.having.get_cols() + where, w_params = self.compile(self.query.where) + having, h_params = self.compile(self.query.having) + having_group_by = self.query.having.get_group_by_cols() params = [] for val in six.itervalues(self.query.extra_select): params.extend(val[1]) @@ -180,7 +194,7 @@ class SQLCompiler(object): (without the table names) are given unique aliases. This is needed in some cases to avoid ambiguity with nested queries. """ - qn = self.quote_name_unless_alias + qn = self qn2 = self.connection.ops.quote_name result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in six.iteritems(self.query.extra_select)] params = [] @@ -213,7 +227,7 @@ class SQLCompiler(object): aliases.add(r) col_aliases.add(col[1]) else: - col_sql, col_params = col.as_sql(qn, self.connection) + col_sql, col_params = self.compile(col) result.append(col_sql) params.extend(col_params) @@ -229,7 +243,7 @@ class SQLCompiler(object): max_name_length = self.connection.ops.max_name_length() for alias, aggregate in self.query.aggregate_select.items(): - agg_sql, agg_params = aggregate.as_sql(qn, self.connection) + agg_sql, agg_params = self.compile(aggregate) if alias is None: result.append(agg_sql) else: @@ -267,7 +281,7 @@ class SQLCompiler(object): result = [] if opts is None: opts = self.query.get_meta() - qn = self.quote_name_unless_alias + qn = self qn2 = self.connection.ops.quote_name aliases = set() only_load = self.deferred_to_columns() @@ -319,7 +333,7 @@ class SQLCompiler(object): Note that this method can alter the tables in the query, and thus it must be called before get_from_clause(). """ - qn = self.quote_name_unless_alias + qn = self qn2 = self.connection.ops.quote_name result = [] opts = self.query.get_meta() @@ -352,7 +366,7 @@ class SQLCompiler(object): ordering = (self.query.order_by or self.query.get_meta().ordering or []) - qn = self.quote_name_unless_alias + qn = self qn2 = self.connection.ops.quote_name distinct = self.query.distinct select_aliases = self._select_aliases @@ -490,7 +504,7 @@ class SQLCompiler(object): ordering and distinct must be done first. """ result = [] - qn = self.quote_name_unless_alias + qn = self qn2 = self.connection.ops.quote_name first = True from_params = [] @@ -508,8 +522,7 @@ class SQLCompiler(object): extra_cond = join_field.get_extra_restriction( self.query.where_class, alias, lhs) if extra_cond: - extra_sql, extra_params = extra_cond.as_sql( - qn, self.connection) + extra_sql, extra_params = self.compile(extra_cond) extra_sql = 'AND (%s)' % extra_sql from_params.extend(extra_params) else: @@ -541,7 +554,7 @@ class SQLCompiler(object): """ Returns a tuple representing the SQL elements in the "group by" clause. """ - qn = self.quote_name_unless_alias + qn = self result, params = [], [] if self.query.group_by is not None: select_cols = self.query.select + self.query.related_select_cols @@ -560,7 +573,7 @@ class SQLCompiler(object): if isinstance(col, (list, tuple)): sql = '%s.%s' % (qn(col[0]), qn(col[1])) elif hasattr(col, 'as_sql'): - sql, col_params = col.as_sql(qn, self.connection) + self.compile(col) else: sql = '(%s)' % str(col) if sql not in seen: @@ -784,7 +797,7 @@ class SQLCompiler(object): return result def as_subquery_condition(self, alias, columns, qn): - inner_qn = self.quote_name_unless_alias + inner_qn = self qn2 = self.connection.ops.quote_name if len(columns) == 1: sql, params = self.as_sql() @@ -895,9 +908,9 @@ class SQLDeleteCompiler(SQLCompiler): """ assert len(self.query.tables) == 1, \ "Can only delete from one table at a time." - qn = self.quote_name_unless_alias + qn = self result = ['DELETE FROM %s' % qn(self.query.tables[0])] - where, params = self.query.where.as_sql(qn=qn, connection=self.connection) + where, params = self.compile(self.query.where) if where: result.append('WHERE %s' % where) return ' '.join(result), tuple(params) @@ -913,7 +926,7 @@ class SQLUpdateCompiler(SQLCompiler): if not self.query.values: return '', () table = self.query.tables[0] - qn = self.quote_name_unless_alias + qn = self result = ['UPDATE %s' % qn(table)] result.append('SET') values, update_params = [], [] @@ -933,7 +946,7 @@ class SQLUpdateCompiler(SQLCompiler): val = SQLEvaluator(val, self.query, allow_joins=False) name = field.column if hasattr(val, 'as_sql'): - sql, params = val.as_sql(qn, self.connection) + sql, params = self.compile(val) values.append('%s = %s' % (qn(name), sql)) update_params.extend(params) elif val is not None: @@ -944,7 +957,7 @@ class SQLUpdateCompiler(SQLCompiler): if not values: return '', () result.append(', '.join(values)) - where, params = self.query.where.as_sql(qn=qn, connection=self.connection) + where, params = self.compile(self.query.where) if where: result.append('WHERE %s' % where) return ' '.join(result), tuple(update_params + params) @@ -1024,11 +1037,11 @@ class SQLAggregateCompiler(SQLCompiler): parameters. """ if qn is None: - qn = self.quote_name_unless_alias + qn = self sql, params = [], [] for aggregate in self.query.aggregate_select.values(): - agg_sql, agg_params = aggregate.as_sql(qn, self.connection) + agg_sql, agg_params = self.compile(aggregate) sql.append(agg_sql) params.extend(agg_params) sql = ', '.join(sql) |
