diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 6fb617671b2..4d315f428c3 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -157,7 +157,7 @@ SCRIPT_DEBUG_CONTINUE_STOP: SignalTypeFormat[Literal["continue", "stop"]] = ( ) SCRIPT_DEBUG_CONTINUE_ALL = "script_debug_continue_all" -script_stack_cv: ContextVar[list[int] | None] = ContextVar("script_stack", default=None) +script_stack_cv: ContextVar[list[str] | None] = ContextVar("script_stack", default=None) class ScriptData(TypedDict): @@ -452,7 +452,7 @@ class _ScriptRun: if (script_stack := script_stack_cv.get()) is None: script_stack = [] script_stack_cv.set(script_stack) - script_stack.append(id(self._script)) + script_stack.append(self._script.unique_id) response = None try: @@ -1401,6 +1401,7 @@ class Script: self.sequence = sequence template.attach(hass, self.sequence) self.name = name + self.unique_id = f"{domain}.{name}-{id(self)}" self.domain = domain self.running_description = running_description or f"{domain} script" self._change_listener = change_listener @@ -1723,10 +1724,21 @@ class Script: if ( self.script_mode in (SCRIPT_MODE_RESTART, SCRIPT_MODE_QUEUED) and script_stack is not None - and id(self) in script_stack + and self.unique_id in script_stack ): script_execution_set("disallowed_recursion_detected") - self._log("Disallowed recursion detected", level=logging.WARNING) + formatted_stack = [ + f"- {name_id.partition('-')[0]}" for name_id in script_stack + ] + self._log( + "Disallowed recursion detected, " + f"{script_stack[-1].partition('-')[0]} tried to start " + f"{self.domain}.{self.name} which is already running " + "in the current execution path; " + "Traceback (most recent call last):\n" + f"{"\n".join(formatted_stack)}", + level=logging.WARNING, + ) return None if self.script_mode != SCRIPT_MODE_QUEUED: diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index 948255ccea5..47221a77cee 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -6247,3 +6247,72 @@ async def test_stopping_run_before_starting( # would hang indefinitely. run = script._ScriptRun(hass, script_obj, {}, None, True) await run.async_stop() + + +async def test_disallowed_recursion( + hass: HomeAssistant, caplog: pytest.LogCaptureFixture +) -> None: + """Test a queued mode script disallowed recursion.""" + context = Context() + calls = 0 + alias = "event step" + sequence1 = cv.SCRIPT_SCHEMA({"alias": alias, "service": "test.call_script_2"}) + script1_obj = script.Script( + hass, + sequence1, + "Test Name1", + "test_domain1", + script_mode="queued", + running_description="test script1", + ) + + sequence2 = cv.SCRIPT_SCHEMA({"alias": alias, "service": "test.call_script_3"}) + script2_obj = script.Script( + hass, + sequence2, + "Test Name2", + "test_domain2", + script_mode="queued", + running_description="test script2", + ) + + sequence3 = cv.SCRIPT_SCHEMA({"alias": alias, "service": "test.call_script_1"}) + script3_obj = script.Script( + hass, + sequence3, + "Test Name3", + "test_domain3", + script_mode="queued", + running_description="test script3", + ) + + async def _async_service_handler_1(*args, **kwargs) -> None: + await script1_obj.async_run(context=context) + + hass.services.async_register("test", "call_script_1", _async_service_handler_1) + + async def _async_service_handler_2(*args, **kwargs) -> None: + await script2_obj.async_run(context=context) + + hass.services.async_register("test", "call_script_2", _async_service_handler_2) + + async def _async_service_handler_3(*args, **kwargs) -> None: + await script3_obj.async_run(context=context) + + hass.services.async_register("test", "call_script_3", _async_service_handler_3) + + await script1_obj.async_run(context=context) + await hass.async_block_till_done() + + assert calls == 0 + assert ( + "Test Name1: Disallowed recursion detected, " + "test_domain3.Test Name3 tried to start test_domain1.Test Name1" + " which is already running in the current execution path; " + "Traceback (most recent call last):" + ) in caplog.text + assert ( + "- test_domain1.Test Name1\n" + "- test_domain2.Test Name2\n" + "- test_domain3.Test Name3" + ) in caplog.text