mirror of
https://github.com/home-assistant/core.git
synced 2025-07-22 12:47:08 +00:00
Prevent recursive script calls from deadlocking (#67861)
* Prevent recursive script calls from deadlocking * Address review comments, improve tests * Tweak comment
This commit is contained in:
parent
3d212b868e
commit
65fbcfa0ba
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from contextlib import asynccontextmanager, suppress
|
from contextlib import asynccontextmanager, suppress
|
||||||
|
from contextvars import ContextVar
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import itertools
|
import itertools
|
||||||
@ -126,6 +127,8 @@ SCRIPT_BREAKPOINT_HIT = "script_breakpoint_hit"
|
|||||||
SCRIPT_DEBUG_CONTINUE_STOP = "script_debug_continue_stop_{}_{}"
|
SCRIPT_DEBUG_CONTINUE_STOP = "script_debug_continue_stop_{}_{}"
|
||||||
SCRIPT_DEBUG_CONTINUE_ALL = "script_debug_continue_all"
|
SCRIPT_DEBUG_CONTINUE_ALL = "script_debug_continue_all"
|
||||||
|
|
||||||
|
script_stack_cv: ContextVar[list[int] | None] = ContextVar("script_stack", default=None)
|
||||||
|
|
||||||
|
|
||||||
def action_trace_append(variables, path):
|
def action_trace_append(variables, path):
|
||||||
"""Append a TraceElement to trace[path]."""
|
"""Append a TraceElement to trace[path]."""
|
||||||
@ -340,6 +343,12 @@ class _ScriptRun:
|
|||||||
|
|
||||||
async def async_run(self) -> None:
|
async def async_run(self) -> None:
|
||||||
"""Run script."""
|
"""Run script."""
|
||||||
|
# Push the script to the script execution stack
|
||||||
|
if (script_stack := script_stack_cv.get()) is None:
|
||||||
|
script_stack = []
|
||||||
|
script_stack_cv.set(script_stack)
|
||||||
|
script_stack.append(id(self._script))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._log("Running %s", self._script.running_description)
|
self._log("Running %s", self._script.running_description)
|
||||||
for self._step, self._action in enumerate(self._script.sequence):
|
for self._step, self._action in enumerate(self._script.sequence):
|
||||||
@ -355,6 +364,8 @@ class _ScriptRun:
|
|||||||
script_execution_set("error")
|
script_execution_set("error")
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
|
# Pop the script from the script execution stack
|
||||||
|
script_stack.pop()
|
||||||
self._finish()
|
self._finish()
|
||||||
|
|
||||||
async def _async_step(self, log_exceptions):
|
async def _async_step(self, log_exceptions):
|
||||||
@ -1218,6 +1229,18 @@ class Script:
|
|||||||
else:
|
else:
|
||||||
variables = cast(dict, run_variables)
|
variables = cast(dict, run_variables)
|
||||||
|
|
||||||
|
# Prevent non-allowed recursive calls which will cause deadlocks when we try to
|
||||||
|
# stop (restart) or wait for (queued) our own script run.
|
||||||
|
script_stack = script_stack_cv.get()
|
||||||
|
if (
|
||||||
|
self.script_mode in (SCRIPT_MODE_RESTART, SCRIPT_MODE_QUEUED)
|
||||||
|
and (script_stack := script_stack_cv.get()) is not None
|
||||||
|
and id(self) in script_stack
|
||||||
|
):
|
||||||
|
script_execution_set("disallowed_recursion_detected")
|
||||||
|
_LOGGER.warning("Disallowed recursion detected")
|
||||||
|
return
|
||||||
|
|
||||||
if self.script_mode != SCRIPT_MODE_QUEUED:
|
if self.script_mode != SCRIPT_MODE_QUEUED:
|
||||||
cls = _ScriptRun
|
cls = _ScriptRun
|
||||||
else:
|
else:
|
||||||
|
@ -27,6 +27,13 @@ from homeassistant.core import (
|
|||||||
from homeassistant.exceptions import ServiceNotFound
|
from homeassistant.exceptions import ServiceNotFound
|
||||||
from homeassistant.helpers import template
|
from homeassistant.helpers import template
|
||||||
from homeassistant.helpers.event import async_track_state_change
|
from homeassistant.helpers.event import async_track_state_change
|
||||||
|
from homeassistant.helpers.script import (
|
||||||
|
SCRIPT_MODE_CHOICES,
|
||||||
|
SCRIPT_MODE_PARALLEL,
|
||||||
|
SCRIPT_MODE_QUEUED,
|
||||||
|
SCRIPT_MODE_RESTART,
|
||||||
|
SCRIPT_MODE_SINGLE,
|
||||||
|
)
|
||||||
from homeassistant.helpers.service import async_get_all_descriptions
|
from homeassistant.helpers.service import async_get_all_descriptions
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
import homeassistant.util.dt as dt_util
|
import homeassistant.util.dt as dt_util
|
||||||
@ -790,3 +797,121 @@ async def test_script_restore_last_triggered(hass: HomeAssistant) -> None:
|
|||||||
state = hass.states.get("script.last_triggered")
|
state = hass.states.get("script.last_triggered")
|
||||||
assert state
|
assert state
|
||||||
assert state.attributes["last_triggered"] == time
|
assert state.attributes["last_triggered"] == time
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"script_mode,warning_msg",
|
||||||
|
(
|
||||||
|
(SCRIPT_MODE_PARALLEL, "Maximum number of runs exceeded"),
|
||||||
|
(SCRIPT_MODE_QUEUED, "Disallowed recursion detected"),
|
||||||
|
(SCRIPT_MODE_RESTART, "Disallowed recursion detected"),
|
||||||
|
(SCRIPT_MODE_SINGLE, "Already running"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
async def test_recursive_script(hass, script_mode, warning_msg, caplog):
|
||||||
|
"""Test recursive script calls does not deadlock."""
|
||||||
|
# Make sure we cover all script modes
|
||||||
|
assert SCRIPT_MODE_CHOICES == [
|
||||||
|
SCRIPT_MODE_PARALLEL,
|
||||||
|
SCRIPT_MODE_QUEUED,
|
||||||
|
SCRIPT_MODE_RESTART,
|
||||||
|
SCRIPT_MODE_SINGLE,
|
||||||
|
]
|
||||||
|
|
||||||
|
assert await async_setup_component(
|
||||||
|
hass,
|
||||||
|
"script",
|
||||||
|
{
|
||||||
|
"script": {
|
||||||
|
"script1": {
|
||||||
|
"mode": script_mode,
|
||||||
|
"sequence": [
|
||||||
|
{"service": "script.script1"},
|
||||||
|
{"service": "test.script"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
service_called = asyncio.Event()
|
||||||
|
|
||||||
|
async def async_service_handler(service):
|
||||||
|
service_called.set()
|
||||||
|
|
||||||
|
hass.services.async_register("test", "script", async_service_handler)
|
||||||
|
hass.states.async_set("input_boolean.test", "on")
|
||||||
|
hass.states.async_set("input_boolean.test2", "off")
|
||||||
|
|
||||||
|
await hass.services.async_call("script", "script1")
|
||||||
|
await asyncio.wait_for(service_called.wait(), 1)
|
||||||
|
|
||||||
|
assert warning_msg in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"script_mode,warning_msg",
|
||||||
|
(
|
||||||
|
(SCRIPT_MODE_PARALLEL, "Maximum number of runs exceeded"),
|
||||||
|
(SCRIPT_MODE_QUEUED, "Disallowed recursion detected"),
|
||||||
|
(SCRIPT_MODE_RESTART, "Disallowed recursion detected"),
|
||||||
|
(SCRIPT_MODE_SINGLE, "Already running"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
async def test_recursive_script_indirect(hass, script_mode, warning_msg, caplog):
|
||||||
|
"""Test recursive script calls does not deadlock."""
|
||||||
|
# Make sure we cover all script modes
|
||||||
|
assert SCRIPT_MODE_CHOICES == [
|
||||||
|
SCRIPT_MODE_PARALLEL,
|
||||||
|
SCRIPT_MODE_QUEUED,
|
||||||
|
SCRIPT_MODE_RESTART,
|
||||||
|
SCRIPT_MODE_SINGLE,
|
||||||
|
]
|
||||||
|
|
||||||
|
assert await async_setup_component(
|
||||||
|
hass,
|
||||||
|
"script",
|
||||||
|
{
|
||||||
|
"script": {
|
||||||
|
"script1": {
|
||||||
|
"mode": script_mode,
|
||||||
|
"sequence": [
|
||||||
|
{"service": "script.script2"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"script2": {
|
||||||
|
"mode": script_mode,
|
||||||
|
"sequence": [
|
||||||
|
{"service": "script.script3"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"script3": {
|
||||||
|
"mode": script_mode,
|
||||||
|
"sequence": [
|
||||||
|
{"service": "script.script4"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"script4": {
|
||||||
|
"mode": script_mode,
|
||||||
|
"sequence": [
|
||||||
|
{"service": "script.script1"},
|
||||||
|
{"service": "test.script"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
service_called = asyncio.Event()
|
||||||
|
|
||||||
|
async def async_service_handler(service):
|
||||||
|
service_called.set()
|
||||||
|
|
||||||
|
hass.services.async_register("test", "script", async_service_handler)
|
||||||
|
hass.states.async_set("input_boolean.test", "on")
|
||||||
|
hass.states.async_set("input_boolean.test2", "off")
|
||||||
|
|
||||||
|
await hass.services.async_call("script", "script1")
|
||||||
|
await asyncio.wait_for(service_called.wait(), 1)
|
||||||
|
|
||||||
|
assert warning_msg in caplog.text
|
||||||
|
Loading…
x
Reference in New Issue
Block a user