summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMariusz Felisiak <felisiak.mariusz@gmail.com>2022-04-27 11:30:43 +0200
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2022-05-02 10:52:33 +0200
commit271a8e73ee382bb487d15e97ffaa675d78869413 (patch)
treec85f3ca863810dd3100578db7e9d2058f887b4f9
parent77926176b281b9c553c934e52acdd1c0377ea601 (diff)
Refs #33646 -- Made QuerySet.raw() async-compatible.
-rw-r--r--django/db/models/query.py16
-rw-r--r--tests/async_queryset/tests.py5
2 files changed, 21 insertions, 0 deletions
diff --git a/django/db/models/query.py b/django/db/models/query.py
index 80dc83a41c..ce7e295d52 100644
--- a/django/db/models/query.py
+++ b/django/db/models/query.py
@@ -2025,6 +2025,12 @@ class RawQuerySet:
if self._prefetch_related_lookups and not self._prefetch_done:
self._prefetch_related_objects()
+ async def _async_fetch_all(self):
+ if self._result_cache is None:
+ self._result_cache = [result async for result in RawModelIterable(self)]
+ if self._prefetch_related_lookups and not self._prefetch_done:
+ sync_to_async(self._prefetch_related_objects)()
+
def __len__(self):
self._fetch_all()
return len(self._result_cache)
@@ -2037,6 +2043,16 @@ class RawQuerySet:
self._fetch_all()
return iter(self._result_cache)
+ def __aiter__(self):
+ # Remember, __aiter__ itself is synchronous, it's the thing it returns
+ # that is async!
+ async def generator():
+ await self._async_fetch_all()
+ for item in self._result_cache:
+ yield item
+
+ return generator()
+
def iterator(self):
yield from RawModelIterable(self)
diff --git a/tests/async_queryset/tests.py b/tests/async_queryset/tests.py
index f600cfe392..792797fb9d 100644
--- a/tests/async_queryset/tests.py
+++ b/tests/async_queryset/tests.py
@@ -225,3 +225,8 @@ class AsyncQuerySetTest(TestCase):
json.loads(result)
except json.JSONDecodeError as e:
self.fail(f"QuerySet.aexplain() result is not valid JSON: {e}")
+
+ async def test_raw(self):
+ sql = "SELECT id, field FROM async_queryset_simplemodel WHERE created=%s"
+ qs = SimpleModel.objects.raw(sql, [self.s1.created])
+ self.assertEqual([o async for o in qs], [self.s1])