Improve async_track_state_change_filtered callback typing (#97134)

This commit is contained in:
Marc Mueller 2023-07-24 12:42:29 +02:00 committed by GitHub
parent 582499a260
commit 4161f53bea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 18 deletions

View File

@ -9,7 +9,6 @@ import voluptuous as vol
from homeassistant.const import CONF_EVENT, CONF_PLATFORM, CONF_SOURCE, CONF_ZONE
from homeassistant.core import (
CALLBACK_TYPE,
Event,
HassJob,
HomeAssistant,
State,
@ -17,9 +16,13 @@ from homeassistant.core import (
)
from homeassistant.helpers import condition, config_validation as cv
from homeassistant.helpers.config_validation import entity_domain
from homeassistant.helpers.event import TrackStates, async_track_state_change_filtered
from homeassistant.helpers.event import (
EventStateChangedData,
TrackStates,
async_track_state_change_filtered,
)
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
from homeassistant.helpers.typing import ConfigType
from homeassistant.helpers.typing import ConfigType, EventType
from . import DOMAIN
@ -60,11 +63,11 @@ async def async_attach_trigger(
job = HassJob(action)
@callback
def state_change_listener(event: Event) -> None:
def state_change_listener(event: EventType[EventStateChangedData]) -> None:
"""Handle specific state changes."""
# Skip if the event's source does not match the trigger's source.
from_state = event.data.get("old_state")
to_state = event.data.get("new_state")
from_state = event.data["old_state"]
to_state = event.data["new_state"]
if not source_match(from_state, source) and not source_match(to_state, source):
return
@ -96,7 +99,7 @@ async def async_attach_trigger(
**trigger_data,
"platform": "geo_location",
"source": source,
"entity_id": event.data.get("entity_id"),
"entity_id": event.data["entity_id"],
"from_state": from_state,
"to_state": to_state,
"zone": zone_state,

View File

@ -11,7 +11,6 @@ import voluptuous as vol
from homeassistant import config_entries
from homeassistant.const import (
ATTR_EDITABLE,
ATTR_ENTITY_ID,
ATTR_LATITUDE,
ATTR_LONGITUDE,
ATTR_PERSONS,
@ -378,10 +377,12 @@ class Zone(collection.CollectionEntity):
self.async_write_ha_state()
@callback
def _person_state_change_listener(self, evt: Event) -> None:
person_entity_id = evt.data[ATTR_ENTITY_ID]
def _person_state_change_listener(
self, evt: EventType[event.EventStateChangedData]
) -> None:
person_entity_id = evt.data["entity_id"]
cur_count = len(self._persons_in_zone)
if self._state_is_in_zone(evt.data.get("new_state")):
if self._state_is_in_zone(evt.data["new_state"]):
self._persons_in_zone.add(person_entity_id)
elif person_entity_id in self._persons_in_zone:
self._persons_in_zone.remove(person_entity_id)

View File

@ -300,21 +300,21 @@ async def test_async_track_state_change_filtered(hass: HomeAssistant) -> None:
multiple_entity_id_tracker = []
@ha.callback
def single_run_callback(event):
old_state = event.data.get("old_state")
new_state = event.data.get("new_state")
def single_run_callback(event: EventType[EventStateChangedData]) -> None:
old_state = event.data["old_state"]
new_state = event.data["new_state"]
single_entity_id_tracker.append((old_state, new_state))
@ha.callback
def multiple_run_callback(event):
old_state = event.data.get("old_state")
new_state = event.data.get("new_state")
def multiple_run_callback(event: EventType[EventStateChangedData]) -> None:
old_state = event.data["old_state"]
new_state = event.data["new_state"]
multiple_entity_id_tracker.append((old_state, new_state))
@ha.callback
def callback_that_throws(event):
def callback_that_throws(event: EventType[EventStateChangedData]) -> None:
raise ValueError
track_single = async_track_state_change_filtered(