mirror of
https://github.com/home-assistant/supervisor.git
synced 2025-11-25 10:38:03 +00:00
* Simplify ensure_access_token Make the caller of ensure_access_token responsible for connection error handling. This is especially useful for API connection checks, as it avoids an extra call to the API (if we fail to connect when refreshing the token there is no point in calling the API to check if it is up). Document the change in the docstring. Also avoid the overhead of creating a Job object. We can simply use an asyncio.Lock() to ensure only one coroutine is refreshing the token at a time. This also avoids Job interference in Exception handling. * Remove check_port from API checks Remove check_port usage from Home Assistant API connection checks. Simply rely on errors raised from actual connection attempts. During a Supervisor startup when Home Assistant Core is running (e.g. after a Supervisor update) we make about 10 successful API checks. The old code path did a port check and then a connection check, causing two socket creation. The new code without the separate port check safes 10 socket creations per startup (the aiohttp connections are reused, hence do not cause only one socket creation). * Log API exceptions on call site Since make_request is no longer logging API exceptions on its own, we need to log them where we call make_request. This approach gives the user more context about what Supervisor was trying to do when the error happened. * Avoid unnecessary nesting * Improve error when ingress panel update fails * Add comment about fast path
356 lines
12 KiB
Python
356 lines
12 KiB
Python
"""Home Assistant Websocket API."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from contextlib import suppress
|
|
import logging
|
|
from typing import Any, TypeVar, cast
|
|
|
|
import aiohttp
|
|
from aiohttp.http_websocket import WSMsgType
|
|
from awesomeversion import AwesomeVersion
|
|
|
|
from ..const import (
|
|
ATTR_ACCESS_TOKEN,
|
|
ATTR_DATA,
|
|
ATTR_EVENT,
|
|
ATTR_TYPE,
|
|
ATTR_UPDATE_KEY,
|
|
STARTING_STATES,
|
|
BusEvent,
|
|
CoreState,
|
|
)
|
|
from ..coresys import CoreSys, CoreSysAttributes
|
|
from ..exceptions import (
|
|
HomeAssistantAPIError,
|
|
HomeAssistantWSConnectionError,
|
|
HomeAssistantWSError,
|
|
)
|
|
from ..utils.json import json_dumps
|
|
from .const import CLOSING_STATES, WSEvent, WSType
|
|
|
|
MIN_VERSION = {
|
|
WSType.SUPERVISOR_EVENT: "2021.2.4",
|
|
WSType.BACKUP_START: "2022.1.0",
|
|
WSType.BACKUP_END: "2022.1.0",
|
|
}
|
|
|
|
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
class WSClient:
|
|
"""Home Assistant Websocket client."""
|
|
|
|
def __init__(
|
|
self,
|
|
loop: asyncio.BaseEventLoop,
|
|
ha_version: AwesomeVersion,
|
|
client: aiohttp.ClientWebSocketResponse,
|
|
):
|
|
"""Initialise the WS client."""
|
|
self.ha_version = ha_version
|
|
self._client = client
|
|
self._message_id: int = 0
|
|
self._loop = loop
|
|
self._futures: dict[int, asyncio.Future[T]] = {} # type: ignore
|
|
|
|
@property
|
|
def connected(self) -> bool:
|
|
"""Return if we're currently connected."""
|
|
return self._client is not None and not self._client.closed
|
|
|
|
async def close(self) -> None:
|
|
"""Close down the client."""
|
|
for future in self._futures.values():
|
|
if not future.done():
|
|
future.set_exception(
|
|
HomeAssistantWSConnectionError("Connection was closed")
|
|
)
|
|
|
|
if not self._client.closed:
|
|
await self._client.close()
|
|
|
|
async def async_send_message(self, message: dict[str, Any]) -> None:
|
|
"""Send a websocket message, don't wait for response."""
|
|
self._message_id += 1
|
|
_LOGGER.debug("Sending: %s", message)
|
|
try:
|
|
await self._client.send_json(message, dumps=json_dumps)
|
|
except ConnectionError as err:
|
|
raise HomeAssistantWSConnectionError(str(err)) from err
|
|
|
|
async def async_send_command(self, message: dict[str, Any]) -> T | None:
|
|
"""Send a websocket message, and return the response."""
|
|
self._message_id += 1
|
|
message["id"] = self._message_id
|
|
self._futures[message["id"]] = self._loop.create_future()
|
|
_LOGGER.debug("Sending: %s", message)
|
|
try:
|
|
await self._client.send_json(message, dumps=json_dumps)
|
|
except ConnectionError as err:
|
|
raise HomeAssistantWSConnectionError(str(err)) from err
|
|
|
|
try:
|
|
return await self._futures[message["id"]]
|
|
finally:
|
|
self._futures.pop(message["id"])
|
|
|
|
async def start_listener(self) -> None:
|
|
"""Start listening to the websocket."""
|
|
if not self.connected:
|
|
raise HomeAssistantWSConnectionError("Not connected when start listening")
|
|
|
|
try:
|
|
while self.connected:
|
|
await self._receive_json()
|
|
except HomeAssistantWSError:
|
|
pass
|
|
|
|
finally:
|
|
await self.close()
|
|
|
|
async def _receive_json(self) -> None:
|
|
"""Receive json."""
|
|
msg = await self._client.receive()
|
|
_LOGGER.debug("Received: %s", msg)
|
|
|
|
if msg.type == WSMsgType.CLOSE:
|
|
raise HomeAssistantWSConnectionError("Connection was closed", _LOGGER.debug)
|
|
|
|
if msg.type in (
|
|
WSMsgType.CLOSED,
|
|
WSMsgType.CLOSING,
|
|
):
|
|
raise HomeAssistantWSConnectionError(
|
|
"Connection is closed", _LOGGER.warning
|
|
)
|
|
|
|
if msg.type == WSMsgType.ERROR:
|
|
raise HomeAssistantWSError(f"WebSocket Error: {msg}", _LOGGER.error)
|
|
|
|
if msg.type != WSMsgType.TEXT:
|
|
raise HomeAssistantWSError(
|
|
f"Received non-Text message: {msg.type}", _LOGGER.error
|
|
)
|
|
|
|
try:
|
|
data = msg.json()
|
|
except ValueError as err:
|
|
raise HomeAssistantWSError(
|
|
f"Received invalid JSON - {msg}", _LOGGER.error
|
|
) from err
|
|
|
|
if data["type"] == "result":
|
|
if (future := self._futures.get(data["id"])) is None:
|
|
return
|
|
|
|
if data["success"]:
|
|
future.set_result(data["result"])
|
|
return
|
|
|
|
future.set_exception(
|
|
HomeAssistantWSError(f"Unsuccessful websocket message - {data}")
|
|
)
|
|
|
|
@classmethod
|
|
async def connect_with_auth(
|
|
cls, session: aiohttp.ClientSession, loop, url: str, token: str
|
|
) -> WSClient:
|
|
"""Create an authenticated websocket client."""
|
|
try:
|
|
client = await session.ws_connect(url, ssl=False)
|
|
except aiohttp.client_exceptions.ClientConnectorError:
|
|
raise HomeAssistantWSError("Can't connect") from None
|
|
|
|
hello_message = await client.receive_json()
|
|
|
|
await client.send_json(
|
|
{ATTR_TYPE: WSType.AUTH, ATTR_ACCESS_TOKEN: token}, dumps=json_dumps
|
|
)
|
|
|
|
auth_ok_message = await client.receive_json()
|
|
|
|
if auth_ok_message[ATTR_TYPE] != "auth_ok":
|
|
raise HomeAssistantAPIError("AUTH NOT OK")
|
|
|
|
return cls(loop, AwesomeVersion(hello_message["ha_version"]), client)
|
|
|
|
|
|
class HomeAssistantWebSocket(CoreSysAttributes):
|
|
"""Home Assistant Websocket API."""
|
|
|
|
def __init__(self, coresys: CoreSys):
|
|
"""Initialize Home Assistant object."""
|
|
self.coresys: CoreSys = coresys
|
|
self._client: WSClient | None = None
|
|
self._lock: asyncio.Lock = asyncio.Lock()
|
|
self._queue: list[dict[str, Any]] = []
|
|
|
|
async def _process_queue(self, reference: CoreState) -> None:
|
|
"""Process queue once supervisor is running."""
|
|
if reference == CoreState.RUNNING:
|
|
for msg in self._queue:
|
|
await self.async_send_message(msg)
|
|
|
|
self._queue.clear()
|
|
|
|
async def _get_ws_client(self) -> WSClient:
|
|
"""Return a websocket client."""
|
|
async with self._lock:
|
|
if self._client is not None and self._client.connected:
|
|
return self._client
|
|
|
|
with suppress(asyncio.TimeoutError, aiohttp.ClientError):
|
|
await self.sys_homeassistant.api.ensure_access_token()
|
|
client = await WSClient.connect_with_auth(
|
|
self.sys_websession,
|
|
self.sys_loop,
|
|
self.sys_homeassistant.ws_url,
|
|
cast(str, self.sys_homeassistant.api.access_token),
|
|
)
|
|
|
|
self.sys_create_task(client.start_listener())
|
|
return client
|
|
|
|
async def _can_send(self, message: dict[str, Any]) -> bool:
|
|
"""Determine if we can use WebSocket messages."""
|
|
if self.sys_core.state in CLOSING_STATES:
|
|
return False
|
|
|
|
connected = self._client and self._client.connected
|
|
# If we are already connected, we can avoid the check_api_state call
|
|
# since it makes a new socket connection and we already have one.
|
|
if not connected and not await self.sys_homeassistant.api.check_api_state():
|
|
# No core access, don't try.
|
|
return False
|
|
|
|
if not self._client:
|
|
self._client = await self._get_ws_client()
|
|
|
|
if not self._client.connected:
|
|
self._client = await self._get_ws_client()
|
|
|
|
message_type = message.get("type")
|
|
|
|
if (
|
|
message_type is not None
|
|
and message_type in MIN_VERSION
|
|
and self._client.ha_version < MIN_VERSION[message_type]
|
|
):
|
|
_LOGGER.info(
|
|
"WebSocket command %s is not supported until core-%s. Ignoring WebSocket message.",
|
|
message_type,
|
|
MIN_VERSION[message_type],
|
|
)
|
|
return False
|
|
return True
|
|
|
|
async def load(self) -> None:
|
|
"""Set up queue processor after startup completes."""
|
|
self.sys_bus.register_event(
|
|
BusEvent.SUPERVISOR_STATE_CHANGE, self._process_queue
|
|
)
|
|
|
|
async def async_send_message(self, message: dict[str, Any]) -> None:
|
|
"""Send a message with the WS client."""
|
|
# Only commands allowed during startup as those tell Home Assistant to do something.
|
|
# Messages may cause clients to make follow-up API calls so those wait.
|
|
if self.sys_core.state in STARTING_STATES:
|
|
self._queue.append(message)
|
|
_LOGGER.debug("Queuing message until startup has completed: %s", message)
|
|
return
|
|
|
|
if not await self._can_send(message):
|
|
return
|
|
|
|
try:
|
|
if self._client:
|
|
await self._client.async_send_command(message)
|
|
except HomeAssistantWSConnectionError:
|
|
if self._client:
|
|
await self._client.close()
|
|
self._client = None
|
|
|
|
async def async_send_command(self, message: dict[str, Any]) -> T | None:
|
|
"""Send a command with the WS client and wait for the response."""
|
|
if not await self._can_send(message):
|
|
return None
|
|
|
|
try:
|
|
if self._client:
|
|
return await self._client.async_send_command(message)
|
|
except HomeAssistantWSConnectionError:
|
|
if self._client:
|
|
await self._client.close()
|
|
self._client = None
|
|
raise
|
|
return None
|
|
|
|
def send_message(self, message: dict[str, Any]) -> None:
|
|
"""Send a supervisor/event message."""
|
|
if self.sys_core.state in CLOSING_STATES:
|
|
return
|
|
self.sys_create_task(self.async_send_message(message))
|
|
|
|
async def async_supervisor_event_custom(
|
|
self, event: WSEvent, extra_data: dict[str, Any] | None = None
|
|
) -> None:
|
|
"""Send a supervisor/event message to Home Assistant with custom data."""
|
|
try:
|
|
await self.async_send_message(
|
|
{
|
|
ATTR_TYPE: WSType.SUPERVISOR_EVENT,
|
|
ATTR_DATA: {
|
|
ATTR_EVENT: event,
|
|
**(extra_data or {}),
|
|
},
|
|
}
|
|
)
|
|
except HomeAssistantWSError as err:
|
|
_LOGGER.error("Could not send message to Home Assistant due to %s", err)
|
|
|
|
def supervisor_event_custom(
|
|
self, event: WSEvent, extra_data: dict[str, Any] | None = None
|
|
) -> None:
|
|
"""Send a supervisor/event message to Home Assistant with custom data."""
|
|
if self.sys_core.state in CLOSING_STATES:
|
|
return
|
|
self.sys_create_task(self.async_supervisor_event_custom(event, extra_data))
|
|
|
|
def supervisor_event(
|
|
self, event: WSEvent, data: dict[str, Any] | None = None
|
|
) -> None:
|
|
"""Send a supervisor/event message to Home Assistant."""
|
|
if self.sys_core.state in CLOSING_STATES:
|
|
return
|
|
self.sys_create_task(
|
|
self.async_supervisor_event_custom(event, {ATTR_DATA: data or {}})
|
|
)
|
|
|
|
async def async_supervisor_update_event(
|
|
self,
|
|
key: str,
|
|
data: dict[str, Any] | None = None,
|
|
) -> None:
|
|
"""Send an update supervisor/event message."""
|
|
await self.async_supervisor_event_custom(
|
|
WSEvent.SUPERVISOR_UPDATE,
|
|
{
|
|
ATTR_UPDATE_KEY: key,
|
|
ATTR_DATA: data or {},
|
|
},
|
|
)
|
|
|
|
def supervisor_update_event(
|
|
self,
|
|
key: str,
|
|
data: dict[str, Any] | None = None,
|
|
) -> None:
|
|
"""Send an update supervisor/event message."""
|
|
if self.sys_core.state in CLOSING_STATES:
|
|
return
|
|
self.sys_create_task(self.async_supervisor_update_event(key, data))
|