mirror of
https://github.com/home-assistant/core.git
synced 2025-07-28 07:37:34 +00:00
Add method to track entity state changes from target selectors (#148086)
Co-authored-by: Erik Montnemery <erik@montnemery.com>
This commit is contained in:
parent
8421ca7802
commit
1753baf186
@ -2,9 +2,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import dataclasses
|
||||
import logging
|
||||
from logging import Logger
|
||||
from typing import TypeGuard
|
||||
from typing import Any, TypeGuard
|
||||
|
||||
from homeassistant.const import (
|
||||
ATTR_AREA_ID,
|
||||
@ -14,7 +16,14 @@ from homeassistant.const import (
|
||||
ATTR_LABEL_ID,
|
||||
ENTITY_MATCH_NONE,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.core import (
|
||||
CALLBACK_TYPE,
|
||||
Event,
|
||||
EventStateChangedData,
|
||||
HomeAssistant,
|
||||
callback,
|
||||
)
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
from . import (
|
||||
area_registry as ar,
|
||||
@ -25,8 +34,11 @@ from . import (
|
||||
group,
|
||||
label_registry as lr,
|
||||
)
|
||||
from .event import async_track_state_change_event
|
||||
from .typing import ConfigType
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _has_match(ids: str | list[str] | None) -> TypeGuard[str | list[str]]:
|
||||
"""Check if ids can match anything."""
|
||||
@ -238,3 +250,102 @@ def async_extract_referenced_entity_ids(
|
||||
)
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
class TargetStateChangeTracker:
|
||||
"""Helper class to manage state change tracking for targets."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
selector_data: TargetSelectorData,
|
||||
action: Callable[[Event[EventStateChangedData]], Any],
|
||||
) -> None:
|
||||
"""Initialize the state change tracker."""
|
||||
self._hass = hass
|
||||
self._selector_data = selector_data
|
||||
self._action = action
|
||||
|
||||
self._state_change_unsub: CALLBACK_TYPE | None = None
|
||||
self._registry_unsubs: list[CALLBACK_TYPE] = []
|
||||
|
||||
def async_setup(self) -> Callable[[], None]:
|
||||
"""Set up the state change tracking."""
|
||||
self._setup_registry_listeners()
|
||||
self._track_entities_state_change()
|
||||
return self._unsubscribe
|
||||
|
||||
def _track_entities_state_change(self) -> None:
|
||||
"""Set up state change tracking for currently selected entities."""
|
||||
selected = async_extract_referenced_entity_ids(
|
||||
self._hass, self._selector_data, expand_group=False
|
||||
)
|
||||
|
||||
@callback
|
||||
def state_change_listener(event: Event[EventStateChangedData]) -> None:
|
||||
"""Handle state change events."""
|
||||
if (
|
||||
event.data["entity_id"] in selected.referenced
|
||||
or event.data["entity_id"] in selected.indirectly_referenced
|
||||
):
|
||||
self._action(event)
|
||||
|
||||
tracked_entities = selected.referenced.union(selected.indirectly_referenced)
|
||||
|
||||
_LOGGER.debug("Tracking state changes for entities: %s", tracked_entities)
|
||||
self._state_change_unsub = async_track_state_change_event(
|
||||
self._hass, tracked_entities, state_change_listener
|
||||
)
|
||||
|
||||
def _setup_registry_listeners(self) -> None:
|
||||
"""Set up listeners for registry changes that require resubscription."""
|
||||
|
||||
@callback
|
||||
def resubscribe_state_change_event(event: Event[Any] | None = None) -> None:
|
||||
"""Resubscribe to state change events when registry changes."""
|
||||
if self._state_change_unsub:
|
||||
self._state_change_unsub()
|
||||
self._track_entities_state_change()
|
||||
|
||||
# Subscribe to registry updates that can change the entities to track:
|
||||
# - Entity registry: entity added/removed; entity labels changed; entity area changed.
|
||||
# - Device registry: device labels changed; device area changed.
|
||||
# - Area registry: area floor changed.
|
||||
#
|
||||
# We don't track other registries (like floor or label registries) because their
|
||||
# changes don't affect which entities are tracked.
|
||||
self._registry_unsubs = [
|
||||
self._hass.bus.async_listen(
|
||||
er.EVENT_ENTITY_REGISTRY_UPDATED, resubscribe_state_change_event
|
||||
),
|
||||
self._hass.bus.async_listen(
|
||||
dr.EVENT_DEVICE_REGISTRY_UPDATED, resubscribe_state_change_event
|
||||
),
|
||||
self._hass.bus.async_listen(
|
||||
ar.EVENT_AREA_REGISTRY_UPDATED, resubscribe_state_change_event
|
||||
),
|
||||
]
|
||||
|
||||
def _unsubscribe(self) -> None:
|
||||
"""Unsubscribe from all events."""
|
||||
for registry_unsub in self._registry_unsubs:
|
||||
registry_unsub()
|
||||
self._registry_unsubs.clear()
|
||||
if self._state_change_unsub:
|
||||
self._state_change_unsub()
|
||||
self._state_change_unsub = None
|
||||
|
||||
|
||||
def async_track_target_selector_state_change_event(
|
||||
hass: HomeAssistant,
|
||||
target_selector_config: ConfigType,
|
||||
action: Callable[[Event[EventStateChangedData]], Any],
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Track state changes for entities referenced directly or indirectly in a target selector."""
|
||||
selector_data = TargetSelectorData(target_selector_config)
|
||||
if not selector_data.has_any_selector:
|
||||
raise HomeAssistantError(
|
||||
f"Target selector {target_selector_config} does not have any selectors defined"
|
||||
)
|
||||
tracker = TargetStateChangeTracker(hass, selector_data, action)
|
||||
return tracker.async_setup()
|
||||
|
@ -2,9 +2,6 @@
|
||||
|
||||
import pytest
|
||||
|
||||
# TODO(abmantis): is this import needed?
|
||||
# To prevent circular import when running just this file
|
||||
import homeassistant.components # noqa: F401
|
||||
from homeassistant.components.group import Group
|
||||
from homeassistant.const import (
|
||||
ATTR_AREA_ID,
|
||||
@ -17,17 +14,21 @@ from homeassistant.const import (
|
||||
STATE_ON,
|
||||
EntityCategory,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.core import Event, EventStateChangedData, HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import (
|
||||
area_registry as ar,
|
||||
device_registry as dr,
|
||||
entity_registry as er,
|
||||
floor_registry as fr,
|
||||
label_registry as lr,
|
||||
target,
|
||||
)
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import (
|
||||
MockConfigEntry,
|
||||
RegistryEntryWithDefaults,
|
||||
mock_area_registry,
|
||||
mock_device_registry,
|
||||
@ -457,3 +458,188 @@ async def test_extract_referenced_entity_ids(
|
||||
)
|
||||
== expected_selected
|
||||
)
|
||||
|
||||
|
||||
async def test_async_track_target_selector_state_change_event_empty_selector(
|
||||
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Test async_track_target_selector_state_change_event with empty selector."""
|
||||
|
||||
@callback
|
||||
def state_change_callback(event):
|
||||
"""Handle state change events."""
|
||||
|
||||
with pytest.raises(HomeAssistantError) as excinfo:
|
||||
target.async_track_target_selector_state_change_event(
|
||||
hass, {}, state_change_callback
|
||||
)
|
||||
assert str(excinfo.value) == (
|
||||
"Target selector {} does not have any selectors defined"
|
||||
)
|
||||
|
||||
|
||||
async def test_async_track_target_selector_state_change_event(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test async_track_target_selector_state_change_event with multiple targets."""
|
||||
events: list[Event[EventStateChangedData]] = []
|
||||
|
||||
@callback
|
||||
def state_change_callback(event: Event[EventStateChangedData]):
|
||||
"""Handle state change events."""
|
||||
events.append(event)
|
||||
|
||||
last_state = STATE_OFF
|
||||
|
||||
async def set_states_and_check_events(
|
||||
entities_to_set_state: list[str], entities_to_assert_change: list[str]
|
||||
) -> None:
|
||||
"""Toggle the state entities and check for events."""
|
||||
nonlocal last_state
|
||||
last_state = STATE_ON if last_state == STATE_OFF else STATE_OFF
|
||||
for entity_id in entities_to_set_state:
|
||||
hass.states.async_set(entity_id, last_state)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(events) == len(entities_to_assert_change)
|
||||
entities_seen = set()
|
||||
for event in events:
|
||||
entities_seen.add(event.data["entity_id"])
|
||||
assert event.data["new_state"].state == last_state
|
||||
assert entities_seen == set(entities_to_assert_change)
|
||||
events.clear()
|
||||
|
||||
config_entry = MockConfigEntry(domain="test")
|
||||
config_entry.add_to_hass(hass)
|
||||
|
||||
device_reg = dr.async_get(hass)
|
||||
device_entry = device_reg.async_get_or_create(
|
||||
config_entry_id=config_entry.entry_id,
|
||||
identifiers={("test", "device_1")},
|
||||
)
|
||||
|
||||
untargeted_device_entry = device_reg.async_get_or_create(
|
||||
config_entry_id=config_entry.entry_id,
|
||||
identifiers={("test", "area_device")},
|
||||
)
|
||||
|
||||
entity_reg = er.async_get(hass)
|
||||
device_entity = entity_reg.async_get_or_create(
|
||||
domain="light",
|
||||
platform="test",
|
||||
unique_id="device_light",
|
||||
device_id=device_entry.id,
|
||||
).entity_id
|
||||
|
||||
untargeted_device_entity = entity_reg.async_get_or_create(
|
||||
domain="light",
|
||||
platform="test",
|
||||
unique_id="area_device_light",
|
||||
device_id=untargeted_device_entry.id,
|
||||
).entity_id
|
||||
|
||||
untargeted_entity = entity_reg.async_get_or_create(
|
||||
domain="light",
|
||||
platform="test",
|
||||
unique_id="untargeted_light",
|
||||
).entity_id
|
||||
|
||||
targeted_entity = "light.test_light"
|
||||
|
||||
targeted_entities = [targeted_entity, device_entity]
|
||||
await set_states_and_check_events(targeted_entities, [])
|
||||
|
||||
label = lr.async_get(hass).async_create("Test Label").name
|
||||
area = ar.async_get(hass).async_create("Test Area").id
|
||||
floor = fr.async_get(hass).async_create("Test Floor").floor_id
|
||||
|
||||
selector_config = {
|
||||
ATTR_ENTITY_ID: targeted_entity,
|
||||
ATTR_DEVICE_ID: device_entry.id,
|
||||
ATTR_AREA_ID: area,
|
||||
ATTR_FLOOR_ID: floor,
|
||||
ATTR_LABEL_ID: label,
|
||||
}
|
||||
unsub = target.async_track_target_selector_state_change_event(
|
||||
hass, selector_config, state_change_callback
|
||||
)
|
||||
|
||||
# Test directly targeted entity and device
|
||||
await set_states_and_check_events(targeted_entities, targeted_entities)
|
||||
|
||||
# Add new entity to the targeted device -> should trigger on state change
|
||||
device_entity_2 = entity_reg.async_get_or_create(
|
||||
domain="light",
|
||||
platform="test",
|
||||
unique_id="device_light_2",
|
||||
device_id=device_entry.id,
|
||||
).entity_id
|
||||
|
||||
targeted_entities = [targeted_entity, device_entity, device_entity_2]
|
||||
await set_states_and_check_events(targeted_entities, targeted_entities)
|
||||
|
||||
# Test untargeted entity -> should not trigger
|
||||
await set_states_and_check_events(
|
||||
[*targeted_entities, untargeted_entity], targeted_entities
|
||||
)
|
||||
|
||||
# Add label to untargeted entity -> should trigger now
|
||||
entity_reg.async_update_entity(untargeted_entity, labels={label})
|
||||
await set_states_and_check_events(
|
||||
[*targeted_entities, untargeted_entity], [*targeted_entities, untargeted_entity]
|
||||
)
|
||||
|
||||
# Remove label from untargeted entity -> should not trigger anymore
|
||||
entity_reg.async_update_entity(untargeted_entity, labels={})
|
||||
await set_states_and_check_events(
|
||||
[*targeted_entities, untargeted_entity], targeted_entities
|
||||
)
|
||||
|
||||
# Add area to untargeted entity -> should trigger now
|
||||
entity_reg.async_update_entity(untargeted_entity, area_id=area)
|
||||
await set_states_and_check_events(
|
||||
[*targeted_entities, untargeted_entity], [*targeted_entities, untargeted_entity]
|
||||
)
|
||||
|
||||
# Remove area from untargeted entity -> should not trigger anymore
|
||||
entity_reg.async_update_entity(untargeted_entity, area_id=None)
|
||||
await set_states_and_check_events(
|
||||
[*targeted_entities, untargeted_entity], targeted_entities
|
||||
)
|
||||
|
||||
# Add area to untargeted device -> should trigger on state change
|
||||
device_reg.async_update_device(untargeted_device_entry.id, area_id=area)
|
||||
await set_states_and_check_events(
|
||||
[*targeted_entities, untargeted_device_entity],
|
||||
[*targeted_entities, untargeted_device_entity],
|
||||
)
|
||||
|
||||
# Remove area from untargeted device -> should not trigger anymore
|
||||
device_reg.async_update_device(untargeted_device_entry.id, area_id=None)
|
||||
await set_states_and_check_events(
|
||||
[*targeted_entities, untargeted_device_entity], targeted_entities
|
||||
)
|
||||
|
||||
# Set the untargeted area on the untargeted entity -> should not trigger
|
||||
untracked_area = ar.async_get(hass).async_create("Untargeted Area").id
|
||||
entity_reg.async_update_entity(untargeted_entity, area_id=untracked_area)
|
||||
await set_states_and_check_events(
|
||||
[*targeted_entities, untargeted_entity], targeted_entities
|
||||
)
|
||||
|
||||
# Set targeted floor on the untargeted area -> should trigger now
|
||||
ar.async_get(hass).async_update(untracked_area, floor_id=floor)
|
||||
await set_states_and_check_events(
|
||||
[*targeted_entities, untargeted_entity],
|
||||
[*targeted_entities, untargeted_entity],
|
||||
)
|
||||
|
||||
# Remove untargeted area from targeted floor -> should not trigger anymore
|
||||
ar.async_get(hass).async_update(untracked_area, floor_id=None)
|
||||
await set_states_and_check_events(
|
||||
[*targeted_entities, untargeted_entity], targeted_entities
|
||||
)
|
||||
|
||||
# After unsubscribing, changes should not trigger
|
||||
unsub()
|
||||
await set_states_and_check_events(targeted_entities, [])
|
||||
|
Loading…
x
Reference in New Issue
Block a user