From fe67703e13b4be487497eeea2147f2a6dac2513d Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 21 Oct 2022 09:52:03 -0500 Subject: [PATCH] Log invalid messages instead of raising in system_log (#80645) --- .../components/system_log/__init__.py | 50 +++++++++++++++---- tests/components/system_log/test_init.py | 33 ++++++++++++ 2 files changed, 72 insertions(+), 11 deletions(-) diff --git a/homeassistant/components/system_log/__init__.py b/homeassistant/components/system_log/__init__.py index fae8598407e..9f00009b322 100644 --- a/homeassistant/components/system_log/__init__.py +++ b/homeassistant/components/system_log/__init__.py @@ -1,9 +1,11 @@ """Support for system log.""" +from __future__ import annotations + from collections import OrderedDict, deque import logging import re import traceback -from typing import Any +from typing import Any, cast import voluptuous as vol @@ -56,7 +58,9 @@ SERVICE_WRITE_SCHEMA = vol.Schema( ) -def _figure_out_source(record, call_stack, paths_re): +def _figure_out_source( + record: logging.LogRecord, call_stack: list[tuple[str, int]], paths_re: re.Pattern +) -> tuple[str, int]: # If a stack trace exists, extract file names from the entire call stack. # The other case is when a regular "log" is made (without an attached @@ -81,20 +85,44 @@ def _figure_out_source(record, call_stack, paths_re): # Try to match with a file within Home Assistant if match := paths_re.match(pathname[0]): - return [match.group(1), pathname[1]] + return (cast(str, match.group(1)), pathname[1]) # Ok, we don't know what this is return (record.pathname, record.lineno) +def _safe_get_message(record: logging.LogRecord) -> str: + """Get message from record and handle exceptions. + + This code will be unreachable during a pytest run + because pytest installs a logging handler that + will prevent this code from being reached. + + Calling record.getMessage() can raise an exception + if the log message does not contain sufficient arguments. + + As there is no guarantees about which exceptions + that can be raised, we catch all exceptions and + return a generic message. + + This must be manually tested when changing the code. + """ + try: + return record.getMessage() + except Exception: # pylint: disable=broad-except + return f"Bad logger message: {record.msg} ({record.args})" + + class LogEntry: """Store HA log entries.""" - def __init__(self, record, stack, source): + def __init__(self, record: logging.LogRecord, source: tuple[str, int]) -> None: """Initialize a log entry.""" self.first_occurred = self.timestamp = record.created self.name = record.name self.level = record.levelname - self.message = deque([record.getMessage()], maxlen=5) + # See the docstring of _safe_get_message for why we need to do this. + # This must be manually tested when changing the code. + self.message = deque([_safe_get_message(record)], maxlen=5) self.exception = "" self.root_cause = None if record.exc_info: @@ -129,7 +157,7 @@ class DedupStore(OrderedDict): super().__init__() self.maxlen = maxlen - def add_entry(self, entry): + def add_entry(self, entry: LogEntry) -> None: """Add a new entry.""" key = entry.hash @@ -158,7 +186,9 @@ class DedupStore(OrderedDict): class LogErrorHandler(logging.Handler): """Log handler for error messages.""" - def __init__(self, hass, maxlen, fire_event, paths_re): + def __init__( + self, hass: HomeAssistant, maxlen: int, fire_event: bool, paths_re: re.Pattern + ) -> None: """Initialize a new LogErrorHandler.""" super().__init__() self.hass = hass @@ -166,7 +196,7 @@ class LogErrorHandler(logging.Handler): self.fire_event = fire_event self.paths_re = paths_re - def emit(self, record): + def emit(self, record: logging.LogRecord) -> None: """Save error and warning logs. Everything logged with error or warning is saved in local buffer. A @@ -177,9 +207,7 @@ class LogErrorHandler(logging.Handler): if not record.exc_info: stack = [(f[0], f[1]) for f in traceback.extract_stack()] - entry = LogEntry( - record, stack, _figure_out_source(record, stack, self.paths_re) - ) + entry = LogEntry(record, _figure_out_source(record, stack, self.paths_re)) self.records.add_entry(entry) if self.fire_event: self.hass.bus.fire(EVENT_SYSTEM_LOG, entry.to_dict()) diff --git a/tests/components/system_log/test_init.py b/tests/components/system_log/test_init.py index 96e5480acb5..18693aee448 100644 --- a/tests/components/system_log/test_init.py +++ b/tests/components/system_log/test_init.py @@ -136,6 +136,28 @@ async def test_warning(hass, hass_ws_client): assert_log(log, "", "warning message", "WARNING") +async def test_warning_good_format(hass, hass_ws_client): + """Test that warning with good format arguments are logged and retrieved correctly.""" + await async_setup_component(hass, system_log.DOMAIN, BASIC_CONFIG) + await hass.async_block_till_done() + _LOGGER.warning("warning message: %s", "test") + await hass.async_block_till_done() + + log = find_log(await get_error_log(hass_ws_client), "WARNING") + assert_log(log, "", "warning message: test", "WARNING") + + +async def test_warning_missing_format_args(hass, hass_ws_client): + """Test that warning with missing format arguments are logged and retrieved correctly.""" + await async_setup_component(hass, system_log.DOMAIN, BASIC_CONFIG) + await hass.async_block_till_done() + _LOGGER.warning("warning message missing a format arg %s") + await hass.async_block_till_done() + + log = find_log(await get_error_log(hass_ws_client), "WARNING") + assert_log(log, "", ["warning message missing a format arg %s"], "WARNING") + + async def test_error(hass, hass_ws_client): """Test that errors are logged and retrieved correctly.""" await async_setup_component(hass, system_log.DOMAIN, BASIC_CONFIG) @@ -195,6 +217,17 @@ async def test_critical(hass, hass_ws_client): assert_log(log, "", "critical message", "CRITICAL") +async def test_critical_with_missing_format_args(hass, hass_ws_client): + """Test that critical messages with missing format args are logged and retrieved correctly.""" + await async_setup_component(hass, system_log.DOMAIN, BASIC_CONFIG) + await hass.async_block_till_done() + + try: + _LOGGER.critical("critical message %s = %s", "one_but_needs_two") + except TypeError: + pass + + async def test_remove_older_logs(hass, hass_ws_client): """Test that older logs are rotated out.""" await async_setup_component(hass, system_log.DOMAIN, BASIC_CONFIG)