Pipeline to stream TTS on tool call (#145477)

This commit is contained in:
Paulus Schoutsen 2025-05-25 15:59:07 -04:00 committed by GitHub
parent f472bf7c87
commit 1cc2baa95e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 376 additions and 58 deletions

View File

@ -1178,25 +1178,33 @@ class PipelineRun:
if role := delta.get("role"):
chat_log_role = role
# We are only interested in assistant deltas with content
if chat_log_role != "assistant" or not (
content := delta.get("content")
):
# We are only interested in assistant deltas
if chat_log_role != "assistant":
return
tts_input_stream.put_nowait(content)
if content := delta.get("content"):
tts_input_stream.put_nowait(content)
if self._streamed_response_text:
return
nonlocal delta_character_count
delta_character_count += len(content)
if delta_character_count < STREAM_RESPONSE_CHARS:
# Streamed responses are not cached. That's why we only start streaming text after
# we have received enough characters that indicates it will be a long response
# or if we have received text, and then a tool call.
# Tool call after we already received text
start_streaming = delta_character_count > 0 and delta.get("tool_calls")
# Count characters in the content and test if we exceed streaming threshold
if not start_streaming and content:
delta_character_count += len(content)
start_streaming = delta_character_count > STREAM_RESPONSE_CHARS
if not start_streaming:
return
# Streamed responses are not cached. We only start streaming text after
# we have received a couple of words that indicates it will be a long response.
self._streamed_response_text = True
async def tts_input_stream_generator() -> AsyncGenerator[str]:
@ -1204,6 +1212,17 @@ class PipelineRun:
while (tts_input := await tts_input_stream.get()) is not None:
yield tts_input
# Concatenate all existing queue items
parts = []
while not tts_input_stream.empty():
parts.append(tts_input_stream.get_nowait())
tts_input_stream.put_nowait(
"".join(
# At this point parts is only strings, None indicates end of queue
cast(list[str], parts)
)
)
assert self.tts_stream is not None
self.tts_stream.async_set_message_stream(tts_input_stream_generator())

View File

@ -1,5 +1,5 @@
# serializer version: 1
# name: test_chat_log_tts_streaming[to_stream_tts0-0-]
# name: test_chat_log_tts_streaming[to_stream_deltas0-0-]
list([
dict({
'data': dict({
@ -154,7 +154,7 @@
}),
])
# ---
# name: test_chat_log_tts_streaming[to_stream_tts1-16-hello, how are you? I'm doing well, thank you. What about you?]
# name: test_chat_log_tts_streaming[to_stream_deltas1-3-hello, how are you? I'm doing well, thank you. What about you?!]
list([
dict({
'data': dict({
@ -317,10 +317,18 @@
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': '!',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'intent_output': dict({
'continue_conversation': True,
'continue_conversation': False,
'conversation_id': <ANY>,
'response': dict({
'card': dict({
@ -338,7 +346,7 @@
'speech': dict({
'plain': dict({
'extra_data': None,
'speech': "hello, how are you? I'm doing well, thank you. What about you?",
'speech': "hello, how are you? I'm doing well, thank you. What about you?!",
}),
}),
}),
@ -351,7 +359,229 @@
'data': dict({
'engine': 'tts.test',
'language': 'en_US',
'tts_input': "hello, how are you? I'm doing well, thank you. What about you?",
'tts_input': "hello, how are you? I'm doing well, thank you. What about you?!",
'voice': None,
}),
'type': <PipelineEventType.TTS_START: 'tts-start'>,
}),
dict({
'data': dict({
'tts_output': dict({
'media_id': 'media-source://tts/-stream-/mocked-token.mp3',
'mime_type': 'audio/mpeg',
'token': 'mocked-token.mp3',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
}),
'type': <PipelineEventType.TTS_END: 'tts-end'>,
}),
dict({
'data': None,
'type': <PipelineEventType.RUN_END: 'run-end'>,
}),
])
# ---
# name: test_chat_log_tts_streaming[to_stream_deltas2-8-hello, how are you? I'm doing well, thank you.]
list([
dict({
'data': dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'tts_output': dict({
'mime_type': 'audio/mpeg',
'stream_response': True,
'token': 'mocked-token.mp3',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
}),
'type': <PipelineEventType.RUN_START: 'run-start'>,
}),
dict({
'data': dict({
'conversation_id': 'mock-ulid',
'device_id': None,
'engine': 'test-agent',
'intent_input': 'Set a timer',
'language': 'en',
'prefer_local_intents': False,
}),
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'role': 'assistant',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': 'hello, ',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': 'how ',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': 'are ',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': 'you',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': '? ',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'tool_calls': list([
dict({
'id': 'test_tool_id',
'tool_args': dict({
}),
'tool_name': 'test_tool',
}),
]),
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'agent_id': 'test-agent',
'role': 'tool_result',
'tool_call_id': 'test_tool_id',
'tool_name': 'test_tool',
'tool_result': 'Test response',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'role': 'assistant',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': "I'm ",
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': 'doing ',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': 'well',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': ', ',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': 'thank ',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': 'you',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': '.',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'intent_output': dict({
'continue_conversation': False,
'conversation_id': <ANY>,
'response': dict({
'card': dict({
}),
'data': dict({
'failed': list([
]),
'success': list([
]),
'targets': list([
]),
}),
'language': 'en',
'response_type': 'action_done',
'speech': dict({
'plain': dict({
'extra_data': None,
'speech': "I'm doing well, thank you.",
}),
}),
}),
}),
'processed_locally': False,
}),
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
}),
dict({
'data': dict({
'engine': 'tts.test',
'language': 'en_US',
'tts_input': "I'm doing well, thank you.",
'voice': None,
}),
'type': <PipelineEventType.TTS_START: 'tts-start'>,

View File

@ -2,11 +2,12 @@
from collections.abc import AsyncGenerator, Generator
from typing import Any
from unittest.mock import ANY, Mock, patch
from unittest.mock import ANY, AsyncMock, Mock, patch
from hassil.recognize import Intent, IntentData, RecognizeResult
import pytest
from syrupy.assertion import SnapshotAssertion
import voluptuous as vol
from homeassistant.components import (
assist_pipeline,
@ -33,7 +34,7 @@ from homeassistant.components.assist_pipeline.pipeline import (
)
from homeassistant.const import MATCH_ALL
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import chat_session, intent
from homeassistant.helpers import chat_session, intent, llm
from homeassistant.setup import async_setup_component
from . import MANY_LANGUAGES, process_events
@ -1575,47 +1576,86 @@ async def test_pipeline_language_used_instead_of_conversation_language(
@pytest.mark.parametrize(
("to_stream_tts", "expected_chunks", "chunk_text"),
("to_stream_deltas", "expected_chunks", "chunk_text"),
[
# Size below STREAM_RESPONSE_CHUNKS
(
[
"hello,",
" ",
"how",
" ",
"are",
" ",
"you",
"?",
],
(
[
"hello,",
" ",
"how",
" ",
"are",
" ",
"you",
"?",
],
),
# We are not streaming, so 0 chunks via streaming method
0,
"",
),
# Size above STREAM_RESPONSE_CHUNKS
(
[
"hello, ",
"how ",
"are ",
"you",
"? ",
"I'm ",
"doing ",
"well",
", ",
"thank ",
"you",
". ",
"What ",
"about ",
"you",
"?",
],
# We are streamed, so equal to count above list items
16,
"hello, how are you? I'm doing well, thank you. What about you?",
(
[
"hello, ",
"how ",
"are ",
"you",
"? ",
"I'm ",
"doing ",
"well",
", ",
"thank ",
"you",
". ",
"What ",
"about ",
"you",
"?",
"!",
],
),
# We are streamed. First 15 chunks are grouped into 1 chunk
# and the rest are streamed
3,
"hello, how are you? I'm doing well, thank you. What about you?!",
),
# Stream a bit, then a tool call, then stream some more
(
(
[
"hello, ",
"how ",
"are ",
"you",
"? ",
],
{
"tool_calls": [
llm.ToolInput(
tool_name="test_tool",
tool_args={},
id="test_tool_id",
)
],
},
[
"I'm ",
"doing ",
"well",
", ",
"thank ",
"you",
".",
],
),
# 1 chunk before tool call, then 7 after
8,
"hello, how are you? I'm doing well, thank you.",
),
],
)
@ -1627,11 +1667,18 @@ async def test_chat_log_tts_streaming(
snapshot: SnapshotAssertion,
mock_tts_entity: MockTTSEntity,
pipeline_data: assist_pipeline.pipeline.PipelineData,
to_stream_tts: list[str],
to_stream_deltas: tuple[dict | list[str]],
expected_chunks: int,
chunk_text: str,
) -> None:
"""Test that chat log events are streamed to the TTS entity."""
text_deltas = [
delta
for deltas in to_stream_deltas
if isinstance(deltas, list)
for delta in deltas
]
events: list[assist_pipeline.PipelineEvent] = []
pipeline_store = pipeline_data.pipeline_store
@ -1678,7 +1725,7 @@ async def test_chat_log_tts_streaming(
options: dict[str, Any] | None = None,
) -> tts.TtsAudioType:
"""Mock get TTS audio."""
return ("mp3", b"".join([chunk.encode() for chunk in to_stream_tts]))
return ("mp3", b"".join([chunk.encode() for chunk in text_deltas]))
mock_tts_entity.async_get_tts_audio = async_get_tts_audio
mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio
@ -1716,9 +1763,13 @@ async def test_chat_log_tts_streaming(
)
async def stream_llm_response():
yield {"role": "assistant"}
for chunk in to_stream_tts:
yield {"content": chunk}
for deltas in to_stream_deltas:
if isinstance(deltas, dict):
yield deltas
else:
yield {"role": "assistant"}
for chunk in deltas:
yield {"content": chunk}
with (
chat_session.async_get_chat_session(hass, conversation_id) as session,
@ -1728,21 +1779,39 @@ async def test_chat_log_tts_streaming(
conversation_input,
) as chat_log,
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=conversation_input,
user_llm_hass_api="assist",
user_llm_prompt=None,
)
async for _content in chat_log.async_add_delta_content_stream(
agent_id, stream_llm_response()
):
pass
intent_response = intent.IntentResponse(language)
intent_response.async_set_speech("".join(to_stream_tts))
intent_response.async_set_speech("".join(to_stream_deltas[-1]))
return conversation.ConversationResult(
response=intent_response,
conversation_id=chat_log.conversation_id,
continue_conversation=chat_log.continue_conversation,
)
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
mock_converse,
mock_tool = AsyncMock()
mock_tool.name = "test_tool"
mock_tool.description = "Test function"
mock_tool.parameters = vol.Schema({})
mock_tool.async_call.return_value = "Test response"
with (
patch(
"homeassistant.helpers.llm.AssistAPI._async_get_tools",
return_value=[mock_tool],
),
patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
mock_converse,
),
):
await pipeline_input.execute()
@ -1752,7 +1821,7 @@ async def test_chat_log_tts_streaming(
[chunk.decode() async for chunk in stream.async_stream_result()]
)
streamed_text = "".join(to_stream_tts)
streamed_text = "".join(text_deltas)
assert tts_result == streamed_text
assert len(received_tts) == expected_chunks
assert "".join(received_tts) == chunk_text