Allow passing binary to the WS connection (#89882)

* Allow passing binary to the WS connection

* Expand test coverage

* Test non-existing handler

* Allow signaling end of stream using empty payloads

* Store handlers in a list

* Handle binary handlers raising exceptions
This commit is contained in:
Paulus Schoutsen 2023-03-22 08:36:36 -04:00 committed by GitHub
parent 19d56a7102
commit 0ca6723378
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 175 additions and 3 deletions

View File

@ -25,6 +25,9 @@ current_connection = ContextVar["ActiveConnection | None"](
"current_connection", default=None
)
MessageHandler = Callable[[HomeAssistant, "ActiveConnection", dict[str, Any]], None]
BinaryHandler = Callable[[HomeAssistant, "ActiveConnection", bytes], None]
class ActiveConnection:
"""Handle an active websocket client connection."""
@ -46,7 +49,10 @@ class ActiveConnection:
self.subscriptions: dict[Hashable, Callable[[], Any]] = {}
self.last_id = 0
self.supported_features: dict[str, float] = {}
self.handlers = self.hass.data[const.DOMAIN]
self.handlers: dict[str, tuple[MessageHandler, vol.Schema]] = self.hass.data[
const.DOMAIN
]
self.binary_handlers: list[BinaryHandler | None] = []
current_connection.set(self)
def get_description(self, request: web.Request | None) -> str:
@ -60,6 +66,38 @@ class ActiveConnection:
"""Return a context."""
return Context(user_id=self.user.id)
@callback
def async_register_binary_handler(
self, handler: BinaryHandler
) -> tuple[int, Callable[[], None]]:
"""Register a temporary binary handler for this connection.
Returns a binary handler_id (1 byte) and a callback to unregister the handler.
"""
if len(self.binary_handlers) < 255:
index = len(self.binary_handlers)
self.binary_handlers.append(None)
else:
# Once the list is full, we search for a None entry to reuse.
index = None
for idx, existing in enumerate(self.binary_handlers):
if existing is None:
index = idx
break
if index is None:
raise RuntimeError("Too many binary handlers registered")
self.binary_handlers[index] = handler
@callback
def unsub() -> None:
"""Unregister the handler."""
assert index is not None
self.binary_handlers[index] = None
return index + 1, unsub
@callback
def send_result(self, msg_id: int, result: Any | None = None) -> None:
"""Send a result message."""
@ -75,6 +113,26 @@ class ActiveConnection:
"""Send a error message."""
self.send_message(messages.error_message(msg_id, code, message))
@callback
def async_handle_binary(self, handler_id: int, payload: bytes) -> None:
"""Handle a single incoming binary message."""
index = handler_id - 1
if (
index < 0
or index >= len(self.binary_handlers)
or (handler := self.binary_handlers[index]) is None
):
self.logger.error(
"Received binary message for non-existing handler %s", handler_id
)
return
try:
handler(self.hass, self, payload)
except Exception: # pylint: disable=broad-except
self.logger.exception("Error handling binary message")
self.binary_handlers[index] = None
@callback
def async_handle(self, msg: dict[str, Any]) -> None:
"""Handle a single incoming message."""

View File

@ -312,6 +312,15 @@ class WebSocketHandler:
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING):
break
if msg.type == WSMsgType.BINARY:
if len(msg.data) < 1:
disconnect_warn = "Received invalid binary message."
break
handler = msg.data[0]
payload = msg.data[1:]
connection.async_handle_binary(handler, payload)
continue
if msg.type != WSMsgType.TEXT:
disconnect_warn = "Received non-Text message."
break

View File

@ -101,3 +101,27 @@ async def test_exception_handling(
assert send_messages[0]["error"]["code"] == code
assert send_messages[0]["error"]["message"] == err
assert log in caplog.text
async def test_binary_handler_registration() -> None:
"""Test binary handler registration."""
connection = websocket_api.ActiveConnection(
None, Mock(data={websocket_api.DOMAIN: None}), None, None, Mock()
)
# One filler to align indexes with prefix numbers
unsubs = [None]
fake_handler = object()
for i in range(255):
prefix, unsub = connection.async_register_binary_handler(fake_handler)
assert prefix == i + 1
unsubs.append(unsub)
with pytest.raises(RuntimeError):
connection.async_register_binary_handler(None)
unsubs[15]()
# Verify we reuse an unsubscribed prefix
prefix, unsub = connection.async_register_binary_handler(None)
assert prefix == 15

View File

@ -1,13 +1,20 @@
"""Test Websocket API http module."""
import asyncio
from datetime import timedelta
from typing import Any
from unittest.mock import patch
from aiohttp import ServerDisconnectedError, WSMsgType, web
import pytest
from homeassistant.components.websocket_api import const, http
from homeassistant.core import HomeAssistant
from homeassistant.components.websocket_api import (
async_register_command,
const,
http,
websocket_command,
)
from homeassistant.components.websocket_api.connection import ActiveConnection
from homeassistant.core import HomeAssistant, callback
from homeassistant.util.dt import utcnow
from tests.common import async_fire_time_changed
@ -155,3 +162,77 @@ async def test_prepare_fail(
await hass_ws_client(hass)
assert "Timeout preparing request" in caplog.text
async def test_binary_message(
hass: HomeAssistant, websocket_client, caplog: pytest.LogCaptureFixture
) -> None:
"""Test binary messages."""
binary_payloads = {
104: ([], asyncio.Future()),
105: ([], asyncio.Future()),
}
# Register a handler
@callback
@websocket_command(
{
"type": "get_binary_message_handler",
}
)
def get_binary_message_handler(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
):
unsub = None
@callback
def binary_message_handler(
hass: HomeAssistant, connection: ActiveConnection, payload: bytes
):
nonlocal unsub
if msg["id"] == 103:
raise ValueError("Boom")
if payload:
binary_payloads[msg["id"]][0].append(payload)
else:
binary_payloads[msg["id"]][1].set_result(
b"".join(binary_payloads[msg["id"]][0])
)
unsub()
prefix, unsub = connection.async_register_binary_handler(binary_message_handler)
connection.send_result(msg["id"], {"prefix": prefix})
async_register_command(hass, get_binary_message_handler)
# Register multiple binary handlers
for i in range(101, 106):
await websocket_client.send_json(
{"id": i, "type": "get_binary_message_handler"}
)
result = await websocket_client.receive_json()
assert result["id"] == i
assert result["type"] == const.TYPE_RESULT
assert result["success"]
assert result["result"]["prefix"] == i - 100
# Send message to binary
await websocket_client.send_bytes((0).to_bytes(1, "big") + b"test0")
await websocket_client.send_bytes((3).to_bytes(1, "big") + b"test3")
await websocket_client.send_bytes((3).to_bytes(1, "big") + b"test3")
await websocket_client.send_bytes((10).to_bytes(1, "big") + b"test10")
await websocket_client.send_bytes((4).to_bytes(1, "big") + b"test4")
await websocket_client.send_bytes((4).to_bytes(1, "big") + b"")
await websocket_client.send_bytes((5).to_bytes(1, "big") + b"test5")
await websocket_client.send_bytes((5).to_bytes(1, "big") + b"test5-2")
await websocket_client.send_bytes((5).to_bytes(1, "big") + b"")
# Verify received
assert await binary_payloads[104][1] == b"test4"
assert await binary_payloads[105][1] == b"test5test5-2"
assert "Error handling binary message" in caplog.text
assert "Received binary message for non-existing handler 0" in caplog.text
assert "Received binary message for non-existing handler 3" in caplog.text
assert "Received binary message for non-existing handler 10" in caplog.text