Allow passing binary to the WS connection (#89882)

* Allow passing binary to the WS connection

* Expand test coverage

* Test non-existing handler

* Allow signaling end of stream using empty payloads

* Store handlers in a list

* Handle binary handlers raising exceptions
This commit is contained in:
Paulus Schoutsen 2023-03-22 08:36:36 -04:00 committed by GitHub
parent 19d56a7102
commit 0ca6723378
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 175 additions and 3 deletions

View File

@ -25,6 +25,9 @@ current_connection = ContextVar["ActiveConnection | None"](
"current_connection", default=None "current_connection", default=None
) )
MessageHandler = Callable[[HomeAssistant, "ActiveConnection", dict[str, Any]], None]
BinaryHandler = Callable[[HomeAssistant, "ActiveConnection", bytes], None]
class ActiveConnection: class ActiveConnection:
"""Handle an active websocket client connection.""" """Handle an active websocket client connection."""
@ -46,7 +49,10 @@ class ActiveConnection:
self.subscriptions: dict[Hashable, Callable[[], Any]] = {} self.subscriptions: dict[Hashable, Callable[[], Any]] = {}
self.last_id = 0 self.last_id = 0
self.supported_features: dict[str, float] = {} self.supported_features: dict[str, float] = {}
self.handlers = self.hass.data[const.DOMAIN] self.handlers: dict[str, tuple[MessageHandler, vol.Schema]] = self.hass.data[
const.DOMAIN
]
self.binary_handlers: list[BinaryHandler | None] = []
current_connection.set(self) current_connection.set(self)
def get_description(self, request: web.Request | None) -> str: def get_description(self, request: web.Request | None) -> str:
@ -60,6 +66,38 @@ class ActiveConnection:
"""Return a context.""" """Return a context."""
return Context(user_id=self.user.id) return Context(user_id=self.user.id)
@callback
def async_register_binary_handler(
self, handler: BinaryHandler
) -> tuple[int, Callable[[], None]]:
"""Register a temporary binary handler for this connection.
Returns a binary handler_id (1 byte) and a callback to unregister the handler.
"""
if len(self.binary_handlers) < 255:
index = len(self.binary_handlers)
self.binary_handlers.append(None)
else:
# Once the list is full, we search for a None entry to reuse.
index = None
for idx, existing in enumerate(self.binary_handlers):
if existing is None:
index = idx
break
if index is None:
raise RuntimeError("Too many binary handlers registered")
self.binary_handlers[index] = handler
@callback
def unsub() -> None:
"""Unregister the handler."""
assert index is not None
self.binary_handlers[index] = None
return index + 1, unsub
@callback @callback
def send_result(self, msg_id: int, result: Any | None = None) -> None: def send_result(self, msg_id: int, result: Any | None = None) -> None:
"""Send a result message.""" """Send a result message."""
@ -75,6 +113,26 @@ class ActiveConnection:
"""Send a error message.""" """Send a error message."""
self.send_message(messages.error_message(msg_id, code, message)) self.send_message(messages.error_message(msg_id, code, message))
@callback
def async_handle_binary(self, handler_id: int, payload: bytes) -> None:
"""Handle a single incoming binary message."""
index = handler_id - 1
if (
index < 0
or index >= len(self.binary_handlers)
or (handler := self.binary_handlers[index]) is None
):
self.logger.error(
"Received binary message for non-existing handler %s", handler_id
)
return
try:
handler(self.hass, self, payload)
except Exception: # pylint: disable=broad-except
self.logger.exception("Error handling binary message")
self.binary_handlers[index] = None
@callback @callback
def async_handle(self, msg: dict[str, Any]) -> None: def async_handle(self, msg: dict[str, Any]) -> None:
"""Handle a single incoming message.""" """Handle a single incoming message."""

View File

@ -312,6 +312,15 @@ class WebSocketHandler:
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING): if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING):
break break
if msg.type == WSMsgType.BINARY:
if len(msg.data) < 1:
disconnect_warn = "Received invalid binary message."
break
handler = msg.data[0]
payload = msg.data[1:]
connection.async_handle_binary(handler, payload)
continue
if msg.type != WSMsgType.TEXT: if msg.type != WSMsgType.TEXT:
disconnect_warn = "Received non-Text message." disconnect_warn = "Received non-Text message."
break break

View File

@ -101,3 +101,27 @@ async def test_exception_handling(
assert send_messages[0]["error"]["code"] == code assert send_messages[0]["error"]["code"] == code
assert send_messages[0]["error"]["message"] == err assert send_messages[0]["error"]["message"] == err
assert log in caplog.text assert log in caplog.text
async def test_binary_handler_registration() -> None:
"""Test binary handler registration."""
connection = websocket_api.ActiveConnection(
None, Mock(data={websocket_api.DOMAIN: None}), None, None, Mock()
)
# One filler to align indexes with prefix numbers
unsubs = [None]
fake_handler = object()
for i in range(255):
prefix, unsub = connection.async_register_binary_handler(fake_handler)
assert prefix == i + 1
unsubs.append(unsub)
with pytest.raises(RuntimeError):
connection.async_register_binary_handler(None)
unsubs[15]()
# Verify we reuse an unsubscribed prefix
prefix, unsub = connection.async_register_binary_handler(None)
assert prefix == 15

View File

@ -1,13 +1,20 @@
"""Test Websocket API http module.""" """Test Websocket API http module."""
import asyncio import asyncio
from datetime import timedelta from datetime import timedelta
from typing import Any
from unittest.mock import patch from unittest.mock import patch
from aiohttp import ServerDisconnectedError, WSMsgType, web from aiohttp import ServerDisconnectedError, WSMsgType, web
import pytest import pytest
from homeassistant.components.websocket_api import const, http from homeassistant.components.websocket_api import (
from homeassistant.core import HomeAssistant async_register_command,
const,
http,
websocket_command,
)
from homeassistant.components.websocket_api.connection import ActiveConnection
from homeassistant.core import HomeAssistant, callback
from homeassistant.util.dt import utcnow from homeassistant.util.dt import utcnow
from tests.common import async_fire_time_changed from tests.common import async_fire_time_changed
@ -155,3 +162,77 @@ async def test_prepare_fail(
await hass_ws_client(hass) await hass_ws_client(hass)
assert "Timeout preparing request" in caplog.text assert "Timeout preparing request" in caplog.text
async def test_binary_message(
hass: HomeAssistant, websocket_client, caplog: pytest.LogCaptureFixture
) -> None:
"""Test binary messages."""
binary_payloads = {
104: ([], asyncio.Future()),
105: ([], asyncio.Future()),
}
# Register a handler
@callback
@websocket_command(
{
"type": "get_binary_message_handler",
}
)
def get_binary_message_handler(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
):
unsub = None
@callback
def binary_message_handler(
hass: HomeAssistant, connection: ActiveConnection, payload: bytes
):
nonlocal unsub
if msg["id"] == 103:
raise ValueError("Boom")
if payload:
binary_payloads[msg["id"]][0].append(payload)
else:
binary_payloads[msg["id"]][1].set_result(
b"".join(binary_payloads[msg["id"]][0])
)
unsub()
prefix, unsub = connection.async_register_binary_handler(binary_message_handler)
connection.send_result(msg["id"], {"prefix": prefix})
async_register_command(hass, get_binary_message_handler)
# Register multiple binary handlers
for i in range(101, 106):
await websocket_client.send_json(
{"id": i, "type": "get_binary_message_handler"}
)
result = await websocket_client.receive_json()
assert result["id"] == i
assert result["type"] == const.TYPE_RESULT
assert result["success"]
assert result["result"]["prefix"] == i - 100
# Send message to binary
await websocket_client.send_bytes((0).to_bytes(1, "big") + b"test0")
await websocket_client.send_bytes((3).to_bytes(1, "big") + b"test3")
await websocket_client.send_bytes((3).to_bytes(1, "big") + b"test3")
await websocket_client.send_bytes((10).to_bytes(1, "big") + b"test10")
await websocket_client.send_bytes((4).to_bytes(1, "big") + b"test4")
await websocket_client.send_bytes((4).to_bytes(1, "big") + b"")
await websocket_client.send_bytes((5).to_bytes(1, "big") + b"test5")
await websocket_client.send_bytes((5).to_bytes(1, "big") + b"test5-2")
await websocket_client.send_bytes((5).to_bytes(1, "big") + b"")
# Verify received
assert await binary_payloads[104][1] == b"test4"
assert await binary_payloads[105][1] == b"test5test5-2"
assert "Error handling binary message" in caplog.text
assert "Received binary message for non-existing handler 0" in caplog.text
assert "Received binary message for non-existing handler 3" in caplog.text
assert "Received binary message for non-existing handler 10" in caplog.text