Add method to track entity state changes from target selectors (#148086)

Co-authored-by: Erik Montnemery <erik@montnemery.com>
This commit is contained in:
Abílio Costa 2025-07-14 19:28:53 +01:00 committed by GitHub
parent 8421ca7802
commit 1753baf186
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 303 additions and 6 deletions

View File

@ -2,9 +2,11 @@
from __future__ import annotations
from collections.abc import Callable
import dataclasses
import logging
from logging import Logger
from typing import TypeGuard
from typing import Any, TypeGuard
from homeassistant.const import (
ATTR_AREA_ID,
@ -14,7 +16,14 @@ from homeassistant.const import (
ATTR_LABEL_ID,
ENTITY_MATCH_NONE,
)
from homeassistant.core import HomeAssistant
from homeassistant.core import (
CALLBACK_TYPE,
Event,
EventStateChangedData,
HomeAssistant,
callback,
)
from homeassistant.exceptions import HomeAssistantError
from . import (
area_registry as ar,
@ -25,8 +34,11 @@ from . import (
group,
label_registry as lr,
)
from .event import async_track_state_change_event
from .typing import ConfigType
_LOGGER = logging.getLogger(__name__)
def _has_match(ids: str | list[str] | None) -> TypeGuard[str | list[str]]:
"""Check if ids can match anything."""
@ -238,3 +250,102 @@ def async_extract_referenced_entity_ids(
)
return selected
class TargetStateChangeTracker:
"""Helper class to manage state change tracking for targets."""
def __init__(
self,
hass: HomeAssistant,
selector_data: TargetSelectorData,
action: Callable[[Event[EventStateChangedData]], Any],
) -> None:
"""Initialize the state change tracker."""
self._hass = hass
self._selector_data = selector_data
self._action = action
self._state_change_unsub: CALLBACK_TYPE | None = None
self._registry_unsubs: list[CALLBACK_TYPE] = []
def async_setup(self) -> Callable[[], None]:
"""Set up the state change tracking."""
self._setup_registry_listeners()
self._track_entities_state_change()
return self._unsubscribe
def _track_entities_state_change(self) -> None:
"""Set up state change tracking for currently selected entities."""
selected = async_extract_referenced_entity_ids(
self._hass, self._selector_data, expand_group=False
)
@callback
def state_change_listener(event: Event[EventStateChangedData]) -> None:
"""Handle state change events."""
if (
event.data["entity_id"] in selected.referenced
or event.data["entity_id"] in selected.indirectly_referenced
):
self._action(event)
tracked_entities = selected.referenced.union(selected.indirectly_referenced)
_LOGGER.debug("Tracking state changes for entities: %s", tracked_entities)
self._state_change_unsub = async_track_state_change_event(
self._hass, tracked_entities, state_change_listener
)
def _setup_registry_listeners(self) -> None:
"""Set up listeners for registry changes that require resubscription."""
@callback
def resubscribe_state_change_event(event: Event[Any] | None = None) -> None:
"""Resubscribe to state change events when registry changes."""
if self._state_change_unsub:
self._state_change_unsub()
self._track_entities_state_change()
# Subscribe to registry updates that can change the entities to track:
# - Entity registry: entity added/removed; entity labels changed; entity area changed.
# - Device registry: device labels changed; device area changed.
# - Area registry: area floor changed.
#
# We don't track other registries (like floor or label registries) because their
# changes don't affect which entities are tracked.
self._registry_unsubs = [
self._hass.bus.async_listen(
er.EVENT_ENTITY_REGISTRY_UPDATED, resubscribe_state_change_event
),
self._hass.bus.async_listen(
dr.EVENT_DEVICE_REGISTRY_UPDATED, resubscribe_state_change_event
),
self._hass.bus.async_listen(
ar.EVENT_AREA_REGISTRY_UPDATED, resubscribe_state_change_event
),
]
def _unsubscribe(self) -> None:
"""Unsubscribe from all events."""
for registry_unsub in self._registry_unsubs:
registry_unsub()
self._registry_unsubs.clear()
if self._state_change_unsub:
self._state_change_unsub()
self._state_change_unsub = None
def async_track_target_selector_state_change_event(
hass: HomeAssistant,
target_selector_config: ConfigType,
action: Callable[[Event[EventStateChangedData]], Any],
) -> CALLBACK_TYPE:
"""Track state changes for entities referenced directly or indirectly in a target selector."""
selector_data = TargetSelectorData(target_selector_config)
if not selector_data.has_any_selector:
raise HomeAssistantError(
f"Target selector {target_selector_config} does not have any selectors defined"
)
tracker = TargetStateChangeTracker(hass, selector_data, action)
return tracker.async_setup()

View File

@ -2,9 +2,6 @@
import pytest
# TODO(abmantis): is this import needed?
# To prevent circular import when running just this file
import homeassistant.components # noqa: F401
from homeassistant.components.group import Group
from homeassistant.const import (
ATTR_AREA_ID,
@ -17,17 +14,21 @@ from homeassistant.const import (
STATE_ON,
EntityCategory,
)
from homeassistant.core import HomeAssistant
from homeassistant.core import Event, EventStateChangedData, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import (
area_registry as ar,
device_registry as dr,
entity_registry as er,
floor_registry as fr,
label_registry as lr,
target,
)
from homeassistant.helpers.typing import ConfigType
from homeassistant.setup import async_setup_component
from tests.common import (
MockConfigEntry,
RegistryEntryWithDefaults,
mock_area_registry,
mock_device_registry,
@ -457,3 +458,188 @@ async def test_extract_referenced_entity_ids(
)
== expected_selected
)
async def test_async_track_target_selector_state_change_event_empty_selector(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test async_track_target_selector_state_change_event with empty selector."""
@callback
def state_change_callback(event):
"""Handle state change events."""
with pytest.raises(HomeAssistantError) as excinfo:
target.async_track_target_selector_state_change_event(
hass, {}, state_change_callback
)
assert str(excinfo.value) == (
"Target selector {} does not have any selectors defined"
)
async def test_async_track_target_selector_state_change_event(
hass: HomeAssistant,
) -> None:
"""Test async_track_target_selector_state_change_event with multiple targets."""
events: list[Event[EventStateChangedData]] = []
@callback
def state_change_callback(event: Event[EventStateChangedData]):
"""Handle state change events."""
events.append(event)
last_state = STATE_OFF
async def set_states_and_check_events(
entities_to_set_state: list[str], entities_to_assert_change: list[str]
) -> None:
"""Toggle the state entities and check for events."""
nonlocal last_state
last_state = STATE_ON if last_state == STATE_OFF else STATE_OFF
for entity_id in entities_to_set_state:
hass.states.async_set(entity_id, last_state)
await hass.async_block_till_done()
assert len(events) == len(entities_to_assert_change)
entities_seen = set()
for event in events:
entities_seen.add(event.data["entity_id"])
assert event.data["new_state"].state == last_state
assert entities_seen == set(entities_to_assert_change)
events.clear()
config_entry = MockConfigEntry(domain="test")
config_entry.add_to_hass(hass)
device_reg = dr.async_get(hass)
device_entry = device_reg.async_get_or_create(
config_entry_id=config_entry.entry_id,
identifiers={("test", "device_1")},
)
untargeted_device_entry = device_reg.async_get_or_create(
config_entry_id=config_entry.entry_id,
identifiers={("test", "area_device")},
)
entity_reg = er.async_get(hass)
device_entity = entity_reg.async_get_or_create(
domain="light",
platform="test",
unique_id="device_light",
device_id=device_entry.id,
).entity_id
untargeted_device_entity = entity_reg.async_get_or_create(
domain="light",
platform="test",
unique_id="area_device_light",
device_id=untargeted_device_entry.id,
).entity_id
untargeted_entity = entity_reg.async_get_or_create(
domain="light",
platform="test",
unique_id="untargeted_light",
).entity_id
targeted_entity = "light.test_light"
targeted_entities = [targeted_entity, device_entity]
await set_states_and_check_events(targeted_entities, [])
label = lr.async_get(hass).async_create("Test Label").name
area = ar.async_get(hass).async_create("Test Area").id
floor = fr.async_get(hass).async_create("Test Floor").floor_id
selector_config = {
ATTR_ENTITY_ID: targeted_entity,
ATTR_DEVICE_ID: device_entry.id,
ATTR_AREA_ID: area,
ATTR_FLOOR_ID: floor,
ATTR_LABEL_ID: label,
}
unsub = target.async_track_target_selector_state_change_event(
hass, selector_config, state_change_callback
)
# Test directly targeted entity and device
await set_states_and_check_events(targeted_entities, targeted_entities)
# Add new entity to the targeted device -> should trigger on state change
device_entity_2 = entity_reg.async_get_or_create(
domain="light",
platform="test",
unique_id="device_light_2",
device_id=device_entry.id,
).entity_id
targeted_entities = [targeted_entity, device_entity, device_entity_2]
await set_states_and_check_events(targeted_entities, targeted_entities)
# Test untargeted entity -> should not trigger
await set_states_and_check_events(
[*targeted_entities, untargeted_entity], targeted_entities
)
# Add label to untargeted entity -> should trigger now
entity_reg.async_update_entity(untargeted_entity, labels={label})
await set_states_and_check_events(
[*targeted_entities, untargeted_entity], [*targeted_entities, untargeted_entity]
)
# Remove label from untargeted entity -> should not trigger anymore
entity_reg.async_update_entity(untargeted_entity, labels={})
await set_states_and_check_events(
[*targeted_entities, untargeted_entity], targeted_entities
)
# Add area to untargeted entity -> should trigger now
entity_reg.async_update_entity(untargeted_entity, area_id=area)
await set_states_and_check_events(
[*targeted_entities, untargeted_entity], [*targeted_entities, untargeted_entity]
)
# Remove area from untargeted entity -> should not trigger anymore
entity_reg.async_update_entity(untargeted_entity, area_id=None)
await set_states_and_check_events(
[*targeted_entities, untargeted_entity], targeted_entities
)
# Add area to untargeted device -> should trigger on state change
device_reg.async_update_device(untargeted_device_entry.id, area_id=area)
await set_states_and_check_events(
[*targeted_entities, untargeted_device_entity],
[*targeted_entities, untargeted_device_entity],
)
# Remove area from untargeted device -> should not trigger anymore
device_reg.async_update_device(untargeted_device_entry.id, area_id=None)
await set_states_and_check_events(
[*targeted_entities, untargeted_device_entity], targeted_entities
)
# Set the untargeted area on the untargeted entity -> should not trigger
untracked_area = ar.async_get(hass).async_create("Untargeted Area").id
entity_reg.async_update_entity(untargeted_entity, area_id=untracked_area)
await set_states_and_check_events(
[*targeted_entities, untargeted_entity], targeted_entities
)
# Set targeted floor on the untargeted area -> should trigger now
ar.async_get(hass).async_update(untracked_area, floor_id=floor)
await set_states_and_check_events(
[*targeted_entities, untargeted_entity],
[*targeted_entities, untargeted_entity],
)
# Remove untargeted area from targeted floor -> should not trigger anymore
ar.async_get(hass).async_update(untracked_area, floor_id=None)
await set_states_and_check_events(
[*targeted_entities, untargeted_entity], targeted_entities
)
# After unsubscribing, changes should not trigger
unsub()
await set_states_and_check_events(targeted_entities, [])