TTS to only use stream entity method when streaming request comes in (#145167)

Co-authored-by: Franck Nijhof <git@frenck.dev>
This commit is contained in:
Paulus Schoutsen 2025-05-19 14:54:21 -04:00 committed by GitHub
parent 1f6faaacab
commit e78f4d2a29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 42 additions and 29 deletions

View File

@ -852,12 +852,9 @@ class SpeechManager:
else:
_LOGGER.debug("Generating audio for %s", message[0:32])
async def message_stream() -> AsyncGenerator[str]:
yield message
extension = options.get(ATTR_PREFERRED_FORMAT, _DEFAULT_FORMAT)
data_gen = self._async_generate_tts_audio(
engine_instance, message_stream(), language, options
engine_instance, message, language, options
)
cache = TTSCache(
@ -931,7 +928,7 @@ class SpeechManager:
async def _async_generate_tts_audio(
self,
engine_instance: TextToSpeechEntity | Provider,
message_stream: AsyncGenerator[str],
message_or_stream: str | AsyncGenerator[str],
language: str,
options: dict[str, Any],
) -> AsyncGenerator[bytes]:
@ -979,9 +976,12 @@ class SpeechManager:
if engine_instance.name is None or engine_instance.name is UNDEFINED:
raise HomeAssistantError("TTS engine name is not set.")
if isinstance(engine_instance, Provider):
message = "".join([chunk async for chunk in message_stream])
extension, data = await engine_instance.async_get_tts_audio(
if isinstance(engine_instance, Provider) or isinstance(message_or_stream, str):
if isinstance(message_or_stream, str):
message = message_or_stream
else:
message = "".join([chunk async for chunk in message_or_stream])
extension, data = await engine_instance.async_internal_get_tts_audio(
message, language, options
)
@ -997,7 +997,7 @@ class SpeechManager:
else:
tts_result = await engine_instance.internal_async_stream_tts_audio(
TTSAudioRequest(language, options, message_stream)
TTSAudioRequest(language, options, message_or_stream)
)
extension = tts_result.extension
data_gen = tts_result.data_gen

View File

@ -165,6 +165,18 @@ class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH
self.async_write_ha_state()
return await self.async_stream_tts_audio(request)
@final
async def async_internal_get_tts_audio(
self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType:
"""Load tts audio file from the engine and update state.
Return a tuple of file extension and data as bytes.
"""
self.__last_tts_loaded = dt_util.utcnow().isoformat()
self.async_write_ha_state()
return await self.async_get_tts_audio(message, language, options=options)
async def async_stream_tts_audio(
self, request: TTSAudioRequest
) -> TTSAudioResponse:

View File

@ -7,7 +7,7 @@ from collections.abc import Coroutine, Mapping
from functools import partial
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, final
import voluptuous as vol
@ -252,3 +252,15 @@ class Provider:
return await self.hass.async_add_executor_job(
partial(self.get_tts_audio, message, language, options=options)
)
@final
async def async_internal_get_tts_audio(
self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType:
"""Load tts audio file from provider.
Proxies request to mimic the entity interface.
Return a tuple of file extension and data as bytes.
"""
return await self.async_get_tts_audio(message, language, options)

View File

@ -1627,25 +1627,15 @@ async def test_chat_log_tts_streaming(
),
)
received_tts = []
async def async_stream_tts_audio(
request: tts.TTSAudioRequest,
async def async_get_tts_audio(
message: str,
language: str,
options: dict[str, Any] | None = None,
) -> tts.TTSAudioResponse:
"""Mock stream TTS audio."""
"""Mock get TTS audio."""
return ("mp3", b"".join([chunk.encode() for chunk in to_stream_tts]))
async def gen_data():
async for msg in request.message_gen:
received_tts.append(msg)
yield msg.encode()
return tts.TTSAudioResponse(
extension="mp3",
data_gen=gen_data(),
)
mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio
mock_tts_entity.async_supports_streaming_input = Mock(return_value=True)
mock_tts_entity.async_get_tts_audio = async_get_tts_audio
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
@ -1717,7 +1707,6 @@ async def test_chat_log_tts_streaming(
streamed_text = "".join(to_stream_tts)
assert tts_result == streamed_text
assert len(received_tts) == expected_chunks
assert "".join(received_tts) == streamed_text
assert expected_chunks == 1
assert process_events(events) == snapshot