Use EventType for more helper methods (#97107)

This commit is contained in:
Marc Mueller 2023-07-23 23:22:04 +02:00 committed by GitHub
parent 54d7ba72ee
commit 69d7b035e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 80 additions and 27 deletions

View File

@ -6,10 +6,11 @@ from collections.abc import Coroutine, ValuesView
from enum import StrEnum
import logging
import time
from typing import TYPE_CHECKING, Any, TypeVar, cast
from typing import TYPE_CHECKING, Any, Literal, TypedDict, TypeVar, cast
from urllib.parse import urlparse
import attr
from typing_extensions import NotRequired
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP
from homeassistant.core import Event, HomeAssistant, callback
@ -96,6 +97,14 @@ DEVICE_INFO_TYPES = {
DEVICE_INFO_KEYS = set.union(*(itm for itm in DEVICE_INFO_TYPES.values()))
class EventDeviceRegistryUpdatedData(TypedDict):
"""EventDeviceRegistryUpdated data."""
action: Literal["create", "remove", "update"]
device_id: str
changes: NotRequired[dict[str, Any]]
class DeviceEntryType(StrEnum):
"""Device entry type."""

View File

@ -15,9 +15,10 @@ from datetime import datetime, timedelta
from enum import StrEnum
import logging
import time
from typing import TYPE_CHECKING, Any, TypeVar, cast
from typing import TYPE_CHECKING, Any, Literal, TypedDict, TypeVar, cast
import attr
from typing_extensions import NotRequired
import voluptuous as vol
from homeassistant.const import (
@ -107,6 +108,15 @@ class RegistryEntryHider(StrEnum):
USER = "user"
class EventEntityRegistryUpdatedData(TypedDict):
"""EventEntityRegistryUpdated data."""
action: Literal["create", "remove", "update"]
entity_id: str
changes: NotRequired[dict[str, Any]]
old_entity_id: NotRequired[str]
EntityOptionsType = Mapping[str, Mapping[str, Any]]
ReadOnlyEntityOptionsType = ReadOnlyDict[str, Mapping[str, Any]]

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable, Coroutine, Iterable, Sequence
from collections.abc import Callable, Coroutine, Iterable, Mapping, Sequence
import copy
from dataclasses import dataclass
from datetime import datetime, timedelta
@ -10,7 +10,7 @@ import functools as ft
import logging
from random import randint
import time
from typing import Any, Concatenate, ParamSpec, TypedDict, cast
from typing import Any, Concatenate, ParamSpec, TypedDict, TypeVar, cast
import attr
@ -36,8 +36,14 @@ from homeassistant.loader import bind_hass
from homeassistant.util import dt as dt_util
from homeassistant.util.async_ import run_callback_threadsafe
from .device_registry import EVENT_DEVICE_REGISTRY_UPDATED
from .entity_registry import EVENT_ENTITY_REGISTRY_UPDATED
from .device_registry import (
EVENT_DEVICE_REGISTRY_UPDATED,
EventDeviceRegistryUpdatedData,
)
from .entity_registry import (
EVENT_ENTITY_REGISTRY_UPDATED,
EventEntityRegistryUpdatedData,
)
from .ratelimit import KeyedRateLimit
from .sun import get_astral_event_next
from .template import RenderInfo, Template, result_as_boolean
@ -67,6 +73,7 @@ _LOGGER = logging.getLogger(__name__)
RANDOM_MICROSECOND_MIN = 50000
RANDOM_MICROSECOND_MAX = 500000
_TypedDictT = TypeVar("_TypedDictT", bound=Mapping[str, Any])
_P = ParamSpec("_P")
@ -313,10 +320,9 @@ def _async_track_state_change_event(
TRACK_STATE_CHANGE_CALLBACKS,
TRACK_STATE_CHANGE_LISTENER,
EVENT_STATE_CHANGED,
# Remove type ignores when _async_track_event uses EventType
_async_dispatch_entity_id_event, # type: ignore[arg-type]
_async_state_change_filter, # type: ignore[arg-type]
action, # type: ignore[arg-type]
_async_dispatch_entity_id_event,
_async_state_change_filter,
action,
)
@ -351,12 +357,22 @@ def _async_track_event(
listeners_key: str,
event_type: str,
dispatcher_callable: Callable[
[HomeAssistant, dict[str, list[HassJob[[Event], Any]]], Event], None
[
HomeAssistant,
dict[str, list[HassJob[[EventType[_TypedDictT]], Any]]],
EventType[_TypedDictT],
],
None,
],
filter_callable: Callable[
[HomeAssistant, dict[str, list[HassJob[[Event], Any]]], Event], bool
[
HomeAssistant,
dict[str, list[HassJob[[EventType[_TypedDictT]], Any]]],
EventType[_TypedDictT],
],
bool,
],
action: Callable[[Event], None],
action: Callable[[EventType[_TypedDictT]], None],
) -> CALLBACK_TYPE:
"""Track an event by a specific key."""
if not keys:
@ -367,9 +383,9 @@ def _async_track_event(
hass_data = hass.data
callbacks: dict[str, list[HassJob[[Event], Any]]] | None = hass_data.get(
callbacks_key
)
callbacks: dict[
str, list[HassJob[[EventType[_TypedDictT]], Any]]
] | None = hass_data.get(callbacks_key)
if not callbacks:
callbacks = hass_data[callbacks_key] = {}
@ -395,8 +411,10 @@ def _async_track_event(
@callback
def _async_dispatch_old_entity_id_or_entity_id_event(
hass: HomeAssistant,
callbacks: dict[str, list[HassJob[[Event], Any]]],
event: Event,
callbacks: dict[
str, list[HassJob[[EventType[EventEntityRegistryUpdatedData]], Any]]
],
event: EventType[EventEntityRegistryUpdatedData],
) -> None:
"""Dispatch to listeners."""
if not (
@ -418,7 +436,11 @@ def _async_dispatch_old_entity_id_or_entity_id_event(
@callback
def _async_entity_registry_updated_filter(
hass: HomeAssistant, callbacks: dict[str, list[HassJob[[Event], Any]]], event: Event
hass: HomeAssistant,
callbacks: dict[
str, list[HassJob[[EventType[EventEntityRegistryUpdatedData]], Any]]
],
event: EventType[EventEntityRegistryUpdatedData],
) -> bool:
"""Filter entity registry updates by entity_id."""
return event.data.get("old_entity_id", event.data["entity_id"]) in callbacks
@ -451,7 +473,11 @@ def async_track_entity_registry_updated_event(
@callback
def _async_device_registry_updated_filter(
hass: HomeAssistant, callbacks: dict[str, list[HassJob[[Event], Any]]], event: Event
hass: HomeAssistant,
callbacks: dict[
str, list[HassJob[[EventType[EventDeviceRegistryUpdatedData]], Any]]
],
event: EventType[EventDeviceRegistryUpdatedData],
) -> bool:
"""Filter device registry updates by device_id."""
return event.data["device_id"] in callbacks
@ -460,8 +486,10 @@ def _async_device_registry_updated_filter(
@callback
def _async_dispatch_device_id_event(
hass: HomeAssistant,
callbacks: dict[str, list[HassJob[[Event], Any]]],
event: Event,
callbacks: dict[
str, list[HassJob[[EventType[EventDeviceRegistryUpdatedData]], Any]]
],
event: EventType[EventDeviceRegistryUpdatedData],
) -> None:
"""Dispatch to listeners."""
if not (callbacks_list := callbacks.get(event.data["device_id"])):
@ -501,7 +529,9 @@ def async_track_device_registry_updated_event(
@callback
def _async_dispatch_domain_event(
hass: HomeAssistant, callbacks: dict[str, list[HassJob[[Event], Any]]], event: Event
hass: HomeAssistant,
callbacks: dict[str, list[HassJob[[EventType[EventStateChangedData]], Any]]],
event: EventType[EventStateChangedData],
) -> None:
"""Dispatch domain event listeners."""
domain = split_entity_id(event.data["entity_id"])[0]
@ -516,10 +546,12 @@ def _async_dispatch_domain_event(
@callback
def _async_domain_added_filter(
hass: HomeAssistant, callbacks: dict[str, list[HassJob[[Event], Any]]], event: Event
hass: HomeAssistant,
callbacks: dict[str, list[HassJob[[EventType[EventStateChangedData]], Any]]],
event: EventType[EventStateChangedData],
) -> bool:
"""Filter state changes by entity_id."""
return event.data.get("old_state") is None and (
return event.data["old_state"] is None and (
MATCH_ALL in callbacks
or split_entity_id(event.data["entity_id"])[0] in callbacks
)
@ -558,10 +590,12 @@ def _async_track_state_added_domain(
@callback
def _async_domain_removed_filter(
hass: HomeAssistant, callbacks: dict[str, list[HassJob[[Event], Any]]], event: Event
hass: HomeAssistant,
callbacks: dict[str, list[HassJob[[EventType[EventStateChangedData]], Any]]],
event: EventType[EventStateChangedData],
) -> bool:
"""Filter state changes by entity_id."""
return event.data.get("new_state") is None and (
return event.data["new_state"] is None and (
MATCH_ALL in callbacks
or split_entity_id(event.data["entity_id"])[0] in callbacks
)