diff --git a/supervisor/api/proxy.py b/supervisor/api/proxy.py index 8f03e8af4..74a9346a8 100644 --- a/supervisor/api/proxy.py +++ b/supervisor/api/proxy.py @@ -6,7 +6,10 @@ import logging import aiohttp from aiohttp import web from aiohttp.client_exceptions import ClientConnectorError +from aiohttp.client_ws import ClientWebSocketResponse from aiohttp.hdrs import AUTHORIZATION, CONTENT_TYPE +from aiohttp.http import WSMessage +from aiohttp.http_websocket import WSMsgType from aiohttp.web_exceptions import HTTPBadGateway, HTTPUnauthorized from ..coresys import CoreSysAttributes @@ -114,7 +117,7 @@ class APIProxy(CoreSysAttributes): body=data, status=client.status, content_type=client.content_type ) - async def _websocket_client(self): + async def _websocket_client(self) -> ClientWebSocketResponse: """Initialize a WebSocket API connection.""" url = f"{self.sys_homeassistant.api_url}/api/websocket" @@ -167,6 +170,25 @@ class APIProxy(CoreSysAttributes): raise APIError() + async def _proxy_message( + self, + read_task: asyncio.Task, + target: web.WebSocketResponse | ClientWebSocketResponse, + ) -> None: + """Proxy a message from client to server or vice versa.""" + if read_task.exception(): + raise read_task.exception() + + msg: WSMessage = read_task.result() + if msg.type == WSMsgType.TEXT: + return await target.send_str(msg.data) + if msg.type == WSMsgType.BINARY: + return await target.send_bytes(msg.data) + + raise TypeError( + f"Cannot proxy websocket message of unsupported type: {msg.type}" + ) + async def websocket(self, request: web.Request): """Initialize a WebSocket API connection.""" if not await self.sys_homeassistant.api.check_api_state(): @@ -214,13 +236,13 @@ class APIProxy(CoreSysAttributes): _LOGGER.info("Home Assistant WebSocket API request running") try: - client_read = None - server_read = None + client_read: asyncio.Task | None = None + server_read: asyncio.Task | None = None while not server.closed and not client.closed: if not client_read: - client_read = self.sys_create_task(client.receive_str()) + client_read = self.sys_create_task(client.receive()) if not server_read: - server_read = self.sys_create_task(server.receive_str()) + server_read = self.sys_create_task(server.receive()) # wait until data need to be processed await asyncio.wait( @@ -229,14 +251,12 @@ class APIProxy(CoreSysAttributes): # server if server_read.done() and not client.closed: - server_read.exception() - await client.send_str(server_read.result()) + await self._proxy_message(server_read, client) server_read = None # client if client_read.done() and not server.closed: - client_read.exception() - await server.send_str(client_read.result()) + await self._proxy_message(client_read, server) client_read = None except asyncio.CancelledError: @@ -246,9 +266,9 @@ class APIProxy(CoreSysAttributes): _LOGGER.info("Home Assistant WebSocket API error: %s", err) finally: - if client_read: + if client_read and not client_read.done(): client_read.cancel() - if server_read: + if server_read and not server_read.done(): server_read.cancel() # close connections diff --git a/tests/api/test_proxy.py b/tests/api/test_proxy.py new file mode 100644 index 000000000..0da8675cc --- /dev/null +++ b/tests/api/test_proxy.py @@ -0,0 +1,177 @@ +"""Test Home Assistant proxy.""" + +from __future__ import annotations + +import asyncio +from collections.abc import Awaitable, Callable, Coroutine, Generator +from json import dumps +from typing import Any, cast +from unittest.mock import patch + +from aiohttp import ClientWebSocketResponse +from aiohttp.http_websocket import WSMessage, WSMsgType +from aiohttp.test_utils import TestClient +import pytest + +from supervisor.addons.addon import Addon +from supervisor.api.proxy import APIProxy +from supervisor.const import ATTR_ACCESS_TOKEN + + +def id_generator() -> Generator[int, None, None]: + """Generate IDs for WS messages.""" + i = 0 + while True: + yield (i := i + 1) + + +class MockHAClientWebSocket(ClientWebSocketResponse): + """Protocol for a wrapped ClientWebSocketResponse.""" + + client: TestClient + send_json_auto_id: Callable[[dict[str, Any]], Coroutine[Any, Any, None]] + + +class MockHAServerWebSocket: + """Mock of HA Websocket server.""" + + closed: bool = False + + def __init__(self) -> None: + """Initialize object.""" + self.outgoing: asyncio.Queue[WSMessage] = asyncio.Queue() + self.incoming: asyncio.Queue[WSMessage] = asyncio.Queue() + self._id_generator = id_generator() + + def receive(self) -> Awaitable[WSMessage]: + """Receive next message.""" + return self.outgoing.get() + + def send_str(self, data: str) -> Awaitable[None]: + """Incoming string message.""" + return self.incoming.put(WSMessage(WSMsgType.TEXT, data, None)) + + def send_bytes(self, data: bytes) -> Awaitable[None]: + """Incoming string message.""" + return self.incoming.put(WSMessage(WSMsgType.BINARY, data, None)) + + def respond_json(self, data: dict[str, Any]) -> Awaitable[None]: + """Respond with JSON.""" + return self.outgoing.put( + WSMessage( + WSMsgType.TEXT, dumps(data | {"id": next(self._id_generator)}), None + ) + ) + + def respond_bytes(self, data: bytes) -> Awaitable[None]: + """Respond with binary.""" + return self.outgoing.put(WSMessage(WSMsgType.BINARY, data, None)) + + async def close(self) -> None: + """Close connection.""" + self.closed = True + + +WebSocketGenerator = Callable[..., Coroutine[Any, Any, MockHAClientWebSocket]] + + +@pytest.fixture(name="ha_ws_server") +async def fixture_ha_ws_server() -> MockHAServerWebSocket: + """Mock HA WS server for testing.""" + with patch.object( + APIProxy, + "_websocket_client", + return_value=(mock_server := MockHAServerWebSocket()), + ): + yield mock_server + + +@pytest.fixture(name="proxy_ws_client") +def fixture_proxy_ws_client( + api_client: TestClient, ha_ws_server: MockHAServerWebSocket +) -> WebSocketGenerator: + """Websocket client fixture connected to websocket server.""" + + async def create_client(auth_token: str) -> MockHAClientWebSocket: + """Create a websocket client.""" + websocket = await api_client.ws_connect("/core/websocket") + auth_resp = await websocket.receive_json() + assert auth_resp["type"] == "auth_required" + await websocket.send_json({"type": "auth", "access_token": auth_token}) + + auth_ok = await websocket.receive_json() + assert auth_ok["type"] == "auth_ok" + + _id_generator = id_generator() + + def _send_json_auto_id(data: dict[str, Any]) -> Coroutine[Any, Any, None]: + data["id"] = next(_id_generator) + return websocket.send_json(data) + + # wrap in client + wrapped_websocket = cast(MockHAClientWebSocket, websocket) + wrapped_websocket.client = api_client + wrapped_websocket.send_json_auto_id = _send_json_auto_id + return wrapped_websocket + + return create_client + + +async def test_proxy_message( + proxy_ws_client: WebSocketGenerator, + ha_ws_server: MockHAServerWebSocket, + install_addon_ssh: Addon, +): + """Test proxy a message to and from Home Assistant.""" + install_addon_ssh.persist[ATTR_ACCESS_TOKEN] = "abc123" + client: MockHAClientWebSocket = await proxy_ws_client( + install_addon_ssh.supervisor_token + ) + + await client.send_json_auto_id({"hello": "world"}) + proxied_msg = await ha_ws_server.incoming.get() + assert proxied_msg.type == WSMsgType.TEXT + assert proxied_msg.data == '{"hello": "world", "id": 1}' + + await ha_ws_server.respond_json({"world": "received"}) + assert await client.receive_json() == {"world": "received", "id": 1} + + assert await client.close() + + +async def test_proxy_binary_message( + proxy_ws_client: WebSocketGenerator, + ha_ws_server: MockHAServerWebSocket, + install_addon_ssh: Addon, +): + """Test proxy a binary message to and from Home Assistant.""" + install_addon_ssh.persist[ATTR_ACCESS_TOKEN] = "abc123" + client: MockHAClientWebSocket = await proxy_ws_client( + install_addon_ssh.supervisor_token + ) + + await client.send_bytes(b"hello world") + proxied_msg = await ha_ws_server.incoming.get() + assert proxied_msg.type == WSMsgType.BINARY + assert proxied_msg.data == b"hello world" + + await ha_ws_server.respond_bytes(b"world received") + assert await client.receive_bytes() == b"world received" + + assert await client.close() + + +@pytest.mark.parametrize("auth_token", ["abc123", "bad"]) +async def test_proxy_invalid_auth( + api_client: TestClient, install_addon_example: Addon, auth_token: str +): + """Test invalid access token or addon with no access.""" + install_addon_example.persist[ATTR_ACCESS_TOKEN] = "abc123" + websocket = await api_client.ws_connect("/core/websocket") + auth_resp = await websocket.receive_json() + assert auth_resp["type"] == "auth_required" + await websocket.send_json({"type": "auth", "access_token": auth_token}) + + auth_not_ok = await websocket.receive_json() + assert auth_not_ok["type"] == "auth_invalid" + assert auth_not_ok["message"] == "Invalid access" diff --git a/tests/fixtures/addons/local/ssh/config.yaml b/tests/fixtures/addons/local/ssh/config.yaml index 3e67c70d5..38dac441c 100644 --- a/tests/fixtures/addons/local/ssh/config.yaml +++ b/tests/fixtures/addons/local/ssh/config.yaml @@ -17,6 +17,7 @@ panel_icon: "mdi:console" panel_title: Terminal hassio_api: true hassio_role: manager +homeassistant_api: true audio: true uart: true ports: