summaryrefslogtreecommitdiff
path: root/django/db/models/sql/compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'django/db/models/sql/compiler.py')
-rw-r--r--django/db/models/sql/compiler.py63
1 files changed, 38 insertions, 25 deletions
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index 41bba93206..123427cf8b 100644
--- a/django/db/models/sql/compiler.py
+++ b/django/db/models/sql/compiler.py
@@ -45,7 +45,7 @@ class SQLCompiler(object):
if self.query.select_related and not self.query.related_select_cols:
self.fill_related_selections()
- def quote_name_unless_alias(self, name):
+ def __call__(self, name):
"""
A wrapper around connection.ops.quote_name that doesn't quote aliases
for table names. This avoids problems with some SQL dialects that treat
@@ -61,6 +61,22 @@ class SQLCompiler(object):
self.quote_cache[name] = r
return r
+ def quote_name_unless_alias(self, name):
+ """
+ A wrapper around connection.ops.quote_name that doesn't quote aliases
+ for table names. This avoids problems with some SQL dialects that treat
+ quoted strings specially (e.g. PostgreSQL).
+ """
+ return self(name)
+
+ def compile(self, node):
+ vendor_impl = getattr(
+ node, 'as_' + self.connection.vendor, None)
+ if vendor_impl:
+ return vendor_impl(self, self.connection)
+ else:
+ return node.as_sql(self, self.connection)
+
def as_sql(self, with_limits=True, with_col_aliases=False):
"""
Creates the SQL for this query. Returns the SQL string and list of
@@ -88,11 +104,9 @@ class SQLCompiler(object):
# docstring of get_from_clause() for details.
from_, f_params = self.get_from_clause()
- qn = self.quote_name_unless_alias
-
- where, w_params = self.query.where.as_sql(qn=qn, connection=self.connection)
- having, h_params = self.query.having.as_sql(qn=qn, connection=self.connection)
- having_group_by = self.query.having.get_cols()
+ where, w_params = self.compile(self.query.where)
+ having, h_params = self.compile(self.query.having)
+ having_group_by = self.query.having.get_group_by_cols()
params = []
for val in six.itervalues(self.query.extra_select):
params.extend(val[1])
@@ -180,7 +194,7 @@ class SQLCompiler(object):
(without the table names) are given unique aliases. This is needed in
some cases to avoid ambiguity with nested queries.
"""
- qn = self.quote_name_unless_alias
+ qn = self
qn2 = self.connection.ops.quote_name
result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in six.iteritems(self.query.extra_select)]
params = []
@@ -213,7 +227,7 @@ class SQLCompiler(object):
aliases.add(r)
col_aliases.add(col[1])
else:
- col_sql, col_params = col.as_sql(qn, self.connection)
+ col_sql, col_params = self.compile(col)
result.append(col_sql)
params.extend(col_params)
@@ -229,7 +243,7 @@ class SQLCompiler(object):
max_name_length = self.connection.ops.max_name_length()
for alias, aggregate in self.query.aggregate_select.items():
- agg_sql, agg_params = aggregate.as_sql(qn, self.connection)
+ agg_sql, agg_params = self.compile(aggregate)
if alias is None:
result.append(agg_sql)
else:
@@ -267,7 +281,7 @@ class SQLCompiler(object):
result = []
if opts is None:
opts = self.query.get_meta()
- qn = self.quote_name_unless_alias
+ qn = self
qn2 = self.connection.ops.quote_name
aliases = set()
only_load = self.deferred_to_columns()
@@ -319,7 +333,7 @@ class SQLCompiler(object):
Note that this method can alter the tables in the query, and thus it
must be called before get_from_clause().
"""
- qn = self.quote_name_unless_alias
+ qn = self
qn2 = self.connection.ops.quote_name
result = []
opts = self.query.get_meta()
@@ -352,7 +366,7 @@ class SQLCompiler(object):
ordering = (self.query.order_by
or self.query.get_meta().ordering
or [])
- qn = self.quote_name_unless_alias
+ qn = self
qn2 = self.connection.ops.quote_name
distinct = self.query.distinct
select_aliases = self._select_aliases
@@ -490,7 +504,7 @@ class SQLCompiler(object):
ordering and distinct must be done first.
"""
result = []
- qn = self.quote_name_unless_alias
+ qn = self
qn2 = self.connection.ops.quote_name
first = True
from_params = []
@@ -508,8 +522,7 @@ class SQLCompiler(object):
extra_cond = join_field.get_extra_restriction(
self.query.where_class, alias, lhs)
if extra_cond:
- extra_sql, extra_params = extra_cond.as_sql(
- qn, self.connection)
+ extra_sql, extra_params = self.compile(extra_cond)
extra_sql = 'AND (%s)' % extra_sql
from_params.extend(extra_params)
else:
@@ -541,7 +554,7 @@ class SQLCompiler(object):
"""
Returns a tuple representing the SQL elements in the "group by" clause.
"""
- qn = self.quote_name_unless_alias
+ qn = self
result, params = [], []
if self.query.group_by is not None:
select_cols = self.query.select + self.query.related_select_cols
@@ -560,7 +573,7 @@ class SQLCompiler(object):
if isinstance(col, (list, tuple)):
sql = '%s.%s' % (qn(col[0]), qn(col[1]))
elif hasattr(col, 'as_sql'):
- sql, col_params = col.as_sql(qn, self.connection)
+ self.compile(col)
else:
sql = '(%s)' % str(col)
if sql not in seen:
@@ -784,7 +797,7 @@ class SQLCompiler(object):
return result
def as_subquery_condition(self, alias, columns, qn):
- inner_qn = self.quote_name_unless_alias
+ inner_qn = self
qn2 = self.connection.ops.quote_name
if len(columns) == 1:
sql, params = self.as_sql()
@@ -895,9 +908,9 @@ class SQLDeleteCompiler(SQLCompiler):
"""
assert len(self.query.tables) == 1, \
"Can only delete from one table at a time."
- qn = self.quote_name_unless_alias
+ qn = self
result = ['DELETE FROM %s' % qn(self.query.tables[0])]
- where, params = self.query.where.as_sql(qn=qn, connection=self.connection)
+ where, params = self.compile(self.query.where)
if where:
result.append('WHERE %s' % where)
return ' '.join(result), tuple(params)
@@ -913,7 +926,7 @@ class SQLUpdateCompiler(SQLCompiler):
if not self.query.values:
return '', ()
table = self.query.tables[0]
- qn = self.quote_name_unless_alias
+ qn = self
result = ['UPDATE %s' % qn(table)]
result.append('SET')
values, update_params = [], []
@@ -933,7 +946,7 @@ class SQLUpdateCompiler(SQLCompiler):
val = SQLEvaluator(val, self.query, allow_joins=False)
name = field.column
if hasattr(val, 'as_sql'):
- sql, params = val.as_sql(qn, self.connection)
+ sql, params = self.compile(val)
values.append('%s = %s' % (qn(name), sql))
update_params.extend(params)
elif val is not None:
@@ -944,7 +957,7 @@ class SQLUpdateCompiler(SQLCompiler):
if not values:
return '', ()
result.append(', '.join(values))
- where, params = self.query.where.as_sql(qn=qn, connection=self.connection)
+ where, params = self.compile(self.query.where)
if where:
result.append('WHERE %s' % where)
return ' '.join(result), tuple(update_params + params)
@@ -1024,11 +1037,11 @@ class SQLAggregateCompiler(SQLCompiler):
parameters.
"""
if qn is None:
- qn = self.quote_name_unless_alias
+ qn = self
sql, params = [], []
for aggregate in self.query.aggregate_select.values():
- agg_sql, agg_params = aggregate.as_sql(qn, self.connection)
+ agg_sql, agg_params = self.compile(aggregate)
sql.append(agg_sql)
params.extend(agg_params)
sql = ', '.join(sql)