diff options
Diffstat (limited to 'django/db/models/sql')
| -rw-r--r-- | django/db/models/sql/aggregates.py | 11 | ||||
| -rw-r--r-- | django/db/models/sql/compiler.py | 99 | ||||
| -rw-r--r-- | django/db/models/sql/constants.py | 3 | ||||
| -rw-r--r-- | django/db/models/sql/datastructures.py | 23 | ||||
| -rw-r--r-- | django/db/models/sql/expressions.py | 4 | ||||
| -rw-r--r-- | django/db/models/sql/subqueries.py | 48 | ||||
| -rw-r--r-- | django/db/models/sql/where.py | 34 |
7 files changed, 161 insertions, 61 deletions
diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py index 75a330f22a..3c8720210b 100644 --- a/django/db/models/sql/aggregates.py +++ b/django/db/models/sql/aggregates.py @@ -73,22 +73,23 @@ class Aggregate(object): self.col = (change_map.get(self.col[0], self.col[0]), self.col[1]) def as_sql(self, qn, connection): - "Return the aggregate, rendered as SQL." + "Return the aggregate, rendered as SQL with parameters." + params = [] if hasattr(self.col, 'as_sql'): - field_name = self.col.as_sql(qn, connection) + field_name, params = self.col.as_sql(qn, connection) elif isinstance(self.col, (list, tuple)): field_name = '.'.join([qn(c) for c in self.col]) else: field_name = self.col - params = { + substitutions = { 'function': self.sql_function, 'field': field_name } - params.update(self.extra) + substitutions.update(self.extra) - return self.sql_template % params + return self.sql_template % substitutions, params class Avg(Aggregate): diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 79b5d99452..1b6654b670 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -1,5 +1,6 @@ -from django.utils.six.moves import zip +import datetime +from django.conf import settings from django.core.exceptions import FieldError from django.db import transaction from django.db.backends.util import truncate_name @@ -12,6 +13,8 @@ from django.db.models.sql.expressions import SQLEvaluator from django.db.models.sql.query import get_order_dir, Query from django.db.utils import DatabaseError from django.utils import six +from django.utils.six.moves import zip +from django.utils import timezone class SQLCompiler(object): @@ -71,7 +74,7 @@ class SQLCompiler(object): # as the pre_sql_setup will modify query state in a way that forbids # another run of it. self.refcounts_before = self.query.alias_refcount.copy() - out_cols = self.get_columns(with_col_aliases) + out_cols, s_params = self.get_columns(with_col_aliases) ordering, ordering_group_by = self.get_ordering() distinct_fields = self.get_distinct() @@ -94,6 +97,7 @@ class SQLCompiler(object): result.append(self.connection.ops.distinct_sql(distinct_fields)) result.append(', '.join(out_cols + self.query.ordering_aliases)) + params.extend(s_params) result.append('FROM') result.extend(from_) @@ -161,9 +165,10 @@ class SQLCompiler(object): def get_columns(self, with_aliases=False): """ - Returns the list of columns to use in the select statement. If no - columns have been specified, returns all columns relating to fields in - the model. + Returns the list of columns to use in the select statement, as well as + a list any extra parameters that need to be included. If no columns + have been specified, returns all columns relating to fields in the + model. If 'with_aliases' is true, any column names that are duplicated (without the table names) are given unique aliases. This is needed in @@ -172,6 +177,7 @@ class SQLCompiler(object): qn = self.quote_name_unless_alias qn2 = self.connection.ops.quote_name result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in six.iteritems(self.query.extra_select)] + params = [] aliases = set(self.query.extra_select.keys()) if with_aliases: col_aliases = aliases.copy() @@ -201,7 +207,9 @@ class SQLCompiler(object): aliases.add(r) col_aliases.add(col[1]) else: - result.append(col.as_sql(qn, self.connection)) + col_sql, col_params = col.as_sql(qn, self.connection) + result.append(col_sql) + params.extend(col_params) if hasattr(col, 'alias'): aliases.add(col.alias) @@ -214,15 +222,13 @@ class SQLCompiler(object): aliases.update(new_aliases) max_name_length = self.connection.ops.max_name_length() - result.extend([ - '%s%s' % ( - aggregate.as_sql(qn, self.connection), - alias is not None - and ' AS %s' % qn(truncate_name(alias, max_name_length)) - or '' - ) - for alias, aggregate in self.query.aggregate_select.items() - ]) + for alias, aggregate in self.query.aggregate_select.items(): + agg_sql, agg_params = aggregate.as_sql(qn, self.connection) + if alias is None: + result.append(agg_sql) + else: + result.append('%s AS %s' % (agg_sql, qn(truncate_name(alias, max_name_length)))) + params.extend(agg_params) for (table, col), _ in self.query.related_select_cols: r = '%s.%s' % (qn(table), qn(col)) @@ -237,7 +243,7 @@ class SQLCompiler(object): col_aliases.add(col) self._select_aliases = aliases - return result + return result, params def get_default_columns(self, with_aliases=False, col_aliases=None, start_alias=None, opts=None, as_pairs=False, from_parent=None): @@ -542,14 +548,16 @@ class SQLCompiler(object): seen = set() cols = self.query.group_by + select_cols for col in cols: + col_params = () if isinstance(col, (list, tuple)): sql = '%s.%s' % (qn(col[0]), qn(col[1])) elif hasattr(col, 'as_sql'): - sql = col.as_sql(qn, self.connection) + sql, col_params = col.as_sql(qn, self.connection) else: sql = '(%s)' % str(col) if sql not in seen: result.append(sql) + params.extend(col_params) seen.add(sql) # Still, we need to add all stuff in ordering (except if the backend can @@ -988,15 +996,17 @@ class SQLAggregateCompiler(SQLCompiler): if qn is None: qn = self.quote_name_unless_alias - sql = ('SELECT %s FROM (%s) subquery' % ( - ', '.join([ - aggregate.as_sql(qn, self.connection) - for aggregate in self.query.aggregate_select.values() - ]), - self.query.subquery) - ) - params = self.query.sub_params - return (sql, params) + sql, params = [], [] + for aggregate in self.query.aggregate_select.values(): + agg_sql, agg_params = aggregate.as_sql(qn, self.connection) + sql.append(agg_sql) + params.extend(agg_params) + sql = ', '.join(sql) + params = tuple(params) + + sql = 'SELECT %s FROM (%s) subquery' % (sql, self.query.subquery) + params = params + self.query.sub_params + return sql, params class SQLDateCompiler(SQLCompiler): def results_iter(self): @@ -1005,10 +1015,10 @@ class SQLDateCompiler(SQLCompiler): """ resolve_columns = hasattr(self, 'resolve_columns') if resolve_columns: - from django.db.models.fields import DateTimeField - fields = [DateTimeField()] + from django.db.models.fields import DateField + fields = [DateField()] else: - from django.db.backends.util import typecast_timestamp + from django.db.backends.util import typecast_date needs_string_cast = self.connection.features.needs_datetime_string_cast offset = len(self.query.extra_select) @@ -1018,9 +1028,38 @@ class SQLDateCompiler(SQLCompiler): if resolve_columns: date = self.resolve_columns(row, fields)[offset] elif needs_string_cast: - date = typecast_timestamp(str(date)) + date = typecast_date(str(date)) + if isinstance(date, datetime.datetime): + date = date.date() yield date +class SQLDateTimeCompiler(SQLCompiler): + def results_iter(self): + """ + Returns an iterator over the results from executing this query. + """ + resolve_columns = hasattr(self, 'resolve_columns') + if resolve_columns: + from django.db.models.fields import DateTimeField + fields = [DateTimeField()] + else: + from django.db.backends.util import typecast_timestamp + needs_string_cast = self.connection.features.needs_datetime_string_cast + + offset = len(self.query.extra_select) + for rows in self.execute_sql(MULTI): + for row in rows: + datetime = row[offset] + if resolve_columns: + datetime = self.resolve_columns(row, fields)[offset] + elif needs_string_cast: + datetime = typecast_timestamp(str(datetime)) + # Datetimes are artifically returned in UTC on databases that + # don't support time zone. Restore the zone used in the query. + if settings.USE_TZ: + datetime = datetime.replace(tzinfo=None) + datetime = timezone.make_aware(datetime, self.query.tzinfo) + yield datetime def order_modified_iter(cursor, trim, sentinel): """ diff --git a/django/db/models/sql/constants.py b/django/db/models/sql/constants.py index 1764db7fcc..81bd646d69 100644 --- a/django/db/models/sql/constants.py +++ b/django/db/models/sql/constants.py @@ -11,7 +11,8 @@ import re QUERY_TERMS = set([ 'exact', 'iexact', 'contains', 'icontains', 'gt', 'gte', 'lt', 'lte', 'in', 'startswith', 'istartswith', 'endswith', 'iendswith', 'range', 'year', - 'month', 'day', 'week_day', 'isnull', 'search', 'regex', 'iregex', + 'month', 'day', 'week_day', 'hour', 'minute', 'second', 'isnull', 'search', + 'regex', 'iregex', ]) # Size of each "chunk" for get_iterator calls. diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py index b8e06daf01..612eb8f2d9 100644 --- a/django/db/models/sql/datastructures.py +++ b/django/db/models/sql/datastructures.py @@ -40,4 +40,25 @@ class Date(object): col = '%s.%s' % tuple([qn(c) for c in self.col]) else: col = self.col - return connection.ops.date_trunc_sql(self.lookup_type, col) + return connection.ops.date_trunc_sql(self.lookup_type, col), [] + +class DateTime(object): + """ + Add a datetime selection column. + """ + def __init__(self, col, lookup_type, tzname): + self.col = col + self.lookup_type = lookup_type + self.tzname = tzname + + def relabel_aliases(self, change_map): + c = self.col + if isinstance(c, (list, tuple)): + self.col = (change_map.get(c[0], c[0]), c[1]) + + def as_sql(self, qn, connection): + if isinstance(self.col, (list, tuple)): + col = '%s.%s' % tuple([qn(c) for c in self.col]) + else: + col = self.col + return connection.ops.datetime_trunc_sql(self.lookup_type, col, self.tzname) diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py index a4c1d85c65..2a5008f067 100644 --- a/django/db/models/sql/expressions.py +++ b/django/db/models/sql/expressions.py @@ -94,9 +94,9 @@ class SQLEvaluator(object): if col is None: raise ValueError("Given node not found") if hasattr(col, 'as_sql'): - return col.as_sql(qn, connection), () + return col.as_sql(qn, connection) else: - return '%s.%s' % (qn(col[0]), qn(col[1])), () + return '%s.%s' % (qn(col[0]), qn(col[1])), [] def evaluate_date_modifier_node(self, node, qn, connection): timedelta = node.children.pop() diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index 6072804697..6aac5c898c 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -2,22 +2,23 @@ Query subclasses which provide extra functionality beyond simple data retrieval. """ +from django.conf import settings from django.core.exceptions import FieldError from django.db import connections from django.db.models.constants import LOOKUP_SEP -from django.db.models.fields import DateField, FieldDoesNotExist +from django.db.models.fields import DateField, DateTimeField, FieldDoesNotExist from django.db.models.sql.constants import * -from django.db.models.sql.datastructures import Date +from django.db.models.sql.datastructures import Date, DateTime from django.db.models.sql.query import Query from django.db.models.sql.where import AND, Constraint -from django.utils.datastructures import SortedDict from django.utils.functional import Promise from django.utils.encoding import force_text from django.utils import six +from django.utils import timezone __all__ = ['DeleteQuery', 'UpdateQuery', 'InsertQuery', 'DateQuery', - 'AggregateQuery'] + 'DateTimeQuery', 'AggregateQuery'] class DeleteQuery(Query): """ @@ -223,9 +224,9 @@ class DateQuery(Query): compiler = 'SQLDateCompiler' - def add_date_select(self, field_name, lookup_type, order='ASC'): + def add_select(self, field_name, lookup_type, order='ASC'): """ - Converts the query into a date extraction query. + Converts the query into an extraction query. """ try: result = self.setup_joins( @@ -238,10 +239,9 @@ class DateQuery(Query): self.model._meta.object_name, field_name )) field = result[0] - assert isinstance(field, DateField), "%r isn't a DateField." \ - % field.name + self._check_field(field) # overridden in DateTimeQuery alias = result[3][-1] - select = Date((alias, field.column), lookup_type) + select = self._get_select((alias, field.column), lookup_type) self.clear_select_clause() self.select = [SelectInfo(select, None)] self.distinct = True @@ -250,6 +250,36 @@ class DateQuery(Query): if field.null: self.add_filter(("%s__isnull" % field_name, False)) + def _check_field(self, field): + assert isinstance(field, DateField), \ + "%r isn't a DateField." % field.name + if settings.USE_TZ: + assert not isinstance(field, DateTimeField), \ + "%r is a DateTimeField, not a DateField." % field.name + + def _get_select(self, col, lookup_type): + return Date(col, lookup_type) + +class DateTimeQuery(DateQuery): + """ + A DateTimeQuery is like a DateQuery but for a datetime field. If time zone + support is active, the tzinfo attribute contains the time zone to use for + converting the values before truncating them. Otherwise it's set to None. + """ + + compiler = 'SQLDateTimeCompiler' + + def _check_field(self, field): + assert isinstance(field, DateTimeField), \ + "%r isn't a DateTimeField." % field.name + + def _get_select(self, col, lookup_type): + if self.tzinfo is None: + tzname = None + else: + tzname = timezone._get_timezone_name(self.tzinfo) + return DateTime(col, lookup_type, tzname) + class AggregateQuery(Query): """ An AggregateQuery takes another query as a parameter to the FROM diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index cbb0546d6a..ef856893b5 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -8,11 +8,13 @@ import collections import datetime from itertools import repeat -from django.utils import tree -from django.db.models.fields import Field +from django.conf import settings +from django.db.models.fields import DateTimeField, Field from django.db.models.sql.datastructures import EmptyResultSet, Empty from django.db.models.sql.aggregates import Aggregate from django.utils.six.moves import xrange +from django.utils import timezone +from django.utils import tree # Connection types AND = 'AND' @@ -60,7 +62,8 @@ class WhereNode(tree.Node): # about the value(s) to the query construction. Specifically, datetime # and empty values need special handling. Other types could be used # here in the future (using Python types is suggested for consistency). - if isinstance(value, datetime.datetime): + if (isinstance(value, datetime.datetime) + or (isinstance(obj.field, DateTimeField) and lookup_type != 'isnull')): value_annotation = datetime.datetime elif hasattr(value, 'value_annotation'): value_annotation = value.value_annotation @@ -169,15 +172,13 @@ class WhereNode(tree.Node): if isinstance(lvalue, tuple): # A direct database column lookup. - field_sql = self.sql_for_columns(lvalue, qn, connection) + field_sql, field_params = self.sql_for_columns(lvalue, qn, connection), [] else: # A smart object with an as_sql() method. - field_sql = lvalue.as_sql(qn, connection) + field_sql, field_params = lvalue.as_sql(qn, connection) - if value_annotation is datetime.datetime: - cast_sql = connection.ops.datetime_cast_sql() - else: - cast_sql = '%s' + is_datetime_field = value_annotation is datetime.datetime + cast_sql = connection.ops.datetime_cast_sql() if is_datetime_field else '%s' if hasattr(params, 'as_sql'): extra, params = params.as_sql(qn, connection) @@ -185,6 +186,8 @@ class WhereNode(tree.Node): else: extra = '' + params = field_params + params + if (len(params) == 1 and params[0] == '' and lookup_type == 'exact' and connection.features.interprets_empty_strings_as_nulls): lookup_type = 'isnull' @@ -221,9 +224,14 @@ class WhereNode(tree.Node): params) elif lookup_type in ('range', 'year'): return ('%s BETWEEN %%s and %%s' % field_sql, params) + elif is_datetime_field and lookup_type in ('month', 'day', 'week_day', + 'hour', 'minute', 'second'): + tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None + sql, tz_params = connection.ops.datetime_extract_sql(lookup_type, field_sql, tzname) + return ('%s = %%s' % sql, tz_params + params) elif lookup_type in ('month', 'day', 'week_day'): - return ('%s = %%s' % connection.ops.date_extract_sql(lookup_type, field_sql), - params) + return ('%s = %%s' + % connection.ops.date_extract_sql(lookup_type, field_sql), params) elif lookup_type == 'isnull': assert value_annotation in (True, False), "Invalid value_annotation for isnull" return ('%s IS %sNULL' % (field_sql, ('' if value_annotation else 'NOT ')), ()) @@ -238,7 +246,7 @@ class WhereNode(tree.Node): """ Returns the SQL fragment used for the left-hand side of a column constraint (for example, the "T1.foo" portion in the clause - "WHERE ... T1.foo = 6"). + "WHERE ... T1.foo = 6") and a list of parameters. """ table_alias, name, db_type = data if table_alias: @@ -331,7 +339,7 @@ class ExtraWhere(object): def as_sql(self, qn=None, connection=None): sqls = ["(%s)" % sql for sql in self.sqls] - return " AND ".join(sqls), tuple(self.params or ()) + return " AND ".join(sqls), list(self.params or ()) def clone(self): return self |
