From 14c2ca85ec3a8ce8779c26a7af2bb3a6aa3d3735 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 18 Aug 2024 14:17:17 -0500 Subject: [PATCH] Refactor websocket handler to reduce complexity (#124174) --- .../components/websocket_api/http.py | 348 ++++++++++-------- 1 file changed, 188 insertions(+), 160 deletions(-) diff --git a/homeassistant/components/websocket_api/http.py b/homeassistant/components/websocket_api/http.py index 8ed3469d7ed..e33da9a8b4a 100644 --- a/homeassistant/components/websocket_api/http.py +++ b/homeassistant/components/websocket_api/http.py @@ -11,6 +11,7 @@ import logging from typing import TYPE_CHECKING, Any, Final from aiohttp import WSMsgType, web +from aiohttp.http_websocket import WebSocketWriter from homeassistant.components.http import KEY_HASS, HomeAssistantView from homeassistant.const import EVENT_HOMEASSISTANT_STOP @@ -124,7 +125,9 @@ class WebSocketHandler: return "finished connection" async def _writer( - self, send_bytes_text: Callable[[bytes], Coroutine[Any, Any, None]] + self, + connection: ActiveConnection, + send_bytes_text: Callable[[bytes], Coroutine[Any, Any, None]], ) -> None: """Write outgoing messages.""" # Variables are set locally to avoid lookups in the loop @@ -134,7 +137,7 @@ class WebSocketHandler: loop = self._loop is_debug_log_enabled = partial(logger.isEnabledFor, logging.DEBUG) debug = logger.debug - can_coalesce = self._connection and self._connection.can_coalesce + can_coalesce = connection.can_coalesce ready_message_count = len(message_queue) # Exceptions if Socket disconnected or cancelled by connection handler try: @@ -148,7 +151,7 @@ class WebSocketHandler: if not can_coalesce: # coalesce may be enabled later in the connection - can_coalesce = self._connection and self._connection.can_coalesce + can_coalesce = connection.can_coalesce if not can_coalesce or ready_message_count == 1: message = message_queue.popleft() @@ -298,19 +301,16 @@ class WebSocketHandler: request = self._request wsock = self._wsock logger = self._logger - debug = logger.debug hass = self._hass - is_enabled_for = logger.isEnabledFor - logging_debug = logging.DEBUG try: async with asyncio.timeout(10): await wsock.prepare(request) except TimeoutError: - self._logger.warning("Timeout preparing request from %s", request.remote) + logger.warning("Timeout preparing request from %s", request.remote) return wsock - debug("%s: Connected from %s", self.description, request.remote) + logger.debug("%s: Connected from %s", self.description, request.remote) self._handle_task = asyncio.current_task() unsub_stop = hass.bus.async_listen( @@ -325,134 +325,25 @@ class WebSocketHandler: auth = AuthPhase( logger, hass, self._send_message, self._cancel, request, send_bytes_text ) - connection = None - disconnect_warn = None + connection: ActiveConnection | None = None + disconnect_warn: str | None = None try: - await send_bytes_text(AUTH_REQUIRED_MESSAGE) - - # Auth Phase - try: - msg = await wsock.receive(10) - except TimeoutError as err: - disconnect_warn = "Did not receive auth message within 10 seconds" - raise Disconnect from err - - if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING): - raise Disconnect # noqa: TRY301 - - if msg.type != WSMsgType.TEXT: - disconnect_warn = "Received non-Text message." - raise Disconnect # noqa: TRY301 - - try: - auth_msg_data = json_loads(msg.data) - except ValueError as err: - disconnect_warn = "Received invalid JSON." - raise Disconnect from err - - if is_enabled_for(logging_debug): - debug("%s: Received %s", self.description, auth_msg_data) - connection = await auth.async_handle(auth_msg_data) - # As the webserver is now started before the start - # event we do not want to block for websocket responses - # - # We only start the writer queue after the auth phase is completed - # since there is no need to queue messages before the auth phase - self._connection = connection - self._writer_task = create_eager_task(self._writer(send_bytes_text)) - hass.data[DATA_CONNECTIONS] = hass.data.get(DATA_CONNECTIONS, 0) + 1 - async_dispatcher_send(hass, SIGNAL_WEBSOCKET_CONNECTED) - - self._authenticated = True - # - # - # Our websocket implementation is backed by a deque - # - # As back-pressure builds, the queue will back up and use more memory - # until we disconnect the client when the queue size reaches - # MAX_PENDING_MSG. When we are generating a high volume of websocket messages, - # we hit a bottleneck in aiohttp where it will wait for - # the buffer to drain before sending the next message and messages - # start backing up in the queue. - # - # https://github.com/aio-libs/aiohttp/issues/1367 added drains - # to the websocket writer to handle malicious clients and network issues. - # The drain causes multiple problems for us since the buffer cannot be - # drained fast enough when we deliver a high volume or large messages: - # - # - We end up disconnecting the client. The client will then reconnect, - # and the cycle repeats itself, which results in a significant amount of - # CPU usage. - # - # - Messages latency increases because messages cannot be moved into - # the TCP buffer because it is blocked waiting for the drain to happen because - # of the low default limit of 16KiB. By increasing the limit, we instead - # rely on the underlying TCP buffer and stack to deliver the messages which - # can typically happen much faster. - # - # After the auth phase is completed, and we are not concerned about - # the user being a malicious client, we set the limit to force a drain - # to 1MiB. 1MiB is the maximum expected size of the serialized entity - # registry, which is the largest message we usually send. - # - # https://github.com/aio-libs/aiohttp/commit/b3c80ee3f7d5d8f0b8bc27afe52e4d46621eaf99 - # added a way to set the limit, but there is no way to actually - # reach the code to set the limit, so we have to set it directly. - # - writer._limit = 2**20 # noqa: SLF001 - async_handle_str = connection.async_handle - async_handle_binary = connection.async_handle_binary - - # Command phase - while not wsock.closed: - msg = await wsock.receive() - - if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING): - break - - if msg.type is WSMsgType.BINARY: - if len(msg.data) < 1: - disconnect_warn = "Received invalid binary message." - break - handler = msg.data[0] - payload = msg.data[1:] - async_handle_binary(handler, payload) - continue - - if msg.type is not WSMsgType.TEXT: - disconnect_warn = "Received non-Text message." - break - - try: - command_msg_data = json_loads(msg.data) - except ValueError: - disconnect_warn = "Received invalid JSON." - break - - if is_enabled_for(logging_debug): - debug("%s: Received %s", self.description, command_msg_data) - - # command_msg_data is always deserialized from JSON as a list - if type(command_msg_data) is not list: # noqa: E721 - async_handle_str(command_msg_data) - continue - - for split_msg in command_msg_data: - async_handle_str(split_msg) - + connection = await self._async_handle_auth_phase(auth, send_bytes_text) + self._async_increase_writer_limit(writer) + await self._async_websocket_command_phase(connection, send_bytes_text) except asyncio.CancelledError: - debug("%s: Connection cancelled", self.description) + logger.debug("%s: Connection cancelled", self.description) raise - except Disconnect as ex: - debug("%s: Connection closed by client: %s", self.description, ex) + if disconnect_msg := str(ex): + disconnect_warn = disconnect_msg + logger.debug("%s: Connection closed by client: %s", self.description, ex) except Exception: - self._logger.exception( + logger.exception( "%s: Unexpected error inside websocket API", self.description ) - finally: unsub_stop() @@ -465,38 +356,175 @@ class WebSocketHandler: if self._ready_future and not self._ready_future.done(): self._ready_future.set_result(len(self._message_queue)) - # If the writer gets canceled we still need to close the websocket - # so we have another finally block to make sure we close the websocket - # if the writer gets canceled. - try: - if self._writer_task: - await self._writer_task - finally: - try: - # Make sure all error messages are written before closing - await wsock.close() - finally: - if disconnect_warn is None: - debug("%s: Disconnected", self.description) - else: - self._logger.warning( - "%s: Disconnected: %s", self.description, disconnect_warn - ) - - if connection is not None: - hass.data[DATA_CONNECTIONS] -= 1 - self._connection = None - - async_dispatcher_send(hass, SIGNAL_WEBSOCKET_DISCONNECTED) - - # Break reference cycles to make sure GC can happen sooner - self._wsock = None # type: ignore[assignment] - self._request = None # type: ignore[assignment] - self._hass = None # type: ignore[assignment] - self._logger = None # type: ignore[assignment] - self._message_queue = None # type: ignore[assignment] - self._handle_task = None - self._writer_task = None - self._ready_future = None + await self._async_cleanup_writer_and_close(disconnect_warn, connection) return wsock + + async def _async_handle_auth_phase( + self, + auth: AuthPhase, + send_bytes_text: Callable[[bytes], Coroutine[Any, Any, None]], + ) -> ActiveConnection: + """Handle the auth phase of the websocket connection.""" + await send_bytes_text(AUTH_REQUIRED_MESSAGE) + + # Auth Phase + try: + msg = await self._wsock.receive(10) + except TimeoutError as err: + raise Disconnect("Did not receive auth message within 10 seconds") from err + + if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING): + raise Disconnect("Received close message during auth phase") + + if msg.type is not WSMsgType.TEXT: + raise Disconnect("Received non-Text message during auth phase") + + try: + auth_msg_data = json_loads(msg.data) + except ValueError as err: + raise Disconnect("Received invalid JSON during auth phase") from err + + if self._logger.isEnabledFor(logging.DEBUG): + self._logger.debug("%s: Received %s", self.description, auth_msg_data) + connection = await auth.async_handle(auth_msg_data) + # As the webserver is now started before the start + # event we do not want to block for websocket responses + # + # We only start the writer queue after the auth phase is completed + # since there is no need to queue messages before the auth phase + self._connection = connection + self._writer_task = create_eager_task(self._writer(connection, send_bytes_text)) + self._hass.data[DATA_CONNECTIONS] = self._hass.data.get(DATA_CONNECTIONS, 0) + 1 + async_dispatcher_send(self._hass, SIGNAL_WEBSOCKET_CONNECTED) + + self._authenticated = True + return connection + + @callback + def _async_increase_writer_limit(self, writer: WebSocketWriter) -> None: + # + # + # Our websocket implementation is backed by a deque + # + # As back-pressure builds, the queue will back up and use more memory + # until we disconnect the client when the queue size reaches + # MAX_PENDING_MSG. When we are generating a high volume of websocket messages, + # we hit a bottleneck in aiohttp where it will wait for + # the buffer to drain before sending the next message and messages + # start backing up in the queue. + # + # https://github.com/aio-libs/aiohttp/issues/1367 added drains + # to the websocket writer to handle malicious clients and network issues. + # The drain causes multiple problems for us since the buffer cannot be + # drained fast enough when we deliver a high volume or large messages: + # + # - We end up disconnecting the client. The client will then reconnect, + # and the cycle repeats itself, which results in a significant amount of + # CPU usage. + # + # - Messages latency increases because messages cannot be moved into + # the TCP buffer because it is blocked waiting for the drain to happen because + # of the low default limit of 16KiB. By increasing the limit, we instead + # rely on the underlying TCP buffer and stack to deliver the messages which + # can typically happen much faster. + # + # After the auth phase is completed, and we are not concerned about + # the user being a malicious client, we set the limit to force a drain + # to 1MiB. 1MiB is the maximum expected size of the serialized entity + # registry, which is the largest message we usually send. + # + # https://github.com/aio-libs/aiohttp/commit/b3c80ee3f7d5d8f0b8bc27afe52e4d46621eaf99 + # added a way to set the limit, but there is no way to actually + # reach the code to set the limit, so we have to set it directly. + # + writer._limit = 2**20 # noqa: SLF001 + + async def _async_websocket_command_phase( + self, + connection: ActiveConnection, + send_bytes_text: Callable[[bytes], Coroutine[Any, Any, None]], + ) -> None: + """Handle the command phase of the websocket connection.""" + wsock = self._wsock + async_handle_str = connection.async_handle + async_handle_binary = connection.async_handle_binary + _debug_enabled = partial(self._logger.isEnabledFor, logging.DEBUG) + + # Command phase + while not wsock.closed: + msg = await wsock.receive() + + if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING): + break + + if msg.type is WSMsgType.BINARY: + if len(msg.data) < 1: + raise Disconnect("Received invalid binary message.") + + handler = msg.data[0] + payload = msg.data[1:] + async_handle_binary(handler, payload) + continue + + if msg.type is not WSMsgType.TEXT: + raise Disconnect("Received non-Text message.") + + try: + command_msg_data = json_loads(msg.data) + except ValueError as ex: + raise Disconnect("Received invalid JSON.") from ex + + if _debug_enabled(): + self._logger.debug( + "%s: Received %s", self.description, command_msg_data + ) + + # command_msg_data is always deserialized from JSON as a list + if type(command_msg_data) is not list: # noqa: E721 + async_handle_str(command_msg_data) + continue + + for split_msg in command_msg_data: + async_handle_str(split_msg) + + async def _async_cleanup_writer_and_close( + self, disconnect_warn: str | None, connection: ActiveConnection | None + ) -> None: + """Cleanup the writer and close the websocket.""" + # If the writer gets canceled we still need to close the websocket + # so we have another finally block to make sure we close the websocket + # if the writer gets canceled. + wsock = self._wsock + hass = self._hass + logger = self._logger + try: + if self._writer_task: + await self._writer_task + finally: + try: + # Make sure all error messages are written before closing + await wsock.close() + finally: + if disconnect_warn is None: + logger.debug("%s: Disconnected", self.description) + else: + logger.warning( + "%s: Disconnected: %s", self.description, disconnect_warn + ) + + if connection is not None: + hass.data[DATA_CONNECTIONS] -= 1 + self._connection = None + + async_dispatcher_send(hass, SIGNAL_WEBSOCKET_DISCONNECTED) + + # Break reference cycles to make sure GC can happen sooner + self._wsock = None # type: ignore[assignment] + self._request = None # type: ignore[assignment] + self._hass = None # type: ignore[assignment] + self._logger = None # type: ignore[assignment] + self._message_queue = None # type: ignore[assignment] + self._handle_task = None + self._writer_task = None + self._ready_future = None