diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index 5f811ac955b..7d5f98e87f6 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -89,6 +89,8 @@ KEY_ASSIST_PIPELINE: HassKey[PipelineData] = HassKey(DOMAIN) KEY_PIPELINE_CONVERSATION_DATA: HassKey[dict[str, PipelineConversationData]] = HassKey( "pipeline_conversation_data" ) +# Number of response parts to handle before streaming the response +STREAM_RESPONSE_CHARS = 60 def validate_language(data: dict[str, Any]) -> Any: @@ -552,7 +554,7 @@ class PipelineRun: event_callback: PipelineEventCallback language: str = None # type: ignore[assignment] runner_data: Any | None = None - intent_agent: str | None = None + intent_agent: conversation.AgentInfo | None = None tts_audio_output: str | dict[str, Any] | None = None wake_word_settings: WakeWordSettings | None = None audio_settings: AudioSettings = field(default_factory=AudioSettings) @@ -588,6 +590,9 @@ class PipelineRun: _intent_agent_only = False """If request should only be handled by agent, ignoring sentence triggers and local processing.""" + _streamed_response_text = False + """If the conversation agent streamed response text to TTS result.""" + def __post_init__(self) -> None: """Set language for pipeline.""" self.language = self.pipeline.language or self.hass.config.language @@ -649,6 +654,11 @@ class PipelineRun: "token": self.tts_stream.token, "url": self.tts_stream.url, "mime_type": self.tts_stream.content_type, + "stream_response": ( + self.tts_stream.supports_streaming_input + and self.intent_agent + and self.intent_agent.supports_streaming + ), } self.process_event(PipelineEvent(PipelineEventType.RUN_START, data)) @@ -896,12 +906,12 @@ class PipelineRun: ) -> str: """Run speech-to-text portion of pipeline. Returns the spoken text.""" # Create a background task to prepare the conversation agent - if self.end_stage >= PipelineStage.INTENT: + if self.end_stage >= PipelineStage.INTENT and self.intent_agent: self.hass.async_create_background_task( conversation.async_prepare_agent( - self.hass, self.intent_agent, self.language + self.hass, self.intent_agent.id, self.language ), - f"prepare conversation agent {self.intent_agent}", + f"prepare conversation agent {self.intent_agent.id}", ) if isinstance(self.stt_provider, stt.Provider): @@ -1042,7 +1052,7 @@ class PipelineRun: message=f"Intent recognition engine {engine} is not found", ) - self.intent_agent = agent_info.id + self.intent_agent = agent_info async def recognize_intent( self, @@ -1075,7 +1085,7 @@ class PipelineRun: PipelineEvent( PipelineEventType.INTENT_START, { - "engine": self.intent_agent, + "engine": self.intent_agent.id, "language": input_language, "intent_input": intent_input, "conversation_id": conversation_id, @@ -1092,11 +1102,11 @@ class PipelineRun: conversation_id=conversation_id, device_id=device_id, language=input_language, - agent_id=self.intent_agent, + agent_id=self.intent_agent.id, extra_system_prompt=conversation_extra_system_prompt, ) - agent_id = self.intent_agent + agent_id = self.intent_agent.id processed_locally = agent_id == conversation.HOME_ASSISTANT_AGENT intent_response: intent.IntentResponse | None = None if not processed_locally and not self._intent_agent_only: @@ -1118,7 +1128,7 @@ class PipelineRun: # If the LLM has API access, we filter out some sentences that are # interfering with LLM operation. if ( - intent_agent_state := self.hass.states.get(self.intent_agent) + intent_agent_state := self.hass.states.get(self.intent_agent.id) ) and intent_agent_state.attributes.get( ATTR_SUPPORTED_FEATURES, 0 ) & conversation.ConversationEntityFeature.CONTROL: @@ -1140,6 +1150,13 @@ class PipelineRun: agent_id = conversation.HOME_ASSISTANT_AGENT processed_locally = True + if self.tts_stream and self.tts_stream.supports_streaming_input: + tts_input_stream: asyncio.Queue[str | None] | None = asyncio.Queue() + else: + tts_input_stream = None + chat_log_role = None + delta_character_count = 0 + @callback def chat_log_delta_listener( chat_log: conversation.ChatLog, delta: dict @@ -1153,6 +1170,42 @@ class PipelineRun: }, ) ) + if tts_input_stream is None: + return + + nonlocal chat_log_role + + if role := delta.get("role"): + chat_log_role = role + + # We are only interested in assistant deltas with content + if chat_log_role != "assistant" or not ( + content := delta.get("content") + ): + return + + tts_input_stream.put_nowait(content) + + if self._streamed_response_text: + return + + nonlocal delta_character_count + + delta_character_count += len(content) + if delta_character_count < STREAM_RESPONSE_CHARS: + return + + # Streamed responses are not cached. We only start streaming text after + # we have received a couple of words that indicates it will be a long response. + self._streamed_response_text = True + + async def tts_input_stream_generator() -> AsyncGenerator[str]: + """Yield TTS input stream.""" + while (tts_input := await tts_input_stream.get()) is not None: + yield tts_input + + assert self.tts_stream is not None + self.tts_stream.async_set_message_stream(tts_input_stream_generator()) with ( chat_session.async_get_chat_session( @@ -1196,6 +1249,8 @@ class PipelineRun: speech = conversation_result.response.speech.get("plain", {}).get( "speech", "" ) + if tts_input_stream and self._streamed_response_text: + tts_input_stream.put_nowait(None) except Exception as src_error: _LOGGER.exception("Unexpected error during intent recognition") @@ -1273,7 +1328,8 @@ class PipelineRun: ) ) - self.tts_stream.async_set_message(tts_input) + if not self._streamed_response_text: + self.tts_stream.async_set_message(tts_input) tts_output = { "media_id": self.tts_stream.media_source_id, diff --git a/tests/components/assist_pipeline/snapshots/test_init.ambr b/tests/components/assist_pipeline/snapshots/test_init.ambr index 5d2d25ddc5c..4ae4b5dce4c 100644 --- a/tests/components/assist_pipeline/snapshots/test_init.ambr +++ b/tests/components/assist_pipeline/snapshots/test_init.ambr @@ -8,6 +8,7 @@ 'pipeline': , 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'stream_response': False, 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', }), @@ -107,6 +108,7 @@ 'pipeline': , 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'stream_response': False, 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', }), @@ -206,6 +208,7 @@ 'pipeline': , 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'stream_response': False, 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', }), @@ -305,6 +308,7 @@ 'pipeline': , 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'stream_response': False, 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', }), @@ -428,6 +432,7 @@ 'pipeline': , 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'stream_response': False, 'token': 'mocked-token.mp3', 'url': '/api/tts_proxy/mocked-token.mp3', }), diff --git a/tests/components/assist_pipeline/snapshots/test_pipeline.ambr b/tests/components/assist_pipeline/snapshots/test_pipeline.ambr index f5940edbc76..2e005fb4c13 100644 --- a/tests/components/assist_pipeline/snapshots/test_pipeline.ambr +++ b/tests/components/assist_pipeline/snapshots/test_pipeline.ambr @@ -1,5 +1,5 @@ # serializer version: 1 -# name: test_chat_log_tts_streaming[to_stream_tts0-1] +# name: test_chat_log_tts_streaming[to_stream_tts0-0-] list([ dict({ 'data': dict({ @@ -8,6 +8,7 @@ 'pipeline': , 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'stream_response': True, 'token': 'mocked-token.mp3', 'url': '/api/tts_proxy/mocked-token.mp3', }), @@ -153,6 +154,225 @@ }), ]) # --- +# name: test_chat_log_tts_streaming[to_stream_tts1-16-hello, how are you? I'm doing well, thank you. What about you?] + list([ + dict({ + 'data': dict({ + 'conversation_id': 'mock-ulid', + 'language': 'en', + 'pipeline': , + 'tts_output': dict({ + 'mime_type': 'audio/mpeg', + 'stream_response': True, + 'token': 'mocked-token.mp3', + 'url': '/api/tts_proxy/mocked-token.mp3', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'conversation_id': 'mock-ulid', + 'device_id': None, + 'engine': 'test-agent', + 'intent_input': 'Set a timer', + 'language': 'en', + 'prefer_local_intents': False, + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'role': 'assistant', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': 'hello, ', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': 'how ', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': 'are ', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': 'you', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': '? ', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': "I'm ", + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': 'doing ', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': 'well', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': ', ', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': 'thank ', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': 'you', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': '. ', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': 'What ', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': 'about ', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': 'you', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': '?', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'intent_output': dict({ + 'continue_conversation': True, + 'conversation_id': , + 'response': dict({ + 'card': dict({ + }), + 'data': dict({ + 'failed': list([ + ]), + 'success': list([ + ]), + 'targets': list([ + ]), + }), + 'language': 'en', + 'response_type': 'action_done', + 'speech': dict({ + 'plain': dict({ + 'extra_data': None, + 'speech': "hello, how are you? I'm doing well, thank you. What about you?", + }), + }), + }), + }), + 'processed_locally': False, + }), + 'type': , + }), + dict({ + 'data': dict({ + 'engine': 'tts.test', + 'language': 'en_US', + 'tts_input': "hello, how are you? I'm doing well, thank you. What about you?", + 'voice': None, + }), + 'type': , + }), + dict({ + 'data': dict({ + 'tts_output': dict({ + 'media_id': 'media-source://tts/-stream-/mocked-token.mp3', + 'mime_type': 'audio/mpeg', + 'token': 'mocked-token.mp3', + 'url': '/api/tts_proxy/mocked-token.mp3', + }), + }), + 'type': , + }), + dict({ + 'data': None, + 'type': , + }), + ]) +# --- # name: test_pipeline_language_used_instead_of_conversation_language list([ dict({ @@ -321,6 +541,7 @@ 'pipeline': , 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'stream_response': False, 'token': 'mocked-token.mp3', 'url': '/api/tts_proxy/mocked-token.mp3', }), diff --git a/tests/components/assist_pipeline/snapshots/test_websocket.ambr b/tests/components/assist_pipeline/snapshots/test_websocket.ambr index 827b9c71ba8..4f29fd79568 100644 --- a/tests/components/assist_pipeline/snapshots/test_websocket.ambr +++ b/tests/components/assist_pipeline/snapshots/test_websocket.ambr @@ -10,6 +10,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'stream_response': False, 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', }), @@ -101,6 +102,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'stream_response': False, 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', }), @@ -204,6 +206,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'stream_response': False, 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', }), @@ -295,6 +298,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'stream_response': False, 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', }), @@ -408,6 +412,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'stream_response': False, 'token': 'test_token.mp3', 'url': '/api/tts_proxy/test_token.mp3', }), @@ -616,6 +621,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'stream_response': False, 'token': 'mocked-token.mp3', 'url': '/api/tts_proxy/mocked-token.mp3', }), @@ -670,6 +676,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'stream_response': False, 'token': 'mocked-token.mp3', 'url': '/api/tts_proxy/mocked-token.mp3', }), @@ -686,6 +693,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'stream_response': False, 'token': 'mocked-token.mp3', 'url': '/api/tts_proxy/mocked-token.mp3', }), @@ -702,6 +710,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'stream_response': False, 'token': 'mocked-token.mp3', 'url': '/api/tts_proxy/mocked-token.mp3', }), @@ -718,6 +727,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'stream_response': False, 'token': 'mocked-token.mp3', 'url': '/api/tts_proxy/mocked-token.mp3', }), @@ -734,6 +744,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'stream_response': False, 'token': 'mocked-token.mp3', 'url': '/api/tts_proxy/mocked-token.mp3', }), @@ -868,6 +879,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'stream_response': False, 'token': 'mocked-token.mp3', 'url': '/api/tts_proxy/mocked-token.mp3', }), @@ -884,6 +896,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'stream_response': False, 'token': 'mocked-token.mp3', 'url': '/api/tts_proxy/mocked-token.mp3', }), @@ -941,6 +954,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'stream_response': False, 'token': 'mocked-token.mp3', 'url': '/api/tts_proxy/mocked-token.mp3', }), @@ -957,6 +971,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'stream_response': False, 'token': 'mocked-token.mp3', 'url': '/api/tts_proxy/mocked-token.mp3', }), @@ -1017,6 +1032,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'stream_response': False, 'token': 'mocked-token.mp3', 'url': '/api/tts_proxy/mocked-token.mp3', }), @@ -1033,6 +1049,7 @@ }), 'tts_output': dict({ 'mime_type': 'audio/mpeg', + 'stream_response': False, 'token': 'mocked-token.mp3', 'url': '/api/tts_proxy/mocked-token.mp3', }), diff --git a/tests/components/assist_pipeline/test_pipeline.py b/tests/components/assist_pipeline/test_pipeline.py index 1714c909a18..d8550f34deb 100644 --- a/tests/components/assist_pipeline/test_pipeline.py +++ b/tests/components/assist_pipeline/test_pipeline.py @@ -1575,8 +1575,9 @@ async def test_pipeline_language_used_instead_of_conversation_language( @pytest.mark.parametrize( - ("to_stream_tts", "expected_chunks"), + ("to_stream_tts", "expected_chunks", "chunk_text"), [ + # Size below STREAM_RESPONSE_CHUNKS ( [ "hello,", @@ -1588,7 +1589,33 @@ async def test_pipeline_language_used_instead_of_conversation_language( "you", "?", ], - 1, + # We are not streaming, so 0 chunks via streaming method + 0, + "", + ), + # Size above STREAM_RESPONSE_CHUNKS + ( + [ + "hello, ", + "how ", + "are ", + "you", + "? ", + "I'm ", + "doing ", + "well", + ", ", + "thank ", + "you", + ". ", + "What ", + "about ", + "you", + "?", + ], + # We are streamed, so equal to count above list items + 16, + "hello, how are you? I'm doing well, thank you. What about you?", ), ], ) @@ -1602,6 +1629,7 @@ async def test_chat_log_tts_streaming( pipeline_data: assist_pipeline.pipeline.PipelineData, to_stream_tts: list[str], expected_chunks: int, + chunk_text: str, ) -> None: """Test that chat log events are streamed to the TTS entity.""" events: list[assist_pipeline.PipelineEvent] = [] @@ -1627,22 +1655,41 @@ async def test_chat_log_tts_streaming( ), ) + received_tts = [] + + async def async_stream_tts_audio( + request: tts.TTSAudioRequest, + ) -> tts.TTSAudioResponse: + """Mock stream TTS audio.""" + + async def gen_data(): + async for msg in request.message_gen: + received_tts.append(msg) + yield msg.encode() + + return tts.TTSAudioResponse( + extension="mp3", + data_gen=gen_data(), + ) + async def async_get_tts_audio( message: str, language: str, options: dict[str, Any] | None = None, - ) -> tts.TTSAudioResponse: + ) -> tts.TtsAudioType: """Mock get TTS audio.""" return ("mp3", b"".join([chunk.encode() for chunk in to_stream_tts])) mock_tts_entity.async_get_tts_audio = async_get_tts_audio + mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio + mock_tts_entity.async_supports_streaming_input = Mock(return_value=True) with patch( "homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info", return_value=conversation.AgentInfo( id="test-agent", name="Test Agent", - supports_streaming=False, + supports_streaming=True, ), ): await pipeline_input.validate() @@ -1707,6 +1754,7 @@ async def test_chat_log_tts_streaming( streamed_text = "".join(to_stream_tts) assert tts_result == streamed_text - assert expected_chunks == 1 + assert len(received_tts) == expected_chunks + assert "".join(received_tts) == chunk_text assert process_events(events) == snapshot