mirror of
https://github.com/home-assistant/core.git
synced 2025-04-24 09:17:53 +00:00
Enable strict typing on script helper (#122075)
This commit is contained in:
parent
a0f91d27a3
commit
efb7bede40
@ -21,6 +21,7 @@ homeassistant.helpers.entity_platform
|
||||
homeassistant.helpers.entity_values
|
||||
homeassistant.helpers.event
|
||||
homeassistant.helpers.reload
|
||||
homeassistant.helpers.script
|
||||
homeassistant.helpers.script_variables
|
||||
homeassistant.helpers.singleton
|
||||
homeassistant.helpers.sun
|
||||
|
@ -13,7 +13,7 @@ from functools import cached_property, partial
|
||||
import itertools
|
||||
import logging
|
||||
from types import MappingProxyType
|
||||
from typing import Any, Literal, TypedDict, cast
|
||||
from typing import Any, Literal, TypedDict, cast, overload
|
||||
|
||||
import async_interrupt
|
||||
import voluptuous as vol
|
||||
@ -75,6 +75,7 @@ from homeassistant.core import (
|
||||
HassJob,
|
||||
HomeAssistant,
|
||||
ServiceResponse,
|
||||
State,
|
||||
SupportsResponse,
|
||||
callback,
|
||||
)
|
||||
@ -107,9 +108,7 @@ from .trace import (
|
||||
trace_update_result,
|
||||
)
|
||||
from .trigger import async_initialize_triggers, async_validate_trigger_config
|
||||
from .typing import UNDEFINED, ConfigType, UndefinedType
|
||||
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
|
||||
from .typing import UNDEFINED, ConfigType, TemplateVarsType, UndefinedType
|
||||
|
||||
SCRIPT_MODE_PARALLEL = "parallel"
|
||||
SCRIPT_MODE_QUEUED = "queued"
|
||||
@ -177,7 +176,7 @@ def _set_result_unless_done(future: asyncio.Future[None]) -> None:
|
||||
future.set_result(None)
|
||||
|
||||
|
||||
def action_trace_append(variables, path):
|
||||
def action_trace_append(variables: dict[str, Any], path: str) -> TraceElement:
|
||||
"""Append a TraceElement to trace[path]."""
|
||||
trace_element = TraceElement(variables, path)
|
||||
trace_append_element(trace_element, ACTION_TRACE_NODE_MAX_LEN)
|
||||
@ -430,7 +429,7 @@ class _ScriptRun:
|
||||
if not self._stop.done():
|
||||
self._script._changed() # noqa: SLF001
|
||||
|
||||
async def _async_get_condition(self, config):
|
||||
async def _async_get_condition(self, config: ConfigType) -> ConditionCheckerType:
|
||||
return await self._script._async_get_condition(config) # noqa: SLF001
|
||||
|
||||
def _log(
|
||||
@ -438,7 +437,7 @@ class _ScriptRun:
|
||||
) -> None:
|
||||
self._script._log(msg, *args, level=level, **kwargs) # noqa: SLF001
|
||||
|
||||
def _step_log(self, default_message, timeout=None):
|
||||
def _step_log(self, default_message: str, timeout: float | None = None) -> None:
|
||||
self._script.last_action = self._action.get(CONF_ALIAS, default_message)
|
||||
_timeout = (
|
||||
"" if timeout is None else f" (timeout: {timedelta(seconds=timeout)})"
|
||||
@ -580,7 +579,7 @@ class _ScriptRun:
|
||||
if not isinstance(exception, exceptions.HomeAssistantError):
|
||||
raise exception
|
||||
|
||||
def _log_exception(self, exception):
|
||||
def _log_exception(self, exception: Exception) -> None:
|
||||
action_type = cv.determine_script_action(self._action)
|
||||
|
||||
error = str(exception)
|
||||
@ -629,7 +628,7 @@ class _ScriptRun:
|
||||
)
|
||||
raise _AbortScript from ex
|
||||
|
||||
async def _async_delay_step(self):
|
||||
async def _async_delay_step(self) -> None:
|
||||
"""Handle delay."""
|
||||
delay_delta = self._get_pos_time_period_template(CONF_DELAY)
|
||||
|
||||
@ -661,7 +660,7 @@ class _ScriptRun:
|
||||
return self._get_pos_time_period_template(CONF_TIMEOUT).total_seconds()
|
||||
return None
|
||||
|
||||
async def _async_wait_template_step(self):
|
||||
async def _async_wait_template_step(self) -> None:
|
||||
"""Handle a wait template."""
|
||||
timeout = self._get_timeout_seconds_from_action()
|
||||
self._step_log("wait template", timeout)
|
||||
@ -690,7 +689,9 @@ class _ScriptRun:
|
||||
futures.append(done)
|
||||
|
||||
@callback
|
||||
def async_script_wait(entity_id, from_s, to_s):
|
||||
def async_script_wait(
|
||||
entity_id: str, from_s: State | None, to_s: State | None
|
||||
) -> None:
|
||||
"""Handle script after template condition is true."""
|
||||
self._async_set_remaining_time_var(timeout_handle)
|
||||
self._variables["wait"]["completed"] = True
|
||||
@ -727,7 +728,7 @@ class _ScriptRun:
|
||||
except ScriptStoppedError as ex:
|
||||
raise asyncio.CancelledError from ex
|
||||
|
||||
async def _async_call_service_step(self):
|
||||
async def _async_call_service_step(self) -> None:
|
||||
"""Call the service specified in the action."""
|
||||
self._step_log("call service")
|
||||
|
||||
@ -774,14 +775,14 @@ class _ScriptRun:
|
||||
if response_variable:
|
||||
self._variables[response_variable] = response_data
|
||||
|
||||
async def _async_device_step(self):
|
||||
async def _async_device_step(self) -> None:
|
||||
"""Perform the device automation specified in the action."""
|
||||
self._step_log("device automation")
|
||||
await device_action.async_call_action_from_config(
|
||||
self._hass, self._action, self._variables, self._context
|
||||
)
|
||||
|
||||
async def _async_scene_step(self):
|
||||
async def _async_scene_step(self) -> None:
|
||||
"""Activate the scene specified in the action."""
|
||||
self._step_log("activate scene")
|
||||
trace_set_result(scene=self._action[CONF_SCENE])
|
||||
@ -793,7 +794,7 @@ class _ScriptRun:
|
||||
context=self._context,
|
||||
)
|
||||
|
||||
async def _async_event_step(self):
|
||||
async def _async_event_step(self) -> None:
|
||||
"""Fire an event."""
|
||||
self._step_log(self._action.get(CONF_ALIAS, self._action[CONF_EVENT]))
|
||||
event_data = {}
|
||||
@ -815,7 +816,7 @@ class _ScriptRun:
|
||||
self._action[CONF_EVENT], event_data, context=self._context
|
||||
)
|
||||
|
||||
async def _async_condition_step(self):
|
||||
async def _async_condition_step(self) -> None:
|
||||
"""Test if condition is matching."""
|
||||
self._script.last_action = self._action.get(
|
||||
CONF_ALIAS, self._action[CONF_CONDITION]
|
||||
@ -835,12 +836,19 @@ class _ScriptRun:
|
||||
if not check:
|
||||
raise _ConditionFail
|
||||
|
||||
def _test_conditions(self, conditions, name, condition_path=None):
|
||||
def _test_conditions(
|
||||
self,
|
||||
conditions: list[ConditionCheckerType],
|
||||
name: str,
|
||||
condition_path: str | None = None,
|
||||
) -> bool | None:
|
||||
if condition_path is None:
|
||||
condition_path = name
|
||||
|
||||
@trace_condition_function
|
||||
def traced_test_conditions(hass, variables):
|
||||
def traced_test_conditions(
|
||||
hass: HomeAssistant, variables: TemplateVarsType
|
||||
) -> bool | None:
|
||||
try:
|
||||
with trace_path(condition_path):
|
||||
for idx, cond in enumerate(conditions):
|
||||
@ -856,7 +864,7 @@ class _ScriptRun:
|
||||
return traced_test_conditions(self._hass, self._variables)
|
||||
|
||||
@async_trace_path("repeat")
|
||||
async def _async_repeat_step(self): # noqa: C901
|
||||
async def _async_repeat_step(self) -> None: # noqa: C901
|
||||
"""Repeat a sequence."""
|
||||
description = self._action.get(CONF_ALIAS, "sequence")
|
||||
repeat = self._action[CONF_REPEAT]
|
||||
@ -876,7 +884,7 @@ class _ScriptRun:
|
||||
script = self._script._get_repeat_script(self._step) # noqa: SLF001
|
||||
warned_too_many_loops = False
|
||||
|
||||
async def async_run_sequence(iteration, extra_msg=""):
|
||||
async def async_run_sequence(iteration: int, extra_msg: str = "") -> None:
|
||||
self._log("Repeating %s: Iteration %i%s", description, iteration, extra_msg)
|
||||
with trace_path("sequence"):
|
||||
await self._async_run_script(script)
|
||||
@ -1052,7 +1060,7 @@ class _ScriptRun:
|
||||
"""If sequence."""
|
||||
if_data = await self._script._async_get_if_data(self._step) # noqa: SLF001
|
||||
|
||||
test_conditions = False
|
||||
test_conditions: bool | None = False
|
||||
try:
|
||||
with trace_path("if"):
|
||||
test_conditions = self._test_conditions(
|
||||
@ -1072,6 +1080,26 @@ class _ScriptRun:
|
||||
with trace_path("else"):
|
||||
await self._async_run_script(if_data["if_else"])
|
||||
|
||||
@overload
|
||||
def _async_futures_with_timeout(
|
||||
self,
|
||||
timeout: float,
|
||||
) -> tuple[
|
||||
list[asyncio.Future[None]],
|
||||
asyncio.TimerHandle,
|
||||
asyncio.Future[None],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def _async_futures_with_timeout(
|
||||
self,
|
||||
timeout: None,
|
||||
) -> tuple[
|
||||
list[asyncio.Future[None]],
|
||||
None,
|
||||
None,
|
||||
]: ...
|
||||
|
||||
def _async_futures_with_timeout(
|
||||
self,
|
||||
timeout: float | None,
|
||||
@ -1098,7 +1126,7 @@ class _ScriptRun:
|
||||
futures.append(timeout_future)
|
||||
return futures, timeout_handle, timeout_future
|
||||
|
||||
async def _async_wait_for_trigger_step(self):
|
||||
async def _async_wait_for_trigger_step(self) -> None:
|
||||
"""Wait for a trigger event."""
|
||||
timeout = self._get_timeout_seconds_from_action()
|
||||
|
||||
@ -1119,12 +1147,14 @@ class _ScriptRun:
|
||||
done = self._hass.loop.create_future()
|
||||
futures.append(done)
|
||||
|
||||
async def async_done(variables, context=None):
|
||||
async def async_done(
|
||||
variables: dict[str, Any], context: Context | None = None
|
||||
) -> None:
|
||||
self._async_set_remaining_time_var(timeout_handle)
|
||||
self._variables["wait"]["trigger"] = variables["trigger"]
|
||||
_set_result_unless_done(done)
|
||||
|
||||
def log_cb(level, msg, **kwargs):
|
||||
def log_cb(level: int, msg: str, **kwargs: Any) -> None:
|
||||
self._log(msg, level=level, **kwargs)
|
||||
|
||||
remove_triggers = await async_initialize_triggers(
|
||||
@ -1168,14 +1198,14 @@ class _ScriptRun:
|
||||
|
||||
unsub()
|
||||
|
||||
async def _async_variables_step(self):
|
||||
async def _async_variables_step(self) -> None:
|
||||
"""Set a variable value."""
|
||||
self._step_log("setting variables")
|
||||
self._variables = self._action[CONF_VARIABLES].async_render(
|
||||
self._hass, self._variables, render_as_defaults=False
|
||||
)
|
||||
|
||||
async def _async_set_conversation_response_step(self):
|
||||
async def _async_set_conversation_response_step(self) -> None:
|
||||
"""Set conversation response."""
|
||||
self._step_log("setting conversation response")
|
||||
resp: template.Template | None = self._action[CONF_SET_CONVERSATION_RESPONSE]
|
||||
@ -1187,7 +1217,7 @@ class _ScriptRun:
|
||||
)
|
||||
trace_set_result(conversation_response=self._conversation_response)
|
||||
|
||||
async def _async_stop_step(self):
|
||||
async def _async_stop_step(self) -> None:
|
||||
"""Stop script execution."""
|
||||
stop = self._action[CONF_STOP]
|
||||
error = self._action.get(CONF_ERROR, False)
|
||||
@ -1320,7 +1350,7 @@ async def _async_stop_scripts_at_shutdown(hass: HomeAssistant, event: Event) ->
|
||||
)
|
||||
|
||||
|
||||
type _VarsType = dict[str, Any] | MappingProxyType
|
||||
type _VarsType = dict[str, Any] | MappingProxyType[str, Any]
|
||||
|
||||
|
||||
def _referenced_extract_ids(data: Any, key: str, found: set[str]) -> None:
|
||||
@ -1358,7 +1388,7 @@ class ScriptRunResult:
|
||||
|
||||
conversation_response: str | None | UndefinedType
|
||||
service_response: ServiceResponse
|
||||
variables: dict
|
||||
variables: dict[str, Any]
|
||||
|
||||
|
||||
class Script:
|
||||
@ -1413,7 +1443,7 @@ class Script:
|
||||
self._set_logger(logger)
|
||||
self._log_exceptions = log_exceptions
|
||||
|
||||
self.last_action = None
|
||||
self.last_action: str | None = None
|
||||
self.last_triggered: datetime | None = None
|
||||
|
||||
self._runs: list[_ScriptRun] = []
|
||||
@ -1421,7 +1451,7 @@ class Script:
|
||||
self._max_exceeded = max_exceeded
|
||||
if script_mode == SCRIPT_MODE_QUEUED:
|
||||
self._queue_lck = asyncio.Lock()
|
||||
self._config_cache: dict[set[tuple], Callable[..., bool]] = {}
|
||||
self._config_cache: dict[frozenset[tuple[str, str]], ConditionCheckerType] = {}
|
||||
self._repeat_script: dict[int, Script] = {}
|
||||
self._choose_data: dict[int, _ChooseData] = {}
|
||||
self._if_data: dict[int, _IfData] = {}
|
||||
@ -1714,9 +1744,11 @@ class Script:
|
||||
|
||||
variables["context"] = context
|
||||
elif self._copy_variables_on_run:
|
||||
variables = cast(dict, copy(run_variables))
|
||||
# This is not the top level script, variables have been turned to a dict
|
||||
variables = cast(dict[str, Any], copy(run_variables))
|
||||
else:
|
||||
variables = cast(dict, run_variables)
|
||||
# This is not the top level script, variables have been turned to a dict
|
||||
variables = cast(dict[str, Any], run_variables)
|
||||
|
||||
# Prevent non-allowed recursive calls which will cause deadlocks when we try to
|
||||
# stop (restart) or wait for (queued) our own script run.
|
||||
@ -1745,9 +1777,7 @@ class Script:
|
||||
cls = _ScriptRun
|
||||
else:
|
||||
cls = _QueuedScriptRun
|
||||
run = cls(
|
||||
self._hass, self, cast(dict, variables), context, self._log_exceptions
|
||||
)
|
||||
run = cls(self._hass, self, variables, context, self._log_exceptions)
|
||||
has_existing_runs = bool(self._runs)
|
||||
self._runs.append(run)
|
||||
if self.script_mode == SCRIPT_MODE_RESTART and has_existing_runs:
|
||||
@ -1772,7 +1802,9 @@ class Script:
|
||||
self._changed()
|
||||
raise
|
||||
|
||||
async def _async_stop(self, aws: list[asyncio.Task], update_state: bool) -> None:
|
||||
async def _async_stop(
|
||||
self, aws: list[asyncio.Task[None]], update_state: bool
|
||||
) -> None:
|
||||
await asyncio.wait(aws)
|
||||
if update_state:
|
||||
self._changed()
|
||||
@ -1791,7 +1823,7 @@ class Script:
|
||||
return
|
||||
await asyncio.shield(create_eager_task(self._async_stop(aws, update_state)))
|
||||
|
||||
async def _async_get_condition(self, config):
|
||||
async def _async_get_condition(self, config: ConfigType) -> ConditionCheckerType:
|
||||
config_cache_key = frozenset((k, str(v)) for k, v in config.items())
|
||||
if not (cond := self._config_cache.get(config_cache_key)):
|
||||
cond = await condition.async_from_config(self._hass, config)
|
||||
|
@ -34,7 +34,7 @@ class TraceElement:
|
||||
"""Container for trace data."""
|
||||
self._child_key: str | None = None
|
||||
self._child_run_id: str | None = None
|
||||
self._error: Exception | None = None
|
||||
self._error: BaseException | None = None
|
||||
self.path: str = path
|
||||
self._result: dict[str, Any] | None = None
|
||||
self.reuse_by_child = False
|
||||
@ -52,7 +52,7 @@ class TraceElement:
|
||||
self._child_key = child_key
|
||||
self._child_run_id = child_run_id
|
||||
|
||||
def set_error(self, ex: Exception) -> None:
|
||||
def set_error(self, ex: BaseException | None) -> None:
|
||||
"""Set error."""
|
||||
self._error = ex
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user