From 0ecd23baeebb97fde241ba42bf4ec09dd5c71802 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Mon, 17 Apr 2023 17:48:02 +0200 Subject: [PATCH] Add WS API for debugging previous assist_pipeline runs (#91541) * Add WS API for debugging previous assist_pipeline runs * Improve typing --- .../components/assist_pipeline/pipeline.py | 56 ++- .../assist_pipeline/websocket_api.py | 78 +++- homeassistant/components/trace/__init__.py | 2 +- .../utils.py => util/limited_size_dict.py} | 0 .../snapshots/test_websocket.ambr | 73 ++++ .../assist_pipeline/test_pipeline.py | 7 +- .../assist_pipeline/test_websocket.py | 380 +++++++++++++++++- 7 files changed, 564 insertions(+), 32 deletions(-) rename homeassistant/{components/trace/utils.py => util/limited_size_dict.py} (100%) diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index 07581d10b87..9f5f28ec727 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -24,6 +24,7 @@ from homeassistant.helpers.collection import ( ) from homeassistant.helpers.storage import Store from homeassistant.util import dt as dt_util, ulid as ulid_util +from homeassistant.util.limited_size_dict import LimitedSizeDict from .const import DOMAIN from .error import ( @@ -46,6 +47,8 @@ STORAGE_FIELDS = { vol.Required("tts_engine"): str, } +STORED_PIPELINE_RUNS = 10 + SAVE_DELAY = 10 @@ -53,14 +56,14 @@ async def async_get_pipeline( hass: HomeAssistant, pipeline_id: str | None = None, language: str | None = None ) -> Pipeline | None: """Get a pipeline by id or create one for a language.""" - pipeline_store: PipelineStorageCollection = hass.data[DOMAIN] + pipeline_data: PipelineData = hass.data[DOMAIN] if pipeline_id is not None: - return pipeline_store.data.get(pipeline_id) + return pipeline_data.pipeline_store.data.get(pipeline_id) # Construct a pipeline for the required/configured language language = language or hass.config.language - return await pipeline_store.async_create_item( + return await pipeline_data.pipeline_store.async_create_item( { "name": language, "language": language, @@ -171,6 +174,8 @@ class PipelineRun: tts_engine: str | None = None tts_options: dict | None = None + id: str = field(default_factory=ulid_util.ulid) + def __post_init__(self) -> None: """Set language for pipeline.""" self.language = self.pipeline.language or self.hass.config.language @@ -181,6 +186,23 @@ class PipelineRun: ): raise InvalidPipelineStagesError(self.start_stage, self.end_stage) + pipeline_data: PipelineData = self.hass.data[DOMAIN] + if self.pipeline.id not in pipeline_data.pipeline_runs: + pipeline_data.pipeline_runs[self.pipeline.id] = LimitedSizeDict( + size_limit=STORED_PIPELINE_RUNS + ) + pipeline_data.pipeline_runs[self.pipeline.id][self.id] = [] + + @callback + def process_event(self, event: PipelineEvent) -> None: + """Log an event and call listener.""" + self.event_callback(event) + pipeline_data: PipelineData = self.hass.data[DOMAIN] + if self.id not in pipeline_data.pipeline_runs[self.pipeline.id]: + # This run has been evicted from the logged pipeline runs already + return + pipeline_data.pipeline_runs[self.pipeline.id][self.id].append(event) + def start(self) -> None: """Emit run start event.""" data = { @@ -190,11 +212,11 @@ class PipelineRun: if self.runner_data is not None: data["runner_data"] = self.runner_data - self.event_callback(PipelineEvent(PipelineEventType.RUN_START, data)) + self.process_event(PipelineEvent(PipelineEventType.RUN_START, data)) def end(self) -> None: """Emit run end event.""" - self.event_callback( + self.process_event( PipelineEvent( PipelineEventType.RUN_END, ) @@ -233,7 +255,7 @@ class PipelineRun: engine = self.stt_provider.name - self.event_callback( + self.process_event( PipelineEvent( PipelineEventType.STT_START, { @@ -268,7 +290,7 @@ class PipelineRun: code="stt-no-text-recognized", message="No text recognized" ) - self.event_callback( + self.process_event( PipelineEvent( PipelineEventType.STT_END, { @@ -306,7 +328,7 @@ class PipelineRun: if self.intent_agent is None: raise RuntimeError("Recognize intent was not prepared") - self.event_callback( + self.process_event( PipelineEvent( PipelineEventType.INTENT_START, { @@ -334,7 +356,7 @@ class PipelineRun: _LOGGER.debug("conversation result %s", conversation_result) - self.event_callback( + self.process_event( PipelineEvent( PipelineEventType.INTENT_END, {"intent_output": conversation_result.as_dict()}, @@ -379,7 +401,7 @@ class PipelineRun: if self.tts_engine is None: raise RuntimeError("Text to speech was not prepared") - self.event_callback( + self.process_event( PipelineEvent( PipelineEventType.TTS_START, { @@ -412,7 +434,7 @@ class PipelineRun: _LOGGER.debug("TTS result %s", tts_media) - self.event_callback( + self.process_event( PipelineEvent( PipelineEventType.TTS_END, { @@ -480,7 +502,7 @@ class PipelineInput: await self.run.text_to_speech(tts_input) except PipelineError as err: - self.run.event_callback( + self.run.process_event( PipelineEvent( PipelineEventType.ERROR, {"code": err.code, "message": err.message}, @@ -691,6 +713,14 @@ class PipelineStorageCollectionWebsocket( connection.send_result(msg["id"]) +@dataclass +class PipelineData: + """Store and debug data stored in hass.data.""" + + pipeline_runs: dict[str, LimitedSizeDict[str, list[PipelineEvent]]] + pipeline_store: PipelineStorageCollection + + async def async_setup_pipeline_store(hass: HomeAssistant) -> None: """Set up the pipeline storage collection.""" pipeline_store = PipelineStorageCollection( @@ -700,4 +730,4 @@ async def async_setup_pipeline_store(hass: HomeAssistant) -> None: PipelineStorageCollectionWebsocket( pipeline_store, f"{DOMAIN}/pipeline", "pipeline", STORAGE_FIELDS, STORAGE_FIELDS ).async_setup(hass) - hass.data[DOMAIN] = pipeline_store + hass.data[DOMAIN] = PipelineData({}, pipeline_store) diff --git a/homeassistant/components/assist_pipeline/websocket_api.py b/homeassistant/components/assist_pipeline/websocket_api.py index 93b7c47681f..c99e92a1993 100644 --- a/homeassistant/components/assist_pipeline/websocket_api.py +++ b/homeassistant/components/assist_pipeline/websocket_api.py @@ -12,7 +12,9 @@ from homeassistant.components import stt, websocket_api from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import config_validation as cv +from .const import DOMAIN from .pipeline import ( + PipelineData, PipelineError, PipelineEvent, PipelineEventType, @@ -69,6 +71,8 @@ def async_register_websocket_api(hass: HomeAssistant) -> None: ), ), ) + websocket_api.async_register_command(hass, websocket_list_runs) + websocket_api.async_register_command(hass, websocket_get_run) @websocket_api.async_response @@ -193,14 +197,82 @@ async def websocket_run( async with async_timeout.timeout(timeout): await run_task except asyncio.TimeoutError: - connection.send_event( - msg["id"], + pipeline_input.run.process_event( PipelineEvent( PipelineEventType.ERROR, {"code": "timeout", "message": "Timeout running pipeline"}, - ), + ) ) finally: if unregister_handler is not None: # Unregister binary handler unregister_handler() + + +@callback +@websocket_api.require_admin +@websocket_api.websocket_command( + { + vol.Required("type"): "assist_pipeline/pipeline_debug/list", + vol.Required("pipeline_id"): str, + } +) +def websocket_list_runs( + hass: HomeAssistant, + connection: websocket_api.connection.ActiveConnection, + msg: dict[str, Any], +) -> None: + """List pipeline runs for which debug data is available.""" + pipeline_data: PipelineData = hass.data[DOMAIN] + pipeline_id = msg["pipeline_id"] + + if pipeline_id not in pipeline_data.pipeline_runs: + connection.send_result(msg["id"], {"pipeline_runs": []}) + return + + pipeline_runs = pipeline_data.pipeline_runs[pipeline_id] + + connection.send_result(msg["id"], {"pipeline_runs": list(pipeline_runs)}) + + +@callback +@websocket_api.require_admin +@websocket_api.websocket_command( + { + vol.Required("type"): "assist_pipeline/pipeline_debug/get", + vol.Required("pipeline_id"): str, + vol.Required("pipeline_run_id"): str, + } +) +def websocket_get_run( + hass: HomeAssistant, + connection: websocket_api.connection.ActiveConnection, + msg: dict[str, Any], +) -> None: + """Get debug data for a pipeline run.""" + pipeline_data: PipelineData = hass.data[DOMAIN] + pipeline_id = msg["pipeline_id"] + pipeline_run_id = msg["pipeline_run_id"] + + if pipeline_id not in pipeline_data.pipeline_runs: + connection.send_error( + msg["id"], + websocket_api.const.ERR_NOT_FOUND, + f"pipeline_id {pipeline_id} not found", + ) + return + + pipeline_runs = pipeline_data.pipeline_runs[pipeline_id] + + if pipeline_run_id not in pipeline_runs: + connection.send_error( + msg["id"], + websocket_api.const.ERR_NOT_FOUND, + f"pipeline_run_id {pipeline_run_id} not found", + ) + return + + connection.send_result( + msg["id"], + {"events": pipeline_runs[pipeline_run_id]}, + ) diff --git a/homeassistant/components/trace/__init__.py b/homeassistant/components/trace/__init__.py index 3d7510b57b2..5d0b188f724 100644 --- a/homeassistant/components/trace/__init__.py +++ b/homeassistant/components/trace/__init__.py @@ -14,6 +14,7 @@ import homeassistant.helpers.config_validation as cv from homeassistant.helpers.json import ExtendedJSONEncoder from homeassistant.helpers.storage import Store from homeassistant.helpers.typing import ConfigType +from homeassistant.util.limited_size_dict import LimitedSizeDict from . import websocket_api from .const import ( @@ -24,7 +25,6 @@ from .const import ( DEFAULT_STORED_TRACES, ) from .models import ActionTrace, BaseTrace, RestoredTrace -from .utils import LimitedSizeDict _LOGGER = logging.getLogger(__name__) diff --git a/homeassistant/components/trace/utils.py b/homeassistant/util/limited_size_dict.py similarity index 100% rename from homeassistant/components/trace/utils.py rename to homeassistant/util/limited_size_dict.py diff --git a/tests/components/assist_pipeline/snapshots/test_websocket.ambr b/tests/components/assist_pipeline/snapshots/test_websocket.ambr index ad7d7f570ce..fcda918dd0d 100644 --- a/tests/components/assist_pipeline/snapshots/test_websocket.ambr +++ b/tests/components/assist_pipeline/snapshots/test_websocket.ambr @@ -72,6 +72,79 @@ }), }) # --- +# name: test_audio_pipeline_debug + dict({ + 'language': 'en-US', + 'pipeline': 'en-US', + 'runner_data': dict({ + 'stt_binary_handler_id': 1, + 'timeout': 30, + }), + }) +# --- +# name: test_audio_pipeline_debug.1 + dict({ + 'engine': 'test', + 'metadata': dict({ + 'bit_rate': 16, + 'channel': 1, + 'codec': 'pcm', + 'format': 'wav', + 'language': 'en-US', + 'sample_rate': 16000, + }), + }) +# --- +# name: test_audio_pipeline_debug.2 + dict({ + 'stt_output': dict({ + 'text': 'test transcript', + }), + }) +# --- +# name: test_audio_pipeline_debug.3 + dict({ + 'engine': 'homeassistant', + 'intent_input': 'test transcript', + }) +# --- +# name: test_audio_pipeline_debug.4 + dict({ + 'intent_output': dict({ + 'conversation_id': None, + 'response': dict({ + 'card': dict({ + }), + 'data': dict({ + 'code': 'no_intent_match', + }), + 'language': 'en-US', + 'response_type': 'error', + 'speech': dict({ + 'plain': dict({ + 'extra_data': None, + 'speech': "Sorry, I couldn't understand that", + }), + }), + }), + }), + }) +# --- +# name: test_audio_pipeline_debug.5 + dict({ + 'engine': 'test', + 'tts_input': "Sorry, I couldn't understand that", + }) +# --- +# name: test_audio_pipeline_debug.6 + dict({ + 'tts_output': dict({ + 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US", + 'mime_type': 'audio/mpeg', + 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3', + }), + }) +# --- # name: test_intent_failed dict({ 'language': 'en-US', diff --git a/tests/components/assist_pipeline/test_pipeline.py b/tests/components/assist_pipeline/test_pipeline.py index 1898e3d2237..6eee84d9e9b 100644 --- a/tests/components/assist_pipeline/test_pipeline.py +++ b/tests/components/assist_pipeline/test_pipeline.py @@ -5,6 +5,7 @@ from homeassistant.components.assist_pipeline.const import DOMAIN from homeassistant.components.assist_pipeline.pipeline import ( STORAGE_KEY, STORAGE_VERSION, + PipelineData, PipelineStorageCollection, ) from homeassistant.core import HomeAssistant @@ -42,7 +43,8 @@ async def test_load_datasets(hass: HomeAssistant, init_components) -> None: ] pipeline_ids = [] - store1: PipelineStorageCollection = hass.data[DOMAIN] + pipeline_data: PipelineData = hass.data[DOMAIN] + store1 = pipeline_data.pipeline_store for pipeline in pipelines: pipeline_ids.append((await store1.async_create_item(pipeline)).id) assert len(store1.data) == 3 @@ -103,6 +105,7 @@ async def test_loading_datasets_from_storage( assert await async_setup_component(hass, "assist_pipeline", {}) - store: PipelineStorageCollection = hass.data[DOMAIN] + pipeline_data: PipelineData = hass.data[DOMAIN] + store = pipeline_data.pipeline_store assert len(store.data) == 3 assert store.async_get_preferred_item() == "01GX8ZWBAQYWNB1XV3EXEZ75DY" diff --git a/tests/components/assist_pipeline/test_websocket.py b/tests/components/assist_pipeline/test_websocket.py index 2b3f66c4159..c34dc187490 100644 --- a/tests/components/assist_pipeline/test_websocket.py +++ b/tests/components/assist_pipeline/test_websocket.py @@ -5,10 +5,7 @@ from unittest.mock import ANY, MagicMock, patch from syrupy.assertion import SnapshotAssertion from homeassistant.components.assist_pipeline.const import DOMAIN -from homeassistant.components.assist_pipeline.pipeline import ( - Pipeline, - PipelineStorageCollection, -) +from homeassistant.components.assist_pipeline.pipeline import Pipeline, PipelineData from homeassistant.core import HomeAssistant from tests.typing import WebSocketGenerator @@ -21,6 +18,7 @@ async def test_text_only_pipeline( snapshot: SnapshotAssertion, ) -> None: """Test events from a pipeline run with text input (no STT/TTS).""" + events = [] client = await hass_ws_client(hass) await client.send_json_auto_id( @@ -40,20 +38,39 @@ async def test_text_only_pipeline( msg = await client.receive_json() assert msg["event"]["type"] == "run-start" assert msg["event"]["data"] == snapshot + events.append(msg["event"]) # intent msg = await client.receive_json() assert msg["event"]["type"] == "intent-start" assert msg["event"]["data"] == snapshot + events.append(msg["event"]) msg = await client.receive_json() assert msg["event"]["type"] == "intent-end" assert msg["event"]["data"] == snapshot + events.append(msg["event"]) # run end msg = await client.receive_json() assert msg["event"]["type"] == "run-end" assert msg["event"]["data"] is None + events.append(msg["event"]) + + pipeline_data: PipelineData = hass.data[DOMAIN] + pipeline_id = list(pipeline_data.pipeline_runs)[0] + pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0] + + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline_debug/get", + "pipeline_id": pipeline_id, + "pipeline_run_id": pipeline_run_id, + } + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == {"events": events} async def test_audio_pipeline( @@ -63,6 +80,7 @@ async def test_audio_pipeline( snapshot: SnapshotAssertion, ) -> None: """Test events from a pipeline run with audio input/output.""" + events = [] client = await hass_ws_client(hass) await client.send_json_auto_id( @@ -84,11 +102,13 @@ async def test_audio_pipeline( msg = await client.receive_json() assert msg["event"]["type"] == "run-start" assert msg["event"]["data"] == snapshot + events.append(msg["event"]) # stt msg = await client.receive_json() assert msg["event"]["type"] == "stt-start" assert msg["event"]["data"] == snapshot + events.append(msg["event"]) # End of audio stream (handler id + empty payload) await client.send_bytes(bytes([1])) @@ -96,29 +116,50 @@ async def test_audio_pipeline( msg = await client.receive_json() assert msg["event"]["type"] == "stt-end" assert msg["event"]["data"] == snapshot + events.append(msg["event"]) # intent msg = await client.receive_json() assert msg["event"]["type"] == "intent-start" assert msg["event"]["data"] == snapshot + events.append(msg["event"]) msg = await client.receive_json() assert msg["event"]["type"] == "intent-end" assert msg["event"]["data"] == snapshot + events.append(msg["event"]) # text to speech msg = await client.receive_json() assert msg["event"]["type"] == "tts-start" assert msg["event"]["data"] == snapshot + events.append(msg["event"]) msg = await client.receive_json() assert msg["event"]["type"] == "tts-end" assert msg["event"]["data"] == snapshot + events.append(msg["event"]) # run end msg = await client.receive_json() assert msg["event"]["type"] == "run-end" assert msg["event"]["data"] is None + events.append(msg["event"]) + + pipeline_data: PipelineData = hass.data[DOMAIN] + pipeline_id = list(pipeline_data.pipeline_runs)[0] + pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0] + + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline_debug/get", + "pipeline_id": pipeline_id, + "pipeline_run_id": pipeline_run_id, + } + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == {"events": events} async def test_intent_timeout( @@ -128,6 +169,7 @@ async def test_intent_timeout( snapshot: SnapshotAssertion, ) -> None: """Test partial pipeline run with conversation agent timeout.""" + events = [] client = await hass_ws_client(hass) async def sleepy_converse(*args, **kwargs): @@ -155,16 +197,34 @@ async def test_intent_timeout( msg = await client.receive_json() assert msg["event"]["type"] == "run-start" assert msg["event"]["data"] == snapshot + events.append(msg["event"]) # intent msg = await client.receive_json() assert msg["event"]["type"] == "intent-start" assert msg["event"]["data"] == snapshot + events.append(msg["event"]) # timeout error msg = await client.receive_json() assert msg["event"]["type"] == "error" assert msg["event"]["data"] == snapshot + events.append(msg["event"]) + + pipeline_data: PipelineData = hass.data[DOMAIN] + pipeline_id = list(pipeline_data.pipeline_runs)[0] + pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0] + + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline_debug/get", + "pipeline_id": pipeline_id, + "pipeline_run_id": pipeline_run_id, + } + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == {"events": events} async def test_text_pipeline_timeout( @@ -174,6 +234,7 @@ async def test_text_pipeline_timeout( snapshot: SnapshotAssertion, ) -> None: """Test text-only pipeline run with immediate timeout.""" + events = [] client = await hass_ws_client(hass) async def sleepy_run(*args, **kwargs): @@ -201,6 +262,22 @@ async def test_text_pipeline_timeout( msg = await client.receive_json() assert msg["event"]["type"] == "error" assert msg["event"]["data"] == snapshot + events.append(msg["event"]) + + pipeline_data: PipelineData = hass.data[DOMAIN] + pipeline_id = list(pipeline_data.pipeline_runs)[0] + pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0] + + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline_debug/get", + "pipeline_id": pipeline_id, + "pipeline_run_id": pipeline_run_id, + } + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == {"events": events} async def test_intent_failed( @@ -210,6 +287,7 @@ async def test_intent_failed( snapshot: SnapshotAssertion, ) -> None: """Test text-only pipeline run with conversation agent error.""" + events = [] client = await hass_ws_client(hass) with patch( @@ -233,16 +311,34 @@ async def test_intent_failed( msg = await client.receive_json() assert msg["event"]["type"] == "run-start" assert msg["event"]["data"] == snapshot + events.append(msg["event"]) # intent start msg = await client.receive_json() assert msg["event"]["type"] == "intent-start" assert msg["event"]["data"] == snapshot + events.append(msg["event"]) # intent error msg = await client.receive_json() assert msg["event"]["type"] == "error" assert msg["event"]["data"]["code"] == "intent-failed" + events.append(msg["event"]) + + pipeline_data: PipelineData = hass.data[DOMAIN] + pipeline_id = list(pipeline_data.pipeline_runs)[0] + pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0] + + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline_debug/get", + "pipeline_id": pipeline_id, + "pipeline_run_id": pipeline_run_id, + } + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == {"events": events} async def test_audio_pipeline_timeout( @@ -252,6 +348,7 @@ async def test_audio_pipeline_timeout( snapshot: SnapshotAssertion, ) -> None: """Test audio pipeline run with immediate timeout.""" + events = [] client = await hass_ws_client(hass) async def sleepy_run(*args, **kwargs): @@ -281,6 +378,22 @@ async def test_audio_pipeline_timeout( msg = await client.receive_json() assert msg["event"]["type"] == "error" assert msg["event"]["data"]["code"] == "timeout" + events.append(msg["event"]) + + pipeline_data: PipelineData = hass.data[DOMAIN] + pipeline_id = list(pipeline_data.pipeline_runs)[0] + pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0] + + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline_debug/get", + "pipeline_id": pipeline_id, + "pipeline_run_id": pipeline_run_id, + } + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == {"events": events} async def test_stt_provider_missing( @@ -320,12 +433,13 @@ async def test_stt_stream_failed( snapshot: SnapshotAssertion, ) -> None: """Test events from a pipeline run with a non-existent STT provider.""" + events = [] + client = await hass_ws_client(hass) + with patch( "tests.components.assist_pipeline.conftest.MockSttProvider.async_process_audio_stream", new=MagicMock(side_effect=RuntimeError), ): - client = await hass_ws_client(hass) - await client.send_json_auto_id( { "type": "assist_pipeline/run", @@ -345,11 +459,13 @@ async def test_stt_stream_failed( msg = await client.receive_json() assert msg["event"]["type"] == "run-start" assert msg["event"]["data"] == snapshot + events.append(msg["event"]) # stt msg = await client.receive_json() assert msg["event"]["type"] == "stt-start" assert msg["event"]["data"] == snapshot + events.append(msg["event"]) # End of audio stream (handler id + empty payload) await client.send_bytes(b"1") @@ -358,6 +474,22 @@ async def test_stt_stream_failed( msg = await client.receive_json() assert msg["event"]["type"] == "error" assert msg["event"]["data"]["code"] == "stt-stream-failed" + events.append(msg["event"]) + + pipeline_data: PipelineData = hass.data[DOMAIN] + pipeline_id = list(pipeline_data.pipeline_runs)[0] + pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0] + + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline_debug/get", + "pipeline_id": pipeline_id, + "pipeline_run_id": pipeline_run_id, + } + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == {"events": events} async def test_tts_failed( @@ -367,15 +499,15 @@ async def test_tts_failed( snapshot: SnapshotAssertion, ) -> None: """Test pipeline run with text to speech error.""" + events = [] client = await hass_ws_client(hass) with patch( "homeassistant.components.media_source.async_resolve_media", new=MagicMock(return_value=RuntimeError), ): - await client.send_json( + await client.send_json_auto_id( { - "id": 5, "type": "assist_pipeline/run", "start_stage": "tts", "end_stage": "tts", @@ -391,16 +523,34 @@ async def test_tts_failed( msg = await client.receive_json() assert msg["event"]["type"] == "run-start" assert msg["event"]["data"] == snapshot + events.append(msg["event"]) # tts start msg = await client.receive_json() assert msg["event"]["type"] == "tts-start" assert msg["event"]["data"] == snapshot + events.append(msg["event"]) # tts error msg = await client.receive_json() assert msg["event"]["type"] == "error" assert msg["event"]["data"]["code"] == "tts-failed" + events.append(msg["event"]) + + pipeline_data: PipelineData = hass.data[DOMAIN] + pipeline_id = list(pipeline_data.pipeline_runs)[0] + pipeline_run_id = list(pipeline_data.pipeline_runs[pipeline_id])[0] + + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline_debug/get", + "pipeline_id": pipeline_id, + "pipeline_run_id": pipeline_run_id, + } + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == {"events": events} async def test_invalid_stage_order( @@ -428,7 +578,8 @@ async def test_add_pipeline( ) -> None: """Test we can add a pipeline.""" client = await hass_ws_client(hass) - pipeline_store: PipelineStorageCollection = hass.data[DOMAIN] + pipeline_data: PipelineData = hass.data[DOMAIN] + pipeline_store = pipeline_data.pipeline_store await client.send_json_auto_id( { @@ -468,7 +619,8 @@ async def test_delete_pipeline( ) -> None: """Test we can delete a pipeline.""" client = await hass_ws_client(hass) - pipeline_store: PipelineStorageCollection = hass.data[DOMAIN] + pipeline_data: PipelineData = hass.data[DOMAIN] + pipeline_store = pipeline_data.pipeline_store await client.send_json_auto_id( { @@ -542,7 +694,8 @@ async def test_list_pipelines( ) -> None: """Test we can list pipelines.""" client = await hass_ws_client(hass) - pipeline_store: PipelineStorageCollection = hass.data[DOMAIN] + pipeline_data: PipelineData = hass.data[DOMAIN] + pipeline_store = pipeline_data.pipeline_store await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"}) msg = await client.receive_json() @@ -586,7 +739,8 @@ async def test_update_pipeline( ) -> None: """Test we can list pipelines.""" client = await hass_ws_client(hass) - pipeline_store: PipelineStorageCollection = hass.data[DOMAIN] + pipeline_data: PipelineData = hass.data[DOMAIN] + pipeline_store = pipeline_data.pipeline_store await client.send_json_auto_id( { @@ -660,7 +814,8 @@ async def test_set_preferred_pipeline( ) -> None: """Test updating the preferred pipeline.""" client = await hass_ws_client(hass) - pipeline_store: PipelineStorageCollection = hass.data[DOMAIN] + pipeline_data: PipelineData = hass.data[DOMAIN] + pipeline_store = pipeline_data.pipeline_store await client.send_json_auto_id( { @@ -715,3 +870,202 @@ async def test_set_preferred_pipeline_wrong_id( ) msg = await client.receive_json() assert msg["error"]["code"] == "not_found" + + +async def test_audio_pipeline_debug( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + init_components, + snapshot: SnapshotAssertion, +) -> None: + """Test debug listing events from a pipeline run with audio input/output.""" + events = [] + client = await hass_ws_client(hass) + + await client.send_json_auto_id( + { + "type": "assist_pipeline/run", + "start_stage": "stt", + "end_stage": "tts", + "input": { + "sample_rate": 44100, + }, + } + ) + + # result + msg = await client.receive_json() + assert msg["success"] + + # run start + msg = await client.receive_json() + assert msg["event"]["type"] == "run-start" + assert msg["event"]["data"] == snapshot + events.append(msg["event"]) + + # stt + msg = await client.receive_json() + assert msg["event"]["type"] == "stt-start" + assert msg["event"]["data"] == snapshot + events.append(msg["event"]) + + # End of audio stream (handler id + empty payload) + await client.send_bytes(bytes([1])) + + msg = await client.receive_json() + assert msg["event"]["type"] == "stt-end" + assert msg["event"]["data"] == snapshot + events.append(msg["event"]) + + # intent + msg = await client.receive_json() + assert msg["event"]["type"] == "intent-start" + assert msg["event"]["data"] == snapshot + events.append(msg["event"]) + + msg = await client.receive_json() + assert msg["event"]["type"] == "intent-end" + assert msg["event"]["data"] == snapshot + events.append(msg["event"]) + + # text to speech + msg = await client.receive_json() + assert msg["event"]["type"] == "tts-start" + assert msg["event"]["data"] == snapshot + events.append(msg["event"]) + + msg = await client.receive_json() + assert msg["event"]["type"] == "tts-end" + assert msg["event"]["data"] == snapshot + events.append(msg["event"]) + + # run end + msg = await client.receive_json() + assert msg["event"]["type"] == "run-end" + assert msg["event"]["data"] is None + events.append(msg["event"]) + + # Get the id of the pipeline + await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"}) + msg = await client.receive_json() + assert msg["success"] + assert len(msg["result"]["pipelines"]) == 1 + + pipeline_id = msg["result"]["pipelines"][0]["id"] + + # Get the id for the run + await client.send_json_auto_id( + {"type": "assist_pipeline/pipeline_debug/list", "pipeline_id": pipeline_id} + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == {"pipeline_runs": [ANY]} + + pipeline_run_id = msg["result"]["pipeline_runs"][0] + + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline_debug/get", + "pipeline_id": pipeline_id, + "pipeline_run_id": pipeline_run_id, + } + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == {"events": events} + + +async def test_pipeline_debug_list_runs_wrong_pipeline( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + init_components, +) -> None: + """Test debug listing events from a pipeline.""" + client = await hass_ws_client(hass) + + await client.send_json_auto_id( + {"type": "assist_pipeline/pipeline_debug/list", "pipeline_id": "blah"} + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == {"pipeline_runs": []} + + +async def test_pipeline_debug_get_run_wrong_pipeline( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + init_components, +) -> None: + """Test debug listing events from a pipeline.""" + client = await hass_ws_client(hass) + + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline_debug/get", + "pipeline_id": "blah", + "pipeline_run_id": "blah", + } + ) + msg = await client.receive_json() + assert not msg["success"] + assert msg["error"] == { + "code": "not_found", + "message": "pipeline_id blah not found", + } + + +async def test_pipeline_debug_get_run_wrong_pipeline_run( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + init_components, +) -> None: + """Test debug listing events from a pipeline.""" + client = await hass_ws_client(hass) + + await client.send_json_auto_id( + { + "type": "assist_pipeline/run", + "start_stage": "intent", + "end_stage": "intent", + "input": {"text": "Are the lights on?"}, + } + ) + + # result + msg = await client.receive_json() + assert msg["success"] + + # consume events + msg = await client.receive_json() + assert msg["event"]["type"] == "run-start" + + msg = await client.receive_json() + assert msg["event"]["type"] == "intent-start" + + msg = await client.receive_json() + assert msg["event"]["type"] == "intent-end" + + msg = await client.receive_json() + assert msg["event"]["type"] == "run-end" + + # Get the id of the pipeline + await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"}) + msg = await client.receive_json() + assert msg["success"] + assert len(msg["result"]["pipelines"]) == 1 + pipeline_id = msg["result"]["pipelines"][0]["id"] + + # get debug data for the wrong run + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline_debug/get", + "pipeline_id": pipeline_id, + "pipeline_run_id": "blah", + } + ) + msg = await client.receive_json() + assert not msg["success"] + assert msg["error"] == { + "code": "not_found", + "message": "pipeline_run_id blah not found", + }