mirror of
https://github.com/home-assistant/core.git
synced 2025-07-27 07:07:28 +00:00
Assist Pipeline stream TTS when supported and long response (#145264)
* Assist Pipeline stream TTS when supported and long response * Indicate in run-start if streaming supported * Simplify a little bit * Trigger streaming based on characters * 60
This commit is contained in:
parent
37e13505cf
commit
abcf925b79
@ -89,6 +89,8 @@ KEY_ASSIST_PIPELINE: HassKey[PipelineData] = HassKey(DOMAIN)
|
|||||||
KEY_PIPELINE_CONVERSATION_DATA: HassKey[dict[str, PipelineConversationData]] = HassKey(
|
KEY_PIPELINE_CONVERSATION_DATA: HassKey[dict[str, PipelineConversationData]] = HassKey(
|
||||||
"pipeline_conversation_data"
|
"pipeline_conversation_data"
|
||||||
)
|
)
|
||||||
|
# Number of response parts to handle before streaming the response
|
||||||
|
STREAM_RESPONSE_CHARS = 60
|
||||||
|
|
||||||
|
|
||||||
def validate_language(data: dict[str, Any]) -> Any:
|
def validate_language(data: dict[str, Any]) -> Any:
|
||||||
@ -552,7 +554,7 @@ class PipelineRun:
|
|||||||
event_callback: PipelineEventCallback
|
event_callback: PipelineEventCallback
|
||||||
language: str = None # type: ignore[assignment]
|
language: str = None # type: ignore[assignment]
|
||||||
runner_data: Any | None = None
|
runner_data: Any | None = None
|
||||||
intent_agent: str | None = None
|
intent_agent: conversation.AgentInfo | None = None
|
||||||
tts_audio_output: str | dict[str, Any] | None = None
|
tts_audio_output: str | dict[str, Any] | None = None
|
||||||
wake_word_settings: WakeWordSettings | None = None
|
wake_word_settings: WakeWordSettings | None = None
|
||||||
audio_settings: AudioSettings = field(default_factory=AudioSettings)
|
audio_settings: AudioSettings = field(default_factory=AudioSettings)
|
||||||
@ -588,6 +590,9 @@ class PipelineRun:
|
|||||||
_intent_agent_only = False
|
_intent_agent_only = False
|
||||||
"""If request should only be handled by agent, ignoring sentence triggers and local processing."""
|
"""If request should only be handled by agent, ignoring sentence triggers and local processing."""
|
||||||
|
|
||||||
|
_streamed_response_text = False
|
||||||
|
"""If the conversation agent streamed response text to TTS result."""
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
"""Set language for pipeline."""
|
"""Set language for pipeline."""
|
||||||
self.language = self.pipeline.language or self.hass.config.language
|
self.language = self.pipeline.language or self.hass.config.language
|
||||||
@ -649,6 +654,11 @@ class PipelineRun:
|
|||||||
"token": self.tts_stream.token,
|
"token": self.tts_stream.token,
|
||||||
"url": self.tts_stream.url,
|
"url": self.tts_stream.url,
|
||||||
"mime_type": self.tts_stream.content_type,
|
"mime_type": self.tts_stream.content_type,
|
||||||
|
"stream_response": (
|
||||||
|
self.tts_stream.supports_streaming_input
|
||||||
|
and self.intent_agent
|
||||||
|
and self.intent_agent.supports_streaming
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
self.process_event(PipelineEvent(PipelineEventType.RUN_START, data))
|
self.process_event(PipelineEvent(PipelineEventType.RUN_START, data))
|
||||||
@ -896,12 +906,12 @@ class PipelineRun:
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""Run speech-to-text portion of pipeline. Returns the spoken text."""
|
"""Run speech-to-text portion of pipeline. Returns the spoken text."""
|
||||||
# Create a background task to prepare the conversation agent
|
# Create a background task to prepare the conversation agent
|
||||||
if self.end_stage >= PipelineStage.INTENT:
|
if self.end_stage >= PipelineStage.INTENT and self.intent_agent:
|
||||||
self.hass.async_create_background_task(
|
self.hass.async_create_background_task(
|
||||||
conversation.async_prepare_agent(
|
conversation.async_prepare_agent(
|
||||||
self.hass, self.intent_agent, self.language
|
self.hass, self.intent_agent.id, self.language
|
||||||
),
|
),
|
||||||
f"prepare conversation agent {self.intent_agent}",
|
f"prepare conversation agent {self.intent_agent.id}",
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(self.stt_provider, stt.Provider):
|
if isinstance(self.stt_provider, stt.Provider):
|
||||||
@ -1042,7 +1052,7 @@ class PipelineRun:
|
|||||||
message=f"Intent recognition engine {engine} is not found",
|
message=f"Intent recognition engine {engine} is not found",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.intent_agent = agent_info.id
|
self.intent_agent = agent_info
|
||||||
|
|
||||||
async def recognize_intent(
|
async def recognize_intent(
|
||||||
self,
|
self,
|
||||||
@ -1075,7 +1085,7 @@ class PipelineRun:
|
|||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
PipelineEventType.INTENT_START,
|
PipelineEventType.INTENT_START,
|
||||||
{
|
{
|
||||||
"engine": self.intent_agent,
|
"engine": self.intent_agent.id,
|
||||||
"language": input_language,
|
"language": input_language,
|
||||||
"intent_input": intent_input,
|
"intent_input": intent_input,
|
||||||
"conversation_id": conversation_id,
|
"conversation_id": conversation_id,
|
||||||
@ -1092,11 +1102,11 @@ class PipelineRun:
|
|||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
device_id=device_id,
|
device_id=device_id,
|
||||||
language=input_language,
|
language=input_language,
|
||||||
agent_id=self.intent_agent,
|
agent_id=self.intent_agent.id,
|
||||||
extra_system_prompt=conversation_extra_system_prompt,
|
extra_system_prompt=conversation_extra_system_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
agent_id = self.intent_agent
|
agent_id = self.intent_agent.id
|
||||||
processed_locally = agent_id == conversation.HOME_ASSISTANT_AGENT
|
processed_locally = agent_id == conversation.HOME_ASSISTANT_AGENT
|
||||||
intent_response: intent.IntentResponse | None = None
|
intent_response: intent.IntentResponse | None = None
|
||||||
if not processed_locally and not self._intent_agent_only:
|
if not processed_locally and not self._intent_agent_only:
|
||||||
@ -1118,7 +1128,7 @@ class PipelineRun:
|
|||||||
# If the LLM has API access, we filter out some sentences that are
|
# If the LLM has API access, we filter out some sentences that are
|
||||||
# interfering with LLM operation.
|
# interfering with LLM operation.
|
||||||
if (
|
if (
|
||||||
intent_agent_state := self.hass.states.get(self.intent_agent)
|
intent_agent_state := self.hass.states.get(self.intent_agent.id)
|
||||||
) and intent_agent_state.attributes.get(
|
) and intent_agent_state.attributes.get(
|
||||||
ATTR_SUPPORTED_FEATURES, 0
|
ATTR_SUPPORTED_FEATURES, 0
|
||||||
) & conversation.ConversationEntityFeature.CONTROL:
|
) & conversation.ConversationEntityFeature.CONTROL:
|
||||||
@ -1140,6 +1150,13 @@ class PipelineRun:
|
|||||||
agent_id = conversation.HOME_ASSISTANT_AGENT
|
agent_id = conversation.HOME_ASSISTANT_AGENT
|
||||||
processed_locally = True
|
processed_locally = True
|
||||||
|
|
||||||
|
if self.tts_stream and self.tts_stream.supports_streaming_input:
|
||||||
|
tts_input_stream: asyncio.Queue[str | None] | None = asyncio.Queue()
|
||||||
|
else:
|
||||||
|
tts_input_stream = None
|
||||||
|
chat_log_role = None
|
||||||
|
delta_character_count = 0
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def chat_log_delta_listener(
|
def chat_log_delta_listener(
|
||||||
chat_log: conversation.ChatLog, delta: dict
|
chat_log: conversation.ChatLog, delta: dict
|
||||||
@ -1153,6 +1170,42 @@ class PipelineRun:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if tts_input_stream is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
nonlocal chat_log_role
|
||||||
|
|
||||||
|
if role := delta.get("role"):
|
||||||
|
chat_log_role = role
|
||||||
|
|
||||||
|
# We are only interested in assistant deltas with content
|
||||||
|
if chat_log_role != "assistant" or not (
|
||||||
|
content := delta.get("content")
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
tts_input_stream.put_nowait(content)
|
||||||
|
|
||||||
|
if self._streamed_response_text:
|
||||||
|
return
|
||||||
|
|
||||||
|
nonlocal delta_character_count
|
||||||
|
|
||||||
|
delta_character_count += len(content)
|
||||||
|
if delta_character_count < STREAM_RESPONSE_CHARS:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Streamed responses are not cached. We only start streaming text after
|
||||||
|
# we have received a couple of words that indicates it will be a long response.
|
||||||
|
self._streamed_response_text = True
|
||||||
|
|
||||||
|
async def tts_input_stream_generator() -> AsyncGenerator[str]:
|
||||||
|
"""Yield TTS input stream."""
|
||||||
|
while (tts_input := await tts_input_stream.get()) is not None:
|
||||||
|
yield tts_input
|
||||||
|
|
||||||
|
assert self.tts_stream is not None
|
||||||
|
self.tts_stream.async_set_message_stream(tts_input_stream_generator())
|
||||||
|
|
||||||
with (
|
with (
|
||||||
chat_session.async_get_chat_session(
|
chat_session.async_get_chat_session(
|
||||||
@ -1196,6 +1249,8 @@ class PipelineRun:
|
|||||||
speech = conversation_result.response.speech.get("plain", {}).get(
|
speech = conversation_result.response.speech.get("plain", {}).get(
|
||||||
"speech", ""
|
"speech", ""
|
||||||
)
|
)
|
||||||
|
if tts_input_stream and self._streamed_response_text:
|
||||||
|
tts_input_stream.put_nowait(None)
|
||||||
|
|
||||||
except Exception as src_error:
|
except Exception as src_error:
|
||||||
_LOGGER.exception("Unexpected error during intent recognition")
|
_LOGGER.exception("Unexpected error during intent recognition")
|
||||||
@ -1273,6 +1328,7 @@ class PipelineRun:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not self._streamed_response_text:
|
||||||
self.tts_stream.async_set_message(tts_input)
|
self.tts_stream.async_set_message(tts_input)
|
||||||
|
|
||||||
tts_output = {
|
tts_output = {
|
||||||
|
@ -8,6 +8,7 @@
|
|||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
|
'stream_response': False,
|
||||||
'token': 'test_token.mp3',
|
'token': 'test_token.mp3',
|
||||||
'url': '/api/tts_proxy/test_token.mp3',
|
'url': '/api/tts_proxy/test_token.mp3',
|
||||||
}),
|
}),
|
||||||
@ -107,6 +108,7 @@
|
|||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
|
'stream_response': False,
|
||||||
'token': 'test_token.mp3',
|
'token': 'test_token.mp3',
|
||||||
'url': '/api/tts_proxy/test_token.mp3',
|
'url': '/api/tts_proxy/test_token.mp3',
|
||||||
}),
|
}),
|
||||||
@ -206,6 +208,7 @@
|
|||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
|
'stream_response': False,
|
||||||
'token': 'test_token.mp3',
|
'token': 'test_token.mp3',
|
||||||
'url': '/api/tts_proxy/test_token.mp3',
|
'url': '/api/tts_proxy/test_token.mp3',
|
||||||
}),
|
}),
|
||||||
@ -305,6 +308,7 @@
|
|||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
|
'stream_response': False,
|
||||||
'token': 'test_token.mp3',
|
'token': 'test_token.mp3',
|
||||||
'url': '/api/tts_proxy/test_token.mp3',
|
'url': '/api/tts_proxy/test_token.mp3',
|
||||||
}),
|
}),
|
||||||
@ -428,6 +432,7 @@
|
|||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
|
'stream_response': False,
|
||||||
'token': 'mocked-token.mp3',
|
'token': 'mocked-token.mp3',
|
||||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||||
}),
|
}),
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# serializer version: 1
|
# serializer version: 1
|
||||||
# name: test_chat_log_tts_streaming[to_stream_tts0-1]
|
# name: test_chat_log_tts_streaming[to_stream_tts0-0-]
|
||||||
list([
|
list([
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
@ -8,6 +8,7 @@
|
|||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
|
'stream_response': True,
|
||||||
'token': 'mocked-token.mp3',
|
'token': 'mocked-token.mp3',
|
||||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||||
}),
|
}),
|
||||||
@ -153,6 +154,225 @@
|
|||||||
}),
|
}),
|
||||||
])
|
])
|
||||||
# ---
|
# ---
|
||||||
|
# name: test_chat_log_tts_streaming[to_stream_tts1-16-hello, how are you? I'm doing well, thank you. What about you?]
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
|
'language': 'en',
|
||||||
|
'pipeline': <ANY>,
|
||||||
|
'tts_output': dict({
|
||||||
|
'mime_type': 'audio/mpeg',
|
||||||
|
'stream_response': True,
|
||||||
|
'token': 'mocked-token.mp3',
|
||||||
|
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
|
'device_id': None,
|
||||||
|
'engine': 'test-agent',
|
||||||
|
'intent_input': 'Set a timer',
|
||||||
|
'language': 'en',
|
||||||
|
'prefer_local_intents': False,
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'role': 'assistant',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': 'hello, ',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': 'how ',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': 'are ',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': 'you',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': '? ',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': "I'm ",
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': 'doing ',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': 'well',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': ', ',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': 'thank ',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': 'you',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': '. ',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': 'What ',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': 'about ',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': 'you',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'chat_log_delta': dict({
|
||||||
|
'content': '?',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_PROGRESS: 'intent-progress'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'intent_output': dict({
|
||||||
|
'continue_conversation': True,
|
||||||
|
'conversation_id': <ANY>,
|
||||||
|
'response': dict({
|
||||||
|
'card': dict({
|
||||||
|
}),
|
||||||
|
'data': dict({
|
||||||
|
'failed': list([
|
||||||
|
]),
|
||||||
|
'success': list([
|
||||||
|
]),
|
||||||
|
'targets': list([
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'language': 'en',
|
||||||
|
'response_type': 'action_done',
|
||||||
|
'speech': dict({
|
||||||
|
'plain': dict({
|
||||||
|
'extra_data': None,
|
||||||
|
'speech': "hello, how are you? I'm doing well, thank you. What about you?",
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'processed_locally': False,
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'engine': 'tts.test',
|
||||||
|
'language': 'en_US',
|
||||||
|
'tts_input': "hello, how are you? I'm doing well, thank you. What about you?",
|
||||||
|
'voice': None,
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.TTS_START: 'tts-start'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'tts_output': dict({
|
||||||
|
'media_id': 'media-source://tts/-stream-/mocked-token.mp3',
|
||||||
|
'mime_type': 'audio/mpeg',
|
||||||
|
'token': 'mocked-token.mp3',
|
||||||
|
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.TTS_END: 'tts-end'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': None,
|
||||||
|
'type': <PipelineEventType.RUN_END: 'run-end'>,
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
# name: test_pipeline_language_used_instead_of_conversation_language
|
# name: test_pipeline_language_used_instead_of_conversation_language
|
||||||
list([
|
list([
|
||||||
dict({
|
dict({
|
||||||
@ -321,6 +541,7 @@
|
|||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
|
'stream_response': False,
|
||||||
'token': 'mocked-token.mp3',
|
'token': 'mocked-token.mp3',
|
||||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||||
}),
|
}),
|
||||||
|
@ -10,6 +10,7 @@
|
|||||||
}),
|
}),
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
|
'stream_response': False,
|
||||||
'token': 'test_token.mp3',
|
'token': 'test_token.mp3',
|
||||||
'url': '/api/tts_proxy/test_token.mp3',
|
'url': '/api/tts_proxy/test_token.mp3',
|
||||||
}),
|
}),
|
||||||
@ -101,6 +102,7 @@
|
|||||||
}),
|
}),
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
|
'stream_response': False,
|
||||||
'token': 'test_token.mp3',
|
'token': 'test_token.mp3',
|
||||||
'url': '/api/tts_proxy/test_token.mp3',
|
'url': '/api/tts_proxy/test_token.mp3',
|
||||||
}),
|
}),
|
||||||
@ -204,6 +206,7 @@
|
|||||||
}),
|
}),
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
|
'stream_response': False,
|
||||||
'token': 'test_token.mp3',
|
'token': 'test_token.mp3',
|
||||||
'url': '/api/tts_proxy/test_token.mp3',
|
'url': '/api/tts_proxy/test_token.mp3',
|
||||||
}),
|
}),
|
||||||
@ -295,6 +298,7 @@
|
|||||||
}),
|
}),
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
|
'stream_response': False,
|
||||||
'token': 'test_token.mp3',
|
'token': 'test_token.mp3',
|
||||||
'url': '/api/tts_proxy/test_token.mp3',
|
'url': '/api/tts_proxy/test_token.mp3',
|
||||||
}),
|
}),
|
||||||
@ -408,6 +412,7 @@
|
|||||||
}),
|
}),
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
|
'stream_response': False,
|
||||||
'token': 'test_token.mp3',
|
'token': 'test_token.mp3',
|
||||||
'url': '/api/tts_proxy/test_token.mp3',
|
'url': '/api/tts_proxy/test_token.mp3',
|
||||||
}),
|
}),
|
||||||
@ -616,6 +621,7 @@
|
|||||||
}),
|
}),
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
|
'stream_response': False,
|
||||||
'token': 'mocked-token.mp3',
|
'token': 'mocked-token.mp3',
|
||||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||||
}),
|
}),
|
||||||
@ -670,6 +676,7 @@
|
|||||||
}),
|
}),
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
|
'stream_response': False,
|
||||||
'token': 'mocked-token.mp3',
|
'token': 'mocked-token.mp3',
|
||||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||||
}),
|
}),
|
||||||
@ -686,6 +693,7 @@
|
|||||||
}),
|
}),
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
|
'stream_response': False,
|
||||||
'token': 'mocked-token.mp3',
|
'token': 'mocked-token.mp3',
|
||||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||||
}),
|
}),
|
||||||
@ -702,6 +710,7 @@
|
|||||||
}),
|
}),
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
|
'stream_response': False,
|
||||||
'token': 'mocked-token.mp3',
|
'token': 'mocked-token.mp3',
|
||||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||||
}),
|
}),
|
||||||
@ -718,6 +727,7 @@
|
|||||||
}),
|
}),
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
|
'stream_response': False,
|
||||||
'token': 'mocked-token.mp3',
|
'token': 'mocked-token.mp3',
|
||||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||||
}),
|
}),
|
||||||
@ -734,6 +744,7 @@
|
|||||||
}),
|
}),
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
|
'stream_response': False,
|
||||||
'token': 'mocked-token.mp3',
|
'token': 'mocked-token.mp3',
|
||||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||||
}),
|
}),
|
||||||
@ -868,6 +879,7 @@
|
|||||||
}),
|
}),
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
|
'stream_response': False,
|
||||||
'token': 'mocked-token.mp3',
|
'token': 'mocked-token.mp3',
|
||||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||||
}),
|
}),
|
||||||
@ -884,6 +896,7 @@
|
|||||||
}),
|
}),
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
|
'stream_response': False,
|
||||||
'token': 'mocked-token.mp3',
|
'token': 'mocked-token.mp3',
|
||||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||||
}),
|
}),
|
||||||
@ -941,6 +954,7 @@
|
|||||||
}),
|
}),
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
|
'stream_response': False,
|
||||||
'token': 'mocked-token.mp3',
|
'token': 'mocked-token.mp3',
|
||||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||||
}),
|
}),
|
||||||
@ -957,6 +971,7 @@
|
|||||||
}),
|
}),
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
|
'stream_response': False,
|
||||||
'token': 'mocked-token.mp3',
|
'token': 'mocked-token.mp3',
|
||||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||||
}),
|
}),
|
||||||
@ -1017,6 +1032,7 @@
|
|||||||
}),
|
}),
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
|
'stream_response': False,
|
||||||
'token': 'mocked-token.mp3',
|
'token': 'mocked-token.mp3',
|
||||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||||
}),
|
}),
|
||||||
@ -1033,6 +1049,7 @@
|
|||||||
}),
|
}),
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
|
'stream_response': False,
|
||||||
'token': 'mocked-token.mp3',
|
'token': 'mocked-token.mp3',
|
||||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||||
}),
|
}),
|
||||||
|
@ -1575,8 +1575,9 @@ async def test_pipeline_language_used_instead_of_conversation_language(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("to_stream_tts", "expected_chunks"),
|
("to_stream_tts", "expected_chunks", "chunk_text"),
|
||||||
[
|
[
|
||||||
|
# Size below STREAM_RESPONSE_CHUNKS
|
||||||
(
|
(
|
||||||
[
|
[
|
||||||
"hello,",
|
"hello,",
|
||||||
@ -1588,7 +1589,33 @@ async def test_pipeline_language_used_instead_of_conversation_language(
|
|||||||
"you",
|
"you",
|
||||||
"?",
|
"?",
|
||||||
],
|
],
|
||||||
1,
|
# We are not streaming, so 0 chunks via streaming method
|
||||||
|
0,
|
||||||
|
"",
|
||||||
|
),
|
||||||
|
# Size above STREAM_RESPONSE_CHUNKS
|
||||||
|
(
|
||||||
|
[
|
||||||
|
"hello, ",
|
||||||
|
"how ",
|
||||||
|
"are ",
|
||||||
|
"you",
|
||||||
|
"? ",
|
||||||
|
"I'm ",
|
||||||
|
"doing ",
|
||||||
|
"well",
|
||||||
|
", ",
|
||||||
|
"thank ",
|
||||||
|
"you",
|
||||||
|
". ",
|
||||||
|
"What ",
|
||||||
|
"about ",
|
||||||
|
"you",
|
||||||
|
"?",
|
||||||
|
],
|
||||||
|
# We are streamed, so equal to count above list items
|
||||||
|
16,
|
||||||
|
"hello, how are you? I'm doing well, thank you. What about you?",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -1602,6 +1629,7 @@ async def test_chat_log_tts_streaming(
|
|||||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||||
to_stream_tts: list[str],
|
to_stream_tts: list[str],
|
||||||
expected_chunks: int,
|
expected_chunks: int,
|
||||||
|
chunk_text: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that chat log events are streamed to the TTS entity."""
|
"""Test that chat log events are streamed to the TTS entity."""
|
||||||
events: list[assist_pipeline.PipelineEvent] = []
|
events: list[assist_pipeline.PipelineEvent] = []
|
||||||
@ -1627,22 +1655,41 @@ async def test_chat_log_tts_streaming(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
received_tts = []
|
||||||
|
|
||||||
|
async def async_stream_tts_audio(
|
||||||
|
request: tts.TTSAudioRequest,
|
||||||
|
) -> tts.TTSAudioResponse:
|
||||||
|
"""Mock stream TTS audio."""
|
||||||
|
|
||||||
|
async def gen_data():
|
||||||
|
async for msg in request.message_gen:
|
||||||
|
received_tts.append(msg)
|
||||||
|
yield msg.encode()
|
||||||
|
|
||||||
|
return tts.TTSAudioResponse(
|
||||||
|
extension="mp3",
|
||||||
|
data_gen=gen_data(),
|
||||||
|
)
|
||||||
|
|
||||||
async def async_get_tts_audio(
|
async def async_get_tts_audio(
|
||||||
message: str,
|
message: str,
|
||||||
language: str,
|
language: str,
|
||||||
options: dict[str, Any] | None = None,
|
options: dict[str, Any] | None = None,
|
||||||
) -> tts.TTSAudioResponse:
|
) -> tts.TtsAudioType:
|
||||||
"""Mock get TTS audio."""
|
"""Mock get TTS audio."""
|
||||||
return ("mp3", b"".join([chunk.encode() for chunk in to_stream_tts]))
|
return ("mp3", b"".join([chunk.encode() for chunk in to_stream_tts]))
|
||||||
|
|
||||||
mock_tts_entity.async_get_tts_audio = async_get_tts_audio
|
mock_tts_entity.async_get_tts_audio = async_get_tts_audio
|
||||||
|
mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio
|
||||||
|
mock_tts_entity.async_supports_streaming_input = Mock(return_value=True)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
||||||
return_value=conversation.AgentInfo(
|
return_value=conversation.AgentInfo(
|
||||||
id="test-agent",
|
id="test-agent",
|
||||||
name="Test Agent",
|
name="Test Agent",
|
||||||
supports_streaming=False,
|
supports_streaming=True,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
await pipeline_input.validate()
|
await pipeline_input.validate()
|
||||||
@ -1707,6 +1754,7 @@ async def test_chat_log_tts_streaming(
|
|||||||
|
|
||||||
streamed_text = "".join(to_stream_tts)
|
streamed_text = "".join(to_stream_tts)
|
||||||
assert tts_result == streamed_text
|
assert tts_result == streamed_text
|
||||||
assert expected_chunks == 1
|
assert len(received_tts) == expected_chunks
|
||||||
|
assert "".join(received_tts) == chunk_text
|
||||||
|
|
||||||
assert process_events(events) == snapshot
|
assert process_events(events) == snapshot
|
||||||
|
Loading…
x
Reference in New Issue
Block a user