mirror of
https://github.com/home-assistant/core.git
synced 2025-07-17 18:27:09 +00:00
Add Assist satellite configuration (#126063)
* Basic implementation * Add websocket commands * Clean up * Add callback to other signatures * Remove unused constant * Re-add callback * Add callback to test
This commit is contained in:
parent
738818aa7a
commit
dde989685c
@ -11,15 +11,22 @@ from homeassistant.helpers.entity_component import EntityComponent
|
|||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
|
||||||
from .const import DOMAIN, AssistSatelliteEntityFeature
|
from .const import DOMAIN, AssistSatelliteEntityFeature
|
||||||
from .entity import AssistSatelliteEntity, AssistSatelliteEntityDescription
|
from .entity import (
|
||||||
|
AssistSatelliteConfiguration,
|
||||||
|
AssistSatelliteEntity,
|
||||||
|
AssistSatelliteEntityDescription,
|
||||||
|
AssistSatelliteWakeWord,
|
||||||
|
)
|
||||||
from .errors import SatelliteBusyError
|
from .errors import SatelliteBusyError
|
||||||
from .websocket_api import async_register_websocket_api
|
from .websocket_api import async_register_websocket_api
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DOMAIN",
|
"DOMAIN",
|
||||||
"AssistSatelliteEntity",
|
"AssistSatelliteEntity",
|
||||||
|
"AssistSatelliteConfiguration",
|
||||||
"AssistSatelliteEntityDescription",
|
"AssistSatelliteEntityDescription",
|
||||||
"AssistSatelliteEntityFeature",
|
"AssistSatelliteEntityFeature",
|
||||||
|
"AssistSatelliteWakeWord",
|
||||||
"SatelliteBusyError",
|
"SatelliteBusyError",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ from abc import abstractmethod
|
|||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import AsyncIterable
|
from collections.abc import AsyncIterable
|
||||||
import contextlib
|
import contextlib
|
||||||
|
from dataclasses import dataclass
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
@ -57,6 +58,34 @@ class AssistSatelliteEntityDescription(EntityDescription, frozen_or_thawed=True)
|
|||||||
"""A class that describes Assist satellite entities."""
|
"""A class that describes Assist satellite entities."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AssistSatelliteWakeWord:
|
||||||
|
"""Available wake word model."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
"""Unique id for wake word model."""
|
||||||
|
|
||||||
|
wake_word: str
|
||||||
|
"""Wake word phrase."""
|
||||||
|
|
||||||
|
trained_languages: list[str]
|
||||||
|
"""List of languages that the wake word was trained on."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AssistSatelliteConfiguration:
|
||||||
|
"""Satellite configuration."""
|
||||||
|
|
||||||
|
available_wake_words: list[AssistSatelliteWakeWord]
|
||||||
|
"""List of available available wake word models."""
|
||||||
|
|
||||||
|
active_wake_words: list[str]
|
||||||
|
"""List of active wake word ids."""
|
||||||
|
|
||||||
|
max_active_wake_words: int
|
||||||
|
"""Maximum number of simultaneous wake words allowed (0 for no limit)."""
|
||||||
|
|
||||||
|
|
||||||
class AssistSatelliteEntity(entity.Entity):
|
class AssistSatelliteEntity(entity.Entity):
|
||||||
"""Entity encapsulating the state and functionality of an Assist satellite."""
|
"""Entity encapsulating the state and functionality of an Assist satellite."""
|
||||||
|
|
||||||
@ -98,6 +127,17 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
"""Options passed for text-to-speech."""
|
"""Options passed for text-to-speech."""
|
||||||
return self._attr_tts_options
|
return self._attr_tts_options
|
||||||
|
|
||||||
|
@callback
|
||||||
|
@abstractmethod
|
||||||
|
def async_get_configuration(self) -> AssistSatelliteConfiguration:
|
||||||
|
"""Get the current satellite configuration."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def async_set_configuration(
|
||||||
|
self, config: AssistSatelliteConfiguration
|
||||||
|
) -> None:
|
||||||
|
"""Set the current satellite configuration."""
|
||||||
|
|
||||||
async def async_intercept_wake_word(self) -> str | None:
|
async def async_intercept_wake_word(self) -> str | None:
|
||||||
"""Intercept the next wake word from the satellite.
|
"""Intercept the next wake word from the satellite.
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Assist satellite Websocket API."""
|
"""Assist satellite Websocket API."""
|
||||||
|
|
||||||
|
from dataclasses import asdict, replace
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
@ -18,6 +19,8 @@ from .entity import AssistSatelliteEntity
|
|||||||
def async_register_websocket_api(hass: HomeAssistant) -> None:
|
def async_register_websocket_api(hass: HomeAssistant) -> None:
|
||||||
"""Register the websocket API."""
|
"""Register the websocket API."""
|
||||||
websocket_api.async_register_command(hass, websocket_intercept_wake_word)
|
websocket_api.async_register_command(hass, websocket_intercept_wake_word)
|
||||||
|
websocket_api.async_register_command(hass, websocket_get_configuration)
|
||||||
|
websocket_api.async_register_command(hass, websocket_set_wake_words)
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@ -59,3 +62,84 @@ async def websocket_intercept_wake_word(
|
|||||||
task = hass.async_create_task(intercept_wake_word(), "intercept_wake_word")
|
task = hass.async_create_task(intercept_wake_word(), "intercept_wake_word")
|
||||||
connection.subscriptions[msg["id"]] = task.cancel
|
connection.subscriptions[msg["id"]] = task.cancel
|
||||||
connection.send_message(websocket_api.result_message(msg["id"]))
|
connection.send_message(websocket_api.result_message(msg["id"]))
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
@websocket_api.websocket_command(
|
||||||
|
{
|
||||||
|
vol.Required("type"): "assist_satellite/get_configuration",
|
||||||
|
vol.Required("entity_id"): cv.entity_domain(DOMAIN),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def websocket_get_configuration(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
connection: websocket_api.connection.ActiveConnection,
|
||||||
|
msg: dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
"""Get the current satellite configuration."""
|
||||||
|
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
|
||||||
|
satellite = component.get_entity(msg["entity_id"])
|
||||||
|
if satellite is None:
|
||||||
|
connection.send_error(
|
||||||
|
msg["id"], websocket_api.ERR_NOT_FOUND, "Entity not found"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
config_dict = asdict(satellite.async_get_configuration())
|
||||||
|
config_dict["pipeline_entity_id"] = satellite.pipeline_entity_id
|
||||||
|
config_dict["vad_entity_id"] = satellite.vad_sensitivity_entity_id
|
||||||
|
|
||||||
|
connection.send_result(msg["id"], config_dict)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
@websocket_api.websocket_command(
|
||||||
|
{
|
||||||
|
vol.Required("type"): "assist_satellite/set_wake_words",
|
||||||
|
vol.Required("entity_id"): cv.entity_domain(DOMAIN),
|
||||||
|
vol.Required("wake_word_ids"): [str],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@websocket_api.require_admin
|
||||||
|
@websocket_api.async_response
|
||||||
|
async def websocket_set_wake_words(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
connection: websocket_api.connection.ActiveConnection,
|
||||||
|
msg: dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
"""Set the active wake words for the satellite."""
|
||||||
|
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
|
||||||
|
satellite = component.get_entity(msg["entity_id"])
|
||||||
|
if satellite is None:
|
||||||
|
connection.send_error(
|
||||||
|
msg["id"], websocket_api.ERR_NOT_FOUND, "Entity not found"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
config = satellite.async_get_configuration()
|
||||||
|
|
||||||
|
# Don't set too many active wake words
|
||||||
|
actual_ids = msg["wake_word_ids"]
|
||||||
|
if len(actual_ids) > config.max_active_wake_words:
|
||||||
|
connection.send_error(
|
||||||
|
msg["id"],
|
||||||
|
websocket_api.ERR_NOT_SUPPORTED,
|
||||||
|
f"Maximum number of active wake words is {config.max_active_wake_words}",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Verify all ids are available
|
||||||
|
available_ids = {ww.id for ww in config.available_wake_words}
|
||||||
|
for ww_id in actual_ids:
|
||||||
|
if ww_id not in available_ids:
|
||||||
|
connection.send_error(
|
||||||
|
msg["id"],
|
||||||
|
websocket_api.ERR_NOT_SUPPORTED,
|
||||||
|
f"Wake word id is not supported: {ww_id}",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
await satellite.async_set_configuration(
|
||||||
|
replace(config, active_wake_words=actual_ids)
|
||||||
|
)
|
||||||
|
connection.send_result(msg["id"])
|
||||||
|
@ -36,7 +36,7 @@ from homeassistant.components.intent import (
|
|||||||
from homeassistant.components.media_player import async_process_play_media_url
|
from homeassistant.components.media_player import async_process_play_media_url
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.const import EntityCategory, Platform
|
from homeassistant.const import EntityCategory, Platform
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.helpers import entity_registry as er
|
from homeassistant.helpers import entity_registry as er
|
||||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
|
|
||||||
@ -150,6 +150,19 @@ class EsphomeAssistSatellite(
|
|||||||
f"{self.entry_data.device_info.mac_address}-vad_sensitivity",
|
f"{self.entry_data.device_info.mac_address}-vad_sensitivity",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_get_configuration(
|
||||||
|
self,
|
||||||
|
) -> assist_satellite.AssistSatelliteConfiguration:
|
||||||
|
"""Get the current satellite configuration."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def async_set_configuration(
|
||||||
|
self, config: assist_satellite.AssistSatelliteConfiguration
|
||||||
|
) -> None:
|
||||||
|
"""Set the current satellite configuration."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
async def async_added_to_hass(self) -> None:
|
async def async_added_to_hass(self) -> None:
|
||||||
"""Run when entity about to be added to hass."""
|
"""Run when entity about to be added to hass."""
|
||||||
await super().async_added_to_hass()
|
await super().async_added_to_hass()
|
||||||
|
@ -20,6 +20,7 @@ from homeassistant.components.assist_pipeline import (
|
|||||||
PipelineNotFound,
|
PipelineNotFound,
|
||||||
)
|
)
|
||||||
from homeassistant.components.assist_satellite import (
|
from homeassistant.components.assist_satellite import (
|
||||||
|
AssistSatelliteConfiguration,
|
||||||
AssistSatelliteEntity,
|
AssistSatelliteEntity,
|
||||||
AssistSatelliteEntityDescription,
|
AssistSatelliteEntityDescription,
|
||||||
)
|
)
|
||||||
@ -141,6 +142,19 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
|||||||
assert self.voip_device.protocol == self
|
assert self.voip_device.protocol == self
|
||||||
self.voip_device.protocol = None
|
self.voip_device.protocol = None
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_get_configuration(
|
||||||
|
self,
|
||||||
|
) -> AssistSatelliteConfiguration:
|
||||||
|
"""Get the current satellite configuration."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def async_set_configuration(
|
||||||
|
self, config: AssistSatelliteConfiguration
|
||||||
|
) -> None:
|
||||||
|
"""Set the current satellite configuration."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
# VoIP
|
# VoIP
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
|
@ -8,11 +8,13 @@ import pytest
|
|||||||
from homeassistant.components.assist_pipeline import PipelineEvent
|
from homeassistant.components.assist_pipeline import PipelineEvent
|
||||||
from homeassistant.components.assist_satellite import (
|
from homeassistant.components.assist_satellite import (
|
||||||
DOMAIN as AS_DOMAIN,
|
DOMAIN as AS_DOMAIN,
|
||||||
|
AssistSatelliteConfiguration,
|
||||||
AssistSatelliteEntity,
|
AssistSatelliteEntity,
|
||||||
AssistSatelliteEntityFeature,
|
AssistSatelliteEntityFeature,
|
||||||
|
AssistSatelliteWakeWord,
|
||||||
)
|
)
|
||||||
from homeassistant.config_entries import ConfigEntry, ConfigFlow
|
from homeassistant.config_entries import ConfigEntry, ConfigFlow
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from tests.common import (
|
from tests.common import (
|
||||||
@ -42,6 +44,20 @@ class MockAssistSatellite(AssistSatelliteEntity):
|
|||||||
"""Initialize the mock entity."""
|
"""Initialize the mock entity."""
|
||||||
self.events = []
|
self.events = []
|
||||||
self.announcements = []
|
self.announcements = []
|
||||||
|
self.config = AssistSatelliteConfiguration(
|
||||||
|
available_wake_words=[
|
||||||
|
AssistSatelliteWakeWord(
|
||||||
|
id="1234", wake_word="okay nabu", trained_languages=["en"]
|
||||||
|
),
|
||||||
|
AssistSatelliteWakeWord(
|
||||||
|
id="5678",
|
||||||
|
wake_word="hey jarvis",
|
||||||
|
trained_languages=["en"],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
active_wake_words=["1234"],
|
||||||
|
max_active_wake_words=1,
|
||||||
|
)
|
||||||
|
|
||||||
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||||
"""Handle pipeline events."""
|
"""Handle pipeline events."""
|
||||||
@ -51,6 +67,17 @@ class MockAssistSatellite(AssistSatelliteEntity):
|
|||||||
"""Announce media on a device."""
|
"""Announce media on a device."""
|
||||||
self.announcements.append((message, media_id))
|
self.announcements.append((message, media_id))
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_get_configuration(self) -> AssistSatelliteConfiguration:
|
||||||
|
"""Get the current satellite configuration."""
|
||||||
|
return self.config
|
||||||
|
|
||||||
|
async def async_set_configuration(
|
||||||
|
self, config: AssistSatelliteConfiguration
|
||||||
|
) -> None:
|
||||||
|
"""Set the current satellite configuration."""
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def entity() -> MockAssistSatellite:
|
def entity() -> MockAssistSatellite:
|
||||||
|
@ -273,3 +273,115 @@ async def test_intercept_wake_word_unsubscribe(
|
|||||||
|
|
||||||
# Wake word should not be intercepted
|
# Wake word should not be intercepted
|
||||||
mock_pipeline_from_audio_stream.assert_called_once()
|
mock_pipeline_from_audio_stream.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_configuration(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: ConfigEntry,
|
||||||
|
entity: MockAssistSatellite,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test getting satellite configuration."""
|
||||||
|
ws_client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(entity, "_attr_pipeline_entity_id", "select.test_pipeline"),
|
||||||
|
patch.object(entity, "_attr_vad_sensitivity_entity_id", "select.test_vad"),
|
||||||
|
):
|
||||||
|
await ws_client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_satellite/get_configuration",
|
||||||
|
"entity_id": ENTITY_ID,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await ws_client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] == {
|
||||||
|
"active_wake_words": ["1234"],
|
||||||
|
"available_wake_words": [
|
||||||
|
{"id": "1234", "trained_languages": ["en"], "wake_word": "okay nabu"},
|
||||||
|
{"id": "5678", "trained_languages": ["en"], "wake_word": "hey jarvis"},
|
||||||
|
],
|
||||||
|
"max_active_wake_words": 1,
|
||||||
|
"pipeline_entity_id": "select.test_pipeline",
|
||||||
|
"vad_entity_id": "select.test_vad",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_set_wake_words(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: ConfigEntry,
|
||||||
|
entity: MockAssistSatellite,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test setting active wake words."""
|
||||||
|
ws_client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
await ws_client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_satellite/set_wake_words",
|
||||||
|
"entity_id": ENTITY_ID,
|
||||||
|
"wake_word_ids": ["5678"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await ws_client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
|
||||||
|
# Verify change
|
||||||
|
await ws_client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_satellite/get_configuration",
|
||||||
|
"entity_id": ENTITY_ID,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await ws_client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"].get("active_wake_words") == ["5678"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_set_wake_words_exceed_maximum(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: ConfigEntry,
|
||||||
|
entity: MockAssistSatellite,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test setting too many active wake words."""
|
||||||
|
ws_client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
await ws_client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_satellite/set_wake_words",
|
||||||
|
"entity_id": ENTITY_ID,
|
||||||
|
"wake_word_ids": ["1234", "5678"], # max of 1
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await ws_client.receive_json()
|
||||||
|
assert not msg["success"]
|
||||||
|
assert msg["error"] == {
|
||||||
|
"code": "not_supported",
|
||||||
|
"message": "Maximum number of active wake words is 1",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_set_wake_words_bad_id(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: ConfigEntry,
|
||||||
|
entity: MockAssistSatellite,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test setting active wake words with a bad id."""
|
||||||
|
ws_client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
await ws_client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_satellite/set_wake_words",
|
||||||
|
"entity_id": ENTITY_ID,
|
||||||
|
"wake_word_ids": ["abcd"], # not an available id
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await ws_client.receive_json()
|
||||||
|
assert not msg["success"]
|
||||||
|
assert msg["error"] == {
|
||||||
|
"code": "not_supported",
|
||||||
|
"message": "Wake word id is not supported: abcd",
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user