Replace EventType with Event [helpers] (#112743)

This commit is contained in:
Marc Mueller
2024-03-08 19:41:50 +01:00
committed by GitHub
parent 3db28d46b2
commit b026b5d589
3 changed files with 114 additions and 121 deletions

View File

@@ -32,6 +32,7 @@ from homeassistant.const import (
)
from homeassistant.core import (
CALLBACK_TYPE,
Event,
HassJob,
HassJobType,
HomeAssistant,
@@ -55,7 +56,7 @@ from .entity_registry import (
from .ratelimit import KeyedRateLimit
from .sun import get_astral_event_next
from .template import RenderInfo, Template, result_as_boolean
from .typing import EventType, TemplateVarsType
from .typing import TemplateVarsType
TRACK_STATE_CHANGE_CALLBACKS = "track_state_change_callbacks"
TRACK_STATE_CHANGE_LISTENER = "track_state_change_listener"
@@ -99,16 +100,16 @@ class _KeyedEventTracker(Generic[_TypedDictT]):
dispatcher_callable: Callable[
[
HomeAssistant,
dict[str, list[HassJob[[EventType[_TypedDictT]], Any]]],
EventType[_TypedDictT],
dict[str, list[HassJob[[Event[_TypedDictT]], Any]]],
Event[_TypedDictT],
],
None,
]
filter_callable: Callable[
[
HomeAssistant,
dict[str, list[HassJob[[EventType[_TypedDictT]], Any]]],
EventType[_TypedDictT],
dict[str, list[HassJob[[Event[_TypedDictT]], Any]]],
Event[_TypedDictT],
],
bool,
]
@@ -236,7 +237,7 @@ def async_track_state_change(
job = HassJob(action, f"track state change {entity_ids} {from_state} {to_state}")
@callback
def state_change_filter(event: EventType[EventStateChangedData]) -> bool:
def state_change_filter(event: Event[EventStateChangedData]) -> bool:
"""Handle specific state changes."""
if from_state is not None:
old_state_str: str | None = None
@@ -257,7 +258,7 @@ def async_track_state_change(
return True
@callback
def state_change_dispatcher(event: EventType[EventStateChangedData]) -> None:
def state_change_dispatcher(event: Event[EventStateChangedData]) -> None:
"""Handle specific state changes."""
hass.async_run_hass_job(
job,
@@ -267,7 +268,7 @@ def async_track_state_change(
)
@callback
def state_change_listener(event: EventType[EventStateChangedData]) -> None:
def state_change_listener(event: Event[EventStateChangedData]) -> None:
"""Handle specific state changes."""
if not state_change_filter(event):
return
@@ -299,7 +300,7 @@ track_state_change = threaded_listener_factory(async_track_state_change)
def async_track_state_change_event(
hass: HomeAssistant,
entity_ids: str | Iterable[str],
action: Callable[[EventType[EventStateChangedData]], Any],
action: Callable[[Event[EventStateChangedData]], Any],
job_type: HassJobType | None = None,
) -> CALLBACK_TYPE:
"""Track specific state change events indexed by entity_id.
@@ -321,8 +322,8 @@ def async_track_state_change_event(
@callback
def _async_dispatch_entity_id_event(
hass: HomeAssistant,
callbacks: dict[str, list[HassJob[[EventType[EventStateChangedData]], Any]]],
event: EventType[EventStateChangedData],
callbacks: dict[str, list[HassJob[[Event[EventStateChangedData]], Any]]],
event: Event[EventStateChangedData],
) -> None:
"""Dispatch to listeners."""
if not (callbacks_list := callbacks.get(event.data["entity_id"])):
@@ -341,8 +342,8 @@ def _async_dispatch_entity_id_event(
@callback
def _async_state_change_filter(
hass: HomeAssistant,
callbacks: dict[str, list[HassJob[[EventType[EventStateChangedData]], Any]]],
event: EventType[EventStateChangedData],
callbacks: dict[str, list[HassJob[[Event[EventStateChangedData]], Any]]],
event: Event[EventStateChangedData],
) -> bool:
"""Filter state changes by entity_id."""
return event.data["entity_id"] in callbacks
@@ -362,7 +363,7 @@ _KEYED_TRACK_STATE_CHANGE = _KeyedEventTracker(
def _async_track_state_change_event(
hass: HomeAssistant,
entity_ids: str | Iterable[str],
action: Callable[[EventType[EventStateChangedData]], Any],
action: Callable[[Event[EventStateChangedData]], Any],
job_type: HassJobType | None,
) -> CALLBACK_TYPE:
"""async_track_state_change_event without lowercasing."""
@@ -381,8 +382,8 @@ def _remove_listener(
hass: HomeAssistant,
listeners_key: str,
keys: Iterable[str],
job: HassJob[[EventType[_TypedDictT]], Any],
callbacks: dict[str, list[HassJob[[EventType[_TypedDictT]], Any]]],
job: HassJob[[Event[_TypedDictT]], Any],
callbacks: dict[str, list[HassJob[[Event[_TypedDictT]], Any]]],
) -> None:
"""Remove listener."""
for key in keys:
@@ -401,7 +402,7 @@ def _async_track_event(
tracker: _KeyedEventTracker[_TypedDictT],
hass: HomeAssistant,
keys: str | Iterable[str],
action: Callable[[EventType[_TypedDictT]], None],
action: Callable[[Event[_TypedDictT]], None],
job_type: HassJobType | None,
) -> CALLBACK_TYPE:
"""Track an event by a specific key.
@@ -421,7 +422,7 @@ def _async_track_event(
hass_data = hass.data
callbacks_key = tracker.callbacks_key
callbacks: dict[str, list[HassJob[[EventType[_TypedDictT]], Any]]] | None
callbacks: dict[str, list[HassJob[[Event[_TypedDictT]], Any]]] | None
if not (callbacks := hass_data.get(callbacks_key)):
callbacks = hass_data[callbacks_key] = {}
@@ -449,10 +450,8 @@ def _async_track_event(
@callback
def _async_dispatch_old_entity_id_or_entity_id_event(
hass: HomeAssistant,
callbacks: dict[
str, list[HassJob[[EventType[EventEntityRegistryUpdatedData]], Any]]
],
event: EventType[EventEntityRegistryUpdatedData],
callbacks: dict[str, list[HassJob[[Event[EventEntityRegistryUpdatedData]], Any]]],
event: Event[EventEntityRegistryUpdatedData],
) -> None:
"""Dispatch to listeners."""
if not (
@@ -475,10 +474,8 @@ def _async_dispatch_old_entity_id_or_entity_id_event(
@callback
def _async_entity_registry_updated_filter(
hass: HomeAssistant,
callbacks: dict[
str, list[HassJob[[EventType[EventEntityRegistryUpdatedData]], Any]]
],
event: EventType[EventEntityRegistryUpdatedData],
callbacks: dict[str, list[HassJob[[Event[EventEntityRegistryUpdatedData]], Any]]],
event: Event[EventEntityRegistryUpdatedData],
) -> bool:
"""Filter entity registry updates by entity_id."""
return event.data.get("old_entity_id", event.data["entity_id"]) in callbacks
@@ -499,7 +496,7 @@ _KEYED_TRACK_ENTITY_REGISTRY_UPDATED = _KeyedEventTracker(
def async_track_entity_registry_updated_event(
hass: HomeAssistant,
entity_ids: str | Iterable[str],
action: Callable[[EventType[EventEntityRegistryUpdatedData]], Any],
action: Callable[[Event[EventEntityRegistryUpdatedData]], Any],
job_type: HassJobType | None = None,
) -> CALLBACK_TYPE:
"""Track specific entity registry updated events indexed by entity_id.
@@ -516,10 +513,8 @@ def async_track_entity_registry_updated_event(
@callback
def _async_device_registry_updated_filter(
hass: HomeAssistant,
callbacks: dict[
str, list[HassJob[[EventType[EventDeviceRegistryUpdatedData]], Any]]
],
event: EventType[EventDeviceRegistryUpdatedData],
callbacks: dict[str, list[HassJob[[Event[EventDeviceRegistryUpdatedData]], Any]]],
event: Event[EventDeviceRegistryUpdatedData],
) -> bool:
"""Filter device registry updates by device_id."""
return event.data["device_id"] in callbacks
@@ -528,10 +523,8 @@ def _async_device_registry_updated_filter(
@callback
def _async_dispatch_device_id_event(
hass: HomeAssistant,
callbacks: dict[
str, list[HassJob[[EventType[EventDeviceRegistryUpdatedData]], Any]]
],
event: EventType[EventDeviceRegistryUpdatedData],
callbacks: dict[str, list[HassJob[[Event[EventDeviceRegistryUpdatedData]], Any]]],
event: Event[EventDeviceRegistryUpdatedData],
) -> None:
"""Dispatch to listeners."""
if not (callbacks_list := callbacks.get(event.data["device_id"])):
@@ -561,7 +554,7 @@ _KEYED_TRACK_DEVICE_REGISTRY_UPDATED = _KeyedEventTracker(
def async_track_device_registry_updated_event(
hass: HomeAssistant,
device_ids: str | Iterable[str],
action: Callable[[EventType[EventDeviceRegistryUpdatedData]], Any],
action: Callable[[Event[EventDeviceRegistryUpdatedData]], Any],
job_type: HassJobType | None = None,
) -> CALLBACK_TYPE:
"""Track specific device registry updated events indexed by device_id.
@@ -576,8 +569,8 @@ def async_track_device_registry_updated_event(
@callback
def _async_dispatch_domain_event(
hass: HomeAssistant,
callbacks: dict[str, list[HassJob[[EventType[EventStateChangedData]], Any]]],
event: EventType[EventStateChangedData],
callbacks: dict[str, list[HassJob[[Event[EventStateChangedData]], Any]]],
event: Event[EventStateChangedData],
) -> None:
"""Dispatch domain event listeners."""
domain = split_entity_id(event.data["entity_id"])[0]
@@ -593,8 +586,8 @@ def _async_dispatch_domain_event(
@callback
def _async_domain_added_filter(
hass: HomeAssistant,
callbacks: dict[str, list[HassJob[[EventType[EventStateChangedData]], Any]]],
event: EventType[EventStateChangedData],
callbacks: dict[str, list[HassJob[[Event[EventStateChangedData]], Any]]],
event: Event[EventStateChangedData],
) -> bool:
"""Filter state changes by entity_id."""
return event.data["old_state"] is None and (
@@ -607,7 +600,7 @@ def _async_domain_added_filter(
def async_track_state_added_domain(
hass: HomeAssistant,
domains: str | Iterable[str],
action: Callable[[EventType[EventStateChangedData]], Any],
action: Callable[[Event[EventStateChangedData]], Any],
job_type: HassJobType | None = None,
) -> CALLBACK_TYPE:
"""Track state change events when an entity is added to domains."""
@@ -630,7 +623,7 @@ _KEYED_TRACK_STATE_ADDED_DOMAIN = _KeyedEventTracker(
def _async_track_state_added_domain(
hass: HomeAssistant,
domains: str | Iterable[str],
action: Callable[[EventType[EventStateChangedData]], Any],
action: Callable[[Event[EventStateChangedData]], Any],
job_type: HassJobType | None,
) -> CALLBACK_TYPE:
"""Track state change events when an entity is added to domains."""
@@ -642,8 +635,8 @@ def _async_track_state_added_domain(
@callback
def _async_domain_removed_filter(
hass: HomeAssistant,
callbacks: dict[str, list[HassJob[[EventType[EventStateChangedData]], Any]]],
event: EventType[EventStateChangedData],
callbacks: dict[str, list[HassJob[[Event[EventStateChangedData]], Any]]],
event: Event[EventStateChangedData],
) -> bool:
"""Filter state changes by entity_id."""
return event.data["new_state"] is None and (
@@ -666,7 +659,7 @@ _KEYED_TRACK_STATE_REMOVED_DOMAIN = _KeyedEventTracker(
def async_track_state_removed_domain(
hass: HomeAssistant,
domains: str | Iterable[str],
action: Callable[[EventType[EventStateChangedData]], Any],
action: Callable[[Event[EventStateChangedData]], Any],
job_type: HassJobType | None = None,
) -> CALLBACK_TYPE:
"""Track state change events when an entity is removed from domains."""
@@ -690,7 +683,7 @@ class _TrackStateChangeFiltered:
self,
hass: HomeAssistant,
track_states: TrackStates,
action: Callable[[EventType[EventStateChangedData]], Any],
action: Callable[[Event[EventStateChangedData]], Any],
) -> None:
"""Handle removal / refresh of tracker init."""
self.hass = hass
@@ -794,7 +787,7 @@ class _TrackStateChangeFiltered:
)
@callback
def _state_added(self, event: EventType[EventStateChangedData]) -> None:
def _state_added(self, event: Event[EventStateChangedData]) -> None:
self._cancel_listener(_ENTITIES_LISTENER)
self._setup_entities_listener(
self._last_track_states.domains, self._last_track_states.entities
@@ -823,7 +816,7 @@ class _TrackStateChangeFiltered:
def async_track_state_change_filtered(
hass: HomeAssistant,
track_states: TrackStates,
action: Callable[[EventType[EventStateChangedData]], Any],
action: Callable[[Event[EventStateChangedData]], Any],
) -> _TrackStateChangeFiltered:
"""Track state changes with a TrackStates filter that can be updated.
@@ -897,7 +890,7 @@ def async_track_template(
@callback
def _template_changed_listener(
event: EventType[EventStateChangedData] | None,
event: Event[EventStateChangedData] | None,
updates: list[TrackTemplateResult],
) -> None:
"""Check if condition is correct and run action."""
@@ -1088,7 +1081,7 @@ class TrackTemplateResultInfo:
self,
track_template_: TrackTemplate,
now: datetime,
event: EventType[EventStateChangedData] | None,
event: Event[EventStateChangedData] | None,
) -> bool | TrackTemplateResult:
"""Re-render the template if conditions match.
@@ -1177,7 +1170,7 @@ class TrackTemplateResultInfo:
@callback
def _refresh(
self,
event: EventType[EventStateChangedData] | None,
event: Event[EventStateChangedData] | None,
track_templates: Iterable[TrackTemplate] | None = None,
replayed: bool | None = False,
) -> None:
@@ -1273,7 +1266,7 @@ class TrackTemplateResultInfo:
TrackTemplateResultListener = Callable[
[
EventType[EventStateChangedData] | None,
Event[EventStateChangedData] | None,
list[TrackTemplateResult],
],
Coroutine[Any, Any, None] | None,
@@ -1381,7 +1374,7 @@ def async_track_same_state(
hass.async_run_hass_job(job)
@callback
def state_for_cancel_listener(event: EventType[EventStateChangedData]) -> None:
def state_for_cancel_listener(event: Event[EventStateChangedData]) -> None:
"""Fire on changes and cancel for listener if changed."""
entity = event.data["entity_id"]
from_state = event.data["old_state"]
@@ -1919,7 +1912,7 @@ def _render_infos_to_track_states(render_infos: Iterable[RenderInfo]) -> TrackSt
@callback
def _event_triggers_rerender(
event: EventType[EventStateChangedData], info: RenderInfo
event: Event[EventStateChangedData], info: RenderInfo
) -> bool:
"""Determine if a template should be re-rendered from an event."""
entity_id = event.data["entity_id"]
@@ -1935,7 +1928,7 @@ def _event_triggers_rerender(
@callback
def _rate_limit_for_event(
event: EventType[EventStateChangedData],
event: Event[EventStateChangedData],
info: RenderInfo,
track_template_: TrackTemplate,
) -> timedelta | None: