summaryrefslogtreecommitdiff
path: root/django/db/backends/postgresql
diff options
context:
space:
mode:
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>2022-12-01 20:23:43 +0100
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2022-12-15 06:17:57 +0100
commit09ffc5c1212d4ced58b708cbbf3dfbfb77b782ca (patch)
tree15bb8bb049f9339f30d637e78b340473c2038126 /django/db/backends/postgresql
parentd44ee518c4c110af25bebdbedbbf9fba04d197aa (diff)
Fixed #33308 -- Added support for psycopg version 3.
Thanks Simon Charette, Tim Graham, and Adam Johnson for reviews. Co-authored-by: Florian Apolloner <florian@apolloner.eu> Co-authored-by: Mariusz Felisiak <felisiak.mariusz@gmail.com>
Diffstat (limited to 'django/db/backends/postgresql')
-rw-r--r--django/db/backends/postgresql/base.py162
-rw-r--r--django/db/backends/postgresql/features.py11
-rw-r--r--django/db/backends/postgresql/operations.py90
-rw-r--r--django/db/backends/postgresql/psycopg_any.py113
-rw-r--r--django/db/backends/postgresql/schema.py8
5 files changed, 291 insertions, 93 deletions
diff --git a/django/db/backends/postgresql/base.py b/django/db/backends/postgresql/base.py
index 0aee39aa5c..ceea1bebad 100644
--- a/django/db/backends/postgresql/base.py
+++ b/django/db/backends/postgresql/base.py
@@ -1,7 +1,7 @@
"""
PostgreSQL database backend for Django.
-Requires psycopg 2: https://www.psycopg.org/
+Requires psycopg2 >= 2.8.4 or psycopg >= 3.1
"""
import asyncio
@@ -21,48 +21,63 @@ from django.utils.safestring import SafeString
from django.utils.version import get_version_tuple
try:
- import psycopg2 as Database
- import psycopg2.extensions
- import psycopg2.extras
-except ImportError as e:
- raise ImproperlyConfigured("Error loading psycopg2 module: %s" % e)
+ try:
+ import psycopg as Database
+ except ImportError:
+ import psycopg2 as Database
+except ImportError:
+ raise ImproperlyConfigured("Error loading psycopg2 or psycopg module")
-def psycopg2_version():
- version = psycopg2.__version__.split(" ", 1)[0]
+def psycopg_version():
+ version = Database.__version__.split(" ", 1)[0]
return get_version_tuple(version)
-PSYCOPG2_VERSION = psycopg2_version()
-
-if PSYCOPG2_VERSION < (2, 8, 4):
+if psycopg_version() < (2, 8, 4):
+ raise ImproperlyConfigured(
+ f"psycopg2 version 2.8.4 or newer is required; you have {Database.__version__}"
+ )
+if (3,) <= psycopg_version() < (3, 1):
raise ImproperlyConfigured(
- "psycopg2 version 2.8.4 or newer is required; you have %s"
- % psycopg2.__version__
+ f"psycopg version 3.1 or newer is required; you have {Database.__version__}"
)
-# Some of these import psycopg2, so import them after checking if it's installed.
-from .client import DatabaseClient # NOQA
-from .creation import DatabaseCreation # NOQA
-from .features import DatabaseFeatures # NOQA
-from .introspection import DatabaseIntrospection # NOQA
-from .operations import DatabaseOperations # NOQA
-from .psycopg_any import IsolationLevel # NOQA
-from .schema import DatabaseSchemaEditor # NOQA
+from .psycopg_any import IsolationLevel, is_psycopg3 # NOQA isort:skip
+
+if is_psycopg3:
+ from psycopg import adapters, sql
+ from psycopg.pq import Format
-psycopg2.extensions.register_adapter(SafeString, psycopg2.extensions.QuotedString)
-psycopg2.extras.register_uuid()
+ from .psycopg_any import get_adapters_template, register_tzloader
-# Register support for inet[] manually so we don't have to handle the Inet()
-# object on load all the time.
-INETARRAY_OID = 1041
-INETARRAY = psycopg2.extensions.new_array_type(
- (INETARRAY_OID,),
- "INETARRAY",
- psycopg2.extensions.UNICODE,
-)
-psycopg2.extensions.register_type(INETARRAY)
+ TIMESTAMPTZ_OID = adapters.types["timestamptz"].oid
+
+else:
+ import psycopg2.extensions
+ import psycopg2.extras
+
+ psycopg2.extensions.register_adapter(SafeString, psycopg2.extensions.QuotedString)
+ psycopg2.extras.register_uuid()
+
+ # Register support for inet[] manually so we don't have to handle the Inet()
+ # object on load all the time.
+ INETARRAY_OID = 1041
+ INETARRAY = psycopg2.extensions.new_array_type(
+ (INETARRAY_OID,),
+ "INETARRAY",
+ psycopg2.extensions.UNICODE,
+ )
+ psycopg2.extensions.register_type(INETARRAY)
+
+# Some of these import psycopg, so import them after checking if it's installed.
+from .client import DatabaseClient # NOQA isort:skip
+from .creation import DatabaseCreation # NOQA isort:skip
+from .features import DatabaseFeatures # NOQA isort:skip
+from .introspection import DatabaseIntrospection # NOQA isort:skip
+from .operations import DatabaseOperations # NOQA isort:skip
+from .schema import DatabaseSchemaEditor # NOQA isort:skip
class DatabaseWrapper(BaseDatabaseWrapper):
@@ -209,6 +224,15 @@ class DatabaseWrapper(BaseDatabaseWrapper):
conn_params["host"] = settings_dict["HOST"]
if settings_dict["PORT"]:
conn_params["port"] = settings_dict["PORT"]
+ if is_psycopg3:
+ conn_params["context"] = get_adapters_template(
+ settings.USE_TZ, self.timezone
+ )
+ # Disable prepared statements by default to keep connection poolers
+ # working. Can be reenabled via OPTIONS in the settings dict.
+ conn_params["prepare_threshold"] = conn_params.pop(
+ "prepare_threshold", None
+ )
return conn_params
@async_unsafe
@@ -232,17 +256,19 @@ class DatabaseWrapper(BaseDatabaseWrapper):
except ValueError:
raise ImproperlyConfigured(
f"Invalid transaction isolation level {isolation_level_value} "
- f"specified. Use one of the IsolationLevel values."
+ f"specified. Use one of the psycopg.IsolationLevel values."
)
- connection = Database.connect(**conn_params)
+ connection = self.Database.connect(**conn_params)
if set_isolation_level:
connection.isolation_level = self.isolation_level
- # Register dummy loads() to avoid a round trip from psycopg2's decode
- # to json.dumps() to json.loads(), when using a custom decoder in
- # JSONField.
- psycopg2.extras.register_default_jsonb(
- conn_or_curs=connection, loads=lambda x: x
- )
+ if not is_psycopg3:
+ # Register dummy loads() to avoid a round trip from psycopg2's
+ # decode to json.dumps() to json.loads(), when using a custom
+ # decoder in JSONField.
+ psycopg2.extras.register_default_jsonb(
+ conn_or_curs=connection, loads=lambda x: x
+ )
+ connection.cursor_factory = Cursor
return connection
def ensure_timezone(self):
@@ -275,7 +301,15 @@ class DatabaseWrapper(BaseDatabaseWrapper):
)
else:
cursor = self.connection.cursor()
- cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None
+
+ if is_psycopg3:
+ # Register the cursor timezone only if the connection disagrees, to
+ # avoid copying the adapter map.
+ tzloader = self.connection.adapters.get_loader(TIMESTAMPTZ_OID, Format.TEXT)
+ if self.timezone != tzloader.timezone:
+ register_tzloader(self.timezone, cursor)
+ else:
+ cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None
return cursor
def tzinfo_factory(self, offset):
@@ -379,11 +413,43 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return CursorDebugWrapper(cursor, self)
-class CursorDebugWrapper(BaseCursorDebugWrapper):
- def copy_expert(self, sql, file, *args):
- with self.debug_sql(sql):
- return self.cursor.copy_expert(sql, file, *args)
+if is_psycopg3:
+
+ class Cursor(Database.Cursor):
+ """
+ A subclass of psycopg cursor implementing callproc.
+ """
+
+ def callproc(self, name, args=None):
+ if not isinstance(name, sql.Identifier):
+ name = sql.Identifier(name)
+
+ qparts = [sql.SQL("SELECT * FROM "), name, sql.SQL("(")]
+ if args:
+ for item in args:
+ qparts.append(sql.Literal(item))
+ qparts.append(sql.SQL(","))
+ del qparts[-1]
+
+ qparts.append(sql.SQL(")"))
+ stmt = sql.Composed(qparts)
+ self.execute(stmt)
+ return args
+
+ class CursorDebugWrapper(BaseCursorDebugWrapper):
+ def copy(self, statement):
+ with self.debug_sql(statement):
+ return self.cursor.copy(statement)
+
+else:
+
+ Cursor = psycopg2.extensions.cursor
+
+ class CursorDebugWrapper(BaseCursorDebugWrapper):
+ def copy_expert(self, sql, file, *args):
+ with self.debug_sql(sql):
+ return self.cursor.copy_expert(sql, file, *args)
- def copy_to(self, file, table, *args, **kwargs):
- with self.debug_sql(sql="COPY %s TO STDOUT" % table):
- return self.cursor.copy_to(file, table, *args, **kwargs)
+ def copy_to(self, file, table, *args, **kwargs):
+ with self.debug_sql(sql="COPY %s TO STDOUT" % table):
+ return self.cursor.copy_to(file, table, *args, **kwargs)
diff --git a/django/db/backends/postgresql/features.py b/django/db/backends/postgresql/features.py
index 0eed8c8d63..fd5b05aad4 100644
--- a/django/db/backends/postgresql/features.py
+++ b/django/db/backends/postgresql/features.py
@@ -1,7 +1,8 @@
import operator
-from django.db import InterfaceError
+from django.db import DataError, InterfaceError
from django.db.backends.base.features import BaseDatabaseFeatures
+from django.db.backends.postgresql.psycopg_any import is_psycopg3
from django.utils.functional import cached_property
@@ -26,6 +27,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
can_introspect_materialized_views = True
can_distinct_on_fields = True
can_rollback_ddl = True
+ schema_editor_uses_clientside_param_binding = True
supports_combined_alters = True
nulls_order_largest = True
closed_cursor_error_class = InterfaceError
@@ -82,6 +84,13 @@ class DatabaseFeatures(BaseDatabaseFeatures):
}
@cached_property
+ def prohibits_null_characters_in_text_exception(self):
+ if is_psycopg3:
+ return DataError, "PostgreSQL text fields cannot contain NUL (0x00) bytes"
+ else:
+ return ValueError, "A string literal cannot contain NUL (0x00) characters."
+
+ @cached_property
def introspected_field_types(self):
return {
**super().introspected_field_types,
diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py
index 824e0c3e4b..18cfcb29cb 100644
--- a/django/db/backends/postgresql/operations.py
+++ b/django/db/backends/postgresql/operations.py
@@ -3,9 +3,16 @@ from functools import lru_cache, partial
from django.conf import settings
from django.db.backends.base.operations import BaseDatabaseOperations
-from django.db.backends.postgresql.psycopg_any import Inet, Jsonb, mogrify
+from django.db.backends.postgresql.psycopg_any import (
+ Inet,
+ Jsonb,
+ errors,
+ is_psycopg3,
+ mogrify,
+)
from django.db.backends.utils import split_tzname_delta
from django.db.models.constants import OnConflict
+from django.utils.regex_helper import _lazy_re_compile
@lru_cache
@@ -36,6 +43,18 @@ class DatabaseOperations(BaseDatabaseOperations):
"SmallAutoField": "smallint",
}
+ if is_psycopg3:
+ from psycopg.types import numeric
+
+ integerfield_type_map = {
+ "SmallIntegerField": numeric.Int2,
+ "IntegerField": numeric.Int4,
+ "BigIntegerField": numeric.Int8,
+ "PositiveSmallIntegerField": numeric.Int2,
+ "PositiveIntegerField": numeric.Int4,
+ "PositiveBigIntegerField": numeric.Int8,
+ }
+
def unification_cast_sql(self, output_field):
internal_type = output_field.get_internal_type()
if internal_type in (
@@ -56,19 +75,23 @@ class DatabaseOperations(BaseDatabaseOperations):
)
return "%s"
+ # EXTRACT format cannot be passed in parameters.
+ _extract_format_re = _lazy_re_compile(r"[A-Z_]+")
+
def date_extract_sql(self, lookup_type, sql, params):
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT
- extract_sql = f"EXTRACT(%s FROM {sql})"
- extract_param = lookup_type
if lookup_type == "week_day":
# For consistency across backends, we return Sunday=1, Saturday=7.
- extract_sql = f"EXTRACT(%s FROM {sql}) + 1"
- extract_param = "dow"
+ return f"EXTRACT(DOW FROM {sql}) + 1", params
elif lookup_type == "iso_week_day":
- extract_param = "isodow"
+ return f"EXTRACT(ISODOW FROM {sql})", params
elif lookup_type == "iso_year":
- extract_param = "isoyear"
- return extract_sql, (extract_param, *params)
+ return f"EXTRACT(ISOYEAR FROM {sql})", params
+
+ lookup_type = lookup_type.upper()
+ if not self._extract_format_re.fullmatch(lookup_type):
+ raise ValueError(f"Invalid lookup type: {lookup_type!r}")
+ return f"EXTRACT({lookup_type} FROM {sql})", params
def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
sql, params = self._convert_sql_to_tz(sql, params, tzname)
@@ -100,10 +123,7 @@ class DatabaseOperations(BaseDatabaseOperations):
sql, params = self._convert_sql_to_tz(sql, params, tzname)
if lookup_type == "second":
# Truncate fractional seconds.
- return (
- f"EXTRACT(%s FROM DATE_TRUNC(%s, {sql}))",
- ("second", "second", *params),
- )
+ return f"EXTRACT(SECOND FROM DATE_TRUNC(%s, {sql}))", ("second", *params)
return self.date_extract_sql(lookup_type, sql, params)
def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
@@ -114,10 +134,7 @@ class DatabaseOperations(BaseDatabaseOperations):
def time_extract_sql(self, lookup_type, sql, params):
if lookup_type == "second":
# Truncate fractional seconds.
- return (
- f"EXTRACT(%s FROM DATE_TRUNC(%s, {sql}))",
- ("second", "second", *params),
- )
+ return f"EXTRACT(SECOND FROM DATE_TRUNC(%s, {sql}))", ("second", *params)
return self.date_extract_sql(lookup_type, sql, params)
def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
@@ -137,6 +154,16 @@ class DatabaseOperations(BaseDatabaseOperations):
def lookup_cast(self, lookup_type, internal_type=None):
lookup = "%s"
+ if lookup_type == "isnull" and internal_type in (
+ "CharField",
+ "EmailField",
+ "TextField",
+ "CICharField",
+ "CIEmailField",
+ "CITextField",
+ ):
+ return "%s::text"
+
# Cast text lookups to text to allow things like filter(x__contains=4)
if lookup_type in (
"iexact",
@@ -178,7 +205,7 @@ class DatabaseOperations(BaseDatabaseOperations):
return mogrify(sql, params, self.connection)
def set_time_zone_sql(self):
- return "SET TIME ZONE %s"
+ return "SELECT set_config('TimeZone', %s, false)"
def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
if not tables:
@@ -278,12 +305,22 @@ class DatabaseOperations(BaseDatabaseOperations):
else:
return ["DISTINCT"], []
- def last_executed_query(self, cursor, sql, params):
- # https://www.psycopg.org/docs/cursor.html#cursor.query
- # The query attribute is a Psycopg extension to the DB API 2.0.
- if cursor.query is not None:
- return cursor.query.decode()
- return None
+ if is_psycopg3:
+
+ def last_executed_query(self, cursor, sql, params):
+ try:
+ return self.compose_sql(sql, params)
+ except errors.DataError:
+ return None
+
+ else:
+
+ def last_executed_query(self, cursor, sql, params):
+ # https://www.psycopg.org/docs/cursor.html#cursor.query
+ # The query attribute is a Psycopg extension to the DB API 2.0.
+ if cursor.query is not None:
+ return cursor.query.decode()
+ return None
def return_insert_columns(self, fields):
if not fields:
@@ -303,6 +340,13 @@ class DatabaseOperations(BaseDatabaseOperations):
values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql)
return "VALUES " + values_sql
+ if is_psycopg3:
+
+ def adapt_integerfield_value(self, value, internal_type):
+ if value is None or hasattr(value, "resolve_expression"):
+ return value
+ return self.integerfield_type_map[internal_type](value)
+
def adapt_datefield_value(self, value):
return value
diff --git a/django/db/backends/postgresql/psycopg_any.py b/django/db/backends/postgresql/psycopg_any.py
index e9bb84f313..579104dead 100644
--- a/django/db/backends/postgresql/psycopg_any.py
+++ b/django/db/backends/postgresql/psycopg_any.py
@@ -1,31 +1,102 @@
-from enum import IntEnum
+import ipaddress
+from functools import lru_cache
-from psycopg2 import errors, extensions, sql # NOQA
-from psycopg2.extras import DateRange, DateTimeRange, DateTimeTZRange, Inet # NOQA
-from psycopg2.extras import Json as Jsonb # NOQA
-from psycopg2.extras import NumericRange, Range # NOQA
+try:
+ from psycopg import ClientCursor, IsolationLevel, adapt, adapters, errors, sql
+ from psycopg.postgres import types
+ from psycopg.types.datetime import TimestamptzLoader
+ from psycopg.types.json import Jsonb
+ from psycopg.types.range import Range, RangeDumper
+ from psycopg.types.string import TextLoader
-RANGE_TYPES = (DateRange, DateTimeRange, DateTimeTZRange, NumericRange)
+ Inet = ipaddress.ip_address
+ DateRange = DateTimeRange = DateTimeTZRange = NumericRange = Range
+ RANGE_TYPES = (Range,)
-class IsolationLevel(IntEnum):
- READ_UNCOMMITTED = extensions.ISOLATION_LEVEL_READ_UNCOMMITTED
- READ_COMMITTED = extensions.ISOLATION_LEVEL_READ_COMMITTED
- REPEATABLE_READ = extensions.ISOLATION_LEVEL_REPEATABLE_READ
- SERIALIZABLE = extensions.ISOLATION_LEVEL_SERIALIZABLE
+ TSRANGE_OID = types["tsrange"].oid
+ TSTZRANGE_OID = types["tstzrange"].oid
+ def mogrify(sql, params, connection):
+ return ClientCursor(connection.connection).mogrify(sql, params)
-def _quote(value, connection=None):
- adapted = extensions.adapt(value)
- if hasattr(adapted, "encoding"):
- adapted.encoding = "utf8"
- # getquoted() returns a quoted bytestring of the adapted value.
- return adapted.getquoted().decode()
+ # Adapters.
+ class BaseTzLoader(TimestamptzLoader):
+ """
+ Load a PostgreSQL timestamptz using the a specific timezone.
+ The timezone can be None too, in which case it will be chopped.
+ """
+ timezone = None
-sql.quote = _quote
+ def load(self, data):
+ res = super().load(data)
+ return res.replace(tzinfo=self.timezone)
+ def register_tzloader(tz, context):
+ class SpecificTzLoader(BaseTzLoader):
+ timezone = tz
-def mogrify(sql, params, connection):
- with connection.cursor() as cursor:
- return cursor.mogrify(sql, params).decode()
+ context.adapters.register_loader("timestamptz", SpecificTzLoader)
+
+ class DjangoRangeDumper(RangeDumper):
+ """A Range dumper customized for Django."""
+
+ def upgrade(self, obj, format):
+ # Dump ranges containing naive datetimes as tstzrange, because
+ # Django doesn't use tz-aware ones.
+ dumper = super().upgrade(obj, format)
+ if dumper is not self and dumper.oid == TSRANGE_OID:
+ dumper.oid = TSTZRANGE_OID
+ return dumper
+
+ @lru_cache
+ def get_adapters_template(use_tz, timezone):
+ # Create at adapters map extending the base one.
+ ctx = adapt.AdaptersMap(adapters)
+ # Register a no-op dumper to avoid a round trip from psycopg version 3
+ # decode to json.dumps() to json.loads(), when using a custom decoder
+ # in JSONField.
+ ctx.register_loader("jsonb", TextLoader)
+ # Don't convert automatically from PostgreSQL network types to Python
+ # ipaddress.
+ ctx.register_loader("inet", TextLoader)
+ ctx.register_loader("cidr", TextLoader)
+ ctx.register_dumper(Range, DjangoRangeDumper)
+ # Register a timestamptz loader configured on self.timezone.
+ # This, however, can be overridden by create_cursor.
+ register_tzloader(timezone, ctx)
+ return ctx
+
+ is_psycopg3 = True
+
+except ImportError:
+ from enum import IntEnum
+
+ from psycopg2 import errors, extensions, sql # NOQA
+ from psycopg2.extras import DateRange, DateTimeRange, DateTimeTZRange, Inet # NOQA
+ from psycopg2.extras import Json as Jsonb # NOQA
+ from psycopg2.extras import NumericRange, Range # NOQA
+
+ RANGE_TYPES = (DateRange, DateTimeRange, DateTimeTZRange, NumericRange)
+
+ class IsolationLevel(IntEnum):
+ READ_UNCOMMITTED = extensions.ISOLATION_LEVEL_READ_UNCOMMITTED
+ READ_COMMITTED = extensions.ISOLATION_LEVEL_READ_COMMITTED
+ REPEATABLE_READ = extensions.ISOLATION_LEVEL_REPEATABLE_READ
+ SERIALIZABLE = extensions.ISOLATION_LEVEL_SERIALIZABLE
+
+ def _quote(value, connection=None):
+ adapted = extensions.adapt(value)
+ if hasattr(adapted, "encoding"):
+ adapted.encoding = "utf8"
+ # getquoted() returns a quoted bytestring of the adapted value.
+ return adapted.getquoted().decode()
+
+ sql.quote = _quote
+
+ def mogrify(sql, params, connection):
+ with connection.cursor() as cursor:
+ return cursor.mogrify(sql, params).decode()
+
+ is_psycopg3 = False
diff --git a/django/db/backends/postgresql/schema.py b/django/db/backends/postgresql/schema.py
index cc0da85817..1bd72bc0cb 100644
--- a/django/db/backends/postgresql/schema.py
+++ b/django/db/backends/postgresql/schema.py
@@ -40,6 +40,14 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
)
sql_delete_procedure = "DROP FUNCTION %(procedure)s(%(param_types)s)"
+ def execute(self, sql, params=()):
+ # Merge the query client-side, as PostgreSQL won't do it server-side.
+ if params is None:
+ return super().execute(sql, params)
+ sql = self.connection.ops.compose_sql(str(sql), params)
+ # Don't let the superclass touch anything.
+ return super().execute(sql, None)
+
sql_add_identity = (
"ALTER TABLE %(table)s ALTER COLUMN %(column)s ADD "
"GENERATED BY DEFAULT AS IDENTITY"