summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/annotations/models.py8
-rw-r--r--tests/annotations/tests.py21
2 files changed, 29 insertions, 0 deletions
diff --git a/tests/annotations/models.py b/tests/annotations/models.py
index fbb9ca6988..914770d2fe 100644
--- a/tests/annotations/models.py
+++ b/tests/annotations/models.py
@@ -58,3 +58,11 @@ class Company(models.Model):
class Ticket(models.Model):
active_at = models.DateTimeField()
duration = models.DurationField()
+
+
+class JsonModel(models.Model):
+ data = models.JSONField(default=dict, blank=True)
+ id = models.IntegerField(primary_key=True)
+
+ class Meta:
+ required_db_features = {"supports_json_field"}
diff --git a/tests/annotations/tests.py b/tests/annotations/tests.py
index 703847e1dd..29660a827e 100644
--- a/tests/annotations/tests.py
+++ b/tests/annotations/tests.py
@@ -1,7 +1,9 @@
import datetime
from decimal import Decimal
+from unittest import skipUnless
from django.core.exceptions import FieldDoesNotExist, FieldError
+from django.db import connection
from django.db.models import (
BooleanField,
Case,
@@ -15,6 +17,7 @@ from django.db.models import (
FloatField,
Func,
IntegerField,
+ JSONField,
Max,
OuterRef,
Q,
@@ -43,6 +46,7 @@ from .models import (
Company,
DepartmentStore,
Employee,
+ JsonModel,
Publisher,
Store,
Ticket,
@@ -1167,6 +1171,23 @@ class NonAggregateAnnotationTestCase(TestCase):
with self.assertRaisesMessage(ValueError, msg):
Book.objects.annotate(**{crafted_alias: Value(1)})
+ @skipUnless(connection.vendor == "postgresql", "PostgreSQL tests")
+ @skipUnlessDBFeature("supports_json_field")
+ def test_set_returning_functions(self):
+ class JSONBPathQuery(Func):
+ function = "jsonb_path_query"
+ output_field = JSONField()
+ set_returning = True
+
+ test_model = JsonModel.objects.create(
+ data={"key": [{"id": 1, "name": "test1"}, {"id": 2, "name": "test2"}]}, id=1
+ )
+ qs = JsonModel.objects.annotate(
+ table_element=JSONBPathQuery("data", Value("$.key[*]"))
+ ).filter(pk=test_model.pk)
+
+ self.assertEqual(qs.count(), len(qs))
+
class AliasTests(TestCase):
@classmethod