summaryrefslogtreecommitdiff
path: root/tests/aggregation/tests.py
diff options
context:
space:
mode:
authorChris Muthig <camuthig@gmail.com>2024-12-22 16:30:55 +0100
committerSarah Boyce <42296566+sarahboyce@users.noreply.github.com>2025-03-03 11:37:00 +0100
commit4b977a5d7283e7ca51288cc0ed0860e0004653ca (patch)
tree0e9473c86dadb67af745171084564d9359054564 /tests/aggregation/tests.py
parent6d1cf5375f6fbc1496095d2356357c3b08a46324 (diff)
Fixed #35444 -- Added generic support for Aggregate.order_by.
This moves the behaviors of `order_by` used in Postgres aggregates into the `Aggregate` class. This allows for creating aggregate functions that support this behavior across all database engines. This is shown by moving the `StringAgg` class into the shared `aggregates` module and adding support for all databases. The Postgres `StringAgg` class is now a thin wrapper on the new shared `StringAgg` class. Thank you Simon Charette for the review.
Diffstat (limited to 'tests/aggregation/tests.py')
-rw-r--r--tests/aggregation/tests.py231
1 files changed, 220 insertions, 11 deletions
diff --git a/tests/aggregation/tests.py b/tests/aggregation/tests.py
index 86151000b5..8caefc060c 100644
--- a/tests/aggregation/tests.py
+++ b/tests/aggregation/tests.py
@@ -4,10 +4,11 @@ import re
from decimal import Decimal
from django.core.exceptions import FieldError
-from django.db import connection
+from django.db import NotSupportedError, connection
from django.db.models import (
Avg,
Case,
+ CharField,
Count,
DateField,
DateTimeField,
@@ -22,6 +23,7 @@ from django.db.models import (
OuterRef,
Q,
StdDev,
+ StringAgg,
Subquery,
Sum,
TimeField,
@@ -32,9 +34,11 @@ from django.db.models import (
Window,
)
from django.db.models.expressions import Func, RawSQL
+from django.db.models.fields.json import KeyTextTransform
from django.db.models.functions import (
Cast,
Coalesce,
+ Concat,
Greatest,
Least,
Lower,
@@ -45,11 +49,11 @@ from django.db.models.functions import (
TruncHour,
)
from django.test import TestCase
-from django.test.testcases import skipUnlessDBFeature
+from django.test.testcases import skipIfDBFeature, skipUnlessDBFeature
from django.test.utils import Approximate, CaptureQueriesContext
from django.utils import timezone
-from .models import Author, Book, Publisher, Store
+from .models import Author, Book, Employee, Publisher, Store
class NowUTC(Now):
@@ -566,6 +570,28 @@ class AggregateTestCase(TestCase):
)
self.assertEqual(books["ratings"], expected_result)
+ @skipUnlessDBFeature("supports_aggregate_distinct_multiple_argument")
+ def test_distinct_on_stringagg(self):
+ books = Book.objects.aggregate(
+ ratings=StringAgg(Cast(F("rating"), CharField()), Value(","), distinct=True)
+ )
+ self.assertEqual(books["ratings"], "3,4,4.5,5")
+
+ @skipIfDBFeature("supports_aggregate_distinct_multiple_argument")
+ def test_raises_error_on_multiple_argument_distinct(self):
+ message = (
+ "StringAgg does not support distinct with multiple expressions on this "
+ "database backend."
+ )
+ with self.assertRaisesMessage(NotSupportedError, message):
+ Book.objects.aggregate(
+ ratings=StringAgg(
+ Cast(F("rating"), CharField()),
+ Value(","),
+ distinct=True,
+ )
+ )
+
def test_non_grouped_annotation_not_in_group_by(self):
"""
An annotation not included in values() before an aggregate should be
@@ -1288,24 +1314,30 @@ class AggregateTestCase(TestCase):
Book.objects.annotate(Max("id")).annotate(my_max=MyMax("id__max", "price"))
def test_multi_arg_aggregate(self):
- class MyMax(Max):
+ class MultiArgAgg(Max):
output_field = DecimalField()
arity = None
- def as_sql(self, compiler, connection):
+ def as_sql(self, compiler, connection, **extra_context):
copy = self.copy()
- copy.set_source_expressions(copy.get_source_expressions()[0:1] + [None])
- return super(MyMax, copy).as_sql(compiler, connection)
+ # Most database backends do not support compiling multiple arguments on
+ # the Max aggregate, and that isn't what is being tested here anyway. To
+ # avoid errors, the extra argument is just dropped.
+ copy.set_source_expressions(
+ copy.get_source_expressions()[0:1] + [None, None]
+ )
+
+ return super(MultiArgAgg, copy).as_sql(compiler, connection)
with self.assertRaisesMessage(TypeError, "Complex aggregates require an alias"):
- Book.objects.aggregate(MyMax("pages", "price"))
+ Book.objects.aggregate(MultiArgAgg("pages", "price"))
with self.assertRaisesMessage(
TypeError, "Complex annotations require an alias"
):
- Book.objects.annotate(MyMax("pages", "price"))
+ Book.objects.annotate(MultiArgAgg("pages", "price"))
- Book.objects.aggregate(max_field=MyMax("pages", "price"))
+ Book.objects.aggregate(max_field=MultiArgAgg("pages", "price"))
def test_add_implementation(self):
class MySum(Sum):
@@ -1318,6 +1350,8 @@ class AggregateTestCase(TestCase):
"function": self.function.lower(),
"expressions": sql,
"distinct": "",
+ "filter": "",
+ "order_by": "",
}
substitutions.update(self.extra)
return self.template % substitutions, params
@@ -1351,7 +1385,13 @@ class AggregateTestCase(TestCase):
# test overriding all parts of the template
def be_evil(self, compiler, connection):
- substitutions = {"function": "MAX", "expressions": "2", "distinct": ""}
+ substitutions = {
+ "function": "MAX",
+ "expressions": "2",
+ "distinct": "",
+ "filter": "",
+ "order_by": "",
+ }
substitutions.update(self.extra)
return self.template % substitutions, ()
@@ -1779,10 +1819,12 @@ class AggregateTestCase(TestCase):
Publisher.objects.none().aggregate(
sum_awards=Sum("num_awards"),
books_count=Count("book"),
+ all_names=StringAgg("name", Value(",")),
),
{
"sum_awards": None,
"books_count": 0,
+ "all_names": None,
},
)
# Expression without empty_result_set_value forces queries to be
@@ -1874,6 +1916,12 @@ class AggregateTestCase(TestCase):
)
self.assertEqual(result["value"], 35)
+ def test_stringagg_default_value(self):
+ result = Author.objects.filter(age__gt=100).aggregate(
+ value=StringAgg("name", delimiter=Value(";"), default=Value("<empty>")),
+ )
+ self.assertEqual(result["value"], "<empty>")
+
def test_aggregation_default_group_by(self):
qs = (
Publisher.objects.values("name")
@@ -2202,6 +2250,167 @@ class AggregateTestCase(TestCase):
with self.assertRaisesMessage(TypeError, msg):
super(function, func_instance).__init__(Value(1), Value(2))
+ def test_string_agg_requires_delimiter(self):
+ with self.assertRaises(TypeError):
+ Book.objects.aggregate(stringagg=StringAgg("name"))
+
+ def test_string_agg_escapes_delimiter(self):
+ values = Publisher.objects.aggregate(
+ stringagg=StringAgg("name", delimiter=Value("'"))
+ )
+
+ self.assertEqual(
+ values,
+ {
+ "stringagg": "Apress'Sams'Prentice Hall'Morgan Kaufmann'Jonno's House "
+ "of Books",
+ },
+ )
+
+ @skipUnlessDBFeature("supports_aggregate_order_by_clause")
+ def test_string_agg_order_by(self):
+ order_by_test_cases = (
+ (
+ F("original_opening").desc(),
+ "Books.com;Amazon.com;Mamma and Pappa's Books",
+ ),
+ (
+ F("original_opening").asc(),
+ "Mamma and Pappa's Books;Amazon.com;Books.com",
+ ),
+ (F("original_opening"), "Mamma and Pappa's Books;Amazon.com;Books.com"),
+ ("original_opening", "Mamma and Pappa's Books;Amazon.com;Books.com"),
+ ("-original_opening", "Books.com;Amazon.com;Mamma and Pappa's Books"),
+ (
+ Concat("original_opening", Value("@")),
+ "Mamma and Pappa's Books;Amazon.com;Books.com",
+ ),
+ (
+ Concat("original_opening", Value("@")).desc(),
+ "Books.com;Amazon.com;Mamma and Pappa's Books",
+ ),
+ )
+ for order_by, expected_output in order_by_test_cases:
+ with self.subTest(order_by=order_by, expected_output=expected_output):
+ values = Store.objects.aggregate(
+ stringagg=StringAgg("name", delimiter=Value(";"), order_by=order_by)
+ )
+ self.assertEqual(values, {"stringagg": expected_output})
+
+ @skipIfDBFeature("supports_aggregate_order_by_clause")
+ def test_string_agg_order_by_is_not_supported(self):
+ message = (
+ "This database backend does not support specifying an order on aggregates."
+ )
+ with self.assertRaisesMessage(NotSupportedError, message):
+ Store.objects.aggregate(
+ stringagg=StringAgg(
+ "name",
+ delimiter=Value(";"),
+ order_by="original_opening",
+ )
+ )
+
+ def test_string_agg_filter(self):
+ values = Book.objects.aggregate(
+ stringagg=StringAgg(
+ "name",
+ delimiter=Value(";"),
+ filter=Q(name__startswith="P"),
+ )
+ )
+
+ expected_values = {
+ "stringagg": "Practical Django Projects;"
+ "Python Web Development with Django;Paradigms of Artificial "
+ "Intelligence Programming: Case Studies in Common Lisp",
+ }
+ self.assertEqual(values, expected_values)
+
+ @skipUnlessDBFeature("supports_json_field", "supports_aggregate_order_by_clause")
+ def test_string_agg_jsonfield_order_by(self):
+ Employee.objects.bulk_create(
+ [
+ Employee(work_day_preferences={"Monday": "morning"}),
+ Employee(work_day_preferences={"Monday": "afternoon"}),
+ ]
+ )
+ values = Employee.objects.aggregate(
+ stringagg=StringAgg(
+ KeyTextTransform("Monday", "work_day_preferences"),
+ delimiter=Value(","),
+ order_by=KeyTextTransform(
+ "Monday",
+ "work_day_preferences",
+ ),
+ output_field=CharField(),
+ ),
+ )
+ self.assertEqual(values, {"stringagg": "afternoon,morning"})
+
+ def test_string_agg_filter_in_subquery(self):
+ aggregate = StringAgg(
+ "authors__name",
+ delimiter=Value(";"),
+ filter=~Q(authors__name__startswith="J"),
+ )
+ subquery = (
+ Book.objects.filter(
+ pk=OuterRef("pk"),
+ )
+ .annotate(agg=aggregate)
+ .values("agg")
+ )
+ values = list(
+ Book.objects.annotate(
+ agg=Subquery(subquery),
+ ).values_list("agg", flat=True)
+ )
+
+ expected_values = [
+ "Adrian Holovaty",
+ "Brad Dayley",
+ "Paul Bissex;Wesley J. Chun",
+ "Peter Norvig;Stuart Russell",
+ "Peter Norvig",
+ "" if connection.features.interprets_empty_strings_as_nulls else None,
+ ]
+
+ self.assertQuerySetEqual(expected_values, values, ordered=False)
+
+ @skipUnlessDBFeature("supports_aggregate_order_by_clause")
+ def test_order_by_in_subquery(self):
+ aggregate = StringAgg(
+ "authors__name",
+ delimiter=Value(";"),
+ order_by="authors__name",
+ )
+ subquery = (
+ Book.objects.filter(
+ pk=OuterRef("pk"),
+ )
+ .annotate(agg=aggregate)
+ .values("agg")
+ )
+ values = list(
+ Book.objects.annotate(
+ agg=Subquery(subquery),
+ )
+ .order_by("agg")
+ .values_list("agg", flat=True)
+ )
+
+ expected_values = [
+ "Adrian Holovaty;Jacob Kaplan-Moss",
+ "Brad Dayley",
+ "James Bennett",
+ "Jeffrey Forcier;Paul Bissex;Wesley J. Chun",
+ "Peter Norvig",
+ "Peter Norvig;Stuart Russell",
+ ]
+
+ self.assertEqual(expected_values, values)
+
class AggregateAnnotationPruningTests(TestCase):
@classmethod