Add Event typing to websocket_api for entity subscriptions (#112786)

This commit is contained in:
J. Nick Koston 2024-03-08 17:29:46 -10:00 committed by GitHub
parent ed3ec85e55
commit a66399ad3d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 9 deletions

View File

@ -367,7 +367,7 @@ def _forward_entity_changes(
entity_ids: set[str], entity_ids: set[str],
user: User, user: User,
msg_id: int, msg_id: int,
event: Event, event: Event[EventStateChangedData],
) -> None: ) -> None:
"""Forward entity state changed events to websocket.""" """Forward entity state changed events to websocket."""
entity_id = event.data["entity_id"] entity_id = event.data["entity_id"]

View File

@ -4,7 +4,7 @@ from __future__ import annotations
from functools import lru_cache from functools import lru_cache
import logging import logging
from typing import TYPE_CHECKING, Any, Final, cast from typing import Any, Final
import voluptuous as vol import voluptuous as vol
@ -17,6 +17,7 @@ from homeassistant.const import (
) )
from homeassistant.core import Event, State from homeassistant.core import Event, State
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.event import EventStateChangedData
from homeassistant.helpers.json import ( from homeassistant.helpers.json import (
JSON_DUMP, JSON_DUMP,
find_paths_unserializable_data, find_paths_unserializable_data,
@ -141,7 +142,7 @@ def _partial_cached_event_message(event: Event) -> bytes:
) )
def cached_state_diff_message(iden: int, event: Event) -> bytes: def cached_state_diff_message(iden: int, event: Event[EventStateChangedData]) -> bytes:
"""Return an event message. """Return an event message.
Serialize to json once per message. Serialize to json once per message.
@ -161,7 +162,7 @@ def cached_state_diff_message(iden: int, event: Event) -> bytes:
@lru_cache(maxsize=128) @lru_cache(maxsize=128)
def _partial_cached_state_diff_message(event: Event) -> bytes: def _partial_cached_state_diff_message(event: Event[EventStateChangedData]) -> bytes:
"""Cache and serialize the event to json. """Cache and serialize the event to json.
The message is constructed without the id which The message is constructed without the id which
@ -175,7 +176,7 @@ def _partial_cached_state_diff_message(event: Event) -> bytes:
) )
def _state_diff_event(event: Event) -> dict: def _state_diff_event(event: Event[EventStateChangedData]) -> dict:
"""Convert a state_changed event to the minimal version. """Convert a state_changed event to the minimal version.
State update example State update example
@ -188,16 +189,12 @@ def _state_diff_event(event: Event) -> dict:
""" """
if (event_new_state := event.data["new_state"]) is None: if (event_new_state := event.data["new_state"]) is None:
return {ENTITY_EVENT_REMOVE: [event.data["entity_id"]]} return {ENTITY_EVENT_REMOVE: [event.data["entity_id"]]}
if TYPE_CHECKING:
event_new_state = cast(State, event_new_state)
if (event_old_state := event.data["old_state"]) is None: if (event_old_state := event.data["old_state"]) is None:
return { return {
ENTITY_EVENT_ADD: { ENTITY_EVENT_ADD: {
event_new_state.entity_id: event_new_state.as_compressed_state event_new_state.entity_id: event_new_state.as_compressed_state
} }
} }
if TYPE_CHECKING:
event_old_state = cast(State, event_old_state)
return _state_diff(event_old_state, event_new_state) return _state_diff(event_old_state, event_new_state)