summaryrefslogtreecommitdiff
path: root/django/db
diff options
context:
space:
mode:
authorJosh Smeaton <josh.smeaton@gmail.com>2015-01-17 16:03:46 +1100
committerJosh Smeaton <josh.smeaton@gmail.com>2015-01-27 12:20:06 +1100
commit8196e4bdf498acb05e6680c81f9d7bf700f4295c (patch)
treed7b8c51787581b4f9ead00686380bba89d678296 /django/db
parent511be35779a98427387d9aa4abacce01dedd7272 (diff)
Fixed #24154 -- Backends can now check support for expressions
Diffstat (limited to 'django/db')
-rw-r--r--django/db/backends/base/features.py8
-rw-r--r--django/db/backends/base/operations.py20
-rw-r--r--django/db/backends/sqlite3/features.py2
-rw-r--r--django/db/backends/sqlite3/operations.py24
-rw-r--r--django/db/models/aggregates.py11
-rw-r--r--django/db/models/expressions.py15
-rw-r--r--django/db/models/sql/query.py5
-rw-r--r--django/db/models/sql/where.py11
8 files changed, 41 insertions, 55 deletions
diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py
index 0f6ee0efe3..4b4d5c6d75 100644
--- a/django/db/backends/base/features.py
+++ b/django/db/backends/base/features.py
@@ -1,3 +1,5 @@
+from django.db.models.aggregates import StdDev
+from django.db.models.expressions import Value
from django.db.utils import ProgrammingError
from django.utils.functional import cached_property
@@ -226,12 +228,8 @@ class BaseDatabaseFeatures(object):
@cached_property
def supports_stddev(self):
"""Confirm support for STDDEV and related stats functions."""
- class StdDevPop(object):
- contains_aggregate = True
- sql_function = 'STDDEV_POP'
-
try:
- self.connection.ops.check_aggregate_support(StdDevPop())
+ self.connection.ops.check_expression_support(StdDev(Value(1)))
return True
except NotImplementedError:
return False
diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py
index 09a6316376..50f6c93a07 100644
--- a/django/db/backends/base/operations.py
+++ b/django/db/backends/base/operations.py
@@ -1,12 +1,14 @@
import datetime
import decimal
from importlib import import_module
+import warnings
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.db.backends import utils
from django.utils import six, timezone
from django.utils.dateparse import parse_duration
+from django.utils.deprecation import RemovedInDjango21Warning
from django.utils.encoding import force_text
@@ -517,12 +519,20 @@ class BaseDatabaseOperations(object):
return value
def check_aggregate_support(self, aggregate_func):
- """Check that the backend supports the provided aggregate
+ warnings.warn(
+ "check_aggregate_support has been deprecated. Use "
+ "check_expression_support instead.",
+ RemovedInDjango21Warning, stacklevel=2)
+ return self.check_expression_support(aggregate_func)
- This is used on specific backends to rule out known aggregates
- that are known to have faulty implementations. If the named
- aggregate function has a known problem, the backend should
- raise NotImplementedError.
+ def check_expression_support(self, expression):
+ """
+ Check that the backend supports the provided expression.
+
+ This is used on specific backends to rule out known expressions
+ that have problematic or nonexistent implementations. If the
+ expression has a known problem, the backend should raise
+ NotImplementedError.
"""
pass
diff --git a/django/db/backends/sqlite3/features.py b/django/db/backends/sqlite3/features.py
index fa5a002603..ee86469177 100644
--- a/django/db/backends/sqlite3/features.py
+++ b/django/db/backends/sqlite3/features.py
@@ -60,7 +60,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"""Confirm support for STDDEV and related stats functions
SQLite supports STDDEV as an extension package; so
- connection.ops.check_aggregate_support() can't unilaterally
+ connection.ops.check_expression_support() can't unilaterally
rule out support for STDDEV. We need to manually check
whether the call works.
"""
diff --git a/django/db/backends/sqlite3/operations.py b/django/db/backends/sqlite3/operations.py
index ad05f88585..cd29092cd7 100644
--- a/django/db/backends/sqlite3/operations.py
+++ b/django/db/backends/sqlite3/operations.py
@@ -4,7 +4,7 @@ import datetime
import uuid
from django.conf import settings
-from django.core.exceptions import ImproperlyConfigured
+from django.core.exceptions import ImproperlyConfigured, FieldError
from django.db import utils
from django.db.backends import utils as backend_utils
from django.db.backends.base.operations import BaseDatabaseOperations
@@ -33,15 +33,21 @@ class DatabaseOperations(BaseDatabaseOperations):
limit = 999 if len(fields) > 1 else 500
return (limit // len(fields)) if len(fields) > 0 else len(objs)
- def check_aggregate_support(self, aggregate):
+ def check_expression_support(self, expression):
bad_fields = (fields.DateField, fields.DateTimeField, fields.TimeField)
- bad_aggregates = (aggregates.Sum, aggregates.Avg,
- aggregates.Variance, aggregates.StdDev)
- if aggregate.refs_field(bad_aggregates, bad_fields):
- raise NotImplementedError(
- 'You cannot use Sum, Avg, StdDev and Variance aggregations '
- 'on date/time fields in sqlite3 '
- 'since date/time is saved as text.')
+ bad_aggregates = (aggregates.Sum, aggregates.Avg, aggregates.Variance, aggregates.StdDev)
+ if isinstance(expression, bad_aggregates):
+ try:
+ output_field = expression.input_field.output_field
+ if isinstance(output_field, bad_fields):
+ raise NotImplementedError(
+ 'You cannot use Sum, Avg, StdDev and Variance aggregations '
+ 'on date/time fields in sqlite3 '
+ 'since date/time is saved as text.')
+ except FieldError:
+ # not every sub-expression has an output_field which is fine to
+ # ignore
+ pass
def date_extract_sql(self, lookup_type, field_name):
# sqlite doesn't support extract, so we fake it with the user-defined
diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py
index 06220123ca..668f79f622 100644
--- a/django/db/models/aggregates.py
+++ b/django/db/models/aggregates.py
@@ -25,17 +25,6 @@ class Aggregate(Func):
c._patch_aggregate(query) # backward-compatibility support
return c
- def refs_field(self, aggregate_types, field_types):
- try:
- return (isinstance(self, aggregate_types) and
- isinstance(self.input_field._output_field_or_none, field_types))
- except FieldError:
- # Sometimes we don't know the input_field's output type (for example,
- # doing Sum(F('datetimefield') + F('datefield'), output_type=DateTimeField())
- # is OK, but the Expression(F('datetimefield') + F('datefield')) doesn't
- # have any output field.
- return False
-
@property
def input_field(self):
return self.source_expressions[0]
diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py
index 97a2a9071d..fb094fd4ef 100644
--- a/django/db/models/expressions.py
+++ b/django/db/models/expressions.py
@@ -297,14 +297,6 @@ class BaseExpression(object):
return agg, lookup
return False, ()
- def refs_field(self, aggregate_types, field_types):
- """
- Helper method for check_aggregate_support on backends
- """
- return any(
- node.refs_field(aggregate_types, field_types)
- for node in self.get_source_expressions())
-
def prepare_database_save(self, field):
return self
@@ -401,6 +393,7 @@ class DurationExpression(Expression):
return compiler.compile(side)
def as_sql(self, compiler, connection):
+ connection.ops.check_expression_support(self)
expressions = []
expression_params = []
sql, params = self.compile(self.lhs, compiler, connection)
@@ -473,6 +466,7 @@ class Func(ExpressionNode):
return c
def as_sql(self, compiler, connection, function=None, template=None):
+ connection.ops.check_expression_support(self)
sql_parts = []
params = []
for arg in self.source_expressions:
@@ -511,6 +505,7 @@ class Value(ExpressionNode):
self.value = value
def as_sql(self, compiler, connection):
+ connection.ops.check_expression_support(self)
val = self.value
# check _output_field to avoid triggering an exception
if self._output_field is not None:
@@ -536,6 +531,7 @@ class Value(ExpressionNode):
class DurationValue(Value):
def as_sql(self, compiler, connection):
+ connection.ops.check_expression_support(self)
if (connection.features.has_native_duration_field and
connection.features.driver_supports_timedelta_args):
return super(DurationValue, self).as_sql(compiler, connection)
@@ -650,6 +646,7 @@ class When(ExpressionNode):
return c
def as_sql(self, compiler, connection, template=None):
+ connection.ops.check_expression_support(self)
template_params = {}
sql_params = []
condition_sql, condition_params = compiler.compile(self.condition)
@@ -715,6 +712,7 @@ class Case(ExpressionNode):
return c
def as_sql(self, compiler, connection, template=None, extra=None):
+ connection.ops.check_expression_support(self)
if not self.cases:
return compiler.compile(self.default)
template_params = dict(extra) if extra else {}
@@ -851,6 +849,7 @@ class OrderBy(BaseExpression):
return [self.expression]
def as_sql(self, compiler, connection):
+ connection.ops.check_expression_support(self)
expression_sql, params = compiler.compile(self.expression)
placeholders = {'expression': expression_sql}
placeholders['ordering'] = 'DESC' if self.descending else 'ASC'
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
index c5926c034d..f7d5556e0a 100644
--- a/django/db/models/sql/query.py
+++ b/django/db/models/sql/query.py
@@ -230,11 +230,6 @@ class Query(object):
raise ValueError("Need either using or connection")
if using:
connection = connections[using]
-
- # Check that the compiler will be able to execute the query
- for alias, annotation in self.annotation_select.items():
- connection.ops.check_aggregate_support(annotation)
-
return connection.ops.compiler(self.compiler)(self, connection, using)
def get_meta(self):
diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py
index 3475ccaee5..fe428b2418 100644
--- a/django/db/models/sql/where.py
+++ b/django/db/models/sql/where.py
@@ -153,17 +153,6 @@ class WhereNode(tree.Node):
def contains_aggregate(self):
return self._contains_aggregate(self)
- @classmethod
- def _refs_field(cls, obj, aggregate_types, field_types):
- if not isinstance(obj, tree.Node):
- if hasattr(obj.rhs, 'refs_field'):
- return obj.rhs.refs_field(aggregate_types, field_types)
- return False
- return any(cls._refs_field(c, aggregate_types, field_types) for c in obj.children)
-
- def refs_field(self, aggregate_types, field_types):
- return self._refs_field(self, aggregate_types, field_types)
-
class EmptyWhere(WhereNode):
def add(self, data, connector):