From 3371c41bda0b20f73534a891bd1cf89f4b65fa02 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Mon, 24 Jul 2023 09:42:01 +0200 Subject: [PATCH] Improve `async_track_device_registry_updated_event` callback typing (#97125) --- homeassistant/components/mqtt/mixins.py | 19 +++++++++++++------ homeassistant/helpers/device_registry.py | 22 +++++++++++++++++----- homeassistant/helpers/entity.py | 6 ++++-- 3 files changed, 34 insertions(+), 13 deletions(-) diff --git a/homeassistant/components/mqtt/mixins.py b/homeassistant/components/mqtt/mixins.py index 0a2ee68f7c4..ee7095bb3bc 100644 --- a/homeassistant/components/mqtt/mixins.py +++ b/homeassistant/components/mqtt/mixins.py @@ -34,7 +34,10 @@ from homeassistant.helpers import ( device_registry as dr, entity_registry as er, ) -from homeassistant.helpers.device_registry import DeviceEntry +from homeassistant.helpers.device_registry import ( + DeviceEntry, + EventDeviceRegistryUpdatedData, +) from homeassistant.helpers.dispatcher import ( async_dispatcher_connect, async_dispatcher_send, @@ -720,7 +723,9 @@ class MqttDiscoveryDeviceUpdate(ABC): ) return - async def _async_device_removed(self, event: Event) -> None: + async def _async_device_removed( + self, event: EventType[EventDeviceRegistryUpdatedData] + ) -> None: """Handle the manual removal of a device.""" if self._skip_device_removal or not async_removed_from_device( self.hass, event, cast(str, self._device_id), self._config_entry_id @@ -1178,14 +1183,16 @@ def update_device( @callback def async_removed_from_device( - hass: HomeAssistant, event: Event, mqtt_device_id: str, config_entry_id: str + hass: HomeAssistant, + event: EventType[EventDeviceRegistryUpdatedData], + mqtt_device_id: str, + config_entry_id: str, ) -> bool: """Check if the passed event indicates MQTT was removed from a device.""" - action: str = event.data["action"] - if action not in ("remove", "update"): + if event.data["action"] not in ("remove", "update"): return False - if action == "update": + if event.data["action"] == "update": if "config_entries" not in event.data["changes"]: return False device_registry = dr.async_get(hass) diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index 45a4459b5d3..f1eed86f10c 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, Literal, TypedDict, TypeVar, cast from urllib.parse import urlparse import attr -from typing_extensions import NotRequired from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP from homeassistant.core import Event, HomeAssistant, callback @@ -97,12 +96,25 @@ DEVICE_INFO_TYPES = { DEVICE_INFO_KEYS = set.union(*(itm for itm in DEVICE_INFO_TYPES.values())) -class EventDeviceRegistryUpdatedData(TypedDict): - """EventDeviceRegistryUpdated data.""" +class _EventDeviceRegistryUpdatedData_CreateRemove(TypedDict): + """EventDeviceRegistryUpdated data for action type 'create' and 'remove'.""" - action: Literal["create", "remove", "update"] + action: Literal["create", "remove"] device_id: str - changes: NotRequired[dict[str, Any]] + + +class _EventDeviceRegistryUpdatedData_Update(TypedDict): + """EventDeviceRegistryUpdated data for action type 'update'.""" + + action: Literal["update"] + device_id: str + changes: dict[str, Any] + + +EventDeviceRegistryUpdatedData = ( + _EventDeviceRegistryUpdatedData_CreateRemove + | _EventDeviceRegistryUpdatedData_Update +) class DeviceEntryType(StrEnum): diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index a720c1831d7..acb5568ccb0 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -40,7 +40,7 @@ from homeassistant.loader import bind_hass from homeassistant.util import dt as dt_util, ensure_unique_string, slugify from . import device_registry as dr, entity_registry as er -from .device_registry import DeviceEntryType +from .device_registry import DeviceEntryType, EventDeviceRegistryUpdatedData from .event import ( async_track_device_registry_updated_event, async_track_entity_registry_updated_event, @@ -1146,7 +1146,9 @@ class Entity(ABC): self._unsub_device_updates = None @callback - def _async_device_registry_updated(self, event: Event) -> None: + def _async_device_registry_updated( + self, event: EventType[EventDeviceRegistryUpdatedData] + ) -> None: """Handle device registry update.""" data = event.data