diff options
| author | Simon Charette <charette.s@gmail.com> | 2025-03-19 01:11:34 -0400 |
|---|---|---|
| committer | Mariusz Felisiak <felisiak.mariusz@gmail.com> | 2025-09-14 00:27:49 +0200 |
| commit | 55a0073b3beb9de8f7c1f7c44a7d0bc10126c841 (patch) | |
| tree | 616a0bf54b0d9e3d09a2d033980f07bbb2a83e0d /django/db | |
| parent | c48904a225e2e8f02274257247d5b7d29c5fe183 (diff) | |
Refs #27222 -- Refreshed GeneratedFields values on save() initiated update.
This required implementing UPDATE RETURNING machinery that heavily
borrows from the INSERT one.
Diffstat (limited to 'django/db')
| -rw-r--r-- | django/db/backends/base/features.py | 1 | ||||
| -rw-r--r-- | django/db/backends/oracle/base.py | 1 | ||||
| -rw-r--r-- | django/db/backends/oracle/features.py | 1 | ||||
| -rw-r--r-- | django/db/backends/postgresql/features.py | 1 | ||||
| -rw-r--r-- | django/db/backends/sqlite3/features.py | 4 | ||||
| -rw-r--r-- | django/db/models/base.py | 73 | ||||
| -rw-r--r-- | django/db/models/fields/generated.py | 10 | ||||
| -rw-r--r-- | django/db/models/query.py | 6 | ||||
| -rw-r--r-- | django/db/models/sql/compiler.py | 44 |
9 files changed, 117 insertions, 24 deletions
diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py index 49582e6261..0c79e5c133 100644 --- a/django/db/backends/base/features.py +++ b/django/db/backends/base/features.py @@ -38,6 +38,7 @@ class BaseDatabaseFeatures: can_use_chunked_reads = True can_return_columns_from_insert = False can_return_rows_from_bulk_insert = False + can_return_rows_from_update = False has_bulk_insert = True uses_savepoints = True can_release_savepoints = False diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py index c8b49609bd..bf79f7a6e3 100644 --- a/django/db/backends/oracle/base.py +++ b/django/db/backends/oracle/base.py @@ -243,6 +243,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): "use_returning_into", True ) self.features.can_return_columns_from_insert = use_returning_into + self.features.can_return_rows_from_update = use_returning_into @property def is_pool(self): diff --git a/django/db/backends/oracle/features.py b/django/db/backends/oracle/features.py index 7ca40a8000..e87f495e5c 100644 --- a/django/db/backends/oracle/features.py +++ b/django/db/backends/oracle/features.py @@ -19,6 +19,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): has_select_for_update_of = True select_for_update_of_column = True can_return_columns_from_insert = True + can_return_rows_from_update = True supports_subqueries_in_group_by = False ignores_unnecessary_order_by_in_subqueries = False supports_tuple_comparison_against_subquery = False diff --git a/django/db/backends/postgresql/features.py b/django/db/backends/postgresql/features.py index 419fad8686..5f63b6c713 100644 --- a/django/db/backends/postgresql/features.py +++ b/django/db/backends/postgresql/features.py @@ -11,6 +11,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): allows_group_by_selected_pks = True can_return_columns_from_insert = True can_return_rows_from_bulk_insert = True + can_return_rows_from_update = True has_real_datatype = True has_native_uuid_field = True has_native_duration_field = True diff --git a/django/db/backends/sqlite3/features.py b/django/db/backends/sqlite3/features.py index 8604adf40a..143ee1e98b 100644 --- a/django/db/backends/sqlite3/features.py +++ b/django/db/backends/sqlite3/features.py @@ -171,3 +171,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): can_return_rows_from_bulk_insert = property( operator.attrgetter("can_return_columns_from_insert") ) + + can_return_rows_from_update = property( + operator.attrgetter("can_return_columns_from_insert") + ) diff --git a/django/db/models/base.py b/django/db/models/base.py index 3827b00346..93e53bde95 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -1094,12 +1094,28 @@ class Model(AltersData, metaclass=ModelBase): ] forced_update = update_fields or force_update pk_val = self._get_pk_val(meta) - updated = self._do_update( - base_qs, using, pk_val, values, update_fields, forced_update + returning_fields = [ + f + for f in meta.local_concrete_fields + if ( + f.generated + and f.referenced_fields.intersection(non_pks_non_generated) + ) + ] + results = self._do_update( + base_qs, + using, + pk_val, + values, + update_fields, + forced_update, + returning_fields, ) - if force_update and not updated: + if updated := bool(results): + self._assign_returned_values(results[0], returning_fields) + elif force_update: raise self.NotUpdated("Forced update did not affect any rows.") - if update_fields and not updated: + elif update_fields: raise self.NotUpdated( "Save with update_fields did not affect any rows." ) @@ -1131,11 +1147,19 @@ class Model(AltersData, metaclass=ModelBase): cls._base_manager, using, fields, returning_fields, raw ) if results: - for value, field in zip(results[0], returning_fields): - setattr(self, field.attname, value) + self._assign_returned_values(results[0], returning_fields) return updated - def _do_update(self, base_qs, using, pk_val, values, update_fields, forced_update): + def _do_update( + self, + base_qs, + using, + pk_val, + values, + update_fields, + forced_update, + returning_fields, + ): """ Try to update the model. Return True if the model was updated (if an update query was done and a matching row was found in the DB). @@ -1147,22 +1171,23 @@ class Model(AltersData, metaclass=ModelBase): # case we just say the update succeeded. Another case ending up # here is a model with just PK - in that case check that the PK # still exists. - return update_fields is not None or filtered.exists() + if update_fields is not None or filtered.exists(): + return [()] + return [] if self._meta.select_on_save and not forced_update: - return ( - filtered.exists() - and - # It may happen that the object is deleted from the DB right - # after this check, causing the subsequent UPDATE to return - # zero matching rows. The same result can occur in some rare - # cases when the database returns zero despite the UPDATE being - # executed successfully (a row is matched and updated). In - # order to distinguish these two cases, the object's existence - # in the database is again checked for if the UPDATE query - # returns 0. - (filtered._update(values) > 0 or filtered.exists()) - ) - return filtered._update(values) > 0 + # It may happen that the object is deleted from the DB right after + # this check, causing the subsequent UPDATE to return zero matching + # rows. The same result can occur in some rare cases when the + # database returns zero despite the UPDATE being executed + # successfully (a row is matched and updated). In order to + # distinguish these two cases, the object's existence in the + # database is again checked for if the UPDATE query returns 0. + if not filtered.exists(): + return [] + if results := filtered._update(values, returning_fields): + return results + return [()] if filtered.exists() else [] + return filtered._update(values, returning_fields) def _do_insert(self, manager, using, fields, returning_fields, raw): """ @@ -1177,6 +1202,10 @@ class Model(AltersData, metaclass=ModelBase): raw=raw, ) + def _assign_returned_values(self, returned_values, returning_fields): + for value, field in zip(returned_values, returning_fields): + setattr(self, field.attname, value) + def _prepare_related_fields_for_save(self, operation_name, fields=None): # Ensure that a model instance without a PK hasn't been assigned to # a ForeignKey, GenericForeignKey or OneToOneField on this model. If diff --git a/django/db/models/fields/generated.py b/django/db/models/fields/generated.py index f6e3445b71..f89269b5e6 100644 --- a/django/db/models/fields/generated.py +++ b/django/db/models/fields/generated.py @@ -66,6 +66,16 @@ class GeneratedField(Field): sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END" return sql, params + @cached_property + def referenced_fields(self): + resolved_expression = self.expression.resolve_expression( + self._query, allow_joins=False + ) + referenced_fields = [] + for col in self._query._gen_cols([resolved_expression]): + referenced_fields.append(col.target) + return frozenset(referenced_fields) + def check(self, **kwargs): databases = kwargs.get("databases") or [] errors = [ diff --git a/django/db/models/query.py b/django/db/models/query.py index d2f31d15a0..2359ee3bb4 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -1306,7 +1306,7 @@ class QuerySet(AltersData): aupdate.alters_data = True - def _update(self, values): + def _update(self, values, returning_fields=None): """ A version of update() that accepts field objects instead of field names. Used primarily for model saving and not intended for use by @@ -1320,7 +1320,9 @@ class QuerySet(AltersData): # Clear any annotations so that they won't be present in subqueries. query.annotations = {} self._result_cache = None - return query.get_compiler(self.db).execute_sql(ROW_COUNT) + if returning_fields is None: + return query.get_compiler(self.db).execute_sql(ROW_COUNT) + return query.get_compiler(self.db).execute_returning_sql(returning_fields) _update.alters_data = True _update.queryset_only = False diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 73dfa5b87c..0e483dc4f6 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -2020,6 +2020,9 @@ class SQLDeleteCompiler(SQLCompiler): class SQLUpdateCompiler(SQLCompiler): + returning_fields = None + returning_params = () + def as_sql(self): """ Create the SQL for this query. Return the SQL string and list of @@ -2087,6 +2090,15 @@ class SQLUpdateCompiler(SQLCompiler): params = [] else: result.append("WHERE %s" % where) + if self.returning_fields: + # Skip empty r_sql to allow subclasses to customize behavior for + # 3rd party backends. Refs #19096. + r_sql, self.returning_params = self.connection.ops.returning_columns( + self.returning_fields + ) + if r_sql: + result.append(r_sql) + params.extend(self.returning_params) return " ".join(result), tuple(update_params + params) def execute_sql(self, result_type): @@ -2110,6 +2122,38 @@ class SQLUpdateCompiler(SQLCompiler): is_empty = False return row_count + def execute_returning_sql(self, returning_fields): + """ + Execute the specified update and return rows of the returned columns + associated with the specified returning_field if the backend supports + it. + """ + if self.query.get_related_updates(): + raise NotImplementedError( + "Update returning is not implemented for queries with related updates." + ) + + if ( + not returning_fields + or not self.connection.features.can_return_rows_from_update + ): + row_count = self.execute_sql(ROW_COUNT) + return [()] * row_count + + self.returning_fields = returning_fields + with self.connection.cursor() as cursor: + sql, params = self.as_sql() + cursor.execute(sql, params) + rows = self.connection.ops.fetch_returned_rows( + cursor, self.returning_params + ) + opts = self.query.get_meta() + cols = [field.get_col(opts.db_table) for field in self.returning_fields] + converters = self.get_converters(cols) + if converters: + rows = self.apply_converters(rows, converters) + return list(rows) + def pre_sql_setup(self): """ If the update depends on results from other tables, munge the "where" |
