Add generic Event class (#97071)

This commit is contained in:
Marc Mueller 2023-07-23 21:51:54 +02:00 committed by GitHub
parent 860a37aa65
commit bdd253328d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 58 additions and 33 deletions

View File

@ -33,6 +33,7 @@ from homeassistant.helpers import condition
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import ( from homeassistant.helpers.event import (
EventStateChangedData,
TrackTemplate, TrackTemplate,
TrackTemplateResult, TrackTemplateResult,
TrackTemplateResultInfo, TrackTemplateResultInfo,
@ -41,7 +42,7 @@ from homeassistant.helpers.event import (
) )
from homeassistant.helpers.reload import async_setup_reload_service from homeassistant.helpers.reload import async_setup_reload_service
from homeassistant.helpers.template import Template, result_as_boolean from homeassistant.helpers.template import Template, result_as_boolean
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, EventType
from . import DOMAIN, PLATFORMS from . import DOMAIN, PLATFORMS
from .const import ( from .const import (
@ -231,16 +232,20 @@ class BayesianBinarySensor(BinarySensorEntity):
""" """
@callback @callback
def async_threshold_sensor_state_listener(event: Event) -> None: def async_threshold_sensor_state_listener(
event: EventType[EventStateChangedData],
) -> None:
"""Handle sensor state changes. """Handle sensor state changes.
When a state changes, we must update our list of current observations, When a state changes, we must update our list of current observations,
then calculate the new probability. then calculate the new probability.
""" """
entity: str = event.data[CONF_ENTITY_ID] entity_id = event.data["entity_id"]
self.current_observations.update(self._record_entity_observations(entity)) self.current_observations.update(
self._record_entity_observations(entity_id)
)
self.async_set_context(event.context) self.async_set_context(event.context)
self._recalculate_and_write_state() self._recalculate_and_write_state()

View File

@ -2,11 +2,11 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from typing import TYPE_CHECKING, cast
from homeassistant.components.logbook import LOGBOOK_ENTRY_MESSAGE, LOGBOOK_ENTRY_NAME from homeassistant.components.logbook import LOGBOOK_ENTRY_MESSAGE, LOGBOOK_ENTRY_NAME
from homeassistant.core import Event, HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.device_registry import async_get from homeassistant.helpers.device_registry import async_get
from homeassistant.helpers.typing import EventType
from .const import ( from .const import (
BTHOME_BLE_EVENT, BTHOME_BLE_EVENT,
@ -18,17 +18,17 @@ from .const import (
@callback @callback
def async_describe_events( def async_describe_events(
hass: HomeAssistant, hass: HomeAssistant,
async_describe_event: Callable[[str, str, Callable[[Event], dict[str, str]]], None], async_describe_event: Callable[
[str, str, Callable[[EventType[BTHomeBleEvent]], dict[str, str]]], None
],
) -> None: ) -> None:
"""Describe logbook events.""" """Describe logbook events."""
dr = async_get(hass) dr = async_get(hass)
@callback @callback
def async_describe_bthome_event(event: Event) -> dict[str, str]: def async_describe_bthome_event(event: EventType[BTHomeBleEvent]) -> dict[str, str]:
"""Describe bthome logbook event.""" """Describe bthome logbook event."""
data = event.data data = event.data
if TYPE_CHECKING:
data = cast(BTHomeBleEvent, data) # type: ignore[assignment]
device = dr.async_get(data["device_id"]) device = dr.async_get(data["device_id"])
name = device and device.name or f'BTHome {data["address"]}' name = device and device.name or f'BTHome {data["address"]}'
if properties := data["event_properties"]: if properties := data["event_properties"]:

View File

@ -10,7 +10,7 @@ import functools as ft
import logging import logging
from random import randint from random import randint
import time import time
from typing import Any, Concatenate, ParamSpec, cast from typing import Any, Concatenate, ParamSpec, TypedDict, cast
import attr import attr
@ -41,7 +41,7 @@ from .entity_registry import EVENT_ENTITY_REGISTRY_UPDATED
from .ratelimit import KeyedRateLimit from .ratelimit import KeyedRateLimit
from .sun import get_astral_event_next from .sun import get_astral_event_next
from .template import RenderInfo, Template, result_as_boolean from .template import RenderInfo, Template, result_as_boolean
from .typing import TemplateVarsType from .typing import EventType, TemplateVarsType
TRACK_STATE_CHANGE_CALLBACKS = "track_state_change_callbacks" TRACK_STATE_CHANGE_CALLBACKS = "track_state_change_callbacks"
TRACK_STATE_CHANGE_LISTENER = "track_state_change_listener" TRACK_STATE_CHANGE_LISTENER = "track_state_change_listener"
@ -117,6 +117,14 @@ class TrackTemplateResult:
result: Any result: Any
class EventStateChangedData(TypedDict):
"""EventStateChanged data."""
entity_id: str
old_state: State | None
new_state: State | None
def threaded_listener_factory( def threaded_listener_factory(
async_factory: Callable[Concatenate[HomeAssistant, _P], Any] async_factory: Callable[Concatenate[HomeAssistant, _P], Any]
) -> Callable[Concatenate[HomeAssistant, _P], CALLBACK_TYPE]: ) -> Callable[Concatenate[HomeAssistant, _P], CALLBACK_TYPE]:
@ -183,36 +191,38 @@ def async_track_state_change(
job = HassJob(action, f"track state change {entity_ids} {from_state} {to_state}") job = HassJob(action, f"track state change {entity_ids} {from_state} {to_state}")
@callback @callback
def state_change_filter(event: Event) -> bool: def state_change_filter(event: EventType[EventStateChangedData]) -> bool:
"""Handle specific state changes.""" """Handle specific state changes."""
if from_state is not None: if from_state is not None:
if (old_state := event.data.get("old_state")) is not None: old_state_str: str | None = None
old_state = old_state.state if (old_state := event.data["old_state"]) is not None:
old_state_str = old_state.state
if not match_from_state(old_state): if not match_from_state(old_state_str):
return False return False
if to_state is not None: if to_state is not None:
if (new_state := event.data.get("new_state")) is not None: new_state_str: str | None = None
new_state = new_state.state if (new_state := event.data["new_state"]) is not None:
new_state_str = new_state.state
if not match_to_state(new_state): if not match_to_state(new_state_str):
return False return False
return True return True
@callback @callback
def state_change_dispatcher(event: Event) -> None: def state_change_dispatcher(event: EventType[EventStateChangedData]) -> None:
"""Handle specific state changes.""" """Handle specific state changes."""
hass.async_run_hass_job( hass.async_run_hass_job(
job, job,
event.data["entity_id"], event.data["entity_id"],
event.data.get("old_state"), event.data["old_state"],
event.data["new_state"], event.data["new_state"],
) )
@callback @callback
def state_change_listener(event: Event) -> None: def state_change_listener(event: EventType[EventStateChangedData]) -> None:
"""Handle specific state changes.""" """Handle specific state changes."""
if not state_change_filter(event): if not state_change_filter(event):
return return
@ -231,7 +241,7 @@ def async_track_state_change(
return async_track_state_change_event(hass, entity_ids, state_change_listener) return async_track_state_change_event(hass, entity_ids, state_change_listener)
return hass.bus.async_listen( return hass.bus.async_listen(
EVENT_STATE_CHANGED, state_change_dispatcher, event_filter=state_change_filter EVENT_STATE_CHANGED, state_change_dispatcher, event_filter=state_change_filter # type: ignore[arg-type]
) )
@ -242,7 +252,7 @@ track_state_change = threaded_listener_factory(async_track_state_change)
def async_track_state_change_event( def async_track_state_change_event(
hass: HomeAssistant, hass: HomeAssistant,
entity_ids: str | Iterable[str], entity_ids: str | Iterable[str],
action: Callable[[Event], Any], action: Callable[[EventType[EventStateChangedData]], Any],
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Track specific state change events indexed by entity_id. """Track specific state change events indexed by entity_id.
@ -263,8 +273,8 @@ def async_track_state_change_event(
@callback @callback
def _async_dispatch_entity_id_event( def _async_dispatch_entity_id_event(
hass: HomeAssistant, hass: HomeAssistant,
callbacks: dict[str, list[HassJob[[Event], Any]]], callbacks: dict[str, list[HassJob[[EventType[EventStateChangedData]], Any]]],
event: Event, event: EventType[EventStateChangedData],
) -> None: ) -> None:
"""Dispatch to listeners.""" """Dispatch to listeners."""
if not (callbacks_list := callbacks.get(event.data["entity_id"])): if not (callbacks_list := callbacks.get(event.data["entity_id"])):
@ -282,7 +292,9 @@ def _async_dispatch_entity_id_event(
@callback @callback
def _async_state_change_filter( def _async_state_change_filter(
hass: HomeAssistant, callbacks: dict[str, list[HassJob[[Event], Any]]], event: Event hass: HomeAssistant,
callbacks: dict[str, list[HassJob[[EventType[EventStateChangedData]], Any]]],
event: EventType[EventStateChangedData],
) -> bool: ) -> bool:
"""Filter state changes by entity_id.""" """Filter state changes by entity_id."""
return event.data["entity_id"] in callbacks return event.data["entity_id"] in callbacks
@ -292,7 +304,7 @@ def _async_state_change_filter(
def _async_track_state_change_event( def _async_track_state_change_event(
hass: HomeAssistant, hass: HomeAssistant,
entity_ids: str | Iterable[str], entity_ids: str | Iterable[str],
action: Callable[[Event], Any], action: Callable[[EventType[EventStateChangedData]], Any],
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""async_track_state_change_event without lowercasing.""" """async_track_state_change_event without lowercasing."""
return _async_track_event( return _async_track_event(
@ -301,9 +313,10 @@ def _async_track_state_change_event(
TRACK_STATE_CHANGE_CALLBACKS, TRACK_STATE_CHANGE_CALLBACKS,
TRACK_STATE_CHANGE_LISTENER, TRACK_STATE_CHANGE_LISTENER,
EVENT_STATE_CHANGED, EVENT_STATE_CHANGED,
_async_dispatch_entity_id_event, # Remove type ignores when _async_track_event uses EventType
_async_state_change_filter, _async_dispatch_entity_id_event, # type: ignore[arg-type]
action, _async_state_change_filter, # type: ignore[arg-type]
action, # type: ignore[arg-type]
) )

View File

@ -1,10 +1,12 @@
"""Typing Helpers for Home Assistant.""" """Typing Helpers for Home Assistant."""
from collections.abc import Mapping from collections.abc import Mapping
from enum import Enum from enum import Enum
from typing import Any from typing import Any, Generic, TypeVar
import homeassistant.core import homeassistant.core
_DataT = TypeVar("_DataT")
GPSType = tuple[float, float] GPSType = tuple[float, float]
ConfigType = dict[str, Any] ConfigType = dict[str, Any]
ContextType = homeassistant.core.Context ContextType = homeassistant.core.Context
@ -32,5 +34,10 @@ UNDEFINED = UndefinedType._singleton # pylint: disable=protected-access
# that may rely on them. # that may rely on them.
# In due time they will be removed. # In due time they will be removed.
HomeAssistantType = homeassistant.core.HomeAssistant HomeAssistantType = homeassistant.core.HomeAssistant
EventType = homeassistant.core.Event
ServiceCallType = homeassistant.core.ServiceCall ServiceCallType = homeassistant.core.ServiceCall
class EventType(homeassistant.core.Event, Generic[_DataT]):
"""Generic Event class to better type data."""
data: _DataT # type: ignore[assignment]