From 0ca67233788038d0d00dfa4c3cd4de8e707f5b85 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Wed, 22 Mar 2023 08:36:36 -0400 Subject: [PATCH] Allow passing binary to the WS connection (#89882) * Allow passing binary to the WS connection * Expand test coverage * Test non-existing handler * Allow signaling end of stream using empty payloads * Store handlers in a list * Handle binary handlers raising exceptions --- .../components/websocket_api/connection.py | 60 ++++++++++++- .../components/websocket_api/http.py | 9 ++ .../websocket_api/test_connection.py | 24 ++++++ tests/components/websocket_api/test_http.py | 85 ++++++++++++++++++- 4 files changed, 175 insertions(+), 3 deletions(-) diff --git a/homeassistant/components/websocket_api/connection.py b/homeassistant/components/websocket_api/connection.py index 08d05314521..f91cc3a827a 100644 --- a/homeassistant/components/websocket_api/connection.py +++ b/homeassistant/components/websocket_api/connection.py @@ -25,6 +25,9 @@ current_connection = ContextVar["ActiveConnection | None"]( "current_connection", default=None ) +MessageHandler = Callable[[HomeAssistant, "ActiveConnection", dict[str, Any]], None] +BinaryHandler = Callable[[HomeAssistant, "ActiveConnection", bytes], None] + class ActiveConnection: """Handle an active websocket client connection.""" @@ -46,7 +49,10 @@ class ActiveConnection: self.subscriptions: dict[Hashable, Callable[[], Any]] = {} self.last_id = 0 self.supported_features: dict[str, float] = {} - self.handlers = self.hass.data[const.DOMAIN] + self.handlers: dict[str, tuple[MessageHandler, vol.Schema]] = self.hass.data[ + const.DOMAIN + ] + self.binary_handlers: list[BinaryHandler | None] = [] current_connection.set(self) def get_description(self, request: web.Request | None) -> str: @@ -60,6 +66,38 @@ class ActiveConnection: """Return a context.""" return Context(user_id=self.user.id) + @callback + def async_register_binary_handler( + self, handler: BinaryHandler + ) -> tuple[int, Callable[[], None]]: + """Register a temporary binary handler for this connection. + + Returns a binary handler_id (1 byte) and a callback to unregister the handler. + """ + if len(self.binary_handlers) < 255: + index = len(self.binary_handlers) + self.binary_handlers.append(None) + else: + # Once the list is full, we search for a None entry to reuse. + index = None + for idx, existing in enumerate(self.binary_handlers): + if existing is None: + index = idx + break + + if index is None: + raise RuntimeError("Too many binary handlers registered") + + self.binary_handlers[index] = handler + + @callback + def unsub() -> None: + """Unregister the handler.""" + assert index is not None + self.binary_handlers[index] = None + + return index + 1, unsub + @callback def send_result(self, msg_id: int, result: Any | None = None) -> None: """Send a result message.""" @@ -75,6 +113,26 @@ class ActiveConnection: """Send a error message.""" self.send_message(messages.error_message(msg_id, code, message)) + @callback + def async_handle_binary(self, handler_id: int, payload: bytes) -> None: + """Handle a single incoming binary message.""" + index = handler_id - 1 + if ( + index < 0 + or index >= len(self.binary_handlers) + or (handler := self.binary_handlers[index]) is None + ): + self.logger.error( + "Received binary message for non-existing handler %s", handler_id + ) + return + + try: + handler(self.hass, self, payload) + except Exception: # pylint: disable=broad-except + self.logger.exception("Error handling binary message") + self.binary_handlers[index] = None + @callback def async_handle(self, msg: dict[str, Any]) -> None: """Handle a single incoming message.""" diff --git a/homeassistant/components/websocket_api/http.py b/homeassistant/components/websocket_api/http.py index de0b23e4957..75eccc7aba9 100644 --- a/homeassistant/components/websocket_api/http.py +++ b/homeassistant/components/websocket_api/http.py @@ -312,6 +312,15 @@ class WebSocketHandler: if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING): break + if msg.type == WSMsgType.BINARY: + if len(msg.data) < 1: + disconnect_warn = "Received invalid binary message." + break + handler = msg.data[0] + payload = msg.data[1:] + connection.async_handle_binary(handler, payload) + continue + if msg.type != WSMsgType.TEXT: disconnect_warn = "Received non-Text message." break diff --git a/tests/components/websocket_api/test_connection.py b/tests/components/websocket_api/test_connection.py index 53baab98b4f..da435d64d58 100644 --- a/tests/components/websocket_api/test_connection.py +++ b/tests/components/websocket_api/test_connection.py @@ -101,3 +101,27 @@ async def test_exception_handling( assert send_messages[0]["error"]["code"] == code assert send_messages[0]["error"]["message"] == err assert log in caplog.text + + +async def test_binary_handler_registration() -> None: + """Test binary handler registration.""" + connection = websocket_api.ActiveConnection( + None, Mock(data={websocket_api.DOMAIN: None}), None, None, Mock() + ) + + # One filler to align indexes with prefix numbers + unsubs = [None] + fake_handler = object() + for i in range(255): + prefix, unsub = connection.async_register_binary_handler(fake_handler) + assert prefix == i + 1 + unsubs.append(unsub) + + with pytest.raises(RuntimeError): + connection.async_register_binary_handler(None) + + unsubs[15]() + + # Verify we reuse an unsubscribed prefix + prefix, unsub = connection.async_register_binary_handler(None) + assert prefix == 15 diff --git a/tests/components/websocket_api/test_http.py b/tests/components/websocket_api/test_http.py index fce6eb428ae..475fbeee765 100644 --- a/tests/components/websocket_api/test_http.py +++ b/tests/components/websocket_api/test_http.py @@ -1,13 +1,20 @@ """Test Websocket API http module.""" import asyncio from datetime import timedelta +from typing import Any from unittest.mock import patch from aiohttp import ServerDisconnectedError, WSMsgType, web import pytest -from homeassistant.components.websocket_api import const, http -from homeassistant.core import HomeAssistant +from homeassistant.components.websocket_api import ( + async_register_command, + const, + http, + websocket_command, +) +from homeassistant.components.websocket_api.connection import ActiveConnection +from homeassistant.core import HomeAssistant, callback from homeassistant.util.dt import utcnow from tests.common import async_fire_time_changed @@ -155,3 +162,77 @@ async def test_prepare_fail( await hass_ws_client(hass) assert "Timeout preparing request" in caplog.text + + +async def test_binary_message( + hass: HomeAssistant, websocket_client, caplog: pytest.LogCaptureFixture +) -> None: + """Test binary messages.""" + binary_payloads = { + 104: ([], asyncio.Future()), + 105: ([], asyncio.Future()), + } + + # Register a handler + @callback + @websocket_command( + { + "type": "get_binary_message_handler", + } + ) + def get_binary_message_handler( + hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] + ): + unsub = None + + @callback + def binary_message_handler( + hass: HomeAssistant, connection: ActiveConnection, payload: bytes + ): + nonlocal unsub + if msg["id"] == 103: + raise ValueError("Boom") + + if payload: + binary_payloads[msg["id"]][0].append(payload) + else: + binary_payloads[msg["id"]][1].set_result( + b"".join(binary_payloads[msg["id"]][0]) + ) + unsub() + + prefix, unsub = connection.async_register_binary_handler(binary_message_handler) + + connection.send_result(msg["id"], {"prefix": prefix}) + + async_register_command(hass, get_binary_message_handler) + + # Register multiple binary handlers + for i in range(101, 106): + await websocket_client.send_json( + {"id": i, "type": "get_binary_message_handler"} + ) + result = await websocket_client.receive_json() + assert result["id"] == i + assert result["type"] == const.TYPE_RESULT + assert result["success"] + assert result["result"]["prefix"] == i - 100 + + # Send message to binary + await websocket_client.send_bytes((0).to_bytes(1, "big") + b"test0") + await websocket_client.send_bytes((3).to_bytes(1, "big") + b"test3") + await websocket_client.send_bytes((3).to_bytes(1, "big") + b"test3") + await websocket_client.send_bytes((10).to_bytes(1, "big") + b"test10") + await websocket_client.send_bytes((4).to_bytes(1, "big") + b"test4") + await websocket_client.send_bytes((4).to_bytes(1, "big") + b"") + await websocket_client.send_bytes((5).to_bytes(1, "big") + b"test5") + await websocket_client.send_bytes((5).to_bytes(1, "big") + b"test5-2") + await websocket_client.send_bytes((5).to_bytes(1, "big") + b"") + + # Verify received + assert await binary_payloads[104][1] == b"test4" + assert await binary_payloads[105][1] == b"test5test5-2" + assert "Error handling binary message" in caplog.text + assert "Received binary message for non-existing handler 0" in caplog.text + assert "Received binary message for non-existing handler 3" in caplog.text + assert "Received binary message for non-existing handler 10" in caplog.text