diff --git a/homeassistant/components/assist_pipeline/__init__.py b/homeassistant/components/assist_pipeline/__init__.py index cc7ecc1c426..9a32821e3a0 100644 --- a/homeassistant/components/assist_pipeline/__init__.py +++ b/homeassistant/components/assist_pipeline/__init__.py @@ -9,6 +9,7 @@ import voluptuous as vol from homeassistant.components import stt from homeassistant.core import Context, HomeAssistant +from homeassistant.helpers import chat_session from homeassistant.helpers.typing import ConfigType from .const import ( @@ -114,24 +115,25 @@ async def async_pipeline_from_audio_stream( Raises PipelineNotFound if no pipeline is found. """ - pipeline_input = PipelineInput( - conversation_id=conversation_id, - device_id=device_id, - stt_metadata=stt_metadata, - stt_stream=stt_stream, - wake_word_phrase=wake_word_phrase, - conversation_extra_system_prompt=conversation_extra_system_prompt, - run=PipelineRun( - hass, - context=context, - pipeline=async_get_pipeline(hass, pipeline_id=pipeline_id), - start_stage=start_stage, - end_stage=end_stage, - event_callback=event_callback, - tts_audio_output=tts_audio_output, - wake_word_settings=wake_word_settings, - audio_settings=audio_settings or AudioSettings(), - ), - ) - await pipeline_input.validate() - await pipeline_input.execute() + with chat_session.async_get_chat_session(hass, conversation_id) as session: + pipeline_input = PipelineInput( + conversation_id=session.conversation_id, + device_id=device_id, + stt_metadata=stt_metadata, + stt_stream=stt_stream, + wake_word_phrase=wake_word_phrase, + conversation_extra_system_prompt=conversation_extra_system_prompt, + run=PipelineRun( + hass, + context=context, + pipeline=async_get_pipeline(hass, pipeline_id=pipeline_id), + start_stage=start_stage, + end_stage=end_stage, + event_callback=event_callback, + tts_audio_output=tts_audio_output, + wake_word_settings=wake_word_settings, + audio_settings=audio_settings or AudioSettings(), + ), + ) + await pipeline_input.validate() + await pipeline_input.execute() diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index c5f9098623a..262f4c59687 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -624,7 +624,7 @@ class PipelineRun: return pipeline_data.pipeline_debug[self.pipeline.id][self.id].events.append(event) - def start(self, device_id: str | None) -> None: + def start(self, conversation_id: str, device_id: str | None) -> None: """Emit run start event.""" self._device_id = device_id self._start_debug_recording_thread() @@ -632,6 +632,7 @@ class PipelineRun: data = { "pipeline": self.pipeline.id, "language": self.language, + "conversation_id": conversation_id, } if self.runner_data is not None: data["runner_data"] = self.runner_data @@ -1015,7 +1016,7 @@ class PipelineRun: async def recognize_intent( self, intent_input: str, - conversation_id: str | None, + conversation_id: str, device_id: str | None, conversation_extra_system_prompt: str | None, ) -> str: @@ -1409,12 +1410,15 @@ def _pipeline_debug_recording_thread_proc( wav_writer.close() -@dataclass +@dataclass(kw_only=True) class PipelineInput: """Input to a pipeline run.""" run: PipelineRun + conversation_id: str + """Identifier for the conversation.""" + stt_metadata: stt.SpeechMetadata | None = None """Metadata of stt input audio. Required when start_stage = stt.""" @@ -1430,9 +1434,6 @@ class PipelineInput: tts_input: str | None = None """Input for text-to-speech. Required when start_stage = tts.""" - conversation_id: str | None = None - """Identifier for the conversation.""" - conversation_extra_system_prompt: str | None = None """Extra prompt information for the conversation agent.""" @@ -1441,7 +1442,7 @@ class PipelineInput: async def execute(self) -> None: """Run pipeline.""" - self.run.start(device_id=self.device_id) + self.run.start(conversation_id=self.conversation_id, device_id=self.device_id) current_stage: PipelineStage | None = self.run.start_stage stt_audio_buffer: list[EnhancedAudioChunk] = [] stt_processed_stream: AsyncIterable[EnhancedAudioChunk] | None = None diff --git a/homeassistant/components/assist_pipeline/websocket_api.py b/homeassistant/components/assist_pipeline/websocket_api.py index 69f917fcf83..d2d54a1b7c3 100644 --- a/homeassistant/components/assist_pipeline/websocket_api.py +++ b/homeassistant/components/assist_pipeline/websocket_api.py @@ -14,7 +14,11 @@ import voluptuous as vol from homeassistant.components import conversation, stt, tts, websocket_api from homeassistant.const import ATTR_DEVICE_ID, ATTR_SECONDS, MATCH_ALL from homeassistant.core import HomeAssistant, callback -from homeassistant.helpers import config_validation as cv, entity_registry as er +from homeassistant.helpers import ( + chat_session, + config_validation as cv, + entity_registry as er, +) from homeassistant.util import language as language_util from .const import ( @@ -145,7 +149,6 @@ async def websocket_run( # Arguments to PipelineInput input_args: dict[str, Any] = { - "conversation_id": msg.get("conversation_id"), "device_id": msg.get("device_id"), } @@ -233,38 +236,42 @@ async def websocket_run( audio_settings=audio_settings or AudioSettings(), ) - pipeline_input = PipelineInput(**input_args) + with chat_session.async_get_chat_session( + hass, msg.get("conversation_id") + ) as session: + input_args["conversation_id"] = session.conversation_id + pipeline_input = PipelineInput(**input_args) - try: - await pipeline_input.validate() - except PipelineError as error: - # Report more specific error when possible - connection.send_error(msg["id"], error.code, error.message) - return + try: + await pipeline_input.validate() + except PipelineError as error: + # Report more specific error when possible + connection.send_error(msg["id"], error.code, error.message) + return - # Confirm subscription - connection.send_result(msg["id"]) + # Confirm subscription + connection.send_result(msg["id"]) - run_task = hass.async_create_task(pipeline_input.execute()) + run_task = hass.async_create_task(pipeline_input.execute()) - # Cancel pipeline if user unsubscribes - connection.subscriptions[msg["id"]] = run_task.cancel + # Cancel pipeline if user unsubscribes + connection.subscriptions[msg["id"]] = run_task.cancel - try: - # Task contains a timeout - async with asyncio.timeout(timeout): - await run_task - except TimeoutError: - pipeline_input.run.process_event( - PipelineEvent( - PipelineEventType.ERROR, - {"code": "timeout", "message": "Timeout running pipeline"}, + try: + # Task contains a timeout + async with asyncio.timeout(timeout): + await run_task + except TimeoutError: + pipeline_input.run.process_event( + PipelineEvent( + PipelineEventType.ERROR, + {"code": "timeout", "message": "Timeout running pipeline"}, + ) ) - ) - finally: - if unregister_handler is not None: - # Unregister binary handler - unregister_handler() + finally: + if unregister_handler is not None: + # Unregister binary handler + unregister_handler() @callback diff --git a/tests/components/assist_pipeline/snapshots/test_init.ambr b/tests/components/assist_pipeline/snapshots/test_init.ambr index 526e1bff151..11e6bc2339a 100644 --- a/tests/components/assist_pipeline/snapshots/test_init.ambr +++ b/tests/components/assist_pipeline/snapshots/test_init.ambr @@ -3,6 +3,7 @@ list([ dict({ 'data': dict({ + 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , }), @@ -32,7 +33,7 @@ }), dict({ 'data': dict({ - 'conversation_id': None, + 'conversation_id': 'mock-ulid', 'device_id': None, 'engine': 'conversation.home_assistant', 'intent_input': 'test transcript', @@ -94,6 +95,7 @@ list([ dict({ 'data': dict({ + 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , }), @@ -123,7 +125,7 @@ }), dict({ 'data': dict({ - 'conversation_id': None, + 'conversation_id': 'mock-ulid', 'device_id': None, 'engine': 'conversation.home_assistant', 'intent_input': 'test transcript', @@ -185,6 +187,7 @@ list([ dict({ 'data': dict({ + 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , }), @@ -214,7 +217,7 @@ }), dict({ 'data': dict({ - 'conversation_id': None, + 'conversation_id': 'mock-ulid', 'device_id': None, 'engine': 'conversation.home_assistant', 'intent_input': 'test transcript', @@ -276,6 +279,7 @@ list([ dict({ 'data': dict({ + 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , }), @@ -329,7 +333,7 @@ }), dict({ 'data': dict({ - 'conversation_id': None, + 'conversation_id': 'mock-ulid', 'device_id': None, 'engine': 'conversation.home_assistant', 'intent_input': 'test transcript', @@ -391,6 +395,7 @@ list([ dict({ 'data': dict({ + 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , }), @@ -427,6 +432,7 @@ list([ dict({ 'data': dict({ + 'conversation_id': 'mock-conversation-id', 'language': 'en', 'pipeline': , }), @@ -434,7 +440,7 @@ }), dict({ 'data': dict({ - 'conversation_id': None, + 'conversation_id': 'mock-conversation-id', 'device_id': None, 'engine': 'conversation.home_assistant', 'intent_input': 'test input', @@ -478,6 +484,7 @@ list([ dict({ 'data': dict({ + 'conversation_id': 'mock-conversation-id', 'language': 'en', 'pipeline': , }), @@ -485,7 +492,7 @@ }), dict({ 'data': dict({ - 'conversation_id': None, + 'conversation_id': 'mock-conversation-id', 'device_id': None, 'engine': 'conversation.home_assistant', 'intent_input': 'test input', @@ -529,6 +536,7 @@ list([ dict({ 'data': dict({ + 'conversation_id': 'mock-conversation-id', 'language': 'en', 'pipeline': , }), @@ -536,7 +544,7 @@ }), dict({ 'data': dict({ - 'conversation_id': None, + 'conversation_id': 'mock-conversation-id', 'device_id': None, 'engine': 'conversation.home_assistant', 'intent_input': 'test input', @@ -580,6 +588,7 @@ list([ dict({ 'data': dict({ + 'conversation_id': 'mock-conversation-id', 'language': 'en', 'pipeline': , }), diff --git a/tests/components/assist_pipeline/snapshots/test_websocket.ambr b/tests/components/assist_pipeline/snapshots/test_websocket.ambr index 5f06172404b..f677fa6d8cf 100644 --- a/tests/components/assist_pipeline/snapshots/test_websocket.ambr +++ b/tests/components/assist_pipeline/snapshots/test_websocket.ambr @@ -1,6 +1,7 @@ # serializer version: 1 # name: test_audio_pipeline dict({ + 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , 'runner_data': dict({ @@ -31,7 +32,7 @@ # --- # name: test_audio_pipeline.3 dict({ - 'conversation_id': None, + 'conversation_id': 'mock-ulid', 'device_id': None, 'engine': 'conversation.home_assistant', 'intent_input': 'test transcript', @@ -84,6 +85,7 @@ # --- # name: test_audio_pipeline_debug dict({ + 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , 'runner_data': dict({ @@ -114,7 +116,7 @@ # --- # name: test_audio_pipeline_debug.3 dict({ - 'conversation_id': None, + 'conversation_id': 'mock-ulid', 'device_id': None, 'engine': 'conversation.home_assistant', 'intent_input': 'test transcript', @@ -179,6 +181,7 @@ # --- # name: test_audio_pipeline_with_enhancements dict({ + 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , 'runner_data': dict({ @@ -209,7 +212,7 @@ # --- # name: test_audio_pipeline_with_enhancements.3 dict({ - 'conversation_id': None, + 'conversation_id': 'mock-ulid', 'device_id': None, 'engine': 'conversation.home_assistant', 'intent_input': 'test transcript', @@ -262,6 +265,7 @@ # --- # name: test_audio_pipeline_with_wake_word_no_timeout dict({ + 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , 'runner_data': dict({ @@ -314,7 +318,7 @@ # --- # name: test_audio_pipeline_with_wake_word_no_timeout.5 dict({ - 'conversation_id': None, + 'conversation_id': 'mock-ulid', 'device_id': None, 'engine': 'conversation.home_assistant', 'intent_input': 'test transcript', @@ -367,6 +371,7 @@ # --- # name: test_audio_pipeline_with_wake_word_timeout dict({ + 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , 'runner_data': dict({ @@ -399,6 +404,7 @@ # --- # name: test_device_capture dict({ + 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , 'runner_data': dict({ @@ -425,6 +431,7 @@ # --- # name: test_device_capture_override dict({ + 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , 'runner_data': dict({ @@ -473,6 +480,7 @@ # --- # name: test_device_capture_queue_full dict({ + 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , 'runner_data': dict({ @@ -512,6 +520,7 @@ # --- # name: test_intent_failed dict({ + 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , 'runner_data': dict({ @@ -522,7 +531,7 @@ # --- # name: test_intent_failed.1 dict({ - 'conversation_id': None, + 'conversation_id': 'mock-ulid', 'device_id': None, 'engine': 'conversation.home_assistant', 'intent_input': 'Are the lights on?', @@ -535,6 +544,7 @@ # --- # name: test_intent_timeout dict({ + 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , 'runner_data': dict({ @@ -545,7 +555,7 @@ # --- # name: test_intent_timeout.1 dict({ - 'conversation_id': None, + 'conversation_id': 'mock-ulid', 'device_id': None, 'engine': 'conversation.home_assistant', 'intent_input': 'Are the lights on?', @@ -564,6 +574,7 @@ # --- # name: test_pipeline_empty_tts_output dict({ + 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , 'runner_data': dict({ @@ -574,7 +585,7 @@ # --- # name: test_pipeline_empty_tts_output.1 dict({ - 'conversation_id': None, + 'conversation_id': 'mock-ulid', 'device_id': None, 'engine': 'conversation.home_assistant', 'intent_input': 'never mind', @@ -611,6 +622,7 @@ # --- # name: test_stt_cooldown_different_ids dict({ + 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , 'runner_data': dict({ @@ -621,6 +633,7 @@ # --- # name: test_stt_cooldown_different_ids.1 dict({ + 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , 'runner_data': dict({ @@ -631,6 +644,7 @@ # --- # name: test_stt_cooldown_same_id dict({ + 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , 'runner_data': dict({ @@ -641,6 +655,7 @@ # --- # name: test_stt_cooldown_same_id.1 dict({ + 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , 'runner_data': dict({ @@ -651,6 +666,7 @@ # --- # name: test_stt_stream_failed dict({ + 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , 'runner_data': dict({ @@ -677,6 +693,7 @@ # --- # name: test_text_only_pipeline[extra_msg0] dict({ + 'conversation_id': 'mock-conversation-id', 'language': 'en', 'pipeline': , 'runner_data': dict({ @@ -723,6 +740,7 @@ # --- # name: test_text_only_pipeline[extra_msg1] dict({ + 'conversation_id': 'mock-conversation-id', 'language': 'en', 'pipeline': , 'runner_data': dict({ @@ -775,6 +793,7 @@ # --- # name: test_tts_failed dict({ + 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , 'runner_data': dict({ @@ -796,6 +815,7 @@ # --- # name: test_wake_word_cooldown_different_entities dict({ + 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , 'runner_data': dict({ @@ -806,6 +826,7 @@ # --- # name: test_wake_word_cooldown_different_entities.1 dict({ + 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , 'runner_data': dict({ @@ -857,6 +878,7 @@ # --- # name: test_wake_word_cooldown_different_ids dict({ + 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , 'runner_data': dict({ @@ -867,6 +889,7 @@ # --- # name: test_wake_word_cooldown_different_ids.1 dict({ + 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , 'runner_data': dict({ @@ -921,6 +944,7 @@ # --- # name: test_wake_word_cooldown_same_id dict({ + 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , 'runner_data': dict({ @@ -931,6 +955,7 @@ # --- # name: test_wake_word_cooldown_same_id.1 dict({ + 'conversation_id': 'mock-ulid', 'language': 'en', 'pipeline': , 'runner_data': dict({ diff --git a/tests/components/assist_pipeline/test_init.py b/tests/components/assist_pipeline/test_init.py index a2cb9ef382a..1651950c173 100644 --- a/tests/components/assist_pipeline/test_init.py +++ b/tests/components/assist_pipeline/test_init.py @@ -1,11 +1,12 @@ """Test Voice Assistant init.""" import asyncio +from collections.abc import Generator from dataclasses import asdict import itertools as it from pathlib import Path import tempfile -from unittest.mock import ANY, patch +from unittest.mock import ANY, Mock, patch import wave import hass_nabucasa @@ -41,6 +42,14 @@ from .conftest import ( from tests.typing import ClientSessionGenerator, WebSocketGenerator +@pytest.fixture(autouse=True) +def mock_ulid() -> Generator[Mock]: + """Mock the ulid of chat sessions.""" + with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now: + mock_ulid_now.return_value = "mock-ulid" + yield mock_ulid_now + + def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]: """Process events to remove dynamic values.""" processed = [] @@ -684,7 +693,7 @@ async def test_wake_word_detection_aborted( pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) pipeline_input = assist_pipeline.pipeline.PipelineInput( - conversation_id=None, + conversation_id="mock-conversation-id", device_id=None, stt_metadata=stt.SpeechMetadata( language="", @@ -771,7 +780,7 @@ async def test_tts_audio_output( pipeline_input = assist_pipeline.pipeline.PipelineInput( tts_input="This is a test.", - conversation_id=None, + conversation_id="mock-conversation-id", device_id=None, run=assist_pipeline.pipeline.PipelineRun( hass, @@ -828,7 +837,7 @@ async def test_tts_wav_preferred_format( pipeline_input = assist_pipeline.pipeline.PipelineInput( tts_input="This is a test.", - conversation_id=None, + conversation_id="mock-conversation-id", device_id=None, run=assist_pipeline.pipeline.PipelineRun( hass, @@ -896,7 +905,7 @@ async def test_tts_dict_preferred_format( pipeline_input = assist_pipeline.pipeline.PipelineInput( tts_input="This is a test.", - conversation_id=None, + conversation_id="mock-conversation-id", device_id=None, run=assist_pipeline.pipeline.PipelineRun( hass, @@ -982,6 +991,7 @@ async def test_sentence_trigger_overrides_conversation_agent( pipeline_input = assist_pipeline.pipeline.PipelineInput( intent_input="test trigger sentence", + conversation_id="mock-conversation-id", run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), @@ -1059,6 +1069,7 @@ async def test_prefer_local_intents( pipeline_input = assist_pipeline.pipeline.PipelineInput( intent_input="I'd like to order a stout please", + conversation_id="mock-conversation-id", run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), @@ -1136,6 +1147,7 @@ async def test_stt_language_used_instead_of_conversation_language( pipeline_input = assist_pipeline.pipeline.PipelineInput( intent_input="test input", + conversation_id="mock-conversation-id", run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), @@ -1210,6 +1222,7 @@ async def test_tts_language_used_instead_of_conversation_language( pipeline_input = assist_pipeline.pipeline.PipelineInput( intent_input="test input", + conversation_id="mock-conversation-id", run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), @@ -1284,6 +1297,7 @@ async def test_pipeline_language_used_instead_of_conversation_language( pipeline_input = assist_pipeline.pipeline.PipelineInput( intent_input="test input", + conversation_id="mock-conversation-id", run=assist_pipeline.pipeline.PipelineRun( hass, context=Context(), diff --git a/tests/components/assist_pipeline/test_websocket.py b/tests/components/assist_pipeline/test_websocket.py index c1caf6f86a4..2cd56f094dd 100644 --- a/tests/components/assist_pipeline/test_websocket.py +++ b/tests/components/assist_pipeline/test_websocket.py @@ -2,8 +2,9 @@ import asyncio import base64 +from collections.abc import Generator from typing import Any -from unittest.mock import ANY, patch +from unittest.mock import ANY, Mock, patch import pytest from syrupy.assertion import SnapshotAssertion @@ -35,6 +36,14 @@ from tests.common import MockConfigEntry from tests.typing import WebSocketGenerator +@pytest.fixture(autouse=True) +def mock_ulid() -> Generator[Mock]: + """Mock the ulid of chat sessions.""" + with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now: + mock_ulid_now.return_value = "mock-ulid" + yield mock_ulid_now + + @pytest.mark.parametrize( "extra_msg", [