diff --git a/supervisor/api/proxy.py b/supervisor/api/proxy.py index d47efa8eb..26e85bdef 100644 --- a/supervisor/api/proxy.py +++ b/supervisor/api/proxy.py @@ -10,10 +10,11 @@ from aiohttp import WSMessageTypeError, web from aiohttp.client_exceptions import ClientConnectorError from aiohttp.client_ws import ClientWebSocketResponse from aiohttp.hdrs import AUTHORIZATION, CONTENT_TYPE -from aiohttp.http import WSMessage from aiohttp.http_websocket import WSMsgType from aiohttp.web_exceptions import HTTPBadGateway, HTTPUnauthorized +from supervisor.utils.logging import AddonLoggerAdapter + from ..coresys import CoreSysAttributes from ..exceptions import APIError, HomeAssistantAPIError, HomeAssistantAuthError from ..utils.json import json_dumps @@ -179,23 +180,39 @@ class APIProxy(CoreSysAttributes): async def _proxy_message( self, - read_task: asyncio.Task, + source: web.WebSocketResponse | ClientWebSocketResponse, target: web.WebSocketResponse | ClientWebSocketResponse, + logger: AddonLoggerAdapter, ) -> None: """Proxy a message from client to server or vice versa.""" - msg: WSMessage = read_task.result() - match msg.type: - case WSMsgType.TEXT: - await target.send_str(msg.data) - case WSMsgType.BINARY: - await target.send_bytes(msg.data) - case WSMsgType.CLOSE: - _LOGGER.debug("Received close message from WebSocket.") - await target.close() - case _: - raise TypeError( - f"Cannot proxy websocket message of unsupported type: {msg.type}" - ) + while not source.closed and not target.closed: + msg = await source.receive() + match msg.type: + case WSMsgType.TEXT: + await target.send_str(msg.data) + case WSMsgType.BINARY: + await target.send_bytes(msg.data) + case WSMsgType.CLOSE | WSMsgType.CLOSED: + logger.debug( + "Received WebSocket message type %r from %s.", + msg.type, + "add-on" if type(source) is web.WebSocketResponse else "Core", + ) + await target.close() + case WSMsgType.CLOSING: + pass + case WSMsgType.ERROR: + logger.warning( + "Error WebSocket message received while proxying: %r", msg.data + ) + await target.close(code=source.close_code) + case _: + logger.warning( + "Cannot proxy WebSocket message of unsupported type: %r", + msg.type, + ) + await source.close() + await target.close() async def websocket(self, request: web.Request): """Initialize a WebSocket API connection.""" @@ -255,48 +272,32 @@ class APIProxy(CoreSysAttributes): except APIError: return server - _LOGGER.info("Home Assistant WebSocket API request running") - try: - client_read: asyncio.Task | None = None - server_read: asyncio.Task | None = None - while not server.closed and not client.closed: - if not client_read: - client_read = self.sys_create_task(client.receive()) - if not server_read: - server_read = self.sys_create_task(server.receive()) + logger = AddonLoggerAdapter(_LOGGER, {"addon_name": addon_name}) + logger.info("Home Assistant WebSocket API proxy running") - # wait until data need to be processed - await asyncio.wait( - [client_read, server_read], return_when=asyncio.FIRST_COMPLETED - ) + client_task = self.sys_create_task(self._proxy_message(client, server, logger)) + server_task = self.sys_create_task(self._proxy_message(server, client, logger)) - # server - if server_read.done() and not client.closed: - await self._proxy_message(server_read, client) - server_read = None + # Typically, this will return with an empty pending set. However, if one of + # the directions has an exception, make sure to close both connections and + # wait for the other proxy task to exit gracefully. Using this over try-except + # handling makes it easier to wait for the other direction to complete. + _, pending = await asyncio.wait( + (client_task, server_task), return_when=asyncio.FIRST_EXCEPTION + ) - # client - if client_read.done() and not server.closed: - await self._proxy_message(client_read, server) - client_read = None + if not client.closed: + await client.close() + if not server.closed: + await server.close() - except asyncio.CancelledError: - pass + if pending: + _, pending = await asyncio.wait( + pending, timeout=10, return_when=asyncio.ALL_COMPLETED + ) + for task in pending: + task.cancel() + logger.critical("WebSocket proxy task: %s did not end gracefully", task) - except (RuntimeError, ConnectionError, TypeError) as err: - _LOGGER.info("Home Assistant WebSocket API error: %s", err) - - finally: - if client_read and not client_read.done(): - client_read.cancel() - if server_read and not server_read.done(): - server_read.cancel() - - # close connections - if not client.closed: - await client.close() - if not server.closed: - await server.close() - - _LOGGER.info("Home Assistant WebSocket API for %s closed", addon_name) + logger.info("Home Assistant WebSocket API closed") return server diff --git a/supervisor/utils/logging.py b/supervisor/utils/logging.py index 344b46e44..e43d8c45d 100644 --- a/supervisor/utils/logging.py +++ b/supervisor/utils/logging.py @@ -8,6 +8,14 @@ import queue from typing import Any +class AddonLoggerAdapter(logging.LoggerAdapter): + """Logging Adapter which prepends log entries with add-on name.""" + + def process(self, msg, kwargs): + """Process the logging message by prepending the add-on name.""" + return f"[{self.extra['addon_name']}] {msg}", kwargs + + class SupervisorQueueHandler(logging.handlers.QueueHandler): """Process the log in another thread.""" diff --git a/tests/api/test_proxy.py b/tests/api/test_proxy.py index fb9c0b53b..d031f7b71 100644 --- a/tests/api/test_proxy.py +++ b/tests/api/test_proxy.py @@ -9,7 +9,7 @@ import logging from typing import Any, cast from unittest.mock import patch -from aiohttp import ClientWebSocketResponse +from aiohttp import ClientWebSocketResponse, WSCloseCode from aiohttp.http_websocket import WSMessage, WSMsgType from aiohttp.test_utils import TestClient import pytest @@ -37,6 +37,7 @@ class MockHAServerWebSocket: """Mock of HA Websocket server.""" closed: bool = False + close_code: int | None = None def __init__(self) -> None: """Initialize object.""" @@ -44,9 +45,12 @@ class MockHAServerWebSocket: self.incoming: asyncio.Queue[WSMessage] = asyncio.Queue() self._id_generator = id_generator() - def receive(self) -> Awaitable[WSMessage]: + async def receive(self) -> WSMessage: """Receive next message.""" - return self.outgoing.get() + try: + return await self.outgoing.get() + except asyncio.QueueShutDown: + return WSMessage(WSMsgType.CLOSED, None, None) def send_str(self, data: str) -> Awaitable[None]: """Incoming string message.""" @@ -68,9 +72,11 @@ class MockHAServerWebSocket: """Respond with binary.""" return self.outgoing.put(WSMessage(WSMsgType.BINARY, data, None)) - async def close(self) -> None: + async def close(self, code: int = WSCloseCode.OK) -> None: """Close connection.""" self.closed = True + self.outgoing.shutdown(immediate=True) + self.close_code = code WebSocketGenerator = Callable[..., Coroutine[Any, Any, MockHAClientWebSocket]] @@ -162,6 +168,26 @@ async def test_proxy_binary_message( assert await client.close() +async def test_proxy_large_message( + proxy_ws_client: WebSocketGenerator, + ha_ws_server: MockHAServerWebSocket, + install_addon_ssh: Addon, +): + """Test too large message handled gracefully.""" + install_addon_ssh.persist[ATTR_ACCESS_TOKEN] = "abc123" + client: MockHAClientWebSocket = await proxy_ws_client( + install_addon_ssh.supervisor_token + ) + + # Test message over size limit of 4MB + await client.send_bytes(bytearray(1024 * 1024 * 4)) + msg = await client.receive() + assert msg.type == WSMsgType.CLOSE + assert msg.data == WSCloseCode.MESSAGE_TOO_BIG + + assert ha_ws_server.closed + + @pytest.mark.parametrize("auth_token", ["abc123", "bad"]) async def test_proxy_invalid_auth( api_client: TestClient, install_addon_example: Addon, auth_token: str