From 14ef0531f03411f6e015e9dc45888704aa1c7f53 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Mon, 29 Mar 2021 08:09:14 +0200 Subject: [PATCH] Address review comments from trace refactoring PRs (#48288) --- homeassistant/components/automation/trace.py | 66 +++++++++++++++++-- homeassistant/components/script/trace.py | 18 ++++- homeassistant/components/trace/__init__.py | 65 ------------------ homeassistant/components/trace/trace.py | 35 ---------- homeassistant/components/trace/utils.py | 2 +- .../components/trace/websocket_api.py | 40 +++++++---- tests/components/trace/test_websocket_api.py | 44 +++++++------ 7 files changed, 130 insertions(+), 140 deletions(-) delete mode 100644 homeassistant/components/trace/trace.py diff --git a/homeassistant/components/automation/trace.py b/homeassistant/components/automation/trace.py index a2c2e40c80c..31c20eb1402 100644 --- a/homeassistant/components/automation/trace.py +++ b/homeassistant/components/automation/trace.py @@ -2,25 +2,79 @@ from __future__ import annotations from contextlib import contextmanager +from typing import Any, Deque -from homeassistant.components.trace import AutomationTrace, async_store_trace +from homeassistant.components.trace import ActionTrace, async_store_trace +from homeassistant.core import Context +from homeassistant.helpers.trace import TraceElement # mypy: allow-untyped-calls, allow-untyped-defs # mypy: no-check-untyped-defs, no-warn-return-any +class AutomationTrace(ActionTrace): + """Container for automation trace.""" + + def __init__( + self, + item_id: str, + config: dict[str, Any], + context: Context, + ): + """Container for automation trace.""" + key = ("automation", item_id) + super().__init__(key, config, context) + self._condition_trace: dict[str, Deque[TraceElement]] | None = None + + def set_condition_trace(self, trace: dict[str, Deque[TraceElement]]) -> None: + """Set condition trace.""" + self._condition_trace = trace + + def as_dict(self) -> dict[str, Any]: + """Return dictionary version of this AutomationTrace.""" + + result = super().as_dict() + + condition_traces = {} + + if self._condition_trace: + for key, trace_list in self._condition_trace.items(): + condition_traces[key] = [item.as_dict() for item in trace_list] + result["condition_trace"] = condition_traces + + return result + + def as_short_dict(self) -> dict[str, Any]: + """Return a brief dictionary version of this AutomationTrace.""" + + result = super().as_short_dict() + + last_condition = None + trigger = None + + if self._condition_trace: + last_condition = list(self._condition_trace)[-1] + if self._variables: + trigger = self._variables.get("trigger", {}).get("description") + + result["trigger"] = trigger + result["last_condition"] = last_condition + + return result + + @contextmanager -def trace_automation(hass, item_id, config, context): - """Trace action execution of automation with item_id.""" - trace = AutomationTrace(item_id, config, context) +def trace_automation(hass, automation_id, config, context): + """Trace action execution of automation with automation_id.""" + trace = AutomationTrace(automation_id, config, context) async_store_trace(hass, trace) try: yield trace except Exception as ex: # pylint: disable=broad-except - if item_id: + if automation_id: trace.set_error(ex) raise ex finally: - if item_id: + if automation_id: trace.finished() diff --git a/homeassistant/components/script/trace.py b/homeassistant/components/script/trace.py index 09b22f98133..8183c882d3d 100644 --- a/homeassistant/components/script/trace.py +++ b/homeassistant/components/script/trace.py @@ -2,8 +2,24 @@ from __future__ import annotations from contextlib import contextmanager +from typing import Any -from homeassistant.components.trace import ScriptTrace, async_store_trace +from homeassistant.components.trace import ActionTrace, async_store_trace +from homeassistant.core import Context + + +class ScriptTrace(ActionTrace): + """Container for automation trace.""" + + def __init__( + self, + item_id: str, + config: dict[str, Any], + context: Context, + ): + """Container for automation trace.""" + key = ("script", item_id) + super().__init__(key, config, context) @contextmanager diff --git a/homeassistant/components/trace/__init__.py b/homeassistant/components/trace/__init__.py index b0211505e7c..9505c1c1264 100644 --- a/homeassistant/components/trace/__init__.py +++ b/homeassistant/components/trace/__init__.py @@ -128,68 +128,3 @@ class ActionTrace: result["last_action"] = last_action return result - - -class AutomationTrace(ActionTrace): - """Container for automation trace.""" - - def __init__( - self, - item_id: str, - config: dict[str, Any], - context: Context, - ): - """Container for automation trace.""" - key = ("automation", item_id) - super().__init__(key, config, context) - self._condition_trace: dict[str, Deque[TraceElement]] | None = None - - def set_condition_trace(self, trace: dict[str, Deque[TraceElement]]) -> None: - """Set condition trace.""" - self._condition_trace = trace - - def as_dict(self) -> dict[str, Any]: - """Return dictionary version of this AutomationTrace.""" - - result = super().as_dict() - - condition_traces = {} - - if self._condition_trace: - for key, trace_list in self._condition_trace.items(): - condition_traces[key] = [item.as_dict() for item in trace_list] - result["condition_trace"] = condition_traces - - return result - - def as_short_dict(self) -> dict[str, Any]: - """Return a brief dictionary version of this AutomationTrace.""" - - result = super().as_short_dict() - - last_condition = None - trigger = None - - if self._condition_trace: - last_condition = list(self._condition_trace)[-1] - if self._variables: - trigger = self._variables.get("trigger", {}).get("description") - - result["trigger"] = trigger - result["last_condition"] = last_condition - - return result - - -class ScriptTrace(ActionTrace): - """Container for automation trace.""" - - def __init__( - self, - item_id: str, - config: dict[str, Any], - context: Context, - ): - """Container for automation trace.""" - key = ("script", item_id) - super().__init__(key, config, context) diff --git a/homeassistant/components/trace/trace.py b/homeassistant/components/trace/trace.py deleted file mode 100644 index e9255550ec5..00000000000 --- a/homeassistant/components/trace/trace.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Support for automation and script tracing and debugging.""" -from homeassistant.core import callback - -from .const import DATA_TRACE - - -@callback -def get_debug_trace(hass, key, run_id): - """Return a serializable debug trace.""" - return hass.data[DATA_TRACE][key][run_id] - - -@callback -def get_debug_traces(hass, key, summary=False): - """Return a serializable list of debug traces for an automation or script.""" - traces = [] - - for trace in hass.data[DATA_TRACE].get(key, {}).values(): - if summary: - traces.append(trace.as_short_dict()) - else: - traces.append(trace.as_dict()) - - return traces - - -@callback -def get_all_debug_traces(hass, summary=False): - """Return a serializable list of debug traces for all automations and scripts.""" - traces = [] - - for key in hass.data[DATA_TRACE]: - traces.extend(get_debug_traces(hass, key, summary)) - - return traces diff --git a/homeassistant/components/trace/utils.py b/homeassistant/components/trace/utils.py index 59bf8c98498..7e804724c55 100644 --- a/homeassistant/components/trace/utils.py +++ b/homeassistant/components/trace/utils.py @@ -1,4 +1,4 @@ -"""Helpers for automation and script tracing and debugging.""" +"""Helpers for script and automation tracing and debugging.""" from collections import OrderedDict from datetime import timedelta from typing import Any diff --git a/homeassistant/components/trace/websocket_api.py b/homeassistant/components/trace/websocket_api.py index 1b5270f6253..02d718a97ec 100644 --- a/homeassistant/components/trace/websocket_api.py +++ b/homeassistant/components/trace/websocket_api.py @@ -23,12 +23,12 @@ from homeassistant.helpers.script import ( debug_stop, ) -from .trace import DATA_TRACE, get_all_debug_traces, get_debug_trace, get_debug_traces +from .const import DATA_TRACE from .utils import TraceJSONEncoder # mypy: allow-untyped-calls, allow-untyped-defs -TRACE_DOMAINS = ["automation", "script"] +TRACE_DOMAINS = ("automation", "script") @callback @@ -57,33 +57,47 @@ def async_setup(hass: HomeAssistant) -> None: } ) def websocket_trace_get(hass, connection, msg): - """Get an automation or script trace.""" + """Get an script or automation trace.""" key = (msg["domain"], msg["item_id"]) run_id = msg["run_id"] - trace = get_debug_trace(hass, key, run_id) + trace = hass.data[DATA_TRACE][key][run_id] message = websocket_api.messages.result_message(msg["id"], trace) connection.send_message(json.dumps(message, cls=TraceJSONEncoder, allow_nan=False)) +def get_debug_traces(hass, key): + """Return a serializable list of debug traces for an script or automation.""" + traces = [] + + for trace in hass.data[DATA_TRACE].get(key, {}).values(): + traces.append(trace.as_short_dict()) + + return traces + + @callback @websocket_api.require_admin @websocket_api.websocket_command( { vol.Required("type"): "trace/list", - vol.Inclusive("domain", "id"): vol.In(TRACE_DOMAINS), - vol.Inclusive("item_id", "id"): str, + vol.Required("domain", "id"): vol.In(TRACE_DOMAINS), + vol.Optional("item_id", "id"): str, } ) def websocket_trace_list(hass, connection, msg): - """Summarize automation and script traces.""" - key = (msg["domain"], msg["item_id"]) if "item_id" in msg else None + """Summarize script and automation traces.""" + domain = msg["domain"] + key = (domain, msg["item_id"]) if "item_id" in msg else None if not key: - traces = get_all_debug_traces(hass, summary=True) + traces = [] + for key in hass.data[DATA_TRACE]: + if key[0] == domain: + traces.extend(get_debug_traces(hass, key)) else: - traces = get_debug_traces(hass, key, summary=True) + traces = get_debug_traces(hass, key) connection.send_result(msg["id"], traces) @@ -230,7 +244,7 @@ def websocket_subscribe_breakpoint_events(hass, connection, msg): } ) def websocket_debug_continue(hass, connection, msg): - """Resume execution of halted automation or script.""" + """Resume execution of halted script or automation.""" key = (msg["domain"], msg["item_id"]) run_id = msg["run_id"] @@ -250,7 +264,7 @@ def websocket_debug_continue(hass, connection, msg): } ) def websocket_debug_step(hass, connection, msg): - """Single step a halted automation or script.""" + """Single step a halted script or automation.""" key = (msg["domain"], msg["item_id"]) run_id = msg["run_id"] @@ -270,7 +284,7 @@ def websocket_debug_step(hass, connection, msg): } ) def websocket_debug_stop(hass, connection, msg): - """Stop a halted automation or script.""" + """Stop a halted script or automation.""" key = (msg["domain"], msg["item_id"]) run_id = msg["run_id"] diff --git a/tests/components/trace/test_websocket_api.py b/tests/components/trace/test_websocket_api.py index f198cf1b55a..16c5304327a 100644 --- a/tests/components/trace/test_websocket_api.py +++ b/tests/components/trace/test_websocket_api.py @@ -9,7 +9,7 @@ from tests.common import assert_lists_same def _find_run_id(traces, trace_type, item_id): - """Find newest run_id for an automation or script.""" + """Find newest run_id for an script or automation.""" for trace in reversed(traces): if trace["domain"] == trace_type and trace["item_id"] == item_id: return trace["run_id"] @@ -18,7 +18,7 @@ def _find_run_id(traces, trace_type, item_id): def _find_traces(traces, trace_type, item_id): - """Find traces for an automation or script.""" + """Find traces for an script or automation.""" return [ trace for trace in traces @@ -30,7 +30,7 @@ def _find_traces(traces, trace_type, item_id): "domain, prefix", [("automation", "action"), ("script", "sequence")] ) async def test_get_trace(hass, hass_ws_client, domain, prefix): - """Test tracing an automation or script.""" + """Test tracing an script or automation.""" id = 1 def next_id(): @@ -92,7 +92,7 @@ async def test_get_trace(hass, hass_ws_client, domain, prefix): await hass.async_block_till_done() # List traces - await client.send_json({"id": next_id(), "type": "trace/list"}) + await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain}) response = await client.receive_json() assert response["success"] run_id = _find_run_id(response["result"], domain, "sun") @@ -140,7 +140,7 @@ async def test_get_trace(hass, hass_ws_client, domain, prefix): await hass.async_block_till_done() # List traces - await client.send_json({"id": next_id(), "type": "trace/list"}) + await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain}) response = await client.receive_json() assert response["success"] run_id = _find_run_id(response["result"], domain, "moon") @@ -193,7 +193,7 @@ async def test_get_trace(hass, hass_ws_client, domain, prefix): await hass.async_block_till_done() # List traces - await client.send_json({"id": next_id(), "type": "trace/list"}) + await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain}) response = await client.receive_json() assert response["success"] run_id = _find_run_id(response["result"], "automation", "moon") @@ -233,7 +233,7 @@ async def test_get_trace(hass, hass_ws_client, domain, prefix): await hass.async_block_till_done() # List traces - await client.send_json({"id": next_id(), "type": "trace/list"}) + await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain}) response = await client.receive_json() assert response["success"] run_id = _find_run_id(response["result"], "automation", "moon") @@ -280,7 +280,7 @@ async def test_get_trace(hass, hass_ws_client, domain, prefix): @pytest.mark.parametrize("domain", ["automation", "script"]) async def test_trace_overflow(hass, hass_ws_client, domain): - """Test the number of stored traces per automation or script is limited.""" + """Test the number of stored traces per script or automation is limited.""" id = 1 def next_id(): @@ -313,7 +313,7 @@ async def test_trace_overflow(hass, hass_ws_client, domain): client = await hass_ws_client() - await client.send_json({"id": next_id(), "type": "trace/list"}) + await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain}) response = await client.receive_json() assert response["success"] assert response["result"] == [] @@ -328,7 +328,7 @@ async def test_trace_overflow(hass, hass_ws_client, domain): await hass.async_block_till_done() # List traces - await client.send_json({"id": next_id(), "type": "trace/list"}) + await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain}) response = await client.receive_json() assert response["success"] assert len(_find_traces(response["result"], domain, "moon")) == 1 @@ -343,7 +343,7 @@ async def test_trace_overflow(hass, hass_ws_client, domain): await hass.services.async_call("script", "moon") await hass.async_block_till_done() - await client.send_json({"id": next_id(), "type": "trace/list"}) + await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain}) response = await client.receive_json() assert response["success"] moon_traces = _find_traces(response["result"], domain, "moon") @@ -358,7 +358,7 @@ async def test_trace_overflow(hass, hass_ws_client, domain): "domain, prefix", [("automation", "action"), ("script", "sequence")] ) async def test_list_traces(hass, hass_ws_client, domain, prefix): - """Test listing automation and script traces.""" + """Test listing script and automation traces.""" id = 1 def next_id(): @@ -398,7 +398,7 @@ async def test_list_traces(hass, hass_ws_client, domain, prefix): client = await hass_ws_client() - await client.send_json({"id": next_id(), "type": "trace/list"}) + await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain}) response = await client.receive_json() assert response["success"] assert response["result"] == [] @@ -418,7 +418,7 @@ async def test_list_traces(hass, hass_ws_client, domain, prefix): await hass.async_block_till_done() # Get trace - await client.send_json({"id": next_id(), "type": "trace/list"}) + await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain}) response = await client.receive_json() assert response["success"] assert len(response["result"]) == 1 @@ -461,7 +461,7 @@ async def test_list_traces(hass, hass_ws_client, domain, prefix): await hass.async_block_till_done() # Get trace - await client.send_json({"id": next_id(), "type": "trace/list"}) + await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain}) response = await client.receive_json() assert response["success"] assert len(_find_traces(response["result"], domain, "moon")) == 3 @@ -585,7 +585,7 @@ async def test_nested_traces(hass, hass_ws_client, domain, prefix): "domain, prefix", [("automation", "action"), ("script", "sequence")] ) async def test_breakpoints(hass, hass_ws_client, domain, prefix): - """Test automation and script breakpoints.""" + """Test script and automation breakpoints.""" id = 1 def next_id(): @@ -594,7 +594,9 @@ async def test_breakpoints(hass, hass_ws_client, domain, prefix): return id async def assert_last_action(item_id, expected_action, expected_state): - await client.send_json({"id": next_id(), "type": "trace/list"}) + await client.send_json( + {"id": next_id(), "type": "trace/list", "domain": domain} + ) response = await client.receive_json() assert response["success"] trace = _find_traces(response["result"], domain, item_id)[-1] @@ -770,7 +772,9 @@ async def test_breakpoints_2(hass, hass_ws_client, domain, prefix): return id async def assert_last_action(item_id, expected_action, expected_state): - await client.send_json({"id": next_id(), "type": "trace/list"}) + await client.send_json( + {"id": next_id(), "type": "trace/list", "domain": domain} + ) response = await client.receive_json() assert response["success"] trace = _find_traces(response["result"], domain, item_id)[-1] @@ -883,7 +887,9 @@ async def test_breakpoints_3(hass, hass_ws_client, domain, prefix): return id async def assert_last_action(item_id, expected_action, expected_state): - await client.send_json({"id": next_id(), "type": "trace/list"}) + await client.send_json( + {"id": next_id(), "type": "trace/list", "domain": domain} + ) response = await client.receive_json() assert response["success"] trace = _find_traces(response["result"], domain, item_id)[-1]