Small improvements to websocket api performance (#95693)

This commit is contained in:
J. Nick Koston 2023-07-02 12:33:25 -05:00 committed by GitHub
parent 65f67669d2
commit 2aff138b92
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 103 additions and 40 deletions

View File

@ -3,12 +3,13 @@ from __future__ import annotations
from collections.abc import Callable
import datetime as dt
from functools import lru_cache
from functools import lru_cache, partial
import json
from typing import Any, cast
import voluptuous as vol
from homeassistant.auth.models import User
from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_READ
from homeassistant.const import (
EVENT_STATE_CHANGED,
@ -88,6 +89,32 @@ def pong_message(iden: int) -> dict[str, Any]:
return {"id": iden, "type": "pong"}
def _forward_events_check_permissions(
send_message: Callable[[str | dict[str, Any] | Callable[[], str]], None],
user: User,
msg_id: int,
event: Event,
) -> None:
"""Forward state changed events to websocket."""
# We have to lookup the permissions again because the user might have
# changed since the subscription was created.
permissions = user.permissions
if not permissions.access_all_entities(
POLICY_READ
) and not permissions.check_entity(event.data["entity_id"], POLICY_READ):
return
send_message(messages.cached_event_message(msg_id, event))
def _forward_events_unconditional(
send_message: Callable[[str | dict[str, Any] | Callable[[], str]], None],
msg_id: int,
event: Event,
) -> None:
"""Forward events to websocket."""
send_message(messages.cached_event_message(msg_id, event))
@callback
@decorators.websocket_command(
{
@ -109,26 +136,18 @@ def handle_subscribe_events(
raise Unauthorized
if event_type == EVENT_STATE_CHANGED:
user = connection.user
@callback
def forward_events(event: Event) -> None:
"""Forward state changed events to websocket."""
# We have to lookup the permissions again because the user might have
# changed since the subscription was created.
permissions = user.permissions
if not permissions.access_all_entities(
POLICY_READ
) and not permissions.check_entity(event.data["entity_id"], POLICY_READ):
return
connection.send_message(messages.cached_event_message(msg["id"], event))
forward_events = callback(
partial(
_forward_events_check_permissions,
connection.send_message,
connection.user,
msg["id"],
)
)
else:
@callback
def forward_events(event: Event) -> None:
"""Forward events to websocket."""
connection.send_message(messages.cached_event_message(msg["id"], event))
forward_events = callback(
partial(_forward_events_unconditional, connection.send_message, msg["id"])
)
connection.subscriptions[msg["id"]] = hass.bus.async_listen(
event_type, forward_events, run_immediately=True
@ -280,6 +299,27 @@ def _send_handle_get_states_response(
connection.send_message(construct_result_message(msg_id, f"[{joined_states}]"))
def _forward_entity_changes(
send_message: Callable[[str | dict[str, Any] | Callable[[], str]], None],
entity_ids: set[str],
user: User,
msg_id: int,
event: Event,
) -> None:
"""Forward entity state changed events to websocket."""
entity_id = event.data["entity_id"]
if entity_ids and entity_id not in entity_ids:
return
# We have to lookup the permissions again because the user might have
# changed since the subscription was created.
permissions = user.permissions
if not permissions.access_all_entities(
POLICY_READ
) and not permissions.check_entity(event.data["entity_id"], POLICY_READ):
return
send_message(messages.cached_state_diff_message(msg_id, event))
@callback
@decorators.websocket_command(
{
@ -292,29 +332,22 @@ def handle_subscribe_entities(
) -> None:
"""Handle subscribe entities command."""
entity_ids = set(msg.get("entity_ids", []))
user = connection.user
@callback
def forward_entity_changes(event: Event) -> None:
"""Forward entity state changed events to websocket."""
entity_id = event.data["entity_id"]
if entity_ids and entity_id not in entity_ids:
return
# We have to lookup the permissions again because the user might have
# changed since the subscription was created.
permissions = user.permissions
if not permissions.access_all_entities(
POLICY_READ
) and not permissions.check_entity(event.data["entity_id"], POLICY_READ):
return
connection.send_message(messages.cached_state_diff_message(msg["id"], event))
# We must never await between sending the states and listening for
# state changed events or we will introduce a race condition
# where some states are missed
states = _async_get_allowed_states(hass, connection)
connection.subscriptions[msg["id"]] = hass.bus.async_listen(
EVENT_STATE_CHANGED, forward_entity_changes, run_immediately=True
EVENT_STATE_CHANGED,
callback(
partial(
_forward_entity_changes,
connection.send_message,
entity_ids,
connection.user,
msg["id"],
)
),
run_immediately=True,
)
connection.send_result(msg["id"])

View File

@ -33,6 +33,20 @@ BinaryHandler = Callable[[HomeAssistant, "ActiveConnection", bytes], None]
class ActiveConnection:
"""Handle an active websocket client connection."""
__slots__ = (
"logger",
"hass",
"send_message",
"user",
"refresh_token_id",
"subscriptions",
"last_id",
"can_coalesce",
"supported_features",
"handlers",
"binary_handlers",
)
def __init__(
self,
logger: WebSocketAdapter,

View File

@ -63,6 +63,21 @@ class WebSocketAdapter(logging.LoggerAdapter):
class WebSocketHandler:
"""Handle an active websocket client connection."""
__slots__ = (
"_hass",
"_request",
"_wsock",
"_handle_task",
"_writer_task",
"_closing",
"_authenticated",
"_logger",
"_peak_checker_unsub",
"_connection",
"_message_queue",
"_ready_future",
)
def __init__(self, hass: HomeAssistant, request: web.Request) -> None:
"""Initialize an active connection."""
self._hass = hass
@ -201,8 +216,9 @@ class WebSocketHandler:
return
message_queue.append(message)
if self._ready_future and not self._ready_future.done():
self._ready_future.set_result(None)
ready_future = self._ready_future
if ready_future and not ready_future.done():
ready_future.set_result(None)
peak_checker_active = self._peak_checker_unsub is not None