mirror of
https://github.com/home-assistant/core.git
synced 2025-07-25 22:27:07 +00:00
Pipeline to stream TTS on tool call (#145477)
This commit is contained in:
parent
f472bf7c87
commit
1cc2baa95e
@ -1178,25 +1178,33 @@ class PipelineRun:
|
||||
if role := delta.get("role"):
|
||||
chat_log_role = role
|
||||
|
||||
# We are only interested in assistant deltas with content
|
||||
if chat_log_role != "assistant" or not (
|
||||
content := delta.get("content")
|
||||
):
|
||||
# We are only interested in assistant deltas
|
||||
if chat_log_role != "assistant":
|
||||
return
|
||||
|
||||
tts_input_stream.put_nowait(content)
|
||||
if content := delta.get("content"):
|
||||
tts_input_stream.put_nowait(content)
|
||||
|
||||
if self._streamed_response_text:
|
||||
return
|
||||
|
||||
nonlocal delta_character_count
|
||||
|
||||
delta_character_count += len(content)
|
||||
if delta_character_count < STREAM_RESPONSE_CHARS:
|
||||
# Streamed responses are not cached. That's why we only start streaming text after
|
||||
# we have received enough characters that indicates it will be a long response
|
||||
# or if we have received text, and then a tool call.
|
||||
|
||||
# Tool call after we already received text
|
||||
start_streaming = delta_character_count > 0 and delta.get("tool_calls")
|
||||
|
||||
# Count characters in the content and test if we exceed streaming threshold
|
||||
if not start_streaming and content:
|
||||
delta_character_count += len(content)
|
||||
start_streaming = delta_character_count > STREAM_RESPONSE_CHARS
|
||||
|
||||
if not start_streaming:
|
||||
return
|
||||
|
||||
# Streamed responses are not cached. We only start streaming text after
|
||||
# we have received a couple of words that indicates it will be a long response.
|
||||
self._streamed_response_text = True
|
||||
|
||||
async def tts_input_stream_generator() -> AsyncGenerator[str]:
|
||||
@ -1204,6 +1212,17 @@ class PipelineRun:
|
||||
while (tts_input := await tts_input_stream.get()) is not None:
|
||||
yield tts_input
|
||||
|
||||
# Concatenate all existing queue items
|
||||
parts = []
|
||||
while not tts_input_stream.empty():
|
||||
parts.append(tts_input_stream.get_nowait())
|
||||
tts_input_stream.put_nowait(
|
||||
"".join(
|
||||
# At this point parts is only strings, None indicates end of queue
|
||||
cast(list[str], parts)
|
||||
)
|
||||
)
|
||||
|
||||
assert self.tts_stream is not None
|
||||
self.tts_stream.async_set_message_stream(tts_input_stream_generator())
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
# serializer version: 1
|
||||
# name: test_chat_log_tts_streaming[to_stream_tts0-0-]
|
||||
# name: test_chat_log_tts_streaming[to_stream_deltas0-0-]
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
@ -154,7 +154,7 @@
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_chat_log_tts_streaming[to_stream_tts1-16-hello, how are you? I'm doing well, thank you. What about you?]
|
||||
# name: test_chat_log_tts_streaming[to_stream_deltas1-3-hello, how are you? I'm doing well, thank you. What about you?!]
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
@ -317,10 +317,18 @@
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': '!',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'intent_output': dict({
|
||||
'continue_conversation': True,
|
||||
'continue_conversation': False,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
@ -338,7 +346,7 @@
|
||||
'speech': dict({
|
||||
'plain': dict({
|
||||
'extra_data': None,
|
||||
'speech': "hello, how are you? I'm doing well, thank you. What about you?",
|
||||
'speech': "hello, how are you? I'm doing well, thank you. What about you?!",
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
@ -351,7 +359,229 @@
|
||||
'data': dict({
|
||||
'engine': 'tts.test',
|
||||
'language': 'en_US',
|
||||
'tts_input': "hello, how are you? I'm doing well, thank you. What about you?",
|
||||
'tts_input': "hello, how are you? I'm doing well, thank you. What about you?!",
|
||||
'voice': None,
|
||||
}),
|
||||
'type': <PipelineEventType.TTS_START: 'tts-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'tts_output': dict({
|
||||
'media_id': 'media-source://tts/-stream-/mocked-token.mp3',
|
||||
'mime_type': 'audio/mpeg',
|
||||
'token': 'mocked-token.mp3',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.TTS_END: 'tts-end'>,
|
||||
}),
|
||||
dict({
|
||||
'data': None,
|
||||
'type': <PipelineEventType.RUN_END: 'run-end'>,
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_chat_log_tts_streaming[to_stream_deltas2-8-hello, how are you? I'm doing well, thank you.]
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'stream_response': True,
|
||||
'token': 'mocked-token.mp3',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'device_id': None,
|
||||
'engine': 'test-agent',
|
||||
'intent_input': 'Set a timer',
|
||||
'language': 'en',
|
||||
'prefer_local_intents': False,
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'role': 'assistant',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': 'hello, ',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': 'how ',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': 'are ',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': 'you',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': '? ',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'tool_calls': list([
|
||||
dict({
|
||||
'id': 'test_tool_id',
|
||||
'tool_args': dict({
|
||||
}),
|
||||
'tool_name': 'test_tool',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'agent_id': 'test-agent',
|
||||
'role': 'tool_result',
|
||||
'tool_call_id': 'test_tool_id',
|
||||
'tool_name': 'test_tool',
|
||||
'tool_result': 'Test response',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'role': 'assistant',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': "I'm ",
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': 'doing ',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': 'well',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': ', ',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': 'thank ',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': 'you',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': '.',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'intent_output': dict({
|
||||
'continue_conversation': False,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
'data': dict({
|
||||
'failed': list([
|
||||
]),
|
||||
'success': list([
|
||||
]),
|
||||
'targets': list([
|
||||
]),
|
||||
}),
|
||||
'language': 'en',
|
||||
'response_type': 'action_done',
|
||||
'speech': dict({
|
||||
'plain': dict({
|
||||
'extra_data': None,
|
||||
'speech': "I'm doing well, thank you.",
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
'processed_locally': False,
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'engine': 'tts.test',
|
||||
'language': 'en_US',
|
||||
'tts_input': "I'm doing well, thank you.",
|
||||
'voice': None,
|
||||
}),
|
||||
'type': <PipelineEventType.TTS_START: 'tts-start'>,
|
||||
|
@ -2,11 +2,12 @@
|
||||
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
from typing import Any
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
from unittest.mock import ANY, AsyncMock, Mock, patch
|
||||
|
||||
from hassil.recognize import Intent, IntentData, RecognizeResult
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import (
|
||||
assist_pipeline,
|
||||
@ -33,7 +34,7 @@ from homeassistant.components.assist_pipeline.pipeline import (
|
||||
)
|
||||
from homeassistant.const import MATCH_ALL
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.helpers import chat_session, intent
|
||||
from homeassistant.helpers import chat_session, intent, llm
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import MANY_LANGUAGES, process_events
|
||||
@ -1575,47 +1576,86 @@ async def test_pipeline_language_used_instead_of_conversation_language(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("to_stream_tts", "expected_chunks", "chunk_text"),
|
||||
("to_stream_deltas", "expected_chunks", "chunk_text"),
|
||||
[
|
||||
# Size below STREAM_RESPONSE_CHUNKS
|
||||
(
|
||||
[
|
||||
"hello,",
|
||||
" ",
|
||||
"how",
|
||||
" ",
|
||||
"are",
|
||||
" ",
|
||||
"you",
|
||||
"?",
|
||||
],
|
||||
(
|
||||
[
|
||||
"hello,",
|
||||
" ",
|
||||
"how",
|
||||
" ",
|
||||
"are",
|
||||
" ",
|
||||
"you",
|
||||
"?",
|
||||
],
|
||||
),
|
||||
# We are not streaming, so 0 chunks via streaming method
|
||||
0,
|
||||
"",
|
||||
),
|
||||
# Size above STREAM_RESPONSE_CHUNKS
|
||||
(
|
||||
[
|
||||
"hello, ",
|
||||
"how ",
|
||||
"are ",
|
||||
"you",
|
||||
"? ",
|
||||
"I'm ",
|
||||
"doing ",
|
||||
"well",
|
||||
", ",
|
||||
"thank ",
|
||||
"you",
|
||||
". ",
|
||||
"What ",
|
||||
"about ",
|
||||
"you",
|
||||
"?",
|
||||
],
|
||||
# We are streamed, so equal to count above list items
|
||||
16,
|
||||
"hello, how are you? I'm doing well, thank you. What about you?",
|
||||
(
|
||||
[
|
||||
"hello, ",
|
||||
"how ",
|
||||
"are ",
|
||||
"you",
|
||||
"? ",
|
||||
"I'm ",
|
||||
"doing ",
|
||||
"well",
|
||||
", ",
|
||||
"thank ",
|
||||
"you",
|
||||
". ",
|
||||
"What ",
|
||||
"about ",
|
||||
"you",
|
||||
"?",
|
||||
"!",
|
||||
],
|
||||
),
|
||||
# We are streamed. First 15 chunks are grouped into 1 chunk
|
||||
# and the rest are streamed
|
||||
3,
|
||||
"hello, how are you? I'm doing well, thank you. What about you?!",
|
||||
),
|
||||
# Stream a bit, then a tool call, then stream some more
|
||||
(
|
||||
(
|
||||
[
|
||||
"hello, ",
|
||||
"how ",
|
||||
"are ",
|
||||
"you",
|
||||
"? ",
|
||||
],
|
||||
{
|
||||
"tool_calls": [
|
||||
llm.ToolInput(
|
||||
tool_name="test_tool",
|
||||
tool_args={},
|
||||
id="test_tool_id",
|
||||
)
|
||||
],
|
||||
},
|
||||
[
|
||||
"I'm ",
|
||||
"doing ",
|
||||
"well",
|
||||
", ",
|
||||
"thank ",
|
||||
"you",
|
||||
".",
|
||||
],
|
||||
),
|
||||
# 1 chunk before tool call, then 7 after
|
||||
8,
|
||||
"hello, how are you? I'm doing well, thank you.",
|
||||
),
|
||||
],
|
||||
)
|
||||
@ -1627,11 +1667,18 @@ async def test_chat_log_tts_streaming(
|
||||
snapshot: SnapshotAssertion,
|
||||
mock_tts_entity: MockTTSEntity,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
to_stream_tts: list[str],
|
||||
to_stream_deltas: tuple[dict | list[str]],
|
||||
expected_chunks: int,
|
||||
chunk_text: str,
|
||||
) -> None:
|
||||
"""Test that chat log events are streamed to the TTS entity."""
|
||||
text_deltas = [
|
||||
delta
|
||||
for deltas in to_stream_deltas
|
||||
if isinstance(deltas, list)
|
||||
for delta in deltas
|
||||
]
|
||||
|
||||
events: list[assist_pipeline.PipelineEvent] = []
|
||||
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
@ -1678,7 +1725,7 @@ async def test_chat_log_tts_streaming(
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> tts.TtsAudioType:
|
||||
"""Mock get TTS audio."""
|
||||
return ("mp3", b"".join([chunk.encode() for chunk in to_stream_tts]))
|
||||
return ("mp3", b"".join([chunk.encode() for chunk in text_deltas]))
|
||||
|
||||
mock_tts_entity.async_get_tts_audio = async_get_tts_audio
|
||||
mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio
|
||||
@ -1716,9 +1763,13 @@ async def test_chat_log_tts_streaming(
|
||||
)
|
||||
|
||||
async def stream_llm_response():
|
||||
yield {"role": "assistant"}
|
||||
for chunk in to_stream_tts:
|
||||
yield {"content": chunk}
|
||||
for deltas in to_stream_deltas:
|
||||
if isinstance(deltas, dict):
|
||||
yield deltas
|
||||
else:
|
||||
yield {"role": "assistant"}
|
||||
for chunk in deltas:
|
||||
yield {"content": chunk}
|
||||
|
||||
with (
|
||||
chat_session.async_get_chat_session(hass, conversation_id) as session,
|
||||
@ -1728,21 +1779,39 @@ async def test_chat_log_tts_streaming(
|
||||
conversation_input,
|
||||
) as chat_log,
|
||||
):
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=conversation_input,
|
||||
user_llm_hass_api="assist",
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
async for _content in chat_log.async_add_delta_content_stream(
|
||||
agent_id, stream_llm_response()
|
||||
):
|
||||
pass
|
||||
intent_response = intent.IntentResponse(language)
|
||||
intent_response.async_set_speech("".join(to_stream_tts))
|
||||
intent_response.async_set_speech("".join(to_stream_deltas[-1]))
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response,
|
||||
conversation_id=chat_log.conversation_id,
|
||||
continue_conversation=chat_log.continue_conversation,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
||||
mock_converse,
|
||||
mock_tool = AsyncMock()
|
||||
mock_tool.name = "test_tool"
|
||||
mock_tool.description = "Test function"
|
||||
mock_tool.parameters = vol.Schema({})
|
||||
mock_tool.async_call.return_value = "Test response"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.helpers.llm.AssistAPI._async_get_tools",
|
||||
return_value=[mock_tool],
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
||||
mock_converse,
|
||||
),
|
||||
):
|
||||
await pipeline_input.execute()
|
||||
|
||||
@ -1752,7 +1821,7 @@ async def test_chat_log_tts_streaming(
|
||||
[chunk.decode() async for chunk in stream.async_stream_result()]
|
||||
)
|
||||
|
||||
streamed_text = "".join(to_stream_tts)
|
||||
streamed_text = "".join(text_deltas)
|
||||
assert tts_result == streamed_text
|
||||
assert len(received_tts) == expected_chunks
|
||||
assert "".join(received_tts) == chunk_text
|
||||
|
Loading…
x
Reference in New Issue
Block a user