From 46f27fdefdf38fff885a64105310875813e8587b Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Tue, 15 Mar 2022 18:48:54 +0100 Subject: [PATCH] Don't prevent automations from triggering themselves (#68178) --- .../components/automation/__init__.py | 5 + homeassistant/helpers/script.py | 2 +- tests/components/automation/test_init.py | 207 +++++++++++++++++- tests/components/script/test_init.py | 4 - 4 files changed, 212 insertions(+), 6 deletions(-) diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index a8d009ac2bb..3c9cd07a146 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -54,6 +54,7 @@ from homeassistant.helpers.script import ( CONF_MAX, CONF_MAX_EXCEEDED, Script, + script_stack_cv, ) from homeassistant.helpers.script_variables import ScriptVariables from homeassistant.helpers.service import ( @@ -505,6 +506,10 @@ class AutomationEntity(ToggleEntity, RestoreEntity): EVENT_AUTOMATION_TRIGGERED, event_data, context=trigger_context ) + # Make a new empty script stack; automations are allowed + # to recursively trigger themselves + script_stack_cv.set([]) + try: with trace_path("action"): await self.action_script.async_run( diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index b7e7c05478c..1ede1d10d89 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -1247,7 +1247,7 @@ class Script: and id(self) in script_stack ): script_execution_set("disallowed_recursion_detected") - _LOGGER.warning("Disallowed recursion detected") + self._log("Disallowed recursion detected", level=logging.WARNING) return if self.script_mode != SCRIPT_MODE_QUEUED: diff --git a/tests/components/automation/test_init.py b/tests/components/automation/test_init.py index b90c6a90819..1c90abe72ca 100644 --- a/tests/components/automation/test_init.py +++ b/tests/components/automation/test_init.py @@ -1,5 +1,6 @@ """The tests for the automation component.""" import asyncio +from datetime import timedelta import logging from unittest.mock import Mock, patch @@ -25,14 +26,30 @@ from homeassistant.const import ( STATE_OFF, STATE_ON, ) -from homeassistant.core import Context, CoreState, State, callback +from homeassistant.core import ( + Context, + CoreState, + HomeAssistant, + ServiceCall, + State, + callback, +) from homeassistant.exceptions import HomeAssistantError, Unauthorized +from homeassistant.helpers.script import ( + SCRIPT_MODE_CHOICES, + SCRIPT_MODE_PARALLEL, + SCRIPT_MODE_QUEUED, + SCRIPT_MODE_RESTART, + SCRIPT_MODE_SINGLE, + _async_stop_scripts_at_shutdown, +) from homeassistant.setup import async_setup_component import homeassistant.util.dt as dt_util from tests.common import ( assert_setup_component, async_capture_events, + async_fire_time_changed, async_mock_service, mock_restore_cache, ) @@ -1570,3 +1587,191 @@ async def test_trigger_condition_explicit_id(hass, calls): await hass.async_block_till_done() assert len(calls) == 2 assert calls[-1].data.get("param") == "two" + + +@pytest.mark.parametrize( + "automation_mode,automation_runs", + ( + (SCRIPT_MODE_PARALLEL, 2), + (SCRIPT_MODE_QUEUED, 2), + (SCRIPT_MODE_RESTART, 2), + (SCRIPT_MODE_SINGLE, 1), + ), +) +@pytest.mark.parametrize( + "script_mode,script_warning_msg", + ( + (SCRIPT_MODE_PARALLEL, "script1: Maximum number of runs exceeded"), + (SCRIPT_MODE_QUEUED, "script1: Disallowed recursion detected"), + (SCRIPT_MODE_RESTART, "script1: Disallowed recursion detected"), + (SCRIPT_MODE_SINGLE, "script1: Already running"), + ), +) +async def test_recursive_automation_starting_script( + hass: HomeAssistant, + automation_mode, + automation_runs, + script_mode, + script_warning_msg, + caplog, +): + """Test starting automations does not interfere with script deadlock prevention.""" + + # Fail if additional script modes are added to + # make sure we cover all script modes in tests + assert SCRIPT_MODE_CHOICES == [ + SCRIPT_MODE_PARALLEL, + SCRIPT_MODE_QUEUED, + SCRIPT_MODE_RESTART, + SCRIPT_MODE_SINGLE, + ] + + stop_scripts_at_shutdown_called = asyncio.Event() + real_stop_scripts_at_shutdown = _async_stop_scripts_at_shutdown + + async def mock_stop_scripts_at_shutdown(*args): + await real_stop_scripts_at_shutdown(*args) + stop_scripts_at_shutdown_called.set() + + with patch( + "homeassistant.helpers.script._async_stop_scripts_at_shutdown", + wraps=mock_stop_scripts_at_shutdown, + ): + assert await async_setup_component( + hass, + "script", + { + "script": { + "script1": { + "mode": script_mode, + "sequence": [ + {"event": "trigger_automation"}, + { + "wait_template": f"{{{{ float(states('sensor.test'), 0) >= {automation_runs} }}}}" + }, + {"service": "script.script1"}, + {"service": "test.script_done"}, + ], + }, + } + }, + ) + + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: { + "mode": automation_mode, + "trigger": [ + {"platform": "event", "event_type": "trigger_automation"}, + ], + "action": [ + {"service": "test.automation_started"}, + {"service": "script.script1"}, + ], + } + }, + ) + + script_done_event = asyncio.Event() + script_done = [] + automation_started = [] + automation_triggered = [] + + async def async_service_handler(service: ServiceCall): + if service.service == "automation_started": + automation_started.append(service) + elif service.service == "script_done": + script_done.append(service) + if len(script_done) == 1: + script_done_event.set() + + async def async_automation_triggered(event): + """Listen to automation_triggered event from the automation integration.""" + automation_triggered.append(event) + hass.states.async_set("sensor.test", str(len(automation_triggered))) + + hass.services.async_register("test", "script_done", async_service_handler) + hass.services.async_register( + "test", "automation_started", async_service_handler + ) + hass.bus.async_listen("automation_triggered", async_automation_triggered) + + hass.bus.async_fire("trigger_automation") + await asyncio.wait_for(script_done_event.wait(), 1) + + # Trigger 1st stage script shutdown + hass.state = CoreState.stopping + hass.bus.async_fire("homeassistant_stop") + await asyncio.wait_for(stop_scripts_at_shutdown_called.wait(), 1) + + # Trigger 2nd stage script shutdown + async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=60)) + await hass.async_block_till_done() + + assert script_warning_msg in caplog.text + + +@pytest.mark.parametrize("automation_mode", SCRIPT_MODE_CHOICES) +async def test_recursive_automation(hass: HomeAssistant, automation_mode, caplog): + """Test automation triggering itself. + + - Illegal recursion detection should not be triggered + - Home Assistant should not hang on shut down + """ + stop_scripts_at_shutdown_called = asyncio.Event() + real_stop_scripts_at_shutdown = _async_stop_scripts_at_shutdown + + async def stop_scripts_at_shutdown(*args): + await real_stop_scripts_at_shutdown(*args) + stop_scripts_at_shutdown_called.set() + + with patch( + "homeassistant.helpers.script._async_stop_scripts_at_shutdown", + wraps=stop_scripts_at_shutdown, + ): + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: { + "mode": automation_mode, + "trigger": [ + {"platform": "event", "event_type": "trigger_automation"}, + ], + "action": [ + {"event": "trigger_automation"}, + {"service": "test.automation_done"}, + ], + } + }, + ) + + service_called = asyncio.Event() + service_called_late = [] + + async def async_service_handler(service): + if service.service == "automation_done": + service_called.set() + if service.service == "automation_started_late": + service_called_late.append(service) + + hass.services.async_register("test", "automation_done", async_service_handler) + hass.services.async_register( + "test", "automation_started_late", async_service_handler + ) + + hass.bus.async_fire("trigger_automation") + await asyncio.wait_for(service_called.wait(), 1) + + # Trigger 1st stage script shutdown + hass.state = CoreState.stopping + hass.bus.async_fire("homeassistant_stop") + await asyncio.wait_for(stop_scripts_at_shutdown_called.wait(), 1) + + # Trigger 2nd stage script shutdown + async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=90)) + await hass.async_block_till_done() + + assert "Disallowed recursion detected" not in caplog.text diff --git a/tests/components/script/test_init.py b/tests/components/script/test_init.py index 35875c6da12..3bd179286a9 100644 --- a/tests/components/script/test_init.py +++ b/tests/components/script/test_init.py @@ -840,8 +840,6 @@ async def test_recursive_script(hass, script_mode, warning_msg, caplog): service_called.set() hass.services.async_register("test", "script", async_service_handler) - hass.states.async_set("input_boolean.test", "on") - hass.states.async_set("input_boolean.test2", "off") await hass.services.async_call("script", "script1") await asyncio.wait_for(service_called.wait(), 1) @@ -908,8 +906,6 @@ async def test_recursive_script_indirect(hass, script_mode, warning_msg, caplog) service_called.set() hass.services.async_register("test", "script", async_service_handler) - hass.states.async_set("input_boolean.test", "on") - hass.states.async_set("input_boolean.test2", "off") await hass.services.async_call("script", "script1") await asyncio.wait_for(service_called.wait(), 1)