summaryrefslogtreecommitdiff
path: root/django/db/models/query.py
diff options
context:
space:
mode:
Diffstat (limited to 'django/db/models/query.py')
-rw-r--r--django/db/models/query.py66
1 files changed, 59 insertions, 7 deletions
diff --git a/django/db/models/query.py b/django/db/models/query.py
index 721bf33e57..4cccb383fd 100644
--- a/django/db/models/query.py
+++ b/django/db/models/query.py
@@ -1166,8 +1166,6 @@ class QuerySet(AltersData):
"""
if self.query.is_sliced:
raise TypeError("Cannot use 'limit' or 'offset' with in_bulk().")
- if not issubclass(self._iterable_class, ModelIterable):
- raise TypeError("in_bulk() cannot be used with values() or values_list().")
opts = self.model._meta
unique_fields = [
constraint.fields[0]
@@ -1184,6 +1182,59 @@ class QuerySet(AltersData):
"in_bulk()'s field_name must be a unique field but %r isn't."
% field_name
)
+
+ qs = self
+
+ def get_obj(obj):
+ return obj
+
+ if issubclass(self._iterable_class, ModelIterable):
+ # Raise an AttributeError if field_name is deferred.
+ get_key = operator.attrgetter(field_name)
+
+ elif issubclass(self._iterable_class, ValuesIterable):
+ if field_name not in self.query.values_select:
+ qs = qs.values(field_name, *self.query.values_select)
+
+ def get_obj(obj): # noqa: F811
+ # We can safely mutate the dictionaries returned by
+ # ValuesIterable here, since they are limited to the scope
+ # of this function, and get_key runs before get_obj.
+ del obj[field_name]
+ return obj
+
+ get_key = operator.itemgetter(field_name)
+
+ elif issubclass(self._iterable_class, ValuesListIterable):
+ try:
+ field_index = self.query.values_select.index(field_name)
+ except ValueError:
+ # field_name is missing from values_select, so add it.
+ field_index = 0
+ if issubclass(self._iterable_class, NamedValuesListIterable):
+ kwargs = {"named": True}
+ else:
+ kwargs = {}
+ get_obj = operator.itemgetter(slice(1, None))
+ qs = qs.values_list(field_name, *self.query.values_select, **kwargs)
+
+ get_key = operator.itemgetter(field_index)
+
+ elif issubclass(self._iterable_class, FlatValuesListIterable):
+ if self.query.values_select == (field_name,):
+ # Mapping field_name to itself.
+ get_key = get_obj
+ else:
+ # Transform it back into a non-flat values_list().
+ qs = qs.values_list(field_name, *self.query.values_select)
+ get_key = operator.itemgetter(0)
+ get_obj = operator.itemgetter(1)
+
+ else:
+ raise TypeError(
+ f"in_bulk() cannot be used with {self._iterable_class.__name__}."
+ )
+
if id_list is not None:
if not id_list:
return {}
@@ -1193,15 +1244,16 @@ class QuerySet(AltersData):
# If the database has a limit on the number of query parameters
# (e.g. SQLite), retrieve objects in batches if necessary.
if batch_size and batch_size < len(id_list):
- qs = ()
+ results = ()
for offset in range(0, len(id_list), batch_size):
batch = id_list[offset : offset + batch_size]
- qs += tuple(self.filter(**{filter_key: batch}))
+ results += tuple(qs.filter(**{filter_key: batch}))
+ qs = results
else:
- qs = self.filter(**{filter_key: id_list})
+ qs = qs.filter(**{filter_key: id_list})
else:
- qs = self._chain()
- return {getattr(obj, field_name): obj for obj in qs}
+ qs = qs._chain()
+ return {get_key(obj): get_obj(obj) for obj in qs}
async def ain_bulk(self, id_list=None, *, field_name="pk"):
return await sync_to_async(self.in_bulk)(