Add support for pre-filtering events to the event bus (#46371)

This commit is contained in:
J. Nick Koston 2021-02-14 09:42:55 -10:00 committed by GitHub
parent f8f86fbe48
commit c9df42b69a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 256 additions and 60 deletions

View File

@ -252,7 +252,22 @@ class Recorder(threading.Thread):
@callback
def async_initialize(self):
"""Initialize the recorder."""
self.hass.bus.async_listen(MATCH_ALL, self.event_listener)
self.hass.bus.async_listen(
MATCH_ALL, self.event_listener, event_filter=self._async_event_filter
)
@callback
def _async_event_filter(self, event):
"""Filter events."""
if event.event_type in self.exclude_t:
return False
entity_id = event.data.get(ATTR_ENTITY_ID)
if entity_id is not None:
if not self.entity_filter(entity_id):
return False
return True
def do_adhoc_purge(self, **kwargs):
"""Trigger an adhoc purge retaining keep_days worth of data."""
@ -378,13 +393,6 @@ class Recorder(threading.Thread):
self._timechanges_seen = 0
self._commit_event_session_or_retry()
continue
if event.event_type in self.exclude_t:
continue
entity_id = event.data.get(ATTR_ENTITY_ID)
if entity_id is not None:
if not self.entity_filter(entity_id):
continue
try:
if event.event_type == EVENT_STATE_CHANGED:

View File

@ -1139,17 +1139,13 @@ class EntityRegistryDisabledHandler:
def async_setup(self) -> None:
"""Set up the disable handler."""
self.hass.bus.async_listen(
entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, self._handle_entry_updated
entity_registry.EVENT_ENTITY_REGISTRY_UPDATED,
self._handle_entry_updated,
event_filter=_handle_entry_updated_filter,
)
async def _handle_entry_updated(self, event: Event) -> None:
"""Handle entity registry entry update."""
if (
event.data["action"] != "update"
or "disabled_by" not in event.data["changes"]
):
return
if self.registry is None:
self.registry = await entity_registry.async_get_registry(self.hass)
@ -1203,6 +1199,14 @@ class EntityRegistryDisabledHandler:
)
@callback
def _handle_entry_updated_filter(event: Event) -> bool:
"""Handle entity registry entry update filter."""
if event.data["action"] != "update" or "disabled_by" not in event.data["changes"]:
return False
return True
async def support_entry_unload(hass: HomeAssistant, domain: str) -> bool:
"""Test if a domain supports entry unloading."""
integration = await loader.async_get_integration(hass, domain)

View File

@ -28,6 +28,7 @@ from typing import (
Mapping,
Optional,
Set,
Tuple,
TypeVar,
Union,
cast,
@ -661,7 +662,7 @@ class EventBus:
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize a new event bus."""
self._listeners: Dict[str, List[HassJob]] = {}
self._listeners: Dict[str, List[Tuple[HassJob, Optional[Callable]]]] = {}
self._hass = hass
@callback
@ -717,7 +718,14 @@ class EventBus:
if not listeners:
return
for job in listeners:
for job, event_filter in listeners:
if event_filter is not None:
try:
if not event_filter(event):
continue
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Error in event filter")
continue
self._hass.async_add_hass_job(job, event)
def listen(self, event_type: str, listener: Callable) -> CALLBACK_TYPE:
@ -737,23 +745,38 @@ class EventBus:
return remove_listener
@callback
def async_listen(self, event_type: str, listener: Callable) -> CALLBACK_TYPE:
def async_listen(
self,
event_type: str,
listener: Callable,
event_filter: Optional[Callable] = None,
) -> CALLBACK_TYPE:
"""Listen for all events or events of a specific type.
To listen to all events specify the constant ``MATCH_ALL``
as event_type.
An optional event_filter, which must be a callable decorated with
@callback that returns a boolean value, determines if the
listener callable should run.
This method must be run in the event loop.
"""
return self._async_listen_job(event_type, HassJob(listener))
if event_filter is not None and not is_callback(event_filter):
raise HomeAssistantError(f"Event filter {event_filter} is not a callback")
return self._async_listen_filterable_job(
event_type, (HassJob(listener), event_filter)
)
@callback
def _async_listen_job(self, event_type: str, hassjob: HassJob) -> CALLBACK_TYPE:
self._listeners.setdefault(event_type, []).append(hassjob)
def _async_listen_filterable_job(
self, event_type: str, filterable_job: Tuple[HassJob, Optional[Callable]]
) -> CALLBACK_TYPE:
self._listeners.setdefault(event_type, []).append(filterable_job)
def remove_listener() -> None:
"""Remove the listener."""
self._async_remove_listener(event_type, hassjob)
self._async_remove_listener(event_type, filterable_job)
return remove_listener
@ -786,12 +809,12 @@ class EventBus:
This method must be run in the event loop.
"""
job: Optional[HassJob] = None
filterable_job: Optional[Tuple[HassJob, Optional[Callable]]] = None
@callback
def _onetime_listener(event: Event) -> None:
"""Remove listener from event bus and then fire listener."""
nonlocal job
nonlocal filterable_job
if hasattr(_onetime_listener, "run"):
return
# Set variable so that we will never run twice.
@ -800,22 +823,24 @@ class EventBus:
# multiple times as well.
# This will make sure the second time it does nothing.
setattr(_onetime_listener, "run", True)
assert job is not None
self._async_remove_listener(event_type, job)
assert filterable_job is not None
self._async_remove_listener(event_type, filterable_job)
self._hass.async_run_job(listener, event)
job = HassJob(_onetime_listener)
filterable_job = (HassJob(_onetime_listener), None)
return self._async_listen_job(event_type, job)
return self._async_listen_filterable_job(event_type, filterable_job)
@callback
def _async_remove_listener(self, event_type: str, hassjob: HassJob) -> None:
def _async_remove_listener(
self, event_type: str, filterable_job: Tuple[HassJob, Optional[Callable]]
) -> None:
"""Remove a listener of a specific event_type.
This method must be run in the event loop.
"""
try:
self._listeners[event_type].remove(hassjob)
self._listeners[event_type].remove(filterable_job)
# delete event_type list if empty
if not self._listeners[event_type]:
@ -823,7 +848,9 @@ class EventBus:
except (KeyError, ValueError):
# KeyError is key event_type listener did not exist
# ValueError if listener did not exist within event_type
_LOGGER.exception("Unable to remove unknown job listener %s", hassjob)
_LOGGER.exception(
"Unable to remove unknown job listener %s", filterable_job
)
class State:

View File

@ -686,25 +686,34 @@ def async_setup_cleanup(hass: HomeAssistantType, dev_reg: DeviceRegistry) -> Non
)
async def entity_registry_changed(event: Event) -> None:
"""Handle entity updated or removed."""
"""Handle entity updated or removed dispatch."""
await debounced_cleanup.async_call()
@callback
def entity_registry_changed_filter(event: Event) -> bool:
"""Handle entity updated or removed filter."""
if (
event.data["action"] == "update"
and "device_id" not in event.data["changes"]
) or event.data["action"] == "create":
return
return False
await debounced_cleanup.async_call()
return True
if hass.is_running:
hass.bus.async_listen(
entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, entity_registry_changed
entity_registry.EVENT_ENTITY_REGISTRY_UPDATED,
entity_registry_changed,
event_filter=entity_registry_changed_filter,
)
return
async def startup_clean(event: Event) -> None:
"""Clean up on startup."""
hass.bus.async_listen(
entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, entity_registry_changed
entity_registry.EVENT_ENTITY_REGISTRY_UPDATED,
entity_registry_changed,
event_filter=entity_registry_changed_filter,
)
await debounced_cleanup.async_call()

View File

@ -641,12 +641,14 @@ def async_setup_entity_restore(
) -> None:
"""Set up the entity restore mechanism."""
@callback
def cleanup_restored_states_filter(event: Event) -> bool:
"""Clean up restored states filter."""
return bool(event.data["action"] == "remove")
@callback
def cleanup_restored_states(event: Event) -> None:
"""Clean up restored states."""
if event.data["action"] != "remove":
return
state = hass.states.get(event.data["entity_id"])
if state is None or not state.attributes.get(ATTR_RESTORED):
@ -654,7 +656,11 @@ def async_setup_entity_restore(
hass.states.async_remove(event.data["entity_id"], context=event.context)
hass.bus.async_listen(EVENT_ENTITY_REGISTRY_UPDATED, cleanup_restored_states)
hass.bus.async_listen(
EVENT_ENTITY_REGISTRY_UPDATED,
cleanup_restored_states,
event_filter=cleanup_restored_states_filter,
)
if hass.is_running:
return

View File

@ -180,7 +180,7 @@ def async_track_state_change(
job = HassJob(action)
@callback
def state_change_listener(event: Event) -> None:
def state_change_filter(event: Event) -> bool:
"""Handle specific state changes."""
if from_state is not None:
old_state = event.data.get("old_state")
@ -188,15 +188,21 @@ def async_track_state_change(
old_state = old_state.state
if not match_from_state(old_state):
return
return False
if to_state is not None:
new_state = event.data.get("new_state")
if new_state is not None:
new_state = new_state.state
if not match_to_state(new_state):
return
return False
return True
@callback
def state_change_dispatcher(event: Event) -> None:
"""Handle specific state changes."""
hass.async_run_hass_job(
job,
event.data.get("entity_id"),
@ -204,6 +210,14 @@ def async_track_state_change(
event.data.get("new_state"),
)
@callback
def state_change_listener(event: Event) -> None:
"""Handle specific state changes."""
if not state_change_filter(event):
return
state_change_dispatcher(event)
if entity_ids != MATCH_ALL:
# If we have a list of entity ids we use
# async_track_state_change_event to route
@ -215,7 +229,9 @@ def async_track_state_change(
# entity_id.
return async_track_state_change_event(hass, entity_ids, state_change_listener)
return hass.bus.async_listen(EVENT_STATE_CHANGED, state_change_listener)
return hass.bus.async_listen(
EVENT_STATE_CHANGED, state_change_dispatcher, event_filter=state_change_filter
)
track_state_change = threaded_listener_factory(async_track_state_change)
@ -246,6 +262,11 @@ def async_track_state_change_event(
if TRACK_STATE_CHANGE_LISTENER not in hass.data:
@callback
def _async_state_change_filter(event: Event) -> bool:
"""Filter state changes by entity_id."""
return event.data.get("entity_id") in entity_callbacks
@callback
def _async_state_change_dispatcher(event: Event) -> None:
"""Dispatch state changes by entity_id."""
@ -263,7 +284,9 @@ def async_track_state_change_event(
)
hass.data[TRACK_STATE_CHANGE_LISTENER] = hass.bus.async_listen(
EVENT_STATE_CHANGED, _async_state_change_dispatcher
EVENT_STATE_CHANGED,
_async_state_change_dispatcher,
event_filter=_async_state_change_filter,
)
job = HassJob(action)
@ -329,6 +352,12 @@ def async_track_entity_registry_updated_event(
if TRACK_ENTITY_REGISTRY_UPDATED_LISTENER not in hass.data:
@callback
def _async_entity_registry_updated_filter(event: Event) -> bool:
"""Filter entity registry updates by entity_id."""
entity_id = event.data.get("old_entity_id", event.data["entity_id"])
return entity_id in entity_callbacks
@callback
def _async_entity_registry_updated_dispatcher(event: Event) -> None:
"""Dispatch entity registry updates by entity_id."""
@ -347,7 +376,9 @@ def async_track_entity_registry_updated_event(
)
hass.data[TRACK_ENTITY_REGISTRY_UPDATED_LISTENER] = hass.bus.async_listen(
EVENT_ENTITY_REGISTRY_UPDATED, _async_entity_registry_updated_dispatcher
EVENT_ENTITY_REGISTRY_UPDATED,
_async_entity_registry_updated_dispatcher,
event_filter=_async_entity_registry_updated_filter,
)
job = HassJob(action)
@ -404,6 +435,11 @@ def async_track_state_added_domain(
if TRACK_STATE_ADDED_DOMAIN_LISTENER not in hass.data:
@callback
def _async_state_change_filter(event: Event) -> bool:
"""Filter state changes by entity_id."""
return event.data.get("old_state") is None
@callback
def _async_state_change_dispatcher(event: Event) -> None:
"""Dispatch state changes by entity_id."""
@ -413,7 +449,9 @@ def async_track_state_added_domain(
_async_dispatch_domain_event(hass, event, domain_callbacks)
hass.data[TRACK_STATE_ADDED_DOMAIN_LISTENER] = hass.bus.async_listen(
EVENT_STATE_CHANGED, _async_state_change_dispatcher
EVENT_STATE_CHANGED,
_async_state_change_dispatcher,
event_filter=_async_state_change_filter,
)
job = HassJob(action)
@ -450,6 +488,11 @@ def async_track_state_removed_domain(
if TRACK_STATE_REMOVED_DOMAIN_LISTENER not in hass.data:
@callback
def _async_state_change_filter(event: Event) -> bool:
"""Filter state changes by entity_id."""
return event.data.get("new_state") is None
@callback
def _async_state_change_dispatcher(event: Event) -> None:
"""Dispatch state changes by entity_id."""
@ -459,7 +502,9 @@ def async_track_state_removed_domain(
_async_dispatch_domain_event(hass, event, domain_callbacks)
hass.data[TRACK_STATE_REMOVED_DOMAIN_LISTENER] = hass.bus.async_listen(
EVENT_STATE_CHANGED, _async_state_change_dispatcher
EVENT_STATE_CHANGED,
_async_state_change_dispatcher,
event_filter=_async_state_change_filter,
)
job = HassJob(action)

View File

@ -62,7 +62,7 @@ async def fire_events(hass):
"""Fire a million events."""
count = 0
event_name = "benchmark_event"
event = asyncio.Event()
events_to_fire = 10 ** 6
@core.callback
def listener(_):
@ -70,17 +70,48 @@ async def fire_events(hass):
nonlocal count
count += 1
if count == 10 ** 6:
event.set()
hass.bus.async_listen(event_name, listener)
for _ in range(10 ** 6):
for _ in range(events_to_fire):
hass.bus.async_fire(event_name)
start = timer()
await event.wait()
await hass.async_block_till_done()
assert count == events_to_fire
return timer() - start
@benchmark
async def fire_events_with_filter(hass):
"""Fire a million events with a filter that rejects them."""
count = 0
event_name = "benchmark_event"
events_to_fire = 10 ** 6
@core.callback
def event_filter(event):
"""Filter event."""
return False
@core.callback
def listener(_):
"""Handle event."""
nonlocal count
count += 1
hass.bus.async_listen(event_name, listener, event_filter=event_filter)
for _ in range(events_to_fire):
hass.bus.async_fire(event_name)
start = timer()
await hass.async_block_till_done()
assert count == 0
return timer() - start
@ -154,7 +185,7 @@ async def state_changed_event_helper(hass):
"""Run a million events through state changed event helper with 1000 entities."""
count = 0
entity_id = "light.kitchen"
event = asyncio.Event()
events_to_fire = 10 ** 6
@core.callback
def listener(*args):
@ -162,9 +193,6 @@ async def state_changed_event_helper(hass):
nonlocal count
count += 1
if count == 10 ** 6:
event.set()
hass.helpers.event.async_track_state_change_event(
[f"{entity_id}{idx}" for idx in range(1000)], listener
)
@ -175,12 +203,49 @@ async def state_changed_event_helper(hass):
"new_state": core.State(entity_id, "on"),
}
for _ in range(10 ** 6):
for _ in range(events_to_fire):
hass.bus.async_fire(EVENT_STATE_CHANGED, event_data)
start = timer()
await event.wait()
await hass.async_block_till_done()
assert count == events_to_fire
return timer() - start
@benchmark
async def state_changed_event_filter_helper(hass):
"""Run a million events through state changed event helper with 1000 entities that all get filtered."""
count = 0
entity_id = "light.kitchen"
events_to_fire = 10 ** 6
@core.callback
def listener(*args):
"""Handle event."""
nonlocal count
count += 1
hass.helpers.event.async_track_state_change_event(
[f"{entity_id}{idx}" for idx in range(1000)], listener
)
event_data = {
"entity_id": "switch.no_listeners",
"old_state": core.State(entity_id, "off"),
"new_state": core.State(entity_id, "on"),
}
for _ in range(events_to_fire):
hass.bus.async_fire(EVENT_STATE_CHANGED, event_data)
start = timer()
await hass.async_block_till_done()
assert count == 0
return timer() - start

View File

@ -194,6 +194,7 @@ async def test_cleanup_device_tracker(hass, device_reg, entity_reg, mqtt_mock):
device_reg.async_remove_device(device_entry.id)
await hass.async_block_till_done()
await hass.async_block_till_done()
# Verify device and registry entries are cleared
device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")})

View File

@ -411,6 +411,7 @@ async def test_cleanup_device(hass, device_reg, entity_reg, mqtt_mock):
device_reg.async_remove_device(device_entry.id)
await hass.async_block_till_done()
await hass.async_block_till_done()
# Verify device and registry entries are cleared
device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")})

View File

@ -353,6 +353,7 @@ async def test_remove_clients(hass, aioclient_mock):
}
controller.api.session_handler(SIGNAL_DATA)
await hass.async_block_till_done()
await hass.async_block_till_done()
assert len(hass.states.async_entity_ids(TRACKER_DOMAIN)) == 1

View File

@ -379,6 +379,35 @@ async def test_eventbus_add_remove_listener(hass):
unsub()
async def test_eventbus_filtered_listener(hass):
"""Test we can prefilter events."""
calls = []
@ha.callback
def listener(event):
"""Mock listener."""
calls.append(event)
@ha.callback
def filter(event):
"""Mock filter."""
return not event.data["filtered"]
unsub = hass.bus.async_listen("test", listener, event_filter=filter)
hass.bus.async_fire("test", {"filtered": True})
await hass.async_block_till_done()
assert len(calls) == 0
hass.bus.async_fire("test", {"filtered": False})
await hass.async_block_till_done()
assert len(calls) == 1
unsub()
async def test_eventbus_unsubscribe_listener(hass):
"""Test unsubscribe listener from returned function."""
calls = []