diff options
| author | Marc Tamlyn <marc.tamlyn@gmail.com> | 2015-01-10 16:14:20 +0000 |
|---|---|---|
| committer | Marc Tamlyn <marc.tamlyn@gmail.com> | 2015-01-10 16:18:19 +0000 |
| commit | 48ad288679a0cb2e2cfb17f128903e6c5b1c4870 (patch) | |
| tree | 75bacb810dbe071058b5c5cf7d8dcb8e20f3f500 /tests/postgres_tests/test_ranges.py | |
| parent | 916e38802f151b34aaca487dc7e928946e81be73 (diff) | |
Fixed #24001 -- Added range fields for PostgreSQL.
Added support for PostgreSQL range types to contrib.postgres.
- 5 new model fields
- 4 new form fields
- New validators
- Uses psycopg2's range type implementation in python
Diffstat (limited to 'tests/postgres_tests/test_ranges.py')
| -rw-r--r-- | tests/postgres_tests/test_ranges.py | 376 |
1 files changed, 376 insertions, 0 deletions
diff --git a/tests/postgres_tests/test_ranges.py b/tests/postgres_tests/test_ranges.py new file mode 100644 index 0000000000..6d35e62cc5 --- /dev/null +++ b/tests/postgres_tests/test_ranges.py @@ -0,0 +1,376 @@ +import datetime +import json +import unittest + +from django import forms +from django.contrib.postgres import forms as pg_forms, fields as pg_fields +from django.contrib.postgres.validators import RangeMaxValueValidator, RangeMinValueValidator +from django.core import exceptions, serializers +from django.db import connection +from django.test import TestCase +from django.utils import timezone + +from psycopg2.extras import NumericRange, DateTimeTZRange, DateRange + +from .models import RangesModel + + +def skipUnlessPG92(test): + if not connection.vendor == 'postgresql': + return unittest.skip('PostgreSQL required')(test) + PG_VERSION = connection.pg_version + if PG_VERSION < 90200: + return unittest.skip('PostgreSQL >= 9.2 required')(test) + return test + + +@skipUnlessPG92 +class TestSaveLoad(TestCase): + + def test_all_fields(self): + now = timezone.now() + instance = RangesModel( + ints=NumericRange(0, 10), + bigints=NumericRange(10, 20), + floats=NumericRange(20, 30), + timestamps=DateTimeTZRange(now - datetime.timedelta(hours=1), now), + dates=DateRange(now.date() - datetime.timedelta(days=1), now.date()), + ) + instance.save() + loaded = RangesModel.objects.get() + self.assertEqual(instance.ints, loaded.ints) + self.assertEqual(instance.bigints, loaded.bigints) + self.assertEqual(instance.floats, loaded.floats) + self.assertEqual(instance.timestamps, loaded.timestamps) + self.assertEqual(instance.dates, loaded.dates) + + def test_range_object(self): + r = NumericRange(0, 10) + instance = RangesModel(ints=r) + instance.save() + loaded = RangesModel.objects.get() + self.assertEqual(r, loaded.ints) + + def test_tuple(self): + instance = RangesModel(ints=(0, 10)) + instance.save() + loaded = RangesModel.objects.get() + self.assertEqual(NumericRange(0, 10), loaded.ints) + + def test_range_object_boundaries(self): + r = NumericRange(0, 10, '[]') + instance = RangesModel(floats=r) + instance.save() + loaded = RangesModel.objects.get() + self.assertEqual(r, loaded.floats) + self.assertTrue(10 in loaded.floats) + + def test_unbounded(self): + r = NumericRange(None, None, '()') + instance = RangesModel(floats=r) + instance.save() + loaded = RangesModel.objects.get() + self.assertEqual(r, loaded.floats) + + def test_empty(self): + r = NumericRange(empty=True) + instance = RangesModel(ints=r) + instance.save() + loaded = RangesModel.objects.get() + self.assertEqual(r, loaded.ints) + + def test_null(self): + instance = RangesModel(ints=None) + instance.save() + loaded = RangesModel.objects.get() + self.assertEqual(None, loaded.ints) + + +@skipUnlessPG92 +class TestQuerying(TestCase): + + @classmethod + def setUpTestData(cls): + cls.objs = [ + RangesModel.objects.create(ints=NumericRange(0, 10)), + RangesModel.objects.create(ints=NumericRange(5, 15)), + RangesModel.objects.create(ints=NumericRange(None, 0)), + RangesModel.objects.create(ints=NumericRange(empty=True)), + RangesModel.objects.create(ints=None), + ] + + def test_exact(self): + self.assertSequenceEqual( + RangesModel.objects.filter(ints__exact=NumericRange(0, 10)), + [self.objs[0]], + ) + + def test_isnull(self): + self.assertSequenceEqual( + RangesModel.objects.filter(ints__isnull=True), + [self.objs[4]], + ) + + def test_isempty(self): + self.assertSequenceEqual( + RangesModel.objects.filter(ints__isempty=True), + [self.objs[3]], + ) + + def test_contains(self): + self.assertSequenceEqual( + RangesModel.objects.filter(ints__contains=8), + [self.objs[0], self.objs[1]], + ) + + def test_contains_range(self): + self.assertSequenceEqual( + RangesModel.objects.filter(ints__contains=NumericRange(3, 8)), + [self.objs[0]], + ) + + def test_contained_by(self): + self.assertSequenceEqual( + RangesModel.objects.filter(ints__contained_by=NumericRange(0, 20)), + [self.objs[0], self.objs[1], self.objs[3]], + ) + + def test_overlap(self): + self.assertSequenceEqual( + RangesModel.objects.filter(ints__overlap=NumericRange(3, 8)), + [self.objs[0], self.objs[1]], + ) + + def test_fully_lt(self): + self.assertSequenceEqual( + RangesModel.objects.filter(ints__fully_lt=NumericRange(5, 10)), + [self.objs[2]], + ) + + def test_fully_gt(self): + self.assertSequenceEqual( + RangesModel.objects.filter(ints__fully_gt=NumericRange(5, 10)), + [], + ) + + def test_not_lt(self): + self.assertSequenceEqual( + RangesModel.objects.filter(ints__not_lt=NumericRange(5, 10)), + [self.objs[1]], + ) + + def test_not_gt(self): + self.assertSequenceEqual( + RangesModel.objects.filter(ints__not_gt=NumericRange(5, 10)), + [self.objs[0], self.objs[2]], + ) + + def test_adjacent_to(self): + self.assertSequenceEqual( + RangesModel.objects.filter(ints__adjacent_to=NumericRange(0, 5)), + [self.objs[1], self.objs[2]], + ) + + def test_startswith(self): + self.assertSequenceEqual( + RangesModel.objects.filter(ints__startswith=0), + [self.objs[0]], + ) + + def test_endswith(self): + self.assertSequenceEqual( + RangesModel.objects.filter(ints__endswith=0), + [self.objs[2]], + ) + + def test_startswith_chaining(self): + self.assertSequenceEqual( + RangesModel.objects.filter(ints__startswith__gte=0), + [self.objs[0], self.objs[1]], + ) + + +@skipUnlessPG92 +class TestSerialization(TestCase): + test_data = ( + '[{"fields": {"ints": "{\\"upper\\": 10, \\"lower\\": 0, ' + '\\"bounds\\": \\"[)\\"}", "floats": "{\\"empty\\": true}", ' + '"bigints": null, "timestamps": null, "dates": null}, ' + '"model": "postgres_tests.rangesmodel", "pk": null}]' + ) + + def test_dumping(self): + instance = RangesModel(ints=NumericRange(0, 10), floats=NumericRange(empty=True)) + data = serializers.serialize('json', [instance]) + dumped = json.loads(data) + dumped[0]['fields']['ints'] = json.loads(dumped[0]['fields']['ints']) + check = json.loads(self.test_data) + check[0]['fields']['ints'] = json.loads(check[0]['fields']['ints']) + self.assertEqual(dumped, check) + + def test_loading(self): + instance = list(serializers.deserialize('json', self.test_data))[0].object + self.assertEqual(instance.ints, NumericRange(0, 10)) + self.assertEqual(instance.floats, NumericRange(empty=True)) + self.assertEqual(instance.dates, None) + + +class TestValidators(TestCase): + + def test_max(self): + validator = RangeMaxValueValidator(5) + validator(NumericRange(0, 5)) + with self.assertRaises(exceptions.ValidationError) as cm: + validator(NumericRange(0, 10)) + self.assertEqual(cm.exception.messages[0], 'Ensure that this range is completely less than or equal to 5.') + self.assertEqual(cm.exception.code, 'max_value') + + def test_min(self): + validator = RangeMinValueValidator(5) + validator(NumericRange(10, 15)) + with self.assertRaises(exceptions.ValidationError) as cm: + validator(NumericRange(0, 10)) + self.assertEqual(cm.exception.messages[0], 'Ensure that this range is completely greater than or equal to 5.') + self.assertEqual(cm.exception.code, 'min_value') + + +class TestFormField(TestCase): + + def test_valid_integer(self): + field = pg_forms.IntegerRangeField() + value = field.clean(['1', '2']) + self.assertEqual(value, NumericRange(1, 2)) + + def test_valid_floats(self): + field = pg_forms.FloatRangeField() + value = field.clean(['1.12345', '2.001']) + self.assertEqual(value, NumericRange(1.12345, 2.001)) + + def test_valid_timestamps(self): + field = pg_forms.DateTimeRangeField() + value = field.clean(['01/01/2014 00:00:00', '02/02/2014 12:12:12']) + lower = datetime.datetime(2014, 1, 1, 0, 0, 0) + upper = datetime.datetime(2014, 2, 2, 12, 12, 12) + self.assertEqual(value, DateTimeTZRange(lower, upper)) + + def test_valid_dates(self): + field = pg_forms.DateRangeField() + value = field.clean(['01/01/2014', '02/02/2014']) + lower = datetime.date(2014, 1, 1) + upper = datetime.date(2014, 2, 2) + self.assertEqual(value, DateRange(lower, upper)) + + def test_using_split_datetime_widget(self): + class SplitDateTimeRangeField(pg_forms.DateTimeRangeField): + base_field = forms.SplitDateTimeField + + class SplitForm(forms.Form): + field = SplitDateTimeRangeField() + + form = SplitForm() + self.assertHTMLEqual(str(form), ''' + <tr> + <th> + <label for="id_field_0">Field:</label> + </th> + <td> + <input id="id_field_0_0" name="field_0_0" type="text" /> + <input id="id_field_0_1" name="field_0_1" type="text" /> + <input id="id_field_1_0" name="field_1_0" type="text" /> + <input id="id_field_1_1" name="field_1_1" type="text" /> + </td> + </tr> + ''') + form = SplitForm({ + 'field_0_0': '01/01/2014', + 'field_0_1': '00:00:00', + 'field_1_0': '02/02/2014', + 'field_1_1': '12:12:12', + }) + self.assertTrue(form.is_valid()) + lower = datetime.datetime(2014, 1, 1, 0, 0, 0) + upper = datetime.datetime(2014, 2, 2, 12, 12, 12) + self.assertEqual(form.cleaned_data['field'], DateTimeTZRange(lower, upper)) + + def test_none(self): + field = pg_forms.IntegerRangeField(required=False) + value = field.clean(['', '']) + self.assertEqual(value, None) + + def test_rendering(self): + class RangeForm(forms.Form): + ints = pg_forms.IntegerRangeField() + + self.assertHTMLEqual(str(RangeForm()), ''' + <tr> + <th><label for="id_ints_0">Ints:</label></th> + <td> + <input id="id_ints_0" name="ints_0" type="number" /> + <input id="id_ints_1" name="ints_1" type="number" /> + </td> + </tr> + ''') + + def test_lower_bound_higher(self): + field = pg_forms.IntegerRangeField() + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean(['10', '2']) + self.assertEqual(cm.exception.messages[0], 'The start of the range must not exceed the end of the range.') + self.assertEqual(cm.exception.code, 'bound_ordering') + + def test_open(self): + field = pg_forms.IntegerRangeField() + value = field.clean(['', '0']) + self.assertEqual(value, NumericRange(None, 0)) + + def test_incorrect_data_type(self): + field = pg_forms.IntegerRangeField() + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean('1') + self.assertEqual(cm.exception.messages[0], 'Enter two valid values.') + self.assertEqual(cm.exception.code, 'invalid') + + def test_invalid_lower(self): + field = pg_forms.IntegerRangeField() + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean(['a', '2']) + self.assertEqual(cm.exception.messages[0], 'Enter a whole number.') + + def test_invalid_upper(self): + field = pg_forms.IntegerRangeField() + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean(['1', 'b']) + self.assertEqual(cm.exception.messages[0], 'Enter a whole number.') + + def test_required(self): + field = pg_forms.IntegerRangeField(required=True) + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean(['', '']) + self.assertEqual(cm.exception.messages[0], 'This field is required.') + value = field.clean([1, '']) + self.assertEqual(value, NumericRange(1, None)) + + def test_model_field_formfield_integer(self): + model_field = pg_fields.IntegerRangeField() + form_field = model_field.formfield() + self.assertIsInstance(form_field, pg_forms.IntegerRangeField) + + def test_model_field_formfield_biginteger(self): + model_field = pg_fields.BigIntegerRangeField() + form_field = model_field.formfield() + self.assertIsInstance(form_field, pg_forms.IntegerRangeField) + + def test_model_field_formfield_float(self): + model_field = pg_fields.FloatRangeField() + form_field = model_field.formfield() + self.assertIsInstance(form_field, pg_forms.FloatRangeField) + + def test_model_field_formfield_date(self): + model_field = pg_fields.DateRangeField() + form_field = model_field.formfield() + self.assertIsInstance(form_field, pg_forms.DateRangeField) + + def test_model_field_formfield_datetime(self): + model_field = pg_fields.DateTimeRangeField() + form_field = model_field.formfield() + self.assertIsInstance(form_field, pg_forms.DateTimeRangeField) |
