Pipeline to stream TTS on tool call (#145477)

This commit is contained in:
Paulus Schoutsen 2025-05-25 15:59:07 -04:00 committed by GitHub
parent f472bf7c87
commit 1cc2baa95e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 376 additions and 58 deletions

View File

@ -1178,25 +1178,33 @@ class PipelineRun:
if role := delta.get("role"): if role := delta.get("role"):
chat_log_role = role chat_log_role = role
# We are only interested in assistant deltas with content # We are only interested in assistant deltas
if chat_log_role != "assistant" or not ( if chat_log_role != "assistant":
content := delta.get("content")
):
return return
tts_input_stream.put_nowait(content) if content := delta.get("content"):
tts_input_stream.put_nowait(content)
if self._streamed_response_text: if self._streamed_response_text:
return return
nonlocal delta_character_count nonlocal delta_character_count
delta_character_count += len(content) # Streamed responses are not cached. That's why we only start streaming text after
if delta_character_count < STREAM_RESPONSE_CHARS: # we have received enough characters that indicates it will be a long response
# or if we have received text, and then a tool call.
# Tool call after we already received text
start_streaming = delta_character_count > 0 and delta.get("tool_calls")
# Count characters in the content and test if we exceed streaming threshold
if not start_streaming and content:
delta_character_count += len(content)
start_streaming = delta_character_count > STREAM_RESPONSE_CHARS
if not start_streaming:
return return
# Streamed responses are not cached. We only start streaming text after
# we have received a couple of words that indicates it will be a long response.
self._streamed_response_text = True self._streamed_response_text = True
async def tts_input_stream_generator() -> AsyncGenerator[str]: async def tts_input_stream_generator() -> AsyncGenerator[str]:
@ -1204,6 +1212,17 @@ class PipelineRun:
while (tts_input := await tts_input_stream.get()) is not None: while (tts_input := await tts_input_stream.get()) is not None:
yield tts_input yield tts_input
# Concatenate all existing queue items
parts = []
while not tts_input_stream.empty():
parts.append(tts_input_stream.get_nowait())
tts_input_stream.put_nowait(
"".join(
# At this point parts is only strings, None indicates end of queue
cast(list[str], parts)
)
)
assert self.tts_stream is not None assert self.tts_stream is not None
self.tts_stream.async_set_message_stream(tts_input_stream_generator()) self.tts_stream.async_set_message_stream(tts_input_stream_generator())

View File

@ -1,5 +1,5 @@
# serializer version: 1 # serializer version: 1
# name: test_chat_log_tts_streaming[to_stream_tts0-0-] # name: test_chat_log_tts_streaming[to_stream_deltas0-0-]
list([ list([
dict({ dict({
'data': dict({ 'data': dict({
@ -154,7 +154,7 @@
}), }),
]) ])
# --- # ---
# name: test_chat_log_tts_streaming[to_stream_tts1-16-hello, how are you? I'm doing well, thank you. What about you?] # name: test_chat_log_tts_streaming[to_stream_deltas1-3-hello, how are you? I'm doing well, thank you. What about you?!]
list([ list([
dict({ dict({
'data': dict({ 'data': dict({
@ -317,10 +317,18 @@
}), }),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>, 'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}), }),
dict({
'data': dict({
'chat_log_delta': dict({
'content': '!',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({ dict({
'data': dict({ 'data': dict({
'intent_output': dict({ 'intent_output': dict({
'continue_conversation': True, 'continue_conversation': False,
'conversation_id': <ANY>, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
@ -338,7 +346,7 @@
'speech': dict({ 'speech': dict({
'plain': dict({ 'plain': dict({
'extra_data': None, 'extra_data': None,
'speech': "hello, how are you? I'm doing well, thank you. What about you?", 'speech': "hello, how are you? I'm doing well, thank you. What about you?!",
}), }),
}), }),
}), }),
@ -351,7 +359,229 @@
'data': dict({ 'data': dict({
'engine': 'tts.test', 'engine': 'tts.test',
'language': 'en_US', 'language': 'en_US',
'tts_input': "hello, how are you? I'm doing well, thank you. What about you?", 'tts_input': "hello, how are you? I'm doing well, thank you. What about you?!",
'voice': None,
}),
'type': <PipelineEventType.TTS_START: 'tts-start'>,
}),
dict({
'data': dict({
'tts_output': dict({
'media_id': 'media-source://tts/-stream-/mocked-token.mp3',
'mime_type': 'audio/mpeg',
'token': 'mocked-token.mp3',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
}),
'type': <PipelineEventType.TTS_END: 'tts-end'>,
}),
dict({
'data': None,
'type': <PipelineEventType.RUN_END: 'run-end'>,
}),
])
# ---
# name: test_chat_log_tts_streaming[to_stream_deltas2-8-hello, how are you? I'm doing well, thank you.]
list([
dict({
'data': dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'tts_output': dict({
'mime_type': 'audio/mpeg',
'stream_response': True,
'token': 'mocked-token.mp3',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
}),
'type': <PipelineEventType.RUN_START: 'run-start'>,
}),
dict({
'data': dict({
'conversation_id': 'mock-ulid',
'device_id': None,
'engine': 'test-agent',
'intent_input': 'Set a timer',
'language': 'en',
'prefer_local_intents': False,
}),
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'role': 'assistant',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': 'hello, ',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': 'how ',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': 'are ',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': 'you',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': '? ',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'tool_calls': list([
dict({
'id': 'test_tool_id',
'tool_args': dict({
}),
'tool_name': 'test_tool',
}),
]),
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'agent_id': 'test-agent',
'role': 'tool_result',
'tool_call_id': 'test_tool_id',
'tool_name': 'test_tool',
'tool_result': 'Test response',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'role': 'assistant',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': "I'm ",
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': 'doing ',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': 'well',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': ', ',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': 'thank ',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': 'you',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': '.',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'intent_output': dict({
'continue_conversation': False,
'conversation_id': <ANY>,
'response': dict({
'card': dict({
}),
'data': dict({
'failed': list([
]),
'success': list([
]),
'targets': list([
]),
}),
'language': 'en',
'response_type': 'action_done',
'speech': dict({
'plain': dict({
'extra_data': None,
'speech': "I'm doing well, thank you.",
}),
}),
}),
}),
'processed_locally': False,
}),
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
}),
dict({
'data': dict({
'engine': 'tts.test',
'language': 'en_US',
'tts_input': "I'm doing well, thank you.",
'voice': None, 'voice': None,
}), }),
'type': <PipelineEventType.TTS_START: 'tts-start'>, 'type': <PipelineEventType.TTS_START: 'tts-start'>,

View File

@ -2,11 +2,12 @@
from collections.abc import AsyncGenerator, Generator from collections.abc import AsyncGenerator, Generator
from typing import Any from typing import Any
from unittest.mock import ANY, Mock, patch from unittest.mock import ANY, AsyncMock, Mock, patch
from hassil.recognize import Intent, IntentData, RecognizeResult from hassil.recognize import Intent, IntentData, RecognizeResult
import pytest import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
import voluptuous as vol
from homeassistant.components import ( from homeassistant.components import (
assist_pipeline, assist_pipeline,
@ -33,7 +34,7 @@ from homeassistant.components.assist_pipeline.pipeline import (
) )
from homeassistant.const import MATCH_ALL from homeassistant.const import MATCH_ALL
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import chat_session, intent from homeassistant.helpers import chat_session, intent, llm
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from . import MANY_LANGUAGES, process_events from . import MANY_LANGUAGES, process_events
@ -1575,47 +1576,86 @@ async def test_pipeline_language_used_instead_of_conversation_language(
@pytest.mark.parametrize( @pytest.mark.parametrize(
("to_stream_tts", "expected_chunks", "chunk_text"), ("to_stream_deltas", "expected_chunks", "chunk_text"),
[ [
# Size below STREAM_RESPONSE_CHUNKS # Size below STREAM_RESPONSE_CHUNKS
( (
[ (
"hello,", [
" ", "hello,",
"how", " ",
" ", "how",
"are", " ",
" ", "are",
"you", " ",
"?", "you",
], "?",
],
),
# We are not streaming, so 0 chunks via streaming method # We are not streaming, so 0 chunks via streaming method
0, 0,
"", "",
), ),
# Size above STREAM_RESPONSE_CHUNKS # Size above STREAM_RESPONSE_CHUNKS
( (
[ (
"hello, ", [
"how ", "hello, ",
"are ", "how ",
"you", "are ",
"? ", "you",
"I'm ", "? ",
"doing ", "I'm ",
"well", "doing ",
", ", "well",
"thank ", ", ",
"you", "thank ",
". ", "you",
"What ", ". ",
"about ", "What ",
"you", "about ",
"?", "you",
], "?",
# We are streamed, so equal to count above list items "!",
16, ],
"hello, how are you? I'm doing well, thank you. What about you?", ),
# We are streamed. First 15 chunks are grouped into 1 chunk
# and the rest are streamed
3,
"hello, how are you? I'm doing well, thank you. What about you?!",
),
# Stream a bit, then a tool call, then stream some more
(
(
[
"hello, ",
"how ",
"are ",
"you",
"? ",
],
{
"tool_calls": [
llm.ToolInput(
tool_name="test_tool",
tool_args={},
id="test_tool_id",
)
],
},
[
"I'm ",
"doing ",
"well",
", ",
"thank ",
"you",
".",
],
),
# 1 chunk before tool call, then 7 after
8,
"hello, how are you? I'm doing well, thank you.",
), ),
], ],
) )
@ -1627,11 +1667,18 @@ async def test_chat_log_tts_streaming(
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
mock_tts_entity: MockTTSEntity, mock_tts_entity: MockTTSEntity,
pipeline_data: assist_pipeline.pipeline.PipelineData, pipeline_data: assist_pipeline.pipeline.PipelineData,
to_stream_tts: list[str], to_stream_deltas: tuple[dict | list[str]],
expected_chunks: int, expected_chunks: int,
chunk_text: str, chunk_text: str,
) -> None: ) -> None:
"""Test that chat log events are streamed to the TTS entity.""" """Test that chat log events are streamed to the TTS entity."""
text_deltas = [
delta
for deltas in to_stream_deltas
if isinstance(deltas, list)
for delta in deltas
]
events: list[assist_pipeline.PipelineEvent] = [] events: list[assist_pipeline.PipelineEvent] = []
pipeline_store = pipeline_data.pipeline_store pipeline_store = pipeline_data.pipeline_store
@ -1678,7 +1725,7 @@ async def test_chat_log_tts_streaming(
options: dict[str, Any] | None = None, options: dict[str, Any] | None = None,
) -> tts.TtsAudioType: ) -> tts.TtsAudioType:
"""Mock get TTS audio.""" """Mock get TTS audio."""
return ("mp3", b"".join([chunk.encode() for chunk in to_stream_tts])) return ("mp3", b"".join([chunk.encode() for chunk in text_deltas]))
mock_tts_entity.async_get_tts_audio = async_get_tts_audio mock_tts_entity.async_get_tts_audio = async_get_tts_audio
mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio
@ -1716,9 +1763,13 @@ async def test_chat_log_tts_streaming(
) )
async def stream_llm_response(): async def stream_llm_response():
yield {"role": "assistant"} for deltas in to_stream_deltas:
for chunk in to_stream_tts: if isinstance(deltas, dict):
yield {"content": chunk} yield deltas
else:
yield {"role": "assistant"}
for chunk in deltas:
yield {"content": chunk}
with ( with (
chat_session.async_get_chat_session(hass, conversation_id) as session, chat_session.async_get_chat_session(hass, conversation_id) as session,
@ -1728,21 +1779,39 @@ async def test_chat_log_tts_streaming(
conversation_input, conversation_input,
) as chat_log, ) as chat_log,
): ):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=conversation_input,
user_llm_hass_api="assist",
user_llm_prompt=None,
)
async for _content in chat_log.async_add_delta_content_stream( async for _content in chat_log.async_add_delta_content_stream(
agent_id, stream_llm_response() agent_id, stream_llm_response()
): ):
pass pass
intent_response = intent.IntentResponse(language) intent_response = intent.IntentResponse(language)
intent_response.async_set_speech("".join(to_stream_tts)) intent_response.async_set_speech("".join(to_stream_deltas[-1]))
return conversation.ConversationResult( return conversation.ConversationResult(
response=intent_response, response=intent_response,
conversation_id=chat_log.conversation_id, conversation_id=chat_log.conversation_id,
continue_conversation=chat_log.continue_conversation, continue_conversation=chat_log.continue_conversation,
) )
with patch( mock_tool = AsyncMock()
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse", mock_tool.name = "test_tool"
mock_converse, mock_tool.description = "Test function"
mock_tool.parameters = vol.Schema({})
mock_tool.async_call.return_value = "Test response"
with (
patch(
"homeassistant.helpers.llm.AssistAPI._async_get_tools",
return_value=[mock_tool],
),
patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
mock_converse,
),
): ):
await pipeline_input.execute() await pipeline_input.execute()
@ -1752,7 +1821,7 @@ async def test_chat_log_tts_streaming(
[chunk.decode() async for chunk in stream.async_stream_result()] [chunk.decode() async for chunk in stream.async_stream_result()]
) )
streamed_text = "".join(to_stream_tts) streamed_text = "".join(text_deltas)
assert tts_result == streamed_text assert tts_result == streamed_text
assert len(received_tts) == expected_chunks assert len(received_tts) == expected_chunks
assert "".join(received_tts) == chunk_text assert "".join(received_tts) == chunk_text