Improve websocket throughput and reduce latency (#92967)

This commit is contained in:
J. Nick Koston 2023-05-13 00:13:57 +09:00 committed by GitHub
parent 9a70f47049
commit 8711735ec0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 120 additions and 47 deletions

View File

@ -715,7 +715,7 @@ def handle_supported_features(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None: ) -> None:
"""Handle setting supported features.""" """Handle setting supported features."""
connection.supported_features = msg["features"] connection.set_supported_features(msg["features"])
connection.send_result(msg["id"]) connection.send_result(msg["id"])

View File

@ -48,6 +48,7 @@ class ActiveConnection:
self.refresh_token_id = refresh_token.id self.refresh_token_id = refresh_token.id
self.subscriptions: dict[Hashable, Callable[[], Any]] = {} self.subscriptions: dict[Hashable, Callable[[], Any]] = {}
self.last_id = 0 self.last_id = 0
self.can_coalesce = False
self.supported_features: dict[str, float] = {} self.supported_features: dict[str, float] = {}
self.handlers: dict[str, tuple[MessageHandler, vol.Schema]] = self.hass.data[ self.handlers: dict[str, tuple[MessageHandler, vol.Schema]] = self.hass.data[
const.DOMAIN const.DOMAIN
@ -55,6 +56,11 @@ class ActiveConnection:
self.binary_handlers: list[BinaryHandler | None] = [] self.binary_handlers: list[BinaryHandler | None] = []
current_connection.set(self) current_connection.set(self)
def set_supported_features(self, features: dict[str, float]) -> None:
"""Set supported features."""
self.supported_features = features
self.can_coalesce = const.FEATURE_COALESCE_MESSAGES in features
def get_description(self, request: web.Request | None) -> str: def get_description(self, request: web.Request | None) -> str:
"""Return a description of the connection.""" """Return a description of the connection."""
description = self.user.name or "" description = self.user.name or ""

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections import deque
from collections.abc import Callable from collections.abc import Callable
from contextlib import suppress from contextlib import suppress
import datetime as dt import datetime as dt
@ -22,7 +23,6 @@ from .auth import AuthPhase, auth_required_message
from .const import ( from .const import (
CANCELLATION_ERRORS, CANCELLATION_ERRORS,
DATA_CONNECTIONS, DATA_CONNECTIONS,
FEATURE_COALESCE_MESSAGES,
MAX_PENDING_MSG, MAX_PENDING_MSG,
PENDING_MSG_PEAK, PENDING_MSG_PEAK,
PENDING_MSG_PEAK_TIME, PENDING_MSG_PEAK_TIME,
@ -71,7 +71,6 @@ class WebSocketHandler:
self.hass = hass self.hass = hass
self.request = request self.request = request
self.wsock = web.WebSocketResponse(heartbeat=55) self.wsock = web.WebSocketResponse(heartbeat=55)
self._to_write: asyncio.Queue = asyncio.Queue(maxsize=MAX_PENDING_MSG)
self._handle_task: asyncio.Task | None = None self._handle_task: asyncio.Task | None = None
self._writer_task: asyncio.Task | None = None self._writer_task: asyncio.Task | None = None
self._closing: bool = False self._closing: bool = False
@ -79,6 +78,13 @@ class WebSocketHandler:
self._peak_checker_unsub: Callable[[], None] | None = None self._peak_checker_unsub: Callable[[], None] | None = None
self.connection: ActiveConnection | None = None self.connection: ActiveConnection | None = None
# The WebSocketHandler has a single consumer and path
# to where messages are queued. This allows the implementation
# to use a deque and an asyncio.Future to avoid the overhead of
# an asyncio.Queue.
self._message_queue: deque = deque()
self._ready_future: asyncio.Future[None] | None = None
@property @property
def description(self) -> str: def description(self) -> str:
"""Return a description of the connection.""" """Return a description of the connection."""
@ -88,39 +94,52 @@ class WebSocketHandler:
async def _writer(self) -> None: async def _writer(self) -> None:
"""Write outgoing messages.""" """Write outgoing messages."""
# Exceptions if Socket disconnected or cancelled by connection handler # Variables are set locally to avoid lookups in the loop
to_write = self._to_write message_queue = self._message_queue
logger = self._logger logger = self._logger
wsock = self.wsock send_str = self.wsock.send_str
loop = self.hass.loop
debug = logger.debug
# Exceptions if Socket disconnected or cancelled by connection handler
try: try:
with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS): with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS):
while not self.wsock.closed: while not self.wsock.closed:
if (process := await to_write.get()) is None: if (messages_remaining := len(message_queue)) == 0:
self._ready_future = loop.create_future()
await self._ready_future
messages_remaining = len(message_queue)
# A None message is used to signal the end of the connection
if (process := message_queue.popleft()) is None:
return return
messages_remaining -= 1
message = process if isinstance(process, str) else process() message = process if isinstance(process, str) else process()
if ( if (
to_write.empty() not messages_remaining
or not self.connection or not self.connection
or FEATURE_COALESCE_MESSAGES or not self.connection.can_coalesce
not in self.connection.supported_features
): ):
logger.debug("Sending %s", message) debug("Sending %s", message)
await wsock.send_str(message) await send_str(message)
continue continue
messages: list[str] = [message] messages: list[str] = [message]
while not to_write.empty(): while messages_remaining:
if (process := to_write.get_nowait()) is None: # A None message is used to signal the end of the connection
if (process := message_queue.popleft()) is None:
return return
messages.append( messages.append(
process if isinstance(process, str) else process() process if isinstance(process, str) else process()
) )
messages_remaining -= 1
coalesced_messages = "[" + ",".join(messages) + "]" coalesced_messages = "[" + ",".join(messages) + "]"
logger.debug("Sending %s", coalesced_messages) debug("Sending %s", coalesced_messages)
await wsock.send_str(coalesced_messages) await send_str(coalesced_messages)
finally: finally:
# Clean up the peaker checker when we shut down the writer # Clean up the peak checker when we shut down the writer
self._cancel_peak_checker() self._cancel_peak_checker()
@callback @callback
@ -146,11 +165,9 @@ class WebSocketHandler:
if isinstance(message, dict): if isinstance(message, dict):
message = message_to_json(message) message = message_to_json(message)
to_write = self._to_write message_queue = self._message_queue
queue_size_before_add = len(message_queue)
try: if queue_size_before_add >= MAX_PENDING_MSG:
to_write.put_nowait(message)
except asyncio.QueueFull:
self._logger.error( self._logger.error(
( (
"%s: Client unable to keep up with pending messages. Reached %s pending" "%s: Client unable to keep up with pending messages. Reached %s pending"
@ -162,10 +179,15 @@ class WebSocketHandler:
message, message,
) )
self._cancel() self._cancel()
return
message_queue.append(message)
if self._ready_future and not self._ready_future.done():
self._ready_future.set_result(None)
peak_checker_active = self._peak_checker_unsub is not None peak_checker_active = self._peak_checker_unsub is not None
if to_write.qsize() < PENDING_MSG_PEAK: if queue_size_before_add <= PENDING_MSG_PEAK:
if peak_checker_active: if peak_checker_active:
self._cancel_peak_checker() self._cancel_peak_checker()
return return
@ -180,7 +202,7 @@ class WebSocketHandler:
"""Check that we are no longer above the write peak.""" """Check that we are no longer above the write peak."""
self._peak_checker_unsub = None self._peak_checker_unsub = None
if self._to_write.qsize() < PENDING_MSG_PEAK: if len(self._message_queue) < PENDING_MSG_PEAK:
return return
self._logger.error( self._logger.error(
@ -199,6 +221,7 @@ class WebSocketHandler:
def _cancel(self) -> None: def _cancel(self) -> None:
"""Cancel the connection.""" """Cancel the connection."""
self._closing = True self._closing = True
self._cancel_peak_checker()
if self._handle_task is not None: if self._handle_task is not None:
self._handle_task.cancel() self._handle_task.cancel()
if self._writer_task is not None: if self._writer_task is not None:
@ -356,14 +379,14 @@ class WebSocketHandler:
self._closing = True self._closing = True
self._message_queue.append(None)
if self._ready_future and not self._ready_future.done():
self._ready_future.set_result(None)
try: try:
self._to_write.put_nowait(None)
# Make sure all error messages are written before closing # Make sure all error messages are written before closing
await self._writer_task await self._writer_task
await wsock.close() await wsock.close()
except asyncio.QueueFull: # can be raised by put_nowait
self._writer_task.cancel()
finally: finally:
if disconnect_warn is None: if disconnect_warn is None:
self._logger.debug("Disconnected") self._logger.debug("Disconnected")

View File

@ -1,7 +1,7 @@
"""Test Websocket API http module.""" """Test Websocket API http module."""
import asyncio import asyncio
from datetime import timedelta from datetime import timedelta
from typing import Any from typing import Any, cast
from unittest.mock import patch from unittest.mock import patch
from aiohttp import ServerDisconnectedError, WSMsgType, web from aiohttp import ServerDisconnectedError, WSMsgType, web
@ -53,12 +53,12 @@ async def test_pending_msg_peak(
) -> None: ) -> None:
"""Test pending msg overflow command.""" """Test pending msg overflow command."""
orig_handler = http.WebSocketHandler orig_handler = http.WebSocketHandler
instance = None setup_instance: http.WebSocketHandler | None = None
def instantiate_handler(*args): def instantiate_handler(*args):
nonlocal instance nonlocal setup_instance
instance = orig_handler(*args) setup_instance = orig_handler(*args)
return instance return setup_instance
with patch( with patch(
"homeassistant.components.websocket_api.http.WebSocketHandler", "homeassistant.components.websocket_api.http.WebSocketHandler",
@ -66,12 +66,11 @@ async def test_pending_msg_peak(
): ):
websocket_client = await hass_ws_client() websocket_client = await hass_ws_client()
# Kill writer task and fill queue past peak instance: http.WebSocketHandler = cast(http.WebSocketHandler, setup_instance)
for _ in range(5):
instance._to_write.put_nowait(None)
# Trigger the peak check # Fill the queue past the allowed peak
instance._send_message({}) for _ in range(10):
instance._send_message({})
async_fire_time_changed( async_fire_time_changed(
hass, utcnow() + timedelta(seconds=const.PENDING_MSG_PEAK_TIME + 1) hass, utcnow() + timedelta(seconds=const.PENDING_MSG_PEAK_TIME + 1)
@ -79,8 +78,54 @@ async def test_pending_msg_peak(
msg = await websocket_client.receive() msg = await websocket_client.receive()
assert msg.type == WSMsgType.close assert msg.type == WSMsgType.close
assert "Client unable to keep up with pending messages" in caplog.text assert "Client unable to keep up with pending messages" in caplog.text
assert "Stayed over 5 for 5 seconds"
async def test_pending_msg_peak_recovery(
hass: HomeAssistant,
mock_low_peak,
hass_ws_client: WebSocketGenerator,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test pending msg nears the peak but recovers."""
orig_handler = http.WebSocketHandler
setup_instance: http.WebSocketHandler | None = None
def instantiate_handler(*args):
nonlocal setup_instance
setup_instance = orig_handler(*args)
return setup_instance
with patch(
"homeassistant.components.websocket_api.http.WebSocketHandler",
instantiate_handler,
):
websocket_client = await hass_ws_client()
instance: http.WebSocketHandler = cast(http.WebSocketHandler, setup_instance)
# Make sure the call later is started
for _ in range(10):
instance._send_message({})
for _ in range(10):
msg = await websocket_client.receive()
assert msg.type == WSMsgType.TEXT
instance._send_message({})
msg = await websocket_client.receive()
assert msg.type == WSMsgType.TEXT
# Cleanly shutdown
instance._send_message({})
instance._handle_task.cancel()
msg = await websocket_client.receive()
assert msg.type == WSMsgType.TEXT
msg = await websocket_client.receive()
assert msg.type == WSMsgType.close
assert "Client unable to keep up with pending messages" not in caplog.text
async def test_pending_msg_peak_but_does_not_overflow( async def test_pending_msg_peak_but_does_not_overflow(
@ -91,12 +136,12 @@ async def test_pending_msg_peak_but_does_not_overflow(
) -> None: ) -> None:
"""Test pending msg hits the low peak but recovers and does not overflow.""" """Test pending msg hits the low peak but recovers and does not overflow."""
orig_handler = http.WebSocketHandler orig_handler = http.WebSocketHandler
instance: http.WebSocketHandler | None = None setup_instance: http.WebSocketHandler | None = None
def instantiate_handler(*args): def instantiate_handler(*args):
nonlocal instance nonlocal setup_instance
instance = orig_handler(*args) setup_instance = orig_handler(*args)
return instance return setup_instance
with patch( with patch(
"homeassistant.components.websocket_api.http.WebSocketHandler", "homeassistant.components.websocket_api.http.WebSocketHandler",
@ -104,18 +149,17 @@ async def test_pending_msg_peak_but_does_not_overflow(
): ):
websocket_client = await hass_ws_client() websocket_client = await hass_ws_client()
assert instance is not None instance: http.WebSocketHandler = cast(http.WebSocketHandler, setup_instance)
# Kill writer task and fill queue past peak # Kill writer task and fill queue past peak
for _ in range(5): for _ in range(5):
instance._to_write.put_nowait(None) instance._message_queue.append(None)
# Trigger the peak check # Trigger the peak check
instance._send_message({}) instance._send_message({})
# Clear the queue # Clear the queue
while instance._to_write.qsize() > 0: instance._message_queue.clear()
instance._to_write.get_nowait()
# Trigger the peak clear # Trigger the peak clear
instance._send_message({}) instance._send_message({})