summaryrefslogtreecommitdiff
path: root/django/db/models/sql/compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'django/db/models/sql/compiler.py')
-rw-r--r--django/db/models/sql/compiler.py44
1 files changed, 44 insertions, 0 deletions
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index 73dfa5b87c..0e483dc4f6 100644
--- a/django/db/models/sql/compiler.py
+++ b/django/db/models/sql/compiler.py
@@ -2020,6 +2020,9 @@ class SQLDeleteCompiler(SQLCompiler):
class SQLUpdateCompiler(SQLCompiler):
+ returning_fields = None
+ returning_params = ()
+
def as_sql(self):
"""
Create the SQL for this query. Return the SQL string and list of
@@ -2087,6 +2090,15 @@ class SQLUpdateCompiler(SQLCompiler):
params = []
else:
result.append("WHERE %s" % where)
+ if self.returning_fields:
+ # Skip empty r_sql to allow subclasses to customize behavior for
+ # 3rd party backends. Refs #19096.
+ r_sql, self.returning_params = self.connection.ops.returning_columns(
+ self.returning_fields
+ )
+ if r_sql:
+ result.append(r_sql)
+ params.extend(self.returning_params)
return " ".join(result), tuple(update_params + params)
def execute_sql(self, result_type):
@@ -2110,6 +2122,38 @@ class SQLUpdateCompiler(SQLCompiler):
is_empty = False
return row_count
+ def execute_returning_sql(self, returning_fields):
+ """
+ Execute the specified update and return rows of the returned columns
+ associated with the specified returning_field if the backend supports
+ it.
+ """
+ if self.query.get_related_updates():
+ raise NotImplementedError(
+ "Update returning is not implemented for queries with related updates."
+ )
+
+ if (
+ not returning_fields
+ or not self.connection.features.can_return_rows_from_update
+ ):
+ row_count = self.execute_sql(ROW_COUNT)
+ return [()] * row_count
+
+ self.returning_fields = returning_fields
+ with self.connection.cursor() as cursor:
+ sql, params = self.as_sql()
+ cursor.execute(sql, params)
+ rows = self.connection.ops.fetch_returned_rows(
+ cursor, self.returning_params
+ )
+ opts = self.query.get_meta()
+ cols = [field.get_col(opts.db_table) for field in self.returning_fields]
+ converters = self.get_converters(cols)
+ if converters:
+ rows = self.apply_converters(rows, converters)
+ return list(rows)
+
def pre_sql_setup(self):
"""
If the update depends on results from other tables, munge the "where"