diff options
| author | Adam Johnson <me@adamj.eu> | 2023-11-29 09:35:34 +0000 |
|---|---|---|
| committer | Jacob Walls <jacobtylerwalls@gmail.com> | 2025-10-16 14:52:22 -0400 |
| commit | e097e8a12f21a4e92594830f1ad1942b31916d0f (patch) | |
| tree | 43f448bf968f0c6c1a48577cbc4d1ba5b920624a /django/db/models | |
| parent | f6bd90c84050a1c74fe2161cced00e7282cb845c (diff) | |
Fixed #28586 -- Added model field fetch modes.
May your database queries be much reduced with minimal effort.
co-authored-by: Andreas Pelme <andreas@pelme.se>
co-authored-by: Simon Charette <charette.s@gmail.com>
co-authored-by: Jacob Walls <jacobtylerwalls@gmail.com>
Diffstat (limited to 'django/db/models')
| -rw-r--r-- | django/db/models/__init__.py | 4 | ||||
| -rw-r--r-- | django/db/models/base.py | 25 | ||||
| -rw-r--r-- | django/db/models/fetch_modes.py | 52 | ||||
| -rw-r--r-- | django/db/models/fields/related_descriptors.py | 67 | ||||
| -rw-r--r-- | django/db/models/query.py | 33 | ||||
| -rw-r--r-- | django/db/models/query_utils.py | 17 |
6 files changed, 174 insertions, 24 deletions
diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py index ec54b65240..f15ddecfaa 100644 --- a/django/db/models/__init__.py +++ b/django/db/models/__init__.py @@ -36,6 +36,7 @@ from django.db.models.expressions import ( WindowFrame, WindowFrameExclusion, ) +from django.db.models.fetch_modes import FETCH_ONE, FETCH_PEERS, RAISE from django.db.models.fields import * # NOQA from django.db.models.fields import __all__ as fields_all from django.db.models.fields.composite import CompositePrimaryKey @@ -105,6 +106,9 @@ __all__ += [ "GeneratedField", "JSONField", "OrderWrt", + "FETCH_ONE", + "FETCH_PEERS", + "RAISE", "Lookup", "Transform", "Manager", diff --git a/django/db/models/base.py b/django/db/models/base.py index fd51052d01..b92a198660 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -32,6 +32,7 @@ from django.db.models import NOT_PROVIDED, ExpressionWrapper, IntegerField, Max, from django.db.models.constants import LOOKUP_SEP from django.db.models.deletion import CASCADE, Collector from django.db.models.expressions import DatabaseDefault +from django.db.models.fetch_modes import FETCH_ONE from django.db.models.fields.composite import CompositePrimaryKey from django.db.models.fields.related import ( ForeignObjectRel, @@ -466,6 +467,14 @@ class ModelStateFieldsCacheDescriptor: return res +class ModelStateFetchModeDescriptor: + def __get__(self, instance, cls=None): + if instance is None: + return self + res = instance.fetch_mode = FETCH_ONE + return res + + class ModelState: """Store model instance state.""" @@ -476,6 +485,14 @@ class ModelState: # on the actual save. adding = True fields_cache = ModelStateFieldsCacheDescriptor() + fetch_mode = ModelStateFetchModeDescriptor() + peers = () + + def __getstate__(self): + state = self.__dict__.copy() + # Weak references can't be pickled. + state.pop("peers", None) + return state class Model(AltersData, metaclass=ModelBase): @@ -595,7 +612,7 @@ class Model(AltersData, metaclass=ModelBase): post_init.send(sender=cls, instance=self) @classmethod - def from_db(cls, db, field_names, values): + def from_db(cls, db, field_names, values, *, fetch_mode=None): if len(values) != len(cls._meta.concrete_fields): values_iter = iter(values) values = [ @@ -605,6 +622,8 @@ class Model(AltersData, metaclass=ModelBase): new = cls(*values) new._state.adding = False new._state.db = db + if fetch_mode is not None: + new._state.fetch_mode = fetch_mode return new def __repr__(self): @@ -714,8 +733,8 @@ class Model(AltersData, metaclass=ModelBase): should be an iterable of field attnames. If fields is None, then all non-deferred fields are reloaded. - When accessing deferred fields of an instance, the deferred loading - of the field will call this method. + When fetching deferred fields for a single instance (the FETCH_ONE + fetch mode), the deferred loading uses this method. """ if fields is None: self._prefetched_objects_cache = {} diff --git a/django/db/models/fetch_modes.py b/django/db/models/fetch_modes.py new file mode 100644 index 0000000000..a22ccd8a23 --- /dev/null +++ b/django/db/models/fetch_modes.py @@ -0,0 +1,52 @@ +from django.core.exceptions import FieldFetchBlocked + + +class FetchMode: + __slots__ = () + + track_peers = False + + def fetch(self, fetcher, instance): + raise NotImplementedError("Subclasses must implement this method.") + + +class FetchOne(FetchMode): + __slots__ = () + + def fetch(self, fetcher, instance): + fetcher.fetch_one(instance) + + +FETCH_ONE = FetchOne() + + +class FetchPeers(FetchMode): + __slots__ = () + + track_peers = True + + def fetch(self, fetcher, instance): + instances = [ + peer + for peer_weakref in instance._state.peers + if (peer := peer_weakref()) is not None + ] + if len(instances) > 1: + fetcher.fetch_many(instances) + else: + fetcher.fetch_one(instance) + + +FETCH_PEERS = FetchPeers() + + +class Raise(FetchMode): + __slots__ = () + + def fetch(self, fetcher, instance): + klass = instance.__class__.__qualname__ + field_name = fetcher.field.name + raise FieldFetchBlocked(f"Fetching of {klass}.{field_name} blocked.") from None + + +RAISE = Raise() diff --git a/django/db/models/fields/related_descriptors.py b/django/db/models/fields/related_descriptors.py index 3e2150e0f6..2c8e59f1d9 100644 --- a/django/db/models/fields/related_descriptors.py +++ b/django/db/models/fields/related_descriptors.py @@ -78,7 +78,7 @@ from django.db.models.expressions import ColPairs from django.db.models.fields.tuple_lookups import TupleIn from django.db.models.functions import RowNumber from django.db.models.lookups import GreaterThan, LessThanOrEqual -from django.db.models.query import QuerySet +from django.db.models.query import QuerySet, prefetch_related_objects from django.db.models.query_utils import DeferredAttribute from django.db.models.utils import AltersData, resolve_callables from django.utils.functional import cached_property @@ -254,13 +254,9 @@ class ForwardManyToOneDescriptor: break if rel_obj is None and has_value: - rel_obj = self.get_object(instance) - remote_field = self.field.remote_field - # If this is a one-to-one relation, set the reverse accessor - # cache on the related object to the current instance to avoid - # an extra SQL query if it's accessed later on. - if not remote_field.multiple: - remote_field.set_cached_value(rel_obj, instance) + instance._state.fetch_mode.fetch(self, instance) + return self.field.get_cached_value(instance) + self.field.set_cached_value(instance, rel_obj) if rel_obj is None and not self.field.null: @@ -270,6 +266,21 @@ class ForwardManyToOneDescriptor: else: return rel_obj + def fetch_one(self, instance): + rel_obj = self.get_object(instance) + self.field.set_cached_value(instance, rel_obj) + # If this is a one-to-one relation, set the reverse accessor cache on + # the related object to the current instance to avoid an extra SQL + # query if it's accessed later on. + remote_field = self.field.remote_field + if not remote_field.multiple: + remote_field.set_cached_value(rel_obj, instance) + + def fetch_many(self, instances): + is_cached = self.is_cached + missing_instances = [i for i in instances if not is_cached(i)] + prefetch_related_objects(missing_instances, self.field.name) + def __set__(self, instance, value): """ Set the related instance through the forward relation. @@ -504,16 +515,8 @@ class ReverseOneToOneDescriptor: if not instance._is_pk_set(): rel_obj = None else: - filter_args = self.related.field.get_forward_related_filter(instance) - try: - rel_obj = self.get_queryset(instance=instance).get(**filter_args) - except self.related.related_model.DoesNotExist: - rel_obj = None - else: - # Set the forward accessor cache on the related object to - # the current instance to avoid an extra SQL query if it's - # accessed later on. - self.related.field.set_cached_value(rel_obj, instance) + instance._state.fetch_mode.fetch(self, instance) + rel_obj = self.related.get_cached_value(instance) self.related.set_cached_value(instance, rel_obj) if rel_obj is None: @@ -524,6 +527,34 @@ class ReverseOneToOneDescriptor: else: return rel_obj + @property + def field(self): + """ + Add compatibility with the fetcher protocol. While self.related is not + a field but a OneToOneRel, it quacks enough like a field to work. + """ + return self.related + + def fetch_one(self, instance): + # Kept for backwards compatibility with overridden + # get_forward_related_filter() + filter_args = self.related.field.get_forward_related_filter(instance) + try: + rel_obj = self.get_queryset(instance=instance).get(**filter_args) + except self.related.related_model.DoesNotExist: + rel_obj = None + else: + self.related.field.set_cached_value(rel_obj, instance) + self.related.set_cached_value(instance, rel_obj) + + def fetch_many(self, instances): + is_cached = self.is_cached + missing_instances = [i for i in instances if not is_cached(i)] + prefetch_related_objects( + missing_instances, + self.related.get_accessor_name(), + ) + def __set__(self, instance, value): """ Set the related instance through the reverse relation. diff --git a/django/db/models/query.py b/django/db/models/query.py index 39cc9b6cb3..0811b90b5e 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -8,6 +8,7 @@ import warnings from contextlib import nullcontext from functools import reduce from itertools import chain, islice +from weakref import ref as weak_ref from asgiref.sync import sync_to_async @@ -26,6 +27,7 @@ from django.db.models import AutoField, DateField, DateTimeField, Field, Max, sq from django.db.models.constants import LOOKUP_SEP, OnConflict from django.db.models.deletion import Collector from django.db.models.expressions import Case, DatabaseDefault, F, Value, When +from django.db.models.fetch_modes import FETCH_ONE from django.db.models.functions import Cast, Trunc from django.db.models.query_utils import FilteredRelation, Q from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, ROW_COUNT @@ -122,10 +124,18 @@ class ModelIterable(BaseIterable): ) for field, related_objs in queryset._known_related_objects.items() ] + fetch_mode = queryset._fetch_mode + peers = [] for row in compiler.results_iter(results): obj = model_cls.from_db( - db, init_list, row[model_fields_start:model_fields_end] + db, + init_list, + row[model_fields_start:model_fields_end], + fetch_mode=fetch_mode, ) + if fetch_mode.track_peers: + peers.append(weak_ref(obj)) + obj._state.peers = peers for rel_populator in related_populators: rel_populator.populate(row, obj) if annotation_col_map: @@ -183,10 +193,17 @@ class RawModelIterable(BaseIterable): query_iterator = compiler.composite_fields_to_tuples( query_iterator, cols ) + fetch_mode = self.queryset._fetch_mode + peers = [] for values in query_iterator: # Associate fields to values model_init_values = [values[pos] for pos in model_init_pos] - instance = model_cls.from_db(db, model_init_names, model_init_values) + instance = model_cls.from_db( + db, model_init_names, model_init_values, fetch_mode=fetch_mode + ) + if fetch_mode.track_peers: + peers.append(weak_ref(instance)) + instance._state.peers = peers if annotation_fields: for column, pos in annotation_fields: setattr(instance, column, values[pos]) @@ -293,6 +310,7 @@ class QuerySet(AltersData): self._prefetch_done = False self._known_related_objects = {} # {rel_field: {pk: rel_obj}} self._iterable_class = ModelIterable + self._fetch_mode = FETCH_ONE self._fields = None self._defer_next_filter = False self._deferred_filter = None @@ -1442,6 +1460,7 @@ class QuerySet(AltersData): params=params, translations=translations, using=using, + fetch_mode=self._fetch_mode, ) qs._prefetch_related_lookups = self._prefetch_related_lookups[:] return qs @@ -1913,6 +1932,12 @@ class QuerySet(AltersData): clone._db = alias return clone + def fetch_mode(self, fetch_mode): + """Set the fetch mode for the QuerySet.""" + clone = self._chain() + clone._fetch_mode = fetch_mode + return clone + ################################### # PUBLIC INTROSPECTION ATTRIBUTES # ################################### @@ -2051,6 +2076,7 @@ class QuerySet(AltersData): c._prefetch_related_lookups = self._prefetch_related_lookups[:] c._known_related_objects = self._known_related_objects c._iterable_class = self._iterable_class + c._fetch_mode = self._fetch_mode c._fields = self._fields return c @@ -2186,6 +2212,7 @@ class RawQuerySet: translations=None, using=None, hints=None, + fetch_mode=FETCH_ONE, ): self.raw_query = raw_query self.model = model @@ -2197,6 +2224,7 @@ class RawQuerySet: self._result_cache = None self._prefetch_related_lookups = () self._prefetch_done = False + self._fetch_mode = fetch_mode def resolve_model_init_order(self): """Resolve the init field names and value positions.""" @@ -2295,6 +2323,7 @@ class RawQuerySet: params=self.params, translations=self.translations, using=alias, + fetch_mode=self._fetch_mode, ) @cached_property diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py index c383b80640..23d543211a 100644 --- a/django/db/models/query_utils.py +++ b/django/db/models/query_utils.py @@ -264,7 +264,8 @@ class DeferredAttribute: f"Cannot retrieve deferred field {field_name!r} " "from an unsaved model." ) - instance.refresh_from_db(fields=[field_name]) + + instance._state.fetch_mode.fetch(self, instance) else: data[field_name] = val return data[field_name] @@ -281,6 +282,20 @@ class DeferredAttribute: return getattr(instance, link_field.attname) return None + def fetch_one(self, instance): + instance.refresh_from_db(fields=[self.field.attname]) + + def fetch_many(self, instances): + attname = self.field.attname + db = instances[0]._state.db + value_by_pk = ( + self.field.model._base_manager.using(db) + .values_list(attname) + .in_bulk({i.pk for i in instances}) + ) + for instance in instances: + setattr(instance, attname, value_by_pk[instance.pk]) + class class_or_instance_method: """ |
