mirror of
https://github.com/home-assistant/core.git
synced 2025-07-27 15:17:35 +00:00
Allow passing binary to the WS connection (#89882)
* Allow passing binary to the WS connection * Expand test coverage * Test non-existing handler * Allow signaling end of stream using empty payloads * Store handlers in a list * Handle binary handlers raising exceptions
This commit is contained in:
parent
19d56a7102
commit
0ca6723378
@ -25,6 +25,9 @@ current_connection = ContextVar["ActiveConnection | None"](
|
|||||||
"current_connection", default=None
|
"current_connection", default=None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
MessageHandler = Callable[[HomeAssistant, "ActiveConnection", dict[str, Any]], None]
|
||||||
|
BinaryHandler = Callable[[HomeAssistant, "ActiveConnection", bytes], None]
|
||||||
|
|
||||||
|
|
||||||
class ActiveConnection:
|
class ActiveConnection:
|
||||||
"""Handle an active websocket client connection."""
|
"""Handle an active websocket client connection."""
|
||||||
@ -46,7 +49,10 @@ class ActiveConnection:
|
|||||||
self.subscriptions: dict[Hashable, Callable[[], Any]] = {}
|
self.subscriptions: dict[Hashable, Callable[[], Any]] = {}
|
||||||
self.last_id = 0
|
self.last_id = 0
|
||||||
self.supported_features: dict[str, float] = {}
|
self.supported_features: dict[str, float] = {}
|
||||||
self.handlers = self.hass.data[const.DOMAIN]
|
self.handlers: dict[str, tuple[MessageHandler, vol.Schema]] = self.hass.data[
|
||||||
|
const.DOMAIN
|
||||||
|
]
|
||||||
|
self.binary_handlers: list[BinaryHandler | None] = []
|
||||||
current_connection.set(self)
|
current_connection.set(self)
|
||||||
|
|
||||||
def get_description(self, request: web.Request | None) -> str:
|
def get_description(self, request: web.Request | None) -> str:
|
||||||
@ -60,6 +66,38 @@ class ActiveConnection:
|
|||||||
"""Return a context."""
|
"""Return a context."""
|
||||||
return Context(user_id=self.user.id)
|
return Context(user_id=self.user.id)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_register_binary_handler(
|
||||||
|
self, handler: BinaryHandler
|
||||||
|
) -> tuple[int, Callable[[], None]]:
|
||||||
|
"""Register a temporary binary handler for this connection.
|
||||||
|
|
||||||
|
Returns a binary handler_id (1 byte) and a callback to unregister the handler.
|
||||||
|
"""
|
||||||
|
if len(self.binary_handlers) < 255:
|
||||||
|
index = len(self.binary_handlers)
|
||||||
|
self.binary_handlers.append(None)
|
||||||
|
else:
|
||||||
|
# Once the list is full, we search for a None entry to reuse.
|
||||||
|
index = None
|
||||||
|
for idx, existing in enumerate(self.binary_handlers):
|
||||||
|
if existing is None:
|
||||||
|
index = idx
|
||||||
|
break
|
||||||
|
|
||||||
|
if index is None:
|
||||||
|
raise RuntimeError("Too many binary handlers registered")
|
||||||
|
|
||||||
|
self.binary_handlers[index] = handler
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def unsub() -> None:
|
||||||
|
"""Unregister the handler."""
|
||||||
|
assert index is not None
|
||||||
|
self.binary_handlers[index] = None
|
||||||
|
|
||||||
|
return index + 1, unsub
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def send_result(self, msg_id: int, result: Any | None = None) -> None:
|
def send_result(self, msg_id: int, result: Any | None = None) -> None:
|
||||||
"""Send a result message."""
|
"""Send a result message."""
|
||||||
@ -75,6 +113,26 @@ class ActiveConnection:
|
|||||||
"""Send a error message."""
|
"""Send a error message."""
|
||||||
self.send_message(messages.error_message(msg_id, code, message))
|
self.send_message(messages.error_message(msg_id, code, message))
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_handle_binary(self, handler_id: int, payload: bytes) -> None:
|
||||||
|
"""Handle a single incoming binary message."""
|
||||||
|
index = handler_id - 1
|
||||||
|
if (
|
||||||
|
index < 0
|
||||||
|
or index >= len(self.binary_handlers)
|
||||||
|
or (handler := self.binary_handlers[index]) is None
|
||||||
|
):
|
||||||
|
self.logger.error(
|
||||||
|
"Received binary message for non-existing handler %s", handler_id
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
handler(self.hass, self, payload)
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
self.logger.exception("Error handling binary message")
|
||||||
|
self.binary_handlers[index] = None
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_handle(self, msg: dict[str, Any]) -> None:
|
def async_handle(self, msg: dict[str, Any]) -> None:
|
||||||
"""Handle a single incoming message."""
|
"""Handle a single incoming message."""
|
||||||
|
@ -312,6 +312,15 @@ class WebSocketHandler:
|
|||||||
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING):
|
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if msg.type == WSMsgType.BINARY:
|
||||||
|
if len(msg.data) < 1:
|
||||||
|
disconnect_warn = "Received invalid binary message."
|
||||||
|
break
|
||||||
|
handler = msg.data[0]
|
||||||
|
payload = msg.data[1:]
|
||||||
|
connection.async_handle_binary(handler, payload)
|
||||||
|
continue
|
||||||
|
|
||||||
if msg.type != WSMsgType.TEXT:
|
if msg.type != WSMsgType.TEXT:
|
||||||
disconnect_warn = "Received non-Text message."
|
disconnect_warn = "Received non-Text message."
|
||||||
break
|
break
|
||||||
|
@ -101,3 +101,27 @@ async def test_exception_handling(
|
|||||||
assert send_messages[0]["error"]["code"] == code
|
assert send_messages[0]["error"]["code"] == code
|
||||||
assert send_messages[0]["error"]["message"] == err
|
assert send_messages[0]["error"]["message"] == err
|
||||||
assert log in caplog.text
|
assert log in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
async def test_binary_handler_registration() -> None:
|
||||||
|
"""Test binary handler registration."""
|
||||||
|
connection = websocket_api.ActiveConnection(
|
||||||
|
None, Mock(data={websocket_api.DOMAIN: None}), None, None, Mock()
|
||||||
|
)
|
||||||
|
|
||||||
|
# One filler to align indexes with prefix numbers
|
||||||
|
unsubs = [None]
|
||||||
|
fake_handler = object()
|
||||||
|
for i in range(255):
|
||||||
|
prefix, unsub = connection.async_register_binary_handler(fake_handler)
|
||||||
|
assert prefix == i + 1
|
||||||
|
unsubs.append(unsub)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
connection.async_register_binary_handler(None)
|
||||||
|
|
||||||
|
unsubs[15]()
|
||||||
|
|
||||||
|
# Verify we reuse an unsubscribed prefix
|
||||||
|
prefix, unsub = connection.async_register_binary_handler(None)
|
||||||
|
assert prefix == 15
|
||||||
|
@ -1,13 +1,20 @@
|
|||||||
"""Test Websocket API http module."""
|
"""Test Websocket API http module."""
|
||||||
import asyncio
|
import asyncio
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
from typing import Any
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from aiohttp import ServerDisconnectedError, WSMsgType, web
|
from aiohttp import ServerDisconnectedError, WSMsgType, web
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components.websocket_api import const, http
|
from homeassistant.components.websocket_api import (
|
||||||
from homeassistant.core import HomeAssistant
|
async_register_command,
|
||||||
|
const,
|
||||||
|
http,
|
||||||
|
websocket_command,
|
||||||
|
)
|
||||||
|
from homeassistant.components.websocket_api.connection import ActiveConnection
|
||||||
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.util.dt import utcnow
|
from homeassistant.util.dt import utcnow
|
||||||
|
|
||||||
from tests.common import async_fire_time_changed
|
from tests.common import async_fire_time_changed
|
||||||
@ -155,3 +162,77 @@ async def test_prepare_fail(
|
|||||||
await hass_ws_client(hass)
|
await hass_ws_client(hass)
|
||||||
|
|
||||||
assert "Timeout preparing request" in caplog.text
|
assert "Timeout preparing request" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
async def test_binary_message(
|
||||||
|
hass: HomeAssistant, websocket_client, caplog: pytest.LogCaptureFixture
|
||||||
|
) -> None:
|
||||||
|
"""Test binary messages."""
|
||||||
|
binary_payloads = {
|
||||||
|
104: ([], asyncio.Future()),
|
||||||
|
105: ([], asyncio.Future()),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Register a handler
|
||||||
|
@callback
|
||||||
|
@websocket_command(
|
||||||
|
{
|
||||||
|
"type": "get_binary_message_handler",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def get_binary_message_handler(
|
||||||
|
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||||
|
):
|
||||||
|
unsub = None
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def binary_message_handler(
|
||||||
|
hass: HomeAssistant, connection: ActiveConnection, payload: bytes
|
||||||
|
):
|
||||||
|
nonlocal unsub
|
||||||
|
if msg["id"] == 103:
|
||||||
|
raise ValueError("Boom")
|
||||||
|
|
||||||
|
if payload:
|
||||||
|
binary_payloads[msg["id"]][0].append(payload)
|
||||||
|
else:
|
||||||
|
binary_payloads[msg["id"]][1].set_result(
|
||||||
|
b"".join(binary_payloads[msg["id"]][0])
|
||||||
|
)
|
||||||
|
unsub()
|
||||||
|
|
||||||
|
prefix, unsub = connection.async_register_binary_handler(binary_message_handler)
|
||||||
|
|
||||||
|
connection.send_result(msg["id"], {"prefix": prefix})
|
||||||
|
|
||||||
|
async_register_command(hass, get_binary_message_handler)
|
||||||
|
|
||||||
|
# Register multiple binary handlers
|
||||||
|
for i in range(101, 106):
|
||||||
|
await websocket_client.send_json(
|
||||||
|
{"id": i, "type": "get_binary_message_handler"}
|
||||||
|
)
|
||||||
|
result = await websocket_client.receive_json()
|
||||||
|
assert result["id"] == i
|
||||||
|
assert result["type"] == const.TYPE_RESULT
|
||||||
|
assert result["success"]
|
||||||
|
assert result["result"]["prefix"] == i - 100
|
||||||
|
|
||||||
|
# Send message to binary
|
||||||
|
await websocket_client.send_bytes((0).to_bytes(1, "big") + b"test0")
|
||||||
|
await websocket_client.send_bytes((3).to_bytes(1, "big") + b"test3")
|
||||||
|
await websocket_client.send_bytes((3).to_bytes(1, "big") + b"test3")
|
||||||
|
await websocket_client.send_bytes((10).to_bytes(1, "big") + b"test10")
|
||||||
|
await websocket_client.send_bytes((4).to_bytes(1, "big") + b"test4")
|
||||||
|
await websocket_client.send_bytes((4).to_bytes(1, "big") + b"")
|
||||||
|
await websocket_client.send_bytes((5).to_bytes(1, "big") + b"test5")
|
||||||
|
await websocket_client.send_bytes((5).to_bytes(1, "big") + b"test5-2")
|
||||||
|
await websocket_client.send_bytes((5).to_bytes(1, "big") + b"")
|
||||||
|
|
||||||
|
# Verify received
|
||||||
|
assert await binary_payloads[104][1] == b"test4"
|
||||||
|
assert await binary_payloads[105][1] == b"test5test5-2"
|
||||||
|
assert "Error handling binary message" in caplog.text
|
||||||
|
assert "Received binary message for non-existing handler 0" in caplog.text
|
||||||
|
assert "Received binary message for non-existing handler 3" in caplog.text
|
||||||
|
assert "Received binary message for non-existing handler 10" in caplog.text
|
||||||
|
Loading…
x
Reference in New Issue
Block a user