From b3e247d5f03c7d934ee96d361643545640402c84 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Thu, 16 Nov 2023 10:28:06 -0600 Subject: [PATCH] Add websocket command to capture audio from a device (#103936) * Add websocket command to capture audio from a device * Update homeassistant/components/assist_pipeline/pipeline.py Co-authored-by: Paulus Schoutsen * Add device capture test * More tests * Add logbook * Remove unnecessary check * Remove seconds and make logbook message past tense --------- Co-authored-by: Paulus Schoutsen --- .../components/assist_pipeline/__init__.py | 9 +- .../components/assist_pipeline/const.py | 2 + .../components/assist_pipeline/logbook.py | 39 ++ .../components/assist_pipeline/pipeline.py | 62 +++- .../assist_pipeline/websocket_api.py | 119 +++++- .../snapshots/test_websocket.ambr | 113 ++++++ .../assist_pipeline/test_logbook.py | 42 +++ .../assist_pipeline/test_websocket.py | 350 +++++++++++++++++- 8 files changed, 720 insertions(+), 16 deletions(-) create mode 100644 homeassistant/components/assist_pipeline/logbook.py create mode 100644 tests/components/assist_pipeline/test_logbook.py diff --git a/homeassistant/components/assist_pipeline/__init__.py b/homeassistant/components/assist_pipeline/__init__.py index 64fe9e1f5f4..6d00f26ee15 100644 --- a/homeassistant/components/assist_pipeline/__init__.py +++ b/homeassistant/components/assist_pipeline/__init__.py @@ -9,7 +9,13 @@ from homeassistant.components import stt from homeassistant.core import Context, HomeAssistant from homeassistant.helpers.typing import ConfigType -from .const import CONF_DEBUG_RECORDING_DIR, DATA_CONFIG, DATA_LAST_WAKE_UP, DOMAIN +from .const import ( + CONF_DEBUG_RECORDING_DIR, + DATA_CONFIG, + DATA_LAST_WAKE_UP, + DOMAIN, + EVENT_RECORDING, +) from .error import PipelineNotFound from .pipeline import ( AudioSettings, @@ -40,6 +46,7 @@ __all__ = ( "PipelineEventType", "PipelineNotFound", "WakeWordSettings", + "EVENT_RECORDING", ) CONFIG_SCHEMA = vol.Schema( diff --git a/homeassistant/components/assist_pipeline/const.py b/homeassistant/components/assist_pipeline/const.py index 84b49fc18fa..091b19db69e 100644 --- a/homeassistant/components/assist_pipeline/const.py +++ b/homeassistant/components/assist_pipeline/const.py @@ -11,3 +11,5 @@ CONF_DEBUG_RECORDING_DIR = "debug_recording_dir" DATA_LAST_WAKE_UP = f"{DOMAIN}.last_wake_up" DEFAULT_WAKE_WORD_COOLDOWN = 2 # seconds + +EVENT_RECORDING = f"{DOMAIN}_recording" diff --git a/homeassistant/components/assist_pipeline/logbook.py b/homeassistant/components/assist_pipeline/logbook.py new file mode 100644 index 00000000000..f2cfb8d3d5e --- /dev/null +++ b/homeassistant/components/assist_pipeline/logbook.py @@ -0,0 +1,39 @@ +"""Describe assist_pipeline logbook events.""" +from __future__ import annotations + +from collections.abc import Callable + +from homeassistant.components.logbook import LOGBOOK_ENTRY_MESSAGE, LOGBOOK_ENTRY_NAME +from homeassistant.const import ATTR_DEVICE_ID +from homeassistant.core import Event, HomeAssistant, callback +import homeassistant.helpers.device_registry as dr + +from .const import DOMAIN, EVENT_RECORDING + + +@callback +def async_describe_events( + hass: HomeAssistant, + async_describe_event: Callable[[str, str, Callable[[Event], dict[str, str]]], None], +) -> None: + """Describe logbook events.""" + device_registry = dr.async_get(hass) + + @callback + def async_describe_logbook_event(event: Event) -> dict[str, str]: + """Describe logbook event.""" + device: dr.DeviceEntry | None = None + device_name: str = "Unknown device" + + device = device_registry.devices[event.data[ATTR_DEVICE_ID]] + if device: + device_name = device.name_by_user or device.name or "Unknown device" + + message = f"{device_name} started recording audio" + + return { + LOGBOOK_ENTRY_NAME: device_name, + LOGBOOK_ENTRY_MESSAGE: message, + } + + async_describe_event(DOMAIN, EVENT_RECORDING, async_describe_logbook_event) diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index c6d0f6c5435..71e93371257 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -503,6 +503,9 @@ class PipelineRun: audio_processor_buffer: AudioBuffer = field(init=False, repr=False) """Buffer used when splitting audio into chunks for audio processing""" + _device_id: str | None = None + """Optional device id set during run start.""" + def __post_init__(self) -> None: """Set language for pipeline.""" self.language = self.pipeline.language or self.hass.config.language @@ -554,7 +557,8 @@ class PipelineRun: def start(self, device_id: str | None) -> None: """Emit run start event.""" - self._start_debug_recording_thread(device_id) + self._device_id = device_id + self._start_debug_recording_thread() data = { "pipeline": self.pipeline.id, @@ -567,6 +571,9 @@ class PipelineRun: async def end(self) -> None: """Emit run end event.""" + # Signal end of stream to listeners + self._capture_chunk(None) + # Stop the recording thread before emitting run-end. # This ensures that files are properly closed if the event handler reads them. await self._stop_debug_recording_thread() @@ -746,9 +753,7 @@ class PipelineRun: if self.abort_wake_word_detection: raise WakeWordDetectionAborted - if self.debug_recording_queue is not None: - self.debug_recording_queue.put_nowait(chunk.audio) - + self._capture_chunk(chunk.audio) yield chunk.audio, chunk.timestamp_ms # Wake-word-detection occurs *after* the wake word was actually @@ -870,8 +875,7 @@ class PipelineRun: chunk_seconds = AUDIO_PROCESSOR_SAMPLES / sample_rate sent_vad_start = False async for chunk in audio_stream: - if self.debug_recording_queue is not None: - self.debug_recording_queue.put_nowait(chunk.audio) + self._capture_chunk(chunk.audio) if stt_vad is not None: if not stt_vad.process(chunk_seconds, chunk.is_speech): @@ -1057,7 +1061,28 @@ class PipelineRun: return tts_media.url - def _start_debug_recording_thread(self, device_id: str | None) -> None: + def _capture_chunk(self, audio_bytes: bytes | None) -> None: + """Forward audio chunk to various capturing mechanisms.""" + if self.debug_recording_queue is not None: + # Forward to debug WAV file recording + self.debug_recording_queue.put_nowait(audio_bytes) + + if self._device_id is None: + return + + # Forward to device audio capture + pipeline_data: PipelineData = self.hass.data[DOMAIN] + audio_queue = pipeline_data.device_audio_queues.get(self._device_id) + if audio_queue is None: + return + + try: + audio_queue.queue.put_nowait(audio_bytes) + except asyncio.QueueFull: + audio_queue.overflow = True + _LOGGER.warning("Audio queue full for device %s", self._device_id) + + def _start_debug_recording_thread(self) -> None: """Start thread to record wake/stt audio if debug_recording_dir is set.""" if self.debug_recording_thread is not None: # Already started @@ -1068,7 +1093,7 @@ class PipelineRun: if debug_recording_dir := self.hass.data[DATA_CONFIG].get( CONF_DEBUG_RECORDING_DIR ): - if device_id is None: + if self._device_id is None: # // run_recording_dir = ( Path(debug_recording_dir) @@ -1079,7 +1104,7 @@ class PipelineRun: # /// run_recording_dir = ( Path(debug_recording_dir) - / device_id + / self._device_id / self.pipeline.name / str(time.monotonic_ns()) ) @@ -1100,8 +1125,8 @@ class PipelineRun: # Not running return - # Signal thread to stop gracefully - self.debug_recording_queue.put(None) + # NOTE: Expecting a None to have been put in self.debug_recording_queue + # in self.end() to signal the thread to stop. # Wait until the thread has finished to ensure that files are fully written await self.hass.async_add_executor_job(self.debug_recording_thread.join) @@ -1632,6 +1657,20 @@ class PipelineRuns: pipeline_run.abort_wake_word_detection = True +@dataclass +class DeviceAudioQueue: + """Audio capture queue for a satellite device.""" + + queue: asyncio.Queue[bytes | None] + """Queue of audio chunks (None = stop signal)""" + + id: str = field(default_factory=ulid_util.ulid) + """Unique id to ensure the correct audio queue is cleaned up in websocket API.""" + + overflow: bool = False + """Flag to be set if audio samples were dropped because the queue was full.""" + + class PipelineData: """Store and debug data stored in hass.data.""" @@ -1641,6 +1680,7 @@ class PipelineData: self.pipeline_debug: dict[str, LimitedSizeDict[str, PipelineRunDebug]] = {} self.pipeline_devices: set[str] = set() self.pipeline_runs = PipelineRuns(pipeline_store) + self.device_audio_queues: dict[str, DeviceAudioQueue] = {} @dataclass diff --git a/homeassistant/components/assist_pipeline/websocket_api.py b/homeassistant/components/assist_pipeline/websocket_api.py index fda3e266490..6bfe969dc3e 100644 --- a/homeassistant/components/assist_pipeline/websocket_api.py +++ b/homeassistant/components/assist_pipeline/websocket_api.py @@ -3,22 +3,31 @@ import asyncio # Suppressing disable=deprecated-module is needed for Python 3.11 import audioop # pylint: disable=deprecated-module +import base64 from collections.abc import AsyncGenerator, Callable +import contextlib import logging -from typing import Any +import math +from typing import Any, Final import voluptuous as vol from homeassistant.components import conversation, stt, tts, websocket_api -from homeassistant.const import MATCH_ALL +from homeassistant.const import ATTR_DEVICE_ID, ATTR_SECONDS, MATCH_ALL from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import config_validation as cv from homeassistant.util import language as language_util -from .const import DEFAULT_PIPELINE_TIMEOUT, DEFAULT_WAKE_WORD_TIMEOUT, DOMAIN +from .const import ( + DEFAULT_PIPELINE_TIMEOUT, + DEFAULT_WAKE_WORD_TIMEOUT, + DOMAIN, + EVENT_RECORDING, +) from .error import PipelineNotFound from .pipeline import ( AudioSettings, + DeviceAudioQueue, PipelineData, PipelineError, PipelineEvent, @@ -32,6 +41,11 @@ from .pipeline import ( _LOGGER = logging.getLogger(__name__) +CAPTURE_RATE: Final = 16000 +CAPTURE_WIDTH: Final = 2 +CAPTURE_CHANNELS: Final = 1 +MAX_CAPTURE_TIMEOUT: Final = 60.0 + @callback def async_register_websocket_api(hass: HomeAssistant) -> None: @@ -40,6 +54,7 @@ def async_register_websocket_api(hass: HomeAssistant) -> None: websocket_api.async_register_command(hass, websocket_list_languages) websocket_api.async_register_command(hass, websocket_list_runs) websocket_api.async_register_command(hass, websocket_get_run) + websocket_api.async_register_command(hass, websocket_device_capture) @websocket_api.websocket_command( @@ -371,3 +386,101 @@ async def websocket_list_languages( else pipeline_languages }, ) + + +@websocket_api.require_admin +@websocket_api.websocket_command( + { + vol.Required("type"): "assist_pipeline/device/capture", + vol.Required("device_id"): str, + vol.Required("timeout"): vol.All( + # 0 < timeout <= MAX_CAPTURE_TIMEOUT + vol.Coerce(float), + vol.Range(min=0, min_included=False, max=MAX_CAPTURE_TIMEOUT), + ), + } +) +@websocket_api.async_response +async def websocket_device_capture( + hass: HomeAssistant, + connection: websocket_api.connection.ActiveConnection, + msg: dict[str, Any], +) -> None: + """Capture raw audio from a satellite device and forward to client.""" + pipeline_data: PipelineData = hass.data[DOMAIN] + device_id = msg["device_id"] + + # Number of seconds to record audio in wall clock time + timeout_seconds = msg["timeout"] + + # We don't know the chunk size, so the upper bound is calculated assuming a + # single sample (16 bits) per queue item. + max_queue_items = ( + # +1 for None to signal end + int(math.ceil(timeout_seconds * CAPTURE_RATE)) + + 1 + ) + + audio_queue = DeviceAudioQueue(queue=asyncio.Queue(maxsize=max_queue_items)) + + # Running simultaneous captures for a single device will not work by design. + # The new capture will cause the old capture to stop. + if ( + old_audio_queue := pipeline_data.device_audio_queues.pop(device_id, None) + ) is not None: + with contextlib.suppress(asyncio.QueueFull): + # Signal other websocket command that we're taking over + old_audio_queue.queue.put_nowait(None) + + # Only one client can be capturing audio at a time + pipeline_data.device_audio_queues[device_id] = audio_queue + + def clean_up_queue() -> None: + # Clean up our audio queue + maybe_audio_queue = pipeline_data.device_audio_queues.get(device_id) + if (maybe_audio_queue is not None) and (maybe_audio_queue.id == audio_queue.id): + # Only pop if this is our queue + pipeline_data.device_audio_queues.pop(device_id) + + # Unsubscribe cleans up queue + connection.subscriptions[msg["id"]] = clean_up_queue + + # Audio will follow as events + connection.send_result(msg["id"]) + + # Record to logbook + hass.bus.async_fire( + EVENT_RECORDING, + { + ATTR_DEVICE_ID: device_id, + ATTR_SECONDS: timeout_seconds, + }, + ) + + try: + with contextlib.suppress(asyncio.TimeoutError): + async with asyncio.timeout(timeout_seconds): + while True: + # Send audio chunks encoded as base64 + audio_bytes = await audio_queue.queue.get() + if audio_bytes is None: + # Signal to stop + break + + connection.send_event( + msg["id"], + { + "type": "audio", + "rate": CAPTURE_RATE, # hertz + "width": CAPTURE_WIDTH, # bytes + "channels": CAPTURE_CHANNELS, + "audio": base64.b64encode(audio_bytes).decode("ascii"), + }, + ) + + # Capture has ended + connection.send_event( + msg["id"], {"type": "end", "overflow": audio_queue.overflow} + ) + finally: + clean_up_queue() diff --git a/tests/components/assist_pipeline/snapshots/test_websocket.ambr b/tests/components/assist_pipeline/snapshots/test_websocket.ambr index 9eb7e1e5a05..1f625528806 100644 --- a/tests/components/assist_pipeline/snapshots/test_websocket.ambr +++ b/tests/components/assist_pipeline/snapshots/test_websocket.ambr @@ -487,6 +487,119 @@ # name: test_audio_pipeline_with_wake_word_timeout.3 None # --- +# name: test_device_capture + dict({ + 'language': 'en', + 'pipeline': , + 'runner_data': dict({ + 'stt_binary_handler_id': 1, + 'timeout': 300, + }), + }) +# --- +# name: test_device_capture.1 + dict({ + 'engine': 'test', + 'metadata': dict({ + 'bit_rate': 16, + 'channel': 1, + 'codec': 'pcm', + 'format': 'wav', + 'language': 'en-US', + 'sample_rate': 16000, + }), + }) +# --- +# name: test_device_capture.2 + None +# --- +# name: test_device_capture_override + dict({ + 'language': 'en', + 'pipeline': , + 'runner_data': dict({ + 'stt_binary_handler_id': 1, + 'timeout': 300, + }), + }) +# --- +# name: test_device_capture_override.1 + dict({ + 'engine': 'test', + 'metadata': dict({ + 'bit_rate': 16, + 'channel': 1, + 'codec': 'pcm', + 'format': 'wav', + 'language': 'en-US', + 'sample_rate': 16000, + }), + }) +# --- +# name: test_device_capture_override.2 + dict({ + 'audio': 'Y2h1bmsx', + 'channels': 1, + 'rate': 16000, + 'type': 'audio', + 'width': 2, + }) +# --- +# name: test_device_capture_override.3 + dict({ + 'stt_output': dict({ + 'text': 'test transcript', + }), + }) +# --- +# name: test_device_capture_override.4 + None +# --- +# name: test_device_capture_override.5 + dict({ + 'overflow': False, + 'type': 'end', + }) +# --- +# name: test_device_capture_queue_full + dict({ + 'language': 'en', + 'pipeline': , + 'runner_data': dict({ + 'stt_binary_handler_id': 1, + 'timeout': 300, + }), + }) +# --- +# name: test_device_capture_queue_full.1 + dict({ + 'engine': 'test', + 'metadata': dict({ + 'bit_rate': 16, + 'channel': 1, + 'codec': 'pcm', + 'format': 'wav', + 'language': 'en-US', + 'sample_rate': 16000, + }), + }) +# --- +# name: test_device_capture_queue_full.2 + dict({ + 'stt_output': dict({ + 'text': 'test transcript', + }), + }) +# --- +# name: test_device_capture_queue_full.3 + None +# --- +# name: test_device_capture_queue_full.4 + dict({ + 'overflow': True, + 'type': 'end', + }) +# --- # name: test_intent_failed dict({ 'language': 'en', diff --git a/tests/components/assist_pipeline/test_logbook.py b/tests/components/assist_pipeline/test_logbook.py new file mode 100644 index 00000000000..6a997236f1c --- /dev/null +++ b/tests/components/assist_pipeline/test_logbook.py @@ -0,0 +1,42 @@ +"""The tests for assist_pipeline logbook.""" +from homeassistant.components import assist_pipeline, logbook +from homeassistant.const import ATTR_DEVICE_ID +from homeassistant.core import HomeAssistant +from homeassistant.helpers import device_registry as dr +from homeassistant.setup import async_setup_component + +from tests.common import MockConfigEntry +from tests.components.logbook.common import MockRow, mock_humanify + + +async def test_recording_event( + hass: HomeAssistant, init_components, device_registry: dr.DeviceRegistry +) -> None: + """Test recording event.""" + hass.config.components.add("recorder") + assert await async_setup_component(hass, "logbook", {}) + + entry = MockConfigEntry() + entry.add_to_hass(hass) + satellite_device = device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections=set(), + identifiers={("demo", "satellite-1234")}, + ) + + device_registry.async_update_device(satellite_device.id, name="My Satellite") + event = mock_humanify( + hass, + [ + MockRow( + assist_pipeline.EVENT_RECORDING, + {ATTR_DEVICE_ID: satellite_device.id}, + ), + ], + )[0] + + assert event[logbook.LOGBOOK_ENTRY_NAME] == "My Satellite" + assert event[logbook.LOGBOOK_ENTRY_DOMAIN] == assist_pipeline.DOMAIN + assert ( + event[logbook.LOGBOOK_ENTRY_MESSAGE] == "My Satellite started recording audio" + ) diff --git a/tests/components/assist_pipeline/test_websocket.py b/tests/components/assist_pipeline/test_websocket.py index 9a4e78a29af..931b31dd77b 100644 --- a/tests/components/assist_pipeline/test_websocket.py +++ b/tests/components/assist_pipeline/test_websocket.py @@ -1,16 +1,23 @@ """Websocket tests for Voice Assistant integration.""" import asyncio +import base64 from unittest.mock import ANY, patch from syrupy.assertion import SnapshotAssertion from homeassistant.components.assist_pipeline.const import DOMAIN -from homeassistant.components.assist_pipeline.pipeline import Pipeline, PipelineData +from homeassistant.components.assist_pipeline.pipeline import ( + DeviceAudioQueue, + Pipeline, + PipelineData, +) from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import device_registry as dr from .conftest import MockWakeWordEntity, MockWakeWordEntity2 +from tests.common import MockConfigEntry from tests.typing import WebSocketGenerator @@ -2104,3 +2111,344 @@ async def test_wake_word_cooldown_different_entities( # Wake words should be the same assert ww_id_1 == ww_id_2 + + +async def test_device_capture( + hass: HomeAssistant, + init_components, + hass_ws_client: WebSocketGenerator, + device_registry: dr.DeviceRegistry, + snapshot: SnapshotAssertion, +) -> None: + """Test audio capture from a satellite device.""" + entry = MockConfigEntry() + entry.add_to_hass(hass) + satellite_device = device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections=set(), + identifiers={("demo", "satellite-1234")}, + ) + + audio_chunks = [b"chunk1", b"chunk2", b"chunk3"] + + # Start capture + client_capture = await hass_ws_client(hass) + await client_capture.send_json_auto_id( + { + "type": "assist_pipeline/device/capture", + "timeout": 30, + "device_id": satellite_device.id, + } + ) + + # result + msg = await client_capture.receive_json() + assert msg["success"] + + # Run pipeline + client_pipeline = await hass_ws_client(hass) + await client_pipeline.send_json_auto_id( + { + "type": "assist_pipeline/run", + "start_stage": "stt", + "end_stage": "stt", + "input": { + "sample_rate": 16000, + "no_vad": True, + "no_chunking": True, + }, + "device_id": satellite_device.id, + } + ) + + # result + msg = await client_pipeline.receive_json() + assert msg["success"] + + # run start + msg = await client_pipeline.receive_json() + assert msg["event"]["type"] == "run-start" + msg["event"]["data"]["pipeline"] = ANY + assert msg["event"]["data"] == snapshot + handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"] + + # stt + msg = await client_pipeline.receive_json() + assert msg["event"]["type"] == "stt-start" + assert msg["event"]["data"] == snapshot + + for audio_chunk in audio_chunks: + await client_pipeline.send_bytes(bytes([handler_id]) + audio_chunk) + + # End of audio stream + await client_pipeline.send_bytes(bytes([handler_id])) + + msg = await client_pipeline.receive_json() + assert msg["event"]["type"] == "stt-end" + + # run end + msg = await client_pipeline.receive_json() + assert msg["event"]["type"] == "run-end" + assert msg["event"]["data"] == snapshot + + # Verify capture + events = [] + async with asyncio.timeout(1): + while True: + msg = await client_capture.receive_json() + assert msg["type"] == "event" + event_data = msg["event"] + events.append(event_data) + if event_data["type"] == "end": + break + + assert len(events) == len(audio_chunks) + 1 + + # Verify audio chunks + for i, audio_chunk in enumerate(audio_chunks): + assert events[i]["type"] == "audio" + assert events[i]["rate"] == 16000 + assert events[i]["width"] == 2 + assert events[i]["channels"] == 1 + + # Audio is base64 encoded + assert events[i]["audio"] == base64.b64encode(audio_chunk).decode("ascii") + + # Last event is the end + assert events[-1]["type"] == "end" + + +async def test_device_capture_override( + hass: HomeAssistant, + init_components, + hass_ws_client: WebSocketGenerator, + device_registry: dr.DeviceRegistry, + snapshot: SnapshotAssertion, +) -> None: + """Test overriding an existing audio capture from a satellite device.""" + entry = MockConfigEntry() + entry.add_to_hass(hass) + satellite_device = device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections=set(), + identifiers={("demo", "satellite-1234")}, + ) + + audio_chunks = [b"chunk1", b"chunk2", b"chunk3"] + + # Start first capture + client_capture_1 = await hass_ws_client(hass) + await client_capture_1.send_json_auto_id( + { + "type": "assist_pipeline/device/capture", + "timeout": 30, + "device_id": satellite_device.id, + } + ) + + # result + msg = await client_capture_1.receive_json() + assert msg["success"] + + # Run pipeline + client_pipeline = await hass_ws_client(hass) + await client_pipeline.send_json_auto_id( + { + "type": "assist_pipeline/run", + "start_stage": "stt", + "end_stage": "stt", + "input": { + "sample_rate": 16000, + "no_vad": True, + "no_chunking": True, + }, + "device_id": satellite_device.id, + } + ) + + # result + msg = await client_pipeline.receive_json() + assert msg["success"] + + # run start + msg = await client_pipeline.receive_json() + assert msg["event"]["type"] == "run-start" + msg["event"]["data"]["pipeline"] = ANY + assert msg["event"]["data"] == snapshot + handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"] + + # stt + msg = await client_pipeline.receive_json() + assert msg["event"]["type"] == "stt-start" + assert msg["event"]["data"] == snapshot + + # Send first audio chunk + await client_pipeline.send_bytes(bytes([handler_id]) + audio_chunks[0]) + + # Verify first capture + msg = await client_capture_1.receive_json() + assert msg["type"] == "event" + assert msg["event"] == snapshot + assert msg["event"]["audio"] == base64.b64encode(audio_chunks[0]).decode("ascii") + + # Start a new capture + client_capture_2 = await hass_ws_client(hass) + await client_capture_2.send_json_auto_id( + { + "type": "assist_pipeline/device/capture", + "timeout": 30, + "device_id": satellite_device.id, + } + ) + + # result (capture 2) + msg = await client_capture_2.receive_json() + assert msg["success"] + + # Send remaining audio chunks + for audio_chunk in audio_chunks[1:]: + await client_pipeline.send_bytes(bytes([handler_id]) + audio_chunk) + + # End of audio stream + await client_pipeline.send_bytes(bytes([handler_id])) + + msg = await client_pipeline.receive_json() + assert msg["event"]["type"] == "stt-end" + assert msg["event"]["data"] == snapshot + + # run end + msg = await client_pipeline.receive_json() + assert msg["event"]["type"] == "run-end" + assert msg["event"]["data"] == snapshot + + # Verify that first capture ended with no more audio + msg = await client_capture_1.receive_json() + assert msg["type"] == "event" + assert msg["event"] == snapshot + assert msg["event"]["type"] == "end" + + # Verify that the second capture got the remaining audio + events = [] + async with asyncio.timeout(1): + while True: + msg = await client_capture_2.receive_json() + assert msg["type"] == "event" + event_data = msg["event"] + events.append(event_data) + if event_data["type"] == "end": + break + + # -1 since first audio chunk went to the first capture + assert len(events) == len(audio_chunks) + + # Verify all but first audio chunk + for i, audio_chunk in enumerate(audio_chunks[1:]): + assert events[i]["type"] == "audio" + assert events[i]["rate"] == 16000 + assert events[i]["width"] == 2 + assert events[i]["channels"] == 1 + + # Audio is base64 encoded + assert events[i]["audio"] == base64.b64encode(audio_chunk).decode("ascii") + + # Last event is the end + assert events[-1]["type"] == "end" + + +async def test_device_capture_queue_full( + hass: HomeAssistant, + init_components, + hass_ws_client: WebSocketGenerator, + device_registry: dr.DeviceRegistry, + snapshot: SnapshotAssertion, +) -> None: + """Test audio capture from a satellite device when the recording queue fills up.""" + entry = MockConfigEntry() + entry.add_to_hass(hass) + satellite_device = device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections=set(), + identifiers={("demo", "satellite-1234")}, + ) + + class FakeQueue(asyncio.Queue): + """Queue that reports full for anything but None.""" + + def put_nowait(self, item): + if item is not None: + raise asyncio.QueueFull() + + super().put_nowait(item) + + with patch( + "homeassistant.components.assist_pipeline.websocket_api.DeviceAudioQueue" + ) as mock: + mock.return_value = DeviceAudioQueue(queue=FakeQueue()) + + # Start capture + client_capture = await hass_ws_client(hass) + await client_capture.send_json_auto_id( + { + "type": "assist_pipeline/device/capture", + "timeout": 30, + "device_id": satellite_device.id, + } + ) + + # result + msg = await client_capture.receive_json() + assert msg["success"] + + # Run pipeline + client_pipeline = await hass_ws_client(hass) + await client_pipeline.send_json_auto_id( + { + "type": "assist_pipeline/run", + "start_stage": "stt", + "end_stage": "stt", + "input": { + "sample_rate": 16000, + "no_vad": True, + "no_chunking": True, + }, + "device_id": satellite_device.id, + } + ) + + # result + msg = await client_pipeline.receive_json() + assert msg["success"] + + # run start + msg = await client_pipeline.receive_json() + assert msg["event"]["type"] == "run-start" + msg["event"]["data"]["pipeline"] = ANY + assert msg["event"]["data"] == snapshot + handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"] + + # stt + msg = await client_pipeline.receive_json() + assert msg["event"]["type"] == "stt-start" + assert msg["event"]["data"] == snapshot + + # Single sample will "overflow" the queue + await client_pipeline.send_bytes(bytes([handler_id, 0, 0])) + + # End of audio stream + await client_pipeline.send_bytes(bytes([handler_id])) + + msg = await client_pipeline.receive_json() + assert msg["event"]["type"] == "stt-end" + assert msg["event"]["data"] == snapshot + + msg = await client_pipeline.receive_json() + assert msg["event"]["type"] == "run-end" + assert msg["event"]["data"] == snapshot + + # Queue should have been overflowed + async with asyncio.timeout(1): + msg = await client_capture.receive_json() + assert msg["type"] == "event" + assert msg["event"] == snapshot + assert msg["event"]["type"] == "end" + assert msg["event"]["overflow"]