Enable strict typing on script helper (#122075)

This commit is contained in:
Erik Montnemery 2024-07-17 13:51:59 +02:00 committed by GitHub
parent a0f91d27a3
commit efb7bede40
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 76 additions and 40 deletions

View File

@ -21,6 +21,7 @@ homeassistant.helpers.entity_platform
homeassistant.helpers.entity_values
homeassistant.helpers.event
homeassistant.helpers.reload
homeassistant.helpers.script
homeassistant.helpers.script_variables
homeassistant.helpers.singleton
homeassistant.helpers.sun

View File

@ -13,7 +13,7 @@ from functools import cached_property, partial
import itertools
import logging
from types import MappingProxyType
from typing import Any, Literal, TypedDict, cast
from typing import Any, Literal, TypedDict, cast, overload
import async_interrupt
import voluptuous as vol
@ -75,6 +75,7 @@ from homeassistant.core import (
HassJob,
HomeAssistant,
ServiceResponse,
State,
SupportsResponse,
callback,
)
@ -107,9 +108,7 @@ from .trace import (
trace_update_result,
)
from .trigger import async_initialize_triggers, async_validate_trigger_config
from .typing import UNDEFINED, ConfigType, UndefinedType
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
from .typing import UNDEFINED, ConfigType, TemplateVarsType, UndefinedType
SCRIPT_MODE_PARALLEL = "parallel"
SCRIPT_MODE_QUEUED = "queued"
@ -177,7 +176,7 @@ def _set_result_unless_done(future: asyncio.Future[None]) -> None:
future.set_result(None)
def action_trace_append(variables, path):
def action_trace_append(variables: dict[str, Any], path: str) -> TraceElement:
"""Append a TraceElement to trace[path]."""
trace_element = TraceElement(variables, path)
trace_append_element(trace_element, ACTION_TRACE_NODE_MAX_LEN)
@ -430,7 +429,7 @@ class _ScriptRun:
if not self._stop.done():
self._script._changed() # noqa: SLF001
async def _async_get_condition(self, config):
async def _async_get_condition(self, config: ConfigType) -> ConditionCheckerType:
return await self._script._async_get_condition(config) # noqa: SLF001
def _log(
@ -438,7 +437,7 @@ class _ScriptRun:
) -> None:
self._script._log(msg, *args, level=level, **kwargs) # noqa: SLF001
def _step_log(self, default_message, timeout=None):
def _step_log(self, default_message: str, timeout: float | None = None) -> None:
self._script.last_action = self._action.get(CONF_ALIAS, default_message)
_timeout = (
"" if timeout is None else f" (timeout: {timedelta(seconds=timeout)})"
@ -580,7 +579,7 @@ class _ScriptRun:
if not isinstance(exception, exceptions.HomeAssistantError):
raise exception
def _log_exception(self, exception):
def _log_exception(self, exception: Exception) -> None:
action_type = cv.determine_script_action(self._action)
error = str(exception)
@ -629,7 +628,7 @@ class _ScriptRun:
)
raise _AbortScript from ex
async def _async_delay_step(self):
async def _async_delay_step(self) -> None:
"""Handle delay."""
delay_delta = self._get_pos_time_period_template(CONF_DELAY)
@ -661,7 +660,7 @@ class _ScriptRun:
return self._get_pos_time_period_template(CONF_TIMEOUT).total_seconds()
return None
async def _async_wait_template_step(self):
async def _async_wait_template_step(self) -> None:
"""Handle a wait template."""
timeout = self._get_timeout_seconds_from_action()
self._step_log("wait template", timeout)
@ -690,7 +689,9 @@ class _ScriptRun:
futures.append(done)
@callback
def async_script_wait(entity_id, from_s, to_s):
def async_script_wait(
entity_id: str, from_s: State | None, to_s: State | None
) -> None:
"""Handle script after template condition is true."""
self._async_set_remaining_time_var(timeout_handle)
self._variables["wait"]["completed"] = True
@ -727,7 +728,7 @@ class _ScriptRun:
except ScriptStoppedError as ex:
raise asyncio.CancelledError from ex
async def _async_call_service_step(self):
async def _async_call_service_step(self) -> None:
"""Call the service specified in the action."""
self._step_log("call service")
@ -774,14 +775,14 @@ class _ScriptRun:
if response_variable:
self._variables[response_variable] = response_data
async def _async_device_step(self):
async def _async_device_step(self) -> None:
"""Perform the device automation specified in the action."""
self._step_log("device automation")
await device_action.async_call_action_from_config(
self._hass, self._action, self._variables, self._context
)
async def _async_scene_step(self):
async def _async_scene_step(self) -> None:
"""Activate the scene specified in the action."""
self._step_log("activate scene")
trace_set_result(scene=self._action[CONF_SCENE])
@ -793,7 +794,7 @@ class _ScriptRun:
context=self._context,
)
async def _async_event_step(self):
async def _async_event_step(self) -> None:
"""Fire an event."""
self._step_log(self._action.get(CONF_ALIAS, self._action[CONF_EVENT]))
event_data = {}
@ -815,7 +816,7 @@ class _ScriptRun:
self._action[CONF_EVENT], event_data, context=self._context
)
async def _async_condition_step(self):
async def _async_condition_step(self) -> None:
"""Test if condition is matching."""
self._script.last_action = self._action.get(
CONF_ALIAS, self._action[CONF_CONDITION]
@ -835,12 +836,19 @@ class _ScriptRun:
if not check:
raise _ConditionFail
def _test_conditions(self, conditions, name, condition_path=None):
def _test_conditions(
self,
conditions: list[ConditionCheckerType],
name: str,
condition_path: str | None = None,
) -> bool | None:
if condition_path is None:
condition_path = name
@trace_condition_function
def traced_test_conditions(hass, variables):
def traced_test_conditions(
hass: HomeAssistant, variables: TemplateVarsType
) -> bool | None:
try:
with trace_path(condition_path):
for idx, cond in enumerate(conditions):
@ -856,7 +864,7 @@ class _ScriptRun:
return traced_test_conditions(self._hass, self._variables)
@async_trace_path("repeat")
async def _async_repeat_step(self): # noqa: C901
async def _async_repeat_step(self) -> None: # noqa: C901
"""Repeat a sequence."""
description = self._action.get(CONF_ALIAS, "sequence")
repeat = self._action[CONF_REPEAT]
@ -876,7 +884,7 @@ class _ScriptRun:
script = self._script._get_repeat_script(self._step) # noqa: SLF001
warned_too_many_loops = False
async def async_run_sequence(iteration, extra_msg=""):
async def async_run_sequence(iteration: int, extra_msg: str = "") -> None:
self._log("Repeating %s: Iteration %i%s", description, iteration, extra_msg)
with trace_path("sequence"):
await self._async_run_script(script)
@ -1052,7 +1060,7 @@ class _ScriptRun:
"""If sequence."""
if_data = await self._script._async_get_if_data(self._step) # noqa: SLF001
test_conditions = False
test_conditions: bool | None = False
try:
with trace_path("if"):
test_conditions = self._test_conditions(
@ -1072,6 +1080,26 @@ class _ScriptRun:
with trace_path("else"):
await self._async_run_script(if_data["if_else"])
@overload
def _async_futures_with_timeout(
self,
timeout: float,
) -> tuple[
list[asyncio.Future[None]],
asyncio.TimerHandle,
asyncio.Future[None],
]: ...
@overload
def _async_futures_with_timeout(
self,
timeout: None,
) -> tuple[
list[asyncio.Future[None]],
None,
None,
]: ...
def _async_futures_with_timeout(
self,
timeout: float | None,
@ -1098,7 +1126,7 @@ class _ScriptRun:
futures.append(timeout_future)
return futures, timeout_handle, timeout_future
async def _async_wait_for_trigger_step(self):
async def _async_wait_for_trigger_step(self) -> None:
"""Wait for a trigger event."""
timeout = self._get_timeout_seconds_from_action()
@ -1119,12 +1147,14 @@ class _ScriptRun:
done = self._hass.loop.create_future()
futures.append(done)
async def async_done(variables, context=None):
async def async_done(
variables: dict[str, Any], context: Context | None = None
) -> None:
self._async_set_remaining_time_var(timeout_handle)
self._variables["wait"]["trigger"] = variables["trigger"]
_set_result_unless_done(done)
def log_cb(level, msg, **kwargs):
def log_cb(level: int, msg: str, **kwargs: Any) -> None:
self._log(msg, level=level, **kwargs)
remove_triggers = await async_initialize_triggers(
@ -1168,14 +1198,14 @@ class _ScriptRun:
unsub()
async def _async_variables_step(self):
async def _async_variables_step(self) -> None:
"""Set a variable value."""
self._step_log("setting variables")
self._variables = self._action[CONF_VARIABLES].async_render(
self._hass, self._variables, render_as_defaults=False
)
async def _async_set_conversation_response_step(self):
async def _async_set_conversation_response_step(self) -> None:
"""Set conversation response."""
self._step_log("setting conversation response")
resp: template.Template | None = self._action[CONF_SET_CONVERSATION_RESPONSE]
@ -1187,7 +1217,7 @@ class _ScriptRun:
)
trace_set_result(conversation_response=self._conversation_response)
async def _async_stop_step(self):
async def _async_stop_step(self) -> None:
"""Stop script execution."""
stop = self._action[CONF_STOP]
error = self._action.get(CONF_ERROR, False)
@ -1320,7 +1350,7 @@ async def _async_stop_scripts_at_shutdown(hass: HomeAssistant, event: Event) ->
)
type _VarsType = dict[str, Any] | MappingProxyType
type _VarsType = dict[str, Any] | MappingProxyType[str, Any]
def _referenced_extract_ids(data: Any, key: str, found: set[str]) -> None:
@ -1358,7 +1388,7 @@ class ScriptRunResult:
conversation_response: str | None | UndefinedType
service_response: ServiceResponse
variables: dict
variables: dict[str, Any]
class Script:
@ -1413,7 +1443,7 @@ class Script:
self._set_logger(logger)
self._log_exceptions = log_exceptions
self.last_action = None
self.last_action: str | None = None
self.last_triggered: datetime | None = None
self._runs: list[_ScriptRun] = []
@ -1421,7 +1451,7 @@ class Script:
self._max_exceeded = max_exceeded
if script_mode == SCRIPT_MODE_QUEUED:
self._queue_lck = asyncio.Lock()
self._config_cache: dict[set[tuple], Callable[..., bool]] = {}
self._config_cache: dict[frozenset[tuple[str, str]], ConditionCheckerType] = {}
self._repeat_script: dict[int, Script] = {}
self._choose_data: dict[int, _ChooseData] = {}
self._if_data: dict[int, _IfData] = {}
@ -1714,9 +1744,11 @@ class Script:
variables["context"] = context
elif self._copy_variables_on_run:
variables = cast(dict, copy(run_variables))
# This is not the top level script, variables have been turned to a dict
variables = cast(dict[str, Any], copy(run_variables))
else:
variables = cast(dict, run_variables)
# This is not the top level script, variables have been turned to a dict
variables = cast(dict[str, Any], run_variables)
# Prevent non-allowed recursive calls which will cause deadlocks when we try to
# stop (restart) or wait for (queued) our own script run.
@ -1745,9 +1777,7 @@ class Script:
cls = _ScriptRun
else:
cls = _QueuedScriptRun
run = cls(
self._hass, self, cast(dict, variables), context, self._log_exceptions
)
run = cls(self._hass, self, variables, context, self._log_exceptions)
has_existing_runs = bool(self._runs)
self._runs.append(run)
if self.script_mode == SCRIPT_MODE_RESTART and has_existing_runs:
@ -1772,7 +1802,9 @@ class Script:
self._changed()
raise
async def _async_stop(self, aws: list[asyncio.Task], update_state: bool) -> None:
async def _async_stop(
self, aws: list[asyncio.Task[None]], update_state: bool
) -> None:
await asyncio.wait(aws)
if update_state:
self._changed()
@ -1791,7 +1823,7 @@ class Script:
return
await asyncio.shield(create_eager_task(self._async_stop(aws, update_state)))
async def _async_get_condition(self, config):
async def _async_get_condition(self, config: ConfigType) -> ConditionCheckerType:
config_cache_key = frozenset((k, str(v)) for k, v in config.items())
if not (cond := self._config_cache.get(config_cache_key)):
cond = await condition.async_from_config(self._hass, config)

View File

@ -34,7 +34,7 @@ class TraceElement:
"""Container for trace data."""
self._child_key: str | None = None
self._child_run_id: str | None = None
self._error: Exception | None = None
self._error: BaseException | None = None
self.path: str = path
self._result: dict[str, Any] | None = None
self.reuse_by_child = False
@ -52,7 +52,7 @@ class TraceElement:
self._child_key = child_key
self._child_run_id = child_run_id
def set_error(self, ex: Exception) -> None:
def set_error(self, ex: BaseException | None) -> None:
"""Set error."""
self._error = ex

View File

@ -85,6 +85,9 @@ disallow_any_generics = true
[mypy-homeassistant.helpers.reload]
disallow_any_generics = true
[mypy-homeassistant.helpers.script]
disallow_any_generics = true
[mypy-homeassistant.helpers.script_variables]
disallow_any_generics = true