summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorSergey Fedoseev <fedoseev.sergey@gmail.com>2017-09-11 20:56:39 +0500
committerTim Graham <timograham@gmail.com>2017-09-11 11:56:39 -0400
commit3905cfa1a578275323bfbfbef09f5aee05b33301 (patch)
tree9311789febe66894dbe92857b4a800644e550704 /tests
parent99e65d648842d4715f32682117adc01223fef316 (diff)
Fixed #28353 -- Fixed some GIS functions when queryset is evaluated more than once.
Reverted test for refs #27603 in favor of using FuncTestMixin.
Diffstat (limited to 'tests')
-rw-r--r--tests/gis_tests/distapp/tests.py6
-rw-r--r--tests/gis_tests/geo3d/tests.py3
-rw-r--r--tests/gis_tests/geoapp/test_functions.py9
-rw-r--r--tests/gis_tests/geogapp/tests.py4
-rw-r--r--tests/gis_tests/test_fields.py4
-rw-r--r--tests/gis_tests/test_gis_tests_utils.py52
-rw-r--r--tests/gis_tests/utils.py39
7 files changed, 104 insertions, 13 deletions
diff --git a/tests/gis_tests/distapp/tests.py b/tests/gis_tests/distapp/tests.py
index 395f7226ef..d162759513 100644
--- a/tests/gis_tests/distapp/tests.py
+++ b/tests/gis_tests/distapp/tests.py
@@ -9,7 +9,9 @@ from django.db import connection
from django.db.models import F, Q
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
-from ..utils import mysql, no_oracle, oracle, postgis, spatialite
+from ..utils import (
+ FuncTestMixin, mysql, no_oracle, oracle, postgis, spatialite,
+)
from .models import (
AustraliaCity, CensusZipcode, Interstate, SouthTexasCity, SouthTexasCityFt,
SouthTexasInterstate, SouthTexasZipcode,
@@ -262,7 +264,7 @@ Perimeter(geom1) | OK | :-(
''' # NOQA
-class DistanceFunctionsTests(TestCase):
+class DistanceFunctionsTests(FuncTestMixin, TestCase):
fixtures = ['initial']
@skipUnlessDBFeature("has_Area_function")
diff --git a/tests/gis_tests/geo3d/tests.py b/tests/gis_tests/geo3d/tests.py
index 39603d1249..d2e85f0607 100644
--- a/tests/gis_tests/geo3d/tests.py
+++ b/tests/gis_tests/geo3d/tests.py
@@ -8,6 +8,7 @@ from django.contrib.gis.db.models.functions import (
from django.contrib.gis.geos import GEOSGeometry, LineString, Point, Polygon
from django.test import TestCase, skipUnlessDBFeature
+from ..utils import FuncTestMixin
from .models import (
City3D, Interstate2D, Interstate3D, InterstateProj2D, InterstateProj3D,
MultiPoint3D, Point2D, Point3D, Polygon2D, Polygon3D,
@@ -205,7 +206,7 @@ class Geo3DTest(Geo3DLoadingHelper, TestCase):
@skipUnlessDBFeature("supports_3d_functions")
-class Geo3DFunctionsTests(Geo3DLoadingHelper, TestCase):
+class Geo3DFunctionsTests(FuncTestMixin, Geo3DLoadingHelper, TestCase):
def test_kml(self):
"""
Test KML() function with Z values.
diff --git a/tests/gis_tests/geoapp/test_functions.py b/tests/gis_tests/geoapp/test_functions.py
index bb13d9e37f..cdd05d78ff 100644
--- a/tests/gis_tests/geoapp/test_functions.py
+++ b/tests/gis_tests/geoapp/test_functions.py
@@ -12,11 +12,11 @@ from django.db import connection
from django.db.models import Sum
from django.test import TestCase, skipUnlessDBFeature
-from ..utils import mysql, oracle, postgis, spatialite
+from ..utils import FuncTestMixin, mysql, oracle, postgis, spatialite
from .models import City, Country, CountryWebMercator, State, Track
-class GISFunctionsTests(TestCase):
+class GISFunctionsTests(FuncTestMixin, TestCase):
"""
Testing functions from django/contrib/gis/db/models/functions.py.
Area/Distance/Length/Perimeter are tested in distapp/tests.
@@ -127,11 +127,8 @@ class GISFunctionsTests(TestCase):
City.objects.annotate(kml=functions.AsKML('name'))
# Ensuring the KML is as expected.
- qs = City.objects.annotate(kml=functions.AsKML('point', precision=9))
- ptown = qs.get(name='Pueblo')
+ ptown = City.objects.annotate(kml=functions.AsKML('point', precision=9)).get(name='Pueblo')
self.assertEqual('<Point><coordinates>-104.609252,38.255001</coordinates></Point>', ptown.kml)
- # Same result if the queryset is evaluated again.
- self.assertEqual(qs.get(name='Pueblo').kml, ptown.kml)
@skipUnlessDBFeature("has_AsSVG_function")
def test_assvg(self):
diff --git a/tests/gis_tests/geogapp/tests.py b/tests/gis_tests/geogapp/tests.py
index 2969ca1cc6..c9986fd78b 100644
--- a/tests/gis_tests/geogapp/tests.py
+++ b/tests/gis_tests/geogapp/tests.py
@@ -11,7 +11,7 @@ from django.db import connection
from django.db.models.functions import Cast
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
-from ..utils import oracle, postgis, spatialite
+from ..utils import FuncTestMixin, oracle, postgis, spatialite
from .models import City, County, Zipcode
@@ -86,7 +86,7 @@ class GeographyTest(TestCase):
self.assertEqual(state, c.state)
-class GeographyFunctionTests(TestCase):
+class GeographyFunctionTests(FuncTestMixin, TestCase):
fixtures = ['initial']
@skipUnlessDBFeature("supports_extent_aggr")
diff --git a/tests/gis_tests/test_fields.py b/tests/gis_tests/test_fields.py
index fb0c953f21..27db3e1dfa 100644
--- a/tests/gis_tests/test_fields.py
+++ b/tests/gis_tests/test_fields.py
@@ -7,9 +7,9 @@ from django.test import SimpleTestCase
class FieldsTests(SimpleTestCase):
def test_area_field_deepcopy(self):
- field = AreaField()
+ field = AreaField(None)
self.assertEqual(copy.deepcopy(field), field)
def test_distance_field_deepcopy(self):
- field = DistanceField()
+ field = DistanceField(None)
self.assertEqual(copy.deepcopy(field), field)
diff --git a/tests/gis_tests/test_gis_tests_utils.py b/tests/gis_tests/test_gis_tests_utils.py
new file mode 100644
index 0000000000..32d072fd9b
--- /dev/null
+++ b/tests/gis_tests/test_gis_tests_utils.py
@@ -0,0 +1,52 @@
+from django.db import connection, models
+from django.db.models.expressions import Func
+from django.test import SimpleTestCase
+
+from .utils import FuncTestMixin
+
+
+def test_mutation(raises=True):
+ def wrapper(mutation_func):
+ def test(test_case_instance, *args, **kwargs):
+ class TestFunc(Func):
+ output_field = models.IntegerField()
+
+ def __init__(self):
+ self.attribute = 'initial'
+ super().__init__('initial', ['initial'])
+
+ def as_sql(self, *args, **kwargs):
+ mutation_func(self)
+ return '', ()
+
+ if raises:
+ msg = 'TestFunc Func was mutated during compilation.'
+ with test_case_instance.assertRaisesMessage(AssertionError, msg):
+ getattr(TestFunc(), 'as_' + connection.vendor)(None, None)
+ else:
+ getattr(TestFunc(), 'as_' + connection.vendor)(None, None)
+
+ return test
+ return wrapper
+
+
+class FuncTestMixinTests(FuncTestMixin, SimpleTestCase):
+ @test_mutation()
+ def test_mutated_attribute(func):
+ func.attribute = 'mutated'
+
+ @test_mutation()
+ def test_mutated_expressions(func):
+ func.source_expressions.clear()
+
+ @test_mutation()
+ def test_mutated_expression(func):
+ func.source_expressions[0].name = 'mutated'
+
+ @test_mutation()
+ def test_mutated_expression_deep(func):
+ func.source_expressions[1].value[0] = 'mutated'
+
+ @test_mutation(raises=False)
+ def test_not_mutated(func):
+ pass
diff --git a/tests/gis_tests/utils.py b/tests/gis_tests/utils.py
index 6eb029c1d5..b30da7e40d 100644
--- a/tests/gis_tests/utils.py
+++ b/tests/gis_tests/utils.py
@@ -1,8 +1,11 @@
+import copy
import unittest
from functools import wraps
+from unittest import mock
from django.conf import settings
from django.db import DEFAULT_DB_ALIAS, connection
+from django.db.models.expressions import Func
def skipUnlessGISLookup(*gis_lookups):
@@ -56,3 +59,39 @@ elif spatialite:
from django.contrib.gis.db.backends.spatialite.models import SpatialiteSpatialRefSys as SpatialRefSys
else:
SpatialRefSys = None
+
+
+class FuncTestMixin:
+ """Assert that Func expressions aren't mutated during their as_sql()."""
+ def setUp(self):
+ def as_sql_wrapper(original_as_sql):
+ def inner(*args, **kwargs):
+ func = original_as_sql.__self__
+ # Resolve output_field before as_sql() so touching it in
+ # as_sql() won't change __dict__.
+ func.output_field
+ __dict__original = copy.deepcopy(func.__dict__)
+ result = original_as_sql(*args, **kwargs)
+ msg = '%s Func was mutated during compilation.' % func.__class__.__name__
+ self.assertEqual(func.__dict__, __dict__original, msg)
+ return result
+ return inner
+
+ def __getattribute__(self, name):
+ if name != vendor_impl:
+ return __getattribute__original(self, name)
+ try:
+ as_sql = __getattribute__original(self, vendor_impl)
+ except AttributeError:
+ as_sql = __getattribute__original(self, 'as_sql')
+ return as_sql_wrapper(as_sql)
+
+ vendor_impl = 'as_' + connection.vendor
+ __getattribute__original = Func.__getattribute__
+ self.func_patcher = mock.patch.object(Func, '__getattribute__', __getattribute__)
+ self.func_patcher.start()
+ super().setUp()
+
+ def tearDown(self):
+ super().tearDown()
+ self.func_patcher.stop()