mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Allow passing binary to the WS connection (#89882)
* Allow passing binary to the WS connection * Expand test coverage * Test non-existing handler * Allow signaling end of stream using empty payloads * Store handlers in a list * Handle binary handlers raising exceptions
This commit is contained in:
parent
19d56a7102
commit
0ca6723378
@ -25,6 +25,9 @@ current_connection = ContextVar["ActiveConnection | None"](
|
||||
"current_connection", default=None
|
||||
)
|
||||
|
||||
MessageHandler = Callable[[HomeAssistant, "ActiveConnection", dict[str, Any]], None]
|
||||
BinaryHandler = Callable[[HomeAssistant, "ActiveConnection", bytes], None]
|
||||
|
||||
|
||||
class ActiveConnection:
|
||||
"""Handle an active websocket client connection."""
|
||||
@ -46,7 +49,10 @@ class ActiveConnection:
|
||||
self.subscriptions: dict[Hashable, Callable[[], Any]] = {}
|
||||
self.last_id = 0
|
||||
self.supported_features: dict[str, float] = {}
|
||||
self.handlers = self.hass.data[const.DOMAIN]
|
||||
self.handlers: dict[str, tuple[MessageHandler, vol.Schema]] = self.hass.data[
|
||||
const.DOMAIN
|
||||
]
|
||||
self.binary_handlers: list[BinaryHandler | None] = []
|
||||
current_connection.set(self)
|
||||
|
||||
def get_description(self, request: web.Request | None) -> str:
|
||||
@ -60,6 +66,38 @@ class ActiveConnection:
|
||||
"""Return a context."""
|
||||
return Context(user_id=self.user.id)
|
||||
|
||||
@callback
|
||||
def async_register_binary_handler(
|
||||
self, handler: BinaryHandler
|
||||
) -> tuple[int, Callable[[], None]]:
|
||||
"""Register a temporary binary handler for this connection.
|
||||
|
||||
Returns a binary handler_id (1 byte) and a callback to unregister the handler.
|
||||
"""
|
||||
if len(self.binary_handlers) < 255:
|
||||
index = len(self.binary_handlers)
|
||||
self.binary_handlers.append(None)
|
||||
else:
|
||||
# Once the list is full, we search for a None entry to reuse.
|
||||
index = None
|
||||
for idx, existing in enumerate(self.binary_handlers):
|
||||
if existing is None:
|
||||
index = idx
|
||||
break
|
||||
|
||||
if index is None:
|
||||
raise RuntimeError("Too many binary handlers registered")
|
||||
|
||||
self.binary_handlers[index] = handler
|
||||
|
||||
@callback
|
||||
def unsub() -> None:
|
||||
"""Unregister the handler."""
|
||||
assert index is not None
|
||||
self.binary_handlers[index] = None
|
||||
|
||||
return index + 1, unsub
|
||||
|
||||
@callback
|
||||
def send_result(self, msg_id: int, result: Any | None = None) -> None:
|
||||
"""Send a result message."""
|
||||
@ -75,6 +113,26 @@ class ActiveConnection:
|
||||
"""Send a error message."""
|
||||
self.send_message(messages.error_message(msg_id, code, message))
|
||||
|
||||
@callback
|
||||
def async_handle_binary(self, handler_id: int, payload: bytes) -> None:
|
||||
"""Handle a single incoming binary message."""
|
||||
index = handler_id - 1
|
||||
if (
|
||||
index < 0
|
||||
or index >= len(self.binary_handlers)
|
||||
or (handler := self.binary_handlers[index]) is None
|
||||
):
|
||||
self.logger.error(
|
||||
"Received binary message for non-existing handler %s", handler_id
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
handler(self.hass, self, payload)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
self.logger.exception("Error handling binary message")
|
||||
self.binary_handlers[index] = None
|
||||
|
||||
@callback
|
||||
def async_handle(self, msg: dict[str, Any]) -> None:
|
||||
"""Handle a single incoming message."""
|
||||
|
@ -312,6 +312,15 @@ class WebSocketHandler:
|
||||
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING):
|
||||
break
|
||||
|
||||
if msg.type == WSMsgType.BINARY:
|
||||
if len(msg.data) < 1:
|
||||
disconnect_warn = "Received invalid binary message."
|
||||
break
|
||||
handler = msg.data[0]
|
||||
payload = msg.data[1:]
|
||||
connection.async_handle_binary(handler, payload)
|
||||
continue
|
||||
|
||||
if msg.type != WSMsgType.TEXT:
|
||||
disconnect_warn = "Received non-Text message."
|
||||
break
|
||||
|
@ -101,3 +101,27 @@ async def test_exception_handling(
|
||||
assert send_messages[0]["error"]["code"] == code
|
||||
assert send_messages[0]["error"]["message"] == err
|
||||
assert log in caplog.text
|
||||
|
||||
|
||||
async def test_binary_handler_registration() -> None:
|
||||
"""Test binary handler registration."""
|
||||
connection = websocket_api.ActiveConnection(
|
||||
None, Mock(data={websocket_api.DOMAIN: None}), None, None, Mock()
|
||||
)
|
||||
|
||||
# One filler to align indexes with prefix numbers
|
||||
unsubs = [None]
|
||||
fake_handler = object()
|
||||
for i in range(255):
|
||||
prefix, unsub = connection.async_register_binary_handler(fake_handler)
|
||||
assert prefix == i + 1
|
||||
unsubs.append(unsub)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
connection.async_register_binary_handler(None)
|
||||
|
||||
unsubs[15]()
|
||||
|
||||
# Verify we reuse an unsubscribed prefix
|
||||
prefix, unsub = connection.async_register_binary_handler(None)
|
||||
assert prefix == 15
|
||||
|
@ -1,13 +1,20 @@
|
||||
"""Test Websocket API http module."""
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from aiohttp import ServerDisconnectedError, WSMsgType, web
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.websocket_api import const, http
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.components.websocket_api import (
|
||||
async_register_command,
|
||||
const,
|
||||
http,
|
||||
websocket_command,
|
||||
)
|
||||
from homeassistant.components.websocket_api.connection import ActiveConnection
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.util.dt import utcnow
|
||||
|
||||
from tests.common import async_fire_time_changed
|
||||
@ -155,3 +162,77 @@ async def test_prepare_fail(
|
||||
await hass_ws_client(hass)
|
||||
|
||||
assert "Timeout preparing request" in caplog.text
|
||||
|
||||
|
||||
async def test_binary_message(
|
||||
hass: HomeAssistant, websocket_client, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Test binary messages."""
|
||||
binary_payloads = {
|
||||
104: ([], asyncio.Future()),
|
||||
105: ([], asyncio.Future()),
|
||||
}
|
||||
|
||||
# Register a handler
|
||||
@callback
|
||||
@websocket_command(
|
||||
{
|
||||
"type": "get_binary_message_handler",
|
||||
}
|
||||
)
|
||||
def get_binary_message_handler(
|
||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||
):
|
||||
unsub = None
|
||||
|
||||
@callback
|
||||
def binary_message_handler(
|
||||
hass: HomeAssistant, connection: ActiveConnection, payload: bytes
|
||||
):
|
||||
nonlocal unsub
|
||||
if msg["id"] == 103:
|
||||
raise ValueError("Boom")
|
||||
|
||||
if payload:
|
||||
binary_payloads[msg["id"]][0].append(payload)
|
||||
else:
|
||||
binary_payloads[msg["id"]][1].set_result(
|
||||
b"".join(binary_payloads[msg["id"]][0])
|
||||
)
|
||||
unsub()
|
||||
|
||||
prefix, unsub = connection.async_register_binary_handler(binary_message_handler)
|
||||
|
||||
connection.send_result(msg["id"], {"prefix": prefix})
|
||||
|
||||
async_register_command(hass, get_binary_message_handler)
|
||||
|
||||
# Register multiple binary handlers
|
||||
for i in range(101, 106):
|
||||
await websocket_client.send_json(
|
||||
{"id": i, "type": "get_binary_message_handler"}
|
||||
)
|
||||
result = await websocket_client.receive_json()
|
||||
assert result["id"] == i
|
||||
assert result["type"] == const.TYPE_RESULT
|
||||
assert result["success"]
|
||||
assert result["result"]["prefix"] == i - 100
|
||||
|
||||
# Send message to binary
|
||||
await websocket_client.send_bytes((0).to_bytes(1, "big") + b"test0")
|
||||
await websocket_client.send_bytes((3).to_bytes(1, "big") + b"test3")
|
||||
await websocket_client.send_bytes((3).to_bytes(1, "big") + b"test3")
|
||||
await websocket_client.send_bytes((10).to_bytes(1, "big") + b"test10")
|
||||
await websocket_client.send_bytes((4).to_bytes(1, "big") + b"test4")
|
||||
await websocket_client.send_bytes((4).to_bytes(1, "big") + b"")
|
||||
await websocket_client.send_bytes((5).to_bytes(1, "big") + b"test5")
|
||||
await websocket_client.send_bytes((5).to_bytes(1, "big") + b"test5-2")
|
||||
await websocket_client.send_bytes((5).to_bytes(1, "big") + b"")
|
||||
|
||||
# Verify received
|
||||
assert await binary_payloads[104][1] == b"test4"
|
||||
assert await binary_payloads[105][1] == b"test5test5-2"
|
||||
assert "Error handling binary message" in caplog.text
|
||||
assert "Received binary message for non-existing handler 0" in caplog.text
|
||||
assert "Received binary message for non-existing handler 3" in caplog.text
|
||||
assert "Received binary message for non-existing handler 10" in caplog.text
|
||||
|
Loading…
x
Reference in New Issue
Block a user