diff options
Diffstat (limited to 'tests/test_runner/test_debug_sql.py')
| -rw-r--r-- | tests/test_runner/test_debug_sql.py | 176 |
1 files changed, 174 insertions, 2 deletions
diff --git a/tests/test_runner/test_debug_sql.py b/tests/test_runner/test_debug_sql.py index 27fc4001c2..acf66633ef 100644 --- a/tests/test_runner/test_debug_sql.py +++ b/tests/test_runner/test_debug_sql.py @@ -1,12 +1,184 @@ +import logging import unittest from io import StringIO +from time import time +from unittest import mock -from django.db import connection +from django.db import DEFAULT_DB_ALIAS, connection, connections from django.test import TestCase -from django.test.runner import DiscoverRunner +from django.test.runner import DiscoverRunner, QueryFormatter from .models import Person +logger = logging.getLogger(__name__) + + +class QueryFormatterTests(unittest.TestCase): + + def setUp(self): + super().setUp() + self.format_sql_calls = [] + + def new_format_sql(self, sql): + # Use time() to introduce some uniqueness. + formatted = "Formatted! %s at %s" % (sql.upper(), time()) + self.format_sql_calls.append({sql: formatted}) + return formatted + + def make_handler(self, **formatter_kwargs): + formatter = QueryFormatter(**formatter_kwargs) + + handler = logging.StreamHandler(StringIO()) + handler.setLevel(logging.DEBUG) + handler.setFormatter(formatter) + + original_level = logger.getEffectiveLevel() + logger.setLevel(logging.DEBUG) + self.addCleanup(logger.setLevel, original_level) + logger.addHandler(handler) + self.addCleanup(logger.removeHandler, handler) + + return handler + + def do_log(self, msg, *logger_args, alias=DEFAULT_DB_ALIAS, extra=None): + if extra is None: + extra = {} + if alias and "alias" not in extra: + extra["alias"] = alias + # Patch connection's format_debug_sql to ensure it was properly called. + with mock.patch.object( + connections[alias].ops, "format_debug_sql", side_effect=self.new_format_sql + ): + logger.info(msg, *logger_args, extra=extra) + + def assertLogRecord(self, handler, expected): + handler.stream.seek(0) + self.assertEqual(handler.stream.read().strip(), expected) + + def assertSQLFormatted(self, handler, sql, total_calls=1): + self.assertEqual(len(self.format_sql_calls), total_calls) + formatted_sql = self.format_sql_calls[0][sql] + expected = f"=> Executing query duration=3.142 sql={formatted_sql}" + self.assertLogRecord(handler, expected) + + def test_formats_sql_bracket_format_style(self): + handler = self.make_handler( + fmt="{message} duration={duration:.3f} sql={sql}", style="{" + ) + msg = "=> Executing query" + sql = "select * from foo" + + self.do_log(msg, extra={"sql": sql, "duration": 3.1416}) + self.assertSQLFormatted(handler, sql) + + def test_formats_sql_named_fmt_format_style(self): + handler = self.make_handler( + fmt="%(message)s duration=%(duration).3f sql=%(sql)s" + ) + msg = "=> Executing query" + sql = "select * from foo" + + self.do_log(msg, extra={"sql": sql, "duration": 3.1416}) + self.assertSQLFormatted(handler, sql) + + def test_formats_sql_named_percent_format_style(self): + handler = self.make_handler() + msg = "=> Executing query duration=%(duration).3f sql=%(sql)s" + sql = "select * from foo" + + self.do_log(msg, {"duration": 3.1416, "sql": sql}) + self.assertSQLFormatted(handler, sql) + + def test_formats_sql_default_percent_format_style(self): + handler = self.make_handler() + msg = "=> Executing query duration=%.3f sql=%s" + sql = "select * from foo" + + self.do_log(msg, 3.1416, sql) + self.assertSQLFormatted(handler, sql) + + def test_formats_sql_multiple_matching_sql(self): + handler = self.make_handler() + msg = "=> Executing query duration=%.3f sql=%s" + sql = "select * from foo" + + self.do_log(msg, 3.1416, sql, extra={"duration": 3.1416, "sql": sql}) + self.assertSQLFormatted(handler, sql) + + def test_formats_sql_multiple_non_matching_sql(self): + handler = self.make_handler() + msg = "=> Executing query duration=%.3f sql=%s" + sql1 = "select * from foo" + sql2 = "select * from other" + + self.do_log(msg, 3.1416, sql1, extra={"duration": 3.1416, "sql": sql2}) + self.assertSQLFormatted(handler, sql1, total_calls=2) + # Second format call is triggered since the sql are different. + self.assertEqual(list(self.format_sql_calls[1].keys()), [sql2]) + + def test_log_record_no_args(self): + handler = self.make_handler() + msg = "=> Executing query no args" + + self.do_log(msg) + self.assertEqual(self.format_sql_calls, []) + self.assertLogRecord(handler, msg) + + def test_log_record_not_enough_args(self): + handler = self.make_handler() + msg = "=> Executing query one args %r" + args = "not formatted" + + self.do_log(msg, args) + self.assertEqual(self.format_sql_calls, []) + self.assertLogRecord(handler, msg % args) + + def test_log_record_not_key_in_dict_args(self): + handler = self.make_handler() + msg = "=> Executing query missing sql key %(foo)r" + args = {"foo": "bar"} + + self.do_log(msg, args) + self.assertEqual(self.format_sql_calls, []) + self.assertLogRecord(handler, msg % args) + + def test_log_record_no_alias(self): + handler = self.make_handler() + msg = "=> Executing query duration=%.3f sql=%s" + args = (3.1416, "select * from foo") + + self.do_log(msg, *args, extra={"alias": "does not exist"}) + self.assertEqual(self.format_sql_calls, []) + self.assertLogRecord(handler, msg % args) + + def test_log_record_sql_arg_none(self): + handler = self.make_handler() + msg = "=> Executing query duration=%.3f sql=%s" + args = (3.1416, None) + + self.do_log(msg, *args) + self.assertEqual(self.format_sql_calls, []) + self.assertLogRecord(handler, msg % args) + + def test_log_record_sql_key_none(self): + handler = self.make_handler() + msg = "=> Executing query duration=%(duration).3f sql=%(sql)s" + args = {"duration": 3.1416, "sql": None} + + self.do_log(msg, args) + self.assertEqual(self.format_sql_calls, []) + self.assertLogRecord(handler, msg % args) + + def test_log_record_sql_extra_none(self): + handler = self.make_handler( + fmt="{message} duration={duration:.3f} sql={sql}", style="{" + ) + msg = "=> Executing query" + + self.do_log(msg, extra={"sql": None, "duration": 3.1416}) + self.assertEqual(self.format_sql_calls, []) + self.assertLogRecord(handler, f"{msg} duration=3.142 sql=None") + @unittest.skipUnless( connection.vendor == "sqlite", "Only run on sqlite so we can check output SQL." |
