summaryrefslogtreecommitdiff
path: root/django/db/models
diff options
context:
space:
mode:
authorJosh Smeaton <josh.smeaton@gmail.com>2015-01-17 16:03:46 +1100
committerJosh Smeaton <josh.smeaton@gmail.com>2015-01-27 12:20:06 +1100
commit8196e4bdf498acb05e6680c81f9d7bf700f4295c (patch)
treed7b8c51787581b4f9ead00686380bba89d678296 /django/db/models
parent511be35779a98427387d9aa4abacce01dedd7272 (diff)
Fixed #24154 -- Backends can now check support for expressions
Diffstat (limited to 'django/db/models')
-rw-r--r--django/db/models/aggregates.py11
-rw-r--r--django/db/models/expressions.py15
-rw-r--r--django/db/models/sql/query.py5
-rw-r--r--django/db/models/sql/where.py11
4 files changed, 7 insertions, 35 deletions
diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py
index 06220123ca..668f79f622 100644
--- a/django/db/models/aggregates.py
+++ b/django/db/models/aggregates.py
@@ -25,17 +25,6 @@ class Aggregate(Func):
c._patch_aggregate(query) # backward-compatibility support
return c
- def refs_field(self, aggregate_types, field_types):
- try:
- return (isinstance(self, aggregate_types) and
- isinstance(self.input_field._output_field_or_none, field_types))
- except FieldError:
- # Sometimes we don't know the input_field's output type (for example,
- # doing Sum(F('datetimefield') + F('datefield'), output_type=DateTimeField())
- # is OK, but the Expression(F('datetimefield') + F('datefield')) doesn't
- # have any output field.
- return False
-
@property
def input_field(self):
return self.source_expressions[0]
diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py
index 97a2a9071d..fb094fd4ef 100644
--- a/django/db/models/expressions.py
+++ b/django/db/models/expressions.py
@@ -297,14 +297,6 @@ class BaseExpression(object):
return agg, lookup
return False, ()
- def refs_field(self, aggregate_types, field_types):
- """
- Helper method for check_aggregate_support on backends
- """
- return any(
- node.refs_field(aggregate_types, field_types)
- for node in self.get_source_expressions())
-
def prepare_database_save(self, field):
return self
@@ -401,6 +393,7 @@ class DurationExpression(Expression):
return compiler.compile(side)
def as_sql(self, compiler, connection):
+ connection.ops.check_expression_support(self)
expressions = []
expression_params = []
sql, params = self.compile(self.lhs, compiler, connection)
@@ -473,6 +466,7 @@ class Func(ExpressionNode):
return c
def as_sql(self, compiler, connection, function=None, template=None):
+ connection.ops.check_expression_support(self)
sql_parts = []
params = []
for arg in self.source_expressions:
@@ -511,6 +505,7 @@ class Value(ExpressionNode):
self.value = value
def as_sql(self, compiler, connection):
+ connection.ops.check_expression_support(self)
val = self.value
# check _output_field to avoid triggering an exception
if self._output_field is not None:
@@ -536,6 +531,7 @@ class Value(ExpressionNode):
class DurationValue(Value):
def as_sql(self, compiler, connection):
+ connection.ops.check_expression_support(self)
if (connection.features.has_native_duration_field and
connection.features.driver_supports_timedelta_args):
return super(DurationValue, self).as_sql(compiler, connection)
@@ -650,6 +646,7 @@ class When(ExpressionNode):
return c
def as_sql(self, compiler, connection, template=None):
+ connection.ops.check_expression_support(self)
template_params = {}
sql_params = []
condition_sql, condition_params = compiler.compile(self.condition)
@@ -715,6 +712,7 @@ class Case(ExpressionNode):
return c
def as_sql(self, compiler, connection, template=None, extra=None):
+ connection.ops.check_expression_support(self)
if not self.cases:
return compiler.compile(self.default)
template_params = dict(extra) if extra else {}
@@ -851,6 +849,7 @@ class OrderBy(BaseExpression):
return [self.expression]
def as_sql(self, compiler, connection):
+ connection.ops.check_expression_support(self)
expression_sql, params = compiler.compile(self.expression)
placeholders = {'expression': expression_sql}
placeholders['ordering'] = 'DESC' if self.descending else 'ASC'
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
index c5926c034d..f7d5556e0a 100644
--- a/django/db/models/sql/query.py
+++ b/django/db/models/sql/query.py
@@ -230,11 +230,6 @@ class Query(object):
raise ValueError("Need either using or connection")
if using:
connection = connections[using]
-
- # Check that the compiler will be able to execute the query
- for alias, annotation in self.annotation_select.items():
- connection.ops.check_aggregate_support(annotation)
-
return connection.ops.compiler(self.compiler)(self, connection, using)
def get_meta(self):
diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py
index 3475ccaee5..fe428b2418 100644
--- a/django/db/models/sql/where.py
+++ b/django/db/models/sql/where.py
@@ -153,17 +153,6 @@ class WhereNode(tree.Node):
def contains_aggregate(self):
return self._contains_aggregate(self)
- @classmethod
- def _refs_field(cls, obj, aggregate_types, field_types):
- if not isinstance(obj, tree.Node):
- if hasattr(obj.rhs, 'refs_field'):
- return obj.rhs.refs_field(aggregate_types, field_types)
- return False
- return any(cls._refs_field(c, aggregate_types, field_types) for c in obj.children)
-
- def refs_field(self, aggregate_types, field_types):
- return self._refs_field(self, aggregate_types, field_types)
-
class EmptyWhere(WhereNode):
def add(self, data, connector):