diff options
| author | Josh Smeaton <josh.smeaton@gmail.com> | 2015-01-17 16:03:46 +1100 |
|---|---|---|
| committer | Josh Smeaton <josh.smeaton@gmail.com> | 2015-01-27 12:20:06 +1100 |
| commit | 8196e4bdf498acb05e6680c81f9d7bf700f4295c (patch) | |
| tree | d7b8c51787581b4f9ead00686380bba89d678296 /django/db/models | |
| parent | 511be35779a98427387d9aa4abacce01dedd7272 (diff) | |
Fixed #24154 -- Backends can now check support for expressions
Diffstat (limited to 'django/db/models')
| -rw-r--r-- | django/db/models/aggregates.py | 11 | ||||
| -rw-r--r-- | django/db/models/expressions.py | 15 | ||||
| -rw-r--r-- | django/db/models/sql/query.py | 5 | ||||
| -rw-r--r-- | django/db/models/sql/where.py | 11 |
4 files changed, 7 insertions, 35 deletions
diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index 06220123ca..668f79f622 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -25,17 +25,6 @@ class Aggregate(Func): c._patch_aggregate(query) # backward-compatibility support return c - def refs_field(self, aggregate_types, field_types): - try: - return (isinstance(self, aggregate_types) and - isinstance(self.input_field._output_field_or_none, field_types)) - except FieldError: - # Sometimes we don't know the input_field's output type (for example, - # doing Sum(F('datetimefield') + F('datefield'), output_type=DateTimeField()) - # is OK, but the Expression(F('datetimefield') + F('datefield')) doesn't - # have any output field. - return False - @property def input_field(self): return self.source_expressions[0] diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 97a2a9071d..fb094fd4ef 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -297,14 +297,6 @@ class BaseExpression(object): return agg, lookup return False, () - def refs_field(self, aggregate_types, field_types): - """ - Helper method for check_aggregate_support on backends - """ - return any( - node.refs_field(aggregate_types, field_types) - for node in self.get_source_expressions()) - def prepare_database_save(self, field): return self @@ -401,6 +393,7 @@ class DurationExpression(Expression): return compiler.compile(side) def as_sql(self, compiler, connection): + connection.ops.check_expression_support(self) expressions = [] expression_params = [] sql, params = self.compile(self.lhs, compiler, connection) @@ -473,6 +466,7 @@ class Func(ExpressionNode): return c def as_sql(self, compiler, connection, function=None, template=None): + connection.ops.check_expression_support(self) sql_parts = [] params = [] for arg in self.source_expressions: @@ -511,6 +505,7 @@ class Value(ExpressionNode): self.value = value def as_sql(self, compiler, connection): + connection.ops.check_expression_support(self) val = self.value # check _output_field to avoid triggering an exception if self._output_field is not None: @@ -536,6 +531,7 @@ class Value(ExpressionNode): class DurationValue(Value): def as_sql(self, compiler, connection): + connection.ops.check_expression_support(self) if (connection.features.has_native_duration_field and connection.features.driver_supports_timedelta_args): return super(DurationValue, self).as_sql(compiler, connection) @@ -650,6 +646,7 @@ class When(ExpressionNode): return c def as_sql(self, compiler, connection, template=None): + connection.ops.check_expression_support(self) template_params = {} sql_params = [] condition_sql, condition_params = compiler.compile(self.condition) @@ -715,6 +712,7 @@ class Case(ExpressionNode): return c def as_sql(self, compiler, connection, template=None, extra=None): + connection.ops.check_expression_support(self) if not self.cases: return compiler.compile(self.default) template_params = dict(extra) if extra else {} @@ -851,6 +849,7 @@ class OrderBy(BaseExpression): return [self.expression] def as_sql(self, compiler, connection): + connection.ops.check_expression_support(self) expression_sql, params = compiler.compile(self.expression) placeholders = {'expression': expression_sql} placeholders['ordering'] = 'DESC' if self.descending else 'ASC' diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index c5926c034d..f7d5556e0a 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -230,11 +230,6 @@ class Query(object): raise ValueError("Need either using or connection") if using: connection = connections[using] - - # Check that the compiler will be able to execute the query - for alias, annotation in self.annotation_select.items(): - connection.ops.check_aggregate_support(annotation) - return connection.ops.compiler(self.compiler)(self, connection, using) def get_meta(self): diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 3475ccaee5..fe428b2418 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -153,17 +153,6 @@ class WhereNode(tree.Node): def contains_aggregate(self): return self._contains_aggregate(self) - @classmethod - def _refs_field(cls, obj, aggregate_types, field_types): - if not isinstance(obj, tree.Node): - if hasattr(obj.rhs, 'refs_field'): - return obj.rhs.refs_field(aggregate_types, field_types) - return False - return any(cls._refs_field(c, aggregate_types, field_types) for c in obj.children) - - def refs_field(self, aggregate_types, field_types): - return self._refs_field(self, aggregate_types, field_types) - class EmptyWhere(WhereNode): def add(self, data, connector): |
