Type hint improvements (#49320)

This commit is contained in:
Ville Skyttä 2021-04-17 09:35:21 +03:00 committed by GitHub
parent f7b7a805f5
commit 970cbcbe15
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 58 additions and 44 deletions

View File

@ -784,7 +784,9 @@ class EventBus:
return remove_listener
def listen_once(self, event_type: str, listener: Callable) -> CALLBACK_TYPE:
def listen_once(
self, event_type: str, listener: Callable[[Event], None]
) -> CALLBACK_TYPE:
"""Listen once for event of a specific type.
To listen to all events specify the constant ``MATCH_ALL``

View File

@ -1200,7 +1200,7 @@ SCRIPT_ACTION_WAIT_FOR_TRIGGER = "wait_for_trigger"
SCRIPT_ACTION_VARIABLES = "variables"
def determine_script_action(action: dict) -> str:
def determine_script_action(action: dict[str, Any]) -> str:
"""Determine action type."""
if CONF_DELAY in action:
return SCRIPT_ACTION_DELAY

View File

@ -8,7 +8,7 @@ from datetime import datetime, timedelta
import functools as ft
import logging
import time
from typing import Any, Awaitable, Callable, Iterable, List
from typing import Any, Awaitable, Callable, Iterable, List, cast
import attr
@ -1453,10 +1453,10 @@ def process_state_match(parameter: None | str | Iterable[str]) -> Callable[[str]
@callback
def _entities_domains_from_render_infos(
render_infos: Iterable[RenderInfo],
) -> tuple[set, set]:
) -> tuple[set[str], set[str]]:
"""Combine from multiple RenderInfo."""
entities = set()
domains = set()
entities: set[str] = set()
domains: set[str] = set()
for render_info in render_infos:
if render_info.entities:
@ -1497,7 +1497,7 @@ def _render_infos_to_track_states(render_infos: Iterable[RenderInfo]) -> TrackSt
@callback
def _event_triggers_rerender(event: Event, info: RenderInfo) -> bool:
"""Determine if a template should be re-rendered from an event."""
entity_id = event.data.get(ATTR_ENTITY_ID)
entity_id = cast(str, event.data.get(ATTR_ENTITY_ID))
if info.filter(entity_id):
return True

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import logging
from typing import Sequence
from typing import Iterable
import voluptuous as vol
@ -25,7 +25,7 @@ def has_location(state: State) -> bool:
)
def closest(latitude: float, longitude: float, states: Sequence[State]) -> State | None:
def closest(latitude: float, longitude: float, states: Iterable[State]) -> State | None:
"""Return closest state to point.
Async friendly.

View File

@ -8,7 +8,7 @@ from functools import partial
import itertools
import logging
from types import MappingProxyType
from typing import Any, Callable, Dict, Sequence, Union, cast
from typing import Any, Callable, Dict, Sequence, TypedDict, Union, cast
import async_timeout
import voluptuous as vol
@ -56,7 +56,10 @@ from homeassistant.core import (
callback,
)
from homeassistant.helpers import condition, config_validation as cv, service, template
from homeassistant.helpers.condition import trace_condition_function
from homeassistant.helpers.condition import (
ConditionCheckerType,
trace_condition_function,
)
from homeassistant.helpers.dispatcher import (
async_dispatcher_connect,
async_dispatcher_send,
@ -492,7 +495,7 @@ class _ScriptRun:
task.cancel()
unsub()
async def _async_run_long_action(self, long_task):
async def _async_run_long_action(self, long_task: asyncio.tasks.Task) -> None:
"""Run a long task while monitoring for stop request."""
async def async_cancel_long_task() -> None:
@ -741,7 +744,7 @@ class _ScriptRun:
except exceptions.ConditionError as ex:
_LOGGER.warning("Error in 'choose' evaluation:\n%s", ex)
if choose_data["default"]:
if choose_data["default"] is not None:
trace_set_result(choice="default")
with trace_path(["default"]):
await self._async_run_script(choose_data["default"])
@ -808,7 +811,7 @@ class _ScriptRun:
self._hass, self._variables, render_as_defaults=False
)
async def _async_run_script(self, script):
async def _async_run_script(self, script: Script) -> None:
"""Execute a script."""
await self._async_run_long_action(
self._hass.async_create_task(
@ -912,6 +915,11 @@ def _referenced_extract_ids(data: dict[str, Any], key: str, found: set[str]) ->
found.add(item_id)
class _ChooseData(TypedDict):
choices: list[tuple[list[ConditionCheckerType], Script]]
default: Script | None
class Script:
"""Representation of a script."""
@ -973,7 +981,7 @@ class Script:
self._queue_lck = asyncio.Lock()
self._config_cache: dict[set[tuple], Callable[..., bool]] = {}
self._repeat_script: dict[int, Script] = {}
self._choose_data: dict[int, dict[str, Any]] = {}
self._choose_data: dict[int, _ChooseData] = {}
self._referenced_entities: set[str] | None = None
self._referenced_devices: set[str] | None = None
self._referenced_areas: set[str] | None = None
@ -1011,14 +1019,14 @@ class Script:
for choose_data in self._choose_data.values():
for _, script in choose_data["choices"]:
script.update_logger(self._logger)
if choose_data["default"]:
if choose_data["default"] is not None:
choose_data["default"].update_logger(self._logger)
def _changed(self) -> None:
if self._change_listener_job:
self._hass.async_run_hass_job(self._change_listener_job)
def _chain_change_listener(self, sub_script):
def _chain_change_listener(self, sub_script: Script) -> None:
if sub_script.is_running:
self.last_action = sub_script.last_action
self._changed()
@ -1203,7 +1211,9 @@ class Script:
self._changed()
raise
async def _async_stop(self, update_state, spare=None):
async def _async_stop(
self, update_state: bool, spare: _ScriptRun | None = None
) -> None:
aws = [
asyncio.create_task(run.async_stop()) for run in self._runs if run != spare
]
@ -1230,7 +1240,7 @@ class Script:
self._config_cache[config_cache_key] = cond
return cond
def _prep_repeat_script(self, step):
def _prep_repeat_script(self, step: int) -> Script:
action = self.sequence[step]
step_name = action.get(CONF_ALIAS, f"Repeat at step {step+1}")
sub_script = Script(
@ -1247,14 +1257,14 @@ class Script:
sub_script.change_listener = partial(self._chain_change_listener, sub_script)
return sub_script
def _get_repeat_script(self, step):
def _get_repeat_script(self, step: int) -> Script:
sub_script = self._repeat_script.get(step)
if not sub_script:
sub_script = self._prep_repeat_script(step)
self._repeat_script[step] = sub_script
return sub_script
async def _async_prep_choose_data(self, step):
async def _async_prep_choose_data(self, step: int) -> _ChooseData:
action = self.sequence[step]
step_name = action.get(CONF_ALIAS, f"Choose at step {step+1}")
choices = []
@ -1280,6 +1290,7 @@ class Script:
)
choices.append((conditions, sub_script))
default_script: Script | None
if CONF_DEFAULT in action:
default_script = Script(
self._hass,
@ -1300,7 +1311,7 @@ class Script:
return {"choices": choices, "default": default_script}
async def _async_get_choose_data(self, step):
async def _async_get_choose_data(self, step: int) -> _ChooseData:
choose_data = self._choose_data.get(step)
if not choose_data:
choose_data = await self._async_prep_choose_data(step)
@ -1330,7 +1341,7 @@ def breakpoint_clear(hass, key, run_id, node):
@callback
def breakpoint_clear_all(hass):
def breakpoint_clear_all(hass: HomeAssistant) -> None:
"""Clear all breakpoints."""
hass.data[DATA_SCRIPT_BREAKPOINTS] = {}
@ -1348,7 +1359,7 @@ def breakpoint_set(hass, key, run_id, node):
@callback
def breakpoint_list(hass):
def breakpoint_list(hass: HomeAssistant) -> list[dict[str, Any]]:
"""List breakpoints."""
breakpoints = hass.data[DATA_SCRIPT_BREAKPOINTS]

View File

@ -9,7 +9,7 @@ import os
from typing import Any, Callable
from homeassistant.const import EVENT_HOMEASSISTANT_FINAL_WRITE
from homeassistant.core import CALLBACK_TYPE, CoreState, HomeAssistant, callback
from homeassistant.core import CALLBACK_TYPE, CoreState, Event, HomeAssistant, callback
from homeassistant.helpers.event import async_call_later
from homeassistant.loader import bind_hass
from homeassistant.util import json as json_util
@ -169,7 +169,7 @@ class Store:
)
@callback
def _async_ensure_final_write_listener(self):
def _async_ensure_final_write_listener(self) -> None:
"""Ensure that we write if we quit before delay has passed."""
if self._unsub_final_write_listener is None:
self._unsub_final_write_listener = self.hass.bus.async_listen_once(
@ -177,14 +177,14 @@ class Store:
)
@callback
def _async_cleanup_final_write_listener(self):
def _async_cleanup_final_write_listener(self) -> None:
"""Clean up a stop listener."""
if self._unsub_final_write_listener is not None:
self._unsub_final_write_listener()
self._unsub_final_write_listener = None
@callback
def _async_cleanup_delay_listener(self):
def _async_cleanup_delay_listener(self) -> None:
"""Clean up a delay listener."""
if self._unsub_delay_listener is not None:
self._unsub_delay_listener()
@ -198,7 +198,7 @@ class Store:
return
await self._async_handle_write_data()
async def _async_callback_final_write(self, _event):
async def _async_callback_final_write(self, _event: Event) -> None:
"""Handle a write because Home Assistant is in final write state."""
self._unsub_final_write_listener = None
await self._async_handle_write_data()
@ -239,7 +239,7 @@ class Store:
"""Migrate to the new version."""
raise NotImplementedError
async def async_remove(self):
async def async_remove(self) -> None:
"""Remove all data."""
self._async_cleanup_delay_listener()
self._async_cleanup_final_write_listener()

View File

@ -16,7 +16,7 @@ from operator import attrgetter
import random
import re
import sys
from typing import Any, Generator, Iterable, cast
from typing import Any, Callable, Generator, Iterable, cast
from urllib.parse import urlencode as urllib_urlencode
import weakref
@ -193,31 +193,32 @@ RESULT_WRAPPERS: dict[type, type] = {
RESULT_WRAPPERS[tuple] = TupleWrapper
def _true(arg: Any) -> bool:
def _true(arg: str) -> bool:
return True
def _false(arg: Any) -> bool:
def _false(arg: str) -> bool:
return False
class RenderInfo:
"""Holds information about a template render."""
def __init__(self, template):
def __init__(self, template: Template) -> None:
"""Initialise."""
self.template = template
# Will be set sensibly once frozen.
self.filter_lifecycle = _true
self.filter = _true
self.filter_lifecycle: Callable[[str], bool] = _true
self.filter: Callable[[str], bool] = _true
self._result: str | None = None
self.is_static = False
self.exception: TemplateError | None = None
self.all_states = False
self.all_states_lifecycle = False
self.domains = set()
self.domains_lifecycle = set()
self.entities = set()
# pylint: disable=unsubscriptable-object # for abc.Set, https://github.com/PyCQA/pylint/pull/4275
self.domains: collections.abc.Set[str] = set()
self.domains_lifecycle: collections.abc.Set[str] = set()
self.entities: collections.abc.Set[str] = set()
self.rate_limit: timedelta | None = None
self.has_time = False
@ -491,7 +492,7 @@ class Template:
"""Render the template and collect an entity filter."""
assert self.hass and _RENDER_INFO not in self.hass.data
render_info = RenderInfo(self) # type: ignore[no-untyped-call]
render_info = RenderInfo(self)
# pylint: disable=protected-access
if self.is_static:
@ -1039,13 +1040,13 @@ def is_state(hass: HomeAssistant, entity_id: str, state: State) -> bool:
return state_obj is not None and state_obj.state == state
def is_state_attr(hass, entity_id, name, value):
def is_state_attr(hass: HomeAssistant, entity_id: str, name: str, value: Any) -> bool:
"""Test if a state's attribute is a specific value."""
attr = state_attr(hass, entity_id, name)
return attr is not None and attr == value
def state_attr(hass, entity_id, name):
def state_attr(hass: HomeAssistant, entity_id: str, name: str) -> Any:
"""Get a specific attribute from a state."""
state_obj = _get_state(hass, entity_id)
if state_obj is not None:
@ -1053,7 +1054,7 @@ def state_attr(hass, entity_id, name):
return None
def now(hass):
def now(hass: HomeAssistant) -> datetime:
"""Record fetching now."""
render_info = hass.data.get(_RENDER_INFO)
if render_info is not None:
@ -1062,7 +1063,7 @@ def now(hass):
return dt_util.now()
def utcnow(hass):
def utcnow(hass: HomeAssistant) -> datetime:
"""Record fetching utcnow."""
render_info = hass.data.get(_RENDER_INFO)
if render_info is not None: