mirror of
https://github.com/home-assistant/core.git
synced 2025-07-24 21:57:51 +00:00
Implement websocket message coalescing (#77238)
Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
parent
2161b6f049
commit
f6a03625ba
@ -74,6 +74,7 @@ def async_register_commands(
|
|||||||
async_reg(hass, handle_validate_config)
|
async_reg(hass, handle_validate_config)
|
||||||
async_reg(hass, handle_subscribe_entities)
|
async_reg(hass, handle_subscribe_entities)
|
||||||
async_reg(hass, handle_supported_brands)
|
async_reg(hass, handle_supported_brands)
|
||||||
|
async_reg(hass, handle_supported_features)
|
||||||
|
|
||||||
|
|
||||||
def pong_message(iden: int) -> dict[str, Any]:
|
def pong_message(iden: int) -> dict[str, Any]:
|
||||||
@ -723,3 +724,18 @@ async def handle_supported_brands(
|
|||||||
raise int_or_exc
|
raise int_or_exc
|
||||||
data[int_or_exc.domain] = int_or_exc.manifest["supported_brands"]
|
data[int_or_exc.domain] = int_or_exc.manifest["supported_brands"]
|
||||||
connection.send_result(msg["id"], data)
|
connection.send_result(msg["id"], data)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
@decorators.websocket_command(
|
||||||
|
{
|
||||||
|
vol.Required("type"): "supported_features",
|
||||||
|
vol.Required("features"): {str: int},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def handle_supported_features(
|
||||||
|
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
|
"""Handle setting supported features."""
|
||||||
|
connection.supported_features = msg["features"]
|
||||||
|
connection.send_result(msg["id"])
|
||||||
|
@ -42,6 +42,7 @@ class ActiveConnection:
|
|||||||
self.refresh_token_id = refresh_token.id
|
self.refresh_token_id = refresh_token.id
|
||||||
self.subscriptions: dict[Hashable, Callable[[], Any]] = {}
|
self.subscriptions: dict[Hashable, Callable[[], Any]] = {}
|
||||||
self.last_id = 0
|
self.last_id = 0
|
||||||
|
self.supported_features: dict[str, float] = {}
|
||||||
current_connection.set(self)
|
current_connection.set(self)
|
||||||
|
|
||||||
def context(self, msg: dict[str, Any]) -> Context:
|
def context(self, msg: dict[str, Any]) -> Context:
|
||||||
|
@ -55,3 +55,5 @@ COMPRESSED_STATE_ATTRIBUTES = "a"
|
|||||||
COMPRESSED_STATE_CONTEXT = "c"
|
COMPRESSED_STATE_CONTEXT = "c"
|
||||||
COMPRESSED_STATE_LAST_CHANGED = "lc"
|
COMPRESSED_STATE_LAST_CHANGED = "lc"
|
||||||
COMPRESSED_STATE_LAST_UPDATED = "lu"
|
COMPRESSED_STATE_LAST_UPDATED = "lu"
|
||||||
|
|
||||||
|
FEATURE_COALESCE_MESSAGES = "coalesce_messages"
|
||||||
|
@ -6,7 +6,7 @@ from collections.abc import Callable
|
|||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
import datetime as dt
|
import datetime as dt
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Final
|
from typing import TYPE_CHECKING, Any, Final
|
||||||
|
|
||||||
from aiohttp import WSMsgType, web
|
from aiohttp import WSMsgType, web
|
||||||
import async_timeout
|
import async_timeout
|
||||||
@ -16,11 +16,13 @@ from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
|||||||
from homeassistant.core import Event, HomeAssistant, callback
|
from homeassistant.core import Event, HomeAssistant, callback
|
||||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
||||||
from homeassistant.helpers.event import async_call_later
|
from homeassistant.helpers.event import async_call_later
|
||||||
|
from homeassistant.helpers.json import json_loads
|
||||||
|
|
||||||
from .auth import AuthPhase, auth_required_message
|
from .auth import AuthPhase, auth_required_message
|
||||||
from .const import (
|
from .const import (
|
||||||
CANCELLATION_ERRORS,
|
CANCELLATION_ERRORS,
|
||||||
DATA_CONNECTIONS,
|
DATA_CONNECTIONS,
|
||||||
|
FEATURE_COALESCE_MESSAGES,
|
||||||
MAX_PENDING_MSG,
|
MAX_PENDING_MSG,
|
||||||
PENDING_MSG_PEAK,
|
PENDING_MSG_PEAK,
|
||||||
PENDING_MSG_PEAK_TIME,
|
PENDING_MSG_PEAK_TIME,
|
||||||
@ -31,6 +33,10 @@ from .const import (
|
|||||||
from .error import Disconnect
|
from .error import Disconnect
|
||||||
from .messages import message_to_json
|
from .messages import message_to_json
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .connection import ActiveConnection
|
||||||
|
|
||||||
|
|
||||||
_WS_LOGGER: Final = logging.getLogger(f"{__name__}.connection")
|
_WS_LOGGER: Final = logging.getLogger(f"{__name__}.connection")
|
||||||
|
|
||||||
|
|
||||||
@ -67,26 +73,47 @@ class WebSocketHandler:
|
|||||||
self._writer_task: asyncio.Task | None = None
|
self._writer_task: asyncio.Task | None = None
|
||||||
self._logger = WebSocketAdapter(_WS_LOGGER, {"connid": id(self)})
|
self._logger = WebSocketAdapter(_WS_LOGGER, {"connid": id(self)})
|
||||||
self._peak_checker_unsub: Callable[[], None] | None = None
|
self._peak_checker_unsub: Callable[[], None] | None = None
|
||||||
|
self.connection: ActiveConnection | None = None
|
||||||
|
|
||||||
async def _writer(self) -> None:
|
async def _writer(self) -> None:
|
||||||
"""Write outgoing messages."""
|
"""Write outgoing messages."""
|
||||||
# Exceptions if Socket disconnected or cancelled by connection handler
|
# Exceptions if Socket disconnected or cancelled by connection handler
|
||||||
with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS):
|
to_write = self._to_write
|
||||||
while not self.wsock.closed:
|
logger = self._logger
|
||||||
if (process := await self._to_write.get()) is None:
|
wsock = self.wsock
|
||||||
break
|
try:
|
||||||
|
with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS):
|
||||||
|
while not self.wsock.closed:
|
||||||
|
if (process := await to_write.get()) is None:
|
||||||
|
return
|
||||||
|
message = process if isinstance(process, str) else process()
|
||||||
|
|
||||||
if not isinstance(process, str):
|
if (
|
||||||
message: str = process()
|
to_write.empty()
|
||||||
else:
|
or not self.connection
|
||||||
message = process
|
or FEATURE_COALESCE_MESSAGES
|
||||||
self._logger.debug("Sending %s", message)
|
not in self.connection.supported_features
|
||||||
await self.wsock.send_str(message)
|
):
|
||||||
|
logger.debug("Sending %s", message)
|
||||||
|
await wsock.send_str(message)
|
||||||
|
continue
|
||||||
|
|
||||||
# Clean up the peaker checker when we shut down the writer
|
messages: list[str] = [message]
|
||||||
if self._peak_checker_unsub is not None:
|
while not to_write.empty():
|
||||||
self._peak_checker_unsub()
|
if (process := to_write.get_nowait()) is None:
|
||||||
self._peak_checker_unsub = None
|
return
|
||||||
|
messages.append(
|
||||||
|
process if isinstance(process, str) else process()
|
||||||
|
)
|
||||||
|
|
||||||
|
coalesced_messages = "[" + ",".join(messages) + "]"
|
||||||
|
self._logger.debug("Sending %s", coalesced_messages)
|
||||||
|
await self.wsock.send_str(coalesced_messages)
|
||||||
|
finally:
|
||||||
|
# Clean up the peaker checker when we shut down the writer
|
||||||
|
if self._peak_checker_unsub is not None:
|
||||||
|
self._peak_checker_unsub()
|
||||||
|
self._peak_checker_unsub = None
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _send_message(self, message: str | dict[str, Any] | Callable[[], str]) -> None:
|
def _send_message(self, message: str | dict[str, Any] | Callable[[], str]) -> None:
|
||||||
@ -194,13 +221,13 @@ class WebSocketHandler:
|
|||||||
raise Disconnect
|
raise Disconnect
|
||||||
|
|
||||||
try:
|
try:
|
||||||
msg_data = msg.json()
|
msg_data = msg.json(loads=json_loads)
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
disconnect_warn = "Received invalid JSON."
|
disconnect_warn = "Received invalid JSON."
|
||||||
raise Disconnect from err
|
raise Disconnect from err
|
||||||
|
|
||||||
self._logger.debug("Received %s", msg_data)
|
self._logger.debug("Received %s", msg_data)
|
||||||
connection = await auth.async_handle(msg_data)
|
self.connection = connection = await auth.async_handle(msg_data)
|
||||||
self.hass.data[DATA_CONNECTIONS] = (
|
self.hass.data[DATA_CONNECTIONS] = (
|
||||||
self.hass.data.get(DATA_CONNECTIONS, 0) + 1
|
self.hass.data.get(DATA_CONNECTIONS, 0) + 1
|
||||||
)
|
)
|
||||||
@ -218,13 +245,18 @@ class WebSocketHandler:
|
|||||||
break
|
break
|
||||||
|
|
||||||
try:
|
try:
|
||||||
msg_data = msg.json()
|
msg_data = msg.json(loads=json_loads)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
disconnect_warn = "Received invalid JSON."
|
disconnect_warn = "Received invalid JSON."
|
||||||
break
|
break
|
||||||
|
|
||||||
self._logger.debug("Received %s", msg_data)
|
self._logger.debug("Received %s", msg_data)
|
||||||
connection.async_handle(msg_data)
|
if not isinstance(msg_data, list):
|
||||||
|
connection.async_handle(msg_data)
|
||||||
|
continue
|
||||||
|
|
||||||
|
for split_msg in msg_data:
|
||||||
|
connection.async_handle(split_msg)
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
self._logger.info("Connection closed by client")
|
self._logger.info("Connection closed by client")
|
||||||
@ -257,6 +289,8 @@ class WebSocketHandler:
|
|||||||
|
|
||||||
if connection is not None:
|
if connection is not None:
|
||||||
self.hass.data[DATA_CONNECTIONS] -= 1
|
self.hass.data[DATA_CONNECTIONS] -= 1
|
||||||
|
self.connection = None
|
||||||
|
|
||||||
async_dispatcher_send(self.hass, SIGNAL_WEBSOCKET_DISCONNECTED)
|
async_dispatcher_send(self.hass, SIGNAL_WEBSOCKET_DISCONNECTED)
|
||||||
|
|
||||||
return wsock
|
return wsock
|
||||||
|
@ -13,12 +13,13 @@ from homeassistant.components.websocket_api.auth import (
|
|||||||
TYPE_AUTH_OK,
|
TYPE_AUTH_OK,
|
||||||
TYPE_AUTH_REQUIRED,
|
TYPE_AUTH_REQUIRED,
|
||||||
)
|
)
|
||||||
from homeassistant.components.websocket_api.const import URL
|
from homeassistant.components.websocket_api.const import FEATURE_COALESCE_MESSAGES, URL
|
||||||
from homeassistant.const import SIGNAL_BOOTSTRAP_INTEGRATONS
|
from homeassistant.const import SIGNAL_BOOTSTRAP_INTEGRATONS
|
||||||
from homeassistant.core import Context, HomeAssistant, State, callback
|
from homeassistant.core import Context, HomeAssistant, State, callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import entity
|
from homeassistant.helpers import entity
|
||||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
||||||
|
from homeassistant.helpers.json import json_loads
|
||||||
from homeassistant.loader import async_get_integration
|
from homeassistant.loader import async_get_integration
|
||||||
from homeassistant.setup import DATA_SETUP_TIME, async_setup_component
|
from homeassistant.setup import DATA_SETUP_TIME, async_setup_component
|
||||||
|
|
||||||
@ -1788,3 +1789,186 @@ async def test_supported_brands(hass, websocket_client):
|
|||||||
"hello": "World",
|
"hello": "World",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_message_coalescing(hass, websocket_client, hass_admin_user):
|
||||||
|
"""Test enabling message coalescing."""
|
||||||
|
await websocket_client.send_json(
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"type": "supported_features",
|
||||||
|
"features": {FEATURE_COALESCE_MESSAGES: 1},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
hass.states.async_set("light.permitted", "on", {"color": "red"})
|
||||||
|
|
||||||
|
data = await websocket_client.receive_str()
|
||||||
|
msg = json_loads(data)
|
||||||
|
assert msg["id"] == 1
|
||||||
|
assert msg["type"] == const.TYPE_RESULT
|
||||||
|
assert msg["success"]
|
||||||
|
|
||||||
|
await websocket_client.send_json({"id": 7, "type": "subscribe_entities"})
|
||||||
|
|
||||||
|
data = await websocket_client.receive_str()
|
||||||
|
msgs = json_loads(data)
|
||||||
|
msg = msgs.pop(0)
|
||||||
|
assert msg["id"] == 7
|
||||||
|
assert msg["type"] == const.TYPE_RESULT
|
||||||
|
assert msg["success"]
|
||||||
|
|
||||||
|
msg = msgs.pop(0)
|
||||||
|
assert msg["id"] == 7
|
||||||
|
assert msg["type"] == "event"
|
||||||
|
assert msg["event"] == {
|
||||||
|
"a": {
|
||||||
|
"light.permitted": {"a": {"color": "red"}, "c": ANY, "lc": ANY, "s": "on"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
hass.states.async_set("light.permitted", "on", {"color": "yellow"})
|
||||||
|
hass.states.async_set("light.permitted", "on", {"color": "green"})
|
||||||
|
hass.states.async_set("light.permitted", "on", {"color": "blue"})
|
||||||
|
|
||||||
|
data = await websocket_client.receive_str()
|
||||||
|
msgs = json_loads(data)
|
||||||
|
|
||||||
|
msg = msgs.pop(0)
|
||||||
|
assert msg["id"] == 7
|
||||||
|
assert msg["type"] == "event"
|
||||||
|
assert msg["event"] == {
|
||||||
|
"c": {"light.permitted": {"+": {"a": {"color": "yellow"}, "c": ANY, "lu": ANY}}}
|
||||||
|
}
|
||||||
|
|
||||||
|
msg = msgs.pop(0)
|
||||||
|
assert msg["id"] == 7
|
||||||
|
assert msg["type"] == "event"
|
||||||
|
assert msg["event"] == {
|
||||||
|
"c": {"light.permitted": {"+": {"a": {"color": "green"}, "c": ANY, "lu": ANY}}}
|
||||||
|
}
|
||||||
|
|
||||||
|
msg = msgs.pop(0)
|
||||||
|
assert msg["id"] == 7
|
||||||
|
assert msg["type"] == "event"
|
||||||
|
assert msg["event"] == {
|
||||||
|
"c": {"light.permitted": {"+": {"a": {"color": "blue"}, "c": ANY, "lu": ANY}}}
|
||||||
|
}
|
||||||
|
|
||||||
|
hass.states.async_set("light.permitted", "on", {"color": "yellow"})
|
||||||
|
hass.states.async_set("light.permitted", "on", {"color": "green"})
|
||||||
|
hass.states.async_set("light.permitted", "on", {"color": "blue"})
|
||||||
|
await websocket_client.close()
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_message_coalescing_not_supported_by_websocket_client(
|
||||||
|
hass, websocket_client, hass_admin_user
|
||||||
|
):
|
||||||
|
"""Test enabling message coalescing not supported by websocket client."""
|
||||||
|
await websocket_client.send_json({"id": 7, "type": "subscribe_entities"})
|
||||||
|
|
||||||
|
data = await websocket_client.receive_str()
|
||||||
|
msg = json_loads(data)
|
||||||
|
assert msg["id"] == 7
|
||||||
|
assert msg["type"] == const.TYPE_RESULT
|
||||||
|
assert msg["success"]
|
||||||
|
|
||||||
|
hass.states.async_set("light.permitted", "on", {"color": "red"})
|
||||||
|
hass.states.async_set("light.permitted", "on", {"color": "blue"})
|
||||||
|
|
||||||
|
data = await websocket_client.receive_str()
|
||||||
|
msg = json_loads(data)
|
||||||
|
assert msg["id"] == 7
|
||||||
|
assert msg["type"] == "event"
|
||||||
|
assert msg["event"] == {"a": {}}
|
||||||
|
|
||||||
|
data = await websocket_client.receive_str()
|
||||||
|
msg = json_loads(data)
|
||||||
|
assert msg["id"] == 7
|
||||||
|
assert msg["type"] == "event"
|
||||||
|
assert msg["event"] == {
|
||||||
|
"a": {
|
||||||
|
"light.permitted": {"a": {"color": "red"}, "c": ANY, "lc": ANY, "s": "on"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data = await websocket_client.receive_str()
|
||||||
|
msg = json_loads(data)
|
||||||
|
assert msg["id"] == 7
|
||||||
|
assert msg["type"] == "event"
|
||||||
|
assert msg["event"] == {
|
||||||
|
"c": {"light.permitted": {"+": {"a": {"color": "blue"}, "c": ANY, "lu": ANY}}}
|
||||||
|
}
|
||||||
|
await websocket_client.close()
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_client_message_coalescing(hass, websocket_client, hass_admin_user):
|
||||||
|
"""Test client message coalescing."""
|
||||||
|
await websocket_client.send_json(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"type": "supported_features",
|
||||||
|
"features": {FEATURE_COALESCE_MESSAGES: 1},
|
||||||
|
},
|
||||||
|
{"id": 7, "type": "subscribe_entities"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
hass.states.async_set("light.permitted", "on", {"color": "red"})
|
||||||
|
|
||||||
|
data = await websocket_client.receive_str()
|
||||||
|
msgs = json_loads(data)
|
||||||
|
|
||||||
|
msg = msgs.pop(0)
|
||||||
|
assert msg["id"] == 1
|
||||||
|
assert msg["type"] == const.TYPE_RESULT
|
||||||
|
assert msg["success"]
|
||||||
|
|
||||||
|
msg = msgs.pop(0)
|
||||||
|
assert msg["id"] == 7
|
||||||
|
assert msg["type"] == const.TYPE_RESULT
|
||||||
|
assert msg["success"]
|
||||||
|
|
||||||
|
msg = msgs.pop(0)
|
||||||
|
assert msg["id"] == 7
|
||||||
|
assert msg["type"] == "event"
|
||||||
|
assert msg["event"] == {
|
||||||
|
"a": {
|
||||||
|
"light.permitted": {"a": {"color": "red"}, "c": ANY, "lc": ANY, "s": "on"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
hass.states.async_set("light.permitted", "on", {"color": "yellow"})
|
||||||
|
hass.states.async_set("light.permitted", "on", {"color": "green"})
|
||||||
|
hass.states.async_set("light.permitted", "on", {"color": "blue"})
|
||||||
|
|
||||||
|
data = await websocket_client.receive_str()
|
||||||
|
msgs = json_loads(data)
|
||||||
|
|
||||||
|
msg = msgs.pop(0)
|
||||||
|
assert msg["id"] == 7
|
||||||
|
assert msg["type"] == "event"
|
||||||
|
assert msg["event"] == {
|
||||||
|
"c": {"light.permitted": {"+": {"a": {"color": "yellow"}, "c": ANY, "lu": ANY}}}
|
||||||
|
}
|
||||||
|
|
||||||
|
msg = msgs.pop(0)
|
||||||
|
assert msg["id"] == 7
|
||||||
|
assert msg["type"] == "event"
|
||||||
|
assert msg["event"] == {
|
||||||
|
"c": {"light.permitted": {"+": {"a": {"color": "green"}, "c": ANY, "lu": ANY}}}
|
||||||
|
}
|
||||||
|
|
||||||
|
msg = msgs.pop(0)
|
||||||
|
assert msg["id"] == 7
|
||||||
|
assert msg["type"] == "event"
|
||||||
|
assert msg["event"] == {
|
||||||
|
"c": {"light.permitted": {"+": {"a": {"color": "blue"}, "c": ANY, "lu": ANY}}}
|
||||||
|
}
|
||||||
|
|
||||||
|
hass.states.async_set("light.permitted", "on", {"color": "yellow"})
|
||||||
|
hass.states.async_set("light.permitted", "on", {"color": "green"})
|
||||||
|
hass.states.async_set("light.permitted", "on", {"color": "blue"})
|
||||||
|
await websocket_client.close()
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
@ -2,15 +2,25 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator, Callable, Generator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
import functools
|
import functools
|
||||||
|
from json import JSONDecoder, loads
|
||||||
import logging
|
import logging
|
||||||
import ssl
|
import ssl
|
||||||
import threading
|
import threading
|
||||||
|
from typing import Any
|
||||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||||
|
|
||||||
from aiohttp.test_utils import make_mocked_request
|
from aiohttp import client
|
||||||
|
from aiohttp.pytest_plugin import AiohttpClient
|
||||||
|
from aiohttp.test_utils import (
|
||||||
|
BaseTestServer,
|
||||||
|
TestClient,
|
||||||
|
TestServer,
|
||||||
|
make_mocked_request,
|
||||||
|
)
|
||||||
|
from aiohttp.web import Application
|
||||||
import freezegun
|
import freezegun
|
||||||
import multidict
|
import multidict
|
||||||
import pytest
|
import pytest
|
||||||
@ -57,6 +67,7 @@ from tests.components.recorder.common import ( # noqa: E402, isort:skip
|
|||||||
async_recorder_block_till_done,
|
async_recorder_block_till_done,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
@ -203,6 +214,97 @@ def load_registries():
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class CoalescingResponse(client.ClientWebSocketResponse):
|
||||||
|
"""ClientWebSocketResponse client that mimics the websocket js code."""
|
||||||
|
|
||||||
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
"""Init the ClientWebSocketResponse."""
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._recv_buffer: list[Any] = []
|
||||||
|
|
||||||
|
async def receive_json(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
loads: JSONDecoder = loads,
|
||||||
|
timeout: float | None = None,
|
||||||
|
) -> Any:
|
||||||
|
"""receive_json or from buffer."""
|
||||||
|
if self._recv_buffer:
|
||||||
|
return self._recv_buffer.pop(0)
|
||||||
|
data = await self.receive_str(timeout=timeout)
|
||||||
|
decoded = loads(data)
|
||||||
|
if isinstance(decoded, list):
|
||||||
|
self._recv_buffer = decoded
|
||||||
|
return self._recv_buffer.pop(0)
|
||||||
|
return decoded
|
||||||
|
|
||||||
|
|
||||||
|
class CoalescingClient(TestClient):
|
||||||
|
"""Client that mimics the websocket js code."""
|
||||||
|
|
||||||
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
"""Init TestClient."""
|
||||||
|
super().__init__(*args, ws_response_class=CoalescingResponse, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def aiohttp_client_cls():
|
||||||
|
"""Override the test class for aiohttp."""
|
||||||
|
return CoalescingClient
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def aiohttp_client(
|
||||||
|
loop: asyncio.AbstractEventLoop,
|
||||||
|
) -> Generator[AiohttpClient, None, None]:
|
||||||
|
"""Override the default aiohttp_client since 3.x does not support aiohttp_client_cls.
|
||||||
|
|
||||||
|
Remove this when upgrading to 4.x as aiohttp_client_cls
|
||||||
|
will do the same thing
|
||||||
|
|
||||||
|
aiohttp_client(app, **kwargs)
|
||||||
|
aiohttp_client(server, **kwargs)
|
||||||
|
aiohttp_client(raw_server, **kwargs)
|
||||||
|
"""
|
||||||
|
clients = []
|
||||||
|
|
||||||
|
async def go(
|
||||||
|
__param: Application | BaseTestServer,
|
||||||
|
*args: Any,
|
||||||
|
server_kwargs: dict[str, Any] | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> TestClient:
|
||||||
|
|
||||||
|
if isinstance(__param, Callable) and not isinstance( # type: ignore[arg-type]
|
||||||
|
__param, (Application, BaseTestServer)
|
||||||
|
):
|
||||||
|
__param = __param(loop, *args, **kwargs)
|
||||||
|
kwargs = {}
|
||||||
|
else:
|
||||||
|
assert not args, "args should be empty"
|
||||||
|
|
||||||
|
if isinstance(__param, Application):
|
||||||
|
server_kwargs = server_kwargs or {}
|
||||||
|
server = TestServer(__param, loop=loop, **server_kwargs)
|
||||||
|
client = CoalescingClient(server, loop=loop, **kwargs)
|
||||||
|
elif isinstance(__param, BaseTestServer):
|
||||||
|
client = TestClient(__param, loop=loop, **kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown argument type: %r" % type(__param))
|
||||||
|
|
||||||
|
await client.start_server()
|
||||||
|
clients.append(client)
|
||||||
|
return client
|
||||||
|
|
||||||
|
yield go
|
||||||
|
|
||||||
|
async def finalize() -> None:
|
||||||
|
while clients:
|
||||||
|
await clients.pop().close()
|
||||||
|
|
||||||
|
loop.run_until_complete(finalize())
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def hass(loop, load_registries, hass_storage, request):
|
def hass(loop, load_registries, hass_storage, request):
|
||||||
"""Fixture to provide a test instance of Home Assistant."""
|
"""Fixture to provide a test instance of Home Assistant."""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user