Prevent recursive script calls from deadlocking (#67861)

* Prevent recursive script calls from deadlocking

* Address review comments, improve tests

* Tweak comment
This commit is contained in:
Erik Montnemery 2022-03-10 19:28:00 +01:00 committed by GitHub
parent 3d212b868e
commit 65fbcfa0ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 148 additions and 0 deletions

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from contextlib import asynccontextmanager, suppress from contextlib import asynccontextmanager, suppress
from contextvars import ContextVar
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import partial from functools import partial
import itertools import itertools
@ -126,6 +127,8 @@ SCRIPT_BREAKPOINT_HIT = "script_breakpoint_hit"
SCRIPT_DEBUG_CONTINUE_STOP = "script_debug_continue_stop_{}_{}" SCRIPT_DEBUG_CONTINUE_STOP = "script_debug_continue_stop_{}_{}"
SCRIPT_DEBUG_CONTINUE_ALL = "script_debug_continue_all" SCRIPT_DEBUG_CONTINUE_ALL = "script_debug_continue_all"
script_stack_cv: ContextVar[list[int] | None] = ContextVar("script_stack", default=None)
def action_trace_append(variables, path): def action_trace_append(variables, path):
"""Append a TraceElement to trace[path].""" """Append a TraceElement to trace[path]."""
@ -340,6 +343,12 @@ class _ScriptRun:
async def async_run(self) -> None: async def async_run(self) -> None:
"""Run script.""" """Run script."""
# Push the script to the script execution stack
if (script_stack := script_stack_cv.get()) is None:
script_stack = []
script_stack_cv.set(script_stack)
script_stack.append(id(self._script))
try: try:
self._log("Running %s", self._script.running_description) self._log("Running %s", self._script.running_description)
for self._step, self._action in enumerate(self._script.sequence): for self._step, self._action in enumerate(self._script.sequence):
@ -355,6 +364,8 @@ class _ScriptRun:
script_execution_set("error") script_execution_set("error")
raise raise
finally: finally:
# Pop the script from the script execution stack
script_stack.pop()
self._finish() self._finish()
async def _async_step(self, log_exceptions): async def _async_step(self, log_exceptions):
@ -1218,6 +1229,18 @@ class Script:
else: else:
variables = cast(dict, run_variables) variables = cast(dict, run_variables)
# Prevent non-allowed recursive calls which will cause deadlocks when we try to
# stop (restart) or wait for (queued) our own script run.
script_stack = script_stack_cv.get()
if (
self.script_mode in (SCRIPT_MODE_RESTART, SCRIPT_MODE_QUEUED)
and (script_stack := script_stack_cv.get()) is not None
and id(self) in script_stack
):
script_execution_set("disallowed_recursion_detected")
_LOGGER.warning("Disallowed recursion detected")
return
if self.script_mode != SCRIPT_MODE_QUEUED: if self.script_mode != SCRIPT_MODE_QUEUED:
cls = _ScriptRun cls = _ScriptRun
else: else:

View File

@ -27,6 +27,13 @@ from homeassistant.core import (
from homeassistant.exceptions import ServiceNotFound from homeassistant.exceptions import ServiceNotFound
from homeassistant.helpers import template from homeassistant.helpers import template
from homeassistant.helpers.event import async_track_state_change from homeassistant.helpers.event import async_track_state_change
from homeassistant.helpers.script import (
SCRIPT_MODE_CHOICES,
SCRIPT_MODE_PARALLEL,
SCRIPT_MODE_QUEUED,
SCRIPT_MODE_RESTART,
SCRIPT_MODE_SINGLE,
)
from homeassistant.helpers.service import async_get_all_descriptions from homeassistant.helpers.service import async_get_all_descriptions
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
@ -790,3 +797,121 @@ async def test_script_restore_last_triggered(hass: HomeAssistant) -> None:
state = hass.states.get("script.last_triggered") state = hass.states.get("script.last_triggered")
assert state assert state
assert state.attributes["last_triggered"] == time assert state.attributes["last_triggered"] == time
@pytest.mark.parametrize(
"script_mode,warning_msg",
(
(SCRIPT_MODE_PARALLEL, "Maximum number of runs exceeded"),
(SCRIPT_MODE_QUEUED, "Disallowed recursion detected"),
(SCRIPT_MODE_RESTART, "Disallowed recursion detected"),
(SCRIPT_MODE_SINGLE, "Already running"),
),
)
async def test_recursive_script(hass, script_mode, warning_msg, caplog):
"""Test recursive script calls does not deadlock."""
# Make sure we cover all script modes
assert SCRIPT_MODE_CHOICES == [
SCRIPT_MODE_PARALLEL,
SCRIPT_MODE_QUEUED,
SCRIPT_MODE_RESTART,
SCRIPT_MODE_SINGLE,
]
assert await async_setup_component(
hass,
"script",
{
"script": {
"script1": {
"mode": script_mode,
"sequence": [
{"service": "script.script1"},
{"service": "test.script"},
],
},
}
},
)
service_called = asyncio.Event()
async def async_service_handler(service):
service_called.set()
hass.services.async_register("test", "script", async_service_handler)
hass.states.async_set("input_boolean.test", "on")
hass.states.async_set("input_boolean.test2", "off")
await hass.services.async_call("script", "script1")
await asyncio.wait_for(service_called.wait(), 1)
assert warning_msg in caplog.text
@pytest.mark.parametrize(
"script_mode,warning_msg",
(
(SCRIPT_MODE_PARALLEL, "Maximum number of runs exceeded"),
(SCRIPT_MODE_QUEUED, "Disallowed recursion detected"),
(SCRIPT_MODE_RESTART, "Disallowed recursion detected"),
(SCRIPT_MODE_SINGLE, "Already running"),
),
)
async def test_recursive_script_indirect(hass, script_mode, warning_msg, caplog):
"""Test recursive script calls does not deadlock."""
# Make sure we cover all script modes
assert SCRIPT_MODE_CHOICES == [
SCRIPT_MODE_PARALLEL,
SCRIPT_MODE_QUEUED,
SCRIPT_MODE_RESTART,
SCRIPT_MODE_SINGLE,
]
assert await async_setup_component(
hass,
"script",
{
"script": {
"script1": {
"mode": script_mode,
"sequence": [
{"service": "script.script2"},
],
},
"script2": {
"mode": script_mode,
"sequence": [
{"service": "script.script3"},
],
},
"script3": {
"mode": script_mode,
"sequence": [
{"service": "script.script4"},
],
},
"script4": {
"mode": script_mode,
"sequence": [
{"service": "script.script1"},
{"service": "test.script"},
],
},
}
},
)
service_called = asyncio.Event()
async def async_service_handler(service):
service_called.set()
hass.services.async_register("test", "script", async_service_handler)
hass.states.async_set("input_boolean.test", "on")
hass.states.async_set("input_boolean.test2", "off")
await hass.services.async_call("script", "script1")
await asyncio.wait_for(service_called.wait(), 1)
assert warning_msg in caplog.text