summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStéphane "Twidi" Angel <s.angel@twidi.com>2022-07-07 04:26:49 +0200
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2022-07-08 08:17:42 +0200
commitccbf714ebeff51d1370789e5e487a978d0e2dbfb (patch)
treeb358e51931b1bad5f0890969cb1630f7792d81cd
parent41019e48bbf082c985e6ba3bad34d118b903bff1 (diff)
Fixed #33829 -- Made BaseConstraint.deconstruct() and equality handle violation_error_message.
Regression in 667105877e6723c6985399803a364848891513cc.
-rw-r--r--django/contrib/postgres/constraints.py1
-rw-r--r--django/db/models/constraints.py20
-rw-r--r--tests/constraints/tests.py77
-rw-r--r--tests/postgres_tests/test_constraints.py22
4 files changed, 117 insertions, 3 deletions
diff --git a/django/contrib/postgres/constraints.py b/django/contrib/postgres/constraints.py
index c19602b26e..2e6f7f7998 100644
--- a/django/contrib/postgres/constraints.py
+++ b/django/contrib/postgres/constraints.py
@@ -177,6 +177,7 @@ class ExclusionConstraint(BaseConstraint):
and self.deferrable == other.deferrable
and self.include == other.include
and self.opclasses == other.opclasses
+ and self.violation_error_message == other.violation_error_message
)
return super().__eq__(other)
diff --git a/django/db/models/constraints.py b/django/db/models/constraints.py
index 9949b50b1e..86f015465a 100644
--- a/django/db/models/constraints.py
+++ b/django/db/models/constraints.py
@@ -14,12 +14,15 @@ __all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"
class BaseConstraint:
- violation_error_message = _("Constraint “%(name)s” is violated.")
+ default_violation_error_message = _("Constraint “%(name)s” is violated.")
+ violation_error_message = None
def __init__(self, name, violation_error_message=None):
self.name = name
if violation_error_message is not None:
self.violation_error_message = violation_error_message
+ else:
+ self.violation_error_message = self.default_violation_error_message
@property
def contains_expressions(self):
@@ -43,7 +46,13 @@ class BaseConstraint:
def deconstruct(self):
path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
path = path.replace("django.db.models.constraints", "django.db.models")
- return (path, (), {"name": self.name})
+ kwargs = {"name": self.name}
+ if (
+ self.violation_error_message is not None
+ and self.violation_error_message != self.default_violation_error_message
+ ):
+ kwargs["violation_error_message"] = self.violation_error_message
+ return (path, (), kwargs)
def clone(self):
_, args, kwargs = self.deconstruct()
@@ -94,7 +103,11 @@ class CheckConstraint(BaseConstraint):
def __eq__(self, other):
if isinstance(other, CheckConstraint):
- return self.name == other.name and self.check == other.check
+ return (
+ self.name == other.name
+ and self.check == other.check
+ and self.violation_error_message == other.violation_error_message
+ )
return super().__eq__(other)
def deconstruct(self):
@@ -273,6 +286,7 @@ class UniqueConstraint(BaseConstraint):
and self.include == other.include
and self.opclasses == other.opclasses
and self.expressions == other.expressions
+ and self.violation_error_message == other.violation_error_message
)
return super().__eq__(other)
diff --git a/tests/constraints/tests.py b/tests/constraints/tests.py
index d9e377438e..4032b418b4 100644
--- a/tests/constraints/tests.py
+++ b/tests/constraints/tests.py
@@ -65,6 +65,29 @@ class BaseConstraintTests(SimpleTestCase):
)
self.assertEqual(c.get_violation_error_message(), "custom base_name message")
+ def test_custom_violation_error_message_clone(self):
+ constraint = BaseConstraint(
+ "base_name",
+ violation_error_message="custom %(name)s message",
+ ).clone()
+ self.assertEqual(
+ constraint.get_violation_error_message(),
+ "custom base_name message",
+ )
+
+ def test_deconstruction(self):
+ constraint = BaseConstraint(
+ "base_name",
+ violation_error_message="custom %(name)s message",
+ )
+ path, args, kwargs = constraint.deconstruct()
+ self.assertEqual(path, "django.db.models.BaseConstraint")
+ self.assertEqual(args, ())
+ self.assertEqual(
+ kwargs,
+ {"name": "base_name", "violation_error_message": "custom %(name)s message"},
+ )
+
class CheckConstraintTests(TestCase):
def test_eq(self):
@@ -84,6 +107,28 @@ class CheckConstraintTests(TestCase):
models.CheckConstraint(check=check2, name="price"),
)
self.assertNotEqual(models.CheckConstraint(check=check1, name="price"), 1)
+ self.assertNotEqual(
+ models.CheckConstraint(check=check1, name="price"),
+ models.CheckConstraint(
+ check=check1, name="price", violation_error_message="custom error"
+ ),
+ )
+ self.assertNotEqual(
+ models.CheckConstraint(
+ check=check1, name="price", violation_error_message="custom error"
+ ),
+ models.CheckConstraint(
+ check=check1, name="price", violation_error_message="other custom error"
+ ),
+ )
+ self.assertEqual(
+ models.CheckConstraint(
+ check=check1, name="price", violation_error_message="custom error"
+ ),
+ models.CheckConstraint(
+ check=check1, name="price", violation_error_message="custom error"
+ ),
+ )
def test_repr(self):
constraint = models.CheckConstraint(
@@ -216,6 +261,38 @@ class UniqueConstraintTests(TestCase):
self.assertNotEqual(
models.UniqueConstraint(fields=["foo", "bar"], name="unique"), 1
)
+ self.assertNotEqual(
+ models.UniqueConstraint(fields=["foo", "bar"], name="unique"),
+ models.UniqueConstraint(
+ fields=["foo", "bar"],
+ name="unique",
+ violation_error_message="custom error",
+ ),
+ )
+ self.assertNotEqual(
+ models.UniqueConstraint(
+ fields=["foo", "bar"],
+ name="unique",
+ violation_error_message="custom error",
+ ),
+ models.UniqueConstraint(
+ fields=["foo", "bar"],
+ name="unique",
+ violation_error_message="other custom error",
+ ),
+ )
+ self.assertEqual(
+ models.UniqueConstraint(
+ fields=["foo", "bar"],
+ name="unique",
+ violation_error_message="custom error",
+ ),
+ models.UniqueConstraint(
+ fields=["foo", "bar"],
+ name="unique",
+ violation_error_message="custom error",
+ ),
+ )
def test_eq_with_condition(self):
self.assertEqual(
diff --git a/tests/postgres_tests/test_constraints.py b/tests/postgres_tests/test_constraints.py
index d36c6fd9ed..a33c485a36 100644
--- a/tests/postgres_tests/test_constraints.py
+++ b/tests/postgres_tests/test_constraints.py
@@ -444,17 +444,39 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
)
self.assertNotEqual(constraint_2, constraint_9)
self.assertNotEqual(constraint_7, constraint_8)
+
+ constraint_10 = ExclusionConstraint(
+ name="exclude_overlapping",
+ expressions=[
+ (F("datespan"), RangeOperators.OVERLAPS),
+ (F("room"), RangeOperators.EQUAL),
+ ],
+ condition=Q(cancelled=False),
+ violation_error_message="custom error",
+ )
+ constraint_11 = ExclusionConstraint(
+ name="exclude_overlapping",
+ expressions=[
+ (F("datespan"), RangeOperators.OVERLAPS),
+ (F("room"), RangeOperators.EQUAL),
+ ],
+ condition=Q(cancelled=False),
+ violation_error_message="other custom error",
+ )
self.assertEqual(constraint_1, constraint_1)
self.assertEqual(constraint_1, mock.ANY)
self.assertNotEqual(constraint_1, constraint_2)
self.assertNotEqual(constraint_1, constraint_3)
self.assertNotEqual(constraint_1, constraint_4)
+ self.assertNotEqual(constraint_1, constraint_10)
self.assertNotEqual(constraint_2, constraint_3)
self.assertNotEqual(constraint_2, constraint_4)
self.assertNotEqual(constraint_2, constraint_7)
self.assertNotEqual(constraint_4, constraint_5)
self.assertNotEqual(constraint_5, constraint_6)
self.assertNotEqual(constraint_1, object())
+ self.assertNotEqual(constraint_10, constraint_11)
+ self.assertEqual(constraint_10, constraint_10)
def test_deconstruct(self):
constraint = ExclusionConstraint(