Type hint improvements (#49320)

This commit is contained in:
Ville Skyttä 2021-04-17 09:35:21 +03:00 committed by GitHub
parent f7b7a805f5
commit 970cbcbe15
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 58 additions and 44 deletions

View File

@ -784,7 +784,9 @@ class EventBus:
return remove_listener return remove_listener
def listen_once(self, event_type: str, listener: Callable) -> CALLBACK_TYPE: def listen_once(
self, event_type: str, listener: Callable[[Event], None]
) -> CALLBACK_TYPE:
"""Listen once for event of a specific type. """Listen once for event of a specific type.
To listen to all events specify the constant ``MATCH_ALL`` To listen to all events specify the constant ``MATCH_ALL``

View File

@ -1200,7 +1200,7 @@ SCRIPT_ACTION_WAIT_FOR_TRIGGER = "wait_for_trigger"
SCRIPT_ACTION_VARIABLES = "variables" SCRIPT_ACTION_VARIABLES = "variables"
def determine_script_action(action: dict) -> str: def determine_script_action(action: dict[str, Any]) -> str:
"""Determine action type.""" """Determine action type."""
if CONF_DELAY in action: if CONF_DELAY in action:
return SCRIPT_ACTION_DELAY return SCRIPT_ACTION_DELAY

View File

@ -8,7 +8,7 @@ from datetime import datetime, timedelta
import functools as ft import functools as ft
import logging import logging
import time import time
from typing import Any, Awaitable, Callable, Iterable, List from typing import Any, Awaitable, Callable, Iterable, List, cast
import attr import attr
@ -1453,10 +1453,10 @@ def process_state_match(parameter: None | str | Iterable[str]) -> Callable[[str]
@callback @callback
def _entities_domains_from_render_infos( def _entities_domains_from_render_infos(
render_infos: Iterable[RenderInfo], render_infos: Iterable[RenderInfo],
) -> tuple[set, set]: ) -> tuple[set[str], set[str]]:
"""Combine from multiple RenderInfo.""" """Combine from multiple RenderInfo."""
entities = set() entities: set[str] = set()
domains = set() domains: set[str] = set()
for render_info in render_infos: for render_info in render_infos:
if render_info.entities: if render_info.entities:
@ -1497,7 +1497,7 @@ def _render_infos_to_track_states(render_infos: Iterable[RenderInfo]) -> TrackSt
@callback @callback
def _event_triggers_rerender(event: Event, info: RenderInfo) -> bool: def _event_triggers_rerender(event: Event, info: RenderInfo) -> bool:
"""Determine if a template should be re-rendered from an event.""" """Determine if a template should be re-rendered from an event."""
entity_id = event.data.get(ATTR_ENTITY_ID) entity_id = cast(str, event.data.get(ATTR_ENTITY_ID))
if info.filter(entity_id): if info.filter(entity_id):
return True return True

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Sequence from typing import Iterable
import voluptuous as vol import voluptuous as vol
@ -25,7 +25,7 @@ def has_location(state: State) -> bool:
) )
def closest(latitude: float, longitude: float, states: Sequence[State]) -> State | None: def closest(latitude: float, longitude: float, states: Iterable[State]) -> State | None:
"""Return closest state to point. """Return closest state to point.
Async friendly. Async friendly.

View File

@ -8,7 +8,7 @@ from functools import partial
import itertools import itertools
import logging import logging
from types import MappingProxyType from types import MappingProxyType
from typing import Any, Callable, Dict, Sequence, Union, cast from typing import Any, Callable, Dict, Sequence, TypedDict, Union, cast
import async_timeout import async_timeout
import voluptuous as vol import voluptuous as vol
@ -56,7 +56,10 @@ from homeassistant.core import (
callback, callback,
) )
from homeassistant.helpers import condition, config_validation as cv, service, template from homeassistant.helpers import condition, config_validation as cv, service, template
from homeassistant.helpers.condition import trace_condition_function from homeassistant.helpers.condition import (
ConditionCheckerType,
trace_condition_function,
)
from homeassistant.helpers.dispatcher import ( from homeassistant.helpers.dispatcher import (
async_dispatcher_connect, async_dispatcher_connect,
async_dispatcher_send, async_dispatcher_send,
@ -492,7 +495,7 @@ class _ScriptRun:
task.cancel() task.cancel()
unsub() unsub()
async def _async_run_long_action(self, long_task): async def _async_run_long_action(self, long_task: asyncio.tasks.Task) -> None:
"""Run a long task while monitoring for stop request.""" """Run a long task while monitoring for stop request."""
async def async_cancel_long_task() -> None: async def async_cancel_long_task() -> None:
@ -741,7 +744,7 @@ class _ScriptRun:
except exceptions.ConditionError as ex: except exceptions.ConditionError as ex:
_LOGGER.warning("Error in 'choose' evaluation:\n%s", ex) _LOGGER.warning("Error in 'choose' evaluation:\n%s", ex)
if choose_data["default"]: if choose_data["default"] is not None:
trace_set_result(choice="default") trace_set_result(choice="default")
with trace_path(["default"]): with trace_path(["default"]):
await self._async_run_script(choose_data["default"]) await self._async_run_script(choose_data["default"])
@ -808,7 +811,7 @@ class _ScriptRun:
self._hass, self._variables, render_as_defaults=False self._hass, self._variables, render_as_defaults=False
) )
async def _async_run_script(self, script): async def _async_run_script(self, script: Script) -> None:
"""Execute a script.""" """Execute a script."""
await self._async_run_long_action( await self._async_run_long_action(
self._hass.async_create_task( self._hass.async_create_task(
@ -912,6 +915,11 @@ def _referenced_extract_ids(data: dict[str, Any], key: str, found: set[str]) ->
found.add(item_id) found.add(item_id)
class _ChooseData(TypedDict):
choices: list[tuple[list[ConditionCheckerType], Script]]
default: Script | None
class Script: class Script:
"""Representation of a script.""" """Representation of a script."""
@ -973,7 +981,7 @@ class Script:
self._queue_lck = asyncio.Lock() self._queue_lck = asyncio.Lock()
self._config_cache: dict[set[tuple], Callable[..., bool]] = {} self._config_cache: dict[set[tuple], Callable[..., bool]] = {}
self._repeat_script: dict[int, Script] = {} self._repeat_script: dict[int, Script] = {}
self._choose_data: dict[int, dict[str, Any]] = {} self._choose_data: dict[int, _ChooseData] = {}
self._referenced_entities: set[str] | None = None self._referenced_entities: set[str] | None = None
self._referenced_devices: set[str] | None = None self._referenced_devices: set[str] | None = None
self._referenced_areas: set[str] | None = None self._referenced_areas: set[str] | None = None
@ -1011,14 +1019,14 @@ class Script:
for choose_data in self._choose_data.values(): for choose_data in self._choose_data.values():
for _, script in choose_data["choices"]: for _, script in choose_data["choices"]:
script.update_logger(self._logger) script.update_logger(self._logger)
if choose_data["default"]: if choose_data["default"] is not None:
choose_data["default"].update_logger(self._logger) choose_data["default"].update_logger(self._logger)
def _changed(self) -> None: def _changed(self) -> None:
if self._change_listener_job: if self._change_listener_job:
self._hass.async_run_hass_job(self._change_listener_job) self._hass.async_run_hass_job(self._change_listener_job)
def _chain_change_listener(self, sub_script): def _chain_change_listener(self, sub_script: Script) -> None:
if sub_script.is_running: if sub_script.is_running:
self.last_action = sub_script.last_action self.last_action = sub_script.last_action
self._changed() self._changed()
@ -1203,7 +1211,9 @@ class Script:
self._changed() self._changed()
raise raise
async def _async_stop(self, update_state, spare=None): async def _async_stop(
self, update_state: bool, spare: _ScriptRun | None = None
) -> None:
aws = [ aws = [
asyncio.create_task(run.async_stop()) for run in self._runs if run != spare asyncio.create_task(run.async_stop()) for run in self._runs if run != spare
] ]
@ -1230,7 +1240,7 @@ class Script:
self._config_cache[config_cache_key] = cond self._config_cache[config_cache_key] = cond
return cond return cond
def _prep_repeat_script(self, step): def _prep_repeat_script(self, step: int) -> Script:
action = self.sequence[step] action = self.sequence[step]
step_name = action.get(CONF_ALIAS, f"Repeat at step {step+1}") step_name = action.get(CONF_ALIAS, f"Repeat at step {step+1}")
sub_script = Script( sub_script = Script(
@ -1247,14 +1257,14 @@ class Script:
sub_script.change_listener = partial(self._chain_change_listener, sub_script) sub_script.change_listener = partial(self._chain_change_listener, sub_script)
return sub_script return sub_script
def _get_repeat_script(self, step): def _get_repeat_script(self, step: int) -> Script:
sub_script = self._repeat_script.get(step) sub_script = self._repeat_script.get(step)
if not sub_script: if not sub_script:
sub_script = self._prep_repeat_script(step) sub_script = self._prep_repeat_script(step)
self._repeat_script[step] = sub_script self._repeat_script[step] = sub_script
return sub_script return sub_script
async def _async_prep_choose_data(self, step): async def _async_prep_choose_data(self, step: int) -> _ChooseData:
action = self.sequence[step] action = self.sequence[step]
step_name = action.get(CONF_ALIAS, f"Choose at step {step+1}") step_name = action.get(CONF_ALIAS, f"Choose at step {step+1}")
choices = [] choices = []
@ -1280,6 +1290,7 @@ class Script:
) )
choices.append((conditions, sub_script)) choices.append((conditions, sub_script))
default_script: Script | None
if CONF_DEFAULT in action: if CONF_DEFAULT in action:
default_script = Script( default_script = Script(
self._hass, self._hass,
@ -1300,7 +1311,7 @@ class Script:
return {"choices": choices, "default": default_script} return {"choices": choices, "default": default_script}
async def _async_get_choose_data(self, step): async def _async_get_choose_data(self, step: int) -> _ChooseData:
choose_data = self._choose_data.get(step) choose_data = self._choose_data.get(step)
if not choose_data: if not choose_data:
choose_data = await self._async_prep_choose_data(step) choose_data = await self._async_prep_choose_data(step)
@ -1330,7 +1341,7 @@ def breakpoint_clear(hass, key, run_id, node):
@callback @callback
def breakpoint_clear_all(hass): def breakpoint_clear_all(hass: HomeAssistant) -> None:
"""Clear all breakpoints.""" """Clear all breakpoints."""
hass.data[DATA_SCRIPT_BREAKPOINTS] = {} hass.data[DATA_SCRIPT_BREAKPOINTS] = {}
@ -1348,7 +1359,7 @@ def breakpoint_set(hass, key, run_id, node):
@callback @callback
def breakpoint_list(hass): def breakpoint_list(hass: HomeAssistant) -> list[dict[str, Any]]:
"""List breakpoints.""" """List breakpoints."""
breakpoints = hass.data[DATA_SCRIPT_BREAKPOINTS] breakpoints = hass.data[DATA_SCRIPT_BREAKPOINTS]

View File

@ -9,7 +9,7 @@ import os
from typing import Any, Callable from typing import Any, Callable
from homeassistant.const import EVENT_HOMEASSISTANT_FINAL_WRITE from homeassistant.const import EVENT_HOMEASSISTANT_FINAL_WRITE
from homeassistant.core import CALLBACK_TYPE, CoreState, HomeAssistant, callback from homeassistant.core import CALLBACK_TYPE, CoreState, Event, HomeAssistant, callback
from homeassistant.helpers.event import async_call_later from homeassistant.helpers.event import async_call_later
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.util import json as json_util from homeassistant.util import json as json_util
@ -169,7 +169,7 @@ class Store:
) )
@callback @callback
def _async_ensure_final_write_listener(self): def _async_ensure_final_write_listener(self) -> None:
"""Ensure that we write if we quit before delay has passed.""" """Ensure that we write if we quit before delay has passed."""
if self._unsub_final_write_listener is None: if self._unsub_final_write_listener is None:
self._unsub_final_write_listener = self.hass.bus.async_listen_once( self._unsub_final_write_listener = self.hass.bus.async_listen_once(
@ -177,14 +177,14 @@ class Store:
) )
@callback @callback
def _async_cleanup_final_write_listener(self): def _async_cleanup_final_write_listener(self) -> None:
"""Clean up a stop listener.""" """Clean up a stop listener."""
if self._unsub_final_write_listener is not None: if self._unsub_final_write_listener is not None:
self._unsub_final_write_listener() self._unsub_final_write_listener()
self._unsub_final_write_listener = None self._unsub_final_write_listener = None
@callback @callback
def _async_cleanup_delay_listener(self): def _async_cleanup_delay_listener(self) -> None:
"""Clean up a delay listener.""" """Clean up a delay listener."""
if self._unsub_delay_listener is not None: if self._unsub_delay_listener is not None:
self._unsub_delay_listener() self._unsub_delay_listener()
@ -198,7 +198,7 @@ class Store:
return return
await self._async_handle_write_data() await self._async_handle_write_data()
async def _async_callback_final_write(self, _event): async def _async_callback_final_write(self, _event: Event) -> None:
"""Handle a write because Home Assistant is in final write state.""" """Handle a write because Home Assistant is in final write state."""
self._unsub_final_write_listener = None self._unsub_final_write_listener = None
await self._async_handle_write_data() await self._async_handle_write_data()
@ -239,7 +239,7 @@ class Store:
"""Migrate to the new version.""" """Migrate to the new version."""
raise NotImplementedError raise NotImplementedError
async def async_remove(self): async def async_remove(self) -> None:
"""Remove all data.""" """Remove all data."""
self._async_cleanup_delay_listener() self._async_cleanup_delay_listener()
self._async_cleanup_final_write_listener() self._async_cleanup_final_write_listener()

View File

@ -16,7 +16,7 @@ from operator import attrgetter
import random import random
import re import re
import sys import sys
from typing import Any, Generator, Iterable, cast from typing import Any, Callable, Generator, Iterable, cast
from urllib.parse import urlencode as urllib_urlencode from urllib.parse import urlencode as urllib_urlencode
import weakref import weakref
@ -193,31 +193,32 @@ RESULT_WRAPPERS: dict[type, type] = {
RESULT_WRAPPERS[tuple] = TupleWrapper RESULT_WRAPPERS[tuple] = TupleWrapper
def _true(arg: Any) -> bool: def _true(arg: str) -> bool:
return True return True
def _false(arg: Any) -> bool: def _false(arg: str) -> bool:
return False return False
class RenderInfo: class RenderInfo:
"""Holds information about a template render.""" """Holds information about a template render."""
def __init__(self, template): def __init__(self, template: Template) -> None:
"""Initialise.""" """Initialise."""
self.template = template self.template = template
# Will be set sensibly once frozen. # Will be set sensibly once frozen.
self.filter_lifecycle = _true self.filter_lifecycle: Callable[[str], bool] = _true
self.filter = _true self.filter: Callable[[str], bool] = _true
self._result: str | None = None self._result: str | None = None
self.is_static = False self.is_static = False
self.exception: TemplateError | None = None self.exception: TemplateError | None = None
self.all_states = False self.all_states = False
self.all_states_lifecycle = False self.all_states_lifecycle = False
self.domains = set() # pylint: disable=unsubscriptable-object # for abc.Set, https://github.com/PyCQA/pylint/pull/4275
self.domains_lifecycle = set() self.domains: collections.abc.Set[str] = set()
self.entities = set() self.domains_lifecycle: collections.abc.Set[str] = set()
self.entities: collections.abc.Set[str] = set()
self.rate_limit: timedelta | None = None self.rate_limit: timedelta | None = None
self.has_time = False self.has_time = False
@ -491,7 +492,7 @@ class Template:
"""Render the template and collect an entity filter.""" """Render the template and collect an entity filter."""
assert self.hass and _RENDER_INFO not in self.hass.data assert self.hass and _RENDER_INFO not in self.hass.data
render_info = RenderInfo(self) # type: ignore[no-untyped-call] render_info = RenderInfo(self)
# pylint: disable=protected-access # pylint: disable=protected-access
if self.is_static: if self.is_static:
@ -1039,13 +1040,13 @@ def is_state(hass: HomeAssistant, entity_id: str, state: State) -> bool:
return state_obj is not None and state_obj.state == state return state_obj is not None and state_obj.state == state
def is_state_attr(hass, entity_id, name, value): def is_state_attr(hass: HomeAssistant, entity_id: str, name: str, value: Any) -> bool:
"""Test if a state's attribute is a specific value.""" """Test if a state's attribute is a specific value."""
attr = state_attr(hass, entity_id, name) attr = state_attr(hass, entity_id, name)
return attr is not None and attr == value return attr is not None and attr == value
def state_attr(hass, entity_id, name): def state_attr(hass: HomeAssistant, entity_id: str, name: str) -> Any:
"""Get a specific attribute from a state.""" """Get a specific attribute from a state."""
state_obj = _get_state(hass, entity_id) state_obj = _get_state(hass, entity_id)
if state_obj is not None: if state_obj is not None:
@ -1053,7 +1054,7 @@ def state_attr(hass, entity_id, name):
return None return None
def now(hass): def now(hass: HomeAssistant) -> datetime:
"""Record fetching now.""" """Record fetching now."""
render_info = hass.data.get(_RENDER_INFO) render_info = hass.data.get(_RENDER_INFO)
if render_info is not None: if render_info is not None:
@ -1062,7 +1063,7 @@ def now(hass):
return dt_util.now() return dt_util.now()
def utcnow(hass): def utcnow(hass: HomeAssistant) -> datetime:
"""Record fetching utcnow.""" """Record fetching utcnow."""
render_info = hass.data.get(_RENDER_INFO) render_info = hass.data.get(_RENDER_INFO)
if render_info is not None: if render_info is not None: