Use EventType for remaining event helper methods (#97121)

This commit is contained in:
Marc Mueller 2023-07-24 01:32:29 +02:00 committed by GitHub
parent f8c3aa7bec
commit 235b98da8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -10,12 +10,11 @@ import functools as ft
import logging
from random import randint
import time
from typing import Any, Concatenate, ParamSpec, TypedDict, TypeVar, cast
from typing import Any, Concatenate, ParamSpec, TypedDict, TypeVar
import attr
from homeassistant.const import (
ATTR_ENTITY_ID,
EVENT_CORE_CONFIG_UPDATE,
EVENT_STATE_CHANGED,
MATCH_ALL,
@ -24,7 +23,6 @@ from homeassistant.const import (
)
from homeassistant.core import (
CALLBACK_TYPE,
Event,
HassJob,
HomeAssistant,
State,
@ -331,13 +329,13 @@ def _remove_empty_listener() -> None:
"""Remove a listener that does nothing."""
@callback
@callback # type: ignore[arg-type] # mypy bug?
def _remove_listener(
hass: HomeAssistant,
listeners_key: str,
keys: Iterable[str],
job: HassJob[[Event], Any],
callbacks: dict[str, list[HassJob[[Event], Any]]],
job: HassJob[[EventType[_TypedDictT]], Any],
callbacks: dict[str, list[HassJob[[EventType[_TypedDictT]], Any]]],
) -> None:
"""Remove listener."""
for key in keys:
@ -451,7 +449,7 @@ def _async_entity_registry_updated_filter(
def async_track_entity_registry_updated_event(
hass: HomeAssistant,
entity_ids: str | Iterable[str],
action: Callable[[Event], Any],
action: Callable[[EventType[EventEntityRegistryUpdatedData]], Any],
) -> CALLBACK_TYPE:
"""Track specific entity registry updated events indexed by entity_id.
@ -509,7 +507,7 @@ def _async_dispatch_device_id_event(
def async_track_device_registry_updated_event(
hass: HomeAssistant,
device_ids: str | Iterable[str],
action: Callable[[Event], Any],
action: Callable[[EventType[EventDeviceRegistryUpdatedData]], Any],
) -> CALLBACK_TYPE:
"""Track specific device registry updated events indexed by device_id.
@ -561,7 +559,7 @@ def _async_domain_added_filter(
def async_track_state_added_domain(
hass: HomeAssistant,
domains: str | Iterable[str],
action: Callable[[Event], Any],
action: Callable[[EventType[EventStateChangedData]], Any],
) -> CALLBACK_TYPE:
"""Track state change events when an entity is added to domains."""
if not (domains := _async_string_to_lower_list(domains)):
@ -573,7 +571,7 @@ def async_track_state_added_domain(
def _async_track_state_added_domain(
hass: HomeAssistant,
domains: str | Iterable[str],
action: Callable[[Event], Any],
action: Callable[[EventType[EventStateChangedData]], Any],
) -> CALLBACK_TYPE:
"""Track state change events when an entity is added to domains."""
return _async_track_event(
@ -605,7 +603,7 @@ def _async_domain_removed_filter(
def async_track_state_removed_domain(
hass: HomeAssistant,
domains: str | Iterable[str],
action: Callable[[Event], Any],
action: Callable[[EventType[EventStateChangedData]], Any],
) -> CALLBACK_TYPE:
"""Track state change events when an entity is removed from domains."""
return _async_track_event(
@ -635,7 +633,7 @@ class _TrackStateChangeFiltered:
self,
hass: HomeAssistant,
track_states: TrackStates,
action: Callable[[Event], Any],
action: Callable[[EventType[EventStateChangedData]], Any],
) -> None:
"""Handle removal / refresh of tracker init."""
self.hass = hass
@ -739,7 +737,7 @@ class _TrackStateChangeFiltered:
)
@callback
def _state_added(self, event: Event) -> None:
def _state_added(self, event: EventType[EventStateChangedData]) -> None:
self._cancel_listener(_ENTITIES_LISTENER)
self._setup_entities_listener(
self._last_track_states.domains, self._last_track_states.entities
@ -758,7 +756,7 @@ class _TrackStateChangeFiltered:
@callback
def _setup_all_listener(self) -> None:
self._listeners[_ALL_LISTENER] = self.hass.bus.async_listen(
EVENT_STATE_CHANGED, self._action
EVENT_STATE_CHANGED, self._action # type: ignore[arg-type]
)
@ -767,7 +765,7 @@ class _TrackStateChangeFiltered:
def async_track_state_change_filtered(
hass: HomeAssistant,
track_states: TrackStates,
action: Callable[[Event], Any],
action: Callable[[EventType[EventStateChangedData]], Any],
) -> _TrackStateChangeFiltered:
"""Track state changes with a TrackStates filter that can be updated.
@ -841,7 +839,8 @@ def async_track_template(
@callback
def _template_changed_listener(
event: Event | None, updates: list[TrackTemplateResult]
event: EventType[EventStateChangedData] | None,
updates: list[TrackTemplateResult],
) -> None:
"""Check if condition is correct and run action."""
track_result = updates.pop()
@ -867,9 +866,9 @@ def async_track_template(
hass.async_run_hass_job(
job,
event and event.data.get("entity_id"),
event and event.data.get("old_state"),
event and event.data.get("new_state"),
event and event.data["entity_id"],
event and event.data["old_state"],
event and event.data["new_state"],
)
info = async_track_template_result(
@ -889,7 +888,9 @@ class TrackTemplateResultInfo:
self,
hass: HomeAssistant,
track_templates: Sequence[TrackTemplate],
action: Callable[[Event | None, list[TrackTemplateResult]], None],
action: Callable[
[EventType[EventStateChangedData] | None, list[TrackTemplateResult]], None
],
has_super_template: bool = False,
) -> None:
"""Handle removal / refresh of tracker init."""
@ -1026,7 +1027,7 @@ class TrackTemplateResultInfo:
self,
track_template_: TrackTemplate,
now: datetime,
event: Event | None,
event: EventType[EventStateChangedData] | None,
) -> bool | TrackTemplateResult:
"""Re-render the template if conditions match.
@ -1097,7 +1098,7 @@ class TrackTemplateResultInfo:
@callback
def _refresh(
self,
event: Event | None,
event: EventType[EventStateChangedData] | None,
track_templates: Iterable[TrackTemplate] | None = None,
replayed: bool | None = False,
) -> None:
@ -1205,7 +1206,7 @@ class TrackTemplateResultInfo:
TrackTemplateResultListener = Callable[
[
Event | None,
EventType[EventStateChangedData] | None,
list[TrackTemplateResult],
],
None,
@ -1315,11 +1316,11 @@ def async_track_same_state(
hass.async_run_hass_job(job)
@callback
def state_for_cancel_listener(event: Event) -> None:
def state_for_cancel_listener(event: EventType[EventStateChangedData]) -> None:
"""Fire on changes and cancel for listener if changed."""
entity: str = event.data["entity_id"]
from_state: State | None = event.data.get("old_state")
to_state: State | None = event.data.get("new_state")
entity = event.data["entity_id"]
from_state = event.data["old_state"]
to_state = event.data["new_state"]
if not async_check_same_func(entity, from_state, to_state):
clear_listener()
@ -1330,7 +1331,7 @@ def async_track_same_state(
if entity_ids == MATCH_ALL:
async_remove_state_for_cancel = hass.bus.async_listen(
EVENT_STATE_CHANGED, state_for_cancel_listener
EVENT_STATE_CHANGED, state_for_cancel_listener # type: ignore[arg-type]
)
else:
async_remove_state_for_cancel = async_track_state_change_event(
@ -1761,17 +1762,16 @@ def _render_infos_to_track_states(render_infos: Iterable[RenderInfo]) -> TrackSt
@callback
def _event_triggers_rerender(event: Event, info: RenderInfo) -> bool:
def _event_triggers_rerender(
event: EventType[EventStateChangedData], info: RenderInfo
) -> bool:
"""Determine if a template should be re-rendered from an event."""
entity_id = cast(str, event.data.get(ATTR_ENTITY_ID))
entity_id = event.data["entity_id"]
if info.filter(entity_id):
return True
if (
event.data.get("new_state") is not None
and event.data.get("old_state") is not None
):
if event.data["new_state"] is not None and event.data["old_state"] is not None:
return False
return bool(info.filter_lifecycle(entity_id))
@ -1779,12 +1779,14 @@ def _event_triggers_rerender(event: Event, info: RenderInfo) -> bool:
@callback
def _rate_limit_for_event(
event: Event, info: RenderInfo, track_template_: TrackTemplate
event: EventType[EventStateChangedData],
info: RenderInfo,
track_template_: TrackTemplate,
) -> timedelta | None:
"""Determine the rate limit for an event."""
# Specifically referenced entities are excluded
# from the rate limit
if event.data.get(ATTR_ENTITY_ID) in info.entities:
if event.data["entity_id"] in info.entities:
return None
if track_template_.rate_limit is not None: