diff options
Diffstat (limited to 'django/db/models/sql')
| -rw-r--r-- | django/db/models/sql/compiler.py | 46 |
1 files changed, 25 insertions, 21 deletions
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 8742de00d6..6c758fb526 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -1690,11 +1690,12 @@ class SQLInsertCompiler(SQLCompiler): returning_fields = None returning_params = () - def field_as_sql(self, field, get_placeholder, val): + def field_as_sql(self, field, get_placeholder_sql, val): """ Take a field and a value intended to be saved on that field, and return placeholder SQL and accompanying params. Check for raw values, - expressions, and fields with get_placeholder() defined in that order. + fields with get_placeholder_sql(), and compilable defined in that + order. When field is None, consider the value raw and use it as the placeholder, with no corresponding parameters returned. @@ -1702,13 +1703,13 @@ class SQLInsertCompiler(SQLCompiler): if field is None: # A field value of None means the value is raw. sql, params = val, [] + elif get_placeholder_sql is not None: + # Some fields (e.g. geo fields) need special munging before + # they can be inserted. + sql, params = get_placeholder_sql(val, self, self.connection) elif hasattr(val, "as_sql"): # This is an expression, let's compile it. sql, params = self.compile(val) - elif get_placeholder is not None: - # Some fields (e.g. geo fields) need special munging before - # they can be inserted. - sql, params = get_placeholder(val, self, self.connection), [val] else: # Return the common case for the placeholder sql, params = "%s", [val] @@ -1777,11 +1778,15 @@ class SQLInsertCompiler(SQLCompiler): # list of (sql, [params]) tuples for each object to be saved # Shape: [n_objs][n_fields][2] - get_placeholders = [getattr(field, "get_placeholder", None) for field in fields] + get_placeholder_sqls = [ + getattr(field, "get_placeholder_sql", None) for field in fields + ] rows_of_fields_as_sql = ( ( - self.field_as_sql(field, get_placeholder, value) - for field, get_placeholder, value in zip(fields, get_placeholders, row) + self.field_as_sql(field, get_placeholder_sql, value) + for field, get_placeholder_sql, value in zip( + fields, get_placeholder_sqls, row + ) ) for row in value_rows ) @@ -2078,21 +2083,20 @@ class SQLUpdateCompiler(SQLCompiler): ) val = field.get_db_prep_save(val, connection=self.connection) - # Getting the placeholder for the field. - if hasattr(field, "get_placeholder"): - placeholder = field.get_placeholder(val, self, self.connection) - else: - placeholder = "%s" - name = field.column - if hasattr(val, "as_sql"): + quoted_name = qn(field.column) + if ( + get_placeholder_sql := getattr(field, "get_placeholder_sql", None) + ) is not None: + sql, params = get_placeholder_sql(val, self, self.connection) + values.append(f"{quoted_name} = {sql}") + update_params.extend(params) + elif hasattr(val, "as_sql"): sql, params = self.compile(val) - values.append("%s = %s" % (qn(name), placeholder % sql)) + values.append(f"{quoted_name} = {sql}") update_params.extend(params) - elif val is not None: - values.append("%s = %s" % (qn(name), placeholder)) - update_params.append(val) else: - values.append("%s = NULL" % qn(name)) + values.append(f"{quoted_name} = %s") + update_params.append(val) table = self.query.base_table result = [ "UPDATE %s SET" % qn(table), |
