diff options
Diffstat (limited to 'tests/expressions_window/tests.py')
| -rw-r--r-- | tests/expressions_window/tests.py | 310 |
1 files changed, 280 insertions, 30 deletions
diff --git a/tests/expressions_window/tests.py b/tests/expressions_window/tests.py index 15f8a4d6b2..a71a3f947d 100644 --- a/tests/expressions_window/tests.py +++ b/tests/expressions_window/tests.py @@ -6,10 +6,9 @@ from django.core.exceptions import FieldError from django.db import NotSupportedError, connection from django.db.models import ( Avg, - BooleanField, Case, + Count, F, - Func, IntegerField, Max, Min, @@ -41,15 +40,17 @@ from django.db.models.functions import ( RowNumber, Upper, ) +from django.db.models.lookups import Exact from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature -from .models import Detail, Employee +from .models import Classification, Detail, Employee, PastEmployeeDepartment @skipUnlessDBFeature("supports_over_clause") class WindowFunctionTests(TestCase): @classmethod def setUpTestData(cls): + classification = Classification.objects.create() Employee.objects.bulk_create( [ Employee( @@ -59,6 +60,7 @@ class WindowFunctionTests(TestCase): hire_date=e[3], age=e[4], bonus=Decimal(e[1]) / 400, + classification=classification, ) for e in [ ("Jones", 45000, "Accounting", datetime.datetime(2005, 11, 1), 20), @@ -82,6 +84,13 @@ class WindowFunctionTests(TestCase): ] ] ) + employees = list(Employee.objects.order_by("pk")) + PastEmployeeDepartment.objects.bulk_create( + [ + PastEmployeeDepartment(employee=employees[6], department="Sales"), + PastEmployeeDepartment(employee=employees[10], department="IT"), + ] + ) def test_dense_rank(self): tests = [ @@ -902,6 +911,263 @@ class WindowFunctionTests(TestCase): ) self.assertEqual(qs.count(), 12) + def test_filter(self): + qs = Employee.objects.annotate( + department_salary_rank=Window( + Rank(), partition_by="department", order_by="-salary" + ), + department_avg_age_diff=( + Window(Avg("age"), partition_by="department") - F("age") + ), + ).order_by("department", "name") + # Direct window reference. + self.assertQuerysetEqual( + qs.filter(department_salary_rank=1), + ["Adams", "Wilkinson", "Miller", "Johnson", "Smith"], + lambda employee: employee.name, + ) + # Through a combined expression containing a window. + self.assertQuerysetEqual( + qs.filter(department_avg_age_diff__gt=0), + ["Jenson", "Jones", "Williams", "Miller", "Smith"], + lambda employee: employee.name, + ) + # Intersection of multiple windows. + self.assertQuerysetEqual( + qs.filter(department_salary_rank=1, department_avg_age_diff__gt=0), + ["Miller"], + lambda employee: employee.name, + ) + # Union of multiple windows. + self.assertQuerysetEqual( + qs.filter(Q(department_salary_rank=1) | Q(department_avg_age_diff__gt=0)), + [ + "Adams", + "Jenson", + "Jones", + "Williams", + "Wilkinson", + "Miller", + "Johnson", + "Smith", + "Smith", + ], + lambda employee: employee.name, + ) + + def test_filter_conditional_annotation(self): + qs = ( + Employee.objects.annotate( + rank=Window(Rank(), partition_by="department", order_by="-salary"), + case_first_rank=Case( + When(rank=1, then=True), + default=False, + ), + q_first_rank=Q(rank=1), + ) + .order_by("name") + .values_list("name", flat=True) + ) + for annotation in ["case_first_rank", "q_first_rank"]: + with self.subTest(annotation=annotation): + self.assertSequenceEqual( + qs.filter(**{annotation: True}), + ["Adams", "Johnson", "Miller", "Smith", "Wilkinson"], + ) + + def test_filter_conditional_expression(self): + qs = ( + Employee.objects.filter( + Exact(Window(Rank(), partition_by="department", order_by="-salary"), 1) + ) + .order_by("name") + .values_list("name", flat=True) + ) + self.assertSequenceEqual( + qs, ["Adams", "Johnson", "Miller", "Smith", "Wilkinson"] + ) + + def test_filter_column_ref_rhs(self): + qs = ( + Employee.objects.annotate( + max_dept_salary=Window(Max("salary"), partition_by="department") + ) + .filter(max_dept_salary=F("salary")) + .order_by("name") + .values_list("name", flat=True) + ) + self.assertSequenceEqual( + qs, ["Adams", "Johnson", "Miller", "Smith", "Wilkinson"] + ) + + def test_filter_values(self): + qs = ( + Employee.objects.annotate( + department_salary_rank=Window( + Rank(), partition_by="department", order_by="-salary" + ), + ) + .order_by("department", "name") + .values_list(Upper("name"), flat=True) + ) + self.assertSequenceEqual( + qs.filter(department_salary_rank=1), + ["ADAMS", "WILKINSON", "MILLER", "JOHNSON", "SMITH"], + ) + + def test_filter_alias(self): + qs = Employee.objects.alias( + department_avg_age_diff=( + Window(Avg("age"), partition_by="department") - F("age") + ), + ).order_by("department", "name") + self.assertQuerysetEqual( + qs.filter(department_avg_age_diff__gt=0), + ["Jenson", "Jones", "Williams", "Miller", "Smith"], + lambda employee: employee.name, + ) + + def test_filter_select_related(self): + qs = ( + Employee.objects.alias( + department_avg_age_diff=( + Window(Avg("age"), partition_by="department") - F("age") + ), + ) + .select_related("classification") + .filter(department_avg_age_diff__gt=0) + .order_by("department", "name") + ) + self.assertQuerysetEqual( + qs, + ["Jenson", "Jones", "Williams", "Miller", "Smith"], + lambda employee: employee.name, + ) + with self.assertNumQueries(0): + qs[0].classification + + def test_exclude(self): + qs = Employee.objects.annotate( + department_salary_rank=Window( + Rank(), partition_by="department", order_by="-salary" + ), + department_avg_age_diff=( + Window(Avg("age"), partition_by="department") - F("age") + ), + ).order_by("department", "name") + # Direct window reference. + self.assertQuerysetEqual( + qs.exclude(department_salary_rank__gt=1), + ["Adams", "Wilkinson", "Miller", "Johnson", "Smith"], + lambda employee: employee.name, + ) + # Through a combined expression containing a window. + self.assertQuerysetEqual( + qs.exclude(department_avg_age_diff__lte=0), + ["Jenson", "Jones", "Williams", "Miller", "Smith"], + lambda employee: employee.name, + ) + # Union of multiple windows. + self.assertQuerysetEqual( + qs.exclude( + Q(department_salary_rank__gt=1) | Q(department_avg_age_diff__lte=0) + ), + ["Miller"], + lambda employee: employee.name, + ) + # Intersection of multiple windows. + self.assertQuerysetEqual( + qs.exclude(department_salary_rank__gt=1, department_avg_age_diff__lte=0), + [ + "Adams", + "Jenson", + "Jones", + "Williams", + "Wilkinson", + "Miller", + "Johnson", + "Smith", + "Smith", + ], + lambda employee: employee.name, + ) + + def test_heterogeneous_filter(self): + qs = ( + Employee.objects.annotate( + department_salary_rank=Window( + Rank(), partition_by="department", order_by="-salary" + ), + ) + .order_by("name") + .values_list("name", flat=True) + ) + # Heterogeneous filter between window function and aggregates pushes + # the WHERE clause to the QUALIFY outer query. + self.assertSequenceEqual( + qs.filter( + department_salary_rank=1, department__in=["Accounting", "Management"] + ), + ["Adams", "Miller"], + ) + self.assertSequenceEqual( + qs.filter( + Q(department_salary_rank=1) + | Q(department__in=["Accounting", "Management"]) + ), + [ + "Adams", + "Jenson", + "Johnson", + "Johnson", + "Jones", + "Miller", + "Smith", + "Wilkinson", + "Williams", + ], + ) + # Heterogeneous filter between window function and aggregates pushes + # the HAVING clause to the QUALIFY outer query. + qs = qs.annotate(past_department_count=Count("past_departments")) + self.assertSequenceEqual( + qs.filter(department_salary_rank=1, past_department_count__gte=1), + ["Johnson", "Miller"], + ) + self.assertSequenceEqual( + qs.filter(Q(department_salary_rank=1) | Q(past_department_count__gte=1)), + ["Adams", "Johnson", "Miller", "Smith", "Wilkinson"], + ) + + def test_limited_filter(self): + """ + A query filtering against a window function have its limit applied + after window filtering takes place. + """ + self.assertQuerysetEqual( + Employee.objects.annotate( + department_salary_rank=Window( + Rank(), partition_by="department", order_by="-salary" + ) + ) + .filter(department_salary_rank=1) + .order_by("department")[0:3], + ["Adams", "Wilkinson", "Miller"], + lambda employee: employee.name, + ) + + def test_filter_count(self): + self.assertEqual( + Employee.objects.annotate( + department_salary_rank=Window( + Rank(), partition_by="department", order_by="-salary" + ) + ) + .filter(department_salary_rank=1) + .count(), + 5, + ) + @skipUnlessDBFeature("supports_frame_range_fixed_distance") def test_range_n_preceding_and_following(self): qs = Employee.objects.annotate( @@ -1071,6 +1337,7 @@ class WindowFunctionTests(TestCase): ), year=ExtractYear("hire_date"), ) + .filter(sum__gte=45000) .values("year", "sum") .distinct("year") .order_by("year") @@ -1081,7 +1348,6 @@ class WindowFunctionTests(TestCase): {"year": 2008, "sum": 45000}, {"year": 2009, "sum": 128000}, {"year": 2011, "sum": 60000}, - {"year": 2012, "sum": 40000}, {"year": 2013, "sum": 84000}, ] for idx, val in zip(range(len(results)), results): @@ -1348,34 +1614,18 @@ class NonQueryWindowTests(SimpleTestCase): frame.window_frame_start_end(None, None, None) def test_invalid_filter(self): - msg = "Window is disallowed in the filter clause" - qs = Employee.objects.annotate(dense_rank=Window(expression=DenseRank())) - with self.assertRaisesMessage(NotSupportedError, msg): - qs.filter(dense_rank__gte=1) - with self.assertRaisesMessage(NotSupportedError, msg): - qs.annotate(inc_rank=F("dense_rank") + Value(1)).filter(inc_rank__gte=1) - with self.assertRaisesMessage(NotSupportedError, msg): - qs.filter(id=F("dense_rank")) - with self.assertRaisesMessage(NotSupportedError, msg): - qs.filter(id=Func("dense_rank", 2, function="div")) - with self.assertRaisesMessage(NotSupportedError, msg): - qs.annotate(total=Sum("dense_rank", filter=Q(name="Jones"))).filter(total=1) - - def test_conditional_annotation(self): + msg = ( + "Heterogeneous disjunctive predicates against window functions are not " + "implemented when performing conditional aggregation." + ) qs = Employee.objects.annotate( - dense_rank=Window(expression=DenseRank()), - ).annotate( - equal=Case( - When(id=F("dense_rank"), then=Value(True)), - default=Value(False), - output_field=BooleanField(), - ), + window=Window(Rank()), + past_dept_cnt=Count("past_departments"), ) - # The SQL standard disallows referencing window functions in the WHERE - # clause. - msg = "Window is disallowed in the filter clause" - with self.assertRaisesMessage(NotSupportedError, msg): - qs.filter(equal=True) + with self.assertRaisesMessage(NotImplementedError, msg): + list(qs.filter(Q(window=1) | Q(department="Accounting"))) + with self.assertRaisesMessage(NotImplementedError, msg): + list(qs.exclude(window=1, department="Accounting")) def test_invalid_order_by(self): msg = ( |
