summaryrefslogtreecommitdiff
path: root/tests/migrations
diff options
context:
space:
mode:
authorAdam Johnson <me@adamj.eu>2025-05-12 17:49:38 -0300
committernessita <124304+nessita@users.noreply.github.com>2025-05-13 21:42:19 -0300
commit57fdc104d26df0a060f637f2128d830bfcc8e4f8 (patch)
tree478bf51a5b59df95f39e4eb2a6aa45d8d4b28aa3 /tests/migrations
parent4647e2b8663cbd22a07af70bf0f8540946763851 (diff)
Refs #36383 -- Added extra tests for serializing functools.partial in tests/migrations/test_writer.py.
This includes a test helper to better assert over the expected output. Co-authored-by: Natalia <124304+nessita@users.noreply.github.com>
Diffstat (limited to 'tests/migrations')
-rw-r--r--tests/migrations/test_writer.py68
1 files changed, 60 insertions, 8 deletions
diff --git a/tests/migrations/test_writer.py b/tests/migrations/test_writer.py
index fc3d3bc909..dcc5c9e62f 100644
--- a/tests/migrations/test_writer.py
+++ b/tests/migrations/test_writer.py
@@ -300,6 +300,18 @@ class WriterTests(SimpleTestCase):
self.assertEqual(value.null, new_value.null)
self.assertEqual(value.unique, new_value.unique)
+ def assertSerializedFunctoolsPartialEqual(
+ self, value, expected_string, expected_imports
+ ):
+ string, imports = MigrationWriter.serialize(value)
+ self.assertEqual(string, expected_string)
+ self.assertEqual(imports, expected_imports)
+ result = self.serialize_round_trip(value)
+ self.assertEqual(result.func, value.func)
+ self.assertEqual(result.args, value.args)
+ self.assertEqual(result.keywords, value.keywords)
+ return result
+
def test_serialize_numbers(self):
self.assertSerializedEqual(1)
self.assertSerializedEqual(1.2)
@@ -895,19 +907,59 @@ class WriterTests(SimpleTestCase):
self.assertSerializedEqual(datetime.timedelta(minutes=42))
def test_serialize_functools_partial(self):
+ value = functools.partial(datetime.timedelta)
+ string, imports = MigrationWriter.serialize(value)
+ self.assertSerializedFunctoolsPartialEqual(
+ value,
+ "functools.partial(datetime.timedelta, *(), **{})",
+ {"import datetime", "import functools"},
+ )
+
+ def test_serialize_functools_partial_posarg(self):
+ value = functools.partial(datetime.timedelta, 1)
+ string, imports = MigrationWriter.serialize(value)
+ self.assertSerializedFunctoolsPartialEqual(
+ value,
+ "functools.partial(datetime.timedelta, *(1,), **{})",
+ {"import datetime", "import functools"},
+ )
+
+ def test_serialize_functools_partial_kwarg(self):
+ value = functools.partial(datetime.timedelta, seconds=2)
+ string, imports = MigrationWriter.serialize(value)
+ self.assertSerializedFunctoolsPartialEqual(
+ value,
+ "functools.partial(datetime.timedelta, *(), **{'seconds': 2})",
+ {"import datetime", "import functools"},
+ )
+
+ def test_serialize_functools_partial_mixed(self):
value = functools.partial(datetime.timedelta, 1, seconds=2)
- result = self.serialize_round_trip(value)
- self.assertEqual(result.func, value.func)
- self.assertEqual(result.args, value.args)
- self.assertEqual(result.keywords, value.keywords)
+ string, imports = MigrationWriter.serialize(value)
+ self.assertSerializedFunctoolsPartialEqual(
+ value,
+ "functools.partial(datetime.timedelta, *(1,), **{'seconds': 2})",
+ {"import datetime", "import functools"},
+ )
+
+ def test_serialize_functools_partial_non_identifier_keyword(self):
+ value = functools.partial(datetime.timedelta, **{"kebab-case": 1})
+ string, imports = MigrationWriter.serialize(value)
+ self.assertSerializedFunctoolsPartialEqual(
+ value,
+ "functools.partial(datetime.timedelta, *(), **{'kebab-case': 1})",
+ {"import datetime", "import functools"},
+ )
def test_serialize_functools_partialmethod(self):
value = functools.partialmethod(datetime.timedelta, 1, seconds=2)
- result = self.serialize_round_trip(value)
+ string, imports = MigrationWriter.serialize(value)
+ result = self.assertSerializedFunctoolsPartialEqual(
+ value,
+ "functools.partialmethod(datetime.timedelta, *(1,), **{'seconds': 2})",
+ {"import datetime", "import functools"},
+ )
self.assertIsInstance(result, functools.partialmethod)
- self.assertEqual(result.func, value.func)
- self.assertEqual(result.args, value.args)
- self.assertEqual(result.keywords, value.keywords)
def test_serialize_type_none(self):
self.assertSerializedEqual(NoneType)