diff options
| author | Ian Foote <python@ian.feete.org> | 2020-11-22 22:27:57 +0000 |
|---|---|---|
| committer | Mariusz Felisiak <felisiak.mariusz@gmail.com> | 2023-05-12 19:11:40 +0200 |
| commit | 7414704e88d73dafbcfbb85f9bc54cb6111439d3 (patch) | |
| tree | f51136b16e457d7f46e01ff3cc06308faf0923db /tests | |
| parent | 599f3e2cda50ab084915ffd08edb5ad6cad61415 (diff) | |
Fixed #470 -- Added support for database defaults on fields.
Special thanks to Hannes Ljungberg for finding multiple implementation
gaps.
Thanks also to Simon Charette, Adam Johnson, and Mariusz Felisiak for
reviews.
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/basic/models.py | 4 | ||||
| -rw-r--r-- | tests/basic/tests.py | 6 | ||||
| -rw-r--r-- | tests/field_defaults/models.py | 44 | ||||
| -rw-r--r-- | tests/field_defaults/tests.py | 192 | ||||
| -rw-r--r-- | tests/invalid_models_tests/test_ordinary_fields.py | 107 | ||||
| -rw-r--r-- | tests/migrations/test_autodetector.py | 40 | ||||
| -rw-r--r-- | tests/migrations/test_base.py | 7 | ||||
| -rw-r--r-- | tests/migrations/test_operations.py | 317 | ||||
| -rw-r--r-- | tests/schema/tests.py | 27 |
9 files changed, 737 insertions, 7 deletions
diff --git a/tests/basic/models.py b/tests/basic/models.py index 59a6a8d67f..b71b60a213 100644 --- a/tests/basic/models.py +++ b/tests/basic/models.py @@ -49,5 +49,9 @@ class PrimaryKeyWithDefault(models.Model): uuid = models.UUIDField(primary_key=True, default=uuid.uuid4) +class PrimaryKeyWithDbDefault(models.Model): + uuid = models.IntegerField(primary_key=True, db_default=1) + + class ChildPrimaryKeyWithDefault(PrimaryKeyWithDefault): pass diff --git a/tests/basic/tests.py b/tests/basic/tests.py index ea9228376c..3c2d1dead9 100644 --- a/tests/basic/tests.py +++ b/tests/basic/tests.py @@ -20,6 +20,7 @@ from .models import ( ArticleSelectOnSave, ChildPrimaryKeyWithDefault, FeaturedArticle, + PrimaryKeyWithDbDefault, PrimaryKeyWithDefault, SelfRef, ) @@ -175,6 +176,11 @@ class ModelInstanceCreationTests(TestCase): with self.assertNumQueries(1): PrimaryKeyWithDefault().save() + def test_save_primary_with_db_default(self): + # An UPDATE attempt is skipped when a primary key has db_default. + with self.assertNumQueries(1): + PrimaryKeyWithDbDefault().save() + def test_save_parent_primary_with_default(self): # An UPDATE attempt is skipped when an inherited primary key has # default. diff --git a/tests/field_defaults/models.py b/tests/field_defaults/models.py index b95005192a..5f9c38a5a4 100644 --- a/tests/field_defaults/models.py +++ b/tests/field_defaults/models.py @@ -12,6 +12,8 @@ field. from datetime import datetime from django.db import models +from django.db.models.functions import Coalesce, ExtractYear, Now, Pi +from django.db.models.lookups import GreaterThan class Article(models.Model): @@ -20,3 +22,45 @@ class Article(models.Model): def __str__(self): return self.headline + + +class DBArticle(models.Model): + """ + Values or expressions can be passed as the db_default parameter to a field. + When the object is created without an explicit value passed in, the + database will insert the default value automatically. + """ + + headline = models.CharField(max_length=100, db_default="Default headline") + pub_date = models.DateTimeField(db_default=Now()) + + class Meta: + required_db_features = {"supports_expression_defaults"} + + +class DBDefaults(models.Model): + both = models.IntegerField(default=1, db_default=2) + null = models.FloatField(null=True, db_default=1.1) + + +class DBDefaultsFunction(models.Model): + number = models.FloatField(db_default=Pi()) + year = models.IntegerField(db_default=ExtractYear(Now())) + added = models.FloatField(db_default=Pi() + 4.5) + multiple_subfunctions = models.FloatField(db_default=Coalesce(4.5, Pi())) + case_when = models.IntegerField( + db_default=models.Case(models.When(GreaterThan(2, 1), then=3), default=4) + ) + + class Meta: + required_db_features = {"supports_expression_defaults"} + + +class DBDefaultsPK(models.Model): + language_code = models.CharField(primary_key=True, max_length=2, db_default="en") + + +class DBDefaultsFK(models.Model): + language_code = models.ForeignKey( + DBDefaultsPK, db_default="fr", on_delete=models.CASCADE + ) diff --git a/tests/field_defaults/tests.py b/tests/field_defaults/tests.py index 19b05aa537..76d01f7a5a 100644 --- a/tests/field_defaults/tests.py +++ b/tests/field_defaults/tests.py @@ -1,8 +1,28 @@ from datetime import datetime +from math import pi -from django.test import TestCase +from django.db import connection +from django.db.models import Case, F, FloatField, Value, When +from django.db.models.expressions import ( + Expression, + ExpressionList, + ExpressionWrapper, + Func, + OrderByList, + RawSQL, +) +from django.db.models.functions import Collate +from django.db.models.lookups import GreaterThan +from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature -from .models import Article +from .models import ( + Article, + DBArticle, + DBDefaults, + DBDefaultsFK, + DBDefaultsFunction, + DBDefaultsPK, +) class DefaultTests(TestCase): @@ -14,3 +34,171 @@ class DefaultTests(TestCase): self.assertIsInstance(a.id, int) self.assertEqual(a.headline, "Default headline") self.assertLess((now - a.pub_date).seconds, 5) + + @skipUnlessDBFeature( + "can_return_columns_from_insert", "supports_expression_defaults" + ) + def test_field_db_defaults_returning(self): + a = DBArticle() + a.save() + self.assertIsInstance(a.id, int) + self.assertEqual(a.headline, "Default headline") + self.assertIsInstance(a.pub_date, datetime) + + @skipIfDBFeature("can_return_columns_from_insert") + @skipUnlessDBFeature("supports_expression_defaults") + def test_field_db_defaults_refresh(self): + a = DBArticle() + a.save() + a.refresh_from_db() + self.assertIsInstance(a.id, int) + self.assertEqual(a.headline, "Default headline") + self.assertIsInstance(a.pub_date, datetime) + + def test_null_db_default(self): + obj1 = DBDefaults.objects.create() + if not connection.features.can_return_columns_from_insert: + obj1.refresh_from_db() + self.assertEqual(obj1.null, 1.1) + + obj2 = DBDefaults.objects.create(null=None) + self.assertIsNone(obj2.null) + + @skipUnlessDBFeature("supports_expression_defaults") + def test_db_default_function(self): + m = DBDefaultsFunction.objects.create() + if not connection.features.can_return_columns_from_insert: + m.refresh_from_db() + self.assertAlmostEqual(m.number, pi) + self.assertEqual(m.year, datetime.now().year) + self.assertAlmostEqual(m.added, pi + 4.5) + self.assertEqual(m.multiple_subfunctions, 4.5) + + @skipUnlessDBFeature("insert_test_table_with_defaults") + def test_both_default(self): + create_sql = connection.features.insert_test_table_with_defaults + with connection.cursor() as cursor: + cursor.execute(create_sql.format(DBDefaults._meta.db_table)) + obj1 = DBDefaults.objects.get() + self.assertEqual(obj1.both, 2) + + obj2 = DBDefaults.objects.create() + self.assertEqual(obj2.both, 1) + + def test_pk_db_default(self): + obj1 = DBDefaultsPK.objects.create() + if not connection.features.can_return_columns_from_insert: + # refresh_from_db() cannot be used because that needs the pk to + # already be known to Django. + obj1 = DBDefaultsPK.objects.get(pk="en") + self.assertEqual(obj1.pk, "en") + self.assertEqual(obj1.language_code, "en") + + obj2 = DBDefaultsPK.objects.create(language_code="de") + self.assertEqual(obj2.pk, "de") + self.assertEqual(obj2.language_code, "de") + + def test_foreign_key_db_default(self): + parent1 = DBDefaultsPK.objects.create(language_code="fr") + child1 = DBDefaultsFK.objects.create() + if not connection.features.can_return_columns_from_insert: + child1.refresh_from_db() + self.assertEqual(child1.language_code, parent1) + + parent2 = DBDefaultsPK.objects.create() + if not connection.features.can_return_columns_from_insert: + # refresh_from_db() cannot be used because that needs the pk to + # already be known to Django. + parent2 = DBDefaultsPK.objects.get(pk="en") + child2 = DBDefaultsFK.objects.create(language_code=parent2) + self.assertEqual(child2.language_code, parent2) + + @skipUnlessDBFeature( + "can_return_columns_from_insert", "supports_expression_defaults" + ) + def test_case_when_db_default_returning(self): + m = DBDefaultsFunction.objects.create() + self.assertEqual(m.case_when, 3) + + @skipIfDBFeature("can_return_columns_from_insert") + @skipUnlessDBFeature("supports_expression_defaults") + def test_case_when_db_default_no_returning(self): + m = DBDefaultsFunction.objects.create() + m.refresh_from_db() + self.assertEqual(m.case_when, 3) + + @skipUnlessDBFeature("supports_expression_defaults") + def test_bulk_create_all_db_defaults(self): + articles = [DBArticle(), DBArticle()] + DBArticle.objects.bulk_create(articles) + + headlines = DBArticle.objects.values_list("headline", flat=True) + self.assertSequenceEqual(headlines, ["Default headline", "Default headline"]) + + @skipUnlessDBFeature("supports_expression_defaults") + def test_bulk_create_all_db_defaults_one_field(self): + pub_date = datetime.now() + articles = [DBArticle(pub_date=pub_date), DBArticle(pub_date=pub_date)] + DBArticle.objects.bulk_create(articles) + + headlines = DBArticle.objects.values_list("headline", "pub_date") + self.assertSequenceEqual( + headlines, + [ + ("Default headline", pub_date), + ("Default headline", pub_date), + ], + ) + + @skipUnlessDBFeature("supports_expression_defaults") + def test_bulk_create_mixed_db_defaults(self): + articles = [DBArticle(), DBArticle(headline="Something else")] + DBArticle.objects.bulk_create(articles) + + headlines = DBArticle.objects.values_list("headline", flat=True) + self.assertCountEqual(headlines, ["Default headline", "Something else"]) + + @skipUnlessDBFeature("supports_expression_defaults") + def test_bulk_create_mixed_db_defaults_function(self): + instances = [DBDefaultsFunction(), DBDefaultsFunction(year=2000)] + DBDefaultsFunction.objects.bulk_create(instances) + + years = DBDefaultsFunction.objects.values_list("year", flat=True) + self.assertCountEqual(years, [2000, datetime.now().year]) + + +class AllowedDefaultTests(SimpleTestCase): + def test_allowed(self): + class Max(Func): + function = "MAX" + + tests = [ + Value(10), + Max(1, 2), + RawSQL("Now()", ()), + Value(10) + Value(7), # Combined expression. + ExpressionList(Value(1), Value(2)), + ExpressionWrapper(Value(1), output_field=FloatField()), + Case(When(GreaterThan(2, 1), then=3), default=4), + ] + for expression in tests: + with self.subTest(expression=expression): + self.assertIs(expression.allowed_default, True) + + def test_disallowed(self): + class Max(Func): + function = "MAX" + + tests = [ + Expression(), + F("field"), + Max(F("count"), 1), + Value(10) + F("count"), # Combined expression. + ExpressionList(F("count"), Value(2)), + ExpressionWrapper(F("count"), output_field=FloatField()), + Collate(Value("John"), "nocase"), + OrderByList("field"), + ] + for expression in tests: + with self.subTest(expression=expression): + self.assertIs(expression.allowed_default, False) diff --git a/tests/invalid_models_tests/test_ordinary_fields.py b/tests/invalid_models_tests/test_ordinary_fields.py index 4e37c48286..e9e8a702e0 100644 --- a/tests/invalid_models_tests/test_ordinary_fields.py +++ b/tests/invalid_models_tests/test_ordinary_fields.py @@ -4,6 +4,7 @@ import uuid from django.core.checks import Error from django.core.checks import Warning as DjangoWarning from django.db import connection, models +from django.db.models.functions import Coalesce, Pi from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature from django.test.utils import isolate_apps, override_settings from django.utils.functional import lazy @@ -1057,3 +1058,109 @@ class DbCommentTests(TestCase): errors = Model._meta.get_field("field").check(databases=self.databases) self.assertEqual(errors, []) + + +@isolate_apps("invalid_models_tests") +class InvalidDBDefaultTests(TestCase): + def test_db_default(self): + class Model(models.Model): + field = models.FloatField(db_default=Pi()) + + field = Model._meta.get_field("field") + errors = field.check(databases=self.databases) + + if connection.features.supports_expression_defaults: + expected_errors = [] + else: + msg = ( + f"{connection.display_name} does not support default database values " + "with expressions (db_default)." + ) + expected_errors = [Error(msg=msg, obj=field, id="fields.E011")] + self.assertEqual(errors, expected_errors) + + def test_db_default_literal(self): + class Model(models.Model): + field = models.IntegerField(db_default=1) + + field = Model._meta.get_field("field") + errors = field.check(databases=self.databases) + self.assertEqual(errors, []) + + def test_db_default_required_db_features(self): + class Model(models.Model): + field = models.FloatField(db_default=Pi()) + + class Meta: + required_db_features = {"supports_expression_defaults"} + + field = Model._meta.get_field("field") + errors = field.check(databases=self.databases) + self.assertEqual(errors, []) + + def test_db_default_expression_invalid(self): + expression = models.F("field_name") + + class Model(models.Model): + field = models.FloatField(db_default=expression) + + field = Model._meta.get_field("field") + errors = field.check(databases=self.databases) + + if connection.features.supports_expression_defaults: + msg = f"{expression} cannot be used in db_default." + expected_errors = [Error(msg=msg, obj=field, id="fields.E012")] + else: + msg = ( + f"{connection.display_name} does not support default database values " + "with expressions (db_default)." + ) + expected_errors = [Error(msg=msg, obj=field, id="fields.E011")] + self.assertEqual(errors, expected_errors) + + def test_db_default_expression_required_db_features(self): + expression = models.F("field_name") + + class Model(models.Model): + field = models.FloatField(db_default=expression) + + class Meta: + required_db_features = {"supports_expression_defaults"} + + field = Model._meta.get_field("field") + errors = field.check(databases=self.databases) + + if connection.features.supports_expression_defaults: + msg = f"{expression} cannot be used in db_default." + expected_errors = [Error(msg=msg, obj=field, id="fields.E012")] + else: + expected_errors = [] + self.assertEqual(errors, expected_errors) + + @skipUnlessDBFeature("supports_expression_defaults") + def test_db_default_combined_invalid(self): + expression = models.Value(4.5) + models.F("field_name") + + class Model(models.Model): + field = models.FloatField(db_default=expression) + + field = Model._meta.get_field("field") + errors = field.check(databases=self.databases) + + msg = f"{expression} cannot be used in db_default." + expected_error = Error(msg=msg, obj=field, id="fields.E012") + self.assertEqual(errors, [expected_error]) + + @skipUnlessDBFeature("supports_expression_defaults") + def test_db_default_function_arguments_invalid(self): + expression = Coalesce(models.Value(4.5), models.F("field_name")) + + class Model(models.Model): + field = models.FloatField(db_default=expression) + + field = Model._meta.get_field("field") + errors = field.check(databases=self.databases) + + msg = f"{expression} cannot be used in db_default." + expected_error = Error(msg=msg, obj=field, id="fields.E012") + self.assertEqual(errors, [expected_error]) diff --git a/tests/migrations/test_autodetector.py b/tests/migrations/test_autodetector.py index ee199fea68..74892bbf3d 100644 --- a/tests/migrations/test_autodetector.py +++ b/tests/migrations/test_autodetector.py @@ -269,6 +269,14 @@ class AutodetectorTests(BaseAutodetectorTests): ("name", models.CharField(max_length=200, default="Ada Lovelace")), ], ) + author_name_db_default = ModelState( + "testapp", + "Author", + [ + ("id", models.AutoField(primary_key=True)), + ("name", models.CharField(max_length=200, db_default="Ada Lovelace")), + ], + ) author_name_check_constraint = ModelState( "testapp", "Author", @@ -1293,6 +1301,21 @@ class AutodetectorTests(BaseAutodetectorTests): "django.db.migrations.questioner.MigrationQuestioner.ask_not_null_addition", side_effect=AssertionError("Should not have prompted for not null addition"), ) + def test_add_not_null_field_with_db_default(self, mocked_ask_method): + changes = self.get_changes([self.author_empty], [self.author_name_db_default]) + self.assertNumberMigrations(changes, "testapp", 1) + self.assertOperationTypes(changes, "testapp", 0, ["AddField"]) + self.assertOperationAttributes( + changes, "testapp", 0, 0, name="name", preserve_default=True + ) + self.assertOperationFieldAttributes( + changes, "testapp", 0, 0, db_default=models.Value("Ada Lovelace") + ) + + @mock.patch( + "django.db.migrations.questioner.MigrationQuestioner.ask_not_null_addition", + side_effect=AssertionError("Should not have prompted for not null addition"), + ) def test_add_date_fields_with_auto_now_not_asking_for_default( self, mocked_ask_method ): @@ -1480,6 +1503,23 @@ class AutodetectorTests(BaseAutodetectorTests): @mock.patch( "django.db.migrations.questioner.MigrationQuestioner.ask_not_null_alteration", + side_effect=AssertionError("Should not have prompted for not null alteration"), + ) + def test_alter_field_to_not_null_with_db_default(self, mocked_ask_method): + changes = self.get_changes( + [self.author_name_null], [self.author_name_db_default] + ) + self.assertNumberMigrations(changes, "testapp", 1) + self.assertOperationTypes(changes, "testapp", 0, ["AlterField"]) + self.assertOperationAttributes( + changes, "testapp", 0, 0, name="name", preserve_default=True + ) + self.assertOperationFieldAttributes( + changes, "testapp", 0, 0, db_default=models.Value("Ada Lovelace") + ) + + @mock.patch( + "django.db.migrations.questioner.MigrationQuestioner.ask_not_null_alteration", return_value=models.NOT_PROVIDED, ) def test_alter_field_to_not_null_without_default(self, mocked_ask_method): diff --git a/tests/migrations/test_base.py b/tests/migrations/test_base.py index f038cd7605..b5228ad445 100644 --- a/tests/migrations/test_base.py +++ b/tests/migrations/test_base.py @@ -292,6 +292,13 @@ class OperationTestBase(MigrationTestBase): ("id", models.AutoField(primary_key=True)), ("pink", models.IntegerField(default=3)), ("weight", models.FloatField()), + ("green", models.IntegerField(null=True)), + ( + "yellow", + models.CharField( + blank=True, null=True, db_default="Yellow", max_length=20 + ), + ), ], options=model_options, ) diff --git a/tests/migrations/test_operations.py b/tests/migrations/test_operations.py index b67a871bc8..e377e4ca64 100644 --- a/tests/migrations/test_operations.py +++ b/tests/migrations/test_operations.py @@ -1,14 +1,18 @@ +import math + from django.core.exceptions import FieldDoesNotExist from django.db import IntegrityError, connection, migrations, models, transaction from django.db.migrations.migration import Migration from django.db.migrations.operations.fields import FieldOperation from django.db.migrations.state import ModelState, ProjectState -from django.db.models.functions import Abs +from django.db.models.expressions import Value +from django.db.models.functions import Abs, Pi from django.db.transaction import atomic from django.test import ( SimpleTestCase, ignore_warnings, override_settings, + skipIfDBFeature, skipUnlessDBFeature, ) from django.test.utils import CaptureQueriesContext @@ -1340,7 +1344,7 @@ class OperationTests(OperationTestBase): self.assertEqual(operation.describe(), "Add field height to Pony") self.assertEqual(operation.migration_name_fragment, "pony_height") project_state, new_state = self.make_test_state("test_adfl", operation) - self.assertEqual(len(new_state.models["test_adfl", "pony"].fields), 4) + self.assertEqual(len(new_state.models["test_adfl", "pony"].fields), 6) field = new_state.models["test_adfl", "pony"].fields["height"] self.assertEqual(field.default, 5) # Test the database alteration @@ -1528,7 +1532,7 @@ class OperationTests(OperationTestBase): ) new_state = project_state.clone() operation.state_forwards("test_adflpd", new_state) - self.assertEqual(len(new_state.models["test_adflpd", "pony"].fields), 4) + self.assertEqual(len(new_state.models["test_adflpd", "pony"].fields), 6) field = new_state.models["test_adflpd", "pony"].fields["height"] self.assertEqual(field.default, models.NOT_PROVIDED) # Test the database alteration @@ -1547,6 +1551,169 @@ class OperationTests(OperationTestBase): sorted(definition[2]), ["field", "model_name", "name", "preserve_default"] ) + def test_add_field_database_default(self): + """The AddField operation can set and unset a database default.""" + app_label = "test_adfldd" + table_name = f"{app_label}_pony" + project_state = self.set_up_test_model(app_label) + operation = migrations.AddField( + "Pony", "height", models.FloatField(null=True, db_default=4) + ) + new_state = project_state.clone() + operation.state_forwards(app_label, new_state) + self.assertEqual(len(new_state.models[app_label, "pony"].fields), 6) + field = new_state.models[app_label, "pony"].fields["height"] + self.assertEqual(field.default, models.NOT_PROVIDED) + self.assertEqual(field.db_default, Value(4)) + project_state.apps.get_model(app_label, "pony").objects.create(weight=4) + self.assertColumnNotExists(table_name, "height") + # Add field. + with connection.schema_editor() as editor: + operation.database_forwards(app_label, editor, project_state, new_state) + self.assertColumnExists(table_name, "height") + new_model = new_state.apps.get_model(app_label, "pony") + old_pony = new_model.objects.get() + self.assertEqual(old_pony.height, 4) + new_pony = new_model.objects.create(weight=5) + if not connection.features.can_return_columns_from_insert: + new_pony.refresh_from_db() + self.assertEqual(new_pony.height, 4) + # Reversal. + with connection.schema_editor() as editor: + operation.database_backwards(app_label, editor, new_state, project_state) + self.assertColumnNotExists(table_name, "height") + # Deconstruction. + definition = operation.deconstruct() + self.assertEqual(definition[0], "AddField") + self.assertEqual(definition[1], []) + self.assertEqual( + definition[2], + { + "field": field, + "model_name": "Pony", + "name": "height", + }, + ) + + def test_add_field_database_default_special_char_escaping(self): + app_label = "test_adflddsce" + table_name = f"{app_label}_pony" + project_state = self.set_up_test_model(app_label) + old_pony_pk = ( + project_state.apps.get_model(app_label, "pony").objects.create(weight=4).pk + ) + tests = ["%", "'", '"'] + for db_default in tests: + with self.subTest(db_default=db_default): + operation = migrations.AddField( + "Pony", + "special_char", + models.CharField(max_length=1, db_default=db_default), + ) + new_state = project_state.clone() + operation.state_forwards(app_label, new_state) + self.assertEqual(len(new_state.models[app_label, "pony"].fields), 6) + field = new_state.models[app_label, "pony"].fields["special_char"] + self.assertEqual(field.default, models.NOT_PROVIDED) + self.assertEqual(field.db_default, Value(db_default)) + self.assertColumnNotExists(table_name, "special_char") + with connection.schema_editor() as editor: + operation.database_forwards( + app_label, editor, project_state, new_state + ) + self.assertColumnExists(table_name, "special_char") + new_model = new_state.apps.get_model(app_label, "pony") + try: + new_pony = new_model.objects.create(weight=5) + if not connection.features.can_return_columns_from_insert: + new_pony.refresh_from_db() + self.assertEqual(new_pony.special_char, db_default) + + old_pony = new_model.objects.get(pk=old_pony_pk) + if connection.vendor != "oracle" or db_default != "'": + # The single quotation mark ' is properly quoted and is + # set for new rows on Oracle, however it is not set on + # existing rows. Skip the assertion as it's probably a + # bug in Oracle. + self.assertEqual(old_pony.special_char, db_default) + finally: + with connection.schema_editor() as editor: + operation.database_backwards( + app_label, editor, new_state, project_state + ) + + @skipUnlessDBFeature("supports_expression_defaults") + def test_add_field_database_default_function(self): + app_label = "test_adflddf" + table_name = f"{app_label}_pony" + project_state = self.set_up_test_model(app_label) + operation = migrations.AddField( + "Pony", "height", models.FloatField(db_default=Pi()) + ) + new_state = project_state.clone() + operation.state_forwards(app_label, new_state) + self.assertEqual(len(new_state.models[app_label, "pony"].fields), 6) + field = new_state.models[app_label, "pony"].fields["height"] + self.assertEqual(field.default, models.NOT_PROVIDED) + self.assertEqual(field.db_default, Pi()) + project_state.apps.get_model(app_label, "pony").objects.create(weight=4) + self.assertColumnNotExists(table_name, "height") + # Add field. + with connection.schema_editor() as editor: + operation.database_forwards(app_label, editor, project_state, new_state) + self.assertColumnExists(table_name, "height") + new_model = new_state.apps.get_model(app_label, "pony") + old_pony = new_model.objects.get() + self.assertAlmostEqual(old_pony.height, math.pi) + new_pony = new_model.objects.create(weight=5) + if not connection.features.can_return_columns_from_insert: + new_pony.refresh_from_db() + self.assertAlmostEqual(old_pony.height, math.pi) + + def test_add_field_both_defaults(self): + """The AddField operation with both default and db_default.""" + app_label = "test_adflbddd" + table_name = f"{app_label}_pony" + project_state = self.set_up_test_model(app_label) + operation = migrations.AddField( + "Pony", "height", models.FloatField(default=3, db_default=4) + ) + new_state = project_state.clone() + operation.state_forwards(app_label, new_state) + self.assertEqual(len(new_state.models[app_label, "pony"].fields), 6) + field = new_state.models[app_label, "pony"].fields["height"] + self.assertEqual(field.default, 3) + self.assertEqual(field.db_default, Value(4)) + project_state.apps.get_model(app_label, "pony").objects.create(weight=4) + self.assertColumnNotExists(table_name, "height") + # Add field. + with connection.schema_editor() as editor: + operation.database_forwards(app_label, editor, project_state, new_state) + self.assertColumnExists(table_name, "height") + new_model = new_state.apps.get_model(app_label, "pony") + old_pony = new_model.objects.get() + self.assertEqual(old_pony.height, 4) + new_pony = new_model.objects.create(weight=5) + if not connection.features.can_return_columns_from_insert: + new_pony.refresh_from_db() + self.assertEqual(new_pony.height, 3) + # Reversal. + with connection.schema_editor() as editor: + operation.database_backwards(app_label, editor, new_state, project_state) + self.assertColumnNotExists(table_name, "height") + # Deconstruction. + definition = operation.deconstruct() + self.assertEqual(definition[0], "AddField") + self.assertEqual(definition[1], []) + self.assertEqual( + definition[2], + { + "field": field, + "model_name": "Pony", + "name": "height", + }, + ) + def test_add_field_m2m(self): """ Tests the AddField operation with a ManyToManyField. @@ -1558,7 +1725,7 @@ class OperationTests(OperationTestBase): ) new_state = project_state.clone() operation.state_forwards("test_adflmm", new_state) - self.assertEqual(len(new_state.models["test_adflmm", "pony"].fields), 4) + self.assertEqual(len(new_state.models["test_adflmm", "pony"].fields), 6) # Test the database alteration self.assertTableNotExists("test_adflmm_pony_stables") with connection.schema_editor() as editor: @@ -1727,7 +1894,7 @@ class OperationTests(OperationTestBase): self.assertEqual(operation.migration_name_fragment, "remove_pony_pink") new_state = project_state.clone() operation.state_forwards("test_rmfl", new_state) - self.assertEqual(len(new_state.models["test_rmfl", "pony"].fields), 2) + self.assertEqual(len(new_state.models["test_rmfl", "pony"].fields), 4) # Test the database alteration self.assertColumnExists("test_rmfl_pony", "pink") with connection.schema_editor() as editor: @@ -1934,6 +2101,146 @@ class OperationTests(OperationTestBase): self.assertEqual(definition[1], []) self.assertEqual(sorted(definition[2]), ["field", "model_name", "name"]) + def test_alter_field_add_database_default(self): + app_label = "test_alfladd" + project_state = self.set_up_test_model(app_label) + operation = migrations.AlterField( + "Pony", "weight", models.FloatField(db_default=4.5) + ) + new_state = project_state.clone() + operation.state_forwards(app_label, new_state) + old_weight = project_state.models[app_label, "pony"].fields["weight"] + self.assertIs(old_weight.db_default, models.NOT_PROVIDED) + new_weight = new_state.models[app_label, "pony"].fields["weight"] + self.assertEqual(new_weight.db_default, Value(4.5)) + with self.assertRaises(IntegrityError), transaction.atomic(): + project_state.apps.get_model(app_label, "pony").objects.create() + # Alter field. + with connection.schema_editor() as editor: + operation.database_forwards(app_label, editor, project_state, new_state) + pony = new_state.apps.get_model(app_label, "pony").objects.create() + if not connection.features.can_return_columns_from_insert: + pony.refresh_from_db() + self.assertEqual(pony.weight, 4.5) + # Reversal. + with connection.schema_editor() as editor: + operation.database_backwards(app_label, editor, new_state, project_state) + with self.assertRaises(IntegrityError), transaction.atomic(): + project_state.apps.get_model(app_label, "pony").objects.create() + # Deconstruction. + definition = operation.deconstruct() + self.assertEqual(definition[0], "AlterField") + self.assertEqual(definition[1], []) + self.assertEqual( + definition[2], + { + "field": new_weight, + "model_name": "Pony", + "name": "weight", + }, + ) + + def test_alter_field_change_default_to_database_default(self): + """The AlterField operation changing default to db_default.""" + app_label = "test_alflcdtdd" + project_state = self.set_up_test_model(app_label) + operation = migrations.AlterField( + "Pony", "pink", models.IntegerField(db_default=4) + ) + new_state = project_state.clone() + operation.state_forwards(app_label, new_state) + old_pink = project_state.models[app_label, "pony"].fields["pink"] + self.assertEqual(old_pink.default, 3) + self.assertIs(old_pink.db_default, models.NOT_PROVIDED) + new_pink = new_state.models[app_label, "pony"].fields["pink"] + self.assertIs(new_pink.default, models.NOT_PROVIDED) + self.assertEqual(new_pink.db_default, Value(4)) + pony = project_state.apps.get_model(app_label, "pony").objects.create(weight=1) + self.assertEqual(pony.pink, 3) + # Alter field. + with connection.schema_editor() as editor: + operation.database_forwards(app_label, editor, project_state, new_state) + pony = new_state.apps.get_model(app_label, "pony").objects.create(weight=1) + if not connection.features.can_return_columns_from_insert: + pony.refresh_from_db() + self.assertEqual(pony.pink, 4) + # Reversal. + with connection.schema_editor() as editor: + operation.database_backwards(app_label, editor, new_state, project_state) + pony = project_state.apps.get_model(app_label, "pony").objects.create(weight=1) + self.assertEqual(pony.pink, 3) + + def test_alter_field_change_nullable_to_database_default_not_null(self): + """ + The AlterField operation changing a null field to db_default. + """ + app_label = "test_alflcntddnn" + project_state = self.set_up_test_model(app_label) + operation = migrations.AlterField( + "Pony", "green", models.IntegerField(db_default=4) + ) + new_state = project_state.clone() + operation.state_forwards(app_label, new_state) + old_green = project_state.models[app_label, "pony"].fields["green"] + self.assertIs(old_green.db_default, models.NOT_PROVIDED) + new_green = new_state.models[app_label, "pony"].fields["green"] + self.assertEqual(new_green.db_default, Value(4)) + old_pony = project_state.apps.get_model(app_label, "pony").objects.create( + weight=1 + ) + self.assertIsNone(old_pony.green) + # Alter field. + with connection.schema_editor() as editor: + operation.database_forwards(app_label, editor, project_state, new_state) + old_pony.refresh_from_db() + self.assertEqual(old_pony.green, 4) + pony = new_state.apps.get_model(app_label, "pony").objects.create(weight=1) + if not connection.features.can_return_columns_from_insert: + pony.refresh_from_db() + self.assertEqual(pony.green, 4) + # Reversal. + with connection.schema_editor() as editor: + operation.database_backwards(app_label, editor, new_state, project_state) + pony = project_state.apps.get_model(app_label, "pony").objects.create(weight=1) + self.assertIsNone(pony.green) + + @skipIfDBFeature("interprets_empty_strings_as_nulls") + def test_alter_field_change_blank_nullable_database_default_to_not_null(self): + app_label = "test_alflcbnddnn" + table_name = f"{app_label}_pony" + project_state = self.set_up_test_model(app_label) + default = "Yellow" + operation = migrations.AlterField( + "Pony", + "yellow", + models.CharField(blank=True, db_default=default, max_length=20), + ) + new_state = project_state.clone() + operation.state_forwards(app_label, new_state) + self.assertColumnNull(table_name, "yellow") + pony = project_state.apps.get_model(app_label, "pony").objects.create( + weight=1, yellow=None + ) + self.assertIsNone(pony.yellow) + # Alter field. + with connection.schema_editor() as editor: + operation.database_forwards(app_label, editor, project_state, new_state) + self.assertColumnNotNull(table_name, "yellow") + pony.refresh_from_db() + self.assertEqual(pony.yellow, default) + pony = new_state.apps.get_model(app_label, "pony").objects.create(weight=1) + if not connection.features.can_return_columns_from_insert: + pony.refresh_from_db() + self.assertEqual(pony.yellow, default) + # Reversal. + with connection.schema_editor() as editor: + operation.database_backwards(app_label, editor, new_state, project_state) + self.assertColumnNull(table_name, "yellow") + pony = project_state.apps.get_model(app_label, "pony").objects.create( + weight=1, yellow=None + ) + self.assertIsNone(pony.yellow) + def test_alter_field_add_db_column_noop(self): """ AlterField operation is a noop when adding only a db_column and the diff --git a/tests/schema/tests.py b/tests/schema/tests.py index d81a01b41d..688a9f1fcf 100644 --- a/tests/schema/tests.py +++ b/tests/schema/tests.py @@ -2102,6 +2102,33 @@ class SchemaTests(TransactionTestCase): with self.assertRaises(IntegrityError): NoteRename.objects.create(detail_info=None) + @isolate_apps("schema") + def test_rename_keep_db_default(self): + """Renaming a field shouldn't affect a database default.""" + + class AuthorDbDefault(Model): + birth_year = IntegerField(db_default=1985) + + class Meta: + app_label = "schema" + + self.isolated_local_models = [AuthorDbDefault] + with connection.schema_editor() as editor: + editor.create_model(AuthorDbDefault) + columns = self.column_classes(AuthorDbDefault) + self.assertEqual(columns["birth_year"][1].default, "1985") + + old_field = AuthorDbDefault._meta.get_field("birth_year") + new_field = IntegerField(db_default=1985) + new_field.set_attributes_from_name("renamed_year") + new_field.model = AuthorDbDefault + with connection.schema_editor( + atomic=connection.features.supports_atomic_references_rename + ) as editor: + editor.alter_field(AuthorDbDefault, old_field, new_field, strict=True) + columns = self.column_classes(AuthorDbDefault) + self.assertEqual(columns["renamed_year"][1].default, "1985") + @skipUnlessDBFeature( "supports_column_check_constraints", "can_introspect_check_constraints" ) |
