summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJacob Walls <jacobtylerwalls@gmail.com>2025-02-16 21:35:12 -0500
committerSarah Boyce <42296566+sarahboyce@users.noreply.github.com>2025-02-18 17:29:34 +0100
commit9525135698bd4f97cf1431776ef52ae393dfb3c0 (patch)
tree3ded1c4ebc817866fd9488ebcf0b21378f465096
parentbb4f65ec8719db114b42947605e7ceb131698662 (diff)
[5.2.x] Fixed #35167 -- Delegated to super() in JSONField.get_db_prep_save().
Avoids reports of bulk_update() sending Cast expressions to JSONField.get_prep_value(). Co-authored-by: Simon Charette <charette.s@gmail.com> Backport of 0bf412111be686b6b23e00863f5d449d63557dbf from main.
-rw-r--r--django/db/models/fields/json.py19
-rw-r--r--tests/model_fields/models.py11
-rw-r--r--tests/model_fields/test_jsonfield.py19
3 files changed, 41 insertions, 8 deletions
diff --git a/django/db/models/fields/json.py b/django/db/models/fields/json.py
index 8d743c436a..188fcf520c 100644
--- a/django/db/models/fields/json.py
+++ b/django/db/models/fields/json.py
@@ -99,18 +99,23 @@ class JSONField(CheckFieldDefaultMixin, Field):
def get_db_prep_value(self, value, connection, prepared=False):
if not prepared:
value = self.get_prep_value(value)
- if isinstance(value, expressions.Value) and isinstance(
- value.output_field, JSONField
- ):
- value = value.value
- elif hasattr(value, "as_sql"):
- return value
return connection.ops.adapt_json_value(value, self.encoder)
def get_db_prep_save(self, value, connection):
+ # This slightly involved logic is to allow for `None` to be used to
+ # store SQL `NULL` while `Value(None, JSONField())` can be used to
+ # store JSON `null` while preventing compilable `as_sql` values from
+ # making their way to `get_db_prep_value`, which is what the `super()`
+ # implementation does.
if value is None:
return value
- return self.get_db_prep_value(value, connection)
+ if (
+ isinstance(value, expressions.Value)
+ and value.value is None
+ and isinstance(value.output_field, JSONField)
+ ):
+ value = None
+ return super().get_db_prep_save(value, connection)
def get_transform(self, name):
transform = super().get_transform(name)
diff --git a/tests/model_fields/models.py b/tests/model_fields/models.py
index fdea06b23d..299e927615 100644
--- a/tests/model_fields/models.py
+++ b/tests/model_fields/models.py
@@ -430,6 +430,17 @@ class RelatedJSONModel(models.Model):
required_db_features = {"supports_json_field"}
+class CustomSerializationJSONModel(models.Model):
+ class StringifiedJSONField(models.JSONField):
+ def get_prep_value(self, value):
+ return json.dumps(value, cls=self.encoder)
+
+ json_field = StringifiedJSONField()
+
+ class Meta:
+ required_db_features = {"supports_json_field"}
+
+
class AllFieldsModel(models.Model):
big_integer = models.BigIntegerField()
binary = models.BinaryField()
diff --git a/tests/model_fields/test_jsonfield.py b/tests/model_fields/test_jsonfield.py
index 5a9cf9ad7a..267b9a0e66 100644
--- a/tests/model_fields/test_jsonfield.py
+++ b/tests/model_fields/test_jsonfield.py
@@ -40,7 +40,13 @@ from django.db.models.functions import Cast
from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature
from django.test.utils import CaptureQueriesContext
-from .models import CustomJSONDecoder, JSONModel, NullableJSONModel, RelatedJSONModel
+from .models import (
+ CustomJSONDecoder,
+ CustomSerializationJSONModel,
+ JSONModel,
+ NullableJSONModel,
+ RelatedJSONModel,
+)
@skipUnlessDBFeature("supports_json_field")
@@ -298,6 +304,17 @@ class TestSaveLoad(TestCase):
obj.refresh_from_db()
self.assertEqual(obj.value, value)
+ def test_bulk_update_custom_get_prep_value(self):
+ objs = CustomSerializationJSONModel.objects.bulk_create(
+ [CustomSerializationJSONModel(pk=1, json_field={"version": "1"})]
+ )
+ objs[0].json_field["version"] = "1-alpha"
+ CustomSerializationJSONModel.objects.bulk_update(objs, ["json_field"])
+ self.assertSequenceEqual(
+ CustomSerializationJSONModel.objects.values("json_field"),
+ [{"json_field": '{"version": "1-alpha"}'}],
+ )
+
@skipUnlessDBFeature("supports_json_field")
class TestQuerying(TestCase):