summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--django/contrib/postgres/fields/array.py4
-rw-r--r--django/db/backends/base/features.py3
-rw-r--r--django/db/backends/base/operations.py17
-rw-r--r--django/db/backends/oracle/operations.py2
-rw-r--r--django/db/backends/postgresql_psycopg2/base.py10
-rw-r--r--django/db/backends/postgresql_psycopg2/features.py1
-rw-r--r--django/db/backends/postgresql_psycopg2/operations.py28
-rw-r--r--django/db/models/fields/__init__.py6
-rw-r--r--django/db/models/lookups.py2
-rw-r--r--tests/model_fields/tests.py5
-rw-r--r--tests/postgres_tests/migrations/0002_create_test_models.py16
-rw-r--r--tests/postgres_tests/models.py10
-rw-r--r--tests/postgres_tests/test_array.py30
13 files changed, 111 insertions, 23 deletions
diff --git a/django/contrib/postgres/fields/array.py b/django/contrib/postgres/fields/array.py
index 318afabd2c..b0850b92e7 100644
--- a/django/contrib/postgres/fields/array.py
+++ b/django/contrib/postgres/fields/array.py
@@ -70,9 +70,9 @@ class ArrayField(Field):
size = self.size or ''
return '%s[%s]' % (self.base_field.db_type(connection), size)
- def get_prep_value(self, value):
+ def get_db_prep_value(self, value, connection, prepared=False):
if isinstance(value, list) or isinstance(value, tuple):
- return [self.base_field.get_prep_value(i) for i in value]
+ return [self.base_field.get_db_prep_value(i, connection, prepared) for i in value]
return value
def deconstruct(self):
diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py
index fe13827b77..0f6ee0efe3 100644
--- a/django/db/backends/base/features.py
+++ b/django/db/backends/base/features.py
@@ -59,6 +59,9 @@ class BaseDatabaseFeatures(object):
supports_subqueries_in_group_by = True
supports_bitwise_or = True
+ # Is there a true datatype for uuid?
+ has_native_uuid_field = False
+
# Is there a true datatype for timedeltas?
has_native_duration_field = False
diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py
index c4e78e719e..24bcbb3d08 100644
--- a/django/db/backends/base/operations.py
+++ b/django/db/backends/base/operations.py
@@ -219,7 +219,7 @@ class BaseDatabaseOperations(object):
"""
return cursor.lastrowid
- def lookup_cast(self, lookup_type):
+ def lookup_cast(self, lookup_type, internal_type=None):
"""
Returns the string to use in a query when performing lookups
("contains", "like", etc). The resulting string should contain a '%s'
@@ -442,7 +442,7 @@ class BaseDatabaseOperations(object):
def value_to_db_date(self, value):
"""
- Transform a date value to an object compatible with what is expected
+ Transforms a date value to an object compatible with what is expected
by the backend driver for date columns.
"""
if value is None:
@@ -451,7 +451,7 @@ class BaseDatabaseOperations(object):
def value_to_db_datetime(self, value):
"""
- Transform a datetime value to an object compatible with what is expected
+ Transforms a datetime value to an object compatible with what is expected
by the backend driver for datetime columns.
"""
if value is None:
@@ -460,7 +460,7 @@ class BaseDatabaseOperations(object):
def value_to_db_time(self, value):
"""
- Transform a time value to an object compatible with what is expected
+ Transforms a time value to an object compatible with what is expected
by the backend driver for time columns.
"""
if value is None:
@@ -471,11 +471,18 @@ class BaseDatabaseOperations(object):
def value_to_db_decimal(self, value, max_digits, decimal_places):
"""
- Transform a decimal.Decimal value to an object compatible with what is
+ Transforms a decimal.Decimal value to an object compatible with what is
expected by the backend driver for decimal (numeric) columns.
"""
return utils.format_number(value, max_digits, decimal_places)
+ def value_to_db_ipaddress(self, value):
+ """
+ Transforms a string representation of an IP address into the expected
+ type for the backend driver.
+ """
+ return value
+
def year_lookup_bounds_for_date_field(self, value):
"""
Returns a two-elements list with the lower and upper bound to be used
diff --git a/django/db/backends/oracle/operations.py b/django/db/backends/oracle/operations.py
index f00fd3fbea..fe9c93ba90 100644
--- a/django/db/backends/oracle/operations.py
+++ b/django/db/backends/oracle/operations.py
@@ -246,7 +246,7 @@ WHEN (new.%(col_name)s IS NULL)
cursor.execute('SELECT "%s".currval FROM dual' % sq_name)
return cursor.fetchone()[0]
- def lookup_cast(self, lookup_type):
+ def lookup_cast(self, lookup_type, internal_type=None):
if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'):
return "UPPER(%s)"
return "%s"
diff --git a/django/db/backends/postgresql_psycopg2/base.py b/django/db/backends/postgresql_psycopg2/base.py
index 34b2870773..37433c3987 100644
--- a/django/db/backends/postgresql_psycopg2/base.py
+++ b/django/db/backends/postgresql_psycopg2/base.py
@@ -38,6 +38,16 @@ psycopg2.extensions.register_adapter(SafeBytes, psycopg2.extensions.QuotedString
psycopg2.extensions.register_adapter(SafeText, psycopg2.extensions.QuotedString)
psycopg2.extras.register_uuid()
+# Register support for inet[] manually so we don't have to handle the Inet()
+# object on load all the time.
+INETARRAY_OID = 1041
+INETARRAY = psycopg2.extensions.new_array_type(
+ (INETARRAY_OID,),
+ 'INETARRAY',
+ psycopg2.extensions.UNICODE,
+)
+psycopg2.extensions.register_type(INETARRAY)
+
class DatabaseWrapper(BaseDatabaseWrapper):
vendor = 'postgresql'
diff --git a/django/db/backends/postgresql_psycopg2/features.py b/django/db/backends/postgresql_psycopg2/features.py
index 64acd0570a..6bb6de1a96 100644
--- a/django/db/backends/postgresql_psycopg2/features.py
+++ b/django/db/backends/postgresql_psycopg2/features.py
@@ -6,6 +6,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
needs_datetime_string_cast = False
can_return_id_from_insert = True
has_real_datatype = True
+ has_native_uuid_field = True
has_native_duration_field = True
driver_supports_timedelta_args = True
can_defer_constraint_checks = True
diff --git a/django/db/backends/postgresql_psycopg2/operations.py b/django/db/backends/postgresql_psycopg2/operations.py
index 8e90a4020b..27b19db459 100644
--- a/django/db/backends/postgresql_psycopg2/operations.py
+++ b/django/db/backends/postgresql_psycopg2/operations.py
@@ -3,6 +3,8 @@ from __future__ import unicode_literals
from django.conf import settings
from django.db.backends.base.operations import BaseDatabaseOperations
+from psycopg2.extras import Inet
+
class DatabaseOperations(BaseDatabaseOperations):
def unification_cast_sql(self, output_field):
@@ -57,13 +59,16 @@ class DatabaseOperations(BaseDatabaseOperations):
def deferrable_sql(self):
return " DEFERRABLE INITIALLY DEFERRED"
- def lookup_cast(self, lookup_type):
+ def lookup_cast(self, lookup_type, internal_type=None):
lookup = '%s'
# Cast text lookups to text to allow things like filter(x__contains=4)
if lookup_type in ('iexact', 'contains', 'icontains', 'startswith',
'istartswith', 'endswith', 'iendswith', 'regex', 'iregex'):
- lookup = "%s::text"
+ if internal_type in ('IPAddressField', 'GenericIPAddressField'):
+ lookup = "HOST(%s)"
+ else:
+ lookup = "%s::text"
# Use UPPER(x) for case-insensitive lookups; it's faster.
if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'):
@@ -71,11 +76,6 @@ class DatabaseOperations(BaseDatabaseOperations):
return lookup
- def field_cast_sql(self, db_type, internal_type):
- if internal_type == "GenericIPAddressField" or internal_type == "IPAddressField":
- return 'HOST(%s)'
- return '%s'
-
def last_insert_id(self, cursor, table_name, pk_name):
# Use pg_get_serial_sequence to get the underlying sequence name
# from the table name and column name (available since PostgreSQL 8)
@@ -224,3 +224,17 @@ class DatabaseOperations(BaseDatabaseOperations):
def bulk_insert_sql(self, fields, num_values):
items_sql = "(%s)" % ", ".join(["%s"] * len(fields))
return "VALUES " + ", ".join([items_sql] * num_values)
+
+ def value_to_db_date(self, value):
+ return value
+
+ def value_to_db_datetime(self, value):
+ return value
+
+ def value_to_db_time(self, value):
+ return value
+
+ def value_to_db_ipaddress(self, value):
+ if value:
+ return Inet(value)
+ return None
diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py
index 03c7eafac6..d5dfac733f 100644
--- a/django/db/models/fields/__init__.py
+++ b/django/db/models/fields/__init__.py
@@ -1983,7 +1983,7 @@ class GenericIPAddressField(Field):
def get_db_prep_value(self, value, connection, prepared=False):
if not prepared:
value = self.get_prep_value(value)
- return value or None
+ return connection.ops.value_to_db_ipaddress(value)
def get_prep_value(self, value):
value = super(GenericIPAddressField, self).get_prep_value(value)
@@ -2366,8 +2366,10 @@ class UUIDField(Field):
def get_internal_type(self):
return "UUIDField"
- def get_prep_value(self, value):
+ def get_db_prep_value(self, value, connection, prepared=False):
if isinstance(value, uuid.UUID):
+ if connection.features.has_native_uuid_field:
+ return value
return value.hex
if isinstance(value, six.string_types):
return value.replace('-', '')
diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py
index d7423762f3..7610c0dde4 100644
--- a/django/db/models/lookups.py
+++ b/django/db/models/lookups.py
@@ -198,7 +198,7 @@ class BuiltinLookup(Lookup):
db_type = self.lhs.output_field.db_type(connection=connection)
lhs_sql = connection.ops.field_cast_sql(
db_type, field_internal_type) % lhs_sql
- lhs_sql = connection.ops.lookup_cast(self.lookup_name) % lhs_sql
+ lhs_sql = connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql
return lhs_sql, params
def as_sql(self, compiler, connection):
diff --git a/tests/model_fields/tests.py b/tests/model_fields/tests.py
index 897359c0cf..a9ce43cae6 100644
--- a/tests/model_fields/tests.py
+++ b/tests/model_fields/tests.py
@@ -695,6 +695,11 @@ class GenericIPAddressFieldTests(test.TestCase):
o = GenericIPAddress.objects.get()
self.assertIsNone(o.ip)
+ def test_save_load(self):
+ instance = GenericIPAddress.objects.create(ip='::1')
+ loaded = GenericIPAddress.objects.get()
+ self.assertEqual(loaded.ip, instance.ip)
+
class PromiseTest(test.TestCase):
def test_AutoField(self):
diff --git a/tests/postgres_tests/migrations/0002_create_test_models.py b/tests/postgres_tests/migrations/0002_create_test_models.py
index bdde4a9bf6..841953d351 100644
--- a/tests/postgres_tests/migrations/0002_create_test_models.py
+++ b/tests/postgres_tests/migrations/0002_create_test_models.py
@@ -27,7 +27,9 @@ class Migration(migrations.Migration):
name='DateTimeArrayModel',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
- ('field', django.contrib.postgres.fields.ArrayField(models.DateTimeField(), size=None)),
+ ('datetimes', django.contrib.postgres.fields.ArrayField(models.DateTimeField(), size=None)),
+ ('dates', django.contrib.postgres.fields.ArrayField(models.DateField(), size=None)),
+ ('times', django.contrib.postgres.fields.ArrayField(models.TimeField(), size=None)),
],
options={
},
@@ -44,6 +46,18 @@ class Migration(migrations.Migration):
bases=(models.Model,),
),
migrations.CreateModel(
+ name='OtherTypesArrayModel',
+ fields=[
+ ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
+ ('ips', django.contrib.postgres.fields.ArrayField(models.GenericIPAddressField(), size=None)),
+ ('uuids', django.contrib.postgres.fields.ArrayField(models.UUIDField(), size=None)),
+ ('decimals', django.contrib.postgres.fields.ArrayField(models.DecimalField(max_digits=5, decimal_places=2), size=None)),
+ ],
+ options={
+ },
+ bases=(models.Model,),
+ ),
+ migrations.CreateModel(
name='IntegerArrayModel',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
diff --git a/tests/postgres_tests/models.py b/tests/postgres_tests/models.py
index 74af39dd04..0422aba6a0 100644
--- a/tests/postgres_tests/models.py
+++ b/tests/postgres_tests/models.py
@@ -18,13 +18,21 @@ class CharArrayModel(models.Model):
class DateTimeArrayModel(models.Model):
- field = ArrayField(models.DateTimeField())
+ datetimes = ArrayField(models.DateTimeField())
+ dates = ArrayField(models.DateField())
+ times = ArrayField(models.TimeField())
class NestedIntegerArrayModel(models.Model):
field = ArrayField(ArrayField(models.IntegerField()))
+class OtherTypesArrayModel(models.Model):
+ ips = ArrayField(models.GenericIPAddressField())
+ uuids = ArrayField(models.UUIDField())
+ decimals = ArrayField(models.DecimalField(max_digits=5, decimal_places=2))
+
+
class HStoreModel(models.Model):
field = HStoreField(blank=True, null=True)
diff --git a/tests/postgres_tests/test_array.py b/tests/postgres_tests/test_array.py
index 90f4c246c6..5c300f7ea3 100644
--- a/tests/postgres_tests/test_array.py
+++ b/tests/postgres_tests/test_array.py
@@ -1,5 +1,7 @@
+import decimal
import json
import unittest
+import uuid
from django.contrib.postgres.fields import ArrayField
from django.contrib.postgres.forms import SimpleArrayField, SplitArrayField
@@ -10,7 +12,11 @@ from django import forms
from django.test import TestCase, override_settings
from django.utils import timezone
-from .models import IntegerArrayModel, NullableIntegerArrayModel, CharArrayModel, DateTimeArrayModel, NestedIntegerArrayModel, ArrayFieldSubclass
+from .models import (
+ IntegerArrayModel, NullableIntegerArrayModel, CharArrayModel,
+ DateTimeArrayModel, NestedIntegerArrayModel, OtherTypesArrayModel,
+ ArrayFieldSubclass,
+)
@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL required')
@@ -29,10 +35,16 @@ class TestSaveLoad(TestCase):
self.assertEqual(instance.field, loaded.field)
def test_dates(self):
- instance = DateTimeArrayModel(field=[timezone.now()])
+ instance = DateTimeArrayModel(
+ datetimes=[timezone.now()],
+ dates=[timezone.now().date()],
+ times=[timezone.now().time()],
+ )
instance.save()
loaded = DateTimeArrayModel.objects.get()
- self.assertEqual(instance.field, loaded.field)
+ self.assertEqual(instance.datetimes, loaded.datetimes)
+ self.assertEqual(instance.dates, loaded.dates)
+ self.assertEqual(instance.times, loaded.times)
def test_tuples(self):
instance = IntegerArrayModel(field=(1,))
@@ -70,6 +82,18 @@ class TestSaveLoad(TestCase):
loaded = NestedIntegerArrayModel.objects.get()
self.assertEqual(instance.field, loaded.field)
+ def test_other_array_types(self):
+ instance = OtherTypesArrayModel(
+ ips=['192.168.0.1', '::1'],
+ uuids=[uuid.uuid4()],
+ decimals=[decimal.Decimal(1.25), 1.75],
+ )
+ instance.save()
+ loaded = OtherTypesArrayModel.objects.get()
+ self.assertEqual(instance.ips, loaded.ips)
+ self.assertEqual(instance.uuids, loaded.uuids)
+ self.assertEqual(instance.decimals, loaded.decimals)
+
@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL required')
class TestQuerying(TestCase):