Add connection test feature to assist_satellite (#126256)

* Add connection test feature to assist_satellite

* Add http to assist_satellite dependencies

* Remove extra logging

* Incorporate feedback

* Fix tests

* ruff

* Apply suggestions from code review

Co-authored-by: Bram Kragten <mail@bramkragten.nl>

* Use asyncio.Event instead of dispatcher

* Respond asap

* Update homeassistant/components/assist_satellite/websocket_api.py

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

---------

Co-authored-by: Michael Hansen <mike@rhasspy.org>
Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
Co-authored-by: Bram Kragten <mail@bramkragten.nl>
Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Erik Montnemery 2024-09-22 16:55:31 +02:00 committed by GitHub
parent bb2c2d161a
commit 8158ca7c69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 258 additions and 5 deletions

View File

@ -10,7 +10,13 @@ from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from .const import DOMAIN, DOMAIN_DATA, AssistSatelliteEntityFeature from .connection_test import ConnectionTestView
from .const import (
CONNECTION_TEST_DATA,
DOMAIN,
DOMAIN_DATA,
AssistSatelliteEntityFeature,
)
from .entity import ( from .entity import (
AssistSatelliteAnnouncement, AssistSatelliteAnnouncement,
AssistSatelliteConfiguration, AssistSatelliteConfiguration,
@ -57,7 +63,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"async_internal_announce", "async_internal_announce",
[AssistSatelliteEntityFeature.ANNOUNCE], [AssistSatelliteEntityFeature.ANNOUNCE],
) )
hass.data[CONNECTION_TEST_DATA] = {}
async_register_websocket_api(hass) async_register_websocket_api(hass)
hass.http.register_view(ConnectionTestView())
return True return True

View File

@ -0,0 +1,43 @@
"""Assist satellite connection test."""
import logging
from pathlib import Path
from aiohttp import web
from homeassistant.components.http import KEY_HASS, HomeAssistantView
from .const import CONNECTION_TEST_DATA
_LOGGER = logging.getLogger(__name__)
CONNECTION_TEST_CONTENT_TYPE = "audio/mpeg"
CONNECTION_TEST_FILENAME = "connection_test.mp3"
CONNECTION_TEST_URL_BASE = "/api/assist_satellite/connection_test"
class ConnectionTestView(HomeAssistantView):
"""View to serve an audio sample for connection test."""
requires_auth = False
url = f"{CONNECTION_TEST_URL_BASE}/{{connection_id}}"
name = "api:assist_satellite_connection_test"
async def get(self, request: web.Request, connection_id: str) -> web.Response:
"""Start a get request."""
_LOGGER.debug("Request for connection test with id %s", connection_id)
hass = request.app[KEY_HASS]
connection_test_data = hass.data[CONNECTION_TEST_DATA]
connection_test_event = connection_test_data.pop(connection_id, None)
if connection_test_event is None:
return web.Response(status=404)
connection_test_event.set()
audio_path = Path(__file__).parent / CONNECTION_TEST_FILENAME
audio_data = await hass.async_add_executor_job(audio_path.read_bytes)
return web.Response(body=audio_data, content_type=CONNECTION_TEST_CONTENT_TYPE)

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio
from enum import IntFlag from enum import IntFlag
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@ -15,6 +16,9 @@ if TYPE_CHECKING:
DOMAIN = "assist_satellite" DOMAIN = "assist_satellite"
DOMAIN_DATA: HassKey[EntityComponent[AssistSatelliteEntity]] = HassKey(DOMAIN) DOMAIN_DATA: HassKey[EntityComponent[AssistSatelliteEntity]] = HassKey(DOMAIN)
CONNECTION_TEST_DATA: HassKey[dict[str, asyncio.Event]] = HassKey(
f"{DOMAIN}_connection_tests"
)
class AssistSatelliteEntityFeature(IntFlag): class AssistSatelliteEntityFeature(IntFlag):

View File

@ -2,7 +2,7 @@
"domain": "assist_satellite", "domain": "assist_satellite",
"name": "Assist Satellite", "name": "Assist Satellite",
"codeowners": ["@home-assistant/core", "@synesthesiam"], "codeowners": ["@home-assistant/core", "@synesthesiam"],
"dependencies": ["assist_pipeline", "stt", "tts"], "dependencies": ["assist_pipeline", "http", "stt", "tts"],
"documentation": "https://www.home-assistant.io/integrations/assist_satellite", "documentation": "https://www.home-assistant.io/integrations/assist_satellite",
"integration_type": "entity", "integration_type": "entity",
"quality_scale": "internal" "quality_scale": "internal"

View File

@ -1,5 +1,6 @@
"""Assist satellite Websocket API.""" """Assist satellite Websocket API."""
import asyncio
from dataclasses import asdict, replace from dataclasses import asdict, replace
from typing import Any from typing import Any
@ -9,8 +10,19 @@ from homeassistant.components import websocket_api
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.util import uuid as uuid_util
from .const import DOMAIN, DOMAIN_DATA from .connection_test import CONNECTION_TEST_URL_BASE
from .const import (
CONNECTION_TEST_DATA,
DOMAIN,
DOMAIN_DATA,
AssistSatelliteEntityFeature,
)
from .entity import AssistSatelliteEntity
CONNECTION_TEST_TIMEOUT = 30
@callback @callback
@ -19,6 +31,7 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
websocket_api.async_register_command(hass, websocket_intercept_wake_word) websocket_api.async_register_command(hass, websocket_intercept_wake_word)
websocket_api.async_register_command(hass, websocket_get_configuration) websocket_api.async_register_command(hass, websocket_get_configuration)
websocket_api.async_register_command(hass, websocket_set_wake_words) websocket_api.async_register_command(hass, websocket_set_wake_words)
websocket_api.async_register_command(hass, websocket_test_connection)
@callback @callback
@ -138,3 +151,57 @@ async def websocket_set_wake_words(
replace(config, active_wake_words=actual_ids) replace(config, active_wake_words=actual_ids)
) )
connection.send_result(msg["id"]) connection.send_result(msg["id"])
@websocket_api.websocket_command(
{
vol.Required("type"): "assist_satellite/test_connection",
vol.Required("entity_id"): cv.entity_domain(DOMAIN),
}
)
@websocket_api.async_response
async def websocket_test_connection(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Test the connection between the device and Home Assistant.
Send an announcement to the device with a special media id.
"""
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
satellite = component.get_entity(msg["entity_id"])
if satellite is None:
connection.send_error(
msg["id"], websocket_api.ERR_NOT_FOUND, "Entity not found"
)
return
if not (satellite.supported_features or 0) & AssistSatelliteEntityFeature.ANNOUNCE:
connection.send_error(
msg["id"],
websocket_api.ERR_NOT_SUPPORTED,
"Entity does not support announce",
)
return
# Announce and wait for event
connection_test_data = hass.data[CONNECTION_TEST_DATA]
connection_id = uuid_util.random_uuid_hex()
connection_test_event = asyncio.Event()
connection_test_data[connection_id] = connection_test_event
hass.async_create_background_task(
satellite.async_internal_announce(
media_id=f"{CONNECTION_TEST_URL_BASE}/{connection_id}"
),
f"assist_satellite_connection_test_{msg['entity_id']}",
)
try:
async with asyncio.timeout(CONNECTION_TEST_TIMEOUT):
await connection_test_event.wait()
connection.send_result(msg["id"], {"status": "success"})
except TimeoutError:
connection.send_result(msg["id"], {"status": "timeout"})
finally:
connection_test_data.pop(connection_id, None)

View File

@ -44,7 +44,7 @@ class MockAssistSatellite(AssistSatelliteEntity):
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize the mock entity.""" """Initialize the mock entity."""
self.events = [] self.events = []
self.announcements = [] self.announcements: list[AssistSatelliteAnnouncement] = []
self.config = AssistSatelliteConfiguration( self.config = AssistSatelliteConfiguration(
available_wake_words=[ available_wake_words=[
AssistSatelliteWakeWord( AssistSatelliteWakeWord(

View File

@ -1,11 +1,16 @@
"""Test WebSocket API.""" """Test WebSocket API."""
import asyncio import asyncio
from http import HTTPStatus
from unittest.mock import patch from unittest.mock import patch
from freezegun.api import FrozenDateTimeFactory
import pytest import pytest
from homeassistant.components.assist_pipeline import PipelineStage from homeassistant.components.assist_pipeline import PipelineStage
from homeassistant.components.assist_satellite.websocket_api import (
CONNECTION_TEST_TIMEOUT,
)
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -13,7 +18,7 @@ from . import ENTITY_ID
from .conftest import MockAssistSatellite from .conftest import MockAssistSatellite
from tests.common import MockUser from tests.common import MockUser
from tests.typing import WebSocketGenerator from tests.typing import ClientSessionGenerator, WebSocketGenerator
async def test_intercept_wake_word( async def test_intercept_wake_word(
@ -385,3 +390,129 @@ async def test_set_wake_words_bad_id(
"code": "not_supported", "code": "not_supported",
"message": "Wake word id is not supported: abcd", "message": "Wake word id is not supported: abcd",
} }
async def test_connection_test(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
hass_client: ClientSessionGenerator,
) -> None:
"""Test connection test."""
ws_client = await hass_ws_client(hass)
await ws_client.send_json_auto_id(
{
"type": "assist_satellite/test_connection",
"entity_id": ENTITY_ID,
}
)
for _ in range(3):
await asyncio.sleep(0)
assert len(entity.announcements) == 1
assert entity.announcements[0].message == ""
announcement_media_id = entity.announcements[0].media_id
hass_url = "http://10.10.10.10:8123"
assert announcement_media_id.startswith(
f"{hass_url}/api/assist_satellite/connection_test/"
)
# Fake satellite fetches the URL
client = await hass_client()
resp = await client.get(announcement_media_id[len(hass_url) :])
assert resp.status == HTTPStatus.OK
response = await ws_client.receive_json()
assert response["success"]
assert response["result"] == {"status": "success"}
async def test_connection_test_timeout(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
hass_client: ClientSessionGenerator,
freezer: FrozenDateTimeFactory,
) -> None:
"""Test connection test timeout."""
ws_client = await hass_ws_client(hass)
await ws_client.send_json_auto_id(
{
"type": "assist_satellite/test_connection",
"entity_id": ENTITY_ID,
}
)
for _ in range(3):
await asyncio.sleep(0)
assert len(entity.announcements) == 1
assert entity.announcements[0].message == ""
announcement_media_id = entity.announcements[0].media_id
hass_url = "http://10.10.10.10:8123"
assert announcement_media_id.startswith(
f"{hass_url}/api/assist_satellite/connection_test/"
)
freezer.tick(CONNECTION_TEST_TIMEOUT + 1)
# Timeout
response = await ws_client.receive_json()
assert response["success"]
assert response["result"] == {"status": "timeout"}
async def test_connection_test_invalid_satellite(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test connection test with unknown entity id."""
ws_client = await hass_ws_client(hass)
await ws_client.send_json_auto_id(
{
"type": "assist_satellite/test_connection",
"entity_id": "assist_satellite.invalid",
}
)
response = await ws_client.receive_json()
assert not response["success"]
assert response["error"] == {
"code": "not_found",
"message": "Entity not found",
}
async def test_connection_test_timeout_announcement_unsupported(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test connection test entity which does not support announce."""
ws_client = await hass_ws_client(hass)
# Disable announce support
entity.supported_features = 0
await ws_client.send_json_auto_id(
{
"type": "assist_satellite/test_connection",
"entity_id": ENTITY_ID,
}
)
response = await ws_client.receive_json()
assert not response["success"]
assert response["error"] == {
"code": "not_supported",
"message": "Entity does not support announce",
}