diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index c65e87a2119..45a4459b5d3 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -6,10 +6,11 @@ from collections.abc import Coroutine, ValuesView from enum import StrEnum import logging import time -from typing import TYPE_CHECKING, Any, TypeVar, cast +from typing import TYPE_CHECKING, Any, Literal, TypedDict, TypeVar, cast from urllib.parse import urlparse import attr +from typing_extensions import NotRequired from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP from homeassistant.core import Event, HomeAssistant, callback @@ -96,6 +97,14 @@ DEVICE_INFO_TYPES = { DEVICE_INFO_KEYS = set.union(*(itm for itm in DEVICE_INFO_TYPES.values())) +class EventDeviceRegistryUpdatedData(TypedDict): + """EventDeviceRegistryUpdated data.""" + + action: Literal["create", "remove", "update"] + device_id: str + changes: NotRequired[dict[str, Any]] + + class DeviceEntryType(StrEnum): """Device entry type.""" diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index 5fc4609d812..248db9d5180 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -15,9 +15,10 @@ from datetime import datetime, timedelta from enum import StrEnum import logging import time -from typing import TYPE_CHECKING, Any, TypeVar, cast +from typing import TYPE_CHECKING, Any, Literal, TypedDict, TypeVar, cast import attr +from typing_extensions import NotRequired import voluptuous as vol from homeassistant.const import ( @@ -107,6 +108,15 @@ class RegistryEntryHider(StrEnum): USER = "user" +class EventEntityRegistryUpdatedData(TypedDict): + """EventEntityRegistryUpdated data.""" + + action: Literal["create", "remove", "update"] + entity_id: str + changes: NotRequired[dict[str, Any]] + old_entity_id: NotRequired[str] + + EntityOptionsType = Mapping[str, Mapping[str, Any]] ReadOnlyEntityOptionsType = ReadOnlyDict[str, Mapping[str, Any]] diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index 004a71fa810..830b6100111 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -2,7 +2,7 @@ from __future__ import annotations import asyncio -from collections.abc import Callable, Coroutine, Iterable, Sequence +from collections.abc import Callable, Coroutine, Iterable, Mapping, Sequence import copy from dataclasses import dataclass from datetime import datetime, timedelta @@ -10,7 +10,7 @@ import functools as ft import logging from random import randint import time -from typing import Any, Concatenate, ParamSpec, TypedDict, cast +from typing import Any, Concatenate, ParamSpec, TypedDict, TypeVar, cast import attr @@ -36,8 +36,14 @@ from homeassistant.loader import bind_hass from homeassistant.util import dt as dt_util from homeassistant.util.async_ import run_callback_threadsafe -from .device_registry import EVENT_DEVICE_REGISTRY_UPDATED -from .entity_registry import EVENT_ENTITY_REGISTRY_UPDATED +from .device_registry import ( + EVENT_DEVICE_REGISTRY_UPDATED, + EventDeviceRegistryUpdatedData, +) +from .entity_registry import ( + EVENT_ENTITY_REGISTRY_UPDATED, + EventEntityRegistryUpdatedData, +) from .ratelimit import KeyedRateLimit from .sun import get_astral_event_next from .template import RenderInfo, Template, result_as_boolean @@ -67,6 +73,7 @@ _LOGGER = logging.getLogger(__name__) RANDOM_MICROSECOND_MIN = 50000 RANDOM_MICROSECOND_MAX = 500000 +_TypedDictT = TypeVar("_TypedDictT", bound=Mapping[str, Any]) _P = ParamSpec("_P") @@ -313,10 +320,9 @@ def _async_track_state_change_event( TRACK_STATE_CHANGE_CALLBACKS, TRACK_STATE_CHANGE_LISTENER, EVENT_STATE_CHANGED, - # Remove type ignores when _async_track_event uses EventType - _async_dispatch_entity_id_event, # type: ignore[arg-type] - _async_state_change_filter, # type: ignore[arg-type] - action, # type: ignore[arg-type] + _async_dispatch_entity_id_event, + _async_state_change_filter, + action, ) @@ -351,12 +357,22 @@ def _async_track_event( listeners_key: str, event_type: str, dispatcher_callable: Callable[ - [HomeAssistant, dict[str, list[HassJob[[Event], Any]]], Event], None + [ + HomeAssistant, + dict[str, list[HassJob[[EventType[_TypedDictT]], Any]]], + EventType[_TypedDictT], + ], + None, ], filter_callable: Callable[ - [HomeAssistant, dict[str, list[HassJob[[Event], Any]]], Event], bool + [ + HomeAssistant, + dict[str, list[HassJob[[EventType[_TypedDictT]], Any]]], + EventType[_TypedDictT], + ], + bool, ], - action: Callable[[Event], None], + action: Callable[[EventType[_TypedDictT]], None], ) -> CALLBACK_TYPE: """Track an event by a specific key.""" if not keys: @@ -367,9 +383,9 @@ def _async_track_event( hass_data = hass.data - callbacks: dict[str, list[HassJob[[Event], Any]]] | None = hass_data.get( - callbacks_key - ) + callbacks: dict[ + str, list[HassJob[[EventType[_TypedDictT]], Any]] + ] | None = hass_data.get(callbacks_key) if not callbacks: callbacks = hass_data[callbacks_key] = {} @@ -395,8 +411,10 @@ def _async_track_event( @callback def _async_dispatch_old_entity_id_or_entity_id_event( hass: HomeAssistant, - callbacks: dict[str, list[HassJob[[Event], Any]]], - event: Event, + callbacks: dict[ + str, list[HassJob[[EventType[EventEntityRegistryUpdatedData]], Any]] + ], + event: EventType[EventEntityRegistryUpdatedData], ) -> None: """Dispatch to listeners.""" if not ( @@ -418,7 +436,11 @@ def _async_dispatch_old_entity_id_or_entity_id_event( @callback def _async_entity_registry_updated_filter( - hass: HomeAssistant, callbacks: dict[str, list[HassJob[[Event], Any]]], event: Event + hass: HomeAssistant, + callbacks: dict[ + str, list[HassJob[[EventType[EventEntityRegistryUpdatedData]], Any]] + ], + event: EventType[EventEntityRegistryUpdatedData], ) -> bool: """Filter entity registry updates by entity_id.""" return event.data.get("old_entity_id", event.data["entity_id"]) in callbacks @@ -451,7 +473,11 @@ def async_track_entity_registry_updated_event( @callback def _async_device_registry_updated_filter( - hass: HomeAssistant, callbacks: dict[str, list[HassJob[[Event], Any]]], event: Event + hass: HomeAssistant, + callbacks: dict[ + str, list[HassJob[[EventType[EventDeviceRegistryUpdatedData]], Any]] + ], + event: EventType[EventDeviceRegistryUpdatedData], ) -> bool: """Filter device registry updates by device_id.""" return event.data["device_id"] in callbacks @@ -460,8 +486,10 @@ def _async_device_registry_updated_filter( @callback def _async_dispatch_device_id_event( hass: HomeAssistant, - callbacks: dict[str, list[HassJob[[Event], Any]]], - event: Event, + callbacks: dict[ + str, list[HassJob[[EventType[EventDeviceRegistryUpdatedData]], Any]] + ], + event: EventType[EventDeviceRegistryUpdatedData], ) -> None: """Dispatch to listeners.""" if not (callbacks_list := callbacks.get(event.data["device_id"])): @@ -501,7 +529,9 @@ def async_track_device_registry_updated_event( @callback def _async_dispatch_domain_event( - hass: HomeAssistant, callbacks: dict[str, list[HassJob[[Event], Any]]], event: Event + hass: HomeAssistant, + callbacks: dict[str, list[HassJob[[EventType[EventStateChangedData]], Any]]], + event: EventType[EventStateChangedData], ) -> None: """Dispatch domain event listeners.""" domain = split_entity_id(event.data["entity_id"])[0] @@ -516,10 +546,12 @@ def _async_dispatch_domain_event( @callback def _async_domain_added_filter( - hass: HomeAssistant, callbacks: dict[str, list[HassJob[[Event], Any]]], event: Event + hass: HomeAssistant, + callbacks: dict[str, list[HassJob[[EventType[EventStateChangedData]], Any]]], + event: EventType[EventStateChangedData], ) -> bool: """Filter state changes by entity_id.""" - return event.data.get("old_state") is None and ( + return event.data["old_state"] is None and ( MATCH_ALL in callbacks or split_entity_id(event.data["entity_id"])[0] in callbacks ) @@ -558,10 +590,12 @@ def _async_track_state_added_domain( @callback def _async_domain_removed_filter( - hass: HomeAssistant, callbacks: dict[str, list[HassJob[[Event], Any]]], event: Event + hass: HomeAssistant, + callbacks: dict[str, list[HassJob[[EventType[EventStateChangedData]], Any]]], + event: EventType[EventStateChangedData], ) -> bool: """Filter state changes by entity_id.""" - return event.data.get("new_state") is None and ( + return event.data["new_state"] is None and ( MATCH_ALL in callbacks or split_entity_id(event.data["entity_id"])[0] in callbacks )