From e78f4d2a29db09949ee5bca39bc537de048831d3 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 19 May 2025 14:54:21 -0400 Subject: [PATCH] TTS to only use stream entity method when streaming request comes in (#145167) Co-authored-by: Franck Nijhof --- homeassistant/components/tts/__init__.py | 18 ++++++------- homeassistant/components/tts/entity.py | 12 +++++++++ homeassistant/components/tts/legacy.py | 14 +++++++++- .../assist_pipeline/test_pipeline.py | 27 ++++++------------- 4 files changed, 42 insertions(+), 29 deletions(-) diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index da8a0f2324e..8292df07ef8 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -852,12 +852,9 @@ class SpeechManager: else: _LOGGER.debug("Generating audio for %s", message[0:32]) - async def message_stream() -> AsyncGenerator[str]: - yield message - extension = options.get(ATTR_PREFERRED_FORMAT, _DEFAULT_FORMAT) data_gen = self._async_generate_tts_audio( - engine_instance, message_stream(), language, options + engine_instance, message, language, options ) cache = TTSCache( @@ -931,7 +928,7 @@ class SpeechManager: async def _async_generate_tts_audio( self, engine_instance: TextToSpeechEntity | Provider, - message_stream: AsyncGenerator[str], + message_or_stream: str | AsyncGenerator[str], language: str, options: dict[str, Any], ) -> AsyncGenerator[bytes]: @@ -979,9 +976,12 @@ class SpeechManager: if engine_instance.name is None or engine_instance.name is UNDEFINED: raise HomeAssistantError("TTS engine name is not set.") - if isinstance(engine_instance, Provider): - message = "".join([chunk async for chunk in message_stream]) - extension, data = await engine_instance.async_get_tts_audio( + if isinstance(engine_instance, Provider) or isinstance(message_or_stream, str): + if isinstance(message_or_stream, str): + message = message_or_stream + else: + message = "".join([chunk async for chunk in message_or_stream]) + extension, data = await engine_instance.async_internal_get_tts_audio( message, language, options ) @@ -997,7 +997,7 @@ class SpeechManager: else: tts_result = await engine_instance.internal_async_stream_tts_audio( - TTSAudioRequest(language, options, message_stream) + TTSAudioRequest(language, options, message_or_stream) ) extension = tts_result.extension data_gen = tts_result.data_gen diff --git a/homeassistant/components/tts/entity.py b/homeassistant/components/tts/entity.py index 1f01a41c5ab..2c3fd446d2f 100644 --- a/homeassistant/components/tts/entity.py +++ b/homeassistant/components/tts/entity.py @@ -165,6 +165,18 @@ class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH self.async_write_ha_state() return await self.async_stream_tts_audio(request) + @final + async def async_internal_get_tts_audio( + self, message: str, language: str, options: dict[str, Any] + ) -> TtsAudioType: + """Load tts audio file from the engine and update state. + + Return a tuple of file extension and data as bytes. + """ + self.__last_tts_loaded = dt_util.utcnow().isoformat() + self.async_write_ha_state() + return await self.async_get_tts_audio(message, language, options=options) + async def async_stream_tts_audio( self, request: TTSAudioRequest ) -> TTSAudioResponse: diff --git a/homeassistant/components/tts/legacy.py b/homeassistant/components/tts/legacy.py index 877ecc034d6..c3d7eb6fdd6 100644 --- a/homeassistant/components/tts/legacy.py +++ b/homeassistant/components/tts/legacy.py @@ -7,7 +7,7 @@ from collections.abc import Coroutine, Mapping from functools import partial import logging from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, final import voluptuous as vol @@ -252,3 +252,15 @@ class Provider: return await self.hass.async_add_executor_job( partial(self.get_tts_audio, message, language, options=options) ) + + @final + async def async_internal_get_tts_audio( + self, message: str, language: str, options: dict[str, Any] + ) -> TtsAudioType: + """Load tts audio file from provider. + + Proxies request to mimic the entity interface. + + Return a tuple of file extension and data as bytes. + """ + return await self.async_get_tts_audio(message, language, options) diff --git a/tests/components/assist_pipeline/test_pipeline.py b/tests/components/assist_pipeline/test_pipeline.py index f4e7c886d40..1714c909a18 100644 --- a/tests/components/assist_pipeline/test_pipeline.py +++ b/tests/components/assist_pipeline/test_pipeline.py @@ -1627,25 +1627,15 @@ async def test_chat_log_tts_streaming( ), ) - received_tts = [] - - async def async_stream_tts_audio( - request: tts.TTSAudioRequest, + async def async_get_tts_audio( + message: str, + language: str, + options: dict[str, Any] | None = None, ) -> tts.TTSAudioResponse: - """Mock stream TTS audio.""" + """Mock get TTS audio.""" + return ("mp3", b"".join([chunk.encode() for chunk in to_stream_tts])) - async def gen_data(): - async for msg in request.message_gen: - received_tts.append(msg) - yield msg.encode() - - return tts.TTSAudioResponse( - extension="mp3", - data_gen=gen_data(), - ) - - mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio - mock_tts_entity.async_supports_streaming_input = Mock(return_value=True) + mock_tts_entity.async_get_tts_audio = async_get_tts_audio with patch( "homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info", @@ -1717,7 +1707,6 @@ async def test_chat_log_tts_streaming( streamed_text = "".join(to_stream_tts) assert tts_result == streamed_text - assert len(received_tts) == expected_chunks - assert "".join(received_tts) == streamed_text + assert expected_chunks == 1 assert process_events(events) == snapshot