diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index 2b146d94472..dc6cc84c09c 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -715,7 +715,7 @@ def handle_supported_features( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Handle setting supported features.""" - connection.supported_features = msg["features"] + connection.set_supported_features(msg["features"]) connection.send_result(msg["id"]) diff --git a/homeassistant/components/websocket_api/connection.py b/homeassistant/components/websocket_api/connection.py index f91cc3a827a..1f6fd302968 100644 --- a/homeassistant/components/websocket_api/connection.py +++ b/homeassistant/components/websocket_api/connection.py @@ -48,6 +48,7 @@ class ActiveConnection: self.refresh_token_id = refresh_token.id self.subscriptions: dict[Hashable, Callable[[], Any]] = {} self.last_id = 0 + self.can_coalesce = False self.supported_features: dict[str, float] = {} self.handlers: dict[str, tuple[MessageHandler, vol.Schema]] = self.hass.data[ const.DOMAIN @@ -55,6 +56,11 @@ class ActiveConnection: self.binary_handlers: list[BinaryHandler | None] = [] current_connection.set(self) + def set_supported_features(self, features: dict[str, float]) -> None: + """Set supported features.""" + self.supported_features = features + self.can_coalesce = const.FEATURE_COALESCE_MESSAGES in features + def get_description(self, request: web.Request | None) -> str: """Return a description of the connection.""" description = self.user.name or "" diff --git a/homeassistant/components/websocket_api/http.py b/homeassistant/components/websocket_api/http.py index 75eccc7aba9..7671d0ceaef 100644 --- a/homeassistant/components/websocket_api/http.py +++ b/homeassistant/components/websocket_api/http.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio +from collections import deque from collections.abc import Callable from contextlib import suppress import datetime as dt @@ -22,7 +23,6 @@ from .auth import AuthPhase, auth_required_message from .const import ( CANCELLATION_ERRORS, DATA_CONNECTIONS, - FEATURE_COALESCE_MESSAGES, MAX_PENDING_MSG, PENDING_MSG_PEAK, PENDING_MSG_PEAK_TIME, @@ -71,7 +71,6 @@ class WebSocketHandler: self.hass = hass self.request = request self.wsock = web.WebSocketResponse(heartbeat=55) - self._to_write: asyncio.Queue = asyncio.Queue(maxsize=MAX_PENDING_MSG) self._handle_task: asyncio.Task | None = None self._writer_task: asyncio.Task | None = None self._closing: bool = False @@ -79,6 +78,13 @@ class WebSocketHandler: self._peak_checker_unsub: Callable[[], None] | None = None self.connection: ActiveConnection | None = None + # The WebSocketHandler has a single consumer and path + # to where messages are queued. This allows the implementation + # to use a deque and an asyncio.Future to avoid the overhead of + # an asyncio.Queue. + self._message_queue: deque = deque() + self._ready_future: asyncio.Future[None] | None = None + @property def description(self) -> str: """Return a description of the connection.""" @@ -88,39 +94,52 @@ class WebSocketHandler: async def _writer(self) -> None: """Write outgoing messages.""" - # Exceptions if Socket disconnected or cancelled by connection handler - to_write = self._to_write + # Variables are set locally to avoid lookups in the loop + message_queue = self._message_queue logger = self._logger - wsock = self.wsock + send_str = self.wsock.send_str + loop = self.hass.loop + debug = logger.debug + # Exceptions if Socket disconnected or cancelled by connection handler try: with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS): while not self.wsock.closed: - if (process := await to_write.get()) is None: + if (messages_remaining := len(message_queue)) == 0: + self._ready_future = loop.create_future() + await self._ready_future + messages_remaining = len(message_queue) + + # A None message is used to signal the end of the connection + if (process := message_queue.popleft()) is None: return + + messages_remaining -= 1 message = process if isinstance(process, str) else process() + if ( - to_write.empty() + not messages_remaining or not self.connection - or FEATURE_COALESCE_MESSAGES - not in self.connection.supported_features + or not self.connection.can_coalesce ): - logger.debug("Sending %s", message) - await wsock.send_str(message) + debug("Sending %s", message) + await send_str(message) continue messages: list[str] = [message] - while not to_write.empty(): - if (process := to_write.get_nowait()) is None: + while messages_remaining: + # A None message is used to signal the end of the connection + if (process := message_queue.popleft()) is None: return messages.append( process if isinstance(process, str) else process() ) + messages_remaining -= 1 coalesced_messages = "[" + ",".join(messages) + "]" - logger.debug("Sending %s", coalesced_messages) - await wsock.send_str(coalesced_messages) + debug("Sending %s", coalesced_messages) + await send_str(coalesced_messages) finally: - # Clean up the peaker checker when we shut down the writer + # Clean up the peak checker when we shut down the writer self._cancel_peak_checker() @callback @@ -146,11 +165,9 @@ class WebSocketHandler: if isinstance(message, dict): message = message_to_json(message) - to_write = self._to_write - - try: - to_write.put_nowait(message) - except asyncio.QueueFull: + message_queue = self._message_queue + queue_size_before_add = len(message_queue) + if queue_size_before_add >= MAX_PENDING_MSG: self._logger.error( ( "%s: Client unable to keep up with pending messages. Reached %s pending" @@ -162,10 +179,15 @@ class WebSocketHandler: message, ) self._cancel() + return + + message_queue.append(message) + if self._ready_future and not self._ready_future.done(): + self._ready_future.set_result(None) peak_checker_active = self._peak_checker_unsub is not None - if to_write.qsize() < PENDING_MSG_PEAK: + if queue_size_before_add <= PENDING_MSG_PEAK: if peak_checker_active: self._cancel_peak_checker() return @@ -180,7 +202,7 @@ class WebSocketHandler: """Check that we are no longer above the write peak.""" self._peak_checker_unsub = None - if self._to_write.qsize() < PENDING_MSG_PEAK: + if len(self._message_queue) < PENDING_MSG_PEAK: return self._logger.error( @@ -199,6 +221,7 @@ class WebSocketHandler: def _cancel(self) -> None: """Cancel the connection.""" self._closing = True + self._cancel_peak_checker() if self._handle_task is not None: self._handle_task.cancel() if self._writer_task is not None: @@ -356,14 +379,14 @@ class WebSocketHandler: self._closing = True + self._message_queue.append(None) + if self._ready_future and not self._ready_future.done(): + self._ready_future.set_result(None) + try: - self._to_write.put_nowait(None) # Make sure all error messages are written before closing await self._writer_task await wsock.close() - except asyncio.QueueFull: # can be raised by put_nowait - self._writer_task.cancel() - finally: if disconnect_warn is None: self._logger.debug("Disconnected") diff --git a/tests/components/websocket_api/test_http.py b/tests/components/websocket_api/test_http.py index 475fbeee765..8e47e7fca2e 100644 --- a/tests/components/websocket_api/test_http.py +++ b/tests/components/websocket_api/test_http.py @@ -1,7 +1,7 @@ """Test Websocket API http module.""" import asyncio from datetime import timedelta -from typing import Any +from typing import Any, cast from unittest.mock import patch from aiohttp import ServerDisconnectedError, WSMsgType, web @@ -53,12 +53,12 @@ async def test_pending_msg_peak( ) -> None: """Test pending msg overflow command.""" orig_handler = http.WebSocketHandler - instance = None + setup_instance: http.WebSocketHandler | None = None def instantiate_handler(*args): - nonlocal instance - instance = orig_handler(*args) - return instance + nonlocal setup_instance + setup_instance = orig_handler(*args) + return setup_instance with patch( "homeassistant.components.websocket_api.http.WebSocketHandler", @@ -66,12 +66,11 @@ async def test_pending_msg_peak( ): websocket_client = await hass_ws_client() - # Kill writer task and fill queue past peak - for _ in range(5): - instance._to_write.put_nowait(None) + instance: http.WebSocketHandler = cast(http.WebSocketHandler, setup_instance) - # Trigger the peak check - instance._send_message({}) + # Fill the queue past the allowed peak + for _ in range(10): + instance._send_message({}) async_fire_time_changed( hass, utcnow() + timedelta(seconds=const.PENDING_MSG_PEAK_TIME + 1) @@ -79,8 +78,54 @@ async def test_pending_msg_peak( msg = await websocket_client.receive() assert msg.type == WSMsgType.close - assert "Client unable to keep up with pending messages" in caplog.text + assert "Stayed over 5 for 5 seconds" + + +async def test_pending_msg_peak_recovery( + hass: HomeAssistant, + mock_low_peak, + hass_ws_client: WebSocketGenerator, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test pending msg nears the peak but recovers.""" + orig_handler = http.WebSocketHandler + setup_instance: http.WebSocketHandler | None = None + + def instantiate_handler(*args): + nonlocal setup_instance + setup_instance = orig_handler(*args) + return setup_instance + + with patch( + "homeassistant.components.websocket_api.http.WebSocketHandler", + instantiate_handler, + ): + websocket_client = await hass_ws_client() + + instance: http.WebSocketHandler = cast(http.WebSocketHandler, setup_instance) + + # Make sure the call later is started + for _ in range(10): + instance._send_message({}) + + for _ in range(10): + msg = await websocket_client.receive() + assert msg.type == WSMsgType.TEXT + + instance._send_message({}) + msg = await websocket_client.receive() + assert msg.type == WSMsgType.TEXT + + # Cleanly shutdown + instance._send_message({}) + instance._handle_task.cancel() + + msg = await websocket_client.receive() + assert msg.type == WSMsgType.TEXT + msg = await websocket_client.receive() + assert msg.type == WSMsgType.close + assert "Client unable to keep up with pending messages" not in caplog.text async def test_pending_msg_peak_but_does_not_overflow( @@ -91,12 +136,12 @@ async def test_pending_msg_peak_but_does_not_overflow( ) -> None: """Test pending msg hits the low peak but recovers and does not overflow.""" orig_handler = http.WebSocketHandler - instance: http.WebSocketHandler | None = None + setup_instance: http.WebSocketHandler | None = None def instantiate_handler(*args): - nonlocal instance - instance = orig_handler(*args) - return instance + nonlocal setup_instance + setup_instance = orig_handler(*args) + return setup_instance with patch( "homeassistant.components.websocket_api.http.WebSocketHandler", @@ -104,18 +149,17 @@ async def test_pending_msg_peak_but_does_not_overflow( ): websocket_client = await hass_ws_client() - assert instance is not None + instance: http.WebSocketHandler = cast(http.WebSocketHandler, setup_instance) # Kill writer task and fill queue past peak for _ in range(5): - instance._to_write.put_nowait(None) + instance._message_queue.append(None) # Trigger the peak check instance._send_message({}) # Clear the queue - while instance._to_write.qsize() > 0: - instance._to_write.get_nowait() + instance._message_queue.clear() # Trigger the peak clear instance._send_message({})