summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/aggregation/tests.py2
-rw-r--r--tests/custom_lookups/__init__.py0
-rw-r--r--tests/custom_lookups/models.py13
-rw-r--r--tests/custom_lookups/tests.py279
-rw-r--r--tests/expressions/tests.py21
-rw-r--r--tests/null_queries/tests.py3
-rw-r--r--tests/queries/tests.py13
7 files changed, 323 insertions, 8 deletions
diff --git a/tests/aggregation/tests.py b/tests/aggregation/tests.py
index eee61654bc..6ea10278f2 100644
--- a/tests/aggregation/tests.py
+++ b/tests/aggregation/tests.py
@@ -443,7 +443,7 @@ class BaseAggregateTestCase(TestCase):
vals = Author.objects.filter(pk=1).aggregate(Count("friends__id"))
self.assertEqual(vals, {"friends__id__count": 2})
- books = Book.objects.annotate(num_authors=Count("authors__name")).filter(num_authors__ge=2).order_by("pk")
+ books = Book.objects.annotate(num_authors=Count("authors__name")).filter(num_authors__exact=2).order_by("pk")
self.assertQuerysetEqual(
books, [
"The Definitive Guide to Django: Web Development Done Right",
diff --git a/tests/custom_lookups/__init__.py b/tests/custom_lookups/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/custom_lookups/__init__.py
diff --git a/tests/custom_lookups/models.py b/tests/custom_lookups/models.py
new file mode 100644
index 0000000000..9841b36ce5
--- /dev/null
+++ b/tests/custom_lookups/models.py
@@ -0,0 +1,13 @@
+from django.db import models
+from django.utils.encoding import python_2_unicode_compatible
+
+
+@python_2_unicode_compatible
+class Author(models.Model):
+ name = models.CharField(max_length=20)
+ age = models.IntegerField(null=True)
+ birthdate = models.DateField(null=True)
+ average_rating = models.FloatField(null=True)
+
+ def __str__(self):
+ return self.name
diff --git a/tests/custom_lookups/tests.py b/tests/custom_lookups/tests.py
new file mode 100644
index 0000000000..9f1e7fd44a
--- /dev/null
+++ b/tests/custom_lookups/tests.py
@@ -0,0 +1,279 @@
+from datetime import date
+import unittest
+
+from django.test import TestCase
+from .models import Author
+from django.db import models
+from django.db import connection
+
+
+class Div3Lookup(models.Lookup):
+ lookup_name = 'div3'
+
+ def as_sql(self, qn, connection):
+ lhs, params = self.process_lhs(qn, connection)
+ rhs, rhs_params = self.process_rhs(qn, connection)
+ params.extend(rhs_params)
+ return '%s %%%% 3 = %s' % (lhs, rhs), params
+
+
+class Div3Transform(models.Transform):
+ lookup_name = 'div3'
+
+ def as_sql(self, qn, connection):
+ lhs, lhs_params = qn.compile(self.lhs)
+ return '%s %%%% 3' % (lhs,), lhs_params
+
+
+class YearTransform(models.Transform):
+ lookup_name = 'year'
+
+ def as_sql(self, qn, connection):
+ lhs_sql, params = qn.compile(self.lhs)
+ return connection.ops.date_extract_sql('year', lhs_sql), params
+
+ @property
+ def output_type(self):
+ return models.IntegerField()
+
+
+class YearExact(models.lookups.Lookup):
+ lookup_name = 'exact'
+
+ def as_sql(self, qn, connection):
+ # We will need to skip the extract part, and instead go
+ # directly with the originating field, that is self.lhs.lhs
+ lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs)
+ rhs_sql, rhs_params = self.process_rhs(qn, connection)
+ # Note that we must be careful so that we have params in the
+ # same order as we have the parts in the SQL.
+ params = lhs_params + rhs_params + lhs_params + rhs_params
+ # We use PostgreSQL specific SQL here. Note that we must do the
+ # conversions in SQL instead of in Python to support F() references.
+ return ("%(lhs)s >= (%(rhs)s || '-01-01')::date "
+ "AND %(lhs)s <= (%(rhs)s || '-12-31')::date" %
+ {'lhs': lhs_sql, 'rhs': rhs_sql}, params)
+YearTransform.register_lookup(YearExact)
+
+
+class YearLte(models.lookups.LessThanOrEqual):
+ """
+ The purpose of this lookup is to efficiently compare the year of the field.
+ """
+
+ def as_sql(self, qn, connection):
+ # Skip the YearTransform above us (no possibility for efficient
+ # lookup otherwise).
+ real_lhs = self.lhs.lhs
+ lhs_sql, params = self.process_lhs(qn, connection, real_lhs)
+ rhs_sql, rhs_params = self.process_rhs(qn, connection)
+ params.extend(rhs_params)
+ # Build SQL where the integer year is concatenated with last month
+ # and day, then convert that to date. (We try to have SQL like:
+ # WHERE somecol <= '2013-12-31')
+ # but also make it work if the rhs_sql is field reference.
+ return "%s <= (%s || '-12-31')::date" % (lhs_sql, rhs_sql), params
+YearTransform.register_lookup(YearLte)
+
+
+# We will register this class temporarily in the test method.
+
+
+class InMonth(models.lookups.Lookup):
+ """
+ InMonth matches if the column's month is the same as value's month.
+ """
+ lookup_name = 'inmonth'
+
+ def as_sql(self, qn, connection):
+ lhs, lhs_params = self.process_lhs(qn, connection)
+ rhs, rhs_params = self.process_rhs(qn, connection)
+ # We need to be careful so that we get the params in right
+ # places.
+ params = lhs_params + rhs_params + lhs_params + rhs_params
+ return ("%s >= date_trunc('month', %s) and "
+ "%s < date_trunc('month', %s) + interval '1 months'" %
+ (lhs, rhs, lhs, rhs), params)
+
+
+class LookupTests(TestCase):
+ def test_basic_lookup(self):
+ a1 = Author.objects.create(name='a1', age=1)
+ a2 = Author.objects.create(name='a2', age=2)
+ a3 = Author.objects.create(name='a3', age=3)
+ a4 = Author.objects.create(name='a4', age=4)
+ models.IntegerField.register_lookup(Div3Lookup)
+ try:
+ self.assertQuerysetEqual(
+ Author.objects.filter(age__div3=0),
+ [a3], lambda x: x
+ )
+ self.assertQuerysetEqual(
+ Author.objects.filter(age__div3=1).order_by('age'),
+ [a1, a4], lambda x: x
+ )
+ self.assertQuerysetEqual(
+ Author.objects.filter(age__div3=2),
+ [a2], lambda x: x
+ )
+ self.assertQuerysetEqual(
+ Author.objects.filter(age__div3=3),
+ [], lambda x: x
+ )
+ finally:
+ models.IntegerField._unregister_lookup(Div3Lookup)
+
+ @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
+ def test_birthdate_month(self):
+ a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16))
+ a2 = Author.objects.create(name='a2', birthdate=date(2012, 2, 29))
+ a3 = Author.objects.create(name='a3', birthdate=date(2012, 1, 31))
+ a4 = Author.objects.create(name='a4', birthdate=date(2012, 3, 1))
+ models.DateField.register_lookup(InMonth)
+ try:
+ self.assertQuerysetEqual(
+ Author.objects.filter(birthdate__inmonth=date(2012, 1, 15)),
+ [a3], lambda x: x
+ )
+ self.assertQuerysetEqual(
+ Author.objects.filter(birthdate__inmonth=date(2012, 2, 1)),
+ [a2], lambda x: x
+ )
+ self.assertQuerysetEqual(
+ Author.objects.filter(birthdate__inmonth=date(1981, 2, 28)),
+ [a1], lambda x: x
+ )
+ self.assertQuerysetEqual(
+ Author.objects.filter(birthdate__inmonth=date(2012, 3, 12)),
+ [a4], lambda x: x
+ )
+ self.assertQuerysetEqual(
+ Author.objects.filter(birthdate__inmonth=date(2012, 4, 1)),
+ [], lambda x: x
+ )
+ finally:
+ models.DateField._unregister_lookup(InMonth)
+
+ def test_div3_extract(self):
+ models.IntegerField.register_lookup(Div3Transform)
+ try:
+ a1 = Author.objects.create(name='a1', age=1)
+ a2 = Author.objects.create(name='a2', age=2)
+ a3 = Author.objects.create(name='a3', age=3)
+ a4 = Author.objects.create(name='a4', age=4)
+ baseqs = Author.objects.order_by('name')
+ self.assertQuerysetEqual(
+ baseqs.filter(age__div3=2),
+ [a2], lambda x: x)
+ self.assertQuerysetEqual(
+ baseqs.filter(age__div3__lte=3),
+ [a1, a2, a3, a4], lambda x: x)
+ self.assertQuerysetEqual(
+ baseqs.filter(age__div3__in=[0, 2]),
+ [a2, a3], lambda x: x)
+ finally:
+ models.IntegerField._unregister_lookup(Div3Transform)
+
+
+class YearLteTests(TestCase):
+ def setUp(self):
+ models.DateField.register_lookup(YearTransform)
+ self.a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16))
+ self.a2 = Author.objects.create(name='a2', birthdate=date(2012, 2, 29))
+ self.a3 = Author.objects.create(name='a3', birthdate=date(2012, 1, 31))
+ self.a4 = Author.objects.create(name='a4', birthdate=date(2012, 3, 1))
+
+ def tearDown(self):
+ models.DateField._unregister_lookup(YearTransform)
+
+ @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
+ def test_year_lte(self):
+ baseqs = Author.objects.order_by('name')
+ self.assertQuerysetEqual(
+ baseqs.filter(birthdate__year__lte=2012),
+ [self.a1, self.a2, self.a3, self.a4], lambda x: x)
+ self.assertQuerysetEqual(
+ baseqs.filter(birthdate__year=2012),
+ [self.a2, self.a3, self.a4], lambda x: x)
+
+ self.assertNotIn('BETWEEN', str(baseqs.filter(birthdate__year=2012).query))
+ self.assertQuerysetEqual(
+ baseqs.filter(birthdate__year__lte=2011),
+ [self.a1], lambda x: x)
+ # The non-optimized version works, too.
+ self.assertQuerysetEqual(
+ baseqs.filter(birthdate__year__lt=2012),
+ [self.a1], lambda x: x)
+
+ @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
+ def test_year_lte_fexpr(self):
+ self.a2.age = 2011
+ self.a2.save()
+ self.a3.age = 2012
+ self.a3.save()
+ self.a4.age = 2013
+ self.a4.save()
+ baseqs = Author.objects.order_by('name')
+ self.assertQuerysetEqual(
+ baseqs.filter(birthdate__year__lte=models.F('age')),
+ [self.a3, self.a4], lambda x: x)
+ self.assertQuerysetEqual(
+ baseqs.filter(birthdate__year__lt=models.F('age')),
+ [self.a4], lambda x: x)
+
+ def test_year_lte_sql(self):
+ # This test will just check the generated SQL for __lte. This
+ # doesn't require running on PostgreSQL and spots the most likely
+ # error - not running YearLte SQL at all.
+ baseqs = Author.objects.order_by('name')
+ self.assertIn(
+ '<= (2011 || ', str(baseqs.filter(birthdate__year__lte=2011).query))
+ self.assertIn(
+ '-12-31', str(baseqs.filter(birthdate__year__lte=2011).query))
+
+ def test_postgres_year_exact(self):
+ baseqs = Author.objects.order_by('name')
+ self.assertIn(
+ '= (2011 || ', str(baseqs.filter(birthdate__year=2011).query))
+ self.assertIn(
+ '-12-31', str(baseqs.filter(birthdate__year=2011).query))
+
+ def test_custom_implementation_year_exact(self):
+ try:
+ # Two ways to add a customized implementation for different backends:
+ # First is MonkeyPatch of the class.
+ def as_custom_sql(self, qn, connection):
+ lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs)
+ rhs_sql, rhs_params = self.process_rhs(qn, connection)
+ params = lhs_params + rhs_params + lhs_params + rhs_params
+ return ("%(lhs)s >= str_to_date(concat(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
+ "AND %(lhs)s <= str_to_date(concat(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" %
+ {'lhs': lhs_sql, 'rhs': rhs_sql}, params)
+ setattr(YearExact, 'as_' + connection.vendor, as_custom_sql)
+ self.assertIn(
+ 'concat(',
+ str(Author.objects.filter(birthdate__year=2012).query))
+ finally:
+ delattr(YearExact, 'as_' + connection.vendor)
+ try:
+ # The other way is to subclass the original lookup and register the subclassed
+ # lookup instead of the original.
+ class CustomYearExact(YearExact):
+ # This method should be named "as_mysql" for MySQL, "as_postgresql" for postgres
+ # and so on, but as we don't know which DB we are running on, we need to use
+ # setattr.
+ def as_custom_sql(self, qn, connection):
+ lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs)
+ rhs_sql, rhs_params = self.process_rhs(qn, connection)
+ params = lhs_params + rhs_params + lhs_params + rhs_params
+ return ("%(lhs)s >= str_to_date(CONCAT(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
+ "AND %(lhs)s <= str_to_date(CONCAT(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" %
+ {'lhs': lhs_sql, 'rhs': rhs_sql}, params)
+ setattr(CustomYearExact, 'as_' + connection.vendor, CustomYearExact.as_custom_sql)
+ YearTransform.register_lookup(CustomYearExact)
+ self.assertIn(
+ 'CONCAT(',
+ str(Author.objects.filter(birthdate__year=2012).query))
+ finally:
+ YearTransform._unregister_lookup(CustomYearExact)
+ YearTransform.register_lookup(YearExact)
diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py
index 99f41024f8..320271b1dc 100644
--- a/tests/expressions/tests.py
+++ b/tests/expressions/tests.py
@@ -3,7 +3,7 @@ from __future__ import unicode_literals
from django.core.exceptions import FieldError
from django.db.models import F
from django.db import transaction
-from django.test import TestCase
+from django.test import TestCase, skipIfDBFeature
from django.utils import six
from .models import Company, Employee
@@ -224,6 +224,25 @@ class ExpressionsTests(TestCase):
acme.num_employees = F("num_employees") + 16
self.assertRaises(TypeError, acme.save)
+ def test_ticket_11722_iexact_lookup(self):
+ Employee.objects.create(firstname="John", lastname="Doe")
+ Employee.objects.create(firstname="Test", lastname="test")
+
+ queryset = Employee.objects.filter(firstname__iexact=F('lastname'))
+ self.assertQuerysetEqual(queryset, ["<Employee: Test test>"])
+
+ @skipIfDBFeature('has_case_insensitive_like')
+ def test_ticket_16731_startswith_lookup(self):
+ Employee.objects.create(firstname="John", lastname="Doe")
+ e2 = Employee.objects.create(firstname="Jack", lastname="Jackson")
+ e3 = Employee.objects.create(firstname="Jack", lastname="jackson")
+ self.assertQuerysetEqual(
+ Employee.objects.filter(lastname__startswith=F('firstname')),
+ [e2], lambda x: x)
+ self.assertQuerysetEqual(
+ Employee.objects.filter(lastname__istartswith=F('firstname')).order_by('pk'),
+ [e2, e3], lambda x: x)
+
def test_ticket_18375_join_reuse(self):
# Test that reverse multijoin F() references and the lookup target
# the same join. Pre #18375 the F() join was generated first, and the
diff --git a/tests/null_queries/tests.py b/tests/null_queries/tests.py
index f807ad88ce..1b73c977b4 100644
--- a/tests/null_queries/tests.py
+++ b/tests/null_queries/tests.py
@@ -45,9 +45,6 @@ class NullQueriesTests(TestCase):
# Can't use None on anything other than __exact and __iexact
self.assertRaises(ValueError, Choice.objects.filter, id__gt=None)
- # Can't use None on anything other than __exact and __iexact
- self.assertRaises(ValueError, Choice.objects.filter, foo__gt=None)
-
# Related managers use __exact=None implicitly if the object hasn't been saved.
p2 = Poll(question="How?")
self.assertEqual(repr(p2.choice_set.all()), '[]')
diff --git a/tests/queries/tests.py b/tests/queries/tests.py
index 338ec06921..03cfc71afe 100644
--- a/tests/queries/tests.py
+++ b/tests/queries/tests.py
@@ -2632,8 +2632,15 @@ class WhereNodeTest(TestCase):
def as_sql(self, qn, connection):
return 'dummy', []
+ class MockCompiler(object):
+ def compile(self, node):
+ return node.as_sql(self, connection)
+
+ def __call__(self, name):
+ return connection.ops.quote_name(name)
+
def test_empty_full_handling_conjunction(self):
- qn = connection.ops.quote_name
+ qn = WhereNodeTest.MockCompiler()
w = WhereNode(children=[EverythingNode()])
self.assertEqual(w.as_sql(qn, connection), ('', []))
w.negate()
@@ -2658,7 +2665,7 @@ class WhereNodeTest(TestCase):
self.assertEqual(w.as_sql(qn, connection), ('', []))
def test_empty_full_handling_disjunction(self):
- qn = connection.ops.quote_name
+ qn = WhereNodeTest.MockCompiler()
w = WhereNode(children=[EverythingNode()], connector='OR')
self.assertEqual(w.as_sql(qn, connection), ('', []))
w.negate()
@@ -2685,7 +2692,7 @@ class WhereNodeTest(TestCase):
self.assertEqual(w.as_sql(qn, connection), ('NOT (dummy)', []))
def test_empty_nodes(self):
- qn = connection.ops.quote_name
+ qn = WhereNodeTest.MockCompiler()
empty_w = WhereNode()
w = WhereNode(children=[empty_w, empty_w])
self.assertEqual(w.as_sql(qn, connection), (None, []))