diff --git a/supervisor/addons/addon.py b/supervisor/addons/addon.py index 68a6b5fbd..bab9aff6f 100644 --- a/supervisor/addons/addon.py +++ b/supervisor/addons/addon.py @@ -119,7 +119,7 @@ class Addon(AddonModel): if self._state == new_state: return self._state = new_state - self.sys_homeassistant.websocket.send_command( + self.sys_homeassistant.websocket.send_message( { ATTR_TYPE: WSType.SUPERVISOR_EVENT, ATTR_DATA: { diff --git a/supervisor/exceptions.py b/supervisor/exceptions.py index 626500258..57281963e 100644 --- a/supervisor/exceptions.py +++ b/supervisor/exceptions.py @@ -69,6 +69,10 @@ class HomeAssistantWSNotSupported(HomeAssistantWSError): """Raise when WebSockets are not supported.""" +class HomeAssistantWSConnectionError(HomeAssistantWSError): + """Raise when the WebSocket connection has an error.""" + + class HomeAssistantJobError(HomeAssistantError, JobException): """Raise on Home Assistant job error.""" diff --git a/supervisor/homeassistant/module.py b/supervisor/homeassistant/module.py index 88a37f634..d6d4f9a41 100644 --- a/supervisor/homeassistant/module.py +++ b/supervisor/homeassistant/module.py @@ -118,6 +118,11 @@ class HomeAssistant(FileConfiguration, CoreSysAttributes): f"{'https' if self.api_ssl else 'http'}://{self.ip_address}:{self.api_port}" ) + @property + def ws_url(self) -> str: + """Return API url to Home Assistant.""" + return f"{'wss' if self.api_ssl else 'ws'}://{self.ip_address}:{self.api_port}/api/websocket" + @property def watchdog(self) -> bool: """Return True if the watchdog should protect Home Assistant.""" @@ -278,4 +283,4 @@ class HomeAssistant(FileConfiguration, CoreSysAttributes): if not configuration or "usb" not in configuration.get("components", []): return - self.sys_homeassistant.websocket.send_command({ATTR_TYPE: "usb/scan"}) + self.sys_homeassistant.websocket.send_message({ATTR_TYPE: "usb/scan"}) diff --git a/supervisor/homeassistant/websocket.py b/supervisor/homeassistant/websocket.py index aa711e365..78c475dc8 100644 --- a/supervisor/homeassistant/websocket.py +++ b/supervisor/homeassistant/websocket.py @@ -1,15 +1,19 @@ """Home Assistant Websocket API.""" +from __future__ import annotations + import asyncio import logging -from typing import Any, Optional +from typing import Any import aiohttp +from aiohttp.http_websocket import WSMsgType from awesomeversion import AwesomeVersion from ..const import ATTR_ACCESS_TOKEN, ATTR_DATA, ATTR_EVENT, ATTR_TYPE, ATTR_UPDATE_KEY from ..coresys import CoreSys, CoreSysAttributes from ..exceptions import ( HomeAssistantAPIError, + HomeAssistantWSConnectionError, HomeAssistantWSError, HomeAssistantWSNotSupported, ) @@ -22,41 +26,113 @@ class WSClient: """Home Assistant Websocket client.""" def __init__( - self, ha_version: AwesomeVersion, client: aiohttp.ClientWebSocketResponse + self, + loop: asyncio.BaseEventLoop, + ha_version: AwesomeVersion, + client: aiohttp.ClientWebSocketResponse, ): """Initialise the WS client.""" - self.ha_version: AwesomeVersion = ha_version - self.client: aiohttp.ClientWebSocketResponse = client - self.message_id: int = 0 - self._lock: asyncio.Lock = asyncio.Lock() + self.ha_version = ha_version + self._client = client + self._message_id: int = 0 + self._loop = loop + self._futures: dict[int, asyncio.Future[dict]] = {} - async def async_send_command(self, message: dict[str, Any]): - """Send a websocket command.""" - async with self._lock: - self.message_id += 1 - message["id"] = self.message_id + @property + def connected(self) -> bool: + """Return if we're currently connected.""" + return self._client is not None and not self._client.closed - _LOGGER.debug("Sending: %s", message) - try: - await self.client.send_json(message) - except ConnectionError: + async def close(self) -> None: + """Close down the client.""" + if not self._client.closed: + await self._client.close() + + async def async_send_message(self, message: dict[str, Any]) -> None: + """Send a websocket message, don't wait for response.""" + self._message_id += 1 + _LOGGER.debug("Sending: %s", message) + try: + await self._client.send_json(message) + except ConnectionError as err: + raise HomeAssistantWSConnectionError(err) from err + + async def async_send_command(self, message: dict[str, Any]) -> dict | None: + """Send a websocket message, and return the response.""" + self._message_id += 1 + message["id"] = self._message_id + self._futures[message["id"]] = self._loop.create_future() + _LOGGER.debug("Sending: %s", message) + try: + await self._client.send_json(message) + except ConnectionError as err: + raise HomeAssistantWSConnectionError(err) from err + + try: + return await self._futures[message["id"]] + finally: + self._futures.pop(message["id"]) + + async def start_listener(self) -> None: + """Start listening to the websocket.""" + if not self.connected: + raise HomeAssistantWSConnectionError("Not connected when start listening") + + try: + while self.connected: + await self._receive_json() + except HomeAssistantWSError: + pass + + finally: + await self.close() + + async def _receive_json(self) -> None: + """Receive json.""" + msg = await self._client.receive() + _LOGGER.debug("Received: %s", msg) + + if msg.type == WSMsgType.CLOSE: + raise HomeAssistantWSConnectionError("Connection was closed", _LOGGER.debug) + + if msg.type in ( + WSMsgType.CLOSED, + WSMsgType.CLOSING, + ): + raise HomeAssistantWSConnectionError( + "Connection is closed", _LOGGER.warning + ) + + if msg.type == WSMsgType.ERROR: + raise HomeAssistantWSError(f"WebSocket Error: {msg}", _LOGGER.error) + + if msg.type != WSMsgType.TEXT: + raise HomeAssistantWSError( + f"Received non-Text message: {msg.type}", _LOGGER.error + ) + + try: + data = msg.json() + except ValueError as err: + raise HomeAssistantWSError( + f"Received invalid JSON - {msg}", _LOGGER.error + ) from err + + if data["type"] == "result": + if (future := self._futures.get(data["id"])) is None: return - try: - response = await self.client.receive_json() - except ConnectionError: + if data["success"]: + future.set_result(data["result"]) return - _LOGGER.debug("Received: %s", response) - - if response["success"]: - return response["result"] - - raise HomeAssistantWSError(response) + future.set_exception( + HomeAssistantWSError(f"Unsuccessful websocket message - {data}") + ) @classmethod async def connect_with_auth( - cls, session: aiohttp.ClientSession, url: str, token: str + cls, session: aiohttp.ClientSession, loop, url: str, token: str ) -> "WSClient": """Create an authenticated websocket client.""" try: @@ -66,17 +142,14 @@ class WSClient: hello_message = await client.receive_json() - try: - await client.send_json({ATTR_TYPE: WSType.AUTH, ATTR_ACCESS_TOKEN: token}) - except HomeAssistantWSNotSupported: - return + await client.send_json({ATTR_TYPE: WSType.AUTH, ATTR_ACCESS_TOKEN: token}) auth_ok_message = await client.receive_json() if auth_ok_message[ATTR_TYPE] != "auth_ok": raise HomeAssistantAPIError("AUTH NOT OK") - return cls(AwesomeVersion(hello_message["ha_version"]), client) + return cls(loop, AwesomeVersion(hello_message["ha_version"]), client) class HomeAssistantWebSocket(CoreSysAttributes): @@ -85,37 +158,40 @@ class HomeAssistantWebSocket(CoreSysAttributes): def __init__(self, coresys: CoreSys): """Initialize Home Assistant object.""" self.coresys: CoreSys = coresys - self._client: Optional[WSClient] = None + self._client: WSClient | None = None self._lock: asyncio.Lock = asyncio.Lock() async def _get_ws_client(self) -> WSClient: """Return a websocket client.""" async with self._lock: - if self._client is not None: + if self._client is not None and self._client.connected: return self._client await self.sys_homeassistant.api.ensure_access_token() client = await WSClient.connect_with_auth( self.sys_websession, - f"{self.sys_homeassistant.api_url}/api/websocket", + self.sys_loop, + self.sys_homeassistant.ws_url, self.sys_homeassistant.api.access_token, ) + self.sys_create_task(client.start_listener()) return client - async def async_send_command(self, message: dict[str, Any]): - """Send a command with the WS client.""" + async def _can_send(self, message: dict[str, Any]) -> bool: + """Determine if we can use WebSocket messages.""" if self.sys_core.state in CLOSING_STATES: - raise HomeAssistantWSNotSupported( - f"Can't execute in a ${self.sys_core.state} state" - ) + return False if not await self.sys_homeassistant.api.check_api_state(): # No core access, don't try. - return + return False if not self._client: self._client = await self._get_ws_client() + if not self._client.connected: + self._client = await self._get_ws_client() + message_type = message.get("type") if ( @@ -128,19 +204,39 @@ class HomeAssistantWebSocket(CoreSysAttributes): message_type, MIN_VERSION[message_type], ) + return False + return True + + async def async_send_message(self, message: dict[str, Any]) -> None: + """Send a command with the WS client.""" + if not await self._can_send(message): + return + + try: + await self._client.async_send_command(message) + except HomeAssistantWSError: + await self._client.close() + self._client = None + + async def async_send_command(self, message: dict[str, Any]) -> dict[str, Any]: + """Send a command with the WS client and wait for the response.""" + if not await self._can_send(message): return try: return await self._client.async_send_command(message) - except HomeAssistantAPIError as err: - raise HomeAssistantWSError from err + except HomeAssistantWSError: + await self._client.close() + self._client = None async def async_supervisor_update_event( - self, key: str, data: Optional[dict[str, Any]] = None - ): + self, + key: str, + data: dict[str, Any] | None = None, + ) -> None: """Send a supervisor/event command.""" try: - await self.async_send_command( + await self.async_send_message( { ATTR_TYPE: WSType.SUPERVISOR_EVENT, ATTR_DATA: { @@ -155,14 +251,18 @@ class HomeAssistantWebSocket(CoreSysAttributes): except HomeAssistantWSError as err: _LOGGER.error(err) - def supervisor_update_event(self, key: str, data: Optional[dict[str, Any]] = None): + def supervisor_update_event( + self, + key: str, + data: dict[str, Any] | None = None, + ) -> None: """Send a supervisor/event command.""" if self.sys_core.state in CLOSING_STATES: return self.sys_create_task(self.async_supervisor_update_event(key, data)) - def send_command(self, message: dict[str, Any]): + def send_message(self, message: dict[str, Any]) -> None: """Send a supervisor/event command.""" if self.sys_core.state in CLOSING_STATES: return - self.sys_create_task(self.async_send_command(message)) + self.sys_create_task(self.async_send_message(message))