Assist Pipeline stream TTS when supported and long response (#145264)

* Assist Pipeline stream TTS when supported and long response

* Indicate in run-start if streaming supported

* Simplify a little bit

* Trigger streaming based on characters

* 60
This commit is contained in:
Paulus Schoutsen
2025-05-20 14:00:27 -04:00
committed by GitHub
parent 37e13505cf
commit abcf925b79
5 changed files with 363 additions and 16 deletions

View File

@@ -1575,8 +1575,9 @@ async def test_pipeline_language_used_instead_of_conversation_language(
@pytest.mark.parametrize(
("to_stream_tts", "expected_chunks"),
("to_stream_tts", "expected_chunks", "chunk_text"),
[
# Size below STREAM_RESPONSE_CHUNKS
(
[
"hello,",
@@ -1588,7 +1589,33 @@ async def test_pipeline_language_used_instead_of_conversation_language(
"you",
"?",
],
1,
# We are not streaming, so 0 chunks via streaming method
0,
"",
),
# Size above STREAM_RESPONSE_CHUNKS
(
[
"hello, ",
"how ",
"are ",
"you",
"? ",
"I'm ",
"doing ",
"well",
", ",
"thank ",
"you",
". ",
"What ",
"about ",
"you",
"?",
],
# We are streamed, so equal to count above list items
16,
"hello, how are you? I'm doing well, thank you. What about you?",
),
],
)
@@ -1602,6 +1629,7 @@ async def test_chat_log_tts_streaming(
pipeline_data: assist_pipeline.pipeline.PipelineData,
to_stream_tts: list[str],
expected_chunks: int,
chunk_text: str,
) -> None:
"""Test that chat log events are streamed to the TTS entity."""
events: list[assist_pipeline.PipelineEvent] = []
@@ -1627,22 +1655,41 @@ async def test_chat_log_tts_streaming(
),
)
received_tts = []
async def async_stream_tts_audio(
request: tts.TTSAudioRequest,
) -> tts.TTSAudioResponse:
"""Mock stream TTS audio."""
async def gen_data():
async for msg in request.message_gen:
received_tts.append(msg)
yield msg.encode()
return tts.TTSAudioResponse(
extension="mp3",
data_gen=gen_data(),
)
async def async_get_tts_audio(
message: str,
language: str,
options: dict[str, Any] | None = None,
) -> tts.TTSAudioResponse:
) -> tts.TtsAudioType:
"""Mock get TTS audio."""
return ("mp3", b"".join([chunk.encode() for chunk in to_stream_tts]))
mock_tts_entity.async_get_tts_audio = async_get_tts_audio
mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio
mock_tts_entity.async_supports_streaming_input = Mock(return_value=True)
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
return_value=conversation.AgentInfo(
id="test-agent",
name="Test Agent",
supports_streaming=False,
supports_streaming=True,
),
):
await pipeline_input.validate()
@@ -1707,6 +1754,7 @@ async def test_chat_log_tts_streaming(
streamed_text = "".join(to_stream_tts)
assert tts_result == streamed_text
assert expected_chunks == 1
assert len(received_tts) == expected_chunks
assert "".join(received_tts) == chunk_text
assert process_events(events) == snapshot