diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index 619fc913e09..c733a96ca9d 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -3,12 +3,13 @@ from __future__ import annotations from collections.abc import Callable import datetime as dt -from functools import lru_cache +from functools import lru_cache, partial import json from typing import Any, cast import voluptuous as vol +from homeassistant.auth.models import User from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_READ from homeassistant.const import ( EVENT_STATE_CHANGED, @@ -88,6 +89,32 @@ def pong_message(iden: int) -> dict[str, Any]: return {"id": iden, "type": "pong"} +def _forward_events_check_permissions( + send_message: Callable[[str | dict[str, Any] | Callable[[], str]], None], + user: User, + msg_id: int, + event: Event, +) -> None: + """Forward state changed events to websocket.""" + # We have to lookup the permissions again because the user might have + # changed since the subscription was created. + permissions = user.permissions + if not permissions.access_all_entities( + POLICY_READ + ) and not permissions.check_entity(event.data["entity_id"], POLICY_READ): + return + send_message(messages.cached_event_message(msg_id, event)) + + +def _forward_events_unconditional( + send_message: Callable[[str | dict[str, Any] | Callable[[], str]], None], + msg_id: int, + event: Event, +) -> None: + """Forward events to websocket.""" + send_message(messages.cached_event_message(msg_id, event)) + + @callback @decorators.websocket_command( { @@ -109,26 +136,18 @@ def handle_subscribe_events( raise Unauthorized if event_type == EVENT_STATE_CHANGED: - user = connection.user - - @callback - def forward_events(event: Event) -> None: - """Forward state changed events to websocket.""" - # We have to lookup the permissions again because the user might have - # changed since the subscription was created. - permissions = user.permissions - if not permissions.access_all_entities( - POLICY_READ - ) and not permissions.check_entity(event.data["entity_id"], POLICY_READ): - return - connection.send_message(messages.cached_event_message(msg["id"], event)) - + forward_events = callback( + partial( + _forward_events_check_permissions, + connection.send_message, + connection.user, + msg["id"], + ) + ) else: - - @callback - def forward_events(event: Event) -> None: - """Forward events to websocket.""" - connection.send_message(messages.cached_event_message(msg["id"], event)) + forward_events = callback( + partial(_forward_events_unconditional, connection.send_message, msg["id"]) + ) connection.subscriptions[msg["id"]] = hass.bus.async_listen( event_type, forward_events, run_immediately=True @@ -280,6 +299,27 @@ def _send_handle_get_states_response( connection.send_message(construct_result_message(msg_id, f"[{joined_states}]")) +def _forward_entity_changes( + send_message: Callable[[str | dict[str, Any] | Callable[[], str]], None], + entity_ids: set[str], + user: User, + msg_id: int, + event: Event, +) -> None: + """Forward entity state changed events to websocket.""" + entity_id = event.data["entity_id"] + if entity_ids and entity_id not in entity_ids: + return + # We have to lookup the permissions again because the user might have + # changed since the subscription was created. + permissions = user.permissions + if not permissions.access_all_entities( + POLICY_READ + ) and not permissions.check_entity(event.data["entity_id"], POLICY_READ): + return + send_message(messages.cached_state_diff_message(msg_id, event)) + + @callback @decorators.websocket_command( { @@ -292,29 +332,22 @@ def handle_subscribe_entities( ) -> None: """Handle subscribe entities command.""" entity_ids = set(msg.get("entity_ids", [])) - user = connection.user - - @callback - def forward_entity_changes(event: Event) -> None: - """Forward entity state changed events to websocket.""" - entity_id = event.data["entity_id"] - if entity_ids and entity_id not in entity_ids: - return - # We have to lookup the permissions again because the user might have - # changed since the subscription was created. - permissions = user.permissions - if not permissions.access_all_entities( - POLICY_READ - ) and not permissions.check_entity(event.data["entity_id"], POLICY_READ): - return - connection.send_message(messages.cached_state_diff_message(msg["id"], event)) - # We must never await between sending the states and listening for # state changed events or we will introduce a race condition # where some states are missed states = _async_get_allowed_states(hass, connection) connection.subscriptions[msg["id"]] = hass.bus.async_listen( - EVENT_STATE_CHANGED, forward_entity_changes, run_immediately=True + EVENT_STATE_CHANGED, + callback( + partial( + _forward_entity_changes, + connection.send_message, + entity_ids, + connection.user, + msg["id"], + ) + ), + run_immediately=True, ) connection.send_result(msg["id"]) diff --git a/homeassistant/components/websocket_api/connection.py b/homeassistant/components/websocket_api/connection.py index 319188dae21..a554001970b 100644 --- a/homeassistant/components/websocket_api/connection.py +++ b/homeassistant/components/websocket_api/connection.py @@ -33,6 +33,20 @@ BinaryHandler = Callable[[HomeAssistant, "ActiveConnection", bytes], None] class ActiveConnection: """Handle an active websocket client connection.""" + __slots__ = ( + "logger", + "hass", + "send_message", + "user", + "refresh_token_id", + "subscriptions", + "last_id", + "can_coalesce", + "supported_features", + "handlers", + "binary_handlers", + ) + def __init__( self, logger: WebSocketAdapter, diff --git a/homeassistant/components/websocket_api/http.py b/homeassistant/components/websocket_api/http.py index 6ac0e10a76c..728405b5d96 100644 --- a/homeassistant/components/websocket_api/http.py +++ b/homeassistant/components/websocket_api/http.py @@ -63,6 +63,21 @@ class WebSocketAdapter(logging.LoggerAdapter): class WebSocketHandler: """Handle an active websocket client connection.""" + __slots__ = ( + "_hass", + "_request", + "_wsock", + "_handle_task", + "_writer_task", + "_closing", + "_authenticated", + "_logger", + "_peak_checker_unsub", + "_connection", + "_message_queue", + "_ready_future", + ) + def __init__(self, hass: HomeAssistant, request: web.Request) -> None: """Initialize an active connection.""" self._hass = hass @@ -201,8 +216,9 @@ class WebSocketHandler: return message_queue.append(message) - if self._ready_future and not self._ready_future.done(): - self._ready_future.set_result(None) + ready_future = self._ready_future + if ready_future and not ready_future.done(): + ready_future.set_result(None) peak_checker_active = self._peak_checker_unsub is not None