Don't prevent automations from triggering themselves (#68178)

This commit is contained in:
Erik Montnemery 2022-03-15 18:48:54 +01:00 committed by GitHub
parent b99934f62f
commit 46f27fdefd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 212 additions and 6 deletions

View File

@ -54,6 +54,7 @@ from homeassistant.helpers.script import (
CONF_MAX, CONF_MAX,
CONF_MAX_EXCEEDED, CONF_MAX_EXCEEDED,
Script, Script,
script_stack_cv,
) )
from homeassistant.helpers.script_variables import ScriptVariables from homeassistant.helpers.script_variables import ScriptVariables
from homeassistant.helpers.service import ( from homeassistant.helpers.service import (
@ -505,6 +506,10 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
EVENT_AUTOMATION_TRIGGERED, event_data, context=trigger_context EVENT_AUTOMATION_TRIGGERED, event_data, context=trigger_context
) )
# Make a new empty script stack; automations are allowed
# to recursively trigger themselves
script_stack_cv.set([])
try: try:
with trace_path("action"): with trace_path("action"):
await self.action_script.async_run( await self.action_script.async_run(

View File

@ -1247,7 +1247,7 @@ class Script:
and id(self) in script_stack and id(self) in script_stack
): ):
script_execution_set("disallowed_recursion_detected") script_execution_set("disallowed_recursion_detected")
_LOGGER.warning("Disallowed recursion detected") self._log("Disallowed recursion detected", level=logging.WARNING)
return return
if self.script_mode != SCRIPT_MODE_QUEUED: if self.script_mode != SCRIPT_MODE_QUEUED:

View File

@ -1,5 +1,6 @@
"""The tests for the automation component.""" """The tests for the automation component."""
import asyncio import asyncio
from datetime import timedelta
import logging import logging
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
@ -25,14 +26,30 @@ from homeassistant.const import (
STATE_OFF, STATE_OFF,
STATE_ON, STATE_ON,
) )
from homeassistant.core import Context, CoreState, State, callback from homeassistant.core import (
Context,
CoreState,
HomeAssistant,
ServiceCall,
State,
callback,
)
from homeassistant.exceptions import HomeAssistantError, Unauthorized from homeassistant.exceptions import HomeAssistantError, Unauthorized
from homeassistant.helpers.script import (
SCRIPT_MODE_CHOICES,
SCRIPT_MODE_PARALLEL,
SCRIPT_MODE_QUEUED,
SCRIPT_MODE_RESTART,
SCRIPT_MODE_SINGLE,
_async_stop_scripts_at_shutdown,
)
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from tests.common import ( from tests.common import (
assert_setup_component, assert_setup_component,
async_capture_events, async_capture_events,
async_fire_time_changed,
async_mock_service, async_mock_service,
mock_restore_cache, mock_restore_cache,
) )
@ -1570,3 +1587,191 @@ async def test_trigger_condition_explicit_id(hass, calls):
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(calls) == 2 assert len(calls) == 2
assert calls[-1].data.get("param") == "two" assert calls[-1].data.get("param") == "two"
@pytest.mark.parametrize(
"automation_mode,automation_runs",
(
(SCRIPT_MODE_PARALLEL, 2),
(SCRIPT_MODE_QUEUED, 2),
(SCRIPT_MODE_RESTART, 2),
(SCRIPT_MODE_SINGLE, 1),
),
)
@pytest.mark.parametrize(
"script_mode,script_warning_msg",
(
(SCRIPT_MODE_PARALLEL, "script1: Maximum number of runs exceeded"),
(SCRIPT_MODE_QUEUED, "script1: Disallowed recursion detected"),
(SCRIPT_MODE_RESTART, "script1: Disallowed recursion detected"),
(SCRIPT_MODE_SINGLE, "script1: Already running"),
),
)
async def test_recursive_automation_starting_script(
hass: HomeAssistant,
automation_mode,
automation_runs,
script_mode,
script_warning_msg,
caplog,
):
"""Test starting automations does not interfere with script deadlock prevention."""
# Fail if additional script modes are added to
# make sure we cover all script modes in tests
assert SCRIPT_MODE_CHOICES == [
SCRIPT_MODE_PARALLEL,
SCRIPT_MODE_QUEUED,
SCRIPT_MODE_RESTART,
SCRIPT_MODE_SINGLE,
]
stop_scripts_at_shutdown_called = asyncio.Event()
real_stop_scripts_at_shutdown = _async_stop_scripts_at_shutdown
async def mock_stop_scripts_at_shutdown(*args):
await real_stop_scripts_at_shutdown(*args)
stop_scripts_at_shutdown_called.set()
with patch(
"homeassistant.helpers.script._async_stop_scripts_at_shutdown",
wraps=mock_stop_scripts_at_shutdown,
):
assert await async_setup_component(
hass,
"script",
{
"script": {
"script1": {
"mode": script_mode,
"sequence": [
{"event": "trigger_automation"},
{
"wait_template": f"{{{{ float(states('sensor.test'), 0) >= {automation_runs} }}}}"
},
{"service": "script.script1"},
{"service": "test.script_done"},
],
},
}
},
)
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"mode": automation_mode,
"trigger": [
{"platform": "event", "event_type": "trigger_automation"},
],
"action": [
{"service": "test.automation_started"},
{"service": "script.script1"},
],
}
},
)
script_done_event = asyncio.Event()
script_done = []
automation_started = []
automation_triggered = []
async def async_service_handler(service: ServiceCall):
if service.service == "automation_started":
automation_started.append(service)
elif service.service == "script_done":
script_done.append(service)
if len(script_done) == 1:
script_done_event.set()
async def async_automation_triggered(event):
"""Listen to automation_triggered event from the automation integration."""
automation_triggered.append(event)
hass.states.async_set("sensor.test", str(len(automation_triggered)))
hass.services.async_register("test", "script_done", async_service_handler)
hass.services.async_register(
"test", "automation_started", async_service_handler
)
hass.bus.async_listen("automation_triggered", async_automation_triggered)
hass.bus.async_fire("trigger_automation")
await asyncio.wait_for(script_done_event.wait(), 1)
# Trigger 1st stage script shutdown
hass.state = CoreState.stopping
hass.bus.async_fire("homeassistant_stop")
await asyncio.wait_for(stop_scripts_at_shutdown_called.wait(), 1)
# Trigger 2nd stage script shutdown
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=60))
await hass.async_block_till_done()
assert script_warning_msg in caplog.text
@pytest.mark.parametrize("automation_mode", SCRIPT_MODE_CHOICES)
async def test_recursive_automation(hass: HomeAssistant, automation_mode, caplog):
"""Test automation triggering itself.
- Illegal recursion detection should not be triggered
- Home Assistant should not hang on shut down
"""
stop_scripts_at_shutdown_called = asyncio.Event()
real_stop_scripts_at_shutdown = _async_stop_scripts_at_shutdown
async def stop_scripts_at_shutdown(*args):
await real_stop_scripts_at_shutdown(*args)
stop_scripts_at_shutdown_called.set()
with patch(
"homeassistant.helpers.script._async_stop_scripts_at_shutdown",
wraps=stop_scripts_at_shutdown,
):
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"mode": automation_mode,
"trigger": [
{"platform": "event", "event_type": "trigger_automation"},
],
"action": [
{"event": "trigger_automation"},
{"service": "test.automation_done"},
],
}
},
)
service_called = asyncio.Event()
service_called_late = []
async def async_service_handler(service):
if service.service == "automation_done":
service_called.set()
if service.service == "automation_started_late":
service_called_late.append(service)
hass.services.async_register("test", "automation_done", async_service_handler)
hass.services.async_register(
"test", "automation_started_late", async_service_handler
)
hass.bus.async_fire("trigger_automation")
await asyncio.wait_for(service_called.wait(), 1)
# Trigger 1st stage script shutdown
hass.state = CoreState.stopping
hass.bus.async_fire("homeassistant_stop")
await asyncio.wait_for(stop_scripts_at_shutdown_called.wait(), 1)
# Trigger 2nd stage script shutdown
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=90))
await hass.async_block_till_done()
assert "Disallowed recursion detected" not in caplog.text

View File

@ -840,8 +840,6 @@ async def test_recursive_script(hass, script_mode, warning_msg, caplog):
service_called.set() service_called.set()
hass.services.async_register("test", "script", async_service_handler) hass.services.async_register("test", "script", async_service_handler)
hass.states.async_set("input_boolean.test", "on")
hass.states.async_set("input_boolean.test2", "off")
await hass.services.async_call("script", "script1") await hass.services.async_call("script", "script1")
await asyncio.wait_for(service_called.wait(), 1) await asyncio.wait_for(service_called.wait(), 1)
@ -908,8 +906,6 @@ async def test_recursive_script_indirect(hass, script_mode, warning_msg, caplog)
service_called.set() service_called.set()
hass.services.async_register("test", "script", async_service_handler) hass.services.async_register("test", "script", async_service_handler)
hass.states.async_set("input_boolean.test", "on")
hass.states.async_set("input_boolean.test2", "off")
await hass.services.async_call("script", "script1") await hass.services.async_call("script", "script1")
await asyncio.wait_for(service_called.wait(), 1) await asyncio.wait_for(service_called.wait(), 1)