From 5dab9ba01ba76f259cd39fe40dca65c649499ce5 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Wed, 30 Apr 2025 08:21:19 -0400 Subject: [PATCH] Allow streaming text into TTS ResultStream objects (#143745) Allow streaming messages into TTS ResultStream --- homeassistant/components/tts/__init__.py | 50 +++++++++++++++++++++++- tests/components/tts/test_init.py | 28 +++++++++++++ 2 files changed, 77 insertions(+), 1 deletion(-) diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index 22c388cae9f..44badaa73d2 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -42,7 +42,7 @@ from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.event import async_call_later from homeassistant.helpers.network import get_url from homeassistant.helpers.typing import UNDEFINED, ConfigType -from homeassistant.util import language as language_util +from homeassistant.util import language as language_util, ulid as ulid_util from .const import ( ATTR_CACHE, @@ -495,6 +495,18 @@ class ResultStream: ) ) + @callback + def async_set_message_stream(self, message_stream: AsyncGenerator[str]) -> None: + """Set a stream that will generate the message.""" + self._result_cache.set_result( + self._manager.async_cache_message_stream_in_memory( + engine=self.engine, + message_stream=message_stream, + language=self.language, + options=self.options, + ) + ) + async def async_stream_result(self) -> AsyncGenerator[bytes]: """Get the stream of this result.""" cache = await self._result_cache @@ -735,6 +747,42 @@ class SpeechManager: self.token_to_stream_cleanup.schedule() return result_stream + @callback + def async_cache_message_stream_in_memory( + self, + engine: str, + message_stream: AsyncGenerator[str], + language: str, + options: dict, + ) -> TTSCache: + """Make sure a message stream will be cached in memory and returns cache object. + + Requires options, language to be processed. + """ + if (engine_instance := get_engine_instance(self.hass, engine)) is None: + raise HomeAssistantError(f"Provider {engine} not found") + + cache_key = ulid_util.ulid_now() + extension = options.get(ATTR_PREFERRED_FORMAT, _DEFAULT_FORMAT) + data_gen = self._async_generate_tts_audio( + engine_instance, message_stream, language, options + ) + + cache = TTSCache( + cache_key=cache_key, + extension=extension, + data_gen=data_gen, + ) + self.mem_cache[cache_key] = cache + self.hass.async_create_background_task( + self._load_data_into_cache( + cache, engine_instance, "[Streaming TTS]", False, language, options + ), + f"tts_load_data_into_cache_{engine_instance.name}", + ) + self.memcache_cleanup.schedule() + return cache + @callback def async_cache_message_in_memory( self, diff --git a/tests/components/tts/test_init.py b/tests/components/tts/test_init.py index 99f4b008c68..45424be8481 100644 --- a/tests/components/tts/test_init.py +++ b/tests/components/tts/test_init.py @@ -1842,6 +1842,7 @@ async def test_default_engine_prefer_cloud_entity( async def test_stream(hass: HomeAssistant, mock_tts_entity: MockTTSEntity) -> None: """Test creating streams.""" await mock_config_entry_setup(hass, mock_tts_entity) + stream = tts.async_create_stream(hass, mock_tts_entity.entity_id) assert stream.language == mock_tts_entity.default_language assert stream.options == (mock_tts_entity.default_options or {}) @@ -1850,6 +1851,33 @@ async def test_stream(hass: HomeAssistant, mock_tts_entity: MockTTSEntity) -> No result_data = b"".join([chunk async for chunk in stream.async_stream_result()]) assert result_data == MOCK_DATA + async def async_stream_tts_audio( + request: tts.TTSAudioRequest, + ) -> tts.TTSAudioResponse: + """Mock stream TTS audio.""" + + async def gen_data(): + async for msg in request.message_gen: + yield msg.encode() + + return tts.TTSAudioResponse( + extension="mp3", + data_gen=gen_data(), + ) + + mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio + + async def stream_message(): + """Mock stream message.""" + yield "he" + yield "ll" + yield "o" + + stream = tts.async_create_stream(hass, mock_tts_entity.entity_id) + stream.async_set_message_stream(stream_message()) + result_data = b"".join([chunk async for chunk in stream.async_stream_result()]) + assert result_data == b"hello" + data = b"beer" stream2 = MockResultStream(hass, "wav", data) assert tts.async_get_stream(hass, stream2.token) is stream2