diff options
Diffstat (limited to 'django/db/models/query.py')
| -rw-r--r-- | django/db/models/query.py | 66 |
1 files changed, 59 insertions, 7 deletions
diff --git a/django/db/models/query.py b/django/db/models/query.py index 721bf33e57..4cccb383fd 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -1166,8 +1166,6 @@ class QuerySet(AltersData): """ if self.query.is_sliced: raise TypeError("Cannot use 'limit' or 'offset' with in_bulk().") - if not issubclass(self._iterable_class, ModelIterable): - raise TypeError("in_bulk() cannot be used with values() or values_list().") opts = self.model._meta unique_fields = [ constraint.fields[0] @@ -1184,6 +1182,59 @@ class QuerySet(AltersData): "in_bulk()'s field_name must be a unique field but %r isn't." % field_name ) + + qs = self + + def get_obj(obj): + return obj + + if issubclass(self._iterable_class, ModelIterable): + # Raise an AttributeError if field_name is deferred. + get_key = operator.attrgetter(field_name) + + elif issubclass(self._iterable_class, ValuesIterable): + if field_name not in self.query.values_select: + qs = qs.values(field_name, *self.query.values_select) + + def get_obj(obj): # noqa: F811 + # We can safely mutate the dictionaries returned by + # ValuesIterable here, since they are limited to the scope + # of this function, and get_key runs before get_obj. + del obj[field_name] + return obj + + get_key = operator.itemgetter(field_name) + + elif issubclass(self._iterable_class, ValuesListIterable): + try: + field_index = self.query.values_select.index(field_name) + except ValueError: + # field_name is missing from values_select, so add it. + field_index = 0 + if issubclass(self._iterable_class, NamedValuesListIterable): + kwargs = {"named": True} + else: + kwargs = {} + get_obj = operator.itemgetter(slice(1, None)) + qs = qs.values_list(field_name, *self.query.values_select, **kwargs) + + get_key = operator.itemgetter(field_index) + + elif issubclass(self._iterable_class, FlatValuesListIterable): + if self.query.values_select == (field_name,): + # Mapping field_name to itself. + get_key = get_obj + else: + # Transform it back into a non-flat values_list(). + qs = qs.values_list(field_name, *self.query.values_select) + get_key = operator.itemgetter(0) + get_obj = operator.itemgetter(1) + + else: + raise TypeError( + f"in_bulk() cannot be used with {self._iterable_class.__name__}." + ) + if id_list is not None: if not id_list: return {} @@ -1193,15 +1244,16 @@ class QuerySet(AltersData): # If the database has a limit on the number of query parameters # (e.g. SQLite), retrieve objects in batches if necessary. if batch_size and batch_size < len(id_list): - qs = () + results = () for offset in range(0, len(id_list), batch_size): batch = id_list[offset : offset + batch_size] - qs += tuple(self.filter(**{filter_key: batch})) + results += tuple(qs.filter(**{filter_key: batch})) + qs = results else: - qs = self.filter(**{filter_key: id_list}) + qs = qs.filter(**{filter_key: id_list}) else: - qs = self._chain() - return {getattr(obj, field_name): obj for obj in qs} + qs = qs._chain() + return {get_key(obj): get_obj(obj) for obj in qs} async def ain_bulk(self, id_list=None, *, field_name="pk"): return await sync_to_async(self.in_bulk)( |
