From 1df99badcf04593bb98348c8669e0d066277ae27 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Wed, 4 May 2022 15:54:37 +0200 Subject: [PATCH] Allow scripts to turn themselves on (#71289) --- homeassistant/components/script/__init__.py | 9 ++- tests/components/automation/test_init.py | 6 -- tests/components/script/test_init.py | 89 ++++++++++++++++++++- 3 files changed, 95 insertions(+), 9 deletions(-) diff --git a/homeassistant/components/script/__init__.py b/homeassistant/components/script/__init__.py index 660e7233b33..efad242fbd0 100644 --- a/homeassistant/components/script/__init__.py +++ b/homeassistant/components/script/__init__.py @@ -42,6 +42,7 @@ from homeassistant.helpers.script import ( CONF_MAX, CONF_MAX_EXCEEDED, Script, + script_stack_cv, ) from homeassistant.helpers.service import async_set_service_schema from homeassistant.helpers.trace import trace_get, trace_path @@ -398,10 +399,14 @@ class ScriptEntity(ToggleEntity, RestoreEntity): return # Caller does not want to wait for called script to finish so let script run in - # separate Task. However, wait for first state change so we can guarantee that - # it is written to the State Machine before we return. + # separate Task. Make a new empty script stack; scripts are allowed to + # recursively turn themselves on when not waiting. + script_stack_cv.set([]) + self._changed.clear() self.hass.async_create_task(coro) + # Wait for first state change so we can guarantee that + # it is written to the State Machine before we return. await self._changed.wait() async def _async_run(self, variables, context): diff --git a/tests/components/automation/test_init.py b/tests/components/automation/test_init.py index ccb508c6acc..dbc8f0fc346 100644 --- a/tests/components/automation/test_init.py +++ b/tests/components/automation/test_init.py @@ -1791,18 +1791,12 @@ async def test_recursive_automation(hass: HomeAssistant, automation_mode, caplog ) service_called = asyncio.Event() - service_called_late = [] async def async_service_handler(service): if service.service == "automation_done": service_called.set() - if service.service == "automation_started_late": - service_called_late.append(service) hass.services.async_register("test", "automation_done", async_service_handler) - hass.services.async_register( - "test", "automation_started_late", async_service_handler - ) hass.bus.async_fire("trigger_automation") await asyncio.wait_for(service_called.wait(), 1) diff --git a/tests/components/script/test_init.py b/tests/components/script/test_init.py index 8a5297786f5..ca0cdb97592 100644 --- a/tests/components/script/test_init.py +++ b/tests/components/script/test_init.py @@ -1,6 +1,7 @@ """The tests for the Script component.""" # pylint: disable=protected-access import asyncio +from datetime import timedelta from unittest.mock import Mock, patch import pytest @@ -33,12 +34,13 @@ from homeassistant.helpers.script import ( SCRIPT_MODE_QUEUED, SCRIPT_MODE_RESTART, SCRIPT_MODE_SINGLE, + _async_stop_scripts_at_shutdown, ) from homeassistant.helpers.service import async_get_all_descriptions from homeassistant.setup import async_setup_component import homeassistant.util.dt as dt_util -from tests.common import async_mock_service, mock_restore_cache +from tests.common import async_fire_time_changed, async_mock_service, mock_restore_cache from tests.components.logbook.test_init import MockLazyEventPartialState ENTITY_ID = "script.test" @@ -919,6 +921,91 @@ async def test_recursive_script_indirect(hass, script_mode, warning_msg, caplog) assert warning_msg in caplog.text +@pytest.mark.parametrize( + "script_mode", [SCRIPT_MODE_PARALLEL, SCRIPT_MODE_QUEUED, SCRIPT_MODE_RESTART] +) +async def test_recursive_script_turn_on(hass: HomeAssistant, script_mode, caplog): + """Test script turning itself on. + + - Illegal recursion detection should not be triggered + - Home Assistant should not hang on shut down + - SCRIPT_MODE_SINGLE is not relevant because suca script can't turn itself on + """ + # Make sure we cover all script modes + assert SCRIPT_MODE_CHOICES == [ + SCRIPT_MODE_PARALLEL, + SCRIPT_MODE_QUEUED, + SCRIPT_MODE_RESTART, + SCRIPT_MODE_SINGLE, + ] + stop_scripts_at_shutdown_called = asyncio.Event() + real_stop_scripts_at_shutdown = _async_stop_scripts_at_shutdown + + async def stop_scripts_at_shutdown(*args): + await real_stop_scripts_at_shutdown(*args) + stop_scripts_at_shutdown_called.set() + + with patch( + "homeassistant.helpers.script._async_stop_scripts_at_shutdown", + wraps=stop_scripts_at_shutdown, + ): + assert await async_setup_component( + hass, + script.DOMAIN, + { + script.DOMAIN: { + "script1": { + "mode": script_mode, + "sequence": [ + { + "choose": { + "conditions": { + "condition": "template", + "value_template": "{{ request == 'step_2' }}", + }, + "sequence": {"service": "test.script_done"}, + }, + "default": { + "service": "script.turn_on", + "data": { + "entity_id": "script.script1", + "variables": {"request": "step_2"}, + }, + }, + }, + { + "service": "script.turn_on", + "data": {"entity_id": "script.script1"}, + }, + ], + } + } + }, + ) + + service_called = asyncio.Event() + + async def async_service_handler(service): + if service.service == "script_done": + service_called.set() + + hass.services.async_register("test", "script_done", async_service_handler) + + await hass.services.async_call("script", "script1") + await asyncio.wait_for(service_called.wait(), 1) + + # Trigger 1st stage script shutdown + hass.state = CoreState.stopping + hass.bus.async_fire("homeassistant_stop") + await asyncio.wait_for(stop_scripts_at_shutdown_called.wait(), 1) + + # Trigger 2nd stage script shutdown + async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=90)) + await hass.async_block_till_done() + + assert "Disallowed recursion detected" not in caplog.text + + async def test_setup_with_duplicate_scripts( hass: HomeAssistant, caplog: pytest.LogCaptureFixture ) -> None: