Switch more TTS core to async generators (#140432)

* Switch more TTS core to async generators

* Document a design choice

* robust

* Add more tests

* Update comment

* Clarify and document TTSCache variables
This commit is contained in:
Paulus Schoutsen 2025-03-13 13:24:44 -04:00 committed by GitHub
parent b07ac301b9
commit 55895df54d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 330 additions and 166 deletions

View File

@ -17,7 +17,7 @@ import secrets
import subprocess
import tempfile
from time import monotonic
from typing import Any, Final, TypedDict
from typing import Any, Final
from aiohttp import web
import mutagen
@ -123,13 +123,94 @@ KEY_PATTERN = "{0}_{1}_{2}_{3}"
SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({})
class TTSCache(TypedDict):
"""Cached TTS file."""
class TTSCache:
"""Cached bytes of a TTS result."""
extension: str
voice: bytes
pending: asyncio.Task | None
last_used: float
_result_data: bytes | None = None
"""When fully loaded, contains the result data."""
_partial_data: list[bytes] | None = None
"""While loading, contains the data already received from the generator."""
_loading_error: Exception | None = None
"""If an error occurred while loading, contains the error."""
_consumers: list[asyncio.Queue[bytes | None]] | None = None
"""A queue for each current consumer to notify of new data while the generator is loading."""
def __init__(
self,
cache_key: str,
extension: str,
data_gen: AsyncGenerator[bytes],
) -> None:
"""Initialize the TTS cache."""
self.cache_key = cache_key
self.extension = extension
self.last_used = monotonic()
self._data_gen = data_gen
async def async_load_data(self) -> bytes:
"""Load the data from the generator."""
if self._result_data is not None or self._partial_data is not None:
raise RuntimeError("Data already being loaded")
self._partial_data = []
self._consumers = []
try:
async for chunk in self._data_gen:
self._partial_data.append(chunk)
for queue in self._consumers:
queue.put_nowait(chunk)
except Exception as err: # pylint: disable=broad-except
self._loading_error = err
raise
finally:
for queue in self._consumers:
queue.put_nowait(None)
self._consumers = None
self._result_data = b"".join(self._partial_data)
self._partial_data = None
return self._result_data
async def async_stream_data(self) -> AsyncGenerator[bytes]:
"""Stream the data.
Will return all data already returned from the generator.
Will listen for future data returned from the generator.
Raises error if one occurred.
"""
if self._result_data is not None:
yield self._result_data
return
if self._loading_error:
raise self._loading_error
if self._partial_data is None:
raise RuntimeError("Data not being loaded")
queue: asyncio.Queue[bytes | None] | None = None
# Check if generator is still feeding data
if self._consumers is not None:
queue = asyncio.Queue()
self._consumers.append(queue)
for chunk in list(self._partial_data):
yield chunk
if self._loading_error:
raise self._loading_error
if queue is not None:
while (chunk2 := await queue.get()) is not None:
yield chunk2
if self._loading_error:
raise self._loading_error
self.last_used = monotonic()
@callback
@ -194,10 +275,11 @@ async def async_get_media_source_audio(
) -> tuple[str, bytes]:
"""Get TTS audio as extension, data."""
manager = hass.data[DATA_TTS_MANAGER]
cache_key = manager.async_cache_message_in_memory(
cache = manager.async_cache_message_in_memory(
**media_source_id_to_kwargs(media_source_id)
)
return await manager.async_get_tts_audio(cache_key)
data = b"".join([chunk async for chunk in cache.async_stream_data()])
return cache.extension, data
@callback
@ -216,18 +298,19 @@ def async_get_text_to_speech_languages(hass: HomeAssistant) -> set[str]:
return languages
async def async_convert_audio(
async def _async_convert_audio(
hass: HomeAssistant,
from_extension: str,
audio_bytes: bytes,
audio_bytes_gen: AsyncGenerator[bytes],
to_extension: str,
to_sample_rate: int | None = None,
to_sample_channels: int | None = None,
to_sample_bytes: int | None = None,
) -> bytes:
) -> AsyncGenerator[bytes]:
"""Convert audio to a preferred format using ffmpeg."""
ffmpeg_manager = ffmpeg.get_ffmpeg_manager(hass)
return await hass.async_add_executor_job(
audio_bytes = b"".join([chunk async for chunk in audio_bytes_gen])
data = await hass.async_add_executor_job(
lambda: _convert_audio(
ffmpeg_manager.binary,
from_extension,
@ -238,6 +321,7 @@ async def async_convert_audio(
to_sample_bytes=to_sample_bytes,
)
)
yield data
def _convert_audio(
@ -401,32 +485,33 @@ class ResultStream:
return f"/api/tts_proxy/{self.token}"
@cached_property
def _result_cache_key(self) -> asyncio.Future[str]:
"""Get the future that returns the cache key."""
def _result_cache(self) -> asyncio.Future[TTSCache]:
"""Get the future that returns the cache."""
return asyncio.Future()
@callback
def async_set_message_cache_key(self, cache_key: str) -> None:
"""Set cache key for message to be streamed."""
self._result_cache_key.set_result(cache_key)
def async_set_message_cache(self, cache: TTSCache) -> None:
"""Set cache containing message audio to be streamed."""
self._result_cache.set_result(cache)
@callback
def async_set_message(self, message: str) -> None:
"""Set message to be generated."""
cache_key = self._manager.async_cache_message_in_memory(
engine=self.engine,
message=message,
use_file_cache=self.use_file_cache,
language=self.language,
options=self.options,
self._result_cache.set_result(
self._manager.async_cache_message_in_memory(
engine=self.engine,
message=message,
use_file_cache=self.use_file_cache,
language=self.language,
options=self.options,
)
)
self._result_cache_key.set_result(cache_key)
async def async_stream_result(self) -> AsyncGenerator[bytes]:
"""Get the stream of this result."""
cache_key = await self._result_cache_key
_extension, data = await self._manager.async_get_tts_audio(cache_key)
yield data
cache = await self._result_cache
async for chunk in cache.async_stream_data():
yield chunk
def _hash_options(options: dict) -> str:
@ -483,7 +568,7 @@ class MemcacheCleanup:
now = monotonic()
for cache_key, info in list(memcache.items()):
if info["last_used"] + maxage < now:
if info.last_used + maxage < now:
_LOGGER.debug("Cleaning up %s", cache_key)
del memcache[cache_key]
@ -638,15 +723,18 @@ class SpeechManager:
if message is None:
return result_stream
cache_key = self._async_ensure_cached_in_memory(
engine=engine,
engine_instance=engine_instance,
message=message,
use_file_cache=use_file_cache,
language=language,
options=options,
# We added this method as an alternative to stream.async_set_message
# to avoid the options being processed twice
result_stream.async_set_message_cache(
self._async_ensure_cached_in_memory(
engine=engine,
engine_instance=engine_instance,
message=message,
use_file_cache=use_file_cache,
language=language,
options=options,
)
)
result_stream.async_set_message_cache_key(cache_key)
return result_stream
@ -658,7 +746,7 @@ class SpeechManager:
use_file_cache: bool | None = None,
language: str | None = None,
options: dict | None = None,
) -> str:
) -> TTSCache:
"""Make sure a message is cached in memory and returns cache key."""
if (engine_instance := get_engine_instance(self.hass, engine)) is None:
raise HomeAssistantError(f"Provider {engine} not found")
@ -685,7 +773,7 @@ class SpeechManager:
use_file_cache: bool,
language: str,
options: dict,
) -> str:
) -> TTSCache:
"""Ensure a message is cached.
Requires options, language to be processed.
@ -697,62 +785,101 @@ class SpeechManager:
).lower()
# Is speech already in memory
if cache_key in self.mem_cache:
return cache_key
if cache := self.mem_cache.get(cache_key):
_LOGGER.debug("Found audio in cache for %s", message[0:32])
return cache
if use_file_cache and cache_key in self.file_cache:
coro = self._async_load_file_to_mem(cache_key)
store_to_disk = use_file_cache
if use_file_cache and (filename := self.file_cache.get(cache_key)):
_LOGGER.debug("Loading audio from disk for %s", message[0:32])
extension = os.path.splitext(filename)[1][1:]
data_gen = self._async_load_file(cache_key)
store_to_disk = False
else:
coro = self._async_generate_tts_audio(
engine_instance, cache_key, message, use_file_cache, language, options
_LOGGER.debug("Generating audio for %s", message[0:32])
extension = options.get(ATTR_PREFERRED_FORMAT, _DEFAULT_FORMAT)
data_gen = self._async_generate_tts_audio(
engine_instance, message, language, options
)
task = self.hass.async_create_task(coro, eager_start=False)
cache = TTSCache(
cache_key=cache_key,
extension=extension,
data_gen=data_gen,
)
def handle_error(future: asyncio.Future) -> None:
"""Handle error."""
if not (err := future.exception()):
return
self.mem_cache[cache_key] = cache
self.hass.async_create_background_task(
self._load_data_into_cache(
cache, engine_instance, message, store_to_disk, language, options
),
f"tts_load_data_into_cache_{engine_instance.name}",
)
self.memcache_cleanup.schedule()
return cache
async def _load_data_into_cache(
self,
cache: TTSCache,
engine_instance: TextToSpeechEntity | Provider,
message: str,
store_to_disk: bool,
language: str,
options: dict,
) -> None:
"""Load and process a finished loading TTS Cache."""
try:
data = await cache.async_load_data()
except Exception as err: # pylint: disable=broad-except # noqa: BLE001
# Truncate message so we don't flood the logs. Cutting off at 32 chars
# but since we add 3 dots to truncated message, we cut off at 35.
trunc_msg = message if len(message) < 35 else f"{message[0:32]}"
_LOGGER.error("Error generating audio for %s: %s", trunc_msg, err)
self.mem_cache.pop(cache_key, None)
_LOGGER.error("Error getting audio for %s: %s", trunc_msg, err)
self.mem_cache.pop(cache.cache_key, None)
return
task.add_done_callback(handle_error)
if not store_to_disk:
return
self.mem_cache[cache_key] = {
"extension": "",
"voice": b"",
"pending": task,
"last_used": monotonic(),
}
return cache_key
filename = f"{cache.cache_key}.{cache.extension}".lower()
async def async_get_tts_audio(self, cache_key: str) -> tuple[str, bytes]:
"""Fetch TTS audio."""
cached = self.mem_cache.get(cache_key)
if cached is None:
raise HomeAssistantError("Audio not cached")
if pending := cached.get("pending"):
await pending
cached = self.mem_cache[cache_key]
cached["last_used"] = monotonic()
return cached["extension"], cached["voice"]
# Validate filename
if not _RE_VOICE_FILE.match(filename) and not _RE_LEGACY_VOICE_FILE.match(
filename
):
raise HomeAssistantError(
f"TTS filename '{filename}' from {engine_instance.name} is invalid!"
)
if cache.extension == "mp3":
name = (
engine_instance.name if isinstance(engine_instance.name, str) else "-"
)
data = self.write_tags(filename, data, name, message, language, options)
voice_file = os.path.join(self.cache_dir, filename)
def save_speech() -> None:
"""Store speech to filesystem."""
with open(voice_file, "wb") as speech:
speech.write(data)
try:
await self.hass.async_add_executor_job(save_speech)
except OSError as err:
_LOGGER.error("Can't write %s: %s", filename, err)
else:
self.file_cache[cache.cache_key] = filename
async def _async_generate_tts_audio(
self,
engine_instance: TextToSpeechEntity | Provider,
cache_key: str,
message: str,
cache_to_disk: bool,
language: str,
options: dict[str, Any],
) -> None:
"""Start loading of the TTS audio.
This method is a coroutine.
"""
) -> AsyncGenerator[bytes]:
"""Generate TTS audio from an engine."""
options = dict(options or {})
supported_options = engine_instance.supported_options or []
@ -800,6 +927,17 @@ class SpeechManager:
extension, data = await engine_instance.async_get_tts_audio(
message, language, options
)
if data is None or extension is None:
raise HomeAssistantError(
f"No TTS from {engine_instance.name} for '{message}'"
)
async def make_data_generator(data: bytes) -> AsyncGenerator[bytes]:
yield data
data_gen = make_data_generator(data)
else:
async def message_gen() -> AsyncGenerator[str]:
@ -809,12 +947,7 @@ class SpeechManager:
TTSAudioRequest(language, options, message_gen())
)
extension = tts_result.extension
data = b"".join([chunk async for chunk in tts_result.data_gen])
if data is None or extension is None:
raise HomeAssistantError(
f"No TTS from {engine_instance.name} for '{message}'"
)
data_gen = tts_result.data_gen
# Only convert if we have a preferred format different than the
# expected format from the TTS system, or if a specific sample
@ -827,62 +960,21 @@ class SpeechManager:
)
if needs_conversion:
data = await async_convert_audio(
data_gen = _async_convert_audio(
self.hass,
extension,
data,
data_gen,
to_extension=final_extension,
to_sample_rate=sample_rate,
to_sample_channels=sample_channels,
to_sample_bytes=sample_bytes,
)
# Create file infos
filename = f"{cache_key}.{final_extension}".lower()
async for chunk in data_gen:
yield chunk
# Validate filename
if not _RE_VOICE_FILE.match(filename) and not _RE_LEGACY_VOICE_FILE.match(
filename
):
raise HomeAssistantError(
f"TTS filename '{filename}' from {engine_instance.name} is invalid!"
)
# Save to memory
if final_extension == "mp3":
data = self.write_tags(
filename, data, engine_instance.name, message, language, options
)
self._async_store_to_memcache(cache_key, final_extension, data)
if not cache_to_disk:
return
voice_file = os.path.join(self.cache_dir, filename)
def save_speech() -> None:
"""Store speech to filesystem."""
with open(voice_file, "wb") as speech:
speech.write(data)
# Don't await, we're going to do this in the background
task = self.hass.async_add_executor_job(save_speech)
def write_done(future: asyncio.Future) -> None:
"""Write is done task."""
if err := future.exception():
_LOGGER.error("Can't write %s: %s", filename, err)
else:
self.file_cache[cache_key] = filename
task.add_done_callback(write_done)
async def _async_load_file_to_mem(self, cache_key: str) -> None:
"""Load voice from file cache into memory.
This method is a coroutine.
"""
async def _async_load_file(self, cache_key: str) -> AsyncGenerator[bytes]:
"""Load TTS audio from disk."""
if not (filename := self.file_cache.get(cache_key)):
raise HomeAssistantError(f"Key {cache_key} not in file cache!")
@ -899,22 +991,7 @@ class SpeechManager:
del self.file_cache[cache_key]
raise HomeAssistantError(f"Can't read {voice_file}") from err
extension = os.path.splitext(filename)[1][1:]
self._async_store_to_memcache(cache_key, extension, data)
@callback
def _async_store_to_memcache(
self, cache_key: str, extension: str, data: bytes
) -> None:
"""Store data to memcache and set timer to remove it."""
self.mem_cache[cache_key] = {
"extension": extension,
"voice": data,
"pending": None,
"last_used": monotonic(),
}
self.memcache_cleanup.schedule()
yield data
@staticmethod
def write_tags(

View File

@ -168,7 +168,7 @@ async def test_service(
assert await get_media_source_url(
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
) == ("/api/tts_proxy/test_token.mp3")
await hass.async_block_till_done()
await hass.async_block_till_done(wait_background_tasks=True)
assert (
mock_tts_cache_dir
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_{expected_url_suffix}.mp3"
@ -230,7 +230,7 @@ async def test_service_default_language(
assert await get_media_source_url(
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
) == ("/api/tts_proxy/test_token.mp3")
await hass.async_block_till_done()
await hass.async_block_till_done(wait_background_tasks=True)
assert (
mock_tts_cache_dir
/ (
@ -294,7 +294,7 @@ async def test_service_default_special_language(
assert await get_media_source_url(
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
) == ("/api/tts_proxy/test_token.mp3")
await hass.async_block_till_done()
await hass.async_block_till_done(wait_background_tasks=True)
assert (
mock_tts_cache_dir
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_{expected_url_suffix}.mp3"
@ -354,7 +354,7 @@ async def test_service_language(
assert await get_media_source_url(
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
) == ("/api/tts_proxy/test_token.mp3")
await hass.async_block_till_done()
await hass.async_block_till_done(wait_background_tasks=True)
assert (
mock_tts_cache_dir
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_-_{expected_url_suffix}.mp3"
@ -470,7 +470,7 @@ async def test_service_options(
assert await get_media_source_url(
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
) == ("/api/tts_proxy/test_token.mp3")
await hass.async_block_till_done()
await hass.async_block_till_done(wait_background_tasks=True)
assert (
mock_tts_cache_dir
/ (
@ -554,7 +554,7 @@ async def test_service_default_options(
assert await get_media_source_url(
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
) == ("/api/tts_proxy/test_token.mp3")
await hass.async_block_till_done()
await hass.async_block_till_done(wait_background_tasks=True)
assert (
mock_tts_cache_dir
/ (
@ -628,7 +628,7 @@ async def test_merge_default_service_options(
assert await get_media_source_url(
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
) == ("/api/tts_proxy/test_token.mp3")
await hass.async_block_till_done()
await hass.async_block_till_done(wait_background_tasks=True)
assert (
mock_tts_cache_dir
/ (
@ -743,7 +743,7 @@ async def test_service_clear_cache(
# To make sure the file is persisted
assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
await hass.async_block_till_done()
await hass.async_block_till_done(wait_background_tasks=True)
assert (
mock_tts_cache_dir
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_{expected_url_suffix}.mp3"
@ -1769,9 +1769,15 @@ async def test_async_convert_audio_error(hass: HomeAssistant) -> None:
"""Test that ffmpeg failing during audio conversion will raise an error."""
assert await async_setup_component(hass, ffmpeg.DOMAIN, {})
with pytest.raises(RuntimeError):
async def bad_data_gen():
yield bytes(0)
with pytest.raises(RuntimeError): # noqa: PT012
# Simulate a bad WAV file
await tts.async_convert_audio(hass, "wav", bytes(0), "mp3")
async for _chunk in tts._async_convert_audio(
hass, "wav", bad_data_gen(), "mp3"
):
pass
async def test_default_engine_prefer_entity(
@ -1846,3 +1852,86 @@ async def test_stream(hass: HomeAssistant, mock_tts_entity: MockTTSEntity) -> No
assert stream2.extension == "wav"
result_data = b"".join([chunk async for chunk in stream2.async_stream_result()])
assert result_data == data
async def test_tts_cache() -> None:
"""Test TTSCache."""
async def data_gen(queue: asyncio.Queue[bytes | None | Exception]):
while chunk := await queue.get():
if isinstance(chunk, Exception):
raise chunk
yield chunk
queue = asyncio.Queue()
cache = tts.TTSCache("test-key", "mp3", data_gen(queue))
assert cache.cache_key == "test-key"
assert cache.extension == "mp3"
for i in range(10):
queue.put_nowait(f"{i}".encode())
queue.put_nowait(None)
assert await cache.async_load_data() == b"0123456789"
with pytest.raises(RuntimeError):
await cache.async_load_data()
# When data is loaded, we get it all in 1 chunk
cur = 0
async for chunk in cache.async_stream_data():
assert chunk == b"0123456789"
cur += 1
assert cur == 1
# Show we can stream the data while it's still being generated
async def consume_cache(cache: tts.TTSCache):
return b"".join([chunk async for chunk in cache.async_stream_data()])
queue = asyncio.Queue()
cache = tts.TTSCache("test-key", "mp3", data_gen(queue))
load_data_task = asyncio.create_task(cache.async_load_data())
consume_pre_data_loaded_task = asyncio.create_task(consume_cache(cache))
queue.put_nowait(b"0")
await asyncio.sleep(0)
queue.put_nowait(b"1")
await asyncio.sleep(0)
consume_mid_data_task = asyncio.create_task(consume_cache(cache))
queue.put_nowait(b"2")
await asyncio.sleep(0)
queue.put_nowait(None)
consume_post_data_loaded_task = asyncio.create_task(consume_cache(cache))
await asyncio.sleep(0)
assert await load_data_task == b"012"
assert await consume_post_data_loaded_task == b"012"
assert await consume_mid_data_task == b"012"
assert await consume_pre_data_loaded_task == b"012"
# Now with errors
async def consume_cache(cache: tts.TTSCache):
return b"".join([chunk async for chunk in cache.async_stream_data()])
queue = asyncio.Queue()
cache = tts.TTSCache("test-key", "mp3", data_gen(queue))
load_data_task = asyncio.create_task(cache.async_load_data())
consume_pre_data_loaded_task = asyncio.create_task(consume_cache(cache))
queue.put_nowait(b"0")
await asyncio.sleep(0)
queue.put_nowait(b"1")
await asyncio.sleep(0)
consume_mid_data_task = asyncio.create_task(consume_cache(cache))
queue.put_nowait(ValueError("Boom!"))
await asyncio.sleep(0)
queue.put_nowait(None)
consume_post_data_loaded_task = asyncio.create_task(consume_cache(cache))
await asyncio.sleep(0)
with pytest.raises(ValueError):
assert await load_data_task == b"012"
with pytest.raises(ValueError):
assert await consume_post_data_loaded_task == b"012"
with pytest.raises(ValueError):
assert await consume_mid_data_task == b"012"
with pytest.raises(ValueError):
assert await consume_pre_data_loaded_task == b"012"

View File

@ -150,17 +150,15 @@ async def test_get_tts_audio_connection_lost(
hass: HomeAssistant, init_wyoming_tts
) -> None:
"""Test streaming audio and losing connection."""
with (
patch(
"homeassistant.components.wyoming.tts.AsyncTcpClient",
MockAsyncTcpClient([None]),
),
pytest.raises(HomeAssistantError),
stream = tts.async_create_stream(hass, "tts.test_tts", "en-US")
with patch(
"homeassistant.components.wyoming.tts.AsyncTcpClient",
MockAsyncTcpClient([None]),
):
await tts.async_get_media_source_audio(
hass,
tts.generate_media_source_id(hass, "Hello world", "tts.test_tts", "en-US"),
)
stream.async_set_message("Hello world")
with pytest.raises(HomeAssistantError): # noqa: PT012
async for _chunk in stream.async_stream_result():
pass
async def test_get_tts_audio_audio_oserror(