mirror of
https://github.com/home-assistant/supervisor.git
synced 2025-07-19 07:06:30 +00:00
Improve Home Assistant Core WebSocket proxy implementation (#5790)
* Improve Home Assistant Core WebSocket proxy implementation This change removes unnecessary task creation for every WebSocket message and instead creates just two tasks, one for each direction. This improves performance by about factor of 3 when measuring 1000 WebSocket requests to Core (from ~530ms to ~160ms). While at it, also handle all WebSocket message related to closing the WebSocket and report all other errors as warnings instead of just info. * Improve logging and error handling * Add WS client error test case * Use asyncio.gather directly * Use asyncio.wait to handle exceptions gracefully * Drop cancellation handling and correctly wait for the other proxy task
This commit is contained in:
parent
0a684bdb12
commit
8fe17d9270
@ -10,10 +10,11 @@ from aiohttp import WSMessageTypeError, web
|
|||||||
from aiohttp.client_exceptions import ClientConnectorError
|
from aiohttp.client_exceptions import ClientConnectorError
|
||||||
from aiohttp.client_ws import ClientWebSocketResponse
|
from aiohttp.client_ws import ClientWebSocketResponse
|
||||||
from aiohttp.hdrs import AUTHORIZATION, CONTENT_TYPE
|
from aiohttp.hdrs import AUTHORIZATION, CONTENT_TYPE
|
||||||
from aiohttp.http import WSMessage
|
|
||||||
from aiohttp.http_websocket import WSMsgType
|
from aiohttp.http_websocket import WSMsgType
|
||||||
from aiohttp.web_exceptions import HTTPBadGateway, HTTPUnauthorized
|
from aiohttp.web_exceptions import HTTPBadGateway, HTTPUnauthorized
|
||||||
|
|
||||||
|
from supervisor.utils.logging import AddonLoggerAdapter
|
||||||
|
|
||||||
from ..coresys import CoreSysAttributes
|
from ..coresys import CoreSysAttributes
|
||||||
from ..exceptions import APIError, HomeAssistantAPIError, HomeAssistantAuthError
|
from ..exceptions import APIError, HomeAssistantAPIError, HomeAssistantAuthError
|
||||||
from ..utils.json import json_dumps
|
from ..utils.json import json_dumps
|
||||||
@ -179,23 +180,39 @@ class APIProxy(CoreSysAttributes):
|
|||||||
|
|
||||||
async def _proxy_message(
|
async def _proxy_message(
|
||||||
self,
|
self,
|
||||||
read_task: asyncio.Task,
|
source: web.WebSocketResponse | ClientWebSocketResponse,
|
||||||
target: web.WebSocketResponse | ClientWebSocketResponse,
|
target: web.WebSocketResponse | ClientWebSocketResponse,
|
||||||
|
logger: AddonLoggerAdapter,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Proxy a message from client to server or vice versa."""
|
"""Proxy a message from client to server or vice versa."""
|
||||||
msg: WSMessage = read_task.result()
|
while not source.closed and not target.closed:
|
||||||
match msg.type:
|
msg = await source.receive()
|
||||||
case WSMsgType.TEXT:
|
match msg.type:
|
||||||
await target.send_str(msg.data)
|
case WSMsgType.TEXT:
|
||||||
case WSMsgType.BINARY:
|
await target.send_str(msg.data)
|
||||||
await target.send_bytes(msg.data)
|
case WSMsgType.BINARY:
|
||||||
case WSMsgType.CLOSE:
|
await target.send_bytes(msg.data)
|
||||||
_LOGGER.debug("Received close message from WebSocket.")
|
case WSMsgType.CLOSE | WSMsgType.CLOSED:
|
||||||
await target.close()
|
logger.debug(
|
||||||
case _:
|
"Received WebSocket message type %r from %s.",
|
||||||
raise TypeError(
|
msg.type,
|
||||||
f"Cannot proxy websocket message of unsupported type: {msg.type}"
|
"add-on" if type(source) is web.WebSocketResponse else "Core",
|
||||||
)
|
)
|
||||||
|
await target.close()
|
||||||
|
case WSMsgType.CLOSING:
|
||||||
|
pass
|
||||||
|
case WSMsgType.ERROR:
|
||||||
|
logger.warning(
|
||||||
|
"Error WebSocket message received while proxying: %r", msg.data
|
||||||
|
)
|
||||||
|
await target.close(code=source.close_code)
|
||||||
|
case _:
|
||||||
|
logger.warning(
|
||||||
|
"Cannot proxy WebSocket message of unsupported type: %r",
|
||||||
|
msg.type,
|
||||||
|
)
|
||||||
|
await source.close()
|
||||||
|
await target.close()
|
||||||
|
|
||||||
async def websocket(self, request: web.Request):
|
async def websocket(self, request: web.Request):
|
||||||
"""Initialize a WebSocket API connection."""
|
"""Initialize a WebSocket API connection."""
|
||||||
@ -255,48 +272,32 @@ class APIProxy(CoreSysAttributes):
|
|||||||
except APIError:
|
except APIError:
|
||||||
return server
|
return server
|
||||||
|
|
||||||
_LOGGER.info("Home Assistant WebSocket API request running")
|
logger = AddonLoggerAdapter(_LOGGER, {"addon_name": addon_name})
|
||||||
try:
|
logger.info("Home Assistant WebSocket API proxy running")
|
||||||
client_read: asyncio.Task | None = None
|
|
||||||
server_read: asyncio.Task | None = None
|
|
||||||
while not server.closed and not client.closed:
|
|
||||||
if not client_read:
|
|
||||||
client_read = self.sys_create_task(client.receive())
|
|
||||||
if not server_read:
|
|
||||||
server_read = self.sys_create_task(server.receive())
|
|
||||||
|
|
||||||
# wait until data need to be processed
|
client_task = self.sys_create_task(self._proxy_message(client, server, logger))
|
||||||
await asyncio.wait(
|
server_task = self.sys_create_task(self._proxy_message(server, client, logger))
|
||||||
[client_read, server_read], return_when=asyncio.FIRST_COMPLETED
|
|
||||||
)
|
|
||||||
|
|
||||||
# server
|
# Typically, this will return with an empty pending set. However, if one of
|
||||||
if server_read.done() and not client.closed:
|
# the directions has an exception, make sure to close both connections and
|
||||||
await self._proxy_message(server_read, client)
|
# wait for the other proxy task to exit gracefully. Using this over try-except
|
||||||
server_read = None
|
# handling makes it easier to wait for the other direction to complete.
|
||||||
|
_, pending = await asyncio.wait(
|
||||||
|
(client_task, server_task), return_when=asyncio.FIRST_EXCEPTION
|
||||||
|
)
|
||||||
|
|
||||||
# client
|
if not client.closed:
|
||||||
if client_read.done() and not server.closed:
|
await client.close()
|
||||||
await self._proxy_message(client_read, server)
|
if not server.closed:
|
||||||
client_read = None
|
await server.close()
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
if pending:
|
||||||
pass
|
_, pending = await asyncio.wait(
|
||||||
|
pending, timeout=10, return_when=asyncio.ALL_COMPLETED
|
||||||
|
)
|
||||||
|
for task in pending:
|
||||||
|
task.cancel()
|
||||||
|
logger.critical("WebSocket proxy task: %s did not end gracefully", task)
|
||||||
|
|
||||||
except (RuntimeError, ConnectionError, TypeError) as err:
|
logger.info("Home Assistant WebSocket API closed")
|
||||||
_LOGGER.info("Home Assistant WebSocket API error: %s", err)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
if client_read and not client_read.done():
|
|
||||||
client_read.cancel()
|
|
||||||
if server_read and not server_read.done():
|
|
||||||
server_read.cancel()
|
|
||||||
|
|
||||||
# close connections
|
|
||||||
if not client.closed:
|
|
||||||
await client.close()
|
|
||||||
if not server.closed:
|
|
||||||
await server.close()
|
|
||||||
|
|
||||||
_LOGGER.info("Home Assistant WebSocket API for %s closed", addon_name)
|
|
||||||
return server
|
return server
|
||||||
|
@ -8,6 +8,14 @@ import queue
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class AddonLoggerAdapter(logging.LoggerAdapter):
|
||||||
|
"""Logging Adapter which prepends log entries with add-on name."""
|
||||||
|
|
||||||
|
def process(self, msg, kwargs):
|
||||||
|
"""Process the logging message by prepending the add-on name."""
|
||||||
|
return f"[{self.extra['addon_name']}] {msg}", kwargs
|
||||||
|
|
||||||
|
|
||||||
class SupervisorQueueHandler(logging.handlers.QueueHandler):
|
class SupervisorQueueHandler(logging.handlers.QueueHandler):
|
||||||
"""Process the log in another thread."""
|
"""Process the log in another thread."""
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ import logging
|
|||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from aiohttp import ClientWebSocketResponse
|
from aiohttp import ClientWebSocketResponse, WSCloseCode
|
||||||
from aiohttp.http_websocket import WSMessage, WSMsgType
|
from aiohttp.http_websocket import WSMessage, WSMsgType
|
||||||
from aiohttp.test_utils import TestClient
|
from aiohttp.test_utils import TestClient
|
||||||
import pytest
|
import pytest
|
||||||
@ -37,6 +37,7 @@ class MockHAServerWebSocket:
|
|||||||
"""Mock of HA Websocket server."""
|
"""Mock of HA Websocket server."""
|
||||||
|
|
||||||
closed: bool = False
|
closed: bool = False
|
||||||
|
close_code: int | None = None
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"""Initialize object."""
|
"""Initialize object."""
|
||||||
@ -44,9 +45,12 @@ class MockHAServerWebSocket:
|
|||||||
self.incoming: asyncio.Queue[WSMessage] = asyncio.Queue()
|
self.incoming: asyncio.Queue[WSMessage] = asyncio.Queue()
|
||||||
self._id_generator = id_generator()
|
self._id_generator = id_generator()
|
||||||
|
|
||||||
def receive(self) -> Awaitable[WSMessage]:
|
async def receive(self) -> WSMessage:
|
||||||
"""Receive next message."""
|
"""Receive next message."""
|
||||||
return self.outgoing.get()
|
try:
|
||||||
|
return await self.outgoing.get()
|
||||||
|
except asyncio.QueueShutDown:
|
||||||
|
return WSMessage(WSMsgType.CLOSED, None, None)
|
||||||
|
|
||||||
def send_str(self, data: str) -> Awaitable[None]:
|
def send_str(self, data: str) -> Awaitable[None]:
|
||||||
"""Incoming string message."""
|
"""Incoming string message."""
|
||||||
@ -68,9 +72,11 @@ class MockHAServerWebSocket:
|
|||||||
"""Respond with binary."""
|
"""Respond with binary."""
|
||||||
return self.outgoing.put(WSMessage(WSMsgType.BINARY, data, None))
|
return self.outgoing.put(WSMessage(WSMsgType.BINARY, data, None))
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self, code: int = WSCloseCode.OK) -> None:
|
||||||
"""Close connection."""
|
"""Close connection."""
|
||||||
self.closed = True
|
self.closed = True
|
||||||
|
self.outgoing.shutdown(immediate=True)
|
||||||
|
self.close_code = code
|
||||||
|
|
||||||
|
|
||||||
WebSocketGenerator = Callable[..., Coroutine[Any, Any, MockHAClientWebSocket]]
|
WebSocketGenerator = Callable[..., Coroutine[Any, Any, MockHAClientWebSocket]]
|
||||||
@ -162,6 +168,26 @@ async def test_proxy_binary_message(
|
|||||||
assert await client.close()
|
assert await client.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_proxy_large_message(
|
||||||
|
proxy_ws_client: WebSocketGenerator,
|
||||||
|
ha_ws_server: MockHAServerWebSocket,
|
||||||
|
install_addon_ssh: Addon,
|
||||||
|
):
|
||||||
|
"""Test too large message handled gracefully."""
|
||||||
|
install_addon_ssh.persist[ATTR_ACCESS_TOKEN] = "abc123"
|
||||||
|
client: MockHAClientWebSocket = await proxy_ws_client(
|
||||||
|
install_addon_ssh.supervisor_token
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test message over size limit of 4MB
|
||||||
|
await client.send_bytes(bytearray(1024 * 1024 * 4))
|
||||||
|
msg = await client.receive()
|
||||||
|
assert msg.type == WSMsgType.CLOSE
|
||||||
|
assert msg.data == WSCloseCode.MESSAGE_TOO_BIG
|
||||||
|
|
||||||
|
assert ha_ws_server.closed
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("auth_token", ["abc123", "bad"])
|
@pytest.mark.parametrize("auth_token", ["abc123", "bad"])
|
||||||
async def test_proxy_invalid_auth(
|
async def test_proxy_invalid_auth(
|
||||||
api_client: TestClient, install_addon_example: Addon, auth_token: str
|
api_client: TestClient, install_addon_example: Addon, auth_token: str
|
||||||
|
Loading…
x
Reference in New Issue
Block a user