summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorcan <cansarigol@derinbilgi.com.tr>2019-01-10 21:05:19 +0300
committerTim Graham <timograham@gmail.com>2019-01-11 18:13:16 -0500
commit7d3b3897c1d7b1ae4dfea6ae0d4f431d3e3dec1c (patch)
tree00c28282b95df023394c985151b20118b7aba572
parent3c01fe30f3dd4dc1c8bb4fec816bd277d1ae5fa6 (diff)
Refs #29738 -- Allowed registering serializers with MigrationWriter.
-rw-r--r--django/db/migrations/serializer.py72
-rw-r--r--django/db/migrations/writer.py10
-rw-r--r--docs/releases/2.2.txt3
-rw-r--r--docs/topics/migrations.txt29
-rw-r--r--tests/migrations/test_writer.py16
5 files changed, 93 insertions, 37 deletions
diff --git a/django/db/migrations/serializer.py b/django/db/migrations/serializer.py
index d395313ff4..ace0a860c4 100644
--- a/django/db/migrations/serializer.py
+++ b/django/db/migrations/serializer.py
@@ -8,6 +8,7 @@ import math
import re
import types
import uuid
+from collections import OrderedDict
from django.conf import SettingsReference
from django.db import models
@@ -271,6 +272,38 @@ class UUIDSerializer(BaseSerializer):
return "uuid.%s" % repr(self.value), {"import uuid"}
+class Serializer:
+ _registry = OrderedDict([
+ (frozenset, FrozensetSerializer),
+ (list, SequenceSerializer),
+ (set, SetSerializer),
+ (tuple, TupleSerializer),
+ (dict, DictionarySerializer),
+ (enum.Enum, EnumSerializer),
+ (datetime.datetime, DatetimeDatetimeSerializer),
+ ((datetime.date, datetime.timedelta, datetime.time), DateTimeSerializer),
+ (SettingsReference, SettingsReferenceSerializer),
+ (float, FloatSerializer),
+ ((bool, int, type(None), bytes, str), BaseSimpleSerializer),
+ (decimal.Decimal, DecimalSerializer),
+ ((functools.partial, functools.partialmethod), FunctoolsPartialSerializer),
+ ((types.FunctionType, types.BuiltinFunctionType, types.MethodType), FunctionTypeSerializer),
+ (collections.abc.Iterable, IterableSerializer),
+ ((COMPILED_REGEX_TYPE, RegexObject), RegexSerializer),
+ (uuid.UUID, UUIDSerializer),
+ ])
+
+ @classmethod
+ def register(cls, type_, serializer):
+ if not issubclass(serializer, BaseSerializer):
+ raise ValueError("'%s' must inherit from 'BaseSerializer'." % serializer.__name__)
+ cls._registry[type_] = serializer
+
+ @classmethod
+ def unregister(cls, type_):
+ cls._registry.pop(type_)
+
+
def serializer_factory(value):
if isinstance(value, Promise):
value = str(value)
@@ -290,42 +323,9 @@ def serializer_factory(value):
# Anything that knows how to deconstruct itself.
if hasattr(value, 'deconstruct'):
return DeconstructableSerializer(value)
-
- # Unfortunately some of these are order-dependent.
- if isinstance(value, frozenset):
- return FrozensetSerializer(value)
- if isinstance(value, list):
- return SequenceSerializer(value)
- if isinstance(value, set):
- return SetSerializer(value)
- if isinstance(value, tuple):
- return TupleSerializer(value)
- if isinstance(value, dict):
- return DictionarySerializer(value)
- if isinstance(value, enum.Enum):
- return EnumSerializer(value)
- if isinstance(value, datetime.datetime):
- return DatetimeDatetimeSerializer(value)
- if isinstance(value, (datetime.date, datetime.timedelta, datetime.time)):
- return DateTimeSerializer(value)
- if isinstance(value, SettingsReference):
- return SettingsReferenceSerializer(value)
- if isinstance(value, float):
- return FloatSerializer(value)
- if isinstance(value, (bool, int, type(None), bytes, str)):
- return BaseSimpleSerializer(value)
- if isinstance(value, decimal.Decimal):
- return DecimalSerializer(value)
- if isinstance(value, (functools.partial, functools.partialmethod)):
- return FunctoolsPartialSerializer(value)
- if isinstance(value, (types.FunctionType, types.BuiltinFunctionType, types.MethodType)):
- return FunctionTypeSerializer(value)
- if isinstance(value, collections.abc.Iterable):
- return IterableSerializer(value)
- if isinstance(value, (COMPILED_REGEX_TYPE, RegexObject)):
- return RegexSerializer(value)
- if isinstance(value, uuid.UUID):
- return UUIDSerializer(value)
+ for type_, serializer_cls in Serializer._registry.items():
+ if isinstance(value, type_):
+ return serializer_cls(value)
raise ValueError(
"Cannot serialize: %r\nThere are some values Django cannot serialize into "
"migration files.\nFor more, see https://docs.djangoproject.com/en/%s/"
diff --git a/django/db/migrations/writer.py b/django/db/migrations/writer.py
index 1e001da4e6..047436ffab 100644
--- a/django/db/migrations/writer.py
+++ b/django/db/migrations/writer.py
@@ -8,7 +8,7 @@ from django.apps import apps
from django.conf import SettingsReference # NOQA
from django.db import migrations
from django.db.migrations.loader import MigrationLoader
-from django.db.migrations.serializer import serializer_factory
+from django.db.migrations.serializer import Serializer, serializer_factory
from django.utils.inspect import get_func_args
from django.utils.module_loading import module_dir
from django.utils.timezone import now
@@ -270,6 +270,14 @@ class MigrationWriter:
def serialize(cls, value):
return serializer_factory(value).serialize()
+ @classmethod
+ def register_serializer(cls, type_, serializer):
+ Serializer.register(type_, serializer)
+
+ @classmethod
+ def unregister_serializer(cls, type_):
+ Serializer.unregister(type_)
+
MIGRATION_HEADER_TEMPLATE = """\
# Generated by Django %(version)s on %(timestamp)s
diff --git a/docs/releases/2.2.txt b/docs/releases/2.2.txt
index 87a7ef4931..c371d50281 100644
--- a/docs/releases/2.2.txt
+++ b/docs/releases/2.2.txt
@@ -211,6 +211,9 @@ Migrations
* ``NoneType`` can now be serialized in migrations.
+* You can now :ref:`register custom serializers <custom-migration-serializers>`
+ for migrations.
+
Models
~~~~~~
diff --git a/docs/topics/migrations.txt b/docs/topics/migrations.txt
index b44f78cc69..2f33b97878 100644
--- a/docs/topics/migrations.txt
+++ b/docs/topics/migrations.txt
@@ -697,6 +697,35 @@ Django cannot serialize:
- Arbitrary class instances (e.g. ``MyClass(4.3, 5.7)``)
- Lambdas
+.. _custom-migration-serializers:
+
+Custom serializers
+------------------
+
+.. versionadded:: 2.2
+
+You can serialize other types by writing a custom serializer. For example, if
+Django didn't serialize :class:`~decimal.Decimal` by default, you could do
+this::
+
+ from decimal import Decimal
+
+ from django.db.migrations.serializer import BaseSerializer
+ from django.db.migrations.writer import MigrationWriter
+
+ class DecimalSerializer(BaseSerializer):
+ def serialize(self):
+ return repr(self.value), {'from decimal import Decimal'}
+
+ MigrationWriter.register_serializer(Decimal, DecimalSerializer)
+
+The first argument of ``MigrationWriter.register_serializer()`` is a type or
+iterable of types that should use the serializer.
+
+The ``serialize()`` method of your serializer must return a string of how the
+value should appear in migrations and a set of any imports that are needed in
+the migration.
+
.. _custom-deconstruct-method:
Adding a ``deconstruct()`` method
diff --git a/tests/migrations/test_writer.py b/tests/migrations/test_writer.py
index 8e30342763..abeeaf5182 100644
--- a/tests/migrations/test_writer.py
+++ b/tests/migrations/test_writer.py
@@ -15,6 +15,7 @@ from django import get_version
from django.conf import SettingsReference, settings
from django.core.validators import EmailValidator, RegexValidator
from django.db import migrations, models
+from django.db.migrations.serializer import BaseSerializer
from django.db.migrations.writer import MigrationWriter, OperationWriter
from django.test import SimpleTestCase
from django.utils.deconstruct import deconstructible
@@ -653,3 +654,18 @@ class WriterTests(SimpleTestCase):
string = MigrationWriter.serialize(models.CharField(default=DeconstructibleInstances))[0]
self.assertEqual(string, "models.CharField(default=migrations.test_writer.DeconstructibleInstances)")
+
+ def test_register_serializer(self):
+ class ComplexSerializer(BaseSerializer):
+ def serialize(self):
+ return 'complex(%r)' % self.value, {}
+
+ MigrationWriter.register_serializer(complex, ComplexSerializer)
+ self.assertSerializedEqual(complex(1, 2))
+ MigrationWriter.unregister_serializer(complex)
+ with self.assertRaisesMessage(ValueError, 'Cannot serialize: (1+2j)'):
+ self.assertSerializedEqual(complex(1, 2))
+
+ def test_register_non_serializer(self):
+ with self.assertRaisesMessage(ValueError, "'TestModel1' must inherit from 'BaseSerializer'."):
+ MigrationWriter.register_serializer(complex, TestModel1)