mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 05:07:41 +00:00
Add assist satellite entity component (#125351)
* Add assist_satellite * Update homeassistant/components/assist_satellite/manifest.json Co-authored-by: Paulus Schoutsen <balloob@gmail.com> * Update homeassistant/components/assist_satellite/manifest.json Co-authored-by: Paulus Schoutsen <balloob@gmail.com> * Add platform constant * Update Dockerfile * Apply suggestions from code review Co-authored-by: Martin Hjelmare <marhje52@gmail.com> * Address comments * Update docstring async_internal_announce * Update CODEOWNERS --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com> Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
parent
c3921f2112
commit
60b0f0dc53
@ -14,6 +14,7 @@ core: &core
|
|||||||
base_platforms: &base_platforms
|
base_platforms: &base_platforms
|
||||||
- homeassistant/components/air_quality/**
|
- homeassistant/components/air_quality/**
|
||||||
- homeassistant/components/alarm_control_panel/**
|
- homeassistant/components/alarm_control_panel/**
|
||||||
|
- homeassistant/components/assist_satellite/**
|
||||||
- homeassistant/components/binary_sensor/**
|
- homeassistant/components/binary_sensor/**
|
||||||
- homeassistant/components/button/**
|
- homeassistant/components/button/**
|
||||||
- homeassistant/components/calendar/**
|
- homeassistant/components/calendar/**
|
||||||
|
@ -95,6 +95,7 @@ homeassistant.components.aruba.*
|
|||||||
homeassistant.components.arwn.*
|
homeassistant.components.arwn.*
|
||||||
homeassistant.components.aseko_pool_live.*
|
homeassistant.components.aseko_pool_live.*
|
||||||
homeassistant.components.assist_pipeline.*
|
homeassistant.components.assist_pipeline.*
|
||||||
|
homeassistant.components.assist_satellite.*
|
||||||
homeassistant.components.asuswrt.*
|
homeassistant.components.asuswrt.*
|
||||||
homeassistant.components.autarco.*
|
homeassistant.components.autarco.*
|
||||||
homeassistant.components.auth.*
|
homeassistant.components.auth.*
|
||||||
|
@ -143,6 +143,8 @@ build.json @home-assistant/supervisor
|
|||||||
/tests/components/aseko_pool_live/ @milanmeu
|
/tests/components/aseko_pool_live/ @milanmeu
|
||||||
/homeassistant/components/assist_pipeline/ @balloob @synesthesiam
|
/homeassistant/components/assist_pipeline/ @balloob @synesthesiam
|
||||||
/tests/components/assist_pipeline/ @balloob @synesthesiam
|
/tests/components/assist_pipeline/ @balloob @synesthesiam
|
||||||
|
/homeassistant/components/assist_satellite/ @home-assistant/core @synesthesiam
|
||||||
|
/tests/components/assist_satellite/ @home-assistant/core @synesthesiam
|
||||||
/homeassistant/components/asuswrt/ @kennedyshead @ollo69
|
/homeassistant/components/asuswrt/ @kennedyshead @ollo69
|
||||||
/tests/components/asuswrt/ @kennedyshead @ollo69
|
/tests/components/asuswrt/ @kennedyshead @ollo69
|
||||||
/homeassistant/components/atag/ @MatsNL
|
/homeassistant/components/atag/ @MatsNL
|
||||||
|
@ -17,6 +17,7 @@ from .const import (
|
|||||||
DATA_LAST_WAKE_UP,
|
DATA_LAST_WAKE_UP,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
EVENT_RECORDING,
|
EVENT_RECORDING,
|
||||||
|
OPTION_PREFERRED,
|
||||||
SAMPLE_CHANNELS,
|
SAMPLE_CHANNELS,
|
||||||
SAMPLE_RATE,
|
SAMPLE_RATE,
|
||||||
SAMPLE_WIDTH,
|
SAMPLE_WIDTH,
|
||||||
@ -58,6 +59,7 @@ __all__ = (
|
|||||||
"PipelineNotFound",
|
"PipelineNotFound",
|
||||||
"WakeWordSettings",
|
"WakeWordSettings",
|
||||||
"EVENT_RECORDING",
|
"EVENT_RECORDING",
|
||||||
|
"OPTION_PREFERRED",
|
||||||
"SAMPLES_PER_CHUNK",
|
"SAMPLES_PER_CHUNK",
|
||||||
"SAMPLE_RATE",
|
"SAMPLE_RATE",
|
||||||
"SAMPLE_WIDTH",
|
"SAMPLE_WIDTH",
|
||||||
|
@ -22,3 +22,5 @@ SAMPLE_CHANNELS = 1 # mono
|
|||||||
MS_PER_CHUNK = 10
|
MS_PER_CHUNK = 10
|
||||||
SAMPLES_PER_CHUNK = SAMPLE_RATE // (1000 // MS_PER_CHUNK) # 10 ms @ 16Khz
|
SAMPLES_PER_CHUNK = SAMPLE_RATE // (1000 // MS_PER_CHUNK) # 10 ms @ 16Khz
|
||||||
BYTES_PER_CHUNK = SAMPLES_PER_CHUNK * SAMPLE_WIDTH * SAMPLE_CHANNELS # 16-bit
|
BYTES_PER_CHUNK = SAMPLES_PER_CHUNK * SAMPLE_WIDTH * SAMPLE_CHANNELS # 16-bit
|
||||||
|
|
||||||
|
OPTION_PREFERRED = "preferred"
|
||||||
|
@ -9,12 +9,10 @@ from homeassistant.const import EntityCategory, Platform
|
|||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.helpers import collection, entity_registry as er, restore_state
|
from homeassistant.helpers import collection, entity_registry as er, restore_state
|
||||||
|
|
||||||
from .const import DOMAIN
|
from .const import DOMAIN, OPTION_PREFERRED
|
||||||
from .pipeline import AssistDevice, PipelineData, PipelineStorageCollection
|
from .pipeline import AssistDevice, PipelineData, PipelineStorageCollection
|
||||||
from .vad import VadSensitivity
|
from .vad import VadSensitivity
|
||||||
|
|
||||||
OPTION_PREFERRED = "preferred"
|
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def get_chosen_pipeline(
|
def get_chosen_pipeline(
|
||||||
|
65
homeassistant/components/assist_satellite/__init__.py
Normal file
65
homeassistant/components/assist_satellite/__init__.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
"""Base class for assist satellite entities."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers import config_validation as cv
|
||||||
|
from homeassistant.helpers.entity_component import EntityComponent
|
||||||
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
|
||||||
|
from .const import DOMAIN, AssistSatelliteEntityFeature
|
||||||
|
from .entity import AssistSatelliteEntity, AssistSatelliteEntityDescription
|
||||||
|
from .errors import SatelliteBusyError
|
||||||
|
from .websocket_api import async_register_websocket_api
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DOMAIN",
|
||||||
|
"AssistSatelliteEntity",
|
||||||
|
"AssistSatelliteEntityDescription",
|
||||||
|
"AssistSatelliteEntityFeature",
|
||||||
|
"SatelliteBusyError",
|
||||||
|
]
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
PLATFORM_SCHEMA_BASE = cv.PLATFORM_SCHEMA_BASE
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
|
component = hass.data[DOMAIN] = EntityComponent[AssistSatelliteEntity](
|
||||||
|
_LOGGER, DOMAIN, hass
|
||||||
|
)
|
||||||
|
await component.async_setup(config)
|
||||||
|
|
||||||
|
component.async_register_entity_service(
|
||||||
|
"announce",
|
||||||
|
vol.All(
|
||||||
|
cv.make_entity_service_schema(
|
||||||
|
{
|
||||||
|
vol.Optional("message"): str,
|
||||||
|
vol.Optional("media_id"): str,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
cv.has_at_least_one_key("message", "media_id"),
|
||||||
|
),
|
||||||
|
"async_internal_announce",
|
||||||
|
[AssistSatelliteEntityFeature.ANNOUNCE],
|
||||||
|
)
|
||||||
|
async_register_websocket_api(hass)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||||
|
"""Set up a config entry."""
|
||||||
|
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
|
||||||
|
return await component.async_setup_entry(entry)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||||
|
"""Unload a config entry."""
|
||||||
|
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
|
||||||
|
return await component.async_unload_entry(entry)
|
12
homeassistant/components/assist_satellite/const.py
Normal file
12
homeassistant/components/assist_satellite/const.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
"""Constants for assist satellite."""
|
||||||
|
|
||||||
|
from enum import IntFlag
|
||||||
|
|
||||||
|
DOMAIN = "assist_satellite"
|
||||||
|
|
||||||
|
|
||||||
|
class AssistSatelliteEntityFeature(IntFlag):
|
||||||
|
"""Supported features of Assist satellite entity."""
|
||||||
|
|
||||||
|
ANNOUNCE = 1
|
||||||
|
"""Device supports remotely triggered announcements."""
|
332
homeassistant/components/assist_satellite/entity.py
Normal file
332
homeassistant/components/assist_satellite/entity.py
Normal file
@ -0,0 +1,332 @@
|
|||||||
|
"""Assist satellite entity."""
|
||||||
|
|
||||||
|
from abc import abstractmethod
|
||||||
|
import asyncio
|
||||||
|
from collections.abc import AsyncIterable
|
||||||
|
from enum import StrEnum
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any, Final, final
|
||||||
|
|
||||||
|
from homeassistant.components import media_source, stt, tts
|
||||||
|
from homeassistant.components.assist_pipeline import (
|
||||||
|
OPTION_PREFERRED,
|
||||||
|
AudioSettings,
|
||||||
|
PipelineEvent,
|
||||||
|
PipelineEventType,
|
||||||
|
PipelineStage,
|
||||||
|
async_get_pipeline,
|
||||||
|
async_get_pipelines,
|
||||||
|
async_pipeline_from_audio_stream,
|
||||||
|
vad,
|
||||||
|
)
|
||||||
|
from homeassistant.components.media_player import async_process_play_media_url
|
||||||
|
from homeassistant.components.tts.media_source import (
|
||||||
|
generate_media_source_id as tts_generate_media_source_id,
|
||||||
|
)
|
||||||
|
from homeassistant.core import Context, callback
|
||||||
|
from homeassistant.helpers import entity
|
||||||
|
from homeassistant.helpers.entity import EntityDescription
|
||||||
|
from homeassistant.util import ulid
|
||||||
|
|
||||||
|
from .const import AssistSatelliteEntityFeature
|
||||||
|
from .errors import AssistSatelliteError, SatelliteBusyError
|
||||||
|
|
||||||
|
_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AssistSatelliteState(StrEnum):
|
||||||
|
"""Valid states of an Assist satellite entity."""
|
||||||
|
|
||||||
|
LISTENING_WAKE_WORD = "listening_wake_word"
|
||||||
|
"""Device is streaming audio for wake word detection to Home Assistant."""
|
||||||
|
|
||||||
|
LISTENING_COMMAND = "listening_command"
|
||||||
|
"""Device is streaming audio with the voice command to Home Assistant."""
|
||||||
|
|
||||||
|
PROCESSING = "processing"
|
||||||
|
"""Home Assistant is processing the voice command."""
|
||||||
|
|
||||||
|
RESPONDING = "responding"
|
||||||
|
"""Device is speaking the response."""
|
||||||
|
|
||||||
|
|
||||||
|
class AssistSatelliteEntityDescription(EntityDescription, frozen_or_thawed=True):
|
||||||
|
"""A class that describes Assist satellite entities."""
|
||||||
|
|
||||||
|
|
||||||
|
class AssistSatelliteEntity(entity.Entity):
|
||||||
|
"""Entity encapsulating the state and functionality of an Assist satellite."""
|
||||||
|
|
||||||
|
entity_description: AssistSatelliteEntityDescription
|
||||||
|
_attr_should_poll = False
|
||||||
|
_attr_supported_features = AssistSatelliteEntityFeature(0)
|
||||||
|
_attr_pipeline_entity_id: str | None = None
|
||||||
|
_attr_vad_sensitivity_entity_id: str | None = None
|
||||||
|
|
||||||
|
_conversation_id: str | None = None
|
||||||
|
_conversation_id_time: float | None = None
|
||||||
|
|
||||||
|
_run_has_tts: bool = False
|
||||||
|
_is_announcing = False
|
||||||
|
_wake_word_intercept_future: asyncio.Future[str | None] | None = None
|
||||||
|
|
||||||
|
__assist_satellite_state: AssistSatelliteState | None = None
|
||||||
|
|
||||||
|
@final
|
||||||
|
@property
|
||||||
|
def state(self) -> str | None:
|
||||||
|
"""Return state of the entity."""
|
||||||
|
return self.__assist_satellite_state
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pipeline_entity_id(self) -> str | None:
|
||||||
|
"""Entity ID of the pipeline to use for the next conversation."""
|
||||||
|
return self._attr_pipeline_entity_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vad_sensitivity_entity_id(self) -> str | None:
|
||||||
|
"""Entity ID of the VAD sensitivity to use for the next conversation."""
|
||||||
|
return self._attr_vad_sensitivity_entity_id
|
||||||
|
|
||||||
|
async def async_intercept_wake_word(self) -> str | None:
|
||||||
|
"""Intercept the next wake word from the satellite.
|
||||||
|
|
||||||
|
Returns the detected wake word phrase or None.
|
||||||
|
"""
|
||||||
|
if self._wake_word_intercept_future is not None:
|
||||||
|
raise SatelliteBusyError("Wake word interception already in progress")
|
||||||
|
|
||||||
|
# Will cause next wake word to be intercepted in
|
||||||
|
# async_accept_pipeline_from_satellite
|
||||||
|
self._wake_word_intercept_future = asyncio.Future()
|
||||||
|
|
||||||
|
_LOGGER.debug("Next wake word will be intercepted: %s", self.entity_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await self._wake_word_intercept_future
|
||||||
|
finally:
|
||||||
|
self._wake_word_intercept_future = None
|
||||||
|
|
||||||
|
async def async_internal_announce(
|
||||||
|
self,
|
||||||
|
message: str | None = None,
|
||||||
|
media_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Play and show an announcement on the satellite.
|
||||||
|
|
||||||
|
If media_id is not provided, message is synthesized to
|
||||||
|
audio with the selected pipeline.
|
||||||
|
|
||||||
|
If media_id is provided, it is played directly. It is possible
|
||||||
|
to omit the message and the satellite will not show any text.
|
||||||
|
|
||||||
|
Calls async_announce with message and media id.
|
||||||
|
"""
|
||||||
|
if message is None:
|
||||||
|
message = ""
|
||||||
|
|
||||||
|
if not media_id:
|
||||||
|
# Synthesize audio and get URL
|
||||||
|
pipeline_id = self._resolve_pipeline()
|
||||||
|
pipeline = async_get_pipeline(self.hass, pipeline_id)
|
||||||
|
|
||||||
|
tts_options: dict[str, Any] = {}
|
||||||
|
if pipeline.tts_voice is not None:
|
||||||
|
tts_options[tts.ATTR_VOICE] = pipeline.tts_voice
|
||||||
|
|
||||||
|
media_id = tts_generate_media_source_id(
|
||||||
|
self.hass,
|
||||||
|
message,
|
||||||
|
engine=pipeline.tts_engine,
|
||||||
|
language=pipeline.tts_language,
|
||||||
|
options=tts_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
if media_source.is_media_source_id(media_id):
|
||||||
|
media = await media_source.async_resolve_media(
|
||||||
|
self.hass,
|
||||||
|
media_id,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
media_id = media.url
|
||||||
|
|
||||||
|
# Resolve to full URL
|
||||||
|
media_id = async_process_play_media_url(self.hass, media_id)
|
||||||
|
|
||||||
|
if self._is_announcing:
|
||||||
|
raise SatelliteBusyError
|
||||||
|
|
||||||
|
self._is_announcing = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Block until announcement is finished
|
||||||
|
await self.async_announce(message, media_id)
|
||||||
|
finally:
|
||||||
|
self._is_announcing = False
|
||||||
|
|
||||||
|
async def async_announce(self, message: str, media_id: str) -> None:
|
||||||
|
"""Announce media on the satellite.
|
||||||
|
|
||||||
|
Should block until the announcement is done playing.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def async_accept_pipeline_from_satellite(
|
||||||
|
self,
|
||||||
|
audio_stream: AsyncIterable[bytes],
|
||||||
|
start_stage: PipelineStage = PipelineStage.STT,
|
||||||
|
end_stage: PipelineStage = PipelineStage.TTS,
|
||||||
|
wake_word_phrase: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Triggers an Assist pipeline in Home Assistant from a satellite."""
|
||||||
|
if self._wake_word_intercept_future and start_stage in (
|
||||||
|
PipelineStage.WAKE_WORD,
|
||||||
|
PipelineStage.STT,
|
||||||
|
):
|
||||||
|
if start_stage == PipelineStage.WAKE_WORD:
|
||||||
|
self._wake_word_intercept_future.set_exception(
|
||||||
|
AssistSatelliteError(
|
||||||
|
"Only on-device wake words currently supported"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Intercepting wake word and immediately end pipeline
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Intercepted wake word: %s (entity_id=%s)",
|
||||||
|
wake_word_phrase,
|
||||||
|
self.entity_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if wake_word_phrase is None:
|
||||||
|
self._wake_word_intercept_future.set_exception(
|
||||||
|
AssistSatelliteError("No wake word phrase provided")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._wake_word_intercept_future.set_result(wake_word_phrase)
|
||||||
|
self._internal_on_pipeline_event(PipelineEvent(PipelineEventType.RUN_END))
|
||||||
|
return
|
||||||
|
|
||||||
|
device_id = self.registry_entry.device_id if self.registry_entry else None
|
||||||
|
|
||||||
|
# Refresh context if necessary
|
||||||
|
if (
|
||||||
|
(self._context is None)
|
||||||
|
or (self._context_set is None)
|
||||||
|
or ((time.time() - self._context_set) > entity.CONTEXT_RECENT_TIME_SECONDS)
|
||||||
|
):
|
||||||
|
self.async_set_context(Context())
|
||||||
|
|
||||||
|
assert self._context is not None
|
||||||
|
|
||||||
|
# Reset conversation id if necessary
|
||||||
|
if (self._conversation_id_time is None) or (
|
||||||
|
(time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC
|
||||||
|
):
|
||||||
|
self._conversation_id = None
|
||||||
|
|
||||||
|
if self._conversation_id is None:
|
||||||
|
self._conversation_id = ulid.ulid()
|
||||||
|
|
||||||
|
# Update timeout
|
||||||
|
self._conversation_id_time = time.monotonic()
|
||||||
|
|
||||||
|
# Set entity state based on pipeline events
|
||||||
|
self._run_has_tts = False
|
||||||
|
|
||||||
|
await async_pipeline_from_audio_stream(
|
||||||
|
self.hass,
|
||||||
|
context=self._context,
|
||||||
|
event_callback=self._internal_on_pipeline_event,
|
||||||
|
stt_metadata=stt.SpeechMetadata(
|
||||||
|
language="", # set in async_pipeline_from_audio_stream
|
||||||
|
format=stt.AudioFormats.WAV,
|
||||||
|
codec=stt.AudioCodecs.PCM,
|
||||||
|
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||||
|
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||||
|
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||||
|
),
|
||||||
|
stt_stream=audio_stream,
|
||||||
|
pipeline_id=self._resolve_pipeline(),
|
||||||
|
conversation_id=self._conversation_id,
|
||||||
|
device_id=device_id,
|
||||||
|
tts_audio_output="wav",
|
||||||
|
wake_word_phrase=wake_word_phrase,
|
||||||
|
audio_settings=AudioSettings(
|
||||||
|
silence_seconds=self._resolve_vad_sensitivity()
|
||||||
|
),
|
||||||
|
start_stage=start_stage,
|
||||||
|
end_stage=end_stage,
|
||||||
|
)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||||
|
"""Handle pipeline events."""
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _internal_on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||||
|
"""Set state based on pipeline stage."""
|
||||||
|
if event.type is PipelineEventType.WAKE_WORD_START:
|
||||||
|
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
|
||||||
|
elif event.type is PipelineEventType.STT_START:
|
||||||
|
self._set_state(AssistSatelliteState.LISTENING_COMMAND)
|
||||||
|
elif event.type is PipelineEventType.INTENT_START:
|
||||||
|
self._set_state(AssistSatelliteState.PROCESSING)
|
||||||
|
elif event.type is PipelineEventType.TTS_START:
|
||||||
|
# Wait until tts_response_finished is called to return to waiting state
|
||||||
|
self._run_has_tts = True
|
||||||
|
self._set_state(AssistSatelliteState.RESPONDING)
|
||||||
|
elif event.type is PipelineEventType.RUN_END:
|
||||||
|
if not self._run_has_tts:
|
||||||
|
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
|
||||||
|
|
||||||
|
self.on_pipeline_event(event)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _set_state(self, state: AssistSatelliteState) -> None:
|
||||||
|
"""Set the entity's state."""
|
||||||
|
self.__assist_satellite_state = state
|
||||||
|
self.async_write_ha_state()
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def tts_response_finished(self) -> None:
|
||||||
|
"""Tell entity that the text-to-speech response has finished playing."""
|
||||||
|
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _resolve_pipeline(self) -> str | None:
|
||||||
|
"""Resolve pipeline from select entity to id.
|
||||||
|
|
||||||
|
Return None to make async_get_pipeline look up the preferred pipeline.
|
||||||
|
"""
|
||||||
|
if not (pipeline_entity_id := self.pipeline_entity_id):
|
||||||
|
return None
|
||||||
|
|
||||||
|
if (pipeline_entity_state := self.hass.states.get(pipeline_entity_id)) is None:
|
||||||
|
raise RuntimeError("Pipeline entity not found")
|
||||||
|
|
||||||
|
if pipeline_entity_state.state != OPTION_PREFERRED:
|
||||||
|
# Resolve pipeline by name
|
||||||
|
for pipeline in async_get_pipelines(self.hass):
|
||||||
|
if pipeline.name == pipeline_entity_state.state:
|
||||||
|
return pipeline.id
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _resolve_vad_sensitivity(self) -> float:
|
||||||
|
"""Resolve VAD sensitivity from select entity to enum."""
|
||||||
|
vad_sensitivity = vad.VadSensitivity.DEFAULT
|
||||||
|
|
||||||
|
if vad_sensitivity_entity_id := self.vad_sensitivity_entity_id:
|
||||||
|
if (
|
||||||
|
vad_sensitivity_state := self.hass.states.get(vad_sensitivity_entity_id)
|
||||||
|
) is None:
|
||||||
|
raise RuntimeError("VAD sensitivity entity not found")
|
||||||
|
|
||||||
|
vad_sensitivity = vad.VadSensitivity(vad_sensitivity_state.state)
|
||||||
|
|
||||||
|
return vad.VadSensitivity.to_seconds(vad_sensitivity)
|
11
homeassistant/components/assist_satellite/errors.py
Normal file
11
homeassistant/components/assist_satellite/errors.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
"""Errors for assist satellite."""
|
||||||
|
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
|
||||||
|
|
||||||
|
class AssistSatelliteError(HomeAssistantError):
|
||||||
|
"""Base class for assist satellite errors."""
|
||||||
|
|
||||||
|
|
||||||
|
class SatelliteBusyError(AssistSatelliteError):
|
||||||
|
"""Satellite is busy and cannot handle the request."""
|
12
homeassistant/components/assist_satellite/icons.json
Normal file
12
homeassistant/components/assist_satellite/icons.json
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
{
|
||||||
|
"entity_component": {
|
||||||
|
"_": {
|
||||||
|
"default": "mdi:account-voice"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"services": {
|
||||||
|
"announce": {
|
||||||
|
"service": "mdi:bullhorn"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
9
homeassistant/components/assist_satellite/manifest.json
Normal file
9
homeassistant/components/assist_satellite/manifest.json
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
{
|
||||||
|
"domain": "assist_satellite",
|
||||||
|
"name": "Assist Satellite",
|
||||||
|
"codeowners": ["@home-assistant/core", "@synesthesiam"],
|
||||||
|
"dependencies": ["assist_pipeline", "stt", "tts"],
|
||||||
|
"documentation": "https://www.home-assistant.io/integrations/assist_satellite",
|
||||||
|
"integration_type": "entity",
|
||||||
|
"quality_scale": "internal"
|
||||||
|
}
|
16
homeassistant/components/assist_satellite/services.yaml
Normal file
16
homeassistant/components/assist_satellite/services.yaml
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
announce:
|
||||||
|
target:
|
||||||
|
entity:
|
||||||
|
domain: assist_satellite
|
||||||
|
supported_features:
|
||||||
|
- assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
|
||||||
|
fields:
|
||||||
|
message:
|
||||||
|
required: false
|
||||||
|
example: "Time to wake up!"
|
||||||
|
selector:
|
||||||
|
text:
|
||||||
|
media_id:
|
||||||
|
required: false
|
||||||
|
selector:
|
||||||
|
text:
|
30
homeassistant/components/assist_satellite/strings.json
Normal file
30
homeassistant/components/assist_satellite/strings.json
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
{
|
||||||
|
"title": "Assist satellite",
|
||||||
|
"entity_component": {
|
||||||
|
"_": {
|
||||||
|
"name": "Assist satellite",
|
||||||
|
"state": {
|
||||||
|
"listening_wake_word": "Wake word",
|
||||||
|
"listening_command": "Voice command",
|
||||||
|
"responding": "Responding",
|
||||||
|
"processing": "Processing"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"services": {
|
||||||
|
"announce": {
|
||||||
|
"name": "Announce",
|
||||||
|
"description": "Let the satellite announce a message.",
|
||||||
|
"fields": {
|
||||||
|
"message": {
|
||||||
|
"name": "Message",
|
||||||
|
"description": "The message to announce."
|
||||||
|
},
|
||||||
|
"media_id": {
|
||||||
|
"name": "Media ID",
|
||||||
|
"description": "The media ID to announce instead of using text-to-speech."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
46
homeassistant/components/assist_satellite/websocket_api.py
Normal file
46
homeassistant/components/assist_satellite/websocket_api.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
"""Assist satellite Websocket API."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant.components import websocket_api
|
||||||
|
from homeassistant.core import HomeAssistant, callback
|
||||||
|
from homeassistant.helpers import config_validation as cv
|
||||||
|
from homeassistant.helpers.entity_component import EntityComponent
|
||||||
|
|
||||||
|
from .const import DOMAIN
|
||||||
|
from .entity import AssistSatelliteEntity
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_register_websocket_api(hass: HomeAssistant) -> None:
|
||||||
|
"""Register the websocket API."""
|
||||||
|
websocket_api.async_register_command(hass, websocket_intercept_wake_word)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
@websocket_api.websocket_command(
|
||||||
|
{
|
||||||
|
vol.Required("type"): "assist_satellite/intercept_wake_word",
|
||||||
|
vol.Required("entity_id"): cv.entity_domain(DOMAIN),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@websocket_api.require_admin
|
||||||
|
@websocket_api.async_response
|
||||||
|
async def websocket_intercept_wake_word(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
connection: websocket_api.connection.ActiveConnection,
|
||||||
|
msg: dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
"""Intercept the next wake word from a satellite."""
|
||||||
|
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
|
||||||
|
satellite = component.get_entity(msg["entity_id"])
|
||||||
|
if satellite is None:
|
||||||
|
connection.send_error(
|
||||||
|
msg["id"], websocket_api.ERR_NOT_FOUND, "Entity not found"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
wake_word_phrase = await satellite.async_intercept_wake_word()
|
||||||
|
connection.send_result(msg["id"], {"wake_word_phrase": wake_word_phrase})
|
@ -41,6 +41,7 @@ class Platform(StrEnum):
|
|||||||
|
|
||||||
AIR_QUALITY = "air_quality"
|
AIR_QUALITY = "air_quality"
|
||||||
ALARM_CONTROL_PANEL = "alarm_control_panel"
|
ALARM_CONTROL_PANEL = "alarm_control_panel"
|
||||||
|
ASSIST_SATELLITE = "assist_satellite"
|
||||||
BINARY_SENSOR = "binary_sensor"
|
BINARY_SENSOR = "binary_sensor"
|
||||||
BUTTON = "button"
|
BUTTON = "button"
|
||||||
CALENDAR = "calendar"
|
CALENDAR = "calendar"
|
||||||
|
10
mypy.ini
10
mypy.ini
@ -705,6 +705,16 @@ disallow_untyped_defs = true
|
|||||||
warn_return_any = true
|
warn_return_any = true
|
||||||
warn_unreachable = true
|
warn_unreachable = true
|
||||||
|
|
||||||
|
[mypy-homeassistant.components.assist_satellite.*]
|
||||||
|
check_untyped_defs = true
|
||||||
|
disallow_incomplete_defs = true
|
||||||
|
disallow_subclassing_any = true
|
||||||
|
disallow_untyped_calls = true
|
||||||
|
disallow_untyped_decorators = true
|
||||||
|
disallow_untyped_defs = true
|
||||||
|
warn_return_any = true
|
||||||
|
warn_unreachable = true
|
||||||
|
|
||||||
[mypy-homeassistant.components.asuswrt.*]
|
[mypy-homeassistant.components.asuswrt.*]
|
||||||
check_untyped_defs = true
|
check_untyped_defs = true
|
||||||
disallow_incomplete_defs = true
|
disallow_incomplete_defs = true
|
||||||
|
@ -23,7 +23,7 @@ RUN --mount=from=ghcr.io/astral-sh/uv:0.2.27,source=/uv,target=/bin/uv \
|
|||||||
-c /usr/src/homeassistant/homeassistant/package_constraints.txt \
|
-c /usr/src/homeassistant/homeassistant/package_constraints.txt \
|
||||||
-r /usr/src/homeassistant/requirements.txt \
|
-r /usr/src/homeassistant/requirements.txt \
|
||||||
stdlib-list==0.10.0 pipdeptree==2.23.1 tqdm==4.66.4 ruff==0.6.2 \
|
stdlib-list==0.10.0 pipdeptree==2.23.1 tqdm==4.66.4 ruff==0.6.2 \
|
||||||
PyTurboJPEG==1.7.5 ha-ffmpeg==3.2.0 hassil==1.7.4 home-assistant-intents==2024.9.4 mutagen==1.47.0
|
PyTurboJPEG==1.7.5 ha-ffmpeg==3.2.0 hassil==1.7.4 home-assistant-intents==2024.9.4 mutagen==1.47.0 pymicro-vad==1.0.1 pyspeex-noise==1.0.2
|
||||||
|
|
||||||
LABEL "name"="hassfest"
|
LABEL "name"="hassfest"
|
||||||
LABEL "maintainer"="Home Assistant <hello@home-assistant.io>"
|
LABEL "maintainer"="Home Assistant <hello@home-assistant.io>"
|
||||||
|
3
tests/components/assist_satellite/__init__.py
Normal file
3
tests/components/assist_satellite/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
"""Tests for Assist Satellite."""
|
||||||
|
|
||||||
|
ENTITY_ID = "assist_satellite.test_entity"
|
107
tests/components/assist_satellite/conftest.py
Normal file
107
tests/components/assist_satellite/conftest.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
"""Test helpers for Assist Satellite."""
|
||||||
|
|
||||||
|
import pathlib
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.components.assist_pipeline import PipelineEvent
|
||||||
|
from homeassistant.components.assist_satellite import (
|
||||||
|
DOMAIN as AS_DOMAIN,
|
||||||
|
AssistSatelliteEntity,
|
||||||
|
AssistSatelliteEntityFeature,
|
||||||
|
)
|
||||||
|
from homeassistant.config_entries import ConfigEntry, ConfigFlow
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
from tests.common import (
|
||||||
|
MockConfigEntry,
|
||||||
|
MockModule,
|
||||||
|
mock_config_flow,
|
||||||
|
mock_integration,
|
||||||
|
mock_platform,
|
||||||
|
setup_test_component_platform,
|
||||||
|
)
|
||||||
|
|
||||||
|
TEST_DOMAIN = "test"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_tts(mock_tts_cache_dir: pathlib.Path) -> None:
|
||||||
|
"""Mock TTS cache dir fixture."""
|
||||||
|
|
||||||
|
|
||||||
|
class MockAssistSatellite(AssistSatelliteEntity):
|
||||||
|
"""Mock Assist Satellite Entity."""
|
||||||
|
|
||||||
|
_attr_name = "Test Entity"
|
||||||
|
_attr_supported_features = AssistSatelliteEntityFeature.ANNOUNCE
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize the mock entity."""
|
||||||
|
self.events = []
|
||||||
|
self.announcements = []
|
||||||
|
|
||||||
|
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||||
|
"""Handle pipeline events."""
|
||||||
|
self.events.append(event)
|
||||||
|
|
||||||
|
async def async_announce(self, message: str, media_id: str) -> None:
|
||||||
|
"""Announce media on a device."""
|
||||||
|
self.announcements.append((message, media_id))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def entity() -> MockAssistSatellite:
|
||||||
|
"""Mock Assist Satellite Entity."""
|
||||||
|
return MockAssistSatellite()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config_entry(hass: HomeAssistant) -> ConfigEntry:
|
||||||
|
"""Mock config entry."""
|
||||||
|
entry = MockConfigEntry(domain=TEST_DOMAIN)
|
||||||
|
entry.add_to_hass(hass)
|
||||||
|
return entry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def init_components(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config_entry: ConfigEntry,
|
||||||
|
entity: MockAssistSatellite,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize components."""
|
||||||
|
assert await async_setup_component(hass, "homeassistant", {})
|
||||||
|
|
||||||
|
async def async_setup_entry_init(
|
||||||
|
hass: HomeAssistant, config_entry: ConfigEntry
|
||||||
|
) -> bool:
|
||||||
|
"""Set up test config entry."""
|
||||||
|
await hass.config_entries.async_forward_entry_setups(config_entry, [AS_DOMAIN])
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def async_unload_entry_init(
|
||||||
|
hass: HomeAssistant, config_entry: ConfigEntry
|
||||||
|
) -> bool:
|
||||||
|
"""Unload test config entry."""
|
||||||
|
await hass.config_entries.async_forward_entry_unload(config_entry, AS_DOMAIN)
|
||||||
|
return True
|
||||||
|
|
||||||
|
mock_integration(
|
||||||
|
hass,
|
||||||
|
MockModule(
|
||||||
|
TEST_DOMAIN,
|
||||||
|
async_setup_entry=async_setup_entry_init,
|
||||||
|
async_unload_entry=async_unload_entry_init,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
setup_test_component_platform(hass, AS_DOMAIN, [entity], from_config_entry=True)
|
||||||
|
mock_platform(hass, f"{TEST_DOMAIN}.config_flow", Mock())
|
||||||
|
|
||||||
|
with mock_config_flow(TEST_DOMAIN, ConfigFlow):
|
||||||
|
assert await hass.config_entries.async_setup(config_entry.entry_id)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
return config_entry
|
332
tests/components/assist_satellite/test_entity.py
Normal file
332
tests/components/assist_satellite/test_entity.py
Normal file
@ -0,0 +1,332 @@
|
|||||||
|
"""Test the Assist Satellite entity."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.components import stt
|
||||||
|
from homeassistant.components.assist_pipeline import (
|
||||||
|
OPTION_PREFERRED,
|
||||||
|
AudioSettings,
|
||||||
|
Pipeline,
|
||||||
|
PipelineEvent,
|
||||||
|
PipelineEventType,
|
||||||
|
PipelineStage,
|
||||||
|
async_get_pipeline,
|
||||||
|
async_update_pipeline,
|
||||||
|
vad,
|
||||||
|
)
|
||||||
|
from homeassistant.components.assist_satellite import SatelliteBusyError
|
||||||
|
from homeassistant.components.assist_satellite.entity import AssistSatelliteState
|
||||||
|
from homeassistant.components.media_source import PlayMedia
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.const import STATE_UNKNOWN
|
||||||
|
from homeassistant.core import Context, HomeAssistant
|
||||||
|
|
||||||
|
from . import ENTITY_ID
|
||||||
|
from .conftest import MockAssistSatellite
|
||||||
|
|
||||||
|
|
||||||
|
async def test_entity_state(
|
||||||
|
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
|
||||||
|
) -> None:
|
||||||
|
"""Test entity state represent events."""
|
||||||
|
|
||||||
|
state = hass.states.get(ENTITY_ID)
|
||||||
|
assert state is not None
|
||||||
|
assert state.state == STATE_UNKNOWN
|
||||||
|
|
||||||
|
context = Context()
|
||||||
|
audio_stream = object()
|
||||||
|
|
||||||
|
entity.async_set_context(context)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream"
|
||||||
|
) as mock_start_pipeline:
|
||||||
|
await entity.async_accept_pipeline_from_satellite(audio_stream)
|
||||||
|
|
||||||
|
assert mock_start_pipeline.called
|
||||||
|
kwargs = mock_start_pipeline.call_args[1]
|
||||||
|
assert kwargs["context"] is context
|
||||||
|
assert kwargs["event_callback"] == entity._internal_on_pipeline_event
|
||||||
|
assert kwargs["stt_metadata"] == stt.SpeechMetadata(
|
||||||
|
language="",
|
||||||
|
format=stt.AudioFormats.WAV,
|
||||||
|
codec=stt.AudioCodecs.PCM,
|
||||||
|
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||||
|
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||||
|
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||||
|
)
|
||||||
|
assert kwargs["stt_stream"] is audio_stream
|
||||||
|
assert kwargs["pipeline_id"] is None
|
||||||
|
assert kwargs["device_id"] is None
|
||||||
|
assert kwargs["tts_audio_output"] == "wav"
|
||||||
|
assert kwargs["wake_word_phrase"] is None
|
||||||
|
assert kwargs["audio_settings"] == AudioSettings(
|
||||||
|
silence_seconds=vad.VadSensitivity.to_seconds(vad.VadSensitivity.DEFAULT)
|
||||||
|
)
|
||||||
|
assert kwargs["start_stage"] == PipelineStage.STT
|
||||||
|
assert kwargs["end_stage"] == PipelineStage.TTS
|
||||||
|
|
||||||
|
for event_type, expected_state in (
|
||||||
|
(PipelineEventType.RUN_START, STATE_UNKNOWN),
|
||||||
|
(PipelineEventType.RUN_END, AssistSatelliteState.LISTENING_WAKE_WORD),
|
||||||
|
(PipelineEventType.WAKE_WORD_START, AssistSatelliteState.LISTENING_WAKE_WORD),
|
||||||
|
(PipelineEventType.WAKE_WORD_END, AssistSatelliteState.LISTENING_WAKE_WORD),
|
||||||
|
(PipelineEventType.STT_START, AssistSatelliteState.LISTENING_COMMAND),
|
||||||
|
(PipelineEventType.STT_VAD_START, AssistSatelliteState.LISTENING_COMMAND),
|
||||||
|
(PipelineEventType.STT_VAD_END, AssistSatelliteState.LISTENING_COMMAND),
|
||||||
|
(PipelineEventType.STT_END, AssistSatelliteState.LISTENING_COMMAND),
|
||||||
|
(PipelineEventType.INTENT_START, AssistSatelliteState.PROCESSING),
|
||||||
|
(PipelineEventType.INTENT_END, AssistSatelliteState.PROCESSING),
|
||||||
|
(PipelineEventType.TTS_START, AssistSatelliteState.RESPONDING),
|
||||||
|
(PipelineEventType.TTS_END, AssistSatelliteState.RESPONDING),
|
||||||
|
(PipelineEventType.ERROR, AssistSatelliteState.RESPONDING),
|
||||||
|
):
|
||||||
|
kwargs["event_callback"](PipelineEvent(event_type, {}))
|
||||||
|
state = hass.states.get(ENTITY_ID)
|
||||||
|
assert state.state == expected_state, event_type
|
||||||
|
|
||||||
|
entity.tts_response_finished()
|
||||||
|
state = hass.states.get(ENTITY_ID)
|
||||||
|
assert state.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("service_data", "expected_params"),
|
||||||
|
[
|
||||||
|
(
|
||||||
|
{"message": "Hello"},
|
||||||
|
("Hello", "https://www.home-assistant.io/resolved.mp3"),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"message": "Hello",
|
||||||
|
"media_id": "http://example.com/bla.mp3",
|
||||||
|
},
|
||||||
|
("Hello", "http://example.com/bla.mp3"),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"media_id": "http://example.com/bla.mp3"},
|
||||||
|
("", "http://example.com/bla.mp3"),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_announce(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: ConfigEntry,
|
||||||
|
entity: MockAssistSatellite,
|
||||||
|
service_data: dict,
|
||||||
|
expected_params: tuple[str, str],
|
||||||
|
) -> None:
|
||||||
|
"""Test announcing on a device."""
|
||||||
|
await async_update_pipeline(
|
||||||
|
hass,
|
||||||
|
async_get_pipeline(hass),
|
||||||
|
tts_engine="tts.mock_entity",
|
||||||
|
tts_language="en",
|
||||||
|
tts_voice="test-voice",
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.assist_satellite.entity.tts_generate_media_source_id",
|
||||||
|
return_value="media-source://bla",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.media_source.async_resolve_media",
|
||||||
|
return_value=PlayMedia(
|
||||||
|
url="https://www.home-assistant.io/resolved.mp3",
|
||||||
|
mime_type="audio/mp3",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
await hass.services.async_call(
|
||||||
|
"assist_satellite",
|
||||||
|
"announce",
|
||||||
|
service_data,
|
||||||
|
target={"entity_id": "assist_satellite.test_entity"},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert entity.announcements[0] == expected_params
|
||||||
|
|
||||||
|
|
||||||
|
async def test_announce_busy(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: ConfigEntry,
|
||||||
|
entity: MockAssistSatellite,
|
||||||
|
) -> None:
|
||||||
|
"""Test that announcing while an announcement is in progress raises an error."""
|
||||||
|
media_id = "https://www.home-assistant.io/resolved.mp3"
|
||||||
|
announce_started = asyncio.Event()
|
||||||
|
got_error = asyncio.Event()
|
||||||
|
|
||||||
|
async def async_announce(message, media_id):
|
||||||
|
announce_started.set()
|
||||||
|
|
||||||
|
# Block so we can do another announcement
|
||||||
|
await got_error.wait()
|
||||||
|
|
||||||
|
with patch.object(entity, "async_announce", new=async_announce):
|
||||||
|
announce_task = asyncio.create_task(
|
||||||
|
entity.async_internal_announce(media_id=media_id)
|
||||||
|
)
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await announce_started.wait()
|
||||||
|
|
||||||
|
# Try to do a second announcement
|
||||||
|
with pytest.raises(SatelliteBusyError):
|
||||||
|
await entity.async_internal_announce(media_id=media_id)
|
||||||
|
|
||||||
|
# Avoid lingering task
|
||||||
|
got_error.set()
|
||||||
|
await announce_task
|
||||||
|
|
||||||
|
|
||||||
|
async def test_context_refresh(
|
||||||
|
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
|
||||||
|
) -> None:
|
||||||
|
"""Test that the context will be automatically refreshed."""
|
||||||
|
audio_stream = object()
|
||||||
|
|
||||||
|
# Remove context
|
||||||
|
entity._context = None
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream"
|
||||||
|
):
|
||||||
|
await entity.async_accept_pipeline_from_satellite(audio_stream)
|
||||||
|
|
||||||
|
# Context should have been refreshed
|
||||||
|
assert entity._context is not None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pipeline_entity(
|
||||||
|
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
|
||||||
|
) -> None:
|
||||||
|
"""Test getting pipeline from an entity."""
|
||||||
|
audio_stream = object()
|
||||||
|
pipeline = Pipeline(
|
||||||
|
conversation_engine="test",
|
||||||
|
conversation_language="en",
|
||||||
|
language="en",
|
||||||
|
name="test-pipeline",
|
||||||
|
stt_engine=None,
|
||||||
|
stt_language=None,
|
||||||
|
tts_engine=None,
|
||||||
|
tts_language=None,
|
||||||
|
tts_voice=None,
|
||||||
|
wake_word_entity=None,
|
||||||
|
wake_word_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline_entity_id = "select.pipeline"
|
||||||
|
hass.states.async_set(pipeline_entity_id, pipeline.name)
|
||||||
|
entity._attr_pipeline_entity_id = pipeline_entity_id
|
||||||
|
|
||||||
|
done = asyncio.Event()
|
||||||
|
|
||||||
|
async def async_pipeline_from_audio_stream(*args, pipeline_id: str, **kwargs):
|
||||||
|
assert pipeline_id == pipeline.id
|
||||||
|
done.set()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
|
new=async_pipeline_from_audio_stream,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.assist_satellite.entity.async_get_pipelines",
|
||||||
|
return_value=[pipeline],
|
||||||
|
),
|
||||||
|
):
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await entity.async_accept_pipeline_from_satellite(audio_stream)
|
||||||
|
await done.wait()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pipeline_entity_preferred(
|
||||||
|
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
|
||||||
|
) -> None:
|
||||||
|
"""Test getting pipeline from an entity with a preferred state."""
|
||||||
|
audio_stream = object()
|
||||||
|
|
||||||
|
pipeline_entity_id = "select.pipeline"
|
||||||
|
hass.states.async_set(pipeline_entity_id, OPTION_PREFERRED)
|
||||||
|
entity._attr_pipeline_entity_id = pipeline_entity_id
|
||||||
|
|
||||||
|
done = asyncio.Event()
|
||||||
|
|
||||||
|
async def async_pipeline_from_audio_stream(*args, pipeline_id: str, **kwargs):
|
||||||
|
# Preferred pipeline
|
||||||
|
assert pipeline_id is None
|
||||||
|
done.set()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
|
new=async_pipeline_from_audio_stream,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await entity.async_accept_pipeline_from_satellite(audio_stream)
|
||||||
|
await done.wait()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_vad_sensitivity_entity(
|
||||||
|
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
|
||||||
|
) -> None:
|
||||||
|
"""Test getting vad sensitivity from an entity."""
|
||||||
|
audio_stream = object()
|
||||||
|
|
||||||
|
vad_sensitivity_entity_id = "select.vad_sensitivity"
|
||||||
|
hass.states.async_set(vad_sensitivity_entity_id, vad.VadSensitivity.AGGRESSIVE)
|
||||||
|
entity._attr_vad_sensitivity_entity_id = vad_sensitivity_entity_id
|
||||||
|
|
||||||
|
done = asyncio.Event()
|
||||||
|
|
||||||
|
async def async_pipeline_from_audio_stream(
|
||||||
|
*args, audio_settings: AudioSettings, **kwargs
|
||||||
|
):
|
||||||
|
# Verify vad sensitivity
|
||||||
|
assert audio_settings.silence_seconds == vad.VadSensitivity.to_seconds(
|
||||||
|
vad.VadSensitivity.AGGRESSIVE
|
||||||
|
)
|
||||||
|
done.set()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
|
new=async_pipeline_from_audio_stream,
|
||||||
|
):
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await entity.async_accept_pipeline_from_satellite(audio_stream)
|
||||||
|
await done.wait()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pipeline_entity_not_found(
|
||||||
|
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
|
||||||
|
) -> None:
|
||||||
|
"""Test that setting the pipeline entity id to a non-existent entity raises an error."""
|
||||||
|
audio_stream = object()
|
||||||
|
|
||||||
|
# Set to an entity that doesn't exist
|
||||||
|
entity._attr_pipeline_entity_id = "select.pipeline"
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
await entity.async_accept_pipeline_from_satellite(audio_stream)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_vad_sensitivity_entity_not_found(
|
||||||
|
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
|
||||||
|
) -> None:
|
||||||
|
"""Test that setting the vad sensitivity entity id to a non-existent entity raises an error."""
|
||||||
|
audio_stream = object()
|
||||||
|
|
||||||
|
# Set to an entity that doesn't exist
|
||||||
|
entity._attr_vad_sensitivity_entity_id = "select.vad_sensitivity"
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
await entity.async_accept_pipeline_from_satellite(audio_stream)
|
192
tests/components/assist_satellite/test_websocket_api.py
Normal file
192
tests/components/assist_satellite/test_websocket_api.py
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
"""Test WebSocket API."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from homeassistant.components.assist_pipeline import PipelineStage
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
|
from . import ENTITY_ID
|
||||||
|
from .conftest import MockAssistSatellite
|
||||||
|
|
||||||
|
from tests.common import MockUser
|
||||||
|
from tests.typing import WebSocketGenerator
|
||||||
|
|
||||||
|
|
||||||
|
async def test_intercept_wake_word(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: ConfigEntry,
|
||||||
|
entity: MockAssistSatellite,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test intercepting a wake word."""
|
||||||
|
ws_client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
await ws_client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_satellite/intercept_wake_word",
|
||||||
|
"entity_id": ENTITY_ID,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
await entity.async_accept_pipeline_from_satellite(
|
||||||
|
object(),
|
||||||
|
start_stage=PipelineStage.STT,
|
||||||
|
wake_word_phrase="ok, nabu",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await ws_client.receive_json()
|
||||||
|
|
||||||
|
assert response["success"]
|
||||||
|
assert response["result"] == {"wake_word_phrase": "ok, nabu"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_intercept_wake_word_requires_on_device_wake_word(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: ConfigEntry,
|
||||||
|
entity: MockAssistSatellite,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test intercepting a wake word fails if detection happens in HA."""
|
||||||
|
ws_client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
await ws_client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_satellite/intercept_wake_word",
|
||||||
|
"entity_id": ENTITY_ID,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
await entity.async_accept_pipeline_from_satellite(
|
||||||
|
object(),
|
||||||
|
# Emulate wake word processing in Home Assistant
|
||||||
|
start_stage=PipelineStage.WAKE_WORD,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await ws_client.receive_json()
|
||||||
|
assert not response["success"]
|
||||||
|
assert response["error"] == {
|
||||||
|
"code": "home_assistant_error",
|
||||||
|
"message": "Only on-device wake words currently supported",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_intercept_wake_word_requires_wake_word_phrase(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: ConfigEntry,
|
||||||
|
entity: MockAssistSatellite,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test intercepting a wake word fails if detection happens in HA."""
|
||||||
|
ws_client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
await ws_client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_satellite/intercept_wake_word",
|
||||||
|
"entity_id": ENTITY_ID,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
await entity.async_accept_pipeline_from_satellite(
|
||||||
|
object(),
|
||||||
|
start_stage=PipelineStage.STT,
|
||||||
|
# We are not passing wake word phrase
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await ws_client.receive_json()
|
||||||
|
assert not response["success"]
|
||||||
|
assert response["error"] == {
|
||||||
|
"code": "home_assistant_error",
|
||||||
|
"message": "No wake word phrase provided",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_intercept_wake_word_require_admin(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: ConfigEntry,
|
||||||
|
entity: MockAssistSatellite,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
hass_admin_user: MockUser,
|
||||||
|
) -> None:
|
||||||
|
"""Test intercepting a wake word requires admin access."""
|
||||||
|
# Remove admin permission and verify we're not allowed
|
||||||
|
hass_admin_user.groups = []
|
||||||
|
ws_client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
await ws_client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_satellite/intercept_wake_word",
|
||||||
|
"entity_id": ENTITY_ID,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
response = await ws_client.receive_json()
|
||||||
|
|
||||||
|
assert not response["success"]
|
||||||
|
assert response["error"] == {
|
||||||
|
"code": "unauthorized",
|
||||||
|
"message": "Unauthorized",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_intercept_wake_word_invalid_satellite(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: ConfigEntry,
|
||||||
|
entity: MockAssistSatellite,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test intercepting a wake word requires admin access."""
|
||||||
|
ws_client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
await ws_client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_satellite/intercept_wake_word",
|
||||||
|
"entity_id": "assist_satellite.invalid",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
response = await ws_client.receive_json()
|
||||||
|
|
||||||
|
assert not response["success"]
|
||||||
|
assert response["error"] == {
|
||||||
|
"code": "not_found",
|
||||||
|
"message": "Entity not found",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_intercept_wake_word_twice(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: ConfigEntry,
|
||||||
|
entity: MockAssistSatellite,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test intercepting a wake word requires admin access."""
|
||||||
|
ws_client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
await ws_client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_satellite/intercept_wake_word",
|
||||||
|
"entity_id": ENTITY_ID,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
await ws_client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_satellite/intercept_wake_word",
|
||||||
|
"entity_id": ENTITY_ID,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
response = await ws_client.receive_json()
|
||||||
|
|
||||||
|
assert not response["success"]
|
||||||
|
assert response["error"] == {
|
||||||
|
"code": "home_assistant_error",
|
||||||
|
"message": "Wake word interception already in progress",
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user