diff options
| author | Adam Johnson <me@adamj.eu> | 2025-05-12 17:49:38 -0300 |
|---|---|---|
| committer | nessita <124304+nessita@users.noreply.github.com> | 2025-05-13 21:42:19 -0300 |
| commit | 57fdc104d26df0a060f637f2128d830bfcc8e4f8 (patch) | |
| tree | 478bf51a5b59df95f39e4eb2a6aa45d8d4b28aa3 /tests/migrations | |
| parent | 4647e2b8663cbd22a07af70bf0f8540946763851 (diff) | |
Refs #36383 -- Added extra tests for serializing functools.partial in tests/migrations/test_writer.py.
This includes a test helper to better assert over the expected output.
Co-authored-by: Natalia <124304+nessita@users.noreply.github.com>
Diffstat (limited to 'tests/migrations')
| -rw-r--r-- | tests/migrations/test_writer.py | 68 |
1 files changed, 60 insertions, 8 deletions
diff --git a/tests/migrations/test_writer.py b/tests/migrations/test_writer.py index fc3d3bc909..dcc5c9e62f 100644 --- a/tests/migrations/test_writer.py +++ b/tests/migrations/test_writer.py @@ -300,6 +300,18 @@ class WriterTests(SimpleTestCase): self.assertEqual(value.null, new_value.null) self.assertEqual(value.unique, new_value.unique) + def assertSerializedFunctoolsPartialEqual( + self, value, expected_string, expected_imports + ): + string, imports = MigrationWriter.serialize(value) + self.assertEqual(string, expected_string) + self.assertEqual(imports, expected_imports) + result = self.serialize_round_trip(value) + self.assertEqual(result.func, value.func) + self.assertEqual(result.args, value.args) + self.assertEqual(result.keywords, value.keywords) + return result + def test_serialize_numbers(self): self.assertSerializedEqual(1) self.assertSerializedEqual(1.2) @@ -895,19 +907,59 @@ class WriterTests(SimpleTestCase): self.assertSerializedEqual(datetime.timedelta(minutes=42)) def test_serialize_functools_partial(self): + value = functools.partial(datetime.timedelta) + string, imports = MigrationWriter.serialize(value) + self.assertSerializedFunctoolsPartialEqual( + value, + "functools.partial(datetime.timedelta, *(), **{})", + {"import datetime", "import functools"}, + ) + + def test_serialize_functools_partial_posarg(self): + value = functools.partial(datetime.timedelta, 1) + string, imports = MigrationWriter.serialize(value) + self.assertSerializedFunctoolsPartialEqual( + value, + "functools.partial(datetime.timedelta, *(1,), **{})", + {"import datetime", "import functools"}, + ) + + def test_serialize_functools_partial_kwarg(self): + value = functools.partial(datetime.timedelta, seconds=2) + string, imports = MigrationWriter.serialize(value) + self.assertSerializedFunctoolsPartialEqual( + value, + "functools.partial(datetime.timedelta, *(), **{'seconds': 2})", + {"import datetime", "import functools"}, + ) + + def test_serialize_functools_partial_mixed(self): value = functools.partial(datetime.timedelta, 1, seconds=2) - result = self.serialize_round_trip(value) - self.assertEqual(result.func, value.func) - self.assertEqual(result.args, value.args) - self.assertEqual(result.keywords, value.keywords) + string, imports = MigrationWriter.serialize(value) + self.assertSerializedFunctoolsPartialEqual( + value, + "functools.partial(datetime.timedelta, *(1,), **{'seconds': 2})", + {"import datetime", "import functools"}, + ) + + def test_serialize_functools_partial_non_identifier_keyword(self): + value = functools.partial(datetime.timedelta, **{"kebab-case": 1}) + string, imports = MigrationWriter.serialize(value) + self.assertSerializedFunctoolsPartialEqual( + value, + "functools.partial(datetime.timedelta, *(), **{'kebab-case': 1})", + {"import datetime", "import functools"}, + ) def test_serialize_functools_partialmethod(self): value = functools.partialmethod(datetime.timedelta, 1, seconds=2) - result = self.serialize_round_trip(value) + string, imports = MigrationWriter.serialize(value) + result = self.assertSerializedFunctoolsPartialEqual( + value, + "functools.partialmethod(datetime.timedelta, *(1,), **{'seconds': 2})", + {"import datetime", "import functools"}, + ) self.assertIsInstance(result, functools.partialmethod) - self.assertEqual(result.func, value.func) - self.assertEqual(result.args, value.args) - self.assertEqual(result.keywords, value.keywords) def test_serialize_type_none(self): self.assertSerializedEqual(NoneType) |
