summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorontowhee <82607723+ontowhee@users.noreply.github.com>2023-05-16 19:12:53 -0700
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2024-02-16 08:57:16 +0100
commit66e47ac69a7e71cf32eee312d05668d8f1ba24bb (patch)
tree4e70a219a399f9d8142bfca7d22621ac22960f3d
parent0d8fbe2ade29f1b7bd9e6ba7a0281f5478603a43 (diff)
Fixed #29725 -- Removed unnecessary join in QuerySet.count() and exists() on a many to many relation.
Co-Authored-By: Shiwei Chen <april.chen.0615@gmail.com>
-rw-r--r--django/db/models/fields/related_descriptors.py53
-rw-r--r--tests/many_to_many/models.py12
-rw-r--r--tests/many_to_many/tests.py96
3 files changed, 151 insertions, 10 deletions
diff --git a/django/db/models/fields/related_descriptors.py b/django/db/models/fields/related_descriptors.py
index 62ddfc60b3..a8f298230a 100644
--- a/django/db/models/fields/related_descriptors.py
+++ b/django/db/models/fields/related_descriptors.py
@@ -75,7 +75,7 @@ from django.db import (
router,
transaction,
)
-from django.db.models import Q, Window, signals
+from django.db.models import Manager, Q, Window, signals
from django.db.models.functions import RowNumber
from django.db.models.lookups import GreaterThan, LessThanOrEqual
from django.db.models.query import QuerySet
@@ -1121,6 +1121,12 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
queryset._defer_next_filter = True
return queryset._next_is_sticky().filter(**self.core_filters)
+ def get_prefetch_cache(self):
+ try:
+ return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
+ except (AttributeError, KeyError):
+ return None
+
def _remove_prefetched_objects(self):
try:
self.instance._prefetched_objects_cache.pop(self.prefetch_cache_name)
@@ -1128,9 +1134,9 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
pass # nothing to clear from cache
def get_queryset(self):
- try:
- return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
- except (AttributeError, KeyError):
+ if (cache := self.get_prefetch_cache()) is not None:
+ return cache
+ else:
queryset = super().get_queryset()
return self._apply_rel_filters(queryset)
@@ -1195,6 +1201,45 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
False,
)
+ @property
+ def constrained_target(self):
+ # If the through relation's target field's foreign integrity is
+ # enforced, the query can be performed solely against the through
+ # table as the INNER JOIN'ing against target table is unnecessary.
+ if not self.target_field.db_constraint:
+ return None
+ db = router.db_for_read(self.through, instance=self.instance)
+ if not connections[db].features.supports_foreign_keys:
+ return None
+ hints = {"instance": self.instance}
+ manager = self.through._base_manager.db_manager(db, hints=hints)
+ filters = {self.source_field_name: self.instance.pk}
+ # Nullable target rows must be excluded as well as they would have
+ # been filtered out from an INNER JOIN.
+ if self.target_field.null:
+ filters["%s__isnull" % self.target_field_name] = False
+ return manager.filter(**filters)
+
+ def exists(self):
+ if (
+ superclass is Manager
+ and self.get_prefetch_cache() is None
+ and (constrained_target := self.constrained_target) is not None
+ ):
+ return constrained_target.exists()
+ else:
+ return super().exists()
+
+ def count(self):
+ if (
+ superclass is Manager
+ and self.get_prefetch_cache() is None
+ and (constrained_target := self.constrained_target) is not None
+ ):
+ return constrained_target.count()
+ else:
+ return super().count()
+
def add(self, *objs, through_defaults=None):
self._remove_prefetched_objects()
db = router.db_for_write(self.through, instance=self.instance)
diff --git a/tests/many_to_many/models.py b/tests/many_to_many/models.py
index 42fc426990..df7222e08d 100644
--- a/tests/many_to_many/models.py
+++ b/tests/many_to_many/models.py
@@ -78,3 +78,15 @@ class InheritedArticleA(AbstractArticle):
class InheritedArticleB(AbstractArticle):
pass
+
+
+class NullableTargetArticle(models.Model):
+ headline = models.CharField(max_length=100)
+ publications = models.ManyToManyField(
+ Publication, through="NullablePublicationThrough"
+ )
+
+
+class NullablePublicationThrough(models.Model):
+ article = models.ForeignKey(NullableTargetArticle, models.CASCADE)
+ publication = models.ForeignKey(Publication, models.CASCADE, null=True)
diff --git a/tests/many_to_many/tests.py b/tests/many_to_many/tests.py
index 7ed3b80abc..351e4eb8cc 100644
--- a/tests/many_to_many/tests.py
+++ b/tests/many_to_many/tests.py
@@ -1,10 +1,18 @@
from unittest import mock
-from django.db import transaction
+from django.db import connection, transaction
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
from django.utils.deprecation import RemovedInDjango60Warning
-from .models import Article, InheritedArticleA, InheritedArticleB, Publication, User
+from .models import (
+ Article,
+ InheritedArticleA,
+ InheritedArticleB,
+ NullablePublicationThrough,
+ NullableTargetArticle,
+ Publication,
+ User,
+)
class ManyToManyTests(TestCase):
@@ -558,10 +566,16 @@ class ManyToManyTests(TestCase):
def test_custom_default_manager_exists_count(self):
a5 = Article.objects.create(headline="deleted")
a5.publications.add(self.p2)
- self.assertEqual(self.p2.article_set.count(), self.p2.article_set.all().count())
- self.assertEqual(
- self.p3.article_set.exists(), self.p3.article_set.all().exists()
- )
+ with self.assertNumQueries(2) as ctx:
+ self.assertEqual(
+ self.p2.article_set.count(), self.p2.article_set.all().count()
+ )
+ self.assertIn("JOIN", ctx.captured_queries[0]["sql"])
+ with self.assertNumQueries(2) as ctx:
+ self.assertEqual(
+ self.p3.article_set.exists(), self.p3.article_set.all().exists()
+ )
+ self.assertIn("JOIN", ctx.captured_queries[0]["sql"])
def test_get_prefetch_queryset_warning(self):
articles = Article.objects.all()
@@ -582,3 +596,73 @@ class ManyToManyTests(TestCase):
instances=articles,
querysets=[Publication.objects.all(), Publication.objects.all()],
)
+
+
+class ManyToManyQueryTests(TestCase):
+ """
+ SQL is optimized to reference the through table without joining against the
+ related table when using count() and exists() functions on a queryset for
+ many to many relations. The optimization applies to the case where there
+ are no filters.
+ """
+
+ @classmethod
+ def setUpTestData(cls):
+ cls.article = Article.objects.create(
+ headline="Django lets you build Web apps easily"
+ )
+ cls.nullable_target_article = NullableTargetArticle.objects.create(
+ headline="The python is good"
+ )
+ NullablePublicationThrough.objects.create(
+ article=cls.nullable_target_article, publication=None
+ )
+
+ @skipUnlessDBFeature("supports_foreign_keys")
+ def test_count_join_optimization(self):
+ with self.assertNumQueries(1) as ctx:
+ self.article.publications.count()
+ self.assertNotIn("JOIN", ctx.captured_queries[0]["sql"])
+
+ with self.assertNumQueries(1) as ctx:
+ self.article.publications.count()
+ self.assertNotIn("JOIN", ctx.captured_queries[0]["sql"])
+ self.assertEqual(self.nullable_target_article.publications.count(), 0)
+
+ def test_count_join_optimization_disabled(self):
+ with (
+ mock.patch.object(connection.features, "supports_foreign_keys", False),
+ self.assertNumQueries(1) as ctx,
+ ):
+ self.article.publications.count()
+
+ self.assertIn("JOIN", ctx.captured_queries[0]["sql"])
+
+ @skipUnlessDBFeature("supports_foreign_keys")
+ def test_exists_join_optimization(self):
+ with self.assertNumQueries(1) as ctx:
+ self.article.publications.exists()
+ self.assertNotIn("JOIN", ctx.captured_queries[0]["sql"])
+
+ self.article.publications.prefetch_related()
+ with self.assertNumQueries(1) as ctx:
+ self.article.publications.exists()
+ self.assertNotIn("JOIN", ctx.captured_queries[0]["sql"])
+ self.assertIs(self.nullable_target_article.publications.exists(), False)
+
+ def test_exists_join_optimization_disabled(self):
+ with (
+ mock.patch.object(connection.features, "supports_foreign_keys", False),
+ self.assertNumQueries(1) as ctx,
+ ):
+ self.article.publications.exists()
+
+ self.assertIn("JOIN", ctx.captured_queries[0]["sql"])
+
+ def test_prefetch_related_no_queries_optimization_disabled(self):
+ qs = Article.objects.prefetch_related("publications")
+ article = qs.get()
+ with self.assertNumQueries(0):
+ article.publications.count()
+ with self.assertNumQueries(0):
+ article.publications.exists()