From 4161f53beaaabee7f863b57df87b0ee31f42ba60 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Mon, 24 Jul 2023 12:42:29 +0200 Subject: [PATCH] Improve `async_track_state_change_filtered` callback typing (#97134) --- .../components/geo_location/trigger.py | 17 ++++++++++------- homeassistant/components/zone/__init__.py | 9 +++++---- tests/helpers/test_event.py | 14 +++++++------- 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/homeassistant/components/geo_location/trigger.py b/homeassistant/components/geo_location/trigger.py index 24632e78454..5527f5ec9f1 100644 --- a/homeassistant/components/geo_location/trigger.py +++ b/homeassistant/components/geo_location/trigger.py @@ -9,7 +9,6 @@ import voluptuous as vol from homeassistant.const import CONF_EVENT, CONF_PLATFORM, CONF_SOURCE, CONF_ZONE from homeassistant.core import ( CALLBACK_TYPE, - Event, HassJob, HomeAssistant, State, @@ -17,9 +16,13 @@ from homeassistant.core import ( ) from homeassistant.helpers import condition, config_validation as cv from homeassistant.helpers.config_validation import entity_domain -from homeassistant.helpers.event import TrackStates, async_track_state_change_filtered +from homeassistant.helpers.event import ( + EventStateChangedData, + TrackStates, + async_track_state_change_filtered, +) from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo -from homeassistant.helpers.typing import ConfigType +from homeassistant.helpers.typing import ConfigType, EventType from . import DOMAIN @@ -60,11 +63,11 @@ async def async_attach_trigger( job = HassJob(action) @callback - def state_change_listener(event: Event) -> None: + def state_change_listener(event: EventType[EventStateChangedData]) -> None: """Handle specific state changes.""" # Skip if the event's source does not match the trigger's source. - from_state = event.data.get("old_state") - to_state = event.data.get("new_state") + from_state = event.data["old_state"] + to_state = event.data["new_state"] if not source_match(from_state, source) and not source_match(to_state, source): return @@ -96,7 +99,7 @@ async def async_attach_trigger( **trigger_data, "platform": "geo_location", "source": source, - "entity_id": event.data.get("entity_id"), + "entity_id": event.data["entity_id"], "from_state": from_state, "to_state": to_state, "zone": zone_state, diff --git a/homeassistant/components/zone/__init__.py b/homeassistant/components/zone/__init__.py index 77c225d72ec..bfc9c2fce09 100644 --- a/homeassistant/components/zone/__init__.py +++ b/homeassistant/components/zone/__init__.py @@ -11,7 +11,6 @@ import voluptuous as vol from homeassistant import config_entries from homeassistant.const import ( ATTR_EDITABLE, - ATTR_ENTITY_ID, ATTR_LATITUDE, ATTR_LONGITUDE, ATTR_PERSONS, @@ -378,10 +377,12 @@ class Zone(collection.CollectionEntity): self.async_write_ha_state() @callback - def _person_state_change_listener(self, evt: Event) -> None: - person_entity_id = evt.data[ATTR_ENTITY_ID] + def _person_state_change_listener( + self, evt: EventType[event.EventStateChangedData] + ) -> None: + person_entity_id = evt.data["entity_id"] cur_count = len(self._persons_in_zone) - if self._state_is_in_zone(evt.data.get("new_state")): + if self._state_is_in_zone(evt.data["new_state"]): self._persons_in_zone.add(person_entity_id) elif person_entity_id in self._persons_in_zone: self._persons_in_zone.remove(person_entity_id) diff --git a/tests/helpers/test_event.py b/tests/helpers/test_event.py index 3c81977c393..2b77da09778 100644 --- a/tests/helpers/test_event.py +++ b/tests/helpers/test_event.py @@ -300,21 +300,21 @@ async def test_async_track_state_change_filtered(hass: HomeAssistant) -> None: multiple_entity_id_tracker = [] @ha.callback - def single_run_callback(event): - old_state = event.data.get("old_state") - new_state = event.data.get("new_state") + def single_run_callback(event: EventType[EventStateChangedData]) -> None: + old_state = event.data["old_state"] + new_state = event.data["new_state"] single_entity_id_tracker.append((old_state, new_state)) @ha.callback - def multiple_run_callback(event): - old_state = event.data.get("old_state") - new_state = event.data.get("new_state") + def multiple_run_callback(event: EventType[EventStateChangedData]) -> None: + old_state = event.data["old_state"] + new_state = event.data["new_state"] multiple_entity_id_tracker.append((old_state, new_state)) @ha.callback - def callback_that_throws(event): + def callback_that_throws(event: EventType[EventStateChangedData]) -> None: raise ValueError track_single = async_track_state_change_filtered(