summaryrefslogtreecommitdiff
path: root/django/db/models
diff options
context:
space:
mode:
authorAdam Johnson <me@adamj.eu>2023-11-29 09:35:34 +0000
committerJacob Walls <jacobtylerwalls@gmail.com>2025-10-16 14:52:22 -0400
commite097e8a12f21a4e92594830f1ad1942b31916d0f (patch)
tree43f448bf968f0c6c1a48577cbc4d1ba5b920624a /django/db/models
parentf6bd90c84050a1c74fe2161cced00e7282cb845c (diff)
Fixed #28586 -- Added model field fetch modes.
May your database queries be much reduced with minimal effort. co-authored-by: Andreas Pelme <andreas@pelme.se> co-authored-by: Simon Charette <charette.s@gmail.com> co-authored-by: Jacob Walls <jacobtylerwalls@gmail.com>
Diffstat (limited to 'django/db/models')
-rw-r--r--django/db/models/__init__.py4
-rw-r--r--django/db/models/base.py25
-rw-r--r--django/db/models/fetch_modes.py52
-rw-r--r--django/db/models/fields/related_descriptors.py67
-rw-r--r--django/db/models/query.py33
-rw-r--r--django/db/models/query_utils.py17
6 files changed, 174 insertions, 24 deletions
diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py
index ec54b65240..f15ddecfaa 100644
--- a/django/db/models/__init__.py
+++ b/django/db/models/__init__.py
@@ -36,6 +36,7 @@ from django.db.models.expressions import (
WindowFrame,
WindowFrameExclusion,
)
+from django.db.models.fetch_modes import FETCH_ONE, FETCH_PEERS, RAISE
from django.db.models.fields import * # NOQA
from django.db.models.fields import __all__ as fields_all
from django.db.models.fields.composite import CompositePrimaryKey
@@ -105,6 +106,9 @@ __all__ += [
"GeneratedField",
"JSONField",
"OrderWrt",
+ "FETCH_ONE",
+ "FETCH_PEERS",
+ "RAISE",
"Lookup",
"Transform",
"Manager",
diff --git a/django/db/models/base.py b/django/db/models/base.py
index fd51052d01..b92a198660 100644
--- a/django/db/models/base.py
+++ b/django/db/models/base.py
@@ -32,6 +32,7 @@ from django.db.models import NOT_PROVIDED, ExpressionWrapper, IntegerField, Max,
from django.db.models.constants import LOOKUP_SEP
from django.db.models.deletion import CASCADE, Collector
from django.db.models.expressions import DatabaseDefault
+from django.db.models.fetch_modes import FETCH_ONE
from django.db.models.fields.composite import CompositePrimaryKey
from django.db.models.fields.related import (
ForeignObjectRel,
@@ -466,6 +467,14 @@ class ModelStateFieldsCacheDescriptor:
return res
+class ModelStateFetchModeDescriptor:
+ def __get__(self, instance, cls=None):
+ if instance is None:
+ return self
+ res = instance.fetch_mode = FETCH_ONE
+ return res
+
+
class ModelState:
"""Store model instance state."""
@@ -476,6 +485,14 @@ class ModelState:
# on the actual save.
adding = True
fields_cache = ModelStateFieldsCacheDescriptor()
+ fetch_mode = ModelStateFetchModeDescriptor()
+ peers = ()
+
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ # Weak references can't be pickled.
+ state.pop("peers", None)
+ return state
class Model(AltersData, metaclass=ModelBase):
@@ -595,7 +612,7 @@ class Model(AltersData, metaclass=ModelBase):
post_init.send(sender=cls, instance=self)
@classmethod
- def from_db(cls, db, field_names, values):
+ def from_db(cls, db, field_names, values, *, fetch_mode=None):
if len(values) != len(cls._meta.concrete_fields):
values_iter = iter(values)
values = [
@@ -605,6 +622,8 @@ class Model(AltersData, metaclass=ModelBase):
new = cls(*values)
new._state.adding = False
new._state.db = db
+ if fetch_mode is not None:
+ new._state.fetch_mode = fetch_mode
return new
def __repr__(self):
@@ -714,8 +733,8 @@ class Model(AltersData, metaclass=ModelBase):
should be an iterable of field attnames. If fields is None, then
all non-deferred fields are reloaded.
- When accessing deferred fields of an instance, the deferred loading
- of the field will call this method.
+ When fetching deferred fields for a single instance (the FETCH_ONE
+ fetch mode), the deferred loading uses this method.
"""
if fields is None:
self._prefetched_objects_cache = {}
diff --git a/django/db/models/fetch_modes.py b/django/db/models/fetch_modes.py
new file mode 100644
index 0000000000..a22ccd8a23
--- /dev/null
+++ b/django/db/models/fetch_modes.py
@@ -0,0 +1,52 @@
+from django.core.exceptions import FieldFetchBlocked
+
+
+class FetchMode:
+ __slots__ = ()
+
+ track_peers = False
+
+ def fetch(self, fetcher, instance):
+ raise NotImplementedError("Subclasses must implement this method.")
+
+
+class FetchOne(FetchMode):
+ __slots__ = ()
+
+ def fetch(self, fetcher, instance):
+ fetcher.fetch_one(instance)
+
+
+FETCH_ONE = FetchOne()
+
+
+class FetchPeers(FetchMode):
+ __slots__ = ()
+
+ track_peers = True
+
+ def fetch(self, fetcher, instance):
+ instances = [
+ peer
+ for peer_weakref in instance._state.peers
+ if (peer := peer_weakref()) is not None
+ ]
+ if len(instances) > 1:
+ fetcher.fetch_many(instances)
+ else:
+ fetcher.fetch_one(instance)
+
+
+FETCH_PEERS = FetchPeers()
+
+
+class Raise(FetchMode):
+ __slots__ = ()
+
+ def fetch(self, fetcher, instance):
+ klass = instance.__class__.__qualname__
+ field_name = fetcher.field.name
+ raise FieldFetchBlocked(f"Fetching of {klass}.{field_name} blocked.") from None
+
+
+RAISE = Raise()
diff --git a/django/db/models/fields/related_descriptors.py b/django/db/models/fields/related_descriptors.py
index 3e2150e0f6..2c8e59f1d9 100644
--- a/django/db/models/fields/related_descriptors.py
+++ b/django/db/models/fields/related_descriptors.py
@@ -78,7 +78,7 @@ from django.db.models.expressions import ColPairs
from django.db.models.fields.tuple_lookups import TupleIn
from django.db.models.functions import RowNumber
from django.db.models.lookups import GreaterThan, LessThanOrEqual
-from django.db.models.query import QuerySet
+from django.db.models.query import QuerySet, prefetch_related_objects
from django.db.models.query_utils import DeferredAttribute
from django.db.models.utils import AltersData, resolve_callables
from django.utils.functional import cached_property
@@ -254,13 +254,9 @@ class ForwardManyToOneDescriptor:
break
if rel_obj is None and has_value:
- rel_obj = self.get_object(instance)
- remote_field = self.field.remote_field
- # If this is a one-to-one relation, set the reverse accessor
- # cache on the related object to the current instance to avoid
- # an extra SQL query if it's accessed later on.
- if not remote_field.multiple:
- remote_field.set_cached_value(rel_obj, instance)
+ instance._state.fetch_mode.fetch(self, instance)
+ return self.field.get_cached_value(instance)
+
self.field.set_cached_value(instance, rel_obj)
if rel_obj is None and not self.field.null:
@@ -270,6 +266,21 @@ class ForwardManyToOneDescriptor:
else:
return rel_obj
+ def fetch_one(self, instance):
+ rel_obj = self.get_object(instance)
+ self.field.set_cached_value(instance, rel_obj)
+ # If this is a one-to-one relation, set the reverse accessor cache on
+ # the related object to the current instance to avoid an extra SQL
+ # query if it's accessed later on.
+ remote_field = self.field.remote_field
+ if not remote_field.multiple:
+ remote_field.set_cached_value(rel_obj, instance)
+
+ def fetch_many(self, instances):
+ is_cached = self.is_cached
+ missing_instances = [i for i in instances if not is_cached(i)]
+ prefetch_related_objects(missing_instances, self.field.name)
+
def __set__(self, instance, value):
"""
Set the related instance through the forward relation.
@@ -504,16 +515,8 @@ class ReverseOneToOneDescriptor:
if not instance._is_pk_set():
rel_obj = None
else:
- filter_args = self.related.field.get_forward_related_filter(instance)
- try:
- rel_obj = self.get_queryset(instance=instance).get(**filter_args)
- except self.related.related_model.DoesNotExist:
- rel_obj = None
- else:
- # Set the forward accessor cache on the related object to
- # the current instance to avoid an extra SQL query if it's
- # accessed later on.
- self.related.field.set_cached_value(rel_obj, instance)
+ instance._state.fetch_mode.fetch(self, instance)
+ rel_obj = self.related.get_cached_value(instance)
self.related.set_cached_value(instance, rel_obj)
if rel_obj is None:
@@ -524,6 +527,34 @@ class ReverseOneToOneDescriptor:
else:
return rel_obj
+ @property
+ def field(self):
+ """
+ Add compatibility with the fetcher protocol. While self.related is not
+ a field but a OneToOneRel, it quacks enough like a field to work.
+ """
+ return self.related
+
+ def fetch_one(self, instance):
+ # Kept for backwards compatibility with overridden
+ # get_forward_related_filter()
+ filter_args = self.related.field.get_forward_related_filter(instance)
+ try:
+ rel_obj = self.get_queryset(instance=instance).get(**filter_args)
+ except self.related.related_model.DoesNotExist:
+ rel_obj = None
+ else:
+ self.related.field.set_cached_value(rel_obj, instance)
+ self.related.set_cached_value(instance, rel_obj)
+
+ def fetch_many(self, instances):
+ is_cached = self.is_cached
+ missing_instances = [i for i in instances if not is_cached(i)]
+ prefetch_related_objects(
+ missing_instances,
+ self.related.get_accessor_name(),
+ )
+
def __set__(self, instance, value):
"""
Set the related instance through the reverse relation.
diff --git a/django/db/models/query.py b/django/db/models/query.py
index 39cc9b6cb3..0811b90b5e 100644
--- a/django/db/models/query.py
+++ b/django/db/models/query.py
@@ -8,6 +8,7 @@ import warnings
from contextlib import nullcontext
from functools import reduce
from itertools import chain, islice
+from weakref import ref as weak_ref
from asgiref.sync import sync_to_async
@@ -26,6 +27,7 @@ from django.db.models import AutoField, DateField, DateTimeField, Field, Max, sq
from django.db.models.constants import LOOKUP_SEP, OnConflict
from django.db.models.deletion import Collector
from django.db.models.expressions import Case, DatabaseDefault, F, Value, When
+from django.db.models.fetch_modes import FETCH_ONE
from django.db.models.functions import Cast, Trunc
from django.db.models.query_utils import FilteredRelation, Q
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, ROW_COUNT
@@ -122,10 +124,18 @@ class ModelIterable(BaseIterable):
)
for field, related_objs in queryset._known_related_objects.items()
]
+ fetch_mode = queryset._fetch_mode
+ peers = []
for row in compiler.results_iter(results):
obj = model_cls.from_db(
- db, init_list, row[model_fields_start:model_fields_end]
+ db,
+ init_list,
+ row[model_fields_start:model_fields_end],
+ fetch_mode=fetch_mode,
)
+ if fetch_mode.track_peers:
+ peers.append(weak_ref(obj))
+ obj._state.peers = peers
for rel_populator in related_populators:
rel_populator.populate(row, obj)
if annotation_col_map:
@@ -183,10 +193,17 @@ class RawModelIterable(BaseIterable):
query_iterator = compiler.composite_fields_to_tuples(
query_iterator, cols
)
+ fetch_mode = self.queryset._fetch_mode
+ peers = []
for values in query_iterator:
# Associate fields to values
model_init_values = [values[pos] for pos in model_init_pos]
- instance = model_cls.from_db(db, model_init_names, model_init_values)
+ instance = model_cls.from_db(
+ db, model_init_names, model_init_values, fetch_mode=fetch_mode
+ )
+ if fetch_mode.track_peers:
+ peers.append(weak_ref(instance))
+ instance._state.peers = peers
if annotation_fields:
for column, pos in annotation_fields:
setattr(instance, column, values[pos])
@@ -293,6 +310,7 @@ class QuerySet(AltersData):
self._prefetch_done = False
self._known_related_objects = {} # {rel_field: {pk: rel_obj}}
self._iterable_class = ModelIterable
+ self._fetch_mode = FETCH_ONE
self._fields = None
self._defer_next_filter = False
self._deferred_filter = None
@@ -1442,6 +1460,7 @@ class QuerySet(AltersData):
params=params,
translations=translations,
using=using,
+ fetch_mode=self._fetch_mode,
)
qs._prefetch_related_lookups = self._prefetch_related_lookups[:]
return qs
@@ -1913,6 +1932,12 @@ class QuerySet(AltersData):
clone._db = alias
return clone
+ def fetch_mode(self, fetch_mode):
+ """Set the fetch mode for the QuerySet."""
+ clone = self._chain()
+ clone._fetch_mode = fetch_mode
+ return clone
+
###################################
# PUBLIC INTROSPECTION ATTRIBUTES #
###################################
@@ -2051,6 +2076,7 @@ class QuerySet(AltersData):
c._prefetch_related_lookups = self._prefetch_related_lookups[:]
c._known_related_objects = self._known_related_objects
c._iterable_class = self._iterable_class
+ c._fetch_mode = self._fetch_mode
c._fields = self._fields
return c
@@ -2186,6 +2212,7 @@ class RawQuerySet:
translations=None,
using=None,
hints=None,
+ fetch_mode=FETCH_ONE,
):
self.raw_query = raw_query
self.model = model
@@ -2197,6 +2224,7 @@ class RawQuerySet:
self._result_cache = None
self._prefetch_related_lookups = ()
self._prefetch_done = False
+ self._fetch_mode = fetch_mode
def resolve_model_init_order(self):
"""Resolve the init field names and value positions."""
@@ -2295,6 +2323,7 @@ class RawQuerySet:
params=self.params,
translations=self.translations,
using=alias,
+ fetch_mode=self._fetch_mode,
)
@cached_property
diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py
index c383b80640..23d543211a 100644
--- a/django/db/models/query_utils.py
+++ b/django/db/models/query_utils.py
@@ -264,7 +264,8 @@ class DeferredAttribute:
f"Cannot retrieve deferred field {field_name!r} "
"from an unsaved model."
)
- instance.refresh_from_db(fields=[field_name])
+
+ instance._state.fetch_mode.fetch(self, instance)
else:
data[field_name] = val
return data[field_name]
@@ -281,6 +282,20 @@ class DeferredAttribute:
return getattr(instance, link_field.attname)
return None
+ def fetch_one(self, instance):
+ instance.refresh_from_db(fields=[self.field.attname])
+
+ def fetch_many(self, instances):
+ attname = self.field.attname
+ db = instances[0]._state.db
+ value_by_pk = (
+ self.field.model._base_manager.using(db)
+ .values_list(attname)
+ .in_bulk({i.pk for i in instances})
+ )
+ for instance in instances:
+ setattr(instance, attname, value_by_pk[instance.pk])
+
class class_or_instance_method:
"""