diff --git a/.strict-typing b/.strict-typing index 1707f0ca9c3..7d85af77868 100644 --- a/.strict-typing +++ b/.strict-typing @@ -21,6 +21,7 @@ homeassistant.helpers.entity_platform homeassistant.helpers.entity_values homeassistant.helpers.event homeassistant.helpers.reload +homeassistant.helpers.script homeassistant.helpers.script_variables homeassistant.helpers.singleton homeassistant.helpers.sun diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 70f8d460e29..a1b885d0c52 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -13,7 +13,7 @@ from functools import cached_property, partial import itertools import logging from types import MappingProxyType -from typing import Any, Literal, TypedDict, cast +from typing import Any, Literal, TypedDict, cast, overload import async_interrupt import voluptuous as vol @@ -75,6 +75,7 @@ from homeassistant.core import ( HassJob, HomeAssistant, ServiceResponse, + State, SupportsResponse, callback, ) @@ -107,9 +108,7 @@ from .trace import ( trace_update_result, ) from .trigger import async_initialize_triggers, async_validate_trigger_config -from .typing import UNDEFINED, ConfigType, UndefinedType - -# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs +from .typing import UNDEFINED, ConfigType, TemplateVarsType, UndefinedType SCRIPT_MODE_PARALLEL = "parallel" SCRIPT_MODE_QUEUED = "queued" @@ -177,7 +176,7 @@ def _set_result_unless_done(future: asyncio.Future[None]) -> None: future.set_result(None) -def action_trace_append(variables, path): +def action_trace_append(variables: dict[str, Any], path: str) -> TraceElement: """Append a TraceElement to trace[path].""" trace_element = TraceElement(variables, path) trace_append_element(trace_element, ACTION_TRACE_NODE_MAX_LEN) @@ -430,7 +429,7 @@ class _ScriptRun: if not self._stop.done(): self._script._changed() # noqa: SLF001 - async def _async_get_condition(self, config): + async def _async_get_condition(self, config: ConfigType) -> ConditionCheckerType: return await self._script._async_get_condition(config) # noqa: SLF001 def _log( @@ -438,7 +437,7 @@ class _ScriptRun: ) -> None: self._script._log(msg, *args, level=level, **kwargs) # noqa: SLF001 - def _step_log(self, default_message, timeout=None): + def _step_log(self, default_message: str, timeout: float | None = None) -> None: self._script.last_action = self._action.get(CONF_ALIAS, default_message) _timeout = ( "" if timeout is None else f" (timeout: {timedelta(seconds=timeout)})" @@ -580,7 +579,7 @@ class _ScriptRun: if not isinstance(exception, exceptions.HomeAssistantError): raise exception - def _log_exception(self, exception): + def _log_exception(self, exception: Exception) -> None: action_type = cv.determine_script_action(self._action) error = str(exception) @@ -629,7 +628,7 @@ class _ScriptRun: ) raise _AbortScript from ex - async def _async_delay_step(self): + async def _async_delay_step(self) -> None: """Handle delay.""" delay_delta = self._get_pos_time_period_template(CONF_DELAY) @@ -661,7 +660,7 @@ class _ScriptRun: return self._get_pos_time_period_template(CONF_TIMEOUT).total_seconds() return None - async def _async_wait_template_step(self): + async def _async_wait_template_step(self) -> None: """Handle a wait template.""" timeout = self._get_timeout_seconds_from_action() self._step_log("wait template", timeout) @@ -690,7 +689,9 @@ class _ScriptRun: futures.append(done) @callback - def async_script_wait(entity_id, from_s, to_s): + def async_script_wait( + entity_id: str, from_s: State | None, to_s: State | None + ) -> None: """Handle script after template condition is true.""" self._async_set_remaining_time_var(timeout_handle) self._variables["wait"]["completed"] = True @@ -727,7 +728,7 @@ class _ScriptRun: except ScriptStoppedError as ex: raise asyncio.CancelledError from ex - async def _async_call_service_step(self): + async def _async_call_service_step(self) -> None: """Call the service specified in the action.""" self._step_log("call service") @@ -774,14 +775,14 @@ class _ScriptRun: if response_variable: self._variables[response_variable] = response_data - async def _async_device_step(self): + async def _async_device_step(self) -> None: """Perform the device automation specified in the action.""" self._step_log("device automation") await device_action.async_call_action_from_config( self._hass, self._action, self._variables, self._context ) - async def _async_scene_step(self): + async def _async_scene_step(self) -> None: """Activate the scene specified in the action.""" self._step_log("activate scene") trace_set_result(scene=self._action[CONF_SCENE]) @@ -793,7 +794,7 @@ class _ScriptRun: context=self._context, ) - async def _async_event_step(self): + async def _async_event_step(self) -> None: """Fire an event.""" self._step_log(self._action.get(CONF_ALIAS, self._action[CONF_EVENT])) event_data = {} @@ -815,7 +816,7 @@ class _ScriptRun: self._action[CONF_EVENT], event_data, context=self._context ) - async def _async_condition_step(self): + async def _async_condition_step(self) -> None: """Test if condition is matching.""" self._script.last_action = self._action.get( CONF_ALIAS, self._action[CONF_CONDITION] @@ -835,12 +836,19 @@ class _ScriptRun: if not check: raise _ConditionFail - def _test_conditions(self, conditions, name, condition_path=None): + def _test_conditions( + self, + conditions: list[ConditionCheckerType], + name: str, + condition_path: str | None = None, + ) -> bool | None: if condition_path is None: condition_path = name @trace_condition_function - def traced_test_conditions(hass, variables): + def traced_test_conditions( + hass: HomeAssistant, variables: TemplateVarsType + ) -> bool | None: try: with trace_path(condition_path): for idx, cond in enumerate(conditions): @@ -856,7 +864,7 @@ class _ScriptRun: return traced_test_conditions(self._hass, self._variables) @async_trace_path("repeat") - async def _async_repeat_step(self): # noqa: C901 + async def _async_repeat_step(self) -> None: # noqa: C901 """Repeat a sequence.""" description = self._action.get(CONF_ALIAS, "sequence") repeat = self._action[CONF_REPEAT] @@ -876,7 +884,7 @@ class _ScriptRun: script = self._script._get_repeat_script(self._step) # noqa: SLF001 warned_too_many_loops = False - async def async_run_sequence(iteration, extra_msg=""): + async def async_run_sequence(iteration: int, extra_msg: str = "") -> None: self._log("Repeating %s: Iteration %i%s", description, iteration, extra_msg) with trace_path("sequence"): await self._async_run_script(script) @@ -1052,7 +1060,7 @@ class _ScriptRun: """If sequence.""" if_data = await self._script._async_get_if_data(self._step) # noqa: SLF001 - test_conditions = False + test_conditions: bool | None = False try: with trace_path("if"): test_conditions = self._test_conditions( @@ -1072,6 +1080,26 @@ class _ScriptRun: with trace_path("else"): await self._async_run_script(if_data["if_else"]) + @overload + def _async_futures_with_timeout( + self, + timeout: float, + ) -> tuple[ + list[asyncio.Future[None]], + asyncio.TimerHandle, + asyncio.Future[None], + ]: ... + + @overload + def _async_futures_with_timeout( + self, + timeout: None, + ) -> tuple[ + list[asyncio.Future[None]], + None, + None, + ]: ... + def _async_futures_with_timeout( self, timeout: float | None, @@ -1098,7 +1126,7 @@ class _ScriptRun: futures.append(timeout_future) return futures, timeout_handle, timeout_future - async def _async_wait_for_trigger_step(self): + async def _async_wait_for_trigger_step(self) -> None: """Wait for a trigger event.""" timeout = self._get_timeout_seconds_from_action() @@ -1119,12 +1147,14 @@ class _ScriptRun: done = self._hass.loop.create_future() futures.append(done) - async def async_done(variables, context=None): + async def async_done( + variables: dict[str, Any], context: Context | None = None + ) -> None: self._async_set_remaining_time_var(timeout_handle) self._variables["wait"]["trigger"] = variables["trigger"] _set_result_unless_done(done) - def log_cb(level, msg, **kwargs): + def log_cb(level: int, msg: str, **kwargs: Any) -> None: self._log(msg, level=level, **kwargs) remove_triggers = await async_initialize_triggers( @@ -1168,14 +1198,14 @@ class _ScriptRun: unsub() - async def _async_variables_step(self): + async def _async_variables_step(self) -> None: """Set a variable value.""" self._step_log("setting variables") self._variables = self._action[CONF_VARIABLES].async_render( self._hass, self._variables, render_as_defaults=False ) - async def _async_set_conversation_response_step(self): + async def _async_set_conversation_response_step(self) -> None: """Set conversation response.""" self._step_log("setting conversation response") resp: template.Template | None = self._action[CONF_SET_CONVERSATION_RESPONSE] @@ -1187,7 +1217,7 @@ class _ScriptRun: ) trace_set_result(conversation_response=self._conversation_response) - async def _async_stop_step(self): + async def _async_stop_step(self) -> None: """Stop script execution.""" stop = self._action[CONF_STOP] error = self._action.get(CONF_ERROR, False) @@ -1320,7 +1350,7 @@ async def _async_stop_scripts_at_shutdown(hass: HomeAssistant, event: Event) -> ) -type _VarsType = dict[str, Any] | MappingProxyType +type _VarsType = dict[str, Any] | MappingProxyType[str, Any] def _referenced_extract_ids(data: Any, key: str, found: set[str]) -> None: @@ -1358,7 +1388,7 @@ class ScriptRunResult: conversation_response: str | None | UndefinedType service_response: ServiceResponse - variables: dict + variables: dict[str, Any] class Script: @@ -1413,7 +1443,7 @@ class Script: self._set_logger(logger) self._log_exceptions = log_exceptions - self.last_action = None + self.last_action: str | None = None self.last_triggered: datetime | None = None self._runs: list[_ScriptRun] = [] @@ -1421,7 +1451,7 @@ class Script: self._max_exceeded = max_exceeded if script_mode == SCRIPT_MODE_QUEUED: self._queue_lck = asyncio.Lock() - self._config_cache: dict[set[tuple], Callable[..., bool]] = {} + self._config_cache: dict[frozenset[tuple[str, str]], ConditionCheckerType] = {} self._repeat_script: dict[int, Script] = {} self._choose_data: dict[int, _ChooseData] = {} self._if_data: dict[int, _IfData] = {} @@ -1714,9 +1744,11 @@ class Script: variables["context"] = context elif self._copy_variables_on_run: - variables = cast(dict, copy(run_variables)) + # This is not the top level script, variables have been turned to a dict + variables = cast(dict[str, Any], copy(run_variables)) else: - variables = cast(dict, run_variables) + # This is not the top level script, variables have been turned to a dict + variables = cast(dict[str, Any], run_variables) # Prevent non-allowed recursive calls which will cause deadlocks when we try to # stop (restart) or wait for (queued) our own script run. @@ -1745,9 +1777,7 @@ class Script: cls = _ScriptRun else: cls = _QueuedScriptRun - run = cls( - self._hass, self, cast(dict, variables), context, self._log_exceptions - ) + run = cls(self._hass, self, variables, context, self._log_exceptions) has_existing_runs = bool(self._runs) self._runs.append(run) if self.script_mode == SCRIPT_MODE_RESTART and has_existing_runs: @@ -1772,7 +1802,9 @@ class Script: self._changed() raise - async def _async_stop(self, aws: list[asyncio.Task], update_state: bool) -> None: + async def _async_stop( + self, aws: list[asyncio.Task[None]], update_state: bool + ) -> None: await asyncio.wait(aws) if update_state: self._changed() @@ -1791,7 +1823,7 @@ class Script: return await asyncio.shield(create_eager_task(self._async_stop(aws, update_state))) - async def _async_get_condition(self, config): + async def _async_get_condition(self, config: ConfigType) -> ConditionCheckerType: config_cache_key = frozenset((k, str(v)) for k, v in config.items()) if not (cond := self._config_cache.get(config_cache_key)): cond = await condition.async_from_config(self._hass, config) diff --git a/homeassistant/helpers/trace.py b/homeassistant/helpers/trace.py index a36939a0f60..431a7a7d1f8 100644 --- a/homeassistant/helpers/trace.py +++ b/homeassistant/helpers/trace.py @@ -34,7 +34,7 @@ class TraceElement: """Container for trace data.""" self._child_key: str | None = None self._child_run_id: str | None = None - self._error: Exception | None = None + self._error: BaseException | None = None self.path: str = path self._result: dict[str, Any] | None = None self.reuse_by_child = False @@ -52,7 +52,7 @@ class TraceElement: self._child_key = child_key self._child_run_id = child_run_id - def set_error(self, ex: Exception) -> None: + def set_error(self, ex: BaseException | None) -> None: """Set error.""" self._error = ex diff --git a/mypy.ini b/mypy.ini index cf16c4f5f63..b9b4eb2f469 100644 --- a/mypy.ini +++ b/mypy.ini @@ -85,6 +85,9 @@ disallow_any_generics = true [mypy-homeassistant.helpers.reload] disallow_any_generics = true +[mypy-homeassistant.helpers.script] +disallow_any_generics = true + [mypy-homeassistant.helpers.script_variables] disallow_any_generics = true