From 970cbcbe15926de9b3c8cb975732963d02fed1ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Skytt=C3=A4?= Date: Sat, 17 Apr 2021 09:35:21 +0300 Subject: [PATCH] Type hint improvements (#49320) --- homeassistant/core.py | 4 ++- homeassistant/helpers/config_validation.py | 2 +- homeassistant/helpers/event.py | 10 +++--- homeassistant/helpers/location.py | 4 +-- homeassistant/helpers/script.py | 41 ++++++++++++++-------- homeassistant/helpers/storage.py | 12 +++---- homeassistant/helpers/template.py | 29 +++++++-------- 7 files changed, 58 insertions(+), 44 deletions(-) diff --git a/homeassistant/core.py b/homeassistant/core.py index d172b3445e8..3b7fad883da 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -784,7 +784,9 @@ class EventBus: return remove_listener - def listen_once(self, event_type: str, listener: Callable) -> CALLBACK_TYPE: + def listen_once( + self, event_type: str, listener: Callable[[Event], None] + ) -> CALLBACK_TYPE: """Listen once for event of a specific type. To listen to all events specify the constant ``MATCH_ALL`` diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index 21d04f11551..bbac18ab839 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -1200,7 +1200,7 @@ SCRIPT_ACTION_WAIT_FOR_TRIGGER = "wait_for_trigger" SCRIPT_ACTION_VARIABLES = "variables" -def determine_script_action(action: dict) -> str: +def determine_script_action(action: dict[str, Any]) -> str: """Determine action type.""" if CONF_DELAY in action: return SCRIPT_ACTION_DELAY diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index d52ebdb551f..abba6f12a25 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -8,7 +8,7 @@ from datetime import datetime, timedelta import functools as ft import logging import time -from typing import Any, Awaitable, Callable, Iterable, List +from typing import Any, Awaitable, Callable, Iterable, List, cast import attr @@ -1453,10 +1453,10 @@ def process_state_match(parameter: None | str | Iterable[str]) -> Callable[[str] @callback def _entities_domains_from_render_infos( render_infos: Iterable[RenderInfo], -) -> tuple[set, set]: +) -> tuple[set[str], set[str]]: """Combine from multiple RenderInfo.""" - entities = set() - domains = set() + entities: set[str] = set() + domains: set[str] = set() for render_info in render_infos: if render_info.entities: @@ -1497,7 +1497,7 @@ def _render_infos_to_track_states(render_infos: Iterable[RenderInfo]) -> TrackSt @callback def _event_triggers_rerender(event: Event, info: RenderInfo) -> bool: """Determine if a template should be re-rendered from an event.""" - entity_id = event.data.get(ATTR_ENTITY_ID) + entity_id = cast(str, event.data.get(ATTR_ENTITY_ID)) if info.filter(entity_id): return True diff --git a/homeassistant/helpers/location.py b/homeassistant/helpers/location.py index ff27c580d23..597787ac173 100644 --- a/homeassistant/helpers/location.py +++ b/homeassistant/helpers/location.py @@ -2,7 +2,7 @@ from __future__ import annotations import logging -from typing import Sequence +from typing import Iterable import voluptuous as vol @@ -25,7 +25,7 @@ def has_location(state: State) -> bool: ) -def closest(latitude: float, longitude: float, states: Sequence[State]) -> State | None: +def closest(latitude: float, longitude: float, states: Iterable[State]) -> State | None: """Return closest state to point. Async friendly. diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 6ecb25dfff1..7103fe17ac9 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -8,7 +8,7 @@ from functools import partial import itertools import logging from types import MappingProxyType -from typing import Any, Callable, Dict, Sequence, Union, cast +from typing import Any, Callable, Dict, Sequence, TypedDict, Union, cast import async_timeout import voluptuous as vol @@ -56,7 +56,10 @@ from homeassistant.core import ( callback, ) from homeassistant.helpers import condition, config_validation as cv, service, template -from homeassistant.helpers.condition import trace_condition_function +from homeassistant.helpers.condition import ( + ConditionCheckerType, + trace_condition_function, +) from homeassistant.helpers.dispatcher import ( async_dispatcher_connect, async_dispatcher_send, @@ -492,7 +495,7 @@ class _ScriptRun: task.cancel() unsub() - async def _async_run_long_action(self, long_task): + async def _async_run_long_action(self, long_task: asyncio.tasks.Task) -> None: """Run a long task while monitoring for stop request.""" async def async_cancel_long_task() -> None: @@ -741,7 +744,7 @@ class _ScriptRun: except exceptions.ConditionError as ex: _LOGGER.warning("Error in 'choose' evaluation:\n%s", ex) - if choose_data["default"]: + if choose_data["default"] is not None: trace_set_result(choice="default") with trace_path(["default"]): await self._async_run_script(choose_data["default"]) @@ -808,7 +811,7 @@ class _ScriptRun: self._hass, self._variables, render_as_defaults=False ) - async def _async_run_script(self, script): + async def _async_run_script(self, script: Script) -> None: """Execute a script.""" await self._async_run_long_action( self._hass.async_create_task( @@ -912,6 +915,11 @@ def _referenced_extract_ids(data: dict[str, Any], key: str, found: set[str]) -> found.add(item_id) +class _ChooseData(TypedDict): + choices: list[tuple[list[ConditionCheckerType], Script]] + default: Script | None + + class Script: """Representation of a script.""" @@ -973,7 +981,7 @@ class Script: self._queue_lck = asyncio.Lock() self._config_cache: dict[set[tuple], Callable[..., bool]] = {} self._repeat_script: dict[int, Script] = {} - self._choose_data: dict[int, dict[str, Any]] = {} + self._choose_data: dict[int, _ChooseData] = {} self._referenced_entities: set[str] | None = None self._referenced_devices: set[str] | None = None self._referenced_areas: set[str] | None = None @@ -1011,14 +1019,14 @@ class Script: for choose_data in self._choose_data.values(): for _, script in choose_data["choices"]: script.update_logger(self._logger) - if choose_data["default"]: + if choose_data["default"] is not None: choose_data["default"].update_logger(self._logger) def _changed(self) -> None: if self._change_listener_job: self._hass.async_run_hass_job(self._change_listener_job) - def _chain_change_listener(self, sub_script): + def _chain_change_listener(self, sub_script: Script) -> None: if sub_script.is_running: self.last_action = sub_script.last_action self._changed() @@ -1203,7 +1211,9 @@ class Script: self._changed() raise - async def _async_stop(self, update_state, spare=None): + async def _async_stop( + self, update_state: bool, spare: _ScriptRun | None = None + ) -> None: aws = [ asyncio.create_task(run.async_stop()) for run in self._runs if run != spare ] @@ -1230,7 +1240,7 @@ class Script: self._config_cache[config_cache_key] = cond return cond - def _prep_repeat_script(self, step): + def _prep_repeat_script(self, step: int) -> Script: action = self.sequence[step] step_name = action.get(CONF_ALIAS, f"Repeat at step {step+1}") sub_script = Script( @@ -1247,14 +1257,14 @@ class Script: sub_script.change_listener = partial(self._chain_change_listener, sub_script) return sub_script - def _get_repeat_script(self, step): + def _get_repeat_script(self, step: int) -> Script: sub_script = self._repeat_script.get(step) if not sub_script: sub_script = self._prep_repeat_script(step) self._repeat_script[step] = sub_script return sub_script - async def _async_prep_choose_data(self, step): + async def _async_prep_choose_data(self, step: int) -> _ChooseData: action = self.sequence[step] step_name = action.get(CONF_ALIAS, f"Choose at step {step+1}") choices = [] @@ -1280,6 +1290,7 @@ class Script: ) choices.append((conditions, sub_script)) + default_script: Script | None if CONF_DEFAULT in action: default_script = Script( self._hass, @@ -1300,7 +1311,7 @@ class Script: return {"choices": choices, "default": default_script} - async def _async_get_choose_data(self, step): + async def _async_get_choose_data(self, step: int) -> _ChooseData: choose_data = self._choose_data.get(step) if not choose_data: choose_data = await self._async_prep_choose_data(step) @@ -1330,7 +1341,7 @@ def breakpoint_clear(hass, key, run_id, node): @callback -def breakpoint_clear_all(hass): +def breakpoint_clear_all(hass: HomeAssistant) -> None: """Clear all breakpoints.""" hass.data[DATA_SCRIPT_BREAKPOINTS] = {} @@ -1348,7 +1359,7 @@ def breakpoint_set(hass, key, run_id, node): @callback -def breakpoint_list(hass): +def breakpoint_list(hass: HomeAssistant) -> list[dict[str, Any]]: """List breakpoints.""" breakpoints = hass.data[DATA_SCRIPT_BREAKPOINTS] diff --git a/homeassistant/helpers/storage.py b/homeassistant/helpers/storage.py index 5a08a97a210..456e9b04709 100644 --- a/homeassistant/helpers/storage.py +++ b/homeassistant/helpers/storage.py @@ -9,7 +9,7 @@ import os from typing import Any, Callable from homeassistant.const import EVENT_HOMEASSISTANT_FINAL_WRITE -from homeassistant.core import CALLBACK_TYPE, CoreState, HomeAssistant, callback +from homeassistant.core import CALLBACK_TYPE, CoreState, Event, HomeAssistant, callback from homeassistant.helpers.event import async_call_later from homeassistant.loader import bind_hass from homeassistant.util import json as json_util @@ -169,7 +169,7 @@ class Store: ) @callback - def _async_ensure_final_write_listener(self): + def _async_ensure_final_write_listener(self) -> None: """Ensure that we write if we quit before delay has passed.""" if self._unsub_final_write_listener is None: self._unsub_final_write_listener = self.hass.bus.async_listen_once( @@ -177,14 +177,14 @@ class Store: ) @callback - def _async_cleanup_final_write_listener(self): + def _async_cleanup_final_write_listener(self) -> None: """Clean up a stop listener.""" if self._unsub_final_write_listener is not None: self._unsub_final_write_listener() self._unsub_final_write_listener = None @callback - def _async_cleanup_delay_listener(self): + def _async_cleanup_delay_listener(self) -> None: """Clean up a delay listener.""" if self._unsub_delay_listener is not None: self._unsub_delay_listener() @@ -198,7 +198,7 @@ class Store: return await self._async_handle_write_data() - async def _async_callback_final_write(self, _event): + async def _async_callback_final_write(self, _event: Event) -> None: """Handle a write because Home Assistant is in final write state.""" self._unsub_final_write_listener = None await self._async_handle_write_data() @@ -239,7 +239,7 @@ class Store: """Migrate to the new version.""" raise NotImplementedError - async def async_remove(self): + async def async_remove(self) -> None: """Remove all data.""" self._async_cleanup_delay_listener() self._async_cleanup_final_write_listener() diff --git a/homeassistant/helpers/template.py b/homeassistant/helpers/template.py index b024c8f2656..06fe5d288f5 100644 --- a/homeassistant/helpers/template.py +++ b/homeassistant/helpers/template.py @@ -16,7 +16,7 @@ from operator import attrgetter import random import re import sys -from typing import Any, Generator, Iterable, cast +from typing import Any, Callable, Generator, Iterable, cast from urllib.parse import urlencode as urllib_urlencode import weakref @@ -193,31 +193,32 @@ RESULT_WRAPPERS: dict[type, type] = { RESULT_WRAPPERS[tuple] = TupleWrapper -def _true(arg: Any) -> bool: +def _true(arg: str) -> bool: return True -def _false(arg: Any) -> bool: +def _false(arg: str) -> bool: return False class RenderInfo: """Holds information about a template render.""" - def __init__(self, template): + def __init__(self, template: Template) -> None: """Initialise.""" self.template = template # Will be set sensibly once frozen. - self.filter_lifecycle = _true - self.filter = _true + self.filter_lifecycle: Callable[[str], bool] = _true + self.filter: Callable[[str], bool] = _true self._result: str | None = None self.is_static = False self.exception: TemplateError | None = None self.all_states = False self.all_states_lifecycle = False - self.domains = set() - self.domains_lifecycle = set() - self.entities = set() + # pylint: disable=unsubscriptable-object # for abc.Set, https://github.com/PyCQA/pylint/pull/4275 + self.domains: collections.abc.Set[str] = set() + self.domains_lifecycle: collections.abc.Set[str] = set() + self.entities: collections.abc.Set[str] = set() self.rate_limit: timedelta | None = None self.has_time = False @@ -491,7 +492,7 @@ class Template: """Render the template and collect an entity filter.""" assert self.hass and _RENDER_INFO not in self.hass.data - render_info = RenderInfo(self) # type: ignore[no-untyped-call] + render_info = RenderInfo(self) # pylint: disable=protected-access if self.is_static: @@ -1039,13 +1040,13 @@ def is_state(hass: HomeAssistant, entity_id: str, state: State) -> bool: return state_obj is not None and state_obj.state == state -def is_state_attr(hass, entity_id, name, value): +def is_state_attr(hass: HomeAssistant, entity_id: str, name: str, value: Any) -> bool: """Test if a state's attribute is a specific value.""" attr = state_attr(hass, entity_id, name) return attr is not None and attr == value -def state_attr(hass, entity_id, name): +def state_attr(hass: HomeAssistant, entity_id: str, name: str) -> Any: """Get a specific attribute from a state.""" state_obj = _get_state(hass, entity_id) if state_obj is not None: @@ -1053,7 +1054,7 @@ def state_attr(hass, entity_id, name): return None -def now(hass): +def now(hass: HomeAssistant) -> datetime: """Record fetching now.""" render_info = hass.data.get(_RENDER_INFO) if render_info is not None: @@ -1062,7 +1063,7 @@ def now(hass): return dt_util.now() -def utcnow(hass): +def utcnow(hass: HomeAssistant) -> datetime: """Record fetching utcnow.""" render_info = hass.data.get(_RENDER_INFO) if render_info is not None: