summaryrefslogtreecommitdiff
path: root/tests/gis_tests/rasterapp/test_rasterfield.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/gis_tests/rasterapp/test_rasterfield.py')
-rw-r--r--tests/gis_tests/rasterapp/test_rasterfield.py47
1 files changed, 46 insertions, 1 deletions
diff --git a/tests/gis_tests/rasterapp/test_rasterfield.py b/tests/gis_tests/rasterapp/test_rasterfield.py
index 3f2ce770a9..89c4ec4856 100644
--- a/tests/gis_tests/rasterapp/test_rasterfield.py
+++ b/tests/gis_tests/rasterapp/test_rasterfield.py
@@ -2,7 +2,11 @@ import json
from django.contrib.gis.db.models.fields import BaseSpatialField
from django.contrib.gis.db.models.functions import Distance
-from django.contrib.gis.db.models.lookups import DistanceLookupBase, GISLookup
+from django.contrib.gis.db.models.lookups import (
+ DistanceLookupBase,
+ GISLookup,
+ RasterBandTransform,
+)
from django.contrib.gis.gdal import GDALRaster
from django.contrib.gis.geos import GEOSGeometry
from django.contrib.gis.measure import D
@@ -356,6 +360,47 @@ class RasterFieldTest(TransactionTestCase):
with self.assertRaisesMessage(ValueError, msg):
qs.count()
+ def test_lookup_invalid_band_rhs(self):
+ rast = GDALRaster(json.loads(JSON_RASTER))
+ qs = RasterModel.objects.filter(rast__contains=(rast, "evil"))
+ msg = "Band index must be an integer, but got 'str'."
+ with self.assertRaisesMessage(TypeError, msg):
+ qs.count()
+
+ def test_lookup_invalid_band_lhs(self):
+ """
+ Typical left-hand side usage is protected against non-integers, but for
+ defense-in-depth purposes, construct custom lookups that evade the
+ `int()` and `+ 1` checks in the lookups shipped by django.contrib.gis.
+ """
+
+ # Evade the int() call in RasterField.get_transform().
+ class MyRasterBandTransform(RasterBandTransform):
+ band_index = "evil"
+
+ def process_band_indices(self, *args, **kwargs):
+ self.band_lhs = self.lhs.band_index
+ self.band_rhs, *self.rhs_params = self.rhs_params
+
+ # Evade the `+ 1` call in BaseSpatialField.process_band_indices().
+ ContainsLookup = RasterModel._meta.get_field("rast").get_lookup("contains")
+
+ class MyContainsLookup(ContainsLookup):
+ def process_band_indices(self, *args, **kwargs):
+ self.band_lhs = self.lhs.band_index
+ self.band_rhs, *self.rhs_params = self.rhs_params
+
+ RasterField = RasterModel._meta.get_field("rast")
+ RasterField.register_lookup(MyContainsLookup, "contains")
+ self.addCleanup(RasterField.register_lookup, ContainsLookup, "contains")
+
+ qs = RasterModel.objects.annotate(
+ transformed=MyRasterBandTransform("rast")
+ ).filter(transformed__contains=(F("transformed"), 1))
+ msg = "Band index must be an integer, but got 'str'."
+ with self.assertRaisesMessage(TypeError, msg):
+ list(qs)
+
def test_isvalid_lookup_with_raster_error(self):
qs = RasterModel.objects.filter(rast__isvalid=True)
msg = (