diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index c6eeffb974f..a2dfcff7699 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -135,7 +135,9 @@ track_state_change = threaded_listener_factory(async_track_state_change) @bind_hass def async_track_state_change_event( - hass: HomeAssistant, entity_ids: Iterable[str], action: Callable[[Event], None] + hass: HomeAssistant, + entity_ids: Union[str, Iterable[str]], + action: Callable[[Event], None], ) -> Callable[[], None]: """Track specific state change events indexed by entity_id. @@ -161,7 +163,7 @@ def async_track_state_change_event( if entity_id not in entity_callbacks: return - for action in entity_callbacks[entity_id]: + for action in entity_callbacks[entity_id][:]: try: hass.async_run_job(action, event) except Exception: # pylint: disable=broad-except @@ -173,13 +175,13 @@ def async_track_state_change_event( EVENT_STATE_CHANGED, _async_state_change_dispatcher ) + if isinstance(entity_ids, str): + entity_ids = [entity_ids] + entity_ids = [entity_id.lower() for entity_id in entity_ids] for entity_id in entity_ids: - if entity_id not in entity_callbacks: - entity_callbacks[entity_id] = [] - - entity_callbacks[entity_id].append(action) + entity_callbacks.setdefault(entity_id, []).append(action) @callback def remove_listener() -> None: @@ -247,7 +249,7 @@ def async_track_same_state( hass: HomeAssistant, period: timedelta, action: Callable[..., None], - async_check_same_func: Callable[[str, State, State], bool], + async_check_same_func: Callable[[str, Optional[State], Optional[State]], bool], entity_ids: Union[str, Iterable[str]] = MATCH_ALL, ) -> CALLBACK_TYPE: """Track the state of entities for a period and run an action. @@ -279,10 +281,12 @@ def async_track_same_state( hass.async_run_job(action) @callback - def state_for_cancel_listener( - entity: str, from_state: State, to_state: State - ) -> None: + def state_for_cancel_listener(event: Event) -> None: """Fire on changes and cancel for listener if changed.""" + entity: str = event.data["entity_id"] + from_state: Optional[State] = event.data.get("old_state") + to_state: Optional[State] = event.data.get("new_state") + if not async_check_same_func(entity, from_state, to_state): clear_listener() @@ -290,9 +294,16 @@ def async_track_same_state( hass, state_for_listener, dt_util.utcnow() + period ) - async_remove_state_for_cancel = async_track_state_change( - hass, entity_ids, state_for_cancel_listener - ) + if entity_ids == MATCH_ALL: + async_remove_state_for_cancel = hass.bus.async_listen( + EVENT_STATE_CHANGED, state_for_cancel_listener + ) + else: + async_remove_state_for_cancel = async_track_state_change_event( + hass, + [entity_ids] if isinstance(entity_ids, str) else entity_ids, + state_for_cancel_listener, + ) return clear_listener diff --git a/tests/helpers/test_event.py b/tests/helpers/test_event.py index 7724a80e8b4..b0034ebaaa6 100644 --- a/tests/helpers/test_event.py +++ b/tests/helpers/test_event.py @@ -1011,3 +1011,104 @@ async def test_async_call_later(hass): assert p_action is action assert p_point == now + timedelta(seconds=3) assert remove is mock() + + +async def test_track_state_change_event_chain_multple_entity(hass): + """Test that adding a new state tracker inside a tracker does not fire right away.""" + tracker_called = [] + chained_tracker_called = [] + + chained_tracker_unsub = [] + tracker_unsub = [] + + @ha.callback + def chained_single_run_callback(event): + old_state = event.data.get("old_state") + new_state = event.data.get("new_state") + + chained_tracker_called.append((old_state, new_state)) + + @ha.callback + def single_run_callback(event): + old_state = event.data.get("old_state") + new_state = event.data.get("new_state") + + tracker_called.append((old_state, new_state)) + + chained_tracker_unsub.append( + async_track_state_change_event( + hass, ["light.bowl", "light.top"], chained_single_run_callback + ) + ) + + tracker_unsub.append( + async_track_state_change_event( + hass, ["light.bowl", "light.top"], single_run_callback + ) + ) + + hass.states.async_set("light.bowl", "on") + hass.states.async_set("light.top", "on") + await hass.async_block_till_done() + + assert len(tracker_called) == 2 + assert len(chained_tracker_called) == 1 + assert len(tracker_unsub) == 1 + assert len(chained_tracker_unsub) == 2 + + hass.states.async_set("light.bowl", "off") + await hass.async_block_till_done() + + assert len(tracker_called) == 3 + assert len(chained_tracker_called) == 3 + assert len(tracker_unsub) == 1 + assert len(chained_tracker_unsub) == 3 + + +async def test_track_state_change_event_chain_single_entity(hass): + """Test that adding a new state tracker inside a tracker does not fire right away.""" + tracker_called = [] + chained_tracker_called = [] + + chained_tracker_unsub = [] + tracker_unsub = [] + + @ha.callback + def chained_single_run_callback(event): + old_state = event.data.get("old_state") + new_state = event.data.get("new_state") + + chained_tracker_called.append((old_state, new_state)) + + @ha.callback + def single_run_callback(event): + old_state = event.data.get("old_state") + new_state = event.data.get("new_state") + + tracker_called.append((old_state, new_state)) + + chained_tracker_unsub.append( + async_track_state_change_event( + hass, "light.bowl", chained_single_run_callback + ) + ) + + tracker_unsub.append( + async_track_state_change_event(hass, "light.bowl", single_run_callback) + ) + + hass.states.async_set("light.bowl", "on") + await hass.async_block_till_done() + + assert len(tracker_called) == 1 + assert len(chained_tracker_called) == 0 + assert len(tracker_unsub) == 1 + assert len(chained_tracker_unsub) == 1 + + hass.states.async_set("light.bowl", "off") + await hass.async_block_till_done() + + assert len(tracker_called) == 2 + assert len(chained_tracker_called) == 1 + assert len(tracker_unsub) == 1 + assert len(chained_tracker_unsub) == 2