mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 16:57:53 +00:00
Implement websocket message coalescing (#77238)
Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
parent
2161b6f049
commit
f6a03625ba
@ -74,6 +74,7 @@ def async_register_commands(
|
||||
async_reg(hass, handle_validate_config)
|
||||
async_reg(hass, handle_subscribe_entities)
|
||||
async_reg(hass, handle_supported_brands)
|
||||
async_reg(hass, handle_supported_features)
|
||||
|
||||
|
||||
def pong_message(iden: int) -> dict[str, Any]:
|
||||
@ -723,3 +724,18 @@ async def handle_supported_brands(
|
||||
raise int_or_exc
|
||||
data[int_or_exc.domain] = int_or_exc.manifest["supported_brands"]
|
||||
connection.send_result(msg["id"], data)
|
||||
|
||||
|
||||
@callback
|
||||
@decorators.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "supported_features",
|
||||
vol.Required("features"): {str: int},
|
||||
}
|
||||
)
|
||||
def handle_supported_features(
|
||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Handle setting supported features."""
|
||||
connection.supported_features = msg["features"]
|
||||
connection.send_result(msg["id"])
|
||||
|
@ -42,6 +42,7 @@ class ActiveConnection:
|
||||
self.refresh_token_id = refresh_token.id
|
||||
self.subscriptions: dict[Hashable, Callable[[], Any]] = {}
|
||||
self.last_id = 0
|
||||
self.supported_features: dict[str, float] = {}
|
||||
current_connection.set(self)
|
||||
|
||||
def context(self, msg: dict[str, Any]) -> Context:
|
||||
|
@ -55,3 +55,5 @@ COMPRESSED_STATE_ATTRIBUTES = "a"
|
||||
COMPRESSED_STATE_CONTEXT = "c"
|
||||
COMPRESSED_STATE_LAST_CHANGED = "lc"
|
||||
COMPRESSED_STATE_LAST_UPDATED = "lu"
|
||||
|
||||
FEATURE_COALESCE_MESSAGES = "coalesce_messages"
|
||||
|
@ -6,7 +6,7 @@ from collections.abc import Callable
|
||||
from contextlib import suppress
|
||||
import datetime as dt
|
||||
import logging
|
||||
from typing import Any, Final
|
||||
from typing import TYPE_CHECKING, Any, Final
|
||||
|
||||
from aiohttp import WSMsgType, web
|
||||
import async_timeout
|
||||
@ -16,11 +16,13 @@ from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
||||
from homeassistant.core import Event, HomeAssistant, callback
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
||||
from homeassistant.helpers.event import async_call_later
|
||||
from homeassistant.helpers.json import json_loads
|
||||
|
||||
from .auth import AuthPhase, auth_required_message
|
||||
from .const import (
|
||||
CANCELLATION_ERRORS,
|
||||
DATA_CONNECTIONS,
|
||||
FEATURE_COALESCE_MESSAGES,
|
||||
MAX_PENDING_MSG,
|
||||
PENDING_MSG_PEAK,
|
||||
PENDING_MSG_PEAK_TIME,
|
||||
@ -31,6 +33,10 @@ from .const import (
|
||||
from .error import Disconnect
|
||||
from .messages import message_to_json
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .connection import ActiveConnection
|
||||
|
||||
|
||||
_WS_LOGGER: Final = logging.getLogger(f"{__name__}.connection")
|
||||
|
||||
|
||||
@ -67,26 +73,47 @@ class WebSocketHandler:
|
||||
self._writer_task: asyncio.Task | None = None
|
||||
self._logger = WebSocketAdapter(_WS_LOGGER, {"connid": id(self)})
|
||||
self._peak_checker_unsub: Callable[[], None] | None = None
|
||||
self.connection: ActiveConnection | None = None
|
||||
|
||||
async def _writer(self) -> None:
|
||||
"""Write outgoing messages."""
|
||||
# Exceptions if Socket disconnected or cancelled by connection handler
|
||||
with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS):
|
||||
while not self.wsock.closed:
|
||||
if (process := await self._to_write.get()) is None:
|
||||
break
|
||||
to_write = self._to_write
|
||||
logger = self._logger
|
||||
wsock = self.wsock
|
||||
try:
|
||||
with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS):
|
||||
while not self.wsock.closed:
|
||||
if (process := await to_write.get()) is None:
|
||||
return
|
||||
message = process if isinstance(process, str) else process()
|
||||
|
||||
if not isinstance(process, str):
|
||||
message: str = process()
|
||||
else:
|
||||
message = process
|
||||
self._logger.debug("Sending %s", message)
|
||||
await self.wsock.send_str(message)
|
||||
if (
|
||||
to_write.empty()
|
||||
or not self.connection
|
||||
or FEATURE_COALESCE_MESSAGES
|
||||
not in self.connection.supported_features
|
||||
):
|
||||
logger.debug("Sending %s", message)
|
||||
await wsock.send_str(message)
|
||||
continue
|
||||
|
||||
# Clean up the peaker checker when we shut down the writer
|
||||
if self._peak_checker_unsub is not None:
|
||||
self._peak_checker_unsub()
|
||||
self._peak_checker_unsub = None
|
||||
messages: list[str] = [message]
|
||||
while not to_write.empty():
|
||||
if (process := to_write.get_nowait()) is None:
|
||||
return
|
||||
messages.append(
|
||||
process if isinstance(process, str) else process()
|
||||
)
|
||||
|
||||
coalesced_messages = "[" + ",".join(messages) + "]"
|
||||
self._logger.debug("Sending %s", coalesced_messages)
|
||||
await self.wsock.send_str(coalesced_messages)
|
||||
finally:
|
||||
# Clean up the peaker checker when we shut down the writer
|
||||
if self._peak_checker_unsub is not None:
|
||||
self._peak_checker_unsub()
|
||||
self._peak_checker_unsub = None
|
||||
|
||||
@callback
|
||||
def _send_message(self, message: str | dict[str, Any] | Callable[[], str]) -> None:
|
||||
@ -194,13 +221,13 @@ class WebSocketHandler:
|
||||
raise Disconnect
|
||||
|
||||
try:
|
||||
msg_data = msg.json()
|
||||
msg_data = msg.json(loads=json_loads)
|
||||
except ValueError as err:
|
||||
disconnect_warn = "Received invalid JSON."
|
||||
raise Disconnect from err
|
||||
|
||||
self._logger.debug("Received %s", msg_data)
|
||||
connection = await auth.async_handle(msg_data)
|
||||
self.connection = connection = await auth.async_handle(msg_data)
|
||||
self.hass.data[DATA_CONNECTIONS] = (
|
||||
self.hass.data.get(DATA_CONNECTIONS, 0) + 1
|
||||
)
|
||||
@ -218,13 +245,18 @@ class WebSocketHandler:
|
||||
break
|
||||
|
||||
try:
|
||||
msg_data = msg.json()
|
||||
msg_data = msg.json(loads=json_loads)
|
||||
except ValueError:
|
||||
disconnect_warn = "Received invalid JSON."
|
||||
break
|
||||
|
||||
self._logger.debug("Received %s", msg_data)
|
||||
connection.async_handle(msg_data)
|
||||
if not isinstance(msg_data, list):
|
||||
connection.async_handle(msg_data)
|
||||
continue
|
||||
|
||||
for split_msg in msg_data:
|
||||
connection.async_handle(split_msg)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
self._logger.info("Connection closed by client")
|
||||
@ -257,6 +289,8 @@ class WebSocketHandler:
|
||||
|
||||
if connection is not None:
|
||||
self.hass.data[DATA_CONNECTIONS] -= 1
|
||||
self.connection = None
|
||||
|
||||
async_dispatcher_send(self.hass, SIGNAL_WEBSOCKET_DISCONNECTED)
|
||||
|
||||
return wsock
|
||||
|
@ -13,12 +13,13 @@ from homeassistant.components.websocket_api.auth import (
|
||||
TYPE_AUTH_OK,
|
||||
TYPE_AUTH_REQUIRED,
|
||||
)
|
||||
from homeassistant.components.websocket_api.const import URL
|
||||
from homeassistant.components.websocket_api.const import FEATURE_COALESCE_MESSAGES, URL
|
||||
from homeassistant.const import SIGNAL_BOOTSTRAP_INTEGRATONS
|
||||
from homeassistant.core import Context, HomeAssistant, State, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import entity
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
||||
from homeassistant.helpers.json import json_loads
|
||||
from homeassistant.loader import async_get_integration
|
||||
from homeassistant.setup import DATA_SETUP_TIME, async_setup_component
|
||||
|
||||
@ -1788,3 +1789,186 @@ async def test_supported_brands(hass, websocket_client):
|
||||
"hello": "World",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def test_message_coalescing(hass, websocket_client, hass_admin_user):
|
||||
"""Test enabling message coalescing."""
|
||||
await websocket_client.send_json(
|
||||
{
|
||||
"id": 1,
|
||||
"type": "supported_features",
|
||||
"features": {FEATURE_COALESCE_MESSAGES: 1},
|
||||
}
|
||||
)
|
||||
hass.states.async_set("light.permitted", "on", {"color": "red"})
|
||||
|
||||
data = await websocket_client.receive_str()
|
||||
msg = json_loads(data)
|
||||
assert msg["id"] == 1
|
||||
assert msg["type"] == const.TYPE_RESULT
|
||||
assert msg["success"]
|
||||
|
||||
await websocket_client.send_json({"id": 7, "type": "subscribe_entities"})
|
||||
|
||||
data = await websocket_client.receive_str()
|
||||
msgs = json_loads(data)
|
||||
msg = msgs.pop(0)
|
||||
assert msg["id"] == 7
|
||||
assert msg["type"] == const.TYPE_RESULT
|
||||
assert msg["success"]
|
||||
|
||||
msg = msgs.pop(0)
|
||||
assert msg["id"] == 7
|
||||
assert msg["type"] == "event"
|
||||
assert msg["event"] == {
|
||||
"a": {
|
||||
"light.permitted": {"a": {"color": "red"}, "c": ANY, "lc": ANY, "s": "on"}
|
||||
}
|
||||
}
|
||||
|
||||
hass.states.async_set("light.permitted", "on", {"color": "yellow"})
|
||||
hass.states.async_set("light.permitted", "on", {"color": "green"})
|
||||
hass.states.async_set("light.permitted", "on", {"color": "blue"})
|
||||
|
||||
data = await websocket_client.receive_str()
|
||||
msgs = json_loads(data)
|
||||
|
||||
msg = msgs.pop(0)
|
||||
assert msg["id"] == 7
|
||||
assert msg["type"] == "event"
|
||||
assert msg["event"] == {
|
||||
"c": {"light.permitted": {"+": {"a": {"color": "yellow"}, "c": ANY, "lu": ANY}}}
|
||||
}
|
||||
|
||||
msg = msgs.pop(0)
|
||||
assert msg["id"] == 7
|
||||
assert msg["type"] == "event"
|
||||
assert msg["event"] == {
|
||||
"c": {"light.permitted": {"+": {"a": {"color": "green"}, "c": ANY, "lu": ANY}}}
|
||||
}
|
||||
|
||||
msg = msgs.pop(0)
|
||||
assert msg["id"] == 7
|
||||
assert msg["type"] == "event"
|
||||
assert msg["event"] == {
|
||||
"c": {"light.permitted": {"+": {"a": {"color": "blue"}, "c": ANY, "lu": ANY}}}
|
||||
}
|
||||
|
||||
hass.states.async_set("light.permitted", "on", {"color": "yellow"})
|
||||
hass.states.async_set("light.permitted", "on", {"color": "green"})
|
||||
hass.states.async_set("light.permitted", "on", {"color": "blue"})
|
||||
await websocket_client.close()
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
async def test_message_coalescing_not_supported_by_websocket_client(
|
||||
hass, websocket_client, hass_admin_user
|
||||
):
|
||||
"""Test enabling message coalescing not supported by websocket client."""
|
||||
await websocket_client.send_json({"id": 7, "type": "subscribe_entities"})
|
||||
|
||||
data = await websocket_client.receive_str()
|
||||
msg = json_loads(data)
|
||||
assert msg["id"] == 7
|
||||
assert msg["type"] == const.TYPE_RESULT
|
||||
assert msg["success"]
|
||||
|
||||
hass.states.async_set("light.permitted", "on", {"color": "red"})
|
||||
hass.states.async_set("light.permitted", "on", {"color": "blue"})
|
||||
|
||||
data = await websocket_client.receive_str()
|
||||
msg = json_loads(data)
|
||||
assert msg["id"] == 7
|
||||
assert msg["type"] == "event"
|
||||
assert msg["event"] == {"a": {}}
|
||||
|
||||
data = await websocket_client.receive_str()
|
||||
msg = json_loads(data)
|
||||
assert msg["id"] == 7
|
||||
assert msg["type"] == "event"
|
||||
assert msg["event"] == {
|
||||
"a": {
|
||||
"light.permitted": {"a": {"color": "red"}, "c": ANY, "lc": ANY, "s": "on"}
|
||||
}
|
||||
}
|
||||
|
||||
data = await websocket_client.receive_str()
|
||||
msg = json_loads(data)
|
||||
assert msg["id"] == 7
|
||||
assert msg["type"] == "event"
|
||||
assert msg["event"] == {
|
||||
"c": {"light.permitted": {"+": {"a": {"color": "blue"}, "c": ANY, "lu": ANY}}}
|
||||
}
|
||||
await websocket_client.close()
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
async def test_client_message_coalescing(hass, websocket_client, hass_admin_user):
|
||||
"""Test client message coalescing."""
|
||||
await websocket_client.send_json(
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"type": "supported_features",
|
||||
"features": {FEATURE_COALESCE_MESSAGES: 1},
|
||||
},
|
||||
{"id": 7, "type": "subscribe_entities"},
|
||||
]
|
||||
)
|
||||
hass.states.async_set("light.permitted", "on", {"color": "red"})
|
||||
|
||||
data = await websocket_client.receive_str()
|
||||
msgs = json_loads(data)
|
||||
|
||||
msg = msgs.pop(0)
|
||||
assert msg["id"] == 1
|
||||
assert msg["type"] == const.TYPE_RESULT
|
||||
assert msg["success"]
|
||||
|
||||
msg = msgs.pop(0)
|
||||
assert msg["id"] == 7
|
||||
assert msg["type"] == const.TYPE_RESULT
|
||||
assert msg["success"]
|
||||
|
||||
msg = msgs.pop(0)
|
||||
assert msg["id"] == 7
|
||||
assert msg["type"] == "event"
|
||||
assert msg["event"] == {
|
||||
"a": {
|
||||
"light.permitted": {"a": {"color": "red"}, "c": ANY, "lc": ANY, "s": "on"}
|
||||
}
|
||||
}
|
||||
|
||||
hass.states.async_set("light.permitted", "on", {"color": "yellow"})
|
||||
hass.states.async_set("light.permitted", "on", {"color": "green"})
|
||||
hass.states.async_set("light.permitted", "on", {"color": "blue"})
|
||||
|
||||
data = await websocket_client.receive_str()
|
||||
msgs = json_loads(data)
|
||||
|
||||
msg = msgs.pop(0)
|
||||
assert msg["id"] == 7
|
||||
assert msg["type"] == "event"
|
||||
assert msg["event"] == {
|
||||
"c": {"light.permitted": {"+": {"a": {"color": "yellow"}, "c": ANY, "lu": ANY}}}
|
||||
}
|
||||
|
||||
msg = msgs.pop(0)
|
||||
assert msg["id"] == 7
|
||||
assert msg["type"] == "event"
|
||||
assert msg["event"] == {
|
||||
"c": {"light.permitted": {"+": {"a": {"color": "green"}, "c": ANY, "lu": ANY}}}
|
||||
}
|
||||
|
||||
msg = msgs.pop(0)
|
||||
assert msg["id"] == 7
|
||||
assert msg["type"] == "event"
|
||||
assert msg["event"] == {
|
||||
"c": {"light.permitted": {"+": {"a": {"color": "blue"}, "c": ANY, "lu": ANY}}}
|
||||
}
|
||||
|
||||
hass.states.async_set("light.permitted", "on", {"color": "yellow"})
|
||||
hass.states.async_set("light.permitted", "on", {"color": "green"})
|
||||
hass.states.async_set("light.permitted", "on", {"color": "blue"})
|
||||
await websocket_client.close()
|
||||
await hass.async_block_till_done()
|
||||
|
@ -2,15 +2,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncGenerator, Callable, Generator
|
||||
from contextlib import asynccontextmanager
|
||||
import functools
|
||||
from json import JSONDecoder, loads
|
||||
import logging
|
||||
import ssl
|
||||
import threading
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
from aiohttp.test_utils import make_mocked_request
|
||||
from aiohttp import client
|
||||
from aiohttp.pytest_plugin import AiohttpClient
|
||||
from aiohttp.test_utils import (
|
||||
BaseTestServer,
|
||||
TestClient,
|
||||
TestServer,
|
||||
make_mocked_request,
|
||||
)
|
||||
from aiohttp.web import Application
|
||||
import freezegun
|
||||
import multidict
|
||||
import pytest
|
||||
@ -57,6 +67,7 @@ from tests.components.recorder.common import ( # noqa: E402, isort:skip
|
||||
async_recorder_block_till_done,
|
||||
)
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
@ -203,6 +214,97 @@ def load_registries():
|
||||
return True
|
||||
|
||||
|
||||
class CoalescingResponse(client.ClientWebSocketResponse):
|
||||
"""ClientWebSocketResponse client that mimics the websocket js code."""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Init the ClientWebSocketResponse."""
|
||||
super().__init__(*args, **kwargs)
|
||||
self._recv_buffer: list[Any] = []
|
||||
|
||||
async def receive_json(
|
||||
self,
|
||||
*,
|
||||
loads: JSONDecoder = loads,
|
||||
timeout: float | None = None,
|
||||
) -> Any:
|
||||
"""receive_json or from buffer."""
|
||||
if self._recv_buffer:
|
||||
return self._recv_buffer.pop(0)
|
||||
data = await self.receive_str(timeout=timeout)
|
||||
decoded = loads(data)
|
||||
if isinstance(decoded, list):
|
||||
self._recv_buffer = decoded
|
||||
return self._recv_buffer.pop(0)
|
||||
return decoded
|
||||
|
||||
|
||||
class CoalescingClient(TestClient):
|
||||
"""Client that mimics the websocket js code."""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Init TestClient."""
|
||||
super().__init__(*args, ws_response_class=CoalescingResponse, **kwargs)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def aiohttp_client_cls():
|
||||
"""Override the test class for aiohttp."""
|
||||
return CoalescingClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def aiohttp_client(
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
) -> Generator[AiohttpClient, None, None]:
|
||||
"""Override the default aiohttp_client since 3.x does not support aiohttp_client_cls.
|
||||
|
||||
Remove this when upgrading to 4.x as aiohttp_client_cls
|
||||
will do the same thing
|
||||
|
||||
aiohttp_client(app, **kwargs)
|
||||
aiohttp_client(server, **kwargs)
|
||||
aiohttp_client(raw_server, **kwargs)
|
||||
"""
|
||||
clients = []
|
||||
|
||||
async def go(
|
||||
__param: Application | BaseTestServer,
|
||||
*args: Any,
|
||||
server_kwargs: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> TestClient:
|
||||
|
||||
if isinstance(__param, Callable) and not isinstance( # type: ignore[arg-type]
|
||||
__param, (Application, BaseTestServer)
|
||||
):
|
||||
__param = __param(loop, *args, **kwargs)
|
||||
kwargs = {}
|
||||
else:
|
||||
assert not args, "args should be empty"
|
||||
|
||||
if isinstance(__param, Application):
|
||||
server_kwargs = server_kwargs or {}
|
||||
server = TestServer(__param, loop=loop, **server_kwargs)
|
||||
client = CoalescingClient(server, loop=loop, **kwargs)
|
||||
elif isinstance(__param, BaseTestServer):
|
||||
client = TestClient(__param, loop=loop, **kwargs)
|
||||
else:
|
||||
raise ValueError("Unknown argument type: %r" % type(__param))
|
||||
|
||||
await client.start_server()
|
||||
clients.append(client)
|
||||
return client
|
||||
|
||||
yield go
|
||||
|
||||
async def finalize() -> None:
|
||||
while clients:
|
||||
await clients.pop().close()
|
||||
|
||||
loop.run_until_complete(finalize())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hass(loop, load_registries, hass_storage, request):
|
||||
"""Fixture to provide a test instance of Home Assistant."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user