diff --git a/homeassistant/components/websocket_api/connection.py b/homeassistant/components/websocket_api/connection.py index 1f6fd302968..a91a5178830 100644 --- a/homeassistant/components/websocket_api/connection.py +++ b/homeassistant/components/websocket_api/connection.py @@ -56,6 +56,10 @@ class ActiveConnection: self.binary_handlers: list[BinaryHandler | None] = [] current_connection.set(self) + def __repr__(self) -> str: + """Return the representation.""" + return f"" + def set_supported_features(self, features: dict[str, float]) -> None: """Set supported features.""" self.supported_features = features @@ -193,7 +197,24 @@ class ActiveConnection: def async_handle_close(self) -> None: """Handle closing down connection.""" for unsub in self.subscriptions.values(): - unsub() + try: + unsub() + except Exception: # pylint: disable=broad-except + # If one fails, make sure we still try the rest + self.logger.exception( + "Error unsubscribing from subscription: %s", unsub + ) + self.subscriptions.clear() + self.send_message = self._connect_closed_error + current_request.set(None) + current_connection.set(None) + + @callback + def _connect_closed_error( + self, msg: str | dict[str, Any] | Callable[[], str] + ) -> None: + """Send a message when the connection is closed.""" + self.logger.debug("Tried to send message %s on closed connection", msg) @callback def async_handle_exception(self, msg: dict[str, Any], err: Exception) -> None: diff --git a/homeassistant/components/websocket_api/const.py b/homeassistant/components/websocket_api/const.py index 9eb04ecbc51..4b9a0943d9a 100644 --- a/homeassistant/components/websocket_api/const.py +++ b/homeassistant/components/websocket_api/const.py @@ -1,9 +1,7 @@ """Websocket constants.""" from __future__ import annotations -import asyncio from collections.abc import Awaitable, Callable -from concurrent import futures from typing import TYPE_CHECKING, Any, Final from homeassistant.core import HomeAssistant @@ -42,10 +40,6 @@ ERR_TEMPLATE_ERROR: Final = "template_error" TYPE_RESULT: Final = "result" -# Define the possible errors that occur when connections are cancelled. -# Originally, this was just asyncio.CancelledError, but issue #9546 showed -# that futures.CancelledErrors can also occur in some situations. -CANCELLATION_ERRORS: Final = (asyncio.CancelledError, futures.CancelledError) # Event types SIGNAL_WEBSOCKET_CONNECTED: Final = "websocket_connected" diff --git a/homeassistant/components/websocket_api/http.py b/homeassistant/components/websocket_api/http.py index 5ca5ea62578..54daf89d8dd 100644 --- a/homeassistant/components/websocket_api/http.py +++ b/homeassistant/components/websocket_api/http.py @@ -4,7 +4,6 @@ from __future__ import annotations import asyncio from collections import deque from collections.abc import Callable -from contextlib import suppress import datetime as dt import logging from typing import TYPE_CHECKING, Any, Final @@ -21,7 +20,6 @@ from homeassistant.util.json import json_loads from .auth import AuthPhase, auth_required_message from .const import ( - CANCELLATION_ERRORS, DATA_CONNECTIONS, MAX_PENDING_MSG, PENDING_MSG_PEAK, @@ -68,15 +66,16 @@ class WebSocketHandler: def __init__(self, hass: HomeAssistant, request: web.Request) -> None: """Initialize an active connection.""" - self.hass = hass - self.request = request - self.wsock = web.WebSocketResponse(heartbeat=55) + self._hass = hass + self._request: web.Request = request + self._wsock = web.WebSocketResponse(heartbeat=55) self._handle_task: asyncio.Task | None = None self._writer_task: asyncio.Task | None = None self._closing: bool = False + self._authenticated: bool = False self._logger = WebSocketAdapter(_WS_LOGGER, {"connid": id(self)}) self._peak_checker_unsub: Callable[[], None] | None = None - self.connection: ActiveConnection | None = None + self._connection: ActiveConnection | None = None # The WebSocketHandler has a single consumer and path # to where messages are queued. This allows the implementation @@ -85,61 +84,81 @@ class WebSocketHandler: self._message_queue: deque = deque() self._ready_future: asyncio.Future[None] | None = None + def __repr__(self) -> str: + """Return the representation.""" + return ( + "" + ) + @property def description(self) -> str: """Return a description of the connection.""" - if self.connection is not None: - return self.connection.get_description(self.request) - return describe_request(self.request) + if connection := self._connection: + return connection.get_description(self._request) + if request := self._request: + return describe_request(request) + return "finished connection" async def _writer(self) -> None: """Write outgoing messages.""" # Variables are set locally to avoid lookups in the loop message_queue = self._message_queue logger = self._logger - send_str = self.wsock.send_str - loop = self.hass.loop + wsock = self._wsock + send_str = wsock.send_str + loop = self._hass.loop debug = logger.debug + is_enabled_for = logger.isEnabledFor + logging_debug = logging.DEBUG # Exceptions if Socket disconnected or cancelled by connection handler try: - with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS): - while not self.wsock.closed: - if (messages_remaining := len(message_queue)) == 0: - self._ready_future = loop.create_future() - await self._ready_future - messages_remaining = len(message_queue) + while not wsock.closed: + if (messages_remaining := len(message_queue)) == 0: + self._ready_future = loop.create_future() + await self._ready_future + messages_remaining = len(message_queue) + # A None message is used to signal the end of the connection + if (process := message_queue.popleft()) is None: + return + + debug_enabled = is_enabled_for(logging_debug) + messages_remaining -= 1 + message = process if isinstance(process, str) else process() + + if ( + not messages_remaining + or not (connection := self._connection) + or not connection.can_coalesce + ): + if debug_enabled: + debug("%s: Sending %s", self.description, message) + await send_str(message) + continue + + messages: list[str] = [message] + while messages_remaining: # A None message is used to signal the end of the connection if (process := message_queue.popleft()) is None: return - + messages.append(process if isinstance(process, str) else process()) messages_remaining -= 1 - message = process if isinstance(process, str) else process() - if ( - not messages_remaining - or not self.connection - or not self.connection.can_coalesce - ): - debug("Sending %s", message) - await send_str(message) - continue - - messages: list[str] = [message] - while messages_remaining: - # A None message is used to signal the end of the connection - if (process := message_queue.popleft()) is None: - return - messages.append( - process if isinstance(process, str) else process() - ) - messages_remaining -= 1 - - joined_messages = ",".join(messages) - coalesced_messages = f"[{joined_messages}]" - debug("Sending %s", coalesced_messages) - await send_str(coalesced_messages) + joined_messages = ",".join(messages) + coalesced_messages = f"[{joined_messages}]" + if debug_enabled: + debug("%s: Sending %s", self.description, coalesced_messages) + await send_str(coalesced_messages) + except asyncio.CancelledError: + debug("%s: Writer cancelled", self.description) + raise + except (RuntimeError, ConnectionResetError) as ex: + debug("%s: Unexpected error in writer: %s", self.description, ex) finally: + debug("%s: Writer done", self.description) # Clean up the peak checker when we shut down the writer self._cancel_peak_checker() @@ -195,7 +214,7 @@ class WebSocketHandler: if not peak_checker_active: self._peak_checker_unsub = async_call_later( - self.hass, PENDING_MSG_PEAK_TIME, self._check_write_peak + self._hass, PENDING_MSG_PEAK_TIME, self._check_write_peak ) @callback @@ -231,8 +250,14 @@ class WebSocketHandler: async def async_handle(self) -> web.WebSocketResponse: """Handle a websocket response.""" - request = self.request - wsock = self.wsock + request = self._request + wsock = self._wsock + logger = self._logger + debug = logger.debug + hass = self._hass + is_enabled_for = logger.isEnabledFor + logging_debug = logging.DEBUG + try: async with async_timeout.timeout(10): await wsock.prepare(request) @@ -240,7 +265,7 @@ class WebSocketHandler: self._logger.warning("Timeout preparing request from %s", request.remote) return wsock - self._logger.debug("Connected from %s", request.remote) + debug("%s: Connected from %s", self.description, request.remote) self._handle_task = asyncio.current_task() @callback @@ -248,17 +273,13 @@ class WebSocketHandler: """Cancel this connection.""" self._cancel() - unsub_stop = self.hass.bus.async_listen( - EVENT_HOMEASSISTANT_STOP, handle_hass_stop - ) + unsub_stop = hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, handle_hass_stop) # As the webserver is now started before the start # event we do not want to block for websocket responses self._writer_task = asyncio.create_task(self._writer()) - auth = AuthPhase( - self._logger, self.hass, self._send_message, self._cancel, request - ) + auth = AuthPhase(logger, hass, self._send_message, self._cancel, request) connection = None disconnect_warn = None @@ -286,13 +307,14 @@ class WebSocketHandler: disconnect_warn = "Received invalid JSON." raise Disconnect from err - self._logger.debug("Received %s", msg_data) - self.connection = connection = await auth.async_handle(msg_data) - self.hass.data[DATA_CONNECTIONS] = ( - self.hass.data.get(DATA_CONNECTIONS, 0) + 1 - ) - async_dispatcher_send(self.hass, SIGNAL_WEBSOCKET_CONNECTED) + if is_enabled_for(logging_debug): + debug("%s: Received %s", self.description, msg_data) + connection = await auth.async_handle(msg_data) + self._connection = connection + hass.data[DATA_CONNECTIONS] = hass.data.get(DATA_CONNECTIONS, 0) + 1 + async_dispatcher_send(hass, SIGNAL_WEBSOCKET_CONNECTED) + self._authenticated = True # # # Our websocket implementation is backed by an asyncio.Queue @@ -356,7 +378,9 @@ class WebSocketHandler: disconnect_warn = "Received invalid JSON." break - self._logger.debug("Received %s", msg_data) + if is_enabled_for(logging_debug): + debug("%s: Received %s", self.description, msg_data) + if not isinstance(msg_data, list): connection.async_handle(msg_data) continue @@ -365,17 +389,22 @@ class WebSocketHandler: connection.async_handle(split_msg) except asyncio.CancelledError: - self._logger.info("Connection closed by client") + debug("%s: Connection cancelled", self.description) + raise - except Disconnect: - pass + except Disconnect as ex: + debug("%s: Connection closed by client: %s", self.description, ex) except Exception: # pylint: disable=broad-except - self._logger.exception("Unexpected error inside websocket API") + self._logger.exception( + "%s: Unexpected error inside websocket API", self.description + ) finally: unsub_stop() + self._cancel_peak_checker() + if connection is not None: connection.async_handle_close() @@ -385,20 +414,37 @@ class WebSocketHandler: if self._ready_future and not self._ready_future.done(): self._ready_future.set_result(None) + # If the writer gets canceled we still need to close the websocket + # so we have another finally block to make sure we close the websocket + # if the writer gets canceled. try: - # Make sure all error messages are written before closing await self._writer_task - await wsock.close() finally: - if disconnect_warn is None: - self._logger.debug("Disconnected") - else: - self._logger.warning("Disconnected: %s", disconnect_warn) + try: + # Make sure all error messages are written before closing + await wsock.close() + finally: + if disconnect_warn is None: + debug("%s: Disconnected", self.description) + else: + self._logger.warning( + "%s: Disconnected: %s", self.description, disconnect_warn + ) - if connection is not None: - self.hass.data[DATA_CONNECTIONS] -= 1 - self.connection = None + if connection is not None: + hass.data[DATA_CONNECTIONS] -= 1 + self._connection = None - async_dispatcher_send(self.hass, SIGNAL_WEBSOCKET_DISCONNECTED) + async_dispatcher_send(hass, SIGNAL_WEBSOCKET_DISCONNECTED) + + # Break reference cycles to make sure GC can happen sooner + self._wsock = None # type: ignore[assignment] + self._request = None # type: ignore[assignment] + self._hass = None # type: ignore[assignment] + self._logger = None # type: ignore[assignment] + self._message_queue = None # type: ignore[assignment] + self._handle_task = None + self._writer_task = None + self._ready_future = None return wsock diff --git a/tests/components/websocket_api/conftest.py b/tests/components/websocket_api/conftest.py index 53569c3fa6a..69adf890584 100644 --- a/tests/components/websocket_api/conftest.py +++ b/tests/components/websocket_api/conftest.py @@ -3,11 +3,19 @@ import pytest from homeassistant.components.websocket_api.auth import TYPE_AUTH_REQUIRED from homeassistant.components.websocket_api.http import URL +from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component +from tests.typing import ( + MockHAClientWebSocket, + WebSocketGenerator, +) + @pytest.fixture -async def websocket_client(hass, hass_ws_client): +async def websocket_client( + hass: HomeAssistant, hass_ws_client: WebSocketGenerator +) -> MockHAClientWebSocket: """Create a websocket client.""" return await hass_ws_client(hass) diff --git a/tests/components/websocket_api/test_http.py b/tests/components/websocket_api/test_http.py index 02384aace89..3205d40b52d 100644 --- a/tests/components/websocket_api/test_http.py +++ b/tests/components/websocket_api/test_http.py @@ -18,7 +18,10 @@ from homeassistant.core import HomeAssistant, callback from homeassistant.util.dt import utcnow from tests.common import async_fire_time_changed -from tests.typing import WebSocketGenerator +from tests.typing import ( + MockHAClientWebSocket, + WebSocketGenerator, +) @pytest.fixture @@ -36,15 +39,103 @@ def mock_low_peak(): async def test_pending_msg_overflow( - hass: HomeAssistant, mock_low_queue, websocket_client + hass: HomeAssistant, mock_low_queue, websocket_client: MockHAClientWebSocket ) -> None: - """Test get_panels command.""" + """Test pending messages overflows.""" for idx in range(10): await websocket_client.send_json({"id": idx + 1, "type": "ping"}) msg = await websocket_client.receive() assert msg.type == WSMsgType.close +async def test_cleanup_on_cancellation( + hass: HomeAssistant, websocket_client: MockHAClientWebSocket +) -> None: + """Test cleanup on cancellation.""" + + subscriptions = None + + # Register a handler that registers a subscription + @callback + @websocket_command( + { + "type": "fake_subscription", + } + ) + def fake_subscription( + hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] + ) -> None: + nonlocal subscriptions + msg_id: int = msg["id"] + connection.subscriptions[msg_id] = callback(lambda: None) + connection.send_result(msg_id) + subscriptions = connection.subscriptions + + async_register_command(hass, fake_subscription) + + # Register a handler that raises on cancel + @callback + @websocket_command( + { + "type": "subscription_that_raises_on_cancel", + } + ) + def subscription_that_raises_on_cancel( + hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] + ) -> None: + nonlocal subscriptions + msg_id: int = msg["id"] + + @callback + def _raise(): + raise ValueError() + + connection.subscriptions[msg_id] = _raise + connection.send_result(msg_id) + subscriptions = connection.subscriptions + + async_register_command(hass, subscription_that_raises_on_cancel) + + # Register a handler that cancels in handler + @callback + @websocket_command( + { + "type": "cancel_in_handler", + } + ) + def cancel_in_handler( + hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] + ) -> None: + raise asyncio.CancelledError() + + async_register_command(hass, cancel_in_handler) + + await websocket_client.send_json({"id": 1, "type": "ping"}) + msg = await websocket_client.receive_json() + assert msg["id"] == 1 + assert msg["type"] == "pong" + assert not subscriptions + await websocket_client.send_json({"id": 2, "type": "fake_subscription"}) + msg = await websocket_client.receive_json() + assert msg["id"] == 2 + assert msg["type"] == const.TYPE_RESULT + assert msg["success"] + assert len(subscriptions) == 2 + await websocket_client.send_json( + {"id": 3, "type": "subscription_that_raises_on_cancel"} + ) + msg = await websocket_client.receive_json() + assert msg["id"] == 3 + assert msg["type"] == const.TYPE_RESULT + assert msg["success"] + assert len(subscriptions) == 3 + await websocket_client.send_json({"id": 4, "type": "cancel_in_handler"}) + await hass.async_block_till_done() + msg = await websocket_client.receive() + assert msg.type == WSMsgType.close + assert len(subscriptions) == 0 + + async def test_pending_msg_peak( hass: HomeAssistant, mock_low_peak,