summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--django/db/models/query.py4
-rw-r--r--tests/composite_pk/tests.py30
-rw-r--r--tests/lookup/tests.py19
3 files changed, 24 insertions, 29 deletions
diff --git a/django/db/models/query.py b/django/db/models/query.py
index 9f245b02ca..7ae9f53bfd 100644
--- a/django/db/models/query.py
+++ b/django/db/models/query.py
@@ -1187,10 +1187,8 @@ class QuerySet(AltersData):
if not id_list:
return {}
filter_key = "{}__in".format(field_name)
- max_params = connections[self.db].features.max_query_params or 0
- num_fields = len(opts.pk_fields) if field_name == "pk" else 1
- batch_size = max_params // num_fields
id_list = tuple(id_list)
+ batch_size = connections[self.db].ops.bulk_batch_size([opts.pk], id_list)
# If the database has a limit on the number of query parameters
# (e.g. SQLite), retrieve objects in batches if necessary.
if batch_size and batch_size < len(id_list):
diff --git a/tests/composite_pk/tests.py b/tests/composite_pk/tests.py
index cc78f3495a..c4a8e6ca8c 100644
--- a/tests/composite_pk/tests.py
+++ b/tests/composite_pk/tests.py
@@ -147,20 +147,24 @@ class CompositePKTests(TestCase):
result = Comment.objects.in_bulk([self.comment.pk])
self.assertEqual(result, {self.comment.pk: self.comment})
- @unittest.mock.patch.object(
- type(connection.features), "max_query_params", new_callable=lambda: 10
- )
- def test_in_bulk_batching(self, mocked_max_query_params):
+ def test_in_bulk_batching(self):
Comment.objects.all().delete()
- num_requiring_batching = (connection.features.max_query_params // 2) + 1
- comments = [
- Comment(id=i, tenant=self.tenant, user=self.user)
- for i in range(1, num_requiring_batching + 1)
- ]
- Comment.objects.bulk_create(comments)
- id_list = list(Comment.objects.values_list("pk", flat=True))
- with self.assertNumQueries(2):
- comment_dict = Comment.objects.in_bulk(id_list=id_list)
+ batching_required = connection.features.max_query_params is not None
+ expected_queries = 2 if batching_required else 1
+ with unittest.mock.patch.object(
+ type(connection.features), "max_query_params", 10
+ ):
+ num_requiring_batching = (
+ connection.ops.bulk_batch_size([Comment._meta.pk], []) + 1
+ )
+ comments = [
+ Comment(id=i, tenant=self.tenant, user=self.user)
+ for i in range(1, num_requiring_batching + 1)
+ ]
+ Comment.objects.bulk_create(comments)
+ id_list = list(Comment.objects.values_list("pk", flat=True))
+ with self.assertNumQueries(expected_queries):
+ comment_dict = Comment.objects.in_bulk(id_list=id_list)
self.assertQuerySetEqual(comment_dict, id_list)
def test_iterator(self):
diff --git a/tests/lookup/tests.py b/tests/lookup/tests.py
index 25336cbee7..ef54472e54 100644
--- a/tests/lookup/tests.py
+++ b/tests/lookup/tests.py
@@ -248,28 +248,21 @@ class LookupTests(TestCase):
with self.assertRaisesMessage(ValueError, msg):
Article.objects.in_bulk([self.au1], field_name="author")
- @skipUnlessDBFeature("can_distinct_on_fields")
def test_in_bulk_preserve_ordering(self):
- articles = (
- Article.objects.order_by("author_id", "-pub_date")
- .distinct("author_id")
- .in_bulk([self.au1.id, self.au2.id], field_name="author_id")
- )
self.assertEqual(
- articles,
- {self.au1.id: self.a4, self.au2.id: self.a5},
+ list(Article.objects.in_bulk([self.au2.id, self.au1.id])),
+ [self.au2.id, self.au1.id],
)
- @skipUnlessDBFeature("can_distinct_on_fields")
def test_in_bulk_preserve_ordering_with_batch_size(self):
- qs = Article.objects.order_by("author_id", "-pub_date").distinct("author_id")
+ qs = Article.objects.all()
with (
- mock.patch.object(connection.features.__class__, "max_query_params", 1),
+ mock.patch.object(connection.ops, "bulk_batch_size", return_value=2),
self.assertNumQueries(2),
):
self.assertEqual(
- qs.in_bulk([self.au1.id, self.au2.id], field_name="author_id"),
- {self.au1.id: self.a4, self.au2.id: self.a5},
+ list(qs.in_bulk([self.a4.id, self.a3.id, self.a2.id, self.a1.id])),
+ [self.a4.id, self.a3.id, self.a2.id, self.a1.id],
)
@skipUnlessDBFeature("can_distinct_on_fields")