Implement websocket message coalescing (#77238)

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
J. Nick Koston 2022-08-24 22:50:48 -05:00 committed by GitHub
parent 2161b6f049
commit f6a03625ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 361 additions and 22 deletions

View File

@ -74,6 +74,7 @@ def async_register_commands(
async_reg(hass, handle_validate_config) async_reg(hass, handle_validate_config)
async_reg(hass, handle_subscribe_entities) async_reg(hass, handle_subscribe_entities)
async_reg(hass, handle_supported_brands) async_reg(hass, handle_supported_brands)
async_reg(hass, handle_supported_features)
def pong_message(iden: int) -> dict[str, Any]: def pong_message(iden: int) -> dict[str, Any]:
@ -723,3 +724,18 @@ async def handle_supported_brands(
raise int_or_exc raise int_or_exc
data[int_or_exc.domain] = int_or_exc.manifest["supported_brands"] data[int_or_exc.domain] = int_or_exc.manifest["supported_brands"]
connection.send_result(msg["id"], data) connection.send_result(msg["id"], data)
@callback
@decorators.websocket_command(
{
vol.Required("type"): "supported_features",
vol.Required("features"): {str: int},
}
)
def handle_supported_features(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle setting supported features."""
connection.supported_features = msg["features"]
connection.send_result(msg["id"])

View File

@ -42,6 +42,7 @@ class ActiveConnection:
self.refresh_token_id = refresh_token.id self.refresh_token_id = refresh_token.id
self.subscriptions: dict[Hashable, Callable[[], Any]] = {} self.subscriptions: dict[Hashable, Callable[[], Any]] = {}
self.last_id = 0 self.last_id = 0
self.supported_features: dict[str, float] = {}
current_connection.set(self) current_connection.set(self)
def context(self, msg: dict[str, Any]) -> Context: def context(self, msg: dict[str, Any]) -> Context:

View File

@ -55,3 +55,5 @@ COMPRESSED_STATE_ATTRIBUTES = "a"
COMPRESSED_STATE_CONTEXT = "c" COMPRESSED_STATE_CONTEXT = "c"
COMPRESSED_STATE_LAST_CHANGED = "lc" COMPRESSED_STATE_LAST_CHANGED = "lc"
COMPRESSED_STATE_LAST_UPDATED = "lu" COMPRESSED_STATE_LAST_UPDATED = "lu"
FEATURE_COALESCE_MESSAGES = "coalesce_messages"

View File

@ -6,7 +6,7 @@ from collections.abc import Callable
from contextlib import suppress from contextlib import suppress
import datetime as dt import datetime as dt
import logging import logging
from typing import Any, Final from typing import TYPE_CHECKING, Any, Final
from aiohttp import WSMsgType, web from aiohttp import WSMsgType, web
import async_timeout import async_timeout
@ -16,11 +16,13 @@ from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import Event, HomeAssistant, callback from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.event import async_call_later from homeassistant.helpers.event import async_call_later
from homeassistant.helpers.json import json_loads
from .auth import AuthPhase, auth_required_message from .auth import AuthPhase, auth_required_message
from .const import ( from .const import (
CANCELLATION_ERRORS, CANCELLATION_ERRORS,
DATA_CONNECTIONS, DATA_CONNECTIONS,
FEATURE_COALESCE_MESSAGES,
MAX_PENDING_MSG, MAX_PENDING_MSG,
PENDING_MSG_PEAK, PENDING_MSG_PEAK,
PENDING_MSG_PEAK_TIME, PENDING_MSG_PEAK_TIME,
@ -31,6 +33,10 @@ from .const import (
from .error import Disconnect from .error import Disconnect
from .messages import message_to_json from .messages import message_to_json
if TYPE_CHECKING:
from .connection import ActiveConnection
_WS_LOGGER: Final = logging.getLogger(f"{__name__}.connection") _WS_LOGGER: Final = logging.getLogger(f"{__name__}.connection")
@ -67,26 +73,47 @@ class WebSocketHandler:
self._writer_task: asyncio.Task | None = None self._writer_task: asyncio.Task | None = None
self._logger = WebSocketAdapter(_WS_LOGGER, {"connid": id(self)}) self._logger = WebSocketAdapter(_WS_LOGGER, {"connid": id(self)})
self._peak_checker_unsub: Callable[[], None] | None = None self._peak_checker_unsub: Callable[[], None] | None = None
self.connection: ActiveConnection | None = None
async def _writer(self) -> None: async def _writer(self) -> None:
"""Write outgoing messages.""" """Write outgoing messages."""
# Exceptions if Socket disconnected or cancelled by connection handler # Exceptions if Socket disconnected or cancelled by connection handler
with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS): to_write = self._to_write
while not self.wsock.closed: logger = self._logger
if (process := await self._to_write.get()) is None: wsock = self.wsock
break try:
with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS):
while not self.wsock.closed:
if (process := await to_write.get()) is None:
return
message = process if isinstance(process, str) else process()
if not isinstance(process, str): if (
message: str = process() to_write.empty()
else: or not self.connection
message = process or FEATURE_COALESCE_MESSAGES
self._logger.debug("Sending %s", message) not in self.connection.supported_features
await self.wsock.send_str(message) ):
logger.debug("Sending %s", message)
await wsock.send_str(message)
continue
# Clean up the peaker checker when we shut down the writer messages: list[str] = [message]
if self._peak_checker_unsub is not None: while not to_write.empty():
self._peak_checker_unsub() if (process := to_write.get_nowait()) is None:
self._peak_checker_unsub = None return
messages.append(
process if isinstance(process, str) else process()
)
coalesced_messages = "[" + ",".join(messages) + "]"
self._logger.debug("Sending %s", coalesced_messages)
await self.wsock.send_str(coalesced_messages)
finally:
# Clean up the peaker checker when we shut down the writer
if self._peak_checker_unsub is not None:
self._peak_checker_unsub()
self._peak_checker_unsub = None
@callback @callback
def _send_message(self, message: str | dict[str, Any] | Callable[[], str]) -> None: def _send_message(self, message: str | dict[str, Any] | Callable[[], str]) -> None:
@ -194,13 +221,13 @@ class WebSocketHandler:
raise Disconnect raise Disconnect
try: try:
msg_data = msg.json() msg_data = msg.json(loads=json_loads)
except ValueError as err: except ValueError as err:
disconnect_warn = "Received invalid JSON." disconnect_warn = "Received invalid JSON."
raise Disconnect from err raise Disconnect from err
self._logger.debug("Received %s", msg_data) self._logger.debug("Received %s", msg_data)
connection = await auth.async_handle(msg_data) self.connection = connection = await auth.async_handle(msg_data)
self.hass.data[DATA_CONNECTIONS] = ( self.hass.data[DATA_CONNECTIONS] = (
self.hass.data.get(DATA_CONNECTIONS, 0) + 1 self.hass.data.get(DATA_CONNECTIONS, 0) + 1
) )
@ -218,13 +245,18 @@ class WebSocketHandler:
break break
try: try:
msg_data = msg.json() msg_data = msg.json(loads=json_loads)
except ValueError: except ValueError:
disconnect_warn = "Received invalid JSON." disconnect_warn = "Received invalid JSON."
break break
self._logger.debug("Received %s", msg_data) self._logger.debug("Received %s", msg_data)
connection.async_handle(msg_data) if not isinstance(msg_data, list):
connection.async_handle(msg_data)
continue
for split_msg in msg_data:
connection.async_handle(split_msg)
except asyncio.CancelledError: except asyncio.CancelledError:
self._logger.info("Connection closed by client") self._logger.info("Connection closed by client")
@ -257,6 +289,8 @@ class WebSocketHandler:
if connection is not None: if connection is not None:
self.hass.data[DATA_CONNECTIONS] -= 1 self.hass.data[DATA_CONNECTIONS] -= 1
self.connection = None
async_dispatcher_send(self.hass, SIGNAL_WEBSOCKET_DISCONNECTED) async_dispatcher_send(self.hass, SIGNAL_WEBSOCKET_DISCONNECTED)
return wsock return wsock

View File

@ -13,12 +13,13 @@ from homeassistant.components.websocket_api.auth import (
TYPE_AUTH_OK, TYPE_AUTH_OK,
TYPE_AUTH_REQUIRED, TYPE_AUTH_REQUIRED,
) )
from homeassistant.components.websocket_api.const import URL from homeassistant.components.websocket_api.const import FEATURE_COALESCE_MESSAGES, URL
from homeassistant.const import SIGNAL_BOOTSTRAP_INTEGRATONS from homeassistant.const import SIGNAL_BOOTSTRAP_INTEGRATONS
from homeassistant.core import Context, HomeAssistant, State, callback from homeassistant.core import Context, HomeAssistant, State, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity from homeassistant.helpers import entity
from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.json import json_loads
from homeassistant.loader import async_get_integration from homeassistant.loader import async_get_integration
from homeassistant.setup import DATA_SETUP_TIME, async_setup_component from homeassistant.setup import DATA_SETUP_TIME, async_setup_component
@ -1788,3 +1789,186 @@ async def test_supported_brands(hass, websocket_client):
"hello": "World", "hello": "World",
}, },
} }
async def test_message_coalescing(hass, websocket_client, hass_admin_user):
"""Test enabling message coalescing."""
await websocket_client.send_json(
{
"id": 1,
"type": "supported_features",
"features": {FEATURE_COALESCE_MESSAGES: 1},
}
)
hass.states.async_set("light.permitted", "on", {"color": "red"})
data = await websocket_client.receive_str()
msg = json_loads(data)
assert msg["id"] == 1
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
await websocket_client.send_json({"id": 7, "type": "subscribe_entities"})
data = await websocket_client.receive_str()
msgs = json_loads(data)
msg = msgs.pop(0)
assert msg["id"] == 7
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
msg = msgs.pop(0)
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"] == {
"a": {
"light.permitted": {"a": {"color": "red"}, "c": ANY, "lc": ANY, "s": "on"}
}
}
hass.states.async_set("light.permitted", "on", {"color": "yellow"})
hass.states.async_set("light.permitted", "on", {"color": "green"})
hass.states.async_set("light.permitted", "on", {"color": "blue"})
data = await websocket_client.receive_str()
msgs = json_loads(data)
msg = msgs.pop(0)
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"] == {
"c": {"light.permitted": {"+": {"a": {"color": "yellow"}, "c": ANY, "lu": ANY}}}
}
msg = msgs.pop(0)
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"] == {
"c": {"light.permitted": {"+": {"a": {"color": "green"}, "c": ANY, "lu": ANY}}}
}
msg = msgs.pop(0)
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"] == {
"c": {"light.permitted": {"+": {"a": {"color": "blue"}, "c": ANY, "lu": ANY}}}
}
hass.states.async_set("light.permitted", "on", {"color": "yellow"})
hass.states.async_set("light.permitted", "on", {"color": "green"})
hass.states.async_set("light.permitted", "on", {"color": "blue"})
await websocket_client.close()
await hass.async_block_till_done()
async def test_message_coalescing_not_supported_by_websocket_client(
hass, websocket_client, hass_admin_user
):
"""Test enabling message coalescing not supported by websocket client."""
await websocket_client.send_json({"id": 7, "type": "subscribe_entities"})
data = await websocket_client.receive_str()
msg = json_loads(data)
assert msg["id"] == 7
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
hass.states.async_set("light.permitted", "on", {"color": "red"})
hass.states.async_set("light.permitted", "on", {"color": "blue"})
data = await websocket_client.receive_str()
msg = json_loads(data)
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"] == {"a": {}}
data = await websocket_client.receive_str()
msg = json_loads(data)
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"] == {
"a": {
"light.permitted": {"a": {"color": "red"}, "c": ANY, "lc": ANY, "s": "on"}
}
}
data = await websocket_client.receive_str()
msg = json_loads(data)
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"] == {
"c": {"light.permitted": {"+": {"a": {"color": "blue"}, "c": ANY, "lu": ANY}}}
}
await websocket_client.close()
await hass.async_block_till_done()
async def test_client_message_coalescing(hass, websocket_client, hass_admin_user):
"""Test client message coalescing."""
await websocket_client.send_json(
[
{
"id": 1,
"type": "supported_features",
"features": {FEATURE_COALESCE_MESSAGES: 1},
},
{"id": 7, "type": "subscribe_entities"},
]
)
hass.states.async_set("light.permitted", "on", {"color": "red"})
data = await websocket_client.receive_str()
msgs = json_loads(data)
msg = msgs.pop(0)
assert msg["id"] == 1
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
msg = msgs.pop(0)
assert msg["id"] == 7
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
msg = msgs.pop(0)
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"] == {
"a": {
"light.permitted": {"a": {"color": "red"}, "c": ANY, "lc": ANY, "s": "on"}
}
}
hass.states.async_set("light.permitted", "on", {"color": "yellow"})
hass.states.async_set("light.permitted", "on", {"color": "green"})
hass.states.async_set("light.permitted", "on", {"color": "blue"})
data = await websocket_client.receive_str()
msgs = json_loads(data)
msg = msgs.pop(0)
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"] == {
"c": {"light.permitted": {"+": {"a": {"color": "yellow"}, "c": ANY, "lu": ANY}}}
}
msg = msgs.pop(0)
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"] == {
"c": {"light.permitted": {"+": {"a": {"color": "green"}, "c": ANY, "lu": ANY}}}
}
msg = msgs.pop(0)
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"] == {
"c": {"light.permitted": {"+": {"a": {"color": "blue"}, "c": ANY, "lu": ANY}}}
}
hass.states.async_set("light.permitted", "on", {"color": "yellow"})
hass.states.async_set("light.permitted", "on", {"color": "green"})
hass.states.async_set("light.permitted", "on", {"color": "blue"})
await websocket_client.close()
await hass.async_block_till_done()

View File

@ -2,15 +2,25 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator, Callable, Generator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import functools import functools
from json import JSONDecoder, loads
import logging import logging
import ssl import ssl
import threading import threading
from typing import Any
from unittest.mock import AsyncMock, MagicMock, Mock, patch from unittest.mock import AsyncMock, MagicMock, Mock, patch
from aiohttp.test_utils import make_mocked_request from aiohttp import client
from aiohttp.pytest_plugin import AiohttpClient
from aiohttp.test_utils import (
BaseTestServer,
TestClient,
TestServer,
make_mocked_request,
)
from aiohttp.web import Application
import freezegun import freezegun
import multidict import multidict
import pytest import pytest
@ -57,6 +67,7 @@ from tests.components.recorder.common import ( # noqa: E402, isort:skip
async_recorder_block_till_done, async_recorder_block_till_done,
) )
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
@ -203,6 +214,97 @@ def load_registries():
return True return True
class CoalescingResponse(client.ClientWebSocketResponse):
"""ClientWebSocketResponse client that mimics the websocket js code."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Init the ClientWebSocketResponse."""
super().__init__(*args, **kwargs)
self._recv_buffer: list[Any] = []
async def receive_json(
self,
*,
loads: JSONDecoder = loads,
timeout: float | None = None,
) -> Any:
"""receive_json or from buffer."""
if self._recv_buffer:
return self._recv_buffer.pop(0)
data = await self.receive_str(timeout=timeout)
decoded = loads(data)
if isinstance(decoded, list):
self._recv_buffer = decoded
return self._recv_buffer.pop(0)
return decoded
class CoalescingClient(TestClient):
"""Client that mimics the websocket js code."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Init TestClient."""
super().__init__(*args, ws_response_class=CoalescingResponse, **kwargs)
@pytest.fixture
def aiohttp_client_cls():
"""Override the test class for aiohttp."""
return CoalescingClient
@pytest.fixture
def aiohttp_client(
loop: asyncio.AbstractEventLoop,
) -> Generator[AiohttpClient, None, None]:
"""Override the default aiohttp_client since 3.x does not support aiohttp_client_cls.
Remove this when upgrading to 4.x as aiohttp_client_cls
will do the same thing
aiohttp_client(app, **kwargs)
aiohttp_client(server, **kwargs)
aiohttp_client(raw_server, **kwargs)
"""
clients = []
async def go(
__param: Application | BaseTestServer,
*args: Any,
server_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> TestClient:
if isinstance(__param, Callable) and not isinstance( # type: ignore[arg-type]
__param, (Application, BaseTestServer)
):
__param = __param(loop, *args, **kwargs)
kwargs = {}
else:
assert not args, "args should be empty"
if isinstance(__param, Application):
server_kwargs = server_kwargs or {}
server = TestServer(__param, loop=loop, **server_kwargs)
client = CoalescingClient(server, loop=loop, **kwargs)
elif isinstance(__param, BaseTestServer):
client = TestClient(__param, loop=loop, **kwargs)
else:
raise ValueError("Unknown argument type: %r" % type(__param))
await client.start_server()
clients.append(client)
return client
yield go
async def finalize() -> None:
while clients:
await clients.pop().close()
loop.run_until_complete(finalize())
@pytest.fixture @pytest.fixture
def hass(loop, load_registries, hass_storage, request): def hass(loop, load_registries, hass_storage, request):
"""Fixture to provide a test instance of Home Assistant.""" """Fixture to provide a test instance of Home Assistant."""