diff options
Diffstat (limited to 'django')
24 files changed, 625 insertions, 176 deletions
diff --git a/django/contrib/contenttypes/generic.py b/django/contrib/contenttypes/generic.py index a92f0f0dc8..7d130bec24 100644 --- a/django/contrib/contenttypes/generic.py +++ b/django/contrib/contenttypes/generic.py @@ -12,7 +12,7 @@ from django.db import models, router, transaction, DEFAULT_DB_ALIAS from django.db.models import signals from django.db.models.fields.related import ForeignObject, ForeignObjectRel from django.db.models.related import PathInfo -from django.db.models.sql.where import Constraint +from django.db.models.sql.datastructures import Col from django.forms import ModelForm, ALL_FIELDS from django.forms.models import (BaseModelFormSet, modelformset_factory, modelform_defines_fields) @@ -236,7 +236,8 @@ class GenericRelation(ForeignObject): field = self.rel.to._meta.get_field_by_name(self.content_type_field_name)[0] contenttype_pk = self.get_content_type().pk cond = where_class() - cond.add((Constraint(remote_alias, field.column, field), 'exact', contenttype_pk), 'AND') + lookup = field.get_lookup('exact')(Col(remote_alias, field, field), contenttype_pk) + cond.add(lookup, 'AND') return cond def bulk_related_objects(self, objs, using=DEFAULT_DB_ALIAS): diff --git a/django/contrib/gis/db/backends/mysql/operations.py b/django/contrib/gis/db/backends/mysql/operations.py index 81aa1e2fb6..989b5a1292 100644 --- a/django/contrib/gis/db/backends/mysql/operations.py +++ b/django/contrib/gis/db/backends/mysql/operations.py @@ -49,9 +49,7 @@ class MySQLOperations(DatabaseOperations, BaseSpatialOperations): return placeholder def spatial_lookup_sql(self, lvalue, lookup_type, value, field, qn): - alias, col, db_type = lvalue - - geo_col = '%s.%s' % (qn(alias), qn(col)) + geo_col, db_type = lvalue lookup_info = self.geometry_functions.get(lookup_type, False) if lookup_info: diff --git a/django/contrib/gis/db/backends/oracle/operations.py b/django/contrib/gis/db/backends/oracle/operations.py index 8212938f01..b3aa191ccc 100644 --- a/django/contrib/gis/db/backends/oracle/operations.py +++ b/django/contrib/gis/db/backends/oracle/operations.py @@ -231,10 +231,7 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations): def spatial_lookup_sql(self, lvalue, lookup_type, value, field, qn): "Returns the SQL WHERE clause for use in Oracle spatial SQL construction." - alias, col, db_type = lvalue - - # Getting the quoted table name as `geo_col`. - geo_col = '%s.%s' % (qn(alias), qn(col)) + geo_col, db_type = lvalue # See if a Oracle Geometry function matches the lookup type next lookup_info = self.geometry_functions.get(lookup_type, False) diff --git a/django/contrib/gis/db/backends/postgis/operations.py b/django/contrib/gis/db/backends/postgis/operations.py index 40ca09f0fa..e8da6a52a6 100644 --- a/django/contrib/gis/db/backends/postgis/operations.py +++ b/django/contrib/gis/db/backends/postgis/operations.py @@ -478,10 +478,7 @@ class PostGISOperations(DatabaseOperations, BaseSpatialOperations): (alias, col, db_type), the lookup type string, lookup value, and the geometry field. """ - alias, col, db_type = lvalue - - # Getting the quoted geometry column. - geo_col = '%s.%s' % (qn(alias), qn(col)) + geo_col, db_type = lvalue if lookup_type in self.geometry_operators: if field.geography and not lookup_type in self.geography_operators: diff --git a/django/contrib/gis/db/backends/spatialite/operations.py b/django/contrib/gis/db/backends/spatialite/operations.py index 5149b541a2..ad44b470b3 100644 --- a/django/contrib/gis/db/backends/spatialite/operations.py +++ b/django/contrib/gis/db/backends/spatialite/operations.py @@ -324,10 +324,7 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations): [a tuple of (alias, column, db_type)], lookup type, lookup value, the model field, and the quoting function. """ - alias, col, db_type = lvalue - - # Getting the quoted field as `geo_col`. - geo_col = '%s.%s' % (qn(alias), qn(col)) + geo_col, db_type = lvalue if lookup_type in self.geometry_functions: # See if a SpatiaLite geometry function matches the lookup type. diff --git a/django/contrib/gis/db/models/constants.py b/django/contrib/gis/db/models/constants.py new file mode 100644 index 0000000000..4ece41415c --- /dev/null +++ b/django/contrib/gis/db/models/constants.py @@ -0,0 +1,15 @@ +from django.db.models.sql.constants import QUERY_TERMS + +GIS_LOOKUPS = { + 'bbcontains', 'bboverlaps', 'contained', 'contains', + 'contains_properly', 'coveredby', 'covers', 'crosses', 'disjoint', + 'distance_gt', 'distance_gte', 'distance_lt', 'distance_lte', + 'dwithin', 'equals', 'exact', + 'intersects', 'overlaps', 'relate', 'same_as', 'touches', 'within', + 'left', 'right', 'overlaps_left', 'overlaps_right', + 'overlaps_above', 'overlaps_below', + 'strictly_above', 'strictly_below' +} +ALL_TERMS = GIS_LOOKUPS | QUERY_TERMS + +__all__ = ['ALL_TERMS', 'GIS_LOOKUPS'] diff --git a/django/contrib/gis/db/models/fields.py b/django/contrib/gis/db/models/fields.py index 734f805c8d..2daf08147b 100644 --- a/django/contrib/gis/db/models/fields.py +++ b/django/contrib/gis/db/models/fields.py @@ -2,6 +2,8 @@ from django.db.models.fields import Field from django.db.models.sql.expressions import SQLEvaluator from django.utils.translation import ugettext_lazy as _ from django.contrib.gis import forms +from django.contrib.gis.db.models.constants import GIS_LOOKUPS +from django.contrib.gis.db.models.lookups import GISLookup from django.contrib.gis.db.models.proxy import GeometryProxy from django.contrib.gis.geometry.backend import Geometry, GeometryException from django.utils import six @@ -284,6 +286,10 @@ class GeometryField(Field): """ return connection.ops.get_geom_placeholder(self, value) +for lookup_name in GIS_LOOKUPS: + lookup = type(lookup_name, (GISLookup,), {'lookup_name': lookup_name}) + GeometryField.register_lookup(lookup) + # The OpenGIS Geometry Type Fields class PointField(GeometryField): diff --git a/django/contrib/gis/db/models/lookups.py b/django/contrib/gis/db/models/lookups.py new file mode 100644 index 0000000000..cad6c69000 --- /dev/null +++ b/django/contrib/gis/db/models/lookups.py @@ -0,0 +1,28 @@ +from django.db.models.lookups import Lookup +from django.db.models.sql.expressions import SQLEvaluator + + +class GISLookup(Lookup): + def as_sql(self, qn, connection): + from django.contrib.gis.db.models.sql import GeoWhereNode + # We use the same approach as was used by GeoWhereNode. It would + # be a good idea to upgrade GIS to use similar code that is used + # for other lookups. + if isinstance(self.rhs, SQLEvaluator): + # Make sure the F Expression destination field exists, and + # set an `srid` attribute with the same as that of the + # destination. + geo_fld = GeoWhereNode._check_geo_field(self.rhs.opts, self.rhs.expression.name) + if not geo_fld: + raise ValueError('No geographic field found in expression.') + self.rhs.srid = geo_fld.srid + db_type = self.lhs.output_type.db_type(connection=connection) + params = self.lhs.output_type.get_db_prep_lookup( + self.lookup_name, self.rhs, connection=connection) + lhs_sql, lhs_params = self.process_lhs(qn, connection) + # lhs_params not currently supported. + assert not lhs_params + data = (lhs_sql, db_type) + spatial_sql, spatial_params = connection.ops.spatial_lookup_sql( + data, self.lookup_name, self.rhs, self.lhs.output_type, qn) + return spatial_sql, spatial_params + params diff --git a/django/contrib/gis/db/models/sql/query.py b/django/contrib/gis/db/models/sql/query.py index 3923bf460e..f3fa1f6322 100644 --- a/django/contrib/gis/db/models/sql/query.py +++ b/django/contrib/gis/db/models/sql/query.py @@ -1,6 +1,7 @@ from django.db import connections from django.db.models.query import sql +from django.contrib.gis.db.models.constants import ALL_TERMS from django.contrib.gis.db.models.fields import GeometryField from django.contrib.gis.db.models.sql import aggregates as gis_aggregates from django.contrib.gis.db.models.sql.conversion import AreaField, DistanceField, GeomField @@ -9,19 +10,6 @@ from django.contrib.gis.geometry.backend import Geometry from django.contrib.gis.measure import Area, Distance -ALL_TERMS = set([ - 'bbcontains', 'bboverlaps', 'contained', 'contains', - 'contains_properly', 'coveredby', 'covers', 'crosses', 'disjoint', - 'distance_gt', 'distance_gte', 'distance_lt', 'distance_lte', - 'dwithin', 'equals', 'exact', - 'intersects', 'overlaps', 'relate', 'same_as', 'touches', 'within', - 'left', 'right', 'overlaps_left', 'overlaps_right', - 'overlaps_above', 'overlaps_below', - 'strictly_above', 'strictly_below' -]) -ALL_TERMS.update(sql.constants.QUERY_TERMS) - - class GeoQuery(sql.Query): """ A single spatial SQL query. diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index 2dbb8b3aae..99d68221e8 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -673,6 +673,9 @@ class BaseDatabaseFeatures(object): # What kind of error does the backend throw when accessing closed cursor? closed_cursor_error_class = ProgrammingError + # Does 'a' LIKE 'A' match? + has_case_insensitive_like = True + def __init__(self, connection): self.connection = connection diff --git a/django/db/backends/postgresql_psycopg2/base.py b/django/db/backends/postgresql_psycopg2/base.py index 7725b0c7a0..33f885d50c 100644 --- a/django/db/backends/postgresql_psycopg2/base.py +++ b/django/db/backends/postgresql_psycopg2/base.py @@ -62,6 +62,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_combined_alters = True nulls_order_largest = True closed_cursor_error_class = InterfaceError + has_case_insensitive_like = False class DatabaseWrapper(BaseDatabaseWrapper): @@ -83,6 +84,11 @@ class DatabaseWrapper(BaseDatabaseWrapper): 'iendswith': 'LIKE UPPER(%s)', } + pattern_ops = { + 'startswith': "LIKE %s || '%%%%'", + 'istartswith': "LIKE UPPER(%s) || '%%%%'", + } + Database = Database def __init__(self, *args, **kwargs): diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index 6c9728889f..e55973ea39 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -334,6 +334,11 @@ class DatabaseWrapper(BaseDatabaseWrapper): 'iendswith': "LIKE %s ESCAPE '\\'", } + pattern_ops = { + 'startswith': "LIKE %s || '%%%%'", + 'istartswith': "LIKE UPPER(%s) || '%%%%'", + } + Database = Database def __init__(self, *args, **kwargs): diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py index 454a693be7..bdbdd5fd91 100644 --- a/django/db/models/__init__.py +++ b/django/db/models/__init__.py @@ -17,6 +17,7 @@ from django.db.models.fields.related import ( # NOQA from django.db.models.fields.proxy import OrderWrt # NOQA from django.db.models.deletion import ( # NOQA CASCADE, PROTECT, SET, SET_NULL, SET_DEFAULT, DO_NOTHING, ProtectedError) +from django.db.models.lookups import Lookup, Transform # NOQA from django.db.models import signals # NOQA diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index 1ec11b4acb..e31d228aa5 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -15,10 +15,11 @@ def refs_aggregate(lookup_parts, aggregates): default annotation names we must check each prefix of the lookup_parts for match. """ - for i in range(len(lookup_parts) + 1): - if LOOKUP_SEP.join(lookup_parts[0:i]) in aggregates: - return True - return False + for n in range(len(lookup_parts) + 1): + level_n_lookup = LOOKUP_SEP.join(lookup_parts[0:n]) + if level_n_lookup in aggregates: + return aggregates[level_n_lookup], lookup_parts[n:] + return False, () class Aggregate(object): diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index 7172cf1a55..7ace8878aa 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -11,6 +11,7 @@ from itertools import tee from django.apps import apps from django.db import connection +from django.db.models.lookups import default_lookups, RegisterLookupMixin from django.db.models.query_utils import QueryWrapper from django.conf import settings from django import forms @@ -80,7 +81,7 @@ def _empty(of_cls): @total_ordering -class Field(object): +class Field(RegisterLookupMixin): """Base class for all field types""" # Designates whether empty strings fundamentally are allowed at the @@ -101,6 +102,7 @@ class Field(object): 'unique': _('%(model_name)s with this %(field_label)s ' 'already exists.'), } + class_lookups = default_lookups.copy() # Generic field type description, usually overridden by subclasses def _description(self): @@ -514,8 +516,7 @@ class Field(object): except ValueError: raise ValueError("The __year lookup type requires an integer " "argument") - - raise TypeError("Field has invalid lookup: %s" % lookup_type) + return self.get_prep_value(value) def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False): @@ -564,6 +565,8 @@ class Field(object): return connection.ops.year_lookup_bounds_for_date_field(value) else: return [value] # this isn't supposed to happen + else: + return [value] def has_default(self): """ diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 82e7725dac..69fb3f8492 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -5,9 +5,11 @@ from django.db.backends import utils from django.db.models import signals, Q from django.db.models.fields import (AutoField, Field, IntegerField, PositiveIntegerField, PositiveSmallIntegerField, FieldDoesNotExist) +from django.db.models.lookups import IsNull from django.db.models.related import RelatedObject, PathInfo from django.db.models.query import QuerySet from django.db.models.deletion import CASCADE +from django.db.models.sql.datastructures import Col from django.utils.encoding import smart_text from django.utils import six from django.utils.deprecation import RenameMethodsBase @@ -987,6 +989,11 @@ class ForeignObjectRel(object): # example custom multicolumn joins currently have no remote field). self.field_name = None + def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookup_type, + raw_value): + return self.field.get_lookup_constraint(constraint_class, alias, targets, sources, + lookup_type, raw_value) + class ManyToOneRel(ForeignObjectRel): def __init__(self, field, to, field_name, related_name=None, limit_choices_to=None, @@ -1193,14 +1200,16 @@ class ForeignObject(RelatedField): pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.rel, not self.unique, False)] return pathinfos - def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookup_type, + def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookups, raw_value): - from django.db.models.sql.where import SubqueryConstraint, Constraint, AND, OR + from django.db.models.sql.where import SubqueryConstraint, AND, OR root_constraint = constraint_class() assert len(targets) == len(sources) + if len(lookups) > 1: + raise exceptions.FieldError('Relation fields do not support nested lookups') + lookup_type = lookups[0] def get_normalized_value(value): - from django.db.models import Model if isinstance(value, Model): value_list = [] @@ -1221,28 +1230,27 @@ class ForeignObject(RelatedField): [source.name for source in sources], raw_value), AND) elif lookup_type == 'isnull': - root_constraint.add( - (Constraint(alias, targets[0].column, targets[0]), lookup_type, raw_value), AND) + root_constraint.add(IsNull(Col(alias, targets[0], sources[0]), raw_value), AND) elif (lookup_type == 'exact' or (lookup_type in ['gt', 'lt', 'gte', 'lte'] and not is_multicolumn)): value = get_normalized_value(raw_value) - for index, source in enumerate(sources): + for target, source, val in zip(targets, sources, value): + lookup_class = target.get_lookup(lookup_type) root_constraint.add( - (Constraint(alias, targets[index].column, sources[index]), lookup_type, - value[index]), AND) + lookup_class(Col(alias, target, source), val), AND) elif lookup_type in ['range', 'in'] and not is_multicolumn: values = [get_normalized_value(value) for value in raw_value] value = [val[0] for val in values] - root_constraint.add( - (Constraint(alias, targets[0].column, sources[0]), lookup_type, value), AND) + lookup_class = targets[0].get_lookup(lookup_type) + root_constraint.add(lookup_class(Col(alias, targets[0], sources[0]), value), AND) elif lookup_type == 'in': values = [get_normalized_value(value) for value in raw_value] for value in values: value_constraint = constraint_class() - for index, target in enumerate(targets): - value_constraint.add( - (Constraint(alias, target.column, sources[index]), 'exact', value[index]), - AND) + for source, target, val in zip(sources, targets, value): + lookup_class = target.get_lookup('exact') + lookup = lookup_class(Col(alias, target, source), val) + value_constraint.add(lookup, AND) root_constraint.add(value_constraint, OR) else: raise TypeError('Related Field got invalid lookup: %s' % lookup_type) diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py new file mode 100644 index 0000000000..5369994bbc --- /dev/null +++ b/django/db/models/lookups.py @@ -0,0 +1,317 @@ +from copy import copy +import inspect + +from django.conf import settings +from django.utils import timezone +from django.utils.functional import cached_property + + +class RegisterLookupMixin(object): + def get_lookup(self, lookup_name): + try: + return self.class_lookups[lookup_name] + except KeyError: + # To allow for inheritance, check parent class class lookups. + for parent in inspect.getmro(self.__class__): + if not 'class_lookups' in parent.__dict__: + continue + if lookup_name in parent.class_lookups: + return parent.class_lookups[lookup_name] + except AttributeError: + # This class didn't have any class_lookups + pass + if hasattr(self, 'output_type'): + return self.output_type.get_lookup(lookup_name) + return None + + @classmethod + def register_lookup(cls, lookup): + if not 'class_lookups' in cls.__dict__: + cls.class_lookups = {} + cls.class_lookups[lookup.lookup_name] = lookup + + @classmethod + def _unregister_lookup(cls, lookup): + """ + Removes given lookup from cls lookups. Meant to be used in + tests only. + """ + del cls.class_lookups[lookup.lookup_name] + + +class Transform(RegisterLookupMixin): + def __init__(self, lhs, lookups): + self.lhs = lhs + self.init_lookups = lookups[:] + + def as_sql(self, qn, connection): + raise NotImplementedError + + @cached_property + def output_type(self): + return self.lhs.output_type + + def relabeled_clone(self, relabels): + return self.__class__(self.lhs.relabeled_clone(relabels)) + + def get_group_by_cols(self): + return self.lhs.get_group_by_cols() + + +class Lookup(RegisterLookupMixin): + lookup_name = None + + def __init__(self, lhs, rhs): + self.lhs, self.rhs = lhs, rhs + self.rhs = self.get_prep_lookup() + + def get_prep_lookup(self): + return self.lhs.output_type.get_prep_lookup(self.lookup_name, self.rhs) + + def get_db_prep_lookup(self, value, connection): + return ( + '%s', self.lhs.output_type.get_db_prep_lookup( + self.lookup_name, value, connection, prepared=True)) + + def process_lhs(self, qn, connection, lhs=None): + lhs = lhs or self.lhs + return qn.compile(lhs) + + def process_rhs(self, qn, connection, rhs=None): + value = rhs or self.rhs + # Due to historical reasons there are a couple of different + # ways to produce sql here. get_compiler is likely a Query + # instance, _as_sql QuerySet and as_sql just something with + # as_sql. Finally the value can of course be just plain + # Python value. + if hasattr(value, 'get_compiler'): + value = value.get_compiler(connection=connection) + if hasattr(value, 'as_sql'): + sql, params = qn.compile(value) + return '(' + sql + ')', params + if hasattr(value, '_as_sql'): + sql, params = value._as_sql(connection=connection) + return '(' + sql + ')', params + else: + return self.get_db_prep_lookup(value, connection) + + def relabeled_clone(self, relabels): + new = copy(self) + new.lhs = new.lhs.relabeled_clone(relabels) + if hasattr(new.rhs, 'relabeled_clone'): + new.rhs = new.rhs.relabeled_clone(relabels) + return new + + def get_group_by_cols(self): + cols = self.lhs.get_group_by_cols() + if hasattr(self.rhs, 'get_group_by_cols'): + cols.extend(self.rhs.get_group_by_cols()) + return cols + + def as_sql(self, qn, connection): + raise NotImplementedError + + +class BuiltinLookup(Lookup): + def as_sql(self, qn, connection): + lhs_sql, params = self.process_lhs(qn, connection) + field_internal_type = self.lhs.output_type.get_internal_type() + db_type = self.lhs.output_type + lhs_sql = connection.ops.field_cast_sql(db_type, field_internal_type) % lhs_sql + lhs_sql = connection.ops.lookup_cast(self.lookup_name) % lhs_sql + rhs_sql, rhs_params = self.process_rhs(qn, connection) + params.extend(rhs_params) + operator_plus_rhs = self.get_rhs_op(connection, rhs_sql) + return '%s %s' % (lhs_sql, operator_plus_rhs), params + + def get_rhs_op(self, connection, rhs): + return connection.operators[self.lookup_name] % rhs + + +default_lookups = {} + + +class Exact(BuiltinLookup): + lookup_name = 'exact' +default_lookups['exact'] = Exact + + +class IExact(BuiltinLookup): + lookup_name = 'iexact' +default_lookups['iexact'] = IExact + + +class Contains(BuiltinLookup): + lookup_name = 'contains' +default_lookups['contains'] = Contains + + +class IContains(BuiltinLookup): + lookup_name = 'icontains' +default_lookups['icontains'] = IContains + + +class GreaterThan(BuiltinLookup): + lookup_name = 'gt' +default_lookups['gt'] = GreaterThan + + +class GreaterThanOrEqual(BuiltinLookup): + lookup_name = 'gte' +default_lookups['gte'] = GreaterThanOrEqual + + +class LessThan(BuiltinLookup): + lookup_name = 'lt' +default_lookups['lt'] = LessThan + + +class LessThanOrEqual(BuiltinLookup): + lookup_name = 'lte' +default_lookups['lte'] = LessThanOrEqual + + +class In(BuiltinLookup): + lookup_name = 'in' + + def get_db_prep_lookup(self, value, connection): + params = self.lhs.output_type.get_db_prep_lookup( + self.lookup_name, value, connection, prepared=True) + if not params: + # TODO: check why this leads to circular import + from django.db.models.sql.datastructures import EmptyResultSet + raise EmptyResultSet + placeholder = '(' + ', '.join('%s' for p in params) + ')' + return (placeholder, params) + + def get_rhs_op(self, connection, rhs): + return 'IN %s' % rhs +default_lookups['in'] = In + + +class PatternLookup(BuiltinLookup): + def get_rhs_op(self, connection, rhs): + # Assume we are in startswith. We need to produce SQL like: + # col LIKE %s, ['thevalue%'] + # For python values we can (and should) do that directly in Python, + # but if the value is for example reference to other column, then + # we need to add the % pattern match to the lookup by something like + # col LIKE othercol || '%%' + # So, for Python values we don't need any special pattern, but for + # SQL reference values we need the correct pattern added. + value = self.rhs + if (hasattr(value, 'get_compiler') or hasattr(value, 'as_sql') + or hasattr(value, '_as_sql')): + return connection.pattern_ops[self.lookup_name] % rhs + else: + return super(PatternLookup, self).get_rhs_op(connection, rhs) + + +class StartsWith(PatternLookup): + lookup_name = 'startswith' +default_lookups['startswith'] = StartsWith + + +class IStartsWith(PatternLookup): + lookup_name = 'istartswith' +default_lookups['istartswith'] = IStartsWith + + +class EndsWith(BuiltinLookup): + lookup_name = 'endswith' +default_lookups['endswith'] = EndsWith + + +class IEndsWith(BuiltinLookup): + lookup_name = 'iendswith' +default_lookups['iendswith'] = IEndsWith + + +class Between(BuiltinLookup): + def get_rhs_op(self, connection, rhs): + return "BETWEEN %s AND %s" % (rhs, rhs) + + +class Year(Between): + lookup_name = 'year' +default_lookups['year'] = Year + + +class Range(Between): + lookup_name = 'range' +default_lookups['range'] = Range + + +class DateLookup(BuiltinLookup): + + def process_lhs(self, qn, connection): + lhs, params = super(DateLookup, self).process_lhs(qn, connection) + tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None + sql, tz_params = connection.ops.datetime_extract_sql(self.extract_type, lhs, tzname) + return connection.ops.lookup_cast(self.lookup_name) % sql, tz_params + + def get_rhs_op(self, connection, rhs): + return '= %s' % rhs + + +class Month(DateLookup): + lookup_name = 'month' + extract_type = 'month' +default_lookups['month'] = Month + + +class Day(DateLookup): + lookup_name = 'day' + extract_type = 'day' +default_lookups['day'] = Day + + +class WeekDay(DateLookup): + lookup_name = 'week_day' + extract_type = 'week_day' +default_lookups['week_day'] = WeekDay + + +class Hour(DateLookup): + lookup_name = 'hour' + extract_type = 'hour' +default_lookups['hour'] = Hour + + +class Minute(DateLookup): + lookup_name = 'minute' + extract_type = 'minute' +default_lookups['minute'] = Minute + + +class Second(DateLookup): + lookup_name = 'second' + extract_type = 'second' +default_lookups['second'] = Second + + +class IsNull(BuiltinLookup): + lookup_name = 'isnull' + + def as_sql(self, qn, connection): + sql, params = qn.compile(self.lhs) + if self.rhs: + return "%s IS NULL" % sql, params + else: + return "%s IS NOT NULL" % sql, params +default_lookups['isnull'] = IsNull + + +class Search(BuiltinLookup): + lookup_name = 'search' +default_lookups['search'] = Search + + +class Regex(BuiltinLookup): + lookup_name = 'regex' +default_lookups['regex'] = Regex + + +class IRegex(BuiltinLookup): + lookup_name = 'iregex' +default_lookups['iregex'] = IRegex diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py index 8542a330c6..aef8b493bb 100644 --- a/django/db/models/sql/aggregates.py +++ b/django/db/models/sql/aggregates.py @@ -4,6 +4,7 @@ Classes to represent the default SQL aggregate functions import copy from django.db.models.fields import IntegerField, FloatField +from django.db.models.lookups import RegisterLookupMixin __all__ = ['Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance'] @@ -14,7 +15,7 @@ ordinal_aggregate_field = IntegerField() computed_aggregate_field = FloatField() -class Aggregate(object): +class Aggregate(RegisterLookupMixin): """ Default SQL Aggregate. """ @@ -93,6 +94,13 @@ class Aggregate(object): return self.sql_template % substitutions, params + def get_group_by_cols(self): + return [] + + @property + def output_type(self): + return self.field + class Avg(Aggregate): is_computed = True diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 41bba93206..123427cf8b 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -45,7 +45,7 @@ class SQLCompiler(object): if self.query.select_related and not self.query.related_select_cols: self.fill_related_selections() - def quote_name_unless_alias(self, name): + def __call__(self, name): """ A wrapper around connection.ops.quote_name that doesn't quote aliases for table names. This avoids problems with some SQL dialects that treat @@ -61,6 +61,22 @@ class SQLCompiler(object): self.quote_cache[name] = r return r + def quote_name_unless_alias(self, name): + """ + A wrapper around connection.ops.quote_name that doesn't quote aliases + for table names. This avoids problems with some SQL dialects that treat + quoted strings specially (e.g. PostgreSQL). + """ + return self(name) + + def compile(self, node): + vendor_impl = getattr( + node, 'as_' + self.connection.vendor, None) + if vendor_impl: + return vendor_impl(self, self.connection) + else: + return node.as_sql(self, self.connection) + def as_sql(self, with_limits=True, with_col_aliases=False): """ Creates the SQL for this query. Returns the SQL string and list of @@ -88,11 +104,9 @@ class SQLCompiler(object): # docstring of get_from_clause() for details. from_, f_params = self.get_from_clause() - qn = self.quote_name_unless_alias - - where, w_params = self.query.where.as_sql(qn=qn, connection=self.connection) - having, h_params = self.query.having.as_sql(qn=qn, connection=self.connection) - having_group_by = self.query.having.get_cols() + where, w_params = self.compile(self.query.where) + having, h_params = self.compile(self.query.having) + having_group_by = self.query.having.get_group_by_cols() params = [] for val in six.itervalues(self.query.extra_select): params.extend(val[1]) @@ -180,7 +194,7 @@ class SQLCompiler(object): (without the table names) are given unique aliases. This is needed in some cases to avoid ambiguity with nested queries. """ - qn = self.quote_name_unless_alias + qn = self qn2 = self.connection.ops.quote_name result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in six.iteritems(self.query.extra_select)] params = [] @@ -213,7 +227,7 @@ class SQLCompiler(object): aliases.add(r) col_aliases.add(col[1]) else: - col_sql, col_params = col.as_sql(qn, self.connection) + col_sql, col_params = self.compile(col) result.append(col_sql) params.extend(col_params) @@ -229,7 +243,7 @@ class SQLCompiler(object): max_name_length = self.connection.ops.max_name_length() for alias, aggregate in self.query.aggregate_select.items(): - agg_sql, agg_params = aggregate.as_sql(qn, self.connection) + agg_sql, agg_params = self.compile(aggregate) if alias is None: result.append(agg_sql) else: @@ -267,7 +281,7 @@ class SQLCompiler(object): result = [] if opts is None: opts = self.query.get_meta() - qn = self.quote_name_unless_alias + qn = self qn2 = self.connection.ops.quote_name aliases = set() only_load = self.deferred_to_columns() @@ -319,7 +333,7 @@ class SQLCompiler(object): Note that this method can alter the tables in the query, and thus it must be called before get_from_clause(). """ - qn = self.quote_name_unless_alias + qn = self qn2 = self.connection.ops.quote_name result = [] opts = self.query.get_meta() @@ -352,7 +366,7 @@ class SQLCompiler(object): ordering = (self.query.order_by or self.query.get_meta().ordering or []) - qn = self.quote_name_unless_alias + qn = self qn2 = self.connection.ops.quote_name distinct = self.query.distinct select_aliases = self._select_aliases @@ -490,7 +504,7 @@ class SQLCompiler(object): ordering and distinct must be done first. """ result = [] - qn = self.quote_name_unless_alias + qn = self qn2 = self.connection.ops.quote_name first = True from_params = [] @@ -508,8 +522,7 @@ class SQLCompiler(object): extra_cond = join_field.get_extra_restriction( self.query.where_class, alias, lhs) if extra_cond: - extra_sql, extra_params = extra_cond.as_sql( - qn, self.connection) + extra_sql, extra_params = self.compile(extra_cond) extra_sql = 'AND (%s)' % extra_sql from_params.extend(extra_params) else: @@ -541,7 +554,7 @@ class SQLCompiler(object): """ Returns a tuple representing the SQL elements in the "group by" clause. """ - qn = self.quote_name_unless_alias + qn = self result, params = [], [] if self.query.group_by is not None: select_cols = self.query.select + self.query.related_select_cols @@ -560,7 +573,7 @@ class SQLCompiler(object): if isinstance(col, (list, tuple)): sql = '%s.%s' % (qn(col[0]), qn(col[1])) elif hasattr(col, 'as_sql'): - sql, col_params = col.as_sql(qn, self.connection) + self.compile(col) else: sql = '(%s)' % str(col) if sql not in seen: @@ -784,7 +797,7 @@ class SQLCompiler(object): return result def as_subquery_condition(self, alias, columns, qn): - inner_qn = self.quote_name_unless_alias + inner_qn = self qn2 = self.connection.ops.quote_name if len(columns) == 1: sql, params = self.as_sql() @@ -895,9 +908,9 @@ class SQLDeleteCompiler(SQLCompiler): """ assert len(self.query.tables) == 1, \ "Can only delete from one table at a time." - qn = self.quote_name_unless_alias + qn = self result = ['DELETE FROM %s' % qn(self.query.tables[0])] - where, params = self.query.where.as_sql(qn=qn, connection=self.connection) + where, params = self.compile(self.query.where) if where: result.append('WHERE %s' % where) return ' '.join(result), tuple(params) @@ -913,7 +926,7 @@ class SQLUpdateCompiler(SQLCompiler): if not self.query.values: return '', () table = self.query.tables[0] - qn = self.quote_name_unless_alias + qn = self result = ['UPDATE %s' % qn(table)] result.append('SET') values, update_params = [], [] @@ -933,7 +946,7 @@ class SQLUpdateCompiler(SQLCompiler): val = SQLEvaluator(val, self.query, allow_joins=False) name = field.column if hasattr(val, 'as_sql'): - sql, params = val.as_sql(qn, self.connection) + sql, params = self.compile(val) values.append('%s = %s' % (qn(name), sql)) update_params.extend(params) elif val is not None: @@ -944,7 +957,7 @@ class SQLUpdateCompiler(SQLCompiler): if not values: return '', () result.append(', '.join(values)) - where, params = self.query.where.as_sql(qn=qn, connection=self.connection) + where, params = self.compile(self.query.where) if where: result.append('WHERE %s' % where) return ' '.join(result), tuple(update_params + params) @@ -1024,11 +1037,11 @@ class SQLAggregateCompiler(SQLCompiler): parameters. """ if qn is None: - qn = self.quote_name_unless_alias + qn = self sql, params = [], [] for aggregate in self.query.aggregate_select.values(): - agg_sql, agg_params = aggregate.as_sql(qn, self.connection) + agg_sql, agg_params = self.compile(aggregate) sql.append(agg_sql) params.extend(agg_params) sql = ', '.join(sql) diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py index f45ecaf76d..421c3cd860 100644 --- a/django/db/models/sql/datastructures.py +++ b/django/db/models/sql/datastructures.py @@ -5,18 +5,27 @@ the SQL domain. class Col(object): - def __init__(self, alias, col): - self.alias = alias - self.col = col + def __init__(self, alias, target, source): + self.alias, self.target, self.source = alias, target, source def as_sql(self, qn, connection): - return '%s.%s' % (qn(self.alias), self.col), [] + return "%s.%s" % (qn(self.alias), qn(self.target.column)), [] - def prepare(self): - return self + @property + def output_type(self): + return self.source def relabeled_clone(self, relabels): - return self.__class__(relabels.get(self.alias, self.alias), self.col) + return self.__class__(relabels.get(self.alias, self.alias), self.target, self.source) + + def get_group_by_cols(self): + return [(self.alias, self.target.column)] + + def get_lookup(self, name): + return self.output_type.get_lookup(name) + + def prepare(self): + return self class EmptyResultSet(Exception): diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py index 9f29e2ace5..e31eaa8a2f 100644 --- a/django/db/models/sql/expressions.py +++ b/django/db/models/sql/expressions.py @@ -24,11 +24,11 @@ class SQLEvaluator(object): (change_map.get(col[0], col[0]), col[1]))) return clone - def get_cols(self): + def get_group_by_cols(self): cols = [] for node, col in self.cols: - if hasattr(node, 'get_cols'): - cols.extend(node.get_cols()) + if hasattr(node, 'get_group_by_cols'): + cols.extend(node.get_group_by_cols()) elif isinstance(col, tuple): cols.append(col) return cols diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index c3c8e55793..db4e6744bf 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -19,6 +19,7 @@ from django.db.models.constants import LOOKUP_SEP from django.db.models.aggregates import refs_aggregate from django.db.models.expressions import ExpressionNode from django.db.models.fields import FieldDoesNotExist +from django.db.models.lookups import Transform from django.db.models.query_utils import Q from django.db.models.related import PathInfo from django.db.models.sql import aggregates as base_aggregates_module @@ -1028,13 +1029,16 @@ class Query(object): # Add the aggregate to the query aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary) - def prepare_lookup_value(self, value, lookup_type, can_reuse): + def prepare_lookup_value(self, value, lookups, can_reuse): + # Default lookup if none given is exact. + if len(lookups) == 0: + lookups = ['exact'] # Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all # uses of None as a query value. if value is None: - if lookup_type not in ('exact', 'iexact'): + if lookups[-1] not in ('exact', 'iexact'): raise ValueError("Cannot use None as a query value") - lookup_type = 'isnull' + lookups[-1] = 'isnull' value = True elif callable(value): warnings.warn( @@ -1055,40 +1059,54 @@ class Query(object): # stage. Using DEFAULT_DB_ALIAS isn't nice, but it is the best we # can do here. Similar thing is done in is_nullable(), too. if (connections[DEFAULT_DB_ALIAS].features.interprets_empty_strings_as_nulls and - lookup_type == 'exact' and value == ''): + lookups[-1] == 'exact' and value == ''): value = True - lookup_type = 'isnull' - return value, lookup_type + lookups[-1] = ['isnull'] + return value, lookups def solve_lookup_type(self, lookup): """ Solve the lookup type from the lookup (eg: 'foobar__id__icontains') """ - lookup_type = 'exact' # Default lookup type - lookup_parts = lookup.split(LOOKUP_SEP) - num_parts = len(lookup_parts) - if (len(lookup_parts) > 1 and lookup_parts[-1] in self.query_terms - and (not self._aggregates or lookup not in self._aggregates)): - # Traverse the lookup query to distinguish related fields from - # lookup types. - lookup_model = self.model - for counter, field_name in enumerate(lookup_parts): - try: - lookup_field = lookup_model._meta.get_field(field_name) - except FieldDoesNotExist: - # Not a field. Bail out. - lookup_type = lookup_parts.pop() - break - # Unless we're at the end of the list of lookups, let's attempt - # to continue traversing relations. - if (counter + 1) < num_parts: - try: - lookup_model = lookup_field.rel.to - except AttributeError: - # Not a related field. Bail out. - lookup_type = lookup_parts.pop() - break - return lookup_type, lookup_parts + lookup_splitted = lookup.split(LOOKUP_SEP) + if self._aggregates: + aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.aggregates) + if aggregate: + return aggregate_lookups, (), aggregate + _, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta()) + field_parts = lookup_splitted[0:len(lookup_splitted) - len(lookup_parts)] + if len(lookup_parts) == 0: + lookup_parts = ['exact'] + elif len(lookup_parts) > 1: + if not field_parts: + raise FieldError( + 'Invalid lookup "%s" for model %s".' % + (lookup, self.get_meta().model.__name__)) + return lookup_parts, field_parts, False + + def build_lookup(self, lookups, lhs, rhs): + lookups = lookups[:] + while lookups: + lookup = lookups[0] + next = lhs.get_lookup(lookup) + if next: + if len(lookups) == 1: + # This was the last lookup, so return value lookup. + if issubclass(next, Transform): + lookups.append('exact') + lhs = next(lhs, lookups) + else: + return next(lhs, rhs) + else: + lhs = next(lhs, lookups) + # A field's get_lookup() can return None to opt for backwards + # compatibility path. + elif len(lookups) > 2: + raise FieldError( + "Unsupported lookup for field '%s'" % lhs.output_type.name) + else: + return None + lookups = lookups[1:] def build_filter(self, filter_expr, branch_negated=False, current_negated=False, can_reuse=None, connector=AND): @@ -1118,21 +1136,24 @@ class Query(object): is responsible for unreffing the joins used. """ arg, value = filter_expr - lookup_type, parts = self.solve_lookup_type(arg) - if not parts: + if not arg: raise FieldError("Cannot parse keyword query %r" % arg) + lookups, parts, reffed_aggregate = self.solve_lookup_type(arg) # Work out the lookup type and remove it from the end of 'parts', # if necessary. - value, lookup_type = self.prepare_lookup_value(value, lookup_type, can_reuse) + value, lookups = self.prepare_lookup_value(value, lookups, can_reuse) used_joins = getattr(value, '_used_joins', []) clause = self.where_class() - if self._aggregates: - for alias, aggregate in self.aggregates.items(): - if alias in (parts[0], LOOKUP_SEP.join(parts)): - clause.add((aggregate, lookup_type, value), AND) - return clause, [] + if reffed_aggregate: + condition = self.build_lookup(lookups, reffed_aggregate, value) + if not condition: + # Backwards compat for custom lookups + assert len(lookups) == 1 + condition = (reffed_aggregate, lookups[0], value) + clause.add(condition, AND) + return clause, [] opts = self.get_meta() alias = self.get_initial_alias() @@ -1154,11 +1175,31 @@ class Query(object): targets, alias, join_list = self.trim_joins(sources, join_list, path) if hasattr(field, 'get_lookup_constraint'): - constraint = field.get_lookup_constraint(self.where_class, alias, targets, sources, - lookup_type, value) + # For now foreign keys get special treatment. This should be + # refactored when composite fields lands. + condition = field.get_lookup_constraint(self.where_class, alias, targets, sources, + lookups, value) + lookup_type = lookups[-1] else: - constraint = (Constraint(alias, targets[0].column, field), lookup_type, value) - clause.add(constraint, AND) + assert(len(targets) == 1) + col = Col(alias, targets[0], field) + condition = self.build_lookup(lookups, col, value) + if not condition: + # Backwards compat for custom lookups + if lookups[0] not in self.query_terms: + raise FieldError( + "Join on field '%s' not permitted. Did you " + "misspell '%s' for the lookup type?" % + (col.output_type.name, lookups[0])) + if len(lookups) > 1: + raise FieldError("Nested lookup '%s' not supported." % + LOOKUP_SEP.join(lookups)) + condition = (Constraint(alias, targets[0].column, field), lookups[0], value) + lookup_type = lookups[-1] + else: + lookup_type = condition.lookup_name + + clause.add(condition, AND) require_outer = lookup_type == 'isnull' and value is True and not current_negated if current_negated and (lookup_type != 'isnull' or value is False): @@ -1175,7 +1216,8 @@ class Query(object): # (col IS NULL OR col != someval) # <=> # NOT (col IS NOT NULL AND col = someval). - clause.add((Constraint(alias, targets[0].column, None), 'isnull', False), AND) + lookup_class = targets[0].get_lookup('isnull') + clause.add(lookup_class(Col(alias, targets[0], sources[0]), False), AND) return clause, used_joins if not require_outer else () def add_filter(self, filter_clause): @@ -1189,7 +1231,7 @@ class Query(object): if not self._aggregates: return False if not isinstance(obj, Node): - return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.aggregates) + return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.aggregates)[0] or (hasattr(obj[1], 'contains_aggregate') and obj[1].contains_aggregate(self.aggregates))) return any(self.need_having(c) for c in obj.children) @@ -1277,7 +1319,7 @@ class Query(object): needed_inner = joinpromoter.update_join_types(self) return target_clause, needed_inner - def names_to_path(self, names, opts, allow_many): + def names_to_path(self, names, opts, allow_many=True, fail_on_missing=False): """ Walks the names path and turns them PathInfo tuples. Note that a single name in 'names' can generate multiple PathInfos (m2m for @@ -1297,9 +1339,10 @@ class Query(object): try: field, model, direct, m2m = opts.get_field_by_name(name) except FieldDoesNotExist: - available = opts.get_all_field_names() + list(self.aggregate_select) - raise FieldError("Cannot resolve keyword %r into field. " - "Choices are: %s" % (name, ", ".join(available))) + # We didn't found the current field, so move position back + # one step. + pos -= 1 + break # Check if we need any joins for concrete inheritance cases (the # field lives in parent, but we are currently in one of its # children) @@ -1334,15 +1377,14 @@ class Query(object): final_field = field targets = (field,) break + if pos == -1 or (fail_on_missing and pos + 1 != len(names)): + self.raise_field_error(opts, name) + return path, final_field, targets, names[pos + 1:] - if pos != len(names) - 1: - if pos == len(names) - 2: - raise FieldError( - "Join on field %r not permitted. Did you misspell %r for " - "the lookup type?" % (name, names[pos + 1])) - else: - raise FieldError("Join on field %r not permitted." % name) - return path, final_field, targets + def raise_field_error(self, opts, name): + available = opts.get_all_field_names() + list(self.aggregate_select) + raise FieldError("Cannot resolve keyword %r into field. " + "Choices are: %s" % (name, ", ".join(available))) def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True): """ @@ -1371,8 +1413,9 @@ class Query(object): """ joins = [alias] # First, generate the path for the names - path, final_field, targets = self.names_to_path( - names, opts, allow_many) + path, final_field, targets, rest = self.names_to_path( + names, opts, allow_many, fail_on_missing=True) + # Then, add the path to the query's joins. Note that we can't trim # joins at this stage - we will need the information about join type # of the trimmed joins. @@ -1387,8 +1430,6 @@ class Query(object): alias = self.join( connection, reuse=reuse, nullable=nullable, join_field=join.join_field) joins.append(alias) - if hasattr(final_field, 'field'): - final_field = final_field.field return final_field, targets, opts, joins, path def trim_joins(self, targets, joins, path): @@ -1451,17 +1492,19 @@ class Query(object): # nothing alias, col = query.select[0].col if self.is_nullable(query.select[0].field): - query.where.add((Constraint(alias, col, query.select[0].field), 'isnull', False), AND) + lookup_class = query.select[0].field.get_lookup('isnull') + lookup = lookup_class(Col(alias, query.select[0].field, query.select[0].field), False) + query.where.add(lookup, AND) if alias in can_reuse: - pk = query.select[0].field.model._meta.pk + select_field = query.select[0].field + pk = select_field.model._meta.pk # Need to add a restriction so that outer query's filters are in effect for # the subquery, too. query.bump_prefix(self) - query.where.add( - (Constraint(query.select[0].col[0], pk.column, pk), - 'exact', Col(alias, pk.column)), - AND - ) + lookup_class = select_field.get_lookup('exact') + lookup = lookup_class(Col(query.select[0].col[0], pk, pk), + Col(alias, pk, pk)) + query.where.add(lookup, AND) condition, needed_inner = self.build_filter( ('%s__in' % trimmed_prefix, query), diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index e9e292e787..86b1efd3f8 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -5,12 +5,12 @@ Query subclasses which provide extra functionality beyond simple data retrieval. from django.conf import settings from django.core.exceptions import FieldError from django.db import connections +from django.db.models.query_utils import Q from django.db.models.constants import LOOKUP_SEP from django.db.models.fields import DateField, DateTimeField, FieldDoesNotExist from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, SelectInfo from django.db.models.sql.datastructures import Date, DateTime from django.db.models.sql.query import Query -from django.db.models.sql.where import AND, Constraint from django.utils import six from django.utils import timezone @@ -42,10 +42,10 @@ class DeleteQuery(Query): if not field: field = self.get_meta().pk for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): - where = self.where_class() - where.add((Constraint(None, field.column, field), 'in', - pk_list[offset:offset + GET_ITERATOR_CHUNK_SIZE]), AND) - self.do_query(self.get_meta().db_table, where, using=using) + self.where = self.where_class() + self.add_q(Q( + **{field.attname + '__in': pk_list[offset:offset + GET_ITERATOR_CHUNK_SIZE]})) + self.do_query(self.get_meta().db_table, self.where, using=using) def delete_qs(self, query, using): """ @@ -80,9 +80,8 @@ class DeleteQuery(Query): SelectInfo((self.get_initial_alias(), pk.column), None) ] values = innerq - where = self.where_class() - where.add((Constraint(None, pk.column, pk), 'in', values), AND) - self.where = where + self.where = self.where_class() + self.add_q(Q(pk__in=values)) self.get_compiler(using).execute_sql(None) @@ -113,13 +112,10 @@ class UpdateQuery(Query): related_updates=self.related_updates.copy(), **kwargs) def update_batch(self, pk_list, values, using): - pk_field = self.get_meta().pk self.add_update_values(values) for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): self.where = self.where_class() - self.where.add((Constraint(None, pk_field.column, pk_field), 'in', - pk_list[offset:offset + GET_ITERATOR_CHUNK_SIZE]), - AND) + self.add_q(Q(pk__in=pk_list[offset: offset + GET_ITERATOR_CHUNK_SIZE])) self.get_compiler(using).execute_sql(None) def add_update_values(self, values): diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 44a4ce9d1d..be0c559c1b 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -5,6 +5,7 @@ Code to manage the creation and SQL rendering of 'where' constraints. import collections import datetime from itertools import repeat +import warnings from django.conf import settings from django.db.models.fields import DateTimeField, Field @@ -101,7 +102,7 @@ class WhereNode(tree.Node): for child in self.children: try: if hasattr(child, 'as_sql'): - sql, params = child.as_sql(qn=qn, connection=connection) + sql, params = qn.compile(child) else: # A leaf node in the tree. sql, params = self.make_atom(child, qn, connection) @@ -152,16 +153,16 @@ class WhereNode(tree.Node): sql_string = '(%s)' % sql_string return sql_string, result_params - def get_cols(self): + def get_group_by_cols(self): cols = [] for child in self.children: - if hasattr(child, 'get_cols'): - cols.extend(child.get_cols()) + if hasattr(child, 'get_group_by_cols'): + cols.extend(child.get_group_by_cols()) else: if isinstance(child[0], Constraint): cols.append((child[0].alias, child[0].col)) - if hasattr(child[3], 'get_cols'): - cols.extend(child[3].get_cols()) + if hasattr(child[3], 'get_group_by_cols'): + cols.extend(child[3].get_group_by_cols()) return cols def make_atom(self, child, qn, connection): @@ -174,6 +175,9 @@ class WhereNode(tree.Node): Returns the string for the SQL fragment and the parameters to use for it. """ + warnings.warn( + "The make_atom() method will be removed in Django 1.9. Use Lookup class instead.", + PendingDeprecationWarning) lvalue, lookup_type, value_annotation, params_or_value = child field_internal_type = lvalue.field.get_internal_type() if lvalue.field else None @@ -193,13 +197,13 @@ class WhereNode(tree.Node): field_sql, field_params = self.sql_for_columns(lvalue, qn, connection, field_internal_type), [] else: # A smart object with an as_sql() method. - field_sql, field_params = lvalue.as_sql(qn, connection) + field_sql, field_params = qn.compile(lvalue) is_datetime_field = value_annotation is datetime.datetime cast_sql = connection.ops.datetime_cast_sql() if is_datetime_field else '%s' if hasattr(params, 'as_sql'): - extra, params = params.as_sql(qn, connection) + extra, params = qn.compile(params) cast_sql = '' else: extra = '' @@ -282,6 +286,8 @@ class WhereNode(tree.Node): if hasattr(child, 'relabel_aliases'): # For example another WhereNode child.relabel_aliases(change_map) + elif hasattr(child, 'relabeled_clone'): + self.children[pos] = child.relabeled_clone(change_map) elif isinstance(child, (list, tuple)): # tuple starting with Constraint child = (child[0].relabeled_clone(change_map),) + child[1:] @@ -347,10 +353,13 @@ class Constraint(object): pre-process itself prior to including in the WhereNode. """ def __init__(self, alias, col, field): + warnings.warn( + "The Constraint class will be removed in Django 1.9. Use Lookup class instead.", + PendingDeprecationWarning) self.alias, self.col, self.field = alias, col, field def prepare(self, lookup_type, value): - if self.field: + if self.field and not hasattr(value, 'as_sql'): return self.field.get_prep_lookup(lookup_type, value) return value |
