diff options
Diffstat (limited to 'django/db/models/sql/where.py')
| -rw-r--r-- | django/db/models/sql/where.py | 77 |
1 files changed, 46 insertions, 31 deletions
diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index ec0545ca5b..4aa2351f17 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -4,7 +4,6 @@ Code to manage the creation and SQL rendering of 'where' constraints. import datetime from django.utils import tree -from django.db import connection from django.db.models.fields import Field from django.db.models.query_utils import QueryWrapper from datastructures import EmptyResultSet, FullResultSet @@ -51,18 +50,6 @@ class WhereNode(tree.Node): # Consume any generators immediately, so that we can determine # emptiness and transform any non-empty values correctly. value = list(value) - if hasattr(obj, "process"): - try: - obj, params = obj.process(lookup_type, value) - except (EmptyShortCircuit, EmptyResultSet): - # There are situations where we want to short-circuit any - # comparisons and make sure that nothing is returned. One - # example is when checking for a NULL pk value, or the - # equivalent. - super(WhereNode, self).add(NothingNode(), connector) - return - else: - params = Field().get_db_prep_lookup(lookup_type, value) # The "annotation" parameter is used to pass auxilliary information # about the value(s) to the query construction. Specifically, datetime @@ -75,10 +62,16 @@ class WhereNode(tree.Node): else: annotation = bool(value) - super(WhereNode, self).add((obj, lookup_type, annotation, params), + if hasattr(obj, "prepare"): + value = obj.prepare(lookup_type, value) + super(WhereNode, self).add((obj, lookup_type, annotation, value), + connector) + return + + super(WhereNode, self).add((obj, lookup_type, annotation, value), connector) - def as_sql(self, qn=None): + def as_sql(self, qn, connection): """ Returns the SQL version of the where clause and the value to be substituted in. Returns None, None if this node is empty. @@ -87,8 +80,6 @@ class WhereNode(tree.Node): (generally not needed except by the internal implementation for recursion). """ - if not qn: - qn = connection.ops.quote_name if not self.children: return None, [] result = [] @@ -97,10 +88,10 @@ class WhereNode(tree.Node): for child in self.children: try: if hasattr(child, 'as_sql'): - sql, params = child.as_sql(qn=qn) + sql, params = child.as_sql(qn=qn, connection=connection) else: # A leaf node in the tree. - sql, params = self.make_atom(child, qn) + sql, params = self.make_atom(child, qn, connection) except EmptyResultSet: if self.connector == AND and not self.negated: @@ -136,7 +127,7 @@ class WhereNode(tree.Node): sql_string = '(%s)' % sql_string return sql_string, result_params - def make_atom(self, child, qn): + def make_atom(self, child, qn, connection): """ Turn a tuple (table_alias, column_name, db_type, lookup_type, value_annot, params) into valid SQL. @@ -144,13 +135,21 @@ class WhereNode(tree.Node): Returns the string for the SQL fragment and the parameters to use for it. """ - lvalue, lookup_type, value_annot, params = child + lvalue, lookup_type, value_annot, params_or_value = child + if hasattr(lvalue, 'process'): + try: + lvalue, params = lvalue.process(lookup_type, params_or_value, connection) + except EmptyShortCircuit: + raise EmptyResultSet + else: + params = Field().get_db_prep_lookup(lookup_type, params_or_value, + connection=connection, prepared=True) if isinstance(lvalue, tuple): # A direct database column lookup. - field_sql = self.sql_for_columns(lvalue, qn) + field_sql = self.sql_for_columns(lvalue, qn, connection) else: # A smart object with an as_sql() method. - field_sql = lvalue.as_sql(quote_func=qn) + field_sql = lvalue.as_sql(qn, connection) if value_annot is datetime.datetime: cast_sql = connection.ops.datetime_cast_sql() @@ -158,11 +157,16 @@ class WhereNode(tree.Node): cast_sql = '%s' if hasattr(params, 'as_sql'): - extra, params = params.as_sql(qn) + extra, params = params.as_sql(qn, connection) cast_sql = '' else: extra = '' + if (len(params) == 1 and params[0] == '' and lookup_type == 'exact' + and connection.features.interprets_empty_strings_as_nulls): + lookup_type = 'isnull' + value_annot = True + if lookup_type in connection.operators: format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),) return (format % (field_sql, @@ -191,7 +195,7 @@ class WhereNode(tree.Node): raise TypeError('Invalid lookup_type: %r' % lookup_type) - def sql_for_columns(self, data, qn): + def sql_for_columns(self, data, qn, connection): """ Returns the SQL fragment used for the left-hand side of a column constraint (for example, the "T1.foo" portion in the clause @@ -233,7 +237,8 @@ class EverythingNode(object): """ A node that matches everything. """ - def as_sql(self, qn=None): + + def as_sql(self, qn=None, connection=None): raise FullResultSet def relabel_aliases(self, change_map, node=None): @@ -243,7 +248,7 @@ class NothingNode(object): """ A node that matches nothing. """ - def as_sql(self, qn=None): + def as_sql(self, qn=None, connection=None): raise EmptyResultSet def relabel_aliases(self, change_map, node=None): @@ -257,7 +262,12 @@ class Constraint(object): def __init__(self, alias, col, field): self.alias, self.col, self.field = alias, col, field - def process(self, lookup_type, value): + def prepare(self, lookup_type, value): + if self.field: + return self.field.get_prep_lookup(lookup_type, value) + return value + + def process(self, lookup_type, value, connection): """ Returns a tuple of data suitable for inclusion in a WhereNode instance. @@ -266,16 +276,21 @@ class Constraint(object): from django.db.models.base import ObjectDoesNotExist try: if self.field: - params = self.field.get_db_prep_lookup(lookup_type, value) - db_type = self.field.db_type() + params = self.field.get_db_prep_lookup(lookup_type, value, + connection=connection, prepared=True) + db_type = self.field.db_type(connection=connection) else: # This branch is used at times when we add a comparison to NULL # (we don't really want to waste time looking up the associated # field object at the calling location). - params = Field().get_db_prep_lookup(lookup_type, value) + params = Field().get_db_prep_lookup(lookup_type, value, + connection=connection, prepared=True) db_type = None except ObjectDoesNotExist: raise EmptyShortCircuit return (self.alias, self.col, db_type), params + def relabel_aliases(self, change_map): + if self.alias in change_map: + self.alias = change_map[self.alias] |
