Support proxy of binary messages from addons to HA (#4605)

* Support proxy of binary messages from addons to HA

* Added tests for proxy

* Move instantiation into init

* Mock close method on server

* Add invalid auth test and remove auth mock
This commit is contained in:
Mike Degatano 2023-10-14 12:07:49 -04:00 committed by GitHub
parent a70f81aa01
commit 012bfd7e6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 209 additions and 11 deletions

View File

@ -6,7 +6,10 @@ import logging
import aiohttp
from aiohttp import web
from aiohttp.client_exceptions import ClientConnectorError
from aiohttp.client_ws import ClientWebSocketResponse
from aiohttp.hdrs import AUTHORIZATION, CONTENT_TYPE
from aiohttp.http import WSMessage
from aiohttp.http_websocket import WSMsgType
from aiohttp.web_exceptions import HTTPBadGateway, HTTPUnauthorized
from ..coresys import CoreSysAttributes
@ -114,7 +117,7 @@ class APIProxy(CoreSysAttributes):
body=data, status=client.status, content_type=client.content_type
)
async def _websocket_client(self):
async def _websocket_client(self) -> ClientWebSocketResponse:
"""Initialize a WebSocket API connection."""
url = f"{self.sys_homeassistant.api_url}/api/websocket"
@ -167,6 +170,25 @@ class APIProxy(CoreSysAttributes):
raise APIError()
async def _proxy_message(
self,
read_task: asyncio.Task,
target: web.WebSocketResponse | ClientWebSocketResponse,
) -> None:
"""Proxy a message from client to server or vice versa."""
if read_task.exception():
raise read_task.exception()
msg: WSMessage = read_task.result()
if msg.type == WSMsgType.TEXT:
return await target.send_str(msg.data)
if msg.type == WSMsgType.BINARY:
return await target.send_bytes(msg.data)
raise TypeError(
f"Cannot proxy websocket message of unsupported type: {msg.type}"
)
async def websocket(self, request: web.Request):
"""Initialize a WebSocket API connection."""
if not await self.sys_homeassistant.api.check_api_state():
@ -214,13 +236,13 @@ class APIProxy(CoreSysAttributes):
_LOGGER.info("Home Assistant WebSocket API request running")
try:
client_read = None
server_read = None
client_read: asyncio.Task | None = None
server_read: asyncio.Task | None = None
while not server.closed and not client.closed:
if not client_read:
client_read = self.sys_create_task(client.receive_str())
client_read = self.sys_create_task(client.receive())
if not server_read:
server_read = self.sys_create_task(server.receive_str())
server_read = self.sys_create_task(server.receive())
# wait until data need to be processed
await asyncio.wait(
@ -229,14 +251,12 @@ class APIProxy(CoreSysAttributes):
# server
if server_read.done() and not client.closed:
server_read.exception()
await client.send_str(server_read.result())
await self._proxy_message(server_read, client)
server_read = None
# client
if client_read.done() and not server.closed:
client_read.exception()
await server.send_str(client_read.result())
await self._proxy_message(client_read, server)
client_read = None
except asyncio.CancelledError:
@ -246,9 +266,9 @@ class APIProxy(CoreSysAttributes):
_LOGGER.info("Home Assistant WebSocket API error: %s", err)
finally:
if client_read:
if client_read and not client_read.done():
client_read.cancel()
if server_read:
if server_read and not server_read.done():
server_read.cancel()
# close connections

177
tests/api/test_proxy.py Normal file
View File

@ -0,0 +1,177 @@
"""Test Home Assistant proxy."""
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable, Coroutine, Generator
from json import dumps
from typing import Any, cast
from unittest.mock import patch
from aiohttp import ClientWebSocketResponse
from aiohttp.http_websocket import WSMessage, WSMsgType
from aiohttp.test_utils import TestClient
import pytest
from supervisor.addons.addon import Addon
from supervisor.api.proxy import APIProxy
from supervisor.const import ATTR_ACCESS_TOKEN
def id_generator() -> Generator[int, None, None]:
"""Generate IDs for WS messages."""
i = 0
while True:
yield (i := i + 1)
class MockHAClientWebSocket(ClientWebSocketResponse):
"""Protocol for a wrapped ClientWebSocketResponse."""
client: TestClient
send_json_auto_id: Callable[[dict[str, Any]], Coroutine[Any, Any, None]]
class MockHAServerWebSocket:
"""Mock of HA Websocket server."""
closed: bool = False
def __init__(self) -> None:
"""Initialize object."""
self.outgoing: asyncio.Queue[WSMessage] = asyncio.Queue()
self.incoming: asyncio.Queue[WSMessage] = asyncio.Queue()
self._id_generator = id_generator()
def receive(self) -> Awaitable[WSMessage]:
"""Receive next message."""
return self.outgoing.get()
def send_str(self, data: str) -> Awaitable[None]:
"""Incoming string message."""
return self.incoming.put(WSMessage(WSMsgType.TEXT, data, None))
def send_bytes(self, data: bytes) -> Awaitable[None]:
"""Incoming string message."""
return self.incoming.put(WSMessage(WSMsgType.BINARY, data, None))
def respond_json(self, data: dict[str, Any]) -> Awaitable[None]:
"""Respond with JSON."""
return self.outgoing.put(
WSMessage(
WSMsgType.TEXT, dumps(data | {"id": next(self._id_generator)}), None
)
)
def respond_bytes(self, data: bytes) -> Awaitable[None]:
"""Respond with binary."""
return self.outgoing.put(WSMessage(WSMsgType.BINARY, data, None))
async def close(self) -> None:
"""Close connection."""
self.closed = True
WebSocketGenerator = Callable[..., Coroutine[Any, Any, MockHAClientWebSocket]]
@pytest.fixture(name="ha_ws_server")
async def fixture_ha_ws_server() -> MockHAServerWebSocket:
"""Mock HA WS server for testing."""
with patch.object(
APIProxy,
"_websocket_client",
return_value=(mock_server := MockHAServerWebSocket()),
):
yield mock_server
@pytest.fixture(name="proxy_ws_client")
def fixture_proxy_ws_client(
api_client: TestClient, ha_ws_server: MockHAServerWebSocket
) -> WebSocketGenerator:
"""Websocket client fixture connected to websocket server."""
async def create_client(auth_token: str) -> MockHAClientWebSocket:
"""Create a websocket client."""
websocket = await api_client.ws_connect("/core/websocket")
auth_resp = await websocket.receive_json()
assert auth_resp["type"] == "auth_required"
await websocket.send_json({"type": "auth", "access_token": auth_token})
auth_ok = await websocket.receive_json()
assert auth_ok["type"] == "auth_ok"
_id_generator = id_generator()
def _send_json_auto_id(data: dict[str, Any]) -> Coroutine[Any, Any, None]:
data["id"] = next(_id_generator)
return websocket.send_json(data)
# wrap in client
wrapped_websocket = cast(MockHAClientWebSocket, websocket)
wrapped_websocket.client = api_client
wrapped_websocket.send_json_auto_id = _send_json_auto_id
return wrapped_websocket
return create_client
async def test_proxy_message(
proxy_ws_client: WebSocketGenerator,
ha_ws_server: MockHAServerWebSocket,
install_addon_ssh: Addon,
):
"""Test proxy a message to and from Home Assistant."""
install_addon_ssh.persist[ATTR_ACCESS_TOKEN] = "abc123"
client: MockHAClientWebSocket = await proxy_ws_client(
install_addon_ssh.supervisor_token
)
await client.send_json_auto_id({"hello": "world"})
proxied_msg = await ha_ws_server.incoming.get()
assert proxied_msg.type == WSMsgType.TEXT
assert proxied_msg.data == '{"hello": "world", "id": 1}'
await ha_ws_server.respond_json({"world": "received"})
assert await client.receive_json() == {"world": "received", "id": 1}
assert await client.close()
async def test_proxy_binary_message(
proxy_ws_client: WebSocketGenerator,
ha_ws_server: MockHAServerWebSocket,
install_addon_ssh: Addon,
):
"""Test proxy a binary message to and from Home Assistant."""
install_addon_ssh.persist[ATTR_ACCESS_TOKEN] = "abc123"
client: MockHAClientWebSocket = await proxy_ws_client(
install_addon_ssh.supervisor_token
)
await client.send_bytes(b"hello world")
proxied_msg = await ha_ws_server.incoming.get()
assert proxied_msg.type == WSMsgType.BINARY
assert proxied_msg.data == b"hello world"
await ha_ws_server.respond_bytes(b"world received")
assert await client.receive_bytes() == b"world received"
assert await client.close()
@pytest.mark.parametrize("auth_token", ["abc123", "bad"])
async def test_proxy_invalid_auth(
api_client: TestClient, install_addon_example: Addon, auth_token: str
):
"""Test invalid access token or addon with no access."""
install_addon_example.persist[ATTR_ACCESS_TOKEN] = "abc123"
websocket = await api_client.ws_connect("/core/websocket")
auth_resp = await websocket.receive_json()
assert auth_resp["type"] == "auth_required"
await websocket.send_json({"type": "auth", "access_token": auth_token})
auth_not_ok = await websocket.receive_json()
assert auth_not_ok["type"] == "auth_invalid"
assert auth_not_ok["message"] == "Invalid access"

View File

@ -17,6 +17,7 @@ panel_icon: "mdi:console"
panel_title: Terminal
hassio_api: true
hassio_role: manager
homeassistant_api: true
audio: true
uart: true
ports: