summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMark Gensler <mark.gensler@protonmail.com>2024-07-18 08:38:06 +0100
committerSarah Boyce <42296566+sarahboyce@users.noreply.github.com>2024-08-12 13:45:57 +0200
commit228128618bd895ecad235d2215f4ad4e3232595d (patch)
treee0e73d595dce097d69909bee06c86019c61dc20b
parentf883bef05457a5a49eb31109429fc01737f82532 (diff)
Fixed #35575 -- Added support for constraint validation on GeneratedFields.
-rw-r--r--django/contrib/postgres/constraints.py12
-rw-r--r--django/db/models/base.py35
-rw-r--r--django/db/models/constraints.py81
-rw-r--r--docs/releases/5.2.txt3
-rw-r--r--tests/constraints/models.py41
-rw-r--r--tests/constraints/tests.py111
-rw-r--r--tests/postgres_tests/test_constraints.py34
7 files changed, 273 insertions, 44 deletions
diff --git a/django/contrib/postgres/constraints.py b/django/contrib/postgres/constraints.py
index 2701c4ba48..49124adc15 100644
--- a/django/contrib/postgres/constraints.py
+++ b/django/contrib/postgres/constraints.py
@@ -183,17 +183,11 @@ class ExclusionConstraint(BaseConstraint):
)
replacements = {F(field): value for field, value in replacement_map.items()}
lookups = []
- for idx, (expression, operator) in enumerate(self.expressions):
+ for expression, operator in self.expressions:
if isinstance(expression, str):
expression = F(expression)
- if exclude:
- if isinstance(expression, F):
- if expression.name in exclude:
- return
- else:
- for expr in expression.flatten():
- if isinstance(expr, F) and expr.name in exclude:
- return
+ if exclude and self._expression_refs_exclude(model, expression, exclude):
+ return
rhs_expression = expression.replace_expressions(replacements)
if hasattr(expression, "get_expression_for_validation"):
expression = expression.get_expression_for_validation()
diff --git a/django/db/models/base.py b/django/db/models/base.py
index d4b8bab963..a89ceafbef 100644
--- a/django/db/models/base.py
+++ b/django/db/models/base.py
@@ -1337,18 +1337,33 @@ class Model(AltersData, metaclass=ModelBase):
if exclude is None:
exclude = set()
meta = meta or self._meta
- field_map = {
- field.name: (
- value
- if (value := getattr(self, field.attname))
- and hasattr(value, "resolve_expression")
- else Value(value, field)
- )
- for field in meta.local_concrete_fields
- if field.name not in exclude and not field.generated
- }
+ field_map = {}
+ generated_fields = []
+ for field in meta.local_concrete_fields:
+ if field.name in exclude:
+ continue
+ if field.generated:
+ if any(
+ ref[0] in exclude
+ for ref in self._get_expr_references(field.expression)
+ ):
+ continue
+ generated_fields.append(field)
+ continue
+ value = getattr(self, field.attname)
+ if not value or not hasattr(value, "resolve_expression"):
+ value = Value(value, field)
+ field_map[field.name] = value
if "pk" not in exclude:
field_map["pk"] = Value(self.pk, meta.pk)
+ if generated_fields:
+ replacements = {F(name): value for name, value in field_map.items()}
+ for generated_field in generated_fields:
+ field_map[generated_field.name] = ExpressionWrapper(
+ generated_field.expression.replace_expressions(replacements),
+ generated_field.output_field,
+ )
+
return field_map
def prepare_database_save(self, field):
diff --git a/django/db/models/constraints.py b/django/db/models/constraints.py
index 915ace5129..b5952def6a 100644
--- a/django/db/models/constraints.py
+++ b/django/db/models/constraints.py
@@ -68,6 +68,19 @@ class BaseConstraint:
def remove_sql(self, model, schema_editor):
raise NotImplementedError("This method must be implemented by a subclass.")
+ @classmethod
+ def _expression_refs_exclude(cls, model, expression, exclude):
+ get_field = model._meta.get_field
+ for field_name, *__ in model._get_expr_references(expression):
+ if field_name in exclude:
+ return True
+ field = get_field(field_name)
+ if field.generated and cls._expression_refs_exclude(
+ model, field.expression, exclude
+ ):
+ return True
+ return False
+
def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
raise NotImplementedError("This method must be implemented by a subclass.")
@@ -606,36 +619,56 @@ class UniqueConstraint(BaseConstraint):
queryset = model._default_manager.using(using)
if self.fields:
lookup_kwargs = {}
+ generated_field_names = []
for field_name in self.fields:
if exclude and field_name in exclude:
return
field = model._meta.get_field(field_name)
- lookup_value = getattr(instance, field.attname)
- if (
- self.nulls_distinct is not False
- and lookup_value is None
- or (
- lookup_value == ""
- and connections[
- using
- ].features.interprets_empty_strings_as_nulls
- )
- ):
- # A composite constraint containing NULL value cannot cause
- # a violation since NULL != NULL in SQL.
- return
- lookup_kwargs[field.name] = lookup_value
- queryset = queryset.filter(**lookup_kwargs)
+ if field.generated:
+ if exclude and self._expression_refs_exclude(
+ model, field.expression, exclude
+ ):
+ return
+ generated_field_names.append(field.name)
+ else:
+ lookup_value = getattr(instance, field.attname)
+ if (
+ self.nulls_distinct is not False
+ and lookup_value is None
+ or (
+ lookup_value == ""
+ and connections[
+ using
+ ].features.interprets_empty_strings_as_nulls
+ )
+ ):
+ # A composite constraint containing NULL value cannot cause
+ # a violation since NULL != NULL in SQL.
+ return
+ lookup_kwargs[field.name] = lookup_value
+ lookup_args = []
+ if generated_field_names:
+ field_expression_map = instance._get_field_expression_map(
+ meta=model._meta, exclude=exclude
+ )
+ for field_name in generated_field_names:
+ expression = field_expression_map[field_name]
+ if self.nulls_distinct is False:
+ lhs = F(field_name)
+ condition = Q(Exact(lhs, expression)) | Q(
+ IsNull(lhs, True), IsNull(expression, True)
+ )
+ lookup_args.append(condition)
+ else:
+ lookup_kwargs[field_name] = expression
+ queryset = queryset.filter(*lookup_args, **lookup_kwargs)
else:
# Ignore constraints with excluded fields.
- if exclude:
- for expression in self.expressions:
- if hasattr(expression, "flatten"):
- for expr in expression.flatten():
- if isinstance(expr, F) and expr.name in exclude:
- return
- elif isinstance(expression, F) and expression.name in exclude:
- return
+ if exclude and any(
+ self._expression_refs_exclude(model, expression, exclude)
+ for expression in self.expressions
+ ):
+ return
replacements = {
F(field): value
for field, value in instance._get_field_expression_map(
diff --git a/docs/releases/5.2.txt b/docs/releases/5.2.txt
index 7ae44cfd97..02a068e5af 100644
--- a/docs/releases/5.2.txt
+++ b/docs/releases/5.2.txt
@@ -215,6 +215,9 @@ Models
methods such as
:meth:`QuerySet.union()<django.db.models.query.QuerySet.union>` unpredictable.
+* Added support for validation of model constraints which use a
+ :class:`~django.db.models.GeneratedField`.
+
Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~
diff --git a/tests/constraints/models.py b/tests/constraints/models.py
index 983d550502..829f671cdd 100644
--- a/tests/constraints/models.py
+++ b/tests/constraints/models.py
@@ -1,4 +1,5 @@
from django.db import models
+from django.db.models.functions import Coalesce, Lower
class Product(models.Model):
@@ -28,6 +29,46 @@ class Product(models.Model):
]
+class GeneratedFieldStoredProduct(models.Model):
+ name = models.CharField(max_length=255, null=True)
+ price = models.IntegerField(null=True)
+ discounted_price = models.IntegerField(null=True)
+ rebate = models.GeneratedField(
+ expression=Coalesce("price", 0)
+ - Coalesce("discounted_price", Coalesce("price", 0)),
+ output_field=models.IntegerField(),
+ db_persist=True,
+ )
+ lower_name = models.GeneratedField(
+ expression=Lower(models.F("name")),
+ output_field=models.CharField(max_length=255, null=True),
+ db_persist=True,
+ )
+
+ class Meta:
+ required_db_features = {"supports_stored_generated_columns"}
+
+
+class GeneratedFieldVirtualProduct(models.Model):
+ name = models.CharField(max_length=255, null=True)
+ price = models.IntegerField(null=True)
+ discounted_price = models.IntegerField(null=True)
+ rebate = models.GeneratedField(
+ expression=Coalesce("price", 0)
+ - Coalesce("discounted_price", Coalesce("price", 0)),
+ output_field=models.IntegerField(),
+ db_persist=False,
+ )
+ lower_name = models.GeneratedField(
+ expression=Lower(models.F("name")),
+ output_field=models.CharField(max_length=255, null=True),
+ db_persist=False,
+ )
+
+ class Meta:
+ required_db_features = {"supports_virtual_generated_columns"}
+
+
class UniqueConstraintProduct(models.Model):
name = models.CharField(max_length=255)
color = models.CharField(max_length=32, null=True)
diff --git a/tests/constraints/tests.py b/tests/constraints/tests.py
index 350f05f2b8..9ca889ca6d 100644
--- a/tests/constraints/tests.py
+++ b/tests/constraints/tests.py
@@ -4,7 +4,7 @@ from django.core.exceptions import ValidationError
from django.db import IntegrityError, connection, models
from django.db.models import F
from django.db.models.constraints import BaseConstraint, UniqueConstraint
-from django.db.models.functions import Abs, Lower, Upper
+from django.db.models.functions import Abs, Lower, Sqrt, Upper
from django.db.transaction import atomic
from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature
from django.test.utils import ignore_warnings
@@ -13,6 +13,8 @@ from django.utils.deprecation import RemovedInDjango60Warning
from .models import (
ChildModel,
ChildUniqueConstraintProduct,
+ GeneratedFieldStoredProduct,
+ GeneratedFieldVirtualProduct,
JSONFieldModel,
ModelWithDatabaseDefault,
Product,
@@ -384,6 +386,29 @@ class CheckConstraintTests(TestCase):
with self.assertRaisesMessage(ValidationError, msg):
json_exact_constraint.validate(JSONFieldModel, JSONFieldModel(data=data))
+ @skipUnlessDBFeature("supports_stored_generated_columns")
+ def test_validate_generated_field_stored(self):
+ self.assertGeneratedFieldIsValidated(model=GeneratedFieldStoredProduct)
+
+ @skipUnlessDBFeature("supports_virtual_generated_columns")
+ def test_validate_generated_field_virtual(self):
+ self.assertGeneratedFieldIsValidated(model=GeneratedFieldVirtualProduct)
+
+ def assertGeneratedFieldIsValidated(self, model):
+ constraint = models.CheckConstraint(
+ condition=models.Q(rebate__range=(0, 100)), name="bounded_rebate"
+ )
+ constraint.validate(model, model(price=50, discounted_price=20))
+
+ invalid_product = model(price=1200, discounted_price=500)
+ msg = f"Constraint “{constraint.name}” is violated."
+ with self.assertRaisesMessage(ValidationError, msg):
+ constraint.validate(model, invalid_product)
+
+ # Excluding referenced or generated fields should skip validation.
+ constraint.validate(model, invalid_product, exclude={"price"})
+ constraint.validate(model, invalid_product, exclude={"rebate"})
+
def test_check_deprecation(self):
msg = "CheckConstraint.check is deprecated in favor of `.condition`."
condition = models.Q(foo="bar")
@@ -1062,6 +1087,90 @@ class UniqueConstraintTests(TestCase):
exclude={"name"},
)
+ @skipUnlessDBFeature("supports_stored_generated_columns")
+ def test_validate_expression_generated_field_stored(self):
+ self.assertGeneratedFieldWithExpressionIsValidated(
+ model=GeneratedFieldStoredProduct
+ )
+
+ @skipUnlessDBFeature("supports_virtual_generated_columns")
+ def test_validate_expression_generated_field_virtual(self):
+ self.assertGeneratedFieldWithExpressionIsValidated(
+ model=GeneratedFieldVirtualProduct
+ )
+
+ def assertGeneratedFieldWithExpressionIsValidated(self, model):
+ constraint = UniqueConstraint(Sqrt("rebate"), name="unique_rebate_sqrt")
+ model.objects.create(price=100, discounted_price=84)
+
+ valid_product = model(price=100, discounted_price=75)
+ constraint.validate(model, valid_product)
+
+ invalid_product = model(price=20, discounted_price=4)
+ with self.assertRaisesMessage(
+ ValidationError, f"Constraint “{constraint.name}” is violated."
+ ):
+ constraint.validate(model, invalid_product)
+
+ # Excluding referenced or generated fields should skip validation.
+ constraint.validate(model, invalid_product, exclude={"rebate"})
+ constraint.validate(model, invalid_product, exclude={"price"})
+
+ @skipUnlessDBFeature("supports_stored_generated_columns")
+ def test_validate_fields_generated_field_stored(self):
+ self.assertGeneratedFieldWithFieldsIsValidated(
+ model=GeneratedFieldStoredProduct
+ )
+
+ @skipUnlessDBFeature("supports_virtual_generated_columns")
+ def test_validate_fields_generated_field_virtual(self):
+ self.assertGeneratedFieldWithFieldsIsValidated(
+ model=GeneratedFieldVirtualProduct
+ )
+
+ def assertGeneratedFieldWithFieldsIsValidated(self, model):
+ constraint = models.UniqueConstraint(
+ fields=["lower_name"], name="lower_name_unique"
+ )
+ model.objects.create(name="Box")
+ constraint.validate(model, model(name="Case"))
+
+ invalid_product = model(name="BOX")
+ msg = str(invalid_product.unique_error_message(model, ["lower_name"]))
+ with self.assertRaisesMessage(ValidationError, msg):
+ constraint.validate(model, invalid_product)
+
+ # Excluding referenced or generated fields should skip validation.
+ constraint.validate(model, invalid_product, exclude={"lower_name"})
+ constraint.validate(model, invalid_product, exclude={"name"})
+
+ @skipUnlessDBFeature("supports_stored_generated_columns")
+ def test_validate_fields_generated_field_stored_nulls_distinct(self):
+ self.assertGeneratedFieldNullsDistinctIsValidated(
+ model=GeneratedFieldStoredProduct
+ )
+
+ @skipUnlessDBFeature("supports_virtual_generated_columns")
+ def test_validate_fields_generated_field_virtual_nulls_distinct(self):
+ self.assertGeneratedFieldNullsDistinctIsValidated(
+ model=GeneratedFieldVirtualProduct
+ )
+
+ def assertGeneratedFieldNullsDistinctIsValidated(self, model):
+ constraint = models.UniqueConstraint(
+ fields=["lower_name"],
+ name="lower_name_unique_nulls_distinct",
+ nulls_distinct=False,
+ )
+ model.objects.create(name=None)
+ valid_product = model(name="Box")
+ constraint.validate(model, valid_product)
+
+ invalid_product = model(name=None)
+ msg = str(invalid_product.unique_error_message(model, ["lower_name"]))
+ with self.assertRaisesMessage(ValidationError, msg):
+ constraint.validate(model, invalid_product)
+
@skipUnlessDBFeature("supports_table_check_constraints")
def test_validate_nullable_textfield_with_isnull_true(self):
is_null_constraint = models.UniqueConstraint(
diff --git a/tests/postgres_tests/test_constraints.py b/tests/postgres_tests/test_constraints.py
index f571a96f35..ab5bf2bab1 100644
--- a/tests/postgres_tests/test_constraints.py
+++ b/tests/postgres_tests/test_constraints.py
@@ -14,6 +14,7 @@ from django.db.models import (
F,
ForeignKey,
Func,
+ GeneratedField,
IntegerField,
Model,
Q,
@@ -32,6 +33,7 @@ try:
from django.contrib.postgres.constraints import ExclusionConstraint
from django.contrib.postgres.fields import (
DateTimeRangeField,
+ IntegerRangeField,
RangeBoundary,
RangeOperators,
)
@@ -866,6 +868,38 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
constraint.validate(RangesModel, RangesModel(ints=(51, 60)))
constraint.validate(RangesModel, RangesModel(ints=(10, 20)), exclude={"ints"})
+ @skipUnlessDBFeature("supports_stored_generated_columns")
+ @isolate_apps("postgres_tests")
+ def test_validate_generated_field_range_adjacent(self):
+ class RangesModelGeneratedField(Model):
+ ints = IntegerRangeField(blank=True, null=True)
+ ints_generated = GeneratedField(
+ expression=F("ints"),
+ output_field=IntegerRangeField(null=True),
+ db_persist=True,
+ )
+
+ with connection.schema_editor() as editor:
+ editor.create_model(RangesModelGeneratedField)
+
+ constraint = ExclusionConstraint(
+ name="ints_adjacent",
+ expressions=[("ints_generated", RangeOperators.ADJACENT_TO)],
+ violation_error_code="custom_code",
+ violation_error_message="Custom error message.",
+ )
+ RangesModelGeneratedField.objects.create(ints=(20, 50))
+
+ range_obj = RangesModelGeneratedField(ints=(3, 20))
+ with self.assertRaisesMessage(ValidationError, "Custom error message."):
+ constraint.validate(RangesModelGeneratedField, range_obj)
+
+ # Excluding referenced or generated field should skip validation.
+ constraint.validate(RangesModelGeneratedField, range_obj, exclude={"ints"})
+ constraint.validate(
+ RangesModelGeneratedField, range_obj, exclude={"ints_generated"}
+ )
+
def test_validate_with_custom_code_and_condition(self):
constraint = ExclusionConstraint(
name="ints_adjacent",