summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJacob Walls <jacobtylerwalls@gmail.com>2025-01-01 15:27:52 -0500
committerSarah Boyce <42296566+sarahboyce@users.noreply.github.com>2025-01-14 16:47:07 +0100
commitd206d4c200d71c0847e7f6720d88c587e7b46843 (patch)
tree9ca44ecf0e7762653d1a3dbae334ee3d2615187c
parentf07360e8087d3b403d1d12ff696da3138116055a (diff)
Fixed #36051 -- Declared arity on aggregate functions.
Follow-up to 4a66a69239c493c05b322815b18c605cd4c96e7c.
-rw-r--r--django/db/models/aggregates.py7
-rw-r--r--docs/ref/models/expressions.txt1
-rw-r--r--docs/releases/5.2.txt4
-rw-r--r--tests/aggregation/tests.py24
4 files changed, 36 insertions, 0 deletions
diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py
index 911b60a86d..ea16cc440c 100644
--- a/django/db/models/aggregates.py
+++ b/django/db/models/aggregates.py
@@ -158,6 +158,7 @@ class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate):
function = "AVG"
name = "Avg"
allow_distinct = True
+ arity = 1
class Count(Aggregate):
@@ -166,6 +167,7 @@ class Count(Aggregate):
output_field = IntegerField()
allow_distinct = True
empty_result_set_value = 0
+ arity = 1
allows_composite_expressions = True
def __init__(self, expression, filter=None, **extra):
@@ -195,15 +197,18 @@ class Count(Aggregate):
class Max(Aggregate):
function = "MAX"
name = "Max"
+ arity = 1
class Min(Aggregate):
function = "MIN"
name = "Min"
+ arity = 1
class StdDev(NumericOutputFieldMixin, Aggregate):
name = "StdDev"
+ arity = 1
def __init__(self, expression, sample=False, **extra):
self.function = "STDDEV_SAMP" if sample else "STDDEV_POP"
@@ -217,10 +222,12 @@ class Sum(FixDurationInputMixin, Aggregate):
function = "SUM"
name = "Sum"
allow_distinct = True
+ arity = 1
class Variance(NumericOutputFieldMixin, Aggregate):
name = "Variance"
+ arity = 1
def __init__(self, expression, sample=False, **extra):
self.function = "VAR_SAMP" if sample else "VAR_POP"
diff --git a/docs/ref/models/expressions.txt b/docs/ref/models/expressions.txt
index 6faec969c3..5d5504dbe7 100644
--- a/docs/ref/models/expressions.txt
+++ b/docs/ref/models/expressions.txt
@@ -516,6 +516,7 @@ generated. Here's a brief example::
function = "SUM"
template = "%(function)s(%(all_values)s%(expressions)s)"
allow_distinct = False
+ arity = 1
def __init__(self, expression, all_values=False, **extra):
super().__init__(expression, all_values="ALL " if all_values else "", **extra)
diff --git a/docs/releases/5.2.txt b/docs/releases/5.2.txt
index 3d3a958b6d..716f217aee 100644
--- a/docs/releases/5.2.txt
+++ b/docs/releases/5.2.txt
@@ -511,6 +511,10 @@ Miscellaneous
* The minimum supported version of ``oracledb`` is increased from 1.3.2 to
2.3.0.
+* Built-in aggregate functions accepting only one argument (``Avg``, ``Count``,
+ ``Max``, ``Min``, ``StdDev``, ``Sum``, and ``Variance``) now raise
+ :exc:`TypeError` when called with an incorrect number of arguments.
+
.. _deprecated-features-5.2:
Features deprecated in 5.2
diff --git a/tests/aggregation/tests.py b/tests/aggregation/tests.py
index b6ba728e77..861b2c5dfc 100644
--- a/tests/aggregation/tests.py
+++ b/tests/aggregation/tests.py
@@ -1276,6 +1276,8 @@ class AggregateTestCase(TestCase):
Book.objects.annotate(Max("id")).annotate(Sum("id__max"))
class MyMax(Max):
+ arity = None
+
def as_sql(self, compiler, connection):
self.set_source_expressions(self.get_source_expressions()[0:1])
return super().as_sql(compiler, connection)
@@ -1288,6 +1290,7 @@ class AggregateTestCase(TestCase):
def test_multi_arg_aggregate(self):
class MyMax(Max):
output_field = DecimalField()
+ arity = None
def as_sql(self, compiler, connection):
copy = self.copy()
@@ -2178,6 +2181,27 @@ class AggregateTestCase(TestCase):
)
self.assertEqual(list(author_qs), [337])
+ def test_aggregate_arity(self):
+ funcs_with_inherited_constructors = [Avg, Max, Min, Sum]
+ msg = "takes exactly 1 argument (2 given)"
+ for function in funcs_with_inherited_constructors:
+ with (
+ self.subTest(function=function),
+ self.assertRaisesMessage(TypeError, msg),
+ ):
+ function(Value(1), Value(2))
+
+ funcs_with_custom_constructors = [Count, StdDev, Variance]
+ for function in funcs_with_custom_constructors:
+ with self.subTest(function=function):
+ # Extra arguments are rejected via the constructor.
+ with self.assertRaises(TypeError):
+ function(Value(1), True, Value(2))
+ # If the constructor is skipped, the arity check runs.
+ func_instance = function(Value(1), True)
+ with self.assertRaisesMessage(TypeError, msg):
+ super(function, func_instance).__init__(Value(1), Value(2))
+
class AggregateAnnotationPruningTests(TestCase):
@classmethod