mirror of
https://github.com/home-assistant/core.git
synced 2025-11-10 03:19:34 +00:00
Add support for breakpoints in scripts (#47632)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
"""Helpers to execute scripts."""
|
||||
import asyncio
|
||||
from contextlib import contextmanager
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timedelta
|
||||
from functools import partial
|
||||
import itertools
|
||||
@@ -65,6 +65,10 @@ from homeassistant.core import (
|
||||
)
|
||||
from homeassistant.helpers import condition, config_validation as cv, service, template
|
||||
from homeassistant.helpers.condition import trace_condition_function
|
||||
from homeassistant.helpers.dispatcher import (
|
||||
async_dispatcher_connect,
|
||||
async_dispatcher_send,
|
||||
)
|
||||
from homeassistant.helpers.event import async_call_later, async_track_template
|
||||
from homeassistant.helpers.script_variables import ScriptVariables
|
||||
from homeassistant.helpers.trigger import (
|
||||
@@ -78,6 +82,7 @@ from homeassistant.util.dt import utcnow
|
||||
from .trace import (
|
||||
TraceElement,
|
||||
trace_append_element,
|
||||
trace_id_get,
|
||||
trace_path,
|
||||
trace_path_get,
|
||||
trace_set_result,
|
||||
@@ -111,6 +116,9 @@ ATTR_CUR = "current"
|
||||
ATTR_MAX = "max"
|
||||
|
||||
DATA_SCRIPTS = "helpers.script"
|
||||
DATA_SCRIPT_BREAKPOINTS = "helpers.script_breakpoints"
|
||||
RUN_ID_ANY = "*"
|
||||
NODE_ANY = "*"
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -122,6 +130,10 @@ _SHUTDOWN_MAX_WAIT = 60
|
||||
|
||||
ACTION_TRACE_NODE_MAX_LEN = 20 # Max length of a trace node for repeated actions
|
||||
|
||||
SCRIPT_BREAKPOINT_HIT = "script_breakpoint_hit"
|
||||
SCRIPT_DEBUG_CONTINUE_STOP = "script_debug_continue_stop_{}_{}"
|
||||
SCRIPT_DEBUG_CONTINUE_ALL = "script_debug_continue_all"
|
||||
|
||||
|
||||
def action_trace_append(variables, path):
|
||||
"""Append a TraceElement to trace[path]."""
|
||||
@@ -130,11 +142,57 @@ def action_trace_append(variables, path):
|
||||
return trace_element
|
||||
|
||||
|
||||
@contextmanager
|
||||
def trace_action(variables):
|
||||
@asynccontextmanager
|
||||
async def trace_action(hass, script_run, stop, variables):
|
||||
"""Trace action execution."""
|
||||
trace_element = action_trace_append(variables, trace_path_get())
|
||||
path = trace_path_get()
|
||||
trace_element = action_trace_append(variables, path)
|
||||
trace_stack_push(trace_stack_cv, trace_element)
|
||||
|
||||
trace_id = trace_id_get()
|
||||
if trace_id:
|
||||
unique_id = trace_id[0]
|
||||
run_id = trace_id[1]
|
||||
breakpoints = hass.data[DATA_SCRIPT_BREAKPOINTS]
|
||||
if unique_id in breakpoints and (
|
||||
(
|
||||
run_id in breakpoints[unique_id]
|
||||
and (
|
||||
path in breakpoints[unique_id][run_id]
|
||||
or NODE_ANY in breakpoints[unique_id][run_id]
|
||||
)
|
||||
)
|
||||
or (
|
||||
RUN_ID_ANY in breakpoints[unique_id]
|
||||
and (
|
||||
path in breakpoints[unique_id][RUN_ID_ANY]
|
||||
or NODE_ANY in breakpoints[unique_id][RUN_ID_ANY]
|
||||
)
|
||||
)
|
||||
):
|
||||
async_dispatcher_send(hass, SCRIPT_BREAKPOINT_HIT, unique_id, run_id, path)
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
@callback
|
||||
def async_continue_stop(command=None):
|
||||
if command == "stop":
|
||||
stop.set()
|
||||
done.set()
|
||||
|
||||
signal = SCRIPT_DEBUG_CONTINUE_STOP.format(unique_id, run_id)
|
||||
remove_signal1 = async_dispatcher_connect(hass, signal, async_continue_stop)
|
||||
remove_signal2 = async_dispatcher_connect(
|
||||
hass, SCRIPT_DEBUG_CONTINUE_ALL, async_continue_stop
|
||||
)
|
||||
|
||||
tasks = [hass.async_create_task(flag.wait()) for flag in (stop, done)]
|
||||
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
remove_signal1()
|
||||
remove_signal2()
|
||||
|
||||
try:
|
||||
yield trace_element
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
@@ -294,16 +352,19 @@ class _ScriptRun:
|
||||
self._finish()
|
||||
|
||||
async def _async_step(self, log_exceptions):
|
||||
with trace_path(str(self._step)), trace_action(self._variables):
|
||||
try:
|
||||
handler = f"_async_{cv.determine_script_action(self._action)}_step"
|
||||
await getattr(self, handler)()
|
||||
except Exception as ex:
|
||||
if not isinstance(ex, (_StopScript, asyncio.CancelledError)) and (
|
||||
self._log_exceptions or log_exceptions
|
||||
):
|
||||
self._log_exception(ex)
|
||||
raise
|
||||
with trace_path(str(self._step)):
|
||||
async with trace_action(self._hass, self, self._stop, self._variables):
|
||||
if self._stop.is_set():
|
||||
return
|
||||
try:
|
||||
handler = f"_async_{cv.determine_script_action(self._action)}_step"
|
||||
await getattr(self, handler)()
|
||||
except Exception as ex:
|
||||
if not isinstance(ex, (_StopScript, asyncio.CancelledError)) and (
|
||||
self._log_exceptions or log_exceptions
|
||||
):
|
||||
self._log_exception(ex)
|
||||
raise
|
||||
|
||||
def _finish(self) -> None:
|
||||
self._script._runs.remove(self) # pylint: disable=protected-access
|
||||
@@ -876,6 +937,8 @@ class Script:
|
||||
all_scripts.append(
|
||||
{"instance": self, "started_before_shutdown": not hass.is_stopping}
|
||||
)
|
||||
if DATA_SCRIPT_BREAKPOINTS not in hass.data:
|
||||
hass.data[DATA_SCRIPT_BREAKPOINTS] = {}
|
||||
|
||||
self._hass = hass
|
||||
self.sequence = sequence
|
||||
@@ -1213,3 +1276,71 @@ class Script:
|
||||
self._logger.exception(msg, *args, **kwargs)
|
||||
else:
|
||||
self._logger.log(level, msg, *args, **kwargs)
|
||||
|
||||
|
||||
@callback
|
||||
def breakpoint_clear(hass, unique_id, run_id, node):
|
||||
"""Clear a breakpoint."""
|
||||
run_id = run_id or RUN_ID_ANY
|
||||
breakpoints = hass.data[DATA_SCRIPT_BREAKPOINTS]
|
||||
if unique_id not in breakpoints or run_id not in breakpoints[unique_id]:
|
||||
return
|
||||
breakpoints[unique_id][run_id].discard(node)
|
||||
|
||||
|
||||
@callback
|
||||
def breakpoint_clear_all(hass):
|
||||
"""Clear all breakpoints."""
|
||||
hass.data[DATA_SCRIPT_BREAKPOINTS] = {}
|
||||
|
||||
|
||||
@callback
|
||||
def breakpoint_set(hass, unique_id, run_id, node):
|
||||
"""Set a breakpoint."""
|
||||
run_id = run_id or RUN_ID_ANY
|
||||
breakpoints = hass.data[DATA_SCRIPT_BREAKPOINTS]
|
||||
if unique_id not in breakpoints:
|
||||
breakpoints[unique_id] = {}
|
||||
if run_id not in breakpoints[unique_id]:
|
||||
breakpoints[unique_id][run_id] = set()
|
||||
breakpoints[unique_id][run_id].add(node)
|
||||
|
||||
|
||||
@callback
|
||||
def breakpoint_list(hass):
|
||||
"""List breakpoints."""
|
||||
breakpoints = hass.data[DATA_SCRIPT_BREAKPOINTS]
|
||||
|
||||
return [
|
||||
{"unique_id": unique_id, "run_id": run_id, "node": node}
|
||||
for unique_id in breakpoints
|
||||
for run_id in breakpoints[unique_id]
|
||||
for node in breakpoints[unique_id][run_id]
|
||||
]
|
||||
|
||||
|
||||
@callback
|
||||
def debug_continue(hass, unique_id, run_id):
|
||||
"""Continue execution of a halted script."""
|
||||
# Clear any wildcard breakpoint
|
||||
breakpoint_clear(hass, unique_id, run_id, NODE_ANY)
|
||||
|
||||
signal = SCRIPT_DEBUG_CONTINUE_STOP.format(unique_id, run_id)
|
||||
async_dispatcher_send(hass, signal, "continue")
|
||||
|
||||
|
||||
@callback
|
||||
def debug_step(hass, unique_id, run_id):
|
||||
"""Single step a halted script."""
|
||||
# Set a wildcard breakpoint
|
||||
breakpoint_set(hass, unique_id, run_id, NODE_ANY)
|
||||
|
||||
signal = SCRIPT_DEBUG_CONTINUE_STOP.format(unique_id, run_id)
|
||||
async_dispatcher_send(hass, signal, "continue")
|
||||
|
||||
|
||||
@callback
|
||||
def debug_stop(hass, unique_id, run_id):
|
||||
"""Stop execution of a running or halted script."""
|
||||
signal = SCRIPT_DEBUG_CONTINUE_STOP.format(unique_id, run_id)
|
||||
async_dispatcher_send(hass, signal, "stop")
|
||||
|
||||
Reference in New Issue
Block a user