summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFloris den Hengst <florisdenhengst@gmail.com>2016-07-05 11:47:24 +0200
committerTim Graham <timograham@gmail.com>2018-06-28 20:29:33 -0400
commit96199e562dcc409ab4bdc2b2146fa7cf73c7c5fe (patch)
treecb047cbe692bfe22a10b25b1336730079ec027bf
parent2a0116266c4d81bd1cc4e3ea20efe9a7874f481b (diff)
Fixed #26067 -- Added ordering support to ArrayAgg and StringAgg.
-rw-r--r--django/contrib/postgres/aggregates/general.py10
-rw-r--r--django/contrib/postgres/aggregates/mixins.py47
-rw-r--r--docs/ref/contrib/postgres/aggregates.txt31
-rw-r--r--docs/releases/2.2.txt5
-rw-r--r--tests/postgres_tests/test_aggregates.py65
5 files changed, 146 insertions, 12 deletions
diff --git a/django/contrib/postgres/aggregates/general.py b/django/contrib/postgres/aggregates/general.py
index 806ecd1b78..4b2da0b101 100644
--- a/django/contrib/postgres/aggregates/general.py
+++ b/django/contrib/postgres/aggregates/general.py
@@ -1,14 +1,16 @@
from django.contrib.postgres.fields import ArrayField, JSONField
from django.db.models.aggregates import Aggregate
+from .mixins import OrderableAggMixin
+
__all__ = [
'ArrayAgg', 'BitAnd', 'BitOr', 'BoolAnd', 'BoolOr', 'JSONBAgg', 'StringAgg',
]
-class ArrayAgg(Aggregate):
+class ArrayAgg(OrderableAggMixin, Aggregate):
function = 'ARRAY_AGG'
- template = '%(function)s(%(distinct)s%(expressions)s)'
+ template = '%(function)s(%(distinct)s%(expressions)s %(ordering)s)'
@property
def output_field(self):
@@ -49,9 +51,9 @@ class JSONBAgg(Aggregate):
return value
-class StringAgg(Aggregate):
+class StringAgg(OrderableAggMixin, Aggregate):
function = 'STRING_AGG'
- template = "%(function)s(%(distinct)s%(expressions)s, '%(delimiter)s')"
+ template = "%(function)s(%(distinct)s%(expressions)s, '%(delimiter)s'%(ordering)s)"
def __init__(self, expression, delimiter, distinct=False, **extra):
distinct = 'DISTINCT ' if distinct else ''
diff --git a/django/contrib/postgres/aggregates/mixins.py b/django/contrib/postgres/aggregates/mixins.py
new file mode 100644
index 0000000000..b270a5b653
--- /dev/null
+++ b/django/contrib/postgres/aggregates/mixins.py
@@ -0,0 +1,47 @@
+from django.db.models.expressions import F, OrderBy
+
+
+class OrderableAggMixin:
+
+ def __init__(self, expression, ordering=(), **extra):
+ if not isinstance(ordering, (list, tuple)):
+ ordering = [ordering]
+ ordering = ordering or []
+ # Transform minus sign prefixed strings into an OrderBy() expression.
+ ordering = (
+ (OrderBy(F(o[1:]), descending=True) if isinstance(o, str) and o[0] == '-' else o)
+ for o in ordering
+ )
+ super().__init__(expression, **extra)
+ self.ordering = self._parse_expressions(*ordering)
+
+ def resolve_expression(self, *args, **kwargs):
+ self.ordering = [expr.resolve_expression(*args, **kwargs) for expr in self.ordering]
+ return super().resolve_expression(*args, **kwargs)
+
+ def as_sql(self, compiler, connection):
+ if self.ordering:
+ self.extra['ordering'] = 'ORDER BY ' + ', '.join((
+ ordering_element.as_sql(compiler, connection)[0]
+ for ordering_element in self.ordering
+ ))
+ else:
+ self.extra['ordering'] = ''
+ return super().as_sql(compiler, connection)
+
+ def get_source_expressions(self):
+ return self.source_expressions + self.ordering
+
+ def get_source_fields(self):
+ # Filter out fields contributed by the ordering expressions as
+ # these should not be used to determine which the return type of the
+ # expression.
+ return [
+ e._output_field_or_none
+ for e in self.get_source_expressions()[:self._get_ordering_expressions_index()]
+ ]
+
+ def _get_ordering_expressions_index(self):
+ """Return the index at which the ordering expressions start."""
+ source_expressions = self.get_source_expressions()
+ return len(source_expressions) - len(self.ordering)
diff --git a/docs/ref/contrib/postgres/aggregates.txt b/docs/ref/contrib/postgres/aggregates.txt
index 480c230c40..a605bc831c 100644
--- a/docs/ref/contrib/postgres/aggregates.txt
+++ b/docs/ref/contrib/postgres/aggregates.txt
@@ -22,7 +22,7 @@ General-purpose aggregation functions
``ArrayAgg``
------------
-.. class:: ArrayAgg(expression, distinct=False, filter=None, **extra)
+.. class:: ArrayAgg(expression, distinct=False, filter=None, ordering=(), **extra)
Returns a list of values, including nulls, concatenated into an array.
@@ -31,6 +31,22 @@ General-purpose aggregation functions
An optional boolean argument that determines if array values
will be distinct. Defaults to ``False``.
+ .. attribute:: ordering
+
+ .. versionadded:: 2.2
+
+ An optional string of a field name (with an optional ``"-"`` prefix
+ which indicates descending order) or an expression (or a tuple or list
+ of strings and/or expressions) that specifies the ordering of the
+ elements in the result list.
+
+ Examples::
+
+ 'some_field'
+ '-some_field'
+ from django.db.models import F
+ F('some_field').desc()
+
``BitAnd``
----------
@@ -73,7 +89,7 @@ General-purpose aggregation functions
``StringAgg``
-------------
-.. class:: StringAgg(expression, delimiter, distinct=False, filter=None)
+.. class:: StringAgg(expression, delimiter, distinct=False, filter=None, ordering=())
Returns the input values concatenated into a string, separated by
the ``delimiter`` string.
@@ -87,6 +103,17 @@ General-purpose aggregation functions
An optional boolean argument that determines if concatenated values
will be distinct. Defaults to ``False``.
+ .. attribute:: ordering
+
+ .. versionadded:: 2.2
+
+ An optional string of a field name (with an optional ``"-"`` prefix
+ which indicates descending order) or an expression (or a tuple or list
+ of strings and/or expressions) that specifies the ordering of the
+ elements in the result string.
+
+ Examples are the same as for :attr:`ArrayAgg.ordering`.
+
Aggregate functions for statistics
==================================
diff --git a/docs/releases/2.2.txt b/docs/releases/2.2.txt
index 840d4b4d0d..742f4893be 100644
--- a/docs/releases/2.2.txt
+++ b/docs/releases/2.2.txt
@@ -70,7 +70,10 @@ Minor features
:mod:`django.contrib.postgres`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-* ...
+* The new ``ordering`` argument for
+ :class:`~django.contrib.postgres.aggregates.ArrayAgg` and
+ :class:`~django.contrib.postgres.aggregates.StringAgg` determines the
+ ordering of the aggregated elements.
:mod:`django.contrib.redirects`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
diff --git a/tests/postgres_tests/test_aggregates.py b/tests/postgres_tests/test_aggregates.py
index d4a01ff027..85d6f45fd1 100644
--- a/tests/postgres_tests/test_aggregates.py
+++ b/tests/postgres_tests/test_aggregates.py
@@ -22,21 +22,57 @@ class TestGeneralAggregate(PostgreSQLTestCase):
def setUpTestData(cls):
AggregateTestModel.objects.create(boolean_field=True, char_field='Foo1', integer_field=0)
AggregateTestModel.objects.create(boolean_field=False, char_field='Foo2', integer_field=1)
- AggregateTestModel.objects.create(boolean_field=False, char_field='Foo3', integer_field=2)
- AggregateTestModel.objects.create(boolean_field=True, char_field='Foo4', integer_field=0)
+ AggregateTestModel.objects.create(boolean_field=False, char_field='Foo4', integer_field=2)
+ AggregateTestModel.objects.create(boolean_field=True, char_field='Foo3', integer_field=0)
def test_array_agg_charfield(self):
values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('char_field'))
- self.assertEqual(values, {'arrayagg': ['Foo1', 'Foo2', 'Foo3', 'Foo4']})
+ self.assertEqual(values, {'arrayagg': ['Foo1', 'Foo2', 'Foo4', 'Foo3']})
+
+ def test_array_agg_charfield_ordering(self):
+ ordering_test_cases = (
+ (F('char_field').desc(), ['Foo4', 'Foo3', 'Foo2', 'Foo1']),
+ (F('char_field').asc(), ['Foo1', 'Foo2', 'Foo3', 'Foo4']),
+ (F('char_field'), ['Foo1', 'Foo2', 'Foo3', 'Foo4']),
+ ([F('boolean_field'), F('char_field').desc()], ['Foo4', 'Foo2', 'Foo3', 'Foo1']),
+ ((F('boolean_field'), F('char_field').desc()), ['Foo4', 'Foo2', 'Foo3', 'Foo1']),
+ ('char_field', ['Foo1', 'Foo2', 'Foo3', 'Foo4']),
+ ('-char_field', ['Foo4', 'Foo3', 'Foo2', 'Foo1']),
+ )
+ for ordering, expected_output in ordering_test_cases:
+ with self.subTest(ordering=ordering, expected_output=expected_output):
+ values = AggregateTestModel.objects.aggregate(
+ arrayagg=ArrayAgg('char_field', ordering=ordering)
+ )
+ self.assertEqual(values, {'arrayagg': expected_output})
def test_array_agg_integerfield(self):
values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('integer_field'))
self.assertEqual(values, {'arrayagg': [0, 1, 2, 0]})
+ def test_array_agg_integerfield_ordering(self):
+ values = AggregateTestModel.objects.aggregate(
+ arrayagg=ArrayAgg('integer_field', ordering=F('integer_field').desc())
+ )
+ self.assertEqual(values, {'arrayagg': [2, 1, 0, 0]})
+
def test_array_agg_booleanfield(self):
values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('boolean_field'))
self.assertEqual(values, {'arrayagg': [True, False, False, True]})
+ def test_array_agg_booleanfield_ordering(self):
+ ordering_test_cases = (
+ (F('boolean_field').asc(), [False, False, True, True]),
+ (F('boolean_field').desc(), [True, True, False, False]),
+ (F('boolean_field'), [False, False, True, True]),
+ )
+ for ordering, expected_output in ordering_test_cases:
+ with self.subTest(ordering=ordering, expected_output=expected_output):
+ values = AggregateTestModel.objects.aggregate(
+ arrayagg=ArrayAgg('boolean_field', ordering=ordering)
+ )
+ self.assertEqual(values, {'arrayagg': expected_output})
+
def test_array_agg_empty_result(self):
AggregateTestModel.objects.all().delete()
values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('char_field'))
@@ -122,17 +158,36 @@ class TestGeneralAggregate(PostgreSQLTestCase):
def test_string_agg_charfield(self):
values = AggregateTestModel.objects.aggregate(stringagg=StringAgg('char_field', delimiter=';'))
- self.assertEqual(values, {'stringagg': 'Foo1;Foo2;Foo3;Foo4'})
+ self.assertEqual(values, {'stringagg': 'Foo1;Foo2;Foo4;Foo3'})
+
+ def test_string_agg_charfield_ordering(self):
+ ordering_test_cases = (
+ (F('char_field').desc(), 'Foo4;Foo3;Foo2;Foo1'),
+ (F('char_field').asc(), 'Foo1;Foo2;Foo3;Foo4'),
+ (F('char_field'), 'Foo1;Foo2;Foo3;Foo4'),
+ )
+ for ordering, expected_output in ordering_test_cases:
+ with self.subTest(ordering=ordering, expected_output=expected_output):
+ values = AggregateTestModel.objects.aggregate(
+ stringagg=StringAgg('char_field', delimiter=';', ordering=ordering)
+ )
+ self.assertEqual(values, {'stringagg': expected_output})
def test_string_agg_empty_result(self):
AggregateTestModel.objects.all().delete()
values = AggregateTestModel.objects.aggregate(stringagg=StringAgg('char_field', delimiter=';'))
self.assertEqual(values, {'stringagg': ''})
+ def test_orderable_agg_alternative_fields(self):
+ values = AggregateTestModel.objects.aggregate(
+ arrayagg=ArrayAgg('integer_field', ordering=F('char_field').asc())
+ )
+ self.assertEqual(values, {'arrayagg': [0, 1, 0, 2]})
+
@skipUnlessDBFeature('has_jsonb_agg')
def test_json_agg(self):
values = AggregateTestModel.objects.aggregate(jsonagg=JSONBAgg('char_field'))
- self.assertEqual(values, {'jsonagg': ['Foo1', 'Foo2', 'Foo3', 'Foo4']})
+ self.assertEqual(values, {'jsonagg': ['Foo1', 'Foo2', 'Foo4', 'Foo3']})
@skipUnlessDBFeature('has_jsonb_agg')
def test_json_agg_empty(self):