diff options
| author | Chris Muthig <camuthig@gmail.com> | 2024-12-22 16:30:55 +0100 |
|---|---|---|
| committer | Sarah Boyce <42296566+sarahboyce@users.noreply.github.com> | 2025-03-03 11:37:00 +0100 |
| commit | 4b977a5d7283e7ca51288cc0ed0860e0004653ca (patch) | |
| tree | 0e9473c86dadb67af745171084564d9359054564 /tests/aggregation/tests.py | |
| parent | 6d1cf5375f6fbc1496095d2356357c3b08a46324 (diff) | |
Fixed #35444 -- Added generic support for Aggregate.order_by.
This moves the behaviors of `order_by` used in Postgres aggregates into
the `Aggregate` class. This allows for creating aggregate functions that
support this behavior across all database engines. This is shown by
moving the `StringAgg` class into the shared `aggregates` module and
adding support for all databases. The Postgres `StringAgg` class is now
a thin wrapper on the new shared `StringAgg` class.
Thank you Simon Charette for the review.
Diffstat (limited to 'tests/aggregation/tests.py')
| -rw-r--r-- | tests/aggregation/tests.py | 231 |
1 files changed, 220 insertions, 11 deletions
diff --git a/tests/aggregation/tests.py b/tests/aggregation/tests.py index 86151000b5..8caefc060c 100644 --- a/tests/aggregation/tests.py +++ b/tests/aggregation/tests.py @@ -4,10 +4,11 @@ import re from decimal import Decimal from django.core.exceptions import FieldError -from django.db import connection +from django.db import NotSupportedError, connection from django.db.models import ( Avg, Case, + CharField, Count, DateField, DateTimeField, @@ -22,6 +23,7 @@ from django.db.models import ( OuterRef, Q, StdDev, + StringAgg, Subquery, Sum, TimeField, @@ -32,9 +34,11 @@ from django.db.models import ( Window, ) from django.db.models.expressions import Func, RawSQL +from django.db.models.fields.json import KeyTextTransform from django.db.models.functions import ( Cast, Coalesce, + Concat, Greatest, Least, Lower, @@ -45,11 +49,11 @@ from django.db.models.functions import ( TruncHour, ) from django.test import TestCase -from django.test.testcases import skipUnlessDBFeature +from django.test.testcases import skipIfDBFeature, skipUnlessDBFeature from django.test.utils import Approximate, CaptureQueriesContext from django.utils import timezone -from .models import Author, Book, Publisher, Store +from .models import Author, Book, Employee, Publisher, Store class NowUTC(Now): @@ -566,6 +570,28 @@ class AggregateTestCase(TestCase): ) self.assertEqual(books["ratings"], expected_result) + @skipUnlessDBFeature("supports_aggregate_distinct_multiple_argument") + def test_distinct_on_stringagg(self): + books = Book.objects.aggregate( + ratings=StringAgg(Cast(F("rating"), CharField()), Value(","), distinct=True) + ) + self.assertEqual(books["ratings"], "3,4,4.5,5") + + @skipIfDBFeature("supports_aggregate_distinct_multiple_argument") + def test_raises_error_on_multiple_argument_distinct(self): + message = ( + "StringAgg does not support distinct with multiple expressions on this " + "database backend." + ) + with self.assertRaisesMessage(NotSupportedError, message): + Book.objects.aggregate( + ratings=StringAgg( + Cast(F("rating"), CharField()), + Value(","), + distinct=True, + ) + ) + def test_non_grouped_annotation_not_in_group_by(self): """ An annotation not included in values() before an aggregate should be @@ -1288,24 +1314,30 @@ class AggregateTestCase(TestCase): Book.objects.annotate(Max("id")).annotate(my_max=MyMax("id__max", "price")) def test_multi_arg_aggregate(self): - class MyMax(Max): + class MultiArgAgg(Max): output_field = DecimalField() arity = None - def as_sql(self, compiler, connection): + def as_sql(self, compiler, connection, **extra_context): copy = self.copy() - copy.set_source_expressions(copy.get_source_expressions()[0:1] + [None]) - return super(MyMax, copy).as_sql(compiler, connection) + # Most database backends do not support compiling multiple arguments on + # the Max aggregate, and that isn't what is being tested here anyway. To + # avoid errors, the extra argument is just dropped. + copy.set_source_expressions( + copy.get_source_expressions()[0:1] + [None, None] + ) + + return super(MultiArgAgg, copy).as_sql(compiler, connection) with self.assertRaisesMessage(TypeError, "Complex aggregates require an alias"): - Book.objects.aggregate(MyMax("pages", "price")) + Book.objects.aggregate(MultiArgAgg("pages", "price")) with self.assertRaisesMessage( TypeError, "Complex annotations require an alias" ): - Book.objects.annotate(MyMax("pages", "price")) + Book.objects.annotate(MultiArgAgg("pages", "price")) - Book.objects.aggregate(max_field=MyMax("pages", "price")) + Book.objects.aggregate(max_field=MultiArgAgg("pages", "price")) def test_add_implementation(self): class MySum(Sum): @@ -1318,6 +1350,8 @@ class AggregateTestCase(TestCase): "function": self.function.lower(), "expressions": sql, "distinct": "", + "filter": "", + "order_by": "", } substitutions.update(self.extra) return self.template % substitutions, params @@ -1351,7 +1385,13 @@ class AggregateTestCase(TestCase): # test overriding all parts of the template def be_evil(self, compiler, connection): - substitutions = {"function": "MAX", "expressions": "2", "distinct": ""} + substitutions = { + "function": "MAX", + "expressions": "2", + "distinct": "", + "filter": "", + "order_by": "", + } substitutions.update(self.extra) return self.template % substitutions, () @@ -1779,10 +1819,12 @@ class AggregateTestCase(TestCase): Publisher.objects.none().aggregate( sum_awards=Sum("num_awards"), books_count=Count("book"), + all_names=StringAgg("name", Value(",")), ), { "sum_awards": None, "books_count": 0, + "all_names": None, }, ) # Expression without empty_result_set_value forces queries to be @@ -1874,6 +1916,12 @@ class AggregateTestCase(TestCase): ) self.assertEqual(result["value"], 35) + def test_stringagg_default_value(self): + result = Author.objects.filter(age__gt=100).aggregate( + value=StringAgg("name", delimiter=Value(";"), default=Value("<empty>")), + ) + self.assertEqual(result["value"], "<empty>") + def test_aggregation_default_group_by(self): qs = ( Publisher.objects.values("name") @@ -2202,6 +2250,167 @@ class AggregateTestCase(TestCase): with self.assertRaisesMessage(TypeError, msg): super(function, func_instance).__init__(Value(1), Value(2)) + def test_string_agg_requires_delimiter(self): + with self.assertRaises(TypeError): + Book.objects.aggregate(stringagg=StringAgg("name")) + + def test_string_agg_escapes_delimiter(self): + values = Publisher.objects.aggregate( + stringagg=StringAgg("name", delimiter=Value("'")) + ) + + self.assertEqual( + values, + { + "stringagg": "Apress'Sams'Prentice Hall'Morgan Kaufmann'Jonno's House " + "of Books", + }, + ) + + @skipUnlessDBFeature("supports_aggregate_order_by_clause") + def test_string_agg_order_by(self): + order_by_test_cases = ( + ( + F("original_opening").desc(), + "Books.com;Amazon.com;Mamma and Pappa's Books", + ), + ( + F("original_opening").asc(), + "Mamma and Pappa's Books;Amazon.com;Books.com", + ), + (F("original_opening"), "Mamma and Pappa's Books;Amazon.com;Books.com"), + ("original_opening", "Mamma and Pappa's Books;Amazon.com;Books.com"), + ("-original_opening", "Books.com;Amazon.com;Mamma and Pappa's Books"), + ( + Concat("original_opening", Value("@")), + "Mamma and Pappa's Books;Amazon.com;Books.com", + ), + ( + Concat("original_opening", Value("@")).desc(), + "Books.com;Amazon.com;Mamma and Pappa's Books", + ), + ) + for order_by, expected_output in order_by_test_cases: + with self.subTest(order_by=order_by, expected_output=expected_output): + values = Store.objects.aggregate( + stringagg=StringAgg("name", delimiter=Value(";"), order_by=order_by) + ) + self.assertEqual(values, {"stringagg": expected_output}) + + @skipIfDBFeature("supports_aggregate_order_by_clause") + def test_string_agg_order_by_is_not_supported(self): + message = ( + "This database backend does not support specifying an order on aggregates." + ) + with self.assertRaisesMessage(NotSupportedError, message): + Store.objects.aggregate( + stringagg=StringAgg( + "name", + delimiter=Value(";"), + order_by="original_opening", + ) + ) + + def test_string_agg_filter(self): + values = Book.objects.aggregate( + stringagg=StringAgg( + "name", + delimiter=Value(";"), + filter=Q(name__startswith="P"), + ) + ) + + expected_values = { + "stringagg": "Practical Django Projects;" + "Python Web Development with Django;Paradigms of Artificial " + "Intelligence Programming: Case Studies in Common Lisp", + } + self.assertEqual(values, expected_values) + + @skipUnlessDBFeature("supports_json_field", "supports_aggregate_order_by_clause") + def test_string_agg_jsonfield_order_by(self): + Employee.objects.bulk_create( + [ + Employee(work_day_preferences={"Monday": "morning"}), + Employee(work_day_preferences={"Monday": "afternoon"}), + ] + ) + values = Employee.objects.aggregate( + stringagg=StringAgg( + KeyTextTransform("Monday", "work_day_preferences"), + delimiter=Value(","), + order_by=KeyTextTransform( + "Monday", + "work_day_preferences", + ), + output_field=CharField(), + ), + ) + self.assertEqual(values, {"stringagg": "afternoon,morning"}) + + def test_string_agg_filter_in_subquery(self): + aggregate = StringAgg( + "authors__name", + delimiter=Value(";"), + filter=~Q(authors__name__startswith="J"), + ) + subquery = ( + Book.objects.filter( + pk=OuterRef("pk"), + ) + .annotate(agg=aggregate) + .values("agg") + ) + values = list( + Book.objects.annotate( + agg=Subquery(subquery), + ).values_list("agg", flat=True) + ) + + expected_values = [ + "Adrian Holovaty", + "Brad Dayley", + "Paul Bissex;Wesley J. Chun", + "Peter Norvig;Stuart Russell", + "Peter Norvig", + "" if connection.features.interprets_empty_strings_as_nulls else None, + ] + + self.assertQuerySetEqual(expected_values, values, ordered=False) + + @skipUnlessDBFeature("supports_aggregate_order_by_clause") + def test_order_by_in_subquery(self): + aggregate = StringAgg( + "authors__name", + delimiter=Value(";"), + order_by="authors__name", + ) + subquery = ( + Book.objects.filter( + pk=OuterRef("pk"), + ) + .annotate(agg=aggregate) + .values("agg") + ) + values = list( + Book.objects.annotate( + agg=Subquery(subquery), + ) + .order_by("agg") + .values_list("agg", flat=True) + ) + + expected_values = [ + "Adrian Holovaty;Jacob Kaplan-Moss", + "Brad Dayley", + "James Bennett", + "Jeffrey Forcier;Paul Bissex;Wesley J. Chun", + "Peter Norvig", + "Peter Norvig;Stuart Russell", + ] + + self.assertEqual(expected_values, values) + class AggregateAnnotationPruningTests(TestCase): @classmethod |
