From 235b98da8ab7c379e19218fce2aea12edcaf9ec4 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Mon, 24 Jul 2023 01:32:29 +0200 Subject: [PATCH] Use EventType for remaining event helper methods (#97121) --- homeassistant/helpers/event.py | 74 +++++++++++++++++----------------- 1 file changed, 38 insertions(+), 36 deletions(-) diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index 830b6100111..12cf58eaa2b 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -10,12 +10,11 @@ import functools as ft import logging from random import randint import time -from typing import Any, Concatenate, ParamSpec, TypedDict, TypeVar, cast +from typing import Any, Concatenate, ParamSpec, TypedDict, TypeVar import attr from homeassistant.const import ( - ATTR_ENTITY_ID, EVENT_CORE_CONFIG_UPDATE, EVENT_STATE_CHANGED, MATCH_ALL, @@ -24,7 +23,6 @@ from homeassistant.const import ( ) from homeassistant.core import ( CALLBACK_TYPE, - Event, HassJob, HomeAssistant, State, @@ -331,13 +329,13 @@ def _remove_empty_listener() -> None: """Remove a listener that does nothing.""" -@callback +@callback # type: ignore[arg-type] # mypy bug? def _remove_listener( hass: HomeAssistant, listeners_key: str, keys: Iterable[str], - job: HassJob[[Event], Any], - callbacks: dict[str, list[HassJob[[Event], Any]]], + job: HassJob[[EventType[_TypedDictT]], Any], + callbacks: dict[str, list[HassJob[[EventType[_TypedDictT]], Any]]], ) -> None: """Remove listener.""" for key in keys: @@ -451,7 +449,7 @@ def _async_entity_registry_updated_filter( def async_track_entity_registry_updated_event( hass: HomeAssistant, entity_ids: str | Iterable[str], - action: Callable[[Event], Any], + action: Callable[[EventType[EventEntityRegistryUpdatedData]], Any], ) -> CALLBACK_TYPE: """Track specific entity registry updated events indexed by entity_id. @@ -509,7 +507,7 @@ def _async_dispatch_device_id_event( def async_track_device_registry_updated_event( hass: HomeAssistant, device_ids: str | Iterable[str], - action: Callable[[Event], Any], + action: Callable[[EventType[EventDeviceRegistryUpdatedData]], Any], ) -> CALLBACK_TYPE: """Track specific device registry updated events indexed by device_id. @@ -561,7 +559,7 @@ def _async_domain_added_filter( def async_track_state_added_domain( hass: HomeAssistant, domains: str | Iterable[str], - action: Callable[[Event], Any], + action: Callable[[EventType[EventStateChangedData]], Any], ) -> CALLBACK_TYPE: """Track state change events when an entity is added to domains.""" if not (domains := _async_string_to_lower_list(domains)): @@ -573,7 +571,7 @@ def async_track_state_added_domain( def _async_track_state_added_domain( hass: HomeAssistant, domains: str | Iterable[str], - action: Callable[[Event], Any], + action: Callable[[EventType[EventStateChangedData]], Any], ) -> CALLBACK_TYPE: """Track state change events when an entity is added to domains.""" return _async_track_event( @@ -605,7 +603,7 @@ def _async_domain_removed_filter( def async_track_state_removed_domain( hass: HomeAssistant, domains: str | Iterable[str], - action: Callable[[Event], Any], + action: Callable[[EventType[EventStateChangedData]], Any], ) -> CALLBACK_TYPE: """Track state change events when an entity is removed from domains.""" return _async_track_event( @@ -635,7 +633,7 @@ class _TrackStateChangeFiltered: self, hass: HomeAssistant, track_states: TrackStates, - action: Callable[[Event], Any], + action: Callable[[EventType[EventStateChangedData]], Any], ) -> None: """Handle removal / refresh of tracker init.""" self.hass = hass @@ -739,7 +737,7 @@ class _TrackStateChangeFiltered: ) @callback - def _state_added(self, event: Event) -> None: + def _state_added(self, event: EventType[EventStateChangedData]) -> None: self._cancel_listener(_ENTITIES_LISTENER) self._setup_entities_listener( self._last_track_states.domains, self._last_track_states.entities @@ -758,7 +756,7 @@ class _TrackStateChangeFiltered: @callback def _setup_all_listener(self) -> None: self._listeners[_ALL_LISTENER] = self.hass.bus.async_listen( - EVENT_STATE_CHANGED, self._action + EVENT_STATE_CHANGED, self._action # type: ignore[arg-type] ) @@ -767,7 +765,7 @@ class _TrackStateChangeFiltered: def async_track_state_change_filtered( hass: HomeAssistant, track_states: TrackStates, - action: Callable[[Event], Any], + action: Callable[[EventType[EventStateChangedData]], Any], ) -> _TrackStateChangeFiltered: """Track state changes with a TrackStates filter that can be updated. @@ -841,7 +839,8 @@ def async_track_template( @callback def _template_changed_listener( - event: Event | None, updates: list[TrackTemplateResult] + event: EventType[EventStateChangedData] | None, + updates: list[TrackTemplateResult], ) -> None: """Check if condition is correct and run action.""" track_result = updates.pop() @@ -867,9 +866,9 @@ def async_track_template( hass.async_run_hass_job( job, - event and event.data.get("entity_id"), - event and event.data.get("old_state"), - event and event.data.get("new_state"), + event and event.data["entity_id"], + event and event.data["old_state"], + event and event.data["new_state"], ) info = async_track_template_result( @@ -889,7 +888,9 @@ class TrackTemplateResultInfo: self, hass: HomeAssistant, track_templates: Sequence[TrackTemplate], - action: Callable[[Event | None, list[TrackTemplateResult]], None], + action: Callable[ + [EventType[EventStateChangedData] | None, list[TrackTemplateResult]], None + ], has_super_template: bool = False, ) -> None: """Handle removal / refresh of tracker init.""" @@ -1026,7 +1027,7 @@ class TrackTemplateResultInfo: self, track_template_: TrackTemplate, now: datetime, - event: Event | None, + event: EventType[EventStateChangedData] | None, ) -> bool | TrackTemplateResult: """Re-render the template if conditions match. @@ -1097,7 +1098,7 @@ class TrackTemplateResultInfo: @callback def _refresh( self, - event: Event | None, + event: EventType[EventStateChangedData] | None, track_templates: Iterable[TrackTemplate] | None = None, replayed: bool | None = False, ) -> None: @@ -1205,7 +1206,7 @@ class TrackTemplateResultInfo: TrackTemplateResultListener = Callable[ [ - Event | None, + EventType[EventStateChangedData] | None, list[TrackTemplateResult], ], None, @@ -1315,11 +1316,11 @@ def async_track_same_state( hass.async_run_hass_job(job) @callback - def state_for_cancel_listener(event: Event) -> None: + def state_for_cancel_listener(event: EventType[EventStateChangedData]) -> None: """Fire on changes and cancel for listener if changed.""" - entity: str = event.data["entity_id"] - from_state: State | None = event.data.get("old_state") - to_state: State | None = event.data.get("new_state") + entity = event.data["entity_id"] + from_state = event.data["old_state"] + to_state = event.data["new_state"] if not async_check_same_func(entity, from_state, to_state): clear_listener() @@ -1330,7 +1331,7 @@ def async_track_same_state( if entity_ids == MATCH_ALL: async_remove_state_for_cancel = hass.bus.async_listen( - EVENT_STATE_CHANGED, state_for_cancel_listener + EVENT_STATE_CHANGED, state_for_cancel_listener # type: ignore[arg-type] ) else: async_remove_state_for_cancel = async_track_state_change_event( @@ -1761,17 +1762,16 @@ def _render_infos_to_track_states(render_infos: Iterable[RenderInfo]) -> TrackSt @callback -def _event_triggers_rerender(event: Event, info: RenderInfo) -> bool: +def _event_triggers_rerender( + event: EventType[EventStateChangedData], info: RenderInfo +) -> bool: """Determine if a template should be re-rendered from an event.""" - entity_id = cast(str, event.data.get(ATTR_ENTITY_ID)) + entity_id = event.data["entity_id"] if info.filter(entity_id): return True - if ( - event.data.get("new_state") is not None - and event.data.get("old_state") is not None - ): + if event.data["new_state"] is not None and event.data["old_state"] is not None: return False return bool(info.filter_lifecycle(entity_id)) @@ -1779,12 +1779,14 @@ def _event_triggers_rerender(event: Event, info: RenderInfo) -> bool: @callback def _rate_limit_for_event( - event: Event, info: RenderInfo, track_template_: TrackTemplate + event: EventType[EventStateChangedData], + info: RenderInfo, + track_template_: TrackTemplate, ) -> timedelta | None: """Determine the rate limit for an event.""" # Specifically referenced entities are excluded # from the rate limit - if event.data.get(ATTR_ENTITY_ID) in info.entities: + if event.data["entity_id"] in info.entities: return None if track_template_.rate_limit is not None: