mirror of
https://github.com/home-assistant/core.git
synced 2025-07-28 07:37:34 +00:00
Improve websocket throughput and reduce latency (#92967)
This commit is contained in:
parent
9a70f47049
commit
8711735ec0
@ -715,7 +715,7 @@ def handle_supported_features(
|
|||||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle setting supported features."""
|
"""Handle setting supported features."""
|
||||||
connection.supported_features = msg["features"]
|
connection.set_supported_features(msg["features"])
|
||||||
connection.send_result(msg["id"])
|
connection.send_result(msg["id"])
|
||||||
|
|
||||||
|
|
||||||
|
@ -48,6 +48,7 @@ class ActiveConnection:
|
|||||||
self.refresh_token_id = refresh_token.id
|
self.refresh_token_id = refresh_token.id
|
||||||
self.subscriptions: dict[Hashable, Callable[[], Any]] = {}
|
self.subscriptions: dict[Hashable, Callable[[], Any]] = {}
|
||||||
self.last_id = 0
|
self.last_id = 0
|
||||||
|
self.can_coalesce = False
|
||||||
self.supported_features: dict[str, float] = {}
|
self.supported_features: dict[str, float] = {}
|
||||||
self.handlers: dict[str, tuple[MessageHandler, vol.Schema]] = self.hass.data[
|
self.handlers: dict[str, tuple[MessageHandler, vol.Schema]] = self.hass.data[
|
||||||
const.DOMAIN
|
const.DOMAIN
|
||||||
@ -55,6 +56,11 @@ class ActiveConnection:
|
|||||||
self.binary_handlers: list[BinaryHandler | None] = []
|
self.binary_handlers: list[BinaryHandler | None] = []
|
||||||
current_connection.set(self)
|
current_connection.set(self)
|
||||||
|
|
||||||
|
def set_supported_features(self, features: dict[str, float]) -> None:
|
||||||
|
"""Set supported features."""
|
||||||
|
self.supported_features = features
|
||||||
|
self.can_coalesce = const.FEATURE_COALESCE_MESSAGES in features
|
||||||
|
|
||||||
def get_description(self, request: web.Request | None) -> str:
|
def get_description(self, request: web.Request | None) -> str:
|
||||||
"""Return a description of the connection."""
|
"""Return a description of the connection."""
|
||||||
description = self.user.name or ""
|
description = self.user.name or ""
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from collections import deque
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
import datetime as dt
|
import datetime as dt
|
||||||
@ -22,7 +23,6 @@ from .auth import AuthPhase, auth_required_message
|
|||||||
from .const import (
|
from .const import (
|
||||||
CANCELLATION_ERRORS,
|
CANCELLATION_ERRORS,
|
||||||
DATA_CONNECTIONS,
|
DATA_CONNECTIONS,
|
||||||
FEATURE_COALESCE_MESSAGES,
|
|
||||||
MAX_PENDING_MSG,
|
MAX_PENDING_MSG,
|
||||||
PENDING_MSG_PEAK,
|
PENDING_MSG_PEAK,
|
||||||
PENDING_MSG_PEAK_TIME,
|
PENDING_MSG_PEAK_TIME,
|
||||||
@ -71,7 +71,6 @@ class WebSocketHandler:
|
|||||||
self.hass = hass
|
self.hass = hass
|
||||||
self.request = request
|
self.request = request
|
||||||
self.wsock = web.WebSocketResponse(heartbeat=55)
|
self.wsock = web.WebSocketResponse(heartbeat=55)
|
||||||
self._to_write: asyncio.Queue = asyncio.Queue(maxsize=MAX_PENDING_MSG)
|
|
||||||
self._handle_task: asyncio.Task | None = None
|
self._handle_task: asyncio.Task | None = None
|
||||||
self._writer_task: asyncio.Task | None = None
|
self._writer_task: asyncio.Task | None = None
|
||||||
self._closing: bool = False
|
self._closing: bool = False
|
||||||
@ -79,6 +78,13 @@ class WebSocketHandler:
|
|||||||
self._peak_checker_unsub: Callable[[], None] | None = None
|
self._peak_checker_unsub: Callable[[], None] | None = None
|
||||||
self.connection: ActiveConnection | None = None
|
self.connection: ActiveConnection | None = None
|
||||||
|
|
||||||
|
# The WebSocketHandler has a single consumer and path
|
||||||
|
# to where messages are queued. This allows the implementation
|
||||||
|
# to use a deque and an asyncio.Future to avoid the overhead of
|
||||||
|
# an asyncio.Queue.
|
||||||
|
self._message_queue: deque = deque()
|
||||||
|
self._ready_future: asyncio.Future[None] | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
"""Return a description of the connection."""
|
"""Return a description of the connection."""
|
||||||
@ -88,39 +94,52 @@ class WebSocketHandler:
|
|||||||
|
|
||||||
async def _writer(self) -> None:
|
async def _writer(self) -> None:
|
||||||
"""Write outgoing messages."""
|
"""Write outgoing messages."""
|
||||||
# Exceptions if Socket disconnected or cancelled by connection handler
|
# Variables are set locally to avoid lookups in the loop
|
||||||
to_write = self._to_write
|
message_queue = self._message_queue
|
||||||
logger = self._logger
|
logger = self._logger
|
||||||
wsock = self.wsock
|
send_str = self.wsock.send_str
|
||||||
|
loop = self.hass.loop
|
||||||
|
debug = logger.debug
|
||||||
|
# Exceptions if Socket disconnected or cancelled by connection handler
|
||||||
try:
|
try:
|
||||||
with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS):
|
with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS):
|
||||||
while not self.wsock.closed:
|
while not self.wsock.closed:
|
||||||
if (process := await to_write.get()) is None:
|
if (messages_remaining := len(message_queue)) == 0:
|
||||||
|
self._ready_future = loop.create_future()
|
||||||
|
await self._ready_future
|
||||||
|
messages_remaining = len(message_queue)
|
||||||
|
|
||||||
|
# A None message is used to signal the end of the connection
|
||||||
|
if (process := message_queue.popleft()) is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
messages_remaining -= 1
|
||||||
message = process if isinstance(process, str) else process()
|
message = process if isinstance(process, str) else process()
|
||||||
|
|
||||||
if (
|
if (
|
||||||
to_write.empty()
|
not messages_remaining
|
||||||
or not self.connection
|
or not self.connection
|
||||||
or FEATURE_COALESCE_MESSAGES
|
or not self.connection.can_coalesce
|
||||||
not in self.connection.supported_features
|
|
||||||
):
|
):
|
||||||
logger.debug("Sending %s", message)
|
debug("Sending %s", message)
|
||||||
await wsock.send_str(message)
|
await send_str(message)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
messages: list[str] = [message]
|
messages: list[str] = [message]
|
||||||
while not to_write.empty():
|
while messages_remaining:
|
||||||
if (process := to_write.get_nowait()) is None:
|
# A None message is used to signal the end of the connection
|
||||||
|
if (process := message_queue.popleft()) is None:
|
||||||
return
|
return
|
||||||
messages.append(
|
messages.append(
|
||||||
process if isinstance(process, str) else process()
|
process if isinstance(process, str) else process()
|
||||||
)
|
)
|
||||||
|
messages_remaining -= 1
|
||||||
|
|
||||||
coalesced_messages = "[" + ",".join(messages) + "]"
|
coalesced_messages = "[" + ",".join(messages) + "]"
|
||||||
logger.debug("Sending %s", coalesced_messages)
|
debug("Sending %s", coalesced_messages)
|
||||||
await wsock.send_str(coalesced_messages)
|
await send_str(coalesced_messages)
|
||||||
finally:
|
finally:
|
||||||
# Clean up the peaker checker when we shut down the writer
|
# Clean up the peak checker when we shut down the writer
|
||||||
self._cancel_peak_checker()
|
self._cancel_peak_checker()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@ -146,11 +165,9 @@ class WebSocketHandler:
|
|||||||
if isinstance(message, dict):
|
if isinstance(message, dict):
|
||||||
message = message_to_json(message)
|
message = message_to_json(message)
|
||||||
|
|
||||||
to_write = self._to_write
|
message_queue = self._message_queue
|
||||||
|
queue_size_before_add = len(message_queue)
|
||||||
try:
|
if queue_size_before_add >= MAX_PENDING_MSG:
|
||||||
to_write.put_nowait(message)
|
|
||||||
except asyncio.QueueFull:
|
|
||||||
self._logger.error(
|
self._logger.error(
|
||||||
(
|
(
|
||||||
"%s: Client unable to keep up with pending messages. Reached %s pending"
|
"%s: Client unable to keep up with pending messages. Reached %s pending"
|
||||||
@ -162,10 +179,15 @@ class WebSocketHandler:
|
|||||||
message,
|
message,
|
||||||
)
|
)
|
||||||
self._cancel()
|
self._cancel()
|
||||||
|
return
|
||||||
|
|
||||||
|
message_queue.append(message)
|
||||||
|
if self._ready_future and not self._ready_future.done():
|
||||||
|
self._ready_future.set_result(None)
|
||||||
|
|
||||||
peak_checker_active = self._peak_checker_unsub is not None
|
peak_checker_active = self._peak_checker_unsub is not None
|
||||||
|
|
||||||
if to_write.qsize() < PENDING_MSG_PEAK:
|
if queue_size_before_add <= PENDING_MSG_PEAK:
|
||||||
if peak_checker_active:
|
if peak_checker_active:
|
||||||
self._cancel_peak_checker()
|
self._cancel_peak_checker()
|
||||||
return
|
return
|
||||||
@ -180,7 +202,7 @@ class WebSocketHandler:
|
|||||||
"""Check that we are no longer above the write peak."""
|
"""Check that we are no longer above the write peak."""
|
||||||
self._peak_checker_unsub = None
|
self._peak_checker_unsub = None
|
||||||
|
|
||||||
if self._to_write.qsize() < PENDING_MSG_PEAK:
|
if len(self._message_queue) < PENDING_MSG_PEAK:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._logger.error(
|
self._logger.error(
|
||||||
@ -199,6 +221,7 @@ class WebSocketHandler:
|
|||||||
def _cancel(self) -> None:
|
def _cancel(self) -> None:
|
||||||
"""Cancel the connection."""
|
"""Cancel the connection."""
|
||||||
self._closing = True
|
self._closing = True
|
||||||
|
self._cancel_peak_checker()
|
||||||
if self._handle_task is not None:
|
if self._handle_task is not None:
|
||||||
self._handle_task.cancel()
|
self._handle_task.cancel()
|
||||||
if self._writer_task is not None:
|
if self._writer_task is not None:
|
||||||
@ -356,14 +379,14 @@ class WebSocketHandler:
|
|||||||
|
|
||||||
self._closing = True
|
self._closing = True
|
||||||
|
|
||||||
|
self._message_queue.append(None)
|
||||||
|
if self._ready_future and not self._ready_future.done():
|
||||||
|
self._ready_future.set_result(None)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._to_write.put_nowait(None)
|
|
||||||
# Make sure all error messages are written before closing
|
# Make sure all error messages are written before closing
|
||||||
await self._writer_task
|
await self._writer_task
|
||||||
await wsock.close()
|
await wsock.close()
|
||||||
except asyncio.QueueFull: # can be raised by put_nowait
|
|
||||||
self._writer_task.cancel()
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
if disconnect_warn is None:
|
if disconnect_warn is None:
|
||||||
self._logger.debug("Disconnected")
|
self._logger.debug("Disconnected")
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Test Websocket API http module."""
|
"""Test Websocket API http module."""
|
||||||
import asyncio
|
import asyncio
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from aiohttp import ServerDisconnectedError, WSMsgType, web
|
from aiohttp import ServerDisconnectedError, WSMsgType, web
|
||||||
@ -53,12 +53,12 @@ async def test_pending_msg_peak(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Test pending msg overflow command."""
|
"""Test pending msg overflow command."""
|
||||||
orig_handler = http.WebSocketHandler
|
orig_handler = http.WebSocketHandler
|
||||||
instance = None
|
setup_instance: http.WebSocketHandler | None = None
|
||||||
|
|
||||||
def instantiate_handler(*args):
|
def instantiate_handler(*args):
|
||||||
nonlocal instance
|
nonlocal setup_instance
|
||||||
instance = orig_handler(*args)
|
setup_instance = orig_handler(*args)
|
||||||
return instance
|
return setup_instance
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.websocket_api.http.WebSocketHandler",
|
"homeassistant.components.websocket_api.http.WebSocketHandler",
|
||||||
@ -66,11 +66,10 @@ async def test_pending_msg_peak(
|
|||||||
):
|
):
|
||||||
websocket_client = await hass_ws_client()
|
websocket_client = await hass_ws_client()
|
||||||
|
|
||||||
# Kill writer task and fill queue past peak
|
instance: http.WebSocketHandler = cast(http.WebSocketHandler, setup_instance)
|
||||||
for _ in range(5):
|
|
||||||
instance._to_write.put_nowait(None)
|
|
||||||
|
|
||||||
# Trigger the peak check
|
# Fill the queue past the allowed peak
|
||||||
|
for _ in range(10):
|
||||||
instance._send_message({})
|
instance._send_message({})
|
||||||
|
|
||||||
async_fire_time_changed(
|
async_fire_time_changed(
|
||||||
@ -79,8 +78,54 @@ async def test_pending_msg_peak(
|
|||||||
|
|
||||||
msg = await websocket_client.receive()
|
msg = await websocket_client.receive()
|
||||||
assert msg.type == WSMsgType.close
|
assert msg.type == WSMsgType.close
|
||||||
|
|
||||||
assert "Client unable to keep up with pending messages" in caplog.text
|
assert "Client unable to keep up with pending messages" in caplog.text
|
||||||
|
assert "Stayed over 5 for 5 seconds"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pending_msg_peak_recovery(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_low_peak,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
caplog: pytest.LogCaptureFixture,
|
||||||
|
) -> None:
|
||||||
|
"""Test pending msg nears the peak but recovers."""
|
||||||
|
orig_handler = http.WebSocketHandler
|
||||||
|
setup_instance: http.WebSocketHandler | None = None
|
||||||
|
|
||||||
|
def instantiate_handler(*args):
|
||||||
|
nonlocal setup_instance
|
||||||
|
setup_instance = orig_handler(*args)
|
||||||
|
return setup_instance
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.websocket_api.http.WebSocketHandler",
|
||||||
|
instantiate_handler,
|
||||||
|
):
|
||||||
|
websocket_client = await hass_ws_client()
|
||||||
|
|
||||||
|
instance: http.WebSocketHandler = cast(http.WebSocketHandler, setup_instance)
|
||||||
|
|
||||||
|
# Make sure the call later is started
|
||||||
|
for _ in range(10):
|
||||||
|
instance._send_message({})
|
||||||
|
|
||||||
|
for _ in range(10):
|
||||||
|
msg = await websocket_client.receive()
|
||||||
|
assert msg.type == WSMsgType.TEXT
|
||||||
|
|
||||||
|
instance._send_message({})
|
||||||
|
msg = await websocket_client.receive()
|
||||||
|
assert msg.type == WSMsgType.TEXT
|
||||||
|
|
||||||
|
# Cleanly shutdown
|
||||||
|
instance._send_message({})
|
||||||
|
instance._handle_task.cancel()
|
||||||
|
|
||||||
|
msg = await websocket_client.receive()
|
||||||
|
assert msg.type == WSMsgType.TEXT
|
||||||
|
msg = await websocket_client.receive()
|
||||||
|
assert msg.type == WSMsgType.close
|
||||||
|
assert "Client unable to keep up with pending messages" not in caplog.text
|
||||||
|
|
||||||
|
|
||||||
async def test_pending_msg_peak_but_does_not_overflow(
|
async def test_pending_msg_peak_but_does_not_overflow(
|
||||||
@ -91,12 +136,12 @@ async def test_pending_msg_peak_but_does_not_overflow(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Test pending msg hits the low peak but recovers and does not overflow."""
|
"""Test pending msg hits the low peak but recovers and does not overflow."""
|
||||||
orig_handler = http.WebSocketHandler
|
orig_handler = http.WebSocketHandler
|
||||||
instance: http.WebSocketHandler | None = None
|
setup_instance: http.WebSocketHandler | None = None
|
||||||
|
|
||||||
def instantiate_handler(*args):
|
def instantiate_handler(*args):
|
||||||
nonlocal instance
|
nonlocal setup_instance
|
||||||
instance = orig_handler(*args)
|
setup_instance = orig_handler(*args)
|
||||||
return instance
|
return setup_instance
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.websocket_api.http.WebSocketHandler",
|
"homeassistant.components.websocket_api.http.WebSocketHandler",
|
||||||
@ -104,18 +149,17 @@ async def test_pending_msg_peak_but_does_not_overflow(
|
|||||||
):
|
):
|
||||||
websocket_client = await hass_ws_client()
|
websocket_client = await hass_ws_client()
|
||||||
|
|
||||||
assert instance is not None
|
instance: http.WebSocketHandler = cast(http.WebSocketHandler, setup_instance)
|
||||||
|
|
||||||
# Kill writer task and fill queue past peak
|
# Kill writer task and fill queue past peak
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
instance._to_write.put_nowait(None)
|
instance._message_queue.append(None)
|
||||||
|
|
||||||
# Trigger the peak check
|
# Trigger the peak check
|
||||||
instance._send_message({})
|
instance._send_message({})
|
||||||
|
|
||||||
# Clear the queue
|
# Clear the queue
|
||||||
while instance._to_write.qsize() > 0:
|
instance._message_queue.clear()
|
||||||
instance._to_write.get_nowait()
|
|
||||||
|
|
||||||
# Trigger the peak clear
|
# Trigger the peak clear
|
||||||
instance._send_message({})
|
instance._send_message({})
|
||||||
|
Loading…
x
Reference in New Issue
Block a user