Add method to track entity state changes from target selectors (#148086)

Co-authored-by: Erik Montnemery <erik@montnemery.com>
This commit is contained in:
Abílio Costa 2025-07-14 19:28:53 +01:00 committed by GitHub
parent 8421ca7802
commit 1753baf186
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 303 additions and 6 deletions

View File

@ -2,9 +2,11 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable
import dataclasses import dataclasses
import logging
from logging import Logger from logging import Logger
from typing import TypeGuard from typing import Any, TypeGuard
from homeassistant.const import ( from homeassistant.const import (
ATTR_AREA_ID, ATTR_AREA_ID,
@ -14,7 +16,14 @@ from homeassistant.const import (
ATTR_LABEL_ID, ATTR_LABEL_ID,
ENTITY_MATCH_NONE, ENTITY_MATCH_NONE,
) )
from homeassistant.core import HomeAssistant from homeassistant.core import (
CALLBACK_TYPE,
Event,
EventStateChangedData,
HomeAssistant,
callback,
)
from homeassistant.exceptions import HomeAssistantError
from . import ( from . import (
area_registry as ar, area_registry as ar,
@ -25,8 +34,11 @@ from . import (
group, group,
label_registry as lr, label_registry as lr,
) )
from .event import async_track_state_change_event
from .typing import ConfigType from .typing import ConfigType
_LOGGER = logging.getLogger(__name__)
def _has_match(ids: str | list[str] | None) -> TypeGuard[str | list[str]]: def _has_match(ids: str | list[str] | None) -> TypeGuard[str | list[str]]:
"""Check if ids can match anything.""" """Check if ids can match anything."""
@ -238,3 +250,102 @@ def async_extract_referenced_entity_ids(
) )
return selected return selected
class TargetStateChangeTracker:
"""Helper class to manage state change tracking for targets."""
def __init__(
self,
hass: HomeAssistant,
selector_data: TargetSelectorData,
action: Callable[[Event[EventStateChangedData]], Any],
) -> None:
"""Initialize the state change tracker."""
self._hass = hass
self._selector_data = selector_data
self._action = action
self._state_change_unsub: CALLBACK_TYPE | None = None
self._registry_unsubs: list[CALLBACK_TYPE] = []
def async_setup(self) -> Callable[[], None]:
"""Set up the state change tracking."""
self._setup_registry_listeners()
self._track_entities_state_change()
return self._unsubscribe
def _track_entities_state_change(self) -> None:
"""Set up state change tracking for currently selected entities."""
selected = async_extract_referenced_entity_ids(
self._hass, self._selector_data, expand_group=False
)
@callback
def state_change_listener(event: Event[EventStateChangedData]) -> None:
"""Handle state change events."""
if (
event.data["entity_id"] in selected.referenced
or event.data["entity_id"] in selected.indirectly_referenced
):
self._action(event)
tracked_entities = selected.referenced.union(selected.indirectly_referenced)
_LOGGER.debug("Tracking state changes for entities: %s", tracked_entities)
self._state_change_unsub = async_track_state_change_event(
self._hass, tracked_entities, state_change_listener
)
def _setup_registry_listeners(self) -> None:
"""Set up listeners for registry changes that require resubscription."""
@callback
def resubscribe_state_change_event(event: Event[Any] | None = None) -> None:
"""Resubscribe to state change events when registry changes."""
if self._state_change_unsub:
self._state_change_unsub()
self._track_entities_state_change()
# Subscribe to registry updates that can change the entities to track:
# - Entity registry: entity added/removed; entity labels changed; entity area changed.
# - Device registry: device labels changed; device area changed.
# - Area registry: area floor changed.
#
# We don't track other registries (like floor or label registries) because their
# changes don't affect which entities are tracked.
self._registry_unsubs = [
self._hass.bus.async_listen(
er.EVENT_ENTITY_REGISTRY_UPDATED, resubscribe_state_change_event
),
self._hass.bus.async_listen(
dr.EVENT_DEVICE_REGISTRY_UPDATED, resubscribe_state_change_event
),
self._hass.bus.async_listen(
ar.EVENT_AREA_REGISTRY_UPDATED, resubscribe_state_change_event
),
]
def _unsubscribe(self) -> None:
"""Unsubscribe from all events."""
for registry_unsub in self._registry_unsubs:
registry_unsub()
self._registry_unsubs.clear()
if self._state_change_unsub:
self._state_change_unsub()
self._state_change_unsub = None
def async_track_target_selector_state_change_event(
hass: HomeAssistant,
target_selector_config: ConfigType,
action: Callable[[Event[EventStateChangedData]], Any],
) -> CALLBACK_TYPE:
"""Track state changes for entities referenced directly or indirectly in a target selector."""
selector_data = TargetSelectorData(target_selector_config)
if not selector_data.has_any_selector:
raise HomeAssistantError(
f"Target selector {target_selector_config} does not have any selectors defined"
)
tracker = TargetStateChangeTracker(hass, selector_data, action)
return tracker.async_setup()

View File

@ -2,9 +2,6 @@
import pytest import pytest
# TODO(abmantis): is this import needed?
# To prevent circular import when running just this file
import homeassistant.components # noqa: F401
from homeassistant.components.group import Group from homeassistant.components.group import Group
from homeassistant.const import ( from homeassistant.const import (
ATTR_AREA_ID, ATTR_AREA_ID,
@ -17,17 +14,21 @@ from homeassistant.const import (
STATE_ON, STATE_ON,
EntityCategory, EntityCategory,
) )
from homeassistant.core import HomeAssistant from homeassistant.core import Event, EventStateChangedData, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import ( from homeassistant.helpers import (
area_registry as ar, area_registry as ar,
device_registry as dr, device_registry as dr,
entity_registry as er, entity_registry as er,
floor_registry as fr,
label_registry as lr,
target, target,
) )
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import ( from tests.common import (
MockConfigEntry,
RegistryEntryWithDefaults, RegistryEntryWithDefaults,
mock_area_registry, mock_area_registry,
mock_device_registry, mock_device_registry,
@ -457,3 +458,188 @@ async def test_extract_referenced_entity_ids(
) )
== expected_selected == expected_selected
) )
async def test_async_track_target_selector_state_change_event_empty_selector(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test async_track_target_selector_state_change_event with empty selector."""
@callback
def state_change_callback(event):
"""Handle state change events."""
with pytest.raises(HomeAssistantError) as excinfo:
target.async_track_target_selector_state_change_event(
hass, {}, state_change_callback
)
assert str(excinfo.value) == (
"Target selector {} does not have any selectors defined"
)
async def test_async_track_target_selector_state_change_event(
hass: HomeAssistant,
) -> None:
"""Test async_track_target_selector_state_change_event with multiple targets."""
events: list[Event[EventStateChangedData]] = []
@callback
def state_change_callback(event: Event[EventStateChangedData]):
"""Handle state change events."""
events.append(event)
last_state = STATE_OFF
async def set_states_and_check_events(
entities_to_set_state: list[str], entities_to_assert_change: list[str]
) -> None:
"""Toggle the state entities and check for events."""
nonlocal last_state
last_state = STATE_ON if last_state == STATE_OFF else STATE_OFF
for entity_id in entities_to_set_state:
hass.states.async_set(entity_id, last_state)
await hass.async_block_till_done()
assert len(events) == len(entities_to_assert_change)
entities_seen = set()
for event in events:
entities_seen.add(event.data["entity_id"])
assert event.data["new_state"].state == last_state
assert entities_seen == set(entities_to_assert_change)
events.clear()
config_entry = MockConfigEntry(domain="test")
config_entry.add_to_hass(hass)
device_reg = dr.async_get(hass)
device_entry = device_reg.async_get_or_create(
config_entry_id=config_entry.entry_id,
identifiers={("test", "device_1")},
)
untargeted_device_entry = device_reg.async_get_or_create(
config_entry_id=config_entry.entry_id,
identifiers={("test", "area_device")},
)
entity_reg = er.async_get(hass)
device_entity = entity_reg.async_get_or_create(
domain="light",
platform="test",
unique_id="device_light",
device_id=device_entry.id,
).entity_id
untargeted_device_entity = entity_reg.async_get_or_create(
domain="light",
platform="test",
unique_id="area_device_light",
device_id=untargeted_device_entry.id,
).entity_id
untargeted_entity = entity_reg.async_get_or_create(
domain="light",
platform="test",
unique_id="untargeted_light",
).entity_id
targeted_entity = "light.test_light"
targeted_entities = [targeted_entity, device_entity]
await set_states_and_check_events(targeted_entities, [])
label = lr.async_get(hass).async_create("Test Label").name
area = ar.async_get(hass).async_create("Test Area").id
floor = fr.async_get(hass).async_create("Test Floor").floor_id
selector_config = {
ATTR_ENTITY_ID: targeted_entity,
ATTR_DEVICE_ID: device_entry.id,
ATTR_AREA_ID: area,
ATTR_FLOOR_ID: floor,
ATTR_LABEL_ID: label,
}
unsub = target.async_track_target_selector_state_change_event(
hass, selector_config, state_change_callback
)
# Test directly targeted entity and device
await set_states_and_check_events(targeted_entities, targeted_entities)
# Add new entity to the targeted device -> should trigger on state change
device_entity_2 = entity_reg.async_get_or_create(
domain="light",
platform="test",
unique_id="device_light_2",
device_id=device_entry.id,
).entity_id
targeted_entities = [targeted_entity, device_entity, device_entity_2]
await set_states_and_check_events(targeted_entities, targeted_entities)
# Test untargeted entity -> should not trigger
await set_states_and_check_events(
[*targeted_entities, untargeted_entity], targeted_entities
)
# Add label to untargeted entity -> should trigger now
entity_reg.async_update_entity(untargeted_entity, labels={label})
await set_states_and_check_events(
[*targeted_entities, untargeted_entity], [*targeted_entities, untargeted_entity]
)
# Remove label from untargeted entity -> should not trigger anymore
entity_reg.async_update_entity(untargeted_entity, labels={})
await set_states_and_check_events(
[*targeted_entities, untargeted_entity], targeted_entities
)
# Add area to untargeted entity -> should trigger now
entity_reg.async_update_entity(untargeted_entity, area_id=area)
await set_states_and_check_events(
[*targeted_entities, untargeted_entity], [*targeted_entities, untargeted_entity]
)
# Remove area from untargeted entity -> should not trigger anymore
entity_reg.async_update_entity(untargeted_entity, area_id=None)
await set_states_and_check_events(
[*targeted_entities, untargeted_entity], targeted_entities
)
# Add area to untargeted device -> should trigger on state change
device_reg.async_update_device(untargeted_device_entry.id, area_id=area)
await set_states_and_check_events(
[*targeted_entities, untargeted_device_entity],
[*targeted_entities, untargeted_device_entity],
)
# Remove area from untargeted device -> should not trigger anymore
device_reg.async_update_device(untargeted_device_entry.id, area_id=None)
await set_states_and_check_events(
[*targeted_entities, untargeted_device_entity], targeted_entities
)
# Set the untargeted area on the untargeted entity -> should not trigger
untracked_area = ar.async_get(hass).async_create("Untargeted Area").id
entity_reg.async_update_entity(untargeted_entity, area_id=untracked_area)
await set_states_and_check_events(
[*targeted_entities, untargeted_entity], targeted_entities
)
# Set targeted floor on the untargeted area -> should trigger now
ar.async_get(hass).async_update(untracked_area, floor_id=floor)
await set_states_and_check_events(
[*targeted_entities, untargeted_entity],
[*targeted_entities, untargeted_entity],
)
# Remove untargeted area from targeted floor -> should not trigger anymore
ar.async_get(hass).async_update(untracked_area, floor_id=None)
await set_states_and_check_events(
[*targeted_entities, untargeted_entity], targeted_entities
)
# After unsubscribing, changes should not trigger
unsub()
await set_states_and_check_events(targeted_entities, [])