From b5a2df1951af23887d450208636c613e9a193f6f Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 23 Feb 2024 08:55:02 -1000 Subject: [PATCH] Refactor keyed event trackers to reduce future refactoring risk (#111150) * Refactor keyed event trackers to avoid refactoring risk Follow to https://github.com/home-assistant/core/pull/110978#discussion_r1496771106 I had to do some type ignores because of the EventType vs Event which is hopefully not going to be needed after the next mypy * delete constants only used one in other const * no field * fixes * less refactoring later * less refactoring later * keep const --- homeassistant/helpers/event.py | 172 +++++++++++++++++++-------------- 1 file changed, 101 insertions(+), 71 deletions(-) diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index 35493dbcacb..0dc3115466a 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -10,7 +10,15 @@ import functools as ft import logging from random import randint import time -from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypedDict, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Concatenate, + Generic, + ParamSpec, + TypedDict, + TypeVar, +) import attr @@ -80,6 +88,32 @@ _TypedDictT = TypeVar("_TypedDictT", bound=Mapping[str, Any]) _P = ParamSpec("_P") +@dataclass(slots=True, frozen=True) +class _KeyedEventTracker(Generic[_TypedDictT]): + """Class to track events by key.""" + + listeners_key: str + callbacks_key: str + event_type: str + dispatcher_callable: Callable[ + [ + HomeAssistant, + dict[str, list[HassJob[[EventType[_TypedDictT]], Any]]], + EventType[_TypedDictT], + ], + None, + ] + filter_callable: Callable[ + [ + HomeAssistant, + dict[str, list[HassJob[[EventType[_TypedDictT]], Any]]], + EventType[_TypedDictT], + ], + bool, + ] + run_immediately: bool + + @dataclass(slots=True) class TrackStates: """Class for keeping track of states being tracked. @@ -312,6 +346,16 @@ def _async_state_change_filter( return event.data["entity_id"] in callbacks +_KEYED_TRACK_STATE_CHANGE = _KeyedEventTracker( + listeners_key=TRACK_STATE_CHANGE_LISTENER, + callbacks_key=TRACK_STATE_CHANGE_CALLBACKS, + event_type=EVENT_STATE_CHANGED, + dispatcher_callable=_async_dispatch_entity_id_event, + filter_callable=_async_state_change_filter, + run_immediately=False, +) + + @bind_hass def _async_track_state_change_event( hass: HomeAssistant, @@ -319,17 +363,7 @@ def _async_track_state_change_event( action: Callable[[EventType[EventStateChangedData]], Any], ) -> CALLBACK_TYPE: """async_track_state_change_event without lowercasing.""" - return _async_track_event( - hass, - entity_ids, - TRACK_STATE_CHANGE_CALLBACKS, - TRACK_STATE_CHANGE_LISTENER, - EVENT_STATE_CHANGED, - _async_dispatch_entity_id_event, - _async_state_change_filter, - action, - False, - ) + return _async_track_event(_KEYED_TRACK_STATE_CHANGE, hass, entity_ids, action) @callback @@ -356,30 +390,13 @@ def _remove_listener( del hass.data[listeners_key] +# tracker, not hass is intentionally the first argument here since its +# constant and may be used in a partial in the future def _async_track_event( + tracker: _KeyedEventTracker[_TypedDictT], hass: HomeAssistant, keys: str | Iterable[str], - callbacks_key: str, - listeners_key: str, - event_type: str, - dispatcher_callable: Callable[ - [ - HomeAssistant, - dict[str, list[HassJob[[EventType[_TypedDictT]], Any]]], - EventType[_TypedDictT], - ], - None, - ], - filter_callable: Callable[ - [ - HomeAssistant, - dict[str, list[HassJob[[EventType[_TypedDictT]], Any]]], - EventType[_TypedDictT], - ], - bool, - ], action: Callable[[EventType[_TypedDictT]], None], - run_immediately: bool, ) -> CALLBACK_TYPE: """Track an event by a specific key. @@ -396,20 +413,23 @@ def _async_track_event( keys = [keys] hass_data = hass.data + callbacks_key = tracker.callbacks_key callbacks: dict[str, list[HassJob[[EventType[_TypedDictT]], Any]]] | None if not (callbacks := hass_data.get(callbacks_key)): callbacks = hass_data[callbacks_key] = {} + listeners_key = tracker.listeners_key + if listeners_key not in hass_data: hass_data[listeners_key] = hass.bus.async_listen( - event_type, - ft.partial(dispatcher_callable, hass, callbacks), - event_filter=ft.partial(filter_callable, hass, callbacks), - run_immediately=run_immediately, + tracker.event_type, + ft.partial(tracker.dispatcher_callable, hass, callbacks), + event_filter=ft.partial(tracker.filter_callable, hass, callbacks), + run_immediately=tracker.run_immediately, ) - job = HassJob(action, f"track {event_type} event {keys}") + job = HassJob(action, f"track {tracker.event_type} event {keys}") for key in keys: if callback_list := callbacks.get(key): @@ -458,6 +478,16 @@ def _async_entity_registry_updated_filter( return event.data.get("old_entity_id", event.data["entity_id"]) in callbacks +_KEYED_TRACK_ENTITY_REGISTRY_UPDATED = _KeyedEventTracker( + listeners_key=TRACK_ENTITY_REGISTRY_UPDATED_LISTENER, + callbacks_key=TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS, + event_type=EVENT_ENTITY_REGISTRY_UPDATED, + dispatcher_callable=_async_dispatch_old_entity_id_or_entity_id_event, + filter_callable=_async_entity_registry_updated_filter, + run_immediately=True, +) + + @bind_hass @callback def async_track_entity_registry_updated_event( @@ -472,15 +502,10 @@ def async_track_entity_registry_updated_event( Similar to async_track_state_change_event. """ return _async_track_event( + _KEYED_TRACK_ENTITY_REGISTRY_UPDATED, hass, entity_ids, - TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS, - TRACK_ENTITY_REGISTRY_UPDATED_LISTENER, - EVENT_ENTITY_REGISTRY_UPDATED, - _async_dispatch_old_entity_id_or_entity_id_event, - _async_entity_registry_updated_filter, action, - True, ) @@ -518,6 +543,16 @@ def _async_dispatch_device_id_event( ) +_KEYED_TRACK_DEVICE_REGISTRY_UPDATED = _KeyedEventTracker( + listeners_key=TRACK_DEVICE_REGISTRY_UPDATED_LISTENER, + callbacks_key=TRACK_DEVICE_REGISTRY_UPDATED_CALLBACKS, + event_type=EVENT_DEVICE_REGISTRY_UPDATED, + dispatcher_callable=_async_dispatch_device_id_event, + filter_callable=_async_device_registry_updated_filter, + run_immediately=True, +) + + @callback def async_track_device_registry_updated_event( hass: HomeAssistant, @@ -529,15 +564,10 @@ def async_track_device_registry_updated_event( Similar to async_track_entity_registry_updated_event. """ return _async_track_event( + _KEYED_TRACK_DEVICE_REGISTRY_UPDATED, hass, device_ids, - TRACK_DEVICE_REGISTRY_UPDATED_CALLBACKS, - TRACK_DEVICE_REGISTRY_UPDATED_LISTENER, - EVENT_DEVICE_REGISTRY_UPDATED, - _async_dispatch_device_id_event, - _async_device_registry_updated_filter, action, - True, ) @@ -583,6 +613,16 @@ def async_track_state_added_domain( return _async_track_state_added_domain(hass, domains, action) +_KEYED_TRACK_STATE_ADDED_DOMAIN = _KeyedEventTracker( + listeners_key=TRACK_STATE_ADDED_DOMAIN_LISTENER, + callbacks_key=TRACK_STATE_ADDED_DOMAIN_CALLBACKS, + event_type=EVENT_STATE_CHANGED, + dispatcher_callable=_async_dispatch_domain_event, + filter_callable=_async_domain_added_filter, + run_immediately=False, +) + + @bind_hass def _async_track_state_added_domain( hass: HomeAssistant, @@ -590,17 +630,7 @@ def _async_track_state_added_domain( action: Callable[[EventType[EventStateChangedData]], Any], ) -> CALLBACK_TYPE: """Track state change events when an entity is added to domains.""" - return _async_track_event( - hass, - domains, - TRACK_STATE_ADDED_DOMAIN_CALLBACKS, - TRACK_STATE_ADDED_DOMAIN_LISTENER, - EVENT_STATE_CHANGED, - _async_dispatch_domain_event, - _async_domain_added_filter, - action, - False, - ) + return _async_track_event(_KEYED_TRACK_STATE_ADDED_DOMAIN, hass, domains, action) @callback @@ -616,6 +646,16 @@ def _async_domain_removed_filter( ) +_KEYED_TRACK_STATE_REMOVED_DOMAIN = _KeyedEventTracker( + listeners_key=TRACK_STATE_REMOVED_DOMAIN_LISTENER, + callbacks_key=TRACK_STATE_REMOVED_DOMAIN_CALLBACKS, + event_type=EVENT_STATE_CHANGED, + dispatcher_callable=_async_dispatch_domain_event, + filter_callable=_async_domain_removed_filter, + run_immediately=False, +) + + @bind_hass def async_track_state_removed_domain( hass: HomeAssistant, @@ -623,17 +663,7 @@ def async_track_state_removed_domain( action: Callable[[EventType[EventStateChangedData]], Any], ) -> CALLBACK_TYPE: """Track state change events when an entity is removed from domains.""" - return _async_track_event( - hass, - domains, - TRACK_STATE_REMOVED_DOMAIN_CALLBACKS, - TRACK_STATE_REMOVED_DOMAIN_LISTENER, - EVENT_STATE_CHANGED, - _async_dispatch_domain_event, - _async_domain_removed_filter, - action, - False, - ) + return _async_track_event(_KEYED_TRACK_STATE_REMOVED_DOMAIN, hass, domains, action) @callback