summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaolo Melchiorre <paolo@melchiorre.org>2023-09-13 22:11:08 +0200
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2023-09-14 21:17:12 +0200
commit68d769e691eb0d765228defddb3ba982eabdc761 (patch)
tree51410c961f6d24b4871e52192d753fcc43d2cc1f
parent969ecb8236f033d183108fb28849974da188da50 (diff)
Fixed #34838 -- Corrected output_field of resolved columns for GeneratedFields.
Thanks Simon Charette for the implementation idea.
-rw-r--r--django/db/models/fields/generated.py12
-rw-r--r--tests/model_fields/test_generatedfield.py36
2 files changed, 47 insertions, 1 deletions
diff --git a/django/db/models/fields/generated.py b/django/db/models/fields/generated.py
index 0980be98af..deb5875638 100644
--- a/django/db/models/fields/generated.py
+++ b/django/db/models/fields/generated.py
@@ -1,6 +1,7 @@
from django.core import checks
from django.db import connections, router
from django.db.models.sql import Query
+from django.utils.functional import cached_property
from . import NOT_PROVIDED, Field
@@ -32,6 +33,17 @@ class GeneratedField(Field):
self.db_persist = db_persist
super().__init__(**kwargs)
+ @cached_property
+ def cached_col(self):
+ from django.db.models.expressions import Col
+
+ return Col(self.model._meta.db_table, self, self.output_field)
+
+ def get_col(self, alias, output_field=None):
+ if alias != self.model._meta.db_table and output_field is None:
+ output_field = self.output_field
+ return super().get_col(alias, output_field)
+
def contribute_to_class(self, *args, **kwargs):
super().contribute_to_class(*args, **kwargs)
diff --git a/tests/model_fields/test_generatedfield.py b/tests/model_fields/test_generatedfield.py
index e2746bdd0c..dec1f3a31f 100644
--- a/tests/model_fields/test_generatedfield.py
+++ b/tests/model_fields/test_generatedfield.py
@@ -1,6 +1,6 @@
from django.core.exceptions import FieldError
from django.db import IntegrityError, connection
-from django.db.models import F, GeneratedField, IntegerField
+from django.db.models import F, FloatField, GeneratedField, IntegerField, Model
from django.db.models.functions import Lower
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
@@ -49,6 +49,40 @@ class BaseGeneratedFieldTests(SimpleTestCase):
self.assertEqual(args, [])
self.assertEqual(kwargs, {"db_persist": True, "expression": F("a") + F("b")})
+ def test_get_col(self):
+ class Square(Model):
+ side = IntegerField()
+ area = GeneratedField(expression=F("side") * F("side"), db_persist=True)
+
+ col = Square._meta.get_field("area").get_col("alias")
+ self.assertIsInstance(col.output_field, IntegerField)
+
+ class FloatSquare(Model):
+ side = IntegerField()
+ area = GeneratedField(
+ expression=F("side") * F("side"),
+ db_persist=True,
+ output_field=FloatField(),
+ )
+
+ col = FloatSquare._meta.get_field("area").get_col("alias")
+ self.assertIsInstance(col.output_field, FloatField)
+
+ def test_cached_col(self):
+ class Sum(Model):
+ a = IntegerField()
+ b = IntegerField()
+ total = GeneratedField(expression=F("a") + F("b"), db_persist=True)
+
+ field = Sum._meta.get_field("total")
+ cached_col = field.cached_col
+ self.assertIs(field.get_col(Sum._meta.db_table), cached_col)
+ self.assertIs(field.get_col(Sum._meta.db_table, field), cached_col)
+ self.assertIsNot(field.get_col("alias"), cached_col)
+ self.assertIsNot(field.get_col(Sum._meta.db_table, IntegerField()), cached_col)
+ self.assertIs(cached_col.target, field)
+ self.assertIsInstance(cached_col.output_field, IntegerField)
+
class GeneratedFieldTestMixin:
def _refresh_if_needed(self, m):