TTS to only use stream entity method when streaming request comes in (#145167)

Co-authored-by: Franck Nijhof <git@frenck.dev>
This commit is contained in:
Paulus Schoutsen 2025-05-19 14:54:21 -04:00 committed by GitHub
parent 1f6faaacab
commit e78f4d2a29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 42 additions and 29 deletions

View File

@ -852,12 +852,9 @@ class SpeechManager:
else: else:
_LOGGER.debug("Generating audio for %s", message[0:32]) _LOGGER.debug("Generating audio for %s", message[0:32])
async def message_stream() -> AsyncGenerator[str]:
yield message
extension = options.get(ATTR_PREFERRED_FORMAT, _DEFAULT_FORMAT) extension = options.get(ATTR_PREFERRED_FORMAT, _DEFAULT_FORMAT)
data_gen = self._async_generate_tts_audio( data_gen = self._async_generate_tts_audio(
engine_instance, message_stream(), language, options engine_instance, message, language, options
) )
cache = TTSCache( cache = TTSCache(
@ -931,7 +928,7 @@ class SpeechManager:
async def _async_generate_tts_audio( async def _async_generate_tts_audio(
self, self,
engine_instance: TextToSpeechEntity | Provider, engine_instance: TextToSpeechEntity | Provider,
message_stream: AsyncGenerator[str], message_or_stream: str | AsyncGenerator[str],
language: str, language: str,
options: dict[str, Any], options: dict[str, Any],
) -> AsyncGenerator[bytes]: ) -> AsyncGenerator[bytes]:
@ -979,9 +976,12 @@ class SpeechManager:
if engine_instance.name is None or engine_instance.name is UNDEFINED: if engine_instance.name is None or engine_instance.name is UNDEFINED:
raise HomeAssistantError("TTS engine name is not set.") raise HomeAssistantError("TTS engine name is not set.")
if isinstance(engine_instance, Provider): if isinstance(engine_instance, Provider) or isinstance(message_or_stream, str):
message = "".join([chunk async for chunk in message_stream]) if isinstance(message_or_stream, str):
extension, data = await engine_instance.async_get_tts_audio( message = message_or_stream
else:
message = "".join([chunk async for chunk in message_or_stream])
extension, data = await engine_instance.async_internal_get_tts_audio(
message, language, options message, language, options
) )
@ -997,7 +997,7 @@ class SpeechManager:
else: else:
tts_result = await engine_instance.internal_async_stream_tts_audio( tts_result = await engine_instance.internal_async_stream_tts_audio(
TTSAudioRequest(language, options, message_stream) TTSAudioRequest(language, options, message_or_stream)
) )
extension = tts_result.extension extension = tts_result.extension
data_gen = tts_result.data_gen data_gen = tts_result.data_gen

View File

@ -165,6 +165,18 @@ class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH
self.async_write_ha_state() self.async_write_ha_state()
return await self.async_stream_tts_audio(request) return await self.async_stream_tts_audio(request)
@final
async def async_internal_get_tts_audio(
self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType:
"""Load tts audio file from the engine and update state.
Return a tuple of file extension and data as bytes.
"""
self.__last_tts_loaded = dt_util.utcnow().isoformat()
self.async_write_ha_state()
return await self.async_get_tts_audio(message, language, options=options)
async def async_stream_tts_audio( async def async_stream_tts_audio(
self, request: TTSAudioRequest self, request: TTSAudioRequest
) -> TTSAudioResponse: ) -> TTSAudioResponse:

View File

@ -7,7 +7,7 @@ from collections.abc import Coroutine, Mapping
from functools import partial from functools import partial
import logging import logging
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, final
import voluptuous as vol import voluptuous as vol
@ -252,3 +252,15 @@ class Provider:
return await self.hass.async_add_executor_job( return await self.hass.async_add_executor_job(
partial(self.get_tts_audio, message, language, options=options) partial(self.get_tts_audio, message, language, options=options)
) )
@final
async def async_internal_get_tts_audio(
self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType:
"""Load tts audio file from provider.
Proxies request to mimic the entity interface.
Return a tuple of file extension and data as bytes.
"""
return await self.async_get_tts_audio(message, language, options)

View File

@ -1627,25 +1627,15 @@ async def test_chat_log_tts_streaming(
), ),
) )
received_tts = [] async def async_get_tts_audio(
message: str,
async def async_stream_tts_audio( language: str,
request: tts.TTSAudioRequest, options: dict[str, Any] | None = None,
) -> tts.TTSAudioResponse: ) -> tts.TTSAudioResponse:
"""Mock stream TTS audio.""" """Mock get TTS audio."""
return ("mp3", b"".join([chunk.encode() for chunk in to_stream_tts]))
async def gen_data(): mock_tts_entity.async_get_tts_audio = async_get_tts_audio
async for msg in request.message_gen:
received_tts.append(msg)
yield msg.encode()
return tts.TTSAudioResponse(
extension="mp3",
data_gen=gen_data(),
)
mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio
mock_tts_entity.async_supports_streaming_input = Mock(return_value=True)
with patch( with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info", "homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
@ -1717,7 +1707,6 @@ async def test_chat_log_tts_streaming(
streamed_text = "".join(to_stream_tts) streamed_text = "".join(to_stream_tts)
assert tts_result == streamed_text assert tts_result == streamed_text
assert len(received_tts) == expected_chunks assert expected_chunks == 1
assert "".join(received_tts) == streamed_text
assert process_events(events) == snapshot assert process_events(events) == snapshot