Add support for pre-filtering events to the event bus (#46371)

This commit is contained in:
J. Nick Koston 2021-02-14 09:42:55 -10:00 committed by GitHub
parent f8f86fbe48
commit c9df42b69a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 256 additions and 60 deletions

View File

@ -252,7 +252,22 @@ class Recorder(threading.Thread):
@callback @callback
def async_initialize(self): def async_initialize(self):
"""Initialize the recorder.""" """Initialize the recorder."""
self.hass.bus.async_listen(MATCH_ALL, self.event_listener) self.hass.bus.async_listen(
MATCH_ALL, self.event_listener, event_filter=self._async_event_filter
)
@callback
def _async_event_filter(self, event):
"""Filter events."""
if event.event_type in self.exclude_t:
return False
entity_id = event.data.get(ATTR_ENTITY_ID)
if entity_id is not None:
if not self.entity_filter(entity_id):
return False
return True
def do_adhoc_purge(self, **kwargs): def do_adhoc_purge(self, **kwargs):
"""Trigger an adhoc purge retaining keep_days worth of data.""" """Trigger an adhoc purge retaining keep_days worth of data."""
@ -378,13 +393,6 @@ class Recorder(threading.Thread):
self._timechanges_seen = 0 self._timechanges_seen = 0
self._commit_event_session_or_retry() self._commit_event_session_or_retry()
continue continue
if event.event_type in self.exclude_t:
continue
entity_id = event.data.get(ATTR_ENTITY_ID)
if entity_id is not None:
if not self.entity_filter(entity_id):
continue
try: try:
if event.event_type == EVENT_STATE_CHANGED: if event.event_type == EVENT_STATE_CHANGED:

View File

@ -1139,17 +1139,13 @@ class EntityRegistryDisabledHandler:
def async_setup(self) -> None: def async_setup(self) -> None:
"""Set up the disable handler.""" """Set up the disable handler."""
self.hass.bus.async_listen( self.hass.bus.async_listen(
entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, self._handle_entry_updated entity_registry.EVENT_ENTITY_REGISTRY_UPDATED,
self._handle_entry_updated,
event_filter=_handle_entry_updated_filter,
) )
async def _handle_entry_updated(self, event: Event) -> None: async def _handle_entry_updated(self, event: Event) -> None:
"""Handle entity registry entry update.""" """Handle entity registry entry update."""
if (
event.data["action"] != "update"
or "disabled_by" not in event.data["changes"]
):
return
if self.registry is None: if self.registry is None:
self.registry = await entity_registry.async_get_registry(self.hass) self.registry = await entity_registry.async_get_registry(self.hass)
@ -1203,6 +1199,14 @@ class EntityRegistryDisabledHandler:
) )
@callback
def _handle_entry_updated_filter(event: Event) -> bool:
"""Handle entity registry entry update filter."""
if event.data["action"] != "update" or "disabled_by" not in event.data["changes"]:
return False
return True
async def support_entry_unload(hass: HomeAssistant, domain: str) -> bool: async def support_entry_unload(hass: HomeAssistant, domain: str) -> bool:
"""Test if a domain supports entry unloading.""" """Test if a domain supports entry unloading."""
integration = await loader.async_get_integration(hass, domain) integration = await loader.async_get_integration(hass, domain)

View File

@ -28,6 +28,7 @@ from typing import (
Mapping, Mapping,
Optional, Optional,
Set, Set,
Tuple,
TypeVar, TypeVar,
Union, Union,
cast, cast,
@ -661,7 +662,7 @@ class EventBus:
def __init__(self, hass: HomeAssistant) -> None: def __init__(self, hass: HomeAssistant) -> None:
"""Initialize a new event bus.""" """Initialize a new event bus."""
self._listeners: Dict[str, List[HassJob]] = {} self._listeners: Dict[str, List[Tuple[HassJob, Optional[Callable]]]] = {}
self._hass = hass self._hass = hass
@callback @callback
@ -717,7 +718,14 @@ class EventBus:
if not listeners: if not listeners:
return return
for job in listeners: for job, event_filter in listeners:
if event_filter is not None:
try:
if not event_filter(event):
continue
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Error in event filter")
continue
self._hass.async_add_hass_job(job, event) self._hass.async_add_hass_job(job, event)
def listen(self, event_type: str, listener: Callable) -> CALLBACK_TYPE: def listen(self, event_type: str, listener: Callable) -> CALLBACK_TYPE:
@ -737,23 +745,38 @@ class EventBus:
return remove_listener return remove_listener
@callback @callback
def async_listen(self, event_type: str, listener: Callable) -> CALLBACK_TYPE: def async_listen(
self,
event_type: str,
listener: Callable,
event_filter: Optional[Callable] = None,
) -> CALLBACK_TYPE:
"""Listen for all events or events of a specific type. """Listen for all events or events of a specific type.
To listen to all events specify the constant ``MATCH_ALL`` To listen to all events specify the constant ``MATCH_ALL``
as event_type. as event_type.
An optional event_filter, which must be a callable decorated with
@callback that returns a boolean value, determines if the
listener callable should run.
This method must be run in the event loop. This method must be run in the event loop.
""" """
return self._async_listen_job(event_type, HassJob(listener)) if event_filter is not None and not is_callback(event_filter):
raise HomeAssistantError(f"Event filter {event_filter} is not a callback")
return self._async_listen_filterable_job(
event_type, (HassJob(listener), event_filter)
)
@callback @callback
def _async_listen_job(self, event_type: str, hassjob: HassJob) -> CALLBACK_TYPE: def _async_listen_filterable_job(
self._listeners.setdefault(event_type, []).append(hassjob) self, event_type: str, filterable_job: Tuple[HassJob, Optional[Callable]]
) -> CALLBACK_TYPE:
self._listeners.setdefault(event_type, []).append(filterable_job)
def remove_listener() -> None: def remove_listener() -> None:
"""Remove the listener.""" """Remove the listener."""
self._async_remove_listener(event_type, hassjob) self._async_remove_listener(event_type, filterable_job)
return remove_listener return remove_listener
@ -786,12 +809,12 @@ class EventBus:
This method must be run in the event loop. This method must be run in the event loop.
""" """
job: Optional[HassJob] = None filterable_job: Optional[Tuple[HassJob, Optional[Callable]]] = None
@callback @callback
def _onetime_listener(event: Event) -> None: def _onetime_listener(event: Event) -> None:
"""Remove listener from event bus and then fire listener.""" """Remove listener from event bus and then fire listener."""
nonlocal job nonlocal filterable_job
if hasattr(_onetime_listener, "run"): if hasattr(_onetime_listener, "run"):
return return
# Set variable so that we will never run twice. # Set variable so that we will never run twice.
@ -800,22 +823,24 @@ class EventBus:
# multiple times as well. # multiple times as well.
# This will make sure the second time it does nothing. # This will make sure the second time it does nothing.
setattr(_onetime_listener, "run", True) setattr(_onetime_listener, "run", True)
assert job is not None assert filterable_job is not None
self._async_remove_listener(event_type, job) self._async_remove_listener(event_type, filterable_job)
self._hass.async_run_job(listener, event) self._hass.async_run_job(listener, event)
job = HassJob(_onetime_listener) filterable_job = (HassJob(_onetime_listener), None)
return self._async_listen_job(event_type, job) return self._async_listen_filterable_job(event_type, filterable_job)
@callback @callback
def _async_remove_listener(self, event_type: str, hassjob: HassJob) -> None: def _async_remove_listener(
self, event_type: str, filterable_job: Tuple[HassJob, Optional[Callable]]
) -> None:
"""Remove a listener of a specific event_type. """Remove a listener of a specific event_type.
This method must be run in the event loop. This method must be run in the event loop.
""" """
try: try:
self._listeners[event_type].remove(hassjob) self._listeners[event_type].remove(filterable_job)
# delete event_type list if empty # delete event_type list if empty
if not self._listeners[event_type]: if not self._listeners[event_type]:
@ -823,7 +848,9 @@ class EventBus:
except (KeyError, ValueError): except (KeyError, ValueError):
# KeyError is key event_type listener did not exist # KeyError is key event_type listener did not exist
# ValueError if listener did not exist within event_type # ValueError if listener did not exist within event_type
_LOGGER.exception("Unable to remove unknown job listener %s", hassjob) _LOGGER.exception(
"Unable to remove unknown job listener %s", filterable_job
)
class State: class State:

View File

@ -686,25 +686,34 @@ def async_setup_cleanup(hass: HomeAssistantType, dev_reg: DeviceRegistry) -> Non
) )
async def entity_registry_changed(event: Event) -> None: async def entity_registry_changed(event: Event) -> None:
"""Handle entity updated or removed.""" """Handle entity updated or removed dispatch."""
await debounced_cleanup.async_call()
@callback
def entity_registry_changed_filter(event: Event) -> bool:
"""Handle entity updated or removed filter."""
if ( if (
event.data["action"] == "update" event.data["action"] == "update"
and "device_id" not in event.data["changes"] and "device_id" not in event.data["changes"]
) or event.data["action"] == "create": ) or event.data["action"] == "create":
return return False
await debounced_cleanup.async_call() return True
if hass.is_running: if hass.is_running:
hass.bus.async_listen( hass.bus.async_listen(
entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, entity_registry_changed entity_registry.EVENT_ENTITY_REGISTRY_UPDATED,
entity_registry_changed,
event_filter=entity_registry_changed_filter,
) )
return return
async def startup_clean(event: Event) -> None: async def startup_clean(event: Event) -> None:
"""Clean up on startup.""" """Clean up on startup."""
hass.bus.async_listen( hass.bus.async_listen(
entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, entity_registry_changed entity_registry.EVENT_ENTITY_REGISTRY_UPDATED,
entity_registry_changed,
event_filter=entity_registry_changed_filter,
) )
await debounced_cleanup.async_call() await debounced_cleanup.async_call()

View File

@ -641,12 +641,14 @@ def async_setup_entity_restore(
) -> None: ) -> None:
"""Set up the entity restore mechanism.""" """Set up the entity restore mechanism."""
@callback
def cleanup_restored_states_filter(event: Event) -> bool:
"""Clean up restored states filter."""
return bool(event.data["action"] == "remove")
@callback @callback
def cleanup_restored_states(event: Event) -> None: def cleanup_restored_states(event: Event) -> None:
"""Clean up restored states.""" """Clean up restored states."""
if event.data["action"] != "remove":
return
state = hass.states.get(event.data["entity_id"]) state = hass.states.get(event.data["entity_id"])
if state is None or not state.attributes.get(ATTR_RESTORED): if state is None or not state.attributes.get(ATTR_RESTORED):
@ -654,7 +656,11 @@ def async_setup_entity_restore(
hass.states.async_remove(event.data["entity_id"], context=event.context) hass.states.async_remove(event.data["entity_id"], context=event.context)
hass.bus.async_listen(EVENT_ENTITY_REGISTRY_UPDATED, cleanup_restored_states) hass.bus.async_listen(
EVENT_ENTITY_REGISTRY_UPDATED,
cleanup_restored_states,
event_filter=cleanup_restored_states_filter,
)
if hass.is_running: if hass.is_running:
return return

View File

@ -180,7 +180,7 @@ def async_track_state_change(
job = HassJob(action) job = HassJob(action)
@callback @callback
def state_change_listener(event: Event) -> None: def state_change_filter(event: Event) -> bool:
"""Handle specific state changes.""" """Handle specific state changes."""
if from_state is not None: if from_state is not None:
old_state = event.data.get("old_state") old_state = event.data.get("old_state")
@ -188,15 +188,21 @@ def async_track_state_change(
old_state = old_state.state old_state = old_state.state
if not match_from_state(old_state): if not match_from_state(old_state):
return return False
if to_state is not None: if to_state is not None:
new_state = event.data.get("new_state") new_state = event.data.get("new_state")
if new_state is not None: if new_state is not None:
new_state = new_state.state new_state = new_state.state
if not match_to_state(new_state): if not match_to_state(new_state):
return return False
return True
@callback
def state_change_dispatcher(event: Event) -> None:
"""Handle specific state changes."""
hass.async_run_hass_job( hass.async_run_hass_job(
job, job,
event.data.get("entity_id"), event.data.get("entity_id"),
@ -204,6 +210,14 @@ def async_track_state_change(
event.data.get("new_state"), event.data.get("new_state"),
) )
@callback
def state_change_listener(event: Event) -> None:
"""Handle specific state changes."""
if not state_change_filter(event):
return
state_change_dispatcher(event)
if entity_ids != MATCH_ALL: if entity_ids != MATCH_ALL:
# If we have a list of entity ids we use # If we have a list of entity ids we use
# async_track_state_change_event to route # async_track_state_change_event to route
@ -215,7 +229,9 @@ def async_track_state_change(
# entity_id. # entity_id.
return async_track_state_change_event(hass, entity_ids, state_change_listener) return async_track_state_change_event(hass, entity_ids, state_change_listener)
return hass.bus.async_listen(EVENT_STATE_CHANGED, state_change_listener) return hass.bus.async_listen(
EVENT_STATE_CHANGED, state_change_dispatcher, event_filter=state_change_filter
)
track_state_change = threaded_listener_factory(async_track_state_change) track_state_change = threaded_listener_factory(async_track_state_change)
@ -246,6 +262,11 @@ def async_track_state_change_event(
if TRACK_STATE_CHANGE_LISTENER not in hass.data: if TRACK_STATE_CHANGE_LISTENER not in hass.data:
@callback
def _async_state_change_filter(event: Event) -> bool:
"""Filter state changes by entity_id."""
return event.data.get("entity_id") in entity_callbacks
@callback @callback
def _async_state_change_dispatcher(event: Event) -> None: def _async_state_change_dispatcher(event: Event) -> None:
"""Dispatch state changes by entity_id.""" """Dispatch state changes by entity_id."""
@ -263,7 +284,9 @@ def async_track_state_change_event(
) )
hass.data[TRACK_STATE_CHANGE_LISTENER] = hass.bus.async_listen( hass.data[TRACK_STATE_CHANGE_LISTENER] = hass.bus.async_listen(
EVENT_STATE_CHANGED, _async_state_change_dispatcher EVENT_STATE_CHANGED,
_async_state_change_dispatcher,
event_filter=_async_state_change_filter,
) )
job = HassJob(action) job = HassJob(action)
@ -329,6 +352,12 @@ def async_track_entity_registry_updated_event(
if TRACK_ENTITY_REGISTRY_UPDATED_LISTENER not in hass.data: if TRACK_ENTITY_REGISTRY_UPDATED_LISTENER not in hass.data:
@callback
def _async_entity_registry_updated_filter(event: Event) -> bool:
"""Filter entity registry updates by entity_id."""
entity_id = event.data.get("old_entity_id", event.data["entity_id"])
return entity_id in entity_callbacks
@callback @callback
def _async_entity_registry_updated_dispatcher(event: Event) -> None: def _async_entity_registry_updated_dispatcher(event: Event) -> None:
"""Dispatch entity registry updates by entity_id.""" """Dispatch entity registry updates by entity_id."""
@ -347,7 +376,9 @@ def async_track_entity_registry_updated_event(
) )
hass.data[TRACK_ENTITY_REGISTRY_UPDATED_LISTENER] = hass.bus.async_listen( hass.data[TRACK_ENTITY_REGISTRY_UPDATED_LISTENER] = hass.bus.async_listen(
EVENT_ENTITY_REGISTRY_UPDATED, _async_entity_registry_updated_dispatcher EVENT_ENTITY_REGISTRY_UPDATED,
_async_entity_registry_updated_dispatcher,
event_filter=_async_entity_registry_updated_filter,
) )
job = HassJob(action) job = HassJob(action)
@ -404,6 +435,11 @@ def async_track_state_added_domain(
if TRACK_STATE_ADDED_DOMAIN_LISTENER not in hass.data: if TRACK_STATE_ADDED_DOMAIN_LISTENER not in hass.data:
@callback
def _async_state_change_filter(event: Event) -> bool:
"""Filter state changes by entity_id."""
return event.data.get("old_state") is None
@callback @callback
def _async_state_change_dispatcher(event: Event) -> None: def _async_state_change_dispatcher(event: Event) -> None:
"""Dispatch state changes by entity_id.""" """Dispatch state changes by entity_id."""
@ -413,7 +449,9 @@ def async_track_state_added_domain(
_async_dispatch_domain_event(hass, event, domain_callbacks) _async_dispatch_domain_event(hass, event, domain_callbacks)
hass.data[TRACK_STATE_ADDED_DOMAIN_LISTENER] = hass.bus.async_listen( hass.data[TRACK_STATE_ADDED_DOMAIN_LISTENER] = hass.bus.async_listen(
EVENT_STATE_CHANGED, _async_state_change_dispatcher EVENT_STATE_CHANGED,
_async_state_change_dispatcher,
event_filter=_async_state_change_filter,
) )
job = HassJob(action) job = HassJob(action)
@ -450,6 +488,11 @@ def async_track_state_removed_domain(
if TRACK_STATE_REMOVED_DOMAIN_LISTENER not in hass.data: if TRACK_STATE_REMOVED_DOMAIN_LISTENER not in hass.data:
@callback
def _async_state_change_filter(event: Event) -> bool:
"""Filter state changes by entity_id."""
return event.data.get("new_state") is None
@callback @callback
def _async_state_change_dispatcher(event: Event) -> None: def _async_state_change_dispatcher(event: Event) -> None:
"""Dispatch state changes by entity_id.""" """Dispatch state changes by entity_id."""
@ -459,7 +502,9 @@ def async_track_state_removed_domain(
_async_dispatch_domain_event(hass, event, domain_callbacks) _async_dispatch_domain_event(hass, event, domain_callbacks)
hass.data[TRACK_STATE_REMOVED_DOMAIN_LISTENER] = hass.bus.async_listen( hass.data[TRACK_STATE_REMOVED_DOMAIN_LISTENER] = hass.bus.async_listen(
EVENT_STATE_CHANGED, _async_state_change_dispatcher EVENT_STATE_CHANGED,
_async_state_change_dispatcher,
event_filter=_async_state_change_filter,
) )
job = HassJob(action) job = HassJob(action)

View File

@ -62,7 +62,7 @@ async def fire_events(hass):
"""Fire a million events.""" """Fire a million events."""
count = 0 count = 0
event_name = "benchmark_event" event_name = "benchmark_event"
event = asyncio.Event() events_to_fire = 10 ** 6
@core.callback @core.callback
def listener(_): def listener(_):
@ -70,17 +70,48 @@ async def fire_events(hass):
nonlocal count nonlocal count
count += 1 count += 1
if count == 10 ** 6:
event.set()
hass.bus.async_listen(event_name, listener) hass.bus.async_listen(event_name, listener)
for _ in range(10 ** 6): for _ in range(events_to_fire):
hass.bus.async_fire(event_name) hass.bus.async_fire(event_name)
start = timer() start = timer()
await event.wait() await hass.async_block_till_done()
assert count == events_to_fire
return timer() - start
@benchmark
async def fire_events_with_filter(hass):
"""Fire a million events with a filter that rejects them."""
count = 0
event_name = "benchmark_event"
events_to_fire = 10 ** 6
@core.callback
def event_filter(event):
"""Filter event."""
return False
@core.callback
def listener(_):
"""Handle event."""
nonlocal count
count += 1
hass.bus.async_listen(event_name, listener, event_filter=event_filter)
for _ in range(events_to_fire):
hass.bus.async_fire(event_name)
start = timer()
await hass.async_block_till_done()
assert count == 0
return timer() - start return timer() - start
@ -154,7 +185,7 @@ async def state_changed_event_helper(hass):
"""Run a million events through state changed event helper with 1000 entities.""" """Run a million events through state changed event helper with 1000 entities."""
count = 0 count = 0
entity_id = "light.kitchen" entity_id = "light.kitchen"
event = asyncio.Event() events_to_fire = 10 ** 6
@core.callback @core.callback
def listener(*args): def listener(*args):
@ -162,9 +193,6 @@ async def state_changed_event_helper(hass):
nonlocal count nonlocal count
count += 1 count += 1
if count == 10 ** 6:
event.set()
hass.helpers.event.async_track_state_change_event( hass.helpers.event.async_track_state_change_event(
[f"{entity_id}{idx}" for idx in range(1000)], listener [f"{entity_id}{idx}" for idx in range(1000)], listener
) )
@ -175,12 +203,49 @@ async def state_changed_event_helper(hass):
"new_state": core.State(entity_id, "on"), "new_state": core.State(entity_id, "on"),
} }
for _ in range(10 ** 6): for _ in range(events_to_fire):
hass.bus.async_fire(EVENT_STATE_CHANGED, event_data) hass.bus.async_fire(EVENT_STATE_CHANGED, event_data)
start = timer() start = timer()
await event.wait() await hass.async_block_till_done()
assert count == events_to_fire
return timer() - start
@benchmark
async def state_changed_event_filter_helper(hass):
"""Run a million events through state changed event helper with 1000 entities that all get filtered."""
count = 0
entity_id = "light.kitchen"
events_to_fire = 10 ** 6
@core.callback
def listener(*args):
"""Handle event."""
nonlocal count
count += 1
hass.helpers.event.async_track_state_change_event(
[f"{entity_id}{idx}" for idx in range(1000)], listener
)
event_data = {
"entity_id": "switch.no_listeners",
"old_state": core.State(entity_id, "off"),
"new_state": core.State(entity_id, "on"),
}
for _ in range(events_to_fire):
hass.bus.async_fire(EVENT_STATE_CHANGED, event_data)
start = timer()
await hass.async_block_till_done()
assert count == 0
return timer() - start return timer() - start

View File

@ -194,6 +194,7 @@ async def test_cleanup_device_tracker(hass, device_reg, entity_reg, mqtt_mock):
device_reg.async_remove_device(device_entry.id) device_reg.async_remove_device(device_entry.id)
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_block_till_done()
# Verify device and registry entries are cleared # Verify device and registry entries are cleared
device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}) device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")})

View File

@ -411,6 +411,7 @@ async def test_cleanup_device(hass, device_reg, entity_reg, mqtt_mock):
device_reg.async_remove_device(device_entry.id) device_reg.async_remove_device(device_entry.id)
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_block_till_done()
# Verify device and registry entries are cleared # Verify device and registry entries are cleared
device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}) device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")})

View File

@ -353,6 +353,7 @@ async def test_remove_clients(hass, aioclient_mock):
} }
controller.api.session_handler(SIGNAL_DATA) controller.api.session_handler(SIGNAL_DATA)
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_block_till_done()
assert len(hass.states.async_entity_ids(TRACKER_DOMAIN)) == 1 assert len(hass.states.async_entity_ids(TRACKER_DOMAIN)) == 1

View File

@ -379,6 +379,35 @@ async def test_eventbus_add_remove_listener(hass):
unsub() unsub()
async def test_eventbus_filtered_listener(hass):
"""Test we can prefilter events."""
calls = []
@ha.callback
def listener(event):
"""Mock listener."""
calls.append(event)
@ha.callback
def filter(event):
"""Mock filter."""
return not event.data["filtered"]
unsub = hass.bus.async_listen("test", listener, event_filter=filter)
hass.bus.async_fire("test", {"filtered": True})
await hass.async_block_till_done()
assert len(calls) == 0
hass.bus.async_fire("test", {"filtered": False})
await hass.async_block_till_done()
assert len(calls) == 1
unsub()
async def test_eventbus_unsubscribe_listener(hass): async def test_eventbus_unsubscribe_listener(hass):
"""Test unsubscribe listener from returned function.""" """Test unsubscribe listener from returned function."""
calls = [] calls = []