mirror of
https://github.com/home-assistant/core.git
synced 2025-07-21 04:07:08 +00:00
Refactor websocket handler to reduce complexity (#124174)
This commit is contained in:
parent
1010edf4bd
commit
14c2ca85ec
@ -11,6 +11,7 @@ import logging
|
||||
from typing import TYPE_CHECKING, Any, Final
|
||||
|
||||
from aiohttp import WSMsgType, web
|
||||
from aiohttp.http_websocket import WebSocketWriter
|
||||
|
||||
from homeassistant.components.http import KEY_HASS, HomeAssistantView
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
||||
@ -124,7 +125,9 @@ class WebSocketHandler:
|
||||
return "finished connection"
|
||||
|
||||
async def _writer(
|
||||
self, send_bytes_text: Callable[[bytes], Coroutine[Any, Any, None]]
|
||||
self,
|
||||
connection: ActiveConnection,
|
||||
send_bytes_text: Callable[[bytes], Coroutine[Any, Any, None]],
|
||||
) -> None:
|
||||
"""Write outgoing messages."""
|
||||
# Variables are set locally to avoid lookups in the loop
|
||||
@ -134,7 +137,7 @@ class WebSocketHandler:
|
||||
loop = self._loop
|
||||
is_debug_log_enabled = partial(logger.isEnabledFor, logging.DEBUG)
|
||||
debug = logger.debug
|
||||
can_coalesce = self._connection and self._connection.can_coalesce
|
||||
can_coalesce = connection.can_coalesce
|
||||
ready_message_count = len(message_queue)
|
||||
# Exceptions if Socket disconnected or cancelled by connection handler
|
||||
try:
|
||||
@ -148,7 +151,7 @@ class WebSocketHandler:
|
||||
|
||||
if not can_coalesce:
|
||||
# coalesce may be enabled later in the connection
|
||||
can_coalesce = self._connection and self._connection.can_coalesce
|
||||
can_coalesce = connection.can_coalesce
|
||||
|
||||
if not can_coalesce or ready_message_count == 1:
|
||||
message = message_queue.popleft()
|
||||
@ -298,19 +301,16 @@ class WebSocketHandler:
|
||||
request = self._request
|
||||
wsock = self._wsock
|
||||
logger = self._logger
|
||||
debug = logger.debug
|
||||
hass = self._hass
|
||||
is_enabled_for = logger.isEnabledFor
|
||||
logging_debug = logging.DEBUG
|
||||
|
||||
try:
|
||||
async with asyncio.timeout(10):
|
||||
await wsock.prepare(request)
|
||||
except TimeoutError:
|
||||
self._logger.warning("Timeout preparing request from %s", request.remote)
|
||||
logger.warning("Timeout preparing request from %s", request.remote)
|
||||
return wsock
|
||||
|
||||
debug("%s: Connected from %s", self.description, request.remote)
|
||||
logger.debug("%s: Connected from %s", self.description, request.remote)
|
||||
self._handle_task = asyncio.current_task()
|
||||
|
||||
unsub_stop = hass.bus.async_listen(
|
||||
@ -325,134 +325,25 @@ class WebSocketHandler:
|
||||
auth = AuthPhase(
|
||||
logger, hass, self._send_message, self._cancel, request, send_bytes_text
|
||||
)
|
||||
connection = None
|
||||
disconnect_warn = None
|
||||
connection: ActiveConnection | None = None
|
||||
disconnect_warn: str | None = None
|
||||
|
||||
try:
|
||||
await send_bytes_text(AUTH_REQUIRED_MESSAGE)
|
||||
|
||||
# Auth Phase
|
||||
try:
|
||||
msg = await wsock.receive(10)
|
||||
except TimeoutError as err:
|
||||
disconnect_warn = "Did not receive auth message within 10 seconds"
|
||||
raise Disconnect from err
|
||||
|
||||
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING):
|
||||
raise Disconnect # noqa: TRY301
|
||||
|
||||
if msg.type != WSMsgType.TEXT:
|
||||
disconnect_warn = "Received non-Text message."
|
||||
raise Disconnect # noqa: TRY301
|
||||
|
||||
try:
|
||||
auth_msg_data = json_loads(msg.data)
|
||||
except ValueError as err:
|
||||
disconnect_warn = "Received invalid JSON."
|
||||
raise Disconnect from err
|
||||
|
||||
if is_enabled_for(logging_debug):
|
||||
debug("%s: Received %s", self.description, auth_msg_data)
|
||||
connection = await auth.async_handle(auth_msg_data)
|
||||
# As the webserver is now started before the start
|
||||
# event we do not want to block for websocket responses
|
||||
#
|
||||
# We only start the writer queue after the auth phase is completed
|
||||
# since there is no need to queue messages before the auth phase
|
||||
self._connection = connection
|
||||
self._writer_task = create_eager_task(self._writer(send_bytes_text))
|
||||
hass.data[DATA_CONNECTIONS] = hass.data.get(DATA_CONNECTIONS, 0) + 1
|
||||
async_dispatcher_send(hass, SIGNAL_WEBSOCKET_CONNECTED)
|
||||
|
||||
self._authenticated = True
|
||||
#
|
||||
#
|
||||
# Our websocket implementation is backed by a deque
|
||||
#
|
||||
# As back-pressure builds, the queue will back up and use more memory
|
||||
# until we disconnect the client when the queue size reaches
|
||||
# MAX_PENDING_MSG. When we are generating a high volume of websocket messages,
|
||||
# we hit a bottleneck in aiohttp where it will wait for
|
||||
# the buffer to drain before sending the next message and messages
|
||||
# start backing up in the queue.
|
||||
#
|
||||
# https://github.com/aio-libs/aiohttp/issues/1367 added drains
|
||||
# to the websocket writer to handle malicious clients and network issues.
|
||||
# The drain causes multiple problems for us since the buffer cannot be
|
||||
# drained fast enough when we deliver a high volume or large messages:
|
||||
#
|
||||
# - We end up disconnecting the client. The client will then reconnect,
|
||||
# and the cycle repeats itself, which results in a significant amount of
|
||||
# CPU usage.
|
||||
#
|
||||
# - Messages latency increases because messages cannot be moved into
|
||||
# the TCP buffer because it is blocked waiting for the drain to happen because
|
||||
# of the low default limit of 16KiB. By increasing the limit, we instead
|
||||
# rely on the underlying TCP buffer and stack to deliver the messages which
|
||||
# can typically happen much faster.
|
||||
#
|
||||
# After the auth phase is completed, and we are not concerned about
|
||||
# the user being a malicious client, we set the limit to force a drain
|
||||
# to 1MiB. 1MiB is the maximum expected size of the serialized entity
|
||||
# registry, which is the largest message we usually send.
|
||||
#
|
||||
# https://github.com/aio-libs/aiohttp/commit/b3c80ee3f7d5d8f0b8bc27afe52e4d46621eaf99
|
||||
# added a way to set the limit, but there is no way to actually
|
||||
# reach the code to set the limit, so we have to set it directly.
|
||||
#
|
||||
writer._limit = 2**20 # noqa: SLF001
|
||||
async_handle_str = connection.async_handle
|
||||
async_handle_binary = connection.async_handle_binary
|
||||
|
||||
# Command phase
|
||||
while not wsock.closed:
|
||||
msg = await wsock.receive()
|
||||
|
||||
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING):
|
||||
break
|
||||
|
||||
if msg.type is WSMsgType.BINARY:
|
||||
if len(msg.data) < 1:
|
||||
disconnect_warn = "Received invalid binary message."
|
||||
break
|
||||
handler = msg.data[0]
|
||||
payload = msg.data[1:]
|
||||
async_handle_binary(handler, payload)
|
||||
continue
|
||||
|
||||
if msg.type is not WSMsgType.TEXT:
|
||||
disconnect_warn = "Received non-Text message."
|
||||
break
|
||||
|
||||
try:
|
||||
command_msg_data = json_loads(msg.data)
|
||||
except ValueError:
|
||||
disconnect_warn = "Received invalid JSON."
|
||||
break
|
||||
|
||||
if is_enabled_for(logging_debug):
|
||||
debug("%s: Received %s", self.description, command_msg_data)
|
||||
|
||||
# command_msg_data is always deserialized from JSON as a list
|
||||
if type(command_msg_data) is not list: # noqa: E721
|
||||
async_handle_str(command_msg_data)
|
||||
continue
|
||||
|
||||
for split_msg in command_msg_data:
|
||||
async_handle_str(split_msg)
|
||||
|
||||
connection = await self._async_handle_auth_phase(auth, send_bytes_text)
|
||||
self._async_increase_writer_limit(writer)
|
||||
await self._async_websocket_command_phase(connection, send_bytes_text)
|
||||
except asyncio.CancelledError:
|
||||
debug("%s: Connection cancelled", self.description)
|
||||
logger.debug("%s: Connection cancelled", self.description)
|
||||
raise
|
||||
|
||||
except Disconnect as ex:
|
||||
debug("%s: Connection closed by client: %s", self.description, ex)
|
||||
if disconnect_msg := str(ex):
|
||||
disconnect_warn = disconnect_msg
|
||||
|
||||
logger.debug("%s: Connection closed by client: %s", self.description, ex)
|
||||
except Exception:
|
||||
self._logger.exception(
|
||||
logger.exception(
|
||||
"%s: Unexpected error inside websocket API", self.description
|
||||
)
|
||||
|
||||
finally:
|
||||
unsub_stop()
|
||||
|
||||
@ -465,38 +356,175 @@ class WebSocketHandler:
|
||||
if self._ready_future and not self._ready_future.done():
|
||||
self._ready_future.set_result(len(self._message_queue))
|
||||
|
||||
# If the writer gets canceled we still need to close the websocket
|
||||
# so we have another finally block to make sure we close the websocket
|
||||
# if the writer gets canceled.
|
||||
try:
|
||||
if self._writer_task:
|
||||
await self._writer_task
|
||||
finally:
|
||||
try:
|
||||
# Make sure all error messages are written before closing
|
||||
await wsock.close()
|
||||
finally:
|
||||
if disconnect_warn is None:
|
||||
debug("%s: Disconnected", self.description)
|
||||
else:
|
||||
self._logger.warning(
|
||||
"%s: Disconnected: %s", self.description, disconnect_warn
|
||||
)
|
||||
|
||||
if connection is not None:
|
||||
hass.data[DATA_CONNECTIONS] -= 1
|
||||
self._connection = None
|
||||
|
||||
async_dispatcher_send(hass, SIGNAL_WEBSOCKET_DISCONNECTED)
|
||||
|
||||
# Break reference cycles to make sure GC can happen sooner
|
||||
self._wsock = None # type: ignore[assignment]
|
||||
self._request = None # type: ignore[assignment]
|
||||
self._hass = None # type: ignore[assignment]
|
||||
self._logger = None # type: ignore[assignment]
|
||||
self._message_queue = None # type: ignore[assignment]
|
||||
self._handle_task = None
|
||||
self._writer_task = None
|
||||
self._ready_future = None
|
||||
await self._async_cleanup_writer_and_close(disconnect_warn, connection)
|
||||
|
||||
return wsock
|
||||
|
||||
async def _async_handle_auth_phase(
|
||||
self,
|
||||
auth: AuthPhase,
|
||||
send_bytes_text: Callable[[bytes], Coroutine[Any, Any, None]],
|
||||
) -> ActiveConnection:
|
||||
"""Handle the auth phase of the websocket connection."""
|
||||
await send_bytes_text(AUTH_REQUIRED_MESSAGE)
|
||||
|
||||
# Auth Phase
|
||||
try:
|
||||
msg = await self._wsock.receive(10)
|
||||
except TimeoutError as err:
|
||||
raise Disconnect("Did not receive auth message within 10 seconds") from err
|
||||
|
||||
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING):
|
||||
raise Disconnect("Received close message during auth phase")
|
||||
|
||||
if msg.type is not WSMsgType.TEXT:
|
||||
raise Disconnect("Received non-Text message during auth phase")
|
||||
|
||||
try:
|
||||
auth_msg_data = json_loads(msg.data)
|
||||
except ValueError as err:
|
||||
raise Disconnect("Received invalid JSON during auth phase") from err
|
||||
|
||||
if self._logger.isEnabledFor(logging.DEBUG):
|
||||
self._logger.debug("%s: Received %s", self.description, auth_msg_data)
|
||||
connection = await auth.async_handle(auth_msg_data)
|
||||
# As the webserver is now started before the start
|
||||
# event we do not want to block for websocket responses
|
||||
#
|
||||
# We only start the writer queue after the auth phase is completed
|
||||
# since there is no need to queue messages before the auth phase
|
||||
self._connection = connection
|
||||
self._writer_task = create_eager_task(self._writer(connection, send_bytes_text))
|
||||
self._hass.data[DATA_CONNECTIONS] = self._hass.data.get(DATA_CONNECTIONS, 0) + 1
|
||||
async_dispatcher_send(self._hass, SIGNAL_WEBSOCKET_CONNECTED)
|
||||
|
||||
self._authenticated = True
|
||||
return connection
|
||||
|
||||
@callback
|
||||
def _async_increase_writer_limit(self, writer: WebSocketWriter) -> None:
|
||||
#
|
||||
#
|
||||
# Our websocket implementation is backed by a deque
|
||||
#
|
||||
# As back-pressure builds, the queue will back up and use more memory
|
||||
# until we disconnect the client when the queue size reaches
|
||||
# MAX_PENDING_MSG. When we are generating a high volume of websocket messages,
|
||||
# we hit a bottleneck in aiohttp where it will wait for
|
||||
# the buffer to drain before sending the next message and messages
|
||||
# start backing up in the queue.
|
||||
#
|
||||
# https://github.com/aio-libs/aiohttp/issues/1367 added drains
|
||||
# to the websocket writer to handle malicious clients and network issues.
|
||||
# The drain causes multiple problems for us since the buffer cannot be
|
||||
# drained fast enough when we deliver a high volume or large messages:
|
||||
#
|
||||
# - We end up disconnecting the client. The client will then reconnect,
|
||||
# and the cycle repeats itself, which results in a significant amount of
|
||||
# CPU usage.
|
||||
#
|
||||
# - Messages latency increases because messages cannot be moved into
|
||||
# the TCP buffer because it is blocked waiting for the drain to happen because
|
||||
# of the low default limit of 16KiB. By increasing the limit, we instead
|
||||
# rely on the underlying TCP buffer and stack to deliver the messages which
|
||||
# can typically happen much faster.
|
||||
#
|
||||
# After the auth phase is completed, and we are not concerned about
|
||||
# the user being a malicious client, we set the limit to force a drain
|
||||
# to 1MiB. 1MiB is the maximum expected size of the serialized entity
|
||||
# registry, which is the largest message we usually send.
|
||||
#
|
||||
# https://github.com/aio-libs/aiohttp/commit/b3c80ee3f7d5d8f0b8bc27afe52e4d46621eaf99
|
||||
# added a way to set the limit, but there is no way to actually
|
||||
# reach the code to set the limit, so we have to set it directly.
|
||||
#
|
||||
writer._limit = 2**20 # noqa: SLF001
|
||||
|
||||
async def _async_websocket_command_phase(
|
||||
self,
|
||||
connection: ActiveConnection,
|
||||
send_bytes_text: Callable[[bytes], Coroutine[Any, Any, None]],
|
||||
) -> None:
|
||||
"""Handle the command phase of the websocket connection."""
|
||||
wsock = self._wsock
|
||||
async_handle_str = connection.async_handle
|
||||
async_handle_binary = connection.async_handle_binary
|
||||
_debug_enabled = partial(self._logger.isEnabledFor, logging.DEBUG)
|
||||
|
||||
# Command phase
|
||||
while not wsock.closed:
|
||||
msg = await wsock.receive()
|
||||
|
||||
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING):
|
||||
break
|
||||
|
||||
if msg.type is WSMsgType.BINARY:
|
||||
if len(msg.data) < 1:
|
||||
raise Disconnect("Received invalid binary message.")
|
||||
|
||||
handler = msg.data[0]
|
||||
payload = msg.data[1:]
|
||||
async_handle_binary(handler, payload)
|
||||
continue
|
||||
|
||||
if msg.type is not WSMsgType.TEXT:
|
||||
raise Disconnect("Received non-Text message.")
|
||||
|
||||
try:
|
||||
command_msg_data = json_loads(msg.data)
|
||||
except ValueError as ex:
|
||||
raise Disconnect("Received invalid JSON.") from ex
|
||||
|
||||
if _debug_enabled():
|
||||
self._logger.debug(
|
||||
"%s: Received %s", self.description, command_msg_data
|
||||
)
|
||||
|
||||
# command_msg_data is always deserialized from JSON as a list
|
||||
if type(command_msg_data) is not list: # noqa: E721
|
||||
async_handle_str(command_msg_data)
|
||||
continue
|
||||
|
||||
for split_msg in command_msg_data:
|
||||
async_handle_str(split_msg)
|
||||
|
||||
async def _async_cleanup_writer_and_close(
|
||||
self, disconnect_warn: str | None, connection: ActiveConnection | None
|
||||
) -> None:
|
||||
"""Cleanup the writer and close the websocket."""
|
||||
# If the writer gets canceled we still need to close the websocket
|
||||
# so we have another finally block to make sure we close the websocket
|
||||
# if the writer gets canceled.
|
||||
wsock = self._wsock
|
||||
hass = self._hass
|
||||
logger = self._logger
|
||||
try:
|
||||
if self._writer_task:
|
||||
await self._writer_task
|
||||
finally:
|
||||
try:
|
||||
# Make sure all error messages are written before closing
|
||||
await wsock.close()
|
||||
finally:
|
||||
if disconnect_warn is None:
|
||||
logger.debug("%s: Disconnected", self.description)
|
||||
else:
|
||||
logger.warning(
|
||||
"%s: Disconnected: %s", self.description, disconnect_warn
|
||||
)
|
||||
|
||||
if connection is not None:
|
||||
hass.data[DATA_CONNECTIONS] -= 1
|
||||
self._connection = None
|
||||
|
||||
async_dispatcher_send(hass, SIGNAL_WEBSOCKET_DISCONNECTED)
|
||||
|
||||
# Break reference cycles to make sure GC can happen sooner
|
||||
self._wsock = None # type: ignore[assignment]
|
||||
self._request = None # type: ignore[assignment]
|
||||
self._hass = None # type: ignore[assignment]
|
||||
self._logger = None # type: ignore[assignment]
|
||||
self._message_queue = None # type: ignore[assignment]
|
||||
self._handle_task = None
|
||||
self._writer_task = None
|
||||
self._ready_future = None
|
||||
|
Loading…
x
Reference in New Issue
Block a user