diff options
| author | Josh Smeaton <josh.smeaton@gmail.com> | 2015-01-17 16:03:46 +1100 |
|---|---|---|
| committer | Josh Smeaton <josh.smeaton@gmail.com> | 2015-01-27 12:20:06 +1100 |
| commit | 8196e4bdf498acb05e6680c81f9d7bf700f4295c (patch) | |
| tree | d7b8c51787581b4f9ead00686380bba89d678296 /django/db | |
| parent | 511be35779a98427387d9aa4abacce01dedd7272 (diff) | |
Fixed #24154 -- Backends can now check support for expressions
Diffstat (limited to 'django/db')
| -rw-r--r-- | django/db/backends/base/features.py | 8 | ||||
| -rw-r--r-- | django/db/backends/base/operations.py | 20 | ||||
| -rw-r--r-- | django/db/backends/sqlite3/features.py | 2 | ||||
| -rw-r--r-- | django/db/backends/sqlite3/operations.py | 24 | ||||
| -rw-r--r-- | django/db/models/aggregates.py | 11 | ||||
| -rw-r--r-- | django/db/models/expressions.py | 15 | ||||
| -rw-r--r-- | django/db/models/sql/query.py | 5 | ||||
| -rw-r--r-- | django/db/models/sql/where.py | 11 |
8 files changed, 41 insertions, 55 deletions
diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py index 0f6ee0efe3..4b4d5c6d75 100644 --- a/django/db/backends/base/features.py +++ b/django/db/backends/base/features.py @@ -1,3 +1,5 @@ +from django.db.models.aggregates import StdDev +from django.db.models.expressions import Value from django.db.utils import ProgrammingError from django.utils.functional import cached_property @@ -226,12 +228,8 @@ class BaseDatabaseFeatures(object): @cached_property def supports_stddev(self): """Confirm support for STDDEV and related stats functions.""" - class StdDevPop(object): - contains_aggregate = True - sql_function = 'STDDEV_POP' - try: - self.connection.ops.check_aggregate_support(StdDevPop()) + self.connection.ops.check_expression_support(StdDev(Value(1))) return True except NotImplementedError: return False diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py index 09a6316376..50f6c93a07 100644 --- a/django/db/backends/base/operations.py +++ b/django/db/backends/base/operations.py @@ -1,12 +1,14 @@ import datetime import decimal from importlib import import_module +import warnings from django.conf import settings from django.core.exceptions import ImproperlyConfigured from django.db.backends import utils from django.utils import six, timezone from django.utils.dateparse import parse_duration +from django.utils.deprecation import RemovedInDjango21Warning from django.utils.encoding import force_text @@ -517,12 +519,20 @@ class BaseDatabaseOperations(object): return value def check_aggregate_support(self, aggregate_func): - """Check that the backend supports the provided aggregate + warnings.warn( + "check_aggregate_support has been deprecated. Use " + "check_expression_support instead.", + RemovedInDjango21Warning, stacklevel=2) + return self.check_expression_support(aggregate_func) - This is used on specific backends to rule out known aggregates - that are known to have faulty implementations. If the named - aggregate function has a known problem, the backend should - raise NotImplementedError. + def check_expression_support(self, expression): + """ + Check that the backend supports the provided expression. + + This is used on specific backends to rule out known expressions + that have problematic or nonexistent implementations. If the + expression has a known problem, the backend should raise + NotImplementedError. """ pass diff --git a/django/db/backends/sqlite3/features.py b/django/db/backends/sqlite3/features.py index fa5a002603..ee86469177 100644 --- a/django/db/backends/sqlite3/features.py +++ b/django/db/backends/sqlite3/features.py @@ -60,7 +60,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): """Confirm support for STDDEV and related stats functions SQLite supports STDDEV as an extension package; so - connection.ops.check_aggregate_support() can't unilaterally + connection.ops.check_expression_support() can't unilaterally rule out support for STDDEV. We need to manually check whether the call works. """ diff --git a/django/db/backends/sqlite3/operations.py b/django/db/backends/sqlite3/operations.py index ad05f88585..cd29092cd7 100644 --- a/django/db/backends/sqlite3/operations.py +++ b/django/db/backends/sqlite3/operations.py @@ -4,7 +4,7 @@ import datetime import uuid from django.conf import settings -from django.core.exceptions import ImproperlyConfigured +from django.core.exceptions import ImproperlyConfigured, FieldError from django.db import utils from django.db.backends import utils as backend_utils from django.db.backends.base.operations import BaseDatabaseOperations @@ -33,15 +33,21 @@ class DatabaseOperations(BaseDatabaseOperations): limit = 999 if len(fields) > 1 else 500 return (limit // len(fields)) if len(fields) > 0 else len(objs) - def check_aggregate_support(self, aggregate): + def check_expression_support(self, expression): bad_fields = (fields.DateField, fields.DateTimeField, fields.TimeField) - bad_aggregates = (aggregates.Sum, aggregates.Avg, - aggregates.Variance, aggregates.StdDev) - if aggregate.refs_field(bad_aggregates, bad_fields): - raise NotImplementedError( - 'You cannot use Sum, Avg, StdDev and Variance aggregations ' - 'on date/time fields in sqlite3 ' - 'since date/time is saved as text.') + bad_aggregates = (aggregates.Sum, aggregates.Avg, aggregates.Variance, aggregates.StdDev) + if isinstance(expression, bad_aggregates): + try: + output_field = expression.input_field.output_field + if isinstance(output_field, bad_fields): + raise NotImplementedError( + 'You cannot use Sum, Avg, StdDev and Variance aggregations ' + 'on date/time fields in sqlite3 ' + 'since date/time is saved as text.') + except FieldError: + # not every sub-expression has an output_field which is fine to + # ignore + pass def date_extract_sql(self, lookup_type, field_name): # sqlite doesn't support extract, so we fake it with the user-defined diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index 06220123ca..668f79f622 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -25,17 +25,6 @@ class Aggregate(Func): c._patch_aggregate(query) # backward-compatibility support return c - def refs_field(self, aggregate_types, field_types): - try: - return (isinstance(self, aggregate_types) and - isinstance(self.input_field._output_field_or_none, field_types)) - except FieldError: - # Sometimes we don't know the input_field's output type (for example, - # doing Sum(F('datetimefield') + F('datefield'), output_type=DateTimeField()) - # is OK, but the Expression(F('datetimefield') + F('datefield')) doesn't - # have any output field. - return False - @property def input_field(self): return self.source_expressions[0] diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 97a2a9071d..fb094fd4ef 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -297,14 +297,6 @@ class BaseExpression(object): return agg, lookup return False, () - def refs_field(self, aggregate_types, field_types): - """ - Helper method for check_aggregate_support on backends - """ - return any( - node.refs_field(aggregate_types, field_types) - for node in self.get_source_expressions()) - def prepare_database_save(self, field): return self @@ -401,6 +393,7 @@ class DurationExpression(Expression): return compiler.compile(side) def as_sql(self, compiler, connection): + connection.ops.check_expression_support(self) expressions = [] expression_params = [] sql, params = self.compile(self.lhs, compiler, connection) @@ -473,6 +466,7 @@ class Func(ExpressionNode): return c def as_sql(self, compiler, connection, function=None, template=None): + connection.ops.check_expression_support(self) sql_parts = [] params = [] for arg in self.source_expressions: @@ -511,6 +505,7 @@ class Value(ExpressionNode): self.value = value def as_sql(self, compiler, connection): + connection.ops.check_expression_support(self) val = self.value # check _output_field to avoid triggering an exception if self._output_field is not None: @@ -536,6 +531,7 @@ class Value(ExpressionNode): class DurationValue(Value): def as_sql(self, compiler, connection): + connection.ops.check_expression_support(self) if (connection.features.has_native_duration_field and connection.features.driver_supports_timedelta_args): return super(DurationValue, self).as_sql(compiler, connection) @@ -650,6 +646,7 @@ class When(ExpressionNode): return c def as_sql(self, compiler, connection, template=None): + connection.ops.check_expression_support(self) template_params = {} sql_params = [] condition_sql, condition_params = compiler.compile(self.condition) @@ -715,6 +712,7 @@ class Case(ExpressionNode): return c def as_sql(self, compiler, connection, template=None, extra=None): + connection.ops.check_expression_support(self) if not self.cases: return compiler.compile(self.default) template_params = dict(extra) if extra else {} @@ -851,6 +849,7 @@ class OrderBy(BaseExpression): return [self.expression] def as_sql(self, compiler, connection): + connection.ops.check_expression_support(self) expression_sql, params = compiler.compile(self.expression) placeholders = {'expression': expression_sql} placeholders['ordering'] = 'DESC' if self.descending else 'ASC' diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index c5926c034d..f7d5556e0a 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -230,11 +230,6 @@ class Query(object): raise ValueError("Need either using or connection") if using: connection = connections[using] - - # Check that the compiler will be able to execute the query - for alias, annotation in self.annotation_select.items(): - connection.ops.check_aggregate_support(annotation) - return connection.ops.compiler(self.compiler)(self, connection, using) def get_meta(self): diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 3475ccaee5..fe428b2418 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -153,17 +153,6 @@ class WhereNode(tree.Node): def contains_aggregate(self): return self._contains_aggregate(self) - @classmethod - def _refs_field(cls, obj, aggregate_types, field_types): - if not isinstance(obj, tree.Node): - if hasattr(obj.rhs, 'refs_field'): - return obj.rhs.refs_field(aggregate_types, field_types) - return False - return any(cls._refs_field(c, aggregate_types, field_types) for c in obj.children) - - def refs_field(self, aggregate_types, field_types): - return self._refs_field(self, aggregate_types, field_types) - class EmptyWhere(WhereNode): def add(self, data, connector): |
