From 797a9c1eadb33c3a69d2539222d9dc1996334f2a Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Mon, 24 Jul 2023 09:11:41 +0200 Subject: [PATCH] Improve `async_track_state_added/removed_domain` callback typing (#97126) --- .../components/conversation/default_agent.py | 8 +++-- homeassistant/components/dhcp/__init__.py | 11 +++--- .../components/emulated_hue/config.py | 7 ++-- homeassistant/components/zone/__init__.py | 14 +++++--- tests/helpers/test_event.py | 36 +++++++++---------- 5 files changed, 44 insertions(+), 32 deletions(-) diff --git a/homeassistant/components/conversation/default_agent.py b/homeassistant/components/conversation/default_agent.py index 336d6287f18..b0a3702b5c9 100644 --- a/homeassistant/components/conversation/default_agent.py +++ b/homeassistant/components/conversation/default_agent.py @@ -32,7 +32,11 @@ from homeassistant.helpers import ( template, translation, ) -from homeassistant.helpers.event import async_track_state_added_domain +from homeassistant.helpers.event import ( + EventStateChangedData, + async_track_state_added_domain, +) +from homeassistant.helpers.typing import EventType from homeassistant.util.json import JsonObjectType, json_loads_object from .agent import AbstractConversationAgent, ConversationInput, ConversationResult @@ -95,7 +99,7 @@ def async_setup(hass: core.HomeAssistant) -> None: async_should_expose(hass, DOMAIN, entity_id) @core.callback - def async_entity_state_listener(event: core.Event) -> None: + def async_entity_state_listener(event: EventType[EventStateChangedData]) -> None: """Set expose flag on new entities.""" async_should_expose(hass, DOMAIN, event.data["entity_id"]) diff --git a/homeassistant/components/dhcp/__init__.py b/homeassistant/components/dhcp/__init__.py index 9f9ec48f347..b3cfd1b65f2 100644 --- a/homeassistant/components/dhcp/__init__.py +++ b/homeassistant/components/dhcp/__init__.py @@ -51,10 +51,11 @@ from homeassistant.helpers.device_registry import ( ) from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.event import ( + EventStateChangedData, async_track_state_added_domain, async_track_time_interval, ) -from homeassistant.helpers.typing import ConfigType +from homeassistant.helpers.typing import ConfigType, EventType from homeassistant.loader import DHCPMatcher, async_get_dhcp from homeassistant.util.async_ import run_callback_threadsafe from homeassistant.util.network import is_invalid, is_link_local, is_loopback @@ -317,14 +318,16 @@ class DeviceTrackerWatcher(WatcherBase): self._async_process_device_state(state) @callback - def _async_process_device_event(self, event: Event) -> None: + def _async_process_device_event( + self, event: EventType[EventStateChangedData] + ) -> None: """Process a device tracker state change event.""" self._async_process_device_state(event.data["new_state"]) @callback - def _async_process_device_state(self, state: State) -> None: + def _async_process_device_state(self, state: State | None) -> None: """Process a device tracker state.""" - if state.state != STATE_HOME: + if state is None or state.state != STATE_HOME: return attributes = state.attributes diff --git a/homeassistant/components/emulated_hue/config.py b/homeassistant/components/emulated_hue/config.py index 1de6ec98520..104e05605cb 100644 --- a/homeassistant/components/emulated_hue/config.py +++ b/homeassistant/components/emulated_hue/config.py @@ -15,13 +15,14 @@ from homeassistant.components import ( script, ) from homeassistant.const import CONF_ENTITIES, CONF_TYPE -from homeassistant.core import Event, HomeAssistant, State, callback, split_entity_id +from homeassistant.core import HomeAssistant, State, callback, split_entity_id from homeassistant.helpers import storage from homeassistant.helpers.event import ( + EventStateChangedData, async_track_state_added_domain, async_track_state_removed_domain, ) -from homeassistant.helpers.typing import ConfigType +from homeassistant.helpers.typing import ConfigType, EventType SUPPORTED_DOMAINS = { climate.DOMAIN, @@ -222,7 +223,7 @@ class Config: return states @callback - def _clear_exposed_cache(self, event: Event) -> None: + def _clear_exposed_cache(self, event: EventType[EventStateChangedData]) -> None: """Clear the cache of exposed states.""" self.get_exposed_states.cache_clear() # pylint: disable=no-member diff --git a/homeassistant/components/zone/__init__.py b/homeassistant/components/zone/__init__.py index 8d04987d4fa..77c225d72ec 100644 --- a/homeassistant/components/zone/__init__.py +++ b/homeassistant/components/zone/__init__.py @@ -37,7 +37,7 @@ from homeassistant.helpers import ( service, storage, ) -from homeassistant.helpers.typing import ConfigType +from homeassistant.helpers.typing import ConfigType, EventType from homeassistant.loader import bind_hass from homeassistant.util.location import distance @@ -155,15 +155,19 @@ def async_setup_track_zone_entity_ids(hass: HomeAssistant) -> None: hass.data[ZONE_ENTITY_IDS] = zone_entity_ids @callback - def _async_add_zone_entity_id(event_: Event) -> None: + def _async_add_zone_entity_id( + event_: EventType[event.EventStateChangedData], + ) -> None: """Add zone entity ID.""" - zone_entity_ids.append(event_.data[ATTR_ENTITY_ID]) + zone_entity_ids.append(event_.data["entity_id"]) zone_entity_ids.sort() @callback - def _async_remove_zone_entity_id(event_: Event) -> None: + def _async_remove_zone_entity_id( + event_: EventType[event.EventStateChangedData], + ) -> None: """Remove zone entity ID.""" - zone_entity_ids.remove(event_.data[ATTR_ENTITY_ID]) + zone_entity_ids.remove(event_.data["entity_id"]) event.async_track_state_added_domain(hass, DOMAIN, _async_add_zone_entity_id) event.async_track_state_removed_domain(hass, DOMAIN, _async_remove_zone_entity_id) diff --git a/tests/helpers/test_event.py b/tests/helpers/test_event.py index 434957dc131..ee33e20173c 100644 --- a/tests/helpers/test_event.py +++ b/tests/helpers/test_event.py @@ -544,16 +544,16 @@ async def test_async_track_state_added_domain(hass: HomeAssistant) -> None: multiple_entity_id_tracker = [] @ha.callback - def single_run_callback(event): - old_state = event.data.get("old_state") - new_state = event.data.get("new_state") + def single_run_callback(event: EventType[EventStateChangedData]): + old_state = event.data["old_state"] + new_state = event.data["new_state"] single_entity_id_tracker.append((old_state, new_state)) @ha.callback - def multiple_run_callback(event): - old_state = event.data.get("old_state") - new_state = event.data.get("new_state") + def multiple_run_callback(event: EventType[EventStateChangedData]): + old_state = event.data["old_state"] + new_state = event.data["new_state"] multiple_entity_id_tracker.append((old_state, new_state)) @@ -656,16 +656,16 @@ async def test_async_track_state_removed_domain(hass: HomeAssistant) -> None: multiple_entity_id_tracker = [] @ha.callback - def single_run_callback(event): - old_state = event.data.get("old_state") - new_state = event.data.get("new_state") + def single_run_callback(event: EventType[EventStateChangedData]): + old_state = event.data["old_state"] + new_state = event.data["new_state"] single_entity_id_tracker.append((old_state, new_state)) @ha.callback - def multiple_run_callback(event): - old_state = event.data.get("old_state") - new_state = event.data.get("new_state") + def multiple_run_callback(event: EventType[EventStateChangedData]): + old_state = event.data["old_state"] + new_state = event.data["new_state"] multiple_entity_id_tracker.append((old_state, new_state)) @@ -738,16 +738,16 @@ async def test_async_track_state_removed_domain_match_all(hass: HomeAssistant) - match_all_entity_id_tracker = [] @ha.callback - def single_run_callback(event): - old_state = event.data.get("old_state") - new_state = event.data.get("new_state") + def single_run_callback(event: EventType[EventStateChangedData]): + old_state = event.data["old_state"] + new_state = event.data["new_state"] single_entity_id_tracker.append((old_state, new_state)) @ha.callback - def match_all_run_callback(event): - old_state = event.data.get("old_state") - new_state = event.data.get("new_state") + def match_all_run_callback(event: EventType[EventStateChangedData]): + old_state = event.data["old_state"] + new_state = event.data["new_state"] match_all_entity_id_tracker.append((old_state, new_state))