Add Assist satellite configuration (#126063)

* Basic implementation

* Add websocket commands

* Clean up

* Add callback to other signatures

* Remove unused constant

* Re-add callback

* Add callback to test
This commit is contained in:
Michael Hansen 2024-09-16 21:34:07 -05:00 committed by GitHub
parent 738818aa7a
commit dde989685c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 300 additions and 3 deletions

View File

@ -11,15 +11,22 @@ from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from .const import DOMAIN, AssistSatelliteEntityFeature from .const import DOMAIN, AssistSatelliteEntityFeature
from .entity import AssistSatelliteEntity, AssistSatelliteEntityDescription from .entity import (
AssistSatelliteConfiguration,
AssistSatelliteEntity,
AssistSatelliteEntityDescription,
AssistSatelliteWakeWord,
)
from .errors import SatelliteBusyError from .errors import SatelliteBusyError
from .websocket_api import async_register_websocket_api from .websocket_api import async_register_websocket_api
__all__ = [ __all__ = [
"DOMAIN", "DOMAIN",
"AssistSatelliteEntity", "AssistSatelliteEntity",
"AssistSatelliteConfiguration",
"AssistSatelliteEntityDescription", "AssistSatelliteEntityDescription",
"AssistSatelliteEntityFeature", "AssistSatelliteEntityFeature",
"AssistSatelliteWakeWord",
"SatelliteBusyError", "SatelliteBusyError",
] ]

View File

@ -4,6 +4,7 @@ from abc import abstractmethod
import asyncio import asyncio
from collections.abc import AsyncIterable from collections.abc import AsyncIterable
import contextlib import contextlib
from dataclasses import dataclass
from enum import StrEnum from enum import StrEnum
import logging import logging
import time import time
@ -57,6 +58,34 @@ class AssistSatelliteEntityDescription(EntityDescription, frozen_or_thawed=True)
"""A class that describes Assist satellite entities.""" """A class that describes Assist satellite entities."""
@dataclass(frozen=True)
class AssistSatelliteWakeWord:
"""Available wake word model."""
id: str
"""Unique id for wake word model."""
wake_word: str
"""Wake word phrase."""
trained_languages: list[str]
"""List of languages that the wake word was trained on."""
@dataclass
class AssistSatelliteConfiguration:
"""Satellite configuration."""
available_wake_words: list[AssistSatelliteWakeWord]
"""List of available available wake word models."""
active_wake_words: list[str]
"""List of active wake word ids."""
max_active_wake_words: int
"""Maximum number of simultaneous wake words allowed (0 for no limit)."""
class AssistSatelliteEntity(entity.Entity): class AssistSatelliteEntity(entity.Entity):
"""Entity encapsulating the state and functionality of an Assist satellite.""" """Entity encapsulating the state and functionality of an Assist satellite."""
@ -98,6 +127,17 @@ class AssistSatelliteEntity(entity.Entity):
"""Options passed for text-to-speech.""" """Options passed for text-to-speech."""
return self._attr_tts_options return self._attr_tts_options
@callback
@abstractmethod
def async_get_configuration(self) -> AssistSatelliteConfiguration:
"""Get the current satellite configuration."""
@abstractmethod
async def async_set_configuration(
self, config: AssistSatelliteConfiguration
) -> None:
"""Set the current satellite configuration."""
async def async_intercept_wake_word(self) -> str | None: async def async_intercept_wake_word(self) -> str | None:
"""Intercept the next wake word from the satellite. """Intercept the next wake word from the satellite.

View File

@ -1,5 +1,6 @@
"""Assist satellite Websocket API.""" """Assist satellite Websocket API."""
from dataclasses import asdict, replace
from typing import Any from typing import Any
import voluptuous as vol import voluptuous as vol
@ -18,6 +19,8 @@ from .entity import AssistSatelliteEntity
def async_register_websocket_api(hass: HomeAssistant) -> None: def async_register_websocket_api(hass: HomeAssistant) -> None:
"""Register the websocket API.""" """Register the websocket API."""
websocket_api.async_register_command(hass, websocket_intercept_wake_word) websocket_api.async_register_command(hass, websocket_intercept_wake_word)
websocket_api.async_register_command(hass, websocket_get_configuration)
websocket_api.async_register_command(hass, websocket_set_wake_words)
@callback @callback
@ -59,3 +62,84 @@ async def websocket_intercept_wake_word(
task = hass.async_create_task(intercept_wake_word(), "intercept_wake_word") task = hass.async_create_task(intercept_wake_word(), "intercept_wake_word")
connection.subscriptions[msg["id"]] = task.cancel connection.subscriptions[msg["id"]] = task.cancel
connection.send_message(websocket_api.result_message(msg["id"])) connection.send_message(websocket_api.result_message(msg["id"]))
@callback
@websocket_api.websocket_command(
{
vol.Required("type"): "assist_satellite/get_configuration",
vol.Required("entity_id"): cv.entity_domain(DOMAIN),
}
)
def websocket_get_configuration(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Get the current satellite configuration."""
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
satellite = component.get_entity(msg["entity_id"])
if satellite is None:
connection.send_error(
msg["id"], websocket_api.ERR_NOT_FOUND, "Entity not found"
)
return
config_dict = asdict(satellite.async_get_configuration())
config_dict["pipeline_entity_id"] = satellite.pipeline_entity_id
config_dict["vad_entity_id"] = satellite.vad_sensitivity_entity_id
connection.send_result(msg["id"], config_dict)
@callback
@websocket_api.websocket_command(
{
vol.Required("type"): "assist_satellite/set_wake_words",
vol.Required("entity_id"): cv.entity_domain(DOMAIN),
vol.Required("wake_word_ids"): [str],
}
)
@websocket_api.require_admin
@websocket_api.async_response
async def websocket_set_wake_words(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Set the active wake words for the satellite."""
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
satellite = component.get_entity(msg["entity_id"])
if satellite is None:
connection.send_error(
msg["id"], websocket_api.ERR_NOT_FOUND, "Entity not found"
)
return
config = satellite.async_get_configuration()
# Don't set too many active wake words
actual_ids = msg["wake_word_ids"]
if len(actual_ids) > config.max_active_wake_words:
connection.send_error(
msg["id"],
websocket_api.ERR_NOT_SUPPORTED,
f"Maximum number of active wake words is {config.max_active_wake_words}",
)
return
# Verify all ids are available
available_ids = {ww.id for ww in config.available_wake_words}
for ww_id in actual_ids:
if ww_id not in available_ids:
connection.send_error(
msg["id"],
websocket_api.ERR_NOT_SUPPORTED,
f"Wake word id is not supported: {ww_id}",
)
return
await satellite.async_set_configuration(
replace(config, active_wake_words=actual_ids)
)
connection.send_result(msg["id"])

View File

@ -36,7 +36,7 @@ from homeassistant.components.intent import (
from homeassistant.components.media_player import async_process_play_media_url from homeassistant.components.media_player import async_process_play_media_url
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import EntityCategory, Platform from homeassistant.const import EntityCategory, Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import entity_registry as er from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
@ -150,6 +150,19 @@ class EsphomeAssistSatellite(
f"{self.entry_data.device_info.mac_address}-vad_sensitivity", f"{self.entry_data.device_info.mac_address}-vad_sensitivity",
) )
@callback
def async_get_configuration(
self,
) -> assist_satellite.AssistSatelliteConfiguration:
"""Get the current satellite configuration."""
raise NotImplementedError
async def async_set_configuration(
self, config: assist_satellite.AssistSatelliteConfiguration
) -> None:
"""Set the current satellite configuration."""
raise NotImplementedError
async def async_added_to_hass(self) -> None: async def async_added_to_hass(self) -> None:
"""Run when entity about to be added to hass.""" """Run when entity about to be added to hass."""
await super().async_added_to_hass() await super().async_added_to_hass()

View File

@ -20,6 +20,7 @@ from homeassistant.components.assist_pipeline import (
PipelineNotFound, PipelineNotFound,
) )
from homeassistant.components.assist_satellite import ( from homeassistant.components.assist_satellite import (
AssistSatelliteConfiguration,
AssistSatelliteEntity, AssistSatelliteEntity,
AssistSatelliteEntityDescription, AssistSatelliteEntityDescription,
) )
@ -141,6 +142,19 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
assert self.voip_device.protocol == self assert self.voip_device.protocol == self
self.voip_device.protocol = None self.voip_device.protocol = None
@callback
def async_get_configuration(
self,
) -> AssistSatelliteConfiguration:
"""Get the current satellite configuration."""
raise NotImplementedError
async def async_set_configuration(
self, config: AssistSatelliteConfiguration
) -> None:
"""Set the current satellite configuration."""
raise NotImplementedError
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# VoIP # VoIP
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------

View File

@ -8,11 +8,13 @@ import pytest
from homeassistant.components.assist_pipeline import PipelineEvent from homeassistant.components.assist_pipeline import PipelineEvent
from homeassistant.components.assist_satellite import ( from homeassistant.components.assist_satellite import (
DOMAIN as AS_DOMAIN, DOMAIN as AS_DOMAIN,
AssistSatelliteConfiguration,
AssistSatelliteEntity, AssistSatelliteEntity,
AssistSatelliteEntityFeature, AssistSatelliteEntityFeature,
AssistSatelliteWakeWord,
) )
from homeassistant.config_entries import ConfigEntry, ConfigFlow from homeassistant.config_entries import ConfigEntry, ConfigFlow
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant, callback
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import ( from tests.common import (
@ -42,6 +44,20 @@ class MockAssistSatellite(AssistSatelliteEntity):
"""Initialize the mock entity.""" """Initialize the mock entity."""
self.events = [] self.events = []
self.announcements = [] self.announcements = []
self.config = AssistSatelliteConfiguration(
available_wake_words=[
AssistSatelliteWakeWord(
id="1234", wake_word="okay nabu", trained_languages=["en"]
),
AssistSatelliteWakeWord(
id="5678",
wake_word="hey jarvis",
trained_languages=["en"],
),
],
active_wake_words=["1234"],
max_active_wake_words=1,
)
def on_pipeline_event(self, event: PipelineEvent) -> None: def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Handle pipeline events.""" """Handle pipeline events."""
@ -51,6 +67,17 @@ class MockAssistSatellite(AssistSatelliteEntity):
"""Announce media on a device.""" """Announce media on a device."""
self.announcements.append((message, media_id)) self.announcements.append((message, media_id))
@callback
def async_get_configuration(self) -> AssistSatelliteConfiguration:
"""Get the current satellite configuration."""
return self.config
async def async_set_configuration(
self, config: AssistSatelliteConfiguration
) -> None:
"""Set the current satellite configuration."""
self.config = config
@pytest.fixture @pytest.fixture
def entity() -> MockAssistSatellite: def entity() -> MockAssistSatellite:

View File

@ -273,3 +273,115 @@ async def test_intercept_wake_word_unsubscribe(
# Wake word should not be intercepted # Wake word should not be intercepted
mock_pipeline_from_audio_stream.assert_called_once() mock_pipeline_from_audio_stream.assert_called_once()
async def test_get_configuration(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test getting satellite configuration."""
ws_client = await hass_ws_client(hass)
with (
patch.object(entity, "_attr_pipeline_entity_id", "select.test_pipeline"),
patch.object(entity, "_attr_vad_sensitivity_entity_id", "select.test_vad"),
):
await ws_client.send_json_auto_id(
{
"type": "assist_satellite/get_configuration",
"entity_id": ENTITY_ID,
}
)
msg = await ws_client.receive_json()
assert msg["success"]
assert msg["result"] == {
"active_wake_words": ["1234"],
"available_wake_words": [
{"id": "1234", "trained_languages": ["en"], "wake_word": "okay nabu"},
{"id": "5678", "trained_languages": ["en"], "wake_word": "hey jarvis"},
],
"max_active_wake_words": 1,
"pipeline_entity_id": "select.test_pipeline",
"vad_entity_id": "select.test_vad",
}
async def test_set_wake_words(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test setting active wake words."""
ws_client = await hass_ws_client(hass)
await ws_client.send_json_auto_id(
{
"type": "assist_satellite/set_wake_words",
"entity_id": ENTITY_ID,
"wake_word_ids": ["5678"],
}
)
msg = await ws_client.receive_json()
assert msg["success"]
# Verify change
await ws_client.send_json_auto_id(
{
"type": "assist_satellite/get_configuration",
"entity_id": ENTITY_ID,
}
)
msg = await ws_client.receive_json()
assert msg["success"]
assert msg["result"].get("active_wake_words") == ["5678"]
async def test_set_wake_words_exceed_maximum(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test setting too many active wake words."""
ws_client = await hass_ws_client(hass)
await ws_client.send_json_auto_id(
{
"type": "assist_satellite/set_wake_words",
"entity_id": ENTITY_ID,
"wake_word_ids": ["1234", "5678"], # max of 1
}
)
msg = await ws_client.receive_json()
assert not msg["success"]
assert msg["error"] == {
"code": "not_supported",
"message": "Maximum number of active wake words is 1",
}
async def test_set_wake_words_bad_id(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test setting active wake words with a bad id."""
ws_client = await hass_ws_client(hass)
await ws_client.send_json_auto_id(
{
"type": "assist_satellite/set_wake_words",
"entity_id": ENTITY_ID,
"wake_word_ids": ["abcd"], # not an available id
}
)
msg = await ws_client.receive_json()
assert not msg["success"]
assert msg["error"] == {
"code": "not_supported",
"message": "Wake word id is not supported: abcd",
}