From e663d4f602a40e3beb0f8417cb593c2acfb44262 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 22 May 2024 17:14:50 -1000 Subject: [PATCH] Refactor state_reported listener setup to avoid merge in async_fire_internal (#117953) * Refactor state_reported listener setup to avoid merge in async_fire_internal Instead of merging the listeners in async_fire_internal, setup the listener for state_changed at the same time so async_fire_internal can avoid having to copy another list * Refactor state_reported listener setup to avoid merge in async_fire_internal Instead of merging the listeners in async_fire_internal, setup the listener for state_changed at the same time so async_fire_internal can avoid having to copy another list * tweak * tweak * tweak * tweak * tweak --- homeassistant/core.py | 40 +++++++++++++++++++++++----------------- tests/test_core.py | 11 ++++++++++- 2 files changed, 33 insertions(+), 18 deletions(-) diff --git a/homeassistant/core.py b/homeassistant/core.py index 5d3433855df..48a600ae1c9 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -1500,7 +1500,6 @@ class EventBus: This method must be run in the event loop. """ - if self._debug: _LOGGER.debug( "Bus:Handling %s", _event_repr(event_type, origin, event_data) @@ -1511,17 +1510,9 @@ class EventBus: match_all_listeners = self._match_all_listeners else: match_all_listeners = EMPTY_LIST - if event_type == EVENT_STATE_CHANGED: - aliased_listeners = self._listeners.get(EVENT_STATE_REPORTED, EMPTY_LIST) - else: - aliased_listeners = EMPTY_LIST - listeners = listeners + match_all_listeners + aliased_listeners - if not listeners: - return event: Event[_DataT] | None = None - - for job, event_filter in listeners: + for job, event_filter in listeners + match_all_listeners: if event_filter is not None: try: if event_data is None or not event_filter(event_data): @@ -1599,18 +1590,32 @@ class EventBus: if event_filter is not None and not is_callback_check_partial(event_filter): raise HomeAssistantError(f"Event filter {event_filter} is not a callback") + filterable_job = (HassJob(listener, f"listen {event_type}"), event_filter) if event_type == EVENT_STATE_REPORTED: if not event_filter: raise HomeAssistantError( f"Event filter is required for event {event_type}" ) - return self._async_listen_filterable_job( - event_type, - ( - HassJob(listener, f"listen {event_type}"), - event_filter, - ), - ) + # Special case for EVENT_STATE_REPORTED, we also want to listen to + # EVENT_STATE_CHANGED + self._listeners[EVENT_STATE_REPORTED].append(filterable_job) + self._listeners[EVENT_STATE_CHANGED].append(filterable_job) + return functools.partial( + self._async_remove_multiple_listeners, + (EVENT_STATE_REPORTED, EVENT_STATE_CHANGED), + filterable_job, + ) + return self._async_listen_filterable_job(event_type, filterable_job) + + @callback + def _async_remove_multiple_listeners( + self, + keys: Iterable[EventType[_DataT] | str], + filterable_job: _FilterableJobType[Any], + ) -> None: + """Remove multiple listeners for specific event_types.""" + for key in keys: + self._async_remove_listener(key, filterable_job) @callback def _async_listen_filterable_job( @@ -1618,6 +1623,7 @@ class EventBus: event_type: EventType[_DataT] | str, filterable_job: _FilterableJobType[_DataT], ) -> CALLBACK_TYPE: + """Listen for all events or events of a specific type.""" self._listeners[event_type].append(filterable_job) return functools.partial( self._async_remove_listener, event_type, filterable_job diff --git a/tests/test_core.py b/tests/test_core.py index b7cdae1c6e5..2f2b3fd7453 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -3346,7 +3346,9 @@ async def test_statemachine_report_state(hass: HomeAssistant) -> None: hass.states.async_set("light.bowl", "on", {}) state_changed_events = async_capture_events(hass, EVENT_STATE_CHANGED) state_reported_events = [] - hass.bus.async_listen(EVENT_STATE_REPORTED, listener, event_filter=mock_filter) + unsub = hass.bus.async_listen( + EVENT_STATE_REPORTED, listener, event_filter=mock_filter + ) hass.states.async_set("light.bowl", "on") await hass.async_block_till_done() @@ -3368,6 +3370,13 @@ async def test_statemachine_report_state(hass: HomeAssistant) -> None: assert len(state_changed_events) == 3 assert len(state_reported_events) == 4 + unsub() + + hass.states.async_set("light.bowl", "on") + await hass.async_block_till_done() + assert len(state_changed_events) == 4 + assert len(state_reported_events) == 4 + async def test_report_state_listener_restrictions(hass: HomeAssistant) -> None: """Test we enforce requirements for EVENT_STATE_REPORTED listeners."""