Move to target.py

This commit is contained in:
abmantis 2025-07-07 17:50:24 +01:00
parent 67a7cf83c2
commit 937671561f
4 changed files with 337 additions and 349 deletions

View File

@ -2,9 +2,11 @@
from __future__ import annotations
from collections.abc import Callable
import dataclasses
import logging
from logging import Logger
from typing import TypeGuard
from typing import Any, TypeGuard
from homeassistant.const import (
ATTR_AREA_ID,
@ -14,7 +16,14 @@ from homeassistant.const import (
ATTR_LABEL_ID,
ENTITY_MATCH_NONE,
)
from homeassistant.core import HomeAssistant
from homeassistant.core import (
CALLBACK_TYPE,
Event,
EventStateChangedData,
HassJobType,
HomeAssistant,
callback,
)
from . import (
area_registry as ar,
@ -25,8 +34,11 @@ from . import (
group,
label_registry as lr,
)
from .event import async_track_state_change_event
from .typing import ConfigType
_LOGGER = logging.getLogger(__name__)
def _has_match(ids: str | list[str] | None) -> TypeGuard[str | list[str]]:
"""Check if ids can match anything."""
@ -238,3 +250,110 @@ def async_extract_referenced_entity_ids(
)
return selected
class TargetStateChangeTracker:
"""Helper class to manage state change tracking for targets."""
def __init__(
self,
hass: HomeAssistant,
selector_data: TargetSelectorData,
job_type: HassJobType | None,
action: Callable[[Event[EventStateChangedData]], Any],
) -> None:
"""Initialize the state change tracker."""
self._hass = hass
self._selector_data = selector_data
self._job_type = job_type
self._action = action
self._state_change_unsub: CALLBACK_TYPE | None = None
self._registry_unsubs: list[CALLBACK_TYPE] = []
self._setup_registry_listeners()
self._track_entities_state_change()
def _track_entities_state_change(self) -> None:
"""Set up state change tracking for currently selected entities."""
selected = async_extract_referenced_entity_ids(
self._hass, self._selector_data, expand_group=False
)
@callback
def state_change_listener(event: Event[EventStateChangedData]) -> None:
"""Handle state change events."""
if (
event.data["entity_id"] in selected.referenced
or event.data["entity_id"] in selected.indirectly_referenced
):
self._action(event)
tracked_entities = selected.referenced.union(selected.indirectly_referenced)
_LOGGER.debug("Tracking state changes for entities: %s", tracked_entities)
self._state_change_unsub = async_track_state_change_event(
self._hass, tracked_entities, state_change_listener, job_type=self._job_type
)
def _setup_registry_listeners(self) -> None:
"""Set up listeners for registry changes that require resubscription."""
@callback
def resubscribe_state_change_event(event: Event[Any] | None = None) -> None:
"""Resubscribe to state change events when registry changes."""
if self._state_change_unsub:
self._state_change_unsub()
self._track_entities_state_change()
# Subscribe to registry updates that can change the entities to track:
# - Entity registry: entity added/removed; entity labels changed; entity area changed.
# - Device registry: device labels changed; device area changed.
# - Area registry: area floor changed.
#
# We don't track other registries (like floor or label registries) because their
# changes don't affect which entities are tracked.
self._registry_unsubs = [
self._hass.bus.async_listen(
er.EVENT_ENTITY_REGISTRY_UPDATED,
resubscribe_state_change_event,
# TODO(abmantis): filter for entities that match the target selector?
# event_filter=self._filter_entity_registry_changes,
),
self._hass.bus.async_listen(
dr.EVENT_DEVICE_REGISTRY_UPDATED,
resubscribe_state_change_event,
),
self._hass.bus.async_listen(
ar.EVENT_AREA_REGISTRY_UPDATED,
resubscribe_state_change_event,
),
]
def unsub(self) -> None:
"""Unsubscribe from all events."""
for registry_unsub in self._registry_unsubs:
registry_unsub()
self._registry_unsubs.clear()
if self._state_change_unsub:
self._state_change_unsub()
self._state_change_unsub = None
def async_track_target_selector_state_change_event(
hass: HomeAssistant,
target_selector_config: ConfigType,
action: Callable[[Event[EventStateChangedData]], Any],
job_type: HassJobType | None = None,
) -> CALLBACK_TYPE:
"""Track state changes for entities referenced directly or indirectly in a target selector."""
selector_data = TargetSelectorData(target_selector_config)
if not selector_data.has_any_selector:
_LOGGER.warning(
"Target selector %s does not have any selectors defined",
target_selector_config,
)
return lambda: None
tracker = TargetStateChangeTracker(hass, selector_data, job_type, action)
return tracker.unsub

View File

@ -23,9 +23,7 @@ from homeassistant.const import (
from homeassistant.core import (
CALLBACK_TYPE,
Context,
Event,
HassJob,
HassJobType,
HomeAssistant,
callback,
is_callback,
@ -42,10 +40,8 @@ from homeassistant.util.hass_dict import HassKey
from homeassistant.util.yaml import load_yaml_dict
from homeassistant.util.yaml.loader import JSON_TYPE
from . import area_registry, config_validation as cv, device_registry, entity_registry
from .event import EventStateChangedData, async_track_state_change_event
from . import config_validation as cv
from .integration_platform import async_process_integration_platforms
from .target import TargetSelectorData, async_extract_referenced_entity_ids
from .template import Template
from .typing import ConfigType, TemplateVarsType
@ -625,110 +621,3 @@ async def async_get_all_descriptions(
hass.data[TRIGGER_DESCRIPTION_CACHE] = new_descriptions_cache
return new_descriptions_cache
class TargetStateChangeTracker:
"""Helper class to manage state change tracking for targets."""
def __init__(
self,
hass: HomeAssistant,
selector_data: TargetSelectorData,
job_type: HassJobType | None,
action: Callable[[Event[EventStateChangedData]], Any],
) -> None:
"""Initialize the state change tracker."""
self._hass = hass
self._selector_data = selector_data
self._job_type = job_type
self._action = action
self._state_change_unsub: CALLBACK_TYPE | None = None
self._registry_unsubs: list[CALLBACK_TYPE] = []
self._setup_registry_listeners()
self._track_entities_state_change()
def _track_entities_state_change(self) -> None:
"""Set up state change tracking for currently selected entities."""
selected = async_extract_referenced_entity_ids(
self._hass, self._selector_data, expand_group=False
)
@callback
def state_change_listener(event: Event[EventStateChangedData]) -> None:
"""Handle state change events."""
if (
event.data["entity_id"] in selected.referenced
or event.data["entity_id"] in selected.indirectly_referenced
):
self._action(event)
tracked_entities = selected.referenced.union(selected.indirectly_referenced)
_LOGGER.debug("Tracking state changes for entities: %s", tracked_entities)
self._state_change_unsub = async_track_state_change_event(
self._hass, tracked_entities, state_change_listener, job_type=self._job_type
)
def _setup_registry_listeners(self) -> None:
"""Set up listeners for registry changes that require resubscription."""
@callback
def resubscribe_state_change_event(event: Event[Any] | None = None) -> None:
"""Resubscribe to state change events when registry changes."""
if self._state_change_unsub:
self._state_change_unsub()
self._track_entities_state_change()
# Subscribe to registry updates that can change the entities to track:
# - Entity registry: entity added/removed; entity labels changed; entity area changed.
# - Device registry: device labels changed; device area changed.
# - Area registry: area floor changed.
#
# We don't track other registries (like floor or label registries) because their
# changes don't affect which entities are tracked.
self._registry_unsubs = [
self._hass.bus.async_listen(
entity_registry.EVENT_ENTITY_REGISTRY_UPDATED,
resubscribe_state_change_event,
# TODO(abmantis): filter for entities that match the target selector?
# event_filter=self._filter_entity_registry_changes,
),
self._hass.bus.async_listen(
device_registry.EVENT_DEVICE_REGISTRY_UPDATED,
resubscribe_state_change_event,
),
self._hass.bus.async_listen(
area_registry.EVENT_AREA_REGISTRY_UPDATED,
resubscribe_state_change_event,
),
]
def unsub(self) -> None:
"""Unsubscribe from all events."""
for registry_unsub in self._registry_unsubs:
registry_unsub()
self._registry_unsubs.clear()
if self._state_change_unsub:
self._state_change_unsub()
self._state_change_unsub = None
def async_track_target_selector_state_change_event(
hass: HomeAssistant,
target_selector_config: ConfigType,
action: Callable[[Event[EventStateChangedData]], Any],
job_type: HassJobType | None = None,
) -> CALLBACK_TYPE:
"""Track state changes for entities referenced directly or indirectly (by device, area, label, etc) in a target selector."""
selector_data = TargetSelectorData(target_selector_config)
if not selector_data.has_any_selector:
_LOGGER.warning(
"Target selector %s does not have any selectors defined",
target_selector_config,
)
return lambda: None
tracker = TargetStateChangeTracker(hass, selector_data, job_type, action)
return tracker.unsub

View File

@ -17,17 +17,20 @@ from homeassistant.const import (
STATE_ON,
EntityCategory,
)
from homeassistant.core import HomeAssistant
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import (
area_registry as ar,
device_registry as dr,
entity_registry as er,
floor_registry as fr,
label_registry as lr,
target,
)
from homeassistant.helpers.typing import ConfigType
from homeassistant.setup import async_setup_component
from tests.common import (
MockConfigEntry,
RegistryEntryWithDefaults,
mock_area_registry,
mock_device_registry,
@ -457,3 +460,212 @@ async def test_extract_referenced_entity_ids(
)
== expected_selected
)
async def test_async_track_target_selector_state_change_event_empty_selector(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test async_track_target_selector_state_change_event with empty selector."""
calls = []
@callback
def state_change_callback(event):
"""Handle state change events."""
calls.append(event)
unsub = target.async_track_target_selector_state_change_event(
hass, {}, state_change_callback
)
assert "Target selector {} does not have any selectors defined" in caplog.text
# Test that no state changes are tracked
hass.states.async_set("light.test", "on")
await hass.async_block_till_done()
assert len(calls) == 0
unsub()
async def test_async_track_target_selector_state_change_event(
hass: HomeAssistant,
) -> None:
"""Test async_track_target_selector_state_change_event with multiple targets."""
calls = []
@callback
def state_change_callback(event):
"""Handle state change events."""
calls.append(event)
# List of entities to toggle state during the test. This list should be insert-only
# so that all entities are changed every time.
entities_to_set_state = []
# List of entities that should assert a state change when toggled. Contrary to
# entities_to_set_state, entities should be added and removed.
entities_to_assert_change = []
last_state = STATE_OFF
async def toggle_states():
"""Toggle the state of all the entities in test."""
nonlocal last_state
last_state = STATE_ON if last_state == STATE_OFF else STATE_OFF
for entity_id in entities_to_set_state:
hass.states.async_set(entity_id, last_state)
await hass.async_block_till_done()
def assert_entity_calls_and_reset() -> None:
assert len(calls) == len(entities_to_assert_change)
for change_call in calls:
assert change_call.data["entity_id"] in entities_to_assert_change
assert change_call.data["new_state"].state == last_state
calls.clear()
config_entry = MockConfigEntry(domain="test")
config_entry.add_to_hass(hass)
device_reg = dr.async_get(hass)
device_entry = device_reg.async_get_or_create(
config_entry_id=config_entry.entry_id,
identifiers={("test", "device_1")},
)
untargeted_device_entry = device_reg.async_get_or_create(
config_entry_id=config_entry.entry_id,
identifiers={("test", "area_device")},
)
entity_reg = er.async_get(hass)
device_entity = entity_reg.async_get_or_create(
domain="light",
platform="test",
unique_id="device_light",
device_id=device_entry.id,
).entity_id
untargeted_device_entity = entity_reg.async_get_or_create(
domain="light",
platform="test",
unique_id="area_device_light",
device_id=untargeted_device_entry.id,
).entity_id
untargeted_entity = entity_reg.async_get_or_create(
domain="light",
platform="test",
unique_id="untargeted_light",
).entity_id
targeted_entity = "light.test_light"
entities_to_set_state.extend([targeted_entity, device_entity, untargeted_entity])
await toggle_states()
label = lr.async_get(hass).async_create("Test Label").name
area = ar.async_get(hass).async_create("Test Area").id
floor = fr.async_get(hass).async_create("Test Floor").floor_id
selector_config = {
ATTR_ENTITY_ID: targeted_entity,
ATTR_DEVICE_ID: device_entry.id,
ATTR_AREA_ID: area,
ATTR_FLOOR_ID: floor,
ATTR_LABEL_ID: label,
}
unsub = target.async_track_target_selector_state_change_event(
hass, selector_config, state_change_callback
)
# Test directly targeted entity and device
entities_to_assert_change.extend([targeted_entity, device_entity])
await toggle_states()
assert_entity_calls_and_reset()
# Add new entity to the targeted device -> should trigger on state change
device_entity_2 = entity_reg.async_get_or_create(
domain="light",
platform="test",
unique_id="device_light_2",
device_id=device_entry.id,
).entity_id
await hass.async_block_till_done()
entities_to_set_state.append(device_entity_2)
entities_to_assert_change.append(device_entity_2)
await toggle_states()
assert_entity_calls_and_reset()
# Test untargeted entity -> should not trigger
entities_to_set_state.append(untargeted_entity)
await toggle_states()
assert_entity_calls_and_reset()
# Add label to untargeted entity -> should trigger now
entity_reg.async_update_entity(untargeted_entity, labels={label})
await hass.async_block_till_done()
entities_to_assert_change.append(untargeted_entity)
await toggle_states()
assert_entity_calls_and_reset()
# Remove label from untargeted entity -> should not trigger anymore
entity_reg.async_update_entity(untargeted_entity, labels={})
await hass.async_block_till_done()
entities_to_assert_change.remove(untargeted_entity)
await toggle_states()
assert_entity_calls_and_reset()
# Add area to untargeted entity -> should trigger now
entity_reg.async_update_entity(untargeted_entity, area_id=area)
await hass.async_block_till_done()
entities_to_assert_change.append(untargeted_entity)
await toggle_states()
assert_entity_calls_and_reset()
# Remove area from untargeted entity -> should not trigger anymore
entity_reg.async_update_entity(untargeted_entity, area_id=None)
await hass.async_block_till_done()
entities_to_assert_change.remove(untargeted_entity)
await toggle_states()
assert_entity_calls_and_reset()
# Add area to untargeted device -> should trigger on state change
device_reg.async_update_device(untargeted_device_entry.id, area_id=area)
await hass.async_block_till_done()
entities_to_set_state.append(untargeted_device_entity)
entities_to_assert_change.append(untargeted_device_entity)
await toggle_states()
assert_entity_calls_and_reset()
# Remove area from untargeted device -> should not trigger anymore
device_reg.async_update_device(untargeted_device_entry.id, area_id=None)
await hass.async_block_till_done()
entities_to_assert_change.remove(untargeted_device_entity)
await toggle_states()
assert_entity_calls_and_reset()
# Set the untargeted area on the untargeted entity -> should not trigger
untracked_area = ar.async_get(hass).async_create("Untargeted Area").id
entity_reg.async_update_entity(untargeted_entity, area_id=untracked_area)
await hass.async_block_till_done()
await toggle_states()
assert_entity_calls_and_reset()
# Set targeted floor on the untargeted area -> should trigger now
ar.async_get(hass).async_update(untracked_area, floor_id=floor)
await hass.async_block_till_done()
entities_to_assert_change.append(untargeted_entity)
await toggle_states()
assert_entity_calls_and_reset()
# Remove untargeted area from targeted floor -> should not trigger anymore
ar.async_get(hass).async_update(untracked_area, floor_id=None)
await hass.async_block_till_done()
entities_to_assert_change.remove(untargeted_entity)
await toggle_states()
assert_entity_calls_and_reset()
# After unsubscribing, changes should not trigger
unsub()
await toggle_states()
assert len(calls) == 0

View File

@ -10,15 +10,6 @@ import voluptuous as vol
from homeassistant.components.sun import DOMAIN as DOMAIN_SUN
from homeassistant.components.system_health import DOMAIN as DOMAIN_SYSTEM_HEALTH
from homeassistant.components.tag import DOMAIN as DOMAIN_TAG
from homeassistant.const import (
ATTR_AREA_ID,
ATTR_DEVICE_ID,
ATTR_ENTITY_ID,
ATTR_FLOOR_ID,
ATTR_LABEL_ID,
STATE_OFF,
STATE_ON,
)
from homeassistant.core import (
CALLBACK_TYPE,
Context,
@ -27,14 +18,7 @@ from homeassistant.core import (
callback,
)
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import (
area_registry as ar,
device_registry as dr,
entity_registry as er,
floor_registry as fr,
label_registry as lr,
trigger,
)
from homeassistant.helpers import trigger
from homeassistant.helpers.trigger import (
DATA_PLUGGABLE_ACTIONS,
PluggableAction,
@ -43,7 +27,6 @@ from homeassistant.helpers.trigger import (
TriggerInfo,
_async_get_trigger_platform,
async_initialize_triggers,
async_track_target_selector_state_change_event,
async_validate_trigger_config,
)
from homeassistant.helpers.typing import ConfigType
@ -51,13 +34,7 @@ from homeassistant.loader import Integration, async_get_integration
from homeassistant.setup import async_setup_component
from homeassistant.util.yaml.loader import parse_yaml
from tests.common import (
MockConfigEntry,
MockModule,
MockPlatform,
mock_integration,
mock_platform,
)
from tests.common import MockModule, MockPlatform, mock_integration, mock_platform
async def test_bad_trigger_platform(hass: HomeAssistant) -> None:
@ -803,212 +780,3 @@ async def test_invalid_trigger_platform(
await async_setup_component(hass, "test", {})
assert "Integration test does not provide trigger support, skipping" in caplog.text
async def test_async_track_target_selector_state_change_event_empty_selector(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test async_track_target_selector_state_change_event with empty selector."""
calls = []
@callback
def state_change_callback(event):
"""Handle state change events."""
calls.append(event)
unsub = async_track_target_selector_state_change_event(
hass, {}, state_change_callback
)
assert "Target selector {} does not have any selectors defined" in caplog.text
# Test that no state changes are tracked
hass.states.async_set("light.test", "on")
await hass.async_block_till_done()
assert len(calls) == 0
unsub()
async def test_async_track_target_selector_state_change_event(
hass: HomeAssistant,
) -> None:
"""Test async_track_target_selector_state_change_event with multiple targets."""
calls = []
@callback
def state_change_callback(event):
"""Handle state change events."""
calls.append(event)
# List of entities to toggle state during the test. This list should be insert-only
# so that all entities are changed every time.
entities_to_set_state = []
# List of entities that should assert a state change when toggled. Contrary to
# entities_to_set_state, entities should be added and removed.
entities_to_assert_change = []
last_state = STATE_OFF
async def toggle_states():
"""Toggle the state of all the entities in test."""
nonlocal last_state
last_state = STATE_ON if last_state == STATE_OFF else STATE_OFF
for entity_id in entities_to_set_state:
hass.states.async_set(entity_id, last_state)
await hass.async_block_till_done()
def assert_entity_calls_and_reset() -> None:
assert len(calls) == len(entities_to_assert_change)
for change_call in calls:
assert change_call.data["entity_id"] in entities_to_assert_change
assert change_call.data["new_state"].state == last_state
calls.clear()
config_entry = MockConfigEntry(domain="test")
config_entry.add_to_hass(hass)
device_reg = dr.async_get(hass)
device_entry = device_reg.async_get_or_create(
config_entry_id=config_entry.entry_id,
identifiers={("test", "device_1")},
)
untargeted_device_entry = device_reg.async_get_or_create(
config_entry_id=config_entry.entry_id,
identifiers={("test", "area_device")},
)
entity_reg = er.async_get(hass)
device_entity = entity_reg.async_get_or_create(
domain="light",
platform="test",
unique_id="device_light",
device_id=device_entry.id,
).entity_id
untargeted_device_entity = entity_reg.async_get_or_create(
domain="light",
platform="test",
unique_id="area_device_light",
device_id=untargeted_device_entry.id,
).entity_id
untargeted_entity = entity_reg.async_get_or_create(
domain="light",
platform="test",
unique_id="untargeted_light",
).entity_id
targeted_entity = "light.test_light"
entities_to_set_state.extend([targeted_entity, device_entity, untargeted_entity])
await toggle_states()
label = lr.async_get(hass).async_create("Test Label").name
area = ar.async_get(hass).async_create("Test Area").id
floor = fr.async_get(hass).async_create("Test Floor").floor_id
selector_config = {
ATTR_ENTITY_ID: targeted_entity,
ATTR_DEVICE_ID: device_entry.id,
ATTR_AREA_ID: area,
ATTR_FLOOR_ID: floor,
ATTR_LABEL_ID: label,
}
unsub = async_track_target_selector_state_change_event(
hass, selector_config, state_change_callback
)
# Test directly targeted entity and device
entities_to_assert_change.extend([targeted_entity, device_entity])
await toggle_states()
assert_entity_calls_and_reset()
# Add new entity to the targeted device -> should trigger on state change
device_entity_2 = entity_reg.async_get_or_create(
domain="light",
platform="test",
unique_id="device_light_2",
device_id=device_entry.id,
).entity_id
await hass.async_block_till_done()
entities_to_set_state.append(device_entity_2)
entities_to_assert_change.append(device_entity_2)
await toggle_states()
assert_entity_calls_and_reset()
# Test untargeted entity -> should not trigger
entities_to_set_state.append(untargeted_entity)
await toggle_states()
assert_entity_calls_and_reset()
# Add label to untargeted entity -> should trigger now
entity_reg.async_update_entity(untargeted_entity, labels={label})
await hass.async_block_till_done()
entities_to_assert_change.append(untargeted_entity)
await toggle_states()
assert_entity_calls_and_reset()
# Remove label from untargeted entity -> should not trigger anymore
entity_reg.async_update_entity(untargeted_entity, labels={})
await hass.async_block_till_done()
entities_to_assert_change.remove(untargeted_entity)
await toggle_states()
assert_entity_calls_and_reset()
# Add area to untargeted entity -> should trigger now
entity_reg.async_update_entity(untargeted_entity, area_id=area)
await hass.async_block_till_done()
entities_to_assert_change.append(untargeted_entity)
await toggle_states()
assert_entity_calls_and_reset()
# Remove area from untargeted entity -> should not trigger anymore
entity_reg.async_update_entity(untargeted_entity, area_id=None)
await hass.async_block_till_done()
entities_to_assert_change.remove(untargeted_entity)
await toggle_states()
assert_entity_calls_and_reset()
# Add area to untargeted device -> should trigger on state change
device_reg.async_update_device(untargeted_device_entry.id, area_id=area)
await hass.async_block_till_done()
entities_to_set_state.append(untargeted_device_entity)
entities_to_assert_change.append(untargeted_device_entity)
await toggle_states()
assert_entity_calls_and_reset()
# Remove area from untargeted device -> should not trigger anymore
device_reg.async_update_device(untargeted_device_entry.id, area_id=None)
await hass.async_block_till_done()
entities_to_assert_change.remove(untargeted_device_entity)
await toggle_states()
assert_entity_calls_and_reset()
# Set the untargeted area on the untargeted entity -> should not trigger
untracked_area = ar.async_get(hass).async_create("Untargeted Area").id
entity_reg.async_update_entity(untargeted_entity, area_id=untracked_area)
await hass.async_block_till_done()
await toggle_states()
assert_entity_calls_and_reset()
# Set targeted floor on the untargeted area -> should trigger now
ar.async_get(hass).async_update(untracked_area, floor_id=floor)
await hass.async_block_till_done()
entities_to_assert_change.append(untargeted_entity)
await toggle_states()
assert_entity_calls_and_reset()
# Remove untargeted area from targeted floor -> should not trigger anymore
ar.async_get(hass).async_update(untracked_area, floor_id=None)
await hass.async_block_till_done()
entities_to_assert_change.remove(untargeted_entity)
await toggle_states()
assert_entity_calls_and_reset()
# After unsubscribing, changes should not trigger
unsub()
await toggle_states()
assert len(calls) == 0