Add method to track entity state changes from target selectors

This commit is contained in:
abmantis 2025-07-03 19:17:04 +01:00
parent 7fbf25e862
commit 89a9ab699d
2 changed files with 544 additions and 4 deletions

View File

@ -6,24 +6,33 @@ import abc
import asyncio import asyncio
from collections import defaultdict from collections import defaultdict
from collections.abc import Callable, Coroutine, Iterable from collections.abc import Callable, Coroutine, Iterable
import dataclasses
from dataclasses import dataclass, field from dataclasses import dataclass, field
import functools import functools
import logging import logging
from typing import TYPE_CHECKING, Any, Protocol, TypedDict, cast from typing import TYPE_CHECKING, Any, Protocol, TypedDict, TypeGuard, cast
import voluptuous as vol import voluptuous as vol
from homeassistant.const import ( from homeassistant.const import (
ATTR_AREA_ID,
ATTR_DEVICE_ID,
ATTR_ENTITY_ID,
ATTR_FLOOR_ID,
ATTR_LABEL_ID,
CONF_ALIAS, CONF_ALIAS,
CONF_ENABLED, CONF_ENABLED,
CONF_ID, CONF_ID,
CONF_PLATFORM, CONF_PLATFORM,
CONF_VARIABLES, CONF_VARIABLES,
ENTITY_MATCH_NONE,
) )
from homeassistant.core import ( from homeassistant.core import (
CALLBACK_TYPE, CALLBACK_TYPE,
Context, Context,
Event,
HassJob, HassJob,
HassJobType,
HomeAssistant, HomeAssistant,
callback, callback,
is_callback, is_callback,
@ -40,7 +49,16 @@ from homeassistant.util.hass_dict import HassKey
from homeassistant.util.yaml import load_yaml_dict from homeassistant.util.yaml import load_yaml_dict
from homeassistant.util.yaml.loader import JSON_TYPE from homeassistant.util.yaml.loader import JSON_TYPE
from . import config_validation as cv from . import (
area_registry,
config_validation as cv,
device_registry,
entity_registry,
floor_registry,
label_registry,
)
from .event import EventStateChangedData, async_track_state_change_event
from .group import expand_entity_ids
from .integration_platform import async_process_integration_platforms from .integration_platform import async_process_integration_platforms
from .template import Template from .template import Template
from .typing import ConfigType, TemplateVarsType from .typing import ConfigType, TemplateVarsType
@ -617,3 +635,289 @@ async def async_get_all_descriptions(
hass.data[TRIGGER_DESCRIPTION_CACHE] = new_descriptions_cache hass.data[TRIGGER_DESCRIPTION_CACHE] = new_descriptions_cache
return new_descriptions_cache return new_descriptions_cache
def _has_match(ids: str | list[str] | None) -> TypeGuard[str | list[str]]:
"""Check if ids can match anything."""
return ids not in (None, ENTITY_MATCH_NONE)
# TODO(abmantis): Since this is a copy from the service one, move it to a common place and use it in both places
class TargetSelectorData:
"""Class to hold data of target selector."""
__slots__ = ("area_ids", "device_ids", "entity_ids", "floor_ids", "label_ids")
def __init__(self, config: ConfigType) -> None:
"""Extract ids from the config."""
entity_ids: str | list | None = config.get(ATTR_ENTITY_ID)
device_ids: str | list | None = config.get(ATTR_DEVICE_ID)
area_ids: str | list | None = config.get(ATTR_AREA_ID)
floor_ids: str | list | None = config.get(ATTR_FLOOR_ID)
label_ids: str | list | None = config.get(ATTR_LABEL_ID)
self.entity_ids = (
set(cv.ensure_list(entity_ids)) if _has_match(entity_ids) else set()
)
self.device_ids = (
set(cv.ensure_list(device_ids)) if _has_match(device_ids) else set()
)
self.area_ids = set(cv.ensure_list(area_ids)) if _has_match(area_ids) else set()
self.floor_ids = (
set(cv.ensure_list(floor_ids)) if _has_match(floor_ids) else set()
)
self.label_ids = (
set(cv.ensure_list(label_ids)) if _has_match(label_ids) else set()
)
@property
def has_any_selector(self) -> bool:
"""Determine if any selectors are present."""
return bool(
self.entity_ids
or self.device_ids
or self.area_ids
or self.floor_ids
or self.label_ids
)
# TODO(abmantis): Since this is a copy from the service one, move it to a common place and use it in both places
@dataclasses.dataclass(slots=True)
class SelectedEntities:
"""Class to hold the selected entities."""
# Entities that were explicitly mentioned.
referenced: set[str] = dataclasses.field(default_factory=set)
# Entities that were referenced via device/area/floor/label ID.
# Should not trigger a warning when they don't exist.
indirectly_referenced: set[str] = dataclasses.field(default_factory=set)
# Referenced items that could not be found.
missing_devices: set[str] = dataclasses.field(default_factory=set)
missing_areas: set[str] = dataclasses.field(default_factory=set)
missing_floors: set[str] = dataclasses.field(default_factory=set)
missing_labels: set[str] = dataclasses.field(default_factory=set)
referenced_devices: set[str] = dataclasses.field(default_factory=set)
referenced_areas: set[str] = dataclasses.field(default_factory=set)
def log_missing(self, missing_entities: set[str]) -> None:
"""Log about missing items."""
parts = []
for label, items in (
("floors", self.missing_floors),
("areas", self.missing_areas),
("devices", self.missing_devices),
("entities", missing_entities),
("labels", self.missing_labels),
):
if items:
parts.append(f"{label} {', '.join(sorted(items))}")
if not parts:
return
_LOGGER.warning(
"Referenced %s are missing or not currently available",
", ".join(parts),
)
# TODO(abmantis): Since this is a copy from the service one, move it to a common place and use it in both places
def async_extract_referenced_entity_ids(
hass: HomeAssistant, selector_data: TargetSelectorData, expand_group: bool = True
) -> SelectedEntities:
"""Extract referenced entity IDs from a target selector."""
selected = SelectedEntities()
if not selector_data.has_any_selector:
return selected
entity_ids: set[str] | list[str] = selector_data.entity_ids
if expand_group:
entity_ids = expand_entity_ids(hass, entity_ids)
selected.referenced.update(entity_ids)
if (
not selector_data.device_ids
and not selector_data.area_ids
and not selector_data.floor_ids
and not selector_data.label_ids
):
return selected
entities = entity_registry.async_get(hass).entities
dev_reg = device_registry.async_get(hass)
area_reg = area_registry.async_get(hass)
if selector_data.floor_ids:
floor_reg = floor_registry.async_get(hass)
for floor_id in selector_data.floor_ids:
if floor_id not in floor_reg.floors:
selected.missing_floors.add(floor_id)
for area_id in selector_data.area_ids:
if area_id not in area_reg.areas:
selected.missing_areas.add(area_id)
for device_id in selector_data.device_ids:
if device_id not in dev_reg.devices:
selected.missing_devices.add(device_id)
if selector_data.label_ids:
label_reg = label_registry.async_get(hass)
for label_id in selector_data.label_ids:
if label_id not in label_reg.labels:
selected.missing_labels.add(label_id)
for entity_entry in entities.get_entries_for_label(label_id):
if (
entity_entry.entity_category is None
and entity_entry.hidden_by is None
):
selected.indirectly_referenced.add(entity_entry.entity_id)
for device_entry in dev_reg.devices.get_devices_for_label(label_id):
selected.referenced_devices.add(device_entry.id)
for area_entry in area_reg.areas.get_areas_for_label(label_id):
selected.referenced_areas.add(area_entry.id)
# Find areas for targeted floors
if selector_data.floor_ids:
selected.referenced_areas.update(
area_entry.id
for floor_id in selector_data.floor_ids
for area_entry in area_reg.areas.get_areas_for_floor(floor_id)
)
selected.referenced_areas.update(selector_data.area_ids)
selected.referenced_devices.update(selector_data.device_ids)
if not selected.referenced_areas and not selected.referenced_devices:
return selected
# Add indirectly referenced by device
selected.indirectly_referenced.update(
entry.entity_id
for device_id in selected.referenced_devices
for entry in entities.get_entries_for_device_id(device_id)
# Do not add entities which are hidden or which are config
# or diagnostic entities.
if (entry.entity_category is None and entry.hidden_by is None)
)
# Find devices for targeted areas
referenced_devices_by_area: set[str] = set()
if selected.referenced_areas:
for area_id in selected.referenced_areas:
referenced_devices_by_area.update(
device_entry.id
for device_entry in dev_reg.devices.get_devices_for_area_id(area_id)
)
selected.referenced_devices.update(referenced_devices_by_area)
# Add indirectly referenced by area
selected.indirectly_referenced.update(
entry.entity_id
for area_id in selected.referenced_areas
# The entity's area matches a targeted area
for entry in entities.get_entries_for_area_id(area_id)
# Do not add entities which are hidden or which are config
# or diagnostic entities.
if entry.entity_category is None and entry.hidden_by is None
)
# Add indirectly referenced by area through device
selected.indirectly_referenced.update(
entry.entity_id
for device_id in referenced_devices_by_area
for entry in entities.get_entries_for_device_id(device_id)
# Do not add entities which are hidden or which are config
# or diagnostic entities.
if (
entry.entity_category is None
and entry.hidden_by is None
and (
# The entity's device matches a device referenced
# by an area and the entity
# has no explicitly set area
not entry.area_id
)
)
)
return selected
def async_track_target_selector_state_change_event(
hass: HomeAssistant,
target_selector_config: ConfigType,
action: Callable[[Event[EventStateChangedData]], Any],
job_type: HassJobType | None = None,
) -> CALLBACK_TYPE:
"""Track state changes for entities referenced directly or indirectly (by device, area, label, etc) in a target selector."""
selector_data = TargetSelectorData(target_selector_config)
if not selector_data.has_any_selector:
_LOGGER.warning(
"Target selector %s does not have any selectors defined",
target_selector_config,
)
return lambda: None
def track_entities_state_change() -> CALLBACK_TYPE:
selected = async_extract_referenced_entity_ids(
hass, selector_data, expand_group=False
)
@callback
def state_change_listener(event: Event[EventStateChangedData]) -> None:
"""Handle state change events."""
if (
event.data["entity_id"] in selected.referenced
or event.data["entity_id"] in selected.indirectly_referenced
):
action(event)
tracked_entities = selected.referenced.union(selected.indirectly_referenced)
_LOGGER.debug("Tracking state changes for entities: %s", tracked_entities)
return async_track_state_change_event(
hass, tracked_entities, state_change_listener, job_type=job_type
)
unsub_state_change = track_entities_state_change()
def resubscribe_state_change_event(event: Event[Any] | None = None) -> None:
# TODO(abmantis): Check if there is a better way to do this
nonlocal unsub_state_change
unsub_state_change()
unsub_state_change = track_entities_state_change()
unsub_registry_updates = [
hass.bus.async_listen(
entity_registry.EVENT_ENTITY_REGISTRY_UPDATED,
resubscribe_state_change_event,
# TODO(abmantis): filter for entities that match the target selector?
# event_filter=self._filter_entity_registry_changes,
),
hass.bus.async_listen(
device_registry.EVENT_DEVICE_REGISTRY_UPDATED,
resubscribe_state_change_event,
),
hass.bus.async_listen(
area_registry.EVENT_AREA_REGISTRY_UPDATED,
resubscribe_state_change_event,
),
]
def unsub() -> None:
"""Unsubscribe from state change and registry update events."""
for registry_unsub in unsub_registry_updates:
registry_unsub()
unsub_registry_updates.clear()
unsub_state_change()
return unsub

View File

@ -10,6 +10,15 @@ import voluptuous as vol
from homeassistant.components.sun import DOMAIN as DOMAIN_SUN from homeassistant.components.sun import DOMAIN as DOMAIN_SUN
from homeassistant.components.system_health import DOMAIN as DOMAIN_SYSTEM_HEALTH from homeassistant.components.system_health import DOMAIN as DOMAIN_SYSTEM_HEALTH
from homeassistant.components.tag import DOMAIN as DOMAIN_TAG from homeassistant.components.tag import DOMAIN as DOMAIN_TAG
from homeassistant.const import (
ATTR_AREA_ID,
ATTR_DEVICE_ID,
ATTR_ENTITY_ID,
ATTR_FLOOR_ID,
ATTR_LABEL_ID,
STATE_OFF,
STATE_ON,
)
from homeassistant.core import ( from homeassistant.core import (
CALLBACK_TYPE, CALLBACK_TYPE,
Context, Context,
@ -18,7 +27,14 @@ from homeassistant.core import (
callback, callback,
) )
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import trigger from homeassistant.helpers import (
area_registry as ar,
device_registry as dr,
entity_registry as er,
floor_registry as fr,
label_registry as lr,
trigger,
)
from homeassistant.helpers.trigger import ( from homeassistant.helpers.trigger import (
DATA_PLUGGABLE_ACTIONS, DATA_PLUGGABLE_ACTIONS,
PluggableAction, PluggableAction,
@ -27,6 +43,7 @@ from homeassistant.helpers.trigger import (
TriggerInfo, TriggerInfo,
_async_get_trigger_platform, _async_get_trigger_platform,
async_initialize_triggers, async_initialize_triggers,
async_track_target_selector_state_change_event,
async_validate_trigger_config, async_validate_trigger_config,
) )
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
@ -34,7 +51,13 @@ from homeassistant.loader import Integration, async_get_integration
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from homeassistant.util.yaml.loader import parse_yaml from homeassistant.util.yaml.loader import parse_yaml
from tests.common import MockModule, MockPlatform, mock_integration, mock_platform from tests.common import (
MockConfigEntry,
MockModule,
MockPlatform,
mock_integration,
mock_platform,
)
async def test_bad_trigger_platform(hass: HomeAssistant) -> None: async def test_bad_trigger_platform(hass: HomeAssistant) -> None:
@ -738,3 +761,216 @@ async def test_invalid_trigger_platform(
await async_setup_component(hass, "test", {}) await async_setup_component(hass, "test", {})
assert "Integration test does not provide trigger support, skipping" in caplog.text assert "Integration test does not provide trigger support, skipping" in caplog.text
async def test_async_track_target_selector_state_change_event_empty_selector(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test async_track_target_selector_state_change_event with empty selector."""
calls = []
@callback
def state_change_callback(event):
"""Handle state change events."""
calls.append(event)
unsub = async_track_target_selector_state_change_event(
hass, {}, state_change_callback
)
assert "Target selector {} does not have any selectors defined" in caplog.text
# Test that no state changes are tracked
hass.states.async_set("light.test", "on")
await hass.async_block_till_done()
assert len(calls) == 0
unsub()
async def test_async_track_target_selector_state_change_event(
hass: HomeAssistant,
) -> None:
"""Test async_track_target_selector_state_change_event with multiple targets."""
calls = []
@callback
def state_change_callback(event):
"""Handle state change events."""
calls.append(event)
async def set_state(entity_id, state):
"""Set the state of an entity."""
hass.states.async_set(entity_id, state)
await hass.async_block_till_done()
def assert_entity_calls_and_reset(entity_id: str) -> None:
assert len(calls) == 1
assert calls[0].data["entity_id"] == entity_id
calls.clear()
config_entry = MockConfigEntry(domain="test")
config_entry.add_to_hass(hass)
device_reg = dr.async_get(hass)
device_entry = device_reg.async_get_or_create(
config_entry_id=config_entry.entry_id,
identifiers={("test", "device_1")},
)
untargeted_device_entry = device_reg.async_get_or_create(
config_entry_id=config_entry.entry_id,
identifiers={("test", "area_device")},
)
entity_reg = er.async_get(hass)
device_entity = entity_reg.async_get_or_create(
domain="light",
platform="test",
unique_id="device_light",
device_id=device_entry.id,
).entity_id
untargeted_device_entity = entity_reg.async_get_or_create(
domain="light",
platform="test",
unique_id="area_device_light",
device_id=untargeted_device_entry.id,
).entity_id
untargeted_entity = entity_reg.async_get_or_create(
domain="light",
platform="test",
unique_id="untargeted_light",
).entity_id
targeted_entity = "light.test_light"
for entity_id in (targeted_entity, device_entity, untargeted_entity):
hass.states.async_set(entity_id, STATE_OFF)
await hass.async_block_till_done()
label = lr.async_get(hass).async_create("Test Label").name
area = ar.async_get(hass).async_create("Test Area").id
floor = fr.async_get(hass).async_create("Test Floor").floor_id
selector_config = {
ATTR_ENTITY_ID: targeted_entity,
ATTR_DEVICE_ID: device_entry.id,
ATTR_AREA_ID: area,
ATTR_FLOOR_ID: floor,
ATTR_LABEL_ID: label,
}
unsub = async_track_target_selector_state_change_event(
hass, selector_config, state_change_callback
)
# Test directly targeted entity and device
await set_state(targeted_entity, STATE_ON)
await set_state(device_entity, STATE_ON)
assert len(calls) == 2
assert calls[0].data["entity_id"] == targeted_entity
assert calls[0].data["old_state"].state == STATE_OFF
assert calls[0].data["new_state"].state == STATE_ON
assert calls[1].data["entity_id"] == device_entity
assert calls[1].data["old_state"].state == STATE_OFF
assert calls[1].data["new_state"].state == STATE_ON
calls.clear()
# Add new entity to the targeted device -> should trigger on state change
device_entity_2 = entity_reg.async_get_or_create(
domain="light",
platform="test",
unique_id="device_light_2",
device_id=device_entry.id,
).entity_id
await hass.async_block_till_done()
await set_state(device_entity_2, STATE_ON)
assert_entity_calls_and_reset(device_entity_2)
# Test untargeted entity -> should not trigger
await set_state(untargeted_entity, STATE_ON)
assert len(calls) == 0
calls.clear()
# Add label to untargeted entity -> should trigger now
entity_reg.async_update_entity(untargeted_entity, labels={label})
await hass.async_block_till_done()
await set_state(untargeted_entity, STATE_OFF)
assert_entity_calls_and_reset(untargeted_entity)
# Remove label from untargeted entity -> should not trigger anymore
entity_reg.async_update_entity(untargeted_entity, labels={})
await hass.async_block_till_done()
await set_state(untargeted_entity, STATE_ON)
await set_state(untargeted_entity, STATE_OFF)
assert len(calls) == 0
# Add area to untargeted entity -> should trigger now
entity_reg.async_update_entity(untargeted_entity, area_id=area)
await hass.async_block_till_done()
await set_state(untargeted_entity, STATE_ON)
assert_entity_calls_and_reset(untargeted_entity)
# Remove area from untargeted entity -> should not trigger anymore
entity_reg.async_update_entity(untargeted_entity, area_id=None)
await hass.async_block_till_done()
await set_state(untargeted_entity, STATE_ON)
await set_state(untargeted_entity, STATE_OFF)
assert len(calls) == 0
# Add area to untargeted device -> should trigger on state change
device_reg.async_update_device(untargeted_device_entry.id, area_id=area)
await hass.async_block_till_done()
await set_state(untargeted_device_entity, STATE_ON)
assert_entity_calls_and_reset(untargeted_device_entity)
# Remove area from untargeted device -> should not trigger anymore
device_reg.async_update_device(untargeted_device_entry.id, area_id=None)
await hass.async_block_till_done()
await set_state(untargeted_device_entity, STATE_OFF)
await set_state(untargeted_device_entity, STATE_ON)
assert len(calls) == 0
# Set the untargeted area on the untargeted entity -> should not trigger
untracked_area = ar.async_get(hass).async_create("Untargeted Area").id
entity_reg.async_update_entity(untargeted_entity, area_id=untracked_area)
await hass.async_block_till_done()
await set_state(untargeted_entity, STATE_ON)
assert len(calls) == 0
# Set targeted floor on the untargeted area -> should trigger now
ar.async_get(hass).async_update(untracked_area, floor_id=floor)
await hass.async_block_till_done()
await set_state(untargeted_entity, STATE_OFF)
assert_entity_calls_and_reset(untargeted_entity)
# Remove untargeted area from targeted floor -> should not trigger anymore
ar.async_get(hass).async_update(untracked_area, floor_id=None)
await hass.async_block_till_done()
await set_state(untargeted_entity, STATE_ON)
await set_state(untargeted_entity, STATE_OFF)
assert len(calls) == 0
# After unsubscribing, changes should not trigger
unsub()
for entity_id in (targeted_entity, device_entity, untargeted_entity):
await set_state(entity_id, STATE_OFF)
await set_state(entity_id, STATE_ON)
assert len(calls) == 0