diff --git a/homeassistant/components/assist_pipeline/__init__.py b/homeassistant/components/assist_pipeline/__init__.py index c2d25da2162..4c2fe01036f 100644 --- a/homeassistant/components/assist_pipeline/__init__.py +++ b/homeassistant/components/assist_pipeline/__init__.py @@ -52,6 +52,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_pipeline_from_audio_stream( hass: HomeAssistant, + *, context: Context, event_callback: PipelineEventCallback, stt_metadata: stt.SpeechMetadata, diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index 320812b2039..3759fc12c75 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -49,6 +49,7 @@ from .error import ( WakeWordDetectionError, WakeWordTimeoutError, ) +from .ring_buffer import RingBuffer from .vad import VoiceActivityTimeout, VoiceCommandSegmenter _LOGGER = logging.getLogger(__name__) @@ -425,7 +426,6 @@ class PipelineRun: async def prepare_wake_word_detection(self) -> None: """Prepare wake-word-detection.""" - # Need to add to pipeline store engine = wake_word.async_default_engine(self.hass) if engine is None: raise WakeWordDetectionError( @@ -448,7 +448,7 @@ class PipelineRun: async def wake_word_detection( self, stream: AsyncIterable[bytes], - audio_buffer: list[bytes], + audio_chunks_for_stt: list[bytes], ) -> wake_word.DetectionResult | None: """Run wake-word-detection portion of pipeline. Returns detection result.""" metadata_dict = asdict( @@ -484,46 +484,29 @@ class PipelineRun: # Use VAD to determine timeout wake_word_vad = VoiceActivityTimeout(wake_word_settings.timeout) - # Audio chunk buffer. - audio_bytes_to_buffer = int( - wake_word_settings.audio_seconds_to_buffer * 16000 * 2 + # Audio chunk buffer. This audio will be forwarded to speech-to-text + # after wake-word-detection. + num_audio_bytes_to_buffer = int( + wake_word_settings.audio_seconds_to_buffer * 16000 * 2 # 16-bit @ 16Khz ) - audio_ring_buffer = b"" - - async def timestamped_stream() -> AsyncIterable[tuple[bytes, int]]: - """Yield audio with timestamps (milliseconds since start of stream).""" - nonlocal audio_ring_buffer - - timestamp_ms = 0 - async for chunk in stream: - yield chunk, timestamp_ms - timestamp_ms += (len(chunk) // 2) // 16 # milliseconds @ 16Khz - - # Keeping audio right before wake word detection allows the - # voice command to be spoken immediately after the wake word. - if audio_bytes_to_buffer > 0: - audio_ring_buffer += chunk - if len(audio_ring_buffer) > audio_bytes_to_buffer: - # A proper ring buffer would be far more efficient - audio_ring_buffer = audio_ring_buffer[ - len(audio_ring_buffer) - audio_bytes_to_buffer : - ] - - if (wake_word_vad is not None) and (not wake_word_vad.process(chunk)): - raise WakeWordTimeoutError( - code="wake-word-timeout", message="Wake word was not detected" - ) + stt_audio_buffer: RingBuffer | None = None + if num_audio_bytes_to_buffer > 0: + stt_audio_buffer = RingBuffer(num_audio_bytes_to_buffer) try: # Detect wake word(s) result = await self.wake_word_provider.async_process_audio_stream( - timestamped_stream() + _wake_word_audio_stream( + audio_stream=stream, + stt_audio_buffer=stt_audio_buffer, + wake_word_vad=wake_word_vad, + ) ) - if audio_ring_buffer: + if stt_audio_buffer is not None: # All audio kept from right before the wake word was detected as # a single chunk. - audio_buffer.append(audio_ring_buffer) + audio_chunks_for_stt.append(stt_audio_buffer.getvalue()) except WakeWordTimeoutError: _LOGGER.debug("Timeout during wake word detection") raise @@ -540,9 +523,14 @@ class PipelineRun: wake_word_output: dict[str, Any] = {} else: if result.queued_audio: - # Add audio that was pending at detection + # Add audio that was pending at detection. + # + # Because detection occurs *after* the wake word was actually + # spoken, we need to make sure pending audio is forwarded to + # speech-to-text so the user does not have to pause before + # speaking the voice command. for chunk_ts in result.queued_audio: - audio_buffer.append(chunk_ts[0]) + audio_chunks_for_stt.append(chunk_ts[0]) wake_word_output = asdict(result) @@ -608,41 +596,12 @@ class PipelineRun: ) try: - segmenter = VoiceCommandSegmenter() - - async def segment_stream( - stream: AsyncIterable[bytes], - ) -> AsyncGenerator[bytes, None]: - """Stop stream when voice command is finished.""" - sent_vad_start = False - timestamp_ms = 0 - async for chunk in stream: - if not segmenter.process(chunk): - # Silence detected at the end of voice command - self.process_event( - PipelineEvent( - PipelineEventType.STT_VAD_END, - {"timestamp": timestamp_ms}, - ) - ) - break - - if segmenter.in_command and (not sent_vad_start): - # Speech detected at start of voice command - self.process_event( - PipelineEvent( - PipelineEventType.STT_VAD_START, - {"timestamp": timestamp_ms}, - ) - ) - sent_vad_start = True - - yield chunk - timestamp_ms += (len(chunk) // 2) // 16 # milliseconds @ 16Khz - # Transcribe audio stream result = await self.stt_provider.async_process_audio_stream( - metadata, segment_stream(stream) + metadata, + self._speech_to_text_stream( + audio_stream=stream, stt_vad=VoiceCommandSegmenter() + ), ) except Exception as src_error: _LOGGER.exception("Unexpected error during speech-to-text") @@ -677,6 +636,42 @@ class PipelineRun: return result.text + async def _speech_to_text_stream( + self, + audio_stream: AsyncIterable[bytes], + stt_vad: VoiceCommandSegmenter | None, + sample_rate: int = 16000, + sample_width: int = 2, + ) -> AsyncGenerator[bytes, None]: + """Yield audio chunks until VAD detects silence or speech-to-text completes.""" + ms_per_sample = sample_rate // 1000 + sent_vad_start = False + timestamp_ms = 0 + async for chunk in audio_stream: + if stt_vad is not None: + if not stt_vad.process(chunk): + # Silence detected at the end of voice command + self.process_event( + PipelineEvent( + PipelineEventType.STT_VAD_END, + {"timestamp": timestamp_ms}, + ) + ) + break + + if stt_vad.in_command and (not sent_vad_start): + # Speech detected at start of voice command + self.process_event( + PipelineEvent( + PipelineEventType.STT_VAD_START, + {"timestamp": timestamp_ms}, + ) + ) + sent_vad_start = True + + yield chunk + timestamp_ms += (len(chunk) // sample_width) // ms_per_sample + async def prepare_recognize_intent(self) -> None: """Prepare recognizing an intent.""" agent_info = conversation.async_get_agent_info( @@ -861,13 +856,14 @@ class PipelineInput: """Run pipeline.""" self.run.start() current_stage: PipelineStage | None = self.run.start_stage - audio_buffer: list[bytes] = [] + stt_audio_buffer: list[bytes] = [] try: if current_stage == PipelineStage.WAKE_WORD: + # wake-word-detection assert self.stt_stream is not None detect_result = await self.run.wake_word_detection( - self.stt_stream, audio_buffer + self.stt_stream, stt_audio_buffer ) if detect_result is None: # No wake word. Abort the rest of the pipeline. @@ -882,19 +878,22 @@ class PipelineInput: assert self.stt_metadata is not None assert self.stt_stream is not None - if audio_buffer: + stt_stream = self.stt_stream - async def buffered_stream() -> AsyncGenerator[bytes, None]: - for chunk in audio_buffer: + if stt_audio_buffer: + # Send audio in the buffer first to speech-to-text, then move on to stt_stream. + # This is basically an async itertools.chain. + async def buffer_then_audio_stream() -> AsyncGenerator[bytes, None]: + # Buffered audio + for chunk in stt_audio_buffer: yield chunk + # Streamed audio assert self.stt_stream is not None async for chunk in self.stt_stream: yield chunk - stt_stream = cast(AsyncIterable[bytes], buffered_stream()) - else: - stt_stream = self.stt_stream + stt_stream = buffer_then_audio_stream() intent_input = await self.run.speech_to_text( self.stt_metadata, @@ -906,6 +905,7 @@ class PipelineInput: tts_input = self.tts_input if current_stage == PipelineStage.INTENT: + # intent-recognition assert intent_input is not None tts_input = await self.run.recognize_intent( intent_input, @@ -915,6 +915,7 @@ class PipelineInput: current_stage = PipelineStage.TTS if self.run.end_stage != PipelineStage.INTENT: + # text-to-speech if current_stage == PipelineStage.TTS: assert tts_input is not None await self.run.text_to_speech(tts_input) @@ -999,6 +1000,36 @@ class PipelineInput: await asyncio.gather(*prepare_tasks) +async def _wake_word_audio_stream( + audio_stream: AsyncIterable[bytes], + stt_audio_buffer: RingBuffer | None, + wake_word_vad: VoiceActivityTimeout | None, + sample_rate: int = 16000, + sample_width: int = 2, +) -> AsyncIterable[tuple[bytes, int]]: + """Yield audio chunks with timestamps (milliseconds since start of stream). + + Adds audio to a ring buffer that will be forwarded to speech-to-text after + detection. Times out if VAD detects enough silence. + """ + ms_per_sample = sample_rate // 1000 + timestamp_ms = 0 + async for chunk in audio_stream: + yield chunk, timestamp_ms + timestamp_ms += (len(chunk) // sample_width) // ms_per_sample + + # Wake-word-detection occurs *after* the wake word was actually + # spoken. Keeping audio right before detection allows the voice + # command to be spoken immediately after the wake word. + if stt_audio_buffer is not None: + stt_audio_buffer.put(chunk) + + if (wake_word_vad is not None) and (not wake_word_vad.process(chunk)): + raise WakeWordTimeoutError( + code="wake-word-timeout", message="Wake word was not detected" + ) + + class PipelinePreferred(CollectionError): """Raised when attempting to delete the preferred pipelen.""" diff --git a/homeassistant/components/assist_pipeline/ring_buffer.py b/homeassistant/components/assist_pipeline/ring_buffer.py new file mode 100644 index 00000000000..d134389216c --- /dev/null +++ b/homeassistant/components/assist_pipeline/ring_buffer.py @@ -0,0 +1,57 @@ +"""Implementation of a ring buffer using bytearray.""" + + +class RingBuffer: + """Basic ring buffer using a bytearray. + + Not threadsafe. + """ + + def __init__(self, maxlen: int) -> None: + """Initialize empty buffer.""" + self._buffer = bytearray(maxlen) + self._pos = 0 + self._length = 0 + self._maxlen = maxlen + + @property + def maxlen(self) -> int: + """Return the maximum size of the buffer.""" + return self._maxlen + + @property + def pos(self) -> int: + """Return the current put position.""" + return self._pos + + def __len__(self) -> int: + """Return the length of data stored in the buffer.""" + return self._length + + def put(self, data: bytes) -> None: + """Put a chunk of data into the buffer, possibly wrapping around.""" + data_len = len(data) + new_pos = self._pos + data_len + if new_pos >= self._maxlen: + # Split into two chunks + num_bytes_1 = self._maxlen - self._pos + num_bytes_2 = new_pos - self._maxlen + + self._buffer[self._pos : self._maxlen] = data[:num_bytes_1] + self._buffer[:num_bytes_2] = data[num_bytes_1:] + new_pos = new_pos - self._maxlen + else: + # Entire chunk fits at current position + self._buffer[self._pos : self._pos + data_len] = data + + self._pos = new_pos + self._length = min(self._maxlen, self._length + data_len) + + def getvalue(self) -> bytes: + """Get bytes written to the buffer.""" + if (self._pos + self._length) <= self._maxlen: + # Single chunk + return bytes(self._buffer[: self._length]) + + # Two chunks + return bytes(self._buffer[self._pos :] + self._buffer[: self._pos]) diff --git a/homeassistant/components/assist_pipeline/vad.py b/homeassistant/components/assist_pipeline/vad.py index cae31671a3c..20a048d5621 100644 --- a/homeassistant/components/assist_pipeline/vad.py +++ b/homeassistant/components/assist_pipeline/vad.py @@ -1,12 +1,15 @@ """Voice activity detection.""" from __future__ import annotations +from collections.abc import Iterable from dataclasses import dataclass, field from enum import StrEnum +from typing import Final import webrtcvad -_SAMPLE_RATE = 16000 +_SAMPLE_RATE: Final = 16000 # Hz +_SAMPLE_WIDTH: Final = 2 # bytes class VadSensitivity(StrEnum): @@ -29,6 +32,45 @@ class VadSensitivity(StrEnum): return 1.0 +class AudioBuffer: + """Fixed-sized audio buffer with variable internal length.""" + + def __init__(self, maxlen: int) -> None: + """Initialize buffer.""" + self._buffer = bytearray(maxlen) + self._length = 0 + + @property + def length(self) -> int: + """Get number of bytes currently in the buffer.""" + return self._length + + def clear(self) -> None: + """Clear the buffer.""" + self._length = 0 + + def append(self, data: bytes) -> None: + """Append bytes to the buffer, increasing the internal length.""" + data_len = len(data) + if (self._length + data_len) > len(self._buffer): + raise ValueError("Length cannot be greater than buffer size") + + self._buffer[self._length : self._length + data_len] = data + self._length += data_len + + def bytes(self) -> bytes: + """Convert written portion of buffer to bytes.""" + return bytes(self._buffer[: self._length]) + + def __len__(self) -> int: + """Get the number of bytes currently in the buffer.""" + return self._length + + def __bool__(self) -> bool: + """Return True if there are bytes in the buffer.""" + return self._length > 0 + + @dataclass class VoiceCommandSegmenter: """Segments an audio stream into voice commands using webrtcvad.""" @@ -36,7 +78,7 @@ class VoiceCommandSegmenter: vad_mode: int = 3 """Aggressiveness in filtering out non-speech. 3 is the most aggressive.""" - vad_frames: int = 480 # 30 ms + vad_samples_per_chunk: int = 480 # 30 ms """Must be 10, 20, or 30 ms at 16Khz.""" speech_seconds: float = 0.3 @@ -67,20 +109,23 @@ class VoiceCommandSegmenter: """Seconds left before resetting start/stop time counters.""" _vad: webrtcvad.Vad = None - _audio_buffer: bytes = field(default_factory=bytes) - _bytes_per_chunk: int = 480 * 2 # 16-bit samples - _seconds_per_chunk: float = 0.03 # 30 ms + _leftover_chunk_buffer: AudioBuffer = field(init=False) + _bytes_per_chunk: int = field(init=False) + _seconds_per_chunk: float = field(init=False) def __post_init__(self) -> None: """Initialize VAD.""" self._vad = webrtcvad.Vad(self.vad_mode) - self._bytes_per_chunk = self.vad_frames * 2 - self._seconds_per_chunk = self.vad_frames / _SAMPLE_RATE + self._bytes_per_chunk = self.vad_samples_per_chunk * _SAMPLE_WIDTH + self._seconds_per_chunk = self.vad_samples_per_chunk / _SAMPLE_RATE + self._leftover_chunk_buffer = AudioBuffer( + self.vad_samples_per_chunk * _SAMPLE_WIDTH + ) self.reset() def reset(self) -> None: """Reset all counters and state.""" - self._audio_buffer = b"" + self._leftover_chunk_buffer.clear() self._speech_seconds_left = self.speech_seconds self._silence_seconds_left = self.silence_seconds self._timeout_seconds_left = self.timeout_seconds @@ -92,27 +137,20 @@ class VoiceCommandSegmenter: Returns False when command is done. """ - self._audio_buffer += samples - - # Process in 10, 20, or 30 ms chunks. - num_chunks = len(self._audio_buffer) // self._bytes_per_chunk - for chunk_idx in range(num_chunks): - chunk_offset = chunk_idx * self._bytes_per_chunk - chunk = self._audio_buffer[ - chunk_offset : chunk_offset + self._bytes_per_chunk - ] + for chunk in chunk_samples( + samples, self._bytes_per_chunk, self._leftover_chunk_buffer + ): if not self._process_chunk(chunk): self.reset() return False - if num_chunks > 0: - # Remove from buffer - self._audio_buffer = self._audio_buffer[ - num_chunks * self._bytes_per_chunk : - ] - return True + @property + def audio_buffer(self) -> bytes: + """Get partial chunk in the audio buffer.""" + return self._leftover_chunk_buffer.bytes() + def _process_chunk(self, chunk: bytes) -> bool: """Process a single chunk of 16-bit 16Khz mono audio. @@ -163,7 +201,7 @@ class VoiceActivityTimeout: vad_mode: int = 3 """Aggressiveness in filtering out non-speech. 3 is the most aggressive.""" - vad_frames: int = 480 # 30 ms + vad_samples_per_chunk: int = 480 # 30 ms """Must be 10, 20, or 30 ms at 16Khz.""" _silence_seconds_left: float = 0.0 @@ -173,20 +211,23 @@ class VoiceActivityTimeout: """Seconds left before resetting start/stop time counters.""" _vad: webrtcvad.Vad = None - _audio_buffer: bytes = field(default_factory=bytes) - _bytes_per_chunk: int = 480 * 2 # 16-bit samples - _seconds_per_chunk: float = 0.03 # 30 ms + _leftover_chunk_buffer: AudioBuffer = field(init=False) + _bytes_per_chunk: int = field(init=False) + _seconds_per_chunk: float = field(init=False) def __post_init__(self) -> None: """Initialize VAD.""" self._vad = webrtcvad.Vad(self.vad_mode) - self._bytes_per_chunk = self.vad_frames * 2 - self._seconds_per_chunk = self.vad_frames / _SAMPLE_RATE + self._bytes_per_chunk = self.vad_samples_per_chunk * _SAMPLE_WIDTH + self._seconds_per_chunk = self.vad_samples_per_chunk / _SAMPLE_RATE + self._leftover_chunk_buffer = AudioBuffer( + self.vad_samples_per_chunk * _SAMPLE_WIDTH + ) self.reset() def reset(self) -> None: """Reset all counters and state.""" - self._audio_buffer = b"" + self._leftover_chunk_buffer.clear() self._silence_seconds_left = self.silence_seconds self._reset_seconds_left = self.reset_seconds @@ -195,24 +236,12 @@ class VoiceActivityTimeout: Returns False when timeout is reached. """ - self._audio_buffer += samples - - # Process in 10, 20, or 30 ms chunks. - num_chunks = len(self._audio_buffer) // self._bytes_per_chunk - for chunk_idx in range(num_chunks): - chunk_offset = chunk_idx * self._bytes_per_chunk - chunk = self._audio_buffer[ - chunk_offset : chunk_offset + self._bytes_per_chunk - ] + for chunk in chunk_samples( + samples, self._bytes_per_chunk, self._leftover_chunk_buffer + ): if not self._process_chunk(chunk): return False - if num_chunks > 0: - # Remove from buffer - self._audio_buffer = self._audio_buffer[ - num_chunks * self._bytes_per_chunk : - ] - return True def _process_chunk(self, chunk: bytes) -> bool: @@ -239,3 +268,37 @@ class VoiceActivityTimeout: ) return True + + +def chunk_samples( + samples: bytes, + bytes_per_chunk: int, + leftover_chunk_buffer: AudioBuffer, +) -> Iterable[bytes]: + """Yield fixed-sized chunks from samples, keeping leftover bytes from previous call(s).""" + + if (len(leftover_chunk_buffer) + len(samples)) < bytes_per_chunk: + # Extend leftover chunk, but not enough samples to complete it + leftover_chunk_buffer.append(samples) + return + + next_chunk_idx = 0 + + if leftover_chunk_buffer: + # Add to leftover chunk from previous call(s). + bytes_to_copy = bytes_per_chunk - len(leftover_chunk_buffer) + leftover_chunk_buffer.append(samples[:bytes_to_copy]) + next_chunk_idx = bytes_to_copy + + # Process full chunk in buffer + yield leftover_chunk_buffer.bytes() + leftover_chunk_buffer.clear() + + while next_chunk_idx < len(samples) - bytes_per_chunk + 1: + # Process full chunk + yield samples[next_chunk_idx : next_chunk_idx + bytes_per_chunk] + next_chunk_idx += bytes_per_chunk + + # Capture leftover chunks + if rest_samples := samples[next_chunk_idx:]: + leftover_chunk_buffer.append(rest_samples) diff --git a/homeassistant/components/wake_word/__init__.py b/homeassistant/components/wake_word/__init__.py index 0a751b7eea2..b308cf98912 100644 --- a/homeassistant/components/wake_word/__init__.py +++ b/homeassistant/components/wake_word/__init__.py @@ -79,8 +79,6 @@ class WakeWordDetectionEntity(RestoreEntity): @final def state(self) -> str | None: """Return the state of the entity.""" - if self.__last_detected is None: - return None return self.__last_detected @property diff --git a/tests/components/assist_pipeline/snapshots/test_init.ambr b/tests/components/assist_pipeline/snapshots/test_init.ambr index 58835e37973..7c1cf0e2b2d 100644 --- a/tests/components/assist_pipeline/snapshots/test_init.ambr +++ b/tests/components/assist_pipeline/snapshots/test_init.ambr @@ -317,6 +317,12 @@ }), 'type': , }), + dict({ + 'data': dict({ + 'timestamp': 1500, + }), + 'type': , + }), dict({ 'data': dict({ 'stt_output': dict({ diff --git a/tests/components/assist_pipeline/test_init.py b/tests/components/assist_pipeline/test_init.py index 184f479f830..aba9862614b 100644 --- a/tests/components/assist_pipeline/test_init.py +++ b/tests/components/assist_pipeline/test_init.py @@ -1,7 +1,7 @@ """Test Voice Assistant init.""" from dataclasses import asdict import itertools as it -from unittest.mock import ANY +from unittest.mock import ANY, patch import pytest from syrupy.assertion import SnapshotAssertion @@ -49,9 +49,9 @@ async def test_pipeline_from_audio_stream_auto( await assist_pipeline.async_pipeline_from_audio_stream( hass, - Context(), - events.append, - stt.SpeechMetadata( + context=Context(), + event_callback=events.append, + stt_metadata=stt.SpeechMetadata( language="", format=stt.AudioFormats.WAV, codec=stt.AudioCodecs.PCM, @@ -59,7 +59,7 @@ async def test_pipeline_from_audio_stream_auto( sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, channel=stt.AudioChannels.CHANNEL_MONO, ), - audio_data(), + stt_stream=audio_data(), ) assert process_events(events) == snapshot @@ -108,9 +108,9 @@ async def test_pipeline_from_audio_stream_legacy( # Use the created pipeline await assist_pipeline.async_pipeline_from_audio_stream( hass, - Context(), - events.append, - stt.SpeechMetadata( + context=Context(), + event_callback=events.append, + stt_metadata=stt.SpeechMetadata( language="en-UK", format=stt.AudioFormats.WAV, codec=stt.AudioCodecs.PCM, @@ -118,7 +118,7 @@ async def test_pipeline_from_audio_stream_legacy( sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, channel=stt.AudioChannels.CHANNEL_MONO, ), - audio_data(), + stt_stream=audio_data(), pipeline_id=pipeline_id, ) @@ -168,9 +168,9 @@ async def test_pipeline_from_audio_stream_entity( # Use the created pipeline await assist_pipeline.async_pipeline_from_audio_stream( hass, - Context(), - events.append, - stt.SpeechMetadata( + context=Context(), + event_callback=events.append, + stt_metadata=stt.SpeechMetadata( language="en-UK", format=stt.AudioFormats.WAV, codec=stt.AudioCodecs.PCM, @@ -178,7 +178,7 @@ async def test_pipeline_from_audio_stream_entity( sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, channel=stt.AudioChannels.CHANNEL_MONO, ), - audio_data(), + stt_stream=audio_data(), pipeline_id=pipeline_id, ) @@ -229,9 +229,9 @@ async def test_pipeline_from_audio_stream_no_stt( with pytest.raises(assist_pipeline.pipeline.PipelineRunValidationError): await assist_pipeline.async_pipeline_from_audio_stream( hass, - Context(), - events.append, - stt.SpeechMetadata( + context=Context(), + event_callback=events.append, + stt_metadata=stt.SpeechMetadata( language="en-UK", format=stt.AudioFormats.WAV, codec=stt.AudioCodecs.PCM, @@ -239,7 +239,7 @@ async def test_pipeline_from_audio_stream_no_stt( sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, channel=stt.AudioChannels.CHANNEL_MONO, ), - audio_data(), + stt_stream=audio_data(), pipeline_id=pipeline_id, ) @@ -268,9 +268,9 @@ async def test_pipeline_from_audio_stream_unknown_pipeline( with pytest.raises(assist_pipeline.PipelineNotFound): await assist_pipeline.async_pipeline_from_audio_stream( hass, - Context(), - events.append, - stt.SpeechMetadata( + context=Context(), + event_callback=events.append, + stt_metadata=stt.SpeechMetadata( language="en-UK", format=stt.AudioFormats.WAV, codec=stt.AudioCodecs.PCM, @@ -278,7 +278,7 @@ async def test_pipeline_from_audio_stream_unknown_pipeline( sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, channel=stt.AudioChannels.CHANNEL_MONO, ), - audio_data(), + stt_stream=audio_data(), pipeline_id="blah", ) @@ -308,26 +308,38 @@ async def test_pipeline_from_audio_stream_wake_word( yield b"wake word" yield b"part1" yield b"part2" + yield b"end" yield b"" - await assist_pipeline.async_pipeline_from_audio_stream( - hass, - Context(), - events.append, - stt.SpeechMetadata( - language="", - format=stt.AudioFormats.WAV, - codec=stt.AudioCodecs.PCM, - bit_rate=stt.AudioBitRates.BITRATE_16, - sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, - channel=stt.AudioChannels.CHANNEL_MONO, - ), - audio_data(), - start_stage=assist_pipeline.PipelineStage.WAKE_WORD, - wake_word_settings=assist_pipeline.WakeWordSettings( - audio_seconds_to_buffer=1.5 - ), - ) + def continue_stt(self, chunk): + # Ensure stt_vad_start event is triggered + self.in_command = True + + # Stop on fake end chunk to trigger stt_vad_end + return chunk != b"end" + + with patch( + "homeassistant.components.assist_pipeline.pipeline.VoiceCommandSegmenter.process", + continue_stt, + ): + await assist_pipeline.async_pipeline_from_audio_stream( + hass, + context=Context(), + event_callback=events.append, + stt_metadata=stt.SpeechMetadata( + language="", + format=stt.AudioFormats.WAV, + codec=stt.AudioCodecs.PCM, + bit_rate=stt.AudioBitRates.BITRATE_16, + sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, + channel=stt.AudioChannels.CHANNEL_MONO, + ), + stt_stream=audio_data(), + start_stage=assist_pipeline.PipelineStage.WAKE_WORD, + wake_word_settings=assist_pipeline.WakeWordSettings( + audio_seconds_to_buffer=1.5 + ), + ) assert process_events(events) == snapshot diff --git a/tests/components/assist_pipeline/test_ring_buffer.py b/tests/components/assist_pipeline/test_ring_buffer.py new file mode 100644 index 00000000000..22185c3ad5b --- /dev/null +++ b/tests/components/assist_pipeline/test_ring_buffer.py @@ -0,0 +1,38 @@ +"""Tests for audio ring buffer.""" +from homeassistant.components.assist_pipeline.ring_buffer import RingBuffer + + +def test_ring_buffer_empty() -> None: + """Test empty ring buffer.""" + rb = RingBuffer(10) + assert rb.maxlen == 10 + assert rb.pos == 0 + assert rb.getvalue() == b"" + + +def test_ring_buffer_put_1() -> None: + """Test putting some data smaller than the maximum length.""" + rb = RingBuffer(10) + rb.put(bytes([1, 2, 3, 4, 5])) + assert len(rb) == 5 + assert rb.pos == 5 + assert rb.getvalue() == bytes([1, 2, 3, 4, 5]) + + +def test_ring_buffer_put_2() -> None: + """Test putting some data past the end of the buffer.""" + rb = RingBuffer(10) + rb.put(bytes([1, 2, 3, 4, 5])) + rb.put(bytes([6, 7, 8, 9, 10, 11, 12])) + assert len(rb) == 10 + assert rb.pos == 2 + assert rb.getvalue() == bytes([3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) + + +def test_ring_buffer_put_too_large() -> None: + """Test putting data too large for the buffer.""" + rb = RingBuffer(10) + rb.put(bytes([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])) + assert len(rb) == 10 + assert rb.pos == 2 + assert rb.getvalue() == bytes([3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) diff --git a/tests/components/assist_pipeline/test_vad.py b/tests/components/assist_pipeline/test_vad.py index 3a5c763ee5c..4dc8c8f6197 100644 --- a/tests/components/assist_pipeline/test_vad.py +++ b/tests/components/assist_pipeline/test_vad.py @@ -1,7 +1,12 @@ """Tests for webrtcvad voice command segmenter.""" +import itertools as it from unittest.mock import patch -from homeassistant.components.assist_pipeline.vad import VoiceCommandSegmenter +from homeassistant.components.assist_pipeline.vad import ( + AudioBuffer, + VoiceCommandSegmenter, + chunk_samples, +) _ONE_SECOND = 16000 * 2 # 16Khz 16-bit @@ -36,3 +41,87 @@ def test_speech() -> None: # silence # False return value indicates voice command is finished assert not segmenter.process(bytes(_ONE_SECOND)) + + +def test_audio_buffer() -> None: + """Test audio buffer wrapping.""" + + def is_speech(self, chunk, sample_rate): + """Disable VAD.""" + return False + + with patch( + "webrtcvad.Vad.is_speech", + new=is_speech, + ): + segmenter = VoiceCommandSegmenter() + bytes_per_chunk = segmenter.vad_samples_per_chunk * 2 + + with patch.object( + segmenter, "_process_chunk", return_value=True + ) as mock_process: + # Partially fill audio buffer + half_chunk = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk // 2)) + segmenter.process(half_chunk) + + assert not mock_process.called + assert segmenter.audio_buffer == half_chunk + + # Fill and wrap with 1/4 chunk left over + three_quarters_chunk = bytes( + it.islice(it.cycle(range(256)), int(0.75 * bytes_per_chunk)) + ) + segmenter.process(three_quarters_chunk) + + assert mock_process.call_count == 1 + assert ( + segmenter.audio_buffer + == three_quarters_chunk[ + len(three_quarters_chunk) - (bytes_per_chunk // 4) : + ] + ) + assert ( + mock_process.call_args[0][0] + == half_chunk + three_quarters_chunk[: bytes_per_chunk // 2] + ) + + # Run 2 chunks through + segmenter.reset() + assert len(segmenter.audio_buffer) == 0 + + mock_process.reset_mock() + two_chunks = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk * 2)) + segmenter.process(two_chunks) + + assert mock_process.call_count == 2 + assert len(segmenter.audio_buffer) == 0 + assert mock_process.call_args_list[0][0][0] == two_chunks[:bytes_per_chunk] + assert mock_process.call_args_list[1][0][0] == two_chunks[bytes_per_chunk:] + + +def test_partial_chunk() -> None: + """Test that chunk_samples returns when given a partial chunk.""" + bytes_per_chunk = 5 + samples = bytes([1, 2, 3]) + leftover_chunk_buffer = AudioBuffer(bytes_per_chunk) + chunks = list(chunk_samples(samples, bytes_per_chunk, leftover_chunk_buffer)) + + assert len(chunks) == 0 + assert leftover_chunk_buffer.bytes() == samples + + +def test_chunk_samples_leftover() -> None: + """Test that chunk_samples property keeps left over bytes across calls.""" + bytes_per_chunk = 5 + samples = bytes([1, 2, 3, 4, 5, 6]) + leftover_chunk_buffer = AudioBuffer(bytes_per_chunk) + chunks = list(chunk_samples(samples, bytes_per_chunk, leftover_chunk_buffer)) + + assert len(chunks) == 1 + assert leftover_chunk_buffer.bytes() == bytes([6]) + + # Add some more to the chunk + chunks = list(chunk_samples(samples, bytes_per_chunk, leftover_chunk_buffer)) + + assert len(chunks) == 1 + assert leftover_chunk_buffer.bytes() == bytes([5, 6])