Wake word cleanup (#98652)

* Make arguments for async_pipeline_from_audio_stream keyword-only after hass

* Use a bytearray ring buffer

* Move generator outside

* Move stt stream generator outside

* Clean up execute

* Refactor VAD to use bytearray

* More tests

* Refactor chunk_samples to be more correct and robust

* Change AudioBuffer to use append instead of setitem

* Cleanup

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Michael Hansen 2023-08-25 12:28:48 -05:00 committed by GitHub
parent 49897341ba
commit 8768c39021
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 458 additions and 163 deletions

View File

@ -52,6 +52,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async def async_pipeline_from_audio_stream( async def async_pipeline_from_audio_stream(
hass: HomeAssistant, hass: HomeAssistant,
*,
context: Context, context: Context,
event_callback: PipelineEventCallback, event_callback: PipelineEventCallback,
stt_metadata: stt.SpeechMetadata, stt_metadata: stt.SpeechMetadata,

View File

@ -49,6 +49,7 @@ from .error import (
WakeWordDetectionError, WakeWordDetectionError,
WakeWordTimeoutError, WakeWordTimeoutError,
) )
from .ring_buffer import RingBuffer
from .vad import VoiceActivityTimeout, VoiceCommandSegmenter from .vad import VoiceActivityTimeout, VoiceCommandSegmenter
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -425,7 +426,6 @@ class PipelineRun:
async def prepare_wake_word_detection(self) -> None: async def prepare_wake_word_detection(self) -> None:
"""Prepare wake-word-detection.""" """Prepare wake-word-detection."""
# Need to add to pipeline store
engine = wake_word.async_default_engine(self.hass) engine = wake_word.async_default_engine(self.hass)
if engine is None: if engine is None:
raise WakeWordDetectionError( raise WakeWordDetectionError(
@ -448,7 +448,7 @@ class PipelineRun:
async def wake_word_detection( async def wake_word_detection(
self, self,
stream: AsyncIterable[bytes], stream: AsyncIterable[bytes],
audio_buffer: list[bytes], audio_chunks_for_stt: list[bytes],
) -> wake_word.DetectionResult | None: ) -> wake_word.DetectionResult | None:
"""Run wake-word-detection portion of pipeline. Returns detection result.""" """Run wake-word-detection portion of pipeline. Returns detection result."""
metadata_dict = asdict( metadata_dict = asdict(
@ -484,46 +484,29 @@ class PipelineRun:
# Use VAD to determine timeout # Use VAD to determine timeout
wake_word_vad = VoiceActivityTimeout(wake_word_settings.timeout) wake_word_vad = VoiceActivityTimeout(wake_word_settings.timeout)
# Audio chunk buffer. # Audio chunk buffer. This audio will be forwarded to speech-to-text
audio_bytes_to_buffer = int( # after wake-word-detection.
wake_word_settings.audio_seconds_to_buffer * 16000 * 2 num_audio_bytes_to_buffer = int(
wake_word_settings.audio_seconds_to_buffer * 16000 * 2 # 16-bit @ 16Khz
) )
audio_ring_buffer = b"" stt_audio_buffer: RingBuffer | None = None
if num_audio_bytes_to_buffer > 0:
async def timestamped_stream() -> AsyncIterable[tuple[bytes, int]]: stt_audio_buffer = RingBuffer(num_audio_bytes_to_buffer)
"""Yield audio with timestamps (milliseconds since start of stream)."""
nonlocal audio_ring_buffer
timestamp_ms = 0
async for chunk in stream:
yield chunk, timestamp_ms
timestamp_ms += (len(chunk) // 2) // 16 # milliseconds @ 16Khz
# Keeping audio right before wake word detection allows the
# voice command to be spoken immediately after the wake word.
if audio_bytes_to_buffer > 0:
audio_ring_buffer += chunk
if len(audio_ring_buffer) > audio_bytes_to_buffer:
# A proper ring buffer would be far more efficient
audio_ring_buffer = audio_ring_buffer[
len(audio_ring_buffer) - audio_bytes_to_buffer :
]
if (wake_word_vad is not None) and (not wake_word_vad.process(chunk)):
raise WakeWordTimeoutError(
code="wake-word-timeout", message="Wake word was not detected"
)
try: try:
# Detect wake word(s) # Detect wake word(s)
result = await self.wake_word_provider.async_process_audio_stream( result = await self.wake_word_provider.async_process_audio_stream(
timestamped_stream() _wake_word_audio_stream(
audio_stream=stream,
stt_audio_buffer=stt_audio_buffer,
wake_word_vad=wake_word_vad,
)
) )
if audio_ring_buffer: if stt_audio_buffer is not None:
# All audio kept from right before the wake word was detected as # All audio kept from right before the wake word was detected as
# a single chunk. # a single chunk.
audio_buffer.append(audio_ring_buffer) audio_chunks_for_stt.append(stt_audio_buffer.getvalue())
except WakeWordTimeoutError: except WakeWordTimeoutError:
_LOGGER.debug("Timeout during wake word detection") _LOGGER.debug("Timeout during wake word detection")
raise raise
@ -540,9 +523,14 @@ class PipelineRun:
wake_word_output: dict[str, Any] = {} wake_word_output: dict[str, Any] = {}
else: else:
if result.queued_audio: if result.queued_audio:
# Add audio that was pending at detection # Add audio that was pending at detection.
#
# Because detection occurs *after* the wake word was actually
# spoken, we need to make sure pending audio is forwarded to
# speech-to-text so the user does not have to pause before
# speaking the voice command.
for chunk_ts in result.queued_audio: for chunk_ts in result.queued_audio:
audio_buffer.append(chunk_ts[0]) audio_chunks_for_stt.append(chunk_ts[0])
wake_word_output = asdict(result) wake_word_output = asdict(result)
@ -608,41 +596,12 @@ class PipelineRun:
) )
try: try:
segmenter = VoiceCommandSegmenter()
async def segment_stream(
stream: AsyncIterable[bytes],
) -> AsyncGenerator[bytes, None]:
"""Stop stream when voice command is finished."""
sent_vad_start = False
timestamp_ms = 0
async for chunk in stream:
if not segmenter.process(chunk):
# Silence detected at the end of voice command
self.process_event(
PipelineEvent(
PipelineEventType.STT_VAD_END,
{"timestamp": timestamp_ms},
)
)
break
if segmenter.in_command and (not sent_vad_start):
# Speech detected at start of voice command
self.process_event(
PipelineEvent(
PipelineEventType.STT_VAD_START,
{"timestamp": timestamp_ms},
)
)
sent_vad_start = True
yield chunk
timestamp_ms += (len(chunk) // 2) // 16 # milliseconds @ 16Khz
# Transcribe audio stream # Transcribe audio stream
result = await self.stt_provider.async_process_audio_stream( result = await self.stt_provider.async_process_audio_stream(
metadata, segment_stream(stream) metadata,
self._speech_to_text_stream(
audio_stream=stream, stt_vad=VoiceCommandSegmenter()
),
) )
except Exception as src_error: except Exception as src_error:
_LOGGER.exception("Unexpected error during speech-to-text") _LOGGER.exception("Unexpected error during speech-to-text")
@ -677,6 +636,42 @@ class PipelineRun:
return result.text return result.text
async def _speech_to_text_stream(
self,
audio_stream: AsyncIterable[bytes],
stt_vad: VoiceCommandSegmenter | None,
sample_rate: int = 16000,
sample_width: int = 2,
) -> AsyncGenerator[bytes, None]:
"""Yield audio chunks until VAD detects silence or speech-to-text completes."""
ms_per_sample = sample_rate // 1000
sent_vad_start = False
timestamp_ms = 0
async for chunk in audio_stream:
if stt_vad is not None:
if not stt_vad.process(chunk):
# Silence detected at the end of voice command
self.process_event(
PipelineEvent(
PipelineEventType.STT_VAD_END,
{"timestamp": timestamp_ms},
)
)
break
if stt_vad.in_command and (not sent_vad_start):
# Speech detected at start of voice command
self.process_event(
PipelineEvent(
PipelineEventType.STT_VAD_START,
{"timestamp": timestamp_ms},
)
)
sent_vad_start = True
yield chunk
timestamp_ms += (len(chunk) // sample_width) // ms_per_sample
async def prepare_recognize_intent(self) -> None: async def prepare_recognize_intent(self) -> None:
"""Prepare recognizing an intent.""" """Prepare recognizing an intent."""
agent_info = conversation.async_get_agent_info( agent_info = conversation.async_get_agent_info(
@ -861,13 +856,14 @@ class PipelineInput:
"""Run pipeline.""" """Run pipeline."""
self.run.start() self.run.start()
current_stage: PipelineStage | None = self.run.start_stage current_stage: PipelineStage | None = self.run.start_stage
audio_buffer: list[bytes] = [] stt_audio_buffer: list[bytes] = []
try: try:
if current_stage == PipelineStage.WAKE_WORD: if current_stage == PipelineStage.WAKE_WORD:
# wake-word-detection
assert self.stt_stream is not None assert self.stt_stream is not None
detect_result = await self.run.wake_word_detection( detect_result = await self.run.wake_word_detection(
self.stt_stream, audio_buffer self.stt_stream, stt_audio_buffer
) )
if detect_result is None: if detect_result is None:
# No wake word. Abort the rest of the pipeline. # No wake word. Abort the rest of the pipeline.
@ -882,19 +878,22 @@ class PipelineInput:
assert self.stt_metadata is not None assert self.stt_metadata is not None
assert self.stt_stream is not None assert self.stt_stream is not None
if audio_buffer: stt_stream = self.stt_stream
async def buffered_stream() -> AsyncGenerator[bytes, None]: if stt_audio_buffer:
for chunk in audio_buffer: # Send audio in the buffer first to speech-to-text, then move on to stt_stream.
# This is basically an async itertools.chain.
async def buffer_then_audio_stream() -> AsyncGenerator[bytes, None]:
# Buffered audio
for chunk in stt_audio_buffer:
yield chunk yield chunk
# Streamed audio
assert self.stt_stream is not None assert self.stt_stream is not None
async for chunk in self.stt_stream: async for chunk in self.stt_stream:
yield chunk yield chunk
stt_stream = cast(AsyncIterable[bytes], buffered_stream()) stt_stream = buffer_then_audio_stream()
else:
stt_stream = self.stt_stream
intent_input = await self.run.speech_to_text( intent_input = await self.run.speech_to_text(
self.stt_metadata, self.stt_metadata,
@ -906,6 +905,7 @@ class PipelineInput:
tts_input = self.tts_input tts_input = self.tts_input
if current_stage == PipelineStage.INTENT: if current_stage == PipelineStage.INTENT:
# intent-recognition
assert intent_input is not None assert intent_input is not None
tts_input = await self.run.recognize_intent( tts_input = await self.run.recognize_intent(
intent_input, intent_input,
@ -915,6 +915,7 @@ class PipelineInput:
current_stage = PipelineStage.TTS current_stage = PipelineStage.TTS
if self.run.end_stage != PipelineStage.INTENT: if self.run.end_stage != PipelineStage.INTENT:
# text-to-speech
if current_stage == PipelineStage.TTS: if current_stage == PipelineStage.TTS:
assert tts_input is not None assert tts_input is not None
await self.run.text_to_speech(tts_input) await self.run.text_to_speech(tts_input)
@ -999,6 +1000,36 @@ class PipelineInput:
await asyncio.gather(*prepare_tasks) await asyncio.gather(*prepare_tasks)
async def _wake_word_audio_stream(
audio_stream: AsyncIterable[bytes],
stt_audio_buffer: RingBuffer | None,
wake_word_vad: VoiceActivityTimeout | None,
sample_rate: int = 16000,
sample_width: int = 2,
) -> AsyncIterable[tuple[bytes, int]]:
"""Yield audio chunks with timestamps (milliseconds since start of stream).
Adds audio to a ring buffer that will be forwarded to speech-to-text after
detection. Times out if VAD detects enough silence.
"""
ms_per_sample = sample_rate // 1000
timestamp_ms = 0
async for chunk in audio_stream:
yield chunk, timestamp_ms
timestamp_ms += (len(chunk) // sample_width) // ms_per_sample
# Wake-word-detection occurs *after* the wake word was actually
# spoken. Keeping audio right before detection allows the voice
# command to be spoken immediately after the wake word.
if stt_audio_buffer is not None:
stt_audio_buffer.put(chunk)
if (wake_word_vad is not None) and (not wake_word_vad.process(chunk)):
raise WakeWordTimeoutError(
code="wake-word-timeout", message="Wake word was not detected"
)
class PipelinePreferred(CollectionError): class PipelinePreferred(CollectionError):
"""Raised when attempting to delete the preferred pipelen.""" """Raised when attempting to delete the preferred pipelen."""

View File

@ -0,0 +1,57 @@
"""Implementation of a ring buffer using bytearray."""
class RingBuffer:
"""Basic ring buffer using a bytearray.
Not threadsafe.
"""
def __init__(self, maxlen: int) -> None:
"""Initialize empty buffer."""
self._buffer = bytearray(maxlen)
self._pos = 0
self._length = 0
self._maxlen = maxlen
@property
def maxlen(self) -> int:
"""Return the maximum size of the buffer."""
return self._maxlen
@property
def pos(self) -> int:
"""Return the current put position."""
return self._pos
def __len__(self) -> int:
"""Return the length of data stored in the buffer."""
return self._length
def put(self, data: bytes) -> None:
"""Put a chunk of data into the buffer, possibly wrapping around."""
data_len = len(data)
new_pos = self._pos + data_len
if new_pos >= self._maxlen:
# Split into two chunks
num_bytes_1 = self._maxlen - self._pos
num_bytes_2 = new_pos - self._maxlen
self._buffer[self._pos : self._maxlen] = data[:num_bytes_1]
self._buffer[:num_bytes_2] = data[num_bytes_1:]
new_pos = new_pos - self._maxlen
else:
# Entire chunk fits at current position
self._buffer[self._pos : self._pos + data_len] = data
self._pos = new_pos
self._length = min(self._maxlen, self._length + data_len)
def getvalue(self) -> bytes:
"""Get bytes written to the buffer."""
if (self._pos + self._length) <= self._maxlen:
# Single chunk
return bytes(self._buffer[: self._length])
# Two chunks
return bytes(self._buffer[self._pos :] + self._buffer[: self._pos])

View File

@ -1,12 +1,15 @@
"""Voice activity detection.""" """Voice activity detection."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import StrEnum from enum import StrEnum
from typing import Final
import webrtcvad import webrtcvad
_SAMPLE_RATE = 16000 _SAMPLE_RATE: Final = 16000 # Hz
_SAMPLE_WIDTH: Final = 2 # bytes
class VadSensitivity(StrEnum): class VadSensitivity(StrEnum):
@ -29,6 +32,45 @@ class VadSensitivity(StrEnum):
return 1.0 return 1.0
class AudioBuffer:
"""Fixed-sized audio buffer with variable internal length."""
def __init__(self, maxlen: int) -> None:
"""Initialize buffer."""
self._buffer = bytearray(maxlen)
self._length = 0
@property
def length(self) -> int:
"""Get number of bytes currently in the buffer."""
return self._length
def clear(self) -> None:
"""Clear the buffer."""
self._length = 0
def append(self, data: bytes) -> None:
"""Append bytes to the buffer, increasing the internal length."""
data_len = len(data)
if (self._length + data_len) > len(self._buffer):
raise ValueError("Length cannot be greater than buffer size")
self._buffer[self._length : self._length + data_len] = data
self._length += data_len
def bytes(self) -> bytes:
"""Convert written portion of buffer to bytes."""
return bytes(self._buffer[: self._length])
def __len__(self) -> int:
"""Get the number of bytes currently in the buffer."""
return self._length
def __bool__(self) -> bool:
"""Return True if there are bytes in the buffer."""
return self._length > 0
@dataclass @dataclass
class VoiceCommandSegmenter: class VoiceCommandSegmenter:
"""Segments an audio stream into voice commands using webrtcvad.""" """Segments an audio stream into voice commands using webrtcvad."""
@ -36,7 +78,7 @@ class VoiceCommandSegmenter:
vad_mode: int = 3 vad_mode: int = 3
"""Aggressiveness in filtering out non-speech. 3 is the most aggressive.""" """Aggressiveness in filtering out non-speech. 3 is the most aggressive."""
vad_frames: int = 480 # 30 ms vad_samples_per_chunk: int = 480 # 30 ms
"""Must be 10, 20, or 30 ms at 16Khz.""" """Must be 10, 20, or 30 ms at 16Khz."""
speech_seconds: float = 0.3 speech_seconds: float = 0.3
@ -67,20 +109,23 @@ class VoiceCommandSegmenter:
"""Seconds left before resetting start/stop time counters.""" """Seconds left before resetting start/stop time counters."""
_vad: webrtcvad.Vad = None _vad: webrtcvad.Vad = None
_audio_buffer: bytes = field(default_factory=bytes) _leftover_chunk_buffer: AudioBuffer = field(init=False)
_bytes_per_chunk: int = 480 * 2 # 16-bit samples _bytes_per_chunk: int = field(init=False)
_seconds_per_chunk: float = 0.03 # 30 ms _seconds_per_chunk: float = field(init=False)
def __post_init__(self) -> None: def __post_init__(self) -> None:
"""Initialize VAD.""" """Initialize VAD."""
self._vad = webrtcvad.Vad(self.vad_mode) self._vad = webrtcvad.Vad(self.vad_mode)
self._bytes_per_chunk = self.vad_frames * 2 self._bytes_per_chunk = self.vad_samples_per_chunk * _SAMPLE_WIDTH
self._seconds_per_chunk = self.vad_frames / _SAMPLE_RATE self._seconds_per_chunk = self.vad_samples_per_chunk / _SAMPLE_RATE
self._leftover_chunk_buffer = AudioBuffer(
self.vad_samples_per_chunk * _SAMPLE_WIDTH
)
self.reset() self.reset()
def reset(self) -> None: def reset(self) -> None:
"""Reset all counters and state.""" """Reset all counters and state."""
self._audio_buffer = b"" self._leftover_chunk_buffer.clear()
self._speech_seconds_left = self.speech_seconds self._speech_seconds_left = self.speech_seconds
self._silence_seconds_left = self.silence_seconds self._silence_seconds_left = self.silence_seconds
self._timeout_seconds_left = self.timeout_seconds self._timeout_seconds_left = self.timeout_seconds
@ -92,27 +137,20 @@ class VoiceCommandSegmenter:
Returns False when command is done. Returns False when command is done.
""" """
self._audio_buffer += samples for chunk in chunk_samples(
samples, self._bytes_per_chunk, self._leftover_chunk_buffer
# Process in 10, 20, or 30 ms chunks. ):
num_chunks = len(self._audio_buffer) // self._bytes_per_chunk
for chunk_idx in range(num_chunks):
chunk_offset = chunk_idx * self._bytes_per_chunk
chunk = self._audio_buffer[
chunk_offset : chunk_offset + self._bytes_per_chunk
]
if not self._process_chunk(chunk): if not self._process_chunk(chunk):
self.reset() self.reset()
return False return False
if num_chunks > 0:
# Remove from buffer
self._audio_buffer = self._audio_buffer[
num_chunks * self._bytes_per_chunk :
]
return True return True
@property
def audio_buffer(self) -> bytes:
"""Get partial chunk in the audio buffer."""
return self._leftover_chunk_buffer.bytes()
def _process_chunk(self, chunk: bytes) -> bool: def _process_chunk(self, chunk: bytes) -> bool:
"""Process a single chunk of 16-bit 16Khz mono audio. """Process a single chunk of 16-bit 16Khz mono audio.
@ -163,7 +201,7 @@ class VoiceActivityTimeout:
vad_mode: int = 3 vad_mode: int = 3
"""Aggressiveness in filtering out non-speech. 3 is the most aggressive.""" """Aggressiveness in filtering out non-speech. 3 is the most aggressive."""
vad_frames: int = 480 # 30 ms vad_samples_per_chunk: int = 480 # 30 ms
"""Must be 10, 20, or 30 ms at 16Khz.""" """Must be 10, 20, or 30 ms at 16Khz."""
_silence_seconds_left: float = 0.0 _silence_seconds_left: float = 0.0
@ -173,20 +211,23 @@ class VoiceActivityTimeout:
"""Seconds left before resetting start/stop time counters.""" """Seconds left before resetting start/stop time counters."""
_vad: webrtcvad.Vad = None _vad: webrtcvad.Vad = None
_audio_buffer: bytes = field(default_factory=bytes) _leftover_chunk_buffer: AudioBuffer = field(init=False)
_bytes_per_chunk: int = 480 * 2 # 16-bit samples _bytes_per_chunk: int = field(init=False)
_seconds_per_chunk: float = 0.03 # 30 ms _seconds_per_chunk: float = field(init=False)
def __post_init__(self) -> None: def __post_init__(self) -> None:
"""Initialize VAD.""" """Initialize VAD."""
self._vad = webrtcvad.Vad(self.vad_mode) self._vad = webrtcvad.Vad(self.vad_mode)
self._bytes_per_chunk = self.vad_frames * 2 self._bytes_per_chunk = self.vad_samples_per_chunk * _SAMPLE_WIDTH
self._seconds_per_chunk = self.vad_frames / _SAMPLE_RATE self._seconds_per_chunk = self.vad_samples_per_chunk / _SAMPLE_RATE
self._leftover_chunk_buffer = AudioBuffer(
self.vad_samples_per_chunk * _SAMPLE_WIDTH
)
self.reset() self.reset()
def reset(self) -> None: def reset(self) -> None:
"""Reset all counters and state.""" """Reset all counters and state."""
self._audio_buffer = b"" self._leftover_chunk_buffer.clear()
self._silence_seconds_left = self.silence_seconds self._silence_seconds_left = self.silence_seconds
self._reset_seconds_left = self.reset_seconds self._reset_seconds_left = self.reset_seconds
@ -195,24 +236,12 @@ class VoiceActivityTimeout:
Returns False when timeout is reached. Returns False when timeout is reached.
""" """
self._audio_buffer += samples for chunk in chunk_samples(
samples, self._bytes_per_chunk, self._leftover_chunk_buffer
# Process in 10, 20, or 30 ms chunks. ):
num_chunks = len(self._audio_buffer) // self._bytes_per_chunk
for chunk_idx in range(num_chunks):
chunk_offset = chunk_idx * self._bytes_per_chunk
chunk = self._audio_buffer[
chunk_offset : chunk_offset + self._bytes_per_chunk
]
if not self._process_chunk(chunk): if not self._process_chunk(chunk):
return False return False
if num_chunks > 0:
# Remove from buffer
self._audio_buffer = self._audio_buffer[
num_chunks * self._bytes_per_chunk :
]
return True return True
def _process_chunk(self, chunk: bytes) -> bool: def _process_chunk(self, chunk: bytes) -> bool:
@ -239,3 +268,37 @@ class VoiceActivityTimeout:
) )
return True return True
def chunk_samples(
samples: bytes,
bytes_per_chunk: int,
leftover_chunk_buffer: AudioBuffer,
) -> Iterable[bytes]:
"""Yield fixed-sized chunks from samples, keeping leftover bytes from previous call(s)."""
if (len(leftover_chunk_buffer) + len(samples)) < bytes_per_chunk:
# Extend leftover chunk, but not enough samples to complete it
leftover_chunk_buffer.append(samples)
return
next_chunk_idx = 0
if leftover_chunk_buffer:
# Add to leftover chunk from previous call(s).
bytes_to_copy = bytes_per_chunk - len(leftover_chunk_buffer)
leftover_chunk_buffer.append(samples[:bytes_to_copy])
next_chunk_idx = bytes_to_copy
# Process full chunk in buffer
yield leftover_chunk_buffer.bytes()
leftover_chunk_buffer.clear()
while next_chunk_idx < len(samples) - bytes_per_chunk + 1:
# Process full chunk
yield samples[next_chunk_idx : next_chunk_idx + bytes_per_chunk]
next_chunk_idx += bytes_per_chunk
# Capture leftover chunks
if rest_samples := samples[next_chunk_idx:]:
leftover_chunk_buffer.append(rest_samples)

View File

@ -79,8 +79,6 @@ class WakeWordDetectionEntity(RestoreEntity):
@final @final
def state(self) -> str | None: def state(self) -> str | None:
"""Return the state of the entity.""" """Return the state of the entity."""
if self.__last_detected is None:
return None
return self.__last_detected return self.__last_detected
@property @property

View File

@ -317,6 +317,12 @@
}), }),
'type': <PipelineEventType.STT_VAD_START: 'stt-vad-start'>, 'type': <PipelineEventType.STT_VAD_START: 'stt-vad-start'>,
}), }),
dict({
'data': dict({
'timestamp': 1500,
}),
'type': <PipelineEventType.STT_VAD_END: 'stt-vad-end'>,
}),
dict({ dict({
'data': dict({ 'data': dict({
'stt_output': dict({ 'stt_output': dict({

View File

@ -1,7 +1,7 @@
"""Test Voice Assistant init.""" """Test Voice Assistant init."""
from dataclasses import asdict from dataclasses import asdict
import itertools as it import itertools as it
from unittest.mock import ANY from unittest.mock import ANY, patch
import pytest import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
@ -49,9 +49,9 @@ async def test_pipeline_from_audio_stream_auto(
await assist_pipeline.async_pipeline_from_audio_stream( await assist_pipeline.async_pipeline_from_audio_stream(
hass, hass,
Context(), context=Context(),
events.append, event_callback=events.append,
stt.SpeechMetadata( stt_metadata=stt.SpeechMetadata(
language="", language="",
format=stt.AudioFormats.WAV, format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM, codec=stt.AudioCodecs.PCM,
@ -59,7 +59,7 @@ async def test_pipeline_from_audio_stream_auto(
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO, channel=stt.AudioChannels.CHANNEL_MONO,
), ),
audio_data(), stt_stream=audio_data(),
) )
assert process_events(events) == snapshot assert process_events(events) == snapshot
@ -108,9 +108,9 @@ async def test_pipeline_from_audio_stream_legacy(
# Use the created pipeline # Use the created pipeline
await assist_pipeline.async_pipeline_from_audio_stream( await assist_pipeline.async_pipeline_from_audio_stream(
hass, hass,
Context(), context=Context(),
events.append, event_callback=events.append,
stt.SpeechMetadata( stt_metadata=stt.SpeechMetadata(
language="en-UK", language="en-UK",
format=stt.AudioFormats.WAV, format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM, codec=stt.AudioCodecs.PCM,
@ -118,7 +118,7 @@ async def test_pipeline_from_audio_stream_legacy(
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO, channel=stt.AudioChannels.CHANNEL_MONO,
), ),
audio_data(), stt_stream=audio_data(),
pipeline_id=pipeline_id, pipeline_id=pipeline_id,
) )
@ -168,9 +168,9 @@ async def test_pipeline_from_audio_stream_entity(
# Use the created pipeline # Use the created pipeline
await assist_pipeline.async_pipeline_from_audio_stream( await assist_pipeline.async_pipeline_from_audio_stream(
hass, hass,
Context(), context=Context(),
events.append, event_callback=events.append,
stt.SpeechMetadata( stt_metadata=stt.SpeechMetadata(
language="en-UK", language="en-UK",
format=stt.AudioFormats.WAV, format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM, codec=stt.AudioCodecs.PCM,
@ -178,7 +178,7 @@ async def test_pipeline_from_audio_stream_entity(
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO, channel=stt.AudioChannels.CHANNEL_MONO,
), ),
audio_data(), stt_stream=audio_data(),
pipeline_id=pipeline_id, pipeline_id=pipeline_id,
) )
@ -229,9 +229,9 @@ async def test_pipeline_from_audio_stream_no_stt(
with pytest.raises(assist_pipeline.pipeline.PipelineRunValidationError): with pytest.raises(assist_pipeline.pipeline.PipelineRunValidationError):
await assist_pipeline.async_pipeline_from_audio_stream( await assist_pipeline.async_pipeline_from_audio_stream(
hass, hass,
Context(), context=Context(),
events.append, event_callback=events.append,
stt.SpeechMetadata( stt_metadata=stt.SpeechMetadata(
language="en-UK", language="en-UK",
format=stt.AudioFormats.WAV, format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM, codec=stt.AudioCodecs.PCM,
@ -239,7 +239,7 @@ async def test_pipeline_from_audio_stream_no_stt(
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO, channel=stt.AudioChannels.CHANNEL_MONO,
), ),
audio_data(), stt_stream=audio_data(),
pipeline_id=pipeline_id, pipeline_id=pipeline_id,
) )
@ -268,9 +268,9 @@ async def test_pipeline_from_audio_stream_unknown_pipeline(
with pytest.raises(assist_pipeline.PipelineNotFound): with pytest.raises(assist_pipeline.PipelineNotFound):
await assist_pipeline.async_pipeline_from_audio_stream( await assist_pipeline.async_pipeline_from_audio_stream(
hass, hass,
Context(), context=Context(),
events.append, event_callback=events.append,
stt.SpeechMetadata( stt_metadata=stt.SpeechMetadata(
language="en-UK", language="en-UK",
format=stt.AudioFormats.WAV, format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM, codec=stt.AudioCodecs.PCM,
@ -278,7 +278,7 @@ async def test_pipeline_from_audio_stream_unknown_pipeline(
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO, channel=stt.AudioChannels.CHANNEL_MONO,
), ),
audio_data(), stt_stream=audio_data(),
pipeline_id="blah", pipeline_id="blah",
) )
@ -308,26 +308,38 @@ async def test_pipeline_from_audio_stream_wake_word(
yield b"wake word" yield b"wake word"
yield b"part1" yield b"part1"
yield b"part2" yield b"part2"
yield b"end"
yield b"" yield b""
await assist_pipeline.async_pipeline_from_audio_stream( def continue_stt(self, chunk):
hass, # Ensure stt_vad_start event is triggered
Context(), self.in_command = True
events.append,
stt.SpeechMetadata( # Stop on fake end chunk to trigger stt_vad_end
language="", return chunk != b"end"
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM, with patch(
bit_rate=stt.AudioBitRates.BITRATE_16, "homeassistant.components.assist_pipeline.pipeline.VoiceCommandSegmenter.process",
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, continue_stt,
channel=stt.AudioChannels.CHANNEL_MONO, ):
), await assist_pipeline.async_pipeline_from_audio_stream(
audio_data(), hass,
start_stage=assist_pipeline.PipelineStage.WAKE_WORD, context=Context(),
wake_word_settings=assist_pipeline.WakeWordSettings( event_callback=events.append,
audio_seconds_to_buffer=1.5 stt_metadata=stt.SpeechMetadata(
), language="",
) format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
wake_word_settings=assist_pipeline.WakeWordSettings(
audio_seconds_to_buffer=1.5
),
)
assert process_events(events) == snapshot assert process_events(events) == snapshot

View File

@ -0,0 +1,38 @@
"""Tests for audio ring buffer."""
from homeassistant.components.assist_pipeline.ring_buffer import RingBuffer
def test_ring_buffer_empty() -> None:
"""Test empty ring buffer."""
rb = RingBuffer(10)
assert rb.maxlen == 10
assert rb.pos == 0
assert rb.getvalue() == b""
def test_ring_buffer_put_1() -> None:
"""Test putting some data smaller than the maximum length."""
rb = RingBuffer(10)
rb.put(bytes([1, 2, 3, 4, 5]))
assert len(rb) == 5
assert rb.pos == 5
assert rb.getvalue() == bytes([1, 2, 3, 4, 5])
def test_ring_buffer_put_2() -> None:
"""Test putting some data past the end of the buffer."""
rb = RingBuffer(10)
rb.put(bytes([1, 2, 3, 4, 5]))
rb.put(bytes([6, 7, 8, 9, 10, 11, 12]))
assert len(rb) == 10
assert rb.pos == 2
assert rb.getvalue() == bytes([3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
def test_ring_buffer_put_too_large() -> None:
"""Test putting data too large for the buffer."""
rb = RingBuffer(10)
rb.put(bytes([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]))
assert len(rb) == 10
assert rb.pos == 2
assert rb.getvalue() == bytes([3, 4, 5, 6, 7, 8, 9, 10, 11, 12])

View File

@ -1,7 +1,12 @@
"""Tests for webrtcvad voice command segmenter.""" """Tests for webrtcvad voice command segmenter."""
import itertools as it
from unittest.mock import patch from unittest.mock import patch
from homeassistant.components.assist_pipeline.vad import VoiceCommandSegmenter from homeassistant.components.assist_pipeline.vad import (
AudioBuffer,
VoiceCommandSegmenter,
chunk_samples,
)
_ONE_SECOND = 16000 * 2 # 16Khz 16-bit _ONE_SECOND = 16000 * 2 # 16Khz 16-bit
@ -36,3 +41,87 @@ def test_speech() -> None:
# silence # silence
# False return value indicates voice command is finished # False return value indicates voice command is finished
assert not segmenter.process(bytes(_ONE_SECOND)) assert not segmenter.process(bytes(_ONE_SECOND))
def test_audio_buffer() -> None:
"""Test audio buffer wrapping."""
def is_speech(self, chunk, sample_rate):
"""Disable VAD."""
return False
with patch(
"webrtcvad.Vad.is_speech",
new=is_speech,
):
segmenter = VoiceCommandSegmenter()
bytes_per_chunk = segmenter.vad_samples_per_chunk * 2
with patch.object(
segmenter, "_process_chunk", return_value=True
) as mock_process:
# Partially fill audio buffer
half_chunk = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk // 2))
segmenter.process(half_chunk)
assert not mock_process.called
assert segmenter.audio_buffer == half_chunk
# Fill and wrap with 1/4 chunk left over
three_quarters_chunk = bytes(
it.islice(it.cycle(range(256)), int(0.75 * bytes_per_chunk))
)
segmenter.process(three_quarters_chunk)
assert mock_process.call_count == 1
assert (
segmenter.audio_buffer
== three_quarters_chunk[
len(three_quarters_chunk) - (bytes_per_chunk // 4) :
]
)
assert (
mock_process.call_args[0][0]
== half_chunk + three_quarters_chunk[: bytes_per_chunk // 2]
)
# Run 2 chunks through
segmenter.reset()
assert len(segmenter.audio_buffer) == 0
mock_process.reset_mock()
two_chunks = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk * 2))
segmenter.process(two_chunks)
assert mock_process.call_count == 2
assert len(segmenter.audio_buffer) == 0
assert mock_process.call_args_list[0][0][0] == two_chunks[:bytes_per_chunk]
assert mock_process.call_args_list[1][0][0] == two_chunks[bytes_per_chunk:]
def test_partial_chunk() -> None:
"""Test that chunk_samples returns when given a partial chunk."""
bytes_per_chunk = 5
samples = bytes([1, 2, 3])
leftover_chunk_buffer = AudioBuffer(bytes_per_chunk)
chunks = list(chunk_samples(samples, bytes_per_chunk, leftover_chunk_buffer))
assert len(chunks) == 0
assert leftover_chunk_buffer.bytes() == samples
def test_chunk_samples_leftover() -> None:
"""Test that chunk_samples property keeps left over bytes across calls."""
bytes_per_chunk = 5
samples = bytes([1, 2, 3, 4, 5, 6])
leftover_chunk_buffer = AudioBuffer(bytes_per_chunk)
chunks = list(chunk_samples(samples, bytes_per_chunk, leftover_chunk_buffer))
assert len(chunks) == 1
assert leftover_chunk_buffer.bytes() == bytes([6])
# Add some more to the chunk
chunks = list(chunk_samples(samples, bytes_per_chunk, leftover_chunk_buffer))
assert len(chunks) == 1
assert leftover_chunk_buffer.bytes() == bytes([5, 6])