summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorThomas Chaumeny <t.chaumeny@gmail.com>2014-09-14 12:34:41 +0200
committerAnssi Kääriäinen <akaariai@gmail.com>2014-10-28 10:02:10 +0200
commit00aa562884a418c4ee20e223ab82c3455997ee7d (patch)
tree934393c39e5087bad689217003a1de59484e0000
parent6b39401bafa955f4891700996aa666349fcdef74 (diff)
Fixed #23493 -- Added bilateral attribute to Transform
-rw-r--r--django/db/models/lookups.py103
-rw-r--r--django/db/models/sql/query.py5
-rw-r--r--docs/howto/custom-lookups.txt48
-rw-r--r--docs/ref/models/lookups.txt9
-rw-r--r--docs/releases/1.8.txt5
-rw-r--r--tests/custom_lookups/tests.py126
6 files changed, 268 insertions, 28 deletions
diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py
index 66bdde54b4..abb5645147 100644
--- a/django/db/models/lookups.py
+++ b/django/db/models/lookups.py
@@ -1,5 +1,4 @@
from copy import copy
-from itertools import repeat
import inspect
from django.conf import settings
@@ -7,6 +6,8 @@ from django.utils import timezone
from django.utils.functional import cached_property
from django.utils.six.moves import xrange
+from .query_utils import QueryWrapper
+
class RegisterLookupMixin(object):
def _get_lookup(self, lookup_name):
@@ -57,6 +58,9 @@ class RegisterLookupMixin(object):
class Transform(RegisterLookupMixin):
+
+ bilateral = False
+
def __init__(self, lhs, lookups):
self.lhs = lhs
self.init_lookups = lookups[:]
@@ -78,9 +82,42 @@ class Transform(RegisterLookupMixin):
class Lookup(RegisterLookupMixin):
lookup_name = None
- def __init__(self, lhs, rhs):
+ def __init__(self, lhs, rhs, bilateral_transforms=None):
self.lhs, self.rhs = lhs, rhs
self.rhs = self.get_prep_lookup()
+ if bilateral_transforms is None:
+ bilateral_transforms = []
+ if bilateral_transforms:
+ # We should warn the user as soon as possible if he is trying to apply
+ # a bilateral transformation on a nested QuerySet: that won't work.
+ # We need to import QuerySet here so as to avoid circular
+ from django.db.models.query import QuerySet
+ if isinstance(rhs, QuerySet):
+ raise NotImplementedError("Bilateral transformations on nested querysets are not supported.")
+ self.bilateral_transforms = bilateral_transforms
+
+ def apply_bilateral_transforms(self, value):
+ for transform, lookups in self.bilateral_transforms:
+ value = transform(value, lookups)
+ return value
+
+ def batch_process_rhs(self, qn, connection, rhs=None):
+ if rhs is None:
+ rhs = self.rhs
+ if self.bilateral_transforms:
+ sqls, sqls_params = [], []
+ for p in rhs:
+ value = QueryWrapper('%s',
+ [self.lhs.output_field.get_db_prep_value(p, connection)])
+ value = self.apply_bilateral_transforms(value)
+ sql, sql_params = qn.compile(value)
+ sqls.append(sql)
+ sqls_params.extend(sql_params)
+ else:
+ params = self.lhs.output_field.get_db_prep_lookup(
+ self.lookup_name, rhs, connection, prepared=True)
+ sqls, sqls_params = ['%s'] * len(params), params
+ return sqls, sqls_params
def get_prep_lookup(self):
return self.lhs.output_field.get_prep_lookup(self.lookup_name, self.rhs)
@@ -96,6 +133,13 @@ class Lookup(RegisterLookupMixin):
def process_rhs(self, qn, connection):
value = self.rhs
+ if self.bilateral_transforms:
+ if self.rhs_is_direct_value():
+ # Do not call get_db_prep_lookup here as the value will be
+ # transformed before being used for lookup
+ value = QueryWrapper("%s",
+ [self.lhs.output_field.get_db_prep_value(value, connection)])
+ value = self.apply_bilateral_transforms(value)
# Due to historical reasons there are a couple of different
# ways to produce sql here. get_compiler is likely a Query
# instance, _as_sql QuerySet and as_sql just something with
@@ -203,15 +247,19 @@ default_lookups['lte'] = LessThanOrEqual
class In(BuiltinLookup):
lookup_name = 'in'
- def get_db_prep_lookup(self, value, connection):
- params = self.lhs.output_field.get_db_prep_lookup(
- self.lookup_name, value, connection, prepared=True)
- if not params:
- # TODO: check why this leads to circular import
- from django.db.models.sql.datastructures import EmptyResultSet
- raise EmptyResultSet
- placeholder = '(' + ', '.join('%s' for p in params) + ')'
- return (placeholder, params)
+ def process_rhs(self, qn, connection):
+ if self.rhs_is_direct_value():
+ # rhs should be an iterable, we use batch_process_rhs
+ # to prepare/transform those values
+ rhs = list(self.rhs)
+ if not rhs:
+ from django.db.models.sql.datastructures import EmptyResultSet
+ raise EmptyResultSet
+ sqls, sqls_params = self.batch_process_rhs(qn, connection, rhs)
+ placeholder = '(' + ', '.join(sqls) + ')'
+ return (placeholder, sqls_params)
+ else:
+ return super(In, self).process_rhs(qn, connection)
def get_rhs_op(self, connection, rhs):
return 'IN %s' % rhs
@@ -220,8 +268,10 @@ class In(BuiltinLookup):
max_in_list_size = connection.ops.max_in_list_size()
if self.rhs_is_direct_value() and (max_in_list_size and
len(self.rhs) > max_in_list_size):
- rhs, rhs_params = self.process_rhs(qn, connection)
+ # This is a special case for Oracle which limits the number of elements
+ # which can appear in an 'IN' clause.
lhs, lhs_params = self.process_lhs(qn, connection)
+ rhs, rhs_params = self.batch_process_rhs(qn, connection)
in_clause_elements = ['(']
params = []
for offset in xrange(0, len(rhs_params), max_in_list_size):
@@ -229,11 +279,12 @@ class In(BuiltinLookup):
in_clause_elements.append(' OR ')
in_clause_elements.append('%s IN (' % lhs)
params.extend(lhs_params)
- group_size = min(len(rhs_params) - offset, max_in_list_size)
- param_group = ', '.join(repeat('%s', group_size))
+ sqls = rhs[offset: offset + max_in_list_size]
+ sqls_params = rhs_params[offset: offset + max_in_list_size]
+ param_group = ', '.join(sqls)
in_clause_elements.append(param_group)
in_clause_elements.append(')')
- params.extend(rhs_params[offset: offset + max_in_list_size])
+ params.extend(sqls_params)
in_clause_elements.append(')')
return ''.join(in_clause_elements), params
else:
@@ -252,10 +303,10 @@ class PatternLookup(BuiltinLookup):
# we need to add the % pattern match to the lookup by something like
# col LIKE othercol || '%%'
# So, for Python values we don't need any special pattern, but for
- # SQL reference values we need the correct pattern added.
- value = self.rhs
- if (hasattr(value, 'get_compiler') or hasattr(value, 'as_sql')
- or hasattr(value, '_as_sql')):
+ # SQL reference values or SQL transformations we need the correct
+ # pattern added.
+ if (hasattr(self.rhs, 'get_compiler') or hasattr(self.rhs, 'as_sql')
+ or hasattr(self.rhs, '_as_sql') or self.bilateral_transforms):
return connection.pattern_ops[self.lookup_name] % rhs
else:
return super(PatternLookup, self).get_rhs_op(connection, rhs)
@@ -291,8 +342,20 @@ class Year(Between):
default_lookups['year'] = Year
-class Range(Between):
+class Range(BuiltinLookup):
lookup_name = 'range'
+
+ def get_rhs_op(self, connection, rhs):
+ return "BETWEEN %s AND %s" % (rhs[0], rhs[1])
+
+ def process_rhs(self, qn, connection):
+ if self.rhs_is_direct_value():
+ # rhs should be an iterable of 2 values, we use batch_process_rhs
+ # to prepare/transform those values
+ return self.batch_process_rhs(qn, connection)
+ else:
+ return super(Range, self).process_rhs(qn, connection)
+
default_lookups['range'] = Range
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
index 08af1fb008..b6690e4526 100644
--- a/django/db/models/sql/query.py
+++ b/django/db/models/sql/query.py
@@ -1111,18 +1111,21 @@ class Query(object):
def build_lookup(self, lookups, lhs, rhs):
lookups = lookups[:]
+ bilaterals = []
while lookups:
lookup = lookups[0]
if len(lookups) == 1:
final_lookup = lhs.get_lookup(lookup)
if final_lookup:
- return final_lookup(lhs, rhs)
+ return final_lookup(lhs, rhs, bilaterals)
# We didn't find a lookup, so we are going to try get_transform
# + get_lookup('exact').
lookups.append('exact')
next = lhs.get_transform(lookup)
if next:
lhs = next(lhs, lookups)
+ if getattr(next, 'bilateral', False):
+ bilaterals.append((next, lookups))
else:
raise FieldError(
"Unsupported lookup '%s' for %s or join on the field not "
diff --git a/docs/howto/custom-lookups.txt b/docs/howto/custom-lookups.txt
index 820a2ef574..d3ed726ba3 100644
--- a/docs/howto/custom-lookups.txt
+++ b/docs/howto/custom-lookups.txt
@@ -127,7 +127,7 @@ function ``ABS()`` to transform the value before comparison::
lhs, params = qn.compile(self.lhs)
return "ABS(%s)" % lhs, params
-Next, lets register it for ``IntegerField``::
+Next, let's register it for ``IntegerField``::
from django.db.models import IntegerField
IntegerField.register_lookup(AbsoluteValue)
@@ -144,9 +144,7 @@ SQL::
SELECT ... WHERE ABS("experiments"."change") < 27
-Subclasses of ``Transform`` usually only operate on the left-hand side of the
-expression. Further lookups will work on the transformed value. Note that in
-this case where there is no other lookup specified, Django interprets
+Note that in case there is no other lookup specified, Django interprets
``change__abs=27`` as ``change__abs__exact=27``.
When looking for which lookups are allowable after the ``Transform`` has been
@@ -197,7 +195,7 @@ Notice also that as both sides are used multiple times in the query the params
need to contain ``lhs_params`` and ``rhs_params`` multiple times.
The final query does the inversion (``27`` to ``-27``) directly in the
-database. The reason for doing this is that if the self.rhs is something else
+database. The reason for doing this is that if the ``self.rhs`` is something else
than a plain integer value (for example an ``F()`` reference) we can't do the
transformations in Python.
@@ -208,6 +206,46 @@ transformations in Python.
want to add an index on ``abs(change)`` which would allow these queries to
be very efficient.
+A bilateral transformer example
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+The ``AbsoluteValue`` example we discussed previously is a transformation which
+applies to the left-hand side of the lookup. There may be some cases where you
+want the transformation to be applied to both the left-hand side and the
+right-hand side. For instance, if you want to filter a queryset based on the
+equality of the left and right-hand side insensitively to some SQL function.
+
+Let's examine the simple example of case-insensitive transformation here. This
+transformation isn't very useful in practice as Django already comes with a bunch
+of built-in case-insensitive lookups, but it will be a nice demonstration of
+bilateral transformations in a database-agnostic way.
+
+We define an ``UpperCase`` transformer which uses the SQL function ``UPPER()`` to
+transform the values before comparison. We define
+:attr:`bilateral = True <django.db.models.Transform.bilateral>` to indicate that
+this transformation should apply to both ``lhs`` and ``rhs``::
+
+ from django.db.models import Transform
+
+ class UpperCase(Transform):
+ lookup_name = 'upper'
+ bilateral = True
+
+ def as_sql(self, qn, connection):
+ lhs, params = qn.compile(self.lhs)
+ return "UPPER(%s)" % lhs, params
+
+Next, let's register it::
+
+ from django.db.models import CharField, TextField
+ CharField.register_lookup(UpperCase)
+ TextField.register_lookup(UpperCase)
+
+Now, the queryset ``Author.objects.filter(name__upper="doe")`` will generate a case
+insensitive query like this::
+
+ SELECT ... WHERE UPPER("author"."name") = UPPER('doe')
+
Writing alternative implementations for existing lookups
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
diff --git a/docs/ref/models/lookups.txt b/docs/ref/models/lookups.txt
index d3f64c07a9..da338b7cb2 100644
--- a/docs/ref/models/lookups.txt
+++ b/docs/ref/models/lookups.txt
@@ -129,6 +129,15 @@ Transform reference
This class follows the :ref:`Query Expression API <query-expression>`, which
implies that you can use ``<expression>__<transform1>__<transform2>``.
+ .. attribute:: bilateral
+
+ .. versionadded:: 1.8
+
+ A boolean indicating whether this transformation should apply to both
+ ``lhs`` and ``rhs``. Bilateral transformations will be applied to ``rhs`` in
+ the same order as they appear in the lookup expression. By default it is set
+ to ``False``. For example usage, see :doc:`/howto/custom-lookups`.
+
.. attribute:: lhs
The left-hand side - what is being transformed. It must follow the
diff --git a/docs/releases/1.8.txt b/docs/releases/1.8.txt
index 6b08e2b5a1..e5d3282874 100644
--- a/docs/releases/1.8.txt
+++ b/docs/releases/1.8.txt
@@ -306,6 +306,11 @@ Models
* :doc:`Custom Lookups</howto/custom-lookups>` can now be registered using
a decorator pattern.
+* The new :attr:`Transform.bilateral <django.db.models.Transform.bilateral>`
+ attribute allows creating bilateral transformations. These transformations
+ are applied to both ``lhs`` and ``rhs`` when used in a lookup expression,
+ providing opportunities for more sophisticated lookups.
+
Signals
^^^^^^^
diff --git a/tests/custom_lookups/tests.py b/tests/custom_lookups/tests.py
index a965e5a4e8..d0f18c5d7b 100644
--- a/tests/custom_lookups/tests.py
+++ b/tests/custom_lookups/tests.py
@@ -17,7 +17,7 @@ class Div3Lookup(models.Lookup):
lhs, params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection)
params.extend(rhs_params)
- return '%s %%%% 3 = %s' % (lhs, rhs), params
+ return '(%s) %%%% 3 = %s' % (lhs, rhs), params
def as_oracle(self, qn, connection):
lhs, params = self.process_lhs(qn, connection)
@@ -31,12 +31,32 @@ class Div3Transform(models.Transform):
def as_sql(self, qn, connection):
lhs, lhs_params = qn.compile(self.lhs)
- return '%s %%%% 3' % (lhs,), lhs_params
+ return '(%s) %%%% 3' % lhs, lhs_params
def as_oracle(self, qn, connection):
lhs, lhs_params = qn.compile(self.lhs)
return 'mod(%s, 3)' % lhs, lhs_params
+class Div3BilateralTransform(Div3Transform):
+ bilateral = True
+
+
+class Mult3BilateralTransform(models.Transform):
+ bilateral = True
+ lookup_name = 'mult3'
+
+ def as_sql(self, qn, connection):
+ lhs, lhs_params = qn.compile(self.lhs)
+ return '3 * (%s)' % lhs, lhs_params
+
+class UpperBilateralTransform(models.Transform):
+ bilateral = True
+ lookup_name = 'upper'
+
+ def as_sql(self, qn, connection):
+ lhs, lhs_params = qn.compile(self.lhs)
+ return 'UPPER(%s)' % lhs, lhs_params
+
class YearTransform(models.Transform):
lookup_name = 'year'
@@ -225,10 +245,112 @@ class LookupTests(TestCase):
self.assertQuerysetEqual(
baseqs.filter(age__div3__in=[0, 2]),
[a2, a3], lambda x: x)
+ self.assertQuerysetEqual(
+ baseqs.filter(age__div3__in=[2, 4]),
+ [a2], lambda x: x)
+ self.assertQuerysetEqual(
+ baseqs.filter(age__div3__gte=3),
+ [], lambda x: x)
+ self.assertQuerysetEqual(
+ baseqs.filter(age__div3__range=(1, 2)),
+ [a1, a2, a4], lambda x: x)
finally:
models.IntegerField._unregister_lookup(Div3Transform)
+class BilateralTransformTests(TestCase):
+
+ def test_bilateral_upper(self):
+ models.CharField.register_lookup(UpperBilateralTransform)
+ try:
+ Author.objects.bulk_create([
+ Author(name='Doe'),
+ Author(name='doe'),
+ Author(name='Foo'),
+ ])
+ self.assertQuerysetEqual(
+ Author.objects.filter(name__upper='doe'),
+ ["<Author: Doe>", "<Author: doe>"], ordered=False)
+ finally:
+ models.CharField._unregister_lookup(UpperBilateralTransform)
+
+ def test_bilateral_inner_qs(self):
+ models.CharField.register_lookup(UpperBilateralTransform)
+ try:
+ with self.assertRaises(NotImplementedError):
+ Author.objects.filter(name__upper__in=Author.objects.values_list('name'))
+ finally:
+ models.CharField._unregister_lookup(UpperBilateralTransform)
+
+ def test_div3_bilateral_extract(self):
+ models.IntegerField.register_lookup(Div3BilateralTransform)
+ try:
+ a1 = Author.objects.create(name='a1', age=1)
+ a2 = Author.objects.create(name='a2', age=2)
+ a3 = Author.objects.create(name='a3', age=3)
+ a4 = Author.objects.create(name='a4', age=4)
+ baseqs = Author.objects.order_by('name')
+ self.assertQuerysetEqual(
+ baseqs.filter(age__div3=2),
+ [a2], lambda x: x)
+ self.assertQuerysetEqual(
+ baseqs.filter(age__div3__lte=3),
+ [a3], lambda x: x)
+ self.assertQuerysetEqual(
+ baseqs.filter(age__div3__in=[0, 2]),
+ [a2, a3], lambda x: x)
+ self.assertQuerysetEqual(
+ baseqs.filter(age__div3__in=[2, 4]),
+ [a1, a2, a4], lambda x: x)
+ self.assertQuerysetEqual(
+ baseqs.filter(age__div3__gte=3),
+ [a1, a2, a3, a4], lambda x: x)
+ self.assertQuerysetEqual(
+ baseqs.filter(age__div3__range=(1, 2)),
+ [a1, a2, a4], lambda x: x)
+ finally:
+ models.IntegerField._unregister_lookup(Div3BilateralTransform)
+
+ def test_bilateral_order(self):
+ models.IntegerField.register_lookup(Mult3BilateralTransform)
+ models.IntegerField.register_lookup(Div3BilateralTransform)
+ try:
+ a1 = Author.objects.create(name='a1', age=1)
+ a2 = Author.objects.create(name='a2', age=2)
+ a3 = Author.objects.create(name='a3', age=3)
+ a4 = Author.objects.create(name='a4', age=4)
+ baseqs = Author.objects.order_by('name')
+
+ self.assertQuerysetEqual(
+ baseqs.filter(age__mult3__div3=42),
+ # mult3__div3 always leads to 0
+ [a1, a2, a3, a4], lambda x: x)
+ self.assertQuerysetEqual(
+ baseqs.filter(age__div3__mult3=42),
+ [a3], lambda x: x)
+ finally:
+ models.IntegerField._unregister_lookup(Mult3BilateralTransform)
+ models.IntegerField._unregister_lookup(Div3BilateralTransform)
+
+ def test_bilateral_fexpr(self):
+ models.IntegerField.register_lookup(Mult3BilateralTransform)
+ try:
+ a1 = Author.objects.create(name='a1', age=1, average_rating=3.2)
+ a2 = Author.objects.create(name='a2', age=2, average_rating=0.5)
+ a3 = Author.objects.create(name='a3', age=3, average_rating=1.5)
+ a4 = Author.objects.create(name='a4', age=4)
+ baseqs = Author.objects.order_by('name')
+ self.assertQuerysetEqual(
+ baseqs.filter(age__mult3=models.F('age')),
+ [a1, a2, a3, a4], lambda x: x)
+ self.assertQuerysetEqual(
+ # Same as age >= average_rating
+ baseqs.filter(age__mult3__gte=models.F('average_rating')),
+ [a2, a3], lambda x: x)
+ finally:
+ models.IntegerField._unregister_lookup(Mult3BilateralTransform)
+
+
class YearLteTests(TestCase):
def setUp(self):
models.DateField.register_lookup(YearTransform)