Small improvements to websocket api performance (#95693)

This commit is contained in:
J. Nick Koston 2023-07-02 12:33:25 -05:00 committed by GitHub
parent 65f67669d2
commit 2aff138b92
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 103 additions and 40 deletions

View File

@ -3,12 +3,13 @@ from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
import datetime as dt import datetime as dt
from functools import lru_cache from functools import lru_cache, partial
import json import json
from typing import Any, cast from typing import Any, cast
import voluptuous as vol import voluptuous as vol
from homeassistant.auth.models import User
from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_READ from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_READ
from homeassistant.const import ( from homeassistant.const import (
EVENT_STATE_CHANGED, EVENT_STATE_CHANGED,
@ -88,6 +89,32 @@ def pong_message(iden: int) -> dict[str, Any]:
return {"id": iden, "type": "pong"} return {"id": iden, "type": "pong"}
def _forward_events_check_permissions(
send_message: Callable[[str | dict[str, Any] | Callable[[], str]], None],
user: User,
msg_id: int,
event: Event,
) -> None:
"""Forward state changed events to websocket."""
# We have to lookup the permissions again because the user might have
# changed since the subscription was created.
permissions = user.permissions
if not permissions.access_all_entities(
POLICY_READ
) and not permissions.check_entity(event.data["entity_id"], POLICY_READ):
return
send_message(messages.cached_event_message(msg_id, event))
def _forward_events_unconditional(
send_message: Callable[[str | dict[str, Any] | Callable[[], str]], None],
msg_id: int,
event: Event,
) -> None:
"""Forward events to websocket."""
send_message(messages.cached_event_message(msg_id, event))
@callback @callback
@decorators.websocket_command( @decorators.websocket_command(
{ {
@ -109,26 +136,18 @@ def handle_subscribe_events(
raise Unauthorized raise Unauthorized
if event_type == EVENT_STATE_CHANGED: if event_type == EVENT_STATE_CHANGED:
user = connection.user forward_events = callback(
partial(
@callback _forward_events_check_permissions,
def forward_events(event: Event) -> None: connection.send_message,
"""Forward state changed events to websocket.""" connection.user,
# We have to lookup the permissions again because the user might have msg["id"],
# changed since the subscription was created. )
permissions = user.permissions )
if not permissions.access_all_entities(
POLICY_READ
) and not permissions.check_entity(event.data["entity_id"], POLICY_READ):
return
connection.send_message(messages.cached_event_message(msg["id"], event))
else: else:
forward_events = callback(
@callback partial(_forward_events_unconditional, connection.send_message, msg["id"])
def forward_events(event: Event) -> None: )
"""Forward events to websocket."""
connection.send_message(messages.cached_event_message(msg["id"], event))
connection.subscriptions[msg["id"]] = hass.bus.async_listen( connection.subscriptions[msg["id"]] = hass.bus.async_listen(
event_type, forward_events, run_immediately=True event_type, forward_events, run_immediately=True
@ -280,6 +299,27 @@ def _send_handle_get_states_response(
connection.send_message(construct_result_message(msg_id, f"[{joined_states}]")) connection.send_message(construct_result_message(msg_id, f"[{joined_states}]"))
def _forward_entity_changes(
send_message: Callable[[str | dict[str, Any] | Callable[[], str]], None],
entity_ids: set[str],
user: User,
msg_id: int,
event: Event,
) -> None:
"""Forward entity state changed events to websocket."""
entity_id = event.data["entity_id"]
if entity_ids and entity_id not in entity_ids:
return
# We have to lookup the permissions again because the user might have
# changed since the subscription was created.
permissions = user.permissions
if not permissions.access_all_entities(
POLICY_READ
) and not permissions.check_entity(event.data["entity_id"], POLICY_READ):
return
send_message(messages.cached_state_diff_message(msg_id, event))
@callback @callback
@decorators.websocket_command( @decorators.websocket_command(
{ {
@ -292,29 +332,22 @@ def handle_subscribe_entities(
) -> None: ) -> None:
"""Handle subscribe entities command.""" """Handle subscribe entities command."""
entity_ids = set(msg.get("entity_ids", [])) entity_ids = set(msg.get("entity_ids", []))
user = connection.user
@callback
def forward_entity_changes(event: Event) -> None:
"""Forward entity state changed events to websocket."""
entity_id = event.data["entity_id"]
if entity_ids and entity_id not in entity_ids:
return
# We have to lookup the permissions again because the user might have
# changed since the subscription was created.
permissions = user.permissions
if not permissions.access_all_entities(
POLICY_READ
) and not permissions.check_entity(event.data["entity_id"], POLICY_READ):
return
connection.send_message(messages.cached_state_diff_message(msg["id"], event))
# We must never await between sending the states and listening for # We must never await between sending the states and listening for
# state changed events or we will introduce a race condition # state changed events or we will introduce a race condition
# where some states are missed # where some states are missed
states = _async_get_allowed_states(hass, connection) states = _async_get_allowed_states(hass, connection)
connection.subscriptions[msg["id"]] = hass.bus.async_listen( connection.subscriptions[msg["id"]] = hass.bus.async_listen(
EVENT_STATE_CHANGED, forward_entity_changes, run_immediately=True EVENT_STATE_CHANGED,
callback(
partial(
_forward_entity_changes,
connection.send_message,
entity_ids,
connection.user,
msg["id"],
)
),
run_immediately=True,
) )
connection.send_result(msg["id"]) connection.send_result(msg["id"])

View File

@ -33,6 +33,20 @@ BinaryHandler = Callable[[HomeAssistant, "ActiveConnection", bytes], None]
class ActiveConnection: class ActiveConnection:
"""Handle an active websocket client connection.""" """Handle an active websocket client connection."""
__slots__ = (
"logger",
"hass",
"send_message",
"user",
"refresh_token_id",
"subscriptions",
"last_id",
"can_coalesce",
"supported_features",
"handlers",
"binary_handlers",
)
def __init__( def __init__(
self, self,
logger: WebSocketAdapter, logger: WebSocketAdapter,

View File

@ -63,6 +63,21 @@ class WebSocketAdapter(logging.LoggerAdapter):
class WebSocketHandler: class WebSocketHandler:
"""Handle an active websocket client connection.""" """Handle an active websocket client connection."""
__slots__ = (
"_hass",
"_request",
"_wsock",
"_handle_task",
"_writer_task",
"_closing",
"_authenticated",
"_logger",
"_peak_checker_unsub",
"_connection",
"_message_queue",
"_ready_future",
)
def __init__(self, hass: HomeAssistant, request: web.Request) -> None: def __init__(self, hass: HomeAssistant, request: web.Request) -> None:
"""Initialize an active connection.""" """Initialize an active connection."""
self._hass = hass self._hass = hass
@ -201,8 +216,9 @@ class WebSocketHandler:
return return
message_queue.append(message) message_queue.append(message)
if self._ready_future and not self._ready_future.done(): ready_future = self._ready_future
self._ready_future.set_result(None) if ready_future and not ready_future.done():
ready_future.set_result(None)
peak_checker_active = self._peak_checker_unsub is not None peak_checker_active = self._peak_checker_unsub is not None