diff --git a/homeassistant/helpers/target.py b/homeassistant/helpers/target.py index 239d1e66336..3ef78ee7f5e 100644 --- a/homeassistant/helpers/target.py +++ b/homeassistant/helpers/target.py @@ -40,6 +40,14 @@ from .typing import ConfigType _LOGGER = logging.getLogger(__name__) +@dataclasses.dataclass +class TargetStateChangedData: + """Data for state change events related to targets.""" + + state_change_event: Event[EventStateChangedData] + targeted_entity_ids: set[str] + + def _has_match(ids: str | list[str] | None) -> TypeGuard[str | list[str]]: """Check if ids can match anything.""" return ids not in (None, ENTITY_MATCH_NONE) @@ -259,7 +267,7 @@ class TargetStateChangeTracker: self, hass: HomeAssistant, selector_data: TargetSelectorData, - action: Callable[[Event[EventStateChangedData]], Any], + action: Callable[[TargetStateChangedData], Any], ) -> None: """Initialize the state change tracker.""" self._hass = hass @@ -281,6 +289,8 @@ class TargetStateChangeTracker: self._hass, self._selector_data, expand_group=False ) + tracked_entities = selected.referenced.union(selected.indirectly_referenced) + @callback def state_change_listener(event: Event[EventStateChangedData]) -> None: """Handle state change events.""" @@ -288,9 +298,7 @@ class TargetStateChangeTracker: event.data["entity_id"] in selected.referenced or event.data["entity_id"] in selected.indirectly_referenced ): - self._action(event) - - tracked_entities = selected.referenced.union(selected.indirectly_referenced) + self._action(TargetStateChangedData(event, tracked_entities)) _LOGGER.debug("Tracking state changes for entities: %s", tracked_entities) self._state_change_unsub = async_track_state_change_event( @@ -339,7 +347,7 @@ class TargetStateChangeTracker: def async_track_target_selector_state_change_event( hass: HomeAssistant, target_selector_config: ConfigType, - action: Callable[[Event[EventStateChangedData]], Any], + action: Callable[[TargetStateChangedData], Any], ) -> CALLBACK_TYPE: """Track state changes for entities referenced directly or indirectly in a target selector.""" selector_data = TargetSelectorData(target_selector_config) diff --git a/tests/helpers/test_target.py b/tests/helpers/test_target.py index c87a320e378..fa31ef375fd 100644 --- a/tests/helpers/test_target.py +++ b/tests/helpers/test_target.py @@ -14,7 +14,7 @@ from homeassistant.const import ( STATE_ON, EntityCategory, ) -from homeassistant.core import Event, EventStateChangedData, HomeAssistant, callback +from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import ( area_registry as ar, @@ -482,10 +482,10 @@ async def test_async_track_target_selector_state_change_event( hass: HomeAssistant, ) -> None: """Test async_track_target_selector_state_change_event with multiple targets.""" - events: list[Event[EventStateChangedData]] = [] + events: list[target.TargetStateChangedData] = [] @callback - def state_change_callback(event: Event[EventStateChangedData]): + def state_change_callback(event: target.TargetStateChangedData): """Handle state change events.""" events.append(event) @@ -504,8 +504,10 @@ async def test_async_track_target_selector_state_change_event( assert len(events) == len(entities_to_assert_change) entities_seen = set() for event in events: - entities_seen.add(event.data["entity_id"]) - assert event.data["new_state"].state == last_state + state_change_event = event.state_change_event + entities_seen.add(state_change_event.data["entity_id"]) + assert state_change_event.data["new_state"].state == last_state + assert event.targeted_entity_ids == set(entities_to_assert_change) assert entities_seen == set(entities_to_assert_change) events.clear()