summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJon Dufresne <jon.dufresne@gmail.com>2019-02-14 07:04:55 -0800
committerTim Graham <timograham@gmail.com>2019-02-14 10:05:13 -0500
commit37cc6a9dce3354cd37f23ee972bc25b0e5cebd5c (patch)
tree90f49316a7b3104b924c5556491ba56e803b04f2
parent07b44a251a41ca93a7f5593761fcf808249665f0 (diff)
[2.2.x] Fixed #30171 -- Fixed DatabaseError in servers tests.
Made DatabaseWrapper thread sharing logic reentrant. Used a reference counting like scheme to allow nested uses. The error appeared after 8c775391b78b2a4a2b57c5e89ed4888f36aada4b. Backport of 76990cbbda5d93fda560c8a5ab019860f7efaab7 from master.
-rw-r--r--django/db/backends/base/base.py38
-rw-r--r--django/db/backends/postgresql/base.py1
-rw-r--r--django/test/testcases.py9
-rw-r--r--docs/releases/2.2.txt3
-rw-r--r--tests/backends/tests.py107
-rw-r--r--tests/servers/test_liveserverthread.py5
-rw-r--r--tests/staticfiles_tests/test_liveserver.py3
7 files changed, 100 insertions, 66 deletions
diff --git a/django/db/backends/base/base.py b/django/db/backends/base/base.py
index f97d171c96..9fa03cc0ee 100644
--- a/django/db/backends/base/base.py
+++ b/django/db/backends/base/base.py
@@ -1,4 +1,5 @@
import copy
+import threading
import time
import warnings
from collections import deque
@@ -43,8 +44,7 @@ class BaseDatabaseWrapper:
queries_limit = 9000
- def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS,
- allow_thread_sharing=False):
+ def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS):
# Connection related attributes.
# The underlying database connection.
self.connection = None
@@ -80,7 +80,8 @@ class BaseDatabaseWrapper:
self.errors_occurred = False
# Thread-safety related attributes.
- self.allow_thread_sharing = allow_thread_sharing
+ self._thread_sharing_lock = threading.Lock()
+ self._thread_sharing_count = 0
self._thread_ident = _thread.get_ident()
# A list of no-argument functions to run when the transaction commits.
@@ -515,12 +516,27 @@ class BaseDatabaseWrapper:
# ##### Thread safety handling #####
+ @property
+ def allow_thread_sharing(self):
+ with self._thread_sharing_lock:
+ return self._thread_sharing_count > 0
+
+ def inc_thread_sharing(self):
+ with self._thread_sharing_lock:
+ self._thread_sharing_count += 1
+
+ def dec_thread_sharing(self):
+ with self._thread_sharing_lock:
+ if self._thread_sharing_count <= 0:
+ raise RuntimeError('Cannot decrement the thread sharing count below zero.')
+ self._thread_sharing_count -= 1
+
def validate_thread_sharing(self):
"""
Validate that the connection isn't accessed by another thread than the
one which originally created it, unless the connection was explicitly
- authorized to be shared between threads (via the `allow_thread_sharing`
- property). Raise an exception if the validation fails.
+ authorized to be shared between threads (via the `inc_thread_sharing()`
+ method). Raise an exception if the validation fails.
"""
if not (self.allow_thread_sharing or self._thread_ident == _thread.get_ident()):
raise DatabaseError(
@@ -589,11 +605,7 @@ class BaseDatabaseWrapper:
potential child threads while (or after) the test database is destroyed.
Refs #10868, #17786, #16969.
"""
- return self.__class__(
- {**self.settings_dict, 'NAME': None},
- alias=NO_DB_ALIAS,
- allow_thread_sharing=False,
- )
+ return self.__class__({**self.settings_dict, 'NAME': None}, alias=NO_DB_ALIAS)
def schema_editor(self, *args, **kwargs):
"""
@@ -635,7 +647,7 @@ class BaseDatabaseWrapper:
finally:
self.execute_wrappers.pop()
- def copy(self, alias=None, allow_thread_sharing=None):
+ def copy(self, alias=None):
"""
Return a copy of this connection.
@@ -644,6 +656,4 @@ class BaseDatabaseWrapper:
settings_dict = copy.deepcopy(self.settings_dict)
if alias is None:
alias = self.alias
- if allow_thread_sharing is None:
- allow_thread_sharing = self.allow_thread_sharing
- return type(self)(settings_dict, alias, allow_thread_sharing)
+ return type(self)(settings_dict, alias)
diff --git a/django/db/backends/postgresql/base.py b/django/db/backends/postgresql/base.py
index e9db668a4d..4376b8609c 100644
--- a/django/db/backends/postgresql/base.py
+++ b/django/db/backends/postgresql/base.py
@@ -277,7 +277,6 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return self.__class__(
{**self.settings_dict, 'NAME': connection.settings_dict['NAME']},
alias=self.alias,
- allow_thread_sharing=False,
)
return nodb_connection
diff --git a/django/test/testcases.py b/django/test/testcases.py
index 991165c04d..dea7fedbcc 100644
--- a/django/test/testcases.py
+++ b/django/test/testcases.py
@@ -1442,7 +1442,7 @@ class LiveServerTestCase(TransactionTestCase):
# the server thread.
if conn.vendor == 'sqlite' and conn.is_in_memory_db():
# Explicitly enable thread-shareability for this connection
- conn.allow_thread_sharing = True
+ conn.inc_thread_sharing()
connections_override[conn.alias] = conn
cls._live_server_modified_settings = modify_settings(
@@ -1478,10 +1478,9 @@ class LiveServerTestCase(TransactionTestCase):
# Terminate the live server's thread
cls.server_thread.terminate()
- # Restore sqlite in-memory database connections' non-shareability
- for conn in connections.all():
- if conn.vendor == 'sqlite' and conn.is_in_memory_db():
- conn.allow_thread_sharing = False
+ # Restore sqlite in-memory database connections' non-shareability.
+ for conn in cls.server_thread.connections_override.values():
+ conn.dec_thread_sharing()
@classmethod
def tearDownClass(cls):
diff --git a/docs/releases/2.2.txt b/docs/releases/2.2.txt
index 1fe5af93fe..ec6639280a 100644
--- a/docs/releases/2.2.txt
+++ b/docs/releases/2.2.txt
@@ -286,6 +286,9 @@ backends.
* ``_delete_fk_sql()`` (to pair with ``_create_fk_sql()``)
* ``_create_check_sql()`` and ``_delete_check_sql()``
+* The third argument of ``DatabaseWrapper.__init__()``,
+ ``allow_thread_sharing``, is removed.
+
Admin actions are no longer collected from base ``ModelAdmin`` classes
----------------------------------------------------------------------
diff --git a/tests/backends/tests.py b/tests/backends/tests.py
index d1b89950c0..a523fa67fd 100644
--- a/tests/backends/tests.py
+++ b/tests/backends/tests.py
@@ -605,21 +605,25 @@ class ThreadTests(TransactionTestCase):
connection = connections[DEFAULT_DB_ALIAS]
# Allow thread sharing so the connection can be closed by the
# main thread.
- connection.allow_thread_sharing = True
+ connection.inc_thread_sharing()
connection.cursor()
connections_dict[id(connection)] = connection
- for x in range(2):
- t = threading.Thread(target=runner)
- t.start()
- t.join()
- # Each created connection got different inner connection.
- self.assertEqual(len({conn.connection for conn in connections_dict.values()}), 3)
- # Finish by closing the connections opened by the other threads (the
- # connection opened in the main thread will automatically be closed on
- # teardown).
- for conn in connections_dict.values():
- if conn is not connection:
- conn.close()
+ try:
+ for x in range(2):
+ t = threading.Thread(target=runner)
+ t.start()
+ t.join()
+ # Each created connection got different inner connection.
+ self.assertEqual(len({conn.connection for conn in connections_dict.values()}), 3)
+ finally:
+ # Finish by closing the connections opened by the other threads
+ # (the connection opened in the main thread will automatically be
+ # closed on teardown).
+ for conn in connections_dict.values():
+ if conn is not connection:
+ if conn.allow_thread_sharing:
+ conn.close()
+ conn.dec_thread_sharing()
def test_connections_thread_local(self):
"""
@@ -636,19 +640,23 @@ class ThreadTests(TransactionTestCase):
for conn in connections.all():
# Allow thread sharing so the connection can be closed by the
# main thread.
- conn.allow_thread_sharing = True
+ conn.inc_thread_sharing()
connections_dict[id(conn)] = conn
- for x in range(2):
- t = threading.Thread(target=runner)
- t.start()
- t.join()
- self.assertEqual(len(connections_dict), 6)
- # Finish by closing the connections opened by the other threads (the
- # connection opened in the main thread will automatically be closed on
- # teardown).
- for conn in connections_dict.values():
- if conn is not connection:
- conn.close()
+ try:
+ for x in range(2):
+ t = threading.Thread(target=runner)
+ t.start()
+ t.join()
+ self.assertEqual(len(connections_dict), 6)
+ finally:
+ # Finish by closing the connections opened by the other threads
+ # (the connection opened in the main thread will automatically be
+ # closed on teardown).
+ for conn in connections_dict.values():
+ if conn is not connection:
+ if conn.allow_thread_sharing:
+ conn.close()
+ conn.dec_thread_sharing()
def test_pass_connection_between_threads(self):
"""
@@ -668,25 +676,21 @@ class ThreadTests(TransactionTestCase):
t.start()
t.join()
- # Without touching allow_thread_sharing, which should be False by default.
- exceptions = []
- do_thread()
- # Forbidden!
- self.assertIsInstance(exceptions[0], DatabaseError)
-
- # If explicitly setting allow_thread_sharing to False
- connections['default'].allow_thread_sharing = False
+ # Without touching thread sharing, which should be False by default.
exceptions = []
do_thread()
# Forbidden!
self.assertIsInstance(exceptions[0], DatabaseError)
- # If explicitly setting allow_thread_sharing to True
- connections['default'].allow_thread_sharing = True
- exceptions = []
- do_thread()
- # All good
- self.assertEqual(exceptions, [])
+ # After calling inc_thread_sharing() on the connection.
+ connections['default'].inc_thread_sharing()
+ try:
+ exceptions = []
+ do_thread()
+ # All good
+ self.assertEqual(exceptions, [])
+ finally:
+ connections['default'].dec_thread_sharing()
def test_closing_non_shared_connections(self):
"""
@@ -721,16 +725,33 @@ class ThreadTests(TransactionTestCase):
except DatabaseError as e:
exceptions.add(e)
# Enable thread sharing
- connections['default'].allow_thread_sharing = True
- t2 = threading.Thread(target=runner2, args=[connections['default']])
- t2.start()
- t2.join()
+ connections['default'].inc_thread_sharing()
+ try:
+ t2 = threading.Thread(target=runner2, args=[connections['default']])
+ t2.start()
+ t2.join()
+ finally:
+ connections['default'].dec_thread_sharing()
t1 = threading.Thread(target=runner1)
t1.start()
t1.join()
# No exception was raised
self.assertEqual(len(exceptions), 0)
+ def test_thread_sharing_count(self):
+ self.assertIs(connection.allow_thread_sharing, False)
+ connection.inc_thread_sharing()
+ self.assertIs(connection.allow_thread_sharing, True)
+ connection.inc_thread_sharing()
+ self.assertIs(connection.allow_thread_sharing, True)
+ connection.dec_thread_sharing()
+ self.assertIs(connection.allow_thread_sharing, True)
+ connection.dec_thread_sharing()
+ self.assertIs(connection.allow_thread_sharing, False)
+ msg = 'Cannot decrement the thread sharing count below zero.'
+ with self.assertRaisesMessage(RuntimeError, msg):
+ connection.dec_thread_sharing()
+
class MySQLPKZeroTests(TestCase):
"""
diff --git a/tests/servers/test_liveserverthread.py b/tests/servers/test_liveserverthread.py
index d39aac8183..9762b53791 100644
--- a/tests/servers/test_liveserverthread.py
+++ b/tests/servers/test_liveserverthread.py
@@ -18,11 +18,10 @@ class LiveServerThreadTest(TestCase):
# Pass a connection to the thread to check they are being closed.
connections_override = {DEFAULT_DB_ALIAS: conn}
- saved_sharing = conn.allow_thread_sharing
+ conn.inc_thread_sharing()
try:
- conn.allow_thread_sharing = True
self.assertTrue(conn.is_usable())
self.run_live_server_thread(connections_override)
self.assertFalse(conn.is_usable())
finally:
- conn.allow_thread_sharing = saved_sharing
+ conn.dec_thread_sharing()
diff --git a/tests/staticfiles_tests/test_liveserver.py b/tests/staticfiles_tests/test_liveserver.py
index 264242bbae..820fa5bc89 100644
--- a/tests/staticfiles_tests/test_liveserver.py
+++ b/tests/staticfiles_tests/test_liveserver.py
@@ -64,6 +64,9 @@ class StaticLiveServerChecks(LiveServerBase):
# app without having set the required STATIC_URL setting.")
pass
finally:
+ # Use del to avoid decrementing the database thread sharing count a
+ # second time.
+ del cls.server_thread
super().tearDownClass()
def test_test_test(self):