From 55895df54dba28802e1f0abc0953f37b18e09793 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Thu, 13 Mar 2025 13:24:44 -0400 Subject: [PATCH] Switch more TTS core to async generators (#140432) * Switch more TTS core to async generators * Document a design choice * robust * Add more tests * Update comment * Clarify and document TTSCache variables --- homeassistant/components/tts/__init__.py | 369 ++++++++++++++--------- tests/components/tts/test_init.py | 109 ++++++- tests/components/wyoming/test_tts.py | 18 +- 3 files changed, 330 insertions(+), 166 deletions(-) diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index 6fc25e32091..350b03a2e80 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -17,7 +17,7 @@ import secrets import subprocess import tempfile from time import monotonic -from typing import Any, Final, TypedDict +from typing import Any, Final from aiohttp import web import mutagen @@ -123,13 +123,94 @@ KEY_PATTERN = "{0}_{1}_{2}_{3}" SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({}) -class TTSCache(TypedDict): - """Cached TTS file.""" +class TTSCache: + """Cached bytes of a TTS result.""" - extension: str - voice: bytes - pending: asyncio.Task | None - last_used: float + _result_data: bytes | None = None + """When fully loaded, contains the result data.""" + + _partial_data: list[bytes] | None = None + """While loading, contains the data already received from the generator.""" + + _loading_error: Exception | None = None + """If an error occurred while loading, contains the error.""" + + _consumers: list[asyncio.Queue[bytes | None]] | None = None + """A queue for each current consumer to notify of new data while the generator is loading.""" + + def __init__( + self, + cache_key: str, + extension: str, + data_gen: AsyncGenerator[bytes], + ) -> None: + """Initialize the TTS cache.""" + self.cache_key = cache_key + self.extension = extension + self.last_used = monotonic() + self._data_gen = data_gen + + async def async_load_data(self) -> bytes: + """Load the data from the generator.""" + if self._result_data is not None or self._partial_data is not None: + raise RuntimeError("Data already being loaded") + + self._partial_data = [] + self._consumers = [] + + try: + async for chunk in self._data_gen: + self._partial_data.append(chunk) + for queue in self._consumers: + queue.put_nowait(chunk) + except Exception as err: # pylint: disable=broad-except + self._loading_error = err + raise + finally: + for queue in self._consumers: + queue.put_nowait(None) + self._consumers = None + + self._result_data = b"".join(self._partial_data) + self._partial_data = None + return self._result_data + + async def async_stream_data(self) -> AsyncGenerator[bytes]: + """Stream the data. + + Will return all data already returned from the generator. + Will listen for future data returned from the generator. + Raises error if one occurred. + """ + if self._result_data is not None: + yield self._result_data + return + if self._loading_error: + raise self._loading_error + + if self._partial_data is None: + raise RuntimeError("Data not being loaded") + + queue: asyncio.Queue[bytes | None] | None = None + # Check if generator is still feeding data + if self._consumers is not None: + queue = asyncio.Queue() + self._consumers.append(queue) + + for chunk in list(self._partial_data): + yield chunk + + if self._loading_error: + raise self._loading_error + + if queue is not None: + while (chunk2 := await queue.get()) is not None: + yield chunk2 + + if self._loading_error: + raise self._loading_error + + self.last_used = monotonic() @callback @@ -194,10 +275,11 @@ async def async_get_media_source_audio( ) -> tuple[str, bytes]: """Get TTS audio as extension, data.""" manager = hass.data[DATA_TTS_MANAGER] - cache_key = manager.async_cache_message_in_memory( + cache = manager.async_cache_message_in_memory( **media_source_id_to_kwargs(media_source_id) ) - return await manager.async_get_tts_audio(cache_key) + data = b"".join([chunk async for chunk in cache.async_stream_data()]) + return cache.extension, data @callback @@ -216,18 +298,19 @@ def async_get_text_to_speech_languages(hass: HomeAssistant) -> set[str]: return languages -async def async_convert_audio( +async def _async_convert_audio( hass: HomeAssistant, from_extension: str, - audio_bytes: bytes, + audio_bytes_gen: AsyncGenerator[bytes], to_extension: str, to_sample_rate: int | None = None, to_sample_channels: int | None = None, to_sample_bytes: int | None = None, -) -> bytes: +) -> AsyncGenerator[bytes]: """Convert audio to a preferred format using ffmpeg.""" ffmpeg_manager = ffmpeg.get_ffmpeg_manager(hass) - return await hass.async_add_executor_job( + audio_bytes = b"".join([chunk async for chunk in audio_bytes_gen]) + data = await hass.async_add_executor_job( lambda: _convert_audio( ffmpeg_manager.binary, from_extension, @@ -238,6 +321,7 @@ async def async_convert_audio( to_sample_bytes=to_sample_bytes, ) ) + yield data def _convert_audio( @@ -401,32 +485,33 @@ class ResultStream: return f"/api/tts_proxy/{self.token}" @cached_property - def _result_cache_key(self) -> asyncio.Future[str]: - """Get the future that returns the cache key.""" + def _result_cache(self) -> asyncio.Future[TTSCache]: + """Get the future that returns the cache.""" return asyncio.Future() @callback - def async_set_message_cache_key(self, cache_key: str) -> None: - """Set cache key for message to be streamed.""" - self._result_cache_key.set_result(cache_key) + def async_set_message_cache(self, cache: TTSCache) -> None: + """Set cache containing message audio to be streamed.""" + self._result_cache.set_result(cache) @callback def async_set_message(self, message: str) -> None: """Set message to be generated.""" - cache_key = self._manager.async_cache_message_in_memory( - engine=self.engine, - message=message, - use_file_cache=self.use_file_cache, - language=self.language, - options=self.options, + self._result_cache.set_result( + self._manager.async_cache_message_in_memory( + engine=self.engine, + message=message, + use_file_cache=self.use_file_cache, + language=self.language, + options=self.options, + ) ) - self._result_cache_key.set_result(cache_key) async def async_stream_result(self) -> AsyncGenerator[bytes]: """Get the stream of this result.""" - cache_key = await self._result_cache_key - _extension, data = await self._manager.async_get_tts_audio(cache_key) - yield data + cache = await self._result_cache + async for chunk in cache.async_stream_data(): + yield chunk def _hash_options(options: dict) -> str: @@ -483,7 +568,7 @@ class MemcacheCleanup: now = monotonic() for cache_key, info in list(memcache.items()): - if info["last_used"] + maxage < now: + if info.last_used + maxage < now: _LOGGER.debug("Cleaning up %s", cache_key) del memcache[cache_key] @@ -638,15 +723,18 @@ class SpeechManager: if message is None: return result_stream - cache_key = self._async_ensure_cached_in_memory( - engine=engine, - engine_instance=engine_instance, - message=message, - use_file_cache=use_file_cache, - language=language, - options=options, + # We added this method as an alternative to stream.async_set_message + # to avoid the options being processed twice + result_stream.async_set_message_cache( + self._async_ensure_cached_in_memory( + engine=engine, + engine_instance=engine_instance, + message=message, + use_file_cache=use_file_cache, + language=language, + options=options, + ) ) - result_stream.async_set_message_cache_key(cache_key) return result_stream @@ -658,7 +746,7 @@ class SpeechManager: use_file_cache: bool | None = None, language: str | None = None, options: dict | None = None, - ) -> str: + ) -> TTSCache: """Make sure a message is cached in memory and returns cache key.""" if (engine_instance := get_engine_instance(self.hass, engine)) is None: raise HomeAssistantError(f"Provider {engine} not found") @@ -685,7 +773,7 @@ class SpeechManager: use_file_cache: bool, language: str, options: dict, - ) -> str: + ) -> TTSCache: """Ensure a message is cached. Requires options, language to be processed. @@ -697,62 +785,101 @@ class SpeechManager: ).lower() # Is speech already in memory - if cache_key in self.mem_cache: - return cache_key + if cache := self.mem_cache.get(cache_key): + _LOGGER.debug("Found audio in cache for %s", message[0:32]) + return cache - if use_file_cache and cache_key in self.file_cache: - coro = self._async_load_file_to_mem(cache_key) + store_to_disk = use_file_cache + + if use_file_cache and (filename := self.file_cache.get(cache_key)): + _LOGGER.debug("Loading audio from disk for %s", message[0:32]) + extension = os.path.splitext(filename)[1][1:] + data_gen = self._async_load_file(cache_key) + store_to_disk = False else: - coro = self._async_generate_tts_audio( - engine_instance, cache_key, message, use_file_cache, language, options + _LOGGER.debug("Generating audio for %s", message[0:32]) + extension = options.get(ATTR_PREFERRED_FORMAT, _DEFAULT_FORMAT) + data_gen = self._async_generate_tts_audio( + engine_instance, message, language, options ) - task = self.hass.async_create_task(coro, eager_start=False) + cache = TTSCache( + cache_key=cache_key, + extension=extension, + data_gen=data_gen, + ) - def handle_error(future: asyncio.Future) -> None: - """Handle error.""" - if not (err := future.exception()): - return + self.mem_cache[cache_key] = cache + self.hass.async_create_background_task( + self._load_data_into_cache( + cache, engine_instance, message, store_to_disk, language, options + ), + f"tts_load_data_into_cache_{engine_instance.name}", + ) + self.memcache_cleanup.schedule() + return cache + + async def _load_data_into_cache( + self, + cache: TTSCache, + engine_instance: TextToSpeechEntity | Provider, + message: str, + store_to_disk: bool, + language: str, + options: dict, + ) -> None: + """Load and process a finished loading TTS Cache.""" + try: + data = await cache.async_load_data() + except Exception as err: # pylint: disable=broad-except # noqa: BLE001 # Truncate message so we don't flood the logs. Cutting off at 32 chars # but since we add 3 dots to truncated message, we cut off at 35. trunc_msg = message if len(message) < 35 else f"{message[0:32]}…" - _LOGGER.error("Error generating audio for %s: %s", trunc_msg, err) - self.mem_cache.pop(cache_key, None) + _LOGGER.error("Error getting audio for %s: %s", trunc_msg, err) + self.mem_cache.pop(cache.cache_key, None) + return - task.add_done_callback(handle_error) + if not store_to_disk: + return - self.mem_cache[cache_key] = { - "extension": "", - "voice": b"", - "pending": task, - "last_used": monotonic(), - } - return cache_key + filename = f"{cache.cache_key}.{cache.extension}".lower() - async def async_get_tts_audio(self, cache_key: str) -> tuple[str, bytes]: - """Fetch TTS audio.""" - cached = self.mem_cache.get(cache_key) - if cached is None: - raise HomeAssistantError("Audio not cached") - if pending := cached.get("pending"): - await pending - cached = self.mem_cache[cache_key] - cached["last_used"] = monotonic() - return cached["extension"], cached["voice"] + # Validate filename + if not _RE_VOICE_FILE.match(filename) and not _RE_LEGACY_VOICE_FILE.match( + filename + ): + raise HomeAssistantError( + f"TTS filename '{filename}' from {engine_instance.name} is invalid!" + ) + + if cache.extension == "mp3": + name = ( + engine_instance.name if isinstance(engine_instance.name, str) else "-" + ) + data = self.write_tags(filename, data, name, message, language, options) + + voice_file = os.path.join(self.cache_dir, filename) + + def save_speech() -> None: + """Store speech to filesystem.""" + with open(voice_file, "wb") as speech: + speech.write(data) + + try: + await self.hass.async_add_executor_job(save_speech) + except OSError as err: + _LOGGER.error("Can't write %s: %s", filename, err) + else: + self.file_cache[cache.cache_key] = filename async def _async_generate_tts_audio( self, engine_instance: TextToSpeechEntity | Provider, - cache_key: str, message: str, - cache_to_disk: bool, language: str, options: dict[str, Any], - ) -> None: - """Start loading of the TTS audio. - - This method is a coroutine. - """ + ) -> AsyncGenerator[bytes]: + """Generate TTS audio from an engine.""" options = dict(options or {}) supported_options = engine_instance.supported_options or [] @@ -800,6 +927,17 @@ class SpeechManager: extension, data = await engine_instance.async_get_tts_audio( message, language, options ) + + if data is None or extension is None: + raise HomeAssistantError( + f"No TTS from {engine_instance.name} for '{message}'" + ) + + async def make_data_generator(data: bytes) -> AsyncGenerator[bytes]: + yield data + + data_gen = make_data_generator(data) + else: async def message_gen() -> AsyncGenerator[str]: @@ -809,12 +947,7 @@ class SpeechManager: TTSAudioRequest(language, options, message_gen()) ) extension = tts_result.extension - data = b"".join([chunk async for chunk in tts_result.data_gen]) - - if data is None or extension is None: - raise HomeAssistantError( - f"No TTS from {engine_instance.name} for '{message}'" - ) + data_gen = tts_result.data_gen # Only convert if we have a preferred format different than the # expected format from the TTS system, or if a specific sample @@ -827,62 +960,21 @@ class SpeechManager: ) if needs_conversion: - data = await async_convert_audio( + data_gen = _async_convert_audio( self.hass, extension, - data, + data_gen, to_extension=final_extension, to_sample_rate=sample_rate, to_sample_channels=sample_channels, to_sample_bytes=sample_bytes, ) - # Create file infos - filename = f"{cache_key}.{final_extension}".lower() + async for chunk in data_gen: + yield chunk - # Validate filename - if not _RE_VOICE_FILE.match(filename) and not _RE_LEGACY_VOICE_FILE.match( - filename - ): - raise HomeAssistantError( - f"TTS filename '{filename}' from {engine_instance.name} is invalid!" - ) - - # Save to memory - if final_extension == "mp3": - data = self.write_tags( - filename, data, engine_instance.name, message, language, options - ) - - self._async_store_to_memcache(cache_key, final_extension, data) - - if not cache_to_disk: - return - - voice_file = os.path.join(self.cache_dir, filename) - - def save_speech() -> None: - """Store speech to filesystem.""" - with open(voice_file, "wb") as speech: - speech.write(data) - - # Don't await, we're going to do this in the background - task = self.hass.async_add_executor_job(save_speech) - - def write_done(future: asyncio.Future) -> None: - """Write is done task.""" - if err := future.exception(): - _LOGGER.error("Can't write %s: %s", filename, err) - else: - self.file_cache[cache_key] = filename - - task.add_done_callback(write_done) - - async def _async_load_file_to_mem(self, cache_key: str) -> None: - """Load voice from file cache into memory. - - This method is a coroutine. - """ + async def _async_load_file(self, cache_key: str) -> AsyncGenerator[bytes]: + """Load TTS audio from disk.""" if not (filename := self.file_cache.get(cache_key)): raise HomeAssistantError(f"Key {cache_key} not in file cache!") @@ -899,22 +991,7 @@ class SpeechManager: del self.file_cache[cache_key] raise HomeAssistantError(f"Can't read {voice_file}") from err - extension = os.path.splitext(filename)[1][1:] - - self._async_store_to_memcache(cache_key, extension, data) - - @callback - def _async_store_to_memcache( - self, cache_key: str, extension: str, data: bytes - ) -> None: - """Store data to memcache and set timer to remove it.""" - self.mem_cache[cache_key] = { - "extension": extension, - "voice": data, - "pending": None, - "last_used": monotonic(), - } - self.memcache_cleanup.schedule() + yield data @staticmethod def write_tags( diff --git a/tests/components/tts/test_init.py b/tests/components/tts/test_init.py index 8bdd17cf3e9..be14e006610 100644 --- a/tests/components/tts/test_init.py +++ b/tests/components/tts/test_init.py @@ -168,7 +168,7 @@ async def test_service( assert await get_media_source_url( hass, calls[0].data[ATTR_MEDIA_CONTENT_ID] ) == ("/api/tts_proxy/test_token.mp3") - await hass.async_block_till_done() + await hass.async_block_till_done(wait_background_tasks=True) assert ( mock_tts_cache_dir / f"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_{expected_url_suffix}.mp3" @@ -230,7 +230,7 @@ async def test_service_default_language( assert await get_media_source_url( hass, calls[0].data[ATTR_MEDIA_CONTENT_ID] ) == ("/api/tts_proxy/test_token.mp3") - await hass.async_block_till_done() + await hass.async_block_till_done(wait_background_tasks=True) assert ( mock_tts_cache_dir / ( @@ -294,7 +294,7 @@ async def test_service_default_special_language( assert await get_media_source_url( hass, calls[0].data[ATTR_MEDIA_CONTENT_ID] ) == ("/api/tts_proxy/test_token.mp3") - await hass.async_block_till_done() + await hass.async_block_till_done(wait_background_tasks=True) assert ( mock_tts_cache_dir / f"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_{expected_url_suffix}.mp3" @@ -354,7 +354,7 @@ async def test_service_language( assert await get_media_source_url( hass, calls[0].data[ATTR_MEDIA_CONTENT_ID] ) == ("/api/tts_proxy/test_token.mp3") - await hass.async_block_till_done() + await hass.async_block_till_done(wait_background_tasks=True) assert ( mock_tts_cache_dir / f"42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_-_{expected_url_suffix}.mp3" @@ -470,7 +470,7 @@ async def test_service_options( assert await get_media_source_url( hass, calls[0].data[ATTR_MEDIA_CONTENT_ID] ) == ("/api/tts_proxy/test_token.mp3") - await hass.async_block_till_done() + await hass.async_block_till_done(wait_background_tasks=True) assert ( mock_tts_cache_dir / ( @@ -554,7 +554,7 @@ async def test_service_default_options( assert await get_media_source_url( hass, calls[0].data[ATTR_MEDIA_CONTENT_ID] ) == ("/api/tts_proxy/test_token.mp3") - await hass.async_block_till_done() + await hass.async_block_till_done(wait_background_tasks=True) assert ( mock_tts_cache_dir / ( @@ -628,7 +628,7 @@ async def test_merge_default_service_options( assert await get_media_source_url( hass, calls[0].data[ATTR_MEDIA_CONTENT_ID] ) == ("/api/tts_proxy/test_token.mp3") - await hass.async_block_till_done() + await hass.async_block_till_done(wait_background_tasks=True) assert ( mock_tts_cache_dir / ( @@ -743,7 +743,7 @@ async def test_service_clear_cache( # To make sure the file is persisted assert len(calls) == 1 await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) - await hass.async_block_till_done() + await hass.async_block_till_done(wait_background_tasks=True) assert ( mock_tts_cache_dir / f"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_{expected_url_suffix}.mp3" @@ -1769,9 +1769,15 @@ async def test_async_convert_audio_error(hass: HomeAssistant) -> None: """Test that ffmpeg failing during audio conversion will raise an error.""" assert await async_setup_component(hass, ffmpeg.DOMAIN, {}) - with pytest.raises(RuntimeError): + async def bad_data_gen(): + yield bytes(0) + + with pytest.raises(RuntimeError): # noqa: PT012 # Simulate a bad WAV file - await tts.async_convert_audio(hass, "wav", bytes(0), "mp3") + async for _chunk in tts._async_convert_audio( + hass, "wav", bad_data_gen(), "mp3" + ): + pass async def test_default_engine_prefer_entity( @@ -1846,3 +1852,86 @@ async def test_stream(hass: HomeAssistant, mock_tts_entity: MockTTSEntity) -> No assert stream2.extension == "wav" result_data = b"".join([chunk async for chunk in stream2.async_stream_result()]) assert result_data == data + + +async def test_tts_cache() -> None: + """Test TTSCache.""" + + async def data_gen(queue: asyncio.Queue[bytes | None | Exception]): + while chunk := await queue.get(): + if isinstance(chunk, Exception): + raise chunk + yield chunk + + queue = asyncio.Queue() + cache = tts.TTSCache("test-key", "mp3", data_gen(queue)) + assert cache.cache_key == "test-key" + assert cache.extension == "mp3" + + for i in range(10): + queue.put_nowait(f"{i}".encode()) + queue.put_nowait(None) + + assert await cache.async_load_data() == b"0123456789" + + with pytest.raises(RuntimeError): + await cache.async_load_data() + + # When data is loaded, we get it all in 1 chunk + cur = 0 + async for chunk in cache.async_stream_data(): + assert chunk == b"0123456789" + cur += 1 + assert cur == 1 + + # Show we can stream the data while it's still being generated + async def consume_cache(cache: tts.TTSCache): + return b"".join([chunk async for chunk in cache.async_stream_data()]) + + queue = asyncio.Queue() + cache = tts.TTSCache("test-key", "mp3", data_gen(queue)) + + load_data_task = asyncio.create_task(cache.async_load_data()) + consume_pre_data_loaded_task = asyncio.create_task(consume_cache(cache)) + queue.put_nowait(b"0") + await asyncio.sleep(0) + queue.put_nowait(b"1") + await asyncio.sleep(0) + consume_mid_data_task = asyncio.create_task(consume_cache(cache)) + queue.put_nowait(b"2") + await asyncio.sleep(0) + queue.put_nowait(None) + consume_post_data_loaded_task = asyncio.create_task(consume_cache(cache)) + await asyncio.sleep(0) + assert await load_data_task == b"012" + assert await consume_post_data_loaded_task == b"012" + assert await consume_mid_data_task == b"012" + assert await consume_pre_data_loaded_task == b"012" + + # Now with errors + async def consume_cache(cache: tts.TTSCache): + return b"".join([chunk async for chunk in cache.async_stream_data()]) + + queue = asyncio.Queue() + cache = tts.TTSCache("test-key", "mp3", data_gen(queue)) + + load_data_task = asyncio.create_task(cache.async_load_data()) + consume_pre_data_loaded_task = asyncio.create_task(consume_cache(cache)) + queue.put_nowait(b"0") + await asyncio.sleep(0) + queue.put_nowait(b"1") + await asyncio.sleep(0) + consume_mid_data_task = asyncio.create_task(consume_cache(cache)) + queue.put_nowait(ValueError("Boom!")) + await asyncio.sleep(0) + queue.put_nowait(None) + consume_post_data_loaded_task = asyncio.create_task(consume_cache(cache)) + await asyncio.sleep(0) + with pytest.raises(ValueError): + assert await load_data_task == b"012" + with pytest.raises(ValueError): + assert await consume_post_data_loaded_task == b"012" + with pytest.raises(ValueError): + assert await consume_mid_data_task == b"012" + with pytest.raises(ValueError): + assert await consume_pre_data_loaded_task == b"012" diff --git a/tests/components/wyoming/test_tts.py b/tests/components/wyoming/test_tts.py index 263804787b1..73fb68b44e5 100644 --- a/tests/components/wyoming/test_tts.py +++ b/tests/components/wyoming/test_tts.py @@ -150,17 +150,15 @@ async def test_get_tts_audio_connection_lost( hass: HomeAssistant, init_wyoming_tts ) -> None: """Test streaming audio and losing connection.""" - with ( - patch( - "homeassistant.components.wyoming.tts.AsyncTcpClient", - MockAsyncTcpClient([None]), - ), - pytest.raises(HomeAssistantError), + stream = tts.async_create_stream(hass, "tts.test_tts", "en-US") + with patch( + "homeassistant.components.wyoming.tts.AsyncTcpClient", + MockAsyncTcpClient([None]), ): - await tts.async_get_media_source_audio( - hass, - tts.generate_media_source_id(hass, "Hello world", "tts.test_tts", "en-US"), - ) + stream.async_set_message("Hello world") + with pytest.raises(HomeAssistantError): # noqa: PT012 + async for _chunk in stream.async_stream_result(): + pass async def test_get_tts_audio_audio_oserror(