mirror of
https://github.com/home-assistant/core.git
synced 2025-07-17 18:27:09 +00:00
Add connection test feature to assist_satellite (#126256)
* Add connection test feature to assist_satellite * Add http to assist_satellite dependencies * Remove extra logging * Incorporate feedback * Fix tests * ruff * Apply suggestions from code review Co-authored-by: Bram Kragten <mail@bramkragten.nl> * Use asyncio.Event instead of dispatcher * Respond asap * Update homeassistant/components/assist_satellite/websocket_api.py Co-authored-by: Martin Hjelmare <marhje52@gmail.com> --------- Co-authored-by: Michael Hansen <mike@rhasspy.org> Co-authored-by: Paulus Schoutsen <balloob@gmail.com> Co-authored-by: Bram Kragten <mail@bramkragten.nl> Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
parent
bb2c2d161a
commit
8158ca7c69
@ -10,7 +10,13 @@ from homeassistant.helpers import config_validation as cv
|
|||||||
from homeassistant.helpers.entity_component import EntityComponent
|
from homeassistant.helpers.entity_component import EntityComponent
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
|
||||||
from .const import DOMAIN, DOMAIN_DATA, AssistSatelliteEntityFeature
|
from .connection_test import ConnectionTestView
|
||||||
|
from .const import (
|
||||||
|
CONNECTION_TEST_DATA,
|
||||||
|
DOMAIN,
|
||||||
|
DOMAIN_DATA,
|
||||||
|
AssistSatelliteEntityFeature,
|
||||||
|
)
|
||||||
from .entity import (
|
from .entity import (
|
||||||
AssistSatelliteAnnouncement,
|
AssistSatelliteAnnouncement,
|
||||||
AssistSatelliteConfiguration,
|
AssistSatelliteConfiguration,
|
||||||
@ -57,7 +63,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
"async_internal_announce",
|
"async_internal_announce",
|
||||||
[AssistSatelliteEntityFeature.ANNOUNCE],
|
[AssistSatelliteEntityFeature.ANNOUNCE],
|
||||||
)
|
)
|
||||||
|
hass.data[CONNECTION_TEST_DATA] = {}
|
||||||
async_register_websocket_api(hass)
|
async_register_websocket_api(hass)
|
||||||
|
hass.http.register_view(ConnectionTestView())
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
BIN
homeassistant/components/assist_satellite/connection_test.mp3
Executable file
BIN
homeassistant/components/assist_satellite/connection_test.mp3
Executable file
Binary file not shown.
43
homeassistant/components/assist_satellite/connection_test.py
Normal file
43
homeassistant/components/assist_satellite/connection_test.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
"""Assist satellite connection test."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from homeassistant.components.http import KEY_HASS, HomeAssistantView
|
||||||
|
|
||||||
|
from .const import CONNECTION_TEST_DATA
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
CONNECTION_TEST_CONTENT_TYPE = "audio/mpeg"
|
||||||
|
CONNECTION_TEST_FILENAME = "connection_test.mp3"
|
||||||
|
CONNECTION_TEST_URL_BASE = "/api/assist_satellite/connection_test"
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionTestView(HomeAssistantView):
|
||||||
|
"""View to serve an audio sample for connection test."""
|
||||||
|
|
||||||
|
requires_auth = False
|
||||||
|
url = f"{CONNECTION_TEST_URL_BASE}/{{connection_id}}"
|
||||||
|
name = "api:assist_satellite_connection_test"
|
||||||
|
|
||||||
|
async def get(self, request: web.Request, connection_id: str) -> web.Response:
|
||||||
|
"""Start a get request."""
|
||||||
|
_LOGGER.debug("Request for connection test with id %s", connection_id)
|
||||||
|
|
||||||
|
hass = request.app[KEY_HASS]
|
||||||
|
connection_test_data = hass.data[CONNECTION_TEST_DATA]
|
||||||
|
|
||||||
|
connection_test_event = connection_test_data.pop(connection_id, None)
|
||||||
|
|
||||||
|
if connection_test_event is None:
|
||||||
|
return web.Response(status=404)
|
||||||
|
|
||||||
|
connection_test_event.set()
|
||||||
|
|
||||||
|
audio_path = Path(__file__).parent / CONNECTION_TEST_FILENAME
|
||||||
|
audio_data = await hass.async_add_executor_job(audio_path.read_bytes)
|
||||||
|
|
||||||
|
return web.Response(body=audio_data, content_type=CONNECTION_TEST_CONTENT_TYPE)
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from enum import IntFlag
|
from enum import IntFlag
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
@ -15,6 +16,9 @@ if TYPE_CHECKING:
|
|||||||
DOMAIN = "assist_satellite"
|
DOMAIN = "assist_satellite"
|
||||||
|
|
||||||
DOMAIN_DATA: HassKey[EntityComponent[AssistSatelliteEntity]] = HassKey(DOMAIN)
|
DOMAIN_DATA: HassKey[EntityComponent[AssistSatelliteEntity]] = HassKey(DOMAIN)
|
||||||
|
CONNECTION_TEST_DATA: HassKey[dict[str, asyncio.Event]] = HassKey(
|
||||||
|
f"{DOMAIN}_connection_tests"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AssistSatelliteEntityFeature(IntFlag):
|
class AssistSatelliteEntityFeature(IntFlag):
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
"domain": "assist_satellite",
|
"domain": "assist_satellite",
|
||||||
"name": "Assist Satellite",
|
"name": "Assist Satellite",
|
||||||
"codeowners": ["@home-assistant/core", "@synesthesiam"],
|
"codeowners": ["@home-assistant/core", "@synesthesiam"],
|
||||||
"dependencies": ["assist_pipeline", "stt", "tts"],
|
"dependencies": ["assist_pipeline", "http", "stt", "tts"],
|
||||||
"documentation": "https://www.home-assistant.io/integrations/assist_satellite",
|
"documentation": "https://www.home-assistant.io/integrations/assist_satellite",
|
||||||
"integration_type": "entity",
|
"integration_type": "entity",
|
||||||
"quality_scale": "internal"
|
"quality_scale": "internal"
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Assist satellite Websocket API."""
|
"""Assist satellite Websocket API."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from dataclasses import asdict, replace
|
from dataclasses import asdict, replace
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -9,8 +10,19 @@ from homeassistant.components import websocket_api
|
|||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import config_validation as cv
|
from homeassistant.helpers import config_validation as cv
|
||||||
|
from homeassistant.helpers.entity_component import EntityComponent
|
||||||
|
from homeassistant.util import uuid as uuid_util
|
||||||
|
|
||||||
from .const import DOMAIN, DOMAIN_DATA
|
from .connection_test import CONNECTION_TEST_URL_BASE
|
||||||
|
from .const import (
|
||||||
|
CONNECTION_TEST_DATA,
|
||||||
|
DOMAIN,
|
||||||
|
DOMAIN_DATA,
|
||||||
|
AssistSatelliteEntityFeature,
|
||||||
|
)
|
||||||
|
from .entity import AssistSatelliteEntity
|
||||||
|
|
||||||
|
CONNECTION_TEST_TIMEOUT = 30
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@ -19,6 +31,7 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
|
|||||||
websocket_api.async_register_command(hass, websocket_intercept_wake_word)
|
websocket_api.async_register_command(hass, websocket_intercept_wake_word)
|
||||||
websocket_api.async_register_command(hass, websocket_get_configuration)
|
websocket_api.async_register_command(hass, websocket_get_configuration)
|
||||||
websocket_api.async_register_command(hass, websocket_set_wake_words)
|
websocket_api.async_register_command(hass, websocket_set_wake_words)
|
||||||
|
websocket_api.async_register_command(hass, websocket_test_connection)
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@ -138,3 +151,57 @@ async def websocket_set_wake_words(
|
|||||||
replace(config, active_wake_words=actual_ids)
|
replace(config, active_wake_words=actual_ids)
|
||||||
)
|
)
|
||||||
connection.send_result(msg["id"])
|
connection.send_result(msg["id"])
|
||||||
|
|
||||||
|
|
||||||
|
@websocket_api.websocket_command(
|
||||||
|
{
|
||||||
|
vol.Required("type"): "assist_satellite/test_connection",
|
||||||
|
vol.Required("entity_id"): cv.entity_domain(DOMAIN),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@websocket_api.async_response
|
||||||
|
async def websocket_test_connection(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
connection: websocket_api.connection.ActiveConnection,
|
||||||
|
msg: dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
"""Test the connection between the device and Home Assistant.
|
||||||
|
|
||||||
|
Send an announcement to the device with a special media id.
|
||||||
|
"""
|
||||||
|
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
|
||||||
|
satellite = component.get_entity(msg["entity_id"])
|
||||||
|
if satellite is None:
|
||||||
|
connection.send_error(
|
||||||
|
msg["id"], websocket_api.ERR_NOT_FOUND, "Entity not found"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
if not (satellite.supported_features or 0) & AssistSatelliteEntityFeature.ANNOUNCE:
|
||||||
|
connection.send_error(
|
||||||
|
msg["id"],
|
||||||
|
websocket_api.ERR_NOT_SUPPORTED,
|
||||||
|
"Entity does not support announce",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Announce and wait for event
|
||||||
|
connection_test_data = hass.data[CONNECTION_TEST_DATA]
|
||||||
|
connection_id = uuid_util.random_uuid_hex()
|
||||||
|
connection_test_event = asyncio.Event()
|
||||||
|
connection_test_data[connection_id] = connection_test_event
|
||||||
|
|
||||||
|
hass.async_create_background_task(
|
||||||
|
satellite.async_internal_announce(
|
||||||
|
media_id=f"{CONNECTION_TEST_URL_BASE}/{connection_id}"
|
||||||
|
),
|
||||||
|
f"assist_satellite_connection_test_{msg['entity_id']}",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with asyncio.timeout(CONNECTION_TEST_TIMEOUT):
|
||||||
|
await connection_test_event.wait()
|
||||||
|
connection.send_result(msg["id"], {"status": "success"})
|
||||||
|
except TimeoutError:
|
||||||
|
connection.send_result(msg["id"], {"status": "timeout"})
|
||||||
|
finally:
|
||||||
|
connection_test_data.pop(connection_id, None)
|
||||||
|
@ -44,7 +44,7 @@ class MockAssistSatellite(AssistSatelliteEntity):
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"""Initialize the mock entity."""
|
"""Initialize the mock entity."""
|
||||||
self.events = []
|
self.events = []
|
||||||
self.announcements = []
|
self.announcements: list[AssistSatelliteAnnouncement] = []
|
||||||
self.config = AssistSatelliteConfiguration(
|
self.config = AssistSatelliteConfiguration(
|
||||||
available_wake_words=[
|
available_wake_words=[
|
||||||
AssistSatelliteWakeWord(
|
AssistSatelliteWakeWord(
|
||||||
|
@ -1,11 +1,16 @@
|
|||||||
"""Test WebSocket API."""
|
"""Test WebSocket API."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from http import HTTPStatus
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from freezegun.api import FrozenDateTimeFactory
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components.assist_pipeline import PipelineStage
|
from homeassistant.components.assist_pipeline import PipelineStage
|
||||||
|
from homeassistant.components.assist_satellite.websocket_api import (
|
||||||
|
CONNECTION_TEST_TIMEOUT,
|
||||||
|
)
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
@ -13,7 +18,7 @@ from . import ENTITY_ID
|
|||||||
from .conftest import MockAssistSatellite
|
from .conftest import MockAssistSatellite
|
||||||
|
|
||||||
from tests.common import MockUser
|
from tests.common import MockUser
|
||||||
from tests.typing import WebSocketGenerator
|
from tests.typing import ClientSessionGenerator, WebSocketGenerator
|
||||||
|
|
||||||
|
|
||||||
async def test_intercept_wake_word(
|
async def test_intercept_wake_word(
|
||||||
@ -385,3 +390,129 @@ async def test_set_wake_words_bad_id(
|
|||||||
"code": "not_supported",
|
"code": "not_supported",
|
||||||
"message": "Wake word id is not supported: abcd",
|
"message": "Wake word id is not supported: abcd",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_connection_test(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: ConfigEntry,
|
||||||
|
entity: MockAssistSatellite,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
hass_client: ClientSessionGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test connection test."""
|
||||||
|
ws_client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
await ws_client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_satellite/test_connection",
|
||||||
|
"entity_id": ENTITY_ID,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
assert len(entity.announcements) == 1
|
||||||
|
assert entity.announcements[0].message == ""
|
||||||
|
announcement_media_id = entity.announcements[0].media_id
|
||||||
|
hass_url = "http://10.10.10.10:8123"
|
||||||
|
assert announcement_media_id.startswith(
|
||||||
|
f"{hass_url}/api/assist_satellite/connection_test/"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fake satellite fetches the URL
|
||||||
|
client = await hass_client()
|
||||||
|
resp = await client.get(announcement_media_id[len(hass_url) :])
|
||||||
|
assert resp.status == HTTPStatus.OK
|
||||||
|
|
||||||
|
response = await ws_client.receive_json()
|
||||||
|
assert response["success"]
|
||||||
|
assert response["result"] == {"status": "success"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_connection_test_timeout(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: ConfigEntry,
|
||||||
|
entity: MockAssistSatellite,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
hass_client: ClientSessionGenerator,
|
||||||
|
freezer: FrozenDateTimeFactory,
|
||||||
|
) -> None:
|
||||||
|
"""Test connection test timeout."""
|
||||||
|
ws_client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
await ws_client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_satellite/test_connection",
|
||||||
|
"entity_id": ENTITY_ID,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
assert len(entity.announcements) == 1
|
||||||
|
assert entity.announcements[0].message == ""
|
||||||
|
announcement_media_id = entity.announcements[0].media_id
|
||||||
|
hass_url = "http://10.10.10.10:8123"
|
||||||
|
assert announcement_media_id.startswith(
|
||||||
|
f"{hass_url}/api/assist_satellite/connection_test/"
|
||||||
|
)
|
||||||
|
|
||||||
|
freezer.tick(CONNECTION_TEST_TIMEOUT + 1)
|
||||||
|
|
||||||
|
# Timeout
|
||||||
|
response = await ws_client.receive_json()
|
||||||
|
assert response["success"]
|
||||||
|
assert response["result"] == {"status": "timeout"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_connection_test_invalid_satellite(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: ConfigEntry,
|
||||||
|
entity: MockAssistSatellite,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test connection test with unknown entity id."""
|
||||||
|
ws_client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
await ws_client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_satellite/test_connection",
|
||||||
|
"entity_id": "assist_satellite.invalid",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
response = await ws_client.receive_json()
|
||||||
|
|
||||||
|
assert not response["success"]
|
||||||
|
assert response["error"] == {
|
||||||
|
"code": "not_found",
|
||||||
|
"message": "Entity not found",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_connection_test_timeout_announcement_unsupported(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: ConfigEntry,
|
||||||
|
entity: MockAssistSatellite,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test connection test entity which does not support announce."""
|
||||||
|
ws_client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
# Disable announce support
|
||||||
|
entity.supported_features = 0
|
||||||
|
|
||||||
|
await ws_client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_satellite/test_connection",
|
||||||
|
"entity_id": ENTITY_ID,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
response = await ws_client.receive_json()
|
||||||
|
|
||||||
|
assert not response["success"]
|
||||||
|
assert response["error"] == {
|
||||||
|
"code": "not_supported",
|
||||||
|
"message": "Entity does not support announce",
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user