summaryrefslogtreecommitdiff
path: root/django/test
diff options
context:
space:
mode:
authorNatalia <124304+nessita@users.noreply.github.com>2025-06-03 15:54:16 -0300
committernessita <124304+nessita@users.noreply.github.com>2025-06-16 17:41:24 -0300
commit1a03a984ab3728253f964ba16cd8d806f76bddf9 (patch)
tree4342bb19065426d7757aadde572ff6a77688a555 /django/test
parent104cbfd44b9eff010daf0ef0e1ce434385855b13 (diff)
Fixed #36380 -- Deferred SQL formatting when running tests with --debug-sql.
Thanks to Jacob Walls for the report and previous iterations of this fix, to Simon Charette for the logging formatter idea, and to Tim Graham for testing and ensuring that 3rd party backends remain compatible. This partially reverts d8f093908c504ae0dbc39d3f5231f7d7920dde37. Refs #36112, #35448. Co-authored-by: Jacob Walls <jacobtylerwalls@gmail.com>
Diffstat (limited to 'django/test')
-rw-r--r--django/test/runner.py55
1 files changed, 38 insertions, 17 deletions
diff --git a/django/test/runner.py b/django/test/runner.py
index c8bb16e7b3..3e5c319ade 100644
--- a/django/test/runner.py
+++ b/django/test/runner.py
@@ -16,7 +16,6 @@ import unittest.suite
from collections import defaultdict
from contextlib import contextmanager
from importlib import import_module
-from io import StringIO
import django
from django.core.management import call_command
@@ -41,16 +40,47 @@ except ImportError:
tblib = None
+class QueryFormatter(logging.Formatter):
+ def format(self, record):
+ if (alias := getattr(record, "alias", None)) in connections:
+ format_sql = connections[alias].ops.format_debug_sql
+
+ sql = None
+ formatted_sql = None
+ if args := record.args:
+ if isinstance(args, tuple) and len(args) > 1 and (sql := args[1]):
+ record.args = (args[0], formatted_sql := format_sql(sql), *args[2:])
+ elif isinstance(record.args, dict) and (sql := record.args.get("sql")):
+ record.args["sql"] = formatted_sql = format_sql(sql)
+
+ if extra_sql := getattr(record, "sql", None):
+ if extra_sql == sql:
+ record.sql = formatted_sql
+ else:
+ record.sql = format_sql(extra_sql)
+
+ return super().format(record)
+
+
class DebugSQLTextTestResult(unittest.TextTestResult):
def __init__(self, stream, descriptions, verbosity):
self.logger = logging.getLogger("django.db.backends")
self.logger.setLevel(logging.DEBUG)
- self.debug_sql_stream = None
+ self.handler = None
super().__init__(stream, descriptions, verbosity)
+ def _read_logger_stream(self):
+ if self.handler is None:
+ # Error before tests e.g. in setUpTestData().
+ sql = ""
+ else:
+ self.handler.stream.seek(0)
+ sql = self.handler.stream.read()
+ return sql
+
def startTest(self, test):
- self.debug_sql_stream = StringIO()
- self.handler = logging.StreamHandler(self.debug_sql_stream)
+ self.handler = logging.StreamHandler(io.StringIO())
+ self.handler.setFormatter(QueryFormatter())
self.logger.addHandler(self.handler)
super().startTest(test)
@@ -58,35 +88,26 @@ class DebugSQLTextTestResult(unittest.TextTestResult):
super().stopTest(test)
self.logger.removeHandler(self.handler)
if self.showAll:
- self.debug_sql_stream.seek(0)
- self.stream.write(self.debug_sql_stream.read())
+ self.stream.write(self._read_logger_stream())
self.stream.writeln(self.separator2)
def addError(self, test, err):
super().addError(test, err)
- if self.debug_sql_stream is None:
- # Error before tests e.g. in setUpTestData().
- sql = ""
- else:
- self.debug_sql_stream.seek(0)
- sql = self.debug_sql_stream.read()
- self.errors[-1] = self.errors[-1] + (sql,)
+ self.errors[-1] = self.errors[-1] + (self._read_logger_stream(),)
def addFailure(self, test, err):
super().addFailure(test, err)
- self.debug_sql_stream.seek(0)
- self.failures[-1] = self.failures[-1] + (self.debug_sql_stream.read(),)
+ self.failures[-1] = self.failures[-1] + (self._read_logger_stream(),)
def addSubTest(self, test, subtest, err):
super().addSubTest(test, subtest, err)
if err is not None:
- self.debug_sql_stream.seek(0)
errors = (
self.failures
if issubclass(err[0], test.failureException)
else self.errors
)
- errors[-1] = errors[-1] + (self.debug_sql_stream.read(),)
+ errors[-1] = errors[-1] + (self._read_logger_stream(),)
def printErrorList(self, flavour, errors):
for test, err, sql_debug in errors: