Improve Home Assistant Core WebSocket proxy implementation (#5790)

* Improve Home Assistant Core WebSocket proxy implementation

This change removes unnecessary task creation for every WebSocket
message and instead creates just two tasks, one for each direction.
This improves performance by about factor of 3 when measuring 1000
WebSocket requests to Core (from ~530ms to ~160ms).

While at it, also handle all WebSocket message related to closing the
WebSocket and report all other errors as warnings instead of just info.

* Improve logging and error handling

* Add WS client error test case

* Use asyncio.gather directly

* Use asyncio.wait to handle exceptions gracefully

* Drop cancellation handling and correctly wait for the other proxy task
This commit is contained in:
Stefan Agner 2025-03-28 10:35:49 +01:00 committed by GitHub
parent 0a684bdb12
commit 8fe17d9270
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 93 additions and 58 deletions

View File

@ -10,10 +10,11 @@ from aiohttp import WSMessageTypeError, web
from aiohttp.client_exceptions import ClientConnectorError
from aiohttp.client_ws import ClientWebSocketResponse
from aiohttp.hdrs import AUTHORIZATION, CONTENT_TYPE
from aiohttp.http import WSMessage
from aiohttp.http_websocket import WSMsgType
from aiohttp.web_exceptions import HTTPBadGateway, HTTPUnauthorized
from supervisor.utils.logging import AddonLoggerAdapter
from ..coresys import CoreSysAttributes
from ..exceptions import APIError, HomeAssistantAPIError, HomeAssistantAuthError
from ..utils.json import json_dumps
@ -179,23 +180,39 @@ class APIProxy(CoreSysAttributes):
async def _proxy_message(
self,
read_task: asyncio.Task,
source: web.WebSocketResponse | ClientWebSocketResponse,
target: web.WebSocketResponse | ClientWebSocketResponse,
logger: AddonLoggerAdapter,
) -> None:
"""Proxy a message from client to server or vice versa."""
msg: WSMessage = read_task.result()
match msg.type:
case WSMsgType.TEXT:
await target.send_str(msg.data)
case WSMsgType.BINARY:
await target.send_bytes(msg.data)
case WSMsgType.CLOSE:
_LOGGER.debug("Received close message from WebSocket.")
await target.close()
case _:
raise TypeError(
f"Cannot proxy websocket message of unsupported type: {msg.type}"
)
while not source.closed and not target.closed:
msg = await source.receive()
match msg.type:
case WSMsgType.TEXT:
await target.send_str(msg.data)
case WSMsgType.BINARY:
await target.send_bytes(msg.data)
case WSMsgType.CLOSE | WSMsgType.CLOSED:
logger.debug(
"Received WebSocket message type %r from %s.",
msg.type,
"add-on" if type(source) is web.WebSocketResponse else "Core",
)
await target.close()
case WSMsgType.CLOSING:
pass
case WSMsgType.ERROR:
logger.warning(
"Error WebSocket message received while proxying: %r", msg.data
)
await target.close(code=source.close_code)
case _:
logger.warning(
"Cannot proxy WebSocket message of unsupported type: %r",
msg.type,
)
await source.close()
await target.close()
async def websocket(self, request: web.Request):
"""Initialize a WebSocket API connection."""
@ -255,48 +272,32 @@ class APIProxy(CoreSysAttributes):
except APIError:
return server
_LOGGER.info("Home Assistant WebSocket API request running")
try:
client_read: asyncio.Task | None = None
server_read: asyncio.Task | None = None
while not server.closed and not client.closed:
if not client_read:
client_read = self.sys_create_task(client.receive())
if not server_read:
server_read = self.sys_create_task(server.receive())
logger = AddonLoggerAdapter(_LOGGER, {"addon_name": addon_name})
logger.info("Home Assistant WebSocket API proxy running")
# wait until data need to be processed
await asyncio.wait(
[client_read, server_read], return_when=asyncio.FIRST_COMPLETED
)
client_task = self.sys_create_task(self._proxy_message(client, server, logger))
server_task = self.sys_create_task(self._proxy_message(server, client, logger))
# server
if server_read.done() and not client.closed:
await self._proxy_message(server_read, client)
server_read = None
# Typically, this will return with an empty pending set. However, if one of
# the directions has an exception, make sure to close both connections and
# wait for the other proxy task to exit gracefully. Using this over try-except
# handling makes it easier to wait for the other direction to complete.
_, pending = await asyncio.wait(
(client_task, server_task), return_when=asyncio.FIRST_EXCEPTION
)
# client
if client_read.done() and not server.closed:
await self._proxy_message(client_read, server)
client_read = None
if not client.closed:
await client.close()
if not server.closed:
await server.close()
except asyncio.CancelledError:
pass
if pending:
_, pending = await asyncio.wait(
pending, timeout=10, return_when=asyncio.ALL_COMPLETED
)
for task in pending:
task.cancel()
logger.critical("WebSocket proxy task: %s did not end gracefully", task)
except (RuntimeError, ConnectionError, TypeError) as err:
_LOGGER.info("Home Assistant WebSocket API error: %s", err)
finally:
if client_read and not client_read.done():
client_read.cancel()
if server_read and not server_read.done():
server_read.cancel()
# close connections
if not client.closed:
await client.close()
if not server.closed:
await server.close()
_LOGGER.info("Home Assistant WebSocket API for %s closed", addon_name)
logger.info("Home Assistant WebSocket API closed")
return server

View File

@ -8,6 +8,14 @@ import queue
from typing import Any
class AddonLoggerAdapter(logging.LoggerAdapter):
"""Logging Adapter which prepends log entries with add-on name."""
def process(self, msg, kwargs):
"""Process the logging message by prepending the add-on name."""
return f"[{self.extra['addon_name']}] {msg}", kwargs
class SupervisorQueueHandler(logging.handlers.QueueHandler):
"""Process the log in another thread."""

View File

@ -9,7 +9,7 @@ import logging
from typing import Any, cast
from unittest.mock import patch
from aiohttp import ClientWebSocketResponse
from aiohttp import ClientWebSocketResponse, WSCloseCode
from aiohttp.http_websocket import WSMessage, WSMsgType
from aiohttp.test_utils import TestClient
import pytest
@ -37,6 +37,7 @@ class MockHAServerWebSocket:
"""Mock of HA Websocket server."""
closed: bool = False
close_code: int | None = None
def __init__(self) -> None:
"""Initialize object."""
@ -44,9 +45,12 @@ class MockHAServerWebSocket:
self.incoming: asyncio.Queue[WSMessage] = asyncio.Queue()
self._id_generator = id_generator()
def receive(self) -> Awaitable[WSMessage]:
async def receive(self) -> WSMessage:
"""Receive next message."""
return self.outgoing.get()
try:
return await self.outgoing.get()
except asyncio.QueueShutDown:
return WSMessage(WSMsgType.CLOSED, None, None)
def send_str(self, data: str) -> Awaitable[None]:
"""Incoming string message."""
@ -68,9 +72,11 @@ class MockHAServerWebSocket:
"""Respond with binary."""
return self.outgoing.put(WSMessage(WSMsgType.BINARY, data, None))
async def close(self) -> None:
async def close(self, code: int = WSCloseCode.OK) -> None:
"""Close connection."""
self.closed = True
self.outgoing.shutdown(immediate=True)
self.close_code = code
WebSocketGenerator = Callable[..., Coroutine[Any, Any, MockHAClientWebSocket]]
@ -162,6 +168,26 @@ async def test_proxy_binary_message(
assert await client.close()
async def test_proxy_large_message(
proxy_ws_client: WebSocketGenerator,
ha_ws_server: MockHAServerWebSocket,
install_addon_ssh: Addon,
):
"""Test too large message handled gracefully."""
install_addon_ssh.persist[ATTR_ACCESS_TOKEN] = "abc123"
client: MockHAClientWebSocket = await proxy_ws_client(
install_addon_ssh.supervisor_token
)
# Test message over size limit of 4MB
await client.send_bytes(bytearray(1024 * 1024 * 4))
msg = await client.receive()
assert msg.type == WSMsgType.CLOSE
assert msg.data == WSCloseCode.MESSAGE_TOO_BIG
assert ha_ws_server.closed
@pytest.mark.parametrize("auth_token", ["abc123", "bad"])
async def test_proxy_invalid_auth(
api_client: TestClient, install_addon_example: Addon, auth_token: str