diff --git a/homeassistant/helpers/target.py b/homeassistant/helpers/target.py index c16819235b9..239d1e66336 100644 --- a/homeassistant/helpers/target.py +++ b/homeassistant/helpers/target.py @@ -2,9 +2,11 @@ from __future__ import annotations +from collections.abc import Callable import dataclasses +import logging from logging import Logger -from typing import TypeGuard +from typing import Any, TypeGuard from homeassistant.const import ( ATTR_AREA_ID, @@ -14,7 +16,14 @@ from homeassistant.const import ( ATTR_LABEL_ID, ENTITY_MATCH_NONE, ) -from homeassistant.core import HomeAssistant +from homeassistant.core import ( + CALLBACK_TYPE, + Event, + EventStateChangedData, + HomeAssistant, + callback, +) +from homeassistant.exceptions import HomeAssistantError from . import ( area_registry as ar, @@ -25,8 +34,11 @@ from . import ( group, label_registry as lr, ) +from .event import async_track_state_change_event from .typing import ConfigType +_LOGGER = logging.getLogger(__name__) + def _has_match(ids: str | list[str] | None) -> TypeGuard[str | list[str]]: """Check if ids can match anything.""" @@ -238,3 +250,102 @@ def async_extract_referenced_entity_ids( ) return selected + + +class TargetStateChangeTracker: + """Helper class to manage state change tracking for targets.""" + + def __init__( + self, + hass: HomeAssistant, + selector_data: TargetSelectorData, + action: Callable[[Event[EventStateChangedData]], Any], + ) -> None: + """Initialize the state change tracker.""" + self._hass = hass + self._selector_data = selector_data + self._action = action + + self._state_change_unsub: CALLBACK_TYPE | None = None + self._registry_unsubs: list[CALLBACK_TYPE] = [] + + def async_setup(self) -> Callable[[], None]: + """Set up the state change tracking.""" + self._setup_registry_listeners() + self._track_entities_state_change() + return self._unsubscribe + + def _track_entities_state_change(self) -> None: + """Set up state change tracking for currently selected entities.""" + selected = async_extract_referenced_entity_ids( + self._hass, self._selector_data, expand_group=False + ) + + @callback + def state_change_listener(event: Event[EventStateChangedData]) -> None: + """Handle state change events.""" + if ( + event.data["entity_id"] in selected.referenced + or event.data["entity_id"] in selected.indirectly_referenced + ): + self._action(event) + + tracked_entities = selected.referenced.union(selected.indirectly_referenced) + + _LOGGER.debug("Tracking state changes for entities: %s", tracked_entities) + self._state_change_unsub = async_track_state_change_event( + self._hass, tracked_entities, state_change_listener + ) + + def _setup_registry_listeners(self) -> None: + """Set up listeners for registry changes that require resubscription.""" + + @callback + def resubscribe_state_change_event(event: Event[Any] | None = None) -> None: + """Resubscribe to state change events when registry changes.""" + if self._state_change_unsub: + self._state_change_unsub() + self._track_entities_state_change() + + # Subscribe to registry updates that can change the entities to track: + # - Entity registry: entity added/removed; entity labels changed; entity area changed. + # - Device registry: device labels changed; device area changed. + # - Area registry: area floor changed. + # + # We don't track other registries (like floor or label registries) because their + # changes don't affect which entities are tracked. + self._registry_unsubs = [ + self._hass.bus.async_listen( + er.EVENT_ENTITY_REGISTRY_UPDATED, resubscribe_state_change_event + ), + self._hass.bus.async_listen( + dr.EVENT_DEVICE_REGISTRY_UPDATED, resubscribe_state_change_event + ), + self._hass.bus.async_listen( + ar.EVENT_AREA_REGISTRY_UPDATED, resubscribe_state_change_event + ), + ] + + def _unsubscribe(self) -> None: + """Unsubscribe from all events.""" + for registry_unsub in self._registry_unsubs: + registry_unsub() + self._registry_unsubs.clear() + if self._state_change_unsub: + self._state_change_unsub() + self._state_change_unsub = None + + +def async_track_target_selector_state_change_event( + hass: HomeAssistant, + target_selector_config: ConfigType, + action: Callable[[Event[EventStateChangedData]], Any], +) -> CALLBACK_TYPE: + """Track state changes for entities referenced directly or indirectly in a target selector.""" + selector_data = TargetSelectorData(target_selector_config) + if not selector_data.has_any_selector: + raise HomeAssistantError( + f"Target selector {target_selector_config} does not have any selectors defined" + ) + tracker = TargetStateChangeTracker(hass, selector_data, action) + return tracker.async_setup() diff --git a/tests/helpers/test_target.py b/tests/helpers/test_target.py index ca38f316d89..c87a320e378 100644 --- a/tests/helpers/test_target.py +++ b/tests/helpers/test_target.py @@ -2,9 +2,6 @@ import pytest -# TODO(abmantis): is this import needed? -# To prevent circular import when running just this file -import homeassistant.components # noqa: F401 from homeassistant.components.group import Group from homeassistant.const import ( ATTR_AREA_ID, @@ -17,17 +14,21 @@ from homeassistant.const import ( STATE_ON, EntityCategory, ) -from homeassistant.core import HomeAssistant +from homeassistant.core import Event, EventStateChangedData, HomeAssistant, callback +from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import ( area_registry as ar, device_registry as dr, entity_registry as er, + floor_registry as fr, + label_registry as lr, target, ) from homeassistant.helpers.typing import ConfigType from homeassistant.setup import async_setup_component from tests.common import ( + MockConfigEntry, RegistryEntryWithDefaults, mock_area_registry, mock_device_registry, @@ -457,3 +458,188 @@ async def test_extract_referenced_entity_ids( ) == expected_selected ) + + +async def test_async_track_target_selector_state_change_event_empty_selector( + hass: HomeAssistant, caplog: pytest.LogCaptureFixture +) -> None: + """Test async_track_target_selector_state_change_event with empty selector.""" + + @callback + def state_change_callback(event): + """Handle state change events.""" + + with pytest.raises(HomeAssistantError) as excinfo: + target.async_track_target_selector_state_change_event( + hass, {}, state_change_callback + ) + assert str(excinfo.value) == ( + "Target selector {} does not have any selectors defined" + ) + + +async def test_async_track_target_selector_state_change_event( + hass: HomeAssistant, +) -> None: + """Test async_track_target_selector_state_change_event with multiple targets.""" + events: list[Event[EventStateChangedData]] = [] + + @callback + def state_change_callback(event: Event[EventStateChangedData]): + """Handle state change events.""" + events.append(event) + + last_state = STATE_OFF + + async def set_states_and_check_events( + entities_to_set_state: list[str], entities_to_assert_change: list[str] + ) -> None: + """Toggle the state entities and check for events.""" + nonlocal last_state + last_state = STATE_ON if last_state == STATE_OFF else STATE_OFF + for entity_id in entities_to_set_state: + hass.states.async_set(entity_id, last_state) + await hass.async_block_till_done() + + assert len(events) == len(entities_to_assert_change) + entities_seen = set() + for event in events: + entities_seen.add(event.data["entity_id"]) + assert event.data["new_state"].state == last_state + assert entities_seen == set(entities_to_assert_change) + events.clear() + + config_entry = MockConfigEntry(domain="test") + config_entry.add_to_hass(hass) + + device_reg = dr.async_get(hass) + device_entry = device_reg.async_get_or_create( + config_entry_id=config_entry.entry_id, + identifiers={("test", "device_1")}, + ) + + untargeted_device_entry = device_reg.async_get_or_create( + config_entry_id=config_entry.entry_id, + identifiers={("test", "area_device")}, + ) + + entity_reg = er.async_get(hass) + device_entity = entity_reg.async_get_or_create( + domain="light", + platform="test", + unique_id="device_light", + device_id=device_entry.id, + ).entity_id + + untargeted_device_entity = entity_reg.async_get_or_create( + domain="light", + platform="test", + unique_id="area_device_light", + device_id=untargeted_device_entry.id, + ).entity_id + + untargeted_entity = entity_reg.async_get_or_create( + domain="light", + platform="test", + unique_id="untargeted_light", + ).entity_id + + targeted_entity = "light.test_light" + + targeted_entities = [targeted_entity, device_entity] + await set_states_and_check_events(targeted_entities, []) + + label = lr.async_get(hass).async_create("Test Label").name + area = ar.async_get(hass).async_create("Test Area").id + floor = fr.async_get(hass).async_create("Test Floor").floor_id + + selector_config = { + ATTR_ENTITY_ID: targeted_entity, + ATTR_DEVICE_ID: device_entry.id, + ATTR_AREA_ID: area, + ATTR_FLOOR_ID: floor, + ATTR_LABEL_ID: label, + } + unsub = target.async_track_target_selector_state_change_event( + hass, selector_config, state_change_callback + ) + + # Test directly targeted entity and device + await set_states_and_check_events(targeted_entities, targeted_entities) + + # Add new entity to the targeted device -> should trigger on state change + device_entity_2 = entity_reg.async_get_or_create( + domain="light", + platform="test", + unique_id="device_light_2", + device_id=device_entry.id, + ).entity_id + + targeted_entities = [targeted_entity, device_entity, device_entity_2] + await set_states_and_check_events(targeted_entities, targeted_entities) + + # Test untargeted entity -> should not trigger + await set_states_and_check_events( + [*targeted_entities, untargeted_entity], targeted_entities + ) + + # Add label to untargeted entity -> should trigger now + entity_reg.async_update_entity(untargeted_entity, labels={label}) + await set_states_and_check_events( + [*targeted_entities, untargeted_entity], [*targeted_entities, untargeted_entity] + ) + + # Remove label from untargeted entity -> should not trigger anymore + entity_reg.async_update_entity(untargeted_entity, labels={}) + await set_states_and_check_events( + [*targeted_entities, untargeted_entity], targeted_entities + ) + + # Add area to untargeted entity -> should trigger now + entity_reg.async_update_entity(untargeted_entity, area_id=area) + await set_states_and_check_events( + [*targeted_entities, untargeted_entity], [*targeted_entities, untargeted_entity] + ) + + # Remove area from untargeted entity -> should not trigger anymore + entity_reg.async_update_entity(untargeted_entity, area_id=None) + await set_states_and_check_events( + [*targeted_entities, untargeted_entity], targeted_entities + ) + + # Add area to untargeted device -> should trigger on state change + device_reg.async_update_device(untargeted_device_entry.id, area_id=area) + await set_states_and_check_events( + [*targeted_entities, untargeted_device_entity], + [*targeted_entities, untargeted_device_entity], + ) + + # Remove area from untargeted device -> should not trigger anymore + device_reg.async_update_device(untargeted_device_entry.id, area_id=None) + await set_states_and_check_events( + [*targeted_entities, untargeted_device_entity], targeted_entities + ) + + # Set the untargeted area on the untargeted entity -> should not trigger + untracked_area = ar.async_get(hass).async_create("Untargeted Area").id + entity_reg.async_update_entity(untargeted_entity, area_id=untracked_area) + await set_states_and_check_events( + [*targeted_entities, untargeted_entity], targeted_entities + ) + + # Set targeted floor on the untargeted area -> should trigger now + ar.async_get(hass).async_update(untracked_area, floor_id=floor) + await set_states_and_check_events( + [*targeted_entities, untargeted_entity], + [*targeted_entities, untargeted_entity], + ) + + # Remove untargeted area from targeted floor -> should not trigger anymore + ar.async_get(hass).async_update(untracked_area, floor_id=None) + await set_states_and_check_events( + [*targeted_entities, untargeted_entity], targeted_entities + ) + + # After unsubscribing, changes should not trigger + unsub() + await set_states_and_check_events(targeted_entities, [])