diff --git a/homeassistant/helpers/trigger.py b/homeassistant/helpers/trigger.py index 66d1560ac70..5f8f45834cd 100644 --- a/homeassistant/helpers/trigger.py +++ b/homeassistant/helpers/trigger.py @@ -6,24 +6,33 @@ import abc import asyncio from collections import defaultdict from collections.abc import Callable, Coroutine, Iterable +import dataclasses from dataclasses import dataclass, field import functools import logging -from typing import TYPE_CHECKING, Any, Protocol, TypedDict, cast +from typing import TYPE_CHECKING, Any, Protocol, TypedDict, TypeGuard, cast import voluptuous as vol from homeassistant.const import ( + ATTR_AREA_ID, + ATTR_DEVICE_ID, + ATTR_ENTITY_ID, + ATTR_FLOOR_ID, + ATTR_LABEL_ID, CONF_ALIAS, CONF_ENABLED, CONF_ID, CONF_PLATFORM, CONF_VARIABLES, + ENTITY_MATCH_NONE, ) from homeassistant.core import ( CALLBACK_TYPE, Context, + Event, HassJob, + HassJobType, HomeAssistant, callback, is_callback, @@ -40,7 +49,16 @@ from homeassistant.util.hass_dict import HassKey from homeassistant.util.yaml import load_yaml_dict from homeassistant.util.yaml.loader import JSON_TYPE -from . import config_validation as cv +from . import ( + area_registry, + config_validation as cv, + device_registry, + entity_registry, + floor_registry, + label_registry, +) +from .event import EventStateChangedData, async_track_state_change_event +from .group import expand_entity_ids from .integration_platform import async_process_integration_platforms from .template import Template from .typing import ConfigType, TemplateVarsType @@ -617,3 +635,289 @@ async def async_get_all_descriptions( hass.data[TRIGGER_DESCRIPTION_CACHE] = new_descriptions_cache return new_descriptions_cache + + +def _has_match(ids: str | list[str] | None) -> TypeGuard[str | list[str]]: + """Check if ids can match anything.""" + return ids not in (None, ENTITY_MATCH_NONE) + + +# TODO(abmantis): Since this is a copy from the service one, move it to a common place and use it in both places +class TargetSelectorData: + """Class to hold data of target selector.""" + + __slots__ = ("area_ids", "device_ids", "entity_ids", "floor_ids", "label_ids") + + def __init__(self, config: ConfigType) -> None: + """Extract ids from the config.""" + entity_ids: str | list | None = config.get(ATTR_ENTITY_ID) + device_ids: str | list | None = config.get(ATTR_DEVICE_ID) + area_ids: str | list | None = config.get(ATTR_AREA_ID) + floor_ids: str | list | None = config.get(ATTR_FLOOR_ID) + label_ids: str | list | None = config.get(ATTR_LABEL_ID) + + self.entity_ids = ( + set(cv.ensure_list(entity_ids)) if _has_match(entity_ids) else set() + ) + self.device_ids = ( + set(cv.ensure_list(device_ids)) if _has_match(device_ids) else set() + ) + self.area_ids = set(cv.ensure_list(area_ids)) if _has_match(area_ids) else set() + self.floor_ids = ( + set(cv.ensure_list(floor_ids)) if _has_match(floor_ids) else set() + ) + self.label_ids = ( + set(cv.ensure_list(label_ids)) if _has_match(label_ids) else set() + ) + + @property + def has_any_selector(self) -> bool: + """Determine if any selectors are present.""" + return bool( + self.entity_ids + or self.device_ids + or self.area_ids + or self.floor_ids + or self.label_ids + ) + + +# TODO(abmantis): Since this is a copy from the service one, move it to a common place and use it in both places +@dataclasses.dataclass(slots=True) +class SelectedEntities: + """Class to hold the selected entities.""" + + # Entities that were explicitly mentioned. + referenced: set[str] = dataclasses.field(default_factory=set) + + # Entities that were referenced via device/area/floor/label ID. + # Should not trigger a warning when they don't exist. + indirectly_referenced: set[str] = dataclasses.field(default_factory=set) + + # Referenced items that could not be found. + missing_devices: set[str] = dataclasses.field(default_factory=set) + missing_areas: set[str] = dataclasses.field(default_factory=set) + missing_floors: set[str] = dataclasses.field(default_factory=set) + missing_labels: set[str] = dataclasses.field(default_factory=set) + + referenced_devices: set[str] = dataclasses.field(default_factory=set) + referenced_areas: set[str] = dataclasses.field(default_factory=set) + + def log_missing(self, missing_entities: set[str]) -> None: + """Log about missing items.""" + parts = [] + for label, items in ( + ("floors", self.missing_floors), + ("areas", self.missing_areas), + ("devices", self.missing_devices), + ("entities", missing_entities), + ("labels", self.missing_labels), + ): + if items: + parts.append(f"{label} {', '.join(sorted(items))}") + + if not parts: + return + + _LOGGER.warning( + "Referenced %s are missing or not currently available", + ", ".join(parts), + ) + + +# TODO(abmantis): Since this is a copy from the service one, move it to a common place and use it in both places +def async_extract_referenced_entity_ids( + hass: HomeAssistant, selector_data: TargetSelectorData, expand_group: bool = True +) -> SelectedEntities: + """Extract referenced entity IDs from a target selector.""" + selected = SelectedEntities() + + if not selector_data.has_any_selector: + return selected + + entity_ids: set[str] | list[str] = selector_data.entity_ids + if expand_group: + entity_ids = expand_entity_ids(hass, entity_ids) + + selected.referenced.update(entity_ids) + + if ( + not selector_data.device_ids + and not selector_data.area_ids + and not selector_data.floor_ids + and not selector_data.label_ids + ): + return selected + + entities = entity_registry.async_get(hass).entities + dev_reg = device_registry.async_get(hass) + area_reg = area_registry.async_get(hass) + + if selector_data.floor_ids: + floor_reg = floor_registry.async_get(hass) + for floor_id in selector_data.floor_ids: + if floor_id not in floor_reg.floors: + selected.missing_floors.add(floor_id) + + for area_id in selector_data.area_ids: + if area_id not in area_reg.areas: + selected.missing_areas.add(area_id) + + for device_id in selector_data.device_ids: + if device_id not in dev_reg.devices: + selected.missing_devices.add(device_id) + + if selector_data.label_ids: + label_reg = label_registry.async_get(hass) + for label_id in selector_data.label_ids: + if label_id not in label_reg.labels: + selected.missing_labels.add(label_id) + + for entity_entry in entities.get_entries_for_label(label_id): + if ( + entity_entry.entity_category is None + and entity_entry.hidden_by is None + ): + selected.indirectly_referenced.add(entity_entry.entity_id) + + for device_entry in dev_reg.devices.get_devices_for_label(label_id): + selected.referenced_devices.add(device_entry.id) + + for area_entry in area_reg.areas.get_areas_for_label(label_id): + selected.referenced_areas.add(area_entry.id) + + # Find areas for targeted floors + if selector_data.floor_ids: + selected.referenced_areas.update( + area_entry.id + for floor_id in selector_data.floor_ids + for area_entry in area_reg.areas.get_areas_for_floor(floor_id) + ) + + selected.referenced_areas.update(selector_data.area_ids) + selected.referenced_devices.update(selector_data.device_ids) + + if not selected.referenced_areas and not selected.referenced_devices: + return selected + + # Add indirectly referenced by device + selected.indirectly_referenced.update( + entry.entity_id + for device_id in selected.referenced_devices + for entry in entities.get_entries_for_device_id(device_id) + # Do not add entities which are hidden or which are config + # or diagnostic entities. + if (entry.entity_category is None and entry.hidden_by is None) + ) + + # Find devices for targeted areas + referenced_devices_by_area: set[str] = set() + if selected.referenced_areas: + for area_id in selected.referenced_areas: + referenced_devices_by_area.update( + device_entry.id + for device_entry in dev_reg.devices.get_devices_for_area_id(area_id) + ) + selected.referenced_devices.update(referenced_devices_by_area) + + # Add indirectly referenced by area + selected.indirectly_referenced.update( + entry.entity_id + for area_id in selected.referenced_areas + # The entity's area matches a targeted area + for entry in entities.get_entries_for_area_id(area_id) + # Do not add entities which are hidden or which are config + # or diagnostic entities. + if entry.entity_category is None and entry.hidden_by is None + ) + # Add indirectly referenced by area through device + selected.indirectly_referenced.update( + entry.entity_id + for device_id in referenced_devices_by_area + for entry in entities.get_entries_for_device_id(device_id) + # Do not add entities which are hidden or which are config + # or diagnostic entities. + if ( + entry.entity_category is None + and entry.hidden_by is None + and ( + # The entity's device matches a device referenced + # by an area and the entity + # has no explicitly set area + not entry.area_id + ) + ) + ) + + return selected + + +def async_track_target_selector_state_change_event( + hass: HomeAssistant, + target_selector_config: ConfigType, + action: Callable[[Event[EventStateChangedData]], Any], + job_type: HassJobType | None = None, +) -> CALLBACK_TYPE: + """Track state changes for entities referenced directly or indirectly (by device, area, label, etc) in a target selector.""" + selector_data = TargetSelectorData(target_selector_config) + if not selector_data.has_any_selector: + _LOGGER.warning( + "Target selector %s does not have any selectors defined", + target_selector_config, + ) + return lambda: None + + def track_entities_state_change() -> CALLBACK_TYPE: + selected = async_extract_referenced_entity_ids( + hass, selector_data, expand_group=False + ) + + @callback + def state_change_listener(event: Event[EventStateChangedData]) -> None: + """Handle state change events.""" + if ( + event.data["entity_id"] in selected.referenced + or event.data["entity_id"] in selected.indirectly_referenced + ): + action(event) + + tracked_entities = selected.referenced.union(selected.indirectly_referenced) + + _LOGGER.debug("Tracking state changes for entities: %s", tracked_entities) + return async_track_state_change_event( + hass, tracked_entities, state_change_listener, job_type=job_type + ) + + unsub_state_change = track_entities_state_change() + + def resubscribe_state_change_event(event: Event[Any] | None = None) -> None: + # TODO(abmantis): Check if there is a better way to do this + nonlocal unsub_state_change + unsub_state_change() + unsub_state_change = track_entities_state_change() + + unsub_registry_updates = [ + hass.bus.async_listen( + entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, + resubscribe_state_change_event, + # TODO(abmantis): filter for entities that match the target selector? + # event_filter=self._filter_entity_registry_changes, + ), + hass.bus.async_listen( + device_registry.EVENT_DEVICE_REGISTRY_UPDATED, + resubscribe_state_change_event, + ), + hass.bus.async_listen( + area_registry.EVENT_AREA_REGISTRY_UPDATED, + resubscribe_state_change_event, + ), + ] + + def unsub() -> None: + """Unsubscribe from state change and registry update events.""" + for registry_unsub in unsub_registry_updates: + registry_unsub() + unsub_registry_updates.clear() + unsub_state_change() + + return unsub diff --git a/tests/helpers/test_trigger.py b/tests/helpers/test_trigger.py index 27cde92d14f..1f3219c538f 100644 --- a/tests/helpers/test_trigger.py +++ b/tests/helpers/test_trigger.py @@ -10,6 +10,15 @@ import voluptuous as vol from homeassistant.components.sun import DOMAIN as DOMAIN_SUN from homeassistant.components.system_health import DOMAIN as DOMAIN_SYSTEM_HEALTH from homeassistant.components.tag import DOMAIN as DOMAIN_TAG +from homeassistant.const import ( + ATTR_AREA_ID, + ATTR_DEVICE_ID, + ATTR_ENTITY_ID, + ATTR_FLOOR_ID, + ATTR_LABEL_ID, + STATE_OFF, + STATE_ON, +) from homeassistant.core import ( CALLBACK_TYPE, Context, @@ -18,7 +27,14 @@ from homeassistant.core import ( callback, ) from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import trigger +from homeassistant.helpers import ( + area_registry as ar, + device_registry as dr, + entity_registry as er, + floor_registry as fr, + label_registry as lr, + trigger, +) from homeassistant.helpers.trigger import ( DATA_PLUGGABLE_ACTIONS, PluggableAction, @@ -27,6 +43,7 @@ from homeassistant.helpers.trigger import ( TriggerInfo, _async_get_trigger_platform, async_initialize_triggers, + async_track_target_selector_state_change_event, async_validate_trigger_config, ) from homeassistant.helpers.typing import ConfigType @@ -34,7 +51,13 @@ from homeassistant.loader import Integration, async_get_integration from homeassistant.setup import async_setup_component from homeassistant.util.yaml.loader import parse_yaml -from tests.common import MockModule, MockPlatform, mock_integration, mock_platform +from tests.common import ( + MockConfigEntry, + MockModule, + MockPlatform, + mock_integration, + mock_platform, +) async def test_bad_trigger_platform(hass: HomeAssistant) -> None: @@ -738,3 +761,216 @@ async def test_invalid_trigger_platform( await async_setup_component(hass, "test", {}) assert "Integration test does not provide trigger support, skipping" in caplog.text + + +async def test_async_track_target_selector_state_change_event_empty_selector( + hass: HomeAssistant, caplog: pytest.LogCaptureFixture +) -> None: + """Test async_track_target_selector_state_change_event with empty selector.""" + calls = [] + + @callback + def state_change_callback(event): + """Handle state change events.""" + calls.append(event) + + unsub = async_track_target_selector_state_change_event( + hass, {}, state_change_callback + ) + + assert "Target selector {} does not have any selectors defined" in caplog.text + + # Test that no state changes are tracked + hass.states.async_set("light.test", "on") + await hass.async_block_till_done() + + assert len(calls) == 0 + + unsub() + + +async def test_async_track_target_selector_state_change_event( + hass: HomeAssistant, +) -> None: + """Test async_track_target_selector_state_change_event with multiple targets.""" + calls = [] + + @callback + def state_change_callback(event): + """Handle state change events.""" + calls.append(event) + + async def set_state(entity_id, state): + """Set the state of an entity.""" + hass.states.async_set(entity_id, state) + await hass.async_block_till_done() + + def assert_entity_calls_and_reset(entity_id: str) -> None: + assert len(calls) == 1 + assert calls[0].data["entity_id"] == entity_id + calls.clear() + + config_entry = MockConfigEntry(domain="test") + config_entry.add_to_hass(hass) + + device_reg = dr.async_get(hass) + device_entry = device_reg.async_get_or_create( + config_entry_id=config_entry.entry_id, + identifiers={("test", "device_1")}, + ) + + untargeted_device_entry = device_reg.async_get_or_create( + config_entry_id=config_entry.entry_id, + identifiers={("test", "area_device")}, + ) + + entity_reg = er.async_get(hass) + device_entity = entity_reg.async_get_or_create( + domain="light", + platform="test", + unique_id="device_light", + device_id=device_entry.id, + ).entity_id + + untargeted_device_entity = entity_reg.async_get_or_create( + domain="light", + platform="test", + unique_id="area_device_light", + device_id=untargeted_device_entry.id, + ).entity_id + + untargeted_entity = entity_reg.async_get_or_create( + domain="light", + platform="test", + unique_id="untargeted_light", + ).entity_id + + targeted_entity = "light.test_light" + + for entity_id in (targeted_entity, device_entity, untargeted_entity): + hass.states.async_set(entity_id, STATE_OFF) + await hass.async_block_till_done() + + label = lr.async_get(hass).async_create("Test Label").name + area = ar.async_get(hass).async_create("Test Area").id + floor = fr.async_get(hass).async_create("Test Floor").floor_id + + selector_config = { + ATTR_ENTITY_ID: targeted_entity, + ATTR_DEVICE_ID: device_entry.id, + ATTR_AREA_ID: area, + ATTR_FLOOR_ID: floor, + ATTR_LABEL_ID: label, + } + unsub = async_track_target_selector_state_change_event( + hass, selector_config, state_change_callback + ) + + # Test directly targeted entity and device + await set_state(targeted_entity, STATE_ON) + await set_state(device_entity, STATE_ON) + + assert len(calls) == 2 + assert calls[0].data["entity_id"] == targeted_entity + assert calls[0].data["old_state"].state == STATE_OFF + assert calls[0].data["new_state"].state == STATE_ON + assert calls[1].data["entity_id"] == device_entity + assert calls[1].data["old_state"].state == STATE_OFF + assert calls[1].data["new_state"].state == STATE_ON + calls.clear() + + # Add new entity to the targeted device -> should trigger on state change + device_entity_2 = entity_reg.async_get_or_create( + domain="light", + platform="test", + unique_id="device_light_2", + device_id=device_entry.id, + ).entity_id + await hass.async_block_till_done() + + await set_state(device_entity_2, STATE_ON) + + assert_entity_calls_and_reset(device_entity_2) + + # Test untargeted entity -> should not trigger + await set_state(untargeted_entity, STATE_ON) + + assert len(calls) == 0 + calls.clear() + + # Add label to untargeted entity -> should trigger now + entity_reg.async_update_entity(untargeted_entity, labels={label}) + await hass.async_block_till_done() + await set_state(untargeted_entity, STATE_OFF) + + assert_entity_calls_and_reset(untargeted_entity) + + # Remove label from untargeted entity -> should not trigger anymore + entity_reg.async_update_entity(untargeted_entity, labels={}) + await hass.async_block_till_done() + await set_state(untargeted_entity, STATE_ON) + await set_state(untargeted_entity, STATE_OFF) + + assert len(calls) == 0 + + # Add area to untargeted entity -> should trigger now + entity_reg.async_update_entity(untargeted_entity, area_id=area) + await hass.async_block_till_done() + await set_state(untargeted_entity, STATE_ON) + + assert_entity_calls_and_reset(untargeted_entity) + + # Remove area from untargeted entity -> should not trigger anymore + entity_reg.async_update_entity(untargeted_entity, area_id=None) + await hass.async_block_till_done() + await set_state(untargeted_entity, STATE_ON) + await set_state(untargeted_entity, STATE_OFF) + + assert len(calls) == 0 + + # Add area to untargeted device -> should trigger on state change + device_reg.async_update_device(untargeted_device_entry.id, area_id=area) + await hass.async_block_till_done() + + await set_state(untargeted_device_entity, STATE_ON) + + assert_entity_calls_and_reset(untargeted_device_entity) + + # Remove area from untargeted device -> should not trigger anymore + device_reg.async_update_device(untargeted_device_entry.id, area_id=None) + await hass.async_block_till_done() + await set_state(untargeted_device_entity, STATE_OFF) + await set_state(untargeted_device_entity, STATE_ON) + + assert len(calls) == 0 + + # Set the untargeted area on the untargeted entity -> should not trigger + untracked_area = ar.async_get(hass).async_create("Untargeted Area").id + entity_reg.async_update_entity(untargeted_entity, area_id=untracked_area) + await hass.async_block_till_done() + + await set_state(untargeted_entity, STATE_ON) + assert len(calls) == 0 + + # Set targeted floor on the untargeted area -> should trigger now + ar.async_get(hass).async_update(untracked_area, floor_id=floor) + await hass.async_block_till_done() + + await set_state(untargeted_entity, STATE_OFF) + assert_entity_calls_and_reset(untargeted_entity) + + # Remove untargeted area from targeted floor -> should not trigger anymore + ar.async_get(hass).async_update(untracked_area, floor_id=None) + await hass.async_block_till_done() + + await set_state(untargeted_entity, STATE_ON) + await set_state(untargeted_entity, STATE_OFF) + assert len(calls) == 0 + + # After unsubscribing, changes should not trigger + unsub() + + for entity_id in (targeted_entity, device_entity, untargeted_entity): + await set_state(entity_id, STATE_OFF) + await set_state(entity_id, STATE_ON) + assert len(calls) == 0