Improve async_track_device_registry_updated_event callback typing (#97125)

This commit is contained in:
Marc Mueller 2023-07-24 09:42:01 +02:00 committed by GitHub
parent daa76bbab6
commit 3371c41bda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 34 additions and 13 deletions

View File

@ -34,7 +34,10 @@ from homeassistant.helpers import (
device_registry as dr, device_registry as dr,
entity_registry as er, entity_registry as er,
) )
from homeassistant.helpers.device_registry import DeviceEntry from homeassistant.helpers.device_registry import (
DeviceEntry,
EventDeviceRegistryUpdatedData,
)
from homeassistant.helpers.dispatcher import ( from homeassistant.helpers.dispatcher import (
async_dispatcher_connect, async_dispatcher_connect,
async_dispatcher_send, async_dispatcher_send,
@ -720,7 +723,9 @@ class MqttDiscoveryDeviceUpdate(ABC):
) )
return return
async def _async_device_removed(self, event: Event) -> None: async def _async_device_removed(
self, event: EventType[EventDeviceRegistryUpdatedData]
) -> None:
"""Handle the manual removal of a device.""" """Handle the manual removal of a device."""
if self._skip_device_removal or not async_removed_from_device( if self._skip_device_removal or not async_removed_from_device(
self.hass, event, cast(str, self._device_id), self._config_entry_id self.hass, event, cast(str, self._device_id), self._config_entry_id
@ -1178,14 +1183,16 @@ def update_device(
@callback @callback
def async_removed_from_device( def async_removed_from_device(
hass: HomeAssistant, event: Event, mqtt_device_id: str, config_entry_id: str hass: HomeAssistant,
event: EventType[EventDeviceRegistryUpdatedData],
mqtt_device_id: str,
config_entry_id: str,
) -> bool: ) -> bool:
"""Check if the passed event indicates MQTT was removed from a device.""" """Check if the passed event indicates MQTT was removed from a device."""
action: str = event.data["action"] if event.data["action"] not in ("remove", "update"):
if action not in ("remove", "update"):
return False return False
if action == "update": if event.data["action"] == "update":
if "config_entries" not in event.data["changes"]: if "config_entries" not in event.data["changes"]:
return False return False
device_registry = dr.async_get(hass) device_registry = dr.async_get(hass)

View File

@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, Literal, TypedDict, TypeVar, cast
from urllib.parse import urlparse from urllib.parse import urlparse
import attr import attr
from typing_extensions import NotRequired
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP
from homeassistant.core import Event, HomeAssistant, callback from homeassistant.core import Event, HomeAssistant, callback
@ -97,12 +96,25 @@ DEVICE_INFO_TYPES = {
DEVICE_INFO_KEYS = set.union(*(itm for itm in DEVICE_INFO_TYPES.values())) DEVICE_INFO_KEYS = set.union(*(itm for itm in DEVICE_INFO_TYPES.values()))
class EventDeviceRegistryUpdatedData(TypedDict): class _EventDeviceRegistryUpdatedData_CreateRemove(TypedDict):
"""EventDeviceRegistryUpdated data.""" """EventDeviceRegistryUpdated data for action type 'create' and 'remove'."""
action: Literal["create", "remove", "update"] action: Literal["create", "remove"]
device_id: str device_id: str
changes: NotRequired[dict[str, Any]]
class _EventDeviceRegistryUpdatedData_Update(TypedDict):
"""EventDeviceRegistryUpdated data for action type 'update'."""
action: Literal["update"]
device_id: str
changes: dict[str, Any]
EventDeviceRegistryUpdatedData = (
_EventDeviceRegistryUpdatedData_CreateRemove
| _EventDeviceRegistryUpdatedData_Update
)
class DeviceEntryType(StrEnum): class DeviceEntryType(StrEnum):

View File

@ -40,7 +40,7 @@ from homeassistant.loader import bind_hass
from homeassistant.util import dt as dt_util, ensure_unique_string, slugify from homeassistant.util import dt as dt_util, ensure_unique_string, slugify
from . import device_registry as dr, entity_registry as er from . import device_registry as dr, entity_registry as er
from .device_registry import DeviceEntryType from .device_registry import DeviceEntryType, EventDeviceRegistryUpdatedData
from .event import ( from .event import (
async_track_device_registry_updated_event, async_track_device_registry_updated_event,
async_track_entity_registry_updated_event, async_track_entity_registry_updated_event,
@ -1146,7 +1146,9 @@ class Entity(ABC):
self._unsub_device_updates = None self._unsub_device_updates = None
@callback @callback
def _async_device_registry_updated(self, event: Event) -> None: def _async_device_registry_updated(
self, event: EventType[EventDeviceRegistryUpdatedData]
) -> None:
"""Handle device registry update.""" """Handle device registry update."""
data = event.data data = event.data