Improve async_track_state_change_filtered callback typing (#97134)

This commit is contained in:
Marc Mueller 2023-07-24 12:42:29 +02:00 committed by GitHub
parent 582499a260
commit 4161f53bea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 18 deletions

View File

@ -9,7 +9,6 @@ import voluptuous as vol
from homeassistant.const import CONF_EVENT, CONF_PLATFORM, CONF_SOURCE, CONF_ZONE from homeassistant.const import CONF_EVENT, CONF_PLATFORM, CONF_SOURCE, CONF_ZONE
from homeassistant.core import ( from homeassistant.core import (
CALLBACK_TYPE, CALLBACK_TYPE,
Event,
HassJob, HassJob,
HomeAssistant, HomeAssistant,
State, State,
@ -17,9 +16,13 @@ from homeassistant.core import (
) )
from homeassistant.helpers import condition, config_validation as cv from homeassistant.helpers import condition, config_validation as cv
from homeassistant.helpers.config_validation import entity_domain from homeassistant.helpers.config_validation import entity_domain
from homeassistant.helpers.event import TrackStates, async_track_state_change_filtered from homeassistant.helpers.event import (
EventStateChangedData,
TrackStates,
async_track_state_change_filtered,
)
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType, EventType
from . import DOMAIN from . import DOMAIN
@ -60,11 +63,11 @@ async def async_attach_trigger(
job = HassJob(action) job = HassJob(action)
@callback @callback
def state_change_listener(event: Event) -> None: def state_change_listener(event: EventType[EventStateChangedData]) -> None:
"""Handle specific state changes.""" """Handle specific state changes."""
# Skip if the event's source does not match the trigger's source. # Skip if the event's source does not match the trigger's source.
from_state = event.data.get("old_state") from_state = event.data["old_state"]
to_state = event.data.get("new_state") to_state = event.data["new_state"]
if not source_match(from_state, source) and not source_match(to_state, source): if not source_match(from_state, source) and not source_match(to_state, source):
return return
@ -96,7 +99,7 @@ async def async_attach_trigger(
**trigger_data, **trigger_data,
"platform": "geo_location", "platform": "geo_location",
"source": source, "source": source,
"entity_id": event.data.get("entity_id"), "entity_id": event.data["entity_id"],
"from_state": from_state, "from_state": from_state,
"to_state": to_state, "to_state": to_state,
"zone": zone_state, "zone": zone_state,

View File

@ -11,7 +11,6 @@ import voluptuous as vol
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.const import ( from homeassistant.const import (
ATTR_EDITABLE, ATTR_EDITABLE,
ATTR_ENTITY_ID,
ATTR_LATITUDE, ATTR_LATITUDE,
ATTR_LONGITUDE, ATTR_LONGITUDE,
ATTR_PERSONS, ATTR_PERSONS,
@ -378,10 +377,12 @@ class Zone(collection.CollectionEntity):
self.async_write_ha_state() self.async_write_ha_state()
@callback @callback
def _person_state_change_listener(self, evt: Event) -> None: def _person_state_change_listener(
person_entity_id = evt.data[ATTR_ENTITY_ID] self, evt: EventType[event.EventStateChangedData]
) -> None:
person_entity_id = evt.data["entity_id"]
cur_count = len(self._persons_in_zone) cur_count = len(self._persons_in_zone)
if self._state_is_in_zone(evt.data.get("new_state")): if self._state_is_in_zone(evt.data["new_state"]):
self._persons_in_zone.add(person_entity_id) self._persons_in_zone.add(person_entity_id)
elif person_entity_id in self._persons_in_zone: elif person_entity_id in self._persons_in_zone:
self._persons_in_zone.remove(person_entity_id) self._persons_in_zone.remove(person_entity_id)

View File

@ -300,21 +300,21 @@ async def test_async_track_state_change_filtered(hass: HomeAssistant) -> None:
multiple_entity_id_tracker = [] multiple_entity_id_tracker = []
@ha.callback @ha.callback
def single_run_callback(event): def single_run_callback(event: EventType[EventStateChangedData]) -> None:
old_state = event.data.get("old_state") old_state = event.data["old_state"]
new_state = event.data.get("new_state") new_state = event.data["new_state"]
single_entity_id_tracker.append((old_state, new_state)) single_entity_id_tracker.append((old_state, new_state))
@ha.callback @ha.callback
def multiple_run_callback(event): def multiple_run_callback(event: EventType[EventStateChangedData]) -> None:
old_state = event.data.get("old_state") old_state = event.data["old_state"]
new_state = event.data.get("new_state") new_state = event.data["new_state"]
multiple_entity_id_tracker.append((old_state, new_state)) multiple_entity_id_tracker.append((old_state, new_state))
@ha.callback @ha.callback
def callback_that_throws(event): def callback_that_throws(event: EventType[EventStateChangedData]) -> None:
raise ValueError raise ValueError
track_single = async_track_state_change_filtered( track_single = async_track_state_change_filtered(