diff options
| author | Sergey Fedoseev <fedoseev.sergey@gmail.com> | 2017-09-11 20:56:39 +0500 |
|---|---|---|
| committer | Tim Graham <timograham@gmail.com> | 2017-09-11 11:56:39 -0400 |
| commit | 3905cfa1a578275323bfbfbef09f5aee05b33301 (patch) | |
| tree | 9311789febe66894dbe92857b4a800644e550704 /tests | |
| parent | 99e65d648842d4715f32682117adc01223fef316 (diff) | |
Fixed #28353 -- Fixed some GIS functions when queryset is evaluated more than once.
Reverted test for refs #27603 in favor of using FuncTestMixin.
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/gis_tests/distapp/tests.py | 6 | ||||
| -rw-r--r-- | tests/gis_tests/geo3d/tests.py | 3 | ||||
| -rw-r--r-- | tests/gis_tests/geoapp/test_functions.py | 9 | ||||
| -rw-r--r-- | tests/gis_tests/geogapp/tests.py | 4 | ||||
| -rw-r--r-- | tests/gis_tests/test_fields.py | 4 | ||||
| -rw-r--r-- | tests/gis_tests/test_gis_tests_utils.py | 52 | ||||
| -rw-r--r-- | tests/gis_tests/utils.py | 39 |
7 files changed, 104 insertions, 13 deletions
diff --git a/tests/gis_tests/distapp/tests.py b/tests/gis_tests/distapp/tests.py index 395f7226ef..d162759513 100644 --- a/tests/gis_tests/distapp/tests.py +++ b/tests/gis_tests/distapp/tests.py @@ -9,7 +9,9 @@ from django.db import connection from django.db.models import F, Q from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature -from ..utils import mysql, no_oracle, oracle, postgis, spatialite +from ..utils import ( + FuncTestMixin, mysql, no_oracle, oracle, postgis, spatialite, +) from .models import ( AustraliaCity, CensusZipcode, Interstate, SouthTexasCity, SouthTexasCityFt, SouthTexasInterstate, SouthTexasZipcode, @@ -262,7 +264,7 @@ Perimeter(geom1) | OK | :-( ''' # NOQA -class DistanceFunctionsTests(TestCase): +class DistanceFunctionsTests(FuncTestMixin, TestCase): fixtures = ['initial'] @skipUnlessDBFeature("has_Area_function") diff --git a/tests/gis_tests/geo3d/tests.py b/tests/gis_tests/geo3d/tests.py index 39603d1249..d2e85f0607 100644 --- a/tests/gis_tests/geo3d/tests.py +++ b/tests/gis_tests/geo3d/tests.py @@ -8,6 +8,7 @@ from django.contrib.gis.db.models.functions import ( from django.contrib.gis.geos import GEOSGeometry, LineString, Point, Polygon from django.test import TestCase, skipUnlessDBFeature +from ..utils import FuncTestMixin from .models import ( City3D, Interstate2D, Interstate3D, InterstateProj2D, InterstateProj3D, MultiPoint3D, Point2D, Point3D, Polygon2D, Polygon3D, @@ -205,7 +206,7 @@ class Geo3DTest(Geo3DLoadingHelper, TestCase): @skipUnlessDBFeature("supports_3d_functions") -class Geo3DFunctionsTests(Geo3DLoadingHelper, TestCase): +class Geo3DFunctionsTests(FuncTestMixin, Geo3DLoadingHelper, TestCase): def test_kml(self): """ Test KML() function with Z values. diff --git a/tests/gis_tests/geoapp/test_functions.py b/tests/gis_tests/geoapp/test_functions.py index bb13d9e37f..cdd05d78ff 100644 --- a/tests/gis_tests/geoapp/test_functions.py +++ b/tests/gis_tests/geoapp/test_functions.py @@ -12,11 +12,11 @@ from django.db import connection from django.db.models import Sum from django.test import TestCase, skipUnlessDBFeature -from ..utils import mysql, oracle, postgis, spatialite +from ..utils import FuncTestMixin, mysql, oracle, postgis, spatialite from .models import City, Country, CountryWebMercator, State, Track -class GISFunctionsTests(TestCase): +class GISFunctionsTests(FuncTestMixin, TestCase): """ Testing functions from django/contrib/gis/db/models/functions.py. Area/Distance/Length/Perimeter are tested in distapp/tests. @@ -127,11 +127,8 @@ class GISFunctionsTests(TestCase): City.objects.annotate(kml=functions.AsKML('name')) # Ensuring the KML is as expected. - qs = City.objects.annotate(kml=functions.AsKML('point', precision=9)) - ptown = qs.get(name='Pueblo') + ptown = City.objects.annotate(kml=functions.AsKML('point', precision=9)).get(name='Pueblo') self.assertEqual('<Point><coordinates>-104.609252,38.255001</coordinates></Point>', ptown.kml) - # Same result if the queryset is evaluated again. - self.assertEqual(qs.get(name='Pueblo').kml, ptown.kml) @skipUnlessDBFeature("has_AsSVG_function") def test_assvg(self): diff --git a/tests/gis_tests/geogapp/tests.py b/tests/gis_tests/geogapp/tests.py index 2969ca1cc6..c9986fd78b 100644 --- a/tests/gis_tests/geogapp/tests.py +++ b/tests/gis_tests/geogapp/tests.py @@ -11,7 +11,7 @@ from django.db import connection from django.db.models.functions import Cast from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature -from ..utils import oracle, postgis, spatialite +from ..utils import FuncTestMixin, oracle, postgis, spatialite from .models import City, County, Zipcode @@ -86,7 +86,7 @@ class GeographyTest(TestCase): self.assertEqual(state, c.state) -class GeographyFunctionTests(TestCase): +class GeographyFunctionTests(FuncTestMixin, TestCase): fixtures = ['initial'] @skipUnlessDBFeature("supports_extent_aggr") diff --git a/tests/gis_tests/test_fields.py b/tests/gis_tests/test_fields.py index fb0c953f21..27db3e1dfa 100644 --- a/tests/gis_tests/test_fields.py +++ b/tests/gis_tests/test_fields.py @@ -7,9 +7,9 @@ from django.test import SimpleTestCase class FieldsTests(SimpleTestCase): def test_area_field_deepcopy(self): - field = AreaField() + field = AreaField(None) self.assertEqual(copy.deepcopy(field), field) def test_distance_field_deepcopy(self): - field = DistanceField() + field = DistanceField(None) self.assertEqual(copy.deepcopy(field), field) diff --git a/tests/gis_tests/test_gis_tests_utils.py b/tests/gis_tests/test_gis_tests_utils.py new file mode 100644 index 0000000000..32d072fd9b --- /dev/null +++ b/tests/gis_tests/test_gis_tests_utils.py @@ -0,0 +1,52 @@ +from django.db import connection, models +from django.db.models.expressions import Func +from django.test import SimpleTestCase + +from .utils import FuncTestMixin + + +def test_mutation(raises=True): + def wrapper(mutation_func): + def test(test_case_instance, *args, **kwargs): + class TestFunc(Func): + output_field = models.IntegerField() + + def __init__(self): + self.attribute = 'initial' + super().__init__('initial', ['initial']) + + def as_sql(self, *args, **kwargs): + mutation_func(self) + return '', () + + if raises: + msg = 'TestFunc Func was mutated during compilation.' + with test_case_instance.assertRaisesMessage(AssertionError, msg): + getattr(TestFunc(), 'as_' + connection.vendor)(None, None) + else: + getattr(TestFunc(), 'as_' + connection.vendor)(None, None) + + return test + return wrapper + + +class FuncTestMixinTests(FuncTestMixin, SimpleTestCase): + @test_mutation() + def test_mutated_attribute(func): + func.attribute = 'mutated' + + @test_mutation() + def test_mutated_expressions(func): + func.source_expressions.clear() + + @test_mutation() + def test_mutated_expression(func): + func.source_expressions[0].name = 'mutated' + + @test_mutation() + def test_mutated_expression_deep(func): + func.source_expressions[1].value[0] = 'mutated' + + @test_mutation(raises=False) + def test_not_mutated(func): + pass diff --git a/tests/gis_tests/utils.py b/tests/gis_tests/utils.py index 6eb029c1d5..b30da7e40d 100644 --- a/tests/gis_tests/utils.py +++ b/tests/gis_tests/utils.py @@ -1,8 +1,11 @@ +import copy import unittest from functools import wraps +from unittest import mock from django.conf import settings from django.db import DEFAULT_DB_ALIAS, connection +from django.db.models.expressions import Func def skipUnlessGISLookup(*gis_lookups): @@ -56,3 +59,39 @@ elif spatialite: from django.contrib.gis.db.backends.spatialite.models import SpatialiteSpatialRefSys as SpatialRefSys else: SpatialRefSys = None + + +class FuncTestMixin: + """Assert that Func expressions aren't mutated during their as_sql().""" + def setUp(self): + def as_sql_wrapper(original_as_sql): + def inner(*args, **kwargs): + func = original_as_sql.__self__ + # Resolve output_field before as_sql() so touching it in + # as_sql() won't change __dict__. + func.output_field + __dict__original = copy.deepcopy(func.__dict__) + result = original_as_sql(*args, **kwargs) + msg = '%s Func was mutated during compilation.' % func.__class__.__name__ + self.assertEqual(func.__dict__, __dict__original, msg) + return result + return inner + + def __getattribute__(self, name): + if name != vendor_impl: + return __getattribute__original(self, name) + try: + as_sql = __getattribute__original(self, vendor_impl) + except AttributeError: + as_sql = __getattribute__original(self, 'as_sql') + return as_sql_wrapper(as_sql) + + vendor_impl = 'as_' + connection.vendor + __getattribute__original = Func.__getattribute__ + self.func_patcher = mock.patch.object(Func, '__getattribute__', __getattribute__) + self.func_patcher.start() + super().setUp() + + def tearDown(self): + super().tearDown() + self.func_patcher.stop() |
