Only extract traceback once in system_log (#113201)

This commit is contained in:
J. Nick Koston 2024-03-13 00:58:34 -10:00 committed by GitHub
parent 546e5f607f
commit bbef3f7f68
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 37 additions and 17 deletions

View File

@ -63,14 +63,19 @@ SERVICE_WRITE_SCHEMA = vol.Schema(
def _figure_out_source( def _figure_out_source(
record: logging.LogRecord, paths_re: re.Pattern[str] record: logging.LogRecord,
paths_re: re.Pattern[str],
extracted_tb: traceback.StackSummary | None = None,
) -> tuple[str, int]: ) -> tuple[str, int]:
"""Figure out where a log message came from.""" """Figure out where a log message came from."""
# If a stack trace exists, extract file names from the entire call stack. # If a stack trace exists, extract file names from the entire call stack.
# The other case is when a regular "log" is made (without an attached # The other case is when a regular "log" is made (without an attached
# exception). In that case, just use the file where the log was made from. # exception). In that case, just use the file where the log was made from.
if record.exc_info: if record.exc_info:
stack = [(x[0], x[1]) for x in traceback.extract_tb(record.exc_info[2])] stack = [
(x[0], x[1])
for x in (extracted_tb or traceback.extract_tb(record.exc_info[2]))
]
for i, (filename, _) in enumerate(stack): for i, (filename, _) in enumerate(stack):
# Slice the stack to the first frame that matches # Slice the stack to the first frame that matches
# the record pathname. # the record pathname.
@ -161,13 +166,19 @@ class LogEntry:
"level", "level",
"message", "message",
"exception", "exception",
"extracted_tb",
"root_cause", "root_cause",
"source", "source",
"count", "count",
"key", "key",
) )
def __init__(self, record: logging.LogRecord, source: tuple[str, int]) -> None: def __init__(
self,
record: logging.LogRecord,
paths_re: re.Pattern,
figure_out_source: bool = False,
) -> None:
"""Initialize a log entry.""" """Initialize a log entry."""
self.first_occurred = self.timestamp = record.created self.first_occurred = self.timestamp = record.created
self.name = record.name self.name = record.name
@ -176,16 +187,21 @@ class LogEntry:
# This must be manually tested when changing the code. # This must be manually tested when changing the code.
self.message = deque([_safe_get_message(record)], maxlen=5) self.message = deque([_safe_get_message(record)], maxlen=5)
self.exception = "" self.exception = ""
self.root_cause = None self.root_cause: str | None = None
extracted_tb: traceback.StackSummary | None = None
if record.exc_info: if record.exc_info:
self.exception = "".join(traceback.format_exception(*record.exc_info)) self.exception = "".join(traceback.format_exception(*record.exc_info))
_, _, tb = record.exc_info if extracted := traceback.extract_tb(record.exc_info[2]):
# Last line of traceback contains the root cause of the exception # Last line of traceback contains the root cause of the exception
if extracted := traceback.extract_tb(tb): extracted_tb = extracted
self.root_cause = str(extracted[-1]) self.root_cause = str(extracted[-1])
self.source = source if figure_out_source:
self.source = _figure_out_source(record, paths_re, extracted_tb)
else:
self.source = (record.pathname, record.lineno)
self.count = 1 self.count = 1
self.key = (self.name, source, self.root_cause) self.extracted_tb = extracted_tb
self.key = (self.name, self.source, self.root_cause)
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
"""Convert object into dict to maintain backward compatibility.""" """Convert object into dict to maintain backward compatibility."""
@ -259,7 +275,7 @@ class LogErrorHandler(logging.Handler):
default upper limit is set to 50 (older entries are discarded) but can default upper limit is set to 50 (older entries are discarded) but can
be changed if needed. be changed if needed.
""" """
entry = LogEntry(record, _figure_out_source(record, self.paths_re)) entry = LogEntry(record, self.paths_re, figure_out_source=True)
self.records.add_entry(entry) self.records.add_entry(entry)
if self.fire_event: if self.fire_event:
self.hass.bus.fire(EVENT_SYSTEM_LOG, entry.to_dict()) self.hass.bus.fire(EVENT_SYSTEM_LOG, entry.to_dict())

View File

@ -30,7 +30,7 @@ from zigpy.state import State
from zigpy.types.named import EUI64 from zigpy.types.named import EUI64
from homeassistant import __path__ as HOMEASSISTANT_PATH from homeassistant import __path__ as HOMEASSISTANT_PATH
from homeassistant.components.system_log import LogEntry, _figure_out_source from homeassistant.components.system_log import LogEntry
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.helpers import device_registry as dr, entity_registry as er
@ -871,10 +871,9 @@ class LogRelayHandler(logging.Handler):
def emit(self, record: LogRecord) -> None: def emit(self, record: LogRecord) -> None:
"""Relay log message via dispatcher.""" """Relay log message via dispatcher."""
if record.levelno >= logging.WARN: entry = LogEntry(
entry = LogEntry(record, _figure_out_source(record, self.paths_re)) record, self.paths_re, figure_out_source=record.levelno >= logging.WARN
else: )
entry = LogEntry(record, (record.pathname, record.lineno))
async_dispatcher_send( async_dispatcher_send(
self.hass, self.hass,
ZHA_GW_MSG, ZHA_GW_MSG,

View File

@ -459,14 +459,19 @@ async def test__figure_out_source(hass: HomeAssistant) -> None:
except ValueError as ex: except ValueError as ex:
exc_info = (type(ex), ex, ex.__traceback__) exc_info = (type(ex), ex, ex.__traceback__)
mock_record = MagicMock( mock_record = MagicMock(
pathname="should not hit", pathname="figure_out_source is False",
lineno=5, lineno=5,
exc_info=exc_info, exc_info=exc_info,
) )
regex_str = f"({__file__})" regex_str = f"({__file__})"
paths_re = re.compile(regex_str)
file, line_no = system_log._figure_out_source( file, line_no = system_log._figure_out_source(
mock_record, mock_record,
re.compile(regex_str), paths_re,
traceback.extract_tb(exc_info[2]),
) )
assert file == __file__ assert file == __file__
assert line_no != 5 assert line_no != 5
entry = system_log.LogEntry(mock_record, paths_re, figure_out_source=False)
assert entry.source == ("figure_out_source is False", 5)