Assist Pipeline stream TTS when supported and long response (#145264)

* Assist Pipeline stream TTS when supported and long response

* Indicate in run-start if streaming supported

* Simplify a little bit

* Trigger streaming based on characters

* 60
This commit is contained in:
Paulus Schoutsen
2025-05-20 14:00:27 -04:00
committed by GitHub
parent 37e13505cf
commit abcf925b79
5 changed files with 363 additions and 16 deletions

View File

@@ -89,6 +89,8 @@ KEY_ASSIST_PIPELINE: HassKey[PipelineData] = HassKey(DOMAIN)
KEY_PIPELINE_CONVERSATION_DATA: HassKey[dict[str, PipelineConversationData]] = HassKey(
"pipeline_conversation_data"
)
# Number of response parts to handle before streaming the response
STREAM_RESPONSE_CHARS = 60
def validate_language(data: dict[str, Any]) -> Any:
@@ -552,7 +554,7 @@ class PipelineRun:
event_callback: PipelineEventCallback
language: str = None # type: ignore[assignment]
runner_data: Any | None = None
intent_agent: str | None = None
intent_agent: conversation.AgentInfo | None = None
tts_audio_output: str | dict[str, Any] | None = None
wake_word_settings: WakeWordSettings | None = None
audio_settings: AudioSettings = field(default_factory=AudioSettings)
@@ -588,6 +590,9 @@ class PipelineRun:
_intent_agent_only = False
"""If request should only be handled by agent, ignoring sentence triggers and local processing."""
_streamed_response_text = False
"""If the conversation agent streamed response text to TTS result."""
def __post_init__(self) -> None:
"""Set language for pipeline."""
self.language = self.pipeline.language or self.hass.config.language
@@ -649,6 +654,11 @@ class PipelineRun:
"token": self.tts_stream.token,
"url": self.tts_stream.url,
"mime_type": self.tts_stream.content_type,
"stream_response": (
self.tts_stream.supports_streaming_input
and self.intent_agent
and self.intent_agent.supports_streaming
),
}
self.process_event(PipelineEvent(PipelineEventType.RUN_START, data))
@@ -896,12 +906,12 @@ class PipelineRun:
) -> str:
"""Run speech-to-text portion of pipeline. Returns the spoken text."""
# Create a background task to prepare the conversation agent
if self.end_stage >= PipelineStage.INTENT:
if self.end_stage >= PipelineStage.INTENT and self.intent_agent:
self.hass.async_create_background_task(
conversation.async_prepare_agent(
self.hass, self.intent_agent, self.language
self.hass, self.intent_agent.id, self.language
),
f"prepare conversation agent {self.intent_agent}",
f"prepare conversation agent {self.intent_agent.id}",
)
if isinstance(self.stt_provider, stt.Provider):
@@ -1042,7 +1052,7 @@ class PipelineRun:
message=f"Intent recognition engine {engine} is not found",
)
self.intent_agent = agent_info.id
self.intent_agent = agent_info
async def recognize_intent(
self,
@@ -1075,7 +1085,7 @@ class PipelineRun:
PipelineEvent(
PipelineEventType.INTENT_START,
{
"engine": self.intent_agent,
"engine": self.intent_agent.id,
"language": input_language,
"intent_input": intent_input,
"conversation_id": conversation_id,
@@ -1092,11 +1102,11 @@ class PipelineRun:
conversation_id=conversation_id,
device_id=device_id,
language=input_language,
agent_id=self.intent_agent,
agent_id=self.intent_agent.id,
extra_system_prompt=conversation_extra_system_prompt,
)
agent_id = self.intent_agent
agent_id = self.intent_agent.id
processed_locally = agent_id == conversation.HOME_ASSISTANT_AGENT
intent_response: intent.IntentResponse | None = None
if not processed_locally and not self._intent_agent_only:
@@ -1118,7 +1128,7 @@ class PipelineRun:
# If the LLM has API access, we filter out some sentences that are
# interfering with LLM operation.
if (
intent_agent_state := self.hass.states.get(self.intent_agent)
intent_agent_state := self.hass.states.get(self.intent_agent.id)
) and intent_agent_state.attributes.get(
ATTR_SUPPORTED_FEATURES, 0
) & conversation.ConversationEntityFeature.CONTROL:
@@ -1140,6 +1150,13 @@ class PipelineRun:
agent_id = conversation.HOME_ASSISTANT_AGENT
processed_locally = True
if self.tts_stream and self.tts_stream.supports_streaming_input:
tts_input_stream: asyncio.Queue[str | None] | None = asyncio.Queue()
else:
tts_input_stream = None
chat_log_role = None
delta_character_count = 0
@callback
def chat_log_delta_listener(
chat_log: conversation.ChatLog, delta: dict
@@ -1153,6 +1170,42 @@ class PipelineRun:
},
)
)
if tts_input_stream is None:
return
nonlocal chat_log_role
if role := delta.get("role"):
chat_log_role = role
# We are only interested in assistant deltas with content
if chat_log_role != "assistant" or not (
content := delta.get("content")
):
return
tts_input_stream.put_nowait(content)
if self._streamed_response_text:
return
nonlocal delta_character_count
delta_character_count += len(content)
if delta_character_count < STREAM_RESPONSE_CHARS:
return
# Streamed responses are not cached. We only start streaming text after
# we have received a couple of words that indicates it will be a long response.
self._streamed_response_text = True
async def tts_input_stream_generator() -> AsyncGenerator[str]:
"""Yield TTS input stream."""
while (tts_input := await tts_input_stream.get()) is not None:
yield tts_input
assert self.tts_stream is not None
self.tts_stream.async_set_message_stream(tts_input_stream_generator())
with (
chat_session.async_get_chat_session(
@@ -1196,6 +1249,8 @@ class PipelineRun:
speech = conversation_result.response.speech.get("plain", {}).get(
"speech", ""
)
if tts_input_stream and self._streamed_response_text:
tts_input_stream.put_nowait(None)
except Exception as src_error:
_LOGGER.exception("Unexpected error during intent recognition")
@@ -1273,7 +1328,8 @@ class PipelineRun:
)
)
self.tts_stream.async_set_message(tts_input)
if not self._streamed_response_text:
self.tts_stream.async_set_message(tts_input)
tts_output = {
"media_id": self.tts_stream.media_source_id,