diff --git a/homeassistant/components/assist_pipeline/__init__.py b/homeassistant/components/assist_pipeline/__init__.py index f481411e551..8ee053162b0 100644 --- a/homeassistant/components/assist_pipeline/__init__.py +++ b/homeassistant/components/assist_pipeline/__init__.py @@ -16,6 +16,10 @@ from .const import ( DATA_LAST_WAKE_UP, DOMAIN, EVENT_RECORDING, + SAMPLE_CHANNELS, + SAMPLE_RATE, + SAMPLE_WIDTH, + SAMPLES_PER_CHUNK, ) from .error import PipelineNotFound from .pipeline import ( @@ -53,6 +57,10 @@ __all__ = ( "PipelineNotFound", "WakeWordSettings", "EVENT_RECORDING", + "SAMPLES_PER_CHUNK", + "SAMPLE_RATE", + "SAMPLE_WIDTH", + "SAMPLE_CHANNELS", ) CONFIG_SCHEMA = vol.Schema( diff --git a/homeassistant/components/assist_pipeline/audio_enhancer.py b/homeassistant/components/assist_pipeline/audio_enhancer.py index e7a149bd00e..c9c60f421b1 100644 --- a/homeassistant/components/assist_pipeline/audio_enhancer.py +++ b/homeassistant/components/assist_pipeline/audio_enhancer.py @@ -6,6 +6,8 @@ import logging from pymicro_vad import MicroVad +from .const import BYTES_PER_CHUNK + _LOGGER = logging.getLogger(__name__) @@ -38,11 +40,6 @@ class AudioEnhancer(ABC): def enhance_chunk(self, audio: bytes, timestamp_ms: int) -> EnhancedAudioChunk: """Enhance chunk of PCM audio @ 16Khz with 16-bit mono samples.""" - @property - @abstractmethod - def samples_per_chunk(self) -> int | None: - """Return number of samples per chunk or None if chunking isn't required.""" - class MicroVadEnhancer(AudioEnhancer): """Audio enhancer that just runs microVAD.""" @@ -61,22 +58,15 @@ class MicroVadEnhancer(AudioEnhancer): _LOGGER.debug("Initialized microVAD with threshold=%s", self.threshold) def enhance_chunk(self, audio: bytes, timestamp_ms: int) -> EnhancedAudioChunk: - """Enhance chunk of PCM audio @ 16Khz with 16-bit mono samples.""" + """Enhance 10ms chunk of PCM audio @ 16Khz with 16-bit mono samples.""" is_speech: bool | None = None if self.vad is not None: # Run VAD + assert len(audio) == BYTES_PER_CHUNK speech_prob = self.vad.Process10ms(audio) is_speech = speech_prob > self.threshold return EnhancedAudioChunk( audio=audio, timestamp_ms=timestamp_ms, is_speech=is_speech ) - - @property - def samples_per_chunk(self) -> int | None: - """Return number of samples per chunk or None if chunking isn't required.""" - if self.is_vad_enabled: - return 160 # 10ms - - return None diff --git a/homeassistant/components/assist_pipeline/const.py b/homeassistant/components/assist_pipeline/const.py index 14b93a90372..f7306b89a54 100644 --- a/homeassistant/components/assist_pipeline/const.py +++ b/homeassistant/components/assist_pipeline/const.py @@ -19,4 +19,6 @@ EVENT_RECORDING = f"{DOMAIN}_recording" SAMPLE_RATE = 16000 # hertz SAMPLE_WIDTH = 2 # bytes SAMPLE_CHANNELS = 1 # mono -SAMPLES_PER_CHUNK = 240 # 20 ms @ 16Khz +MS_PER_CHUNK = 10 +SAMPLES_PER_CHUNK = SAMPLE_RATE // (1000 // MS_PER_CHUNK) # 10 ms @ 16Khz +BYTES_PER_CHUNK = SAMPLES_PER_CHUNK * SAMPLE_WIDTH * SAMPLE_CHANNELS # 16-bit diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index af29888eb07..9fada934ca1 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -51,11 +51,13 @@ from homeassistant.util.limited_size_dict import LimitedSizeDict from .audio_enhancer import AudioEnhancer, EnhancedAudioChunk, MicroVadEnhancer from .const import ( + BYTES_PER_CHUNK, CONF_DEBUG_RECORDING_DIR, DATA_CONFIG, DATA_LAST_WAKE_UP, DATA_MIGRATIONS, DOMAIN, + MS_PER_CHUNK, SAMPLE_CHANNELS, SAMPLE_RATE, SAMPLE_WIDTH, @@ -502,9 +504,6 @@ class AudioSettings: is_vad_enabled: bool = True """True if VAD is used to determine the end of the voice command.""" - samples_per_chunk: int | None = None - """Number of samples that will be in each audio chunk (None for no chunking).""" - silence_seconds: float = 0.5 """Seconds of silence after voice command has ended.""" @@ -525,11 +524,6 @@ class AudioSettings: or (self.auto_gain_dbfs > 0) ) - @property - def is_chunking_enabled(self) -> bool: - """True if chunk size is set.""" - return self.samples_per_chunk is not None - @dataclass class PipelineRun: @@ -566,7 +560,9 @@ class PipelineRun: audio_enhancer: AudioEnhancer | None = None """VAD/noise suppression/auto gain""" - audio_chunking_buffer: AudioBuffer | None = None + audio_chunking_buffer: AudioBuffer = field( + default_factory=lambda: AudioBuffer(BYTES_PER_CHUNK) + ) """Buffer used when splitting audio into chunks for audio processing""" _device_id: str | None = None @@ -599,8 +595,6 @@ class PipelineRun: self.audio_settings.is_vad_enabled, ) - self.audio_chunking_buffer = AudioBuffer(self.samples_per_chunk * SAMPLE_WIDTH) - def __eq__(self, other: object) -> bool: """Compare pipeline runs by id.""" if isinstance(other, PipelineRun): @@ -608,14 +602,6 @@ class PipelineRun: return False - @property - def samples_per_chunk(self) -> int: - """Return number of samples expected in each audio chunk.""" - if self.audio_enhancer is not None: - return self.audio_enhancer.samples_per_chunk or SAMPLES_PER_CHUNK - - return self.audio_settings.samples_per_chunk or SAMPLES_PER_CHUNK - @callback def process_event(self, event: PipelineEvent) -> None: """Log an event and call listener.""" @@ -728,7 +714,7 @@ class PipelineRun: # after wake-word-detection. num_audio_chunks_to_buffer = int( (wake_word_settings.audio_seconds_to_buffer * SAMPLE_RATE) - / self.samples_per_chunk + / SAMPLES_PER_CHUNK ) stt_audio_buffer: deque[EnhancedAudioChunk] | None = None @@ -1216,60 +1202,31 @@ class PipelineRun: self.debug_recording_thread = None async def process_volume_only( - self, - audio_stream: AsyncIterable[bytes], - sample_rate: int = SAMPLE_RATE, - sample_width: int = SAMPLE_WIDTH, + self, audio_stream: AsyncIterable[bytes] ) -> AsyncGenerator[EnhancedAudioChunk]: """Apply volume transformation only (no VAD/audio enhancements) with optional chunking.""" - assert self.audio_chunking_buffer is not None - - bytes_per_chunk = self.samples_per_chunk * sample_width - ms_per_sample = sample_rate // 1000 - ms_per_chunk = self.samples_per_chunk // ms_per_sample timestamp_ms = 0 - async for chunk in audio_stream: if self.audio_settings.volume_multiplier != 1.0: chunk = _multiply_volume(chunk, self.audio_settings.volume_multiplier) - if self.audio_settings.is_chunking_enabled: - for sub_chunk in chunk_samples( - chunk, bytes_per_chunk, self.audio_chunking_buffer - ): - yield EnhancedAudioChunk( - audio=sub_chunk, - timestamp_ms=timestamp_ms, - is_speech=None, # no VAD - ) - timestamp_ms += ms_per_chunk - else: - # No chunking + for sub_chunk in chunk_samples( + chunk, BYTES_PER_CHUNK, self.audio_chunking_buffer + ): yield EnhancedAudioChunk( - audio=chunk, + audio=sub_chunk, timestamp_ms=timestamp_ms, is_speech=None, # no VAD ) - timestamp_ms += (len(chunk) // sample_width) // ms_per_sample + timestamp_ms += MS_PER_CHUNK async def process_enhance_audio( - self, - audio_stream: AsyncIterable[bytes], - sample_rate: int = SAMPLE_RATE, - sample_width: int = SAMPLE_WIDTH, + self, audio_stream: AsyncIterable[bytes] ) -> AsyncGenerator[EnhancedAudioChunk]: - """Split audio into 10 ms chunks and apply VAD/noise suppression/auto gain/volume transformation.""" + """Split audio into chunks and apply VAD/noise suppression/auto gain/volume transformation.""" assert self.audio_enhancer is not None - assert self.audio_enhancer.samples_per_chunk is not None - assert self.audio_chunking_buffer is not None - bytes_per_chunk = self.audio_enhancer.samples_per_chunk * sample_width - ms_per_sample = sample_rate // 1000 - ms_per_chunk = ( - self.audio_enhancer.samples_per_chunk // sample_width - ) // ms_per_sample timestamp_ms = 0 - async for dirty_samples in audio_stream: if self.audio_settings.volume_multiplier != 1.0: # Static gain @@ -1279,10 +1236,10 @@ class PipelineRun: # Split into chunks for audio enhancements/VAD for dirty_chunk in chunk_samples( - dirty_samples, bytes_per_chunk, self.audio_chunking_buffer + dirty_samples, BYTES_PER_CHUNK, self.audio_chunking_buffer ): yield self.audio_enhancer.enhance_chunk(dirty_chunk, timestamp_ms) - timestamp_ms += ms_per_chunk + timestamp_ms += MS_PER_CHUNK def _multiply_volume(chunk: bytes, volume_multiplier: float) -> bytes: diff --git a/homeassistant/components/voip/voip.py b/homeassistant/components/voip/voip.py index 243909629cf..161e938a3b6 100644 --- a/homeassistant/components/voip/voip.py +++ b/homeassistant/components/voip/voip.py @@ -21,7 +21,7 @@ from voip_utils import ( VoipDatagramProtocol, ) -from homeassistant.components import stt, tts +from homeassistant.components import assist_pipeline, stt, tts from homeassistant.components.assist_pipeline import ( Pipeline, PipelineEvent, @@ -331,15 +331,14 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol): async with asyncio.timeout(self.audio_timeout): chunk = await self._audio_queue.get() - assert audio_enhancer.samples_per_chunk is not None - vad_buffer = AudioBuffer(audio_enhancer.samples_per_chunk * WIDTH) + vad_buffer = AudioBuffer(assist_pipeline.SAMPLES_PER_CHUNK * WIDTH) while chunk: chunk_buffer.append(chunk) segmenter.process_with_vad( chunk, - audio_enhancer.samples_per_chunk, + assist_pipeline.SAMPLES_PER_CHUNK, lambda x: audio_enhancer.enhance_chunk(x, 0).is_speech is True, vad_buffer, ) @@ -371,13 +370,12 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol): async with asyncio.timeout(self.audio_timeout): chunk = await self._audio_queue.get() - assert audio_enhancer.samples_per_chunk is not None - vad_buffer = AudioBuffer(audio_enhancer.samples_per_chunk * WIDTH) + vad_buffer = AudioBuffer(assist_pipeline.SAMPLES_PER_CHUNK * WIDTH) while chunk: if not segmenter.process_with_vad( chunk, - audio_enhancer.samples_per_chunk, + assist_pipeline.SAMPLES_PER_CHUNK, lambda x: audio_enhancer.enhance_chunk(x, 0).is_speech is True, vad_buffer, ): @@ -437,13 +435,13 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol): sample_channels = wav_file.getnchannels() if ( - (sample_rate != 16000) - or (sample_width != 2) - or (sample_channels != 1) + (sample_rate != RATE) + or (sample_width != WIDTH) + or (sample_channels != CHANNELS) ): raise ValueError( - "Expected rate/width/channels as 16000/2/1," - " got {sample_rate}/{sample_width}/{sample_channels}}" + f"Expected rate/width/channels as {RATE}/{WIDTH}/{CHANNELS}," + f" got {sample_rate}/{sample_width}/{sample_channels}" ) audio_bytes = wav_file.readframes(wav_file.getnframes()) diff --git a/tests/components/assist_pipeline/conftest.py b/tests/components/assist_pipeline/conftest.py index c041a54d8fa..b2eca1e7ce1 100644 --- a/tests/components/assist_pipeline/conftest.py +++ b/tests/components/assist_pipeline/conftest.py @@ -11,6 +11,12 @@ import pytest from homeassistant.components import stt, tts, wake_word from homeassistant.components.assist_pipeline import DOMAIN, select as assist_select +from homeassistant.components.assist_pipeline.const import ( + BYTES_PER_CHUNK, + SAMPLE_CHANNELS, + SAMPLE_RATE, + SAMPLE_WIDTH, +) from homeassistant.components.assist_pipeline.pipeline import ( PipelineData, PipelineStorageCollection, @@ -33,6 +39,8 @@ from tests.common import ( _TRANSCRIPT = "test transcript" +BYTES_ONE_SECOND = SAMPLE_RATE * SAMPLE_WIDTH * SAMPLE_CHANNELS + @pytest.fixture(autouse=True) def mock_tts_cache_dir_autouse(mock_tts_cache_dir: Path) -> Path: @@ -462,3 +470,8 @@ def pipeline_data(hass: HomeAssistant, init_components) -> PipelineData: def pipeline_storage(pipeline_data) -> PipelineStorageCollection: """Return pipeline storage collection.""" return pipeline_data.pipeline_store + + +def make_10ms_chunk(header: bytes) -> bytes: + """Return 10ms of zeros with the given header.""" + return header + bytes(BYTES_PER_CHUNK - len(header)) diff --git a/tests/components/assist_pipeline/snapshots/test_websocket.ambr b/tests/components/assist_pipeline/snapshots/test_websocket.ambr index 0b04b67bb22..e5ae18d28f2 100644 --- a/tests/components/assist_pipeline/snapshots/test_websocket.ambr +++ b/tests/components/assist_pipeline/snapshots/test_websocket.ambr @@ -440,7 +440,7 @@ # --- # name: test_device_capture_override.2 dict({ - 'audio': 'Y2h1bmsx', + 'audio': 'Y2h1bmsxAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=', 'channels': 1, 'rate': 16000, 'type': 'audio', diff --git a/tests/components/assist_pipeline/test_init.py b/tests/components/assist_pipeline/test_init.py index 8fb7ce5b5a5..4206a288331 100644 --- a/tests/components/assist_pipeline/test_init.py +++ b/tests/components/assist_pipeline/test_init.py @@ -13,6 +13,7 @@ from syrupy.assertion import SnapshotAssertion from homeassistant.components import assist_pipeline, media_source, stt, tts from homeassistant.components.assist_pipeline.const import ( + BYTES_PER_CHUNK, CONF_DEBUG_RECORDING_DIR, DOMAIN, ) @@ -20,16 +21,16 @@ from homeassistant.core import Context, HomeAssistant from homeassistant.setup import async_setup_component from .conftest import ( + BYTES_ONE_SECOND, MockSttProvider, MockSttProviderEntity, MockTTSProvider, MockWakeWordEntity, + make_10ms_chunk, ) from tests.typing import ClientSessionGenerator, WebSocketGenerator -BYTES_ONE_SECOND = 16000 * 2 - def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]: """Process events to remove dynamic values.""" @@ -58,8 +59,8 @@ async def test_pipeline_from_audio_stream_auto( events: list[assist_pipeline.PipelineEvent] = [] async def audio_data(): - yield b"part1" - yield b"part2" + yield make_10ms_chunk(b"part1") + yield make_10ms_chunk(b"part2") yield b"" await assist_pipeline.async_pipeline_from_audio_stream( @@ -79,7 +80,9 @@ async def test_pipeline_from_audio_stream_auto( ) assert process_events(events) == snapshot - assert mock_stt_provider.received == [b"part1", b"part2"] + assert len(mock_stt_provider.received) == 2 + assert mock_stt_provider.received[0].startswith(b"part1") + assert mock_stt_provider.received[1].startswith(b"part2") async def test_pipeline_from_audio_stream_legacy( @@ -98,8 +101,8 @@ async def test_pipeline_from_audio_stream_legacy( events: list[assist_pipeline.PipelineEvent] = [] async def audio_data(): - yield b"part1" - yield b"part2" + yield make_10ms_chunk(b"part1") + yield make_10ms_chunk(b"part2") yield b"" # Create a pipeline using an stt entity @@ -142,7 +145,9 @@ async def test_pipeline_from_audio_stream_legacy( ) assert process_events(events) == snapshot - assert mock_stt_provider.received == [b"part1", b"part2"] + assert len(mock_stt_provider.received) == 2 + assert mock_stt_provider.received[0].startswith(b"part1") + assert mock_stt_provider.received[1].startswith(b"part2") async def test_pipeline_from_audio_stream_entity( @@ -161,8 +166,8 @@ async def test_pipeline_from_audio_stream_entity( events: list[assist_pipeline.PipelineEvent] = [] async def audio_data(): - yield b"part1" - yield b"part2" + yield make_10ms_chunk(b"part1") + yield make_10ms_chunk(b"part2") yield b"" # Create a pipeline using an stt entity @@ -205,7 +210,9 @@ async def test_pipeline_from_audio_stream_entity( ) assert process_events(events) == snapshot - assert mock_stt_provider_entity.received == [b"part1", b"part2"] + assert len(mock_stt_provider_entity.received) == 2 + assert mock_stt_provider_entity.received[0].startswith(b"part1") + assert mock_stt_provider_entity.received[1].startswith(b"part2") async def test_pipeline_from_audio_stream_no_stt( @@ -224,8 +231,8 @@ async def test_pipeline_from_audio_stream_no_stt( events: list[assist_pipeline.PipelineEvent] = [] async def audio_data(): - yield b"part1" - yield b"part2" + yield make_10ms_chunk(b"part1") + yield make_10ms_chunk(b"part2") yield b"" # Create a pipeline without stt support @@ -285,8 +292,8 @@ async def test_pipeline_from_audio_stream_unknown_pipeline( events: list[assist_pipeline.PipelineEvent] = [] async def audio_data(): - yield b"part1" - yield b"part2" + yield make_10ms_chunk(b"part1") + yield make_10ms_chunk(b"part2") yield b"" # Try to use the created pipeline @@ -327,7 +334,7 @@ async def test_pipeline_from_audio_stream_wake_word( # [0, 2, ...] wake_chunk_2 = bytes(it.islice(it.cycle(range(0, 256, 2)), BYTES_ONE_SECOND)) - samples_per_chunk = 160 + samples_per_chunk = 160 # 10ms @ 16Khz bytes_per_chunk = samples_per_chunk * 2 # 16-bit async def audio_data(): @@ -343,8 +350,8 @@ async def test_pipeline_from_audio_stream_wake_word( yield wake_chunk_2[i : i + bytes_per_chunk] i += bytes_per_chunk - for chunk in (b"wake word!", b"part1", b"part2"): - yield chunk + bytes(bytes_per_chunk - len(chunk)) + for header in (b"wake word!", b"part1", b"part2"): + yield make_10ms_chunk(header) yield b"" @@ -365,9 +372,7 @@ async def test_pipeline_from_audio_stream_wake_word( wake_word_settings=assist_pipeline.WakeWordSettings( audio_seconds_to_buffer=1.5 ), - audio_settings=assist_pipeline.AudioSettings( - is_vad_enabled=False, samples_per_chunk=samples_per_chunk - ), + audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False), ) assert process_events(events) == snapshot @@ -408,13 +413,11 @@ async def test_pipeline_save_audio( pipeline = assist_pipeline.async_get_pipeline(hass) events: list[assist_pipeline.PipelineEvent] = [] - # Pad out to an even number of bytes since these "samples" will be saved - # as 16-bit values. async def audio_data(): - yield b"wake word_" + yield make_10ms_chunk(b"wake word") # queued audio - yield b"part1_" - yield b"part2_" + yield make_10ms_chunk(b"part1") + yield make_10ms_chunk(b"part2") yield b"" await assist_pipeline.async_pipeline_from_audio_stream( @@ -457,12 +460,16 @@ async def test_pipeline_save_audio( # Verify wake file with wave.open(str(wake_file), "rb") as wake_wav: wake_data = wake_wav.readframes(wake_wav.getnframes()) - assert wake_data == b"wake word_" + assert wake_data.startswith(b"wake word") # Verify stt file with wave.open(str(stt_file), "rb") as stt_wav: stt_data = stt_wav.readframes(stt_wav.getnframes()) - assert stt_data == b"queued audiopart1_part2_" + assert stt_data.startswith(b"queued audio") + stt_data = stt_data[len(b"queued audio") :] + assert stt_data.startswith(b"part1") + stt_data = stt_data[BYTES_PER_CHUNK:] + assert stt_data.startswith(b"part2") async def test_pipeline_saved_audio_with_device_id( @@ -645,10 +652,10 @@ async def test_wake_word_detection_aborted( events: list[assist_pipeline.PipelineEvent] = [] async def audio_data(): - yield b"silence!" - yield b"wake word!" - yield b"part1" - yield b"part2" + yield make_10ms_chunk(b"silence!") + yield make_10ms_chunk(b"wake word!") + yield make_10ms_chunk(b"part1") + yield make_10ms_chunk(b"part2") yield b"" pipeline_store = pipeline_data.pipeline_store diff --git a/tests/components/assist_pipeline/test_websocket.py b/tests/components/assist_pipeline/test_websocket.py index 7d4a9b18c12..2da914f4252 100644 --- a/tests/components/assist_pipeline/test_websocket.py +++ b/tests/components/assist_pipeline/test_websocket.py @@ -8,7 +8,12 @@ from unittest.mock import ANY, patch import pytest from syrupy.assertion import SnapshotAssertion -from homeassistant.components.assist_pipeline.const import DOMAIN +from homeassistant.components.assist_pipeline.const import ( + DOMAIN, + SAMPLE_CHANNELS, + SAMPLE_RATE, + SAMPLE_WIDTH, +) from homeassistant.components.assist_pipeline.pipeline import ( DeviceAudioQueue, Pipeline, @@ -18,7 +23,13 @@ from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import device_registry as dr -from .conftest import MockWakeWordEntity, MockWakeWordEntity2 +from .conftest import ( + BYTES_ONE_SECOND, + BYTES_PER_CHUNK, + MockWakeWordEntity, + MockWakeWordEntity2, + make_10ms_chunk, +) from tests.common import MockConfigEntry from tests.typing import WebSocketGenerator @@ -205,7 +216,7 @@ async def test_audio_pipeline_with_wake_word_timeout( "start_stage": "wake_word", "end_stage": "tts", "input": { - "sample_rate": 16000, + "sample_rate": SAMPLE_RATE, "timeout": 1, }, } @@ -229,7 +240,7 @@ async def test_audio_pipeline_with_wake_word_timeout( events.append(msg["event"]) # 2 seconds of silence - await client.send_bytes(bytes([1]) + bytes(16000 * 2 * 2)) + await client.send_bytes(bytes([1]) + bytes(2 * BYTES_ONE_SECOND)) # Time out error msg = await client.receive_json() @@ -259,7 +270,7 @@ async def test_audio_pipeline_with_wake_word_no_timeout( "type": "assist_pipeline/run", "start_stage": "wake_word", "end_stage": "tts", - "input": {"sample_rate": 16000, "timeout": 0, "no_vad": True}, + "input": {"sample_rate": SAMPLE_RATE, "timeout": 0, "no_vad": True}, } ) @@ -282,9 +293,10 @@ async def test_audio_pipeline_with_wake_word_no_timeout( events.append(msg["event"]) # "audio" - await client.send_bytes(bytes([handler_id]) + b"wake word") + await client.send_bytes(bytes([handler_id]) + make_10ms_chunk(b"wake word")) - msg = await client.receive_json() + async with asyncio.timeout(1): + msg = await client.receive_json() assert msg["event"]["type"] == "wake_word-end" assert msg["event"]["data"] == snapshot events.append(msg["event"]) @@ -365,7 +377,7 @@ async def test_audio_pipeline_no_wake_word_engine( "start_stage": "wake_word", "end_stage": "tts", "input": { - "sample_rate": 16000, + "sample_rate": SAMPLE_RATE, }, } ) @@ -402,7 +414,7 @@ async def test_audio_pipeline_no_wake_word_entity( "start_stage": "wake_word", "end_stage": "tts", "input": { - "sample_rate": 16000, + "sample_rate": SAMPLE_RATE, }, } ) @@ -1771,7 +1783,7 @@ async def test_audio_pipeline_with_enhancements( "start_stage": "stt", "end_stage": "tts", "input": { - "sample_rate": 16000, + "sample_rate": SAMPLE_RATE, # Enhancements "noise_suppression_level": 2, "auto_gain_dbfs": 15, @@ -1801,7 +1813,7 @@ async def test_audio_pipeline_with_enhancements( # One second of silence. # This will pass through the audio enhancement pipeline, but we don't test # the actual output. - await client.send_bytes(bytes([handler_id]) + bytes(16000 * 2)) + await client.send_bytes(bytes([handler_id]) + bytes(BYTES_ONE_SECOND)) # End of audio stream (handler id + empty payload) await client.send_bytes(bytes([handler_id])) @@ -1871,7 +1883,7 @@ async def test_wake_word_cooldown_same_id( "type": "assist_pipeline/run", "start_stage": "wake_word", "end_stage": "tts", - "input": {"sample_rate": 16000, "no_vad": True}, + "input": {"sample_rate": SAMPLE_RATE, "no_vad": True}, } ) @@ -1880,7 +1892,7 @@ async def test_wake_word_cooldown_same_id( "type": "assist_pipeline/run", "start_stage": "wake_word", "end_stage": "tts", - "input": {"sample_rate": 16000, "no_vad": True}, + "input": {"sample_rate": SAMPLE_RATE, "no_vad": True}, } ) @@ -1914,8 +1926,8 @@ async def test_wake_word_cooldown_same_id( assert msg["event"]["data"] == snapshot # Wake both up at the same time - await client_1.send_bytes(bytes([handler_id_1]) + b"wake word") - await client_2.send_bytes(bytes([handler_id_2]) + b"wake word") + await client_1.send_bytes(bytes([handler_id_1]) + make_10ms_chunk(b"wake word")) + await client_2.send_bytes(bytes([handler_id_2]) + make_10ms_chunk(b"wake word")) # Get response events error_data: dict[str, Any] | None = None @@ -1954,7 +1966,7 @@ async def test_wake_word_cooldown_different_ids( "type": "assist_pipeline/run", "start_stage": "wake_word", "end_stage": "tts", - "input": {"sample_rate": 16000, "no_vad": True}, + "input": {"sample_rate": SAMPLE_RATE, "no_vad": True}, } ) @@ -1963,7 +1975,7 @@ async def test_wake_word_cooldown_different_ids( "type": "assist_pipeline/run", "start_stage": "wake_word", "end_stage": "tts", - "input": {"sample_rate": 16000, "no_vad": True}, + "input": {"sample_rate": SAMPLE_RATE, "no_vad": True}, } ) @@ -1997,8 +2009,8 @@ async def test_wake_word_cooldown_different_ids( assert msg["event"]["data"] == snapshot # Wake both up at the same time, but they will have different wake word ids - await client_1.send_bytes(bytes([handler_id_1]) + b"wake word") - await client_2.send_bytes(bytes([handler_id_2]) + b"wake word") + await client_1.send_bytes(bytes([handler_id_1]) + make_10ms_chunk(b"wake word")) + await client_2.send_bytes(bytes([handler_id_2]) + make_10ms_chunk(b"wake word")) # Get response events msg = await client_1.receive_json() @@ -2073,7 +2085,7 @@ async def test_wake_word_cooldown_different_entities( "pipeline": pipeline_id_1, "start_stage": "wake_word", "end_stage": "tts", - "input": {"sample_rate": 16000, "no_vad": True}, + "input": {"sample_rate": SAMPLE_RATE, "no_vad": True}, } ) @@ -2084,7 +2096,7 @@ async def test_wake_word_cooldown_different_entities( "pipeline": pipeline_id_2, "start_stage": "wake_word", "end_stage": "tts", - "input": {"sample_rate": 16000, "no_vad": True}, + "input": {"sample_rate": SAMPLE_RATE, "no_vad": True}, } ) @@ -2119,8 +2131,8 @@ async def test_wake_word_cooldown_different_entities( # Wake both up at the same time. # They will have the same wake word id, but different entities. - await client_1.send_bytes(bytes([handler_id_1]) + b"wake word") - await client_2.send_bytes(bytes([handler_id_2]) + b"wake word") + await client_1.send_bytes(bytes([handler_id_1]) + make_10ms_chunk(b"wake word")) + await client_2.send_bytes(bytes([handler_id_2]) + make_10ms_chunk(b"wake word")) # Get response events error_data: dict[str, Any] | None = None @@ -2158,7 +2170,11 @@ async def test_device_capture( identifiers={("demo", "satellite-1234")}, ) - audio_chunks = [b"chunk1", b"chunk2", b"chunk3"] + audio_chunks = [ + make_10ms_chunk(b"chunk1"), + make_10ms_chunk(b"chunk2"), + make_10ms_chunk(b"chunk3"), + ] # Start capture client_capture = await hass_ws_client(hass) @@ -2181,7 +2197,7 @@ async def test_device_capture( "type": "assist_pipeline/run", "start_stage": "stt", "end_stage": "stt", - "input": {"sample_rate": 16000, "no_vad": True}, + "input": {"sample_rate": SAMPLE_RATE, "no_vad": True}, "device_id": satellite_device.id, } ) @@ -2232,9 +2248,9 @@ async def test_device_capture( # Verify audio chunks for i, audio_chunk in enumerate(audio_chunks): assert events[i]["type"] == "audio" - assert events[i]["rate"] == 16000 - assert events[i]["width"] == 2 - assert events[i]["channels"] == 1 + assert events[i]["rate"] == SAMPLE_RATE + assert events[i]["width"] == SAMPLE_WIDTH + assert events[i]["channels"] == SAMPLE_CHANNELS # Audio is base64 encoded assert events[i]["audio"] == base64.b64encode(audio_chunk).decode("ascii") @@ -2259,7 +2275,11 @@ async def test_device_capture_override( identifiers={("demo", "satellite-1234")}, ) - audio_chunks = [b"chunk1", b"chunk2", b"chunk3"] + audio_chunks = [ + make_10ms_chunk(b"chunk1"), + make_10ms_chunk(b"chunk2"), + make_10ms_chunk(b"chunk3"), + ] # Start first capture client_capture_1 = await hass_ws_client(hass) @@ -2282,7 +2302,7 @@ async def test_device_capture_override( "type": "assist_pipeline/run", "start_stage": "stt", "end_stage": "stt", - "input": {"sample_rate": 16000, "no_vad": True}, + "input": {"sample_rate": SAMPLE_RATE, "no_vad": True}, "device_id": satellite_device.id, } ) @@ -2365,9 +2385,9 @@ async def test_device_capture_override( # Verify all but first audio chunk for i, audio_chunk in enumerate(audio_chunks[1:]): assert events[i]["type"] == "audio" - assert events[i]["rate"] == 16000 - assert events[i]["width"] == 2 - assert events[i]["channels"] == 1 + assert events[i]["rate"] == SAMPLE_RATE + assert events[i]["width"] == SAMPLE_WIDTH + assert events[i]["channels"] == SAMPLE_CHANNELS # Audio is base64 encoded assert events[i]["audio"] == base64.b64encode(audio_chunk).decode("ascii") @@ -2427,7 +2447,7 @@ async def test_device_capture_queue_full( "type": "assist_pipeline/run", "start_stage": "stt", "end_stage": "stt", - "input": {"sample_rate": 16000, "no_vad": True}, + "input": {"sample_rate": SAMPLE_RATE, "no_vad": True}, "device_id": satellite_device.id, } ) @@ -2448,8 +2468,8 @@ async def test_device_capture_queue_full( assert msg["event"]["type"] == "stt-start" assert msg["event"]["data"] == snapshot - # Single sample will "overflow" the queue - await client_pipeline.send_bytes(bytes([handler_id, 0, 0])) + # Single chunk will "overflow" the queue + await client_pipeline.send_bytes(bytes([handler_id]) + bytes(BYTES_PER_CHUNK)) # End of audio stream await client_pipeline.send_bytes(bytes([handler_id])) @@ -2557,7 +2577,7 @@ async def test_stt_cooldown_same_id( "start_stage": "stt", "end_stage": "tts", "input": { - "sample_rate": 16000, + "sample_rate": SAMPLE_RATE, "wake_word_phrase": "ok_nabu", }, } @@ -2569,7 +2589,7 @@ async def test_stt_cooldown_same_id( "start_stage": "stt", "end_stage": "tts", "input": { - "sample_rate": 16000, + "sample_rate": SAMPLE_RATE, "wake_word_phrase": "ok_nabu", }, } @@ -2628,7 +2648,7 @@ async def test_stt_cooldown_different_ids( "start_stage": "stt", "end_stage": "tts", "input": { - "sample_rate": 16000, + "sample_rate": SAMPLE_RATE, "wake_word_phrase": "ok_nabu", }, } @@ -2640,7 +2660,7 @@ async def test_stt_cooldown_different_ids( "start_stage": "stt", "end_stage": "tts", "input": { - "sample_rate": 16000, + "sample_rate": SAMPLE_RATE, "wake_word_phrase": "hey_jarvis", }, }