summaryrefslogtreecommitdiff
path: root/django
diff options
context:
space:
mode:
Diffstat (limited to 'django')
-rw-r--r--django/contrib/contenttypes/generic.py5
-rw-r--r--django/contrib/gis/db/backends/mysql/operations.py4
-rw-r--r--django/contrib/gis/db/backends/oracle/operations.py5
-rw-r--r--django/contrib/gis/db/backends/postgis/operations.py5
-rw-r--r--django/contrib/gis/db/backends/spatialite/operations.py5
-rw-r--r--django/contrib/gis/db/models/constants.py15
-rw-r--r--django/contrib/gis/db/models/fields.py6
-rw-r--r--django/contrib/gis/db/models/lookups.py28
-rw-r--r--django/contrib/gis/db/models/sql/query.py14
-rw-r--r--django/db/backends/__init__.py3
-rw-r--r--django/db/backends/postgresql_psycopg2/base.py6
-rw-r--r--django/db/backends/sqlite3/base.py5
-rw-r--r--django/db/models/__init__.py1
-rw-r--r--django/db/models/aggregates.py9
-rw-r--r--django/db/models/fields/__init__.py9
-rw-r--r--django/db/models/fields/related.py36
-rw-r--r--django/db/models/lookups.py317
-rw-r--r--django/db/models/sql/aggregates.py10
-rw-r--r--django/db/models/sql/compiler.py63
-rw-r--r--django/db/models/sql/datastructures.py23
-rw-r--r--django/db/models/sql/expressions.py6
-rw-r--r--django/db/models/sql/query.py179
-rw-r--r--django/db/models/sql/subqueries.py20
-rw-r--r--django/db/models/sql/where.py27
24 files changed, 625 insertions, 176 deletions
diff --git a/django/contrib/contenttypes/generic.py b/django/contrib/contenttypes/generic.py
index a92f0f0dc8..7d130bec24 100644
--- a/django/contrib/contenttypes/generic.py
+++ b/django/contrib/contenttypes/generic.py
@@ -12,7 +12,7 @@ from django.db import models, router, transaction, DEFAULT_DB_ALIAS
from django.db.models import signals
from django.db.models.fields.related import ForeignObject, ForeignObjectRel
from django.db.models.related import PathInfo
-from django.db.models.sql.where import Constraint
+from django.db.models.sql.datastructures import Col
from django.forms import ModelForm, ALL_FIELDS
from django.forms.models import (BaseModelFormSet, modelformset_factory,
modelform_defines_fields)
@@ -236,7 +236,8 @@ class GenericRelation(ForeignObject):
field = self.rel.to._meta.get_field_by_name(self.content_type_field_name)[0]
contenttype_pk = self.get_content_type().pk
cond = where_class()
- cond.add((Constraint(remote_alias, field.column, field), 'exact', contenttype_pk), 'AND')
+ lookup = field.get_lookup('exact')(Col(remote_alias, field, field), contenttype_pk)
+ cond.add(lookup, 'AND')
return cond
def bulk_related_objects(self, objs, using=DEFAULT_DB_ALIAS):
diff --git a/django/contrib/gis/db/backends/mysql/operations.py b/django/contrib/gis/db/backends/mysql/operations.py
index 81aa1e2fb6..989b5a1292 100644
--- a/django/contrib/gis/db/backends/mysql/operations.py
+++ b/django/contrib/gis/db/backends/mysql/operations.py
@@ -49,9 +49,7 @@ class MySQLOperations(DatabaseOperations, BaseSpatialOperations):
return placeholder
def spatial_lookup_sql(self, lvalue, lookup_type, value, field, qn):
- alias, col, db_type = lvalue
-
- geo_col = '%s.%s' % (qn(alias), qn(col))
+ geo_col, db_type = lvalue
lookup_info = self.geometry_functions.get(lookup_type, False)
if lookup_info:
diff --git a/django/contrib/gis/db/backends/oracle/operations.py b/django/contrib/gis/db/backends/oracle/operations.py
index 8212938f01..b3aa191ccc 100644
--- a/django/contrib/gis/db/backends/oracle/operations.py
+++ b/django/contrib/gis/db/backends/oracle/operations.py
@@ -231,10 +231,7 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations):
def spatial_lookup_sql(self, lvalue, lookup_type, value, field, qn):
"Returns the SQL WHERE clause for use in Oracle spatial SQL construction."
- alias, col, db_type = lvalue
-
- # Getting the quoted table name as `geo_col`.
- geo_col = '%s.%s' % (qn(alias), qn(col))
+ geo_col, db_type = lvalue
# See if a Oracle Geometry function matches the lookup type next
lookup_info = self.geometry_functions.get(lookup_type, False)
diff --git a/django/contrib/gis/db/backends/postgis/operations.py b/django/contrib/gis/db/backends/postgis/operations.py
index 40ca09f0fa..e8da6a52a6 100644
--- a/django/contrib/gis/db/backends/postgis/operations.py
+++ b/django/contrib/gis/db/backends/postgis/operations.py
@@ -478,10 +478,7 @@ class PostGISOperations(DatabaseOperations, BaseSpatialOperations):
(alias, col, db_type), the lookup type string, lookup value, and
the geometry field.
"""
- alias, col, db_type = lvalue
-
- # Getting the quoted geometry column.
- geo_col = '%s.%s' % (qn(alias), qn(col))
+ geo_col, db_type = lvalue
if lookup_type in self.geometry_operators:
if field.geography and not lookup_type in self.geography_operators:
diff --git a/django/contrib/gis/db/backends/spatialite/operations.py b/django/contrib/gis/db/backends/spatialite/operations.py
index 5149b541a2..ad44b470b3 100644
--- a/django/contrib/gis/db/backends/spatialite/operations.py
+++ b/django/contrib/gis/db/backends/spatialite/operations.py
@@ -324,10 +324,7 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations):
[a tuple of (alias, column, db_type)], lookup type, lookup
value, the model field, and the quoting function.
"""
- alias, col, db_type = lvalue
-
- # Getting the quoted field as `geo_col`.
- geo_col = '%s.%s' % (qn(alias), qn(col))
+ geo_col, db_type = lvalue
if lookup_type in self.geometry_functions:
# See if a SpatiaLite geometry function matches the lookup type.
diff --git a/django/contrib/gis/db/models/constants.py b/django/contrib/gis/db/models/constants.py
new file mode 100644
index 0000000000..4ece41415c
--- /dev/null
+++ b/django/contrib/gis/db/models/constants.py
@@ -0,0 +1,15 @@
+from django.db.models.sql.constants import QUERY_TERMS
+
+GIS_LOOKUPS = {
+ 'bbcontains', 'bboverlaps', 'contained', 'contains',
+ 'contains_properly', 'coveredby', 'covers', 'crosses', 'disjoint',
+ 'distance_gt', 'distance_gte', 'distance_lt', 'distance_lte',
+ 'dwithin', 'equals', 'exact',
+ 'intersects', 'overlaps', 'relate', 'same_as', 'touches', 'within',
+ 'left', 'right', 'overlaps_left', 'overlaps_right',
+ 'overlaps_above', 'overlaps_below',
+ 'strictly_above', 'strictly_below'
+}
+ALL_TERMS = GIS_LOOKUPS | QUERY_TERMS
+
+__all__ = ['ALL_TERMS', 'GIS_LOOKUPS']
diff --git a/django/contrib/gis/db/models/fields.py b/django/contrib/gis/db/models/fields.py
index 734f805c8d..2daf08147b 100644
--- a/django/contrib/gis/db/models/fields.py
+++ b/django/contrib/gis/db/models/fields.py
@@ -2,6 +2,8 @@ from django.db.models.fields import Field
from django.db.models.sql.expressions import SQLEvaluator
from django.utils.translation import ugettext_lazy as _
from django.contrib.gis import forms
+from django.contrib.gis.db.models.constants import GIS_LOOKUPS
+from django.contrib.gis.db.models.lookups import GISLookup
from django.contrib.gis.db.models.proxy import GeometryProxy
from django.contrib.gis.geometry.backend import Geometry, GeometryException
from django.utils import six
@@ -284,6 +286,10 @@ class GeometryField(Field):
"""
return connection.ops.get_geom_placeholder(self, value)
+for lookup_name in GIS_LOOKUPS:
+ lookup = type(lookup_name, (GISLookup,), {'lookup_name': lookup_name})
+ GeometryField.register_lookup(lookup)
+
# The OpenGIS Geometry Type Fields
class PointField(GeometryField):
diff --git a/django/contrib/gis/db/models/lookups.py b/django/contrib/gis/db/models/lookups.py
new file mode 100644
index 0000000000..cad6c69000
--- /dev/null
+++ b/django/contrib/gis/db/models/lookups.py
@@ -0,0 +1,28 @@
+from django.db.models.lookups import Lookup
+from django.db.models.sql.expressions import SQLEvaluator
+
+
+class GISLookup(Lookup):
+ def as_sql(self, qn, connection):
+ from django.contrib.gis.db.models.sql import GeoWhereNode
+ # We use the same approach as was used by GeoWhereNode. It would
+ # be a good idea to upgrade GIS to use similar code that is used
+ # for other lookups.
+ if isinstance(self.rhs, SQLEvaluator):
+ # Make sure the F Expression destination field exists, and
+ # set an `srid` attribute with the same as that of the
+ # destination.
+ geo_fld = GeoWhereNode._check_geo_field(self.rhs.opts, self.rhs.expression.name)
+ if not geo_fld:
+ raise ValueError('No geographic field found in expression.')
+ self.rhs.srid = geo_fld.srid
+ db_type = self.lhs.output_type.db_type(connection=connection)
+ params = self.lhs.output_type.get_db_prep_lookup(
+ self.lookup_name, self.rhs, connection=connection)
+ lhs_sql, lhs_params = self.process_lhs(qn, connection)
+ # lhs_params not currently supported.
+ assert not lhs_params
+ data = (lhs_sql, db_type)
+ spatial_sql, spatial_params = connection.ops.spatial_lookup_sql(
+ data, self.lookup_name, self.rhs, self.lhs.output_type, qn)
+ return spatial_sql, spatial_params + params
diff --git a/django/contrib/gis/db/models/sql/query.py b/django/contrib/gis/db/models/sql/query.py
index 3923bf460e..f3fa1f6322 100644
--- a/django/contrib/gis/db/models/sql/query.py
+++ b/django/contrib/gis/db/models/sql/query.py
@@ -1,6 +1,7 @@
from django.db import connections
from django.db.models.query import sql
+from django.contrib.gis.db.models.constants import ALL_TERMS
from django.contrib.gis.db.models.fields import GeometryField
from django.contrib.gis.db.models.sql import aggregates as gis_aggregates
from django.contrib.gis.db.models.sql.conversion import AreaField, DistanceField, GeomField
@@ -9,19 +10,6 @@ from django.contrib.gis.geometry.backend import Geometry
from django.contrib.gis.measure import Area, Distance
-ALL_TERMS = set([
- 'bbcontains', 'bboverlaps', 'contained', 'contains',
- 'contains_properly', 'coveredby', 'covers', 'crosses', 'disjoint',
- 'distance_gt', 'distance_gte', 'distance_lt', 'distance_lte',
- 'dwithin', 'equals', 'exact',
- 'intersects', 'overlaps', 'relate', 'same_as', 'touches', 'within',
- 'left', 'right', 'overlaps_left', 'overlaps_right',
- 'overlaps_above', 'overlaps_below',
- 'strictly_above', 'strictly_below'
-])
-ALL_TERMS.update(sql.constants.QUERY_TERMS)
-
-
class GeoQuery(sql.Query):
"""
A single spatial SQL query.
diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py
index 2dbb8b3aae..99d68221e8 100644
--- a/django/db/backends/__init__.py
+++ b/django/db/backends/__init__.py
@@ -673,6 +673,9 @@ class BaseDatabaseFeatures(object):
# What kind of error does the backend throw when accessing closed cursor?
closed_cursor_error_class = ProgrammingError
+ # Does 'a' LIKE 'A' match?
+ has_case_insensitive_like = True
+
def __init__(self, connection):
self.connection = connection
diff --git a/django/db/backends/postgresql_psycopg2/base.py b/django/db/backends/postgresql_psycopg2/base.py
index 7725b0c7a0..33f885d50c 100644
--- a/django/db/backends/postgresql_psycopg2/base.py
+++ b/django/db/backends/postgresql_psycopg2/base.py
@@ -62,6 +62,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_combined_alters = True
nulls_order_largest = True
closed_cursor_error_class = InterfaceError
+ has_case_insensitive_like = False
class DatabaseWrapper(BaseDatabaseWrapper):
@@ -83,6 +84,11 @@ class DatabaseWrapper(BaseDatabaseWrapper):
'iendswith': 'LIKE UPPER(%s)',
}
+ pattern_ops = {
+ 'startswith': "LIKE %s || '%%%%'",
+ 'istartswith': "LIKE UPPER(%s) || '%%%%'",
+ }
+
Database = Database
def __init__(self, *args, **kwargs):
diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py
index 6c9728889f..e55973ea39 100644
--- a/django/db/backends/sqlite3/base.py
+++ b/django/db/backends/sqlite3/base.py
@@ -334,6 +334,11 @@ class DatabaseWrapper(BaseDatabaseWrapper):
'iendswith': "LIKE %s ESCAPE '\\'",
}
+ pattern_ops = {
+ 'startswith': "LIKE %s || '%%%%'",
+ 'istartswith': "LIKE UPPER(%s) || '%%%%'",
+ }
+
Database = Database
def __init__(self, *args, **kwargs):
diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py
index 454a693be7..bdbdd5fd91 100644
--- a/django/db/models/__init__.py
+++ b/django/db/models/__init__.py
@@ -17,6 +17,7 @@ from django.db.models.fields.related import ( # NOQA
from django.db.models.fields.proxy import OrderWrt # NOQA
from django.db.models.deletion import ( # NOQA
CASCADE, PROTECT, SET, SET_NULL, SET_DEFAULT, DO_NOTHING, ProtectedError)
+from django.db.models.lookups import Lookup, Transform # NOQA
from django.db.models import signals # NOQA
diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py
index 1ec11b4acb..e31d228aa5 100644
--- a/django/db/models/aggregates.py
+++ b/django/db/models/aggregates.py
@@ -15,10 +15,11 @@ def refs_aggregate(lookup_parts, aggregates):
default annotation names we must check each prefix of the lookup_parts
for match.
"""
- for i in range(len(lookup_parts) + 1):
- if LOOKUP_SEP.join(lookup_parts[0:i]) in aggregates:
- return True
- return False
+ for n in range(len(lookup_parts) + 1):
+ level_n_lookup = LOOKUP_SEP.join(lookup_parts[0:n])
+ if level_n_lookup in aggregates:
+ return aggregates[level_n_lookup], lookup_parts[n:]
+ return False, ()
class Aggregate(object):
diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py
index 7172cf1a55..7ace8878aa 100644
--- a/django/db/models/fields/__init__.py
+++ b/django/db/models/fields/__init__.py
@@ -11,6 +11,7 @@ from itertools import tee
from django.apps import apps
from django.db import connection
+from django.db.models.lookups import default_lookups, RegisterLookupMixin
from django.db.models.query_utils import QueryWrapper
from django.conf import settings
from django import forms
@@ -80,7 +81,7 @@ def _empty(of_cls):
@total_ordering
-class Field(object):
+class Field(RegisterLookupMixin):
"""Base class for all field types"""
# Designates whether empty strings fundamentally are allowed at the
@@ -101,6 +102,7 @@ class Field(object):
'unique': _('%(model_name)s with this %(field_label)s '
'already exists.'),
}
+ class_lookups = default_lookups.copy()
# Generic field type description, usually overridden by subclasses
def _description(self):
@@ -514,8 +516,7 @@ class Field(object):
except ValueError:
raise ValueError("The __year lookup type requires an integer "
"argument")
-
- raise TypeError("Field has invalid lookup: %s" % lookup_type)
+ return self.get_prep_value(value)
def get_db_prep_lookup(self, lookup_type, value, connection,
prepared=False):
@@ -564,6 +565,8 @@ class Field(object):
return connection.ops.year_lookup_bounds_for_date_field(value)
else:
return [value] # this isn't supposed to happen
+ else:
+ return [value]
def has_default(self):
"""
diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py
index 82e7725dac..69fb3f8492 100644
--- a/django/db/models/fields/related.py
+++ b/django/db/models/fields/related.py
@@ -5,9 +5,11 @@ from django.db.backends import utils
from django.db.models import signals, Q
from django.db.models.fields import (AutoField, Field, IntegerField,
PositiveIntegerField, PositiveSmallIntegerField, FieldDoesNotExist)
+from django.db.models.lookups import IsNull
from django.db.models.related import RelatedObject, PathInfo
from django.db.models.query import QuerySet
from django.db.models.deletion import CASCADE
+from django.db.models.sql.datastructures import Col
from django.utils.encoding import smart_text
from django.utils import six
from django.utils.deprecation import RenameMethodsBase
@@ -987,6 +989,11 @@ class ForeignObjectRel(object):
# example custom multicolumn joins currently have no remote field).
self.field_name = None
+ def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookup_type,
+ raw_value):
+ return self.field.get_lookup_constraint(constraint_class, alias, targets, sources,
+ lookup_type, raw_value)
+
class ManyToOneRel(ForeignObjectRel):
def __init__(self, field, to, field_name, related_name=None, limit_choices_to=None,
@@ -1193,14 +1200,16 @@ class ForeignObject(RelatedField):
pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.rel, not self.unique, False)]
return pathinfos
- def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookup_type,
+ def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookups,
raw_value):
- from django.db.models.sql.where import SubqueryConstraint, Constraint, AND, OR
+ from django.db.models.sql.where import SubqueryConstraint, AND, OR
root_constraint = constraint_class()
assert len(targets) == len(sources)
+ if len(lookups) > 1:
+ raise exceptions.FieldError('Relation fields do not support nested lookups')
+ lookup_type = lookups[0]
def get_normalized_value(value):
-
from django.db.models import Model
if isinstance(value, Model):
value_list = []
@@ -1221,28 +1230,27 @@ class ForeignObject(RelatedField):
[source.name for source in sources], raw_value),
AND)
elif lookup_type == 'isnull':
- root_constraint.add(
- (Constraint(alias, targets[0].column, targets[0]), lookup_type, raw_value), AND)
+ root_constraint.add(IsNull(Col(alias, targets[0], sources[0]), raw_value), AND)
elif (lookup_type == 'exact' or (lookup_type in ['gt', 'lt', 'gte', 'lte']
and not is_multicolumn)):
value = get_normalized_value(raw_value)
- for index, source in enumerate(sources):
+ for target, source, val in zip(targets, sources, value):
+ lookup_class = target.get_lookup(lookup_type)
root_constraint.add(
- (Constraint(alias, targets[index].column, sources[index]), lookup_type,
- value[index]), AND)
+ lookup_class(Col(alias, target, source), val), AND)
elif lookup_type in ['range', 'in'] and not is_multicolumn:
values = [get_normalized_value(value) for value in raw_value]
value = [val[0] for val in values]
- root_constraint.add(
- (Constraint(alias, targets[0].column, sources[0]), lookup_type, value), AND)
+ lookup_class = targets[0].get_lookup(lookup_type)
+ root_constraint.add(lookup_class(Col(alias, targets[0], sources[0]), value), AND)
elif lookup_type == 'in':
values = [get_normalized_value(value) for value in raw_value]
for value in values:
value_constraint = constraint_class()
- for index, target in enumerate(targets):
- value_constraint.add(
- (Constraint(alias, target.column, sources[index]), 'exact', value[index]),
- AND)
+ for source, target, val in zip(sources, targets, value):
+ lookup_class = target.get_lookup('exact')
+ lookup = lookup_class(Col(alias, target, source), val)
+ value_constraint.add(lookup, AND)
root_constraint.add(value_constraint, OR)
else:
raise TypeError('Related Field got invalid lookup: %s' % lookup_type)
diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py
new file mode 100644
index 0000000000..5369994bbc
--- /dev/null
+++ b/django/db/models/lookups.py
@@ -0,0 +1,317 @@
+from copy import copy
+import inspect
+
+from django.conf import settings
+from django.utils import timezone
+from django.utils.functional import cached_property
+
+
+class RegisterLookupMixin(object):
+ def get_lookup(self, lookup_name):
+ try:
+ return self.class_lookups[lookup_name]
+ except KeyError:
+ # To allow for inheritance, check parent class class lookups.
+ for parent in inspect.getmro(self.__class__):
+ if not 'class_lookups' in parent.__dict__:
+ continue
+ if lookup_name in parent.class_lookups:
+ return parent.class_lookups[lookup_name]
+ except AttributeError:
+ # This class didn't have any class_lookups
+ pass
+ if hasattr(self, 'output_type'):
+ return self.output_type.get_lookup(lookup_name)
+ return None
+
+ @classmethod
+ def register_lookup(cls, lookup):
+ if not 'class_lookups' in cls.__dict__:
+ cls.class_lookups = {}
+ cls.class_lookups[lookup.lookup_name] = lookup
+
+ @classmethod
+ def _unregister_lookup(cls, lookup):
+ """
+ Removes given lookup from cls lookups. Meant to be used in
+ tests only.
+ """
+ del cls.class_lookups[lookup.lookup_name]
+
+
+class Transform(RegisterLookupMixin):
+ def __init__(self, lhs, lookups):
+ self.lhs = lhs
+ self.init_lookups = lookups[:]
+
+ def as_sql(self, qn, connection):
+ raise NotImplementedError
+
+ @cached_property
+ def output_type(self):
+ return self.lhs.output_type
+
+ def relabeled_clone(self, relabels):
+ return self.__class__(self.lhs.relabeled_clone(relabels))
+
+ def get_group_by_cols(self):
+ return self.lhs.get_group_by_cols()
+
+
+class Lookup(RegisterLookupMixin):
+ lookup_name = None
+
+ def __init__(self, lhs, rhs):
+ self.lhs, self.rhs = lhs, rhs
+ self.rhs = self.get_prep_lookup()
+
+ def get_prep_lookup(self):
+ return self.lhs.output_type.get_prep_lookup(self.lookup_name, self.rhs)
+
+ def get_db_prep_lookup(self, value, connection):
+ return (
+ '%s', self.lhs.output_type.get_db_prep_lookup(
+ self.lookup_name, value, connection, prepared=True))
+
+ def process_lhs(self, qn, connection, lhs=None):
+ lhs = lhs or self.lhs
+ return qn.compile(lhs)
+
+ def process_rhs(self, qn, connection, rhs=None):
+ value = rhs or self.rhs
+ # Due to historical reasons there are a couple of different
+ # ways to produce sql here. get_compiler is likely a Query
+ # instance, _as_sql QuerySet and as_sql just something with
+ # as_sql. Finally the value can of course be just plain
+ # Python value.
+ if hasattr(value, 'get_compiler'):
+ value = value.get_compiler(connection=connection)
+ if hasattr(value, 'as_sql'):
+ sql, params = qn.compile(value)
+ return '(' + sql + ')', params
+ if hasattr(value, '_as_sql'):
+ sql, params = value._as_sql(connection=connection)
+ return '(' + sql + ')', params
+ else:
+ return self.get_db_prep_lookup(value, connection)
+
+ def relabeled_clone(self, relabels):
+ new = copy(self)
+ new.lhs = new.lhs.relabeled_clone(relabels)
+ if hasattr(new.rhs, 'relabeled_clone'):
+ new.rhs = new.rhs.relabeled_clone(relabels)
+ return new
+
+ def get_group_by_cols(self):
+ cols = self.lhs.get_group_by_cols()
+ if hasattr(self.rhs, 'get_group_by_cols'):
+ cols.extend(self.rhs.get_group_by_cols())
+ return cols
+
+ def as_sql(self, qn, connection):
+ raise NotImplementedError
+
+
+class BuiltinLookup(Lookup):
+ def as_sql(self, qn, connection):
+ lhs_sql, params = self.process_lhs(qn, connection)
+ field_internal_type = self.lhs.output_type.get_internal_type()
+ db_type = self.lhs.output_type
+ lhs_sql = connection.ops.field_cast_sql(db_type, field_internal_type) % lhs_sql
+ lhs_sql = connection.ops.lookup_cast(self.lookup_name) % lhs_sql
+ rhs_sql, rhs_params = self.process_rhs(qn, connection)
+ params.extend(rhs_params)
+ operator_plus_rhs = self.get_rhs_op(connection, rhs_sql)
+ return '%s %s' % (lhs_sql, operator_plus_rhs), params
+
+ def get_rhs_op(self, connection, rhs):
+ return connection.operators[self.lookup_name] % rhs
+
+
+default_lookups = {}
+
+
+class Exact(BuiltinLookup):
+ lookup_name = 'exact'
+default_lookups['exact'] = Exact
+
+
+class IExact(BuiltinLookup):
+ lookup_name = 'iexact'
+default_lookups['iexact'] = IExact
+
+
+class Contains(BuiltinLookup):
+ lookup_name = 'contains'
+default_lookups['contains'] = Contains
+
+
+class IContains(BuiltinLookup):
+ lookup_name = 'icontains'
+default_lookups['icontains'] = IContains
+
+
+class GreaterThan(BuiltinLookup):
+ lookup_name = 'gt'
+default_lookups['gt'] = GreaterThan
+
+
+class GreaterThanOrEqual(BuiltinLookup):
+ lookup_name = 'gte'
+default_lookups['gte'] = GreaterThanOrEqual
+
+
+class LessThan(BuiltinLookup):
+ lookup_name = 'lt'
+default_lookups['lt'] = LessThan
+
+
+class LessThanOrEqual(BuiltinLookup):
+ lookup_name = 'lte'
+default_lookups['lte'] = LessThanOrEqual
+
+
+class In(BuiltinLookup):
+ lookup_name = 'in'
+
+ def get_db_prep_lookup(self, value, connection):
+ params = self.lhs.output_type.get_db_prep_lookup(
+ self.lookup_name, value, connection, prepared=True)
+ if not params:
+ # TODO: check why this leads to circular import
+ from django.db.models.sql.datastructures import EmptyResultSet
+ raise EmptyResultSet
+ placeholder = '(' + ', '.join('%s' for p in params) + ')'
+ return (placeholder, params)
+
+ def get_rhs_op(self, connection, rhs):
+ return 'IN %s' % rhs
+default_lookups['in'] = In
+
+
+class PatternLookup(BuiltinLookup):
+ def get_rhs_op(self, connection, rhs):
+ # Assume we are in startswith. We need to produce SQL like:
+ # col LIKE %s, ['thevalue%']
+ # For python values we can (and should) do that directly in Python,
+ # but if the value is for example reference to other column, then
+ # we need to add the % pattern match to the lookup by something like
+ # col LIKE othercol || '%%'
+ # So, for Python values we don't need any special pattern, but for
+ # SQL reference values we need the correct pattern added.
+ value = self.rhs
+ if (hasattr(value, 'get_compiler') or hasattr(value, 'as_sql')
+ or hasattr(value, '_as_sql')):
+ return connection.pattern_ops[self.lookup_name] % rhs
+ else:
+ return super(PatternLookup, self).get_rhs_op(connection, rhs)
+
+
+class StartsWith(PatternLookup):
+ lookup_name = 'startswith'
+default_lookups['startswith'] = StartsWith
+
+
+class IStartsWith(PatternLookup):
+ lookup_name = 'istartswith'
+default_lookups['istartswith'] = IStartsWith
+
+
+class EndsWith(BuiltinLookup):
+ lookup_name = 'endswith'
+default_lookups['endswith'] = EndsWith
+
+
+class IEndsWith(BuiltinLookup):
+ lookup_name = 'iendswith'
+default_lookups['iendswith'] = IEndsWith
+
+
+class Between(BuiltinLookup):
+ def get_rhs_op(self, connection, rhs):
+ return "BETWEEN %s AND %s" % (rhs, rhs)
+
+
+class Year(Between):
+ lookup_name = 'year'
+default_lookups['year'] = Year
+
+
+class Range(Between):
+ lookup_name = 'range'
+default_lookups['range'] = Range
+
+
+class DateLookup(BuiltinLookup):
+
+ def process_lhs(self, qn, connection):
+ lhs, params = super(DateLookup, self).process_lhs(qn, connection)
+ tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
+ sql, tz_params = connection.ops.datetime_extract_sql(self.extract_type, lhs, tzname)
+ return connection.ops.lookup_cast(self.lookup_name) % sql, tz_params
+
+ def get_rhs_op(self, connection, rhs):
+ return '= %s' % rhs
+
+
+class Month(DateLookup):
+ lookup_name = 'month'
+ extract_type = 'month'
+default_lookups['month'] = Month
+
+
+class Day(DateLookup):
+ lookup_name = 'day'
+ extract_type = 'day'
+default_lookups['day'] = Day
+
+
+class WeekDay(DateLookup):
+ lookup_name = 'week_day'
+ extract_type = 'week_day'
+default_lookups['week_day'] = WeekDay
+
+
+class Hour(DateLookup):
+ lookup_name = 'hour'
+ extract_type = 'hour'
+default_lookups['hour'] = Hour
+
+
+class Minute(DateLookup):
+ lookup_name = 'minute'
+ extract_type = 'minute'
+default_lookups['minute'] = Minute
+
+
+class Second(DateLookup):
+ lookup_name = 'second'
+ extract_type = 'second'
+default_lookups['second'] = Second
+
+
+class IsNull(BuiltinLookup):
+ lookup_name = 'isnull'
+
+ def as_sql(self, qn, connection):
+ sql, params = qn.compile(self.lhs)
+ if self.rhs:
+ return "%s IS NULL" % sql, params
+ else:
+ return "%s IS NOT NULL" % sql, params
+default_lookups['isnull'] = IsNull
+
+
+class Search(BuiltinLookup):
+ lookup_name = 'search'
+default_lookups['search'] = Search
+
+
+class Regex(BuiltinLookup):
+ lookup_name = 'regex'
+default_lookups['regex'] = Regex
+
+
+class IRegex(BuiltinLookup):
+ lookup_name = 'iregex'
+default_lookups['iregex'] = IRegex
diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py
index 8542a330c6..aef8b493bb 100644
--- a/django/db/models/sql/aggregates.py
+++ b/django/db/models/sql/aggregates.py
@@ -4,6 +4,7 @@ Classes to represent the default SQL aggregate functions
import copy
from django.db.models.fields import IntegerField, FloatField
+from django.db.models.lookups import RegisterLookupMixin
__all__ = ['Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance']
@@ -14,7 +15,7 @@ ordinal_aggregate_field = IntegerField()
computed_aggregate_field = FloatField()
-class Aggregate(object):
+class Aggregate(RegisterLookupMixin):
"""
Default SQL Aggregate.
"""
@@ -93,6 +94,13 @@ class Aggregate(object):
return self.sql_template % substitutions, params
+ def get_group_by_cols(self):
+ return []
+
+ @property
+ def output_type(self):
+ return self.field
+
class Avg(Aggregate):
is_computed = True
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index 41bba93206..123427cf8b 100644
--- a/django/db/models/sql/compiler.py
+++ b/django/db/models/sql/compiler.py
@@ -45,7 +45,7 @@ class SQLCompiler(object):
if self.query.select_related and not self.query.related_select_cols:
self.fill_related_selections()
- def quote_name_unless_alias(self, name):
+ def __call__(self, name):
"""
A wrapper around connection.ops.quote_name that doesn't quote aliases
for table names. This avoids problems with some SQL dialects that treat
@@ -61,6 +61,22 @@ class SQLCompiler(object):
self.quote_cache[name] = r
return r
+ def quote_name_unless_alias(self, name):
+ """
+ A wrapper around connection.ops.quote_name that doesn't quote aliases
+ for table names. This avoids problems with some SQL dialects that treat
+ quoted strings specially (e.g. PostgreSQL).
+ """
+ return self(name)
+
+ def compile(self, node):
+ vendor_impl = getattr(
+ node, 'as_' + self.connection.vendor, None)
+ if vendor_impl:
+ return vendor_impl(self, self.connection)
+ else:
+ return node.as_sql(self, self.connection)
+
def as_sql(self, with_limits=True, with_col_aliases=False):
"""
Creates the SQL for this query. Returns the SQL string and list of
@@ -88,11 +104,9 @@ class SQLCompiler(object):
# docstring of get_from_clause() for details.
from_, f_params = self.get_from_clause()
- qn = self.quote_name_unless_alias
-
- where, w_params = self.query.where.as_sql(qn=qn, connection=self.connection)
- having, h_params = self.query.having.as_sql(qn=qn, connection=self.connection)
- having_group_by = self.query.having.get_cols()
+ where, w_params = self.compile(self.query.where)
+ having, h_params = self.compile(self.query.having)
+ having_group_by = self.query.having.get_group_by_cols()
params = []
for val in six.itervalues(self.query.extra_select):
params.extend(val[1])
@@ -180,7 +194,7 @@ class SQLCompiler(object):
(without the table names) are given unique aliases. This is needed in
some cases to avoid ambiguity with nested queries.
"""
- qn = self.quote_name_unless_alias
+ qn = self
qn2 = self.connection.ops.quote_name
result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in six.iteritems(self.query.extra_select)]
params = []
@@ -213,7 +227,7 @@ class SQLCompiler(object):
aliases.add(r)
col_aliases.add(col[1])
else:
- col_sql, col_params = col.as_sql(qn, self.connection)
+ col_sql, col_params = self.compile(col)
result.append(col_sql)
params.extend(col_params)
@@ -229,7 +243,7 @@ class SQLCompiler(object):
max_name_length = self.connection.ops.max_name_length()
for alias, aggregate in self.query.aggregate_select.items():
- agg_sql, agg_params = aggregate.as_sql(qn, self.connection)
+ agg_sql, agg_params = self.compile(aggregate)
if alias is None:
result.append(agg_sql)
else:
@@ -267,7 +281,7 @@ class SQLCompiler(object):
result = []
if opts is None:
opts = self.query.get_meta()
- qn = self.quote_name_unless_alias
+ qn = self
qn2 = self.connection.ops.quote_name
aliases = set()
only_load = self.deferred_to_columns()
@@ -319,7 +333,7 @@ class SQLCompiler(object):
Note that this method can alter the tables in the query, and thus it
must be called before get_from_clause().
"""
- qn = self.quote_name_unless_alias
+ qn = self
qn2 = self.connection.ops.quote_name
result = []
opts = self.query.get_meta()
@@ -352,7 +366,7 @@ class SQLCompiler(object):
ordering = (self.query.order_by
or self.query.get_meta().ordering
or [])
- qn = self.quote_name_unless_alias
+ qn = self
qn2 = self.connection.ops.quote_name
distinct = self.query.distinct
select_aliases = self._select_aliases
@@ -490,7 +504,7 @@ class SQLCompiler(object):
ordering and distinct must be done first.
"""
result = []
- qn = self.quote_name_unless_alias
+ qn = self
qn2 = self.connection.ops.quote_name
first = True
from_params = []
@@ -508,8 +522,7 @@ class SQLCompiler(object):
extra_cond = join_field.get_extra_restriction(
self.query.where_class, alias, lhs)
if extra_cond:
- extra_sql, extra_params = extra_cond.as_sql(
- qn, self.connection)
+ extra_sql, extra_params = self.compile(extra_cond)
extra_sql = 'AND (%s)' % extra_sql
from_params.extend(extra_params)
else:
@@ -541,7 +554,7 @@ class SQLCompiler(object):
"""
Returns a tuple representing the SQL elements in the "group by" clause.
"""
- qn = self.quote_name_unless_alias
+ qn = self
result, params = [], []
if self.query.group_by is not None:
select_cols = self.query.select + self.query.related_select_cols
@@ -560,7 +573,7 @@ class SQLCompiler(object):
if isinstance(col, (list, tuple)):
sql = '%s.%s' % (qn(col[0]), qn(col[1]))
elif hasattr(col, 'as_sql'):
- sql, col_params = col.as_sql(qn, self.connection)
+ self.compile(col)
else:
sql = '(%s)' % str(col)
if sql not in seen:
@@ -784,7 +797,7 @@ class SQLCompiler(object):
return result
def as_subquery_condition(self, alias, columns, qn):
- inner_qn = self.quote_name_unless_alias
+ inner_qn = self
qn2 = self.connection.ops.quote_name
if len(columns) == 1:
sql, params = self.as_sql()
@@ -895,9 +908,9 @@ class SQLDeleteCompiler(SQLCompiler):
"""
assert len(self.query.tables) == 1, \
"Can only delete from one table at a time."
- qn = self.quote_name_unless_alias
+ qn = self
result = ['DELETE FROM %s' % qn(self.query.tables[0])]
- where, params = self.query.where.as_sql(qn=qn, connection=self.connection)
+ where, params = self.compile(self.query.where)
if where:
result.append('WHERE %s' % where)
return ' '.join(result), tuple(params)
@@ -913,7 +926,7 @@ class SQLUpdateCompiler(SQLCompiler):
if not self.query.values:
return '', ()
table = self.query.tables[0]
- qn = self.quote_name_unless_alias
+ qn = self
result = ['UPDATE %s' % qn(table)]
result.append('SET')
values, update_params = [], []
@@ -933,7 +946,7 @@ class SQLUpdateCompiler(SQLCompiler):
val = SQLEvaluator(val, self.query, allow_joins=False)
name = field.column
if hasattr(val, 'as_sql'):
- sql, params = val.as_sql(qn, self.connection)
+ sql, params = self.compile(val)
values.append('%s = %s' % (qn(name), sql))
update_params.extend(params)
elif val is not None:
@@ -944,7 +957,7 @@ class SQLUpdateCompiler(SQLCompiler):
if not values:
return '', ()
result.append(', '.join(values))
- where, params = self.query.where.as_sql(qn=qn, connection=self.connection)
+ where, params = self.compile(self.query.where)
if where:
result.append('WHERE %s' % where)
return ' '.join(result), tuple(update_params + params)
@@ -1024,11 +1037,11 @@ class SQLAggregateCompiler(SQLCompiler):
parameters.
"""
if qn is None:
- qn = self.quote_name_unless_alias
+ qn = self
sql, params = [], []
for aggregate in self.query.aggregate_select.values():
- agg_sql, agg_params = aggregate.as_sql(qn, self.connection)
+ agg_sql, agg_params = self.compile(aggregate)
sql.append(agg_sql)
params.extend(agg_params)
sql = ', '.join(sql)
diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py
index f45ecaf76d..421c3cd860 100644
--- a/django/db/models/sql/datastructures.py
+++ b/django/db/models/sql/datastructures.py
@@ -5,18 +5,27 @@ the SQL domain.
class Col(object):
- def __init__(self, alias, col):
- self.alias = alias
- self.col = col
+ def __init__(self, alias, target, source):
+ self.alias, self.target, self.source = alias, target, source
def as_sql(self, qn, connection):
- return '%s.%s' % (qn(self.alias), self.col), []
+ return "%s.%s" % (qn(self.alias), qn(self.target.column)), []
- def prepare(self):
- return self
+ @property
+ def output_type(self):
+ return self.source
def relabeled_clone(self, relabels):
- return self.__class__(relabels.get(self.alias, self.alias), self.col)
+ return self.__class__(relabels.get(self.alias, self.alias), self.target, self.source)
+
+ def get_group_by_cols(self):
+ return [(self.alias, self.target.column)]
+
+ def get_lookup(self, name):
+ return self.output_type.get_lookup(name)
+
+ def prepare(self):
+ return self
class EmptyResultSet(Exception):
diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py
index 9f29e2ace5..e31eaa8a2f 100644
--- a/django/db/models/sql/expressions.py
+++ b/django/db/models/sql/expressions.py
@@ -24,11 +24,11 @@ class SQLEvaluator(object):
(change_map.get(col[0], col[0]), col[1])))
return clone
- def get_cols(self):
+ def get_group_by_cols(self):
cols = []
for node, col in self.cols:
- if hasattr(node, 'get_cols'):
- cols.extend(node.get_cols())
+ if hasattr(node, 'get_group_by_cols'):
+ cols.extend(node.get_group_by_cols())
elif isinstance(col, tuple):
cols.append(col)
return cols
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
index c3c8e55793..db4e6744bf 100644
--- a/django/db/models/sql/query.py
+++ b/django/db/models/sql/query.py
@@ -19,6 +19,7 @@ from django.db.models.constants import LOOKUP_SEP
from django.db.models.aggregates import refs_aggregate
from django.db.models.expressions import ExpressionNode
from django.db.models.fields import FieldDoesNotExist
+from django.db.models.lookups import Transform
from django.db.models.query_utils import Q
from django.db.models.related import PathInfo
from django.db.models.sql import aggregates as base_aggregates_module
@@ -1028,13 +1029,16 @@ class Query(object):
# Add the aggregate to the query
aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary)
- def prepare_lookup_value(self, value, lookup_type, can_reuse):
+ def prepare_lookup_value(self, value, lookups, can_reuse):
+ # Default lookup if none given is exact.
+ if len(lookups) == 0:
+ lookups = ['exact']
# Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all
# uses of None as a query value.
if value is None:
- if lookup_type not in ('exact', 'iexact'):
+ if lookups[-1] not in ('exact', 'iexact'):
raise ValueError("Cannot use None as a query value")
- lookup_type = 'isnull'
+ lookups[-1] = 'isnull'
value = True
elif callable(value):
warnings.warn(
@@ -1055,40 +1059,54 @@ class Query(object):
# stage. Using DEFAULT_DB_ALIAS isn't nice, but it is the best we
# can do here. Similar thing is done in is_nullable(), too.
if (connections[DEFAULT_DB_ALIAS].features.interprets_empty_strings_as_nulls and
- lookup_type == 'exact' and value == ''):
+ lookups[-1] == 'exact' and value == ''):
value = True
- lookup_type = 'isnull'
- return value, lookup_type
+ lookups[-1] = ['isnull']
+ return value, lookups
def solve_lookup_type(self, lookup):
"""
Solve the lookup type from the lookup (eg: 'foobar__id__icontains')
"""
- lookup_type = 'exact' # Default lookup type
- lookup_parts = lookup.split(LOOKUP_SEP)
- num_parts = len(lookup_parts)
- if (len(lookup_parts) > 1 and lookup_parts[-1] in self.query_terms
- and (not self._aggregates or lookup not in self._aggregates)):
- # Traverse the lookup query to distinguish related fields from
- # lookup types.
- lookup_model = self.model
- for counter, field_name in enumerate(lookup_parts):
- try:
- lookup_field = lookup_model._meta.get_field(field_name)
- except FieldDoesNotExist:
- # Not a field. Bail out.
- lookup_type = lookup_parts.pop()
- break
- # Unless we're at the end of the list of lookups, let's attempt
- # to continue traversing relations.
- if (counter + 1) < num_parts:
- try:
- lookup_model = lookup_field.rel.to
- except AttributeError:
- # Not a related field. Bail out.
- lookup_type = lookup_parts.pop()
- break
- return lookup_type, lookup_parts
+ lookup_splitted = lookup.split(LOOKUP_SEP)
+ if self._aggregates:
+ aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.aggregates)
+ if aggregate:
+ return aggregate_lookups, (), aggregate
+ _, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta())
+ field_parts = lookup_splitted[0:len(lookup_splitted) - len(lookup_parts)]
+ if len(lookup_parts) == 0:
+ lookup_parts = ['exact']
+ elif len(lookup_parts) > 1:
+ if not field_parts:
+ raise FieldError(
+ 'Invalid lookup "%s" for model %s".' %
+ (lookup, self.get_meta().model.__name__))
+ return lookup_parts, field_parts, False
+
+ def build_lookup(self, lookups, lhs, rhs):
+ lookups = lookups[:]
+ while lookups:
+ lookup = lookups[0]
+ next = lhs.get_lookup(lookup)
+ if next:
+ if len(lookups) == 1:
+ # This was the last lookup, so return value lookup.
+ if issubclass(next, Transform):
+ lookups.append('exact')
+ lhs = next(lhs, lookups)
+ else:
+ return next(lhs, rhs)
+ else:
+ lhs = next(lhs, lookups)
+ # A field's get_lookup() can return None to opt for backwards
+ # compatibility path.
+ elif len(lookups) > 2:
+ raise FieldError(
+ "Unsupported lookup for field '%s'" % lhs.output_type.name)
+ else:
+ return None
+ lookups = lookups[1:]
def build_filter(self, filter_expr, branch_negated=False, current_negated=False,
can_reuse=None, connector=AND):
@@ -1118,21 +1136,24 @@ class Query(object):
is responsible for unreffing the joins used.
"""
arg, value = filter_expr
- lookup_type, parts = self.solve_lookup_type(arg)
- if not parts:
+ if not arg:
raise FieldError("Cannot parse keyword query %r" % arg)
+ lookups, parts, reffed_aggregate = self.solve_lookup_type(arg)
# Work out the lookup type and remove it from the end of 'parts',
# if necessary.
- value, lookup_type = self.prepare_lookup_value(value, lookup_type, can_reuse)
+ value, lookups = self.prepare_lookup_value(value, lookups, can_reuse)
used_joins = getattr(value, '_used_joins', [])
clause = self.where_class()
- if self._aggregates:
- for alias, aggregate in self.aggregates.items():
- if alias in (parts[0], LOOKUP_SEP.join(parts)):
- clause.add((aggregate, lookup_type, value), AND)
- return clause, []
+ if reffed_aggregate:
+ condition = self.build_lookup(lookups, reffed_aggregate, value)
+ if not condition:
+ # Backwards compat for custom lookups
+ assert len(lookups) == 1
+ condition = (reffed_aggregate, lookups[0], value)
+ clause.add(condition, AND)
+ return clause, []
opts = self.get_meta()
alias = self.get_initial_alias()
@@ -1154,11 +1175,31 @@ class Query(object):
targets, alias, join_list = self.trim_joins(sources, join_list, path)
if hasattr(field, 'get_lookup_constraint'):
- constraint = field.get_lookup_constraint(self.where_class, alias, targets, sources,
- lookup_type, value)
+ # For now foreign keys get special treatment. This should be
+ # refactored when composite fields lands.
+ condition = field.get_lookup_constraint(self.where_class, alias, targets, sources,
+ lookups, value)
+ lookup_type = lookups[-1]
else:
- constraint = (Constraint(alias, targets[0].column, field), lookup_type, value)
- clause.add(constraint, AND)
+ assert(len(targets) == 1)
+ col = Col(alias, targets[0], field)
+ condition = self.build_lookup(lookups, col, value)
+ if not condition:
+ # Backwards compat for custom lookups
+ if lookups[0] not in self.query_terms:
+ raise FieldError(
+ "Join on field '%s' not permitted. Did you "
+ "misspell '%s' for the lookup type?" %
+ (col.output_type.name, lookups[0]))
+ if len(lookups) > 1:
+ raise FieldError("Nested lookup '%s' not supported." %
+ LOOKUP_SEP.join(lookups))
+ condition = (Constraint(alias, targets[0].column, field), lookups[0], value)
+ lookup_type = lookups[-1]
+ else:
+ lookup_type = condition.lookup_name
+
+ clause.add(condition, AND)
require_outer = lookup_type == 'isnull' and value is True and not current_negated
if current_negated and (lookup_type != 'isnull' or value is False):
@@ -1175,7 +1216,8 @@ class Query(object):
# (col IS NULL OR col != someval)
# <=>
# NOT (col IS NOT NULL AND col = someval).
- clause.add((Constraint(alias, targets[0].column, None), 'isnull', False), AND)
+ lookup_class = targets[0].get_lookup('isnull')
+ clause.add(lookup_class(Col(alias, targets[0], sources[0]), False), AND)
return clause, used_joins if not require_outer else ()
def add_filter(self, filter_clause):
@@ -1189,7 +1231,7 @@ class Query(object):
if not self._aggregates:
return False
if not isinstance(obj, Node):
- return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.aggregates)
+ return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.aggregates)[0]
or (hasattr(obj[1], 'contains_aggregate')
and obj[1].contains_aggregate(self.aggregates)))
return any(self.need_having(c) for c in obj.children)
@@ -1277,7 +1319,7 @@ class Query(object):
needed_inner = joinpromoter.update_join_types(self)
return target_clause, needed_inner
- def names_to_path(self, names, opts, allow_many):
+ def names_to_path(self, names, opts, allow_many=True, fail_on_missing=False):
"""
Walks the names path and turns them PathInfo tuples. Note that a
single name in 'names' can generate multiple PathInfos (m2m for
@@ -1297,9 +1339,10 @@ class Query(object):
try:
field, model, direct, m2m = opts.get_field_by_name(name)
except FieldDoesNotExist:
- available = opts.get_all_field_names() + list(self.aggregate_select)
- raise FieldError("Cannot resolve keyword %r into field. "
- "Choices are: %s" % (name, ", ".join(available)))
+ # We didn't found the current field, so move position back
+ # one step.
+ pos -= 1
+ break
# Check if we need any joins for concrete inheritance cases (the
# field lives in parent, but we are currently in one of its
# children)
@@ -1334,15 +1377,14 @@ class Query(object):
final_field = field
targets = (field,)
break
+ if pos == -1 or (fail_on_missing and pos + 1 != len(names)):
+ self.raise_field_error(opts, name)
+ return path, final_field, targets, names[pos + 1:]
- if pos != len(names) - 1:
- if pos == len(names) - 2:
- raise FieldError(
- "Join on field %r not permitted. Did you misspell %r for "
- "the lookup type?" % (name, names[pos + 1]))
- else:
- raise FieldError("Join on field %r not permitted." % name)
- return path, final_field, targets
+ def raise_field_error(self, opts, name):
+ available = opts.get_all_field_names() + list(self.aggregate_select)
+ raise FieldError("Cannot resolve keyword %r into field. "
+ "Choices are: %s" % (name, ", ".join(available)))
def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True):
"""
@@ -1371,8 +1413,9 @@ class Query(object):
"""
joins = [alias]
# First, generate the path for the names
- path, final_field, targets = self.names_to_path(
- names, opts, allow_many)
+ path, final_field, targets, rest = self.names_to_path(
+ names, opts, allow_many, fail_on_missing=True)
+
# Then, add the path to the query's joins. Note that we can't trim
# joins at this stage - we will need the information about join type
# of the trimmed joins.
@@ -1387,8 +1430,6 @@ class Query(object):
alias = self.join(
connection, reuse=reuse, nullable=nullable, join_field=join.join_field)
joins.append(alias)
- if hasattr(final_field, 'field'):
- final_field = final_field.field
return final_field, targets, opts, joins, path
def trim_joins(self, targets, joins, path):
@@ -1451,17 +1492,19 @@ class Query(object):
# nothing
alias, col = query.select[0].col
if self.is_nullable(query.select[0].field):
- query.where.add((Constraint(alias, col, query.select[0].field), 'isnull', False), AND)
+ lookup_class = query.select[0].field.get_lookup('isnull')
+ lookup = lookup_class(Col(alias, query.select[0].field, query.select[0].field), False)
+ query.where.add(lookup, AND)
if alias in can_reuse:
- pk = query.select[0].field.model._meta.pk
+ select_field = query.select[0].field
+ pk = select_field.model._meta.pk
# Need to add a restriction so that outer query's filters are in effect for
# the subquery, too.
query.bump_prefix(self)
- query.where.add(
- (Constraint(query.select[0].col[0], pk.column, pk),
- 'exact', Col(alias, pk.column)),
- AND
- )
+ lookup_class = select_field.get_lookup('exact')
+ lookup = lookup_class(Col(query.select[0].col[0], pk, pk),
+ Col(alias, pk, pk))
+ query.where.add(lookup, AND)
condition, needed_inner = self.build_filter(
('%s__in' % trimmed_prefix, query),
diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py
index e9e292e787..86b1efd3f8 100644
--- a/django/db/models/sql/subqueries.py
+++ b/django/db/models/sql/subqueries.py
@@ -5,12 +5,12 @@ Query subclasses which provide extra functionality beyond simple data retrieval.
from django.conf import settings
from django.core.exceptions import FieldError
from django.db import connections
+from django.db.models.query_utils import Q
from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields import DateField, DateTimeField, FieldDoesNotExist
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, SelectInfo
from django.db.models.sql.datastructures import Date, DateTime
from django.db.models.sql.query import Query
-from django.db.models.sql.where import AND, Constraint
from django.utils import six
from django.utils import timezone
@@ -42,10 +42,10 @@ class DeleteQuery(Query):
if not field:
field = self.get_meta().pk
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
- where = self.where_class()
- where.add((Constraint(None, field.column, field), 'in',
- pk_list[offset:offset + GET_ITERATOR_CHUNK_SIZE]), AND)
- self.do_query(self.get_meta().db_table, where, using=using)
+ self.where = self.where_class()
+ self.add_q(Q(
+ **{field.attname + '__in': pk_list[offset:offset + GET_ITERATOR_CHUNK_SIZE]}))
+ self.do_query(self.get_meta().db_table, self.where, using=using)
def delete_qs(self, query, using):
"""
@@ -80,9 +80,8 @@ class DeleteQuery(Query):
SelectInfo((self.get_initial_alias(), pk.column), None)
]
values = innerq
- where = self.where_class()
- where.add((Constraint(None, pk.column, pk), 'in', values), AND)
- self.where = where
+ self.where = self.where_class()
+ self.add_q(Q(pk__in=values))
self.get_compiler(using).execute_sql(None)
@@ -113,13 +112,10 @@ class UpdateQuery(Query):
related_updates=self.related_updates.copy(), **kwargs)
def update_batch(self, pk_list, values, using):
- pk_field = self.get_meta().pk
self.add_update_values(values)
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
self.where = self.where_class()
- self.where.add((Constraint(None, pk_field.column, pk_field), 'in',
- pk_list[offset:offset + GET_ITERATOR_CHUNK_SIZE]),
- AND)
+ self.add_q(Q(pk__in=pk_list[offset: offset + GET_ITERATOR_CHUNK_SIZE]))
self.get_compiler(using).execute_sql(None)
def add_update_values(self, values):
diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py
index 44a4ce9d1d..be0c559c1b 100644
--- a/django/db/models/sql/where.py
+++ b/django/db/models/sql/where.py
@@ -5,6 +5,7 @@ Code to manage the creation and SQL rendering of 'where' constraints.
import collections
import datetime
from itertools import repeat
+import warnings
from django.conf import settings
from django.db.models.fields import DateTimeField, Field
@@ -101,7 +102,7 @@ class WhereNode(tree.Node):
for child in self.children:
try:
if hasattr(child, 'as_sql'):
- sql, params = child.as_sql(qn=qn, connection=connection)
+ sql, params = qn.compile(child)
else:
# A leaf node in the tree.
sql, params = self.make_atom(child, qn, connection)
@@ -152,16 +153,16 @@ class WhereNode(tree.Node):
sql_string = '(%s)' % sql_string
return sql_string, result_params
- def get_cols(self):
+ def get_group_by_cols(self):
cols = []
for child in self.children:
- if hasattr(child, 'get_cols'):
- cols.extend(child.get_cols())
+ if hasattr(child, 'get_group_by_cols'):
+ cols.extend(child.get_group_by_cols())
else:
if isinstance(child[0], Constraint):
cols.append((child[0].alias, child[0].col))
- if hasattr(child[3], 'get_cols'):
- cols.extend(child[3].get_cols())
+ if hasattr(child[3], 'get_group_by_cols'):
+ cols.extend(child[3].get_group_by_cols())
return cols
def make_atom(self, child, qn, connection):
@@ -174,6 +175,9 @@ class WhereNode(tree.Node):
Returns the string for the SQL fragment and the parameters to use for
it.
"""
+ warnings.warn(
+ "The make_atom() method will be removed in Django 1.9. Use Lookup class instead.",
+ PendingDeprecationWarning)
lvalue, lookup_type, value_annotation, params_or_value = child
field_internal_type = lvalue.field.get_internal_type() if lvalue.field else None
@@ -193,13 +197,13 @@ class WhereNode(tree.Node):
field_sql, field_params = self.sql_for_columns(lvalue, qn, connection, field_internal_type), []
else:
# A smart object with an as_sql() method.
- field_sql, field_params = lvalue.as_sql(qn, connection)
+ field_sql, field_params = qn.compile(lvalue)
is_datetime_field = value_annotation is datetime.datetime
cast_sql = connection.ops.datetime_cast_sql() if is_datetime_field else '%s'
if hasattr(params, 'as_sql'):
- extra, params = params.as_sql(qn, connection)
+ extra, params = qn.compile(params)
cast_sql = ''
else:
extra = ''
@@ -282,6 +286,8 @@ class WhereNode(tree.Node):
if hasattr(child, 'relabel_aliases'):
# For example another WhereNode
child.relabel_aliases(change_map)
+ elif hasattr(child, 'relabeled_clone'):
+ self.children[pos] = child.relabeled_clone(change_map)
elif isinstance(child, (list, tuple)):
# tuple starting with Constraint
child = (child[0].relabeled_clone(change_map),) + child[1:]
@@ -347,10 +353,13 @@ class Constraint(object):
pre-process itself prior to including in the WhereNode.
"""
def __init__(self, alias, col, field):
+ warnings.warn(
+ "The Constraint class will be removed in Django 1.9. Use Lookup class instead.",
+ PendingDeprecationWarning)
self.alias, self.col, self.field = alias, col, field
def prepare(self, lookup_type, value):
- if self.field:
+ if self.field and not hasattr(value, 'as_sql'):
return self.field.get_prep_lookup(lookup_type, value)
return value