diff options
| author | Russell Keith-Magee <russell@keith-magee.com> | 2009-01-29 10:46:36 +0000 |
|---|---|---|
| committer | Russell Keith-Magee <russell@keith-magee.com> | 2009-01-29 10:46:36 +0000 |
| commit | cf37e4624a967f936ecbb5a4eefc9d38ed9d7892 (patch) | |
| tree | e44fab9a21ccdf130d85b6fb80c423181663f103 /django/db/models/sql | |
| parent | 08dd4176edc1019d9168608b55fe777512c641cb (diff) | |
Fixed #7210 -- Added F() expressions to query language. See the documentation for details on usage.
Many thanks to:
* Nicolas Lara, who worked on this feature during the 2008 Google Summer of Code.
* Alex Gaynor for his help debugging and fixing a number of issues.
* Malcolm Tredinnick for his invaluable review notes.
git-svn-id: http://code.djangoproject.com/svn/django/trunk@9792 bcc190cf-cafb-0310-a4f2-bffc1f526a37
Diffstat (limited to 'django/db/models/sql')
| -rw-r--r-- | django/db/models/sql/expressions.py | 92 | ||||
| -rw-r--r-- | django/db/models/sql/query.py | 18 | ||||
| -rw-r--r-- | django/db/models/sql/subqueries.py | 9 | ||||
| -rw-r--r-- | django/db/models/sql/where.py | 10 |
4 files changed, 124 insertions, 5 deletions
diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py new file mode 100644 index 0000000000..878f13bbf7 --- /dev/null +++ b/django/db/models/sql/expressions.py @@ -0,0 +1,92 @@ +from django.core.exceptions import FieldError +from django.db import connection +from django.db.models.fields import FieldDoesNotExist +from django.db.models.sql.constants import LOOKUP_SEP + +class SQLEvaluator(object): + def __init__(self, expression, query, allow_joins=True): + self.expression = expression + self.opts = query.get_meta() + self.cols = {} + + self.contains_aggregate = False + self.expression.prepare(self, query, allow_joins) + + def as_sql(self, qn=None): + return self.expression.evaluate(self, qn) + + def relabel_aliases(self, change_map): + for node, col in self.cols.items(): + self.cols[node] = (change_map.get(col[0], col[0]), col[1]) + + ##################################################### + # Vistor methods for initial expression preparation # + ##################################################### + + def prepare_node(self, node, query, allow_joins): + for child in node.children: + if hasattr(child, 'prepare'): + child.prepare(self, query, allow_joins) + + def prepare_leaf(self, node, query, allow_joins): + if not allow_joins and LOOKUP_SEP in node.name: + raise FieldError("Joined field references are not permitted in this query") + + field_list = node.name.split(LOOKUP_SEP) + if (len(field_list) == 1 and + node.name in query.aggregate_select.keys()): + self.contains_aggregate = True + self.cols[node] = query.aggregate_select[node.name] + else: + try: + field, source, opts, join_list, last, _ = query.setup_joins( + field_list, query.get_meta(), + query.get_initial_alias(), False) + _, _, col, _, join_list = query.trim_joins(source, join_list, last, False) + + self.cols[node] = (join_list[-1], col) + except FieldDoesNotExist: + raise FieldError("Cannot resolve keyword %r into field. " + "Choices are: %s" % (self.name, + [f.name for f in self.opts.fields])) + + ################################################## + # Vistor methods for final expression evaluation # + ################################################## + + def evaluate_node(self, node, qn): + if not qn: + qn = connection.ops.quote_name + + expressions = [] + expression_params = [] + for child in node.children: + if hasattr(child, 'evaluate'): + sql, params = child.evaluate(self, qn) + else: + try: + sql, params = qn(child), () + except: + sql, params = str(child), () + + if hasattr(child, 'children') > 1: + format = '(%s)' + else: + format = '%s' + + if sql: + expressions.append(format % sql) + expression_params.extend(params) + conn = ' %s ' % node.connector + + return conn.join(expressions), expression_params + + def evaluate_leaf(self, node, qn): + if not qn: + qn = connection.ops.quote_name + + col = self.cols[node] + if hasattr(col, 'as_sql'): + return col.as_sql(qn), () + else: + return '%s.%s' % (qn(col[0]), qn(col[1])), () diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 88847d87e1..4e46da6424 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -18,6 +18,7 @@ from django.db.models import signals from django.db.models.fields import FieldDoesNotExist from django.db.models.query_utils import select_related_descend from django.db.models.sql import aggregates as base_aggregates_module +from django.db.models.sql.expressions import SQLEvaluator from django.db.models.sql.where import WhereNode, Constraint, EverythingNode, AND, OR from django.core.exceptions import FieldError from datastructures import EmptyResultSet, Empty, MultiJoin @@ -1271,6 +1272,10 @@ class BaseQuery(object): else: lookup_type = parts.pop() + # By default, this is a WHERE clause. If an aggregate is referenced + # in the value, the filter will be promoted to a HAVING + having_clause = False + # Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all # uses of None as a query value. if value is None: @@ -1284,6 +1289,10 @@ class BaseQuery(object): value = True elif callable(value): value = value() + elif hasattr(value, 'evaluate'): + # If value is a query expression, evaluate it + value = SQLEvaluator(value, self) + having_clause = value.contains_aggregate for alias, aggregate in self.aggregate_select.items(): if alias == parts[0]: @@ -1340,8 +1349,13 @@ class BaseQuery(object): self.promote_alias_chain(join_it, join_promote) self.promote_alias_chain(table_it, table_promote) - self.where.add((Constraint(alias, col, field), lookup_type, value), - connector) + + if having_clause: + self.having.add((Constraint(alias, col, field), lookup_type, value), + connector) + else: + self.where.add((Constraint(alias, col, field), lookup_type, value), + connector) if negate: self.promote_alias_chain(join_list) diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index 0a59b403c8..f2589ea2b6 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -5,6 +5,7 @@ Query subclasses which provide extra functionality beyond simple data retrieval. from django.core.exceptions import FieldError from django.db.models.sql.constants import * from django.db.models.sql.datastructures import Date +from django.db.models.sql.expressions import SQLEvaluator from django.db.models.sql.query import Query from django.db.models.sql.where import AND, Constraint @@ -136,7 +137,11 @@ class UpdateQuery(Query): result.append('SET') values, update_params = [], [] for name, val, placeholder in self.values: - if val is not None: + if hasattr(val, 'as_sql'): + sql, params = val.as_sql(qn) + values.append('%s = %s' % (qn(name), sql)) + update_params.extend(params) + elif val is not None: values.append('%s = %s' % (qn(name), placeholder)) update_params.append(val) else: @@ -251,6 +256,8 @@ class UpdateQuery(Query): else: placeholder = '%s' + if hasattr(val, 'evaluate'): + val = SQLEvaluator(val, self, allow_joins=False) if model: self.add_related_update(model, field.column, val, placeholder) else: diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 9ce1e7bf2d..8724906a8c 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -97,6 +97,7 @@ class WhereNode(tree.Node): else: # A leaf node in the tree. sql, params = self.make_atom(child, qn) + except EmptyResultSet: if self.connector == AND and not self.negated: # We can bail out early in this particular case (only). @@ -114,6 +115,7 @@ class WhereNode(tree.Node): if self.negated: empty = True continue + empty = False if sql: result.append(sql) @@ -151,8 +153,9 @@ class WhereNode(tree.Node): else: cast_sql = '%s' - if isinstance(params, QueryWrapper): - extra, params = params.data + if hasattr(params, 'as_sql'): + extra, params = params.as_sql(qn) + cast_sql = '' else: extra = '' @@ -214,6 +217,9 @@ class WhereNode(tree.Node): if elt[0] in change_map: elt[0] = change_map[elt[0]] node.children[pos] = (tuple(elt),) + child[1:] + # Check if the query value also requires relabelling + if hasattr(child[3], 'relabel_aliases'): + child[3].relabel_aliases(change_map) class EverythingNode(object): """ |
