mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Standardize assist pipelines on 10ms chunk size (#123024)
* Make chunk size always 10ms * Fix voip
This commit is contained in:
parent
a3b5dcc21b
commit
80aa2c269b
@ -16,6 +16,10 @@ from .const import (
|
|||||||
DATA_LAST_WAKE_UP,
|
DATA_LAST_WAKE_UP,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
EVENT_RECORDING,
|
EVENT_RECORDING,
|
||||||
|
SAMPLE_CHANNELS,
|
||||||
|
SAMPLE_RATE,
|
||||||
|
SAMPLE_WIDTH,
|
||||||
|
SAMPLES_PER_CHUNK,
|
||||||
)
|
)
|
||||||
from .error import PipelineNotFound
|
from .error import PipelineNotFound
|
||||||
from .pipeline import (
|
from .pipeline import (
|
||||||
@ -53,6 +57,10 @@ __all__ = (
|
|||||||
"PipelineNotFound",
|
"PipelineNotFound",
|
||||||
"WakeWordSettings",
|
"WakeWordSettings",
|
||||||
"EVENT_RECORDING",
|
"EVENT_RECORDING",
|
||||||
|
"SAMPLES_PER_CHUNK",
|
||||||
|
"SAMPLE_RATE",
|
||||||
|
"SAMPLE_WIDTH",
|
||||||
|
"SAMPLE_CHANNELS",
|
||||||
)
|
)
|
||||||
|
|
||||||
CONFIG_SCHEMA = vol.Schema(
|
CONFIG_SCHEMA = vol.Schema(
|
||||||
|
@ -6,6 +6,8 @@ import logging
|
|||||||
|
|
||||||
from pymicro_vad import MicroVad
|
from pymicro_vad import MicroVad
|
||||||
|
|
||||||
|
from .const import BYTES_PER_CHUNK
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -38,11 +40,6 @@ class AudioEnhancer(ABC):
|
|||||||
def enhance_chunk(self, audio: bytes, timestamp_ms: int) -> EnhancedAudioChunk:
|
def enhance_chunk(self, audio: bytes, timestamp_ms: int) -> EnhancedAudioChunk:
|
||||||
"""Enhance chunk of PCM audio @ 16Khz with 16-bit mono samples."""
|
"""Enhance chunk of PCM audio @ 16Khz with 16-bit mono samples."""
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def samples_per_chunk(self) -> int | None:
|
|
||||||
"""Return number of samples per chunk or None if chunking isn't required."""
|
|
||||||
|
|
||||||
|
|
||||||
class MicroVadEnhancer(AudioEnhancer):
|
class MicroVadEnhancer(AudioEnhancer):
|
||||||
"""Audio enhancer that just runs microVAD."""
|
"""Audio enhancer that just runs microVAD."""
|
||||||
@ -61,22 +58,15 @@ class MicroVadEnhancer(AudioEnhancer):
|
|||||||
_LOGGER.debug("Initialized microVAD with threshold=%s", self.threshold)
|
_LOGGER.debug("Initialized microVAD with threshold=%s", self.threshold)
|
||||||
|
|
||||||
def enhance_chunk(self, audio: bytes, timestamp_ms: int) -> EnhancedAudioChunk:
|
def enhance_chunk(self, audio: bytes, timestamp_ms: int) -> EnhancedAudioChunk:
|
||||||
"""Enhance chunk of PCM audio @ 16Khz with 16-bit mono samples."""
|
"""Enhance 10ms chunk of PCM audio @ 16Khz with 16-bit mono samples."""
|
||||||
is_speech: bool | None = None
|
is_speech: bool | None = None
|
||||||
|
|
||||||
if self.vad is not None:
|
if self.vad is not None:
|
||||||
# Run VAD
|
# Run VAD
|
||||||
|
assert len(audio) == BYTES_PER_CHUNK
|
||||||
speech_prob = self.vad.Process10ms(audio)
|
speech_prob = self.vad.Process10ms(audio)
|
||||||
is_speech = speech_prob > self.threshold
|
is_speech = speech_prob > self.threshold
|
||||||
|
|
||||||
return EnhancedAudioChunk(
|
return EnhancedAudioChunk(
|
||||||
audio=audio, timestamp_ms=timestamp_ms, is_speech=is_speech
|
audio=audio, timestamp_ms=timestamp_ms, is_speech=is_speech
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
|
||||||
def samples_per_chunk(self) -> int | None:
|
|
||||||
"""Return number of samples per chunk or None if chunking isn't required."""
|
|
||||||
if self.is_vad_enabled:
|
|
||||||
return 160 # 10ms
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
@ -19,4 +19,6 @@ EVENT_RECORDING = f"{DOMAIN}_recording"
|
|||||||
SAMPLE_RATE = 16000 # hertz
|
SAMPLE_RATE = 16000 # hertz
|
||||||
SAMPLE_WIDTH = 2 # bytes
|
SAMPLE_WIDTH = 2 # bytes
|
||||||
SAMPLE_CHANNELS = 1 # mono
|
SAMPLE_CHANNELS = 1 # mono
|
||||||
SAMPLES_PER_CHUNK = 240 # 20 ms @ 16Khz
|
MS_PER_CHUNK = 10
|
||||||
|
SAMPLES_PER_CHUNK = SAMPLE_RATE // (1000 // MS_PER_CHUNK) # 10 ms @ 16Khz
|
||||||
|
BYTES_PER_CHUNK = SAMPLES_PER_CHUNK * SAMPLE_WIDTH * SAMPLE_CHANNELS # 16-bit
|
||||||
|
@ -51,11 +51,13 @@ from homeassistant.util.limited_size_dict import LimitedSizeDict
|
|||||||
|
|
||||||
from .audio_enhancer import AudioEnhancer, EnhancedAudioChunk, MicroVadEnhancer
|
from .audio_enhancer import AudioEnhancer, EnhancedAudioChunk, MicroVadEnhancer
|
||||||
from .const import (
|
from .const import (
|
||||||
|
BYTES_PER_CHUNK,
|
||||||
CONF_DEBUG_RECORDING_DIR,
|
CONF_DEBUG_RECORDING_DIR,
|
||||||
DATA_CONFIG,
|
DATA_CONFIG,
|
||||||
DATA_LAST_WAKE_UP,
|
DATA_LAST_WAKE_UP,
|
||||||
DATA_MIGRATIONS,
|
DATA_MIGRATIONS,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
|
MS_PER_CHUNK,
|
||||||
SAMPLE_CHANNELS,
|
SAMPLE_CHANNELS,
|
||||||
SAMPLE_RATE,
|
SAMPLE_RATE,
|
||||||
SAMPLE_WIDTH,
|
SAMPLE_WIDTH,
|
||||||
@ -502,9 +504,6 @@ class AudioSettings:
|
|||||||
is_vad_enabled: bool = True
|
is_vad_enabled: bool = True
|
||||||
"""True if VAD is used to determine the end of the voice command."""
|
"""True if VAD is used to determine the end of the voice command."""
|
||||||
|
|
||||||
samples_per_chunk: int | None = None
|
|
||||||
"""Number of samples that will be in each audio chunk (None for no chunking)."""
|
|
||||||
|
|
||||||
silence_seconds: float = 0.5
|
silence_seconds: float = 0.5
|
||||||
"""Seconds of silence after voice command has ended."""
|
"""Seconds of silence after voice command has ended."""
|
||||||
|
|
||||||
@ -525,11 +524,6 @@ class AudioSettings:
|
|||||||
or (self.auto_gain_dbfs > 0)
|
or (self.auto_gain_dbfs > 0)
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
|
||||||
def is_chunking_enabled(self) -> bool:
|
|
||||||
"""True if chunk size is set."""
|
|
||||||
return self.samples_per_chunk is not None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PipelineRun:
|
class PipelineRun:
|
||||||
@ -566,7 +560,9 @@ class PipelineRun:
|
|||||||
audio_enhancer: AudioEnhancer | None = None
|
audio_enhancer: AudioEnhancer | None = None
|
||||||
"""VAD/noise suppression/auto gain"""
|
"""VAD/noise suppression/auto gain"""
|
||||||
|
|
||||||
audio_chunking_buffer: AudioBuffer | None = None
|
audio_chunking_buffer: AudioBuffer = field(
|
||||||
|
default_factory=lambda: AudioBuffer(BYTES_PER_CHUNK)
|
||||||
|
)
|
||||||
"""Buffer used when splitting audio into chunks for audio processing"""
|
"""Buffer used when splitting audio into chunks for audio processing"""
|
||||||
|
|
||||||
_device_id: str | None = None
|
_device_id: str | None = None
|
||||||
@ -599,8 +595,6 @@ class PipelineRun:
|
|||||||
self.audio_settings.is_vad_enabled,
|
self.audio_settings.is_vad_enabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.audio_chunking_buffer = AudioBuffer(self.samples_per_chunk * SAMPLE_WIDTH)
|
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
"""Compare pipeline runs by id."""
|
"""Compare pipeline runs by id."""
|
||||||
if isinstance(other, PipelineRun):
|
if isinstance(other, PipelineRun):
|
||||||
@ -608,14 +602,6 @@ class PipelineRun:
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
|
||||||
def samples_per_chunk(self) -> int:
|
|
||||||
"""Return number of samples expected in each audio chunk."""
|
|
||||||
if self.audio_enhancer is not None:
|
|
||||||
return self.audio_enhancer.samples_per_chunk or SAMPLES_PER_CHUNK
|
|
||||||
|
|
||||||
return self.audio_settings.samples_per_chunk or SAMPLES_PER_CHUNK
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def process_event(self, event: PipelineEvent) -> None:
|
def process_event(self, event: PipelineEvent) -> None:
|
||||||
"""Log an event and call listener."""
|
"""Log an event and call listener."""
|
||||||
@ -728,7 +714,7 @@ class PipelineRun:
|
|||||||
# after wake-word-detection.
|
# after wake-word-detection.
|
||||||
num_audio_chunks_to_buffer = int(
|
num_audio_chunks_to_buffer = int(
|
||||||
(wake_word_settings.audio_seconds_to_buffer * SAMPLE_RATE)
|
(wake_word_settings.audio_seconds_to_buffer * SAMPLE_RATE)
|
||||||
/ self.samples_per_chunk
|
/ SAMPLES_PER_CHUNK
|
||||||
)
|
)
|
||||||
|
|
||||||
stt_audio_buffer: deque[EnhancedAudioChunk] | None = None
|
stt_audio_buffer: deque[EnhancedAudioChunk] | None = None
|
||||||
@ -1216,60 +1202,31 @@ class PipelineRun:
|
|||||||
self.debug_recording_thread = None
|
self.debug_recording_thread = None
|
||||||
|
|
||||||
async def process_volume_only(
|
async def process_volume_only(
|
||||||
self,
|
self, audio_stream: AsyncIterable[bytes]
|
||||||
audio_stream: AsyncIterable[bytes],
|
|
||||||
sample_rate: int = SAMPLE_RATE,
|
|
||||||
sample_width: int = SAMPLE_WIDTH,
|
|
||||||
) -> AsyncGenerator[EnhancedAudioChunk]:
|
) -> AsyncGenerator[EnhancedAudioChunk]:
|
||||||
"""Apply volume transformation only (no VAD/audio enhancements) with optional chunking."""
|
"""Apply volume transformation only (no VAD/audio enhancements) with optional chunking."""
|
||||||
assert self.audio_chunking_buffer is not None
|
|
||||||
|
|
||||||
bytes_per_chunk = self.samples_per_chunk * sample_width
|
|
||||||
ms_per_sample = sample_rate // 1000
|
|
||||||
ms_per_chunk = self.samples_per_chunk // ms_per_sample
|
|
||||||
timestamp_ms = 0
|
timestamp_ms = 0
|
||||||
|
|
||||||
async for chunk in audio_stream:
|
async for chunk in audio_stream:
|
||||||
if self.audio_settings.volume_multiplier != 1.0:
|
if self.audio_settings.volume_multiplier != 1.0:
|
||||||
chunk = _multiply_volume(chunk, self.audio_settings.volume_multiplier)
|
chunk = _multiply_volume(chunk, self.audio_settings.volume_multiplier)
|
||||||
|
|
||||||
if self.audio_settings.is_chunking_enabled:
|
|
||||||
for sub_chunk in chunk_samples(
|
for sub_chunk in chunk_samples(
|
||||||
chunk, bytes_per_chunk, self.audio_chunking_buffer
|
chunk, BYTES_PER_CHUNK, self.audio_chunking_buffer
|
||||||
):
|
):
|
||||||
yield EnhancedAudioChunk(
|
yield EnhancedAudioChunk(
|
||||||
audio=sub_chunk,
|
audio=sub_chunk,
|
||||||
timestamp_ms=timestamp_ms,
|
timestamp_ms=timestamp_ms,
|
||||||
is_speech=None, # no VAD
|
is_speech=None, # no VAD
|
||||||
)
|
)
|
||||||
timestamp_ms += ms_per_chunk
|
timestamp_ms += MS_PER_CHUNK
|
||||||
else:
|
|
||||||
# No chunking
|
|
||||||
yield EnhancedAudioChunk(
|
|
||||||
audio=chunk,
|
|
||||||
timestamp_ms=timestamp_ms,
|
|
||||||
is_speech=None, # no VAD
|
|
||||||
)
|
|
||||||
timestamp_ms += (len(chunk) // sample_width) // ms_per_sample
|
|
||||||
|
|
||||||
async def process_enhance_audio(
|
async def process_enhance_audio(
|
||||||
self,
|
self, audio_stream: AsyncIterable[bytes]
|
||||||
audio_stream: AsyncIterable[bytes],
|
|
||||||
sample_rate: int = SAMPLE_RATE,
|
|
||||||
sample_width: int = SAMPLE_WIDTH,
|
|
||||||
) -> AsyncGenerator[EnhancedAudioChunk]:
|
) -> AsyncGenerator[EnhancedAudioChunk]:
|
||||||
"""Split audio into 10 ms chunks and apply VAD/noise suppression/auto gain/volume transformation."""
|
"""Split audio into chunks and apply VAD/noise suppression/auto gain/volume transformation."""
|
||||||
assert self.audio_enhancer is not None
|
assert self.audio_enhancer is not None
|
||||||
assert self.audio_enhancer.samples_per_chunk is not None
|
|
||||||
assert self.audio_chunking_buffer is not None
|
|
||||||
|
|
||||||
bytes_per_chunk = self.audio_enhancer.samples_per_chunk * sample_width
|
|
||||||
ms_per_sample = sample_rate // 1000
|
|
||||||
ms_per_chunk = (
|
|
||||||
self.audio_enhancer.samples_per_chunk // sample_width
|
|
||||||
) // ms_per_sample
|
|
||||||
timestamp_ms = 0
|
timestamp_ms = 0
|
||||||
|
|
||||||
async for dirty_samples in audio_stream:
|
async for dirty_samples in audio_stream:
|
||||||
if self.audio_settings.volume_multiplier != 1.0:
|
if self.audio_settings.volume_multiplier != 1.0:
|
||||||
# Static gain
|
# Static gain
|
||||||
@ -1279,10 +1236,10 @@ class PipelineRun:
|
|||||||
|
|
||||||
# Split into chunks for audio enhancements/VAD
|
# Split into chunks for audio enhancements/VAD
|
||||||
for dirty_chunk in chunk_samples(
|
for dirty_chunk in chunk_samples(
|
||||||
dirty_samples, bytes_per_chunk, self.audio_chunking_buffer
|
dirty_samples, BYTES_PER_CHUNK, self.audio_chunking_buffer
|
||||||
):
|
):
|
||||||
yield self.audio_enhancer.enhance_chunk(dirty_chunk, timestamp_ms)
|
yield self.audio_enhancer.enhance_chunk(dirty_chunk, timestamp_ms)
|
||||||
timestamp_ms += ms_per_chunk
|
timestamp_ms += MS_PER_CHUNK
|
||||||
|
|
||||||
|
|
||||||
def _multiply_volume(chunk: bytes, volume_multiplier: float) -> bytes:
|
def _multiply_volume(chunk: bytes, volume_multiplier: float) -> bytes:
|
||||||
|
@ -21,7 +21,7 @@ from voip_utils import (
|
|||||||
VoipDatagramProtocol,
|
VoipDatagramProtocol,
|
||||||
)
|
)
|
||||||
|
|
||||||
from homeassistant.components import stt, tts
|
from homeassistant.components import assist_pipeline, stt, tts
|
||||||
from homeassistant.components.assist_pipeline import (
|
from homeassistant.components.assist_pipeline import (
|
||||||
Pipeline,
|
Pipeline,
|
||||||
PipelineEvent,
|
PipelineEvent,
|
||||||
@ -331,15 +331,14 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
|||||||
async with asyncio.timeout(self.audio_timeout):
|
async with asyncio.timeout(self.audio_timeout):
|
||||||
chunk = await self._audio_queue.get()
|
chunk = await self._audio_queue.get()
|
||||||
|
|
||||||
assert audio_enhancer.samples_per_chunk is not None
|
vad_buffer = AudioBuffer(assist_pipeline.SAMPLES_PER_CHUNK * WIDTH)
|
||||||
vad_buffer = AudioBuffer(audio_enhancer.samples_per_chunk * WIDTH)
|
|
||||||
|
|
||||||
while chunk:
|
while chunk:
|
||||||
chunk_buffer.append(chunk)
|
chunk_buffer.append(chunk)
|
||||||
|
|
||||||
segmenter.process_with_vad(
|
segmenter.process_with_vad(
|
||||||
chunk,
|
chunk,
|
||||||
audio_enhancer.samples_per_chunk,
|
assist_pipeline.SAMPLES_PER_CHUNK,
|
||||||
lambda x: audio_enhancer.enhance_chunk(x, 0).is_speech is True,
|
lambda x: audio_enhancer.enhance_chunk(x, 0).is_speech is True,
|
||||||
vad_buffer,
|
vad_buffer,
|
||||||
)
|
)
|
||||||
@ -371,13 +370,12 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
|||||||
async with asyncio.timeout(self.audio_timeout):
|
async with asyncio.timeout(self.audio_timeout):
|
||||||
chunk = await self._audio_queue.get()
|
chunk = await self._audio_queue.get()
|
||||||
|
|
||||||
assert audio_enhancer.samples_per_chunk is not None
|
vad_buffer = AudioBuffer(assist_pipeline.SAMPLES_PER_CHUNK * WIDTH)
|
||||||
vad_buffer = AudioBuffer(audio_enhancer.samples_per_chunk * WIDTH)
|
|
||||||
|
|
||||||
while chunk:
|
while chunk:
|
||||||
if not segmenter.process_with_vad(
|
if not segmenter.process_with_vad(
|
||||||
chunk,
|
chunk,
|
||||||
audio_enhancer.samples_per_chunk,
|
assist_pipeline.SAMPLES_PER_CHUNK,
|
||||||
lambda x: audio_enhancer.enhance_chunk(x, 0).is_speech is True,
|
lambda x: audio_enhancer.enhance_chunk(x, 0).is_speech is True,
|
||||||
vad_buffer,
|
vad_buffer,
|
||||||
):
|
):
|
||||||
@ -437,13 +435,13 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
|||||||
sample_channels = wav_file.getnchannels()
|
sample_channels = wav_file.getnchannels()
|
||||||
|
|
||||||
if (
|
if (
|
||||||
(sample_rate != 16000)
|
(sample_rate != RATE)
|
||||||
or (sample_width != 2)
|
or (sample_width != WIDTH)
|
||||||
or (sample_channels != 1)
|
or (sample_channels != CHANNELS)
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Expected rate/width/channels as 16000/2/1,"
|
f"Expected rate/width/channels as {RATE}/{WIDTH}/{CHANNELS},"
|
||||||
" got {sample_rate}/{sample_width}/{sample_channels}}"
|
f" got {sample_rate}/{sample_width}/{sample_channels}"
|
||||||
)
|
)
|
||||||
|
|
||||||
audio_bytes = wav_file.readframes(wav_file.getnframes())
|
audio_bytes = wav_file.readframes(wav_file.getnframes())
|
||||||
|
@ -11,6 +11,12 @@ import pytest
|
|||||||
|
|
||||||
from homeassistant.components import stt, tts, wake_word
|
from homeassistant.components import stt, tts, wake_word
|
||||||
from homeassistant.components.assist_pipeline import DOMAIN, select as assist_select
|
from homeassistant.components.assist_pipeline import DOMAIN, select as assist_select
|
||||||
|
from homeassistant.components.assist_pipeline.const import (
|
||||||
|
BYTES_PER_CHUNK,
|
||||||
|
SAMPLE_CHANNELS,
|
||||||
|
SAMPLE_RATE,
|
||||||
|
SAMPLE_WIDTH,
|
||||||
|
)
|
||||||
from homeassistant.components.assist_pipeline.pipeline import (
|
from homeassistant.components.assist_pipeline.pipeline import (
|
||||||
PipelineData,
|
PipelineData,
|
||||||
PipelineStorageCollection,
|
PipelineStorageCollection,
|
||||||
@ -33,6 +39,8 @@ from tests.common import (
|
|||||||
|
|
||||||
_TRANSCRIPT = "test transcript"
|
_TRANSCRIPT = "test transcript"
|
||||||
|
|
||||||
|
BYTES_ONE_SECOND = SAMPLE_RATE * SAMPLE_WIDTH * SAMPLE_CHANNELS
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def mock_tts_cache_dir_autouse(mock_tts_cache_dir: Path) -> Path:
|
def mock_tts_cache_dir_autouse(mock_tts_cache_dir: Path) -> Path:
|
||||||
@ -462,3 +470,8 @@ def pipeline_data(hass: HomeAssistant, init_components) -> PipelineData:
|
|||||||
def pipeline_storage(pipeline_data) -> PipelineStorageCollection:
|
def pipeline_storage(pipeline_data) -> PipelineStorageCollection:
|
||||||
"""Return pipeline storage collection."""
|
"""Return pipeline storage collection."""
|
||||||
return pipeline_data.pipeline_store
|
return pipeline_data.pipeline_store
|
||||||
|
|
||||||
|
|
||||||
|
def make_10ms_chunk(header: bytes) -> bytes:
|
||||||
|
"""Return 10ms of zeros with the given header."""
|
||||||
|
return header + bytes(BYTES_PER_CHUNK - len(header))
|
||||||
|
@ -440,7 +440,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_device_capture_override.2
|
# name: test_device_capture_override.2
|
||||||
dict({
|
dict({
|
||||||
'audio': 'Y2h1bmsx',
|
'audio': 'Y2h1bmsxAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=',
|
||||||
'channels': 1,
|
'channels': 1,
|
||||||
'rate': 16000,
|
'rate': 16000,
|
||||||
'type': 'audio',
|
'type': 'audio',
|
||||||
|
@ -13,6 +13,7 @@ from syrupy.assertion import SnapshotAssertion
|
|||||||
|
|
||||||
from homeassistant.components import assist_pipeline, media_source, stt, tts
|
from homeassistant.components import assist_pipeline, media_source, stt, tts
|
||||||
from homeassistant.components.assist_pipeline.const import (
|
from homeassistant.components.assist_pipeline.const import (
|
||||||
|
BYTES_PER_CHUNK,
|
||||||
CONF_DEBUG_RECORDING_DIR,
|
CONF_DEBUG_RECORDING_DIR,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
)
|
)
|
||||||
@ -20,16 +21,16 @@ from homeassistant.core import Context, HomeAssistant
|
|||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from .conftest import (
|
from .conftest import (
|
||||||
|
BYTES_ONE_SECOND,
|
||||||
MockSttProvider,
|
MockSttProvider,
|
||||||
MockSttProviderEntity,
|
MockSttProviderEntity,
|
||||||
MockTTSProvider,
|
MockTTSProvider,
|
||||||
MockWakeWordEntity,
|
MockWakeWordEntity,
|
||||||
|
make_10ms_chunk,
|
||||||
)
|
)
|
||||||
|
|
||||||
from tests.typing import ClientSessionGenerator, WebSocketGenerator
|
from tests.typing import ClientSessionGenerator, WebSocketGenerator
|
||||||
|
|
||||||
BYTES_ONE_SECOND = 16000 * 2
|
|
||||||
|
|
||||||
|
|
||||||
def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]:
|
def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]:
|
||||||
"""Process events to remove dynamic values."""
|
"""Process events to remove dynamic values."""
|
||||||
@ -58,8 +59,8 @@ async def test_pipeline_from_audio_stream_auto(
|
|||||||
events: list[assist_pipeline.PipelineEvent] = []
|
events: list[assist_pipeline.PipelineEvent] = []
|
||||||
|
|
||||||
async def audio_data():
|
async def audio_data():
|
||||||
yield b"part1"
|
yield make_10ms_chunk(b"part1")
|
||||||
yield b"part2"
|
yield make_10ms_chunk(b"part2")
|
||||||
yield b""
|
yield b""
|
||||||
|
|
||||||
await assist_pipeline.async_pipeline_from_audio_stream(
|
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||||
@ -79,7 +80,9 @@ async def test_pipeline_from_audio_stream_auto(
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert process_events(events) == snapshot
|
assert process_events(events) == snapshot
|
||||||
assert mock_stt_provider.received == [b"part1", b"part2"]
|
assert len(mock_stt_provider.received) == 2
|
||||||
|
assert mock_stt_provider.received[0].startswith(b"part1")
|
||||||
|
assert mock_stt_provider.received[1].startswith(b"part2")
|
||||||
|
|
||||||
|
|
||||||
async def test_pipeline_from_audio_stream_legacy(
|
async def test_pipeline_from_audio_stream_legacy(
|
||||||
@ -98,8 +101,8 @@ async def test_pipeline_from_audio_stream_legacy(
|
|||||||
events: list[assist_pipeline.PipelineEvent] = []
|
events: list[assist_pipeline.PipelineEvent] = []
|
||||||
|
|
||||||
async def audio_data():
|
async def audio_data():
|
||||||
yield b"part1"
|
yield make_10ms_chunk(b"part1")
|
||||||
yield b"part2"
|
yield make_10ms_chunk(b"part2")
|
||||||
yield b""
|
yield b""
|
||||||
|
|
||||||
# Create a pipeline using an stt entity
|
# Create a pipeline using an stt entity
|
||||||
@ -142,7 +145,9 @@ async def test_pipeline_from_audio_stream_legacy(
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert process_events(events) == snapshot
|
assert process_events(events) == snapshot
|
||||||
assert mock_stt_provider.received == [b"part1", b"part2"]
|
assert len(mock_stt_provider.received) == 2
|
||||||
|
assert mock_stt_provider.received[0].startswith(b"part1")
|
||||||
|
assert mock_stt_provider.received[1].startswith(b"part2")
|
||||||
|
|
||||||
|
|
||||||
async def test_pipeline_from_audio_stream_entity(
|
async def test_pipeline_from_audio_stream_entity(
|
||||||
@ -161,8 +166,8 @@ async def test_pipeline_from_audio_stream_entity(
|
|||||||
events: list[assist_pipeline.PipelineEvent] = []
|
events: list[assist_pipeline.PipelineEvent] = []
|
||||||
|
|
||||||
async def audio_data():
|
async def audio_data():
|
||||||
yield b"part1"
|
yield make_10ms_chunk(b"part1")
|
||||||
yield b"part2"
|
yield make_10ms_chunk(b"part2")
|
||||||
yield b""
|
yield b""
|
||||||
|
|
||||||
# Create a pipeline using an stt entity
|
# Create a pipeline using an stt entity
|
||||||
@ -205,7 +210,9 @@ async def test_pipeline_from_audio_stream_entity(
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert process_events(events) == snapshot
|
assert process_events(events) == snapshot
|
||||||
assert mock_stt_provider_entity.received == [b"part1", b"part2"]
|
assert len(mock_stt_provider_entity.received) == 2
|
||||||
|
assert mock_stt_provider_entity.received[0].startswith(b"part1")
|
||||||
|
assert mock_stt_provider_entity.received[1].startswith(b"part2")
|
||||||
|
|
||||||
|
|
||||||
async def test_pipeline_from_audio_stream_no_stt(
|
async def test_pipeline_from_audio_stream_no_stt(
|
||||||
@ -224,8 +231,8 @@ async def test_pipeline_from_audio_stream_no_stt(
|
|||||||
events: list[assist_pipeline.PipelineEvent] = []
|
events: list[assist_pipeline.PipelineEvent] = []
|
||||||
|
|
||||||
async def audio_data():
|
async def audio_data():
|
||||||
yield b"part1"
|
yield make_10ms_chunk(b"part1")
|
||||||
yield b"part2"
|
yield make_10ms_chunk(b"part2")
|
||||||
yield b""
|
yield b""
|
||||||
|
|
||||||
# Create a pipeline without stt support
|
# Create a pipeline without stt support
|
||||||
@ -285,8 +292,8 @@ async def test_pipeline_from_audio_stream_unknown_pipeline(
|
|||||||
events: list[assist_pipeline.PipelineEvent] = []
|
events: list[assist_pipeline.PipelineEvent] = []
|
||||||
|
|
||||||
async def audio_data():
|
async def audio_data():
|
||||||
yield b"part1"
|
yield make_10ms_chunk(b"part1")
|
||||||
yield b"part2"
|
yield make_10ms_chunk(b"part2")
|
||||||
yield b""
|
yield b""
|
||||||
|
|
||||||
# Try to use the created pipeline
|
# Try to use the created pipeline
|
||||||
@ -327,7 +334,7 @@ async def test_pipeline_from_audio_stream_wake_word(
|
|||||||
# [0, 2, ...]
|
# [0, 2, ...]
|
||||||
wake_chunk_2 = bytes(it.islice(it.cycle(range(0, 256, 2)), BYTES_ONE_SECOND))
|
wake_chunk_2 = bytes(it.islice(it.cycle(range(0, 256, 2)), BYTES_ONE_SECOND))
|
||||||
|
|
||||||
samples_per_chunk = 160
|
samples_per_chunk = 160 # 10ms @ 16Khz
|
||||||
bytes_per_chunk = samples_per_chunk * 2 # 16-bit
|
bytes_per_chunk = samples_per_chunk * 2 # 16-bit
|
||||||
|
|
||||||
async def audio_data():
|
async def audio_data():
|
||||||
@ -343,8 +350,8 @@ async def test_pipeline_from_audio_stream_wake_word(
|
|||||||
yield wake_chunk_2[i : i + bytes_per_chunk]
|
yield wake_chunk_2[i : i + bytes_per_chunk]
|
||||||
i += bytes_per_chunk
|
i += bytes_per_chunk
|
||||||
|
|
||||||
for chunk in (b"wake word!", b"part1", b"part2"):
|
for header in (b"wake word!", b"part1", b"part2"):
|
||||||
yield chunk + bytes(bytes_per_chunk - len(chunk))
|
yield make_10ms_chunk(header)
|
||||||
|
|
||||||
yield b""
|
yield b""
|
||||||
|
|
||||||
@ -365,9 +372,7 @@ async def test_pipeline_from_audio_stream_wake_word(
|
|||||||
wake_word_settings=assist_pipeline.WakeWordSettings(
|
wake_word_settings=assist_pipeline.WakeWordSettings(
|
||||||
audio_seconds_to_buffer=1.5
|
audio_seconds_to_buffer=1.5
|
||||||
),
|
),
|
||||||
audio_settings=assist_pipeline.AudioSettings(
|
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
|
||||||
is_vad_enabled=False, samples_per_chunk=samples_per_chunk
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert process_events(events) == snapshot
|
assert process_events(events) == snapshot
|
||||||
@ -408,13 +413,11 @@ async def test_pipeline_save_audio(
|
|||||||
pipeline = assist_pipeline.async_get_pipeline(hass)
|
pipeline = assist_pipeline.async_get_pipeline(hass)
|
||||||
events: list[assist_pipeline.PipelineEvent] = []
|
events: list[assist_pipeline.PipelineEvent] = []
|
||||||
|
|
||||||
# Pad out to an even number of bytes since these "samples" will be saved
|
|
||||||
# as 16-bit values.
|
|
||||||
async def audio_data():
|
async def audio_data():
|
||||||
yield b"wake word_"
|
yield make_10ms_chunk(b"wake word")
|
||||||
# queued audio
|
# queued audio
|
||||||
yield b"part1_"
|
yield make_10ms_chunk(b"part1")
|
||||||
yield b"part2_"
|
yield make_10ms_chunk(b"part2")
|
||||||
yield b""
|
yield b""
|
||||||
|
|
||||||
await assist_pipeline.async_pipeline_from_audio_stream(
|
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||||
@ -457,12 +460,16 @@ async def test_pipeline_save_audio(
|
|||||||
# Verify wake file
|
# Verify wake file
|
||||||
with wave.open(str(wake_file), "rb") as wake_wav:
|
with wave.open(str(wake_file), "rb") as wake_wav:
|
||||||
wake_data = wake_wav.readframes(wake_wav.getnframes())
|
wake_data = wake_wav.readframes(wake_wav.getnframes())
|
||||||
assert wake_data == b"wake word_"
|
assert wake_data.startswith(b"wake word")
|
||||||
|
|
||||||
# Verify stt file
|
# Verify stt file
|
||||||
with wave.open(str(stt_file), "rb") as stt_wav:
|
with wave.open(str(stt_file), "rb") as stt_wav:
|
||||||
stt_data = stt_wav.readframes(stt_wav.getnframes())
|
stt_data = stt_wav.readframes(stt_wav.getnframes())
|
||||||
assert stt_data == b"queued audiopart1_part2_"
|
assert stt_data.startswith(b"queued audio")
|
||||||
|
stt_data = stt_data[len(b"queued audio") :]
|
||||||
|
assert stt_data.startswith(b"part1")
|
||||||
|
stt_data = stt_data[BYTES_PER_CHUNK:]
|
||||||
|
assert stt_data.startswith(b"part2")
|
||||||
|
|
||||||
|
|
||||||
async def test_pipeline_saved_audio_with_device_id(
|
async def test_pipeline_saved_audio_with_device_id(
|
||||||
@ -645,10 +652,10 @@ async def test_wake_word_detection_aborted(
|
|||||||
events: list[assist_pipeline.PipelineEvent] = []
|
events: list[assist_pipeline.PipelineEvent] = []
|
||||||
|
|
||||||
async def audio_data():
|
async def audio_data():
|
||||||
yield b"silence!"
|
yield make_10ms_chunk(b"silence!")
|
||||||
yield b"wake word!"
|
yield make_10ms_chunk(b"wake word!")
|
||||||
yield b"part1"
|
yield make_10ms_chunk(b"part1")
|
||||||
yield b"part2"
|
yield make_10ms_chunk(b"part2")
|
||||||
yield b""
|
yield b""
|
||||||
|
|
||||||
pipeline_store = pipeline_data.pipeline_store
|
pipeline_store = pipeline_data.pipeline_store
|
||||||
|
@ -8,7 +8,12 @@ from unittest.mock import ANY, patch
|
|||||||
import pytest
|
import pytest
|
||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
|
||||||
from homeassistant.components.assist_pipeline.const import DOMAIN
|
from homeassistant.components.assist_pipeline.const import (
|
||||||
|
DOMAIN,
|
||||||
|
SAMPLE_CHANNELS,
|
||||||
|
SAMPLE_RATE,
|
||||||
|
SAMPLE_WIDTH,
|
||||||
|
)
|
||||||
from homeassistant.components.assist_pipeline.pipeline import (
|
from homeassistant.components.assist_pipeline.pipeline import (
|
||||||
DeviceAudioQueue,
|
DeviceAudioQueue,
|
||||||
Pipeline,
|
Pipeline,
|
||||||
@ -18,7 +23,13 @@ from homeassistant.core import HomeAssistant
|
|||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import device_registry as dr
|
from homeassistant.helpers import device_registry as dr
|
||||||
|
|
||||||
from .conftest import MockWakeWordEntity, MockWakeWordEntity2
|
from .conftest import (
|
||||||
|
BYTES_ONE_SECOND,
|
||||||
|
BYTES_PER_CHUNK,
|
||||||
|
MockWakeWordEntity,
|
||||||
|
MockWakeWordEntity2,
|
||||||
|
make_10ms_chunk,
|
||||||
|
)
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
from tests.typing import WebSocketGenerator
|
from tests.typing import WebSocketGenerator
|
||||||
@ -205,7 +216,7 @@ async def test_audio_pipeline_with_wake_word_timeout(
|
|||||||
"start_stage": "wake_word",
|
"start_stage": "wake_word",
|
||||||
"end_stage": "tts",
|
"end_stage": "tts",
|
||||||
"input": {
|
"input": {
|
||||||
"sample_rate": 16000,
|
"sample_rate": SAMPLE_RATE,
|
||||||
"timeout": 1,
|
"timeout": 1,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -229,7 +240,7 @@ async def test_audio_pipeline_with_wake_word_timeout(
|
|||||||
events.append(msg["event"])
|
events.append(msg["event"])
|
||||||
|
|
||||||
# 2 seconds of silence
|
# 2 seconds of silence
|
||||||
await client.send_bytes(bytes([1]) + bytes(16000 * 2 * 2))
|
await client.send_bytes(bytes([1]) + bytes(2 * BYTES_ONE_SECOND))
|
||||||
|
|
||||||
# Time out error
|
# Time out error
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
@ -259,7 +270,7 @@ async def test_audio_pipeline_with_wake_word_no_timeout(
|
|||||||
"type": "assist_pipeline/run",
|
"type": "assist_pipeline/run",
|
||||||
"start_stage": "wake_word",
|
"start_stage": "wake_word",
|
||||||
"end_stage": "tts",
|
"end_stage": "tts",
|
||||||
"input": {"sample_rate": 16000, "timeout": 0, "no_vad": True},
|
"input": {"sample_rate": SAMPLE_RATE, "timeout": 0, "no_vad": True},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -282,8 +293,9 @@ async def test_audio_pipeline_with_wake_word_no_timeout(
|
|||||||
events.append(msg["event"])
|
events.append(msg["event"])
|
||||||
|
|
||||||
# "audio"
|
# "audio"
|
||||||
await client.send_bytes(bytes([handler_id]) + b"wake word")
|
await client.send_bytes(bytes([handler_id]) + make_10ms_chunk(b"wake word"))
|
||||||
|
|
||||||
|
async with asyncio.timeout(1):
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["event"]["type"] == "wake_word-end"
|
assert msg["event"]["type"] == "wake_word-end"
|
||||||
assert msg["event"]["data"] == snapshot
|
assert msg["event"]["data"] == snapshot
|
||||||
@ -365,7 +377,7 @@ async def test_audio_pipeline_no_wake_word_engine(
|
|||||||
"start_stage": "wake_word",
|
"start_stage": "wake_word",
|
||||||
"end_stage": "tts",
|
"end_stage": "tts",
|
||||||
"input": {
|
"input": {
|
||||||
"sample_rate": 16000,
|
"sample_rate": SAMPLE_RATE,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -402,7 +414,7 @@ async def test_audio_pipeline_no_wake_word_entity(
|
|||||||
"start_stage": "wake_word",
|
"start_stage": "wake_word",
|
||||||
"end_stage": "tts",
|
"end_stage": "tts",
|
||||||
"input": {
|
"input": {
|
||||||
"sample_rate": 16000,
|
"sample_rate": SAMPLE_RATE,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -1771,7 +1783,7 @@ async def test_audio_pipeline_with_enhancements(
|
|||||||
"start_stage": "stt",
|
"start_stage": "stt",
|
||||||
"end_stage": "tts",
|
"end_stage": "tts",
|
||||||
"input": {
|
"input": {
|
||||||
"sample_rate": 16000,
|
"sample_rate": SAMPLE_RATE,
|
||||||
# Enhancements
|
# Enhancements
|
||||||
"noise_suppression_level": 2,
|
"noise_suppression_level": 2,
|
||||||
"auto_gain_dbfs": 15,
|
"auto_gain_dbfs": 15,
|
||||||
@ -1801,7 +1813,7 @@ async def test_audio_pipeline_with_enhancements(
|
|||||||
# One second of silence.
|
# One second of silence.
|
||||||
# This will pass through the audio enhancement pipeline, but we don't test
|
# This will pass through the audio enhancement pipeline, but we don't test
|
||||||
# the actual output.
|
# the actual output.
|
||||||
await client.send_bytes(bytes([handler_id]) + bytes(16000 * 2))
|
await client.send_bytes(bytes([handler_id]) + bytes(BYTES_ONE_SECOND))
|
||||||
|
|
||||||
# End of audio stream (handler id + empty payload)
|
# End of audio stream (handler id + empty payload)
|
||||||
await client.send_bytes(bytes([handler_id]))
|
await client.send_bytes(bytes([handler_id]))
|
||||||
@ -1871,7 +1883,7 @@ async def test_wake_word_cooldown_same_id(
|
|||||||
"type": "assist_pipeline/run",
|
"type": "assist_pipeline/run",
|
||||||
"start_stage": "wake_word",
|
"start_stage": "wake_word",
|
||||||
"end_stage": "tts",
|
"end_stage": "tts",
|
||||||
"input": {"sample_rate": 16000, "no_vad": True},
|
"input": {"sample_rate": SAMPLE_RATE, "no_vad": True},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1880,7 +1892,7 @@ async def test_wake_word_cooldown_same_id(
|
|||||||
"type": "assist_pipeline/run",
|
"type": "assist_pipeline/run",
|
||||||
"start_stage": "wake_word",
|
"start_stage": "wake_word",
|
||||||
"end_stage": "tts",
|
"end_stage": "tts",
|
||||||
"input": {"sample_rate": 16000, "no_vad": True},
|
"input": {"sample_rate": SAMPLE_RATE, "no_vad": True},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1914,8 +1926,8 @@ async def test_wake_word_cooldown_same_id(
|
|||||||
assert msg["event"]["data"] == snapshot
|
assert msg["event"]["data"] == snapshot
|
||||||
|
|
||||||
# Wake both up at the same time
|
# Wake both up at the same time
|
||||||
await client_1.send_bytes(bytes([handler_id_1]) + b"wake word")
|
await client_1.send_bytes(bytes([handler_id_1]) + make_10ms_chunk(b"wake word"))
|
||||||
await client_2.send_bytes(bytes([handler_id_2]) + b"wake word")
|
await client_2.send_bytes(bytes([handler_id_2]) + make_10ms_chunk(b"wake word"))
|
||||||
|
|
||||||
# Get response events
|
# Get response events
|
||||||
error_data: dict[str, Any] | None = None
|
error_data: dict[str, Any] | None = None
|
||||||
@ -1954,7 +1966,7 @@ async def test_wake_word_cooldown_different_ids(
|
|||||||
"type": "assist_pipeline/run",
|
"type": "assist_pipeline/run",
|
||||||
"start_stage": "wake_word",
|
"start_stage": "wake_word",
|
||||||
"end_stage": "tts",
|
"end_stage": "tts",
|
||||||
"input": {"sample_rate": 16000, "no_vad": True},
|
"input": {"sample_rate": SAMPLE_RATE, "no_vad": True},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1963,7 +1975,7 @@ async def test_wake_word_cooldown_different_ids(
|
|||||||
"type": "assist_pipeline/run",
|
"type": "assist_pipeline/run",
|
||||||
"start_stage": "wake_word",
|
"start_stage": "wake_word",
|
||||||
"end_stage": "tts",
|
"end_stage": "tts",
|
||||||
"input": {"sample_rate": 16000, "no_vad": True},
|
"input": {"sample_rate": SAMPLE_RATE, "no_vad": True},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1997,8 +2009,8 @@ async def test_wake_word_cooldown_different_ids(
|
|||||||
assert msg["event"]["data"] == snapshot
|
assert msg["event"]["data"] == snapshot
|
||||||
|
|
||||||
# Wake both up at the same time, but they will have different wake word ids
|
# Wake both up at the same time, but they will have different wake word ids
|
||||||
await client_1.send_bytes(bytes([handler_id_1]) + b"wake word")
|
await client_1.send_bytes(bytes([handler_id_1]) + make_10ms_chunk(b"wake word"))
|
||||||
await client_2.send_bytes(bytes([handler_id_2]) + b"wake word")
|
await client_2.send_bytes(bytes([handler_id_2]) + make_10ms_chunk(b"wake word"))
|
||||||
|
|
||||||
# Get response events
|
# Get response events
|
||||||
msg = await client_1.receive_json()
|
msg = await client_1.receive_json()
|
||||||
@ -2073,7 +2085,7 @@ async def test_wake_word_cooldown_different_entities(
|
|||||||
"pipeline": pipeline_id_1,
|
"pipeline": pipeline_id_1,
|
||||||
"start_stage": "wake_word",
|
"start_stage": "wake_word",
|
||||||
"end_stage": "tts",
|
"end_stage": "tts",
|
||||||
"input": {"sample_rate": 16000, "no_vad": True},
|
"input": {"sample_rate": SAMPLE_RATE, "no_vad": True},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -2084,7 +2096,7 @@ async def test_wake_word_cooldown_different_entities(
|
|||||||
"pipeline": pipeline_id_2,
|
"pipeline": pipeline_id_2,
|
||||||
"start_stage": "wake_word",
|
"start_stage": "wake_word",
|
||||||
"end_stage": "tts",
|
"end_stage": "tts",
|
||||||
"input": {"sample_rate": 16000, "no_vad": True},
|
"input": {"sample_rate": SAMPLE_RATE, "no_vad": True},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -2119,8 +2131,8 @@ async def test_wake_word_cooldown_different_entities(
|
|||||||
|
|
||||||
# Wake both up at the same time.
|
# Wake both up at the same time.
|
||||||
# They will have the same wake word id, but different entities.
|
# They will have the same wake word id, but different entities.
|
||||||
await client_1.send_bytes(bytes([handler_id_1]) + b"wake word")
|
await client_1.send_bytes(bytes([handler_id_1]) + make_10ms_chunk(b"wake word"))
|
||||||
await client_2.send_bytes(bytes([handler_id_2]) + b"wake word")
|
await client_2.send_bytes(bytes([handler_id_2]) + make_10ms_chunk(b"wake word"))
|
||||||
|
|
||||||
# Get response events
|
# Get response events
|
||||||
error_data: dict[str, Any] | None = None
|
error_data: dict[str, Any] | None = None
|
||||||
@ -2158,7 +2170,11 @@ async def test_device_capture(
|
|||||||
identifiers={("demo", "satellite-1234")},
|
identifiers={("demo", "satellite-1234")},
|
||||||
)
|
)
|
||||||
|
|
||||||
audio_chunks = [b"chunk1", b"chunk2", b"chunk3"]
|
audio_chunks = [
|
||||||
|
make_10ms_chunk(b"chunk1"),
|
||||||
|
make_10ms_chunk(b"chunk2"),
|
||||||
|
make_10ms_chunk(b"chunk3"),
|
||||||
|
]
|
||||||
|
|
||||||
# Start capture
|
# Start capture
|
||||||
client_capture = await hass_ws_client(hass)
|
client_capture = await hass_ws_client(hass)
|
||||||
@ -2181,7 +2197,7 @@ async def test_device_capture(
|
|||||||
"type": "assist_pipeline/run",
|
"type": "assist_pipeline/run",
|
||||||
"start_stage": "stt",
|
"start_stage": "stt",
|
||||||
"end_stage": "stt",
|
"end_stage": "stt",
|
||||||
"input": {"sample_rate": 16000, "no_vad": True},
|
"input": {"sample_rate": SAMPLE_RATE, "no_vad": True},
|
||||||
"device_id": satellite_device.id,
|
"device_id": satellite_device.id,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -2232,9 +2248,9 @@ async def test_device_capture(
|
|||||||
# Verify audio chunks
|
# Verify audio chunks
|
||||||
for i, audio_chunk in enumerate(audio_chunks):
|
for i, audio_chunk in enumerate(audio_chunks):
|
||||||
assert events[i]["type"] == "audio"
|
assert events[i]["type"] == "audio"
|
||||||
assert events[i]["rate"] == 16000
|
assert events[i]["rate"] == SAMPLE_RATE
|
||||||
assert events[i]["width"] == 2
|
assert events[i]["width"] == SAMPLE_WIDTH
|
||||||
assert events[i]["channels"] == 1
|
assert events[i]["channels"] == SAMPLE_CHANNELS
|
||||||
|
|
||||||
# Audio is base64 encoded
|
# Audio is base64 encoded
|
||||||
assert events[i]["audio"] == base64.b64encode(audio_chunk).decode("ascii")
|
assert events[i]["audio"] == base64.b64encode(audio_chunk).decode("ascii")
|
||||||
@ -2259,7 +2275,11 @@ async def test_device_capture_override(
|
|||||||
identifiers={("demo", "satellite-1234")},
|
identifiers={("demo", "satellite-1234")},
|
||||||
)
|
)
|
||||||
|
|
||||||
audio_chunks = [b"chunk1", b"chunk2", b"chunk3"]
|
audio_chunks = [
|
||||||
|
make_10ms_chunk(b"chunk1"),
|
||||||
|
make_10ms_chunk(b"chunk2"),
|
||||||
|
make_10ms_chunk(b"chunk3"),
|
||||||
|
]
|
||||||
|
|
||||||
# Start first capture
|
# Start first capture
|
||||||
client_capture_1 = await hass_ws_client(hass)
|
client_capture_1 = await hass_ws_client(hass)
|
||||||
@ -2282,7 +2302,7 @@ async def test_device_capture_override(
|
|||||||
"type": "assist_pipeline/run",
|
"type": "assist_pipeline/run",
|
||||||
"start_stage": "stt",
|
"start_stage": "stt",
|
||||||
"end_stage": "stt",
|
"end_stage": "stt",
|
||||||
"input": {"sample_rate": 16000, "no_vad": True},
|
"input": {"sample_rate": SAMPLE_RATE, "no_vad": True},
|
||||||
"device_id": satellite_device.id,
|
"device_id": satellite_device.id,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -2365,9 +2385,9 @@ async def test_device_capture_override(
|
|||||||
# Verify all but first audio chunk
|
# Verify all but first audio chunk
|
||||||
for i, audio_chunk in enumerate(audio_chunks[1:]):
|
for i, audio_chunk in enumerate(audio_chunks[1:]):
|
||||||
assert events[i]["type"] == "audio"
|
assert events[i]["type"] == "audio"
|
||||||
assert events[i]["rate"] == 16000
|
assert events[i]["rate"] == SAMPLE_RATE
|
||||||
assert events[i]["width"] == 2
|
assert events[i]["width"] == SAMPLE_WIDTH
|
||||||
assert events[i]["channels"] == 1
|
assert events[i]["channels"] == SAMPLE_CHANNELS
|
||||||
|
|
||||||
# Audio is base64 encoded
|
# Audio is base64 encoded
|
||||||
assert events[i]["audio"] == base64.b64encode(audio_chunk).decode("ascii")
|
assert events[i]["audio"] == base64.b64encode(audio_chunk).decode("ascii")
|
||||||
@ -2427,7 +2447,7 @@ async def test_device_capture_queue_full(
|
|||||||
"type": "assist_pipeline/run",
|
"type": "assist_pipeline/run",
|
||||||
"start_stage": "stt",
|
"start_stage": "stt",
|
||||||
"end_stage": "stt",
|
"end_stage": "stt",
|
||||||
"input": {"sample_rate": 16000, "no_vad": True},
|
"input": {"sample_rate": SAMPLE_RATE, "no_vad": True},
|
||||||
"device_id": satellite_device.id,
|
"device_id": satellite_device.id,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -2448,8 +2468,8 @@ async def test_device_capture_queue_full(
|
|||||||
assert msg["event"]["type"] == "stt-start"
|
assert msg["event"]["type"] == "stt-start"
|
||||||
assert msg["event"]["data"] == snapshot
|
assert msg["event"]["data"] == snapshot
|
||||||
|
|
||||||
# Single sample will "overflow" the queue
|
# Single chunk will "overflow" the queue
|
||||||
await client_pipeline.send_bytes(bytes([handler_id, 0, 0]))
|
await client_pipeline.send_bytes(bytes([handler_id]) + bytes(BYTES_PER_CHUNK))
|
||||||
|
|
||||||
# End of audio stream
|
# End of audio stream
|
||||||
await client_pipeline.send_bytes(bytes([handler_id]))
|
await client_pipeline.send_bytes(bytes([handler_id]))
|
||||||
@ -2557,7 +2577,7 @@ async def test_stt_cooldown_same_id(
|
|||||||
"start_stage": "stt",
|
"start_stage": "stt",
|
||||||
"end_stage": "tts",
|
"end_stage": "tts",
|
||||||
"input": {
|
"input": {
|
||||||
"sample_rate": 16000,
|
"sample_rate": SAMPLE_RATE,
|
||||||
"wake_word_phrase": "ok_nabu",
|
"wake_word_phrase": "ok_nabu",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -2569,7 +2589,7 @@ async def test_stt_cooldown_same_id(
|
|||||||
"start_stage": "stt",
|
"start_stage": "stt",
|
||||||
"end_stage": "tts",
|
"end_stage": "tts",
|
||||||
"input": {
|
"input": {
|
||||||
"sample_rate": 16000,
|
"sample_rate": SAMPLE_RATE,
|
||||||
"wake_word_phrase": "ok_nabu",
|
"wake_word_phrase": "ok_nabu",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -2628,7 +2648,7 @@ async def test_stt_cooldown_different_ids(
|
|||||||
"start_stage": "stt",
|
"start_stage": "stt",
|
||||||
"end_stage": "tts",
|
"end_stage": "tts",
|
||||||
"input": {
|
"input": {
|
||||||
"sample_rate": 16000,
|
"sample_rate": SAMPLE_RATE,
|
||||||
"wake_word_phrase": "ok_nabu",
|
"wake_word_phrase": "ok_nabu",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -2640,7 +2660,7 @@ async def test_stt_cooldown_different_ids(
|
|||||||
"start_stage": "stt",
|
"start_stage": "stt",
|
||||||
"end_stage": "tts",
|
"end_stage": "tts",
|
||||||
"input": {
|
"input": {
|
||||||
"sample_rate": 16000,
|
"sample_rate": SAMPLE_RATE,
|
||||||
"wake_word_phrase": "hey_jarvis",
|
"wake_word_phrase": "hey_jarvis",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user