diff options
Diffstat (limited to 'django/db/models/sql/where.py')
| -rw-r--r-- | django/db/models/sql/where.py | 34 |
1 files changed, 21 insertions, 13 deletions
diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index cbb0546d6a..ef856893b5 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -8,11 +8,13 @@ import collections import datetime from itertools import repeat -from django.utils import tree -from django.db.models.fields import Field +from django.conf import settings +from django.db.models.fields import DateTimeField, Field from django.db.models.sql.datastructures import EmptyResultSet, Empty from django.db.models.sql.aggregates import Aggregate from django.utils.six.moves import xrange +from django.utils import timezone +from django.utils import tree # Connection types AND = 'AND' @@ -60,7 +62,8 @@ class WhereNode(tree.Node): # about the value(s) to the query construction. Specifically, datetime # and empty values need special handling. Other types could be used # here in the future (using Python types is suggested for consistency). - if isinstance(value, datetime.datetime): + if (isinstance(value, datetime.datetime) + or (isinstance(obj.field, DateTimeField) and lookup_type != 'isnull')): value_annotation = datetime.datetime elif hasattr(value, 'value_annotation'): value_annotation = value.value_annotation @@ -169,15 +172,13 @@ class WhereNode(tree.Node): if isinstance(lvalue, tuple): # A direct database column lookup. - field_sql = self.sql_for_columns(lvalue, qn, connection) + field_sql, field_params = self.sql_for_columns(lvalue, qn, connection), [] else: # A smart object with an as_sql() method. - field_sql = lvalue.as_sql(qn, connection) + field_sql, field_params = lvalue.as_sql(qn, connection) - if value_annotation is datetime.datetime: - cast_sql = connection.ops.datetime_cast_sql() - else: - cast_sql = '%s' + is_datetime_field = value_annotation is datetime.datetime + cast_sql = connection.ops.datetime_cast_sql() if is_datetime_field else '%s' if hasattr(params, 'as_sql'): extra, params = params.as_sql(qn, connection) @@ -185,6 +186,8 @@ class WhereNode(tree.Node): else: extra = '' + params = field_params + params + if (len(params) == 1 and params[0] == '' and lookup_type == 'exact' and connection.features.interprets_empty_strings_as_nulls): lookup_type = 'isnull' @@ -221,9 +224,14 @@ class WhereNode(tree.Node): params) elif lookup_type in ('range', 'year'): return ('%s BETWEEN %%s and %%s' % field_sql, params) + elif is_datetime_field and lookup_type in ('month', 'day', 'week_day', + 'hour', 'minute', 'second'): + tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None + sql, tz_params = connection.ops.datetime_extract_sql(lookup_type, field_sql, tzname) + return ('%s = %%s' % sql, tz_params + params) elif lookup_type in ('month', 'day', 'week_day'): - return ('%s = %%s' % connection.ops.date_extract_sql(lookup_type, field_sql), - params) + return ('%s = %%s' + % connection.ops.date_extract_sql(lookup_type, field_sql), params) elif lookup_type == 'isnull': assert value_annotation in (True, False), "Invalid value_annotation for isnull" return ('%s IS %sNULL' % (field_sql, ('' if value_annotation else 'NOT ')), ()) @@ -238,7 +246,7 @@ class WhereNode(tree.Node): """ Returns the SQL fragment used for the left-hand side of a column constraint (for example, the "T1.foo" portion in the clause - "WHERE ... T1.foo = 6"). + "WHERE ... T1.foo = 6") and a list of parameters. """ table_alias, name, db_type = data if table_alias: @@ -331,7 +339,7 @@ class ExtraWhere(object): def as_sql(self, qn=None, connection=None): sqls = ["(%s)" % sql for sql in self.sqls] - return " AND ".join(sqls), tuple(self.params or ()) + return " AND ".join(sqls), list(self.params or ()) def clone(self): return self |
