Switch from WebRTC to microVAD (#122861)

* Switch WebRTC to microVAD

* Remove webrtc-noise-gain from licenses
This commit is contained in:
Michael Hansen 2024-07-31 02:42:45 -05:00 committed by GitHub
parent beb2ef121e
commit 7f4dabf546
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 320 additions and 347 deletions

View File

@ -0,0 +1,82 @@
"""Audio enhancement for Assist."""
from abc import ABC, abstractmethod
from dataclasses import dataclass
import logging
from pymicro_vad import MicroVad
_LOGGER = logging.getLogger(__name__)
@dataclass(frozen=True, slots=True)
class EnhancedAudioChunk:
"""Enhanced audio chunk and metadata."""
audio: bytes
"""Raw PCM audio @ 16Khz with 16-bit mono samples"""
timestamp_ms: int
"""Timestamp relative to start of audio stream (milliseconds)"""
is_speech: bool | None
"""True if audio chunk likely contains speech, False if not, None if unknown"""
class AudioEnhancer(ABC):
"""Base class for audio enhancement."""
def __init__(
self, auto_gain: int, noise_suppression: int, is_vad_enabled: bool
) -> None:
"""Initialize audio enhancer."""
self.auto_gain = auto_gain
self.noise_suppression = noise_suppression
self.is_vad_enabled = is_vad_enabled
@abstractmethod
def enhance_chunk(self, audio: bytes, timestamp_ms: int) -> EnhancedAudioChunk:
"""Enhance chunk of PCM audio @ 16Khz with 16-bit mono samples."""
@property
@abstractmethod
def samples_per_chunk(self) -> int | None:
"""Return number of samples per chunk or None if chunking isn't required."""
class MicroVadEnhancer(AudioEnhancer):
"""Audio enhancer that just runs microVAD."""
def __init__(
self, auto_gain: int, noise_suppression: int, is_vad_enabled: bool
) -> None:
"""Initialize audio enhancer."""
super().__init__(auto_gain, noise_suppression, is_vad_enabled)
self.vad: MicroVad | None = None
self.threshold = 0.5
if self.is_vad_enabled:
self.vad = MicroVad()
_LOGGER.debug("Initialized microVAD with threshold=%s", self.threshold)
def enhance_chunk(self, audio: bytes, timestamp_ms: int) -> EnhancedAudioChunk:
"""Enhance chunk of PCM audio @ 16Khz with 16-bit mono samples."""
is_speech: bool | None = None
if self.vad is not None:
# Run VAD
speech_prob = self.vad.Process10ms(audio)
is_speech = speech_prob > self.threshold
return EnhancedAudioChunk(
audio=audio, timestamp_ms=timestamp_ms, is_speech=is_speech
)
@property
def samples_per_chunk(self) -> int | None:
"""Return number of samples per chunk or None if chunking isn't required."""
if self.is_vad_enabled:
return 160 # 10ms
return None

View File

@ -15,3 +15,8 @@ DATA_LAST_WAKE_UP = f"{DOMAIN}.last_wake_up"
WAKE_WORD_COOLDOWN = 2 # seconds WAKE_WORD_COOLDOWN = 2 # seconds
EVENT_RECORDING = f"{DOMAIN}_recording" EVENT_RECORDING = f"{DOMAIN}_recording"
SAMPLE_RATE = 16000 # hertz
SAMPLE_WIDTH = 2 # bytes
SAMPLE_CHANNELS = 1 # mono
SAMPLES_PER_CHUNK = 240 # 20 ms @ 16Khz

View File

@ -6,5 +6,5 @@
"documentation": "https://www.home-assistant.io/integrations/assist_pipeline", "documentation": "https://www.home-assistant.io/integrations/assist_pipeline",
"iot_class": "local_push", "iot_class": "local_push",
"quality_scale": "internal", "quality_scale": "internal",
"requirements": ["webrtc-noise-gain==1.2.3"] "requirements": ["pymicro-vad==1.0.0"]
} }

View File

@ -13,14 +13,11 @@ from pathlib import Path
from queue import Empty, Queue from queue import Empty, Queue
from threading import Thread from threading import Thread
import time import time
from typing import TYPE_CHECKING, Any, Final, Literal, cast from typing import Any, Literal, cast
import wave import wave
import voluptuous as vol import voluptuous as vol
if TYPE_CHECKING:
from webrtc_noise_gain import AudioProcessor
from homeassistant.components import ( from homeassistant.components import (
conversation, conversation,
media_source, media_source,
@ -52,12 +49,17 @@ from homeassistant.util import (
) )
from homeassistant.util.limited_size_dict import LimitedSizeDict from homeassistant.util.limited_size_dict import LimitedSizeDict
from .audio_enhancer import AudioEnhancer, EnhancedAudioChunk, MicroVadEnhancer
from .const import ( from .const import (
CONF_DEBUG_RECORDING_DIR, CONF_DEBUG_RECORDING_DIR,
DATA_CONFIG, DATA_CONFIG,
DATA_LAST_WAKE_UP, DATA_LAST_WAKE_UP,
DATA_MIGRATIONS, DATA_MIGRATIONS,
DOMAIN, DOMAIN,
SAMPLE_CHANNELS,
SAMPLE_RATE,
SAMPLE_WIDTH,
SAMPLES_PER_CHUNK,
WAKE_WORD_COOLDOWN, WAKE_WORD_COOLDOWN,
) )
from .error import ( from .error import (
@ -111,9 +113,6 @@ STORED_PIPELINE_RUNS = 10
SAVE_DELAY = 10 SAVE_DELAY = 10
AUDIO_PROCESSOR_SAMPLES: Final = 160 # 10 ms @ 16 Khz
AUDIO_PROCESSOR_BYTES: Final = AUDIO_PROCESSOR_SAMPLES * 2 # 16-bit samples
@callback @callback
def _async_resolve_default_pipeline_settings( def _async_resolve_default_pipeline_settings(
@ -503,8 +502,8 @@ class AudioSettings:
is_vad_enabled: bool = True is_vad_enabled: bool = True
"""True if VAD is used to determine the end of the voice command.""" """True if VAD is used to determine the end of the voice command."""
is_chunking_enabled: bool = True samples_per_chunk: int | None = None
"""True if audio is automatically split into 10 ms chunks (required for VAD, etc.)""" """Number of samples that will be in each audio chunk (None for no chunking)."""
def __post_init__(self) -> None: def __post_init__(self) -> None:
"""Verify settings post-initialization.""" """Verify settings post-initialization."""
@ -514,9 +513,6 @@ class AudioSettings:
if (self.auto_gain_dbfs < 0) or (self.auto_gain_dbfs > 31): if (self.auto_gain_dbfs < 0) or (self.auto_gain_dbfs > 31):
raise ValueError("auto_gain_dbfs must be in [0, 31]") raise ValueError("auto_gain_dbfs must be in [0, 31]")
if self.needs_processor and (not self.is_chunking_enabled):
raise ValueError("Chunking must be enabled for audio processing")
@property @property
def needs_processor(self) -> bool: def needs_processor(self) -> bool:
"""True if an audio processor is needed.""" """True if an audio processor is needed."""
@ -526,19 +522,10 @@ class AudioSettings:
or (self.auto_gain_dbfs > 0) or (self.auto_gain_dbfs > 0)
) )
@property
@dataclass(frozen=True, slots=True) def is_chunking_enabled(self) -> bool:
class ProcessedAudioChunk: """True if chunk size is set."""
"""Processed audio chunk and metadata.""" return self.samples_per_chunk is not None
audio: bytes
"""Raw PCM audio @ 16Khz with 16-bit mono samples"""
timestamp_ms: int
"""Timestamp relative to start of audio stream (milliseconds)"""
is_speech: bool | None
"""True if audio chunk likely contains speech, False if not, None if unknown"""
@dataclass @dataclass
@ -573,10 +560,10 @@ class PipelineRun:
debug_recording_queue: Queue[str | bytes | None] | None = None debug_recording_queue: Queue[str | bytes | None] | None = None
"""Queue to communicate with debug recording thread""" """Queue to communicate with debug recording thread"""
audio_processor: AudioProcessor | None = None audio_enhancer: AudioEnhancer | None = None
"""VAD/noise suppression/auto gain""" """VAD/noise suppression/auto gain"""
audio_processor_buffer: AudioBuffer = field(init=False, repr=False) audio_chunking_buffer: AudioBuffer | None = None
"""Buffer used when splitting audio into chunks for audio processing""" """Buffer used when splitting audio into chunks for audio processing"""
_device_id: str | None = None _device_id: str | None = None
@ -601,19 +588,16 @@ class PipelineRun:
pipeline_data.pipeline_runs.add_run(self) pipeline_data.pipeline_runs.add_run(self)
# Initialize with audio settings # Initialize with audio settings
self.audio_processor_buffer = AudioBuffer(AUDIO_PROCESSOR_BYTES) if self.audio_settings.needs_processor and (self.audio_enhancer is None):
if self.audio_settings.needs_processor: # Default audio enhancer
# Delay import of webrtc so HA start up is not crashing self.audio_enhancer = MicroVadEnhancer(
# on older architectures (armhf).
#
# pylint: disable=import-outside-toplevel
from webrtc_noise_gain import AudioProcessor
self.audio_processor = AudioProcessor(
self.audio_settings.auto_gain_dbfs, self.audio_settings.auto_gain_dbfs,
self.audio_settings.noise_suppression_level, self.audio_settings.noise_suppression_level,
self.audio_settings.is_vad_enabled,
) )
self.audio_chunking_buffer = AudioBuffer(self.samples_per_chunk * SAMPLE_WIDTH)
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
"""Compare pipeline runs by id.""" """Compare pipeline runs by id."""
if isinstance(other, PipelineRun): if isinstance(other, PipelineRun):
@ -621,6 +605,14 @@ class PipelineRun:
return False return False
@property
def samples_per_chunk(self) -> int:
"""Return number of samples expected in each audio chunk."""
if self.audio_enhancer is not None:
return self.audio_enhancer.samples_per_chunk or SAMPLES_PER_CHUNK
return self.audio_settings.samples_per_chunk or SAMPLES_PER_CHUNK
@callback @callback
def process_event(self, event: PipelineEvent) -> None: def process_event(self, event: PipelineEvent) -> None:
"""Log an event and call listener.""" """Log an event and call listener."""
@ -688,8 +680,8 @@ class PipelineRun:
async def wake_word_detection( async def wake_word_detection(
self, self,
stream: AsyncIterable[ProcessedAudioChunk], stream: AsyncIterable[EnhancedAudioChunk],
audio_chunks_for_stt: list[ProcessedAudioChunk], audio_chunks_for_stt: list[EnhancedAudioChunk],
) -> wake_word.DetectionResult | None: ) -> wake_word.DetectionResult | None:
"""Run wake-word-detection portion of pipeline. Returns detection result.""" """Run wake-word-detection portion of pipeline. Returns detection result."""
metadata_dict = asdict( metadata_dict = asdict(
@ -732,10 +724,11 @@ class PipelineRun:
# Audio chunk buffer. This audio will be forwarded to speech-to-text # Audio chunk buffer. This audio will be forwarded to speech-to-text
# after wake-word-detection. # after wake-word-detection.
num_audio_chunks_to_buffer = int( num_audio_chunks_to_buffer = int(
(wake_word_settings.audio_seconds_to_buffer * 16000) (wake_word_settings.audio_seconds_to_buffer * SAMPLE_RATE)
/ AUDIO_PROCESSOR_SAMPLES / self.samples_per_chunk
) )
stt_audio_buffer: deque[ProcessedAudioChunk] | None = None
stt_audio_buffer: deque[EnhancedAudioChunk] | None = None
if num_audio_chunks_to_buffer > 0: if num_audio_chunks_to_buffer > 0:
stt_audio_buffer = deque(maxlen=num_audio_chunks_to_buffer) stt_audio_buffer = deque(maxlen=num_audio_chunks_to_buffer)
@ -797,7 +790,7 @@ class PipelineRun:
# speech-to-text so the user does not have to pause before # speech-to-text so the user does not have to pause before
# speaking the voice command. # speaking the voice command.
audio_chunks_for_stt.extend( audio_chunks_for_stt.extend(
ProcessedAudioChunk( EnhancedAudioChunk(
audio=chunk_ts[0], timestamp_ms=chunk_ts[1], is_speech=False audio=chunk_ts[0], timestamp_ms=chunk_ts[1], is_speech=False
) )
for chunk_ts in result.queued_audio for chunk_ts in result.queued_audio
@ -819,18 +812,17 @@ class PipelineRun:
async def _wake_word_audio_stream( async def _wake_word_audio_stream(
self, self,
audio_stream: AsyncIterable[ProcessedAudioChunk], audio_stream: AsyncIterable[EnhancedAudioChunk],
stt_audio_buffer: deque[ProcessedAudioChunk] | None, stt_audio_buffer: deque[EnhancedAudioChunk] | None,
wake_word_vad: VoiceActivityTimeout | None, wake_word_vad: VoiceActivityTimeout | None,
sample_rate: int = 16000, sample_rate: int = SAMPLE_RATE,
sample_width: int = 2, sample_width: int = SAMPLE_WIDTH,
) -> AsyncIterable[tuple[bytes, int]]: ) -> AsyncIterable[tuple[bytes, int]]:
"""Yield audio chunks with timestamps (milliseconds since start of stream). """Yield audio chunks with timestamps (milliseconds since start of stream).
Adds audio to a ring buffer that will be forwarded to speech-to-text after Adds audio to a ring buffer that will be forwarded to speech-to-text after
detection. Times out if VAD detects enough silence. detection. Times out if VAD detects enough silence.
""" """
chunk_seconds = AUDIO_PROCESSOR_SAMPLES / sample_rate
async for chunk in audio_stream: async for chunk in audio_stream:
if self.abort_wake_word_detection: if self.abort_wake_word_detection:
raise WakeWordDetectionAborted raise WakeWordDetectionAborted
@ -845,6 +837,7 @@ class PipelineRun:
stt_audio_buffer.append(chunk) stt_audio_buffer.append(chunk)
if wake_word_vad is not None: if wake_word_vad is not None:
chunk_seconds = (len(chunk.audio) // sample_width) / sample_rate
if not wake_word_vad.process(chunk_seconds, chunk.is_speech): if not wake_word_vad.process(chunk_seconds, chunk.is_speech):
raise WakeWordTimeoutError( raise WakeWordTimeoutError(
code="wake-word-timeout", message="Wake word was not detected" code="wake-word-timeout", message="Wake word was not detected"
@ -881,7 +874,7 @@ class PipelineRun:
async def speech_to_text( async def speech_to_text(
self, self,
metadata: stt.SpeechMetadata, metadata: stt.SpeechMetadata,
stream: AsyncIterable[ProcessedAudioChunk], stream: AsyncIterable[EnhancedAudioChunk],
) -> str: ) -> str:
"""Run speech-to-text portion of pipeline. Returns the spoken text.""" """Run speech-to-text portion of pipeline. Returns the spoken text."""
# Create a background task to prepare the conversation agent # Create a background task to prepare the conversation agent
@ -957,18 +950,18 @@ class PipelineRun:
async def _speech_to_text_stream( async def _speech_to_text_stream(
self, self,
audio_stream: AsyncIterable[ProcessedAudioChunk], audio_stream: AsyncIterable[EnhancedAudioChunk],
stt_vad: VoiceCommandSegmenter | None, stt_vad: VoiceCommandSegmenter | None,
sample_rate: int = 16000, sample_rate: int = SAMPLE_RATE,
sample_width: int = 2, sample_width: int = SAMPLE_WIDTH,
) -> AsyncGenerator[bytes]: ) -> AsyncGenerator[bytes]:
"""Yield audio chunks until VAD detects silence or speech-to-text completes.""" """Yield audio chunks until VAD detects silence or speech-to-text completes."""
chunk_seconds = AUDIO_PROCESSOR_SAMPLES / sample_rate
sent_vad_start = False sent_vad_start = False
async for chunk in audio_stream: async for chunk in audio_stream:
self._capture_chunk(chunk.audio) self._capture_chunk(chunk.audio)
if stt_vad is not None: if stt_vad is not None:
chunk_seconds = (len(chunk.audio) // sample_width) / sample_rate
if not stt_vad.process(chunk_seconds, chunk.is_speech): if not stt_vad.process(chunk_seconds, chunk.is_speech):
# Silence detected at the end of voice command # Silence detected at the end of voice command
self.process_event( self.process_event(
@ -1072,8 +1065,8 @@ class PipelineRun:
tts_options[tts.ATTR_PREFERRED_FORMAT] = self.tts_audio_output tts_options[tts.ATTR_PREFERRED_FORMAT] = self.tts_audio_output
if self.tts_audio_output == "wav": if self.tts_audio_output == "wav":
# 16 Khz, 16-bit mono # 16 Khz, 16-bit mono
tts_options[tts.ATTR_PREFERRED_SAMPLE_RATE] = 16000 tts_options[tts.ATTR_PREFERRED_SAMPLE_RATE] = SAMPLE_RATE
tts_options[tts.ATTR_PREFERRED_SAMPLE_CHANNELS] = 1 tts_options[tts.ATTR_PREFERRED_SAMPLE_CHANNELS] = SAMPLE_CHANNELS
try: try:
options_supported = await tts.async_support_options( options_supported = await tts.async_support_options(
@ -1220,12 +1213,15 @@ class PipelineRun:
async def process_volume_only( async def process_volume_only(
self, self,
audio_stream: AsyncIterable[bytes], audio_stream: AsyncIterable[bytes],
sample_rate: int = 16000, sample_rate: int = SAMPLE_RATE,
sample_width: int = 2, sample_width: int = SAMPLE_WIDTH,
) -> AsyncGenerator[ProcessedAudioChunk]: ) -> AsyncGenerator[EnhancedAudioChunk]:
"""Apply volume transformation only (no VAD/audio enhancements) with optional chunking.""" """Apply volume transformation only (no VAD/audio enhancements) with optional chunking."""
assert self.audio_chunking_buffer is not None
bytes_per_chunk = self.samples_per_chunk * sample_width
ms_per_sample = sample_rate // 1000 ms_per_sample = sample_rate // 1000
ms_per_chunk = (AUDIO_PROCESSOR_SAMPLES // sample_width) // ms_per_sample ms_per_chunk = self.samples_per_chunk // ms_per_sample
timestamp_ms = 0 timestamp_ms = 0
async for chunk in audio_stream: async for chunk in audio_stream:
@ -1233,19 +1229,18 @@ class PipelineRun:
chunk = _multiply_volume(chunk, self.audio_settings.volume_multiplier) chunk = _multiply_volume(chunk, self.audio_settings.volume_multiplier)
if self.audio_settings.is_chunking_enabled: if self.audio_settings.is_chunking_enabled:
# 10 ms chunking for sub_chunk in chunk_samples(
for chunk_10ms in chunk_samples( chunk, bytes_per_chunk, self.audio_chunking_buffer
chunk, AUDIO_PROCESSOR_BYTES, self.audio_processor_buffer
): ):
yield ProcessedAudioChunk( yield EnhancedAudioChunk(
audio=chunk_10ms, audio=sub_chunk,
timestamp_ms=timestamp_ms, timestamp_ms=timestamp_ms,
is_speech=None, # no VAD is_speech=None, # no VAD
) )
timestamp_ms += ms_per_chunk timestamp_ms += ms_per_chunk
else: else:
# No chunking # No chunking
yield ProcessedAudioChunk( yield EnhancedAudioChunk(
audio=chunk, audio=chunk,
timestamp_ms=timestamp_ms, timestamp_ms=timestamp_ms,
is_speech=None, # no VAD is_speech=None, # no VAD
@ -1255,14 +1250,19 @@ class PipelineRun:
async def process_enhance_audio( async def process_enhance_audio(
self, self,
audio_stream: AsyncIterable[bytes], audio_stream: AsyncIterable[bytes],
sample_rate: int = 16000, sample_rate: int = SAMPLE_RATE,
sample_width: int = 2, sample_width: int = SAMPLE_WIDTH,
) -> AsyncGenerator[ProcessedAudioChunk]: ) -> AsyncGenerator[EnhancedAudioChunk]:
"""Split audio into 10 ms chunks and apply VAD/noise suppression/auto gain/volume transformation.""" """Split audio into 10 ms chunks and apply VAD/noise suppression/auto gain/volume transformation."""
assert self.audio_processor is not None assert self.audio_enhancer is not None
assert self.audio_enhancer.samples_per_chunk is not None
assert self.audio_chunking_buffer is not None
bytes_per_chunk = self.audio_enhancer.samples_per_chunk * sample_width
ms_per_sample = sample_rate // 1000 ms_per_sample = sample_rate // 1000
ms_per_chunk = (AUDIO_PROCESSOR_SAMPLES // sample_width) // ms_per_sample ms_per_chunk = (
self.audio_enhancer.samples_per_chunk // sample_width
) // ms_per_sample
timestamp_ms = 0 timestamp_ms = 0
async for dirty_samples in audio_stream: async for dirty_samples in audio_stream:
@ -1272,17 +1272,11 @@ class PipelineRun:
dirty_samples, self.audio_settings.volume_multiplier dirty_samples, self.audio_settings.volume_multiplier
) )
# Split into 10ms chunks for audio enhancements/VAD # Split into chunks for audio enhancements/VAD
for dirty_10ms_chunk in chunk_samples( for dirty_chunk in chunk_samples(
dirty_samples, AUDIO_PROCESSOR_BYTES, self.audio_processor_buffer dirty_samples, bytes_per_chunk, self.audio_chunking_buffer
): ):
ap_result = self.audio_processor.Process10ms(dirty_10ms_chunk) yield self.audio_enhancer.enhance_chunk(dirty_chunk, timestamp_ms)
yield ProcessedAudioChunk(
audio=ap_result.audio,
timestamp_ms=timestamp_ms,
is_speech=ap_result.is_speech,
)
timestamp_ms += ms_per_chunk timestamp_ms += ms_per_chunk
@ -1323,9 +1317,9 @@ def _pipeline_debug_recording_thread_proc(
wav_path = run_recording_dir / f"{message}.wav" wav_path = run_recording_dir / f"{message}.wav"
wav_writer = wave.open(str(wav_path), "wb") wav_writer = wave.open(str(wav_path), "wb")
wav_writer.setframerate(16000) wav_writer.setframerate(SAMPLE_RATE)
wav_writer.setsampwidth(2) wav_writer.setsampwidth(SAMPLE_WIDTH)
wav_writer.setnchannels(1) wav_writer.setnchannels(SAMPLE_CHANNELS)
elif isinstance(message, bytes): elif isinstance(message, bytes):
# Chunk of 16-bit mono audio at 16Khz # Chunk of 16-bit mono audio at 16Khz
if wav_writer is not None: if wav_writer is not None:
@ -1368,8 +1362,8 @@ class PipelineInput:
"""Run pipeline.""" """Run pipeline."""
self.run.start(device_id=self.device_id) self.run.start(device_id=self.device_id)
current_stage: PipelineStage | None = self.run.start_stage current_stage: PipelineStage | None = self.run.start_stage
stt_audio_buffer: list[ProcessedAudioChunk] = [] stt_audio_buffer: list[EnhancedAudioChunk] = []
stt_processed_stream: AsyncIterable[ProcessedAudioChunk] | None = None stt_processed_stream: AsyncIterable[EnhancedAudioChunk] | None = None
if self.stt_stream is not None: if self.stt_stream is not None:
if self.run.audio_settings.needs_processor: if self.run.audio_settings.needs_processor:
@ -1423,7 +1417,7 @@ class PipelineInput:
# Send audio in the buffer first to speech-to-text, then move on to stt_stream. # Send audio in the buffer first to speech-to-text, then move on to stt_stream.
# This is basically an async itertools.chain. # This is basically an async itertools.chain.
async def buffer_then_audio_stream() -> ( async def buffer_then_audio_stream() -> (
AsyncGenerator[ProcessedAudioChunk] AsyncGenerator[EnhancedAudioChunk]
): ):
# Buffered audio # Buffered audio
for chunk in stt_audio_buffer: for chunk in stt_audio_buffer:

View File

@ -2,12 +2,11 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from collections.abc import Callable, Iterable
from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
from enum import StrEnum from enum import StrEnum
import logging import logging
from typing import Final, cast from typing import Final
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -35,44 +34,6 @@ class VadSensitivity(StrEnum):
return 1.0 return 1.0
class VoiceActivityDetector(ABC):
"""Base class for voice activity detectors (VAD)."""
@abstractmethod
def is_speech(self, chunk: bytes) -> bool:
"""Return True if audio chunk contains speech."""
@property
@abstractmethod
def samples_per_chunk(self) -> int | None:
"""Return number of samples per chunk or None if chunking is not required."""
class WebRtcVad(VoiceActivityDetector):
"""Voice activity detector based on webrtc."""
def __init__(self) -> None:
"""Initialize webrtcvad."""
# Delay import of webrtc so HA start up is not crashing
# on older architectures (armhf).
#
# pylint: disable=import-outside-toplevel
from webrtc_noise_gain import AudioProcessor
# Just VAD: no noise suppression or auto gain
self._audio_processor = AudioProcessor(0, 0)
def is_speech(self, chunk: bytes) -> bool:
"""Return True if audio chunk contains speech."""
result = self._audio_processor.Process10ms(chunk)
return cast(bool, result.is_speech)
@property
def samples_per_chunk(self) -> int | None:
"""Return 10 ms."""
return int(0.01 * _SAMPLE_RATE) # 10 ms
class AudioBuffer: class AudioBuffer:
"""Fixed-sized audio buffer with variable internal length.""" """Fixed-sized audio buffer with variable internal length."""
@ -176,29 +137,38 @@ class VoiceCommandSegmenter:
if self._speech_seconds_left <= 0: if self._speech_seconds_left <= 0:
# Inside voice command # Inside voice command
self.in_command = True self.in_command = True
self._silence_seconds_left = self.silence_seconds
_LOGGER.debug("Voice command started")
else: else:
# Reset if enough silence # Reset if enough silence
self._reset_seconds_left -= chunk_seconds self._reset_seconds_left -= chunk_seconds
if self._reset_seconds_left <= 0: if self._reset_seconds_left <= 0:
self._speech_seconds_left = self.speech_seconds self._speech_seconds_left = self.speech_seconds
self._reset_seconds_left = self.reset_seconds
elif not is_speech: elif not is_speech:
# Silence in command
self._reset_seconds_left = self.reset_seconds self._reset_seconds_left = self.reset_seconds
self._silence_seconds_left -= chunk_seconds self._silence_seconds_left -= chunk_seconds
if self._silence_seconds_left <= 0: if self._silence_seconds_left <= 0:
# Command finished successfully
self.reset() self.reset()
_LOGGER.debug("Voice command finished")
return False return False
else: else:
# Reset if enough speech # Speech in command.
# Reset silence counter if enough speech.
self._reset_seconds_left -= chunk_seconds self._reset_seconds_left -= chunk_seconds
if self._reset_seconds_left <= 0: if self._reset_seconds_left <= 0:
self._silence_seconds_left = self.silence_seconds self._silence_seconds_left = self.silence_seconds
self._reset_seconds_left = self.reset_seconds
return True return True
def process_with_vad( def process_with_vad(
self, self,
chunk: bytes, chunk: bytes,
vad: VoiceActivityDetector, vad_samples_per_chunk: int | None,
vad_is_speech: Callable[[bytes], bool],
leftover_chunk_buffer: AudioBuffer | None, leftover_chunk_buffer: AudioBuffer | None,
) -> bool: ) -> bool:
"""Process an audio chunk using an external VAD. """Process an audio chunk using an external VAD.
@ -207,20 +177,20 @@ class VoiceCommandSegmenter:
Returns False when voice command is finished. Returns False when voice command is finished.
""" """
if vad.samples_per_chunk is None: if vad_samples_per_chunk is None:
# No chunking # No chunking
chunk_seconds = (len(chunk) // _SAMPLE_WIDTH) / _SAMPLE_RATE chunk_seconds = (len(chunk) // _SAMPLE_WIDTH) / _SAMPLE_RATE
is_speech = vad.is_speech(chunk) is_speech = vad_is_speech(chunk)
return self.process(chunk_seconds, is_speech) return self.process(chunk_seconds, is_speech)
if leftover_chunk_buffer is None: if leftover_chunk_buffer is None:
raise ValueError("leftover_chunk_buffer is required when vad uses chunking") raise ValueError("leftover_chunk_buffer is required when vad uses chunking")
# With chunking # With chunking
seconds_per_chunk = vad.samples_per_chunk / _SAMPLE_RATE seconds_per_chunk = vad_samples_per_chunk / _SAMPLE_RATE
bytes_per_chunk = vad.samples_per_chunk * _SAMPLE_WIDTH bytes_per_chunk = vad_samples_per_chunk * _SAMPLE_WIDTH
for vad_chunk in chunk_samples(chunk, bytes_per_chunk, leftover_chunk_buffer): for vad_chunk in chunk_samples(chunk, bytes_per_chunk, leftover_chunk_buffer):
is_speech = vad.is_speech(vad_chunk) is_speech = vad_is_speech(vad_chunk)
if not self.process(seconds_per_chunk, is_speech): if not self.process(seconds_per_chunk, is_speech):
return False return False

View File

@ -24,6 +24,9 @@ from .const import (
DEFAULT_WAKE_WORD_TIMEOUT, DEFAULT_WAKE_WORD_TIMEOUT,
DOMAIN, DOMAIN,
EVENT_RECORDING, EVENT_RECORDING,
SAMPLE_CHANNELS,
SAMPLE_RATE,
SAMPLE_WIDTH,
) )
from .error import PipelineNotFound from .error import PipelineNotFound
from .pipeline import ( from .pipeline import (
@ -92,7 +95,6 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
vol.Optional("volume_multiplier"): float, vol.Optional("volume_multiplier"): float,
# Advanced use cases/testing # Advanced use cases/testing
vol.Optional("no_vad"): bool, vol.Optional("no_vad"): bool,
vol.Optional("no_chunking"): bool,
} }
}, },
extra=vol.ALLOW_EXTRA, extra=vol.ALLOW_EXTRA,
@ -170,9 +172,14 @@ async def websocket_run(
# Yield until we receive an empty chunk # Yield until we receive an empty chunk
while chunk := await audio_queue.get(): while chunk := await audio_queue.get():
if incoming_sample_rate != 16000: if incoming_sample_rate != SAMPLE_RATE:
chunk, state = audioop.ratecv( chunk, state = audioop.ratecv(
chunk, 2, 1, incoming_sample_rate, 16000, state chunk,
SAMPLE_WIDTH,
SAMPLE_CHANNELS,
incoming_sample_rate,
SAMPLE_RATE,
state,
) )
yield chunk yield chunk
@ -206,7 +213,6 @@ async def websocket_run(
auto_gain_dbfs=msg_input.get("auto_gain_dbfs", 0), auto_gain_dbfs=msg_input.get("auto_gain_dbfs", 0),
volume_multiplier=msg_input.get("volume_multiplier", 1.0), volume_multiplier=msg_input.get("volume_multiplier", 1.0),
is_vad_enabled=not msg_input.get("no_vad", False), is_vad_enabled=not msg_input.get("no_vad", False),
is_chunking_enabled=not msg_input.get("no_chunking", False),
) )
elif start_stage == PipelineStage.INTENT: elif start_stage == PipelineStage.INTENT:
# Input to conversation agent # Input to conversation agent
@ -424,9 +430,9 @@ def websocket_list_languages(
connection.send_result( connection.send_result(
msg["id"], msg["id"],
{ {
"languages": sorted(pipeline_languages) "languages": (
if pipeline_languages sorted(pipeline_languages) if pipeline_languages else pipeline_languages
else pipeline_languages )
}, },
) )

View File

@ -31,12 +31,14 @@ from homeassistant.components.assist_pipeline import (
async_pipeline_from_audio_stream, async_pipeline_from_audio_stream,
select as pipeline_select, select as pipeline_select,
) )
from homeassistant.components.assist_pipeline.audio_enhancer import (
AudioEnhancer,
MicroVadEnhancer,
)
from homeassistant.components.assist_pipeline.vad import ( from homeassistant.components.assist_pipeline.vad import (
AudioBuffer, AudioBuffer,
VadSensitivity, VadSensitivity,
VoiceActivityDetector,
VoiceCommandSegmenter, VoiceCommandSegmenter,
WebRtcVad,
) )
from homeassistant.const import __version__ from homeassistant.const import __version__
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
@ -233,13 +235,13 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
try: try:
# Wait for speech before starting pipeline # Wait for speech before starting pipeline
segmenter = VoiceCommandSegmenter(silence_seconds=self.silence_seconds) segmenter = VoiceCommandSegmenter(silence_seconds=self.silence_seconds)
vad = WebRtcVad() audio_enhancer = MicroVadEnhancer(0, 0, True)
chunk_buffer: deque[bytes] = deque( chunk_buffer: deque[bytes] = deque(
maxlen=self.buffered_chunks_before_speech, maxlen=self.buffered_chunks_before_speech,
) )
speech_detected = await self._wait_for_speech( speech_detected = await self._wait_for_speech(
segmenter, segmenter,
vad, audio_enhancer,
chunk_buffer, chunk_buffer,
) )
if not speech_detected: if not speech_detected:
@ -253,7 +255,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
try: try:
async for chunk in self._segment_audio( async for chunk in self._segment_audio(
segmenter, segmenter,
vad, audio_enhancer,
chunk_buffer, chunk_buffer,
): ):
yield chunk yield chunk
@ -317,7 +319,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
async def _wait_for_speech( async def _wait_for_speech(
self, self,
segmenter: VoiceCommandSegmenter, segmenter: VoiceCommandSegmenter,
vad: VoiceActivityDetector, audio_enhancer: AudioEnhancer,
chunk_buffer: MutableSequence[bytes], chunk_buffer: MutableSequence[bytes],
): ):
"""Buffer audio chunks until speech is detected. """Buffer audio chunks until speech is detected.
@ -329,13 +331,18 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
async with asyncio.timeout(self.audio_timeout): async with asyncio.timeout(self.audio_timeout):
chunk = await self._audio_queue.get() chunk = await self._audio_queue.get()
assert vad.samples_per_chunk is not None assert audio_enhancer.samples_per_chunk is not None
vad_buffer = AudioBuffer(vad.samples_per_chunk * WIDTH) vad_buffer = AudioBuffer(audio_enhancer.samples_per_chunk * WIDTH)
while chunk: while chunk:
chunk_buffer.append(chunk) chunk_buffer.append(chunk)
segmenter.process_with_vad(chunk, vad, vad_buffer) segmenter.process_with_vad(
chunk,
audio_enhancer.samples_per_chunk,
lambda x: audio_enhancer.enhance_chunk(x, 0).is_speech is True,
vad_buffer,
)
if segmenter.in_command: if segmenter.in_command:
# Buffer until command starts # Buffer until command starts
if len(vad_buffer) > 0: if len(vad_buffer) > 0:
@ -351,7 +358,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
async def _segment_audio( async def _segment_audio(
self, self,
segmenter: VoiceCommandSegmenter, segmenter: VoiceCommandSegmenter,
vad: VoiceActivityDetector, audio_enhancer: AudioEnhancer,
chunk_buffer: Sequence[bytes], chunk_buffer: Sequence[bytes],
) -> AsyncIterable[bytes]: ) -> AsyncIterable[bytes]:
"""Yield audio chunks until voice command has finished.""" """Yield audio chunks until voice command has finished."""
@ -364,11 +371,16 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
async with asyncio.timeout(self.audio_timeout): async with asyncio.timeout(self.audio_timeout):
chunk = await self._audio_queue.get() chunk = await self._audio_queue.get()
assert vad.samples_per_chunk is not None assert audio_enhancer.samples_per_chunk is not None
vad_buffer = AudioBuffer(vad.samples_per_chunk * WIDTH) vad_buffer = AudioBuffer(audio_enhancer.samples_per_chunk * WIDTH)
while chunk: while chunk:
if not segmenter.process_with_vad(chunk, vad, vad_buffer): if not segmenter.process_with_vad(
chunk,
audio_enhancer.samples_per_chunk,
lambda x: audio_enhancer.enhance_chunk(x, 0).is_speech is True,
vad_buffer,
):
# Voice command is finished # Voice command is finished
break break

View File

@ -45,6 +45,7 @@ Pillow==10.4.0
pip>=21.3.1 pip>=21.3.1
psutil-home-assistant==0.0.1 psutil-home-assistant==0.0.1
PyJWT==2.8.0 PyJWT==2.8.0
pymicro-vad==1.0.0
PyNaCl==1.5.0 PyNaCl==1.5.0
pyOpenSSL==24.2.1 pyOpenSSL==24.2.1
pyserial==3.5 pyserial==3.5
@ -60,7 +61,6 @@ urllib3>=1.26.5,<2
voluptuous-openapi==0.0.5 voluptuous-openapi==0.0.5
voluptuous-serialize==2.6.0 voluptuous-serialize==2.6.0
voluptuous==0.15.2 voluptuous==0.15.2
webrtc-noise-gain==1.2.3
yarl==1.9.4 yarl==1.9.4
zeroconf==0.132.2 zeroconf==0.132.2

View File

@ -2007,6 +2007,9 @@ pymelcloud==2.5.9
# homeassistant.components.meteoclimatic # homeassistant.components.meteoclimatic
pymeteoclimatic==0.1.0 pymeteoclimatic==0.1.0
# homeassistant.components.assist_pipeline
pymicro-vad==1.0.0
# homeassistant.components.xiaomi_tv # homeassistant.components.xiaomi_tv
pymitv==1.4.3 pymitv==1.4.3
@ -2896,9 +2899,6 @@ weatherflow4py==0.2.21
# homeassistant.components.webmin # homeassistant.components.webmin
webmin-xmlrpc==0.0.2 webmin-xmlrpc==0.0.2
# homeassistant.components.assist_pipeline
webrtc-noise-gain==1.2.3
# homeassistant.components.whirlpool # homeassistant.components.whirlpool
whirlpool-sixth-sense==0.18.8 whirlpool-sixth-sense==0.18.8

View File

@ -1603,6 +1603,9 @@ pymelcloud==2.5.9
# homeassistant.components.meteoclimatic # homeassistant.components.meteoclimatic
pymeteoclimatic==0.1.0 pymeteoclimatic==0.1.0
# homeassistant.components.assist_pipeline
pymicro-vad==1.0.0
# homeassistant.components.mochad # homeassistant.components.mochad
pymochad==0.2.0 pymochad==0.2.0
@ -2282,9 +2285,6 @@ weatherflow4py==0.2.21
# homeassistant.components.webmin # homeassistant.components.webmin
webmin-xmlrpc==0.0.2 webmin-xmlrpc==0.0.2
# homeassistant.components.assist_pipeline
webrtc-noise-gain==1.2.3
# homeassistant.components.whirlpool # homeassistant.components.whirlpool
whirlpool-sixth-sense==0.18.8 whirlpool-sixth-sense==0.18.8

View File

@ -172,7 +172,6 @@ EXCEPTIONS = {
"tapsaff", # https://github.com/bazwilliams/python-taps-aff/pull/5 "tapsaff", # https://github.com/bazwilliams/python-taps-aff/pull/5
"tellduslive", # https://github.com/molobrakos/tellduslive/pull/24 "tellduslive", # https://github.com/molobrakos/tellduslive/pull/24
"tellsticknet", # https://github.com/molobrakos/tellsticknet/pull/33 "tellsticknet", # https://github.com/molobrakos/tellsticknet/pull/33
"webrtc_noise_gain", # https://github.com/rhasspy/webrtc-noise-gain/pull/24
"vincenty", # Public domain "vincenty", # Public domain
"zeversolar", # https://github.com/kvanzuijlen/zeversolar/pull/46 "zeversolar", # https://github.com/kvanzuijlen/zeversolar/pull/46
} }

View File

@ -75,9 +75,7 @@ async def test_pipeline_from_audio_stream_auto(
channel=stt.AudioChannels.CHANNEL_MONO, channel=stt.AudioChannels.CHANNEL_MONO,
), ),
stt_stream=audio_data(), stt_stream=audio_data(),
audio_settings=assist_pipeline.AudioSettings( audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
is_vad_enabled=False, is_chunking_enabled=False
),
) )
assert process_events(events) == snapshot assert process_events(events) == snapshot
@ -140,9 +138,7 @@ async def test_pipeline_from_audio_stream_legacy(
), ),
stt_stream=audio_data(), stt_stream=audio_data(),
pipeline_id=pipeline_id, pipeline_id=pipeline_id,
audio_settings=assist_pipeline.AudioSettings( audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
is_vad_enabled=False, is_chunking_enabled=False
),
) )
assert process_events(events) == snapshot assert process_events(events) == snapshot
@ -205,9 +201,7 @@ async def test_pipeline_from_audio_stream_entity(
), ),
stt_stream=audio_data(), stt_stream=audio_data(),
pipeline_id=pipeline_id, pipeline_id=pipeline_id,
audio_settings=assist_pipeline.AudioSettings( audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
is_vad_enabled=False, is_chunking_enabled=False
),
) )
assert process_events(events) == snapshot assert process_events(events) == snapshot
@ -271,9 +265,7 @@ async def test_pipeline_from_audio_stream_no_stt(
), ),
stt_stream=audio_data(), stt_stream=audio_data(),
pipeline_id=pipeline_id, pipeline_id=pipeline_id,
audio_settings=assist_pipeline.AudioSettings( audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
is_vad_enabled=False, is_chunking_enabled=False
),
) )
assert not events assert not events
@ -335,24 +327,25 @@ async def test_pipeline_from_audio_stream_wake_word(
# [0, 2, ...] # [0, 2, ...]
wake_chunk_2 = bytes(it.islice(it.cycle(range(0, 256, 2)), BYTES_ONE_SECOND)) wake_chunk_2 = bytes(it.islice(it.cycle(range(0, 256, 2)), BYTES_ONE_SECOND))
bytes_per_chunk = int(0.01 * BYTES_ONE_SECOND) samples_per_chunk = 160
bytes_per_chunk = samples_per_chunk * 2 # 16-bit
async def audio_data(): async def audio_data():
# 1 second in 10 ms chunks # 1 second in chunks
i = 0 i = 0
while i < len(wake_chunk_1): while i < len(wake_chunk_1):
yield wake_chunk_1[i : i + bytes_per_chunk] yield wake_chunk_1[i : i + bytes_per_chunk]
i += bytes_per_chunk i += bytes_per_chunk
# 1 second in 30 ms chunks # 1 second in chunks
i = 0 i = 0
while i < len(wake_chunk_2): while i < len(wake_chunk_2):
yield wake_chunk_2[i : i + bytes_per_chunk] yield wake_chunk_2[i : i + bytes_per_chunk]
i += bytes_per_chunk i += bytes_per_chunk
yield b"wake word!" for chunk in (b"wake word!", b"part1", b"part2"):
yield b"part1" yield chunk + bytes(bytes_per_chunk - len(chunk))
yield b"part2"
yield b"" yield b""
await assist_pipeline.async_pipeline_from_audio_stream( await assist_pipeline.async_pipeline_from_audio_stream(
@ -373,7 +366,7 @@ async def test_pipeline_from_audio_stream_wake_word(
audio_seconds_to_buffer=1.5 audio_seconds_to_buffer=1.5
), ),
audio_settings=assist_pipeline.AudioSettings( audio_settings=assist_pipeline.AudioSettings(
is_vad_enabled=False, is_chunking_enabled=False is_vad_enabled=False, samples_per_chunk=samples_per_chunk
), ),
) )
@ -390,7 +383,9 @@ async def test_pipeline_from_audio_stream_wake_word(
) )
assert first_chunk == wake_chunk_1[len(wake_chunk_1) // 2 :] + wake_chunk_2 assert first_chunk == wake_chunk_1[len(wake_chunk_1) // 2 :] + wake_chunk_2
assert mock_stt_provider.received[-3:] == [b"queued audio", b"part1", b"part2"] assert mock_stt_provider.received[-3] == b"queued audio"
assert mock_stt_provider.received[-2].startswith(b"part1")
assert mock_stt_provider.received[-1].startswith(b"part2")
async def test_pipeline_save_audio( async def test_pipeline_save_audio(
@ -438,9 +433,7 @@ async def test_pipeline_save_audio(
pipeline_id=pipeline.id, pipeline_id=pipeline.id,
start_stage=assist_pipeline.PipelineStage.WAKE_WORD, start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
end_stage=assist_pipeline.PipelineStage.STT, end_stage=assist_pipeline.PipelineStage.STT,
audio_settings=assist_pipeline.AudioSettings( audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
is_vad_enabled=False, is_chunking_enabled=False
),
) )
pipeline_dirs = list(temp_dir.iterdir()) pipeline_dirs = list(temp_dir.iterdir())
@ -685,9 +678,7 @@ async def test_wake_word_detection_aborted(
wake_word_settings=assist_pipeline.WakeWordSettings( wake_word_settings=assist_pipeline.WakeWordSettings(
audio_seconds_to_buffer=1.5 audio_seconds_to_buffer=1.5
), ),
audio_settings=assist_pipeline.AudioSettings( audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
is_vad_enabled=False, is_chunking_enabled=False
),
), ),
) )
await pipeline_input.validate() await pipeline_input.validate()

View File

@ -1,11 +1,9 @@
"""Tests for voice command segmenter.""" """Tests for voice command segmenter."""
import itertools as it import itertools as it
from unittest.mock import patch
from homeassistant.components.assist_pipeline.vad import ( from homeassistant.components.assist_pipeline.vad import (
AudioBuffer, AudioBuffer,
VoiceActivityDetector,
VoiceCommandSegmenter, VoiceCommandSegmenter,
chunk_samples, chunk_samples,
) )
@ -44,59 +42,41 @@ def test_speech() -> None:
def test_audio_buffer() -> None: def test_audio_buffer() -> None:
"""Test audio buffer wrapping.""" """Test audio buffer wrapping."""
class DisabledVad(VoiceActivityDetector): samples_per_chunk = 160 # 10 ms
def is_speech(self, chunk): bytes_per_chunk = samples_per_chunk * 2
return False leftover_buffer = AudioBuffer(bytes_per_chunk)
@property
def samples_per_chunk(self):
return 160 # 10 ms
vad = DisabledVad()
bytes_per_chunk = vad.samples_per_chunk * 2
vad_buffer = AudioBuffer(bytes_per_chunk)
segmenter = VoiceCommandSegmenter()
with patch.object(vad, "is_speech", return_value=False) as mock_process:
# Partially fill audio buffer # Partially fill audio buffer
half_chunk = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk // 2)) half_chunk = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk // 2))
segmenter.process_with_vad(half_chunk, vad, vad_buffer) chunks = list(chunk_samples(half_chunk, bytes_per_chunk, leftover_buffer))
assert not mock_process.called assert not chunks
assert vad_buffer is not None assert leftover_buffer.bytes() == half_chunk
assert vad_buffer.bytes() == half_chunk
# Fill and wrap with 1/4 chunk left over # Fill and wrap with 1/4 chunk left over
three_quarters_chunk = bytes( three_quarters_chunk = bytes(
it.islice(it.cycle(range(256)), int(0.75 * bytes_per_chunk)) it.islice(it.cycle(range(256)), int(0.75 * bytes_per_chunk))
) )
segmenter.process_with_vad(three_quarters_chunk, vad, vad_buffer) chunks = list(chunk_samples(three_quarters_chunk, bytes_per_chunk, leftover_buffer))
assert mock_process.call_count == 1 assert len(chunks) == 1
assert ( assert (
vad_buffer.bytes() leftover_buffer.bytes()
== three_quarters_chunk[ == three_quarters_chunk[len(three_quarters_chunk) - (bytes_per_chunk // 4) :]
len(three_quarters_chunk) - (bytes_per_chunk // 4) :
]
)
assert (
mock_process.call_args[0][0]
== half_chunk + three_quarters_chunk[: bytes_per_chunk // 2]
) )
assert chunks[0] == half_chunk + three_quarters_chunk[: bytes_per_chunk // 2]
# Run 2 chunks through # Run 2 chunks through
segmenter.reset() leftover_buffer.clear()
vad_buffer.clear() assert len(leftover_buffer) == 0
assert len(vad_buffer) == 0
mock_process.reset_mock()
two_chunks = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk * 2)) two_chunks = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk * 2))
segmenter.process_with_vad(two_chunks, vad, vad_buffer) chunks = list(chunk_samples(two_chunks, bytes_per_chunk, leftover_buffer))
assert mock_process.call_count == 2 assert len(chunks) == 2
assert len(vad_buffer) == 0 assert len(leftover_buffer) == 0
assert mock_process.call_args_list[0][0][0] == two_chunks[:bytes_per_chunk] assert chunks[0] == two_chunks[:bytes_per_chunk]
assert mock_process.call_args_list[1][0][0] == two_chunks[bytes_per_chunk:] assert chunks[1] == two_chunks[bytes_per_chunk:]
def test_partial_chunk() -> None: def test_partial_chunk() -> None:
@ -125,43 +105,3 @@ def test_chunk_samples_leftover() -> None:
assert len(chunks) == 1 assert len(chunks) == 1
assert leftover_chunk_buffer.bytes() == bytes([5, 6]) assert leftover_chunk_buffer.bytes() == bytes([5, 6])
def test_vad_no_chunking() -> None:
"""Test VAD that doesn't require chunking."""
class VadNoChunk(VoiceActivityDetector):
def is_speech(self, chunk: bytes) -> bool:
return sum(chunk) > 0
@property
def samples_per_chunk(self) -> int | None:
return None
vad = VadNoChunk()
segmenter = VoiceCommandSegmenter(
speech_seconds=1.0, silence_seconds=1.0, reset_seconds=0.5
)
silence = bytes([0] * 16000)
speech = bytes([255] * (16000 // 2))
# Test with differently-sized chunks
assert vad.is_speech(speech)
assert not vad.is_speech(silence)
# Simulate voice command
assert segmenter.process_with_vad(silence, vad, None)
# begin
assert segmenter.process_with_vad(speech, vad, None)
assert segmenter.process_with_vad(speech, vad, None)
assert segmenter.process_with_vad(speech, vad, None)
# reset with silence
assert segmenter.process_with_vad(silence, vad, None)
# resume
assert segmenter.process_with_vad(speech, vad, None)
assert segmenter.process_with_vad(speech, vad, None)
assert segmenter.process_with_vad(speech, vad, None)
assert segmenter.process_with_vad(speech, vad, None)
# end
assert segmenter.process_with_vad(silence, vad, None)
assert not segmenter.process_with_vad(silence, vad, None)

View File

@ -259,12 +259,7 @@ async def test_audio_pipeline_with_wake_word_no_timeout(
"type": "assist_pipeline/run", "type": "assist_pipeline/run",
"start_stage": "wake_word", "start_stage": "wake_word",
"end_stage": "tts", "end_stage": "tts",
"input": { "input": {"sample_rate": 16000, "timeout": 0, "no_vad": True},
"sample_rate": 16000,
"timeout": 0,
"no_vad": True,
"no_chunking": True,
},
} }
) )
@ -1876,11 +1871,7 @@ async def test_wake_word_cooldown_same_id(
"type": "assist_pipeline/run", "type": "assist_pipeline/run",
"start_stage": "wake_word", "start_stage": "wake_word",
"end_stage": "tts", "end_stage": "tts",
"input": { "input": {"sample_rate": 16000, "no_vad": True},
"sample_rate": 16000,
"no_vad": True,
"no_chunking": True,
},
} }
) )
@ -1889,11 +1880,7 @@ async def test_wake_word_cooldown_same_id(
"type": "assist_pipeline/run", "type": "assist_pipeline/run",
"start_stage": "wake_word", "start_stage": "wake_word",
"end_stage": "tts", "end_stage": "tts",
"input": { "input": {"sample_rate": 16000, "no_vad": True},
"sample_rate": 16000,
"no_vad": True,
"no_chunking": True,
},
} }
) )
@ -1967,11 +1954,7 @@ async def test_wake_word_cooldown_different_ids(
"type": "assist_pipeline/run", "type": "assist_pipeline/run",
"start_stage": "wake_word", "start_stage": "wake_word",
"end_stage": "tts", "end_stage": "tts",
"input": { "input": {"sample_rate": 16000, "no_vad": True},
"sample_rate": 16000,
"no_vad": True,
"no_chunking": True,
},
} }
) )
@ -1980,11 +1963,7 @@ async def test_wake_word_cooldown_different_ids(
"type": "assist_pipeline/run", "type": "assist_pipeline/run",
"start_stage": "wake_word", "start_stage": "wake_word",
"end_stage": "tts", "end_stage": "tts",
"input": { "input": {"sample_rate": 16000, "no_vad": True},
"sample_rate": 16000,
"no_vad": True,
"no_chunking": True,
},
} }
) )
@ -2094,11 +2073,7 @@ async def test_wake_word_cooldown_different_entities(
"pipeline": pipeline_id_1, "pipeline": pipeline_id_1,
"start_stage": "wake_word", "start_stage": "wake_word",
"end_stage": "tts", "end_stage": "tts",
"input": { "input": {"sample_rate": 16000, "no_vad": True},
"sample_rate": 16000,
"no_vad": True,
"no_chunking": True,
},
} }
) )
@ -2109,11 +2084,7 @@ async def test_wake_word_cooldown_different_entities(
"pipeline": pipeline_id_2, "pipeline": pipeline_id_2,
"start_stage": "wake_word", "start_stage": "wake_word",
"end_stage": "tts", "end_stage": "tts",
"input": { "input": {"sample_rate": 16000, "no_vad": True},
"sample_rate": 16000,
"no_vad": True,
"no_chunking": True,
},
} }
) )
@ -2210,11 +2181,7 @@ async def test_device_capture(
"type": "assist_pipeline/run", "type": "assist_pipeline/run",
"start_stage": "stt", "start_stage": "stt",
"end_stage": "stt", "end_stage": "stt",
"input": { "input": {"sample_rate": 16000, "no_vad": True},
"sample_rate": 16000,
"no_vad": True,
"no_chunking": True,
},
"device_id": satellite_device.id, "device_id": satellite_device.id,
} }
) )
@ -2315,11 +2282,7 @@ async def test_device_capture_override(
"type": "assist_pipeline/run", "type": "assist_pipeline/run",
"start_stage": "stt", "start_stage": "stt",
"end_stage": "stt", "end_stage": "stt",
"input": { "input": {"sample_rate": 16000, "no_vad": True},
"sample_rate": 16000,
"no_vad": True,
"no_chunking": True,
},
"device_id": satellite_device.id, "device_id": satellite_device.id,
} }
) )
@ -2464,11 +2427,7 @@ async def test_device_capture_queue_full(
"type": "assist_pipeline/run", "type": "assist_pipeline/run",
"start_stage": "stt", "start_stage": "stt",
"end_stage": "stt", "end_stage": "stt",
"input": { "input": {"sample_rate": 16000, "no_vad": True},
"sample_rate": 16000,
"no_vad": True,
"no_chunking": True,
},
"device_id": satellite_device.id, "device_id": satellite_device.id,
} }
) )

View File

@ -43,9 +43,12 @@ async def test_pipeline(
"""Test that pipeline function is called from RTP protocol.""" """Test that pipeline function is called from RTP protocol."""
assert await async_setup_component(hass, "voip", {}) assert await async_setup_component(hass, "voip", {})
def is_speech(self, chunk): def process_10ms(self, chunk):
"""Anything non-zero is speech.""" """Anything non-zero is speech."""
return sum(chunk) > 0 if sum(chunk) > 0:
return 1
return 0
done = asyncio.Event() done = asyncio.Event()
@ -98,8 +101,8 @@ async def test_pipeline(
with ( with (
patch( patch(
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech", "pymicro_vad.MicroVad.Process10ms",
new=is_speech, new=process_10ms,
), ),
patch( patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream", "homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
@ -238,9 +241,12 @@ async def test_tts_timeout(
"""Test that TTS will time out based on its length.""" """Test that TTS will time out based on its length."""
assert await async_setup_component(hass, "voip", {}) assert await async_setup_component(hass, "voip", {})
def is_speech(self, chunk): def process_10ms(self, chunk):
"""Anything non-zero is speech.""" """Anything non-zero is speech."""
return sum(chunk) > 0 if sum(chunk) > 0:
return 1
return 0
done = asyncio.Event() done = asyncio.Event()
@ -298,8 +304,8 @@ async def test_tts_timeout(
with ( with (
patch( patch(
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech", "pymicro_vad.MicroVad.Process10ms",
new=is_speech, new=process_10ms,
), ),
patch( patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream", "homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
@ -361,9 +367,12 @@ async def test_tts_wrong_extension(
"""Test that TTS will only stream WAV audio.""" """Test that TTS will only stream WAV audio."""
assert await async_setup_component(hass, "voip", {}) assert await async_setup_component(hass, "voip", {})
def is_speech(self, chunk): def process_10ms(self, chunk):
"""Anything non-zero is speech.""" """Anything non-zero is speech."""
return sum(chunk) > 0 if sum(chunk) > 0:
return 1
return 0
done = asyncio.Event() done = asyncio.Event()
@ -403,8 +412,8 @@ async def test_tts_wrong_extension(
with ( with (
patch( patch(
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech", "pymicro_vad.MicroVad.Process10ms",
new=is_speech, new=process_10ms,
), ),
patch( patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream", "homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
@ -456,9 +465,12 @@ async def test_tts_wrong_wav_format(
"""Test that TTS will only stream WAV audio with a specific format.""" """Test that TTS will only stream WAV audio with a specific format."""
assert await async_setup_component(hass, "voip", {}) assert await async_setup_component(hass, "voip", {})
def is_speech(self, chunk): def process_10ms(self, chunk):
"""Anything non-zero is speech.""" """Anything non-zero is speech."""
return sum(chunk) > 0 if sum(chunk) > 0:
return 1
return 0
done = asyncio.Event() done = asyncio.Event()
@ -505,8 +517,8 @@ async def test_tts_wrong_wav_format(
with ( with (
patch( patch(
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech", "pymicro_vad.MicroVad.Process10ms",
new=is_speech, new=process_10ms,
), ),
patch( patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream", "homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
@ -558,9 +570,12 @@ async def test_empty_tts_output(
"""Test that TTS will not stream when output is empty.""" """Test that TTS will not stream when output is empty."""
assert await async_setup_component(hass, "voip", {}) assert await async_setup_component(hass, "voip", {})
def is_speech(self, chunk): def process_10ms(self, chunk):
"""Anything non-zero is speech.""" """Anything non-zero is speech."""
return sum(chunk) > 0 if sum(chunk) > 0:
return 1
return 0
async def async_pipeline_from_audio_stream(*args, **kwargs): async def async_pipeline_from_audio_stream(*args, **kwargs):
stt_stream = kwargs["stt_stream"] stt_stream = kwargs["stt_stream"]
@ -591,8 +606,8 @@ async def test_empty_tts_output(
with ( with (
patch( patch(
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech", "pymicro_vad.MicroVad.Process10ms",
new=is_speech, new=process_10ms,
), ),
patch( patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream", "homeassistant.components.voip.voip.async_pipeline_from_audio_stream",