diff --git a/homeassistant/components/trace/__init__.py b/homeassistant/components/trace/__init__.py index 43deefaa769..b0211505e7c 100644 --- a/homeassistant/components/trace/__init__.py +++ b/homeassistant/components/trace/__init__.py @@ -6,7 +6,12 @@ from itertools import count from typing import Any, Deque from homeassistant.core import Context -from homeassistant.helpers.trace import TraceElement, trace_id_set +from homeassistant.helpers.trace import ( + TraceElement, + trace_id_get, + trace_id_set, + trace_set_child_id, +) import homeassistant.util.dt as dt_util from . import websocket_api @@ -55,6 +60,8 @@ class ActionTrace: self._timestamp_start: dt.datetime = dt_util.utcnow() self.key: tuple[str, str] = key self._variables: dict[str, Any] | None = None + if trace_id_get(): + trace_set_child_id(self.key, self.run_id) trace_id_set((key, self.run_id)) def set_action_trace(self, trace: dict[str, Deque[TraceElement]]) -> None: diff --git a/homeassistant/helpers/trace.py b/homeassistant/helpers/trace.py index ba39e19943b..5d5a0f5ff03 100644 --- a/homeassistant/helpers/trace.py +++ b/homeassistant/helpers/trace.py @@ -16,6 +16,8 @@ class TraceElement: def __init__(self, variables: TemplateVarsType, path: str): """Container for trace data.""" + self._child_key: tuple[str, str] | None = None + self._child_run_id: str | None = None self._error: Exception | None = None self.path: str = path self._result: dict | None = None @@ -36,6 +38,11 @@ class TraceElement: """Container for trace data.""" return str(self.as_dict()) + def set_child_id(self, child_key: tuple[str, str], child_run_id: str) -> None: + """Set trace id of a nested script run.""" + self._child_key = child_key + self._child_run_id = child_run_id + def set_error(self, ex: Exception) -> None: """Set error.""" self._error = ex @@ -47,6 +54,12 @@ class TraceElement: def as_dict(self) -> dict[str, Any]: """Return dictionary version of this TraceElement.""" result: dict[str, Any] = {"path": self.path, "timestamp": self._timestamp} + if self._child_key is not None: + result["child_id"] = { + "domain": self._child_key[0], + "item_id": self._child_key[1], + "run_id": str(self._child_run_id), + } if self._variables: result["changed_variables"] = self._variables if self._error is not None: @@ -161,6 +174,13 @@ def trace_clear() -> None: variables_cv.set(None) +def trace_set_child_id(child_key: tuple[str, str], child_run_id: str) -> None: + """Set child trace_id of TraceElement at the top of the stack.""" + node = cast(TraceElement, trace_stack_top(trace_stack_cv)) + if node: + node.set_child_id(child_key, child_run_id) + + def trace_set_result(**kwargs: Any) -> None: """Set the result of TraceElement at the top of the stack.""" node = cast(TraceElement, trace_stack_top(trace_stack_cv)) diff --git a/tests/components/trace/test_websocket_api.py b/tests/components/trace/test_websocket_api.py index 8dc09731b79..f198cf1b55a 100644 --- a/tests/components/trace/test_websocket_api.py +++ b/tests/components/trace/test_websocket_api.py @@ -26,12 +26,6 @@ def _find_traces(traces, trace_type, item_id): ] -# TODO: Remove -def _find_traces_for_automation(traces, item_id): - """Find traces for an automation.""" - return [trace for trace in traces if trace["item_id"] == item_id] - - @pytest.mark.parametrize( "domain, prefix", [("automation", "action"), ("script", "sequence")] ) @@ -515,6 +509,78 @@ async def test_list_traces(hass, hass_ws_client, domain, prefix): assert trace["trigger"] == "event 'test_event2'" +@pytest.mark.parametrize( + "domain, prefix", [("automation", "action"), ("script", "sequence")] +) +async def test_nested_traces(hass, hass_ws_client, domain, prefix): + """Test nested automation and script traces.""" + id = 1 + + def next_id(): + nonlocal id + id += 1 + return id + + sun_config = { + "id": "sun", + "trigger": {"platform": "event", "event_type": "test_event"}, + "action": {"service": "script.moon"}, + } + moon_config = { + "sequence": {"event": "another_event"}, + } + if domain == "script": + sun_config = {"sequence": sun_config["action"]} + + if domain == "automation": + assert await async_setup_component(hass, domain, {domain: [sun_config]}) + assert await async_setup_component( + hass, "script", {"script": {"moon": moon_config}} + ) + else: + assert await async_setup_component( + hass, domain, {domain: {"sun": sun_config, "moon": moon_config}} + ) + + client = await hass_ws_client() + + # Trigger "sun" automation / run "sun" script + if domain == "automation": + hass.bus.async_fire("test_event") + else: + await hass.services.async_call("script", "sun") + await hass.async_block_till_done() + + # List traces + await client.send_json({"id": next_id(), "type": "trace/list"}) + response = await client.receive_json() + assert response["success"] + assert len(response["result"]) == 2 + assert len(_find_traces(response["result"], domain, "sun")) == 1 + assert len(_find_traces(response["result"], "script", "moon")) == 1 + sun_run_id = _find_run_id(response["result"], domain, "sun") + moon_run_id = _find_run_id(response["result"], "script", "moon") + assert sun_run_id != moon_run_id + + # Get trace + await client.send_json( + { + "id": next_id(), + "type": "trace/get", + "domain": domain, + "item_id": "sun", + "run_id": sun_run_id, + } + ) + response = await client.receive_json() + assert response["success"] + trace = response["result"] + assert len(trace["action_trace"]) == 1 + assert len(trace["action_trace"][f"{prefix}/0"]) == 1 + child_id = trace["action_trace"][f"{prefix}/0"][0]["child_id"] + assert child_id == {"domain": "script", "item_id": "moon", "run_id": moon_run_id} + + @pytest.mark.parametrize( "domain, prefix", [("automation", "action"), ("script", "sequence")] )