summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSimon Charette <simon.charette@zapier.com>2019-05-12 17:17:47 -0400
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2020-07-15 10:58:29 +0200
commit1e38f1191de21b6e96736f58df57dfb851a28c1f (patch)
tree99b2c92d87d4d351d1f180e93a543b8903ec28d3
parentd08e6f55e3a986a8d4b3a58431d9615c7bc81eaa (diff)
Fixed #30446 -- Resolved Value.output_field for stdlib types.
This required implementing a limited form of dynamic dispatch to combine expressions with numerical output. Refs #26355 should eventually provide a better interface for that.
-rw-r--r--django/contrib/gis/db/models/functions.py11
-rw-r--r--django/contrib/postgres/fields/ranges.py3
-rw-r--r--django/db/models/expressions.py66
-rw-r--r--django/db/models/lookups.py4
-rw-r--r--docs/ref/models/expressions.txt10
-rw-r--r--docs/releases/3.2.txt9
-rw-r--r--tests/aggregation/tests.py6
-rw-r--r--tests/aggregation_regress/tests.py2
-rw-r--r--tests/expressions/tests.py38
-rw-r--r--tests/ordering/tests.py12
10 files changed, 122 insertions, 39 deletions
diff --git a/django/contrib/gis/db/models/functions.py b/django/contrib/gis/db/models/functions.py
index 8ab5de6db7..1f2d372ebb 100644
--- a/django/contrib/gis/db/models/functions.py
+++ b/django/contrib/gis/db/models/functions.py
@@ -101,10 +101,13 @@ class SQLiteDecimalToFloatMixin:
is not acceptable by the GIS functions expecting numeric values.
"""
def as_sqlite(self, compiler, connection, **extra_context):
- for expr in self.get_source_expressions():
- if hasattr(expr, 'value') and isinstance(expr.value, Decimal):
- expr.value = float(expr.value)
- return super().as_sql(compiler, connection, **extra_context)
+ copy = self.copy()
+ copy.set_source_expressions([
+ Value(float(expr.value)) if hasattr(expr, 'value') and isinstance(expr.value, Decimal)
+ else expr
+ for expr in copy.get_source_expressions()
+ ])
+ return copy.as_sql(compiler, connection, **extra_context)
class OracleToleranceMixin:
diff --git a/django/contrib/postgres/fields/ranges.py b/django/contrib/postgres/fields/ranges.py
index c2f24eb5ed..8eab2cd2d9 100644
--- a/django/contrib/postgres/fields/ranges.py
+++ b/django/contrib/postgres/fields/ranges.py
@@ -173,8 +173,7 @@ class DateTimeRangeContains(PostgresOperatorLookup):
def process_rhs(self, compiler, connection):
# Transform rhs value for db lookup.
if isinstance(self.rhs, datetime.date):
- output_field = models.DateTimeField() if isinstance(self.rhs, datetime.datetime) else models.DateField()
- value = models.Value(self.rhs, output_field=output_field)
+ value = models.Value(self.rhs)
self.rhs = value.resolve_expression(compiler.query)
return super().process_rhs(compiler, connection)
diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py
index 8cceb7d966..5b5a0ae4aa 100644
--- a/django/db/models/expressions.py
+++ b/django/db/models/expressions.py
@@ -1,7 +1,9 @@
import copy
import datetime
+import functools
import inspect
from decimal import Decimal
+from uuid import UUID
from django.core.exceptions import EmptyResultSet, FieldError
from django.db import NotSupportedError, connection
@@ -56,12 +58,7 @@ class Combinable:
def _combine(self, other, connector, reversed):
if not hasattr(other, 'resolve_expression'):
# everything must be resolvable to an expression
- output_field = (
- fields.DurationField()
- if isinstance(other, datetime.timedelta) else
- None
- )
- other = Value(other, output_field=output_field)
+ other = Value(other)
if reversed:
return CombinedExpression(other, connector, self)
@@ -422,6 +419,25 @@ class Expression(BaseExpression, Combinable):
pass
+_connector_combinators = {
+ connector: [
+ (fields.IntegerField, fields.DecimalField, fields.DecimalField),
+ (fields.DecimalField, fields.IntegerField, fields.DecimalField),
+ (fields.IntegerField, fields.FloatField, fields.FloatField),
+ (fields.FloatField, fields.IntegerField, fields.FloatField),
+ ]
+ for connector in (Combinable.ADD, Combinable.SUB, Combinable.MUL, Combinable.DIV)
+}
+
+
+@functools.lru_cache(maxsize=128)
+def _resolve_combined_type(connector, lhs_type, rhs_type):
+ combinators = _connector_combinators.get(connector, ())
+ for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
+ if issubclass(lhs_type, combinator_lhs_type) and issubclass(rhs_type, combinator_rhs_type):
+ return combined_type
+
+
class CombinedExpression(SQLiteNumericMixin, Expression):
def __init__(self, lhs, connector, rhs, output_field=None):
@@ -442,6 +458,19 @@ class CombinedExpression(SQLiteNumericMixin, Expression):
def set_source_expressions(self, exprs):
self.lhs, self.rhs = exprs
+ def _resolve_output_field(self):
+ try:
+ return super()._resolve_output_field()
+ except FieldError:
+ combined_type = _resolve_combined_type(
+ self.connector,
+ type(self.lhs.output_field),
+ type(self.rhs.output_field),
+ )
+ if combined_type is None:
+ raise
+ return combined_type()
+
def as_sql(self, compiler, connection):
expressions = []
expression_params = []
@@ -721,6 +750,30 @@ class Value(Expression):
def get_group_by_cols(self, alias=None):
return []
+ def _resolve_output_field(self):
+ if isinstance(self.value, str):
+ return fields.CharField()
+ if isinstance(self.value, bool):
+ return fields.BooleanField()
+ if isinstance(self.value, int):
+ return fields.IntegerField()
+ if isinstance(self.value, float):
+ return fields.FloatField()
+ if isinstance(self.value, datetime.datetime):
+ return fields.DateTimeField()
+ if isinstance(self.value, datetime.date):
+ return fields.DateField()
+ if isinstance(self.value, datetime.time):
+ return fields.TimeField()
+ if isinstance(self.value, datetime.timedelta):
+ return fields.DurationField()
+ if isinstance(self.value, Decimal):
+ return fields.DecimalField()
+ if isinstance(self.value, bytes):
+ return fields.BinaryField()
+ if isinstance(self.value, UUID):
+ return fields.UUIDField()
+
class RawSQL(Expression):
def __init__(self, sql, params, output_field=None):
@@ -1177,7 +1230,6 @@ class OrderBy(BaseExpression):
copy.expression = Case(
When(self.expression, then=True),
default=False,
- output_field=fields.BooleanField(),
)
return copy.as_sql(compiler, connection)
return self.as_sql(compiler, connection)
diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py
index 79313ddd46..f358e50d5b 100644
--- a/django/db/models/lookups.py
+++ b/django/db/models/lookups.py
@@ -6,7 +6,7 @@ from copy import copy
from django.core.exceptions import EmptyResultSet
from django.db.models.expressions import Case, Exists, Func, Value, When
from django.db.models.fields import (
- BooleanField, CharField, DateTimeField, Field, IntegerField, UUIDField,
+ CharField, DateTimeField, Field, IntegerField, UUIDField,
)
from django.db.models.query_utils import RegisterLookupMixin
from django.utils.datastructures import OrderedSet
@@ -123,7 +123,7 @@ class Lookup:
exprs = []
for expr in (self.lhs, self.rhs):
if isinstance(expr, Exists):
- expr = Case(When(expr, then=True), default=False, output_field=BooleanField())
+ expr = Case(When(expr, then=True), default=False)
wrapped = True
exprs.append(expr)
lookup = type(self)(*exprs) if wrapped else self
diff --git a/docs/ref/models/expressions.txt b/docs/ref/models/expressions.txt
index 252966bb6c..31d2572288 100644
--- a/docs/ref/models/expressions.txt
+++ b/docs/ref/models/expressions.txt
@@ -484,7 +484,15 @@ The ``output_field`` argument should be a model field instance, like
after it's retrieved from the database. Usually no arguments are needed when
instantiating the model field as any arguments relating to data validation
(``max_length``, ``max_digits``, etc.) will not be enforced on the expression's
-output value.
+output value. If no ``output_field`` is specified it will be tentatively
+inferred from the :py:class:`type` of the provided ``value``, if possible. For
+example, passing an instance of :py:class:`datetime.datetime` as ``value``
+would default ``output_field`` to :class:`~django.db.models.DateTimeField`.
+
+.. versionchanged:: 3.2
+
+ Support for inferring a default ``output_field`` from the type of ``value``
+ was added.
``ExpressionWrapper()`` expressions
-----------------------------------
diff --git a/docs/releases/3.2.txt b/docs/releases/3.2.txt
index e0e80323e3..aa25b0f511 100644
--- a/docs/releases/3.2.txt
+++ b/docs/releases/3.2.txt
@@ -233,6 +233,15 @@ Models
* The ``of`` argument of :meth:`.QuerySet.select_for_update()` is now allowed
on MySQL 8.0.1+.
+* :class:`Value() <django.db.models.Value>` expression now
+ automatically resolves its ``output_field`` to the appropriate
+ :class:`Field <django.db.models.Field>` subclass based on the type of
+ it's provided ``value`` for :py:class:`bool`, :py:class:`bytes`,
+ :py:class:`float`, :py:class:`int`, :py:class:`str`,
+ :py:class:`datetime.date`, :py:class:`datetime.datetime`,
+ :py:class:`datetime.time`, :py:class:`datetime.timedelta`,
+ :py:class:`decimal.Decimal`, and :py:class:`uuid.UUID` instances.
+
Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~
diff --git a/tests/aggregation/tests.py b/tests/aggregation/tests.py
index a8377c9a26..da78b8d9a9 100644
--- a/tests/aggregation/tests.py
+++ b/tests/aggregation/tests.py
@@ -848,10 +848,6 @@ class AggregateTestCase(TestCase):
book = Book.objects.annotate(val=Max(2, output_field=IntegerField())).first()
self.assertEqual(book.val, 2)
- def test_missing_output_field_raises_error(self):
- with self.assertRaisesMessage(FieldError, 'Cannot resolve expression type, unknown output_field'):
- Book.objects.annotate(val=Max(2)).first()
-
def test_annotation_expressions(self):
authors = Author.objects.annotate(combined_ages=Sum(F('age') + F('friends__age'))).order_by('name')
authors2 = Author.objects.annotate(combined_ages=Sum('age') + Sum('friends__age')).order_by('name')
@@ -893,7 +889,7 @@ class AggregateTestCase(TestCase):
def test_combine_different_types(self):
msg = (
- 'Expression contains mixed types: FloatField, IntegerField. '
+ 'Expression contains mixed types: FloatField, DecimalField. '
'You must set output_field.'
)
qs = Book.objects.annotate(sums=Sum('rating') + Sum('pages') + Sum('price'))
diff --git a/tests/aggregation_regress/tests.py b/tests/aggregation_regress/tests.py
index bdfcb1d89b..b298a1f132 100644
--- a/tests/aggregation_regress/tests.py
+++ b/tests/aggregation_regress/tests.py
@@ -388,7 +388,7 @@ class AggregationTests(TestCase):
)
def test_annotated_conditional_aggregate(self):
- annotated_qs = Book.objects.annotate(discount_price=F('price') * 0.75)
+ annotated_qs = Book.objects.annotate(discount_price=F('price') * Decimal('0.75'))
self.assertAlmostEqual(
annotated_qs.aggregate(test=Avg(Case(
When(pages__lt=400, then='discount_price'),
diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py
index 42b8c8f34b..c89bb1d69e 100644
--- a/tests/expressions/tests.py
+++ b/tests/expressions/tests.py
@@ -3,15 +3,17 @@ import pickle
import unittest
import uuid
from copy import deepcopy
+from decimal import Decimal
from unittest import mock
from django.core.exceptions import FieldError
from django.db import DatabaseError, NotSupportedError, connection
from django.db.models import (
- Avg, BooleanField, Case, CharField, Count, DateField, DateTimeField,
- DurationField, Exists, Expression, ExpressionList, ExpressionWrapper, F,
- Func, IntegerField, Max, Min, Model, OrderBy, OuterRef, Q, StdDev,
- Subquery, Sum, TimeField, UUIDField, Value, Variance, When,
+ Avg, BinaryField, BooleanField, Case, CharField, Count, DateField,
+ DateTimeField, DecimalField, DurationField, Exists, Expression,
+ ExpressionList, ExpressionWrapper, F, FloatField, Func, IntegerField, Max,
+ Min, Model, OrderBy, OuterRef, Q, StdDev, Subquery, Sum, TimeField,
+ UUIDField, Value, Variance, When,
)
from django.db.models.expressions import Col, Combinable, Random, RawSQL, Ref
from django.db.models.functions import (
@@ -1711,6 +1713,30 @@ class ValueTests(TestCase):
value = Value('foo', output_field=CharField())
self.assertEqual(value.as_sql(compiler, connection), ('%s', ['foo']))
+ def test_resolve_output_field(self):
+ value_types = [
+ ('str', CharField),
+ (True, BooleanField),
+ (42, IntegerField),
+ (3.14, FloatField),
+ (datetime.date(2019, 5, 15), DateField),
+ (datetime.datetime(2019, 5, 15), DateTimeField),
+ (datetime.time(3, 16), TimeField),
+ (datetime.timedelta(1), DurationField),
+ (Decimal('3.14'), DecimalField),
+ (b'', BinaryField),
+ (uuid.uuid4(), UUIDField),
+ ]
+ for value, ouput_field_type in value_types:
+ with self.subTest(type=type(value)):
+ expr = Value(value)
+ self.assertIsInstance(expr.output_field, ouput_field_type)
+
+ def test_resolve_output_field_failure(self):
+ msg = 'Cannot resolve expression type, unknown output_field'
+ with self.assertRaisesMessage(FieldError, msg):
+ Value(object()).output_field
+
class FieldTransformTests(TestCase):
@@ -1848,7 +1874,9 @@ class ExpressionWrapperTests(SimpleTestCase):
self.assertEqual(expr.get_group_by_cols(alias=None), [])
def test_non_empty_group_by(self):
- expr = ExpressionWrapper(Lower(Value('f')), output_field=IntegerField())
+ value = Value('f')
+ value.output_field = None
+ expr = ExpressionWrapper(Lower(value), output_field=IntegerField())
group_by_cols = expr.get_group_by_cols(alias=None)
self.assertEqual(group_by_cols, [expr.expression])
self.assertEqual(group_by_cols[0].output_field, expr.output_field)
diff --git a/tests/ordering/tests.py b/tests/ordering/tests.py
index 61ec3a8592..fe319b3859 100644
--- a/tests/ordering/tests.py
+++ b/tests/ordering/tests.py
@@ -1,7 +1,6 @@
from datetime import datetime
from operator import attrgetter
-from django.core.exceptions import FieldError
from django.db.models import (
CharField, DateTimeField, F, Max, OuterRef, Subquery, Value,
)
@@ -439,17 +438,6 @@ class OrderingTests(TestCase):
qs = Article.objects.order_by(Value('1', output_field=CharField()), '-headline')
self.assertSequenceEqual(qs, [self.a4, self.a3, self.a2, self.a1])
- def test_order_by_constant_value_without_output_field(self):
- msg = 'Cannot resolve expression type, unknown output_field'
- qs = Article.objects.annotate(constant=Value('1')).order_by('constant')
- for ordered_qs in (
- qs,
- qs.values('headline'),
- Article.objects.order_by(Value('1')),
- ):
- with self.subTest(ordered_qs=ordered_qs), self.assertRaisesMessage(FieldError, msg):
- ordered_qs.first()
-
def test_related_ordering_duplicate_table_reference(self):
"""
An ordering referencing a model with an ordering referencing a model