diff --git a/CODEOWNERS b/CODEOWNERS index e8617ad7703..084d83b0da1 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1373,6 +1373,8 @@ build.json @home-assistant/supervisor /tests/components/vulcan/ @Antoni-Czaplicki /homeassistant/components/wake_on_lan/ @ntilley905 /tests/components/wake_on_lan/ @ntilley905 +/homeassistant/components/wake_word/ @home-assistant/core @synesthesiam +/tests/components/wake_word/ @home-assistant/core @synesthesiam /homeassistant/components/wallbox/ @hesselonline /tests/components/wallbox/ @hesselonline /homeassistant/components/waqi/ @andrey-git diff --git a/homeassistant/components/assist_pipeline/__init__.py b/homeassistant/components/assist_pipeline/__init__.py index 55b192a730a..c2d25da2162 100644 --- a/homeassistant/components/assist_pipeline/__init__.py +++ b/homeassistant/components/assist_pipeline/__init__.py @@ -18,6 +18,7 @@ from .pipeline import ( PipelineInput, PipelineRun, PipelineStage, + WakeWordSettings, async_create_default_pipeline, async_get_pipeline, async_get_pipelines, @@ -35,6 +36,7 @@ __all__ = ( "PipelineEvent", "PipelineEventType", "PipelineNotFound", + "WakeWordSettings", ) CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN) @@ -57,7 +59,10 @@ async def async_pipeline_from_audio_stream( pipeline_id: str | None = None, conversation_id: str | None = None, tts_audio_output: str | None = None, + wake_word_settings: WakeWordSettings | None = None, device_id: str | None = None, + start_stage: PipelineStage = PipelineStage.STT, + end_stage: PipelineStage = PipelineStage.TTS, ) -> None: """Create an audio pipeline from an audio stream. @@ -72,10 +77,11 @@ async def async_pipeline_from_audio_stream( hass, context=context, pipeline=async_get_pipeline(hass, pipeline_id=pipeline_id), - start_stage=PipelineStage.STT, - end_stage=PipelineStage.TTS, + start_stage=start_stage, + end_stage=end_stage, event_callback=event_callback, tts_audio_output=tts_audio_output, + wake_word_settings=wake_word_settings, ), ) await pipeline_input.validate() diff --git a/homeassistant/components/assist_pipeline/error.py b/homeassistant/components/assist_pipeline/error.py index c5ffdcaf2d3..094913424b6 100644 --- a/homeassistant/components/assist_pipeline/error.py +++ b/homeassistant/components/assist_pipeline/error.py @@ -18,6 +18,14 @@ class PipelineNotFound(PipelineError): """Unspecified pipeline picked.""" +class WakeWordDetectionError(PipelineError): + """Error in wake-word-detection portion of pipeline.""" + + +class WakeWordTimeoutError(WakeWordDetectionError): + """Timeout when wake word was not detected.""" + + class SpeechToTextError(PipelineError): """Error in speech-to-text portion of pipeline.""" diff --git a/homeassistant/components/assist_pipeline/manifest.json b/homeassistant/components/assist_pipeline/manifest.json index e97ceae5dec..1db415b29d2 100644 --- a/homeassistant/components/assist_pipeline/manifest.json +++ b/homeassistant/components/assist_pipeline/manifest.json @@ -2,7 +2,7 @@ "domain": "assist_pipeline", "name": "Assist pipeline", "codeowners": ["@balloob", "@synesthesiam"], - "dependencies": ["conversation", "stt", "tts"], + "dependencies": ["conversation", "stt", "tts", "wake_word"], "documentation": "https://www.home-assistant.io/integrations/assist_pipeline", "iot_class": "local_push", "quality_scale": "internal", diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index 1be9ddbb14f..3303895eec2 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -2,7 +2,7 @@ from __future__ import annotations import asyncio -from collections.abc import AsyncIterable, Callable, Iterable +from collections.abc import AsyncGenerator, AsyncIterable, Callable, Iterable from dataclasses import asdict, dataclass, field from enum import StrEnum import logging @@ -10,7 +10,14 @@ from typing import Any, cast import voluptuous as vol -from homeassistant.components import conversation, media_source, stt, tts, websocket_api +from homeassistant.components import ( + conversation, + media_source, + stt, + tts, + wake_word, + websocket_api, +) from homeassistant.components.tts.media_source import ( generate_media_source_id as tts_generate_media_source_id, ) @@ -39,7 +46,10 @@ from .error import ( PipelineNotFound, SpeechToTextError, TextToSpeechError, + WakeWordDetectionError, + WakeWordTimeoutError, ) +from .vad import VoiceActivityTimeout, VoiceCommandSegmenter _LOGGER = logging.getLogger(__name__) @@ -241,6 +251,8 @@ class PipelineEventType(StrEnum): RUN_START = "run-start" RUN_END = "run-end" + WAKE_WORD_START = "wake_word-start" + WAKE_WORD_END = "wake_word-end" STT_START = "stt-start" STT_END = "stt-end" INTENT_START = "intent-start" @@ -297,12 +309,14 @@ class Pipeline: class PipelineStage(StrEnum): """Stages of a pipeline.""" + WAKE_WORD = "wake_word" STT = "stt" INTENT = "intent" TTS = "tts" PIPELINE_STAGE_ORDER = [ + PipelineStage.WAKE_WORD, PipelineStage.STT, PipelineStage.INTENT, PipelineStage.TTS, @@ -327,6 +341,17 @@ class InvalidPipelineStagesError(PipelineRunValidationError): ) +@dataclass(frozen=True) +class WakeWordSettings: + """Settings for wake word detection.""" + + timeout: float | None = None + """Seconds of silence before detection times out.""" + + audio_seconds_to_buffer: float = 0 + """Seconds of audio to buffer before detection and forward to STT.""" + + @dataclass class PipelineRun: """Running context for a pipeline.""" @@ -341,17 +366,20 @@ class PipelineRun: runner_data: Any | None = None intent_agent: str | None = None tts_audio_output: str | None = None + wake_word_settings: WakeWordSettings | None = None id: str = field(default_factory=ulid_util.ulid) stt_provider: stt.SpeechToTextEntity | stt.Provider = field(init=False) tts_engine: str = field(init=False) tts_options: dict | None = field(init=False, default=None) + wake_word_engine: str = field(init=False) + wake_word_provider: wake_word.WakeWordDetectionEntity = field(init=False) def __post_init__(self) -> None: """Set language for pipeline.""" self.language = self.pipeline.language or self.hass.config.language - # stt -> intent -> tts + # wake -> stt -> intent -> tts if PIPELINE_STAGE_ORDER.index(self.end_stage) < PIPELINE_STAGE_ORDER.index( self.start_stage ): @@ -393,6 +421,141 @@ class PipelineRun: ) ) + async def prepare_wake_word_detection(self) -> None: + """Prepare wake-word-detection.""" + # Need to add to pipeline store + engine = wake_word.async_default_engine(self.hass) + if engine is None: + raise WakeWordDetectionError( + code="wake-engine-missing", + message="No wake word engine", + ) + + wake_word_provider = wake_word.async_get_wake_word_detection_entity( + self.hass, engine + ) + if wake_word_provider is None: + raise WakeWordDetectionError( + code="wake-provider-missing", + message=f"No wake-word-detection provider for: {engine}", + ) + + self.wake_word_engine = engine + self.wake_word_provider = wake_word_provider + + async def wake_word_detection( + self, + stream: AsyncIterable[bytes], + audio_buffer: list[bytes], + ) -> wake_word.DetectionResult | None: + """Run wake-word-detection portion of pipeline. Returns detection result.""" + metadata_dict = asdict( + stt.SpeechMetadata( + language="", + format=stt.AudioFormats.WAV, + codec=stt.AudioCodecs.PCM, + bit_rate=stt.AudioBitRates.BITRATE_16, + sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, + channel=stt.AudioChannels.CHANNEL_MONO, + ) + ) + + # Remove language since it doesn't apply to wake words yet + metadata_dict.pop("language", None) + + self.process_event( + PipelineEvent( + PipelineEventType.WAKE_WORD_START, + { + "engine": self.wake_word_engine, + "metadata": metadata_dict, + }, + ) + ) + + wake_word_settings = self.wake_word_settings or WakeWordSettings() + + wake_word_vad: VoiceActivityTimeout | None = None + if (wake_word_settings.timeout is not None) and ( + wake_word_settings.timeout > 0 + ): + # Use VAD to determine timeout + wake_word_vad = VoiceActivityTimeout(wake_word_settings.timeout) + + # Audio chunk buffer. + audio_bytes_to_buffer = int( + wake_word_settings.audio_seconds_to_buffer * 16000 * 2 + ) + audio_ring_buffer = b"" + + async def timestamped_stream() -> AsyncIterable[tuple[bytes, int]]: + """Yield audio with timestamps (milliseconds since start of stream).""" + nonlocal audio_ring_buffer + + timestamp_ms = 0 + async for chunk in stream: + yield chunk, timestamp_ms + timestamp_ms += (len(chunk) // 2) // 16 # milliseconds @ 16Khz + + # Keeping audio right before wake word detection allows the + # voice command to be spoken immediately after the wake word. + if audio_bytes_to_buffer > 0: + audio_ring_buffer += chunk + if len(audio_ring_buffer) > audio_bytes_to_buffer: + # A proper ring buffer would be far more efficient + audio_ring_buffer = audio_ring_buffer[ + len(audio_ring_buffer) - audio_bytes_to_buffer : + ] + + if (wake_word_vad is not None) and (not wake_word_vad.process(chunk)): + raise WakeWordTimeoutError( + code="wake-word-timeout", message="Wake word was not detected" + ) + + try: + # Detect wake word(s) + result = await self.wake_word_provider.async_process_audio_stream( + timestamped_stream() + ) + + if audio_ring_buffer: + # All audio kept from right before the wake word was detected as + # a single chunk. + audio_buffer.append(audio_ring_buffer) + except WakeWordTimeoutError: + _LOGGER.debug("Timeout during wake word detection") + raise + except Exception as src_error: + _LOGGER.exception("Unexpected error during wake-word-detection") + raise WakeWordDetectionError( + code="wake-stream-failed", + message="Unexpected error during wake-word-detection", + ) from src_error + + _LOGGER.debug("wake-word-detection result %s", result) + + if result is None: + wake_word_output: dict[str, Any] = {} + else: + if result.queued_audio: + # Add audio that was pending at detection + for chunk_ts in result.queued_audio: + audio_buffer.append(chunk_ts[0]) + + wake_word_output = asdict(result) + + # Remove non-JSON fields + wake_word_output.pop("queued_audio", None) + + self.process_event( + PipelineEvent( + PipelineEventType.WAKE_WORD_END, + {"wake_word_output": wake_word_output}, + ) + ) + + return result + async def prepare_speech_to_text(self, metadata: stt.SpeechMetadata) -> None: """Prepare speech-to-text.""" # pipeline.stt_engine can't be None or this function is not called @@ -443,9 +606,21 @@ class PipelineRun: ) try: + segmenter = VoiceCommandSegmenter() + + async def segment_stream( + stream: AsyncIterable[bytes], + ) -> AsyncGenerator[bytes, None]: + """Stop stream when voice command is finished.""" + async for chunk in stream: + if not segmenter.process(chunk): + break + + yield chunk + # Transcribe audio stream result = await self.stt_provider.async_process_audio_stream( - metadata, stream + metadata, segment_stream(stream) ) except Exception as src_error: _LOGGER.exception("Unexpected error during speech-to-text") @@ -663,17 +838,45 @@ class PipelineInput: async def execute(self) -> None: """Run pipeline.""" self.run.start() - current_stage = self.run.start_stage + current_stage: PipelineStage | None = self.run.start_stage + audio_buffer: list[bytes] = [] try: + if current_stage == PipelineStage.WAKE_WORD: + assert self.stt_stream is not None + detect_result = await self.run.wake_word_detection( + self.stt_stream, audio_buffer + ) + if detect_result is None: + # No wake word. Abort the rest of the pipeline. + self.run.end() + return + + current_stage = PipelineStage.STT + # speech-to-text intent_input = self.intent_input if current_stage == PipelineStage.STT: assert self.stt_metadata is not None assert self.stt_stream is not None + + if audio_buffer: + + async def buffered_stream() -> AsyncGenerator[bytes, None]: + for chunk in audio_buffer: + yield chunk + + assert self.stt_stream is not None + async for chunk in self.stt_stream: + yield chunk + + stt_stream = cast(AsyncIterable[bytes], buffered_stream()) + else: + stt_stream = self.stt_stream + intent_input = await self.run.speech_to_text( self.stt_metadata, - self.stt_stream, + stt_stream, ) current_stage = PipelineStage.INTENT @@ -707,7 +910,7 @@ class PipelineInput: async def validate(self) -> None: """Validate pipeline input against start stage.""" - if self.run.start_stage == PipelineStage.STT: + if self.run.start_stage in (PipelineStage.WAKE_WORD, PipelineStage.STT): if self.run.pipeline.stt_engine is None: raise PipelineRunValidationError( "the pipeline does not support speech-to-text" @@ -741,6 +944,13 @@ class PipelineInput: prepare_tasks = [] + if ( + start_stage_index + <= PIPELINE_STAGE_ORDER.index(PipelineStage.WAKE_WORD) + <= end_stage_index + ): + prepare_tasks.append(self.run.prepare_wake_word_detection()) + if ( start_stage_index <= PIPELINE_STAGE_ORDER.index(PipelineStage.STT) diff --git a/homeassistant/components/assist_pipeline/vad.py b/homeassistant/components/assist_pipeline/vad.py index cb19811d650..cae31671a3c 100644 --- a/homeassistant/components/assist_pipeline/vad.py +++ b/homeassistant/components/assist_pipeline/vad.py @@ -88,7 +88,7 @@ class VoiceCommandSegmenter: self.in_command = False def process(self, samples: bytes) -> bool: - """Process a 16-bit 16Khz mono audio samples. + """Process 16-bit 16Khz mono audio samples. Returns False when command is done. """ @@ -148,3 +148,94 @@ class VoiceCommandSegmenter: self._silence_seconds_left = self.silence_seconds return True + + +@dataclass +class VoiceActivityTimeout: + """Detects silence in audio until a timeout is reached.""" + + silence_seconds: float + """Seconds of silence before timeout.""" + + reset_seconds: float = 0.5 + """Seconds of speech before resetting timeout.""" + + vad_mode: int = 3 + """Aggressiveness in filtering out non-speech. 3 is the most aggressive.""" + + vad_frames: int = 480 # 30 ms + """Must be 10, 20, or 30 ms at 16Khz.""" + + _silence_seconds_left: float = 0.0 + """Seconds left before considering voice command as stopped.""" + + _reset_seconds_left: float = 0.0 + """Seconds left before resetting start/stop time counters.""" + + _vad: webrtcvad.Vad = None + _audio_buffer: bytes = field(default_factory=bytes) + _bytes_per_chunk: int = 480 * 2 # 16-bit samples + _seconds_per_chunk: float = 0.03 # 30 ms + + def __post_init__(self) -> None: + """Initialize VAD.""" + self._vad = webrtcvad.Vad(self.vad_mode) + self._bytes_per_chunk = self.vad_frames * 2 + self._seconds_per_chunk = self.vad_frames / _SAMPLE_RATE + self.reset() + + def reset(self) -> None: + """Reset all counters and state.""" + self._audio_buffer = b"" + self._silence_seconds_left = self.silence_seconds + self._reset_seconds_left = self.reset_seconds + + def process(self, samples: bytes) -> bool: + """Process 16-bit 16Khz mono audio samples. + + Returns False when timeout is reached. + """ + self._audio_buffer += samples + + # Process in 10, 20, or 30 ms chunks. + num_chunks = len(self._audio_buffer) // self._bytes_per_chunk + for chunk_idx in range(num_chunks): + chunk_offset = chunk_idx * self._bytes_per_chunk + chunk = self._audio_buffer[ + chunk_offset : chunk_offset + self._bytes_per_chunk + ] + if not self._process_chunk(chunk): + return False + + if num_chunks > 0: + # Remove from buffer + self._audio_buffer = self._audio_buffer[ + num_chunks * self._bytes_per_chunk : + ] + + return True + + def _process_chunk(self, chunk: bytes) -> bool: + """Process a single chunk of 16-bit 16Khz mono audio. + + Returns False when timeout is reached. + """ + if self._vad.is_speech(chunk, _SAMPLE_RATE): + # Speech + self._reset_seconds_left -= self._seconds_per_chunk + if self._reset_seconds_left <= 0: + # Reset timeout + self._silence_seconds_left = self.silence_seconds + else: + # Silence + self._silence_seconds_left -= self._seconds_per_chunk + if self._silence_seconds_left <= 0: + # Timeout reached + return False + + # Slowly build reset counter back up + self._reset_seconds_left = min( + self.reset_seconds, self._reset_seconds_left + self._seconds_per_chunk + ) + + return True diff --git a/homeassistant/components/assist_pipeline/websocket_api.py b/homeassistant/components/assist_pipeline/websocket_api.py index 4e6d44a8868..bf61b9776e9 100644 --- a/homeassistant/components/assist_pipeline/websocket_api.py +++ b/homeassistant/components/assist_pipeline/websocket_api.py @@ -26,11 +26,12 @@ from .pipeline import ( PipelineInput, PipelineRun, PipelineStage, + WakeWordSettings, async_get_pipeline, ) -from .vad import VoiceCommandSegmenter DEFAULT_TIMEOUT = 30 +DEFAULT_WAKE_WORD_TIMEOUT = 3 _LOGGER = logging.getLogger(__name__) @@ -63,6 +64,18 @@ def async_register_websocket_api(hass: HomeAssistant) -> None: cv.key_value_schemas( "start_stage", { + PipelineStage.WAKE_WORD: vol.Schema( + { + vol.Required("input"): { + vol.Required("sample_rate"): int, + vol.Optional("timeout"): vol.Any(float, int), + vol.Optional("audio_seconds_to_buffer"): vol.Any( + float, int + ), + } + }, + extra=vol.ALLOW_EXTRA, + ), PipelineStage.STT: vol.Schema( {vol.Required("input"): {vol.Required("sample_rate"): int}}, extra=vol.ALLOW_EXTRA, @@ -102,6 +115,7 @@ async def websocket_run( end_stage = PipelineStage(msg["end_stage"]) handler_id: int | None = None unregister_handler: Callable[[], None] | None = None + wake_word_settings: WakeWordSettings | None = None # Arguments to PipelineInput input_args: dict[str, Any] = { @@ -109,24 +123,26 @@ async def websocket_run( "device_id": msg.get("device_id"), } - if start_stage == PipelineStage.STT: + if start_stage in (PipelineStage.WAKE_WORD, PipelineStage.STT): # Audio pipeline that will receive audio as binary websocket messages audio_queue: asyncio.Queue[bytes] = asyncio.Queue() incoming_sample_rate = msg["input"]["sample_rate"] + if start_stage == PipelineStage.WAKE_WORD: + wake_word_settings = WakeWordSettings( + timeout=msg["input"].get("timeout", DEFAULT_WAKE_WORD_TIMEOUT), + audio_seconds_to_buffer=msg["input"].get("audio_seconds_to_buffer", 0), + ) + async def stt_stream() -> AsyncGenerator[bytes, None]: state = None - segmenter = VoiceCommandSegmenter() # Yield until we receive an empty chunk while chunk := await audio_queue.get(): - chunk, state = audioop.ratecv( - chunk, 2, 1, incoming_sample_rate, 16000, state - ) - if not segmenter.process(chunk): - # Voice command is finished - break - + if incoming_sample_rate != 16000: + chunk, state = audioop.ratecv( + chunk, 2, 1, incoming_sample_rate, 16000, state + ) yield chunk def handle_binary( @@ -169,6 +185,7 @@ async def websocket_run( "stt_binary_handler_id": handler_id, "timeout": timeout, }, + wake_word_settings=wake_word_settings, ) pipeline_input = PipelineInput(**input_args) diff --git a/homeassistant/components/wake_word/__init__.py b/homeassistant/components/wake_word/__init__.py new file mode 100644 index 00000000000..f33d06c64da --- /dev/null +++ b/homeassistant/components/wake_word/__init__.py @@ -0,0 +1,119 @@ +"""Provide functionality to wake word.""" +from __future__ import annotations + +from abc import abstractmethod +from collections.abc import AsyncIterable +import logging +from typing import final + +from homeassistant.config_entries import ConfigEntry +from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN +from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers import config_validation as cv +from homeassistant.helpers.entity_component import EntityComponent +from homeassistant.helpers.restore_state import RestoreEntity +from homeassistant.helpers.typing import ConfigType +from homeassistant.util import dt as dt_util + +from .const import DOMAIN +from .models import DetectionResult, WakeWord + +__all__ = [ + "async_default_engine", + "async_get_wake_word_detection_entity", + "DetectionResult", + "DOMAIN", + "WakeWord", + "WakeWordDetectionEntity", +] + +_LOGGER = logging.getLogger(__name__) + +CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN) + + +@callback +def async_default_engine(hass: HomeAssistant) -> str | None: + """Return the domain or entity id of the default engine.""" + return next(iter(hass.states.async_entity_ids(DOMAIN)), None) + + +@callback +def async_get_wake_word_detection_entity( + hass: HomeAssistant, entity_id: str +) -> WakeWordDetectionEntity | None: + """Return wake word entity.""" + component: EntityComponent = hass.data[DOMAIN] + + return component.get_entity(entity_id) + + +async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: + """Set up STT.""" + component = hass.data[DOMAIN] = EntityComponent(_LOGGER, DOMAIN, hass) + component.register_shutdown() + + return True + + +async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + """Set up a config entry.""" + component: EntityComponent = hass.data[DOMAIN] + return await component.async_setup_entry(entry) + + +async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + """Unload a config entry.""" + component: EntityComponent = hass.data[DOMAIN] + return await component.async_unload_entry(entry) + + +class WakeWordDetectionEntity(RestoreEntity): + """Represent a single wake word provider.""" + + _attr_should_poll = False + __last_processed: str | None = None + + @property + @final + def state(self) -> str | None: + """Return the state of the entity.""" + if self.__last_processed is None: + return None + return self.__last_processed + + @property + @abstractmethod + def supported_wake_words(self) -> list[WakeWord]: + """Return a list of supported wake words.""" + + @abstractmethod + async def _async_process_audio_stream( + self, stream: AsyncIterable[tuple[bytes, int]] + ) -> DetectionResult | None: + """Try to detect wake word(s) in an audio stream with timestamps. + + Audio must be 16Khz sample rate with 16-bit mono PCM samples. + """ + + async def async_process_audio_stream( + self, stream: AsyncIterable[tuple[bytes, int]] + ) -> DetectionResult | None: + """Try to detect wake word(s) in an audio stream with timestamps. + + Audio must be 16Khz sample rate with 16-bit mono PCM samples. + """ + self.__last_processed = dt_util.utcnow().isoformat() + self.async_write_ha_state() + return await self._async_process_audio_stream(stream) + + async def async_internal_added_to_hass(self) -> None: + """Call when the entity is added to hass.""" + await super().async_internal_added_to_hass() + state = await self.async_get_last_state() + if ( + state is not None + and state.state is not None + and state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN) + ): + self.__last_processed = state.state diff --git a/homeassistant/components/wake_word/const.py b/homeassistant/components/wake_word/const.py new file mode 100644 index 00000000000..fdca6cfab6e --- /dev/null +++ b/homeassistant/components/wake_word/const.py @@ -0,0 +1,2 @@ +"""Wake word constants.""" +DOMAIN = "wake_word" diff --git a/homeassistant/components/wake_word/manifest.json b/homeassistant/components/wake_word/manifest.json new file mode 100644 index 00000000000..7834fad665c --- /dev/null +++ b/homeassistant/components/wake_word/manifest.json @@ -0,0 +1,8 @@ +{ + "domain": "wake_word", + "name": "Wake-word detection", + "codeowners": ["@home-assistant/core", "@synesthesiam"], + "documentation": "https://www.home-assistant.io/integrations/wake_word", + "integration_type": "entity", + "quality_scale": "internal" +} diff --git a/homeassistant/components/wake_word/models.py b/homeassistant/components/wake_word/models.py new file mode 100644 index 00000000000..1ea154f1393 --- /dev/null +++ b/homeassistant/components/wake_word/models.py @@ -0,0 +1,24 @@ +"""Wake word models.""" +from dataclasses import dataclass + + +@dataclass(frozen=True) +class WakeWord: + """Wake word model.""" + + ww_id: str + name: str + + +@dataclass +class DetectionResult: + """Result of wake word detection.""" + + ww_id: str + """Id of detected wake word""" + + timestamp: int | None + """Timestamp of audio chunk with detected wake word""" + + queued_audio: list[tuple[bytes, int]] | None = None + """Audio chunks that were queued when wake word was detected.""" diff --git a/homeassistant/components/wyoming/config_flow.py b/homeassistant/components/wyoming/config_flow.py index d7d5d0278e8..3fccbaea9c4 100644 --- a/homeassistant/components/wyoming/config_flow.py +++ b/homeassistant/components/wyoming/config_flow.py @@ -50,14 +50,21 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): errors={"base": "cannot_connect"}, ) - # ASR = automated speech recognition (STT) + # ASR = automated speech recognition (speech-to-text) asr_installed = [asr for asr in service.info.asr if asr.installed] + + # TTS = text-to-speech tts_installed = [tts for tts in service.info.tts if tts.installed] + # wake-word-detection + wake_installed = [wake for wake in service.info.wake if wake.installed] + if asr_installed: name = asr_installed[0].name elif tts_installed: name = tts_installed[0].name + elif wake_installed: + name = wake_installed[0].name else: return self.async_abort(reason="no_services") diff --git a/homeassistant/components/wyoming/data.py b/homeassistant/components/wyoming/data.py index c2d71835c65..1fe4d60b974 100644 --- a/homeassistant/components/wyoming/data.py +++ b/homeassistant/components/wyoming/data.py @@ -29,6 +29,8 @@ class WyomingService: platforms.append(Platform.STT) if any(tts.installed for tts in info.tts): platforms.append(Platform.TTS) + if any(wake.installed for wake in info.wake): + platforms.append(Platform.WAKE_WORD) self.platforms = platforms @classmethod diff --git a/homeassistant/components/wyoming/wake_word.py b/homeassistant/components/wyoming/wake_word.py new file mode 100644 index 00000000000..0e7fb3c4429 --- /dev/null +++ b/homeassistant/components/wyoming/wake_word.py @@ -0,0 +1,157 @@ +"""Support for Wyoming wake-word-detection services.""" +import asyncio +from collections.abc import AsyncIterable +import logging + +from wyoming.audio import AudioChunk, AudioStart +from wyoming.client import AsyncTcpClient +from wyoming.wake import Detection + +from homeassistant.components import wake_word +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant +from homeassistant.helpers.entity_platform import AddEntitiesCallback + +from .const import DOMAIN +from .data import WyomingService +from .error import WyomingError + +_LOGGER = logging.getLogger(__name__) + + +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, +) -> None: + """Set up Wyoming wake-word-detection.""" + service: WyomingService = hass.data[DOMAIN][config_entry.entry_id] + async_add_entities( + [ + WyomingWakeWordProvider(config_entry, service), + ] + ) + + +class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity): + """Wyoming wake-word-detection provider.""" + + def __init__( + self, + config_entry: ConfigEntry, + service: WyomingService, + ) -> None: + """Set up provider.""" + self.service = service + wake_service = service.info.wake[0] + + self._supported_wake_words = [ + wake_word.WakeWord(ww_id=ww.name, name=ww.name) + for ww in wake_service.models + ] + self._attr_name = wake_service.name + self._attr_unique_id = f"{config_entry.entry_id}-wake_word" + + @property + def supported_wake_words(self) -> list[wake_word.WakeWord]: + """Return a list of supported wake words.""" + return self._supported_wake_words + + async def _async_process_audio_stream( + self, stream: AsyncIterable[tuple[bytes, int]] + ) -> wake_word.DetectionResult | None: + """Try to detect one or more wake words in an audio stream. + + Audio must be 16Khz sample rate with 16-bit mono PCM samples. + """ + + async def next_chunk(): + """Get the next chunk from audio stream.""" + async for chunk_bytes in stream: + return chunk_bytes + + try: + async with AsyncTcpClient(self.service.host, self.service.port) as client: + await client.write_event( + AudioStart( + rate=16000, + width=2, + channels=1, + ).event(), + ) + + # Read audio and wake events in "parallel" + audio_task = asyncio.create_task(next_chunk()) + wake_task = asyncio.create_task(client.read_event()) + pending = {audio_task, wake_task} + + try: + while True: + done, pending = await asyncio.wait( + pending, return_when=asyncio.FIRST_COMPLETED + ) + + if wake_task in done: + event = wake_task.result() + if event is None: + _LOGGER.debug("Connection lost") + break + + if Detection.is_type(event.type): + # Successful detection + detection = Detection.from_event(event) + _LOGGER.info(detection) + + # Retrieve queued audio + queued_audio: list[tuple[bytes, int]] | None = None + if audio_task in pending: + # Save queued audio + await audio_task + pending.remove(audio_task) + queued_audio = [audio_task.result()] + + return wake_word.DetectionResult( + ww_id=detection.name, + timestamp=detection.timestamp, + queued_audio=queued_audio, + ) + + # Next event + wake_task = asyncio.create_task(client.read_event()) + pending.add(wake_task) + + if audio_task in done: + # Forward audio to wake service + chunk_info = audio_task.result() + if chunk_info is None: + break + + chunk_bytes, chunk_timestamp = chunk_info + chunk = AudioChunk( + rate=16000, + width=2, + channels=1, + audio=chunk_bytes, + timestamp=chunk_timestamp, + ) + await client.write_event(chunk.event()) + + # Next chunk + audio_task = asyncio.create_task(next_chunk()) + pending.add(audio_task) + finally: + # Clean up + if audio_task in pending: + # It's critical that we don't cancel the audio task or + # leave it hanging. This would mess up the pipeline STT + # by stopping the audio stream. + await audio_task + pending.remove(audio_task) + + for task in pending: + task.cancel() + + except (OSError, WyomingError) as err: + _LOGGER.exception("Error processing audio stream: %s", err) + + return None diff --git a/homeassistant/const.py b/homeassistant/const.py index a41710f1280..adca3dc965c 100644 --- a/homeassistant/const.py +++ b/homeassistant/const.py @@ -57,6 +57,7 @@ class Platform(StrEnum): TTS = "tts" VACUUM = "vacuum" UPDATE = "update" + WAKE_WORD = "wake_word" WATER_HEATER = "water_heater" WEATHER = "weather" diff --git a/tests/components/assist_pipeline/conftest.py b/tests/components/assist_pipeline/conftest.py index 5aa760cc606..0cc18d73e6f 100644 --- a/tests/components/assist_pipeline/conftest.py +++ b/tests/components/assist_pipeline/conftest.py @@ -7,7 +7,7 @@ from unittest.mock import AsyncMock import pytest -from homeassistant.components import stt, tts +from homeassistant.components import stt, tts, wake_word from homeassistant.components.assist_pipeline import DOMAIN from homeassistant.components.assist_pipeline.pipeline import ( PipelineData, @@ -174,6 +174,40 @@ class MockSttPlatform(MockPlatform): self.async_get_engine = async_get_engine +class MockWakeWordEntity(wake_word.WakeWordDetectionEntity): + """Mock wake word entity.""" + + fail_process_audio = False + url_path = "wake_word.test" + _attr_name = "test" + + @property + def supported_wake_words(self) -> list[wake_word.WakeWord]: + """Return a list of supported wake words.""" + return [wake_word.WakeWord(ww_id="test_ww", name="Test Wake Word")] + + async def _async_process_audio_stream( + self, stream: AsyncIterable[tuple[bytes, int]] + ) -> wake_word.DetectionResult | None: + """Try to detect wake word(s) in an audio stream with timestamps.""" + async for chunk, timestamp in stream: + if chunk == b"wake word": + return wake_word.DetectionResult( + ww_id=self.supported_wake_words[0].ww_id, + timestamp=timestamp, + queued_audio=[(b"queued audio", 0)], + ) + + # Not detected + return None + + +@pytest.fixture +async def mock_wake_word_provider_entity(hass) -> MockWakeWordEntity: + """Mock wake word provider.""" + return MockWakeWordEntity() + + class MockFlow(ConfigFlow): """Test flow.""" @@ -193,6 +227,7 @@ async def init_supporting_components( mock_stt_provider: MockSttProvider, mock_stt_provider_entity: MockSttProviderEntity, mock_tts_provider: MockTTSProvider, + mock_wake_word_provider_entity: MockWakeWordEntity, config_flow_fixture, ): """Initialize relevant components with empty configs.""" @@ -201,14 +236,18 @@ async def init_supporting_components( hass: HomeAssistant, config_entry: ConfigEntry ) -> bool: """Set up test config entry.""" - await hass.config_entries.async_forward_entry_setup(config_entry, stt.DOMAIN) + await hass.config_entries.async_forward_entry_setups( + config_entry, [stt.DOMAIN, wake_word.DOMAIN] + ) return True async def async_unload_entry_init( hass: HomeAssistant, config_entry: ConfigEntry ) -> bool: """Unload up test config entry.""" - await hass.config_entries.async_forward_entry_unload(config_entry, stt.DOMAIN) + await hass.config_entries.async_unload_platforms( + config_entry, [stt.DOMAIN, wake_word.DOMAIN] + ) return True async def async_setup_entry_stt_platform( @@ -219,6 +258,14 @@ async def init_supporting_components( """Set up test stt platform via config entry.""" async_add_entities([mock_stt_provider_entity]) + async def async_setup_entry_wake_word_platform( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, + ) -> None: + """Set up test wake word platform via config entry.""" + async_add_entities([mock_wake_word_provider_entity]) + mock_integration( hass, MockModule( @@ -242,11 +289,19 @@ async def init_supporting_components( async_setup_entry=async_setup_entry_stt_platform, ), ) + mock_platform( + hass, + "test.wake_word", + MockPlatform( + async_setup_entry=async_setup_entry_wake_word_platform, + ), + ) mock_platform(hass, "test.config_flow") assert await async_setup_component(hass, "homeassistant", {}) assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}}) assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}}) + # assert await async_setup_component(hass, wake_word.DOMAIN, {"wake_word": {}}) assert await async_setup_component(hass, "media_source", {}) config_entry = MockConfigEntry(domain="test") diff --git a/tests/components/assist_pipeline/snapshots/test_init.ambr b/tests/components/assist_pipeline/snapshots/test_init.ambr index d8858cec4b6..d0330952f04 100644 --- a/tests/components/assist_pipeline/snapshots/test_init.ambr +++ b/tests/components/assist_pipeline/snapshots/test_init.ambr @@ -266,3 +266,114 @@ }), ]) # --- +# name: test_pipeline_from_audio_stream_wake_word + list([ + dict({ + 'data': dict({ + 'language': 'en', + 'pipeline': , + }), + 'type': , + }), + dict({ + 'data': dict({ + 'engine': 'wake_word.test', + 'metadata': dict({ + 'bit_rate': , + 'channel': , + 'codec': , + 'format': , + 'sample_rate': , + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'wake_word_output': dict({ + 'timestamp': 2000, + 'ww_id': 'test_ww', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'engine': 'test', + 'metadata': dict({ + 'bit_rate': , + 'channel': , + 'codec': , + 'format': , + 'language': 'en-US', + 'sample_rate': , + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'stt_output': dict({ + 'text': 'test transcript', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'conversation_id': None, + 'device_id': None, + 'engine': 'homeassistant', + 'intent_input': 'test transcript', + 'language': 'en', + }), + 'type': , + }), + dict({ + 'data': dict({ + 'intent_output': dict({ + 'conversation_id': None, + 'response': dict({ + 'card': dict({ + }), + 'data': dict({ + 'code': 'no_intent_match', + }), + 'language': 'en', + 'response_type': 'error', + 'speech': dict({ + 'plain': dict({ + 'extra_data': None, + 'speech': "Sorry, I couldn't understand that", + }), + }), + }), + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'engine': 'test', + 'language': 'en-US', + 'tts_input': "Sorry, I couldn't understand that", + 'voice': 'james_earl_jones', + }), + 'type': , + }), + dict({ + 'data': dict({ + 'tts_output': dict({ + 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones", + 'mime_type': 'audio/mpeg', + 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3', + }), + }), + 'type': , + }), + dict({ + 'data': None, + 'type': , + }), + ]) +# --- diff --git a/tests/components/assist_pipeline/snapshots/test_websocket.ambr b/tests/components/assist_pipeline/snapshots/test_websocket.ambr index 12a4d766f06..ea642546e6d 100644 --- a/tests/components/assist_pipeline/snapshots/test_websocket.ambr +++ b/tests/components/assist_pipeline/snapshots/test_websocket.ambr @@ -155,6 +155,243 @@ }), }) # --- +# name: test_audio_pipeline_no_wake_word_engine + dict({ + 'code': 'wake-engine-missing', + 'message': 'No wake word engine', + }) +# --- +# name: test_audio_pipeline_no_wake_word_entity + dict({ + 'code': 'wake-provider-missing', + 'message': 'No wake-word-detection provider for: wake_word.bad-entity-id', + }) +# --- +# name: test_audio_pipeline_with_wake_word + dict({ + 'language': 'en', + 'pipeline': , + 'runner_data': dict({ + 'stt_binary_handler_id': 1, + 'timeout': 30, + }), + }) +# --- +# name: test_audio_pipeline_with_wake_word.1 + dict({ + 'engine': 'wake_word.test', + 'metadata': dict({ + 'bit_rate': 16, + 'channel': 1, + 'codec': 'pcm', + 'format': 'wav', + 'sample_rate': 16000, + }), + }) +# --- +# name: test_audio_pipeline_with_wake_word.2 + dict({ + 'wake_word_output': dict({ + 'queued_audio': None, + 'timestamp': 1000, + 'ww_id': 'test_ww', + }), + }) +# --- +# name: test_audio_pipeline_with_wake_word.3 + dict({ + 'engine': 'test', + 'metadata': dict({ + 'bit_rate': 16, + 'channel': 1, + 'codec': 'pcm', + 'format': 'wav', + 'language': 'en-US', + 'sample_rate': 16000, + }), + }) +# --- +# name: test_audio_pipeline_with_wake_word.4 + dict({ + 'stt_output': dict({ + 'text': 'test transcript', + }), + }) +# --- +# name: test_audio_pipeline_with_wake_word.5 + dict({ + 'conversation_id': None, + 'device_id': None, + 'engine': 'homeassistant', + 'intent_input': 'test transcript', + 'language': 'en', + }) +# --- +# name: test_audio_pipeline_with_wake_word.6 + dict({ + 'intent_output': dict({ + 'conversation_id': None, + 'response': dict({ + 'card': dict({ + }), + 'data': dict({ + 'code': 'no_intent_match', + }), + 'language': 'en', + 'response_type': 'error', + 'speech': dict({ + 'plain': dict({ + 'extra_data': None, + 'speech': "Sorry, I couldn't understand that", + }), + }), + }), + }), + }) +# --- +# name: test_audio_pipeline_with_wake_word.7 + dict({ + 'engine': 'test', + 'language': 'en-US', + 'tts_input': "Sorry, I couldn't understand that", + 'voice': 'james_earl_jones', + }) +# --- +# name: test_audio_pipeline_with_wake_word.8 + dict({ + 'tts_output': dict({ + 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones", + 'mime_type': 'audio/mpeg', + 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3', + }), + }) +# --- +# name: test_audio_pipeline_with_wake_word_no_timeout + dict({ + 'language': 'en', + 'pipeline': , + 'runner_data': dict({ + 'stt_binary_handler_id': 1, + 'timeout': 30, + }), + }) +# --- +# name: test_audio_pipeline_with_wake_word_no_timeout.1 + dict({ + 'engine': 'wake_word.test', + 'metadata': dict({ + 'bit_rate': 16, + 'channel': 1, + 'codec': 'pcm', + 'format': 'wav', + 'sample_rate': 16000, + }), + }) +# --- +# name: test_audio_pipeline_with_wake_word_no_timeout.2 + dict({ + 'wake_word_output': dict({ + 'timestamp': 0, + 'ww_id': 'test_ww', + }), + }) +# --- +# name: test_audio_pipeline_with_wake_word_no_timeout.3 + dict({ + 'engine': 'test', + 'metadata': dict({ + 'bit_rate': 16, + 'channel': 1, + 'codec': 'pcm', + 'format': 'wav', + 'language': 'en-US', + 'sample_rate': 16000, + }), + }) +# --- +# name: test_audio_pipeline_with_wake_word_no_timeout.4 + dict({ + 'stt_output': dict({ + 'text': 'test transcript', + }), + }) +# --- +# name: test_audio_pipeline_with_wake_word_no_timeout.5 + dict({ + 'conversation_id': None, + 'device_id': None, + 'engine': 'homeassistant', + 'intent_input': 'test transcript', + 'language': 'en', + }) +# --- +# name: test_audio_pipeline_with_wake_word_no_timeout.6 + dict({ + 'intent_output': dict({ + 'conversation_id': None, + 'response': dict({ + 'card': dict({ + }), + 'data': dict({ + 'code': 'no_intent_match', + }), + 'language': 'en', + 'response_type': 'error', + 'speech': dict({ + 'plain': dict({ + 'extra_data': None, + 'speech': "Sorry, I couldn't understand that", + }), + }), + }), + }), + }) +# --- +# name: test_audio_pipeline_with_wake_word_no_timeout.7 + dict({ + 'engine': 'test', + 'language': 'en-US', + 'tts_input': "Sorry, I couldn't understand that", + 'voice': 'james_earl_jones', + }) +# --- +# name: test_audio_pipeline_with_wake_word_no_timeout.8 + dict({ + 'tts_output': dict({ + 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones", + 'mime_type': 'audio/mpeg', + 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3', + }), + }) +# --- +# name: test_audio_pipeline_with_wake_word_timeout + dict({ + 'language': 'en', + 'pipeline': , + 'runner_data': dict({ + 'stt_binary_handler_id': 1, + 'timeout': 30, + }), + }) +# --- +# name: test_audio_pipeline_with_wake_word_timeout.1 + dict({ + 'engine': 'wake_word.test', + 'metadata': dict({ + 'bit_rate': 16, + 'channel': 1, + 'codec': 'pcm', + 'format': 'wav', + 'sample_rate': 16000, + }), + }) +# --- +# name: test_audio_pipeline_with_wake_word_timeout.2 + dict({ + 'code': 'wake-word-timeout', + 'message': 'Wake word was not detected', + }) +# --- # name: test_intent_failed dict({ 'language': 'en', diff --git a/tests/components/assist_pipeline/test_init.py b/tests/components/assist_pipeline/test_init.py index 392363fc0cc..44e448aa785 100644 --- a/tests/components/assist_pipeline/test_init.py +++ b/tests/components/assist_pipeline/test_init.py @@ -1,5 +1,6 @@ """Test Voice Assistant init.""" from dataclasses import asdict +import itertools as it from unittest.mock import ANY import pytest @@ -8,10 +9,12 @@ from syrupy.assertion import SnapshotAssertion from homeassistant.components import assist_pipeline, stt from homeassistant.core import Context, HomeAssistant -from .conftest import MockSttProvider, MockSttProviderEntity +from .conftest import MockSttProvider, MockSttProviderEntity, MockWakeWordEntity from tests.typing import WebSocketGenerator +BYTES_ONE_SECOND = 16000 * 2 + def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]: """Process events to remove dynamic values.""" @@ -280,3 +283,61 @@ async def test_pipeline_from_audio_stream_unknown_pipeline( ) assert not events + + +async def test_pipeline_from_audio_stream_wake_word( + hass: HomeAssistant, + mock_stt_provider: MockSttProvider, + mock_wake_word_provider_entity: MockWakeWordEntity, + init_components, + snapshot: SnapshotAssertion, +) -> None: + """Test creating a pipeline from an audio stream with wake word.""" + + events = [] + + # [0, 1, ...] + wake_chunk_1 = bytes(it.islice(it.cycle(range(256)), BYTES_ONE_SECOND)) + + # [0, 2, ...] + wake_chunk_2 = bytes(it.islice(it.cycle(range(0, 256, 2)), BYTES_ONE_SECOND)) + + async def audio_data(): + yield wake_chunk_1 # 1 second + yield wake_chunk_2 # 1 second + yield b"wake word" + yield b"part1" + yield b"part2" + yield b"" + + await assist_pipeline.async_pipeline_from_audio_stream( + hass, + Context(), + events.append, + stt.SpeechMetadata( + language="", + format=stt.AudioFormats.WAV, + codec=stt.AudioCodecs.PCM, + bit_rate=stt.AudioBitRates.BITRATE_16, + sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, + channel=stt.AudioChannels.CHANNEL_MONO, + ), + audio_data(), + start_stage=assist_pipeline.PipelineStage.WAKE_WORD, + wake_word_settings=assist_pipeline.WakeWordSettings( + audio_seconds_to_buffer=1.5 + ), + ) + + assert process_events(events) == snapshot + + # 1. Half of wake_chunk_1 + all wake_chunk_2 + # 2. queued audio (from mock wake word entity) + # 3. part1 + # 4. part2 + assert len(mock_stt_provider.received) == 4 + + first_chunk = mock_stt_provider.received[0] + assert first_chunk == wake_chunk_1[len(wake_chunk_1) // 2 :] + wake_chunk_2 + + assert mock_stt_provider.received[1:] == [b"queued audio", b"part1", b"part2"] diff --git a/tests/components/assist_pipeline/test_websocket.py b/tests/components/assist_pipeline/test_websocket.py index 4ebf0a1fb98..1f2b657dcfa 100644 --- a/tests/components/assist_pipeline/test_websocket.py +++ b/tests/components/assist_pipeline/test_websocket.py @@ -167,6 +167,224 @@ async def test_audio_pipeline( assert msg["result"] == {"events": events} +async def test_audio_pipeline_with_wake_word_timeout( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + init_components, + snapshot: SnapshotAssertion, +) -> None: + """Test timeout from a pipeline run with audio input/output + wake word.""" + events = [] + client = await hass_ws_client(hass) + + await client.send_json_auto_id( + { + "type": "assist_pipeline/run", + "start_stage": "wake_word", + "end_stage": "tts", + "input": { + "sample_rate": 16000, + "timeout": 1, + }, + } + ) + + # result + msg = await client.receive_json() + assert msg["success"], msg + + # run start + msg = await client.receive_json() + assert msg["event"]["type"] == "run-start" + msg["event"]["data"]["pipeline"] = ANY + assert msg["event"]["data"] == snapshot + events.append(msg["event"]) + + # wake_word + msg = await client.receive_json() + assert msg["event"]["type"] == "wake_word-start" + assert msg["event"]["data"] == snapshot + events.append(msg["event"]) + + # 2 seconds of silence + await client.send_bytes(bytes([1]) + bytes(16000 * 2 * 2)) + + # Time out error + msg = await client.receive_json() + assert msg["event"]["type"] == "error" + assert msg["event"]["data"] == snapshot + events.append(msg["event"]) + + +async def test_audio_pipeline_with_wake_word_no_timeout( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + init_components, + snapshot: SnapshotAssertion, +) -> None: + """Test events from a pipeline run with audio input/output + wake word with no timeout.""" + events = [] + client = await hass_ws_client(hass) + + await client.send_json_auto_id( + { + "type": "assist_pipeline/run", + "start_stage": "wake_word", + "end_stage": "tts", + "input": { + "sample_rate": 16000, + "timeout": 0, + }, + } + ) + + # result + msg = await client.receive_json() + assert msg["success"], msg + + # run start + msg = await client.receive_json() + assert msg["event"]["type"] == "run-start" + msg["event"]["data"]["pipeline"] = ANY + assert msg["event"]["data"] == snapshot + events.append(msg["event"]) + + # wake_word + msg = await client.receive_json() + assert msg["event"]["type"] == "wake_word-start" + assert msg["event"]["data"] == snapshot + events.append(msg["event"]) + + # "audio" + await client.send_bytes(bytes([1]) + b"wake word") + + msg = await client.receive_json() + assert msg["event"]["type"] == "wake_word-end" + assert msg["event"]["data"] == snapshot + events.append(msg["event"]) + + # stt + msg = await client.receive_json() + assert msg["event"]["type"] == "stt-start" + assert msg["event"]["data"] == snapshot + events.append(msg["event"]) + + # End of audio stream (handler id + empty payload) + await client.send_bytes(bytes([1])) + + msg = await client.receive_json() + assert msg["event"]["type"] == "stt-end" + assert msg["event"]["data"] == snapshot + events.append(msg["event"]) + + # intent + msg = await client.receive_json() + assert msg["event"]["type"] == "intent-start" + assert msg["event"]["data"] == snapshot + events.append(msg["event"]) + + msg = await client.receive_json() + assert msg["event"]["type"] == "intent-end" + assert msg["event"]["data"] == snapshot + events.append(msg["event"]) + + # text-to-speech + msg = await client.receive_json() + assert msg["event"]["type"] == "tts-start" + assert msg["event"]["data"] == snapshot + events.append(msg["event"]) + + msg = await client.receive_json() + assert msg["event"]["type"] == "tts-end" + assert msg["event"]["data"] == snapshot + events.append(msg["event"]) + + # run end + msg = await client.receive_json() + assert msg["event"]["type"] == "run-end" + assert msg["event"]["data"] is None + events.append(msg["event"]) + + pipeline_data: PipelineData = hass.data[DOMAIN] + pipeline_id = list(pipeline_data.pipeline_runs)[0] + pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0] + + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline_debug/get", + "pipeline_id": pipeline_id, + "pipeline_run_id": pipeline_run_id, + } + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == {"events": events} + + +async def test_audio_pipeline_no_wake_word_engine( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + init_components, + snapshot: SnapshotAssertion, +) -> None: + """Test timeout from a pipeline run with audio input/output + wake word.""" + client = await hass_ws_client(hass) + + with patch( + "homeassistant.components.wake_word.async_default_engine", return_value=None + ): + await client.send_json_auto_id( + { + "type": "assist_pipeline/run", + "start_stage": "wake_word", + "end_stage": "tts", + "input": { + "sample_rate": 16000, + }, + } + ) + + # error + msg = await client.receive_json() + assert not msg["success"] + assert "error" in msg + assert msg["error"] == snapshot + + +async def test_audio_pipeline_no_wake_word_entity( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + init_components, + snapshot: SnapshotAssertion, +) -> None: + """Test timeout from a pipeline run with audio input/output + wake word.""" + client = await hass_ws_client(hass) + + with patch( + "homeassistant.components.wake_word.async_default_engine", + return_value="wake_word.bad-entity-id", + ), patch( + "homeassistant.components.wake_word.async_get_wake_word_detection_entity", + return_value=None, + ): + await client.send_json_auto_id( + { + "type": "assist_pipeline/run", + "start_stage": "wake_word", + "end_stage": "tts", + "input": { + "sample_rate": 16000, + }, + } + ) + + # error + msg = await client.receive_json() + assert not msg["success"] + assert "error" in msg + assert msg["error"] == snapshot + + async def test_intent_timeout( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, diff --git a/tests/components/wake_word/__init__.py b/tests/components/wake_word/__init__.py new file mode 100644 index 00000000000..ed2fe81a7fe --- /dev/null +++ b/tests/components/wake_word/__init__.py @@ -0,0 +1 @@ +"""Wake-word-detection tests.""" diff --git a/tests/components/wake_word/common.py b/tests/components/wake_word/common.py new file mode 100644 index 00000000000..f732044bc13 --- /dev/null +++ b/tests/components/wake_word/common.py @@ -0,0 +1,29 @@ +"""Provide common test tools for wake-word-detection.""" +from __future__ import annotations + +from collections.abc import Callable, Coroutine +from pathlib import Path +from typing import Any + +from homeassistant.components import wake_word +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant +from homeassistant.helpers.entity_platform import AddEntitiesCallback + +from tests.common import MockPlatform, mock_platform + + +def mock_wake_word_entity_platform( + hass: HomeAssistant, + tmp_path: Path, + integration: str, + async_setup_entry: Callable[ + [HomeAssistant, ConfigEntry, AddEntitiesCallback], + Coroutine[Any, Any, None], + ] + | None = None, +) -> MockPlatform: + """Specialize the mock platform for stt.""" + loaded_platform = MockPlatform(async_setup_entry=async_setup_entry) + mock_platform(hass, f"{integration}.{wake_word.DOMAIN}", loaded_platform) + return loaded_platform diff --git a/tests/components/wake_word/snapshots/test_init.ambr b/tests/components/wake_word/snapshots/test_init.ambr new file mode 100644 index 00000000000..ca6d5d950f0 --- /dev/null +++ b/tests/components/wake_word/snapshots/test_init.ambr @@ -0,0 +1,11 @@ +# serializer version: 1 +# name: test_ws_detect + dict({ + 'event': dict({ + 'timestamp': 2048.0, + 'ww_id': 'test_ww', + }), + 'id': 1, + 'type': 'event', + }) +# --- diff --git a/tests/components/wake_word/test_init.py b/tests/components/wake_word/test_init.py new file mode 100644 index 00000000000..954cbe6dc8c --- /dev/null +++ b/tests/components/wake_word/test_init.py @@ -0,0 +1,226 @@ +"""Test wake_word component setup.""" +from collections.abc import AsyncIterable, Generator +from pathlib import Path + +import pytest + +from homeassistant.components import wake_word +from homeassistant.config_entries import ConfigEntry, ConfigEntryState, ConfigFlow +from homeassistant.core import HomeAssistant, State +from homeassistant.helpers.entity_platform import AddEntitiesCallback +from homeassistant.setup import async_setup_component + +from .common import mock_wake_word_entity_platform + +from tests.common import ( + MockConfigEntry, + MockModule, + mock_config_flow, + mock_integration, + mock_platform, + mock_restore_cache, +) + +TEST_DOMAIN = "test" + +_SAMPLES_PER_CHUNK = 1024 +_BYTES_PER_CHUNK = _SAMPLES_PER_CHUNK * 2 # 16-bit +_MS_PER_CHUNK = (_BYTES_PER_CHUNK // 2) // 16 # 16Khz + + +class MockProviderEntity(wake_word.WakeWordDetectionEntity): + """Mock provider entity.""" + + url_path = "wake_word.test" + _attr_name = "test" + + @property + def supported_wake_words(self) -> list[wake_word.WakeWord]: + """Return a list of supported wake words.""" + return [wake_word.WakeWord(ww_id="test_ww", name="Test Wake Word")] + + async def _async_process_audio_stream( + self, stream: AsyncIterable[tuple[bytes, int]] + ) -> wake_word.DetectionResult | None: + """Try to detect wake word(s) in an audio stream with timestamps.""" + async for _chunk, timestamp in stream: + if timestamp >= 2000: + return wake_word.DetectionResult( + ww_id=self.supported_wake_words[0].ww_id, timestamp=timestamp + ) + + # Not detected + return None + + +@pytest.fixture +def mock_provider_entity() -> MockProviderEntity: + """Test provider entity fixture.""" + return MockProviderEntity() + + +class WakeWordFlow(ConfigFlow): + """Test flow.""" + + +@pytest.fixture(autouse=True) +def config_flow_fixture(hass: HomeAssistant) -> Generator[None, None, None]: + """Mock config flow.""" + mock_platform(hass, f"{TEST_DOMAIN}.config_flow") + + with mock_config_flow(TEST_DOMAIN, WakeWordFlow): + yield + + +@pytest.fixture(name="setup") +async def setup_fixture( + hass: HomeAssistant, + tmp_path: Path, +) -> MockProviderEntity: + """Set up the test environment.""" + provider = MockProviderEntity() + await mock_config_entry_setup(hass, tmp_path, provider) + + return provider + + +async def mock_config_entry_setup( + hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity +) -> MockConfigEntry: + """Set up a test provider via config entry.""" + + async def async_setup_entry_init( + hass: HomeAssistant, config_entry: ConfigEntry + ) -> bool: + """Set up test config entry.""" + await hass.config_entries.async_forward_entry_setup( + config_entry, wake_word.DOMAIN + ) + return True + + async def async_unload_entry_init( + hass: HomeAssistant, config_entry: ConfigEntry + ) -> bool: + """Unload up test config entry.""" + await hass.config_entries.async_forward_entry_unload( + config_entry, wake_word.DOMAIN + ) + return True + + mock_integration( + hass, + MockModule( + TEST_DOMAIN, + async_setup_entry=async_setup_entry_init, + async_unload_entry=async_unload_entry_init, + ), + ) + + async def async_setup_entry_platform( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, + ) -> None: + """Set up test stt platform via config entry.""" + async_add_entities([mock_provider_entity]) + + mock_wake_word_entity_platform( + hass, tmp_path, TEST_DOMAIN, async_setup_entry_platform + ) + + config_entry = MockConfigEntry(domain=TEST_DOMAIN) + config_entry.add_to_hass(hass) + assert await hass.config_entries.async_setup(config_entry.entry_id) + await hass.async_block_till_done() + + return config_entry + + +async def test_config_entry_unload( + hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity +) -> None: + """Test we can unload config entry.""" + config_entry = await mock_config_entry_setup(hass, tmp_path, mock_provider_entity) + assert config_entry.state == ConfigEntryState.LOADED + await hass.config_entries.async_unload(config_entry.entry_id) + assert config_entry.state == ConfigEntryState.NOT_LOADED + + +async def test_detected_entity( + hass: HomeAssistant, tmp_path: Path, setup: MockProviderEntity +) -> None: + """Test successful detection through entity.""" + + async def three_second_stream(): + timestamp = 0 + while timestamp < 3000: + yield bytes(_BYTES_PER_CHUNK), timestamp + timestamp += _MS_PER_CHUNK + + # Need 2 seconds to trigger + result = await setup.async_process_audio_stream(three_second_stream()) + assert result == wake_word.DetectionResult("test_ww", 2048) + + +async def test_not_detected_entity( + hass: HomeAssistant, setup: MockProviderEntity +) -> None: + """Test unsuccessful detection through entity.""" + + async def one_second_stream(): + timestamp = 0 + while timestamp < 1000: + yield bytes(_BYTES_PER_CHUNK), timestamp + timestamp += _MS_PER_CHUNK + + # Need 2 seconds to trigger + result = await setup.async_process_audio_stream(one_second_stream()) + assert result is None + + +async def test_default_engine_none(hass: HomeAssistant, tmp_path: Path) -> None: + """Test async_default_engine.""" + assert await async_setup_component(hass, wake_word.DOMAIN, {wake_word.DOMAIN: {}}) + await hass.async_block_till_done() + + assert wake_word.async_default_engine(hass) is None + + +async def test_default_engine_entity( + hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity +) -> None: + """Test async_default_engine.""" + await mock_config_entry_setup(hass, tmp_path, mock_provider_entity) + + assert wake_word.async_default_engine(hass) == f"{wake_word.DOMAIN}.{TEST_DOMAIN}" + + +async def test_get_engine_entity( + hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity +) -> None: + """Test async_get_speech_to_text_engine.""" + await mock_config_entry_setup(hass, tmp_path, mock_provider_entity) + + assert ( + wake_word.async_get_wake_word_detection_entity(hass, f"{wake_word.DOMAIN}.test") + is mock_provider_entity + ) + + +async def test_restore_state( + hass: HomeAssistant, + tmp_path: Path, + mock_provider_entity: MockProviderEntity, +) -> None: + """Test we restore state in the integration.""" + entity_id = f"{wake_word.DOMAIN}.{TEST_DOMAIN}" + timestamp = "2023-01-01T23:59:59+00:00" + mock_restore_cache(hass, (State(entity_id, timestamp),)) + + config_entry = await mock_config_entry_setup(hass, tmp_path, mock_provider_entity) + await hass.async_block_till_done() + + assert config_entry.state == ConfigEntryState.LOADED + state = hass.states.get(entity_id) + assert state + assert state.state == timestamp diff --git a/tests/components/wyoming/__init__.py b/tests/components/wyoming/__init__.py index 3d12d41ce5e..c326228ec8b 100644 --- a/tests/components/wyoming/__init__.py +++ b/tests/components/wyoming/__init__.py @@ -1,4 +1,6 @@ """Tests for the Wyoming integration.""" +import asyncio + from wyoming.info import ( AsrModel, AsrProgram, @@ -7,6 +9,8 @@ from wyoming.info import ( TtsProgram, TtsVoice, TtsVoiceSpeaker, + WakeModel, + WakeProgram, ) TEST_ATTR = Attribution(name="Test", url="http://www.test.com") @@ -49,6 +53,25 @@ TTS_INFO = Info( ) ] ) +WAKE_WORD_INFO = Info( + wake=[ + WakeProgram( + name="Test Wake Word", + description="Test Wake Word", + installed=True, + attribution=TEST_ATTR, + models=[ + WakeModel( + name="Test Model", + description="Test Model", + installed=True, + attribution=TEST_ATTR, + languages=["en-US"], + ) + ], + ) + ] +) EMPTY_INFO = Info() @@ -68,6 +91,7 @@ class MockAsyncTcpClient: async def read_event(self): """Receive.""" + await asyncio.sleep(0) # force context switch return self.responses.pop(0) async def __aenter__(self): diff --git a/tests/components/wyoming/conftest.py b/tests/components/wyoming/conftest.py index 6b4e705914f..2c8081908f7 100644 --- a/tests/components/wyoming/conftest.py +++ b/tests/components/wyoming/conftest.py @@ -8,7 +8,7 @@ from homeassistant.components import stt from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant -from . import STT_INFO, TTS_INFO +from . import STT_INFO, TTS_INFO, WAKE_WORD_INFO from tests.common import MockConfigEntry @@ -52,6 +52,21 @@ def tts_config_entry(hass: HomeAssistant) -> ConfigEntry: return entry +@pytest.fixture +def wake_word_config_entry(hass: HomeAssistant) -> ConfigEntry: + """Create a config entry.""" + entry = MockConfigEntry( + domain="wyoming", + data={ + "host": "1.2.3.4", + "port": 1234, + }, + title="Test Wake Word", + ) + entry.add_to_hass(hass) + return entry + + @pytest.fixture async def init_wyoming_stt(hass: HomeAssistant, stt_config_entry: ConfigEntry): """Initialize Wyoming STT.""" @@ -72,6 +87,18 @@ async def init_wyoming_tts(hass: HomeAssistant, tts_config_entry: ConfigEntry): await hass.config_entries.async_setup(tts_config_entry.entry_id) +@pytest.fixture +async def init_wyoming_wake_word( + hass: HomeAssistant, wake_word_config_entry: ConfigEntry +): + """Initialize Wyoming Wake Word.""" + with patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=WAKE_WORD_INFO, + ): + await hass.config_entries.async_setup(wake_word_config_entry.entry_id) + + @pytest.fixture def metadata(hass: HomeAssistant) -> stt.SpeechMetadata: """Get default STT metadata.""" diff --git a/tests/components/wyoming/snapshots/test_wake_word.ambr b/tests/components/wyoming/snapshots/test_wake_word.ambr new file mode 100644 index 00000000000..041112cb6ff --- /dev/null +++ b/tests/components/wyoming/snapshots/test_wake_word.ambr @@ -0,0 +1,13 @@ +# serializer version: 1 +# name: test_streaming_audio + dict({ + 'queued_audio': list([ + tuple( + b'chunk', + 1, + ), + ]), + 'timestamp': 0, + 'ww_id': 'Test Model', + }) +# --- diff --git a/tests/components/wyoming/test_wake_word.py b/tests/components/wyoming/test_wake_word.py new file mode 100644 index 00000000000..cd156c660a8 --- /dev/null +++ b/tests/components/wyoming/test_wake_word.py @@ -0,0 +1,108 @@ +"""Test stt.""" +from __future__ import annotations + +import asyncio +from unittest.mock import patch + +from syrupy.assertion import SnapshotAssertion +from wyoming.asr import Transcript +from wyoming.wake import Detection + +from homeassistant.components import wake_word +from homeassistant.core import HomeAssistant + +from . import MockAsyncTcpClient + + +async def test_support(hass: HomeAssistant, init_wyoming_wake_word) -> None: + """Test supported properties.""" + state = hass.states.get("wake_word.test_wake_word") + assert state is not None + + entity = wake_word.async_get_wake_word_detection_entity( + hass, "wake_word.test_wake_word" + ) + assert entity is not None + + assert entity.supported_wake_words == [ + wake_word.WakeWord(ww_id="Test Model", name="Test Model") + ] + + +async def test_streaming_audio( + hass: HomeAssistant, init_wyoming_wake_word, snapshot: SnapshotAssertion +) -> None: + """Test streaming audio.""" + entity = wake_word.async_get_wake_word_detection_entity( + hass, "wake_word.test_wake_word" + ) + assert entity is not None + + async def audio_stream(): + yield b"chunk", 0 + + # Delay to force a pending audio chunk + await asyncio.sleep(0.05) + yield b"chunk", 1 + + client_events = [ + Transcript("not a wake word event").event(), + Detection(name="Test Model", timestamp=0).event(), + ] + + with patch( + "homeassistant.components.wyoming.wake_word.AsyncTcpClient", + MockAsyncTcpClient(client_events), + ): + result = await entity.async_process_audio_stream(audio_stream()) + + assert result is not None + assert result == snapshot + + +async def test_streaming_audio_connection_lost( + hass: HomeAssistant, init_wyoming_wake_word +) -> None: + """Test streaming audio and losing connection.""" + entity = wake_word.async_get_wake_word_detection_entity( + hass, "wake_word.test_wake_word" + ) + assert entity is not None + + async def audio_stream(): + # Delay to force a pending audio chunk + await asyncio.sleep(0.05) + yield b"chunk", 1 + + with patch( + "homeassistant.components.wyoming.wake_word.AsyncTcpClient", + MockAsyncTcpClient([None]), + ): + result = await entity.async_process_audio_stream(audio_stream()) + + assert result is None + + +async def test_streaming_audio_oserror( + hass: HomeAssistant, init_wyoming_wake_word +) -> None: + """Test streaming audio and error raising.""" + entity = wake_word.async_get_wake_word_detection_entity( + hass, "wake_word.test_wake_word" + ) + assert entity is not None + + async def audio_stream(): + yield b"chunk1", 1000 + + mock_client = MockAsyncTcpClient( + [Detection(name="Test Model", timestamp=1000).event()] + ) + + with patch( + "homeassistant.components.wyoming.wake_word.AsyncTcpClient", + mock_client, + ), patch.object(mock_client, "read_event", side_effect=OSError("Boom!")): + result = await entity.async_process_audio_stream(audio_stream()) + + assert result is None