Assist Pipeline stream TTS when supported and long response (#145264)

* Assist Pipeline stream TTS when supported and long response

* Indicate in run-start if streaming supported

* Simplify a little bit

* Trigger streaming based on characters

* 60
This commit is contained in:
Paulus Schoutsen 2025-05-20 14:00:27 -04:00 committed by GitHub
parent 37e13505cf
commit abcf925b79
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 363 additions and 16 deletions

View File

@ -89,6 +89,8 @@ KEY_ASSIST_PIPELINE: HassKey[PipelineData] = HassKey(DOMAIN)
KEY_PIPELINE_CONVERSATION_DATA: HassKey[dict[str, PipelineConversationData]] = HassKey(
"pipeline_conversation_data"
)
# Number of response parts to handle before streaming the response
STREAM_RESPONSE_CHARS = 60
def validate_language(data: dict[str, Any]) -> Any:
@ -552,7 +554,7 @@ class PipelineRun:
event_callback: PipelineEventCallback
language: str = None # type: ignore[assignment]
runner_data: Any | None = None
intent_agent: str | None = None
intent_agent: conversation.AgentInfo | None = None
tts_audio_output: str | dict[str, Any] | None = None
wake_word_settings: WakeWordSettings | None = None
audio_settings: AudioSettings = field(default_factory=AudioSettings)
@ -588,6 +590,9 @@ class PipelineRun:
_intent_agent_only = False
"""If request should only be handled by agent, ignoring sentence triggers and local processing."""
_streamed_response_text = False
"""If the conversation agent streamed response text to TTS result."""
def __post_init__(self) -> None:
"""Set language for pipeline."""
self.language = self.pipeline.language or self.hass.config.language
@ -649,6 +654,11 @@ class PipelineRun:
"token": self.tts_stream.token,
"url": self.tts_stream.url,
"mime_type": self.tts_stream.content_type,
"stream_response": (
self.tts_stream.supports_streaming_input
and self.intent_agent
and self.intent_agent.supports_streaming
),
}
self.process_event(PipelineEvent(PipelineEventType.RUN_START, data))
@ -896,12 +906,12 @@ class PipelineRun:
) -> str:
"""Run speech-to-text portion of pipeline. Returns the spoken text."""
# Create a background task to prepare the conversation agent
if self.end_stage >= PipelineStage.INTENT:
if self.end_stage >= PipelineStage.INTENT and self.intent_agent:
self.hass.async_create_background_task(
conversation.async_prepare_agent(
self.hass, self.intent_agent, self.language
self.hass, self.intent_agent.id, self.language
),
f"prepare conversation agent {self.intent_agent}",
f"prepare conversation agent {self.intent_agent.id}",
)
if isinstance(self.stt_provider, stt.Provider):
@ -1042,7 +1052,7 @@ class PipelineRun:
message=f"Intent recognition engine {engine} is not found",
)
self.intent_agent = agent_info.id
self.intent_agent = agent_info
async def recognize_intent(
self,
@ -1075,7 +1085,7 @@ class PipelineRun:
PipelineEvent(
PipelineEventType.INTENT_START,
{
"engine": self.intent_agent,
"engine": self.intent_agent.id,
"language": input_language,
"intent_input": intent_input,
"conversation_id": conversation_id,
@ -1092,11 +1102,11 @@ class PipelineRun:
conversation_id=conversation_id,
device_id=device_id,
language=input_language,
agent_id=self.intent_agent,
agent_id=self.intent_agent.id,
extra_system_prompt=conversation_extra_system_prompt,
)
agent_id = self.intent_agent
agent_id = self.intent_agent.id
processed_locally = agent_id == conversation.HOME_ASSISTANT_AGENT
intent_response: intent.IntentResponse | None = None
if not processed_locally and not self._intent_agent_only:
@ -1118,7 +1128,7 @@ class PipelineRun:
# If the LLM has API access, we filter out some sentences that are
# interfering with LLM operation.
if (
intent_agent_state := self.hass.states.get(self.intent_agent)
intent_agent_state := self.hass.states.get(self.intent_agent.id)
) and intent_agent_state.attributes.get(
ATTR_SUPPORTED_FEATURES, 0
) & conversation.ConversationEntityFeature.CONTROL:
@ -1140,6 +1150,13 @@ class PipelineRun:
agent_id = conversation.HOME_ASSISTANT_AGENT
processed_locally = True
if self.tts_stream and self.tts_stream.supports_streaming_input:
tts_input_stream: asyncio.Queue[str | None] | None = asyncio.Queue()
else:
tts_input_stream = None
chat_log_role = None
delta_character_count = 0
@callback
def chat_log_delta_listener(
chat_log: conversation.ChatLog, delta: dict
@ -1153,6 +1170,42 @@ class PipelineRun:
},
)
)
if tts_input_stream is None:
return
nonlocal chat_log_role
if role := delta.get("role"):
chat_log_role = role
# We are only interested in assistant deltas with content
if chat_log_role != "assistant" or not (
content := delta.get("content")
):
return
tts_input_stream.put_nowait(content)
if self._streamed_response_text:
return
nonlocal delta_character_count
delta_character_count += len(content)
if delta_character_count < STREAM_RESPONSE_CHARS:
return
# Streamed responses are not cached. We only start streaming text after
# we have received a couple of words that indicates it will be a long response.
self._streamed_response_text = True
async def tts_input_stream_generator() -> AsyncGenerator[str]:
"""Yield TTS input stream."""
while (tts_input := await tts_input_stream.get()) is not None:
yield tts_input
assert self.tts_stream is not None
self.tts_stream.async_set_message_stream(tts_input_stream_generator())
with (
chat_session.async_get_chat_session(
@ -1196,6 +1249,8 @@ class PipelineRun:
speech = conversation_result.response.speech.get("plain", {}).get(
"speech", ""
)
if tts_input_stream and self._streamed_response_text:
tts_input_stream.put_nowait(None)
except Exception as src_error:
_LOGGER.exception("Unexpected error during intent recognition")
@ -1273,7 +1328,8 @@ class PipelineRun:
)
)
self.tts_stream.async_set_message(tts_input)
if not self._streamed_response_text:
self.tts_stream.async_set_message(tts_input)
tts_output = {
"media_id": self.tts_stream.media_source_id,

View File

@ -8,6 +8,7 @@
'pipeline': <ANY>,
'tts_output': dict({
'mime_type': 'audio/mpeg',
'stream_response': False,
'token': 'test_token.mp3',
'url': '/api/tts_proxy/test_token.mp3',
}),
@ -107,6 +108,7 @@
'pipeline': <ANY>,
'tts_output': dict({
'mime_type': 'audio/mpeg',
'stream_response': False,
'token': 'test_token.mp3',
'url': '/api/tts_proxy/test_token.mp3',
}),
@ -206,6 +208,7 @@
'pipeline': <ANY>,
'tts_output': dict({
'mime_type': 'audio/mpeg',
'stream_response': False,
'token': 'test_token.mp3',
'url': '/api/tts_proxy/test_token.mp3',
}),
@ -305,6 +308,7 @@
'pipeline': <ANY>,
'tts_output': dict({
'mime_type': 'audio/mpeg',
'stream_response': False,
'token': 'test_token.mp3',
'url': '/api/tts_proxy/test_token.mp3',
}),
@ -428,6 +432,7 @@
'pipeline': <ANY>,
'tts_output': dict({
'mime_type': 'audio/mpeg',
'stream_response': False,
'token': 'mocked-token.mp3',
'url': '/api/tts_proxy/mocked-token.mp3',
}),

View File

@ -1,5 +1,5 @@
# serializer version: 1
# name: test_chat_log_tts_streaming[to_stream_tts0-1]
# name: test_chat_log_tts_streaming[to_stream_tts0-0-]
list([
dict({
'data': dict({
@ -8,6 +8,7 @@
'pipeline': <ANY>,
'tts_output': dict({
'mime_type': 'audio/mpeg',
'stream_response': True,
'token': 'mocked-token.mp3',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
@ -153,6 +154,225 @@
}),
])
# ---
# name: test_chat_log_tts_streaming[to_stream_tts1-16-hello, how are you? I'm doing well, thank you. What about you?]
list([
dict({
'data': dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'tts_output': dict({
'mime_type': 'audio/mpeg',
'stream_response': True,
'token': 'mocked-token.mp3',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
}),
'type': <PipelineEventType.RUN_START: 'run-start'>,
}),
dict({
'data': dict({
'conversation_id': 'mock-ulid',
'device_id': None,
'engine': 'test-agent',
'intent_input': 'Set a timer',
'language': 'en',
'prefer_local_intents': False,
}),
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'role': 'assistant',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': 'hello, ',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': 'how ',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': 'are ',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': 'you',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': '? ',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': "I'm ",
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': 'doing ',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': 'well',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': ', ',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': 'thank ',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': 'you',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': '. ',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': 'What ',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': 'about ',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': 'you',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'chat_log_delta': dict({
'content': '?',
}),
}),
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
}),
dict({
'data': dict({
'intent_output': dict({
'continue_conversation': True,
'conversation_id': <ANY>,
'response': dict({
'card': dict({
}),
'data': dict({
'failed': list([
]),
'success': list([
]),
'targets': list([
]),
}),
'language': 'en',
'response_type': 'action_done',
'speech': dict({
'plain': dict({
'extra_data': None,
'speech': "hello, how are you? I'm doing well, thank you. What about you?",
}),
}),
}),
}),
'processed_locally': False,
}),
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
}),
dict({
'data': dict({
'engine': 'tts.test',
'language': 'en_US',
'tts_input': "hello, how are you? I'm doing well, thank you. What about you?",
'voice': None,
}),
'type': <PipelineEventType.TTS_START: 'tts-start'>,
}),
dict({
'data': dict({
'tts_output': dict({
'media_id': 'media-source://tts/-stream-/mocked-token.mp3',
'mime_type': 'audio/mpeg',
'token': 'mocked-token.mp3',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
}),
'type': <PipelineEventType.TTS_END: 'tts-end'>,
}),
dict({
'data': None,
'type': <PipelineEventType.RUN_END: 'run-end'>,
}),
])
# ---
# name: test_pipeline_language_used_instead_of_conversation_language
list([
dict({
@ -321,6 +541,7 @@
'pipeline': <ANY>,
'tts_output': dict({
'mime_type': 'audio/mpeg',
'stream_response': False,
'token': 'mocked-token.mp3',
'url': '/api/tts_proxy/mocked-token.mp3',
}),

View File

@ -10,6 +10,7 @@
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'stream_response': False,
'token': 'test_token.mp3',
'url': '/api/tts_proxy/test_token.mp3',
}),
@ -101,6 +102,7 @@
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'stream_response': False,
'token': 'test_token.mp3',
'url': '/api/tts_proxy/test_token.mp3',
}),
@ -204,6 +206,7 @@
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'stream_response': False,
'token': 'test_token.mp3',
'url': '/api/tts_proxy/test_token.mp3',
}),
@ -295,6 +298,7 @@
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'stream_response': False,
'token': 'test_token.mp3',
'url': '/api/tts_proxy/test_token.mp3',
}),
@ -408,6 +412,7 @@
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'stream_response': False,
'token': 'test_token.mp3',
'url': '/api/tts_proxy/test_token.mp3',
}),
@ -616,6 +621,7 @@
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'stream_response': False,
'token': 'mocked-token.mp3',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
@ -670,6 +676,7 @@
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'stream_response': False,
'token': 'mocked-token.mp3',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
@ -686,6 +693,7 @@
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'stream_response': False,
'token': 'mocked-token.mp3',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
@ -702,6 +710,7 @@
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'stream_response': False,
'token': 'mocked-token.mp3',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
@ -718,6 +727,7 @@
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'stream_response': False,
'token': 'mocked-token.mp3',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
@ -734,6 +744,7 @@
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'stream_response': False,
'token': 'mocked-token.mp3',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
@ -868,6 +879,7 @@
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'stream_response': False,
'token': 'mocked-token.mp3',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
@ -884,6 +896,7 @@
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'stream_response': False,
'token': 'mocked-token.mp3',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
@ -941,6 +954,7 @@
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'stream_response': False,
'token': 'mocked-token.mp3',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
@ -957,6 +971,7 @@
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'stream_response': False,
'token': 'mocked-token.mp3',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
@ -1017,6 +1032,7 @@
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'stream_response': False,
'token': 'mocked-token.mp3',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
@ -1033,6 +1049,7 @@
}),
'tts_output': dict({
'mime_type': 'audio/mpeg',
'stream_response': False,
'token': 'mocked-token.mp3',
'url': '/api/tts_proxy/mocked-token.mp3',
}),

View File

@ -1575,8 +1575,9 @@ async def test_pipeline_language_used_instead_of_conversation_language(
@pytest.mark.parametrize(
("to_stream_tts", "expected_chunks"),
("to_stream_tts", "expected_chunks", "chunk_text"),
[
# Size below STREAM_RESPONSE_CHUNKS
(
[
"hello,",
@ -1588,7 +1589,33 @@ async def test_pipeline_language_used_instead_of_conversation_language(
"you",
"?",
],
1,
# We are not streaming, so 0 chunks via streaming method
0,
"",
),
# Size above STREAM_RESPONSE_CHUNKS
(
[
"hello, ",
"how ",
"are ",
"you",
"? ",
"I'm ",
"doing ",
"well",
", ",
"thank ",
"you",
". ",
"What ",
"about ",
"you",
"?",
],
# We are streamed, so equal to count above list items
16,
"hello, how are you? I'm doing well, thank you. What about you?",
),
],
)
@ -1602,6 +1629,7 @@ async def test_chat_log_tts_streaming(
pipeline_data: assist_pipeline.pipeline.PipelineData,
to_stream_tts: list[str],
expected_chunks: int,
chunk_text: str,
) -> None:
"""Test that chat log events are streamed to the TTS entity."""
events: list[assist_pipeline.PipelineEvent] = []
@ -1627,22 +1655,41 @@ async def test_chat_log_tts_streaming(
),
)
received_tts = []
async def async_stream_tts_audio(
request: tts.TTSAudioRequest,
) -> tts.TTSAudioResponse:
"""Mock stream TTS audio."""
async def gen_data():
async for msg in request.message_gen:
received_tts.append(msg)
yield msg.encode()
return tts.TTSAudioResponse(
extension="mp3",
data_gen=gen_data(),
)
async def async_get_tts_audio(
message: str,
language: str,
options: dict[str, Any] | None = None,
) -> tts.TTSAudioResponse:
) -> tts.TtsAudioType:
"""Mock get TTS audio."""
return ("mp3", b"".join([chunk.encode() for chunk in to_stream_tts]))
mock_tts_entity.async_get_tts_audio = async_get_tts_audio
mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio
mock_tts_entity.async_supports_streaming_input = Mock(return_value=True)
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
return_value=conversation.AgentInfo(
id="test-agent",
name="Test Agent",
supports_streaming=False,
supports_streaming=True,
),
):
await pipeline_input.validate()
@ -1707,6 +1754,7 @@ async def test_chat_log_tts_streaming(
streamed_text = "".join(to_stream_tts)
assert tts_result == streamed_text
assert expected_chunks == 1
assert len(received_tts) == expected_chunks
assert "".join(received_tts) == chunk_text
assert process_events(events) == snapshot