summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNick Pope <nick.pope@flightdataservices.com>2018-12-01 23:46:28 +0000
committerTim Graham <timograham@gmail.com>2019-01-14 14:35:41 -0500
commitc690afb873cac8035a3cb3be7c597a5ff0e4b261 (patch)
tree19d60b4ceaf53f86a9c1ba37b7e27ac080256f13
parent3d5e0f8394688d40036e27cfcfac295e6fe62269 (diff)
Refs #28643 -- Changed Avg() to use NumericOutputFieldMixin.
Keeps precision instead of forcing DecimalField to FloatField.
-rw-r--r--django/db/models/aggregates.py11
-rw-r--r--django/db/models/functions/mixins.py10
-rw-r--r--docs/ref/models/querysets.txt6
-rw-r--r--docs/releases/2.2.txt3
-rw-r--r--tests/aggregation/tests.py10
-rw-r--r--tests/aggregation_regress/tests.py4
6 files changed, 23 insertions, 21 deletions
diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py
index a7dc55ee98..f9202543a3 100644
--- a/django/db/models/aggregates.py
+++ b/django/db/models/aggregates.py
@@ -3,7 +3,8 @@ Classes to represent the definitions of aggregate functions.
"""
from django.core.exceptions import FieldError
from django.db.models.expressions import Case, Func, Star, When
-from django.db.models.fields import DecimalField, FloatField, IntegerField
+from django.db.models.fields import FloatField, IntegerField
+from django.db.models.functions.mixins import NumericOutputFieldMixin
__all__ = [
'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
@@ -93,16 +94,10 @@ class Aggregate(Func):
return options
-class Avg(Aggregate):
+class Avg(NumericOutputFieldMixin, Aggregate):
function = 'AVG'
name = 'Avg'
- def _resolve_output_field(self):
- source_field = self.get_source_fields()[0]
- if isinstance(source_field, (IntegerField, DecimalField)):
- return FloatField()
- return super()._resolve_output_field()
-
def as_mysql(self, compiler, connection, **extra_context):
sql, params = super().as_sql(compiler, connection, **extra_context)
if self.output_field.get_internal_type() == 'DurationField':
diff --git a/django/db/models/functions/mixins.py b/django/db/models/functions/mixins.py
index 1bf3d6cbd0..9b46987788 100644
--- a/django/db/models/functions/mixins.py
+++ b/django/db/models/functions/mixins.py
@@ -1,6 +1,6 @@
import sys
-from django.db.models.fields import DecimalField, FloatField
+from django.db.models.fields import DecimalField, FloatField, IntegerField
from django.db.models.functions import Cast
@@ -23,5 +23,9 @@ class FixDecimalInputMixin:
class NumericOutputFieldMixin:
def _resolve_output_field(self):
- has_decimals = any(isinstance(s.output_field, DecimalField) for s in self.get_source_expressions())
- return DecimalField() if has_decimals else FloatField()
+ source_expressions = self.get_source_expressions()
+ if any(isinstance(s.output_field, DecimalField) for s in source_expressions):
+ return DecimalField()
+ if any(isinstance(s.output_field, IntegerField) for s in source_expressions):
+ return FloatField()
+ return super()._resolve_output_field() if source_expressions else FloatField()
diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt
index 78eb175329..46fcd50e37 100644
--- a/docs/ref/models/querysets.txt
+++ b/docs/ref/models/querysets.txt
@@ -3336,14 +3336,14 @@ by the aggregate.
``Avg``
~~~~~~~
-.. class:: Avg(expression, output_field=FloatField(), filter=None, **extra)
+.. class:: Avg(expression, output_field=None, filter=None, **extra)
Returns the mean value of the given expression, which must be numeric
unless you specify a different ``output_field``.
* Default alias: ``<field>__avg``
- * Return type: ``float`` (or the type of whatever ``output_field`` is
- specified)
+ * Return type: ``float`` if input is ``int``, otherwise same as input
+ field, or ``output_field`` if supplied
``Count``
~~~~~~~~~
diff --git a/docs/releases/2.2.txt b/docs/releases/2.2.txt
index 13f7617888..b96b0ed1ef 100644
--- a/docs/releases/2.2.txt
+++ b/docs/releases/2.2.txt
@@ -493,6 +493,9 @@ Miscellaneous
* :djadmin:`runserver` no longer supports `pyinotify` (replaced by Watchman).
+* The :class:`~django.db.models.Avg` aggregate function now returns a
+ ``Decimal`` instead of a ``float`` when the input is ``Decimal``.
+
.. _deprecated-features-2.2:
Features deprecated in 2.2
diff --git a/tests/aggregation/tests.py b/tests/aggregation/tests.py
index 75d2ecb1c5..8cac90f020 100644
--- a/tests/aggregation/tests.py
+++ b/tests/aggregation/tests.py
@@ -865,15 +865,15 @@ class AggregateTestCase(TestCase):
def test_avg_decimal_field(self):
v = Book.objects.filter(rating=4).aggregate(avg_price=(Avg('price')))['avg_price']
- self.assertIsInstance(v, float)
- self.assertEqual(v, Approximate(47.39, places=2))
+ self.assertIsInstance(v, Decimal)
+ self.assertEqual(v, Approximate(Decimal('47.39'), places=2))
def test_order_of_precedence(self):
p1 = Book.objects.filter(rating=4).aggregate(avg_price=(Avg('price') + 2) * 3)
- self.assertEqual(p1, {'avg_price': Approximate(148.18, places=2)})
+ self.assertEqual(p1, {'avg_price': Approximate(Decimal('148.18'), places=2)})
p2 = Book.objects.filter(rating=4).aggregate(avg_price=Avg('price') + 2 * 3)
- self.assertEqual(p2, {'avg_price': Approximate(53.39, places=2)})
+ self.assertEqual(p2, {'avg_price': Approximate(Decimal('53.39'), places=2)})
def test_combine_different_types(self):
msg = 'Expression contains mixed types. You must set output_field.'
@@ -1087,7 +1087,7 @@ class AggregateTestCase(TestCase):
return super().as_sql(compiler, connection, function='MAX', **extra_context)
qs = Publisher.objects.annotate(
- price_or_median=Greatest(Avg('book__rating'), Avg('book__price'))
+ price_or_median=Greatest(Avg('book__rating', output_field=DecimalField()), Avg('book__price'))
).filter(price_or_median__gte=F('num_awards')).order_by('num_awards')
self.assertQuerysetEqual(
qs, [1, 3, 7, 9], lambda v: v.num_awards)
diff --git a/tests/aggregation_regress/tests.py b/tests/aggregation_regress/tests.py
index 2b3948a0b4..64bbc13f80 100644
--- a/tests/aggregation_regress/tests.py
+++ b/tests/aggregation_regress/tests.py
@@ -401,7 +401,7 @@ class AggregationTests(TestCase):
When(pages__lt=400, then='discount_price'),
output_field=DecimalField()
)))['test'],
- 22.27, places=2
+ Decimal('22.27'), places=2
)
def test_distinct_conditional_aggregate(self):
@@ -1041,7 +1041,7 @@ class AggregationTests(TestCase):
books = Book.objects.values_list("publisher__name").annotate(
Count("id"), Avg("price"), Avg("authors__age"), avg_pgs=Avg("pages")
).order_by("-publisher__name")
- self.assertEqual(books[0], ('Sams', 1, 23.09, 45.0, 528.0))
+ self.assertEqual(books[0], ('Sams', 1, Decimal('23.09'), 45.0, 528.0))
def test_annotation_disjunction(self):
qs = Book.objects.annotate(n_authors=Count("authors")).filter(