Add wake word integration (#96380)

* Add wake component

* Add wake support to Wyoming

* Add helper function to assist_pipeline (not complete)

* Rename wake to wake_word

* Fix platform

* Use send_event and clean up

* Merge wake word into pipeline

* Add wake option to async_pipeline_from_audio_stream

* Add start/end stages to async_pipeline_from_audio_stream

* Add wake timeout

* Remove layer in wake_output

* Use VAD for wake word timeout

* Include audio metadata in wake-start

* Remove unnecessary websocket command

* wake -> wake_word

* Incorporate feedback

* Clean up wake_word tests

* Add wyoming wake word tests

* Add pipeline wake word test

* Add last processed state

* Fix tests

* Add tests for wake word

* More tests for the codebot
This commit is contained in:
Michael Hansen 2023-08-07 21:22:16 -05:00 committed by GitHub
parent 798fb3e31a
commit 7ea2998b55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 1802 additions and 27 deletions

View File

@ -1373,6 +1373,8 @@ build.json @home-assistant/supervisor
/tests/components/vulcan/ @Antoni-Czaplicki /tests/components/vulcan/ @Antoni-Czaplicki
/homeassistant/components/wake_on_lan/ @ntilley905 /homeassistant/components/wake_on_lan/ @ntilley905
/tests/components/wake_on_lan/ @ntilley905 /tests/components/wake_on_lan/ @ntilley905
/homeassistant/components/wake_word/ @home-assistant/core @synesthesiam
/tests/components/wake_word/ @home-assistant/core @synesthesiam
/homeassistant/components/wallbox/ @hesselonline /homeassistant/components/wallbox/ @hesselonline
/tests/components/wallbox/ @hesselonline /tests/components/wallbox/ @hesselonline
/homeassistant/components/waqi/ @andrey-git /homeassistant/components/waqi/ @andrey-git

View File

@ -18,6 +18,7 @@ from .pipeline import (
PipelineInput, PipelineInput,
PipelineRun, PipelineRun,
PipelineStage, PipelineStage,
WakeWordSettings,
async_create_default_pipeline, async_create_default_pipeline,
async_get_pipeline, async_get_pipeline,
async_get_pipelines, async_get_pipelines,
@ -35,6 +36,7 @@ __all__ = (
"PipelineEvent", "PipelineEvent",
"PipelineEventType", "PipelineEventType",
"PipelineNotFound", "PipelineNotFound",
"WakeWordSettings",
) )
CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN) CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
@ -57,7 +59,10 @@ async def async_pipeline_from_audio_stream(
pipeline_id: str | None = None, pipeline_id: str | None = None,
conversation_id: str | None = None, conversation_id: str | None = None,
tts_audio_output: str | None = None, tts_audio_output: str | None = None,
wake_word_settings: WakeWordSettings | None = None,
device_id: str | None = None, device_id: str | None = None,
start_stage: PipelineStage = PipelineStage.STT,
end_stage: PipelineStage = PipelineStage.TTS,
) -> None: ) -> None:
"""Create an audio pipeline from an audio stream. """Create an audio pipeline from an audio stream.
@ -72,10 +77,11 @@ async def async_pipeline_from_audio_stream(
hass, hass,
context=context, context=context,
pipeline=async_get_pipeline(hass, pipeline_id=pipeline_id), pipeline=async_get_pipeline(hass, pipeline_id=pipeline_id),
start_stage=PipelineStage.STT, start_stage=start_stage,
end_stage=PipelineStage.TTS, end_stage=end_stage,
event_callback=event_callback, event_callback=event_callback,
tts_audio_output=tts_audio_output, tts_audio_output=tts_audio_output,
wake_word_settings=wake_word_settings,
), ),
) )
await pipeline_input.validate() await pipeline_input.validate()

View File

@ -18,6 +18,14 @@ class PipelineNotFound(PipelineError):
"""Unspecified pipeline picked.""" """Unspecified pipeline picked."""
class WakeWordDetectionError(PipelineError):
"""Error in wake-word-detection portion of pipeline."""
class WakeWordTimeoutError(WakeWordDetectionError):
"""Timeout when wake word was not detected."""
class SpeechToTextError(PipelineError): class SpeechToTextError(PipelineError):
"""Error in speech-to-text portion of pipeline.""" """Error in speech-to-text portion of pipeline."""

View File

@ -2,7 +2,7 @@
"domain": "assist_pipeline", "domain": "assist_pipeline",
"name": "Assist pipeline", "name": "Assist pipeline",
"codeowners": ["@balloob", "@synesthesiam"], "codeowners": ["@balloob", "@synesthesiam"],
"dependencies": ["conversation", "stt", "tts"], "dependencies": ["conversation", "stt", "tts", "wake_word"],
"documentation": "https://www.home-assistant.io/integrations/assist_pipeline", "documentation": "https://www.home-assistant.io/integrations/assist_pipeline",
"iot_class": "local_push", "iot_class": "local_push",
"quality_scale": "internal", "quality_scale": "internal",

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import AsyncIterable, Callable, Iterable from collections.abc import AsyncGenerator, AsyncIterable, Callable, Iterable
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from enum import StrEnum from enum import StrEnum
import logging import logging
@ -10,7 +10,14 @@ from typing import Any, cast
import voluptuous as vol import voluptuous as vol
from homeassistant.components import conversation, media_source, stt, tts, websocket_api from homeassistant.components import (
conversation,
media_source,
stt,
tts,
wake_word,
websocket_api,
)
from homeassistant.components.tts.media_source import ( from homeassistant.components.tts.media_source import (
generate_media_source_id as tts_generate_media_source_id, generate_media_source_id as tts_generate_media_source_id,
) )
@ -39,7 +46,10 @@ from .error import (
PipelineNotFound, PipelineNotFound,
SpeechToTextError, SpeechToTextError,
TextToSpeechError, TextToSpeechError,
WakeWordDetectionError,
WakeWordTimeoutError,
) )
from .vad import VoiceActivityTimeout, VoiceCommandSegmenter
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -241,6 +251,8 @@ class PipelineEventType(StrEnum):
RUN_START = "run-start" RUN_START = "run-start"
RUN_END = "run-end" RUN_END = "run-end"
WAKE_WORD_START = "wake_word-start"
WAKE_WORD_END = "wake_word-end"
STT_START = "stt-start" STT_START = "stt-start"
STT_END = "stt-end" STT_END = "stt-end"
INTENT_START = "intent-start" INTENT_START = "intent-start"
@ -297,12 +309,14 @@ class Pipeline:
class PipelineStage(StrEnum): class PipelineStage(StrEnum):
"""Stages of a pipeline.""" """Stages of a pipeline."""
WAKE_WORD = "wake_word"
STT = "stt" STT = "stt"
INTENT = "intent" INTENT = "intent"
TTS = "tts" TTS = "tts"
PIPELINE_STAGE_ORDER = [ PIPELINE_STAGE_ORDER = [
PipelineStage.WAKE_WORD,
PipelineStage.STT, PipelineStage.STT,
PipelineStage.INTENT, PipelineStage.INTENT,
PipelineStage.TTS, PipelineStage.TTS,
@ -327,6 +341,17 @@ class InvalidPipelineStagesError(PipelineRunValidationError):
) )
@dataclass(frozen=True)
class WakeWordSettings:
"""Settings for wake word detection."""
timeout: float | None = None
"""Seconds of silence before detection times out."""
audio_seconds_to_buffer: float = 0
"""Seconds of audio to buffer before detection and forward to STT."""
@dataclass @dataclass
class PipelineRun: class PipelineRun:
"""Running context for a pipeline.""" """Running context for a pipeline."""
@ -341,17 +366,20 @@ class PipelineRun:
runner_data: Any | None = None runner_data: Any | None = None
intent_agent: str | None = None intent_agent: str | None = None
tts_audio_output: str | None = None tts_audio_output: str | None = None
wake_word_settings: WakeWordSettings | None = None
id: str = field(default_factory=ulid_util.ulid) id: str = field(default_factory=ulid_util.ulid)
stt_provider: stt.SpeechToTextEntity | stt.Provider = field(init=False) stt_provider: stt.SpeechToTextEntity | stt.Provider = field(init=False)
tts_engine: str = field(init=False) tts_engine: str = field(init=False)
tts_options: dict | None = field(init=False, default=None) tts_options: dict | None = field(init=False, default=None)
wake_word_engine: str = field(init=False)
wake_word_provider: wake_word.WakeWordDetectionEntity = field(init=False)
def __post_init__(self) -> None: def __post_init__(self) -> None:
"""Set language for pipeline.""" """Set language for pipeline."""
self.language = self.pipeline.language or self.hass.config.language self.language = self.pipeline.language or self.hass.config.language
# stt -> intent -> tts # wake -> stt -> intent -> tts
if PIPELINE_STAGE_ORDER.index(self.end_stage) < PIPELINE_STAGE_ORDER.index( if PIPELINE_STAGE_ORDER.index(self.end_stage) < PIPELINE_STAGE_ORDER.index(
self.start_stage self.start_stage
): ):
@ -393,6 +421,141 @@ class PipelineRun:
) )
) )
async def prepare_wake_word_detection(self) -> None:
"""Prepare wake-word-detection."""
# Need to add to pipeline store
engine = wake_word.async_default_engine(self.hass)
if engine is None:
raise WakeWordDetectionError(
code="wake-engine-missing",
message="No wake word engine",
)
wake_word_provider = wake_word.async_get_wake_word_detection_entity(
self.hass, engine
)
if wake_word_provider is None:
raise WakeWordDetectionError(
code="wake-provider-missing",
message=f"No wake-word-detection provider for: {engine}",
)
self.wake_word_engine = engine
self.wake_word_provider = wake_word_provider
async def wake_word_detection(
self,
stream: AsyncIterable[bytes],
audio_buffer: list[bytes],
) -> wake_word.DetectionResult | None:
"""Run wake-word-detection portion of pipeline. Returns detection result."""
metadata_dict = asdict(
stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
)
)
# Remove language since it doesn't apply to wake words yet
metadata_dict.pop("language", None)
self.process_event(
PipelineEvent(
PipelineEventType.WAKE_WORD_START,
{
"engine": self.wake_word_engine,
"metadata": metadata_dict,
},
)
)
wake_word_settings = self.wake_word_settings or WakeWordSettings()
wake_word_vad: VoiceActivityTimeout | None = None
if (wake_word_settings.timeout is not None) and (
wake_word_settings.timeout > 0
):
# Use VAD to determine timeout
wake_word_vad = VoiceActivityTimeout(wake_word_settings.timeout)
# Audio chunk buffer.
audio_bytes_to_buffer = int(
wake_word_settings.audio_seconds_to_buffer * 16000 * 2
)
audio_ring_buffer = b""
async def timestamped_stream() -> AsyncIterable[tuple[bytes, int]]:
"""Yield audio with timestamps (milliseconds since start of stream)."""
nonlocal audio_ring_buffer
timestamp_ms = 0
async for chunk in stream:
yield chunk, timestamp_ms
timestamp_ms += (len(chunk) // 2) // 16 # milliseconds @ 16Khz
# Keeping audio right before wake word detection allows the
# voice command to be spoken immediately after the wake word.
if audio_bytes_to_buffer > 0:
audio_ring_buffer += chunk
if len(audio_ring_buffer) > audio_bytes_to_buffer:
# A proper ring buffer would be far more efficient
audio_ring_buffer = audio_ring_buffer[
len(audio_ring_buffer) - audio_bytes_to_buffer :
]
if (wake_word_vad is not None) and (not wake_word_vad.process(chunk)):
raise WakeWordTimeoutError(
code="wake-word-timeout", message="Wake word was not detected"
)
try:
# Detect wake word(s)
result = await self.wake_word_provider.async_process_audio_stream(
timestamped_stream()
)
if audio_ring_buffer:
# All audio kept from right before the wake word was detected as
# a single chunk.
audio_buffer.append(audio_ring_buffer)
except WakeWordTimeoutError:
_LOGGER.debug("Timeout during wake word detection")
raise
except Exception as src_error:
_LOGGER.exception("Unexpected error during wake-word-detection")
raise WakeWordDetectionError(
code="wake-stream-failed",
message="Unexpected error during wake-word-detection",
) from src_error
_LOGGER.debug("wake-word-detection result %s", result)
if result is None:
wake_word_output: dict[str, Any] = {}
else:
if result.queued_audio:
# Add audio that was pending at detection
for chunk_ts in result.queued_audio:
audio_buffer.append(chunk_ts[0])
wake_word_output = asdict(result)
# Remove non-JSON fields
wake_word_output.pop("queued_audio", None)
self.process_event(
PipelineEvent(
PipelineEventType.WAKE_WORD_END,
{"wake_word_output": wake_word_output},
)
)
return result
async def prepare_speech_to_text(self, metadata: stt.SpeechMetadata) -> None: async def prepare_speech_to_text(self, metadata: stt.SpeechMetadata) -> None:
"""Prepare speech-to-text.""" """Prepare speech-to-text."""
# pipeline.stt_engine can't be None or this function is not called # pipeline.stt_engine can't be None or this function is not called
@ -443,9 +606,21 @@ class PipelineRun:
) )
try: try:
segmenter = VoiceCommandSegmenter()
async def segment_stream(
stream: AsyncIterable[bytes],
) -> AsyncGenerator[bytes, None]:
"""Stop stream when voice command is finished."""
async for chunk in stream:
if not segmenter.process(chunk):
break
yield chunk
# Transcribe audio stream # Transcribe audio stream
result = await self.stt_provider.async_process_audio_stream( result = await self.stt_provider.async_process_audio_stream(
metadata, stream metadata, segment_stream(stream)
) )
except Exception as src_error: except Exception as src_error:
_LOGGER.exception("Unexpected error during speech-to-text") _LOGGER.exception("Unexpected error during speech-to-text")
@ -663,17 +838,45 @@ class PipelineInput:
async def execute(self) -> None: async def execute(self) -> None:
"""Run pipeline.""" """Run pipeline."""
self.run.start() self.run.start()
current_stage = self.run.start_stage current_stage: PipelineStage | None = self.run.start_stage
audio_buffer: list[bytes] = []
try: try:
if current_stage == PipelineStage.WAKE_WORD:
assert self.stt_stream is not None
detect_result = await self.run.wake_word_detection(
self.stt_stream, audio_buffer
)
if detect_result is None:
# No wake word. Abort the rest of the pipeline.
self.run.end()
return
current_stage = PipelineStage.STT
# speech-to-text # speech-to-text
intent_input = self.intent_input intent_input = self.intent_input
if current_stage == PipelineStage.STT: if current_stage == PipelineStage.STT:
assert self.stt_metadata is not None assert self.stt_metadata is not None
assert self.stt_stream is not None assert self.stt_stream is not None
if audio_buffer:
async def buffered_stream() -> AsyncGenerator[bytes, None]:
for chunk in audio_buffer:
yield chunk
assert self.stt_stream is not None
async for chunk in self.stt_stream:
yield chunk
stt_stream = cast(AsyncIterable[bytes], buffered_stream())
else:
stt_stream = self.stt_stream
intent_input = await self.run.speech_to_text( intent_input = await self.run.speech_to_text(
self.stt_metadata, self.stt_metadata,
self.stt_stream, stt_stream,
) )
current_stage = PipelineStage.INTENT current_stage = PipelineStage.INTENT
@ -707,7 +910,7 @@ class PipelineInput:
async def validate(self) -> None: async def validate(self) -> None:
"""Validate pipeline input against start stage.""" """Validate pipeline input against start stage."""
if self.run.start_stage == PipelineStage.STT: if self.run.start_stage in (PipelineStage.WAKE_WORD, PipelineStage.STT):
if self.run.pipeline.stt_engine is None: if self.run.pipeline.stt_engine is None:
raise PipelineRunValidationError( raise PipelineRunValidationError(
"the pipeline does not support speech-to-text" "the pipeline does not support speech-to-text"
@ -741,6 +944,13 @@ class PipelineInput:
prepare_tasks = [] prepare_tasks = []
if (
start_stage_index
<= PIPELINE_STAGE_ORDER.index(PipelineStage.WAKE_WORD)
<= end_stage_index
):
prepare_tasks.append(self.run.prepare_wake_word_detection())
if ( if (
start_stage_index start_stage_index
<= PIPELINE_STAGE_ORDER.index(PipelineStage.STT) <= PIPELINE_STAGE_ORDER.index(PipelineStage.STT)

View File

@ -88,7 +88,7 @@ class VoiceCommandSegmenter:
self.in_command = False self.in_command = False
def process(self, samples: bytes) -> bool: def process(self, samples: bytes) -> bool:
"""Process a 16-bit 16Khz mono audio samples. """Process 16-bit 16Khz mono audio samples.
Returns False when command is done. Returns False when command is done.
""" """
@ -148,3 +148,94 @@ class VoiceCommandSegmenter:
self._silence_seconds_left = self.silence_seconds self._silence_seconds_left = self.silence_seconds
return True return True
@dataclass
class VoiceActivityTimeout:
"""Detects silence in audio until a timeout is reached."""
silence_seconds: float
"""Seconds of silence before timeout."""
reset_seconds: float = 0.5
"""Seconds of speech before resetting timeout."""
vad_mode: int = 3
"""Aggressiveness in filtering out non-speech. 3 is the most aggressive."""
vad_frames: int = 480 # 30 ms
"""Must be 10, 20, or 30 ms at 16Khz."""
_silence_seconds_left: float = 0.0
"""Seconds left before considering voice command as stopped."""
_reset_seconds_left: float = 0.0
"""Seconds left before resetting start/stop time counters."""
_vad: webrtcvad.Vad = None
_audio_buffer: bytes = field(default_factory=bytes)
_bytes_per_chunk: int = 480 * 2 # 16-bit samples
_seconds_per_chunk: float = 0.03 # 30 ms
def __post_init__(self) -> None:
"""Initialize VAD."""
self._vad = webrtcvad.Vad(self.vad_mode)
self._bytes_per_chunk = self.vad_frames * 2
self._seconds_per_chunk = self.vad_frames / _SAMPLE_RATE
self.reset()
def reset(self) -> None:
"""Reset all counters and state."""
self._audio_buffer = b""
self._silence_seconds_left = self.silence_seconds
self._reset_seconds_left = self.reset_seconds
def process(self, samples: bytes) -> bool:
"""Process 16-bit 16Khz mono audio samples.
Returns False when timeout is reached.
"""
self._audio_buffer += samples
# Process in 10, 20, or 30 ms chunks.
num_chunks = len(self._audio_buffer) // self._bytes_per_chunk
for chunk_idx in range(num_chunks):
chunk_offset = chunk_idx * self._bytes_per_chunk
chunk = self._audio_buffer[
chunk_offset : chunk_offset + self._bytes_per_chunk
]
if not self._process_chunk(chunk):
return False
if num_chunks > 0:
# Remove from buffer
self._audio_buffer = self._audio_buffer[
num_chunks * self._bytes_per_chunk :
]
return True
def _process_chunk(self, chunk: bytes) -> bool:
"""Process a single chunk of 16-bit 16Khz mono audio.
Returns False when timeout is reached.
"""
if self._vad.is_speech(chunk, _SAMPLE_RATE):
# Speech
self._reset_seconds_left -= self._seconds_per_chunk
if self._reset_seconds_left <= 0:
# Reset timeout
self._silence_seconds_left = self.silence_seconds
else:
# Silence
self._silence_seconds_left -= self._seconds_per_chunk
if self._silence_seconds_left <= 0:
# Timeout reached
return False
# Slowly build reset counter back up
self._reset_seconds_left = min(
self.reset_seconds, self._reset_seconds_left + self._seconds_per_chunk
)
return True

View File

@ -26,11 +26,12 @@ from .pipeline import (
PipelineInput, PipelineInput,
PipelineRun, PipelineRun,
PipelineStage, PipelineStage,
WakeWordSettings,
async_get_pipeline, async_get_pipeline,
) )
from .vad import VoiceCommandSegmenter
DEFAULT_TIMEOUT = 30 DEFAULT_TIMEOUT = 30
DEFAULT_WAKE_WORD_TIMEOUT = 3
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -63,6 +64,18 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
cv.key_value_schemas( cv.key_value_schemas(
"start_stage", "start_stage",
{ {
PipelineStage.WAKE_WORD: vol.Schema(
{
vol.Required("input"): {
vol.Required("sample_rate"): int,
vol.Optional("timeout"): vol.Any(float, int),
vol.Optional("audio_seconds_to_buffer"): vol.Any(
float, int
),
}
},
extra=vol.ALLOW_EXTRA,
),
PipelineStage.STT: vol.Schema( PipelineStage.STT: vol.Schema(
{vol.Required("input"): {vol.Required("sample_rate"): int}}, {vol.Required("input"): {vol.Required("sample_rate"): int}},
extra=vol.ALLOW_EXTRA, extra=vol.ALLOW_EXTRA,
@ -102,6 +115,7 @@ async def websocket_run(
end_stage = PipelineStage(msg["end_stage"]) end_stage = PipelineStage(msg["end_stage"])
handler_id: int | None = None handler_id: int | None = None
unregister_handler: Callable[[], None] | None = None unregister_handler: Callable[[], None] | None = None
wake_word_settings: WakeWordSettings | None = None
# Arguments to PipelineInput # Arguments to PipelineInput
input_args: dict[str, Any] = { input_args: dict[str, Any] = {
@ -109,24 +123,26 @@ async def websocket_run(
"device_id": msg.get("device_id"), "device_id": msg.get("device_id"),
} }
if start_stage == PipelineStage.STT: if start_stage in (PipelineStage.WAKE_WORD, PipelineStage.STT):
# Audio pipeline that will receive audio as binary websocket messages # Audio pipeline that will receive audio as binary websocket messages
audio_queue: asyncio.Queue[bytes] = asyncio.Queue() audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
incoming_sample_rate = msg["input"]["sample_rate"] incoming_sample_rate = msg["input"]["sample_rate"]
if start_stage == PipelineStage.WAKE_WORD:
wake_word_settings = WakeWordSettings(
timeout=msg["input"].get("timeout", DEFAULT_WAKE_WORD_TIMEOUT),
audio_seconds_to_buffer=msg["input"].get("audio_seconds_to_buffer", 0),
)
async def stt_stream() -> AsyncGenerator[bytes, None]: async def stt_stream() -> AsyncGenerator[bytes, None]:
state = None state = None
segmenter = VoiceCommandSegmenter()
# Yield until we receive an empty chunk # Yield until we receive an empty chunk
while chunk := await audio_queue.get(): while chunk := await audio_queue.get():
chunk, state = audioop.ratecv( if incoming_sample_rate != 16000:
chunk, 2, 1, incoming_sample_rate, 16000, state chunk, state = audioop.ratecv(
) chunk, 2, 1, incoming_sample_rate, 16000, state
if not segmenter.process(chunk): )
# Voice command is finished
break
yield chunk yield chunk
def handle_binary( def handle_binary(
@ -169,6 +185,7 @@ async def websocket_run(
"stt_binary_handler_id": handler_id, "stt_binary_handler_id": handler_id,
"timeout": timeout, "timeout": timeout,
}, },
wake_word_settings=wake_word_settings,
) )
pipeline_input = PipelineInput(**input_args) pipeline_input = PipelineInput(**input_args)

View File

@ -0,0 +1,119 @@
"""Provide functionality to wake word."""
from __future__ import annotations
from abc import abstractmethod
from collections.abc import AsyncIterable
import logging
from typing import final
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.helpers.typing import ConfigType
from homeassistant.util import dt as dt_util
from .const import DOMAIN
from .models import DetectionResult, WakeWord
__all__ = [
"async_default_engine",
"async_get_wake_word_detection_entity",
"DetectionResult",
"DOMAIN",
"WakeWord",
"WakeWordDetectionEntity",
]
_LOGGER = logging.getLogger(__name__)
CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
@callback
def async_default_engine(hass: HomeAssistant) -> str | None:
"""Return the domain or entity id of the default engine."""
return next(iter(hass.states.async_entity_ids(DOMAIN)), None)
@callback
def async_get_wake_word_detection_entity(
hass: HomeAssistant, entity_id: str
) -> WakeWordDetectionEntity | None:
"""Return wake word entity."""
component: EntityComponent = hass.data[DOMAIN]
return component.get_entity(entity_id)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up STT."""
component = hass.data[DOMAIN] = EntityComponent(_LOGGER, DOMAIN, hass)
component.register_shutdown()
return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up a config entry."""
component: EntityComponent = hass.data[DOMAIN]
return await component.async_setup_entry(entry)
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
component: EntityComponent = hass.data[DOMAIN]
return await component.async_unload_entry(entry)
class WakeWordDetectionEntity(RestoreEntity):
"""Represent a single wake word provider."""
_attr_should_poll = False
__last_processed: str | None = None
@property
@final
def state(self) -> str | None:
"""Return the state of the entity."""
if self.__last_processed is None:
return None
return self.__last_processed
@property
@abstractmethod
def supported_wake_words(self) -> list[WakeWord]:
"""Return a list of supported wake words."""
@abstractmethod
async def _async_process_audio_stream(
self, stream: AsyncIterable[tuple[bytes, int]]
) -> DetectionResult | None:
"""Try to detect wake word(s) in an audio stream with timestamps.
Audio must be 16Khz sample rate with 16-bit mono PCM samples.
"""
async def async_process_audio_stream(
self, stream: AsyncIterable[tuple[bytes, int]]
) -> DetectionResult | None:
"""Try to detect wake word(s) in an audio stream with timestamps.
Audio must be 16Khz sample rate with 16-bit mono PCM samples.
"""
self.__last_processed = dt_util.utcnow().isoformat()
self.async_write_ha_state()
return await self._async_process_audio_stream(stream)
async def async_internal_added_to_hass(self) -> None:
"""Call when the entity is added to hass."""
await super().async_internal_added_to_hass()
state = await self.async_get_last_state()
if (
state is not None
and state.state is not None
and state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
):
self.__last_processed = state.state

View File

@ -0,0 +1,2 @@
"""Wake word constants."""
DOMAIN = "wake_word"

View File

@ -0,0 +1,8 @@
{
"domain": "wake_word",
"name": "Wake-word detection",
"codeowners": ["@home-assistant/core", "@synesthesiam"],
"documentation": "https://www.home-assistant.io/integrations/wake_word",
"integration_type": "entity",
"quality_scale": "internal"
}

View File

@ -0,0 +1,24 @@
"""Wake word models."""
from dataclasses import dataclass
@dataclass(frozen=True)
class WakeWord:
"""Wake word model."""
ww_id: str
name: str
@dataclass
class DetectionResult:
"""Result of wake word detection."""
ww_id: str
"""Id of detected wake word"""
timestamp: int | None
"""Timestamp of audio chunk with detected wake word"""
queued_audio: list[tuple[bytes, int]] | None = None
"""Audio chunks that were queued when wake word was detected."""

View File

@ -50,14 +50,21 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
errors={"base": "cannot_connect"}, errors={"base": "cannot_connect"},
) )
# ASR = automated speech recognition (STT) # ASR = automated speech recognition (speech-to-text)
asr_installed = [asr for asr in service.info.asr if asr.installed] asr_installed = [asr for asr in service.info.asr if asr.installed]
# TTS = text-to-speech
tts_installed = [tts for tts in service.info.tts if tts.installed] tts_installed = [tts for tts in service.info.tts if tts.installed]
# wake-word-detection
wake_installed = [wake for wake in service.info.wake if wake.installed]
if asr_installed: if asr_installed:
name = asr_installed[0].name name = asr_installed[0].name
elif tts_installed: elif tts_installed:
name = tts_installed[0].name name = tts_installed[0].name
elif wake_installed:
name = wake_installed[0].name
else: else:
return self.async_abort(reason="no_services") return self.async_abort(reason="no_services")

View File

@ -29,6 +29,8 @@ class WyomingService:
platforms.append(Platform.STT) platforms.append(Platform.STT)
if any(tts.installed for tts in info.tts): if any(tts.installed for tts in info.tts):
platforms.append(Platform.TTS) platforms.append(Platform.TTS)
if any(wake.installed for wake in info.wake):
platforms.append(Platform.WAKE_WORD)
self.platforms = platforms self.platforms = platforms
@classmethod @classmethod

View File

@ -0,0 +1,157 @@
"""Support for Wyoming wake-word-detection services."""
import asyncio
from collections.abc import AsyncIterable
import logging
from wyoming.audio import AudioChunk, AudioStart
from wyoming.client import AsyncTcpClient
from wyoming.wake import Detection
from homeassistant.components import wake_word
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import DOMAIN
from .data import WyomingService
from .error import WyomingError
_LOGGER = logging.getLogger(__name__)
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up Wyoming wake-word-detection."""
service: WyomingService = hass.data[DOMAIN][config_entry.entry_id]
async_add_entities(
[
WyomingWakeWordProvider(config_entry, service),
]
)
class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity):
"""Wyoming wake-word-detection provider."""
def __init__(
self,
config_entry: ConfigEntry,
service: WyomingService,
) -> None:
"""Set up provider."""
self.service = service
wake_service = service.info.wake[0]
self._supported_wake_words = [
wake_word.WakeWord(ww_id=ww.name, name=ww.name)
for ww in wake_service.models
]
self._attr_name = wake_service.name
self._attr_unique_id = f"{config_entry.entry_id}-wake_word"
@property
def supported_wake_words(self) -> list[wake_word.WakeWord]:
"""Return a list of supported wake words."""
return self._supported_wake_words
async def _async_process_audio_stream(
self, stream: AsyncIterable[tuple[bytes, int]]
) -> wake_word.DetectionResult | None:
"""Try to detect one or more wake words in an audio stream.
Audio must be 16Khz sample rate with 16-bit mono PCM samples.
"""
async def next_chunk():
"""Get the next chunk from audio stream."""
async for chunk_bytes in stream:
return chunk_bytes
try:
async with AsyncTcpClient(self.service.host, self.service.port) as client:
await client.write_event(
AudioStart(
rate=16000,
width=2,
channels=1,
).event(),
)
# Read audio and wake events in "parallel"
audio_task = asyncio.create_task(next_chunk())
wake_task = asyncio.create_task(client.read_event())
pending = {audio_task, wake_task}
try:
while True:
done, pending = await asyncio.wait(
pending, return_when=asyncio.FIRST_COMPLETED
)
if wake_task in done:
event = wake_task.result()
if event is None:
_LOGGER.debug("Connection lost")
break
if Detection.is_type(event.type):
# Successful detection
detection = Detection.from_event(event)
_LOGGER.info(detection)
# Retrieve queued audio
queued_audio: list[tuple[bytes, int]] | None = None
if audio_task in pending:
# Save queued audio
await audio_task
pending.remove(audio_task)
queued_audio = [audio_task.result()]
return wake_word.DetectionResult(
ww_id=detection.name,
timestamp=detection.timestamp,
queued_audio=queued_audio,
)
# Next event
wake_task = asyncio.create_task(client.read_event())
pending.add(wake_task)
if audio_task in done:
# Forward audio to wake service
chunk_info = audio_task.result()
if chunk_info is None:
break
chunk_bytes, chunk_timestamp = chunk_info
chunk = AudioChunk(
rate=16000,
width=2,
channels=1,
audio=chunk_bytes,
timestamp=chunk_timestamp,
)
await client.write_event(chunk.event())
# Next chunk
audio_task = asyncio.create_task(next_chunk())
pending.add(audio_task)
finally:
# Clean up
if audio_task in pending:
# It's critical that we don't cancel the audio task or
# leave it hanging. This would mess up the pipeline STT
# by stopping the audio stream.
await audio_task
pending.remove(audio_task)
for task in pending:
task.cancel()
except (OSError, WyomingError) as err:
_LOGGER.exception("Error processing audio stream: %s", err)
return None

View File

@ -57,6 +57,7 @@ class Platform(StrEnum):
TTS = "tts" TTS = "tts"
VACUUM = "vacuum" VACUUM = "vacuum"
UPDATE = "update" UPDATE = "update"
WAKE_WORD = "wake_word"
WATER_HEATER = "water_heater" WATER_HEATER = "water_heater"
WEATHER = "weather" WEATHER = "weather"

View File

@ -7,7 +7,7 @@ from unittest.mock import AsyncMock
import pytest import pytest
from homeassistant.components import stt, tts from homeassistant.components import stt, tts, wake_word
from homeassistant.components.assist_pipeline import DOMAIN from homeassistant.components.assist_pipeline import DOMAIN
from homeassistant.components.assist_pipeline.pipeline import ( from homeassistant.components.assist_pipeline.pipeline import (
PipelineData, PipelineData,
@ -174,6 +174,40 @@ class MockSttPlatform(MockPlatform):
self.async_get_engine = async_get_engine self.async_get_engine = async_get_engine
class MockWakeWordEntity(wake_word.WakeWordDetectionEntity):
"""Mock wake word entity."""
fail_process_audio = False
url_path = "wake_word.test"
_attr_name = "test"
@property
def supported_wake_words(self) -> list[wake_word.WakeWord]:
"""Return a list of supported wake words."""
return [wake_word.WakeWord(ww_id="test_ww", name="Test Wake Word")]
async def _async_process_audio_stream(
self, stream: AsyncIterable[tuple[bytes, int]]
) -> wake_word.DetectionResult | None:
"""Try to detect wake word(s) in an audio stream with timestamps."""
async for chunk, timestamp in stream:
if chunk == b"wake word":
return wake_word.DetectionResult(
ww_id=self.supported_wake_words[0].ww_id,
timestamp=timestamp,
queued_audio=[(b"queued audio", 0)],
)
# Not detected
return None
@pytest.fixture
async def mock_wake_word_provider_entity(hass) -> MockWakeWordEntity:
"""Mock wake word provider."""
return MockWakeWordEntity()
class MockFlow(ConfigFlow): class MockFlow(ConfigFlow):
"""Test flow.""" """Test flow."""
@ -193,6 +227,7 @@ async def init_supporting_components(
mock_stt_provider: MockSttProvider, mock_stt_provider: MockSttProvider,
mock_stt_provider_entity: MockSttProviderEntity, mock_stt_provider_entity: MockSttProviderEntity,
mock_tts_provider: MockTTSProvider, mock_tts_provider: MockTTSProvider,
mock_wake_word_provider_entity: MockWakeWordEntity,
config_flow_fixture, config_flow_fixture,
): ):
"""Initialize relevant components with empty configs.""" """Initialize relevant components with empty configs."""
@ -201,14 +236,18 @@ async def init_supporting_components(
hass: HomeAssistant, config_entry: ConfigEntry hass: HomeAssistant, config_entry: ConfigEntry
) -> bool: ) -> bool:
"""Set up test config entry.""" """Set up test config entry."""
await hass.config_entries.async_forward_entry_setup(config_entry, stt.DOMAIN) await hass.config_entries.async_forward_entry_setups(
config_entry, [stt.DOMAIN, wake_word.DOMAIN]
)
return True return True
async def async_unload_entry_init( async def async_unload_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry hass: HomeAssistant, config_entry: ConfigEntry
) -> bool: ) -> bool:
"""Unload up test config entry.""" """Unload up test config entry."""
await hass.config_entries.async_forward_entry_unload(config_entry, stt.DOMAIN) await hass.config_entries.async_unload_platforms(
config_entry, [stt.DOMAIN, wake_word.DOMAIN]
)
return True return True
async def async_setup_entry_stt_platform( async def async_setup_entry_stt_platform(
@ -219,6 +258,14 @@ async def init_supporting_components(
"""Set up test stt platform via config entry.""" """Set up test stt platform via config entry."""
async_add_entities([mock_stt_provider_entity]) async_add_entities([mock_stt_provider_entity])
async def async_setup_entry_wake_word_platform(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up test wake word platform via config entry."""
async_add_entities([mock_wake_word_provider_entity])
mock_integration( mock_integration(
hass, hass,
MockModule( MockModule(
@ -242,11 +289,19 @@ async def init_supporting_components(
async_setup_entry=async_setup_entry_stt_platform, async_setup_entry=async_setup_entry_stt_platform,
), ),
) )
mock_platform(
hass,
"test.wake_word",
MockPlatform(
async_setup_entry=async_setup_entry_wake_word_platform,
),
)
mock_platform(hass, "test.config_flow") mock_platform(hass, "test.config_flow")
assert await async_setup_component(hass, "homeassistant", {}) assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}}) assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}})
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}}) assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}})
# assert await async_setup_component(hass, wake_word.DOMAIN, {"wake_word": {}})
assert await async_setup_component(hass, "media_source", {}) assert await async_setup_component(hass, "media_source", {})
config_entry = MockConfigEntry(domain="test") config_entry = MockConfigEntry(domain="test")

View File

@ -266,3 +266,114 @@
}), }),
]) ])
# --- # ---
# name: test_pipeline_from_audio_stream_wake_word
list([
dict({
'data': dict({
'language': 'en',
'pipeline': <ANY>,
}),
'type': <PipelineEventType.RUN_START: 'run-start'>,
}),
dict({
'data': dict({
'engine': 'wake_word.test',
'metadata': dict({
'bit_rate': <AudioBitRates.BITRATE_16: 16>,
'channel': <AudioChannels.CHANNEL_MONO: 1>,
'codec': <AudioCodecs.PCM: 'pcm'>,
'format': <AudioFormats.WAV: 'wav'>,
'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>,
}),
}),
'type': <PipelineEventType.WAKE_WORD_START: 'wake_word-start'>,
}),
dict({
'data': dict({
'wake_word_output': dict({
'timestamp': 2000,
'ww_id': 'test_ww',
}),
}),
'type': <PipelineEventType.WAKE_WORD_END: 'wake_word-end'>,
}),
dict({
'data': dict({
'engine': 'test',
'metadata': dict({
'bit_rate': <AudioBitRates.BITRATE_16: 16>,
'channel': <AudioChannels.CHANNEL_MONO: 1>,
'codec': <AudioCodecs.PCM: 'pcm'>,
'format': <AudioFormats.WAV: 'wav'>,
'language': 'en-US',
'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>,
}),
}),
'type': <PipelineEventType.STT_START: 'stt-start'>,
}),
dict({
'data': dict({
'stt_output': dict({
'text': 'test transcript',
}),
}),
'type': <PipelineEventType.STT_END: 'stt-end'>,
}),
dict({
'data': dict({
'conversation_id': None,
'device_id': None,
'engine': 'homeassistant',
'intent_input': 'test transcript',
'language': 'en',
}),
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
}),
dict({
'data': dict({
'intent_output': dict({
'conversation_id': None,
'response': dict({
'card': dict({
}),
'data': dict({
'code': 'no_intent_match',
}),
'language': 'en',
'response_type': 'error',
'speech': dict({
'plain': dict({
'extra_data': None,
'speech': "Sorry, I couldn't understand that",
}),
}),
}),
}),
}),
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
}),
dict({
'data': dict({
'engine': 'test',
'language': 'en-US',
'tts_input': "Sorry, I couldn't understand that",
'voice': 'james_earl_jones',
}),
'type': <PipelineEventType.TTS_START: 'tts-start'>,
}),
dict({
'data': dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3',
}),
}),
'type': <PipelineEventType.TTS_END: 'tts-end'>,
}),
dict({
'data': None,
'type': <PipelineEventType.RUN_END: 'run-end'>,
}),
])
# ---

View File

@ -155,6 +155,243 @@
}), }),
}) })
# --- # ---
# name: test_audio_pipeline_no_wake_word_engine
dict({
'code': 'wake-engine-missing',
'message': 'No wake word engine',
})
# ---
# name: test_audio_pipeline_no_wake_word_entity
dict({
'code': 'wake-provider-missing',
'message': 'No wake-word-detection provider for: wake_word.bad-entity-id',
})
# ---
# name: test_audio_pipeline_with_wake_word
dict({
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 30,
}),
})
# ---
# name: test_audio_pipeline_with_wake_word.1
dict({
'engine': 'wake_word.test',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
'codec': 'pcm',
'format': 'wav',
'sample_rate': 16000,
}),
})
# ---
# name: test_audio_pipeline_with_wake_word.2
dict({
'wake_word_output': dict({
'queued_audio': None,
'timestamp': 1000,
'ww_id': 'test_ww',
}),
})
# ---
# name: test_audio_pipeline_with_wake_word.3
dict({
'engine': 'test',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
'codec': 'pcm',
'format': 'wav',
'language': 'en-US',
'sample_rate': 16000,
}),
})
# ---
# name: test_audio_pipeline_with_wake_word.4
dict({
'stt_output': dict({
'text': 'test transcript',
}),
})
# ---
# name: test_audio_pipeline_with_wake_word.5
dict({
'conversation_id': None,
'device_id': None,
'engine': 'homeassistant',
'intent_input': 'test transcript',
'language': 'en',
})
# ---
# name: test_audio_pipeline_with_wake_word.6
dict({
'intent_output': dict({
'conversation_id': None,
'response': dict({
'card': dict({
}),
'data': dict({
'code': 'no_intent_match',
}),
'language': 'en',
'response_type': 'error',
'speech': dict({
'plain': dict({
'extra_data': None,
'speech': "Sorry, I couldn't understand that",
}),
}),
}),
}),
})
# ---
# name: test_audio_pipeline_with_wake_word.7
dict({
'engine': 'test',
'language': 'en-US',
'tts_input': "Sorry, I couldn't understand that",
'voice': 'james_earl_jones',
})
# ---
# name: test_audio_pipeline_with_wake_word.8
dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3',
}),
})
# ---
# name: test_audio_pipeline_with_wake_word_no_timeout
dict({
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 30,
}),
})
# ---
# name: test_audio_pipeline_with_wake_word_no_timeout.1
dict({
'engine': 'wake_word.test',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
'codec': 'pcm',
'format': 'wav',
'sample_rate': 16000,
}),
})
# ---
# name: test_audio_pipeline_with_wake_word_no_timeout.2
dict({
'wake_word_output': dict({
'timestamp': 0,
'ww_id': 'test_ww',
}),
})
# ---
# name: test_audio_pipeline_with_wake_word_no_timeout.3
dict({
'engine': 'test',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
'codec': 'pcm',
'format': 'wav',
'language': 'en-US',
'sample_rate': 16000,
}),
})
# ---
# name: test_audio_pipeline_with_wake_word_no_timeout.4
dict({
'stt_output': dict({
'text': 'test transcript',
}),
})
# ---
# name: test_audio_pipeline_with_wake_word_no_timeout.5
dict({
'conversation_id': None,
'device_id': None,
'engine': 'homeassistant',
'intent_input': 'test transcript',
'language': 'en',
})
# ---
# name: test_audio_pipeline_with_wake_word_no_timeout.6
dict({
'intent_output': dict({
'conversation_id': None,
'response': dict({
'card': dict({
}),
'data': dict({
'code': 'no_intent_match',
}),
'language': 'en',
'response_type': 'error',
'speech': dict({
'plain': dict({
'extra_data': None,
'speech': "Sorry, I couldn't understand that",
}),
}),
}),
}),
})
# ---
# name: test_audio_pipeline_with_wake_word_no_timeout.7
dict({
'engine': 'test',
'language': 'en-US',
'tts_input': "Sorry, I couldn't understand that",
'voice': 'james_earl_jones',
})
# ---
# name: test_audio_pipeline_with_wake_word_no_timeout.8
dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3',
}),
})
# ---
# name: test_audio_pipeline_with_wake_word_timeout
dict({
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 30,
}),
})
# ---
# name: test_audio_pipeline_with_wake_word_timeout.1
dict({
'engine': 'wake_word.test',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
'codec': 'pcm',
'format': 'wav',
'sample_rate': 16000,
}),
})
# ---
# name: test_audio_pipeline_with_wake_word_timeout.2
dict({
'code': 'wake-word-timeout',
'message': 'Wake word was not detected',
})
# ---
# name: test_intent_failed # name: test_intent_failed
dict({ dict({
'language': 'en', 'language': 'en',

View File

@ -1,5 +1,6 @@
"""Test Voice Assistant init.""" """Test Voice Assistant init."""
from dataclasses import asdict from dataclasses import asdict
import itertools as it
from unittest.mock import ANY from unittest.mock import ANY
import pytest import pytest
@ -8,10 +9,12 @@ from syrupy.assertion import SnapshotAssertion
from homeassistant.components import assist_pipeline, stt from homeassistant.components import assist_pipeline, stt
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from .conftest import MockSttProvider, MockSttProviderEntity from .conftest import MockSttProvider, MockSttProviderEntity, MockWakeWordEntity
from tests.typing import WebSocketGenerator from tests.typing import WebSocketGenerator
BYTES_ONE_SECOND = 16000 * 2
def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]: def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]:
"""Process events to remove dynamic values.""" """Process events to remove dynamic values."""
@ -280,3 +283,61 @@ async def test_pipeline_from_audio_stream_unknown_pipeline(
) )
assert not events assert not events
async def test_pipeline_from_audio_stream_wake_word(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
mock_wake_word_provider_entity: MockWakeWordEntity,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test creating a pipeline from an audio stream with wake word."""
events = []
# [0, 1, ...]
wake_chunk_1 = bytes(it.islice(it.cycle(range(256)), BYTES_ONE_SECOND))
# [0, 2, ...]
wake_chunk_2 = bytes(it.islice(it.cycle(range(0, 256, 2)), BYTES_ONE_SECOND))
async def audio_data():
yield wake_chunk_1 # 1 second
yield wake_chunk_2 # 1 second
yield b"wake word"
yield b"part1"
yield b"part2"
yield b""
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
Context(),
events.append,
stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
audio_data(),
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
wake_word_settings=assist_pipeline.WakeWordSettings(
audio_seconds_to_buffer=1.5
),
)
assert process_events(events) == snapshot
# 1. Half of wake_chunk_1 + all wake_chunk_2
# 2. queued audio (from mock wake word entity)
# 3. part1
# 4. part2
assert len(mock_stt_provider.received) == 4
first_chunk = mock_stt_provider.received[0]
assert first_chunk == wake_chunk_1[len(wake_chunk_1) // 2 :] + wake_chunk_2
assert mock_stt_provider.received[1:] == [b"queued audio", b"part1", b"part2"]

View File

@ -167,6 +167,224 @@ async def test_audio_pipeline(
assert msg["result"] == {"events": events} assert msg["result"] == {"events": events}
async def test_audio_pipeline_with_wake_word_timeout(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test timeout from a pipeline run with audio input/output + wake word."""
events = []
client = await hass_ws_client(hass)
await client.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "wake_word",
"end_stage": "tts",
"input": {
"sample_rate": 16000,
"timeout": 1,
},
}
)
# result
msg = await client.receive_json()
assert msg["success"], msg
# run start
msg = await client.receive_json()
assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# wake_word
msg = await client.receive_json()
assert msg["event"]["type"] == "wake_word-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# 2 seconds of silence
await client.send_bytes(bytes([1]) + bytes(16000 * 2 * 2))
# Time out error
msg = await client.receive_json()
assert msg["event"]["type"] == "error"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
async def test_audio_pipeline_with_wake_word_no_timeout(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test events from a pipeline run with audio input/output + wake word with no timeout."""
events = []
client = await hass_ws_client(hass)
await client.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "wake_word",
"end_stage": "tts",
"input": {
"sample_rate": 16000,
"timeout": 0,
},
}
)
# result
msg = await client.receive_json()
assert msg["success"], msg
# run start
msg = await client.receive_json()
assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# wake_word
msg = await client.receive_json()
assert msg["event"]["type"] == "wake_word-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# "audio"
await client.send_bytes(bytes([1]) + b"wake word")
msg = await client.receive_json()
assert msg["event"]["type"] == "wake_word-end"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# stt
msg = await client.receive_json()
assert msg["event"]["type"] == "stt-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# End of audio stream (handler id + empty payload)
await client.send_bytes(bytes([1]))
msg = await client.receive_json()
assert msg["event"]["type"] == "stt-end"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# intent
msg = await client.receive_json()
assert msg["event"]["type"] == "intent-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
msg = await client.receive_json()
assert msg["event"]["type"] == "intent-end"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# text-to-speech
msg = await client.receive_json()
assert msg["event"]["type"] == "tts-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
msg = await client.receive_json()
assert msg["event"]["type"] == "tts-end"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# run end
msg = await client.receive_json()
assert msg["event"]["type"] == "run-end"
assert msg["event"]["data"] is None
events.append(msg["event"])
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_id = list(pipeline_data.pipeline_runs)[0]
pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0]
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline_debug/get",
"pipeline_id": pipeline_id,
"pipeline_run_id": pipeline_run_id,
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"events": events}
async def test_audio_pipeline_no_wake_word_engine(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test timeout from a pipeline run with audio input/output + wake word."""
client = await hass_ws_client(hass)
with patch(
"homeassistant.components.wake_word.async_default_engine", return_value=None
):
await client.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "wake_word",
"end_stage": "tts",
"input": {
"sample_rate": 16000,
},
}
)
# error
msg = await client.receive_json()
assert not msg["success"]
assert "error" in msg
assert msg["error"] == snapshot
async def test_audio_pipeline_no_wake_word_entity(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test timeout from a pipeline run with audio input/output + wake word."""
client = await hass_ws_client(hass)
with patch(
"homeassistant.components.wake_word.async_default_engine",
return_value="wake_word.bad-entity-id",
), patch(
"homeassistant.components.wake_word.async_get_wake_word_detection_entity",
return_value=None,
):
await client.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "wake_word",
"end_stage": "tts",
"input": {
"sample_rate": 16000,
},
}
)
# error
msg = await client.receive_json()
assert not msg["success"]
assert "error" in msg
assert msg["error"] == snapshot
async def test_intent_timeout( async def test_intent_timeout(
hass: HomeAssistant, hass: HomeAssistant,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,

View File

@ -0,0 +1 @@
"""Wake-word-detection tests."""

View File

@ -0,0 +1,29 @@
"""Provide common test tools for wake-word-detection."""
from __future__ import annotations
from collections.abc import Callable, Coroutine
from pathlib import Path
from typing import Any
from homeassistant.components import wake_word
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from tests.common import MockPlatform, mock_platform
def mock_wake_word_entity_platform(
hass: HomeAssistant,
tmp_path: Path,
integration: str,
async_setup_entry: Callable[
[HomeAssistant, ConfigEntry, AddEntitiesCallback],
Coroutine[Any, Any, None],
]
| None = None,
) -> MockPlatform:
"""Specialize the mock platform for stt."""
loaded_platform = MockPlatform(async_setup_entry=async_setup_entry)
mock_platform(hass, f"{integration}.{wake_word.DOMAIN}", loaded_platform)
return loaded_platform

View File

@ -0,0 +1,11 @@
# serializer version: 1
# name: test_ws_detect
dict({
'event': dict({
'timestamp': 2048.0,
'ww_id': 'test_ww',
}),
'id': 1,
'type': 'event',
})
# ---

View File

@ -0,0 +1,226 @@
"""Test wake_word component setup."""
from collections.abc import AsyncIterable, Generator
from pathlib import Path
import pytest
from homeassistant.components import wake_word
from homeassistant.config_entries import ConfigEntry, ConfigEntryState, ConfigFlow
from homeassistant.core import HomeAssistant, State
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.setup import async_setup_component
from .common import mock_wake_word_entity_platform
from tests.common import (
MockConfigEntry,
MockModule,
mock_config_flow,
mock_integration,
mock_platform,
mock_restore_cache,
)
TEST_DOMAIN = "test"
_SAMPLES_PER_CHUNK = 1024
_BYTES_PER_CHUNK = _SAMPLES_PER_CHUNK * 2 # 16-bit
_MS_PER_CHUNK = (_BYTES_PER_CHUNK // 2) // 16 # 16Khz
class MockProviderEntity(wake_word.WakeWordDetectionEntity):
"""Mock provider entity."""
url_path = "wake_word.test"
_attr_name = "test"
@property
def supported_wake_words(self) -> list[wake_word.WakeWord]:
"""Return a list of supported wake words."""
return [wake_word.WakeWord(ww_id="test_ww", name="Test Wake Word")]
async def _async_process_audio_stream(
self, stream: AsyncIterable[tuple[bytes, int]]
) -> wake_word.DetectionResult | None:
"""Try to detect wake word(s) in an audio stream with timestamps."""
async for _chunk, timestamp in stream:
if timestamp >= 2000:
return wake_word.DetectionResult(
ww_id=self.supported_wake_words[0].ww_id, timestamp=timestamp
)
# Not detected
return None
@pytest.fixture
def mock_provider_entity() -> MockProviderEntity:
"""Test provider entity fixture."""
return MockProviderEntity()
class WakeWordFlow(ConfigFlow):
"""Test flow."""
@pytest.fixture(autouse=True)
def config_flow_fixture(hass: HomeAssistant) -> Generator[None, None, None]:
"""Mock config flow."""
mock_platform(hass, f"{TEST_DOMAIN}.config_flow")
with mock_config_flow(TEST_DOMAIN, WakeWordFlow):
yield
@pytest.fixture(name="setup")
async def setup_fixture(
hass: HomeAssistant,
tmp_path: Path,
) -> MockProviderEntity:
"""Set up the test environment."""
provider = MockProviderEntity()
await mock_config_entry_setup(hass, tmp_path, provider)
return provider
async def mock_config_entry_setup(
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
) -> MockConfigEntry:
"""Set up a test provider via config entry."""
async def async_setup_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Set up test config entry."""
await hass.config_entries.async_forward_entry_setup(
config_entry, wake_word.DOMAIN
)
return True
async def async_unload_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Unload up test config entry."""
await hass.config_entries.async_forward_entry_unload(
config_entry, wake_word.DOMAIN
)
return True
mock_integration(
hass,
MockModule(
TEST_DOMAIN,
async_setup_entry=async_setup_entry_init,
async_unload_entry=async_unload_entry_init,
),
)
async def async_setup_entry_platform(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up test stt platform via config entry."""
async_add_entities([mock_provider_entity])
mock_wake_word_entity_platform(
hass, tmp_path, TEST_DOMAIN, async_setup_entry_platform
)
config_entry = MockConfigEntry(domain=TEST_DOMAIN)
config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
return config_entry
async def test_config_entry_unload(
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
) -> None:
"""Test we can unload config entry."""
config_entry = await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
assert config_entry.state == ConfigEntryState.LOADED
await hass.config_entries.async_unload(config_entry.entry_id)
assert config_entry.state == ConfigEntryState.NOT_LOADED
async def test_detected_entity(
hass: HomeAssistant, tmp_path: Path, setup: MockProviderEntity
) -> None:
"""Test successful detection through entity."""
async def three_second_stream():
timestamp = 0
while timestamp < 3000:
yield bytes(_BYTES_PER_CHUNK), timestamp
timestamp += _MS_PER_CHUNK
# Need 2 seconds to trigger
result = await setup.async_process_audio_stream(three_second_stream())
assert result == wake_word.DetectionResult("test_ww", 2048)
async def test_not_detected_entity(
hass: HomeAssistant, setup: MockProviderEntity
) -> None:
"""Test unsuccessful detection through entity."""
async def one_second_stream():
timestamp = 0
while timestamp < 1000:
yield bytes(_BYTES_PER_CHUNK), timestamp
timestamp += _MS_PER_CHUNK
# Need 2 seconds to trigger
result = await setup.async_process_audio_stream(one_second_stream())
assert result is None
async def test_default_engine_none(hass: HomeAssistant, tmp_path: Path) -> None:
"""Test async_default_engine."""
assert await async_setup_component(hass, wake_word.DOMAIN, {wake_word.DOMAIN: {}})
await hass.async_block_till_done()
assert wake_word.async_default_engine(hass) is None
async def test_default_engine_entity(
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
) -> None:
"""Test async_default_engine."""
await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
assert wake_word.async_default_engine(hass) == f"{wake_word.DOMAIN}.{TEST_DOMAIN}"
async def test_get_engine_entity(
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
) -> None:
"""Test async_get_speech_to_text_engine."""
await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
assert (
wake_word.async_get_wake_word_detection_entity(hass, f"{wake_word.DOMAIN}.test")
is mock_provider_entity
)
async def test_restore_state(
hass: HomeAssistant,
tmp_path: Path,
mock_provider_entity: MockProviderEntity,
) -> None:
"""Test we restore state in the integration."""
entity_id = f"{wake_word.DOMAIN}.{TEST_DOMAIN}"
timestamp = "2023-01-01T23:59:59+00:00"
mock_restore_cache(hass, (State(entity_id, timestamp),))
config_entry = await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
await hass.async_block_till_done()
assert config_entry.state == ConfigEntryState.LOADED
state = hass.states.get(entity_id)
assert state
assert state.state == timestamp

View File

@ -1,4 +1,6 @@
"""Tests for the Wyoming integration.""" """Tests for the Wyoming integration."""
import asyncio
from wyoming.info import ( from wyoming.info import (
AsrModel, AsrModel,
AsrProgram, AsrProgram,
@ -7,6 +9,8 @@ from wyoming.info import (
TtsProgram, TtsProgram,
TtsVoice, TtsVoice,
TtsVoiceSpeaker, TtsVoiceSpeaker,
WakeModel,
WakeProgram,
) )
TEST_ATTR = Attribution(name="Test", url="http://www.test.com") TEST_ATTR = Attribution(name="Test", url="http://www.test.com")
@ -49,6 +53,25 @@ TTS_INFO = Info(
) )
] ]
) )
WAKE_WORD_INFO = Info(
wake=[
WakeProgram(
name="Test Wake Word",
description="Test Wake Word",
installed=True,
attribution=TEST_ATTR,
models=[
WakeModel(
name="Test Model",
description="Test Model",
installed=True,
attribution=TEST_ATTR,
languages=["en-US"],
)
],
)
]
)
EMPTY_INFO = Info() EMPTY_INFO = Info()
@ -68,6 +91,7 @@ class MockAsyncTcpClient:
async def read_event(self): async def read_event(self):
"""Receive.""" """Receive."""
await asyncio.sleep(0) # force context switch
return self.responses.pop(0) return self.responses.pop(0)
async def __aenter__(self): async def __aenter__(self):

View File

@ -8,7 +8,7 @@ from homeassistant.components import stt
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from . import STT_INFO, TTS_INFO from . import STT_INFO, TTS_INFO, WAKE_WORD_INFO
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -52,6 +52,21 @@ def tts_config_entry(hass: HomeAssistant) -> ConfigEntry:
return entry return entry
@pytest.fixture
def wake_word_config_entry(hass: HomeAssistant) -> ConfigEntry:
"""Create a config entry."""
entry = MockConfigEntry(
domain="wyoming",
data={
"host": "1.2.3.4",
"port": 1234,
},
title="Test Wake Word",
)
entry.add_to_hass(hass)
return entry
@pytest.fixture @pytest.fixture
async def init_wyoming_stt(hass: HomeAssistant, stt_config_entry: ConfigEntry): async def init_wyoming_stt(hass: HomeAssistant, stt_config_entry: ConfigEntry):
"""Initialize Wyoming STT.""" """Initialize Wyoming STT."""
@ -72,6 +87,18 @@ async def init_wyoming_tts(hass: HomeAssistant, tts_config_entry: ConfigEntry):
await hass.config_entries.async_setup(tts_config_entry.entry_id) await hass.config_entries.async_setup(tts_config_entry.entry_id)
@pytest.fixture
async def init_wyoming_wake_word(
hass: HomeAssistant, wake_word_config_entry: ConfigEntry
):
"""Initialize Wyoming Wake Word."""
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=WAKE_WORD_INFO,
):
await hass.config_entries.async_setup(wake_word_config_entry.entry_id)
@pytest.fixture @pytest.fixture
def metadata(hass: HomeAssistant) -> stt.SpeechMetadata: def metadata(hass: HomeAssistant) -> stt.SpeechMetadata:
"""Get default STT metadata.""" """Get default STT metadata."""

View File

@ -0,0 +1,13 @@
# serializer version: 1
# name: test_streaming_audio
dict({
'queued_audio': list([
tuple(
b'chunk',
1,
),
]),
'timestamp': 0,
'ww_id': 'Test Model',
})
# ---

View File

@ -0,0 +1,108 @@
"""Test stt."""
from __future__ import annotations
import asyncio
from unittest.mock import patch
from syrupy.assertion import SnapshotAssertion
from wyoming.asr import Transcript
from wyoming.wake import Detection
from homeassistant.components import wake_word
from homeassistant.core import HomeAssistant
from . import MockAsyncTcpClient
async def test_support(hass: HomeAssistant, init_wyoming_wake_word) -> None:
"""Test supported properties."""
state = hass.states.get("wake_word.test_wake_word")
assert state is not None
entity = wake_word.async_get_wake_word_detection_entity(
hass, "wake_word.test_wake_word"
)
assert entity is not None
assert entity.supported_wake_words == [
wake_word.WakeWord(ww_id="Test Model", name="Test Model")
]
async def test_streaming_audio(
hass: HomeAssistant, init_wyoming_wake_word, snapshot: SnapshotAssertion
) -> None:
"""Test streaming audio."""
entity = wake_word.async_get_wake_word_detection_entity(
hass, "wake_word.test_wake_word"
)
assert entity is not None
async def audio_stream():
yield b"chunk", 0
# Delay to force a pending audio chunk
await asyncio.sleep(0.05)
yield b"chunk", 1
client_events = [
Transcript("not a wake word event").event(),
Detection(name="Test Model", timestamp=0).event(),
]
with patch(
"homeassistant.components.wyoming.wake_word.AsyncTcpClient",
MockAsyncTcpClient(client_events),
):
result = await entity.async_process_audio_stream(audio_stream())
assert result is not None
assert result == snapshot
async def test_streaming_audio_connection_lost(
hass: HomeAssistant, init_wyoming_wake_word
) -> None:
"""Test streaming audio and losing connection."""
entity = wake_word.async_get_wake_word_detection_entity(
hass, "wake_word.test_wake_word"
)
assert entity is not None
async def audio_stream():
# Delay to force a pending audio chunk
await asyncio.sleep(0.05)
yield b"chunk", 1
with patch(
"homeassistant.components.wyoming.wake_word.AsyncTcpClient",
MockAsyncTcpClient([None]),
):
result = await entity.async_process_audio_stream(audio_stream())
assert result is None
async def test_streaming_audio_oserror(
hass: HomeAssistant, init_wyoming_wake_word
) -> None:
"""Test streaming audio and error raising."""
entity = wake_word.async_get_wake_word_detection_entity(
hass, "wake_word.test_wake_word"
)
assert entity is not None
async def audio_stream():
yield b"chunk1", 1000
mock_client = MockAsyncTcpClient(
[Detection(name="Test Model", timestamp=1000).event()]
)
with patch(
"homeassistant.components.wyoming.wake_word.AsyncTcpClient",
mock_client,
), patch.object(mock_client, "read_event", side_effect=OSError("Boom!")):
result = await entity.async_process_audio_stream(audio_stream())
assert result is None