From bdd253328d01a9ea001c703d6ab9ffd40f527264 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Sun, 23 Jul 2023 21:51:54 +0200 Subject: [PATCH] Add generic Event class (#97071) --- .../components/bayesian/binary_sensor.py | 13 +++-- homeassistant/components/bthome/logbook.py | 12 ++-- homeassistant/helpers/event.py | 55 ++++++++++++------- homeassistant/helpers/typing.py | 11 +++- 4 files changed, 58 insertions(+), 33 deletions(-) diff --git a/homeassistant/components/bayesian/binary_sensor.py b/homeassistant/components/bayesian/binary_sensor.py index 06baef1bd0e..43411e9ec0d 100644 --- a/homeassistant/components/bayesian/binary_sensor.py +++ b/homeassistant/components/bayesian/binary_sensor.py @@ -33,6 +33,7 @@ from homeassistant.helpers import condition import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.event import ( + EventStateChangedData, TrackTemplate, TrackTemplateResult, TrackTemplateResultInfo, @@ -41,7 +42,7 @@ from homeassistant.helpers.event import ( ) from homeassistant.helpers.reload import async_setup_reload_service from homeassistant.helpers.template import Template, result_as_boolean -from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType +from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, EventType from . import DOMAIN, PLATFORMS from .const import ( @@ -231,16 +232,20 @@ class BayesianBinarySensor(BinarySensorEntity): """ @callback - def async_threshold_sensor_state_listener(event: Event) -> None: + def async_threshold_sensor_state_listener( + event: EventType[EventStateChangedData], + ) -> None: """Handle sensor state changes. When a state changes, we must update our list of current observations, then calculate the new probability. """ - entity: str = event.data[CONF_ENTITY_ID] + entity_id = event.data["entity_id"] - self.current_observations.update(self._record_entity_observations(entity)) + self.current_observations.update( + self._record_entity_observations(entity_id) + ) self.async_set_context(event.context) self._recalculate_and_write_state() diff --git a/homeassistant/components/bthome/logbook.py b/homeassistant/components/bthome/logbook.py index 703ad671799..4111777375d 100644 --- a/homeassistant/components/bthome/logbook.py +++ b/homeassistant/components/bthome/logbook.py @@ -2,11 +2,11 @@ from __future__ import annotations from collections.abc import Callable -from typing import TYPE_CHECKING, cast from homeassistant.components.logbook import LOGBOOK_ENTRY_MESSAGE, LOGBOOK_ENTRY_NAME -from homeassistant.core import Event, HomeAssistant, callback +from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.device_registry import async_get +from homeassistant.helpers.typing import EventType from .const import ( BTHOME_BLE_EVENT, @@ -18,17 +18,17 @@ from .const import ( @callback def async_describe_events( hass: HomeAssistant, - async_describe_event: Callable[[str, str, Callable[[Event], dict[str, str]]], None], + async_describe_event: Callable[ + [str, str, Callable[[EventType[BTHomeBleEvent]], dict[str, str]]], None + ], ) -> None: """Describe logbook events.""" dr = async_get(hass) @callback - def async_describe_bthome_event(event: Event) -> dict[str, str]: + def async_describe_bthome_event(event: EventType[BTHomeBleEvent]) -> dict[str, str]: """Describe bthome logbook event.""" data = event.data - if TYPE_CHECKING: - data = cast(BTHomeBleEvent, data) # type: ignore[assignment] device = dr.async_get(data["device_id"]) name = device and device.name or f'BTHome {data["address"]}' if properties := data["event_properties"]: diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index b7254c5c347..004a71fa810 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -10,7 +10,7 @@ import functools as ft import logging from random import randint import time -from typing import Any, Concatenate, ParamSpec, cast +from typing import Any, Concatenate, ParamSpec, TypedDict, cast import attr @@ -41,7 +41,7 @@ from .entity_registry import EVENT_ENTITY_REGISTRY_UPDATED from .ratelimit import KeyedRateLimit from .sun import get_astral_event_next from .template import RenderInfo, Template, result_as_boolean -from .typing import TemplateVarsType +from .typing import EventType, TemplateVarsType TRACK_STATE_CHANGE_CALLBACKS = "track_state_change_callbacks" TRACK_STATE_CHANGE_LISTENER = "track_state_change_listener" @@ -117,6 +117,14 @@ class TrackTemplateResult: result: Any +class EventStateChangedData(TypedDict): + """EventStateChanged data.""" + + entity_id: str + old_state: State | None + new_state: State | None + + def threaded_listener_factory( async_factory: Callable[Concatenate[HomeAssistant, _P], Any] ) -> Callable[Concatenate[HomeAssistant, _P], CALLBACK_TYPE]: @@ -183,36 +191,38 @@ def async_track_state_change( job = HassJob(action, f"track state change {entity_ids} {from_state} {to_state}") @callback - def state_change_filter(event: Event) -> bool: + def state_change_filter(event: EventType[EventStateChangedData]) -> bool: """Handle specific state changes.""" if from_state is not None: - if (old_state := event.data.get("old_state")) is not None: - old_state = old_state.state + old_state_str: str | None = None + if (old_state := event.data["old_state"]) is not None: + old_state_str = old_state.state - if not match_from_state(old_state): + if not match_from_state(old_state_str): return False if to_state is not None: - if (new_state := event.data.get("new_state")) is not None: - new_state = new_state.state + new_state_str: str | None = None + if (new_state := event.data["new_state"]) is not None: + new_state_str = new_state.state - if not match_to_state(new_state): + if not match_to_state(new_state_str): return False return True @callback - def state_change_dispatcher(event: Event) -> None: + def state_change_dispatcher(event: EventType[EventStateChangedData]) -> None: """Handle specific state changes.""" hass.async_run_hass_job( job, event.data["entity_id"], - event.data.get("old_state"), + event.data["old_state"], event.data["new_state"], ) @callback - def state_change_listener(event: Event) -> None: + def state_change_listener(event: EventType[EventStateChangedData]) -> None: """Handle specific state changes.""" if not state_change_filter(event): return @@ -231,7 +241,7 @@ def async_track_state_change( return async_track_state_change_event(hass, entity_ids, state_change_listener) return hass.bus.async_listen( - EVENT_STATE_CHANGED, state_change_dispatcher, event_filter=state_change_filter + EVENT_STATE_CHANGED, state_change_dispatcher, event_filter=state_change_filter # type: ignore[arg-type] ) @@ -242,7 +252,7 @@ track_state_change = threaded_listener_factory(async_track_state_change) def async_track_state_change_event( hass: HomeAssistant, entity_ids: str | Iterable[str], - action: Callable[[Event], Any], + action: Callable[[EventType[EventStateChangedData]], Any], ) -> CALLBACK_TYPE: """Track specific state change events indexed by entity_id. @@ -263,8 +273,8 @@ def async_track_state_change_event( @callback def _async_dispatch_entity_id_event( hass: HomeAssistant, - callbacks: dict[str, list[HassJob[[Event], Any]]], - event: Event, + callbacks: dict[str, list[HassJob[[EventType[EventStateChangedData]], Any]]], + event: EventType[EventStateChangedData], ) -> None: """Dispatch to listeners.""" if not (callbacks_list := callbacks.get(event.data["entity_id"])): @@ -282,7 +292,9 @@ def _async_dispatch_entity_id_event( @callback def _async_state_change_filter( - hass: HomeAssistant, callbacks: dict[str, list[HassJob[[Event], Any]]], event: Event + hass: HomeAssistant, + callbacks: dict[str, list[HassJob[[EventType[EventStateChangedData]], Any]]], + event: EventType[EventStateChangedData], ) -> bool: """Filter state changes by entity_id.""" return event.data["entity_id"] in callbacks @@ -292,7 +304,7 @@ def _async_state_change_filter( def _async_track_state_change_event( hass: HomeAssistant, entity_ids: str | Iterable[str], - action: Callable[[Event], Any], + action: Callable[[EventType[EventStateChangedData]], Any], ) -> CALLBACK_TYPE: """async_track_state_change_event without lowercasing.""" return _async_track_event( @@ -301,9 +313,10 @@ def _async_track_state_change_event( TRACK_STATE_CHANGE_CALLBACKS, TRACK_STATE_CHANGE_LISTENER, EVENT_STATE_CHANGED, - _async_dispatch_entity_id_event, - _async_state_change_filter, - action, + # Remove type ignores when _async_track_event uses EventType + _async_dispatch_entity_id_event, # type: ignore[arg-type] + _async_state_change_filter, # type: ignore[arg-type] + action, # type: ignore[arg-type] ) diff --git a/homeassistant/helpers/typing.py b/homeassistant/helpers/typing.py index 5a76fd262a8..9e3f9de34fa 100644 --- a/homeassistant/helpers/typing.py +++ b/homeassistant/helpers/typing.py @@ -1,10 +1,12 @@ """Typing Helpers for Home Assistant.""" from collections.abc import Mapping from enum import Enum -from typing import Any +from typing import Any, Generic, TypeVar import homeassistant.core +_DataT = TypeVar("_DataT") + GPSType = tuple[float, float] ConfigType = dict[str, Any] ContextType = homeassistant.core.Context @@ -32,5 +34,10 @@ UNDEFINED = UndefinedType._singleton # pylint: disable=protected-access # that may rely on them. # In due time they will be removed. HomeAssistantType = homeassistant.core.HomeAssistant -EventType = homeassistant.core.Event ServiceCallType = homeassistant.core.ServiceCall + + +class EventType(homeassistant.core.Event, Generic[_DataT]): + """Generic Event class to better type data.""" + + data: _DataT # type: ignore[assignment]