mirror of
https://github.com/home-assistant/core.git
synced 2025-04-26 10:17:51 +00:00
Type hint improvements (#49320)
This commit is contained in:
parent
f7b7a805f5
commit
970cbcbe15
@ -784,7 +784,9 @@ class EventBus:
|
|||||||
|
|
||||||
return remove_listener
|
return remove_listener
|
||||||
|
|
||||||
def listen_once(self, event_type: str, listener: Callable) -> CALLBACK_TYPE:
|
def listen_once(
|
||||||
|
self, event_type: str, listener: Callable[[Event], None]
|
||||||
|
) -> CALLBACK_TYPE:
|
||||||
"""Listen once for event of a specific type.
|
"""Listen once for event of a specific type.
|
||||||
|
|
||||||
To listen to all events specify the constant ``MATCH_ALL``
|
To listen to all events specify the constant ``MATCH_ALL``
|
||||||
|
@ -1200,7 +1200,7 @@ SCRIPT_ACTION_WAIT_FOR_TRIGGER = "wait_for_trigger"
|
|||||||
SCRIPT_ACTION_VARIABLES = "variables"
|
SCRIPT_ACTION_VARIABLES = "variables"
|
||||||
|
|
||||||
|
|
||||||
def determine_script_action(action: dict) -> str:
|
def determine_script_action(action: dict[str, Any]) -> str:
|
||||||
"""Determine action type."""
|
"""Determine action type."""
|
||||||
if CONF_DELAY in action:
|
if CONF_DELAY in action:
|
||||||
return SCRIPT_ACTION_DELAY
|
return SCRIPT_ACTION_DELAY
|
||||||
|
@ -8,7 +8,7 @@ from datetime import datetime, timedelta
|
|||||||
import functools as ft
|
import functools as ft
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Any, Awaitable, Callable, Iterable, List
|
from typing import Any, Awaitable, Callable, Iterable, List, cast
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
@ -1453,10 +1453,10 @@ def process_state_match(parameter: None | str | Iterable[str]) -> Callable[[str]
|
|||||||
@callback
|
@callback
|
||||||
def _entities_domains_from_render_infos(
|
def _entities_domains_from_render_infos(
|
||||||
render_infos: Iterable[RenderInfo],
|
render_infos: Iterable[RenderInfo],
|
||||||
) -> tuple[set, set]:
|
) -> tuple[set[str], set[str]]:
|
||||||
"""Combine from multiple RenderInfo."""
|
"""Combine from multiple RenderInfo."""
|
||||||
entities = set()
|
entities: set[str] = set()
|
||||||
domains = set()
|
domains: set[str] = set()
|
||||||
|
|
||||||
for render_info in render_infos:
|
for render_info in render_infos:
|
||||||
if render_info.entities:
|
if render_info.entities:
|
||||||
@ -1497,7 +1497,7 @@ def _render_infos_to_track_states(render_infos: Iterable[RenderInfo]) -> TrackSt
|
|||||||
@callback
|
@callback
|
||||||
def _event_triggers_rerender(event: Event, info: RenderInfo) -> bool:
|
def _event_triggers_rerender(event: Event, info: RenderInfo) -> bool:
|
||||||
"""Determine if a template should be re-rendered from an event."""
|
"""Determine if a template should be re-rendered from an event."""
|
||||||
entity_id = event.data.get(ATTR_ENTITY_ID)
|
entity_id = cast(str, event.data.get(ATTR_ENTITY_ID))
|
||||||
|
|
||||||
if info.filter(entity_id):
|
if info.filter(entity_id):
|
||||||
return True
|
return True
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Sequence
|
from typing import Iterable
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
@ -25,7 +25,7 @@ def has_location(state: State) -> bool:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def closest(latitude: float, longitude: float, states: Sequence[State]) -> State | None:
|
def closest(latitude: float, longitude: float, states: Iterable[State]) -> State | None:
|
||||||
"""Return closest state to point.
|
"""Return closest state to point.
|
||||||
|
|
||||||
Async friendly.
|
Async friendly.
|
||||||
|
@ -8,7 +8,7 @@ from functools import partial
|
|||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
from typing import Any, Callable, Dict, Sequence, Union, cast
|
from typing import Any, Callable, Dict, Sequence, TypedDict, Union, cast
|
||||||
|
|
||||||
import async_timeout
|
import async_timeout
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
@ -56,7 +56,10 @@ from homeassistant.core import (
|
|||||||
callback,
|
callback,
|
||||||
)
|
)
|
||||||
from homeassistant.helpers import condition, config_validation as cv, service, template
|
from homeassistant.helpers import condition, config_validation as cv, service, template
|
||||||
from homeassistant.helpers.condition import trace_condition_function
|
from homeassistant.helpers.condition import (
|
||||||
|
ConditionCheckerType,
|
||||||
|
trace_condition_function,
|
||||||
|
)
|
||||||
from homeassistant.helpers.dispatcher import (
|
from homeassistant.helpers.dispatcher import (
|
||||||
async_dispatcher_connect,
|
async_dispatcher_connect,
|
||||||
async_dispatcher_send,
|
async_dispatcher_send,
|
||||||
@ -492,7 +495,7 @@ class _ScriptRun:
|
|||||||
task.cancel()
|
task.cancel()
|
||||||
unsub()
|
unsub()
|
||||||
|
|
||||||
async def _async_run_long_action(self, long_task):
|
async def _async_run_long_action(self, long_task: asyncio.tasks.Task) -> None:
|
||||||
"""Run a long task while monitoring for stop request."""
|
"""Run a long task while monitoring for stop request."""
|
||||||
|
|
||||||
async def async_cancel_long_task() -> None:
|
async def async_cancel_long_task() -> None:
|
||||||
@ -741,7 +744,7 @@ class _ScriptRun:
|
|||||||
except exceptions.ConditionError as ex:
|
except exceptions.ConditionError as ex:
|
||||||
_LOGGER.warning("Error in 'choose' evaluation:\n%s", ex)
|
_LOGGER.warning("Error in 'choose' evaluation:\n%s", ex)
|
||||||
|
|
||||||
if choose_data["default"]:
|
if choose_data["default"] is not None:
|
||||||
trace_set_result(choice="default")
|
trace_set_result(choice="default")
|
||||||
with trace_path(["default"]):
|
with trace_path(["default"]):
|
||||||
await self._async_run_script(choose_data["default"])
|
await self._async_run_script(choose_data["default"])
|
||||||
@ -808,7 +811,7 @@ class _ScriptRun:
|
|||||||
self._hass, self._variables, render_as_defaults=False
|
self._hass, self._variables, render_as_defaults=False
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _async_run_script(self, script):
|
async def _async_run_script(self, script: Script) -> None:
|
||||||
"""Execute a script."""
|
"""Execute a script."""
|
||||||
await self._async_run_long_action(
|
await self._async_run_long_action(
|
||||||
self._hass.async_create_task(
|
self._hass.async_create_task(
|
||||||
@ -912,6 +915,11 @@ def _referenced_extract_ids(data: dict[str, Any], key: str, found: set[str]) ->
|
|||||||
found.add(item_id)
|
found.add(item_id)
|
||||||
|
|
||||||
|
|
||||||
|
class _ChooseData(TypedDict):
|
||||||
|
choices: list[tuple[list[ConditionCheckerType], Script]]
|
||||||
|
default: Script | None
|
||||||
|
|
||||||
|
|
||||||
class Script:
|
class Script:
|
||||||
"""Representation of a script."""
|
"""Representation of a script."""
|
||||||
|
|
||||||
@ -973,7 +981,7 @@ class Script:
|
|||||||
self._queue_lck = asyncio.Lock()
|
self._queue_lck = asyncio.Lock()
|
||||||
self._config_cache: dict[set[tuple], Callable[..., bool]] = {}
|
self._config_cache: dict[set[tuple], Callable[..., bool]] = {}
|
||||||
self._repeat_script: dict[int, Script] = {}
|
self._repeat_script: dict[int, Script] = {}
|
||||||
self._choose_data: dict[int, dict[str, Any]] = {}
|
self._choose_data: dict[int, _ChooseData] = {}
|
||||||
self._referenced_entities: set[str] | None = None
|
self._referenced_entities: set[str] | None = None
|
||||||
self._referenced_devices: set[str] | None = None
|
self._referenced_devices: set[str] | None = None
|
||||||
self._referenced_areas: set[str] | None = None
|
self._referenced_areas: set[str] | None = None
|
||||||
@ -1011,14 +1019,14 @@ class Script:
|
|||||||
for choose_data in self._choose_data.values():
|
for choose_data in self._choose_data.values():
|
||||||
for _, script in choose_data["choices"]:
|
for _, script in choose_data["choices"]:
|
||||||
script.update_logger(self._logger)
|
script.update_logger(self._logger)
|
||||||
if choose_data["default"]:
|
if choose_data["default"] is not None:
|
||||||
choose_data["default"].update_logger(self._logger)
|
choose_data["default"].update_logger(self._logger)
|
||||||
|
|
||||||
def _changed(self) -> None:
|
def _changed(self) -> None:
|
||||||
if self._change_listener_job:
|
if self._change_listener_job:
|
||||||
self._hass.async_run_hass_job(self._change_listener_job)
|
self._hass.async_run_hass_job(self._change_listener_job)
|
||||||
|
|
||||||
def _chain_change_listener(self, sub_script):
|
def _chain_change_listener(self, sub_script: Script) -> None:
|
||||||
if sub_script.is_running:
|
if sub_script.is_running:
|
||||||
self.last_action = sub_script.last_action
|
self.last_action = sub_script.last_action
|
||||||
self._changed()
|
self._changed()
|
||||||
@ -1203,7 +1211,9 @@ class Script:
|
|||||||
self._changed()
|
self._changed()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _async_stop(self, update_state, spare=None):
|
async def _async_stop(
|
||||||
|
self, update_state: bool, spare: _ScriptRun | None = None
|
||||||
|
) -> None:
|
||||||
aws = [
|
aws = [
|
||||||
asyncio.create_task(run.async_stop()) for run in self._runs if run != spare
|
asyncio.create_task(run.async_stop()) for run in self._runs if run != spare
|
||||||
]
|
]
|
||||||
@ -1230,7 +1240,7 @@ class Script:
|
|||||||
self._config_cache[config_cache_key] = cond
|
self._config_cache[config_cache_key] = cond
|
||||||
return cond
|
return cond
|
||||||
|
|
||||||
def _prep_repeat_script(self, step):
|
def _prep_repeat_script(self, step: int) -> Script:
|
||||||
action = self.sequence[step]
|
action = self.sequence[step]
|
||||||
step_name = action.get(CONF_ALIAS, f"Repeat at step {step+1}")
|
step_name = action.get(CONF_ALIAS, f"Repeat at step {step+1}")
|
||||||
sub_script = Script(
|
sub_script = Script(
|
||||||
@ -1247,14 +1257,14 @@ class Script:
|
|||||||
sub_script.change_listener = partial(self._chain_change_listener, sub_script)
|
sub_script.change_listener = partial(self._chain_change_listener, sub_script)
|
||||||
return sub_script
|
return sub_script
|
||||||
|
|
||||||
def _get_repeat_script(self, step):
|
def _get_repeat_script(self, step: int) -> Script:
|
||||||
sub_script = self._repeat_script.get(step)
|
sub_script = self._repeat_script.get(step)
|
||||||
if not sub_script:
|
if not sub_script:
|
||||||
sub_script = self._prep_repeat_script(step)
|
sub_script = self._prep_repeat_script(step)
|
||||||
self._repeat_script[step] = sub_script
|
self._repeat_script[step] = sub_script
|
||||||
return sub_script
|
return sub_script
|
||||||
|
|
||||||
async def _async_prep_choose_data(self, step):
|
async def _async_prep_choose_data(self, step: int) -> _ChooseData:
|
||||||
action = self.sequence[step]
|
action = self.sequence[step]
|
||||||
step_name = action.get(CONF_ALIAS, f"Choose at step {step+1}")
|
step_name = action.get(CONF_ALIAS, f"Choose at step {step+1}")
|
||||||
choices = []
|
choices = []
|
||||||
@ -1280,6 +1290,7 @@ class Script:
|
|||||||
)
|
)
|
||||||
choices.append((conditions, sub_script))
|
choices.append((conditions, sub_script))
|
||||||
|
|
||||||
|
default_script: Script | None
|
||||||
if CONF_DEFAULT in action:
|
if CONF_DEFAULT in action:
|
||||||
default_script = Script(
|
default_script = Script(
|
||||||
self._hass,
|
self._hass,
|
||||||
@ -1300,7 +1311,7 @@ class Script:
|
|||||||
|
|
||||||
return {"choices": choices, "default": default_script}
|
return {"choices": choices, "default": default_script}
|
||||||
|
|
||||||
async def _async_get_choose_data(self, step):
|
async def _async_get_choose_data(self, step: int) -> _ChooseData:
|
||||||
choose_data = self._choose_data.get(step)
|
choose_data = self._choose_data.get(step)
|
||||||
if not choose_data:
|
if not choose_data:
|
||||||
choose_data = await self._async_prep_choose_data(step)
|
choose_data = await self._async_prep_choose_data(step)
|
||||||
@ -1330,7 +1341,7 @@ def breakpoint_clear(hass, key, run_id, node):
|
|||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def breakpoint_clear_all(hass):
|
def breakpoint_clear_all(hass: HomeAssistant) -> None:
|
||||||
"""Clear all breakpoints."""
|
"""Clear all breakpoints."""
|
||||||
hass.data[DATA_SCRIPT_BREAKPOINTS] = {}
|
hass.data[DATA_SCRIPT_BREAKPOINTS] = {}
|
||||||
|
|
||||||
@ -1348,7 +1359,7 @@ def breakpoint_set(hass, key, run_id, node):
|
|||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def breakpoint_list(hass):
|
def breakpoint_list(hass: HomeAssistant) -> list[dict[str, Any]]:
|
||||||
"""List breakpoints."""
|
"""List breakpoints."""
|
||||||
breakpoints = hass.data[DATA_SCRIPT_BREAKPOINTS]
|
breakpoints = hass.data[DATA_SCRIPT_BREAKPOINTS]
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ import os
|
|||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
|
|
||||||
from homeassistant.const import EVENT_HOMEASSISTANT_FINAL_WRITE
|
from homeassistant.const import EVENT_HOMEASSISTANT_FINAL_WRITE
|
||||||
from homeassistant.core import CALLBACK_TYPE, CoreState, HomeAssistant, callback
|
from homeassistant.core import CALLBACK_TYPE, CoreState, Event, HomeAssistant, callback
|
||||||
from homeassistant.helpers.event import async_call_later
|
from homeassistant.helpers.event import async_call_later
|
||||||
from homeassistant.loader import bind_hass
|
from homeassistant.loader import bind_hass
|
||||||
from homeassistant.util import json as json_util
|
from homeassistant.util import json as json_util
|
||||||
@ -169,7 +169,7 @@ class Store:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_ensure_final_write_listener(self):
|
def _async_ensure_final_write_listener(self) -> None:
|
||||||
"""Ensure that we write if we quit before delay has passed."""
|
"""Ensure that we write if we quit before delay has passed."""
|
||||||
if self._unsub_final_write_listener is None:
|
if self._unsub_final_write_listener is None:
|
||||||
self._unsub_final_write_listener = self.hass.bus.async_listen_once(
|
self._unsub_final_write_listener = self.hass.bus.async_listen_once(
|
||||||
@ -177,14 +177,14 @@ class Store:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_cleanup_final_write_listener(self):
|
def _async_cleanup_final_write_listener(self) -> None:
|
||||||
"""Clean up a stop listener."""
|
"""Clean up a stop listener."""
|
||||||
if self._unsub_final_write_listener is not None:
|
if self._unsub_final_write_listener is not None:
|
||||||
self._unsub_final_write_listener()
|
self._unsub_final_write_listener()
|
||||||
self._unsub_final_write_listener = None
|
self._unsub_final_write_listener = None
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_cleanup_delay_listener(self):
|
def _async_cleanup_delay_listener(self) -> None:
|
||||||
"""Clean up a delay listener."""
|
"""Clean up a delay listener."""
|
||||||
if self._unsub_delay_listener is not None:
|
if self._unsub_delay_listener is not None:
|
||||||
self._unsub_delay_listener()
|
self._unsub_delay_listener()
|
||||||
@ -198,7 +198,7 @@ class Store:
|
|||||||
return
|
return
|
||||||
await self._async_handle_write_data()
|
await self._async_handle_write_data()
|
||||||
|
|
||||||
async def _async_callback_final_write(self, _event):
|
async def _async_callback_final_write(self, _event: Event) -> None:
|
||||||
"""Handle a write because Home Assistant is in final write state."""
|
"""Handle a write because Home Assistant is in final write state."""
|
||||||
self._unsub_final_write_listener = None
|
self._unsub_final_write_listener = None
|
||||||
await self._async_handle_write_data()
|
await self._async_handle_write_data()
|
||||||
@ -239,7 +239,7 @@ class Store:
|
|||||||
"""Migrate to the new version."""
|
"""Migrate to the new version."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def async_remove(self):
|
async def async_remove(self) -> None:
|
||||||
"""Remove all data."""
|
"""Remove all data."""
|
||||||
self._async_cleanup_delay_listener()
|
self._async_cleanup_delay_listener()
|
||||||
self._async_cleanup_final_write_listener()
|
self._async_cleanup_final_write_listener()
|
||||||
|
@ -16,7 +16,7 @@ from operator import attrgetter
|
|||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Generator, Iterable, cast
|
from typing import Any, Callable, Generator, Iterable, cast
|
||||||
from urllib.parse import urlencode as urllib_urlencode
|
from urllib.parse import urlencode as urllib_urlencode
|
||||||
import weakref
|
import weakref
|
||||||
|
|
||||||
@ -193,31 +193,32 @@ RESULT_WRAPPERS: dict[type, type] = {
|
|||||||
RESULT_WRAPPERS[tuple] = TupleWrapper
|
RESULT_WRAPPERS[tuple] = TupleWrapper
|
||||||
|
|
||||||
|
|
||||||
def _true(arg: Any) -> bool:
|
def _true(arg: str) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _false(arg: Any) -> bool:
|
def _false(arg: str) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class RenderInfo:
|
class RenderInfo:
|
||||||
"""Holds information about a template render."""
|
"""Holds information about a template render."""
|
||||||
|
|
||||||
def __init__(self, template):
|
def __init__(self, template: Template) -> None:
|
||||||
"""Initialise."""
|
"""Initialise."""
|
||||||
self.template = template
|
self.template = template
|
||||||
# Will be set sensibly once frozen.
|
# Will be set sensibly once frozen.
|
||||||
self.filter_lifecycle = _true
|
self.filter_lifecycle: Callable[[str], bool] = _true
|
||||||
self.filter = _true
|
self.filter: Callable[[str], bool] = _true
|
||||||
self._result: str | None = None
|
self._result: str | None = None
|
||||||
self.is_static = False
|
self.is_static = False
|
||||||
self.exception: TemplateError | None = None
|
self.exception: TemplateError | None = None
|
||||||
self.all_states = False
|
self.all_states = False
|
||||||
self.all_states_lifecycle = False
|
self.all_states_lifecycle = False
|
||||||
self.domains = set()
|
# pylint: disable=unsubscriptable-object # for abc.Set, https://github.com/PyCQA/pylint/pull/4275
|
||||||
self.domains_lifecycle = set()
|
self.domains: collections.abc.Set[str] = set()
|
||||||
self.entities = set()
|
self.domains_lifecycle: collections.abc.Set[str] = set()
|
||||||
|
self.entities: collections.abc.Set[str] = set()
|
||||||
self.rate_limit: timedelta | None = None
|
self.rate_limit: timedelta | None = None
|
||||||
self.has_time = False
|
self.has_time = False
|
||||||
|
|
||||||
@ -491,7 +492,7 @@ class Template:
|
|||||||
"""Render the template and collect an entity filter."""
|
"""Render the template and collect an entity filter."""
|
||||||
assert self.hass and _RENDER_INFO not in self.hass.data
|
assert self.hass and _RENDER_INFO not in self.hass.data
|
||||||
|
|
||||||
render_info = RenderInfo(self) # type: ignore[no-untyped-call]
|
render_info = RenderInfo(self)
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
if self.is_static:
|
if self.is_static:
|
||||||
@ -1039,13 +1040,13 @@ def is_state(hass: HomeAssistant, entity_id: str, state: State) -> bool:
|
|||||||
return state_obj is not None and state_obj.state == state
|
return state_obj is not None and state_obj.state == state
|
||||||
|
|
||||||
|
|
||||||
def is_state_attr(hass, entity_id, name, value):
|
def is_state_attr(hass: HomeAssistant, entity_id: str, name: str, value: Any) -> bool:
|
||||||
"""Test if a state's attribute is a specific value."""
|
"""Test if a state's attribute is a specific value."""
|
||||||
attr = state_attr(hass, entity_id, name)
|
attr = state_attr(hass, entity_id, name)
|
||||||
return attr is not None and attr == value
|
return attr is not None and attr == value
|
||||||
|
|
||||||
|
|
||||||
def state_attr(hass, entity_id, name):
|
def state_attr(hass: HomeAssistant, entity_id: str, name: str) -> Any:
|
||||||
"""Get a specific attribute from a state."""
|
"""Get a specific attribute from a state."""
|
||||||
state_obj = _get_state(hass, entity_id)
|
state_obj = _get_state(hass, entity_id)
|
||||||
if state_obj is not None:
|
if state_obj is not None:
|
||||||
@ -1053,7 +1054,7 @@ def state_attr(hass, entity_id, name):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def now(hass):
|
def now(hass: HomeAssistant) -> datetime:
|
||||||
"""Record fetching now."""
|
"""Record fetching now."""
|
||||||
render_info = hass.data.get(_RENDER_INFO)
|
render_info = hass.data.get(_RENDER_INFO)
|
||||||
if render_info is not None:
|
if render_info is not None:
|
||||||
@ -1062,7 +1063,7 @@ def now(hass):
|
|||||||
return dt_util.now()
|
return dt_util.now()
|
||||||
|
|
||||||
|
|
||||||
def utcnow(hass):
|
def utcnow(hass: HomeAssistant) -> datetime:
|
||||||
"""Record fetching utcnow."""
|
"""Record fetching utcnow."""
|
||||||
render_info = hass.data.get(_RENDER_INFO)
|
render_info = hass.data.get(_RENDER_INFO)
|
||||||
if render_info is not None:
|
if render_info is not None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user