mirror of
https://github.com/home-assistant/core.git
synced 2025-07-25 14:17:45 +00:00
Assist Pipeline stream TTS when supported and long response (#145264)
* Assist Pipeline stream TTS when supported and long response * Indicate in run-start if streaming supported * Simplify a little bit * Trigger streaming based on characters * 60
This commit is contained in:
parent
37e13505cf
commit
abcf925b79
@ -89,6 +89,8 @@ KEY_ASSIST_PIPELINE: HassKey[PipelineData] = HassKey(DOMAIN)
|
||||
KEY_PIPELINE_CONVERSATION_DATA: HassKey[dict[str, PipelineConversationData]] = HassKey(
|
||||
"pipeline_conversation_data"
|
||||
)
|
||||
# Number of response parts to handle before streaming the response
|
||||
STREAM_RESPONSE_CHARS = 60
|
||||
|
||||
|
||||
def validate_language(data: dict[str, Any]) -> Any:
|
||||
@ -552,7 +554,7 @@ class PipelineRun:
|
||||
event_callback: PipelineEventCallback
|
||||
language: str = None # type: ignore[assignment]
|
||||
runner_data: Any | None = None
|
||||
intent_agent: str | None = None
|
||||
intent_agent: conversation.AgentInfo | None = None
|
||||
tts_audio_output: str | dict[str, Any] | None = None
|
||||
wake_word_settings: WakeWordSettings | None = None
|
||||
audio_settings: AudioSettings = field(default_factory=AudioSettings)
|
||||
@ -588,6 +590,9 @@ class PipelineRun:
|
||||
_intent_agent_only = False
|
||||
"""If request should only be handled by agent, ignoring sentence triggers and local processing."""
|
||||
|
||||
_streamed_response_text = False
|
||||
"""If the conversation agent streamed response text to TTS result."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Set language for pipeline."""
|
||||
self.language = self.pipeline.language or self.hass.config.language
|
||||
@ -649,6 +654,11 @@ class PipelineRun:
|
||||
"token": self.tts_stream.token,
|
||||
"url": self.tts_stream.url,
|
||||
"mime_type": self.tts_stream.content_type,
|
||||
"stream_response": (
|
||||
self.tts_stream.supports_streaming_input
|
||||
and self.intent_agent
|
||||
and self.intent_agent.supports_streaming
|
||||
),
|
||||
}
|
||||
|
||||
self.process_event(PipelineEvent(PipelineEventType.RUN_START, data))
|
||||
@ -896,12 +906,12 @@ class PipelineRun:
|
||||
) -> str:
|
||||
"""Run speech-to-text portion of pipeline. Returns the spoken text."""
|
||||
# Create a background task to prepare the conversation agent
|
||||
if self.end_stage >= PipelineStage.INTENT:
|
||||
if self.end_stage >= PipelineStage.INTENT and self.intent_agent:
|
||||
self.hass.async_create_background_task(
|
||||
conversation.async_prepare_agent(
|
||||
self.hass, self.intent_agent, self.language
|
||||
self.hass, self.intent_agent.id, self.language
|
||||
),
|
||||
f"prepare conversation agent {self.intent_agent}",
|
||||
f"prepare conversation agent {self.intent_agent.id}",
|
||||
)
|
||||
|
||||
if isinstance(self.stt_provider, stt.Provider):
|
||||
@ -1042,7 +1052,7 @@ class PipelineRun:
|
||||
message=f"Intent recognition engine {engine} is not found",
|
||||
)
|
||||
|
||||
self.intent_agent = agent_info.id
|
||||
self.intent_agent = agent_info
|
||||
|
||||
async def recognize_intent(
|
||||
self,
|
||||
@ -1075,7 +1085,7 @@ class PipelineRun:
|
||||
PipelineEvent(
|
||||
PipelineEventType.INTENT_START,
|
||||
{
|
||||
"engine": self.intent_agent,
|
||||
"engine": self.intent_agent.id,
|
||||
"language": input_language,
|
||||
"intent_input": intent_input,
|
||||
"conversation_id": conversation_id,
|
||||
@ -1092,11 +1102,11 @@ class PipelineRun:
|
||||
conversation_id=conversation_id,
|
||||
device_id=device_id,
|
||||
language=input_language,
|
||||
agent_id=self.intent_agent,
|
||||
agent_id=self.intent_agent.id,
|
||||
extra_system_prompt=conversation_extra_system_prompt,
|
||||
)
|
||||
|
||||
agent_id = self.intent_agent
|
||||
agent_id = self.intent_agent.id
|
||||
processed_locally = agent_id == conversation.HOME_ASSISTANT_AGENT
|
||||
intent_response: intent.IntentResponse | None = None
|
||||
if not processed_locally and not self._intent_agent_only:
|
||||
@ -1118,7 +1128,7 @@ class PipelineRun:
|
||||
# If the LLM has API access, we filter out some sentences that are
|
||||
# interfering with LLM operation.
|
||||
if (
|
||||
intent_agent_state := self.hass.states.get(self.intent_agent)
|
||||
intent_agent_state := self.hass.states.get(self.intent_agent.id)
|
||||
) and intent_agent_state.attributes.get(
|
||||
ATTR_SUPPORTED_FEATURES, 0
|
||||
) & conversation.ConversationEntityFeature.CONTROL:
|
||||
@ -1140,6 +1150,13 @@ class PipelineRun:
|
||||
agent_id = conversation.HOME_ASSISTANT_AGENT
|
||||
processed_locally = True
|
||||
|
||||
if self.tts_stream and self.tts_stream.supports_streaming_input:
|
||||
tts_input_stream: asyncio.Queue[str | None] | None = asyncio.Queue()
|
||||
else:
|
||||
tts_input_stream = None
|
||||
chat_log_role = None
|
||||
delta_character_count = 0
|
||||
|
||||
@callback
|
||||
def chat_log_delta_listener(
|
||||
chat_log: conversation.ChatLog, delta: dict
|
||||
@ -1153,6 +1170,42 @@ class PipelineRun:
|
||||
},
|
||||
)
|
||||
)
|
||||
if tts_input_stream is None:
|
||||
return
|
||||
|
||||
nonlocal chat_log_role
|
||||
|
||||
if role := delta.get("role"):
|
||||
chat_log_role = role
|
||||
|
||||
# We are only interested in assistant deltas with content
|
||||
if chat_log_role != "assistant" or not (
|
||||
content := delta.get("content")
|
||||
):
|
||||
return
|
||||
|
||||
tts_input_stream.put_nowait(content)
|
||||
|
||||
if self._streamed_response_text:
|
||||
return
|
||||
|
||||
nonlocal delta_character_count
|
||||
|
||||
delta_character_count += len(content)
|
||||
if delta_character_count < STREAM_RESPONSE_CHARS:
|
||||
return
|
||||
|
||||
# Streamed responses are not cached. We only start streaming text after
|
||||
# we have received a couple of words that indicates it will be a long response.
|
||||
self._streamed_response_text = True
|
||||
|
||||
async def tts_input_stream_generator() -> AsyncGenerator[str]:
|
||||
"""Yield TTS input stream."""
|
||||
while (tts_input := await tts_input_stream.get()) is not None:
|
||||
yield tts_input
|
||||
|
||||
assert self.tts_stream is not None
|
||||
self.tts_stream.async_set_message_stream(tts_input_stream_generator())
|
||||
|
||||
with (
|
||||
chat_session.async_get_chat_session(
|
||||
@ -1196,6 +1249,8 @@ class PipelineRun:
|
||||
speech = conversation_result.response.speech.get("plain", {}).get(
|
||||
"speech", ""
|
||||
)
|
||||
if tts_input_stream and self._streamed_response_text:
|
||||
tts_input_stream.put_nowait(None)
|
||||
|
||||
except Exception as src_error:
|
||||
_LOGGER.exception("Unexpected error during intent recognition")
|
||||
@ -1273,7 +1328,8 @@ class PipelineRun:
|
||||
)
|
||||
)
|
||||
|
||||
self.tts_stream.async_set_message(tts_input)
|
||||
if not self._streamed_response_text:
|
||||
self.tts_stream.async_set_message(tts_input)
|
||||
|
||||
tts_output = {
|
||||
"media_id": self.tts_stream.media_source_id,
|
||||
|
@ -8,6 +8,7 @@
|
||||
'pipeline': <ANY>,
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'stream_response': False,
|
||||
'token': 'test_token.mp3',
|
||||
'url': '/api/tts_proxy/test_token.mp3',
|
||||
}),
|
||||
@ -107,6 +108,7 @@
|
||||
'pipeline': <ANY>,
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'stream_response': False,
|
||||
'token': 'test_token.mp3',
|
||||
'url': '/api/tts_proxy/test_token.mp3',
|
||||
}),
|
||||
@ -206,6 +208,7 @@
|
||||
'pipeline': <ANY>,
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'stream_response': False,
|
||||
'token': 'test_token.mp3',
|
||||
'url': '/api/tts_proxy/test_token.mp3',
|
||||
}),
|
||||
@ -305,6 +308,7 @@
|
||||
'pipeline': <ANY>,
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'stream_response': False,
|
||||
'token': 'test_token.mp3',
|
||||
'url': '/api/tts_proxy/test_token.mp3',
|
||||
}),
|
||||
@ -428,6 +432,7 @@
|
||||
'pipeline': <ANY>,
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'stream_response': False,
|
||||
'token': 'mocked-token.mp3',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
|
@ -1,5 +1,5 @@
|
||||
# serializer version: 1
|
||||
# name: test_chat_log_tts_streaming[to_stream_tts0-1]
|
||||
# name: test_chat_log_tts_streaming[to_stream_tts0-0-]
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
@ -8,6 +8,7 @@
|
||||
'pipeline': <ANY>,
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'stream_response': True,
|
||||
'token': 'mocked-token.mp3',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
@ -153,6 +154,225 @@
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_chat_log_tts_streaming[to_stream_tts1-16-hello, how are you? I'm doing well, thank you. What about you?]
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'stream_response': True,
|
||||
'token': 'mocked-token.mp3',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'device_id': None,
|
||||
'engine': 'test-agent',
|
||||
'intent_input': 'Set a timer',
|
||||
'language': 'en',
|
||||
'prefer_local_intents': False,
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'role': 'assistant',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': 'hello, ',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': 'how ',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': 'are ',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': 'you',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': '? ',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': "I'm ",
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': 'doing ',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': 'well',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': ', ',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': 'thank ',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': 'you',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': '. ',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': 'What ',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': 'about ',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': 'you',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'chat_log_delta': dict({
|
||||
'content': '?',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'intent_output': dict({
|
||||
'continue_conversation': True,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
'data': dict({
|
||||
'failed': list([
|
||||
]),
|
||||
'success': list([
|
||||
]),
|
||||
'targets': list([
|
||||
]),
|
||||
}),
|
||||
'language': 'en',
|
||||
'response_type': 'action_done',
|
||||
'speech': dict({
|
||||
'plain': dict({
|
||||
'extra_data': None,
|
||||
'speech': "hello, how are you? I'm doing well, thank you. What about you?",
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
'processed_locally': False,
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'engine': 'tts.test',
|
||||
'language': 'en_US',
|
||||
'tts_input': "hello, how are you? I'm doing well, thank you. What about you?",
|
||||
'voice': None,
|
||||
}),
|
||||
'type': <PipelineEventType.TTS_START: 'tts-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'tts_output': dict({
|
||||
'media_id': 'media-source://tts/-stream-/mocked-token.mp3',
|
||||
'mime_type': 'audio/mpeg',
|
||||
'token': 'mocked-token.mp3',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.TTS_END: 'tts-end'>,
|
||||
}),
|
||||
dict({
|
||||
'data': None,
|
||||
'type': <PipelineEventType.RUN_END: 'run-end'>,
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_pipeline_language_used_instead_of_conversation_language
|
||||
list([
|
||||
dict({
|
||||
@ -321,6 +541,7 @@
|
||||
'pipeline': <ANY>,
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'stream_response': False,
|
||||
'token': 'mocked-token.mp3',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
|
@ -10,6 +10,7 @@
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'stream_response': False,
|
||||
'token': 'test_token.mp3',
|
||||
'url': '/api/tts_proxy/test_token.mp3',
|
||||
}),
|
||||
@ -101,6 +102,7 @@
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'stream_response': False,
|
||||
'token': 'test_token.mp3',
|
||||
'url': '/api/tts_proxy/test_token.mp3',
|
||||
}),
|
||||
@ -204,6 +206,7 @@
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'stream_response': False,
|
||||
'token': 'test_token.mp3',
|
||||
'url': '/api/tts_proxy/test_token.mp3',
|
||||
}),
|
||||
@ -295,6 +298,7 @@
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'stream_response': False,
|
||||
'token': 'test_token.mp3',
|
||||
'url': '/api/tts_proxy/test_token.mp3',
|
||||
}),
|
||||
@ -408,6 +412,7 @@
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'stream_response': False,
|
||||
'token': 'test_token.mp3',
|
||||
'url': '/api/tts_proxy/test_token.mp3',
|
||||
}),
|
||||
@ -616,6 +621,7 @@
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'stream_response': False,
|
||||
'token': 'mocked-token.mp3',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
@ -670,6 +676,7 @@
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'stream_response': False,
|
||||
'token': 'mocked-token.mp3',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
@ -686,6 +693,7 @@
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'stream_response': False,
|
||||
'token': 'mocked-token.mp3',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
@ -702,6 +710,7 @@
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'stream_response': False,
|
||||
'token': 'mocked-token.mp3',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
@ -718,6 +727,7 @@
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'stream_response': False,
|
||||
'token': 'mocked-token.mp3',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
@ -734,6 +744,7 @@
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'stream_response': False,
|
||||
'token': 'mocked-token.mp3',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
@ -868,6 +879,7 @@
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'stream_response': False,
|
||||
'token': 'mocked-token.mp3',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
@ -884,6 +896,7 @@
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'stream_response': False,
|
||||
'token': 'mocked-token.mp3',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
@ -941,6 +954,7 @@
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'stream_response': False,
|
||||
'token': 'mocked-token.mp3',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
@ -957,6 +971,7 @@
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'stream_response': False,
|
||||
'token': 'mocked-token.mp3',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
@ -1017,6 +1032,7 @@
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'stream_response': False,
|
||||
'token': 'mocked-token.mp3',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
@ -1033,6 +1049,7 @@
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'stream_response': False,
|
||||
'token': 'mocked-token.mp3',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
|
@ -1575,8 +1575,9 @@ async def test_pipeline_language_used_instead_of_conversation_language(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("to_stream_tts", "expected_chunks"),
|
||||
("to_stream_tts", "expected_chunks", "chunk_text"),
|
||||
[
|
||||
# Size below STREAM_RESPONSE_CHUNKS
|
||||
(
|
||||
[
|
||||
"hello,",
|
||||
@ -1588,7 +1589,33 @@ async def test_pipeline_language_used_instead_of_conversation_language(
|
||||
"you",
|
||||
"?",
|
||||
],
|
||||
1,
|
||||
# We are not streaming, so 0 chunks via streaming method
|
||||
0,
|
||||
"",
|
||||
),
|
||||
# Size above STREAM_RESPONSE_CHUNKS
|
||||
(
|
||||
[
|
||||
"hello, ",
|
||||
"how ",
|
||||
"are ",
|
||||
"you",
|
||||
"? ",
|
||||
"I'm ",
|
||||
"doing ",
|
||||
"well",
|
||||
", ",
|
||||
"thank ",
|
||||
"you",
|
||||
". ",
|
||||
"What ",
|
||||
"about ",
|
||||
"you",
|
||||
"?",
|
||||
],
|
||||
# We are streamed, so equal to count above list items
|
||||
16,
|
||||
"hello, how are you? I'm doing well, thank you. What about you?",
|
||||
),
|
||||
],
|
||||
)
|
||||
@ -1602,6 +1629,7 @@ async def test_chat_log_tts_streaming(
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
to_stream_tts: list[str],
|
||||
expected_chunks: int,
|
||||
chunk_text: str,
|
||||
) -> None:
|
||||
"""Test that chat log events are streamed to the TTS entity."""
|
||||
events: list[assist_pipeline.PipelineEvent] = []
|
||||
@ -1627,22 +1655,41 @@ async def test_chat_log_tts_streaming(
|
||||
),
|
||||
)
|
||||
|
||||
received_tts = []
|
||||
|
||||
async def async_stream_tts_audio(
|
||||
request: tts.TTSAudioRequest,
|
||||
) -> tts.TTSAudioResponse:
|
||||
"""Mock stream TTS audio."""
|
||||
|
||||
async def gen_data():
|
||||
async for msg in request.message_gen:
|
||||
received_tts.append(msg)
|
||||
yield msg.encode()
|
||||
|
||||
return tts.TTSAudioResponse(
|
||||
extension="mp3",
|
||||
data_gen=gen_data(),
|
||||
)
|
||||
|
||||
async def async_get_tts_audio(
|
||||
message: str,
|
||||
language: str,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> tts.TTSAudioResponse:
|
||||
) -> tts.TtsAudioType:
|
||||
"""Mock get TTS audio."""
|
||||
return ("mp3", b"".join([chunk.encode() for chunk in to_stream_tts]))
|
||||
|
||||
mock_tts_entity.async_get_tts_audio = async_get_tts_audio
|
||||
mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio
|
||||
mock_tts_entity.async_supports_streaming_input = Mock(return_value=True)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
||||
return_value=conversation.AgentInfo(
|
||||
id="test-agent",
|
||||
name="Test Agent",
|
||||
supports_streaming=False,
|
||||
supports_streaming=True,
|
||||
),
|
||||
):
|
||||
await pipeline_input.validate()
|
||||
@ -1707,6 +1754,7 @@ async def test_chat_log_tts_streaming(
|
||||
|
||||
streamed_text = "".join(to_stream_tts)
|
||||
assert tts_result == streamed_text
|
||||
assert expected_chunks == 1
|
||||
assert len(received_tts) == expected_chunks
|
||||
assert "".join(received_tts) == chunk_text
|
||||
|
||||
assert process_events(events) == snapshot
|
||||
|
Loading…
x
Reference in New Issue
Block a user