From 1cc2baa95e69c0cdaa25aa7c16dd1c5fc27d37d2 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sun, 25 May 2025 15:59:07 -0400 Subject: [PATCH] Pipeline to stream TTS on tool call (#145477) --- .../components/assist_pipeline/pipeline.py | 37 ++- .../snapshots/test_pipeline.ambr | 240 +++++++++++++++++- .../assist_pipeline/test_pipeline.py | 157 ++++++++---- 3 files changed, 376 insertions(+), 58 deletions(-) diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index 7d5f98e87f6..34f590574d4 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -1178,25 +1178,33 @@ class PipelineRun: if role := delta.get("role"): chat_log_role = role - # We are only interested in assistant deltas with content - if chat_log_role != "assistant" or not ( - content := delta.get("content") - ): + # We are only interested in assistant deltas + if chat_log_role != "assistant": return - tts_input_stream.put_nowait(content) + if content := delta.get("content"): + tts_input_stream.put_nowait(content) if self._streamed_response_text: return nonlocal delta_character_count - delta_character_count += len(content) - if delta_character_count < STREAM_RESPONSE_CHARS: + # Streamed responses are not cached. That's why we only start streaming text after + # we have received enough characters that indicates it will be a long response + # or if we have received text, and then a tool call. + + # Tool call after we already received text + start_streaming = delta_character_count > 0 and delta.get("tool_calls") + + # Count characters in the content and test if we exceed streaming threshold + if not start_streaming and content: + delta_character_count += len(content) + start_streaming = delta_character_count > STREAM_RESPONSE_CHARS + + if not start_streaming: return - # Streamed responses are not cached. We only start streaming text after - # we have received a couple of words that indicates it will be a long response. self._streamed_response_text = True async def tts_input_stream_generator() -> AsyncGenerator[str]: @@ -1204,6 +1212,17 @@ class PipelineRun: while (tts_input := await tts_input_stream.get()) is not None: yield tts_input + # Concatenate all existing queue items + parts = [] + while not tts_input_stream.empty(): + parts.append(tts_input_stream.get_nowait()) + tts_input_stream.put_nowait( + "".join( + # At this point parts is only strings, None indicates end of queue + cast(list[str], parts) + ) + ) + assert self.tts_stream is not None self.tts_stream.async_set_message_stream(tts_input_stream_generator()) diff --git a/tests/components/assist_pipeline/snapshots/test_pipeline.ambr b/tests/components/assist_pipeline/snapshots/test_pipeline.ambr index 2e005fb4c13..8431e32ed87 100644 --- a/tests/components/assist_pipeline/snapshots/test_pipeline.ambr +++ b/tests/components/assist_pipeline/snapshots/test_pipeline.ambr @@ -1,5 +1,5 @@ # serializer version: 1 -# name: test_chat_log_tts_streaming[to_stream_tts0-0-] +# name: test_chat_log_tts_streaming[to_stream_deltas0-0-] list([ dict({ 'data': dict({ @@ -154,7 +154,7 @@ }), ]) # --- -# name: test_chat_log_tts_streaming[to_stream_tts1-16-hello, how are you? I'm doing well, thank you. What about you?] +# name: test_chat_log_tts_streaming[to_stream_deltas1-3-hello, how are you? I'm doing well, thank you. What about you?!] list([ dict({ 'data': dict({ @@ -317,10 +317,18 @@ }), 'type': , }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': '!', + }), + }), + 'type': , + }), dict({ 'data': dict({ 'intent_output': dict({ - 'continue_conversation': True, + 'continue_conversation': False, 'conversation_id': , 'response': dict({ 'card': dict({ @@ -338,7 +346,7 @@ 'speech': dict({ 'plain': dict({ 'extra_data': None, - 'speech': "hello, how are you? I'm doing well, thank you. What about you?", + 'speech': "hello, how are you? I'm doing well, thank you. What about you?!", }), }), }), @@ -351,7 +359,229 @@ 'data': dict({ 'engine': 'tts.test', 'language': 'en_US', - 'tts_input': "hello, how are you? I'm doing well, thank you. What about you?", + 'tts_input': "hello, how are you? I'm doing well, thank you. What about you?!", + 'voice': None, + }), + 'type': , + }), + dict({ + 'data': dict({ + 'tts_output': dict({ + 'media_id': 'media-source://tts/-stream-/mocked-token.mp3', + 'mime_type': 'audio/mpeg', + 'token': 'mocked-token.mp3', + 'url': '/api/tts_proxy/mocked-token.mp3', + }), + }), + 'type': , + }), + dict({ + 'data': None, + 'type': , + }), + ]) +# --- +# name: test_chat_log_tts_streaming[to_stream_deltas2-8-hello, how are you? I'm doing well, thank you.] + list([ + dict({ + 'data': dict({ + 'conversation_id': 'mock-ulid', + 'language': 'en', + 'pipeline': , + 'tts_output': dict({ + 'mime_type': 'audio/mpeg', + 'stream_response': True, + 'token': 'mocked-token.mp3', + 'url': '/api/tts_proxy/mocked-token.mp3', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'conversation_id': 'mock-ulid', + 'device_id': None, + 'engine': 'test-agent', + 'intent_input': 'Set a timer', + 'language': 'en', + 'prefer_local_intents': False, + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'role': 'assistant', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': 'hello, ', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': 'how ', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': 'are ', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': 'you', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': '? ', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'tool_calls': list([ + dict({ + 'id': 'test_tool_id', + 'tool_args': dict({ + }), + 'tool_name': 'test_tool', + }), + ]), + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'agent_id': 'test-agent', + 'role': 'tool_result', + 'tool_call_id': 'test_tool_id', + 'tool_name': 'test_tool', + 'tool_result': 'Test response', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'role': 'assistant', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': "I'm ", + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': 'doing ', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': 'well', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': ', ', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': 'thank ', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': 'you', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'chat_log_delta': dict({ + 'content': '.', + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'intent_output': dict({ + 'continue_conversation': False, + 'conversation_id': , + 'response': dict({ + 'card': dict({ + }), + 'data': dict({ + 'failed': list([ + ]), + 'success': list([ + ]), + 'targets': list([ + ]), + }), + 'language': 'en', + 'response_type': 'action_done', + 'speech': dict({ + 'plain': dict({ + 'extra_data': None, + 'speech': "I'm doing well, thank you.", + }), + }), + }), + }), + 'processed_locally': False, + }), + 'type': , + }), + dict({ + 'data': dict({ + 'engine': 'tts.test', + 'language': 'en_US', + 'tts_input': "I'm doing well, thank you.", 'voice': None, }), 'type': , diff --git a/tests/components/assist_pipeline/test_pipeline.py b/tests/components/assist_pipeline/test_pipeline.py index d8550f34deb..abdcb55054c 100644 --- a/tests/components/assist_pipeline/test_pipeline.py +++ b/tests/components/assist_pipeline/test_pipeline.py @@ -2,11 +2,12 @@ from collections.abc import AsyncGenerator, Generator from typing import Any -from unittest.mock import ANY, Mock, patch +from unittest.mock import ANY, AsyncMock, Mock, patch from hassil.recognize import Intent, IntentData, RecognizeResult import pytest from syrupy.assertion import SnapshotAssertion +import voluptuous as vol from homeassistant.components import ( assist_pipeline, @@ -33,7 +34,7 @@ from homeassistant.components.assist_pipeline.pipeline import ( ) from homeassistant.const import MATCH_ALL from homeassistant.core import Context, HomeAssistant -from homeassistant.helpers import chat_session, intent +from homeassistant.helpers import chat_session, intent, llm from homeassistant.setup import async_setup_component from . import MANY_LANGUAGES, process_events @@ -1575,47 +1576,86 @@ async def test_pipeline_language_used_instead_of_conversation_language( @pytest.mark.parametrize( - ("to_stream_tts", "expected_chunks", "chunk_text"), + ("to_stream_deltas", "expected_chunks", "chunk_text"), [ # Size below STREAM_RESPONSE_CHUNKS ( - [ - "hello,", - " ", - "how", - " ", - "are", - " ", - "you", - "?", - ], + ( + [ + "hello,", + " ", + "how", + " ", + "are", + " ", + "you", + "?", + ], + ), # We are not streaming, so 0 chunks via streaming method 0, "", ), # Size above STREAM_RESPONSE_CHUNKS ( - [ - "hello, ", - "how ", - "are ", - "you", - "? ", - "I'm ", - "doing ", - "well", - ", ", - "thank ", - "you", - ". ", - "What ", - "about ", - "you", - "?", - ], - # We are streamed, so equal to count above list items - 16, - "hello, how are you? I'm doing well, thank you. What about you?", + ( + [ + "hello, ", + "how ", + "are ", + "you", + "? ", + "I'm ", + "doing ", + "well", + ", ", + "thank ", + "you", + ". ", + "What ", + "about ", + "you", + "?", + "!", + ], + ), + # We are streamed. First 15 chunks are grouped into 1 chunk + # and the rest are streamed + 3, + "hello, how are you? I'm doing well, thank you. What about you?!", + ), + # Stream a bit, then a tool call, then stream some more + ( + ( + [ + "hello, ", + "how ", + "are ", + "you", + "? ", + ], + { + "tool_calls": [ + llm.ToolInput( + tool_name="test_tool", + tool_args={}, + id="test_tool_id", + ) + ], + }, + [ + "I'm ", + "doing ", + "well", + ", ", + "thank ", + "you", + ".", + ], + ), + # 1 chunk before tool call, then 7 after + 8, + "hello, how are you? I'm doing well, thank you.", ), ], ) @@ -1627,11 +1667,18 @@ async def test_chat_log_tts_streaming( snapshot: SnapshotAssertion, mock_tts_entity: MockTTSEntity, pipeline_data: assist_pipeline.pipeline.PipelineData, - to_stream_tts: list[str], + to_stream_deltas: tuple[dict | list[str]], expected_chunks: int, chunk_text: str, ) -> None: """Test that chat log events are streamed to the TTS entity.""" + text_deltas = [ + delta + for deltas in to_stream_deltas + if isinstance(deltas, list) + for delta in deltas + ] + events: list[assist_pipeline.PipelineEvent] = [] pipeline_store = pipeline_data.pipeline_store @@ -1678,7 +1725,7 @@ async def test_chat_log_tts_streaming( options: dict[str, Any] | None = None, ) -> tts.TtsAudioType: """Mock get TTS audio.""" - return ("mp3", b"".join([chunk.encode() for chunk in to_stream_tts])) + return ("mp3", b"".join([chunk.encode() for chunk in text_deltas])) mock_tts_entity.async_get_tts_audio = async_get_tts_audio mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio @@ -1716,9 +1763,13 @@ async def test_chat_log_tts_streaming( ) async def stream_llm_response(): - yield {"role": "assistant"} - for chunk in to_stream_tts: - yield {"content": chunk} + for deltas in to_stream_deltas: + if isinstance(deltas, dict): + yield deltas + else: + yield {"role": "assistant"} + for chunk in deltas: + yield {"content": chunk} with ( chat_session.async_get_chat_session(hass, conversation_id) as session, @@ -1728,21 +1779,39 @@ async def test_chat_log_tts_streaming( conversation_input, ) as chat_log, ): + await chat_log.async_update_llm_data( + conversing_domain="test", + user_input=conversation_input, + user_llm_hass_api="assist", + user_llm_prompt=None, + ) async for _content in chat_log.async_add_delta_content_stream( agent_id, stream_llm_response() ): pass intent_response = intent.IntentResponse(language) - intent_response.async_set_speech("".join(to_stream_tts)) + intent_response.async_set_speech("".join(to_stream_deltas[-1])) return conversation.ConversationResult( response=intent_response, conversation_id=chat_log.conversation_id, continue_conversation=chat_log.continue_conversation, ) - with patch( - "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse", - mock_converse, + mock_tool = AsyncMock() + mock_tool.name = "test_tool" + mock_tool.description = "Test function" + mock_tool.parameters = vol.Schema({}) + mock_tool.async_call.return_value = "Test response" + + with ( + patch( + "homeassistant.helpers.llm.AssistAPI._async_get_tools", + return_value=[mock_tool], + ), + patch( + "homeassistant.components.assist_pipeline.pipeline.conversation.async_converse", + mock_converse, + ), ): await pipeline_input.execute() @@ -1752,7 +1821,7 @@ async def test_chat_log_tts_streaming( [chunk.decode() async for chunk in stream.async_stream_result()] ) - streamed_text = "".join(to_stream_tts) + streamed_text = "".join(text_deltas) assert tts_result == streamed_text assert len(received_tts) == expected_chunks assert "".join(received_tts) == chunk_text