Improve async_track_state_added/removed_domain callback typing (#97126)

This commit is contained in:
Marc Mueller 2023-07-24 09:11:41 +02:00 committed by GitHub
parent 8c870a5683
commit 797a9c1ead
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 44 additions and 32 deletions

View File

@ -32,7 +32,11 @@ from homeassistant.helpers import (
template, template,
translation, translation,
) )
from homeassistant.helpers.event import async_track_state_added_domain from homeassistant.helpers.event import (
EventStateChangedData,
async_track_state_added_domain,
)
from homeassistant.helpers.typing import EventType
from homeassistant.util.json import JsonObjectType, json_loads_object from homeassistant.util.json import JsonObjectType, json_loads_object
from .agent import AbstractConversationAgent, ConversationInput, ConversationResult from .agent import AbstractConversationAgent, ConversationInput, ConversationResult
@ -95,7 +99,7 @@ def async_setup(hass: core.HomeAssistant) -> None:
async_should_expose(hass, DOMAIN, entity_id) async_should_expose(hass, DOMAIN, entity_id)
@core.callback @core.callback
def async_entity_state_listener(event: core.Event) -> None: def async_entity_state_listener(event: EventType[EventStateChangedData]) -> None:
"""Set expose flag on new entities.""" """Set expose flag on new entities."""
async_should_expose(hass, DOMAIN, event.data["entity_id"]) async_should_expose(hass, DOMAIN, event.data["entity_id"])

View File

@ -51,10 +51,11 @@ from homeassistant.helpers.device_registry import (
) )
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.event import ( from homeassistant.helpers.event import (
EventStateChangedData,
async_track_state_added_domain, async_track_state_added_domain,
async_track_time_interval, async_track_time_interval,
) )
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType, EventType
from homeassistant.loader import DHCPMatcher, async_get_dhcp from homeassistant.loader import DHCPMatcher, async_get_dhcp
from homeassistant.util.async_ import run_callback_threadsafe from homeassistant.util.async_ import run_callback_threadsafe
from homeassistant.util.network import is_invalid, is_link_local, is_loopback from homeassistant.util.network import is_invalid, is_link_local, is_loopback
@ -317,14 +318,16 @@ class DeviceTrackerWatcher(WatcherBase):
self._async_process_device_state(state) self._async_process_device_state(state)
@callback @callback
def _async_process_device_event(self, event: Event) -> None: def _async_process_device_event(
self, event: EventType[EventStateChangedData]
) -> None:
"""Process a device tracker state change event.""" """Process a device tracker state change event."""
self._async_process_device_state(event.data["new_state"]) self._async_process_device_state(event.data["new_state"])
@callback @callback
def _async_process_device_state(self, state: State) -> None: def _async_process_device_state(self, state: State | None) -> None:
"""Process a device tracker state.""" """Process a device tracker state."""
if state.state != STATE_HOME: if state is None or state.state != STATE_HOME:
return return
attributes = state.attributes attributes = state.attributes

View File

@ -15,13 +15,14 @@ from homeassistant.components import (
script, script,
) )
from homeassistant.const import CONF_ENTITIES, CONF_TYPE from homeassistant.const import CONF_ENTITIES, CONF_TYPE
from homeassistant.core import Event, HomeAssistant, State, callback, split_entity_id from homeassistant.core import HomeAssistant, State, callback, split_entity_id
from homeassistant.helpers import storage from homeassistant.helpers import storage
from homeassistant.helpers.event import ( from homeassistant.helpers.event import (
EventStateChangedData,
async_track_state_added_domain, async_track_state_added_domain,
async_track_state_removed_domain, async_track_state_removed_domain,
) )
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType, EventType
SUPPORTED_DOMAINS = { SUPPORTED_DOMAINS = {
climate.DOMAIN, climate.DOMAIN,
@ -222,7 +223,7 @@ class Config:
return states return states
@callback @callback
def _clear_exposed_cache(self, event: Event) -> None: def _clear_exposed_cache(self, event: EventType[EventStateChangedData]) -> None:
"""Clear the cache of exposed states.""" """Clear the cache of exposed states."""
self.get_exposed_states.cache_clear() # pylint: disable=no-member self.get_exposed_states.cache_clear() # pylint: disable=no-member

View File

@ -37,7 +37,7 @@ from homeassistant.helpers import (
service, service,
storage, storage,
) )
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType, EventType
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.util.location import distance from homeassistant.util.location import distance
@ -155,15 +155,19 @@ def async_setup_track_zone_entity_ids(hass: HomeAssistant) -> None:
hass.data[ZONE_ENTITY_IDS] = zone_entity_ids hass.data[ZONE_ENTITY_IDS] = zone_entity_ids
@callback @callback
def _async_add_zone_entity_id(event_: Event) -> None: def _async_add_zone_entity_id(
event_: EventType[event.EventStateChangedData],
) -> None:
"""Add zone entity ID.""" """Add zone entity ID."""
zone_entity_ids.append(event_.data[ATTR_ENTITY_ID]) zone_entity_ids.append(event_.data["entity_id"])
zone_entity_ids.sort() zone_entity_ids.sort()
@callback @callback
def _async_remove_zone_entity_id(event_: Event) -> None: def _async_remove_zone_entity_id(
event_: EventType[event.EventStateChangedData],
) -> None:
"""Remove zone entity ID.""" """Remove zone entity ID."""
zone_entity_ids.remove(event_.data[ATTR_ENTITY_ID]) zone_entity_ids.remove(event_.data["entity_id"])
event.async_track_state_added_domain(hass, DOMAIN, _async_add_zone_entity_id) event.async_track_state_added_domain(hass, DOMAIN, _async_add_zone_entity_id)
event.async_track_state_removed_domain(hass, DOMAIN, _async_remove_zone_entity_id) event.async_track_state_removed_domain(hass, DOMAIN, _async_remove_zone_entity_id)

View File

@ -544,16 +544,16 @@ async def test_async_track_state_added_domain(hass: HomeAssistant) -> None:
multiple_entity_id_tracker = [] multiple_entity_id_tracker = []
@ha.callback @ha.callback
def single_run_callback(event): def single_run_callback(event: EventType[EventStateChangedData]):
old_state = event.data.get("old_state") old_state = event.data["old_state"]
new_state = event.data.get("new_state") new_state = event.data["new_state"]
single_entity_id_tracker.append((old_state, new_state)) single_entity_id_tracker.append((old_state, new_state))
@ha.callback @ha.callback
def multiple_run_callback(event): def multiple_run_callback(event: EventType[EventStateChangedData]):
old_state = event.data.get("old_state") old_state = event.data["old_state"]
new_state = event.data.get("new_state") new_state = event.data["new_state"]
multiple_entity_id_tracker.append((old_state, new_state)) multiple_entity_id_tracker.append((old_state, new_state))
@ -656,16 +656,16 @@ async def test_async_track_state_removed_domain(hass: HomeAssistant) -> None:
multiple_entity_id_tracker = [] multiple_entity_id_tracker = []
@ha.callback @ha.callback
def single_run_callback(event): def single_run_callback(event: EventType[EventStateChangedData]):
old_state = event.data.get("old_state") old_state = event.data["old_state"]
new_state = event.data.get("new_state") new_state = event.data["new_state"]
single_entity_id_tracker.append((old_state, new_state)) single_entity_id_tracker.append((old_state, new_state))
@ha.callback @ha.callback
def multiple_run_callback(event): def multiple_run_callback(event: EventType[EventStateChangedData]):
old_state = event.data.get("old_state") old_state = event.data["old_state"]
new_state = event.data.get("new_state") new_state = event.data["new_state"]
multiple_entity_id_tracker.append((old_state, new_state)) multiple_entity_id_tracker.append((old_state, new_state))
@ -738,16 +738,16 @@ async def test_async_track_state_removed_domain_match_all(hass: HomeAssistant) -
match_all_entity_id_tracker = [] match_all_entity_id_tracker = []
@ha.callback @ha.callback
def single_run_callback(event): def single_run_callback(event: EventType[EventStateChangedData]):
old_state = event.data.get("old_state") old_state = event.data["old_state"]
new_state = event.data.get("new_state") new_state = event.data["new_state"]
single_entity_id_tracker.append((old_state, new_state)) single_entity_id_tracker.append((old_state, new_state))
@ha.callback @ha.callback
def match_all_run_callback(event): def match_all_run_callback(event: EventType[EventStateChangedData]):
old_state = event.data.get("old_state") old_state = event.data["old_state"]
new_state = event.data.get("new_state") new_state = event.data["new_state"]
match_all_entity_id_tracker.append((old_state, new_state)) match_all_entity_id_tracker.append((old_state, new_state))