summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBendeguz Csirmaz <csirmazbendeguz@gmail.com>2024-07-23 21:17:34 +0800
committerSarah Boyce <42296566+sarahboyce@users.noreply.github.com>2024-08-01 17:26:09 +0200
commit1eac690d25dd49088256954d4046813daa37dc95 (patch)
tree5888b12e950501b204f0dbb2453a89ec484449b9
parent3dac3271d286f2790780e89d31ddbb7197f8defa (diff)
Refs #373 -- Added tuple lookups.
-rw-r--r--AUTHORS1
-rw-r--r--django/db/models/expressions.py46
-rw-r--r--django/db/models/fields/related_lookups.py83
-rw-r--r--django/db/models/fields/tuple_lookups.py244
-rw-r--r--django/db/models/sql/query.py6
-rw-r--r--tests/foreign_object/test_tuple_lookups.py242
6 files changed, 554 insertions, 68 deletions
diff --git a/AUTHORS b/AUTHORS
index 2915761b96..faf6420618 100644
--- a/AUTHORS
+++ b/AUTHORS
@@ -152,6 +152,7 @@ answer newbie questions, and generally made Django that much better:
Ben Lomax <lomax.on.the.run@gmail.com>
Ben Slavin <benjamin.slavin@gmail.com>
Ben Sturmfels <ben@sturm.com.au>
+ Bendegúz Csirmaz <csirmazbendeguz@gmail.com>
Berker Peksag <berker.peksag@gmail.com>
Bernd Schlapsi
Bernhard Essl <me@bernhardessl.com>
diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py
index ffb9f3c816..4a242012ee 100644
--- a/django/db/models/expressions.py
+++ b/django/db/models/expressions.py
@@ -1295,6 +1295,52 @@ class Col(Expression):
) + self.target.get_db_converters(connection)
+class ColPairs(Expression):
+ def __init__(self, alias, targets, sources, output_field):
+ super().__init__(output_field=output_field)
+ self.alias, self.targets, self.sources = alias, targets, sources
+
+ def __len__(self):
+ return len(self.targets)
+
+ def __iter__(self):
+ return iter(self.get_cols())
+
+ def get_cols(self):
+ return [
+ Col(self.alias, target, source)
+ for target, source in zip(self.targets, self.sources)
+ ]
+
+ def get_source_expressions(self):
+ return self.get_cols()
+
+ def set_source_expressions(self, exprs):
+ assert all(isinstance(expr, Col) and expr.alias == self.alias for expr in exprs)
+ self.targets = [col.target for col in exprs]
+ self.sources = [col.field for col in exprs]
+
+ def as_sql(self, compiler, connection):
+ cols_sql = []
+ cols_params = []
+ cols = self.get_cols()
+
+ for col in cols:
+ sql, params = col.as_sql(compiler, connection)
+ cols_sql.append(sql)
+ cols_params.extend(params)
+
+ return ", ".join(cols_sql), cols_params
+
+ def relabeled_clone(self, relabels):
+ return self.__class__(
+ relabels.get(self.alias, self.alias), self.targets, self.sources, self.field
+ )
+
+ def resolve_expression(self, *args, **kwargs):
+ return self
+
+
class Ref(Expression):
"""
Reference to column alias of the query. For example, Ref('sum_cost') in
diff --git a/django/db/models/fields/related_lookups.py b/django/db/models/fields/related_lookups.py
index 07a06e1686..22fd17ab4f 100644
--- a/django/db/models/fields/related_lookups.py
+++ b/django/db/models/fields/related_lookups.py
@@ -1,3 +1,6 @@
+from django.db import NotSupportedError
+from django.db.models.expressions import ColPairs
+from django.db.models.fields.tuple_lookups import TupleIn, tuple_lookups
from django.db.models.lookups import (
Exact,
GreaterThan,
@@ -9,34 +12,6 @@ from django.db.models.lookups import (
)
-class MultiColSource:
- contains_aggregate = False
- contains_over_clause = False
-
- def __init__(self, alias, targets, sources, field):
- self.targets, self.sources, self.field, self.alias = (
- targets,
- sources,
- field,
- alias,
- )
- self.output_field = self.field
-
- def __repr__(self):
- return "{}({}, {})".format(self.__class__.__name__, self.alias, self.field)
-
- def relabeled_clone(self, relabels):
- return self.__class__(
- relabels.get(self.alias, self.alias), self.targets, self.sources, self.field
- )
-
- def get_lookup(self, lookup):
- return self.output_field.get_lookup(lookup)
-
- def resolve_expression(self, *args, **kwargs):
- return self
-
-
def get_normalized_value(value, lhs):
from django.db.models import Model
@@ -64,7 +39,7 @@ def get_normalized_value(value, lhs):
class RelatedIn(In):
def get_prep_lookup(self):
- if not isinstance(self.lhs, MultiColSource):
+ if not isinstance(self.lhs, ColPairs):
if self.rhs_is_direct_value():
# If we get here, we are dealing with single-column relations.
self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs]
@@ -98,49 +73,33 @@ class RelatedIn(In):
return super().get_prep_lookup()
def as_sql(self, compiler, connection):
- if isinstance(self.lhs, MultiColSource):
+ if isinstance(self.lhs, ColPairs):
# For multicolumn lookups we need to build a multicolumn where clause.
# This clause is either a SubqueryConstraint (for values that need
# to be compiled to SQL) or an OR-combined list of
# (col1 = val1 AND col2 = val2 AND ...) clauses.
- from django.db.models.sql.where import (
- AND,
- OR,
- SubqueryConstraint,
- WhereNode,
- )
+ from django.db.models.sql.where import SubqueryConstraint
- root_constraint = WhereNode(connector=OR)
if self.rhs_is_direct_value():
values = [get_normalized_value(value, self.lhs) for value in self.rhs]
- for value in values:
- value_constraint = WhereNode()
- for source, target, val in zip(
- self.lhs.sources, self.lhs.targets, value
- ):
- lookup_class = target.get_lookup("exact")
- lookup = lookup_class(
- target.get_col(self.lhs.alias, source), val
- )
- value_constraint.add(lookup, AND)
- root_constraint.add(value_constraint, OR)
+ lookup = TupleIn(self.lhs, values)
+ return compiler.compile(lookup)
else:
- root_constraint.add(
+ return compiler.compile(
SubqueryConstraint(
self.lhs.alias,
[target.column for target in self.lhs.targets],
[source.name for source in self.lhs.sources],
self.rhs,
),
- AND,
)
- return root_constraint.as_sql(compiler, connection)
+
return super().as_sql(compiler, connection)
class RelatedLookupMixin:
def get_prep_lookup(self):
- if not isinstance(self.lhs, MultiColSource) and not hasattr(
+ if not isinstance(self.lhs, ColPairs) and not hasattr(
self.rhs, "resolve_expression"
):
# If we get here, we are dealing with single-column relations.
@@ -158,20 +117,16 @@ class RelatedLookupMixin:
return super().get_prep_lookup()
def as_sql(self, compiler, connection):
- if isinstance(self.lhs, MultiColSource):
- assert self.rhs_is_direct_value()
+ if isinstance(self.lhs, ColPairs):
+ if not self.rhs_is_direct_value():
+ raise NotSupportedError(
+ f"'{self.lookup_name}' doesn't support multi-column subqueries."
+ )
self.rhs = get_normalized_value(self.rhs, self.lhs)
- from django.db.models.sql.where import AND, WhereNode
+ lookup_class = tuple_lookups[self.lookup_name]
+ lookup = lookup_class(self.lhs, self.rhs)
+ return compiler.compile(lookup)
- root_constraint = WhereNode()
- for target, source, val in zip(
- self.lhs.targets, self.lhs.sources, self.rhs
- ):
- lookup_class = target.get_lookup(self.lookup_name)
- root_constraint.add(
- lookup_class(target.get_col(self.lhs.alias, source), val), AND
- )
- return root_constraint.as_sql(compiler, connection)
return super().as_sql(compiler, connection)
diff --git a/django/db/models/fields/tuple_lookups.py b/django/db/models/fields/tuple_lookups.py
new file mode 100644
index 0000000000..468826f224
--- /dev/null
+++ b/django/db/models/fields/tuple_lookups.py
@@ -0,0 +1,244 @@
+import itertools
+
+from django.core.exceptions import EmptyResultSet
+from django.db.models.expressions import ColPairs, Func, Value
+from django.db.models.lookups import (
+ Exact,
+ GreaterThan,
+ GreaterThanOrEqual,
+ In,
+ IsNull,
+ LessThan,
+ LessThanOrEqual,
+)
+from django.db.models.sql.where import AND, OR, WhereNode
+
+
+class Tuple(Func):
+ function = ""
+
+
+class TupleLookupMixin:
+ def get_prep_lookup(self):
+ self.check_tuple_lookup()
+ return super().get_prep_lookup()
+
+ def check_tuple_lookup(self):
+ assert isinstance(self.lhs, ColPairs)
+ self.check_rhs_is_tuple_or_list()
+ self.check_rhs_length_equals_lhs_length()
+
+ def check_rhs_is_tuple_or_list(self):
+ if not isinstance(self.rhs, (tuple, list)):
+ raise ValueError(
+ f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field "
+ "must be a tuple or a list"
+ )
+
+ def check_rhs_length_equals_lhs_length(self):
+ if len(self.lhs) != len(self.rhs):
+ raise ValueError(
+ f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field "
+ f"must have {len(self.lhs)} elements"
+ )
+
+ def check_rhs_is_collection_of_tuples_or_lists(self):
+ if not all(isinstance(vals, (tuple, list)) for vals in self.rhs):
+ raise ValueError(
+ f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field "
+ f"must be a collection of tuples or lists"
+ )
+
+ def check_rhs_elements_length_equals_lhs_length(self):
+ if not all(len(self.lhs) == len(vals) for vals in self.rhs):
+ raise ValueError(
+ f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field "
+ f"must have {len(self.lhs)} elements each"
+ )
+
+ def as_sql(self, compiler, connection):
+ # e.g.: (a, b, c) == (x, y, z) as SQL:
+ # WHERE (a, b, c) = (x, y, z)
+ vals = [
+ Value(val, output_field=col.output_field)
+ for col, val in zip(self.lhs, self.rhs)
+ ]
+ lookup_class = self.__class__.__bases__[-1]
+ lookup = lookup_class(Tuple(self.lhs), Tuple(*vals))
+ return lookup.as_sql(compiler, connection)
+
+
+class TupleExact(TupleLookupMixin, Exact):
+ def as_oracle(self, compiler, connection):
+ # e.g.: (a, b, c) == (x, y, z) as SQL:
+ # WHERE a = x AND b = y AND c = z
+ cols = self.lhs.get_cols()
+ lookups = [Exact(col, val) for col, val in zip(cols, self.rhs)]
+ root = WhereNode(lookups, connector=AND)
+
+ return root.as_sql(compiler, connection)
+
+
+class TupleIsNull(IsNull):
+ def as_sql(self, compiler, connection):
+ # e.g.: (a, b, c) is None as SQL:
+ # WHERE a IS NULL AND b IS NULL AND c IS NULL
+ vals = self.rhs
+ if isinstance(vals, bool):
+ vals = [vals] * len(self.lhs)
+
+ cols = self.lhs.get_cols()
+ lookups = [IsNull(col, val) for col, val in zip(cols, vals)]
+ root = WhereNode(lookups, connector=AND)
+
+ return root.as_sql(compiler, connection)
+
+
+class TupleGreaterThan(TupleLookupMixin, GreaterThan):
+ def as_oracle(self, compiler, connection):
+ # e.g.: (a, b, c) > (x, y, z) as SQL:
+ # WHERE a > x OR (a = x AND (b > y OR (b = y AND c > z)))
+ cols = self.lhs.get_cols()
+ lookups = itertools.cycle([GreaterThan, Exact])
+ connectors = itertools.cycle([OR, AND])
+ cols_list = [col for col in cols for _ in range(2)]
+ vals_list = [val for val in self.rhs for _ in range(2)]
+ cols_iter = iter(cols_list[:-1])
+ vals_iter = iter(vals_list[:-1])
+ col, val = next(cols_iter), next(vals_iter)
+ lookup, connector = next(lookups), next(connectors)
+ root = node = WhereNode([lookup(col, val)], connector=connector)
+
+ for col, val in zip(cols_iter, vals_iter):
+ lookup, connector = next(lookups), next(connectors)
+ child = WhereNode([lookup(col, val)], connector=connector)
+ node.children.append(child)
+ node = child
+
+ return root.as_sql(compiler, connection)
+
+
+class TupleGreaterThanOrEqual(TupleLookupMixin, GreaterThanOrEqual):
+ def as_oracle(self, compiler, connection):
+ # e.g.: (a, b, c) >= (x, y, z) as SQL:
+ # WHERE a > x OR (a = x AND (b > y OR (b = y AND (c > z OR c = z))))
+ cols = self.lhs.get_cols()
+ lookups = itertools.cycle([GreaterThan, Exact])
+ connectors = itertools.cycle([OR, AND])
+ cols_list = [col for col in cols for _ in range(2)]
+ vals_list = [val for val in self.rhs for _ in range(2)]
+ cols_iter = iter(cols_list)
+ vals_iter = iter(vals_list)
+ col, val = next(cols_iter), next(vals_iter)
+ lookup, connector = next(lookups), next(connectors)
+ root = node = WhereNode([lookup(col, val)], connector=connector)
+
+ for col, val in zip(cols_iter, vals_iter):
+ lookup, connector = next(lookups), next(connectors)
+ child = WhereNode([lookup(col, val)], connector=connector)
+ node.children.append(child)
+ node = child
+
+ return root.as_sql(compiler, connection)
+
+
+class TupleLessThan(TupleLookupMixin, LessThan):
+ def as_oracle(self, compiler, connection):
+ # e.g.: (a, b, c) < (x, y, z) as SQL:
+ # WHERE a < x OR (a = x AND (b < y OR (b = y AND c < z)))
+ cols = self.lhs.get_cols()
+ lookups = itertools.cycle([LessThan, Exact])
+ connectors = itertools.cycle([OR, AND])
+ cols_list = [col for col in cols for _ in range(2)]
+ vals_list = [val for val in self.rhs for _ in range(2)]
+ cols_iter = iter(cols_list[:-1])
+ vals_iter = iter(vals_list[:-1])
+ col, val = next(cols_iter), next(vals_iter)
+ lookup, connector = next(lookups), next(connectors)
+ root = node = WhereNode([lookup(col, val)], connector=connector)
+
+ for col, val in zip(cols_iter, vals_iter):
+ lookup, connector = next(lookups), next(connectors)
+ child = WhereNode([lookup(col, val)], connector=connector)
+ node.children.append(child)
+ node = child
+
+ return root.as_sql(compiler, connection)
+
+
+class TupleLessThanOrEqual(TupleLookupMixin, LessThanOrEqual):
+ def as_oracle(self, compiler, connection):
+ # e.g.: (a, b, c) <= (x, y, z) as SQL:
+ # WHERE a < x OR (a = x AND (b < y OR (b = y AND (c < z OR c = z))))
+ cols = self.lhs.get_cols()
+ lookups = itertools.cycle([LessThan, Exact])
+ connectors = itertools.cycle([OR, AND])
+ cols_list = [col for col in cols for _ in range(2)]
+ vals_list = [val for val in self.rhs for _ in range(2)]
+ cols_iter = iter(cols_list)
+ vals_iter = iter(vals_list)
+ col, val = next(cols_iter), next(vals_iter)
+ lookup, connector = next(lookups), next(connectors)
+ root = node = WhereNode([lookup(col, val)], connector=connector)
+
+ for col, val in zip(cols_iter, vals_iter):
+ lookup, connector = next(lookups), next(connectors)
+ child = WhereNode([lookup(col, val)], connector=connector)
+ node.children.append(child)
+ node = child
+
+ return root.as_sql(compiler, connection)
+
+
+class TupleIn(TupleLookupMixin, In):
+ def check_tuple_lookup(self):
+ assert isinstance(self.lhs, ColPairs)
+ self.check_rhs_is_tuple_or_list()
+ self.check_rhs_is_collection_of_tuples_or_lists()
+ self.check_rhs_elements_length_equals_lhs_length()
+
+ def as_sql(self, compiler, connection):
+ if not self.rhs:
+ raise EmptyResultSet
+
+ # e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL:
+ # WHERE (a, b, c) IN ((x1, y1, z1), (x2, y2, z2))
+ rhs = []
+ for vals in self.rhs:
+ rhs.append(
+ Tuple(
+ *[
+ Value(val, output_field=col.output_field)
+ for col, val in zip(self.lhs, vals)
+ ]
+ )
+ )
+
+ lookup = In(Tuple(self.lhs), Tuple(*rhs))
+ return lookup.as_sql(compiler, connection)
+
+ def as_sqlite(self, compiler, connection):
+ if not self.rhs:
+ raise EmptyResultSet
+
+ # e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL:
+ # WHERE (a = x1 AND b = y1 AND c = z1) OR (a = x2 AND b = y2 AND c = z2)
+ root = WhereNode([], connector=OR)
+ cols = self.lhs.get_cols()
+
+ for vals in self.rhs:
+ lookups = [Exact(col, val) for col, val in zip(cols, vals)]
+ root.children.append(WhereNode(lookups, connector=AND))
+
+ return root.as_sql(compiler, connection)
+
+
+tuple_lookups = {
+ "exact": TupleExact,
+ "gt": TupleGreaterThan,
+ "gte": TupleGreaterThanOrEqual,
+ "lt": TupleLessThan,
+ "lte": TupleLessThanOrEqual,
+ "in": TupleIn,
+ "isnull": TupleIsNull,
+}
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
index 9a57af2bf3..09916277bc 100644
--- a/django/db/models/sql/query.py
+++ b/django/db/models/sql/query.py
@@ -23,6 +23,7 @@ from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import (
BaseExpression,
Col,
+ ColPairs,
Exists,
F,
OuterRef,
@@ -32,7 +33,6 @@ from django.db.models.expressions import (
Value,
)
from django.db.models.fields import Field
-from django.db.models.fields.related_lookups import MultiColSource
from django.db.models.lookups import Lookup
from django.db.models.query_utils import (
Q,
@@ -1549,9 +1549,7 @@ class Query(BaseExpression):
if len(targets) == 1:
col = self._get_col(targets[0], join_info.final_field, alias)
else:
- col = MultiColSource(
- alias, targets, join_info.targets, join_info.final_field
- )
+ col = ColPairs(alias, targets, join_info.targets, join_info.final_field)
else:
col = self._get_col(targets[0], join_info.final_field, alias)
diff --git a/tests/foreign_object/test_tuple_lookups.py b/tests/foreign_object/test_tuple_lookups.py
new file mode 100644
index 0000000000..cf080d084b
--- /dev/null
+++ b/tests/foreign_object/test_tuple_lookups.py
@@ -0,0 +1,242 @@
+import unittest
+
+from django.db import NotSupportedError, connection
+from django.test import TestCase
+
+from .models import Contact, Customer
+
+
+class TupleLookupsTests(TestCase):
+ @classmethod
+ def setUpTestData(cls):
+ super().setUpTestData()
+ cls.customer_1 = Customer.objects.create(customer_id=1, company="a")
+ cls.customer_2 = Customer.objects.create(customer_id=1, company="b")
+ cls.customer_3 = Customer.objects.create(customer_id=2, company="c")
+ cls.customer_4 = Customer.objects.create(customer_id=3, company="d")
+ cls.customer_5 = Customer.objects.create(customer_id=1, company="e")
+ cls.contact_1 = Contact.objects.create(customer=cls.customer_1)
+ cls.contact_2 = Contact.objects.create(customer=cls.customer_1)
+ cls.contact_3 = Contact.objects.create(customer=cls.customer_2)
+ cls.contact_4 = Contact.objects.create(customer=cls.customer_3)
+ cls.contact_5 = Contact.objects.create(customer=cls.customer_1)
+ cls.contact_6 = Contact.objects.create(customer=cls.customer_5)
+
+ def test_exact(self):
+ test_cases = (
+ (self.customer_1, (self.contact_1, self.contact_2, self.contact_5)),
+ (self.customer_2, (self.contact_3,)),
+ (self.customer_3, (self.contact_4,)),
+ (self.customer_4, ()),
+ (self.customer_5, (self.contact_6,)),
+ )
+
+ for customer, contacts in test_cases:
+ with self.subTest(customer=customer, contacts=contacts):
+ self.assertSequenceEqual(
+ Contact.objects.filter(customer=customer).order_by("id"), contacts
+ )
+
+ def test_exact_subquery(self):
+ with self.assertRaisesMessage(
+ NotSupportedError, "'exact' doesn't support multi-column subqueries."
+ ):
+ subquery = Customer.objects.filter(id=self.customer_1.id)[:1]
+ self.assertSequenceEqual(
+ Contact.objects.filter(customer=subquery).order_by("id"), ()
+ )
+
+ def test_in(self):
+ cust_1, cust_2, cust_3, cust_4, cust_5 = (
+ self.customer_1,
+ self.customer_2,
+ self.customer_3,
+ self.customer_4,
+ self.customer_5,
+ )
+ c1, c2, c3, c4, c5, c6 = (
+ self.contact_1,
+ self.contact_2,
+ self.contact_3,
+ self.contact_4,
+ self.contact_5,
+ self.contact_6,
+ )
+ test_cases = (
+ ((), ()),
+ ((cust_1,), (c1, c2, c5)),
+ ((cust_1, cust_2), (c1, c2, c3, c5)),
+ ((cust_1, cust_2, cust_3), (c1, c2, c3, c4, c5)),
+ ((cust_1, cust_2, cust_3, cust_4), (c1, c2, c3, c4, c5)),
+ ((cust_1, cust_2, cust_3, cust_4, cust_5), (c1, c2, c3, c4, c5, c6)),
+ )
+
+ for contacts, customers in test_cases:
+ with self.subTest(contacts=contacts, customers=customers):
+ self.assertSequenceEqual(
+ Contact.objects.filter(customer__in=contacts).order_by("id"),
+ customers,
+ )
+
+ @unittest.skipIf(
+ connection.vendor == "mysql",
+ "MySQL doesn't support LIMIT & IN/ALL/ANY/SOME subquery",
+ )
+ def test_in_subquery(self):
+ subquery = Customer.objects.filter(id=self.customer_1.id)[:1]
+ self.assertSequenceEqual(
+ Contact.objects.filter(customer__in=subquery).order_by("id"),
+ (self.contact_1, self.contact_2, self.contact_5),
+ )
+
+ def test_lt(self):
+ c1, c2, c3, c4, c5, c6 = (
+ self.contact_1,
+ self.contact_2,
+ self.contact_3,
+ self.contact_4,
+ self.contact_5,
+ self.contact_6,
+ )
+ test_cases = (
+ (self.customer_1, ()),
+ (self.customer_2, (c1, c2, c5)),
+ (self.customer_5, (c1, c2, c3, c5)),
+ (self.customer_3, (c1, c2, c3, c5, c6)),
+ (self.customer_4, (c1, c2, c3, c4, c5, c6)),
+ )
+
+ for customer, contacts in test_cases:
+ with self.subTest(customer=customer, contacts=contacts):
+ self.assertSequenceEqual(
+ Contact.objects.filter(customer__lt=customer).order_by("id"),
+ contacts,
+ )
+
+ def test_lt_subquery(self):
+ with self.assertRaisesMessage(
+ NotSupportedError, "'lt' doesn't support multi-column subqueries."
+ ):
+ subquery = Customer.objects.filter(id=self.customer_1.id)[:1]
+ self.assertSequenceEqual(
+ Contact.objects.filter(customer__lt=subquery).order_by("id"), ()
+ )
+
+ def test_lte(self):
+ c1, c2, c3, c4, c5, c6 = (
+ self.contact_1,
+ self.contact_2,
+ self.contact_3,
+ self.contact_4,
+ self.contact_5,
+ self.contact_6,
+ )
+ test_cases = (
+ (self.customer_1, (c1, c2, c5)),
+ (self.customer_2, (c1, c2, c3, c5)),
+ (self.customer_5, (c1, c2, c3, c5, c6)),
+ (self.customer_3, (c1, c2, c3, c4, c5, c6)),
+ (self.customer_4, (c1, c2, c3, c4, c5, c6)),
+ )
+
+ for customer, contacts in test_cases:
+ with self.subTest(customer=customer, contacts=contacts):
+ self.assertSequenceEqual(
+ Contact.objects.filter(customer__lte=customer).order_by("id"),
+ contacts,
+ )
+
+ def test_lte_subquery(self):
+ with self.assertRaisesMessage(
+ NotSupportedError, "'lte' doesn't support multi-column subqueries."
+ ):
+ subquery = Customer.objects.filter(id=self.customer_1.id)[:1]
+ self.assertSequenceEqual(
+ Contact.objects.filter(customer__lte=subquery).order_by("id"), ()
+ )
+
+ def test_gt(self):
+ test_cases = (
+ (self.customer_1, (self.contact_3, self.contact_4, self.contact_6)),
+ (self.customer_2, (self.contact_4, self.contact_6)),
+ (self.customer_5, (self.contact_4,)),
+ (self.customer_3, ()),
+ (self.customer_4, ()),
+ )
+
+ for customer, contacts in test_cases:
+ with self.subTest(customer=customer, contacts=contacts):
+ self.assertSequenceEqual(
+ Contact.objects.filter(customer__gt=customer).order_by("id"),
+ contacts,
+ )
+
+ def test_gt_subquery(self):
+ with self.assertRaisesMessage(
+ NotSupportedError, "'gt' doesn't support multi-column subqueries."
+ ):
+ subquery = Customer.objects.filter(id=self.customer_1.id)[:1]
+ self.assertSequenceEqual(
+ Contact.objects.filter(customer__gt=subquery).order_by("id"), ()
+ )
+
+ def test_gte(self):
+ c1, c2, c3, c4, c5, c6 = (
+ self.contact_1,
+ self.contact_2,
+ self.contact_3,
+ self.contact_4,
+ self.contact_5,
+ self.contact_6,
+ )
+ test_cases = (
+ (self.customer_1, (c1, c2, c3, c4, c5, c6)),
+ (self.customer_2, (c3, c4, c6)),
+ (self.customer_5, (c4, c6)),
+ (self.customer_3, (c4,)),
+ (self.customer_4, ()),
+ )
+
+ for customer, contacts in test_cases:
+ with self.subTest(customer=customer, contacts=contacts):
+ self.assertSequenceEqual(
+ Contact.objects.filter(customer__gte=customer).order_by("pk"),
+ contacts,
+ )
+
+ def test_gte_subquery(self):
+ with self.assertRaisesMessage(
+ NotSupportedError, "'gte' doesn't support multi-column subqueries."
+ ):
+ subquery = Customer.objects.filter(id=self.customer_1.id)[:1]
+ self.assertSequenceEqual(
+ Contact.objects.filter(customer__gte=subquery).order_by("id"), ()
+ )
+
+ def test_isnull(self):
+ with self.subTest("customer__isnull=True"):
+ self.assertSequenceEqual(
+ Contact.objects.filter(customer__isnull=True).order_by("id"),
+ (),
+ )
+ with self.subTest("customer__isnull=False"):
+ self.assertSequenceEqual(
+ Contact.objects.filter(customer__isnull=False).order_by("id"),
+ (
+ self.contact_1,
+ self.contact_2,
+ self.contact_3,
+ self.contact_4,
+ self.contact_5,
+ self.contact_6,
+ ),
+ )
+
+ def test_isnull_subquery(self):
+ with self.assertRaisesMessage(
+ NotSupportedError, "'isnull' doesn't support multi-column subqueries."
+ ):
+ subquery = Customer.objects.filter(id=0)[:1]
+ self.assertSequenceEqual(
+ Contact.objects.filter(customer__isnull=subquery).order_by("id"), ()
+ )