mirror of
https://github.com/home-assistant/core.git
synced 2025-07-27 15:17:35 +00:00
Pipeline to stream TTS on tool call (#145477)
This commit is contained in:
parent
f472bf7c87
commit
1cc2baa95e
@ -1178,25 +1178,33 @@ class PipelineRun:
|
|||||||
if role := delta.get("role"):
|
if role := delta.get("role"):
|
||||||
chat_log_role = role
|
chat_log_role = role
|
||||||
|
|
||||||
# We are only interested in assistant deltas with content
|
# We are only interested in assistant deltas
|
||||||
if chat_log_role != "assistant" or not (
|
if chat_log_role != "assistant":
|
||||||
content := delta.get("content")
|
|
||||||
):
|
|
||||||
return
|
return
|
||||||
|
|
||||||
tts_input_stream.put_nowait(content)
|
if content := delta.get("content"):
|
||||||
|
tts_input_stream.put_nowait(content)
|
||||||
|
|
||||||
if self._streamed_response_text:
|
if self._streamed_response_text:
|
||||||
return
|
return
|
||||||
|
|
||||||
nonlocal delta_character_count
|
nonlocal delta_character_count
|
||||||
|
|
||||||
delta_character_count += len(content)
|
# Streamed responses are not cached. That's why we only start streaming text after
|
||||||
if delta_character_count < STREAM_RESPONSE_CHARS:
|
# we have received enough characters that indicates it will be a long response
|
||||||
|
# or if we have received text, and then a tool call.
|
||||||
|
|
||||||
|
# Tool call after we already received text
|
||||||
|
start_streaming = delta_character_count > 0 and delta.get("tool_calls")
|
||||||
|
|
||||||
|
# Count characters in the content and test if we exceed streaming threshold
|
||||||
|
if not start_streaming and content:
|
||||||
|
delta_character_count += len(content)
|
||||||
|
start_streaming = delta_character_count > STREAM_RESPONSE_CHARS
|
||||||
|
|
||||||
|
if not start_streaming:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Streamed responses are not cached. We only start streaming text after
|
|
||||||
# we have received a couple of words that indicates it will be a long response.
|
|
||||||
self._streamed_response_text = True
|
self._streamed_response_text = True
|
||||||
|
|
||||||
async def tts_input_stream_generator() -> AsyncGenerator[str]:
|
async def tts_input_stream_generator() -> AsyncGenerator[str]:
|
||||||
@ -1204,6 +1212,17 @@ class PipelineRun:
|
|||||||
while (tts_input := await tts_input_stream.get()) is not None:
|
while (tts_input := await tts_input_stream.get()) is not None:
|
||||||
yield tts_input
|
yield tts_input
|
||||||
|
|
||||||
|
# Concatenate all existing queue items
|
||||||
|
parts = []
|
||||||
|
while not tts_input_stream.empty():
|
||||||
|
parts.append(tts_input_stream.get_nowait())
|
||||||
|
tts_input_stream.put_nowait(
|
||||||
|
"".join(
|
||||||
|
# At this point parts is only strings, None indicates end of queue
|
||||||
|
cast(list[str], parts)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
assert self.tts_stream is not None
|
assert self.tts_stream is not None
|
||||||
self.tts_stream.async_set_message_stream(tts_input_stream_generator())
|
self.tts_stream.async_set_message_stream(tts_input_stream_generator())
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# serializer version: 1
|
# serializer version: 1
|
||||||
# name: test_chat_log_tts_streaming[to_stream_tts0-0-]
|
# name: test_chat_log_tts_streaming[to_stream_deltas0-0-]
|
||||||
list([
|
list([
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
@ -154,7 +154,7 @@
|
|||||||
}),
|
}),
|
||||||
])
|
])
|
||||||
# ---
|
# ---
|
||||||
# name: test_chat_log_tts_streaming[to_stream_tts1-16-hello, how are you? I'm doing well, thank you. What about you?]
|
# name: test_chat_log_tts_streaming[to_stream_deltas1-3-hello, how are you? I'm doing well, thank you. What about you?!]
|
||||||
list([
|
list([
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
@ -317,10 +317,18 @@
|
|||||||
}),
|
}),
|
||||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
}),
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': '!',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'intent_output': dict({
|
'intent_output': dict({
|
||||||
'continue_conversation': True,
|
'continue_conversation': False,
|
||||||
'conversation_id': <ANY>,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
@ -338,7 +346,7 @@
|
|||||||
'speech': dict({
|
'speech': dict({
|
||||||
'plain': dict({
|
'plain': dict({
|
||||||
'extra_data': None,
|
'extra_data': None,
|
||||||
'speech': "hello, how are you? I'm doing well, thank you. What about you?",
|
'speech': "hello, how are you? I'm doing well, thank you. What about you?!",
|
||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
@ -351,7 +359,229 @@
|
|||||||
'data': dict({
|
'data': dict({
|
||||||
'engine': 'tts.test',
|
'engine': 'tts.test',
|
||||||
'language': 'en_US',
|
'language': 'en_US',
|
||||||
'tts_input': "hello, how are you? I'm doing well, thank you. What about you?",
|
'tts_input': "hello, how are you? I'm doing well, thank you. What about you?!",
|
||||||
|
'voice': None,
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.TTS_START: 'tts-start'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'tts_output': dict({
|
||||||
|
'media_id': 'media-source://tts/-stream-/mocked-token.mp3',
|
||||||
|
'mime_type': 'audio/mpeg',
|
||||||
|
'token': 'mocked-token.mp3',
|
||||||
|
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.TTS_END: 'tts-end'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': None,
|
||||||
|
'type': <PipelineEventType.RUN_END: 'run-end'>,
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
|
# name: test_chat_log_tts_streaming[to_stream_deltas2-8-hello, how are you? I'm doing well, thank you.]
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
|
'language': 'en',
|
||||||
|
'pipeline': <ANY>,
|
||||||
|
'tts_output': dict({
|
||||||
|
'mime_type': 'audio/mpeg',
|
||||||
|
'stream_response': True,
|
||||||
|
'token': 'mocked-token.mp3',
|
||||||
|
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
|
'device_id': None,
|
||||||
|
'engine': 'test-agent',
|
||||||
|
'intent_input': 'Set a timer',
|
||||||
|
'language': 'en',
|
||||||
|
'prefer_local_intents': False,
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'role': 'assistant',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': 'hello, ',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': 'how ',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': 'are ',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': 'you',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': '? ',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'tool_calls': list([
|
||||||
|
dict({
|
||||||
|
'id': 'test_tool_id',
|
||||||
|
'tool_args': dict({
|
||||||
|
}),
|
||||||
|
'tool_name': 'test_tool',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'agent_id': 'test-agent',
|
||||||
|
'role': 'tool_result',
|
||||||
|
'tool_call_id': 'test_tool_id',
|
||||||
|
'tool_name': 'test_tool',
|
||||||
|
'tool_result': 'Test response',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'role': 'assistant',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': "I'm ",
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': 'doing ',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': 'well',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': ', ',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': 'thank ',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': 'you',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': '.',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'intent_output': dict({
|
||||||
|
'continue_conversation': False,
|
||||||
|
'conversation_id': <ANY>,
|
||||||
|
'response': dict({
|
||||||
|
'card': dict({
|
||||||
|
}),
|
||||||
|
'data': dict({
|
||||||
|
'failed': list([
|
||||||
|
]),
|
||||||
|
'success': list([
|
||||||
|
]),
|
||||||
|
'targets': list([
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'language': 'en',
|
||||||
|
'response_type': 'action_done',
|
||||||
|
'speech': dict({
|
||||||
|
'plain': dict({
|
||||||
|
'extra_data': None,
|
||||||
|
'speech': "I'm doing well, thank you.",
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'processed_locally': False,
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'engine': 'tts.test',
|
||||||
|
'language': 'en_US',
|
||||||
|
'tts_input': "I'm doing well, thank you.",
|
||||||
'voice': None,
|
'voice': None,
|
||||||
}),
|
}),
|
||||||
'type': <PipelineEventType.TTS_START: 'tts-start'>,
|
'type': <PipelineEventType.TTS_START: 'tts-start'>,
|
||||||
|
@ -2,11 +2,12 @@
|
|||||||
|
|
||||||
from collections.abc import AsyncGenerator, Generator
|
from collections.abc import AsyncGenerator, Generator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import ANY, Mock, patch
|
from unittest.mock import ANY, AsyncMock, Mock, patch
|
||||||
|
|
||||||
from hassil.recognize import Intent, IntentData, RecognizeResult
|
from hassil.recognize import Intent, IntentData, RecognizeResult
|
||||||
import pytest
|
import pytest
|
||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components import (
|
from homeassistant.components import (
|
||||||
assist_pipeline,
|
assist_pipeline,
|
||||||
@ -33,7 +34,7 @@ from homeassistant.components.assist_pipeline.pipeline import (
|
|||||||
)
|
)
|
||||||
from homeassistant.const import MATCH_ALL
|
from homeassistant.const import MATCH_ALL
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
from homeassistant.helpers import chat_session, intent
|
from homeassistant.helpers import chat_session, intent, llm
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from . import MANY_LANGUAGES, process_events
|
from . import MANY_LANGUAGES, process_events
|
||||||
@ -1575,47 +1576,86 @@ async def test_pipeline_language_used_instead_of_conversation_language(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("to_stream_tts", "expected_chunks", "chunk_text"),
|
("to_stream_deltas", "expected_chunks", "chunk_text"),
|
||||||
[
|
[
|
||||||
# Size below STREAM_RESPONSE_CHUNKS
|
# Size below STREAM_RESPONSE_CHUNKS
|
||||||
(
|
(
|
||||||
[
|
(
|
||||||
"hello,",
|
[
|
||||||
" ",
|
"hello,",
|
||||||
"how",
|
" ",
|
||||||
" ",
|
"how",
|
||||||
"are",
|
" ",
|
||||||
" ",
|
"are",
|
||||||
"you",
|
" ",
|
||||||
"?",
|
"you",
|
||||||
],
|
"?",
|
||||||
|
],
|
||||||
|
),
|
||||||
# We are not streaming, so 0 chunks via streaming method
|
# We are not streaming, so 0 chunks via streaming method
|
||||||
0,
|
0,
|
||||||
"",
|
"",
|
||||||
),
|
),
|
||||||
# Size above STREAM_RESPONSE_CHUNKS
|
# Size above STREAM_RESPONSE_CHUNKS
|
||||||
(
|
(
|
||||||
[
|
(
|
||||||
"hello, ",
|
[
|
||||||
"how ",
|
"hello, ",
|
||||||
"are ",
|
"how ",
|
||||||
"you",
|
"are ",
|
||||||
"? ",
|
"you",
|
||||||
"I'm ",
|
"? ",
|
||||||
"doing ",
|
"I'm ",
|
||||||
"well",
|
"doing ",
|
||||||
", ",
|
"well",
|
||||||
"thank ",
|
", ",
|
||||||
"you",
|
"thank ",
|
||||||
". ",
|
"you",
|
||||||
"What ",
|
". ",
|
||||||
"about ",
|
"What ",
|
||||||
"you",
|
"about ",
|
||||||
"?",
|
"you",
|
||||||
],
|
"?",
|
||||||
# We are streamed, so equal to count above list items
|
"!",
|
||||||
16,
|
],
|
||||||
"hello, how are you? I'm doing well, thank you. What about you?",
|
),
|
||||||
|
# We are streamed. First 15 chunks are grouped into 1 chunk
|
||||||
|
# and the rest are streamed
|
||||||
|
3,
|
||||||
|
"hello, how are you? I'm doing well, thank you. What about you?!",
|
||||||
|
),
|
||||||
|
# Stream a bit, then a tool call, then stream some more
|
||||||
|
(
|
||||||
|
(
|
||||||
|
[
|
||||||
|
"hello, ",
|
||||||
|
"how ",
|
||||||
|
"are ",
|
||||||
|
"you",
|
||||||
|
"? ",
|
||||||
|
],
|
||||||
|
{
|
||||||
|
"tool_calls": [
|
||||||
|
llm.ToolInput(
|
||||||
|
tool_name="test_tool",
|
||||||
|
tool_args={},
|
||||||
|
id="test_tool_id",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
},
|
||||||
|
[
|
||||||
|
"I'm ",
|
||||||
|
"doing ",
|
||||||
|
"well",
|
||||||
|
", ",
|
||||||
|
"thank ",
|
||||||
|
"you",
|
||||||
|
".",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
# 1 chunk before tool call, then 7 after
|
||||||
|
8,
|
||||||
|
"hello, how are you? I'm doing well, thank you.",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -1627,11 +1667,18 @@ async def test_chat_log_tts_streaming(
|
|||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
mock_tts_entity: MockTTSEntity,
|
mock_tts_entity: MockTTSEntity,
|
||||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||||
to_stream_tts: list[str],
|
to_stream_deltas: tuple[dict | list[str]],
|
||||||
expected_chunks: int,
|
expected_chunks: int,
|
||||||
chunk_text: str,
|
chunk_text: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that chat log events are streamed to the TTS entity."""
|
"""Test that chat log events are streamed to the TTS entity."""
|
||||||
|
text_deltas = [
|
||||||
|
delta
|
||||||
|
for deltas in to_stream_deltas
|
||||||
|
if isinstance(deltas, list)
|
||||||
|
for delta in deltas
|
||||||
|
]
|
||||||
|
|
||||||
events: list[assist_pipeline.PipelineEvent] = []
|
events: list[assist_pipeline.PipelineEvent] = []
|
||||||
|
|
||||||
pipeline_store = pipeline_data.pipeline_store
|
pipeline_store = pipeline_data.pipeline_store
|
||||||
@ -1678,7 +1725,7 @@ async def test_chat_log_tts_streaming(
|
|||||||
options: dict[str, Any] | None = None,
|
options: dict[str, Any] | None = None,
|
||||||
) -> tts.TtsAudioType:
|
) -> tts.TtsAudioType:
|
||||||
"""Mock get TTS audio."""
|
"""Mock get TTS audio."""
|
||||||
return ("mp3", b"".join([chunk.encode() for chunk in to_stream_tts]))
|
return ("mp3", b"".join([chunk.encode() for chunk in text_deltas]))
|
||||||
|
|
||||||
mock_tts_entity.async_get_tts_audio = async_get_tts_audio
|
mock_tts_entity.async_get_tts_audio = async_get_tts_audio
|
||||||
mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio
|
mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio
|
||||||
@ -1716,9 +1763,13 @@ async def test_chat_log_tts_streaming(
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def stream_llm_response():
|
async def stream_llm_response():
|
||||||
yield {"role": "assistant"}
|
for deltas in to_stream_deltas:
|
||||||
for chunk in to_stream_tts:
|
if isinstance(deltas, dict):
|
||||||
yield {"content": chunk}
|
yield deltas
|
||||||
|
else:
|
||||||
|
yield {"role": "assistant"}
|
||||||
|
for chunk in deltas:
|
||||||
|
yield {"content": chunk}
|
||||||
|
|
||||||
with (
|
with (
|
||||||
chat_session.async_get_chat_session(hass, conversation_id) as session,
|
chat_session.async_get_chat_session(hass, conversation_id) as session,
|
||||||
@ -1728,21 +1779,39 @@ async def test_chat_log_tts_streaming(
|
|||||||
conversation_input,
|
conversation_input,
|
||||||
) as chat_log,
|
) as chat_log,
|
||||||
):
|
):
|
||||||
|
await chat_log.async_update_llm_data(
|
||||||
|
conversing_domain="test",
|
||||||
|
user_input=conversation_input,
|
||||||
|
user_llm_hass_api="assist",
|
||||||
|
user_llm_prompt=None,
|
||||||
|
)
|
||||||
async for _content in chat_log.async_add_delta_content_stream(
|
async for _content in chat_log.async_add_delta_content_stream(
|
||||||
agent_id, stream_llm_response()
|
agent_id, stream_llm_response()
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
intent_response = intent.IntentResponse(language)
|
intent_response = intent.IntentResponse(language)
|
||||||
intent_response.async_set_speech("".join(to_stream_tts))
|
intent_response.async_set_speech("".join(to_stream_deltas[-1]))
|
||||||
return conversation.ConversationResult(
|
return conversation.ConversationResult(
|
||||||
response=intent_response,
|
response=intent_response,
|
||||||
conversation_id=chat_log.conversation_id,
|
conversation_id=chat_log.conversation_id,
|
||||||
continue_conversation=chat_log.continue_conversation,
|
continue_conversation=chat_log.continue_conversation,
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
mock_tool = AsyncMock()
|
||||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
mock_tool.name = "test_tool"
|
||||||
mock_converse,
|
mock_tool.description = "Test function"
|
||||||
|
mock_tool.parameters = vol.Schema({})
|
||||||
|
mock_tool.async_call.return_value = "Test response"
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.helpers.llm.AssistAPI._async_get_tools",
|
||||||
|
return_value=[mock_tool],
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
||||||
|
mock_converse,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
await pipeline_input.execute()
|
await pipeline_input.execute()
|
||||||
|
|
||||||
@ -1752,7 +1821,7 @@ async def test_chat_log_tts_streaming(
|
|||||||
[chunk.decode() async for chunk in stream.async_stream_result()]
|
[chunk.decode() async for chunk in stream.async_stream_result()]
|
||||||
)
|
)
|
||||||
|
|
||||||
streamed_text = "".join(to_stream_tts)
|
streamed_text = "".join(text_deltas)
|
||||||
assert tts_result == streamed_text
|
assert tts_result == streamed_text
|
||||||
assert len(received_tts) == expected_chunks
|
assert len(received_tts) == expected_chunks
|
||||||
assert "".join(received_tts) == chunk_text
|
assert "".join(received_tts) == chunk_text
|
||||||
|
Loading…
x
Reference in New Issue
Block a user