Add list of targeted entities to target state event

This commit is contained in:
abmantis 2025-07-21 19:33:18 +01:00
parent 941d3c2be4
commit 8203308ab4
2 changed files with 20 additions and 10 deletions

View File

@ -40,6 +40,14 @@ from .typing import ConfigType
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@dataclasses.dataclass
class TargetStateChangedData:
"""Data for state change events related to targets."""
state_change_event: Event[EventStateChangedData]
targeted_entity_ids: set[str]
def _has_match(ids: str | list[str] | None) -> TypeGuard[str | list[str]]: def _has_match(ids: str | list[str] | None) -> TypeGuard[str | list[str]]:
"""Check if ids can match anything.""" """Check if ids can match anything."""
return ids not in (None, ENTITY_MATCH_NONE) return ids not in (None, ENTITY_MATCH_NONE)
@ -259,7 +267,7 @@ class TargetStateChangeTracker:
self, self,
hass: HomeAssistant, hass: HomeAssistant,
selector_data: TargetSelectorData, selector_data: TargetSelectorData,
action: Callable[[Event[EventStateChangedData]], Any], action: Callable[[TargetStateChangedData], Any],
) -> None: ) -> None:
"""Initialize the state change tracker.""" """Initialize the state change tracker."""
self._hass = hass self._hass = hass
@ -281,6 +289,8 @@ class TargetStateChangeTracker:
self._hass, self._selector_data, expand_group=False self._hass, self._selector_data, expand_group=False
) )
tracked_entities = selected.referenced.union(selected.indirectly_referenced)
@callback @callback
def state_change_listener(event: Event[EventStateChangedData]) -> None: def state_change_listener(event: Event[EventStateChangedData]) -> None:
"""Handle state change events.""" """Handle state change events."""
@ -288,9 +298,7 @@ class TargetStateChangeTracker:
event.data["entity_id"] in selected.referenced event.data["entity_id"] in selected.referenced
or event.data["entity_id"] in selected.indirectly_referenced or event.data["entity_id"] in selected.indirectly_referenced
): ):
self._action(event) self._action(TargetStateChangedData(event, tracked_entities))
tracked_entities = selected.referenced.union(selected.indirectly_referenced)
_LOGGER.debug("Tracking state changes for entities: %s", tracked_entities) _LOGGER.debug("Tracking state changes for entities: %s", tracked_entities)
self._state_change_unsub = async_track_state_change_event( self._state_change_unsub = async_track_state_change_event(
@ -339,7 +347,7 @@ class TargetStateChangeTracker:
def async_track_target_selector_state_change_event( def async_track_target_selector_state_change_event(
hass: HomeAssistant, hass: HomeAssistant,
target_selector_config: ConfigType, target_selector_config: ConfigType,
action: Callable[[Event[EventStateChangedData]], Any], action: Callable[[TargetStateChangedData], Any],
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Track state changes for entities referenced directly or indirectly in a target selector.""" """Track state changes for entities referenced directly or indirectly in a target selector."""
selector_data = TargetSelectorData(target_selector_config) selector_data = TargetSelectorData(target_selector_config)

View File

@ -14,7 +14,7 @@ from homeassistant.const import (
STATE_ON, STATE_ON,
EntityCategory, EntityCategory,
) )
from homeassistant.core import Event, EventStateChangedData, HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import ( from homeassistant.helpers import (
area_registry as ar, area_registry as ar,
@ -482,10 +482,10 @@ async def test_async_track_target_selector_state_change_event(
hass: HomeAssistant, hass: HomeAssistant,
) -> None: ) -> None:
"""Test async_track_target_selector_state_change_event with multiple targets.""" """Test async_track_target_selector_state_change_event with multiple targets."""
events: list[Event[EventStateChangedData]] = [] events: list[target.TargetStateChangedData] = []
@callback @callback
def state_change_callback(event: Event[EventStateChangedData]): def state_change_callback(event: target.TargetStateChangedData):
"""Handle state change events.""" """Handle state change events."""
events.append(event) events.append(event)
@ -504,8 +504,10 @@ async def test_async_track_target_selector_state_change_event(
assert len(events) == len(entities_to_assert_change) assert len(events) == len(entities_to_assert_change)
entities_seen = set() entities_seen = set()
for event in events: for event in events:
entities_seen.add(event.data["entity_id"]) state_change_event = event.state_change_event
assert event.data["new_state"].state == last_state entities_seen.add(state_change_event.data["entity_id"])
assert state_change_event.data["new_state"].state == last_state
assert event.targeted_entity_ids == set(entities_to_assert_change)
assert entities_seen == set(entities_to_assert_change) assert entities_seen == set(entities_to_assert_change)
events.clear() events.clear()