diff options
| author | Natalia <124304+nessita@users.noreply.github.com> | 2025-06-03 15:54:16 -0300 |
|---|---|---|
| committer | nessita <124304+nessita@users.noreply.github.com> | 2025-06-16 17:41:24 -0300 |
| commit | 1a03a984ab3728253f964ba16cd8d806f76bddf9 (patch) | |
| tree | 4342bb19065426d7757aadde572ff6a77688a555 /django/test | |
| parent | 104cbfd44b9eff010daf0ef0e1ce434385855b13 (diff) | |
Fixed #36380 -- Deferred SQL formatting when running tests with --debug-sql.
Thanks to Jacob Walls for the report and previous iterations of this
fix, to Simon Charette for the logging formatter idea, and to Tim Graham
for testing and ensuring that 3rd party backends remain compatible.
This partially reverts d8f093908c504ae0dbc39d3f5231f7d7920dde37.
Refs #36112, #35448.
Co-authored-by: Jacob Walls <jacobtylerwalls@gmail.com>
Diffstat (limited to 'django/test')
| -rw-r--r-- | django/test/runner.py | 55 |
1 files changed, 38 insertions, 17 deletions
diff --git a/django/test/runner.py b/django/test/runner.py index c8bb16e7b3..3e5c319ade 100644 --- a/django/test/runner.py +++ b/django/test/runner.py @@ -16,7 +16,6 @@ import unittest.suite from collections import defaultdict from contextlib import contextmanager from importlib import import_module -from io import StringIO import django from django.core.management import call_command @@ -41,16 +40,47 @@ except ImportError: tblib = None +class QueryFormatter(logging.Formatter): + def format(self, record): + if (alias := getattr(record, "alias", None)) in connections: + format_sql = connections[alias].ops.format_debug_sql + + sql = None + formatted_sql = None + if args := record.args: + if isinstance(args, tuple) and len(args) > 1 and (sql := args[1]): + record.args = (args[0], formatted_sql := format_sql(sql), *args[2:]) + elif isinstance(record.args, dict) and (sql := record.args.get("sql")): + record.args["sql"] = formatted_sql = format_sql(sql) + + if extra_sql := getattr(record, "sql", None): + if extra_sql == sql: + record.sql = formatted_sql + else: + record.sql = format_sql(extra_sql) + + return super().format(record) + + class DebugSQLTextTestResult(unittest.TextTestResult): def __init__(self, stream, descriptions, verbosity): self.logger = logging.getLogger("django.db.backends") self.logger.setLevel(logging.DEBUG) - self.debug_sql_stream = None + self.handler = None super().__init__(stream, descriptions, verbosity) + def _read_logger_stream(self): + if self.handler is None: + # Error before tests e.g. in setUpTestData(). + sql = "" + else: + self.handler.stream.seek(0) + sql = self.handler.stream.read() + return sql + def startTest(self, test): - self.debug_sql_stream = StringIO() - self.handler = logging.StreamHandler(self.debug_sql_stream) + self.handler = logging.StreamHandler(io.StringIO()) + self.handler.setFormatter(QueryFormatter()) self.logger.addHandler(self.handler) super().startTest(test) @@ -58,35 +88,26 @@ class DebugSQLTextTestResult(unittest.TextTestResult): super().stopTest(test) self.logger.removeHandler(self.handler) if self.showAll: - self.debug_sql_stream.seek(0) - self.stream.write(self.debug_sql_stream.read()) + self.stream.write(self._read_logger_stream()) self.stream.writeln(self.separator2) def addError(self, test, err): super().addError(test, err) - if self.debug_sql_stream is None: - # Error before tests e.g. in setUpTestData(). - sql = "" - else: - self.debug_sql_stream.seek(0) - sql = self.debug_sql_stream.read() - self.errors[-1] = self.errors[-1] + (sql,) + self.errors[-1] = self.errors[-1] + (self._read_logger_stream(),) def addFailure(self, test, err): super().addFailure(test, err) - self.debug_sql_stream.seek(0) - self.failures[-1] = self.failures[-1] + (self.debug_sql_stream.read(),) + self.failures[-1] = self.failures[-1] + (self._read_logger_stream(),) def addSubTest(self, test, subtest, err): super().addSubTest(test, subtest, err) if err is not None: - self.debug_sql_stream.seek(0) errors = ( self.failures if issubclass(err[0], test.failureException) else self.errors ) - errors[-1] = errors[-1] + (self.debug_sql_stream.read(),) + errors[-1] = errors[-1] + (self._read_logger_stream(),) def printErrorList(self, flavour, errors): for test, err, sql_debug in errors: |
