Refactor websocket handler to reduce complexity (#124174)

This commit is contained in:
J. Nick Koston 2024-08-18 14:17:17 -05:00 committed by GitHub
parent 1010edf4bd
commit 14c2ca85ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -11,6 +11,7 @@ import logging
from typing import TYPE_CHECKING, Any, Final from typing import TYPE_CHECKING, Any, Final
from aiohttp import WSMsgType, web from aiohttp import WSMsgType, web
from aiohttp.http_websocket import WebSocketWriter
from homeassistant.components.http import KEY_HASS, HomeAssistantView from homeassistant.components.http import KEY_HASS, HomeAssistantView
from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.const import EVENT_HOMEASSISTANT_STOP
@ -124,7 +125,9 @@ class WebSocketHandler:
return "finished connection" return "finished connection"
async def _writer( async def _writer(
self, send_bytes_text: Callable[[bytes], Coroutine[Any, Any, None]] self,
connection: ActiveConnection,
send_bytes_text: Callable[[bytes], Coroutine[Any, Any, None]],
) -> None: ) -> None:
"""Write outgoing messages.""" """Write outgoing messages."""
# Variables are set locally to avoid lookups in the loop # Variables are set locally to avoid lookups in the loop
@ -134,7 +137,7 @@ class WebSocketHandler:
loop = self._loop loop = self._loop
is_debug_log_enabled = partial(logger.isEnabledFor, logging.DEBUG) is_debug_log_enabled = partial(logger.isEnabledFor, logging.DEBUG)
debug = logger.debug debug = logger.debug
can_coalesce = self._connection and self._connection.can_coalesce can_coalesce = connection.can_coalesce
ready_message_count = len(message_queue) ready_message_count = len(message_queue)
# Exceptions if Socket disconnected or cancelled by connection handler # Exceptions if Socket disconnected or cancelled by connection handler
try: try:
@ -148,7 +151,7 @@ class WebSocketHandler:
if not can_coalesce: if not can_coalesce:
# coalesce may be enabled later in the connection # coalesce may be enabled later in the connection
can_coalesce = self._connection and self._connection.can_coalesce can_coalesce = connection.can_coalesce
if not can_coalesce or ready_message_count == 1: if not can_coalesce or ready_message_count == 1:
message = message_queue.popleft() message = message_queue.popleft()
@ -298,19 +301,16 @@ class WebSocketHandler:
request = self._request request = self._request
wsock = self._wsock wsock = self._wsock
logger = self._logger logger = self._logger
debug = logger.debug
hass = self._hass hass = self._hass
is_enabled_for = logger.isEnabledFor
logging_debug = logging.DEBUG
try: try:
async with asyncio.timeout(10): async with asyncio.timeout(10):
await wsock.prepare(request) await wsock.prepare(request)
except TimeoutError: except TimeoutError:
self._logger.warning("Timeout preparing request from %s", request.remote) logger.warning("Timeout preparing request from %s", request.remote)
return wsock return wsock
debug("%s: Connected from %s", self.description, request.remote) logger.debug("%s: Connected from %s", self.description, request.remote)
self._handle_task = asyncio.current_task() self._handle_task = asyncio.current_task()
unsub_stop = hass.bus.async_listen( unsub_stop = hass.bus.async_listen(
@ -325,34 +325,68 @@ class WebSocketHandler:
auth = AuthPhase( auth = AuthPhase(
logger, hass, self._send_message, self._cancel, request, send_bytes_text logger, hass, self._send_message, self._cancel, request, send_bytes_text
) )
connection = None connection: ActiveConnection | None = None
disconnect_warn = None disconnect_warn: str | None = None
try: try:
connection = await self._async_handle_auth_phase(auth, send_bytes_text)
self._async_increase_writer_limit(writer)
await self._async_websocket_command_phase(connection, send_bytes_text)
except asyncio.CancelledError:
logger.debug("%s: Connection cancelled", self.description)
raise
except Disconnect as ex:
if disconnect_msg := str(ex):
disconnect_warn = disconnect_msg
logger.debug("%s: Connection closed by client: %s", self.description, ex)
except Exception:
logger.exception(
"%s: Unexpected error inside websocket API", self.description
)
finally:
unsub_stop()
self._cancel_peak_checker()
if connection is not None:
connection.async_handle_close()
self._closing = True
if self._ready_future and not self._ready_future.done():
self._ready_future.set_result(len(self._message_queue))
await self._async_cleanup_writer_and_close(disconnect_warn, connection)
return wsock
async def _async_handle_auth_phase(
self,
auth: AuthPhase,
send_bytes_text: Callable[[bytes], Coroutine[Any, Any, None]],
) -> ActiveConnection:
"""Handle the auth phase of the websocket connection."""
await send_bytes_text(AUTH_REQUIRED_MESSAGE) await send_bytes_text(AUTH_REQUIRED_MESSAGE)
# Auth Phase # Auth Phase
try: try:
msg = await wsock.receive(10) msg = await self._wsock.receive(10)
except TimeoutError as err: except TimeoutError as err:
disconnect_warn = "Did not receive auth message within 10 seconds" raise Disconnect("Did not receive auth message within 10 seconds") from err
raise Disconnect from err
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING): if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING):
raise Disconnect # noqa: TRY301 raise Disconnect("Received close message during auth phase")
if msg.type != WSMsgType.TEXT: if msg.type is not WSMsgType.TEXT:
disconnect_warn = "Received non-Text message." raise Disconnect("Received non-Text message during auth phase")
raise Disconnect # noqa: TRY301
try: try:
auth_msg_data = json_loads(msg.data) auth_msg_data = json_loads(msg.data)
except ValueError as err: except ValueError as err:
disconnect_warn = "Received invalid JSON." raise Disconnect("Received invalid JSON during auth phase") from err
raise Disconnect from err
if is_enabled_for(logging_debug): if self._logger.isEnabledFor(logging.DEBUG):
debug("%s: Received %s", self.description, auth_msg_data) self._logger.debug("%s: Received %s", self.description, auth_msg_data)
connection = await auth.async_handle(auth_msg_data) connection = await auth.async_handle(auth_msg_data)
# As the webserver is now started before the start # As the webserver is now started before the start
# event we do not want to block for websocket responses # event we do not want to block for websocket responses
@ -360,11 +394,15 @@ class WebSocketHandler:
# We only start the writer queue after the auth phase is completed # We only start the writer queue after the auth phase is completed
# since there is no need to queue messages before the auth phase # since there is no need to queue messages before the auth phase
self._connection = connection self._connection = connection
self._writer_task = create_eager_task(self._writer(send_bytes_text)) self._writer_task = create_eager_task(self._writer(connection, send_bytes_text))
hass.data[DATA_CONNECTIONS] = hass.data.get(DATA_CONNECTIONS, 0) + 1 self._hass.data[DATA_CONNECTIONS] = self._hass.data.get(DATA_CONNECTIONS, 0) + 1
async_dispatcher_send(hass, SIGNAL_WEBSOCKET_CONNECTED) async_dispatcher_send(self._hass, SIGNAL_WEBSOCKET_CONNECTED)
self._authenticated = True self._authenticated = True
return connection
@callback
def _async_increase_writer_limit(self, writer: WebSocketWriter) -> None:
# #
# #
# Our websocket implementation is backed by a deque # Our websocket implementation is backed by a deque
@ -401,8 +439,17 @@ class WebSocketHandler:
# reach the code to set the limit, so we have to set it directly. # reach the code to set the limit, so we have to set it directly.
# #
writer._limit = 2**20 # noqa: SLF001 writer._limit = 2**20 # noqa: SLF001
async def _async_websocket_command_phase(
self,
connection: ActiveConnection,
send_bytes_text: Callable[[bytes], Coroutine[Any, Any, None]],
) -> None:
"""Handle the command phase of the websocket connection."""
wsock = self._wsock
async_handle_str = connection.async_handle async_handle_str = connection.async_handle
async_handle_binary = connection.async_handle_binary async_handle_binary = connection.async_handle_binary
_debug_enabled = partial(self._logger.isEnabledFor, logging.DEBUG)
# Command phase # Command phase
while not wsock.closed: while not wsock.closed:
@ -413,25 +460,25 @@ class WebSocketHandler:
if msg.type is WSMsgType.BINARY: if msg.type is WSMsgType.BINARY:
if len(msg.data) < 1: if len(msg.data) < 1:
disconnect_warn = "Received invalid binary message." raise Disconnect("Received invalid binary message.")
break
handler = msg.data[0] handler = msg.data[0]
payload = msg.data[1:] payload = msg.data[1:]
async_handle_binary(handler, payload) async_handle_binary(handler, payload)
continue continue
if msg.type is not WSMsgType.TEXT: if msg.type is not WSMsgType.TEXT:
disconnect_warn = "Received non-Text message." raise Disconnect("Received non-Text message.")
break
try: try:
command_msg_data = json_loads(msg.data) command_msg_data = json_loads(msg.data)
except ValueError: except ValueError as ex:
disconnect_warn = "Received invalid JSON." raise Disconnect("Received invalid JSON.") from ex
break
if is_enabled_for(logging_debug): if _debug_enabled():
debug("%s: Received %s", self.description, command_msg_data) self._logger.debug(
"%s: Received %s", self.description, command_msg_data
)
# command_msg_data is always deserialized from JSON as a list # command_msg_data is always deserialized from JSON as a list
if type(command_msg_data) is not list: # noqa: E721 if type(command_msg_data) is not list: # noqa: E721
@ -441,33 +488,16 @@ class WebSocketHandler:
for split_msg in command_msg_data: for split_msg in command_msg_data:
async_handle_str(split_msg) async_handle_str(split_msg)
except asyncio.CancelledError: async def _async_cleanup_writer_and_close(
debug("%s: Connection cancelled", self.description) self, disconnect_warn: str | None, connection: ActiveConnection | None
raise ) -> None:
"""Cleanup the writer and close the websocket."""
except Disconnect as ex:
debug("%s: Connection closed by client: %s", self.description, ex)
except Exception:
self._logger.exception(
"%s: Unexpected error inside websocket API", self.description
)
finally:
unsub_stop()
self._cancel_peak_checker()
if connection is not None:
connection.async_handle_close()
self._closing = True
if self._ready_future and not self._ready_future.done():
self._ready_future.set_result(len(self._message_queue))
# If the writer gets canceled we still need to close the websocket # If the writer gets canceled we still need to close the websocket
# so we have another finally block to make sure we close the websocket # so we have another finally block to make sure we close the websocket
# if the writer gets canceled. # if the writer gets canceled.
wsock = self._wsock
hass = self._hass
logger = self._logger
try: try:
if self._writer_task: if self._writer_task:
await self._writer_task await self._writer_task
@ -477,9 +507,9 @@ class WebSocketHandler:
await wsock.close() await wsock.close()
finally: finally:
if disconnect_warn is None: if disconnect_warn is None:
debug("%s: Disconnected", self.description) logger.debug("%s: Disconnected", self.description)
else: else:
self._logger.warning( logger.warning(
"%s: Disconnected: %s", self.description, disconnect_warn "%s: Disconnected: %s", self.description, disconnect_warn
) )
@ -498,5 +528,3 @@ class WebSocketHandler:
self._handle_task = None self._handle_task = None
self._writer_task = None self._writer_task = None
self._ready_future = None self._ready_future = None
return wsock