Address review comments from trace refactoring PRs (#48288)

This commit is contained in:
Erik Montnemery 2021-03-29 08:09:14 +02:00 committed by GitHub
parent ee81869c05
commit 14ef0531f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 130 additions and 140 deletions

View File

@ -2,25 +2,79 @@
from __future__ import annotations from __future__ import annotations
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Deque
from homeassistant.components.trace import AutomationTrace, async_store_trace from homeassistant.components.trace import ActionTrace, async_store_trace
from homeassistant.core import Context
from homeassistant.helpers.trace import TraceElement
# mypy: allow-untyped-calls, allow-untyped-defs # mypy: allow-untyped-calls, allow-untyped-defs
# mypy: no-check-untyped-defs, no-warn-return-any # mypy: no-check-untyped-defs, no-warn-return-any
class AutomationTrace(ActionTrace):
"""Container for automation trace."""
def __init__(
self,
item_id: str,
config: dict[str, Any],
context: Context,
):
"""Container for automation trace."""
key = ("automation", item_id)
super().__init__(key, config, context)
self._condition_trace: dict[str, Deque[TraceElement]] | None = None
def set_condition_trace(self, trace: dict[str, Deque[TraceElement]]) -> None:
"""Set condition trace."""
self._condition_trace = trace
def as_dict(self) -> dict[str, Any]:
"""Return dictionary version of this AutomationTrace."""
result = super().as_dict()
condition_traces = {}
if self._condition_trace:
for key, trace_list in self._condition_trace.items():
condition_traces[key] = [item.as_dict() for item in trace_list]
result["condition_trace"] = condition_traces
return result
def as_short_dict(self) -> dict[str, Any]:
"""Return a brief dictionary version of this AutomationTrace."""
result = super().as_short_dict()
last_condition = None
trigger = None
if self._condition_trace:
last_condition = list(self._condition_trace)[-1]
if self._variables:
trigger = self._variables.get("trigger", {}).get("description")
result["trigger"] = trigger
result["last_condition"] = last_condition
return result
@contextmanager @contextmanager
def trace_automation(hass, item_id, config, context): def trace_automation(hass, automation_id, config, context):
"""Trace action execution of automation with item_id.""" """Trace action execution of automation with automation_id."""
trace = AutomationTrace(item_id, config, context) trace = AutomationTrace(automation_id, config, context)
async_store_trace(hass, trace) async_store_trace(hass, trace)
try: try:
yield trace yield trace
except Exception as ex: # pylint: disable=broad-except except Exception as ex: # pylint: disable=broad-except
if item_id: if automation_id:
trace.set_error(ex) trace.set_error(ex)
raise ex raise ex
finally: finally:
if item_id: if automation_id:
trace.finished() trace.finished()

View File

@ -2,8 +2,24 @@
from __future__ import annotations from __future__ import annotations
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any
from homeassistant.components.trace import ScriptTrace, async_store_trace from homeassistant.components.trace import ActionTrace, async_store_trace
from homeassistant.core import Context
class ScriptTrace(ActionTrace):
"""Container for automation trace."""
def __init__(
self,
item_id: str,
config: dict[str, Any],
context: Context,
):
"""Container for automation trace."""
key = ("script", item_id)
super().__init__(key, config, context)
@contextmanager @contextmanager

View File

@ -128,68 +128,3 @@ class ActionTrace:
result["last_action"] = last_action result["last_action"] = last_action
return result return result
class AutomationTrace(ActionTrace):
"""Container for automation trace."""
def __init__(
self,
item_id: str,
config: dict[str, Any],
context: Context,
):
"""Container for automation trace."""
key = ("automation", item_id)
super().__init__(key, config, context)
self._condition_trace: dict[str, Deque[TraceElement]] | None = None
def set_condition_trace(self, trace: dict[str, Deque[TraceElement]]) -> None:
"""Set condition trace."""
self._condition_trace = trace
def as_dict(self) -> dict[str, Any]:
"""Return dictionary version of this AutomationTrace."""
result = super().as_dict()
condition_traces = {}
if self._condition_trace:
for key, trace_list in self._condition_trace.items():
condition_traces[key] = [item.as_dict() for item in trace_list]
result["condition_trace"] = condition_traces
return result
def as_short_dict(self) -> dict[str, Any]:
"""Return a brief dictionary version of this AutomationTrace."""
result = super().as_short_dict()
last_condition = None
trigger = None
if self._condition_trace:
last_condition = list(self._condition_trace)[-1]
if self._variables:
trigger = self._variables.get("trigger", {}).get("description")
result["trigger"] = trigger
result["last_condition"] = last_condition
return result
class ScriptTrace(ActionTrace):
"""Container for automation trace."""
def __init__(
self,
item_id: str,
config: dict[str, Any],
context: Context,
):
"""Container for automation trace."""
key = ("script", item_id)
super().__init__(key, config, context)

View File

@ -1,35 +0,0 @@
"""Support for automation and script tracing and debugging."""
from homeassistant.core import callback
from .const import DATA_TRACE
@callback
def get_debug_trace(hass, key, run_id):
"""Return a serializable debug trace."""
return hass.data[DATA_TRACE][key][run_id]
@callback
def get_debug_traces(hass, key, summary=False):
"""Return a serializable list of debug traces for an automation or script."""
traces = []
for trace in hass.data[DATA_TRACE].get(key, {}).values():
if summary:
traces.append(trace.as_short_dict())
else:
traces.append(trace.as_dict())
return traces
@callback
def get_all_debug_traces(hass, summary=False):
"""Return a serializable list of debug traces for all automations and scripts."""
traces = []
for key in hass.data[DATA_TRACE]:
traces.extend(get_debug_traces(hass, key, summary))
return traces

View File

@ -1,4 +1,4 @@
"""Helpers for automation and script tracing and debugging.""" """Helpers for script and automation tracing and debugging."""
from collections import OrderedDict from collections import OrderedDict
from datetime import timedelta from datetime import timedelta
from typing import Any from typing import Any

View File

@ -23,12 +23,12 @@ from homeassistant.helpers.script import (
debug_stop, debug_stop,
) )
from .trace import DATA_TRACE, get_all_debug_traces, get_debug_trace, get_debug_traces from .const import DATA_TRACE
from .utils import TraceJSONEncoder from .utils import TraceJSONEncoder
# mypy: allow-untyped-calls, allow-untyped-defs # mypy: allow-untyped-calls, allow-untyped-defs
TRACE_DOMAINS = ["automation", "script"] TRACE_DOMAINS = ("automation", "script")
@callback @callback
@ -57,33 +57,47 @@ def async_setup(hass: HomeAssistant) -> None:
} }
) )
def websocket_trace_get(hass, connection, msg): def websocket_trace_get(hass, connection, msg):
"""Get an automation or script trace.""" """Get an script or automation trace."""
key = (msg["domain"], msg["item_id"]) key = (msg["domain"], msg["item_id"])
run_id = msg["run_id"] run_id = msg["run_id"]
trace = get_debug_trace(hass, key, run_id) trace = hass.data[DATA_TRACE][key][run_id]
message = websocket_api.messages.result_message(msg["id"], trace) message = websocket_api.messages.result_message(msg["id"], trace)
connection.send_message(json.dumps(message, cls=TraceJSONEncoder, allow_nan=False)) connection.send_message(json.dumps(message, cls=TraceJSONEncoder, allow_nan=False))
def get_debug_traces(hass, key):
"""Return a serializable list of debug traces for an script or automation."""
traces = []
for trace in hass.data[DATA_TRACE].get(key, {}).values():
traces.append(trace.as_short_dict())
return traces
@callback @callback
@websocket_api.require_admin @websocket_api.require_admin
@websocket_api.websocket_command( @websocket_api.websocket_command(
{ {
vol.Required("type"): "trace/list", vol.Required("type"): "trace/list",
vol.Inclusive("domain", "id"): vol.In(TRACE_DOMAINS), vol.Required("domain", "id"): vol.In(TRACE_DOMAINS),
vol.Inclusive("item_id", "id"): str, vol.Optional("item_id", "id"): str,
} }
) )
def websocket_trace_list(hass, connection, msg): def websocket_trace_list(hass, connection, msg):
"""Summarize automation and script traces.""" """Summarize script and automation traces."""
key = (msg["domain"], msg["item_id"]) if "item_id" in msg else None domain = msg["domain"]
key = (domain, msg["item_id"]) if "item_id" in msg else None
if not key: if not key:
traces = get_all_debug_traces(hass, summary=True) traces = []
for key in hass.data[DATA_TRACE]:
if key[0] == domain:
traces.extend(get_debug_traces(hass, key))
else: else:
traces = get_debug_traces(hass, key, summary=True) traces = get_debug_traces(hass, key)
connection.send_result(msg["id"], traces) connection.send_result(msg["id"], traces)
@ -230,7 +244,7 @@ def websocket_subscribe_breakpoint_events(hass, connection, msg):
} }
) )
def websocket_debug_continue(hass, connection, msg): def websocket_debug_continue(hass, connection, msg):
"""Resume execution of halted automation or script.""" """Resume execution of halted script or automation."""
key = (msg["domain"], msg["item_id"]) key = (msg["domain"], msg["item_id"])
run_id = msg["run_id"] run_id = msg["run_id"]
@ -250,7 +264,7 @@ def websocket_debug_continue(hass, connection, msg):
} }
) )
def websocket_debug_step(hass, connection, msg): def websocket_debug_step(hass, connection, msg):
"""Single step a halted automation or script.""" """Single step a halted script or automation."""
key = (msg["domain"], msg["item_id"]) key = (msg["domain"], msg["item_id"])
run_id = msg["run_id"] run_id = msg["run_id"]
@ -270,7 +284,7 @@ def websocket_debug_step(hass, connection, msg):
} }
) )
def websocket_debug_stop(hass, connection, msg): def websocket_debug_stop(hass, connection, msg):
"""Stop a halted automation or script.""" """Stop a halted script or automation."""
key = (msg["domain"], msg["item_id"]) key = (msg["domain"], msg["item_id"])
run_id = msg["run_id"] run_id = msg["run_id"]

View File

@ -9,7 +9,7 @@ from tests.common import assert_lists_same
def _find_run_id(traces, trace_type, item_id): def _find_run_id(traces, trace_type, item_id):
"""Find newest run_id for an automation or script.""" """Find newest run_id for an script or automation."""
for trace in reversed(traces): for trace in reversed(traces):
if trace["domain"] == trace_type and trace["item_id"] == item_id: if trace["domain"] == trace_type and trace["item_id"] == item_id:
return trace["run_id"] return trace["run_id"]
@ -18,7 +18,7 @@ def _find_run_id(traces, trace_type, item_id):
def _find_traces(traces, trace_type, item_id): def _find_traces(traces, trace_type, item_id):
"""Find traces for an automation or script.""" """Find traces for an script or automation."""
return [ return [
trace trace
for trace in traces for trace in traces
@ -30,7 +30,7 @@ def _find_traces(traces, trace_type, item_id):
"domain, prefix", [("automation", "action"), ("script", "sequence")] "domain, prefix", [("automation", "action"), ("script", "sequence")]
) )
async def test_get_trace(hass, hass_ws_client, domain, prefix): async def test_get_trace(hass, hass_ws_client, domain, prefix):
"""Test tracing an automation or script.""" """Test tracing an script or automation."""
id = 1 id = 1
def next_id(): def next_id():
@ -92,7 +92,7 @@ async def test_get_trace(hass, hass_ws_client, domain, prefix):
await hass.async_block_till_done() await hass.async_block_till_done()
# List traces # List traces
await client.send_json({"id": next_id(), "type": "trace/list"}) await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
response = await client.receive_json() response = await client.receive_json()
assert response["success"] assert response["success"]
run_id = _find_run_id(response["result"], domain, "sun") run_id = _find_run_id(response["result"], domain, "sun")
@ -140,7 +140,7 @@ async def test_get_trace(hass, hass_ws_client, domain, prefix):
await hass.async_block_till_done() await hass.async_block_till_done()
# List traces # List traces
await client.send_json({"id": next_id(), "type": "trace/list"}) await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
response = await client.receive_json() response = await client.receive_json()
assert response["success"] assert response["success"]
run_id = _find_run_id(response["result"], domain, "moon") run_id = _find_run_id(response["result"], domain, "moon")
@ -193,7 +193,7 @@ async def test_get_trace(hass, hass_ws_client, domain, prefix):
await hass.async_block_till_done() await hass.async_block_till_done()
# List traces # List traces
await client.send_json({"id": next_id(), "type": "trace/list"}) await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
response = await client.receive_json() response = await client.receive_json()
assert response["success"] assert response["success"]
run_id = _find_run_id(response["result"], "automation", "moon") run_id = _find_run_id(response["result"], "automation", "moon")
@ -233,7 +233,7 @@ async def test_get_trace(hass, hass_ws_client, domain, prefix):
await hass.async_block_till_done() await hass.async_block_till_done()
# List traces # List traces
await client.send_json({"id": next_id(), "type": "trace/list"}) await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
response = await client.receive_json() response = await client.receive_json()
assert response["success"] assert response["success"]
run_id = _find_run_id(response["result"], "automation", "moon") run_id = _find_run_id(response["result"], "automation", "moon")
@ -280,7 +280,7 @@ async def test_get_trace(hass, hass_ws_client, domain, prefix):
@pytest.mark.parametrize("domain", ["automation", "script"]) @pytest.mark.parametrize("domain", ["automation", "script"])
async def test_trace_overflow(hass, hass_ws_client, domain): async def test_trace_overflow(hass, hass_ws_client, domain):
"""Test the number of stored traces per automation or script is limited.""" """Test the number of stored traces per script or automation is limited."""
id = 1 id = 1
def next_id(): def next_id():
@ -313,7 +313,7 @@ async def test_trace_overflow(hass, hass_ws_client, domain):
client = await hass_ws_client() client = await hass_ws_client()
await client.send_json({"id": next_id(), "type": "trace/list"}) await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
response = await client.receive_json() response = await client.receive_json()
assert response["success"] assert response["success"]
assert response["result"] == [] assert response["result"] == []
@ -328,7 +328,7 @@ async def test_trace_overflow(hass, hass_ws_client, domain):
await hass.async_block_till_done() await hass.async_block_till_done()
# List traces # List traces
await client.send_json({"id": next_id(), "type": "trace/list"}) await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
response = await client.receive_json() response = await client.receive_json()
assert response["success"] assert response["success"]
assert len(_find_traces(response["result"], domain, "moon")) == 1 assert len(_find_traces(response["result"], domain, "moon")) == 1
@ -343,7 +343,7 @@ async def test_trace_overflow(hass, hass_ws_client, domain):
await hass.services.async_call("script", "moon") await hass.services.async_call("script", "moon")
await hass.async_block_till_done() await hass.async_block_till_done()
await client.send_json({"id": next_id(), "type": "trace/list"}) await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
response = await client.receive_json() response = await client.receive_json()
assert response["success"] assert response["success"]
moon_traces = _find_traces(response["result"], domain, "moon") moon_traces = _find_traces(response["result"], domain, "moon")
@ -358,7 +358,7 @@ async def test_trace_overflow(hass, hass_ws_client, domain):
"domain, prefix", [("automation", "action"), ("script", "sequence")] "domain, prefix", [("automation", "action"), ("script", "sequence")]
) )
async def test_list_traces(hass, hass_ws_client, domain, prefix): async def test_list_traces(hass, hass_ws_client, domain, prefix):
"""Test listing automation and script traces.""" """Test listing script and automation traces."""
id = 1 id = 1
def next_id(): def next_id():
@ -398,7 +398,7 @@ async def test_list_traces(hass, hass_ws_client, domain, prefix):
client = await hass_ws_client() client = await hass_ws_client()
await client.send_json({"id": next_id(), "type": "trace/list"}) await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
response = await client.receive_json() response = await client.receive_json()
assert response["success"] assert response["success"]
assert response["result"] == [] assert response["result"] == []
@ -418,7 +418,7 @@ async def test_list_traces(hass, hass_ws_client, domain, prefix):
await hass.async_block_till_done() await hass.async_block_till_done()
# Get trace # Get trace
await client.send_json({"id": next_id(), "type": "trace/list"}) await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
response = await client.receive_json() response = await client.receive_json()
assert response["success"] assert response["success"]
assert len(response["result"]) == 1 assert len(response["result"]) == 1
@ -461,7 +461,7 @@ async def test_list_traces(hass, hass_ws_client, domain, prefix):
await hass.async_block_till_done() await hass.async_block_till_done()
# Get trace # Get trace
await client.send_json({"id": next_id(), "type": "trace/list"}) await client.send_json({"id": next_id(), "type": "trace/list", "domain": domain})
response = await client.receive_json() response = await client.receive_json()
assert response["success"] assert response["success"]
assert len(_find_traces(response["result"], domain, "moon")) == 3 assert len(_find_traces(response["result"], domain, "moon")) == 3
@ -585,7 +585,7 @@ async def test_nested_traces(hass, hass_ws_client, domain, prefix):
"domain, prefix", [("automation", "action"), ("script", "sequence")] "domain, prefix", [("automation", "action"), ("script", "sequence")]
) )
async def test_breakpoints(hass, hass_ws_client, domain, prefix): async def test_breakpoints(hass, hass_ws_client, domain, prefix):
"""Test automation and script breakpoints.""" """Test script and automation breakpoints."""
id = 1 id = 1
def next_id(): def next_id():
@ -594,7 +594,9 @@ async def test_breakpoints(hass, hass_ws_client, domain, prefix):
return id return id
async def assert_last_action(item_id, expected_action, expected_state): async def assert_last_action(item_id, expected_action, expected_state):
await client.send_json({"id": next_id(), "type": "trace/list"}) await client.send_json(
{"id": next_id(), "type": "trace/list", "domain": domain}
)
response = await client.receive_json() response = await client.receive_json()
assert response["success"] assert response["success"]
trace = _find_traces(response["result"], domain, item_id)[-1] trace = _find_traces(response["result"], domain, item_id)[-1]
@ -770,7 +772,9 @@ async def test_breakpoints_2(hass, hass_ws_client, domain, prefix):
return id return id
async def assert_last_action(item_id, expected_action, expected_state): async def assert_last_action(item_id, expected_action, expected_state):
await client.send_json({"id": next_id(), "type": "trace/list"}) await client.send_json(
{"id": next_id(), "type": "trace/list", "domain": domain}
)
response = await client.receive_json() response = await client.receive_json()
assert response["success"] assert response["success"]
trace = _find_traces(response["result"], domain, item_id)[-1] trace = _find_traces(response["result"], domain, item_id)[-1]
@ -883,7 +887,9 @@ async def test_breakpoints_3(hass, hass_ws_client, domain, prefix):
return id return id
async def assert_last_action(item_id, expected_action, expected_state): async def assert_last_action(item_id, expected_action, expected_state):
await client.send_json({"id": next_id(), "type": "trace/list"}) await client.send_json(
{"id": next_id(), "type": "trace/list", "domain": domain}
)
response = await client.receive_json() response = await client.receive_json()
assert response["success"] assert response["success"]
trace = _find_traces(response["result"], domain, item_id)[-1] trace = _find_traces(response["result"], domain, item_id)[-1]