mirror of
https://github.com/home-assistant/core.git
synced 2025-04-24 17:27:52 +00:00
Type hint improvements (#49320)
This commit is contained in:
parent
f7b7a805f5
commit
970cbcbe15
@ -784,7 +784,9 @@ class EventBus:
|
||||
|
||||
return remove_listener
|
||||
|
||||
def listen_once(self, event_type: str, listener: Callable) -> CALLBACK_TYPE:
|
||||
def listen_once(
|
||||
self, event_type: str, listener: Callable[[Event], None]
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Listen once for event of a specific type.
|
||||
|
||||
To listen to all events specify the constant ``MATCH_ALL``
|
||||
|
@ -1200,7 +1200,7 @@ SCRIPT_ACTION_WAIT_FOR_TRIGGER = "wait_for_trigger"
|
||||
SCRIPT_ACTION_VARIABLES = "variables"
|
||||
|
||||
|
||||
def determine_script_action(action: dict) -> str:
|
||||
def determine_script_action(action: dict[str, Any]) -> str:
|
||||
"""Determine action type."""
|
||||
if CONF_DELAY in action:
|
||||
return SCRIPT_ACTION_DELAY
|
||||
|
@ -8,7 +8,7 @@ from datetime import datetime, timedelta
|
||||
import functools as ft
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Awaitable, Callable, Iterable, List
|
||||
from typing import Any, Awaitable, Callable, Iterable, List, cast
|
||||
|
||||
import attr
|
||||
|
||||
@ -1453,10 +1453,10 @@ def process_state_match(parameter: None | str | Iterable[str]) -> Callable[[str]
|
||||
@callback
|
||||
def _entities_domains_from_render_infos(
|
||||
render_infos: Iterable[RenderInfo],
|
||||
) -> tuple[set, set]:
|
||||
) -> tuple[set[str], set[str]]:
|
||||
"""Combine from multiple RenderInfo."""
|
||||
entities = set()
|
||||
domains = set()
|
||||
entities: set[str] = set()
|
||||
domains: set[str] = set()
|
||||
|
||||
for render_info in render_infos:
|
||||
if render_info.entities:
|
||||
@ -1497,7 +1497,7 @@ def _render_infos_to_track_states(render_infos: Iterable[RenderInfo]) -> TrackSt
|
||||
@callback
|
||||
def _event_triggers_rerender(event: Event, info: RenderInfo) -> bool:
|
||||
"""Determine if a template should be re-rendered from an event."""
|
||||
entity_id = event.data.get(ATTR_ENTITY_ID)
|
||||
entity_id = cast(str, event.data.get(ATTR_ENTITY_ID))
|
||||
|
||||
if info.filter(entity_id):
|
||||
return True
|
||||
|
@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Sequence
|
||||
from typing import Iterable
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
@ -25,7 +25,7 @@ def has_location(state: State) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def closest(latitude: float, longitude: float, states: Sequence[State]) -> State | None:
|
||||
def closest(latitude: float, longitude: float, states: Iterable[State]) -> State | None:
|
||||
"""Return closest state to point.
|
||||
|
||||
Async friendly.
|
||||
|
@ -8,7 +8,7 @@ from functools import partial
|
||||
import itertools
|
||||
import logging
|
||||
from types import MappingProxyType
|
||||
from typing import Any, Callable, Dict, Sequence, Union, cast
|
||||
from typing import Any, Callable, Dict, Sequence, TypedDict, Union, cast
|
||||
|
||||
import async_timeout
|
||||
import voluptuous as vol
|
||||
@ -56,7 +56,10 @@ from homeassistant.core import (
|
||||
callback,
|
||||
)
|
||||
from homeassistant.helpers import condition, config_validation as cv, service, template
|
||||
from homeassistant.helpers.condition import trace_condition_function
|
||||
from homeassistant.helpers.condition import (
|
||||
ConditionCheckerType,
|
||||
trace_condition_function,
|
||||
)
|
||||
from homeassistant.helpers.dispatcher import (
|
||||
async_dispatcher_connect,
|
||||
async_dispatcher_send,
|
||||
@ -492,7 +495,7 @@ class _ScriptRun:
|
||||
task.cancel()
|
||||
unsub()
|
||||
|
||||
async def _async_run_long_action(self, long_task):
|
||||
async def _async_run_long_action(self, long_task: asyncio.tasks.Task) -> None:
|
||||
"""Run a long task while monitoring for stop request."""
|
||||
|
||||
async def async_cancel_long_task() -> None:
|
||||
@ -741,7 +744,7 @@ class _ScriptRun:
|
||||
except exceptions.ConditionError as ex:
|
||||
_LOGGER.warning("Error in 'choose' evaluation:\n%s", ex)
|
||||
|
||||
if choose_data["default"]:
|
||||
if choose_data["default"] is not None:
|
||||
trace_set_result(choice="default")
|
||||
with trace_path(["default"]):
|
||||
await self._async_run_script(choose_data["default"])
|
||||
@ -808,7 +811,7 @@ class _ScriptRun:
|
||||
self._hass, self._variables, render_as_defaults=False
|
||||
)
|
||||
|
||||
async def _async_run_script(self, script):
|
||||
async def _async_run_script(self, script: Script) -> None:
|
||||
"""Execute a script."""
|
||||
await self._async_run_long_action(
|
||||
self._hass.async_create_task(
|
||||
@ -912,6 +915,11 @@ def _referenced_extract_ids(data: dict[str, Any], key: str, found: set[str]) ->
|
||||
found.add(item_id)
|
||||
|
||||
|
||||
class _ChooseData(TypedDict):
|
||||
choices: list[tuple[list[ConditionCheckerType], Script]]
|
||||
default: Script | None
|
||||
|
||||
|
||||
class Script:
|
||||
"""Representation of a script."""
|
||||
|
||||
@ -973,7 +981,7 @@ class Script:
|
||||
self._queue_lck = asyncio.Lock()
|
||||
self._config_cache: dict[set[tuple], Callable[..., bool]] = {}
|
||||
self._repeat_script: dict[int, Script] = {}
|
||||
self._choose_data: dict[int, dict[str, Any]] = {}
|
||||
self._choose_data: dict[int, _ChooseData] = {}
|
||||
self._referenced_entities: set[str] | None = None
|
||||
self._referenced_devices: set[str] | None = None
|
||||
self._referenced_areas: set[str] | None = None
|
||||
@ -1011,14 +1019,14 @@ class Script:
|
||||
for choose_data in self._choose_data.values():
|
||||
for _, script in choose_data["choices"]:
|
||||
script.update_logger(self._logger)
|
||||
if choose_data["default"]:
|
||||
if choose_data["default"] is not None:
|
||||
choose_data["default"].update_logger(self._logger)
|
||||
|
||||
def _changed(self) -> None:
|
||||
if self._change_listener_job:
|
||||
self._hass.async_run_hass_job(self._change_listener_job)
|
||||
|
||||
def _chain_change_listener(self, sub_script):
|
||||
def _chain_change_listener(self, sub_script: Script) -> None:
|
||||
if sub_script.is_running:
|
||||
self.last_action = sub_script.last_action
|
||||
self._changed()
|
||||
@ -1203,7 +1211,9 @@ class Script:
|
||||
self._changed()
|
||||
raise
|
||||
|
||||
async def _async_stop(self, update_state, spare=None):
|
||||
async def _async_stop(
|
||||
self, update_state: bool, spare: _ScriptRun | None = None
|
||||
) -> None:
|
||||
aws = [
|
||||
asyncio.create_task(run.async_stop()) for run in self._runs if run != spare
|
||||
]
|
||||
@ -1230,7 +1240,7 @@ class Script:
|
||||
self._config_cache[config_cache_key] = cond
|
||||
return cond
|
||||
|
||||
def _prep_repeat_script(self, step):
|
||||
def _prep_repeat_script(self, step: int) -> Script:
|
||||
action = self.sequence[step]
|
||||
step_name = action.get(CONF_ALIAS, f"Repeat at step {step+1}")
|
||||
sub_script = Script(
|
||||
@ -1247,14 +1257,14 @@ class Script:
|
||||
sub_script.change_listener = partial(self._chain_change_listener, sub_script)
|
||||
return sub_script
|
||||
|
||||
def _get_repeat_script(self, step):
|
||||
def _get_repeat_script(self, step: int) -> Script:
|
||||
sub_script = self._repeat_script.get(step)
|
||||
if not sub_script:
|
||||
sub_script = self._prep_repeat_script(step)
|
||||
self._repeat_script[step] = sub_script
|
||||
return sub_script
|
||||
|
||||
async def _async_prep_choose_data(self, step):
|
||||
async def _async_prep_choose_data(self, step: int) -> _ChooseData:
|
||||
action = self.sequence[step]
|
||||
step_name = action.get(CONF_ALIAS, f"Choose at step {step+1}")
|
||||
choices = []
|
||||
@ -1280,6 +1290,7 @@ class Script:
|
||||
)
|
||||
choices.append((conditions, sub_script))
|
||||
|
||||
default_script: Script | None
|
||||
if CONF_DEFAULT in action:
|
||||
default_script = Script(
|
||||
self._hass,
|
||||
@ -1300,7 +1311,7 @@ class Script:
|
||||
|
||||
return {"choices": choices, "default": default_script}
|
||||
|
||||
async def _async_get_choose_data(self, step):
|
||||
async def _async_get_choose_data(self, step: int) -> _ChooseData:
|
||||
choose_data = self._choose_data.get(step)
|
||||
if not choose_data:
|
||||
choose_data = await self._async_prep_choose_data(step)
|
||||
@ -1330,7 +1341,7 @@ def breakpoint_clear(hass, key, run_id, node):
|
||||
|
||||
|
||||
@callback
|
||||
def breakpoint_clear_all(hass):
|
||||
def breakpoint_clear_all(hass: HomeAssistant) -> None:
|
||||
"""Clear all breakpoints."""
|
||||
hass.data[DATA_SCRIPT_BREAKPOINTS] = {}
|
||||
|
||||
@ -1348,7 +1359,7 @@ def breakpoint_set(hass, key, run_id, node):
|
||||
|
||||
|
||||
@callback
|
||||
def breakpoint_list(hass):
|
||||
def breakpoint_list(hass: HomeAssistant) -> list[dict[str, Any]]:
|
||||
"""List breakpoints."""
|
||||
breakpoints = hass.data[DATA_SCRIPT_BREAKPOINTS]
|
||||
|
||||
|
@ -9,7 +9,7 @@ import os
|
||||
from typing import Any, Callable
|
||||
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_FINAL_WRITE
|
||||
from homeassistant.core import CALLBACK_TYPE, CoreState, HomeAssistant, callback
|
||||
from homeassistant.core import CALLBACK_TYPE, CoreState, Event, HomeAssistant, callback
|
||||
from homeassistant.helpers.event import async_call_later
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.util import json as json_util
|
||||
@ -169,7 +169,7 @@ class Store:
|
||||
)
|
||||
|
||||
@callback
|
||||
def _async_ensure_final_write_listener(self):
|
||||
def _async_ensure_final_write_listener(self) -> None:
|
||||
"""Ensure that we write if we quit before delay has passed."""
|
||||
if self._unsub_final_write_listener is None:
|
||||
self._unsub_final_write_listener = self.hass.bus.async_listen_once(
|
||||
@ -177,14 +177,14 @@ class Store:
|
||||
)
|
||||
|
||||
@callback
|
||||
def _async_cleanup_final_write_listener(self):
|
||||
def _async_cleanup_final_write_listener(self) -> None:
|
||||
"""Clean up a stop listener."""
|
||||
if self._unsub_final_write_listener is not None:
|
||||
self._unsub_final_write_listener()
|
||||
self._unsub_final_write_listener = None
|
||||
|
||||
@callback
|
||||
def _async_cleanup_delay_listener(self):
|
||||
def _async_cleanup_delay_listener(self) -> None:
|
||||
"""Clean up a delay listener."""
|
||||
if self._unsub_delay_listener is not None:
|
||||
self._unsub_delay_listener()
|
||||
@ -198,7 +198,7 @@ class Store:
|
||||
return
|
||||
await self._async_handle_write_data()
|
||||
|
||||
async def _async_callback_final_write(self, _event):
|
||||
async def _async_callback_final_write(self, _event: Event) -> None:
|
||||
"""Handle a write because Home Assistant is in final write state."""
|
||||
self._unsub_final_write_listener = None
|
||||
await self._async_handle_write_data()
|
||||
@ -239,7 +239,7 @@ class Store:
|
||||
"""Migrate to the new version."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_remove(self):
|
||||
async def async_remove(self) -> None:
|
||||
"""Remove all data."""
|
||||
self._async_cleanup_delay_listener()
|
||||
self._async_cleanup_final_write_listener()
|
||||
|
@ -16,7 +16,7 @@ from operator import attrgetter
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
from typing import Any, Generator, Iterable, cast
|
||||
from typing import Any, Callable, Generator, Iterable, cast
|
||||
from urllib.parse import urlencode as urllib_urlencode
|
||||
import weakref
|
||||
|
||||
@ -193,31 +193,32 @@ RESULT_WRAPPERS: dict[type, type] = {
|
||||
RESULT_WRAPPERS[tuple] = TupleWrapper
|
||||
|
||||
|
||||
def _true(arg: Any) -> bool:
|
||||
def _true(arg: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def _false(arg: Any) -> bool:
|
||||
def _false(arg: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class RenderInfo:
|
||||
"""Holds information about a template render."""
|
||||
|
||||
def __init__(self, template):
|
||||
def __init__(self, template: Template) -> None:
|
||||
"""Initialise."""
|
||||
self.template = template
|
||||
# Will be set sensibly once frozen.
|
||||
self.filter_lifecycle = _true
|
||||
self.filter = _true
|
||||
self.filter_lifecycle: Callable[[str], bool] = _true
|
||||
self.filter: Callable[[str], bool] = _true
|
||||
self._result: str | None = None
|
||||
self.is_static = False
|
||||
self.exception: TemplateError | None = None
|
||||
self.all_states = False
|
||||
self.all_states_lifecycle = False
|
||||
self.domains = set()
|
||||
self.domains_lifecycle = set()
|
||||
self.entities = set()
|
||||
# pylint: disable=unsubscriptable-object # for abc.Set, https://github.com/PyCQA/pylint/pull/4275
|
||||
self.domains: collections.abc.Set[str] = set()
|
||||
self.domains_lifecycle: collections.abc.Set[str] = set()
|
||||
self.entities: collections.abc.Set[str] = set()
|
||||
self.rate_limit: timedelta | None = None
|
||||
self.has_time = False
|
||||
|
||||
@ -491,7 +492,7 @@ class Template:
|
||||
"""Render the template and collect an entity filter."""
|
||||
assert self.hass and _RENDER_INFO not in self.hass.data
|
||||
|
||||
render_info = RenderInfo(self) # type: ignore[no-untyped-call]
|
||||
render_info = RenderInfo(self)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
if self.is_static:
|
||||
@ -1039,13 +1040,13 @@ def is_state(hass: HomeAssistant, entity_id: str, state: State) -> bool:
|
||||
return state_obj is not None and state_obj.state == state
|
||||
|
||||
|
||||
def is_state_attr(hass, entity_id, name, value):
|
||||
def is_state_attr(hass: HomeAssistant, entity_id: str, name: str, value: Any) -> bool:
|
||||
"""Test if a state's attribute is a specific value."""
|
||||
attr = state_attr(hass, entity_id, name)
|
||||
return attr is not None and attr == value
|
||||
|
||||
|
||||
def state_attr(hass, entity_id, name):
|
||||
def state_attr(hass: HomeAssistant, entity_id: str, name: str) -> Any:
|
||||
"""Get a specific attribute from a state."""
|
||||
state_obj = _get_state(hass, entity_id)
|
||||
if state_obj is not None:
|
||||
@ -1053,7 +1054,7 @@ def state_attr(hass, entity_id, name):
|
||||
return None
|
||||
|
||||
|
||||
def now(hass):
|
||||
def now(hass: HomeAssistant) -> datetime:
|
||||
"""Record fetching now."""
|
||||
render_info = hass.data.get(_RENDER_INFO)
|
||||
if render_info is not None:
|
||||
@ -1062,7 +1063,7 @@ def now(hass):
|
||||
return dt_util.now()
|
||||
|
||||
|
||||
def utcnow(hass):
|
||||
def utcnow(hass: HomeAssistant) -> datetime:
|
||||
"""Record fetching utcnow."""
|
||||
render_info = hass.data.get(_RENDER_INFO)
|
||||
if render_info is not None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user