mirror of
https://github.com/home-assistant/supervisor.git
synced 2025-07-13 12:16:29 +00:00
Support proxy of binary messages from addons to HA (#4605)
* Support proxy of binary messages from addons to HA * Added tests for proxy * Move instantiation into init * Mock close method on server * Add invalid auth test and remove auth mock
This commit is contained in:
parent
a70f81aa01
commit
012bfd7e6c
@ -6,7 +6,10 @@ import logging
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
from aiohttp.client_exceptions import ClientConnectorError
|
||||
from aiohttp.client_ws import ClientWebSocketResponse
|
||||
from aiohttp.hdrs import AUTHORIZATION, CONTENT_TYPE
|
||||
from aiohttp.http import WSMessage
|
||||
from aiohttp.http_websocket import WSMsgType
|
||||
from aiohttp.web_exceptions import HTTPBadGateway, HTTPUnauthorized
|
||||
|
||||
from ..coresys import CoreSysAttributes
|
||||
@ -114,7 +117,7 @@ class APIProxy(CoreSysAttributes):
|
||||
body=data, status=client.status, content_type=client.content_type
|
||||
)
|
||||
|
||||
async def _websocket_client(self):
|
||||
async def _websocket_client(self) -> ClientWebSocketResponse:
|
||||
"""Initialize a WebSocket API connection."""
|
||||
url = f"{self.sys_homeassistant.api_url}/api/websocket"
|
||||
|
||||
@ -167,6 +170,25 @@ class APIProxy(CoreSysAttributes):
|
||||
|
||||
raise APIError()
|
||||
|
||||
async def _proxy_message(
|
||||
self,
|
||||
read_task: asyncio.Task,
|
||||
target: web.WebSocketResponse | ClientWebSocketResponse,
|
||||
) -> None:
|
||||
"""Proxy a message from client to server or vice versa."""
|
||||
if read_task.exception():
|
||||
raise read_task.exception()
|
||||
|
||||
msg: WSMessage = read_task.result()
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
return await target.send_str(msg.data)
|
||||
if msg.type == WSMsgType.BINARY:
|
||||
return await target.send_bytes(msg.data)
|
||||
|
||||
raise TypeError(
|
||||
f"Cannot proxy websocket message of unsupported type: {msg.type}"
|
||||
)
|
||||
|
||||
async def websocket(self, request: web.Request):
|
||||
"""Initialize a WebSocket API connection."""
|
||||
if not await self.sys_homeassistant.api.check_api_state():
|
||||
@ -214,13 +236,13 @@ class APIProxy(CoreSysAttributes):
|
||||
|
||||
_LOGGER.info("Home Assistant WebSocket API request running")
|
||||
try:
|
||||
client_read = None
|
||||
server_read = None
|
||||
client_read: asyncio.Task | None = None
|
||||
server_read: asyncio.Task | None = None
|
||||
while not server.closed and not client.closed:
|
||||
if not client_read:
|
||||
client_read = self.sys_create_task(client.receive_str())
|
||||
client_read = self.sys_create_task(client.receive())
|
||||
if not server_read:
|
||||
server_read = self.sys_create_task(server.receive_str())
|
||||
server_read = self.sys_create_task(server.receive())
|
||||
|
||||
# wait until data need to be processed
|
||||
await asyncio.wait(
|
||||
@ -229,14 +251,12 @@ class APIProxy(CoreSysAttributes):
|
||||
|
||||
# server
|
||||
if server_read.done() and not client.closed:
|
||||
server_read.exception()
|
||||
await client.send_str(server_read.result())
|
||||
await self._proxy_message(server_read, client)
|
||||
server_read = None
|
||||
|
||||
# client
|
||||
if client_read.done() and not server.closed:
|
||||
client_read.exception()
|
||||
await server.send_str(client_read.result())
|
||||
await self._proxy_message(client_read, server)
|
||||
client_read = None
|
||||
|
||||
except asyncio.CancelledError:
|
||||
@ -246,9 +266,9 @@ class APIProxy(CoreSysAttributes):
|
||||
_LOGGER.info("Home Assistant WebSocket API error: %s", err)
|
||||
|
||||
finally:
|
||||
if client_read:
|
||||
if client_read and not client_read.done():
|
||||
client_read.cancel()
|
||||
if server_read:
|
||||
if server_read and not server_read.done():
|
||||
server_read.cancel()
|
||||
|
||||
# close connections
|
||||
|
177
tests/api/test_proxy.py
Normal file
177
tests/api/test_proxy.py
Normal file
@ -0,0 +1,177 @@
|
||||
"""Test Home Assistant proxy."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable, Coroutine, Generator
|
||||
from json import dumps
|
||||
from typing import Any, cast
|
||||
from unittest.mock import patch
|
||||
|
||||
from aiohttp import ClientWebSocketResponse
|
||||
from aiohttp.http_websocket import WSMessage, WSMsgType
|
||||
from aiohttp.test_utils import TestClient
|
||||
import pytest
|
||||
|
||||
from supervisor.addons.addon import Addon
|
||||
from supervisor.api.proxy import APIProxy
|
||||
from supervisor.const import ATTR_ACCESS_TOKEN
|
||||
|
||||
|
||||
def id_generator() -> Generator[int, None, None]:
|
||||
"""Generate IDs for WS messages."""
|
||||
i = 0
|
||||
while True:
|
||||
yield (i := i + 1)
|
||||
|
||||
|
||||
class MockHAClientWebSocket(ClientWebSocketResponse):
|
||||
"""Protocol for a wrapped ClientWebSocketResponse."""
|
||||
|
||||
client: TestClient
|
||||
send_json_auto_id: Callable[[dict[str, Any]], Coroutine[Any, Any, None]]
|
||||
|
||||
|
||||
class MockHAServerWebSocket:
|
||||
"""Mock of HA Websocket server."""
|
||||
|
||||
closed: bool = False
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize object."""
|
||||
self.outgoing: asyncio.Queue[WSMessage] = asyncio.Queue()
|
||||
self.incoming: asyncio.Queue[WSMessage] = asyncio.Queue()
|
||||
self._id_generator = id_generator()
|
||||
|
||||
def receive(self) -> Awaitable[WSMessage]:
|
||||
"""Receive next message."""
|
||||
return self.outgoing.get()
|
||||
|
||||
def send_str(self, data: str) -> Awaitable[None]:
|
||||
"""Incoming string message."""
|
||||
return self.incoming.put(WSMessage(WSMsgType.TEXT, data, None))
|
||||
|
||||
def send_bytes(self, data: bytes) -> Awaitable[None]:
|
||||
"""Incoming string message."""
|
||||
return self.incoming.put(WSMessage(WSMsgType.BINARY, data, None))
|
||||
|
||||
def respond_json(self, data: dict[str, Any]) -> Awaitable[None]:
|
||||
"""Respond with JSON."""
|
||||
return self.outgoing.put(
|
||||
WSMessage(
|
||||
WSMsgType.TEXT, dumps(data | {"id": next(self._id_generator)}), None
|
||||
)
|
||||
)
|
||||
|
||||
def respond_bytes(self, data: bytes) -> Awaitable[None]:
|
||||
"""Respond with binary."""
|
||||
return self.outgoing.put(WSMessage(WSMsgType.BINARY, data, None))
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close connection."""
|
||||
self.closed = True
|
||||
|
||||
|
||||
WebSocketGenerator = Callable[..., Coroutine[Any, Any, MockHAClientWebSocket]]
|
||||
|
||||
|
||||
@pytest.fixture(name="ha_ws_server")
|
||||
async def fixture_ha_ws_server() -> MockHAServerWebSocket:
|
||||
"""Mock HA WS server for testing."""
|
||||
with patch.object(
|
||||
APIProxy,
|
||||
"_websocket_client",
|
||||
return_value=(mock_server := MockHAServerWebSocket()),
|
||||
):
|
||||
yield mock_server
|
||||
|
||||
|
||||
@pytest.fixture(name="proxy_ws_client")
|
||||
def fixture_proxy_ws_client(
|
||||
api_client: TestClient, ha_ws_server: MockHAServerWebSocket
|
||||
) -> WebSocketGenerator:
|
||||
"""Websocket client fixture connected to websocket server."""
|
||||
|
||||
async def create_client(auth_token: str) -> MockHAClientWebSocket:
|
||||
"""Create a websocket client."""
|
||||
websocket = await api_client.ws_connect("/core/websocket")
|
||||
auth_resp = await websocket.receive_json()
|
||||
assert auth_resp["type"] == "auth_required"
|
||||
await websocket.send_json({"type": "auth", "access_token": auth_token})
|
||||
|
||||
auth_ok = await websocket.receive_json()
|
||||
assert auth_ok["type"] == "auth_ok"
|
||||
|
||||
_id_generator = id_generator()
|
||||
|
||||
def _send_json_auto_id(data: dict[str, Any]) -> Coroutine[Any, Any, None]:
|
||||
data["id"] = next(_id_generator)
|
||||
return websocket.send_json(data)
|
||||
|
||||
# wrap in client
|
||||
wrapped_websocket = cast(MockHAClientWebSocket, websocket)
|
||||
wrapped_websocket.client = api_client
|
||||
wrapped_websocket.send_json_auto_id = _send_json_auto_id
|
||||
return wrapped_websocket
|
||||
|
||||
return create_client
|
||||
|
||||
|
||||
async def test_proxy_message(
|
||||
proxy_ws_client: WebSocketGenerator,
|
||||
ha_ws_server: MockHAServerWebSocket,
|
||||
install_addon_ssh: Addon,
|
||||
):
|
||||
"""Test proxy a message to and from Home Assistant."""
|
||||
install_addon_ssh.persist[ATTR_ACCESS_TOKEN] = "abc123"
|
||||
client: MockHAClientWebSocket = await proxy_ws_client(
|
||||
install_addon_ssh.supervisor_token
|
||||
)
|
||||
|
||||
await client.send_json_auto_id({"hello": "world"})
|
||||
proxied_msg = await ha_ws_server.incoming.get()
|
||||
assert proxied_msg.type == WSMsgType.TEXT
|
||||
assert proxied_msg.data == '{"hello": "world", "id": 1}'
|
||||
|
||||
await ha_ws_server.respond_json({"world": "received"})
|
||||
assert await client.receive_json() == {"world": "received", "id": 1}
|
||||
|
||||
assert await client.close()
|
||||
|
||||
|
||||
async def test_proxy_binary_message(
|
||||
proxy_ws_client: WebSocketGenerator,
|
||||
ha_ws_server: MockHAServerWebSocket,
|
||||
install_addon_ssh: Addon,
|
||||
):
|
||||
"""Test proxy a binary message to and from Home Assistant."""
|
||||
install_addon_ssh.persist[ATTR_ACCESS_TOKEN] = "abc123"
|
||||
client: MockHAClientWebSocket = await proxy_ws_client(
|
||||
install_addon_ssh.supervisor_token
|
||||
)
|
||||
|
||||
await client.send_bytes(b"hello world")
|
||||
proxied_msg = await ha_ws_server.incoming.get()
|
||||
assert proxied_msg.type == WSMsgType.BINARY
|
||||
assert proxied_msg.data == b"hello world"
|
||||
|
||||
await ha_ws_server.respond_bytes(b"world received")
|
||||
assert await client.receive_bytes() == b"world received"
|
||||
|
||||
assert await client.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("auth_token", ["abc123", "bad"])
|
||||
async def test_proxy_invalid_auth(
|
||||
api_client: TestClient, install_addon_example: Addon, auth_token: str
|
||||
):
|
||||
"""Test invalid access token or addon with no access."""
|
||||
install_addon_example.persist[ATTR_ACCESS_TOKEN] = "abc123"
|
||||
websocket = await api_client.ws_connect("/core/websocket")
|
||||
auth_resp = await websocket.receive_json()
|
||||
assert auth_resp["type"] == "auth_required"
|
||||
await websocket.send_json({"type": "auth", "access_token": auth_token})
|
||||
|
||||
auth_not_ok = await websocket.receive_json()
|
||||
assert auth_not_ok["type"] == "auth_invalid"
|
||||
assert auth_not_ok["message"] == "Invalid access"
|
1
tests/fixtures/addons/local/ssh/config.yaml
vendored
1
tests/fixtures/addons/local/ssh/config.yaml
vendored
@ -17,6 +17,7 @@ panel_icon: "mdi:console"
|
||||
panel_title: Terminal
|
||||
hassio_api: true
|
||||
hassio_role: manager
|
||||
homeassistant_api: true
|
||||
audio: true
|
||||
uart: true
|
||||
ports:
|
||||
|
Loading…
x
Reference in New Issue
Block a user