summaryrefslogtreecommitdiff
path: root/django/db/backends/postgresql
diff options
context:
space:
mode:
authorSarah Boyce <42296566+sarahboyce@users.noreply.github.com>2023-12-11 11:37:54 +0100
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2024-03-01 09:01:18 +0100
commitfad334e1a9b54ea1acb8cce02a25934c5acfe99f (patch)
tree4fc84e9981d6cdc83175d88eaa59c251cda009e7 /django/db/backends/postgresql
parentbcccea3ef31c777b73cba41a6255cd866bf87237 (diff)
Refs #33497 -- Added connection pool support for PostgreSQL.
Co-authored-by: Florian Apolloner <florian@apolloner.eu> Co-authored-by: Ran Benita <ran@unusedvar.com>
Diffstat (limited to 'django/db/backends/postgresql')
-rw-r--r--django/db/backends/postgresql/base.py145
-rw-r--r--django/db/backends/postgresql/creation.py5
-rw-r--r--django/db/backends/postgresql/features.py32
3 files changed, 150 insertions, 32 deletions
diff --git a/django/db/backends/postgresql/base.py b/django/db/backends/postgresql/base.py
index 793a7bf3bc..e97ab6aa89 100644
--- a/django/db/backends/postgresql/base.py
+++ b/django/db/backends/postgresql/base.py
@@ -13,7 +13,7 @@ from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.db import DatabaseError as WrappedDatabaseError
from django.db import connections
-from django.db.backends.base.base import BaseDatabaseWrapper
+from django.db.backends.base.base import NO_DB_ALIAS, BaseDatabaseWrapper
from django.db.backends.utils import CursorDebugWrapper as BaseCursorDebugWrapper
from django.utils.asyncio import async_unsafe
from django.utils.functional import cached_property
@@ -86,6 +86,24 @@ def _get_varchar_column(data):
return "varchar(%(max_length)s)" % data
+def ensure_timezone(connection, ops, timezone_name):
+ conn_timezone_name = connection.info.parameter_status("TimeZone")
+ if timezone_name and conn_timezone_name != timezone_name:
+ with connection.cursor() as cursor:
+ cursor.execute(ops.set_time_zone_sql(), [timezone_name])
+ return True
+ return False
+
+
+def ensure_role(connection, ops, role_name):
+ if role_name:
+ with connection.cursor() as cursor:
+ sql = ops.compose_sql("SET ROLE %s", [role_name])
+ cursor.execute(sql)
+ return True
+ return False
+
+
class DatabaseWrapper(BaseDatabaseWrapper):
vendor = "postgresql"
display_name = "PostgreSQL"
@@ -179,6 +197,53 @@ class DatabaseWrapper(BaseDatabaseWrapper):
ops_class = DatabaseOperations
# PostgreSQL backend-specific attributes.
_named_cursor_idx = 0
+ _connection_pools = {}
+
+ @property
+ def pool(self):
+ pool_options = self.settings_dict["OPTIONS"].get("pool")
+ if self.alias == NO_DB_ALIAS or not pool_options:
+ return None
+
+ if self.alias not in self._connection_pools:
+ if self.settings_dict.get("CONN_MAX_AGE", 0) != 0:
+ raise ImproperlyConfigured(
+ "Pooling doesn't support persistent connections."
+ )
+ # Set the default options.
+ if pool_options is True:
+ pool_options = {}
+
+ try:
+ from psycopg_pool import ConnectionPool
+ except ImportError as err:
+ raise ImproperlyConfigured(
+ "Error loading psycopg_pool module.\nDid you install psycopg[pool]?"
+ ) from err
+
+ connect_kwargs = self.get_connection_params()
+ # Ensure we run in autocommit, Django properly sets it later on.
+ connect_kwargs["autocommit"] = True
+ enable_checks = self.settings_dict["CONN_HEALTH_CHECKS"]
+ pool = ConnectionPool(
+ kwargs=connect_kwargs,
+ open=False, # Do not open the pool during startup.
+ configure=self._configure_connection,
+ check=ConnectionPool.check_connection if enable_checks else None,
+ **pool_options,
+ )
+ # setdefault() ensures that multiple threads don't set this in
+ # parallel. Since we do not open the pool during it's init above,
+ # this means that at worst during startup multiple threads generate
+ # pool objects and the first to set it wins.
+ self._connection_pools.setdefault(self.alias, pool)
+
+ return self._connection_pools[self.alias]
+
+ def close_pool(self):
+ if self.pool:
+ self.pool.close()
+ del self._connection_pools[self.alias]
def get_database_version(self):
"""
@@ -221,6 +286,11 @@ class DatabaseWrapper(BaseDatabaseWrapper):
conn_params.pop("assume_role", None)
conn_params.pop("isolation_level", None)
+
+ pool_options = conn_params.pop("pool", None)
+ if pool_options and not is_psycopg3:
+ raise ImproperlyConfigured("Database pooling requires psycopg >= 3")
+
server_side_binding = conn_params.pop("server_side_binding", None)
conn_params.setdefault(
"cursor_factory",
@@ -272,7 +342,12 @@ class DatabaseWrapper(BaseDatabaseWrapper):
f"Invalid transaction isolation level {isolation_level_value} "
f"specified. Use one of the psycopg.IsolationLevel values."
)
- connection = self.Database.connect(**conn_params)
+ if self.pool:
+ # If nothing else has opened the pool, open it now.
+ self.pool.open()
+ connection = self.pool.getconn()
+ else:
+ connection = self.Database.connect(**conn_params)
if set_isolation_level:
connection.isolation_level = self.isolation_level
if not is_psycopg3:
@@ -285,36 +360,52 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return connection
def ensure_timezone(self):
+ # Close the pool so new connections pick up the correct timezone.
+ self.close_pool()
if self.connection is None:
return False
- conn_timezone_name = self.connection.info.parameter_status("TimeZone")
- timezone_name = self.timezone_name
- if timezone_name and conn_timezone_name != timezone_name:
- with self.connection.cursor() as cursor:
- cursor.execute(self.ops.set_time_zone_sql(), [timezone_name])
- return True
- return False
+ return ensure_timezone(self.connection, self.ops, self.timezone_name)
- def ensure_role(self):
- if new_role := self.settings_dict["OPTIONS"].get("assume_role"):
- with self.connection.cursor() as cursor:
- sql = self.ops.compose_sql("SET ROLE %s", [new_role])
- cursor.execute(sql)
- return True
- return False
-
- def init_connection_state(self):
- super().init_connection_state()
+ def _configure_connection(self, connection):
+ # This function is called from init_connection_state and from the
+ # psycopg pool itself after a connection is opened. Make sure that
+ # whatever is done here does not access anything on self aside from
+ # variables.
# Commit after setting the time zone.
- commit_tz = self.ensure_timezone()
+ commit_tz = ensure_timezone(connection, self.ops, self.timezone_name)
# Set the role on the connection. This is useful if the credential used
# to login is not the same as the role that owns database resources. As
# can be the case when using temporary or ephemeral credentials.
- commit_role = self.ensure_role()
+ role_name = self.settings_dict["OPTIONS"].get("assume_role")
+ commit_role = ensure_role(connection, self.ops, role_name)
+
+ return commit_role or commit_tz
- if (commit_role or commit_tz) and not self.get_autocommit():
- self.connection.commit()
+ def _close(self):
+ if self.connection is not None:
+ # `wrap_database_errors` only works for `putconn` as long as there
+ # is no `reset` function set in the pool because it is deferred
+ # into a thread and not directly executed.
+ with self.wrap_database_errors:
+ if self.pool:
+ # Ensure the correct pool is returned. This is a workaround
+ # for tests so a pool can be changed on setting changes
+ # (e.g. USE_TZ, TIME_ZONE).
+ self.connection._pool.putconn(self.connection)
+ # Connection can no longer be used.
+ self.connection = None
+ else:
+ return self.connection.close()
+
+ def init_connection_state(self):
+ super().init_connection_state()
+
+ if self.connection is not None and not self.pool:
+ commit = self._configure_connection(self.connection)
+
+ if commit and not self.get_autocommit():
+ self.connection.commit()
@async_unsafe
def create_cursor(self, name=None):
@@ -396,6 +487,8 @@ class DatabaseWrapper(BaseDatabaseWrapper):
cursor.execute("SET CONSTRAINTS ALL DEFERRED")
def is_usable(self):
+ if self.connection is None:
+ return False
try:
# Use a psycopg cursor directly, bypassing Django's utilities.
with self.connection.cursor() as cursor:
@@ -405,6 +498,12 @@ class DatabaseWrapper(BaseDatabaseWrapper):
else:
return True
+ def close_if_health_check_failed(self):
+ if self.pool:
+ # The pool only returns healthy connections.
+ return
+ return super().close_if_health_check_failed()
+
@contextmanager
def _nodb_cursor(self):
cursor = None
diff --git a/django/db/backends/postgresql/creation.py b/django/db/backends/postgresql/creation.py
index 9b562cec18..938be0f56f 100644
--- a/django/db/backends/postgresql/creation.py
+++ b/django/db/backends/postgresql/creation.py
@@ -58,6 +58,7 @@ class DatabaseCreation(BaseDatabaseCreation):
# CREATE DATABASE ... WITH TEMPLATE ... requires closing connections
# to the template database.
self.connection.close()
+ self.connection.close_pool()
source_database_name = self.connection.settings_dict["NAME"]
target_database_name = self.get_test_db_clone_settings(suffix)["NAME"]
@@ -84,3 +85,7 @@ class DatabaseCreation(BaseDatabaseCreation):
except Exception as e:
self.log("Got an error cloning the test database: %s" % e)
sys.exit(2)
+
+ def _destroy_test_db(self, test_database_name, verbosity):
+ self.connection.close_pool()
+ return super()._destroy_test_db(test_database_name, verbosity)
diff --git a/django/db/backends/postgresql/features.py b/django/db/backends/postgresql/features.py
index 7bcc356407..809466fc7f 100644
--- a/django/db/backends/postgresql/features.py
+++ b/django/db/backends/postgresql/features.py
@@ -83,15 +83,29 @@ class DatabaseFeatures(BaseDatabaseFeatures):
test_now_utc_template = "STATEMENT_TIMESTAMP() AT TIME ZONE 'UTC'"
insert_test_table_with_defaults = "INSERT INTO {} DEFAULT VALUES"
- django_test_skips = {
- "opclasses are PostgreSQL only.": {
- "indexes.tests.SchemaIndexesNotPostgreSQLTests."
- "test_create_index_ignores_opclasses",
- },
- "PostgreSQL requires casting to text.": {
- "lookup.tests.LookupTests.test_textfield_exact_null",
- },
- }
+ @cached_property
+ def django_test_skips(self):
+ skips = {
+ "opclasses are PostgreSQL only.": {
+ "indexes.tests.SchemaIndexesNotPostgreSQLTests."
+ "test_create_index_ignores_opclasses",
+ },
+ "PostgreSQL requires casting to text.": {
+ "lookup.tests.LookupTests.test_textfield_exact_null",
+ },
+ }
+ if self.connection.settings_dict["OPTIONS"].get("pool"):
+ skips.update(
+ {
+ "Pool does implicit health checks": {
+ "backends.base.test_base.ConnectionHealthChecksTests."
+ "test_health_checks_enabled",
+ "backends.base.test_base.ConnectionHealthChecksTests."
+ "test_set_autocommit_health_checks_enabled",
+ },
+ }
+ )
+ return skips
@cached_property
def django_test_expected_failures(self):