From 4acc6f333efe1d2426fb74510ce2de95971205b0 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 27 Jun 2020 16:46:45 -0500 Subject: [PATCH] Improve scalability of state change event routing (#37174) --- homeassistant/components/automation/state.py | 5 +- homeassistant/helpers/event.py | 90 ++++++++++++++++++-- tests/components/group/test_init.py | 13 ++- tests/helpers/test_event.py | 86 +++++++++++++++++++ 4 files changed, 183 insertions(+), 11 deletions(-) diff --git a/homeassistant/components/automation/state.py b/homeassistant/components/automation/state.py index 29aea64c9c5..fe49e1cf532 100644 --- a/homeassistant/components/automation/state.py +++ b/homeassistant/components/automation/state.py @@ -6,12 +6,13 @@ from typing import Dict import voluptuous as vol from homeassistant import exceptions -from homeassistant.const import CONF_FOR, CONF_PLATFORM, EVENT_STATE_CHANGED, MATCH_ALL +from homeassistant.const import CONF_FOR, CONF_PLATFORM, MATCH_ALL from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.helpers import config_validation as cv, template from homeassistant.helpers.event import ( Event, async_track_same_state, + async_track_state_change_event, process_state_match, ) @@ -153,7 +154,7 @@ async def async_attach_trigger( hass, period[entity], call_action, _check_same_state, entity_ids=entity, ) - unsub = hass.bus.async_listen(EVENT_STATE_CHANGED, state_automation_listener) + unsub = async_track_state_change_event(hass, entity_id, state_automation_listener) @callback def async_remove(): diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index 266cb150e0a..46c36205cee 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -1,7 +1,7 @@ """Helpers for listening to events.""" from datetime import datetime, timedelta import functools as ft -from typing import Any, Awaitable, Callable, Dict, Iterable, Optional, Union, cast +from typing import Any, Awaitable, Callable, Dict, Iterable, Optional, Union import attr @@ -21,6 +21,9 @@ from homeassistant.loader import bind_hass from homeassistant.util import dt as dt_util from homeassistant.util.async_ import run_callback_threadsafe +TRACK_STATE_CHANGE_CALLBACKS = "track_state_change_callbacks" +TRACK_STATE_CHANGE_LISTENER = "track_state_change_listener" + # PyLint does not like the use of threaded_listener_factory # pylint: disable=invalid-name @@ -81,12 +84,6 @@ def async_track_state_change( @callback def state_change_listener(event: Event) -> None: """Handle specific state changes.""" - if ( - entity_ids != MATCH_ALL - and cast(str, event.data.get("entity_id")) not in entity_ids - ): - return - old_state = event.data.get("old_state") if old_state is not None: old_state = old_state.state @@ -103,12 +100,91 @@ def async_track_state_change( event.data.get("new_state"), ) + if entity_ids != MATCH_ALL: + # If we have a list of entity ids we use + # async_track_state_change_event to route + # by entity_id to avoid iterating though state change + # events and creating a jobs where the most + # common outcome is to return right away because + # the entity_id does not match since usually + # only one or two listeners want that specific + # entity_id. + return async_track_state_change_event(hass, entity_ids, state_change_listener) + return hass.bus.async_listen(EVENT_STATE_CHANGED, state_change_listener) track_state_change = threaded_listener_factory(async_track_state_change) +@bind_hass +def async_track_state_change_event( + hass: HomeAssistant, entity_ids: Iterable[str], action: Callable[[Event], None] +) -> Callable[[], None]: + """Track specific state change events indexed by entity_id. + + Unlike async_track_state_change, async_track_state_change_event + passes the full event to the callback. + + In order to avoid having to iterate a long list + of EVENT_STATE_CHANGED and fire and create a job + for each one, we keep a dict of entity ids that + care about the state change events so we can + do a fast dict lookup to route events. + """ + + entity_callbacks = hass.data.setdefault(TRACK_STATE_CHANGE_CALLBACKS, {}) + + if TRACK_STATE_CHANGE_LISTENER not in hass.data: + + @callback + def _async_state_change_dispatcher(event: Event) -> None: + """Dispatch state changes by entity_id.""" + entity_id = event.data.get("entity_id") + + if entity_id not in entity_callbacks: + return + + for action in entity_callbacks[entity_id]: + hass.async_run_job(action, event) + + hass.data[TRACK_STATE_CHANGE_LISTENER] = hass.bus.async_listen( + EVENT_STATE_CHANGED, _async_state_change_dispatcher + ) + + entity_ids = [entity_id.lower() for entity_id in entity_ids] + + for entity_id in entity_ids: + if entity_id not in entity_callbacks: + entity_callbacks[entity_id] = [] + + entity_callbacks[entity_id].append(action) + + @callback + def remove_listener() -> None: + """Remove state change listener.""" + _async_remove_state_change_listeners(hass, entity_ids, action) + + return remove_listener + + +@callback +def _async_remove_state_change_listeners( + hass: HomeAssistant, entity_ids: Iterable[str], action: Callable[[Event], None] +) -> None: + """Remove a listener.""" + entity_callbacks = hass.data[TRACK_STATE_CHANGE_CALLBACKS] + + for entity_id in entity_ids: + entity_callbacks[entity_id].remove(action) + if len(entity_callbacks[entity_id]) == 0: + del entity_callbacks[entity_id] + + if not entity_callbacks: + hass.data[TRACK_STATE_CHANGE_LISTENER]() + del hass.data[TRACK_STATE_CHANGE_LISTENER] + + @callback @bind_hass def async_track_template( diff --git a/tests/components/group/test_init.py b/tests/components/group/test_init.py index ff5f3a30f75..921b810fe39 100644 --- a/tests/components/group/test_init.py +++ b/tests/components/group/test_init.py @@ -14,6 +14,7 @@ from homeassistant.const import ( STATE_ON, STATE_UNKNOWN, ) +from homeassistant.helpers.event import TRACK_STATE_CHANGE_CALLBACKS from homeassistant.setup import async_setup_component, setup_component from tests.async_mock import patch @@ -390,7 +391,12 @@ class TestComponentsGroup(unittest.TestCase): "group.second_group", "group.test_group", ] - assert self.hass.bus.listeners["state_changed"] == 3 + assert self.hass.bus.listeners["state_changed"] == 1 + assert len(self.hass.data[TRACK_STATE_CHANGE_CALLBACKS]["hello.world"]) == 1 + assert len(self.hass.data[TRACK_STATE_CHANGE_CALLBACKS]["sensor.happy"]) == 1 + assert len(self.hass.data[TRACK_STATE_CHANGE_CALLBACKS]["light.bowl"]) == 1 + assert len(self.hass.data[TRACK_STATE_CHANGE_CALLBACKS]["test.one"]) == 1 + assert len(self.hass.data[TRACK_STATE_CHANGE_CALLBACKS]["test.two"]) == 1 with patch( "homeassistant.config.load_yaml_config_file", @@ -405,7 +411,10 @@ class TestComponentsGroup(unittest.TestCase): "group.all_tests", "group.hello", ] - assert self.hass.bus.listeners["state_changed"] == 2 + assert self.hass.bus.listeners["state_changed"] == 1 + assert len(self.hass.data[TRACK_STATE_CHANGE_CALLBACKS]["light.bowl"]) == 1 + assert len(self.hass.data[TRACK_STATE_CHANGE_CALLBACKS]["test.one"]) == 1 + assert len(self.hass.data[TRACK_STATE_CHANGE_CALLBACKS]["test.two"]) == 1 def test_modify_group(self): """Test modifying a group.""" diff --git a/tests/helpers/test_event.py b/tests/helpers/test_event.py index 654cf8483db..95a093d59ab 100644 --- a/tests/helpers/test_event.py +++ b/tests/helpers/test_event.py @@ -15,6 +15,7 @@ from homeassistant.helpers.event import ( async_track_point_in_utc_time, async_track_same_state, async_track_state_change, + async_track_state_change_event, async_track_sunrise, async_track_sunset, async_track_template, @@ -163,6 +164,91 @@ async def test_track_state_change(hass): assert len(wildercard_runs) == 6 +async def test_async_track_state_change_event(hass): + """Test async_track_state_change_event.""" + single_entity_id_tracker = [] + multiple_entity_id_tracker = [] + + @ha.callback + def single_run_callback(event): + old_state = event.data.get("old_state") + new_state = event.data.get("new_state") + + single_entity_id_tracker.append((old_state, new_state)) + + @ha.callback + def multiple_run_callback(event): + old_state = event.data.get("old_state") + new_state = event.data.get("new_state") + + multiple_entity_id_tracker.append((old_state, new_state)) + + unsub_single = async_track_state_change_event( + hass, ["light.Bowl"], single_run_callback + ) + unsub_multi = async_track_state_change_event( + hass, ["light.Bowl", "switch.kitchen"], multiple_run_callback + ) + + # Adding state to state machine + hass.states.async_set("light.Bowl", "on") + await hass.async_block_till_done() + assert len(single_entity_id_tracker) == 1 + assert single_entity_id_tracker[-1][0] is None + assert single_entity_id_tracker[-1][1] is not None + assert len(multiple_entity_id_tracker) == 1 + assert multiple_entity_id_tracker[-1][0] is None + assert multiple_entity_id_tracker[-1][1] is not None + + # Set same state should not trigger a state change/listener + hass.states.async_set("light.Bowl", "on") + await hass.async_block_till_done() + assert len(single_entity_id_tracker) == 1 + assert len(multiple_entity_id_tracker) == 1 + + # State change off -> on + hass.states.async_set("light.Bowl", "off") + await hass.async_block_till_done() + assert len(single_entity_id_tracker) == 2 + assert len(multiple_entity_id_tracker) == 2 + + # State change off -> off + hass.states.async_set("light.Bowl", "off", {"some_attr": 1}) + await hass.async_block_till_done() + assert len(single_entity_id_tracker) == 3 + assert len(multiple_entity_id_tracker) == 3 + + # State change off -> on + hass.states.async_set("light.Bowl", "on") + await hass.async_block_till_done() + assert len(single_entity_id_tracker) == 4 + assert len(multiple_entity_id_tracker) == 4 + + hass.states.async_remove("light.bowl") + await hass.async_block_till_done() + assert len(single_entity_id_tracker) == 5 + assert single_entity_id_tracker[-1][0] is not None + assert single_entity_id_tracker[-1][1] is None + assert len(multiple_entity_id_tracker) == 5 + assert multiple_entity_id_tracker[-1][0] is not None + assert multiple_entity_id_tracker[-1][1] is None + + # Set state for different entity id + hass.states.async_set("switch.kitchen", "on") + await hass.async_block_till_done() + assert len(single_entity_id_tracker) == 5 + assert len(multiple_entity_id_tracker) == 6 + + unsub_single() + # Ensure unsubing the listener works + hass.states.async_set("light.Bowl", "off") + await hass.async_block_till_done() + assert len(single_entity_id_tracker) == 5 + assert len(multiple_entity_id_tracker) == 7 + + unsub_multi() + + async def test_track_template(hass): """Test tracking template.""" specific_runs = []