diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index 3f0c2db3b2f..f6c423a35af 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -17,7 +17,14 @@ from homeassistant.const import ( SUN_EVENT_SUNRISE, SUN_EVENT_SUNSET, ) -from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, State, callback +from homeassistant.core import ( + CALLBACK_TYPE, + Event, + HomeAssistant, + State, + callback, + split_entity_id, +) from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED from homeassistant.helpers.sun import get_astral_event_next from homeassistant.helpers.template import Template @@ -28,6 +35,9 @@ from homeassistant.util.async_ import run_callback_threadsafe TRACK_STATE_CHANGE_CALLBACKS = "track_state_change_callbacks" TRACK_STATE_CHANGE_LISTENER = "track_state_change_listener" +TRACK_STATE_ADDED_DOMAIN_CALLBACKS = "track_state_added_domain_callbacks" +TRACK_STATE_ADDED_DOMAIN_LISTENER = "track_state_added_domain_listener" + TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS = "track_entity_registry_updated_callbacks" TRACK_ENTITY_REGISTRY_UPDATED_LISTENER = "track_entity_registry_updated_listener" @@ -191,7 +201,7 @@ def async_track_state_change_event( @callback def remove_listener() -> None: """Remove state change listener.""" - _async_remove_entity_listeners( + _async_remove_indexed_listeners( hass, TRACK_STATE_CHANGE_CALLBACKS, TRACK_STATE_CHANGE_LISTENER, @@ -203,23 +213,23 @@ def async_track_state_change_event( @callback -def _async_remove_entity_listeners( +def _async_remove_indexed_listeners( hass: HomeAssistant, - storage_key: str, + data_key: str, listener_key: str, - entity_ids: Iterable[str], + storage_keys: Iterable[str], action: Callable[[Event], Any], ) -> None: """Remove a listener.""" - entity_callbacks = hass.data[storage_key] + callbacks = hass.data[data_key] - for entity_id in entity_ids: - entity_callbacks[entity_id].remove(action) - if len(entity_callbacks[entity_id]) == 0: - del entity_callbacks[entity_id] + for storage_key in storage_keys: + callbacks[storage_key].remove(action) + if len(callbacks[storage_key]) == 0: + del callbacks[storage_key] - if not entity_callbacks: + if not callbacks: hass.data[listener_key]() del hass.data[listener_key] @@ -271,7 +281,7 @@ def async_track_entity_registry_updated_event( @callback def remove_listener() -> None: """Remove state change listener.""" - _async_remove_entity_listeners( + _async_remove_indexed_listeners( hass, TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS, TRACK_ENTITY_REGISTRY_UPDATED_LISTENER, @@ -282,6 +292,63 @@ def async_track_entity_registry_updated_event( return remove_listener +@bind_hass +def async_track_state_added_domain( + hass: HomeAssistant, + domains: Union[str, Iterable[str]], + action: Callable[[Event], Any], +) -> Callable[[], None]: + """Track state change events when an entity is added to domains.""" + + domain_callbacks = hass.data.setdefault(TRACK_STATE_ADDED_DOMAIN_CALLBACKS, {}) + + if TRACK_STATE_ADDED_DOMAIN_LISTENER not in hass.data: + + @callback + def _async_state_change_dispatcher(event: Event) -> None: + """Dispatch state changes by entity_id.""" + if event.data.get("old_state") is not None: + return + + domain = split_entity_id(event.data["entity_id"])[0] + + if domain not in domain_callbacks: + return + + for action in domain_callbacks[domain][:]: + try: + hass.async_run_job(action, event) + except Exception: # pylint: disable=broad-except + _LOGGER.exception( + "Error while processing state added for %s", domain + ) + + hass.data[TRACK_STATE_ADDED_DOMAIN_LISTENER] = hass.bus.async_listen( + EVENT_STATE_CHANGED, _async_state_change_dispatcher + ) + + if isinstance(domains, str): + domains = [domains] + + domains = [domains.lower() for domains in domains] + + for domain in domains: + domain_callbacks.setdefault(domain, []).append(action) + + @callback + def remove_listener() -> None: + """Remove state change listener.""" + _async_remove_indexed_listeners( + hass, + TRACK_STATE_ADDED_DOMAIN_CALLBACKS, + TRACK_STATE_ADDED_DOMAIN_LISTENER, + domains, + action, + ) + + return remove_listener + + @callback @bind_hass def async_track_template( diff --git a/tests/helpers/test_event.py b/tests/helpers/test_event.py index aa0a69d1d67..e30f85c9c38 100644 --- a/tests/helpers/test_event.py +++ b/tests/helpers/test_event.py @@ -16,6 +16,7 @@ from homeassistant.helpers.event import ( async_track_point_in_time, async_track_point_in_utc_time, async_track_same_state, + async_track_state_added_domain, async_track_state_change, async_track_state_change_event, async_track_sunrise, @@ -341,6 +342,88 @@ async def test_async_track_state_change_event(hass): unsub_throws() +async def test_async_track_state_added_domain(hass): + """Test async_track_state_added_domain.""" + single_entity_id_tracker = [] + multiple_entity_id_tracker = [] + + @ha.callback + def single_run_callback(event): + old_state = event.data.get("old_state") + new_state = event.data.get("new_state") + + single_entity_id_tracker.append((old_state, new_state)) + + @ha.callback + def multiple_run_callback(event): + old_state = event.data.get("old_state") + new_state = event.data.get("new_state") + + multiple_entity_id_tracker.append((old_state, new_state)) + + @ha.callback + def callback_that_throws(event): + raise ValueError + + unsub_single = async_track_state_added_domain(hass, "light", single_run_callback) + unsub_multi = async_track_state_added_domain( + hass, ["light", "switch"], multiple_run_callback + ) + unsub_throws = async_track_state_added_domain( + hass, ["light", "switch"], callback_that_throws + ) + + # Adding state to state machine + hass.states.async_set("light.Bowl", "on") + await hass.async_block_till_done() + assert len(single_entity_id_tracker) == 1 + assert single_entity_id_tracker[-1][0] is None + assert single_entity_id_tracker[-1][1] is not None + assert len(multiple_entity_id_tracker) == 1 + assert multiple_entity_id_tracker[-1][0] is None + assert multiple_entity_id_tracker[-1][1] is not None + + # Set same state should not trigger a state change/listener + hass.states.async_set("light.Bowl", "on") + await hass.async_block_till_done() + assert len(single_entity_id_tracker) == 1 + assert len(multiple_entity_id_tracker) == 1 + + # State change off -> on - nothing added so no trigger + hass.states.async_set("light.Bowl", "off") + await hass.async_block_till_done() + assert len(single_entity_id_tracker) == 1 + assert len(multiple_entity_id_tracker) == 1 + + # State change off -> off - nothing added so no trigger + hass.states.async_set("light.Bowl", "off", {"some_attr": 1}) + await hass.async_block_till_done() + assert len(single_entity_id_tracker) == 1 + assert len(multiple_entity_id_tracker) == 1 + + # Removing state does not trigger + hass.states.async_remove("light.bowl") + await hass.async_block_till_done() + assert len(single_entity_id_tracker) == 1 + assert len(multiple_entity_id_tracker) == 1 + + # Set state for different entity id + hass.states.async_set("switch.kitchen", "on") + await hass.async_block_till_done() + assert len(single_entity_id_tracker) == 1 + assert len(multiple_entity_id_tracker) == 2 + + unsub_single() + # Ensure unsubing the listener works + hass.states.async_set("light.new", "off") + await hass.async_block_till_done() + assert len(single_entity_id_tracker) == 1 + assert len(multiple_entity_id_tracker) == 3 + + unsub_multi() + unsub_throws() + + async def test_track_template(hass): """Test tracking template.""" specific_runs = []