Add list of targeted entities to target state event

This commit is contained in:
abmantis 2025-07-21 19:33:18 +01:00
parent 941d3c2be4
commit 8203308ab4
2 changed files with 20 additions and 10 deletions

View File

@ -40,6 +40,14 @@ from .typing import ConfigType
_LOGGER = logging.getLogger(__name__)
@dataclasses.dataclass
class TargetStateChangedData:
"""Data for state change events related to targets."""
state_change_event: Event[EventStateChangedData]
targeted_entity_ids: set[str]
def _has_match(ids: str | list[str] | None) -> TypeGuard[str | list[str]]:
"""Check if ids can match anything."""
return ids not in (None, ENTITY_MATCH_NONE)
@ -259,7 +267,7 @@ class TargetStateChangeTracker:
self,
hass: HomeAssistant,
selector_data: TargetSelectorData,
action: Callable[[Event[EventStateChangedData]], Any],
action: Callable[[TargetStateChangedData], Any],
) -> None:
"""Initialize the state change tracker."""
self._hass = hass
@ -281,6 +289,8 @@ class TargetStateChangeTracker:
self._hass, self._selector_data, expand_group=False
)
tracked_entities = selected.referenced.union(selected.indirectly_referenced)
@callback
def state_change_listener(event: Event[EventStateChangedData]) -> None:
"""Handle state change events."""
@ -288,9 +298,7 @@ class TargetStateChangeTracker:
event.data["entity_id"] in selected.referenced
or event.data["entity_id"] in selected.indirectly_referenced
):
self._action(event)
tracked_entities = selected.referenced.union(selected.indirectly_referenced)
self._action(TargetStateChangedData(event, tracked_entities))
_LOGGER.debug("Tracking state changes for entities: %s", tracked_entities)
self._state_change_unsub = async_track_state_change_event(
@ -339,7 +347,7 @@ class TargetStateChangeTracker:
def async_track_target_selector_state_change_event(
hass: HomeAssistant,
target_selector_config: ConfigType,
action: Callable[[Event[EventStateChangedData]], Any],
action: Callable[[TargetStateChangedData], Any],
) -> CALLBACK_TYPE:
"""Track state changes for entities referenced directly or indirectly in a target selector."""
selector_data = TargetSelectorData(target_selector_config)

View File

@ -14,7 +14,7 @@ from homeassistant.const import (
STATE_ON,
EntityCategory,
)
from homeassistant.core import Event, EventStateChangedData, HomeAssistant, callback
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import (
area_registry as ar,
@ -482,10 +482,10 @@ async def test_async_track_target_selector_state_change_event(
hass: HomeAssistant,
) -> None:
"""Test async_track_target_selector_state_change_event with multiple targets."""
events: list[Event[EventStateChangedData]] = []
events: list[target.TargetStateChangedData] = []
@callback
def state_change_callback(event: Event[EventStateChangedData]):
def state_change_callback(event: target.TargetStateChangedData):
"""Handle state change events."""
events.append(event)
@ -504,8 +504,10 @@ async def test_async_track_target_selector_state_change_event(
assert len(events) == len(entities_to_assert_change)
entities_seen = set()
for event in events:
entities_seen.add(event.data["entity_id"])
assert event.data["new_state"].state == last_state
state_change_event = event.state_change_event
entities_seen.add(state_change_event.data["entity_id"])
assert state_change_event.data["new_state"].state == last_state
assert event.targeted_entity_ids == set(entities_to_assert_change)
assert entities_seen == set(entities_to_assert_change)
events.clear()