summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--django/db/backends/utils.py2
-rw-r--r--django/test/runner.py55
-rw-r--r--tests/backends/tests.py11
-rw-r--r--tests/test_runner/test_debug_sql.py176
4 files changed, 216 insertions, 28 deletions
diff --git a/django/db/backends/utils.py b/django/db/backends/utils.py
index 568f510a67..ab0ea8258b 100644
--- a/django/db/backends/utils.py
+++ b/django/db/backends/utils.py
@@ -151,7 +151,7 @@ class CursorDebugWrapper(CursorWrapper):
logger.debug(
"(%.3f) %s; args=%s; alias=%s",
duration,
- self.db.ops.format_debug_sql(sql),
+ sql,
params,
self.db.alias,
extra={
diff --git a/django/test/runner.py b/django/test/runner.py
index c8bb16e7b3..3e5c319ade 100644
--- a/django/test/runner.py
+++ b/django/test/runner.py
@@ -16,7 +16,6 @@ import unittest.suite
from collections import defaultdict
from contextlib import contextmanager
from importlib import import_module
-from io import StringIO
import django
from django.core.management import call_command
@@ -41,16 +40,47 @@ except ImportError:
tblib = None
+class QueryFormatter(logging.Formatter):
+ def format(self, record):
+ if (alias := getattr(record, "alias", None)) in connections:
+ format_sql = connections[alias].ops.format_debug_sql
+
+ sql = None
+ formatted_sql = None
+ if args := record.args:
+ if isinstance(args, tuple) and len(args) > 1 and (sql := args[1]):
+ record.args = (args[0], formatted_sql := format_sql(sql), *args[2:])
+ elif isinstance(record.args, dict) and (sql := record.args.get("sql")):
+ record.args["sql"] = formatted_sql = format_sql(sql)
+
+ if extra_sql := getattr(record, "sql", None):
+ if extra_sql == sql:
+ record.sql = formatted_sql
+ else:
+ record.sql = format_sql(extra_sql)
+
+ return super().format(record)
+
+
class DebugSQLTextTestResult(unittest.TextTestResult):
def __init__(self, stream, descriptions, verbosity):
self.logger = logging.getLogger("django.db.backends")
self.logger.setLevel(logging.DEBUG)
- self.debug_sql_stream = None
+ self.handler = None
super().__init__(stream, descriptions, verbosity)
+ def _read_logger_stream(self):
+ if self.handler is None:
+ # Error before tests e.g. in setUpTestData().
+ sql = ""
+ else:
+ self.handler.stream.seek(0)
+ sql = self.handler.stream.read()
+ return sql
+
def startTest(self, test):
- self.debug_sql_stream = StringIO()
- self.handler = logging.StreamHandler(self.debug_sql_stream)
+ self.handler = logging.StreamHandler(io.StringIO())
+ self.handler.setFormatter(QueryFormatter())
self.logger.addHandler(self.handler)
super().startTest(test)
@@ -58,35 +88,26 @@ class DebugSQLTextTestResult(unittest.TextTestResult):
super().stopTest(test)
self.logger.removeHandler(self.handler)
if self.showAll:
- self.debug_sql_stream.seek(0)
- self.stream.write(self.debug_sql_stream.read())
+ self.stream.write(self._read_logger_stream())
self.stream.writeln(self.separator2)
def addError(self, test, err):
super().addError(test, err)
- if self.debug_sql_stream is None:
- # Error before tests e.g. in setUpTestData().
- sql = ""
- else:
- self.debug_sql_stream.seek(0)
- sql = self.debug_sql_stream.read()
- self.errors[-1] = self.errors[-1] + (sql,)
+ self.errors[-1] = self.errors[-1] + (self._read_logger_stream(),)
def addFailure(self, test, err):
super().addFailure(test, err)
- self.debug_sql_stream.seek(0)
- self.failures[-1] = self.failures[-1] + (self.debug_sql_stream.read(),)
+ self.failures[-1] = self.failures[-1] + (self._read_logger_stream(),)
def addSubTest(self, test, subtest, err):
super().addSubTest(test, subtest, err)
if err is not None:
- self.debug_sql_stream.seek(0)
errors = (
self.failures
if issubclass(err[0], test.failureException)
else self.errors
)
- errors[-1] = errors[-1] + (self.debug_sql_stream.read(),)
+ errors[-1] = errors[-1] + (self._read_logger_stream(),)
def printErrorList(self, flavour, errors):
for test, err, sql_debug in errors:
diff --git a/tests/backends/tests.py b/tests/backends/tests.py
index 0e5348e248..5e38f0112d 100644
--- a/tests/backends/tests.py
+++ b/tests/backends/tests.py
@@ -83,12 +83,7 @@ class LastExecutedQueryTest(TestCase):
connection.ops.last_executed_query(cursor, "SELECT %s" + suffix, (1,))
def test_debug_sql(self):
- qs = Reporter.objects.filter(first_name="test")
- ops = connections[qs.db].ops
- with mock.patch.object(ops, "format_debug_sql") as format_debug_sql:
- list(qs)
- # Queries are formatted with DatabaseOperations.format_debug_sql().
- format_debug_sql.assert_called()
+ list(Reporter.objects.filter(first_name="test"))
sql = connection.queries[-1]["sql"].lower()
self.assertIn("select", sql)
self.assertIn(Reporter._meta.db_table, sql)
@@ -580,13 +575,13 @@ class BackendTestCase(TransactionTestCase):
@mock.patch("django.db.backends.utils.logger")
@override_settings(DEBUG=True)
def test_queries_logger(self, mocked_logger):
- sql = "SELECT 1" + connection.features.bare_select_suffix
- sql = connection.ops.format_debug_sql(sql)
+ sql = "select 1" + connection.features.bare_select_suffix
with connection.cursor() as cursor:
cursor.execute(sql)
params, kwargs = mocked_logger.debug.call_args
self.assertIn("; alias=%s", params[0])
self.assertEqual(params[2], sql)
+ self.assertNotEqual(params[2], connection.ops.format_debug_sql(sql))
self.assertIsNone(params[3])
self.assertEqual(params[4], connection.alias)
self.assertEqual(
diff --git a/tests/test_runner/test_debug_sql.py b/tests/test_runner/test_debug_sql.py
index 27fc4001c2..acf66633ef 100644
--- a/tests/test_runner/test_debug_sql.py
+++ b/tests/test_runner/test_debug_sql.py
@@ -1,12 +1,184 @@
+import logging
import unittest
from io import StringIO
+from time import time
+from unittest import mock
-from django.db import connection
+from django.db import DEFAULT_DB_ALIAS, connection, connections
from django.test import TestCase
-from django.test.runner import DiscoverRunner
+from django.test.runner import DiscoverRunner, QueryFormatter
from .models import Person
+logger = logging.getLogger(__name__)
+
+
+class QueryFormatterTests(unittest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.format_sql_calls = []
+
+ def new_format_sql(self, sql):
+ # Use time() to introduce some uniqueness.
+ formatted = "Formatted! %s at %s" % (sql.upper(), time())
+ self.format_sql_calls.append({sql: formatted})
+ return formatted
+
+ def make_handler(self, **formatter_kwargs):
+ formatter = QueryFormatter(**formatter_kwargs)
+
+ handler = logging.StreamHandler(StringIO())
+ handler.setLevel(logging.DEBUG)
+ handler.setFormatter(formatter)
+
+ original_level = logger.getEffectiveLevel()
+ logger.setLevel(logging.DEBUG)
+ self.addCleanup(logger.setLevel, original_level)
+ logger.addHandler(handler)
+ self.addCleanup(logger.removeHandler, handler)
+
+ return handler
+
+ def do_log(self, msg, *logger_args, alias=DEFAULT_DB_ALIAS, extra=None):
+ if extra is None:
+ extra = {}
+ if alias and "alias" not in extra:
+ extra["alias"] = alias
+ # Patch connection's format_debug_sql to ensure it was properly called.
+ with mock.patch.object(
+ connections[alias].ops, "format_debug_sql", side_effect=self.new_format_sql
+ ):
+ logger.info(msg, *logger_args, extra=extra)
+
+ def assertLogRecord(self, handler, expected):
+ handler.stream.seek(0)
+ self.assertEqual(handler.stream.read().strip(), expected)
+
+ def assertSQLFormatted(self, handler, sql, total_calls=1):
+ self.assertEqual(len(self.format_sql_calls), total_calls)
+ formatted_sql = self.format_sql_calls[0][sql]
+ expected = f"=> Executing query duration=3.142 sql={formatted_sql}"
+ self.assertLogRecord(handler, expected)
+
+ def test_formats_sql_bracket_format_style(self):
+ handler = self.make_handler(
+ fmt="{message} duration={duration:.3f} sql={sql}", style="{"
+ )
+ msg = "=> Executing query"
+ sql = "select * from foo"
+
+ self.do_log(msg, extra={"sql": sql, "duration": 3.1416})
+ self.assertSQLFormatted(handler, sql)
+
+ def test_formats_sql_named_fmt_format_style(self):
+ handler = self.make_handler(
+ fmt="%(message)s duration=%(duration).3f sql=%(sql)s"
+ )
+ msg = "=> Executing query"
+ sql = "select * from foo"
+
+ self.do_log(msg, extra={"sql": sql, "duration": 3.1416})
+ self.assertSQLFormatted(handler, sql)
+
+ def test_formats_sql_named_percent_format_style(self):
+ handler = self.make_handler()
+ msg = "=> Executing query duration=%(duration).3f sql=%(sql)s"
+ sql = "select * from foo"
+
+ self.do_log(msg, {"duration": 3.1416, "sql": sql})
+ self.assertSQLFormatted(handler, sql)
+
+ def test_formats_sql_default_percent_format_style(self):
+ handler = self.make_handler()
+ msg = "=> Executing query duration=%.3f sql=%s"
+ sql = "select * from foo"
+
+ self.do_log(msg, 3.1416, sql)
+ self.assertSQLFormatted(handler, sql)
+
+ def test_formats_sql_multiple_matching_sql(self):
+ handler = self.make_handler()
+ msg = "=> Executing query duration=%.3f sql=%s"
+ sql = "select * from foo"
+
+ self.do_log(msg, 3.1416, sql, extra={"duration": 3.1416, "sql": sql})
+ self.assertSQLFormatted(handler, sql)
+
+ def test_formats_sql_multiple_non_matching_sql(self):
+ handler = self.make_handler()
+ msg = "=> Executing query duration=%.3f sql=%s"
+ sql1 = "select * from foo"
+ sql2 = "select * from other"
+
+ self.do_log(msg, 3.1416, sql1, extra={"duration": 3.1416, "sql": sql2})
+ self.assertSQLFormatted(handler, sql1, total_calls=2)
+ # Second format call is triggered since the sql are different.
+ self.assertEqual(list(self.format_sql_calls[1].keys()), [sql2])
+
+ def test_log_record_no_args(self):
+ handler = self.make_handler()
+ msg = "=> Executing query no args"
+
+ self.do_log(msg)
+ self.assertEqual(self.format_sql_calls, [])
+ self.assertLogRecord(handler, msg)
+
+ def test_log_record_not_enough_args(self):
+ handler = self.make_handler()
+ msg = "=> Executing query one args %r"
+ args = "not formatted"
+
+ self.do_log(msg, args)
+ self.assertEqual(self.format_sql_calls, [])
+ self.assertLogRecord(handler, msg % args)
+
+ def test_log_record_not_key_in_dict_args(self):
+ handler = self.make_handler()
+ msg = "=> Executing query missing sql key %(foo)r"
+ args = {"foo": "bar"}
+
+ self.do_log(msg, args)
+ self.assertEqual(self.format_sql_calls, [])
+ self.assertLogRecord(handler, msg % args)
+
+ def test_log_record_no_alias(self):
+ handler = self.make_handler()
+ msg = "=> Executing query duration=%.3f sql=%s"
+ args = (3.1416, "select * from foo")
+
+ self.do_log(msg, *args, extra={"alias": "does not exist"})
+ self.assertEqual(self.format_sql_calls, [])
+ self.assertLogRecord(handler, msg % args)
+
+ def test_log_record_sql_arg_none(self):
+ handler = self.make_handler()
+ msg = "=> Executing query duration=%.3f sql=%s"
+ args = (3.1416, None)
+
+ self.do_log(msg, *args)
+ self.assertEqual(self.format_sql_calls, [])
+ self.assertLogRecord(handler, msg % args)
+
+ def test_log_record_sql_key_none(self):
+ handler = self.make_handler()
+ msg = "=> Executing query duration=%(duration).3f sql=%(sql)s"
+ args = {"duration": 3.1416, "sql": None}
+
+ self.do_log(msg, args)
+ self.assertEqual(self.format_sql_calls, [])
+ self.assertLogRecord(handler, msg % args)
+
+ def test_log_record_sql_extra_none(self):
+ handler = self.make_handler(
+ fmt="{message} duration={duration:.3f} sql={sql}", style="{"
+ )
+ msg = "=> Executing query"
+
+ self.do_log(msg, extra={"sql": None, "duration": 3.1416})
+ self.assertEqual(self.format_sql_calls, [])
+ self.assertLogRecord(handler, f"{msg} duration=3.142 sql=None")
+
@unittest.skipUnless(
connection.vendor == "sqlite", "Only run on sqlite so we can check output SQL."