diff options
| author | Daniele Varrazzo <daniele.varrazzo@gmail.com> | 2022-12-01 20:23:43 +0100 |
|---|---|---|
| committer | Mariusz Felisiak <felisiak.mariusz@gmail.com> | 2022-12-15 06:17:57 +0100 |
| commit | 09ffc5c1212d4ced58b708cbbf3dfbfb77b782ca (patch) | |
| tree | 15bb8bb049f9339f30d637e78b340473c2038126 /django/db/backends/postgresql | |
| parent | d44ee518c4c110af25bebdbedbbf9fba04d197aa (diff) | |
Fixed #33308 -- Added support for psycopg version 3.
Thanks Simon Charette, Tim Graham, and Adam Johnson for reviews.
Co-authored-by: Florian Apolloner <florian@apolloner.eu>
Co-authored-by: Mariusz Felisiak <felisiak.mariusz@gmail.com>
Diffstat (limited to 'django/db/backends/postgresql')
| -rw-r--r-- | django/db/backends/postgresql/base.py | 162 | ||||
| -rw-r--r-- | django/db/backends/postgresql/features.py | 11 | ||||
| -rw-r--r-- | django/db/backends/postgresql/operations.py | 90 | ||||
| -rw-r--r-- | django/db/backends/postgresql/psycopg_any.py | 113 | ||||
| -rw-r--r-- | django/db/backends/postgresql/schema.py | 8 |
5 files changed, 291 insertions, 93 deletions
diff --git a/django/db/backends/postgresql/base.py b/django/db/backends/postgresql/base.py index 0aee39aa5c..ceea1bebad 100644 --- a/django/db/backends/postgresql/base.py +++ b/django/db/backends/postgresql/base.py @@ -1,7 +1,7 @@ """ PostgreSQL database backend for Django. -Requires psycopg 2: https://www.psycopg.org/ +Requires psycopg2 >= 2.8.4 or psycopg >= 3.1 """ import asyncio @@ -21,48 +21,63 @@ from django.utils.safestring import SafeString from django.utils.version import get_version_tuple try: - import psycopg2 as Database - import psycopg2.extensions - import psycopg2.extras -except ImportError as e: - raise ImproperlyConfigured("Error loading psycopg2 module: %s" % e) + try: + import psycopg as Database + except ImportError: + import psycopg2 as Database +except ImportError: + raise ImproperlyConfigured("Error loading psycopg2 or psycopg module") -def psycopg2_version(): - version = psycopg2.__version__.split(" ", 1)[0] +def psycopg_version(): + version = Database.__version__.split(" ", 1)[0] return get_version_tuple(version) -PSYCOPG2_VERSION = psycopg2_version() - -if PSYCOPG2_VERSION < (2, 8, 4): +if psycopg_version() < (2, 8, 4): + raise ImproperlyConfigured( + f"psycopg2 version 2.8.4 or newer is required; you have {Database.__version__}" + ) +if (3,) <= psycopg_version() < (3, 1): raise ImproperlyConfigured( - "psycopg2 version 2.8.4 or newer is required; you have %s" - % psycopg2.__version__ + f"psycopg version 3.1 or newer is required; you have {Database.__version__}" ) -# Some of these import psycopg2, so import them after checking if it's installed. -from .client import DatabaseClient # NOQA -from .creation import DatabaseCreation # NOQA -from .features import DatabaseFeatures # NOQA -from .introspection import DatabaseIntrospection # NOQA -from .operations import DatabaseOperations # NOQA -from .psycopg_any import IsolationLevel # NOQA -from .schema import DatabaseSchemaEditor # NOQA +from .psycopg_any import IsolationLevel, is_psycopg3 # NOQA isort:skip + +if is_psycopg3: + from psycopg import adapters, sql + from psycopg.pq import Format -psycopg2.extensions.register_adapter(SafeString, psycopg2.extensions.QuotedString) -psycopg2.extras.register_uuid() + from .psycopg_any import get_adapters_template, register_tzloader -# Register support for inet[] manually so we don't have to handle the Inet() -# object on load all the time. -INETARRAY_OID = 1041 -INETARRAY = psycopg2.extensions.new_array_type( - (INETARRAY_OID,), - "INETARRAY", - psycopg2.extensions.UNICODE, -) -psycopg2.extensions.register_type(INETARRAY) + TIMESTAMPTZ_OID = adapters.types["timestamptz"].oid + +else: + import psycopg2.extensions + import psycopg2.extras + + psycopg2.extensions.register_adapter(SafeString, psycopg2.extensions.QuotedString) + psycopg2.extras.register_uuid() + + # Register support for inet[] manually so we don't have to handle the Inet() + # object on load all the time. + INETARRAY_OID = 1041 + INETARRAY = psycopg2.extensions.new_array_type( + (INETARRAY_OID,), + "INETARRAY", + psycopg2.extensions.UNICODE, + ) + psycopg2.extensions.register_type(INETARRAY) + +# Some of these import psycopg, so import them after checking if it's installed. +from .client import DatabaseClient # NOQA isort:skip +from .creation import DatabaseCreation # NOQA isort:skip +from .features import DatabaseFeatures # NOQA isort:skip +from .introspection import DatabaseIntrospection # NOQA isort:skip +from .operations import DatabaseOperations # NOQA isort:skip +from .schema import DatabaseSchemaEditor # NOQA isort:skip class DatabaseWrapper(BaseDatabaseWrapper): @@ -209,6 +224,15 @@ class DatabaseWrapper(BaseDatabaseWrapper): conn_params["host"] = settings_dict["HOST"] if settings_dict["PORT"]: conn_params["port"] = settings_dict["PORT"] + if is_psycopg3: + conn_params["context"] = get_adapters_template( + settings.USE_TZ, self.timezone + ) + # Disable prepared statements by default to keep connection poolers + # working. Can be reenabled via OPTIONS in the settings dict. + conn_params["prepare_threshold"] = conn_params.pop( + "prepare_threshold", None + ) return conn_params @async_unsafe @@ -232,17 +256,19 @@ class DatabaseWrapper(BaseDatabaseWrapper): except ValueError: raise ImproperlyConfigured( f"Invalid transaction isolation level {isolation_level_value} " - f"specified. Use one of the IsolationLevel values." + f"specified. Use one of the psycopg.IsolationLevel values." ) - connection = Database.connect(**conn_params) + connection = self.Database.connect(**conn_params) if set_isolation_level: connection.isolation_level = self.isolation_level - # Register dummy loads() to avoid a round trip from psycopg2's decode - # to json.dumps() to json.loads(), when using a custom decoder in - # JSONField. - psycopg2.extras.register_default_jsonb( - conn_or_curs=connection, loads=lambda x: x - ) + if not is_psycopg3: + # Register dummy loads() to avoid a round trip from psycopg2's + # decode to json.dumps() to json.loads(), when using a custom + # decoder in JSONField. + psycopg2.extras.register_default_jsonb( + conn_or_curs=connection, loads=lambda x: x + ) + connection.cursor_factory = Cursor return connection def ensure_timezone(self): @@ -275,7 +301,15 @@ class DatabaseWrapper(BaseDatabaseWrapper): ) else: cursor = self.connection.cursor() - cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None + + if is_psycopg3: + # Register the cursor timezone only if the connection disagrees, to + # avoid copying the adapter map. + tzloader = self.connection.adapters.get_loader(TIMESTAMPTZ_OID, Format.TEXT) + if self.timezone != tzloader.timezone: + register_tzloader(self.timezone, cursor) + else: + cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None return cursor def tzinfo_factory(self, offset): @@ -379,11 +413,43 @@ class DatabaseWrapper(BaseDatabaseWrapper): return CursorDebugWrapper(cursor, self) -class CursorDebugWrapper(BaseCursorDebugWrapper): - def copy_expert(self, sql, file, *args): - with self.debug_sql(sql): - return self.cursor.copy_expert(sql, file, *args) +if is_psycopg3: + + class Cursor(Database.Cursor): + """ + A subclass of psycopg cursor implementing callproc. + """ + + def callproc(self, name, args=None): + if not isinstance(name, sql.Identifier): + name = sql.Identifier(name) + + qparts = [sql.SQL("SELECT * FROM "), name, sql.SQL("(")] + if args: + for item in args: + qparts.append(sql.Literal(item)) + qparts.append(sql.SQL(",")) + del qparts[-1] + + qparts.append(sql.SQL(")")) + stmt = sql.Composed(qparts) + self.execute(stmt) + return args + + class CursorDebugWrapper(BaseCursorDebugWrapper): + def copy(self, statement): + with self.debug_sql(statement): + return self.cursor.copy(statement) + +else: + + Cursor = psycopg2.extensions.cursor + + class CursorDebugWrapper(BaseCursorDebugWrapper): + def copy_expert(self, sql, file, *args): + with self.debug_sql(sql): + return self.cursor.copy_expert(sql, file, *args) - def copy_to(self, file, table, *args, **kwargs): - with self.debug_sql(sql="COPY %s TO STDOUT" % table): - return self.cursor.copy_to(file, table, *args, **kwargs) + def copy_to(self, file, table, *args, **kwargs): + with self.debug_sql(sql="COPY %s TO STDOUT" % table): + return self.cursor.copy_to(file, table, *args, **kwargs) diff --git a/django/db/backends/postgresql/features.py b/django/db/backends/postgresql/features.py index 0eed8c8d63..fd5b05aad4 100644 --- a/django/db/backends/postgresql/features.py +++ b/django/db/backends/postgresql/features.py @@ -1,7 +1,8 @@ import operator -from django.db import InterfaceError +from django.db import DataError, InterfaceError from django.db.backends.base.features import BaseDatabaseFeatures +from django.db.backends.postgresql.psycopg_any import is_psycopg3 from django.utils.functional import cached_property @@ -26,6 +27,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): can_introspect_materialized_views = True can_distinct_on_fields = True can_rollback_ddl = True + schema_editor_uses_clientside_param_binding = True supports_combined_alters = True nulls_order_largest = True closed_cursor_error_class = InterfaceError @@ -82,6 +84,13 @@ class DatabaseFeatures(BaseDatabaseFeatures): } @cached_property + def prohibits_null_characters_in_text_exception(self): + if is_psycopg3: + return DataError, "PostgreSQL text fields cannot contain NUL (0x00) bytes" + else: + return ValueError, "A string literal cannot contain NUL (0x00) characters." + + @cached_property def introspected_field_types(self): return { **super().introspected_field_types, diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py index 824e0c3e4b..18cfcb29cb 100644 --- a/django/db/backends/postgresql/operations.py +++ b/django/db/backends/postgresql/operations.py @@ -3,9 +3,16 @@ from functools import lru_cache, partial from django.conf import settings from django.db.backends.base.operations import BaseDatabaseOperations -from django.db.backends.postgresql.psycopg_any import Inet, Jsonb, mogrify +from django.db.backends.postgresql.psycopg_any import ( + Inet, + Jsonb, + errors, + is_psycopg3, + mogrify, +) from django.db.backends.utils import split_tzname_delta from django.db.models.constants import OnConflict +from django.utils.regex_helper import _lazy_re_compile @lru_cache @@ -36,6 +43,18 @@ class DatabaseOperations(BaseDatabaseOperations): "SmallAutoField": "smallint", } + if is_psycopg3: + from psycopg.types import numeric + + integerfield_type_map = { + "SmallIntegerField": numeric.Int2, + "IntegerField": numeric.Int4, + "BigIntegerField": numeric.Int8, + "PositiveSmallIntegerField": numeric.Int2, + "PositiveIntegerField": numeric.Int4, + "PositiveBigIntegerField": numeric.Int8, + } + def unification_cast_sql(self, output_field): internal_type = output_field.get_internal_type() if internal_type in ( @@ -56,19 +75,23 @@ class DatabaseOperations(BaseDatabaseOperations): ) return "%s" + # EXTRACT format cannot be passed in parameters. + _extract_format_re = _lazy_re_compile(r"[A-Z_]+") + def date_extract_sql(self, lookup_type, sql, params): # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT - extract_sql = f"EXTRACT(%s FROM {sql})" - extract_param = lookup_type if lookup_type == "week_day": # For consistency across backends, we return Sunday=1, Saturday=7. - extract_sql = f"EXTRACT(%s FROM {sql}) + 1" - extract_param = "dow" + return f"EXTRACT(DOW FROM {sql}) + 1", params elif lookup_type == "iso_week_day": - extract_param = "isodow" + return f"EXTRACT(ISODOW FROM {sql})", params elif lookup_type == "iso_year": - extract_param = "isoyear" - return extract_sql, (extract_param, *params) + return f"EXTRACT(ISOYEAR FROM {sql})", params + + lookup_type = lookup_type.upper() + if not self._extract_format_re.fullmatch(lookup_type): + raise ValueError(f"Invalid lookup type: {lookup_type!r}") + return f"EXTRACT({lookup_type} FROM {sql})", params def date_trunc_sql(self, lookup_type, sql, params, tzname=None): sql, params = self._convert_sql_to_tz(sql, params, tzname) @@ -100,10 +123,7 @@ class DatabaseOperations(BaseDatabaseOperations): sql, params = self._convert_sql_to_tz(sql, params, tzname) if lookup_type == "second": # Truncate fractional seconds. - return ( - f"EXTRACT(%s FROM DATE_TRUNC(%s, {sql}))", - ("second", "second", *params), - ) + return f"EXTRACT(SECOND FROM DATE_TRUNC(%s, {sql}))", ("second", *params) return self.date_extract_sql(lookup_type, sql, params) def datetime_trunc_sql(self, lookup_type, sql, params, tzname): @@ -114,10 +134,7 @@ class DatabaseOperations(BaseDatabaseOperations): def time_extract_sql(self, lookup_type, sql, params): if lookup_type == "second": # Truncate fractional seconds. - return ( - f"EXTRACT(%s FROM DATE_TRUNC(%s, {sql}))", - ("second", "second", *params), - ) + return f"EXTRACT(SECOND FROM DATE_TRUNC(%s, {sql}))", ("second", *params) return self.date_extract_sql(lookup_type, sql, params) def time_trunc_sql(self, lookup_type, sql, params, tzname=None): @@ -137,6 +154,16 @@ class DatabaseOperations(BaseDatabaseOperations): def lookup_cast(self, lookup_type, internal_type=None): lookup = "%s" + if lookup_type == "isnull" and internal_type in ( + "CharField", + "EmailField", + "TextField", + "CICharField", + "CIEmailField", + "CITextField", + ): + return "%s::text" + # Cast text lookups to text to allow things like filter(x__contains=4) if lookup_type in ( "iexact", @@ -178,7 +205,7 @@ class DatabaseOperations(BaseDatabaseOperations): return mogrify(sql, params, self.connection) def set_time_zone_sql(self): - return "SET TIME ZONE %s" + return "SELECT set_config('TimeZone', %s, false)" def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False): if not tables: @@ -278,12 +305,22 @@ class DatabaseOperations(BaseDatabaseOperations): else: return ["DISTINCT"], [] - def last_executed_query(self, cursor, sql, params): - # https://www.psycopg.org/docs/cursor.html#cursor.query - # The query attribute is a Psycopg extension to the DB API 2.0. - if cursor.query is not None: - return cursor.query.decode() - return None + if is_psycopg3: + + def last_executed_query(self, cursor, sql, params): + try: + return self.compose_sql(sql, params) + except errors.DataError: + return None + + else: + + def last_executed_query(self, cursor, sql, params): + # https://www.psycopg.org/docs/cursor.html#cursor.query + # The query attribute is a Psycopg extension to the DB API 2.0. + if cursor.query is not None: + return cursor.query.decode() + return None def return_insert_columns(self, fields): if not fields: @@ -303,6 +340,13 @@ class DatabaseOperations(BaseDatabaseOperations): values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql) return "VALUES " + values_sql + if is_psycopg3: + + def adapt_integerfield_value(self, value, internal_type): + if value is None or hasattr(value, "resolve_expression"): + return value + return self.integerfield_type_map[internal_type](value) + def adapt_datefield_value(self, value): return value diff --git a/django/db/backends/postgresql/psycopg_any.py b/django/db/backends/postgresql/psycopg_any.py index e9bb84f313..579104dead 100644 --- a/django/db/backends/postgresql/psycopg_any.py +++ b/django/db/backends/postgresql/psycopg_any.py @@ -1,31 +1,102 @@ -from enum import IntEnum +import ipaddress +from functools import lru_cache -from psycopg2 import errors, extensions, sql # NOQA -from psycopg2.extras import DateRange, DateTimeRange, DateTimeTZRange, Inet # NOQA -from psycopg2.extras import Json as Jsonb # NOQA -from psycopg2.extras import NumericRange, Range # NOQA +try: + from psycopg import ClientCursor, IsolationLevel, adapt, adapters, errors, sql + from psycopg.postgres import types + from psycopg.types.datetime import TimestamptzLoader + from psycopg.types.json import Jsonb + from psycopg.types.range import Range, RangeDumper + from psycopg.types.string import TextLoader -RANGE_TYPES = (DateRange, DateTimeRange, DateTimeTZRange, NumericRange) + Inet = ipaddress.ip_address + DateRange = DateTimeRange = DateTimeTZRange = NumericRange = Range + RANGE_TYPES = (Range,) -class IsolationLevel(IntEnum): - READ_UNCOMMITTED = extensions.ISOLATION_LEVEL_READ_UNCOMMITTED - READ_COMMITTED = extensions.ISOLATION_LEVEL_READ_COMMITTED - REPEATABLE_READ = extensions.ISOLATION_LEVEL_REPEATABLE_READ - SERIALIZABLE = extensions.ISOLATION_LEVEL_SERIALIZABLE + TSRANGE_OID = types["tsrange"].oid + TSTZRANGE_OID = types["tstzrange"].oid + def mogrify(sql, params, connection): + return ClientCursor(connection.connection).mogrify(sql, params) -def _quote(value, connection=None): - adapted = extensions.adapt(value) - if hasattr(adapted, "encoding"): - adapted.encoding = "utf8" - # getquoted() returns a quoted bytestring of the adapted value. - return adapted.getquoted().decode() + # Adapters. + class BaseTzLoader(TimestamptzLoader): + """ + Load a PostgreSQL timestamptz using the a specific timezone. + The timezone can be None too, in which case it will be chopped. + """ + timezone = None -sql.quote = _quote + def load(self, data): + res = super().load(data) + return res.replace(tzinfo=self.timezone) + def register_tzloader(tz, context): + class SpecificTzLoader(BaseTzLoader): + timezone = tz -def mogrify(sql, params, connection): - with connection.cursor() as cursor: - return cursor.mogrify(sql, params).decode() + context.adapters.register_loader("timestamptz", SpecificTzLoader) + + class DjangoRangeDumper(RangeDumper): + """A Range dumper customized for Django.""" + + def upgrade(self, obj, format): + # Dump ranges containing naive datetimes as tstzrange, because + # Django doesn't use tz-aware ones. + dumper = super().upgrade(obj, format) + if dumper is not self and dumper.oid == TSRANGE_OID: + dumper.oid = TSTZRANGE_OID + return dumper + + @lru_cache + def get_adapters_template(use_tz, timezone): + # Create at adapters map extending the base one. + ctx = adapt.AdaptersMap(adapters) + # Register a no-op dumper to avoid a round trip from psycopg version 3 + # decode to json.dumps() to json.loads(), when using a custom decoder + # in JSONField. + ctx.register_loader("jsonb", TextLoader) + # Don't convert automatically from PostgreSQL network types to Python + # ipaddress. + ctx.register_loader("inet", TextLoader) + ctx.register_loader("cidr", TextLoader) + ctx.register_dumper(Range, DjangoRangeDumper) + # Register a timestamptz loader configured on self.timezone. + # This, however, can be overridden by create_cursor. + register_tzloader(timezone, ctx) + return ctx + + is_psycopg3 = True + +except ImportError: + from enum import IntEnum + + from psycopg2 import errors, extensions, sql # NOQA + from psycopg2.extras import DateRange, DateTimeRange, DateTimeTZRange, Inet # NOQA + from psycopg2.extras import Json as Jsonb # NOQA + from psycopg2.extras import NumericRange, Range # NOQA + + RANGE_TYPES = (DateRange, DateTimeRange, DateTimeTZRange, NumericRange) + + class IsolationLevel(IntEnum): + READ_UNCOMMITTED = extensions.ISOLATION_LEVEL_READ_UNCOMMITTED + READ_COMMITTED = extensions.ISOLATION_LEVEL_READ_COMMITTED + REPEATABLE_READ = extensions.ISOLATION_LEVEL_REPEATABLE_READ + SERIALIZABLE = extensions.ISOLATION_LEVEL_SERIALIZABLE + + def _quote(value, connection=None): + adapted = extensions.adapt(value) + if hasattr(adapted, "encoding"): + adapted.encoding = "utf8" + # getquoted() returns a quoted bytestring of the adapted value. + return adapted.getquoted().decode() + + sql.quote = _quote + + def mogrify(sql, params, connection): + with connection.cursor() as cursor: + return cursor.mogrify(sql, params).decode() + + is_psycopg3 = False diff --git a/django/db/backends/postgresql/schema.py b/django/db/backends/postgresql/schema.py index cc0da85817..1bd72bc0cb 100644 --- a/django/db/backends/postgresql/schema.py +++ b/django/db/backends/postgresql/schema.py @@ -40,6 +40,14 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): ) sql_delete_procedure = "DROP FUNCTION %(procedure)s(%(param_types)s)" + def execute(self, sql, params=()): + # Merge the query client-side, as PostgreSQL won't do it server-side. + if params is None: + return super().execute(sql, params) + sql = self.connection.ops.compose_sql(str(sql), params) + # Don't let the superclass touch anything. + return super().execute(sql, None) + sql_add_identity = ( "ALTER TABLE %(table)s ALTER COLUMN %(column)s ADD " "GENERATED BY DEFAULT AS IDENTITY" |
