summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBendeguz Csirmaz <csirmazbendeguz@gmail.com>2024-10-15 01:31:27 +0800
committerSarah Boyce <42296566+sarahboyce@users.noreply.github.com>2024-11-04 09:20:54 +0100
commitf7601aed515a125cde776ebbf6ff6e8432cbafdb (patch)
tree0c40dd6bb13f16dc7b080611602219724796fd7b
parent611bf6c2e2a1b4ab93273980c45150c099ab146d (diff)
Refs #373 -- Added TupleIn subqueries.
-rw-r--r--django/db/models/fields/tuple_lookups.py41
-rw-r--r--tests/foreign_object/test_tuple_lookups.py41
2 files changed, 79 insertions, 3 deletions
diff --git a/django/db/models/fields/tuple_lookups.py b/django/db/models/fields/tuple_lookups.py
index a94582db95..6342937cd6 100644
--- a/django/db/models/fields/tuple_lookups.py
+++ b/django/db/models/fields/tuple_lookups.py
@@ -12,6 +12,7 @@ from django.db.models.lookups import (
LessThan,
LessThanOrEqual,
)
+from django.db.models.sql import Query
from django.db.models.sql.where import AND, OR, WhereNode
@@ -211,9 +212,14 @@ class TupleLessThanOrEqual(TupleLookupMixin, LessThanOrEqual):
class TupleIn(TupleLookupMixin, In):
def get_prep_lookup(self):
- self.check_rhs_is_tuple_or_list()
- self.check_rhs_is_collection_of_tuples_or_lists()
- self.check_rhs_elements_length_equals_lhs_length()
+ if self.rhs_is_direct_value():
+ self.check_rhs_is_tuple_or_list()
+ self.check_rhs_is_collection_of_tuples_or_lists()
+ self.check_rhs_elements_length_equals_lhs_length()
+ else:
+ self.check_rhs_is_query()
+ self.check_rhs_select_length_equals_lhs_length()
+
return self.rhs # skip checks from mixin
def check_rhs_is_collection_of_tuples_or_lists(self):
@@ -233,6 +239,25 @@ class TupleIn(TupleLookupMixin, In):
f"must have {len_lhs} elements each"
)
+ def check_rhs_is_query(self):
+ if not isinstance(self.rhs, Query):
+ lhs_str = self.get_lhs_str()
+ rhs_cls = self.rhs.__class__.__name__
+ raise ValueError(
+ f"{self.lookup_name!r} subquery lookup of {lhs_str} "
+ f"must be a Query object (received {rhs_cls!r})"
+ )
+
+ def check_rhs_select_length_equals_lhs_length(self):
+ len_rhs = len(self.rhs.select)
+ len_lhs = len(self.lhs)
+ if len_rhs != len_lhs:
+ lhs_str = self.get_lhs_str()
+ raise ValueError(
+ f"{self.lookup_name!r} subquery lookup of {lhs_str} "
+ f"must have {len_lhs} fields (received {len_rhs})"
+ )
+
def process_rhs(self, compiler, connection):
rhs = self.rhs
if not rhs:
@@ -255,10 +280,17 @@ class TupleIn(TupleLookupMixin, In):
return Tuple(*result).as_sql(compiler, connection)
+ def as_sql(self, compiler, connection):
+ if not self.rhs_is_direct_value():
+ return self.as_subquery(compiler, connection)
+ return super().as_sql(compiler, connection)
+
def as_sqlite(self, compiler, connection):
rhs = self.rhs
if not rhs:
raise EmptyResultSet
+ if not self.rhs_is_direct_value():
+ return self.as_subquery(compiler, connection)
# e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL:
# WHERE (a = x1 AND b = y1 AND c = z1) OR (a = x2 AND b = y2 AND c = z2)
@@ -271,6 +303,9 @@ class TupleIn(TupleLookupMixin, In):
return root.as_sql(compiler, connection)
+ def as_subquery(self, compiler, connection):
+ return compiler.compile(In(self.lhs, self.rhs))
+
tuple_lookups = {
"exact": TupleExact,
diff --git a/tests/foreign_object/test_tuple_lookups.py b/tests/foreign_object/test_tuple_lookups.py
index 499329e7ca..797fea1c8a 100644
--- a/tests/foreign_object/test_tuple_lookups.py
+++ b/tests/foreign_object/test_tuple_lookups.py
@@ -11,6 +11,7 @@ from django.db.models.fields.tuple_lookups import (
TupleLessThan,
TupleLessThanOrEqual,
)
+from django.db.models.lookups import In
from django.test import TestCase, skipUnlessDBFeature
from .models import Contact, Customer
@@ -126,6 +127,46 @@ class TupleLookupsTests(TestCase):
(self.contact_1, self.contact_2, self.contact_5),
)
+ def test_tuple_in_subquery_must_be_query(self):
+ lhs = (F("customer_code"), F("company_code"))
+ # If rhs is any non-Query object with an as_sql() function.
+ rhs = In(F("customer_code"), [1, 2, 3])
+ with self.assertRaisesMessage(
+ ValueError,
+ "'in' subquery lookup of ('customer_code', 'company_code') "
+ "must be a Query object (received 'In')",
+ ):
+ TupleIn(lhs, rhs)
+
+ def test_tuple_in_subquery_must_have_2_fields(self):
+ lhs = (F("customer_code"), F("company_code"))
+ rhs = Customer.objects.values_list("customer_id").query
+ with self.assertRaisesMessage(
+ ValueError,
+ "'in' subquery lookup of ('customer_code', 'company_code') "
+ "must have 2 fields (received 1)",
+ ):
+ TupleIn(lhs, rhs)
+
+ def test_tuple_in_subquery(self):
+ customers = Customer.objects.values_list("customer_id", "company")
+ test_cases = (
+ (self.customer_1, (self.contact_1, self.contact_2, self.contact_5)),
+ (self.customer_2, (self.contact_3,)),
+ (self.customer_3, (self.contact_4,)),
+ (self.customer_4, ()),
+ (self.customer_5, (self.contact_6,)),
+ )
+
+ for customer, contacts in test_cases:
+ lhs = (F("customer_code"), F("company_code"))
+ rhs = customers.filter(id=customer.id).query
+ lookup = TupleIn(lhs, rhs)
+ qs = Contact.objects.filter(lookup).order_by("id")
+
+ with self.subTest(customer=customer.id, query=str(qs.query)):
+ self.assertSequenceEqual(qs, contacts)
+
def test_tuple_in_rhs_must_be_collection_of_tuples_or_lists(self):
test_cases = (
(1, 2, 3),