mirror of
https://github.com/home-assistant/core.git
synced 2025-11-09 19:09:32 +00:00
Assist Pipeline stream TTS when supported and long response (#145264)
* Assist Pipeline stream TTS when supported and long response * Indicate in run-start if streaming supported * Simplify a little bit * Trigger streaming based on characters * 60
This commit is contained in:
@@ -1575,8 +1575,9 @@ async def test_pipeline_language_used_instead_of_conversation_language(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("to_stream_tts", "expected_chunks"),
|
||||
("to_stream_tts", "expected_chunks", "chunk_text"),
|
||||
[
|
||||
# Size below STREAM_RESPONSE_CHUNKS
|
||||
(
|
||||
[
|
||||
"hello,",
|
||||
@@ -1588,7 +1589,33 @@ async def test_pipeline_language_used_instead_of_conversation_language(
|
||||
"you",
|
||||
"?",
|
||||
],
|
||||
1,
|
||||
# We are not streaming, so 0 chunks via streaming method
|
||||
0,
|
||||
"",
|
||||
),
|
||||
# Size above STREAM_RESPONSE_CHUNKS
|
||||
(
|
||||
[
|
||||
"hello, ",
|
||||
"how ",
|
||||
"are ",
|
||||
"you",
|
||||
"? ",
|
||||
"I'm ",
|
||||
"doing ",
|
||||
"well",
|
||||
", ",
|
||||
"thank ",
|
||||
"you",
|
||||
". ",
|
||||
"What ",
|
||||
"about ",
|
||||
"you",
|
||||
"?",
|
||||
],
|
||||
# We are streamed, so equal to count above list items
|
||||
16,
|
||||
"hello, how are you? I'm doing well, thank you. What about you?",
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -1602,6 +1629,7 @@ async def test_chat_log_tts_streaming(
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
to_stream_tts: list[str],
|
||||
expected_chunks: int,
|
||||
chunk_text: str,
|
||||
) -> None:
|
||||
"""Test that chat log events are streamed to the TTS entity."""
|
||||
events: list[assist_pipeline.PipelineEvent] = []
|
||||
@@ -1627,22 +1655,41 @@ async def test_chat_log_tts_streaming(
|
||||
),
|
||||
)
|
||||
|
||||
received_tts = []
|
||||
|
||||
async def async_stream_tts_audio(
|
||||
request: tts.TTSAudioRequest,
|
||||
) -> tts.TTSAudioResponse:
|
||||
"""Mock stream TTS audio."""
|
||||
|
||||
async def gen_data():
|
||||
async for msg in request.message_gen:
|
||||
received_tts.append(msg)
|
||||
yield msg.encode()
|
||||
|
||||
return tts.TTSAudioResponse(
|
||||
extension="mp3",
|
||||
data_gen=gen_data(),
|
||||
)
|
||||
|
||||
async def async_get_tts_audio(
|
||||
message: str,
|
||||
language: str,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> tts.TTSAudioResponse:
|
||||
) -> tts.TtsAudioType:
|
||||
"""Mock get TTS audio."""
|
||||
return ("mp3", b"".join([chunk.encode() for chunk in to_stream_tts]))
|
||||
|
||||
mock_tts_entity.async_get_tts_audio = async_get_tts_audio
|
||||
mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio
|
||||
mock_tts_entity.async_supports_streaming_input = Mock(return_value=True)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
||||
return_value=conversation.AgentInfo(
|
||||
id="test-agent",
|
||||
name="Test Agent",
|
||||
supports_streaming=False,
|
||||
supports_streaming=True,
|
||||
),
|
||||
):
|
||||
await pipeline_input.validate()
|
||||
@@ -1707,6 +1754,7 @@ async def test_chat_log_tts_streaming(
|
||||
|
||||
streamed_text = "".join(to_stream_tts)
|
||||
assert tts_result == streamed_text
|
||||
assert expected_chunks == 1
|
||||
assert len(received_tts) == expected_chunks
|
||||
assert "".join(received_tts) == chunk_text
|
||||
|
||||
assert process_events(events) == snapshot
|
||||
|
||||
Reference in New Issue
Block a user