From 0a9c4f15c4d8a0d3e030450fed44b748604dfc70 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 27 May 2023 19:04:09 -0500 Subject: [PATCH] Add event helper to dispatch device registry updates by device_id (#93602) * Add event helper to dispatch device registry updates by device_id * Update homeassistant/helpers/event.py * dry * dry * reduce * reduce * reorder * reduce * cleanup stack depth * dry * fix double lookup * remove unused * collapse --- homeassistant/helpers/event.py | 420 +++++++++++++++++---------------- tests/helpers/test_event.py | 102 ++++++++ 2 files changed, 315 insertions(+), 207 deletions(-) diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index 4c06b0f23e0..c251f0785a4 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -36,6 +36,7 @@ from homeassistant.loader import bind_hass from homeassistant.util import dt as dt_util from homeassistant.util.async_ import run_callback_threadsafe +from .device_registry import EVENT_DEVICE_REGISTRY_UPDATED from .entity_registry import EVENT_ENTITY_REGISTRY_UPDATED from .ratelimit import KeyedRateLimit from .sun import get_astral_event_next @@ -54,6 +55,9 @@ TRACK_STATE_REMOVED_DOMAIN_LISTENER = "track_state_removed_domain_listener" TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS = "track_entity_registry_updated_callbacks" TRACK_ENTITY_REGISTRY_UPDATED_LISTENER = "track_entity_registry_updated_listener" +TRACK_DEVICE_REGISTRY_UPDATED_CALLBACKS = "track_device_registry_updated_callbacks" +TRACK_DEVICE_REGISTRY_UPDATED_LISTENER = "track_device_registry_updated_listener" + _ALL_LISTENER = "all" _DOMAINS_LISTENER = "domains" _ENTITIES_LISTENER = "entities" @@ -256,6 +260,34 @@ def async_track_state_change_event( return _async_track_state_change_event(hass, entity_ids, action) +@callback +def _async_dispatch_entity_id_event( + hass: HomeAssistant, + callbacks: dict[str, list[HassJob[[Event], Any]]], + event: Event, +) -> None: + """Dispatch to listeners.""" + if not (callbacks_list := callbacks.get(event.data["entity_id"])): + return + for job in callbacks_list[:]: + try: + hass.async_run_hass_job(job, event) + except Exception: # pylint: disable=broad-except + _LOGGER.exception( + "Error while dispatching event for %s to %s", + event.data["entity_id"], + job, + ) + + +@callback +def _async_state_change_filter( + hass: HomeAssistant, callbacks: dict[str, list[HassJob[[Event], Any]]], event: Event +) -> bool: + """Filter state changes by entity_id.""" + return event.data["entity_id"] in callbacks + + @bind_hass def _async_track_state_change_event( hass: HomeAssistant, @@ -263,82 +295,105 @@ def _async_track_state_change_event( action: Callable[[Event], Any], ) -> CALLBACK_TYPE: """async_track_state_change_event without lowercasing.""" - entity_callbacks: dict[str, list[HassJob[[Event], Any]]] = hass.data.setdefault( - TRACK_STATE_CHANGE_CALLBACKS, {} + return _async_track_event( + hass, + entity_ids, + TRACK_STATE_CHANGE_CALLBACKS, + TRACK_STATE_CHANGE_LISTENER, + EVENT_STATE_CHANGED, + _async_dispatch_entity_id_event, + _async_state_change_filter, + action, ) - if TRACK_STATE_CHANGE_LISTENER not in hass.data: - - @callback - def _async_state_change_filter(event: Event) -> bool: - """Filter state changes by entity_id.""" - return event.data.get("entity_id") in entity_callbacks - - @callback - def _async_state_change_dispatcher(event: Event) -> None: - """Dispatch state changes by entity_id.""" - entity_id = event.data.get("entity_id") - - if entity_id not in entity_callbacks: - return - - for job in entity_callbacks[entity_id][:]: - try: - hass.async_run_hass_job(job, event) - except Exception: # pylint: disable=broad-except - _LOGGER.exception( - "Error while processing state change for %s", entity_id - ) - - hass.data[TRACK_STATE_CHANGE_LISTENER] = hass.bus.async_listen( - EVENT_STATE_CHANGED, - _async_state_change_dispatcher, - event_filter=_async_state_change_filter, - ) - - job = HassJob(action, f"track state change event {entity_ids}") - - for entity_id in entity_ids: - entity_callbacks.setdefault(entity_id, []).append(job) - - @callback - def remove_listener() -> None: - """Remove state change listener.""" - _async_remove_indexed_listeners( - hass, - TRACK_STATE_CHANGE_CALLBACKS, - TRACK_STATE_CHANGE_LISTENER, - entity_ids, - job, - ) - - return remove_listener - @callback def _remove_empty_listener() -> None: """Remove a listener that does nothing.""" -@callback -def _async_remove_indexed_listeners( +def _async_track_event( hass: HomeAssistant, - data_key: str, - listener_key: str, - storage_keys: Iterable[str], - job: HassJob[[Event], Any], + keys: str | Iterable[str], + callbacks_key: str, + listeners_key: str, + event_type: str, + dispatcher_callable: Callable[ + [HomeAssistant, dict[str, list[HassJob[[Event], Any]]], Event], None + ], + filter_callable: Callable[ + [HomeAssistant, dict[str, list[HassJob[[Event], Any]]], Event], bool + ], + action: Callable[[Event], None], +) -> CALLBACK_TYPE: + """Track an event by a specific key.""" + if not keys: + return _remove_empty_listener + + if isinstance(keys, str): + keys = [keys] + + callbacks: dict[str, list[HassJob[[Event], Any]]] = hass.data.setdefault( + callbacks_key, {} + ) + + if listeners_key not in hass.data: + hass.data[listeners_key] = hass.bus.async_listen( + event_type, + callback(ft.partial(dispatcher_callable, hass, callbacks)), + event_filter=callback(ft.partial(filter_callable, hass, callbacks)), + ) + + job = HassJob(action, f"track {event_type} event {keys}") + + for key in keys: + callbacks.setdefault(key, []).append(job) + + @callback + def remove_listener() -> None: + """Remove listener.""" + for key in keys: + callbacks[key].remove(job) + if len(callbacks[key]) == 0: + del callbacks[key] + + if not callbacks: + hass.data[listeners_key]() + del hass.data[listeners_key] + + return remove_listener + + +@callback +def _async_dispatch_old_entity_id_or_entity_id_event( + hass: HomeAssistant, + callbacks: dict[str, list[HassJob[[Event], Any]]], + event: Event, ) -> None: - """Remove a listener.""" - callbacks: dict[str, list[HassJob[[Event], Any]]] = hass.data[data_key] + """Dispatch to listeners.""" + if not ( + callbacks_list := callbacks.get( + event.data.get("old_entity_id", event.data["entity_id"]) + ) + ): + return + for job in callbacks_list[:]: + try: + hass.async_run_hass_job(job, event) + except Exception: # pylint: disable=broad-except + _LOGGER.exception( + "Error while dispatching event for %s to %s", + event.data.get("old_entity_id", event.data["entity_id"]), + job, + ) - for storage_key in storage_keys: - callbacks[storage_key].remove(job) - if len(callbacks[storage_key]) == 0: - del callbacks[storage_key] - if not callbacks: - hass.data[listener_key]() - del hass.data[listener_key] +@callback +def _async_entity_registry_updated_filter( + hass: HomeAssistant, callbacks: dict[str, list[HassJob[[Event], Any]]], event: Event +) -> bool: + """Filter entity registry updates by entity_id.""" + return event.data.get("old_entity_id", event.data["entity_id"]) in callbacks @bind_hass @@ -353,76 +408,70 @@ def async_track_entity_registry_updated_event( Similar to async_track_state_change_event. """ - if not entity_ids: - return _remove_empty_listener - if isinstance(entity_ids, str): - entity_ids = [entity_ids] - - entity_callbacks: dict[str, list[HassJob[[Event], Any]]] = hass.data.setdefault( - TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS, {} + return _async_track_event( + hass, + entity_ids, + TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS, + TRACK_ENTITY_REGISTRY_UPDATED_LISTENER, + EVENT_ENTITY_REGISTRY_UPDATED, + _async_dispatch_old_entity_id_or_entity_id_event, + _async_entity_registry_updated_filter, + action, ) - if TRACK_ENTITY_REGISTRY_UPDATED_LISTENER not in hass.data: - - @callback - def _async_entity_registry_updated_filter(event: Event) -> bool: - """Filter entity registry updates by entity_id.""" - entity_id = event.data.get("old_entity_id", event.data["entity_id"]) - return entity_id in entity_callbacks - - @callback - def _async_entity_registry_updated_dispatcher(event: Event) -> None: - """Dispatch entity registry updates by entity_id.""" - entity_id = event.data.get("old_entity_id", event.data["entity_id"]) - - if entity_id not in entity_callbacks: - return - - for job in entity_callbacks[entity_id][:]: - try: - hass.async_run_hass_job(job, event) - except Exception: # pylint: disable=broad-except - _LOGGER.exception( - "Error while processing entity registry update for %s", - entity_id, - ) - - hass.data[TRACK_ENTITY_REGISTRY_UPDATED_LISTENER] = hass.bus.async_listen( - EVENT_ENTITY_REGISTRY_UPDATED, - _async_entity_registry_updated_dispatcher, - event_filter=_async_entity_registry_updated_filter, - ) - - job = HassJob(action, f"track entity registry updated event {entity_ids}") - - for entity_id in entity_ids: - entity_callbacks.setdefault(entity_id, []).append(job) - - @callback - def remove_listener() -> None: - """Remove state change listener.""" - _async_remove_indexed_listeners( - hass, - TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS, - TRACK_ENTITY_REGISTRY_UPDATED_LISTENER, - entity_ids, - job, - ) - - return remove_listener - @callback -def _async_domain_has_listeners( - domain: str, callbacks: dict[str, list[HassJob[[Event], Any]]] +def _async_device_registry_updated_filter( + hass: HomeAssistant, callbacks: dict[str, list[HassJob[[Event], Any]]], event: Event ) -> bool: - """Check if the domain has any listeners.""" - return domain in callbacks or MATCH_ALL in callbacks + """Filter device registry updates by device_id.""" + return event.data["device_id"] in callbacks + + +@callback +def _async_dispatch_device_id_event( + hass: HomeAssistant, + callbacks: dict[str, list[HassJob[[Event], Any]]], + event: Event, +) -> None: + """Dispatch to listeners.""" + if not (callbacks_list := callbacks.get(event.data["device_id"])): + return + for job in callbacks_list[:]: + try: + hass.async_run_hass_job(job, event) + except Exception: # pylint: disable=broad-except + _LOGGER.exception( + "Error while dispatching event for %s to %s", + event.data["device_id"], + job, + ) + + +def async_track_device_registry_updated_event( + hass: HomeAssistant, + device_ids: str | Iterable[str], + action: Callable[[Event], Any], +) -> CALLBACK_TYPE: + """Track specific device registry updated events indexed by device_id. + + Similar to async_track_entity_registry_updated_event. + """ + return _async_track_event( + hass, + device_ids, + TRACK_DEVICE_REGISTRY_UPDATED_CALLBACKS, + TRACK_DEVICE_REGISTRY_UPDATED_LISTENER, + EVENT_DEVICE_REGISTRY_UPDATED, + _async_dispatch_device_id_event, + _async_device_registry_updated_filter, + action, + ) @callback def _async_dispatch_domain_event( - hass: HomeAssistant, event: Event, callbacks: dict[str, list[HassJob[[Event], Any]]] + hass: HomeAssistant, callbacks: dict[str, list[HassJob[[Event], Any]]], event: Event ) -> None: """Dispatch domain event listeners.""" domain = split_entity_id(event.data["entity_id"])[0] @@ -435,6 +484,17 @@ def _async_dispatch_domain_event( ) +@callback +def _async_domain_added_filter( + hass: HomeAssistant, callbacks: dict[str, list[HassJob[[Event], Any]]], event: Event +) -> bool: + """Filter state changes by entity_id.""" + return event.data.get("old_state") is None and ( + MATCH_ALL in callbacks + or split_entity_id(event.data["entity_id"])[0] in callbacks + ) + + @bind_hass def async_track_state_added_domain( hass: HomeAssistant, @@ -453,48 +513,28 @@ def _async_track_state_added_domain( domains: str | Iterable[str], action: Callable[[Event], Any], ) -> CALLBACK_TYPE: - """async_track_state_added_domain without lowercasing.""" - domain_callbacks: dict[str, list[HassJob[[Event], Any]]] = hass.data.setdefault( - TRACK_STATE_ADDED_DOMAIN_CALLBACKS, {} + """Track state change events when an entity is added to domains.""" + return _async_track_event( + hass, + domains, + TRACK_STATE_ADDED_DOMAIN_CALLBACKS, + TRACK_STATE_ADDED_DOMAIN_LISTENER, + EVENT_STATE_CHANGED, + _async_dispatch_domain_event, + _async_domain_added_filter, + action, ) - if TRACK_STATE_ADDED_DOMAIN_LISTENER not in hass.data: - @callback - def _async_state_change_filter(event: Event) -> bool: - """Filter state changes by entity_id.""" - return event.data.get("old_state") is None and _async_domain_has_listeners( - split_entity_id(event.data["entity_id"])[0], domain_callbacks - ) - - @callback - def _async_state_change_dispatcher(event: Event) -> None: - """Dispatch state changes by entity_id.""" - _async_dispatch_domain_event(hass, event, domain_callbacks) - - hass.data[TRACK_STATE_ADDED_DOMAIN_LISTENER] = hass.bus.async_listen( - EVENT_STATE_CHANGED, - _async_state_change_dispatcher, - event_filter=_async_state_change_filter, - ) - - job = HassJob(action, f"track state added domain event {domains}") - - for domain in domains: - domain_callbacks.setdefault(domain, []).append(job) - - @callback - def remove_listener() -> None: - """Remove state change listener.""" - _async_remove_indexed_listeners( - hass, - TRACK_STATE_ADDED_DOMAIN_CALLBACKS, - TRACK_STATE_ADDED_DOMAIN_LISTENER, - domains, - job, - ) - - return remove_listener +@callback +def _async_domain_removed_filter( + hass: HomeAssistant, callbacks: dict[str, list[HassJob[[Event], Any]]], event: Event +) -> bool: + """Filter state changes by entity_id.""" + return event.data.get("new_state") is None and ( + MATCH_ALL in callbacks + or split_entity_id(event.data["entity_id"])[0] in callbacks + ) @bind_hass @@ -504,51 +544,17 @@ def async_track_state_removed_domain( action: Callable[[Event], Any], ) -> CALLBACK_TYPE: """Track state change events when an entity is removed from domains.""" - if not (domains := _async_string_to_lower_list(domains)): - return _remove_empty_listener - - domain_callbacks: dict[str, list[HassJob[[Event], Any]]] = hass.data.setdefault( - TRACK_STATE_REMOVED_DOMAIN_CALLBACKS, {} + return _async_track_event( + hass, + domains, + TRACK_STATE_REMOVED_DOMAIN_CALLBACKS, + TRACK_STATE_REMOVED_DOMAIN_LISTENER, + EVENT_STATE_CHANGED, + _async_dispatch_domain_event, + _async_domain_removed_filter, + action, ) - if TRACK_STATE_REMOVED_DOMAIN_LISTENER not in hass.data: - - @callback - def _async_state_change_filter(event: Event) -> bool: - """Filter state changes by entity_id.""" - return event.data.get("new_state") is None and _async_domain_has_listeners( - split_entity_id(event.data["entity_id"])[0], domain_callbacks - ) - - @callback - def _async_state_change_dispatcher(event: Event) -> None: - """Dispatch state changes by entity_id.""" - _async_dispatch_domain_event(hass, event, domain_callbacks) - - hass.data[TRACK_STATE_REMOVED_DOMAIN_LISTENER] = hass.bus.async_listen( - EVENT_STATE_CHANGED, - _async_state_change_dispatcher, - event_filter=_async_state_change_filter, - ) - - job = HassJob(action, f"track state removed domain event {domains}") - - for domain in domains: - domain_callbacks.setdefault(domain, []).append(job) - - @callback - def remove_listener() -> None: - """Remove state change listener.""" - _async_remove_indexed_listeners( - hass, - TRACK_STATE_REMOVED_DOMAIN_CALLBACKS, - TRACK_STATE_REMOVED_DOMAIN_LISTENER, - domains, - job, - ) - - return remove_listener - @callback def _async_string_to_lower_list(instr: str | Iterable[str]) -> list[str]: diff --git a/tests/helpers/test_event.py b/tests/helpers/test_event.py index 9d90ef1b26c..3740a6b177a 100644 --- a/tests/helpers/test_event.py +++ b/tests/helpers/test_event.py @@ -18,12 +18,14 @@ from homeassistant.const import MATCH_ALL import homeassistant.core as ha from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import TemplateError +from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED from homeassistant.helpers.event import ( TrackStates, TrackTemplate, TrackTemplateResult, async_call_later, + async_track_device_registry_updated_event, async_track_entity_registry_updated_event, async_track_point_in_time, async_track_point_in_utc_time, @@ -4564,3 +4566,103 @@ async def test_async_track_entity_registry_updated_event_with_empty_list( unsub_single2() unsub_single() + + +async def test_async_track_device_registry_updated_event(hass: HomeAssistant) -> None: + """Test tracking device registry updates for an device_id.""" + + device_id = "b92c0f06fbc911edacc9eea8ae14f866" + device_id2 = "747bbf22fbca11ed843aeea8ae14f866" + untracked_device_id = "bda93f86fbc911edacc9eea8ae14f866" + + single_event_data = [] + multiple_event_data = [] + + @ha.callback + def single_device_id_callback(event: ha.Event) -> None: + single_event_data.append(event.data) + + @ha.callback + def multiple_device_id_callback(event: ha.Event) -> None: + multiple_event_data.append(event.data) + + unsub1 = async_track_device_registry_updated_event( + hass, device_id, single_device_id_callback + ) + unsub2 = async_track_device_registry_updated_event( + hass, [device_id, device_id2], multiple_device_id_callback + ) + hass.bus.async_fire( + EVENT_DEVICE_REGISTRY_UPDATED, {"action": "create", "device_id": device_id} + ) + hass.bus.async_fire( + EVENT_ENTITY_REGISTRY_UPDATED, + {"action": "create", "device_id": untracked_device_id}, + ) + await hass.async_block_till_done() + assert len(single_event_data) == 1 + assert len(multiple_event_data) == 1 + hass.bus.async_fire( + EVENT_DEVICE_REGISTRY_UPDATED, {"action": "create", "device_id": device_id2} + ) + await hass.async_block_till_done() + assert len(single_event_data) == 1 + assert len(multiple_event_data) == 2 + + unsub1() + unsub2() + hass.bus.async_fire( + EVENT_ENTITY_REGISTRY_UPDATED, {"action": "create", "device_id": device_id} + ) + hass.bus.async_fire( + EVENT_ENTITY_REGISTRY_UPDATED, {"action": "create", "device_id": device_id2} + ) + await hass.async_block_till_done() + assert len(single_event_data) == 1 + assert len(multiple_event_data) == 2 + + +async def test_async_track_device_registry_updated_event_with_empty_list( + hass: HomeAssistant, +) -> None: + """Test async_track_device_registry_updated_event passing an empty list of devices.""" + unsub_single = async_track_device_registry_updated_event( + hass, [], ha.callback(lambda event: None) + ) + unsub_single2 = async_track_device_registry_updated_event( + hass, [], ha.callback(lambda event: None) + ) + + unsub_single2() + unsub_single() + + +async def test_async_track_device_registry_updated_event_with_a_callback_that_throws( + hass: HomeAssistant, +) -> None: + """Test tracking device registry updates for an device when one callback throws.""" + + device_id = "b92c0f06fbc911edacc9eea8ae14f866" + + event_data = [] + + @ha.callback + def run_callback(event: ha.Event) -> None: + event_data.append(event.data) + + @ha.callback + def failing_callback(event: ha.Event) -> None: + raise ValueError + + unsub1 = async_track_device_registry_updated_event( + hass, device_id, failing_callback + ) + unsub2 = async_track_device_registry_updated_event(hass, device_id, run_callback) + hass.bus.async_fire( + EVENT_DEVICE_REGISTRY_UPDATED, {"action": "create", "device_id": device_id} + ) + await hass.async_block_till_done() + unsub1() + unsub2() + + assert event_data[0] == {"action": "create", "device_id": device_id}