From e208aac834155091f0cbfce2b3a9c14a23c19324 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 11 Sep 2020 06:03:31 -0500 Subject: [PATCH] Add async_track_state_removed_domain to allow tracking when a state is removed from a domain (#39859) when a state is removed from a domain --- homeassistant/helpers/event.py | 102 +++++++++++++++++++------- tests/helpers/test_event.py | 127 +++++++++++++++++++++++++++++++++ 2 files changed, 205 insertions(+), 24 deletions(-) diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index d9f1b8d9681..4f30d255aec 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -55,6 +55,9 @@ TRACK_STATE_CHANGE_LISTENER = "track_state_change_listener" TRACK_STATE_ADDED_DOMAIN_CALLBACKS = "track_state_added_domain_callbacks" TRACK_STATE_ADDED_DOMAIN_LISTENER = "track_state_added_domain_listener" +TRACK_STATE_REMOVED_DOMAIN_CALLBACKS = "track_state_removed_domain_callbacks" +TRACK_STATE_REMOVED_DOMAIN_LISTENER = "track_state_removed_domain_listener" + TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS = "track_entity_registry_updated_callbacks" TRACK_ENTITY_REGISTRY_UPDATED_LISTENER = "track_entity_registry_updated_listener" @@ -235,10 +238,7 @@ def async_track_state_change_event( EVENT_STATE_CHANGED, _async_state_change_dispatcher ) - if isinstance(entity_ids, str): - entity_ids = [entity_ids] - - entity_ids = [entity_id.lower() for entity_id in entity_ids] + entity_ids = _async_string_to_lower_list(entity_ids) for entity_id in entity_ids: entity_callbacks.setdefault(entity_id, []).append(action) @@ -315,10 +315,7 @@ def async_track_entity_registry_updated_event( EVENT_ENTITY_REGISTRY_UPDATED, _async_entity_registry_updated_dispatcher ) - if isinstance(entity_ids, str): - entity_ids = [entity_ids] - - entity_ids = [entity_id.lower() for entity_id in entity_ids] + entity_ids = _async_string_to_lower_list(entity_ids) for entity_id in entity_ids: entity_callbacks.setdefault(entity_id, []).append(action) @@ -337,6 +334,26 @@ def async_track_entity_registry_updated_event( return remove_listener +@callback +def _async_dispatch_domain_event( + hass: HomeAssistant, event: Event, callbacks: Dict[str, List] +) -> None: + domain = split_entity_id(event.data["entity_id"])[0] + + if domain not in callbacks and MATCH_ALL not in callbacks: + return + + listeners = callbacks.get(domain, []) + callbacks.get(MATCH_ALL, []) + + for action in listeners: + try: + hass.async_run_job(action, event) + except Exception: # pylint: disable=broad-except + _LOGGER.exception( + "Error while processing event %s for domain %s", event, domain + ) + + @bind_hass def async_track_state_added_domain( hass: HomeAssistant, @@ -355,27 +372,13 @@ def async_track_state_added_domain( if event.data.get("old_state") is not None: return - domain = split_entity_id(event.data["entity_id"])[0] - - if domain not in domain_callbacks: - return - - for action in domain_callbacks[domain][:]: - try: - hass.async_run_job(action, event) - except Exception: # pylint: disable=broad-except - _LOGGER.exception( - "Error while processing state added for %s", domain - ) + _async_dispatch_domain_event(hass, event, domain_callbacks) hass.data[TRACK_STATE_ADDED_DOMAIN_LISTENER] = hass.bus.async_listen( EVENT_STATE_CHANGED, _async_state_change_dispatcher ) - if isinstance(domains, str): - domains = [domains] - - domains = [domains.lower() for domains in domains] + domains = _async_string_to_lower_list(domains) for domain in domains: domain_callbacks.setdefault(domain, []).append(action) @@ -394,6 +397,57 @@ def async_track_state_added_domain( return remove_listener +@bind_hass +def async_track_state_removed_domain( + hass: HomeAssistant, + domains: Union[str, Iterable[str]], + action: Callable[[Event], Any], +) -> Callable[[], None]: + """Track state change events when an entity is removed from domains.""" + + domain_callbacks = hass.data.setdefault(TRACK_STATE_REMOVED_DOMAIN_CALLBACKS, {}) + + if TRACK_STATE_REMOVED_DOMAIN_LISTENER not in hass.data: + + @callback + def _async_state_change_dispatcher(event: Event) -> None: + """Dispatch state changes by entity_id.""" + if event.data.get("new_state") is not None: + return + + _async_dispatch_domain_event(hass, event, domain_callbacks) + + hass.data[TRACK_STATE_REMOVED_DOMAIN_LISTENER] = hass.bus.async_listen( + EVENT_STATE_CHANGED, _async_state_change_dispatcher + ) + + domains = _async_string_to_lower_list(domains) + + for domain in domains: + domain_callbacks.setdefault(domain, []).append(action) + + @callback + def remove_listener() -> None: + """Remove state change listener.""" + _async_remove_indexed_listeners( + hass, + TRACK_STATE_REMOVED_DOMAIN_CALLBACKS, + TRACK_STATE_REMOVED_DOMAIN_LISTENER, + domains, + action, + ) + + return remove_listener + + +@callback +def _async_string_to_lower_list(instr: Union[str, Iterable[str]]) -> List[str]: + if isinstance(instr, str): + return [instr.lower()] + + return [mstr.lower() for mstr in instr] + + @callback @bind_hass def async_track_template( diff --git a/tests/helpers/test_event.py b/tests/helpers/test_event.py index cc06c0fd19c..fcb8655804e 100644 --- a/tests/helpers/test_event.py +++ b/tests/helpers/test_event.py @@ -23,6 +23,7 @@ from homeassistant.helpers.event import ( async_track_state_added_domain, async_track_state_change, async_track_state_change_event, + async_track_state_removed_domain, async_track_sunrise, async_track_sunset, async_track_template, @@ -429,6 +430,132 @@ async def test_async_track_state_added_domain(hass): unsub_throws() +async def test_async_track_state_removed_domain(hass): + """Test async_track_state_removed_domain.""" + single_entity_id_tracker = [] + multiple_entity_id_tracker = [] + + @ha.callback + def single_run_callback(event): + old_state = event.data.get("old_state") + new_state = event.data.get("new_state") + + single_entity_id_tracker.append((old_state, new_state)) + + @ha.callback + def multiple_run_callback(event): + old_state = event.data.get("old_state") + new_state = event.data.get("new_state") + + multiple_entity_id_tracker.append((old_state, new_state)) + + @ha.callback + def callback_that_throws(event): + raise ValueError + + unsub_single = async_track_state_removed_domain(hass, "light", single_run_callback) + unsub_multi = async_track_state_removed_domain( + hass, ["light", "switch"], multiple_run_callback + ) + unsub_throws = async_track_state_removed_domain( + hass, ["light", "switch"], callback_that_throws + ) + + # Adding state to state machine + hass.states.async_set("light.Bowl", "on") + hass.states.async_remove("light.Bowl") + await hass.async_block_till_done() + assert len(single_entity_id_tracker) == 1 + assert single_entity_id_tracker[-1][1] is None + assert single_entity_id_tracker[-1][0] is not None + assert len(multiple_entity_id_tracker) == 1 + assert multiple_entity_id_tracker[-1][1] is None + assert multiple_entity_id_tracker[-1][0] is not None + + # Added and than removed (light) + hass.states.async_set("light.Bowl", "on") + hass.states.async_remove("light.Bowl") + await hass.async_block_till_done() + assert len(single_entity_id_tracker) == 2 + assert len(multiple_entity_id_tracker) == 2 + + # Added and than removed (light) + hass.states.async_set("light.Bowl", "off") + hass.states.async_remove("light.Bowl") + await hass.async_block_till_done() + assert len(single_entity_id_tracker) == 3 + assert len(multiple_entity_id_tracker) == 3 + + # Added and than removed (light) + hass.states.async_set("light.Bowl", "off", {"some_attr": 1}) + hass.states.async_remove("light.Bowl") + await hass.async_block_till_done() + assert len(single_entity_id_tracker) == 4 + assert len(multiple_entity_id_tracker) == 4 + + # Added and than removed (switch) + hass.states.async_set("switch.kitchen", "on") + hass.states.async_remove("switch.kitchen") + await hass.async_block_till_done() + assert len(single_entity_id_tracker) == 4 + assert len(multiple_entity_id_tracker) == 5 + + unsub_single() + # Ensure unsubing the listener works + hass.states.async_set("light.new", "off") + hass.states.async_remove("light.new") + await hass.async_block_till_done() + assert len(single_entity_id_tracker) == 4 + assert len(multiple_entity_id_tracker) == 6 + + unsub_multi() + unsub_throws() + + +async def test_async_track_state_removed_domain_match_all(hass): + """Test async_track_state_removed_domain with a match_all.""" + single_entity_id_tracker = [] + match_all_entity_id_tracker = [] + + @ha.callback + def single_run_callback(event): + old_state = event.data.get("old_state") + new_state = event.data.get("new_state") + + single_entity_id_tracker.append((old_state, new_state)) + + @ha.callback + def match_all_run_callback(event): + old_state = event.data.get("old_state") + new_state = event.data.get("new_state") + + match_all_entity_id_tracker.append((old_state, new_state)) + + unsub_single = async_track_state_removed_domain(hass, "light", single_run_callback) + unsub_match_all = async_track_state_removed_domain( + hass, MATCH_ALL, match_all_run_callback + ) + hass.states.async_set("light.new", "off") + hass.states.async_remove("light.new") + await hass.async_block_till_done() + assert len(single_entity_id_tracker) == 1 + assert len(match_all_entity_id_tracker) == 1 + + hass.states.async_set("switch.new", "off") + hass.states.async_remove("switch.new") + await hass.async_block_till_done() + assert len(single_entity_id_tracker) == 1 + assert len(match_all_entity_id_tracker) == 2 + + unsub_match_all() + unsub_single() + hass.states.async_set("switch.new", "off") + hass.states.async_remove("switch.new") + await hass.async_block_till_done() + assert len(single_entity_id_tracker) == 1 + assert len(match_all_entity_id_tracker) == 2 + + async def test_track_template(hass): """Test tracking template.""" specific_runs = []