Add support for breakpoints in scripts (#47632)

This commit is contained in:
Erik Montnemery
2021-03-10 06:23:11 +01:00
committed by GitHub
parent bf64421be9
commit 704000c049
6 changed files with 961 additions and 32 deletions

View File

@@ -1,6 +1,6 @@
"""Helpers to execute scripts."""
import asyncio
from contextlib import contextmanager
from contextlib import asynccontextmanager
from datetime import datetime, timedelta
from functools import partial
import itertools
@@ -65,6 +65,10 @@ from homeassistant.core import (
)
from homeassistant.helpers import condition, config_validation as cv, service, template
from homeassistant.helpers.condition import trace_condition_function
from homeassistant.helpers.dispatcher import (
async_dispatcher_connect,
async_dispatcher_send,
)
from homeassistant.helpers.event import async_call_later, async_track_template
from homeassistant.helpers.script_variables import ScriptVariables
from homeassistant.helpers.trigger import (
@@ -78,6 +82,7 @@ from homeassistant.util.dt import utcnow
from .trace import (
TraceElement,
trace_append_element,
trace_id_get,
trace_path,
trace_path_get,
trace_set_result,
@@ -111,6 +116,9 @@ ATTR_CUR = "current"
ATTR_MAX = "max"
DATA_SCRIPTS = "helpers.script"
DATA_SCRIPT_BREAKPOINTS = "helpers.script_breakpoints"
RUN_ID_ANY = "*"
NODE_ANY = "*"
_LOGGER = logging.getLogger(__name__)
@@ -122,6 +130,10 @@ _SHUTDOWN_MAX_WAIT = 60
ACTION_TRACE_NODE_MAX_LEN = 20 # Max length of a trace node for repeated actions
SCRIPT_BREAKPOINT_HIT = "script_breakpoint_hit"
SCRIPT_DEBUG_CONTINUE_STOP = "script_debug_continue_stop_{}_{}"
SCRIPT_DEBUG_CONTINUE_ALL = "script_debug_continue_all"
def action_trace_append(variables, path):
"""Append a TraceElement to trace[path]."""
@@ -130,11 +142,57 @@ def action_trace_append(variables, path):
return trace_element
@contextmanager
def trace_action(variables):
@asynccontextmanager
async def trace_action(hass, script_run, stop, variables):
"""Trace action execution."""
trace_element = action_trace_append(variables, trace_path_get())
path = trace_path_get()
trace_element = action_trace_append(variables, path)
trace_stack_push(trace_stack_cv, trace_element)
trace_id = trace_id_get()
if trace_id:
unique_id = trace_id[0]
run_id = trace_id[1]
breakpoints = hass.data[DATA_SCRIPT_BREAKPOINTS]
if unique_id in breakpoints and (
(
run_id in breakpoints[unique_id]
and (
path in breakpoints[unique_id][run_id]
or NODE_ANY in breakpoints[unique_id][run_id]
)
)
or (
RUN_ID_ANY in breakpoints[unique_id]
and (
path in breakpoints[unique_id][RUN_ID_ANY]
or NODE_ANY in breakpoints[unique_id][RUN_ID_ANY]
)
)
):
async_dispatcher_send(hass, SCRIPT_BREAKPOINT_HIT, unique_id, run_id, path)
done = asyncio.Event()
@callback
def async_continue_stop(command=None):
if command == "stop":
stop.set()
done.set()
signal = SCRIPT_DEBUG_CONTINUE_STOP.format(unique_id, run_id)
remove_signal1 = async_dispatcher_connect(hass, signal, async_continue_stop)
remove_signal2 = async_dispatcher_connect(
hass, SCRIPT_DEBUG_CONTINUE_ALL, async_continue_stop
)
tasks = [hass.async_create_task(flag.wait()) for flag in (stop, done)]
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
for task in tasks:
task.cancel()
remove_signal1()
remove_signal2()
try:
yield trace_element
except Exception as ex: # pylint: disable=broad-except
@@ -294,16 +352,19 @@ class _ScriptRun:
self._finish()
async def _async_step(self, log_exceptions):
with trace_path(str(self._step)), trace_action(self._variables):
try:
handler = f"_async_{cv.determine_script_action(self._action)}_step"
await getattr(self, handler)()
except Exception as ex:
if not isinstance(ex, (_StopScript, asyncio.CancelledError)) and (
self._log_exceptions or log_exceptions
):
self._log_exception(ex)
raise
with trace_path(str(self._step)):
async with trace_action(self._hass, self, self._stop, self._variables):
if self._stop.is_set():
return
try:
handler = f"_async_{cv.determine_script_action(self._action)}_step"
await getattr(self, handler)()
except Exception as ex:
if not isinstance(ex, (_StopScript, asyncio.CancelledError)) and (
self._log_exceptions or log_exceptions
):
self._log_exception(ex)
raise
def _finish(self) -> None:
self._script._runs.remove(self) # pylint: disable=protected-access
@@ -876,6 +937,8 @@ class Script:
all_scripts.append(
{"instance": self, "started_before_shutdown": not hass.is_stopping}
)
if DATA_SCRIPT_BREAKPOINTS not in hass.data:
hass.data[DATA_SCRIPT_BREAKPOINTS] = {}
self._hass = hass
self.sequence = sequence
@@ -1213,3 +1276,71 @@ class Script:
self._logger.exception(msg, *args, **kwargs)
else:
self._logger.log(level, msg, *args, **kwargs)
@callback
def breakpoint_clear(hass, unique_id, run_id, node):
"""Clear a breakpoint."""
run_id = run_id or RUN_ID_ANY
breakpoints = hass.data[DATA_SCRIPT_BREAKPOINTS]
if unique_id not in breakpoints or run_id not in breakpoints[unique_id]:
return
breakpoints[unique_id][run_id].discard(node)
@callback
def breakpoint_clear_all(hass):
"""Clear all breakpoints."""
hass.data[DATA_SCRIPT_BREAKPOINTS] = {}
@callback
def breakpoint_set(hass, unique_id, run_id, node):
"""Set a breakpoint."""
run_id = run_id or RUN_ID_ANY
breakpoints = hass.data[DATA_SCRIPT_BREAKPOINTS]
if unique_id not in breakpoints:
breakpoints[unique_id] = {}
if run_id not in breakpoints[unique_id]:
breakpoints[unique_id][run_id] = set()
breakpoints[unique_id][run_id].add(node)
@callback
def breakpoint_list(hass):
"""List breakpoints."""
breakpoints = hass.data[DATA_SCRIPT_BREAKPOINTS]
return [
{"unique_id": unique_id, "run_id": run_id, "node": node}
for unique_id in breakpoints
for run_id in breakpoints[unique_id]
for node in breakpoints[unique_id][run_id]
]
@callback
def debug_continue(hass, unique_id, run_id):
"""Continue execution of a halted script."""
# Clear any wildcard breakpoint
breakpoint_clear(hass, unique_id, run_id, NODE_ANY)
signal = SCRIPT_DEBUG_CONTINUE_STOP.format(unique_id, run_id)
async_dispatcher_send(hass, signal, "continue")
@callback
def debug_step(hass, unique_id, run_id):
"""Single step a halted script."""
# Set a wildcard breakpoint
breakpoint_set(hass, unique_id, run_id, NODE_ANY)
signal = SCRIPT_DEBUG_CONTINUE_STOP.format(unique_id, run_id)
async_dispatcher_send(hass, signal, "continue")
@callback
def debug_stop(hass, unique_id, run_id):
"""Stop execution of a running or halted script."""
signal = SCRIPT_DEBUG_CONTINUE_STOP.format(unique_id, run_id)
async_dispatcher_send(hass, signal, "stop")