summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--django/utils/choices.py26
-rw-r--r--tests/utils_tests/test_choices.py64
2 files changed, 83 insertions, 7 deletions
diff --git a/django/utils/choices.py b/django/utils/choices.py
index a0611d96f1..734b9331a1 100644
--- a/django/utils/choices.py
+++ b/django/utils/choices.py
@@ -1,11 +1,37 @@
from collections.abc import Callable, Iterable, Iterator, Mapping
+from itertools import islice, zip_longest
from django.utils.functional import Promise
+__all__ = [
+ "BaseChoiceIterator",
+ "CallableChoiceIterator",
+ "normalize_choices",
+]
+
class BaseChoiceIterator:
"""Base class for lazy iterators for choices."""
+ def __eq__(self, other):
+ if isinstance(other, Iterable):
+ return all(a == b for a, b in zip_longest(self, other, fillvalue=object()))
+ return super().__eq__(other)
+
+ def __getitem__(self, index):
+ if index < 0:
+ # Suboptimally consume whole iterator to handle negative index.
+ return list(self)[index]
+ try:
+ return next(islice(self, index, index + 1))
+ except StopIteration:
+ raise IndexError("index out of range") from None
+
+ def __iter__(self):
+ raise NotImplementedError(
+ "BaseChoiceIterator subclasses must implement __iter__()."
+ )
+
class CallableChoiceIterator(BaseChoiceIterator):
"""Iterator to lazily normalize choices generated by a callable."""
diff --git a/tests/utils_tests/test_choices.py b/tests/utils_tests/test_choices.py
index d96c3d49c4..a2ad5541a4 100644
--- a/tests/utils_tests/test_choices.py
+++ b/tests/utils_tests/test_choices.py
@@ -2,10 +2,60 @@ from unittest import mock
from django.db.models import TextChoices
from django.test import SimpleTestCase
-from django.utils.choices import CallableChoiceIterator, normalize_choices
+from django.utils.choices import (
+ BaseChoiceIterator,
+ CallableChoiceIterator,
+ normalize_choices,
+)
from django.utils.translation import gettext_lazy as _
+class SimpleChoiceIterator(BaseChoiceIterator):
+ def __iter__(self):
+ return ((i, f"Item #{i}") for i in range(1, 4))
+
+
+class ChoiceIteratorTests(SimpleTestCase):
+ def test_not_implemented_error_on_missing_iter(self):
+ class InvalidChoiceIterator(BaseChoiceIterator):
+ pass # Not overriding __iter__().
+
+ msg = "BaseChoiceIterator subclasses must implement __iter__()."
+ with self.assertRaisesMessage(NotImplementedError, msg):
+ iter(InvalidChoiceIterator())
+
+ def test_eq(self):
+ unrolled = [(1, "Item #1"), (2, "Item #2"), (3, "Item #3")]
+ self.assertEqual(SimpleChoiceIterator(), unrolled)
+ self.assertEqual(unrolled, SimpleChoiceIterator())
+
+ def test_eq_instances(self):
+ self.assertEqual(SimpleChoiceIterator(), SimpleChoiceIterator())
+
+ def test_not_equal_subset(self):
+ self.assertNotEqual(SimpleChoiceIterator(), [(1, "Item #1"), (2, "Item #2")])
+
+ def test_not_equal_superset(self):
+ self.assertNotEqual(
+ SimpleChoiceIterator(),
+ [(1, "Item #1"), (2, "Item #2"), (3, "Item #3"), None],
+ )
+
+ def test_getitem(self):
+ choices = SimpleChoiceIterator()
+ for i, expected in [(0, (1, "Item #1")), (-1, (3, "Item #3"))]:
+ with self.subTest(index=i):
+ self.assertEqual(choices[i], expected)
+
+ def test_getitem_indexerror(self):
+ choices = SimpleChoiceIterator()
+ for i in (4, -4):
+ with self.subTest(index=i):
+ with self.assertRaises(IndexError) as ctx:
+ choices[i]
+ self.assertTrue(str(ctx.exception).endswith("index out of range"))
+
+
class NormalizeFieldChoicesTests(SimpleTestCase):
expected = [
("C", _("Club")),
@@ -84,7 +134,7 @@ class NormalizeFieldChoicesTests(SimpleTestCase):
get_choices_spy.assert_not_called()
self.assertIsInstance(output, CallableChoiceIterator)
- self.assertEqual(list(output), self.expected)
+ self.assertEqual(output, self.expected)
get_choices_spy.assert_called_once()
def test_mapping(self):
@@ -134,7 +184,7 @@ class NormalizeFieldChoicesTests(SimpleTestCase):
get_media_choices_spy.assert_not_called()
self.assertIsInstance(output, CallableChoiceIterator)
- self.assertEqual(list(output), self.expected_nested)
+ self.assertEqual(output, self.expected_nested)
get_media_choices_spy.assert_called_once()
def test_nested_mapping(self):
@@ -185,7 +235,7 @@ class NormalizeFieldChoicesTests(SimpleTestCase):
get_choices_spy.assert_not_called()
self.assertIsInstance(output, CallableChoiceIterator)
- self.assertEqual(list(output), self.expected)
+ self.assertEqual(output, self.expected)
get_choices_spy.assert_called_once()
def test_iterable_non_canonical(self):
@@ -230,7 +280,7 @@ class NormalizeFieldChoicesTests(SimpleTestCase):
get_media_choices_spy.assert_not_called()
self.assertIsInstance(output, CallableChoiceIterator)
- self.assertEqual(list(output), self.expected_nested)
+ self.assertEqual(output, self.expected_nested)
get_media_choices_spy.assert_called_once()
def test_nested_iterable_non_canonical(self):
@@ -294,12 +344,12 @@ class NormalizeFieldChoicesTests(SimpleTestCase):
def test_unsupported_values_from_callable_returned_unmodified(self):
for value in self.invalid_iterable + self.invalid_nested:
with self.subTest(value=value):
- self.assertEqual(list(normalize_choices(lambda: value)), value)
+ self.assertEqual(normalize_choices(lambda: value), value)
def test_unsupported_values_from_iterator_returned_unmodified(self):
for value in self.invalid_nested:
with self.subTest(value=value):
self.assertEqual(
- list(normalize_choices((lambda: (yield from value))())),
+ normalize_choices((lambda: (yield from value))()),
value,
)