diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index 24121359ac0..862a664976b 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -390,8 +390,12 @@ class AutomationEntity(ToggleEntity, RestoreEntity): reason = f' by {run_variables["trigger"]["description"]}' self._logger.debug("Automation triggered%s", reason) + # Create a new context referring to the old context. + parent_id = None if context is None else context.id + trigger_context = Context(parent_id=parent_id) + with trace_automation( - self.hass, self.unique_id, self._raw_config, context + self.hass, self.unique_id, self._raw_config, trigger_context ) as automation_trace: if self._variables: try: @@ -421,10 +425,6 @@ class AutomationEntity(ToggleEntity, RestoreEntity): # Prepare tracing the execution of the automation's actions automation_trace.set_action_trace(trace_get()) - # Create a new context referring to the old context. - parent_id = None if context is None else context.id - trigger_context = Context(parent_id=parent_id) - self.async_set_context(trigger_context) event_data = { ATTR_NAME: self._name, diff --git a/tests/components/automation/test_websocket_api.py b/tests/components/automation/test_websocket_api.py index 106f687f4ee..99b9540b06e 100644 --- a/tests/components/automation/test_websocket_api.py +++ b/tests/components/automation/test_websocket_api.py @@ -3,6 +3,7 @@ from unittest.mock import patch from homeassistant.bootstrap import async_setup_component from homeassistant.components import automation, config +from homeassistant.core import Context from tests.common import assert_lists_same from tests.components.blueprint.conftest import stub_blueprint_populate # noqa: F401 @@ -52,7 +53,8 @@ async def test_get_automation_trace(hass, hass_ws_client): client = await hass_ws_client() # Trigger "sun" automation - hass.bus.async_fire("test_event") + context = Context() + hass.bus.async_fire("test_event", context=context) await hass.async_block_till_done() # List traces @@ -73,6 +75,7 @@ async def test_get_automation_trace(hass, hass_ws_client): response = await client.receive_json() assert response["success"] trace = response["result"] + assert trace["context"]["parent_id"] == context.id assert len(trace["action_trace"]) == 1 assert len(trace["action_trace"]["action/0"]) == 1 assert trace["action_trace"]["action/0"][0]["error"]