summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Sanders <dsanders11@ucsbalum.com>2016-04-17 10:03:08 -0700
committerTim Graham <timograham@gmail.com>2016-06-29 14:08:13 -0400
commita84344bc539c66589c8d4fe30c6ceaecf8ba1af3 (patch)
tree5e4b4fac4942f56eb444016c1b15eb0e88e91f18
parent06acb3445f6a2decf17f9b7be91a6637024e02c1 (diff)
Fixed #19513, #18580 -- Fixed crash on QuerySet.update() after annotate().
-rw-r--r--django/db/models/query.py2
-rw-r--r--django/db/models/sql/subqueries.py6
-rw-r--r--tests/update/tests.py22
3 files changed, 28 insertions, 2 deletions
diff --git a/django/db/models/query.py b/django/db/models/query.py
index 40f7ae6ea9..4085e618cf 100644
--- a/django/db/models/query.py
+++ b/django/db/models/query.py
@@ -632,6 +632,8 @@ class QuerySet(object):
self._for_write = True
query = self.query.clone(sql.UpdateQuery)
query.add_update_values(kwargs)
+ # Clear any annotations so that they won't be present in subqueries.
+ query._annotations = None
with transaction.atomic(using=self.db, savepoint=False):
rows = query.get_compiler(self.db).execute_sql(CURSOR)
self._result_cache = None
diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py
index 316a5c684d..fc9683064f 100644
--- a/django/db/models/sql/subqueries.py
+++ b/django/db/models/sql/subqueries.py
@@ -142,7 +142,11 @@ class UpdateQuery(Query):
that will be used to generate the UPDATE query. Might be more usefully
called add_update_targets() to hint at the extra information here.
"""
- self.values.extend(values_seq)
+ for field, model, val in values_seq:
+ if hasattr(val, 'resolve_expression'):
+ # Resolve expressions here so that annotations are no longer needed
+ val = val.resolve_expression(self, allow_joins=False, for_save=True)
+ self.values.append((field, model, val))
def add_related_update(self, model, field, value):
"""
diff --git a/tests/update/tests.py b/tests/update/tests.py
index 3dc97c9173..89593f8dfc 100644
--- a/tests/update/tests.py
+++ b/tests/update/tests.py
@@ -1,7 +1,7 @@
from __future__ import unicode_literals
from django.core.exceptions import FieldError
-from django.db.models import F, Max
+from django.db.models import Count, F, Max
from django.test import TestCase
from .models import A, B, Bar, D, DataPoint, Foo, RelatedPoint
@@ -158,3 +158,23 @@ class AdvancedTests(TestCase):
qs = DataPoint.objects.annotate(max=Max('value'))
with self.assertRaisesMessage(FieldError, 'Aggregate functions are not allowed in this query'):
qs.update(another_value=F('max'))
+
+ def test_update_annotated_multi_table_queryset(self):
+ """
+ Update of a queryset that's been annotated and involves multiple tables.
+ """
+ # Trivial annotated update
+ qs = DataPoint.objects.annotate(related_count=Count('relatedpoint'))
+ self.assertEqual(qs.update(value='Foo'), 3)
+ # Update where annotation is used for filtering
+ qs = DataPoint.objects.annotate(related_count=Count('relatedpoint'))
+ self.assertEqual(qs.filter(related_count=1).update(value='Foo'), 1)
+ # Update where annotation is used in update parameters
+ # #26539 - This isn't forbidden but also doesn't generate proper SQL
+ # qs = RelatedPoint.objects.annotate(data_name=F('data__name'))
+ # updated = qs.update(name=F('data_name'))
+ # self.assertEqual(updated, 1)
+ # Update where aggregation annotation is used in update parameters
+ qs = RelatedPoint.objects.annotate(max=Max('data__value'))
+ with self.assertRaisesMessage(FieldError, 'Aggregate functions are not allowed in this query'):
+ qs.update(name=F('max'))