From c42d0feec1d8d6bf109b7f8cbf7023451f48744e Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Fri, 23 Jun 2023 22:29:56 -0400 Subject: [PATCH] Allow passing in device_id to pipeline run WS API (#95139) --- homeassistant/components/assist_pipeline/pipeline.py | 2 ++ .../components/assist_pipeline/websocket_api.py | 2 ++ .../assist_pipeline/snapshots/test_init.ambr | 6 ++++++ .../assist_pipeline/snapshots/test_websocket.ambr | 10 ++++++++++ tests/components/assist_pipeline/test_websocket.py | 3 ++- 5 files changed, 22 insertions(+), 1 deletion(-) diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index d08e1fc3e50..4a811b25f1f 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -512,6 +512,8 @@ class PipelineRun: "engine": self.intent_agent, "language": self.pipeline.conversation_language, "intent_input": intent_input, + "conversation_id": conversation_id, + "device_id": device_id, }, ) ) diff --git a/homeassistant/components/assist_pipeline/websocket_api.py b/homeassistant/components/assist_pipeline/websocket_api.py index bd2ec53db40..ea3aacf43a4 100644 --- a/homeassistant/components/assist_pipeline/websocket_api.py +++ b/homeassistant/components/assist_pipeline/websocket_api.py @@ -56,6 +56,7 @@ def async_register_websocket_api(hass: HomeAssistant) -> None: vol.Optional("input"): dict, vol.Optional("pipeline"): str, vol.Optional("conversation_id"): vol.Any(str, None), + vol.Optional("device_id"): vol.Any(str, None), vol.Optional("timeout"): vol.Any(float, int), }, ), @@ -105,6 +106,7 @@ async def websocket_run( # Arguments to PipelineInput input_args: dict[str, Any] = { "conversation_id": msg.get("conversation_id"), + "device_id": msg.get("device_id"), } if start_stage == PipelineStage.STT: diff --git a/tests/components/assist_pipeline/snapshots/test_init.ambr b/tests/components/assist_pipeline/snapshots/test_init.ambr index 619c59606ed..d8858cec4b6 100644 --- a/tests/components/assist_pipeline/snapshots/test_init.ambr +++ b/tests/components/assist_pipeline/snapshots/test_init.ambr @@ -32,6 +32,8 @@ }), dict({ 'data': dict({ + 'conversation_id': None, + 'device_id': None, 'engine': 'homeassistant', 'intent_input': 'test transcript', 'language': 'en', @@ -119,6 +121,8 @@ }), dict({ 'data': dict({ + 'conversation_id': None, + 'device_id': None, 'engine': 'homeassistant', 'intent_input': 'test transcript', 'language': 'en-US', @@ -206,6 +210,8 @@ }), dict({ 'data': dict({ + 'conversation_id': None, + 'device_id': None, 'engine': 'homeassistant', 'intent_input': 'test transcript', 'language': 'en-US', diff --git a/tests/components/assist_pipeline/snapshots/test_websocket.ambr b/tests/components/assist_pipeline/snapshots/test_websocket.ambr index a2e5ac72b07..12a4d766f06 100644 --- a/tests/components/assist_pipeline/snapshots/test_websocket.ambr +++ b/tests/components/assist_pipeline/snapshots/test_websocket.ambr @@ -31,6 +31,8 @@ # --- # name: test_audio_pipeline.3 dict({ + 'conversation_id': None, + 'device_id': None, 'engine': 'homeassistant', 'intent_input': 'test transcript', 'language': 'en', @@ -107,6 +109,8 @@ # --- # name: test_audio_pipeline_debug.3 dict({ + 'conversation_id': None, + 'device_id': None, 'engine': 'homeassistant', 'intent_input': 'test transcript', 'language': 'en', @@ -163,6 +167,8 @@ # --- # name: test_intent_failed.1 dict({ + 'conversation_id': None, + 'device_id': None, 'engine': 'homeassistant', 'intent_input': 'Are the lights on?', 'language': 'en', @@ -180,6 +186,8 @@ # --- # name: test_intent_timeout.1 dict({ + 'conversation_id': None, + 'device_id': None, 'engine': 'homeassistant', 'intent_input': 'Are the lights on?', 'language': 'en', @@ -249,6 +257,8 @@ # --- # name: test_text_only_pipeline.1 dict({ + 'conversation_id': 'mock-conversation-id', + 'device_id': 'mock-device-id', 'engine': 'homeassistant', 'intent_input': 'Are the lights on?', 'language': 'en', diff --git a/tests/components/assist_pipeline/test_websocket.py b/tests/components/assist_pipeline/test_websocket.py index 23044073368..4ebf0a1fb98 100644 --- a/tests/components/assist_pipeline/test_websocket.py +++ b/tests/components/assist_pipeline/test_websocket.py @@ -28,6 +28,8 @@ async def test_text_only_pipeline( "start_stage": "intent", "end_stage": "intent", "input": {"text": "Are the lights on?"}, + "conversation_id": "mock-conversation-id", + "device_id": "mock-device-id", } ) @@ -954,7 +956,6 @@ async def test_list_pipelines( ) -> None: """Test we can list pipelines.""" client = await hass_ws_client(hass) - hass.data[DOMAIN] await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"}) msg = await client.receive_json()