Add method to track entity state changes from target selectors

This commit is contained in:
abmantis 2025-07-03 19:17:04 +01:00
parent 7fbf25e862
commit 89a9ab699d
2 changed files with 544 additions and 4 deletions

View File

@ -6,24 +6,33 @@ import abc
import asyncio
from collections import defaultdict
from collections.abc import Callable, Coroutine, Iterable
import dataclasses
from dataclasses import dataclass, field
import functools
import logging
from typing import TYPE_CHECKING, Any, Protocol, TypedDict, cast
from typing import TYPE_CHECKING, Any, Protocol, TypedDict, TypeGuard, cast
import voluptuous as vol
from homeassistant.const import (
ATTR_AREA_ID,
ATTR_DEVICE_ID,
ATTR_ENTITY_ID,
ATTR_FLOOR_ID,
ATTR_LABEL_ID,
CONF_ALIAS,
CONF_ENABLED,
CONF_ID,
CONF_PLATFORM,
CONF_VARIABLES,
ENTITY_MATCH_NONE,
)
from homeassistant.core import (
CALLBACK_TYPE,
Context,
Event,
HassJob,
HassJobType,
HomeAssistant,
callback,
is_callback,
@ -40,7 +49,16 @@ from homeassistant.util.hass_dict import HassKey
from homeassistant.util.yaml import load_yaml_dict
from homeassistant.util.yaml.loader import JSON_TYPE
from . import config_validation as cv
from . import (
area_registry,
config_validation as cv,
device_registry,
entity_registry,
floor_registry,
label_registry,
)
from .event import EventStateChangedData, async_track_state_change_event
from .group import expand_entity_ids
from .integration_platform import async_process_integration_platforms
from .template import Template
from .typing import ConfigType, TemplateVarsType
@ -617,3 +635,289 @@ async def async_get_all_descriptions(
hass.data[TRIGGER_DESCRIPTION_CACHE] = new_descriptions_cache
return new_descriptions_cache
def _has_match(ids: str | list[str] | None) -> TypeGuard[str | list[str]]:
"""Check if ids can match anything."""
return ids not in (None, ENTITY_MATCH_NONE)
# TODO(abmantis): Since this is a copy from the service one, move it to a common place and use it in both places
class TargetSelectorData:
"""Class to hold data of target selector."""
__slots__ = ("area_ids", "device_ids", "entity_ids", "floor_ids", "label_ids")
def __init__(self, config: ConfigType) -> None:
"""Extract ids from the config."""
entity_ids: str | list | None = config.get(ATTR_ENTITY_ID)
device_ids: str | list | None = config.get(ATTR_DEVICE_ID)
area_ids: str | list | None = config.get(ATTR_AREA_ID)
floor_ids: str | list | None = config.get(ATTR_FLOOR_ID)
label_ids: str | list | None = config.get(ATTR_LABEL_ID)
self.entity_ids = (
set(cv.ensure_list(entity_ids)) if _has_match(entity_ids) else set()
)
self.device_ids = (
set(cv.ensure_list(device_ids)) if _has_match(device_ids) else set()
)
self.area_ids = set(cv.ensure_list(area_ids)) if _has_match(area_ids) else set()
self.floor_ids = (
set(cv.ensure_list(floor_ids)) if _has_match(floor_ids) else set()
)
self.label_ids = (
set(cv.ensure_list(label_ids)) if _has_match(label_ids) else set()
)
@property
def has_any_selector(self) -> bool:
"""Determine if any selectors are present."""
return bool(
self.entity_ids
or self.device_ids
or self.area_ids
or self.floor_ids
or self.label_ids
)
# TODO(abmantis): Since this is a copy from the service one, move it to a common place and use it in both places
@dataclasses.dataclass(slots=True)
class SelectedEntities:
"""Class to hold the selected entities."""
# Entities that were explicitly mentioned.
referenced: set[str] = dataclasses.field(default_factory=set)
# Entities that were referenced via device/area/floor/label ID.
# Should not trigger a warning when they don't exist.
indirectly_referenced: set[str] = dataclasses.field(default_factory=set)
# Referenced items that could not be found.
missing_devices: set[str] = dataclasses.field(default_factory=set)
missing_areas: set[str] = dataclasses.field(default_factory=set)
missing_floors: set[str] = dataclasses.field(default_factory=set)
missing_labels: set[str] = dataclasses.field(default_factory=set)
referenced_devices: set[str] = dataclasses.field(default_factory=set)
referenced_areas: set[str] = dataclasses.field(default_factory=set)
def log_missing(self, missing_entities: set[str]) -> None:
"""Log about missing items."""
parts = []
for label, items in (
("floors", self.missing_floors),
("areas", self.missing_areas),
("devices", self.missing_devices),
("entities", missing_entities),
("labels", self.missing_labels),
):
if items:
parts.append(f"{label} {', '.join(sorted(items))}")
if not parts:
return
_LOGGER.warning(
"Referenced %s are missing or not currently available",
", ".join(parts),
)
# TODO(abmantis): Since this is a copy from the service one, move it to a common place and use it in both places
def async_extract_referenced_entity_ids(
hass: HomeAssistant, selector_data: TargetSelectorData, expand_group: bool = True
) -> SelectedEntities:
"""Extract referenced entity IDs from a target selector."""
selected = SelectedEntities()
if not selector_data.has_any_selector:
return selected
entity_ids: set[str] | list[str] = selector_data.entity_ids
if expand_group:
entity_ids = expand_entity_ids(hass, entity_ids)
selected.referenced.update(entity_ids)
if (
not selector_data.device_ids
and not selector_data.area_ids
and not selector_data.floor_ids
and not selector_data.label_ids
):
return selected
entities = entity_registry.async_get(hass).entities
dev_reg = device_registry.async_get(hass)
area_reg = area_registry.async_get(hass)
if selector_data.floor_ids:
floor_reg = floor_registry.async_get(hass)
for floor_id in selector_data.floor_ids:
if floor_id not in floor_reg.floors:
selected.missing_floors.add(floor_id)
for area_id in selector_data.area_ids:
if area_id not in area_reg.areas:
selected.missing_areas.add(area_id)
for device_id in selector_data.device_ids:
if device_id not in dev_reg.devices:
selected.missing_devices.add(device_id)
if selector_data.label_ids:
label_reg = label_registry.async_get(hass)
for label_id in selector_data.label_ids:
if label_id not in label_reg.labels:
selected.missing_labels.add(label_id)
for entity_entry in entities.get_entries_for_label(label_id):
if (
entity_entry.entity_category is None
and entity_entry.hidden_by is None
):
selected.indirectly_referenced.add(entity_entry.entity_id)
for device_entry in dev_reg.devices.get_devices_for_label(label_id):
selected.referenced_devices.add(device_entry.id)
for area_entry in area_reg.areas.get_areas_for_label(label_id):
selected.referenced_areas.add(area_entry.id)
# Find areas for targeted floors
if selector_data.floor_ids:
selected.referenced_areas.update(
area_entry.id
for floor_id in selector_data.floor_ids
for area_entry in area_reg.areas.get_areas_for_floor(floor_id)
)
selected.referenced_areas.update(selector_data.area_ids)
selected.referenced_devices.update(selector_data.device_ids)
if not selected.referenced_areas and not selected.referenced_devices:
return selected
# Add indirectly referenced by device
selected.indirectly_referenced.update(
entry.entity_id
for device_id in selected.referenced_devices
for entry in entities.get_entries_for_device_id(device_id)
# Do not add entities which are hidden or which are config
# or diagnostic entities.
if (entry.entity_category is None and entry.hidden_by is None)
)
# Find devices for targeted areas
referenced_devices_by_area: set[str] = set()
if selected.referenced_areas:
for area_id in selected.referenced_areas:
referenced_devices_by_area.update(
device_entry.id
for device_entry in dev_reg.devices.get_devices_for_area_id(area_id)
)
selected.referenced_devices.update(referenced_devices_by_area)
# Add indirectly referenced by area
selected.indirectly_referenced.update(
entry.entity_id
for area_id in selected.referenced_areas
# The entity's area matches a targeted area
for entry in entities.get_entries_for_area_id(area_id)
# Do not add entities which are hidden or which are config
# or diagnostic entities.
if entry.entity_category is None and entry.hidden_by is None
)
# Add indirectly referenced by area through device
selected.indirectly_referenced.update(
entry.entity_id
for device_id in referenced_devices_by_area
for entry in entities.get_entries_for_device_id(device_id)
# Do not add entities which are hidden or which are config
# or diagnostic entities.
if (
entry.entity_category is None
and entry.hidden_by is None
and (
# The entity's device matches a device referenced
# by an area and the entity
# has no explicitly set area
not entry.area_id
)
)
)
return selected
def async_track_target_selector_state_change_event(
hass: HomeAssistant,
target_selector_config: ConfigType,
action: Callable[[Event[EventStateChangedData]], Any],
job_type: HassJobType | None = None,
) -> CALLBACK_TYPE:
"""Track state changes for entities referenced directly or indirectly (by device, area, label, etc) in a target selector."""
selector_data = TargetSelectorData(target_selector_config)
if not selector_data.has_any_selector:
_LOGGER.warning(
"Target selector %s does not have any selectors defined",
target_selector_config,
)
return lambda: None
def track_entities_state_change() -> CALLBACK_TYPE:
selected = async_extract_referenced_entity_ids(
hass, selector_data, expand_group=False
)
@callback
def state_change_listener(event: Event[EventStateChangedData]) -> None:
"""Handle state change events."""
if (
event.data["entity_id"] in selected.referenced
or event.data["entity_id"] in selected.indirectly_referenced
):
action(event)
tracked_entities = selected.referenced.union(selected.indirectly_referenced)
_LOGGER.debug("Tracking state changes for entities: %s", tracked_entities)
return async_track_state_change_event(
hass, tracked_entities, state_change_listener, job_type=job_type
)
unsub_state_change = track_entities_state_change()
def resubscribe_state_change_event(event: Event[Any] | None = None) -> None:
# TODO(abmantis): Check if there is a better way to do this
nonlocal unsub_state_change
unsub_state_change()
unsub_state_change = track_entities_state_change()
unsub_registry_updates = [
hass.bus.async_listen(
entity_registry.EVENT_ENTITY_REGISTRY_UPDATED,
resubscribe_state_change_event,
# TODO(abmantis): filter for entities that match the target selector?
# event_filter=self._filter_entity_registry_changes,
),
hass.bus.async_listen(
device_registry.EVENT_DEVICE_REGISTRY_UPDATED,
resubscribe_state_change_event,
),
hass.bus.async_listen(
area_registry.EVENT_AREA_REGISTRY_UPDATED,
resubscribe_state_change_event,
),
]
def unsub() -> None:
"""Unsubscribe from state change and registry update events."""
for registry_unsub in unsub_registry_updates:
registry_unsub()
unsub_registry_updates.clear()
unsub_state_change()
return unsub

View File

@ -10,6 +10,15 @@ import voluptuous as vol
from homeassistant.components.sun import DOMAIN as DOMAIN_SUN
from homeassistant.components.system_health import DOMAIN as DOMAIN_SYSTEM_HEALTH
from homeassistant.components.tag import DOMAIN as DOMAIN_TAG
from homeassistant.const import (
ATTR_AREA_ID,
ATTR_DEVICE_ID,
ATTR_ENTITY_ID,
ATTR_FLOOR_ID,
ATTR_LABEL_ID,
STATE_OFF,
STATE_ON,
)
from homeassistant.core import (
CALLBACK_TYPE,
Context,
@ -18,7 +27,14 @@ from homeassistant.core import (
callback,
)
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import trigger
from homeassistant.helpers import (
area_registry as ar,
device_registry as dr,
entity_registry as er,
floor_registry as fr,
label_registry as lr,
trigger,
)
from homeassistant.helpers.trigger import (
DATA_PLUGGABLE_ACTIONS,
PluggableAction,
@ -27,6 +43,7 @@ from homeassistant.helpers.trigger import (
TriggerInfo,
_async_get_trigger_platform,
async_initialize_triggers,
async_track_target_selector_state_change_event,
async_validate_trigger_config,
)
from homeassistant.helpers.typing import ConfigType
@ -34,7 +51,13 @@ from homeassistant.loader import Integration, async_get_integration
from homeassistant.setup import async_setup_component
from homeassistant.util.yaml.loader import parse_yaml
from tests.common import MockModule, MockPlatform, mock_integration, mock_platform
from tests.common import (
MockConfigEntry,
MockModule,
MockPlatform,
mock_integration,
mock_platform,
)
async def test_bad_trigger_platform(hass: HomeAssistant) -> None:
@ -738,3 +761,216 @@ async def test_invalid_trigger_platform(
await async_setup_component(hass, "test", {})
assert "Integration test does not provide trigger support, skipping" in caplog.text
async def test_async_track_target_selector_state_change_event_empty_selector(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test async_track_target_selector_state_change_event with empty selector."""
calls = []
@callback
def state_change_callback(event):
"""Handle state change events."""
calls.append(event)
unsub = async_track_target_selector_state_change_event(
hass, {}, state_change_callback
)
assert "Target selector {} does not have any selectors defined" in caplog.text
# Test that no state changes are tracked
hass.states.async_set("light.test", "on")
await hass.async_block_till_done()
assert len(calls) == 0
unsub()
async def test_async_track_target_selector_state_change_event(
hass: HomeAssistant,
) -> None:
"""Test async_track_target_selector_state_change_event with multiple targets."""
calls = []
@callback
def state_change_callback(event):
"""Handle state change events."""
calls.append(event)
async def set_state(entity_id, state):
"""Set the state of an entity."""
hass.states.async_set(entity_id, state)
await hass.async_block_till_done()
def assert_entity_calls_and_reset(entity_id: str) -> None:
assert len(calls) == 1
assert calls[0].data["entity_id"] == entity_id
calls.clear()
config_entry = MockConfigEntry(domain="test")
config_entry.add_to_hass(hass)
device_reg = dr.async_get(hass)
device_entry = device_reg.async_get_or_create(
config_entry_id=config_entry.entry_id,
identifiers={("test", "device_1")},
)
untargeted_device_entry = device_reg.async_get_or_create(
config_entry_id=config_entry.entry_id,
identifiers={("test", "area_device")},
)
entity_reg = er.async_get(hass)
device_entity = entity_reg.async_get_or_create(
domain="light",
platform="test",
unique_id="device_light",
device_id=device_entry.id,
).entity_id
untargeted_device_entity = entity_reg.async_get_or_create(
domain="light",
platform="test",
unique_id="area_device_light",
device_id=untargeted_device_entry.id,
).entity_id
untargeted_entity = entity_reg.async_get_or_create(
domain="light",
platform="test",
unique_id="untargeted_light",
).entity_id
targeted_entity = "light.test_light"
for entity_id in (targeted_entity, device_entity, untargeted_entity):
hass.states.async_set(entity_id, STATE_OFF)
await hass.async_block_till_done()
label = lr.async_get(hass).async_create("Test Label").name
area = ar.async_get(hass).async_create("Test Area").id
floor = fr.async_get(hass).async_create("Test Floor").floor_id
selector_config = {
ATTR_ENTITY_ID: targeted_entity,
ATTR_DEVICE_ID: device_entry.id,
ATTR_AREA_ID: area,
ATTR_FLOOR_ID: floor,
ATTR_LABEL_ID: label,
}
unsub = async_track_target_selector_state_change_event(
hass, selector_config, state_change_callback
)
# Test directly targeted entity and device
await set_state(targeted_entity, STATE_ON)
await set_state(device_entity, STATE_ON)
assert len(calls) == 2
assert calls[0].data["entity_id"] == targeted_entity
assert calls[0].data["old_state"].state == STATE_OFF
assert calls[0].data["new_state"].state == STATE_ON
assert calls[1].data["entity_id"] == device_entity
assert calls[1].data["old_state"].state == STATE_OFF
assert calls[1].data["new_state"].state == STATE_ON
calls.clear()
# Add new entity to the targeted device -> should trigger on state change
device_entity_2 = entity_reg.async_get_or_create(
domain="light",
platform="test",
unique_id="device_light_2",
device_id=device_entry.id,
).entity_id
await hass.async_block_till_done()
await set_state(device_entity_2, STATE_ON)
assert_entity_calls_and_reset(device_entity_2)
# Test untargeted entity -> should not trigger
await set_state(untargeted_entity, STATE_ON)
assert len(calls) == 0
calls.clear()
# Add label to untargeted entity -> should trigger now
entity_reg.async_update_entity(untargeted_entity, labels={label})
await hass.async_block_till_done()
await set_state(untargeted_entity, STATE_OFF)
assert_entity_calls_and_reset(untargeted_entity)
# Remove label from untargeted entity -> should not trigger anymore
entity_reg.async_update_entity(untargeted_entity, labels={})
await hass.async_block_till_done()
await set_state(untargeted_entity, STATE_ON)
await set_state(untargeted_entity, STATE_OFF)
assert len(calls) == 0
# Add area to untargeted entity -> should trigger now
entity_reg.async_update_entity(untargeted_entity, area_id=area)
await hass.async_block_till_done()
await set_state(untargeted_entity, STATE_ON)
assert_entity_calls_and_reset(untargeted_entity)
# Remove area from untargeted entity -> should not trigger anymore
entity_reg.async_update_entity(untargeted_entity, area_id=None)
await hass.async_block_till_done()
await set_state(untargeted_entity, STATE_ON)
await set_state(untargeted_entity, STATE_OFF)
assert len(calls) == 0
# Add area to untargeted device -> should trigger on state change
device_reg.async_update_device(untargeted_device_entry.id, area_id=area)
await hass.async_block_till_done()
await set_state(untargeted_device_entity, STATE_ON)
assert_entity_calls_and_reset(untargeted_device_entity)
# Remove area from untargeted device -> should not trigger anymore
device_reg.async_update_device(untargeted_device_entry.id, area_id=None)
await hass.async_block_till_done()
await set_state(untargeted_device_entity, STATE_OFF)
await set_state(untargeted_device_entity, STATE_ON)
assert len(calls) == 0
# Set the untargeted area on the untargeted entity -> should not trigger
untracked_area = ar.async_get(hass).async_create("Untargeted Area").id
entity_reg.async_update_entity(untargeted_entity, area_id=untracked_area)
await hass.async_block_till_done()
await set_state(untargeted_entity, STATE_ON)
assert len(calls) == 0
# Set targeted floor on the untargeted area -> should trigger now
ar.async_get(hass).async_update(untracked_area, floor_id=floor)
await hass.async_block_till_done()
await set_state(untargeted_entity, STATE_OFF)
assert_entity_calls_and_reset(untargeted_entity)
# Remove untargeted area from targeted floor -> should not trigger anymore
ar.async_get(hass).async_update(untracked_area, floor_id=None)
await hass.async_block_till_done()
await set_state(untargeted_entity, STATE_ON)
await set_state(untargeted_entity, STATE_OFF)
assert len(calls) == 0
# After unsubscribing, changes should not trigger
unsub()
for entity_id in (targeted_entity, device_entity, untargeted_entity):
await set_state(entity_id, STATE_OFF)
await set_state(entity_id, STATE_ON)
assert len(calls) == 0