Switch more TTS core to async generators (#140432)

* Switch more TTS core to async generators

* Document a design choice

* robust

* Add more tests

* Update comment

* Clarify and document TTSCache variables
This commit is contained in:
Paulus Schoutsen 2025-03-13 13:24:44 -04:00 committed by GitHub
parent b07ac301b9
commit 55895df54d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 330 additions and 166 deletions

View File

@ -17,7 +17,7 @@ import secrets
import subprocess import subprocess
import tempfile import tempfile
from time import monotonic from time import monotonic
from typing import Any, Final, TypedDict from typing import Any, Final
from aiohttp import web from aiohttp import web
import mutagen import mutagen
@ -123,13 +123,94 @@ KEY_PATTERN = "{0}_{1}_{2}_{3}"
SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({}) SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({})
class TTSCache(TypedDict): class TTSCache:
"""Cached TTS file.""" """Cached bytes of a TTS result."""
extension: str _result_data: bytes | None = None
voice: bytes """When fully loaded, contains the result data."""
pending: asyncio.Task | None
last_used: float _partial_data: list[bytes] | None = None
"""While loading, contains the data already received from the generator."""
_loading_error: Exception | None = None
"""If an error occurred while loading, contains the error."""
_consumers: list[asyncio.Queue[bytes | None]] | None = None
"""A queue for each current consumer to notify of new data while the generator is loading."""
def __init__(
self,
cache_key: str,
extension: str,
data_gen: AsyncGenerator[bytes],
) -> None:
"""Initialize the TTS cache."""
self.cache_key = cache_key
self.extension = extension
self.last_used = monotonic()
self._data_gen = data_gen
async def async_load_data(self) -> bytes:
"""Load the data from the generator."""
if self._result_data is not None or self._partial_data is not None:
raise RuntimeError("Data already being loaded")
self._partial_data = []
self._consumers = []
try:
async for chunk in self._data_gen:
self._partial_data.append(chunk)
for queue in self._consumers:
queue.put_nowait(chunk)
except Exception as err: # pylint: disable=broad-except
self._loading_error = err
raise
finally:
for queue in self._consumers:
queue.put_nowait(None)
self._consumers = None
self._result_data = b"".join(self._partial_data)
self._partial_data = None
return self._result_data
async def async_stream_data(self) -> AsyncGenerator[bytes]:
"""Stream the data.
Will return all data already returned from the generator.
Will listen for future data returned from the generator.
Raises error if one occurred.
"""
if self._result_data is not None:
yield self._result_data
return
if self._loading_error:
raise self._loading_error
if self._partial_data is None:
raise RuntimeError("Data not being loaded")
queue: asyncio.Queue[bytes | None] | None = None
# Check if generator is still feeding data
if self._consumers is not None:
queue = asyncio.Queue()
self._consumers.append(queue)
for chunk in list(self._partial_data):
yield chunk
if self._loading_error:
raise self._loading_error
if queue is not None:
while (chunk2 := await queue.get()) is not None:
yield chunk2
if self._loading_error:
raise self._loading_error
self.last_used = monotonic()
@callback @callback
@ -194,10 +275,11 @@ async def async_get_media_source_audio(
) -> tuple[str, bytes]: ) -> tuple[str, bytes]:
"""Get TTS audio as extension, data.""" """Get TTS audio as extension, data."""
manager = hass.data[DATA_TTS_MANAGER] manager = hass.data[DATA_TTS_MANAGER]
cache_key = manager.async_cache_message_in_memory( cache = manager.async_cache_message_in_memory(
**media_source_id_to_kwargs(media_source_id) **media_source_id_to_kwargs(media_source_id)
) )
return await manager.async_get_tts_audio(cache_key) data = b"".join([chunk async for chunk in cache.async_stream_data()])
return cache.extension, data
@callback @callback
@ -216,18 +298,19 @@ def async_get_text_to_speech_languages(hass: HomeAssistant) -> set[str]:
return languages return languages
async def async_convert_audio( async def _async_convert_audio(
hass: HomeAssistant, hass: HomeAssistant,
from_extension: str, from_extension: str,
audio_bytes: bytes, audio_bytes_gen: AsyncGenerator[bytes],
to_extension: str, to_extension: str,
to_sample_rate: int | None = None, to_sample_rate: int | None = None,
to_sample_channels: int | None = None, to_sample_channels: int | None = None,
to_sample_bytes: int | None = None, to_sample_bytes: int | None = None,
) -> bytes: ) -> AsyncGenerator[bytes]:
"""Convert audio to a preferred format using ffmpeg.""" """Convert audio to a preferred format using ffmpeg."""
ffmpeg_manager = ffmpeg.get_ffmpeg_manager(hass) ffmpeg_manager = ffmpeg.get_ffmpeg_manager(hass)
return await hass.async_add_executor_job( audio_bytes = b"".join([chunk async for chunk in audio_bytes_gen])
data = await hass.async_add_executor_job(
lambda: _convert_audio( lambda: _convert_audio(
ffmpeg_manager.binary, ffmpeg_manager.binary,
from_extension, from_extension,
@ -238,6 +321,7 @@ async def async_convert_audio(
to_sample_bytes=to_sample_bytes, to_sample_bytes=to_sample_bytes,
) )
) )
yield data
def _convert_audio( def _convert_audio(
@ -401,32 +485,33 @@ class ResultStream:
return f"/api/tts_proxy/{self.token}" return f"/api/tts_proxy/{self.token}"
@cached_property @cached_property
def _result_cache_key(self) -> asyncio.Future[str]: def _result_cache(self) -> asyncio.Future[TTSCache]:
"""Get the future that returns the cache key.""" """Get the future that returns the cache."""
return asyncio.Future() return asyncio.Future()
@callback @callback
def async_set_message_cache_key(self, cache_key: str) -> None: def async_set_message_cache(self, cache: TTSCache) -> None:
"""Set cache key for message to be streamed.""" """Set cache containing message audio to be streamed."""
self._result_cache_key.set_result(cache_key) self._result_cache.set_result(cache)
@callback @callback
def async_set_message(self, message: str) -> None: def async_set_message(self, message: str) -> None:
"""Set message to be generated.""" """Set message to be generated."""
cache_key = self._manager.async_cache_message_in_memory( self._result_cache.set_result(
engine=self.engine, self._manager.async_cache_message_in_memory(
message=message, engine=self.engine,
use_file_cache=self.use_file_cache, message=message,
language=self.language, use_file_cache=self.use_file_cache,
options=self.options, language=self.language,
options=self.options,
)
) )
self._result_cache_key.set_result(cache_key)
async def async_stream_result(self) -> AsyncGenerator[bytes]: async def async_stream_result(self) -> AsyncGenerator[bytes]:
"""Get the stream of this result.""" """Get the stream of this result."""
cache_key = await self._result_cache_key cache = await self._result_cache
_extension, data = await self._manager.async_get_tts_audio(cache_key) async for chunk in cache.async_stream_data():
yield data yield chunk
def _hash_options(options: dict) -> str: def _hash_options(options: dict) -> str:
@ -483,7 +568,7 @@ class MemcacheCleanup:
now = monotonic() now = monotonic()
for cache_key, info in list(memcache.items()): for cache_key, info in list(memcache.items()):
if info["last_used"] + maxage < now: if info.last_used + maxage < now:
_LOGGER.debug("Cleaning up %s", cache_key) _LOGGER.debug("Cleaning up %s", cache_key)
del memcache[cache_key] del memcache[cache_key]
@ -638,15 +723,18 @@ class SpeechManager:
if message is None: if message is None:
return result_stream return result_stream
cache_key = self._async_ensure_cached_in_memory( # We added this method as an alternative to stream.async_set_message
engine=engine, # to avoid the options being processed twice
engine_instance=engine_instance, result_stream.async_set_message_cache(
message=message, self._async_ensure_cached_in_memory(
use_file_cache=use_file_cache, engine=engine,
language=language, engine_instance=engine_instance,
options=options, message=message,
use_file_cache=use_file_cache,
language=language,
options=options,
)
) )
result_stream.async_set_message_cache_key(cache_key)
return result_stream return result_stream
@ -658,7 +746,7 @@ class SpeechManager:
use_file_cache: bool | None = None, use_file_cache: bool | None = None,
language: str | None = None, language: str | None = None,
options: dict | None = None, options: dict | None = None,
) -> str: ) -> TTSCache:
"""Make sure a message is cached in memory and returns cache key.""" """Make sure a message is cached in memory and returns cache key."""
if (engine_instance := get_engine_instance(self.hass, engine)) is None: if (engine_instance := get_engine_instance(self.hass, engine)) is None:
raise HomeAssistantError(f"Provider {engine} not found") raise HomeAssistantError(f"Provider {engine} not found")
@ -685,7 +773,7 @@ class SpeechManager:
use_file_cache: bool, use_file_cache: bool,
language: str, language: str,
options: dict, options: dict,
) -> str: ) -> TTSCache:
"""Ensure a message is cached. """Ensure a message is cached.
Requires options, language to be processed. Requires options, language to be processed.
@ -697,62 +785,101 @@ class SpeechManager:
).lower() ).lower()
# Is speech already in memory # Is speech already in memory
if cache_key in self.mem_cache: if cache := self.mem_cache.get(cache_key):
return cache_key _LOGGER.debug("Found audio in cache for %s", message[0:32])
return cache
if use_file_cache and cache_key in self.file_cache: store_to_disk = use_file_cache
coro = self._async_load_file_to_mem(cache_key)
if use_file_cache and (filename := self.file_cache.get(cache_key)):
_LOGGER.debug("Loading audio from disk for %s", message[0:32])
extension = os.path.splitext(filename)[1][1:]
data_gen = self._async_load_file(cache_key)
store_to_disk = False
else: else:
coro = self._async_generate_tts_audio( _LOGGER.debug("Generating audio for %s", message[0:32])
engine_instance, cache_key, message, use_file_cache, language, options extension = options.get(ATTR_PREFERRED_FORMAT, _DEFAULT_FORMAT)
data_gen = self._async_generate_tts_audio(
engine_instance, message, language, options
) )
task = self.hass.async_create_task(coro, eager_start=False) cache = TTSCache(
cache_key=cache_key,
extension=extension,
data_gen=data_gen,
)
def handle_error(future: asyncio.Future) -> None: self.mem_cache[cache_key] = cache
"""Handle error.""" self.hass.async_create_background_task(
if not (err := future.exception()): self._load_data_into_cache(
return cache, engine_instance, message, store_to_disk, language, options
),
f"tts_load_data_into_cache_{engine_instance.name}",
)
self.memcache_cleanup.schedule()
return cache
async def _load_data_into_cache(
self,
cache: TTSCache,
engine_instance: TextToSpeechEntity | Provider,
message: str,
store_to_disk: bool,
language: str,
options: dict,
) -> None:
"""Load and process a finished loading TTS Cache."""
try:
data = await cache.async_load_data()
except Exception as err: # pylint: disable=broad-except # noqa: BLE001
# Truncate message so we don't flood the logs. Cutting off at 32 chars # Truncate message so we don't flood the logs. Cutting off at 32 chars
# but since we add 3 dots to truncated message, we cut off at 35. # but since we add 3 dots to truncated message, we cut off at 35.
trunc_msg = message if len(message) < 35 else f"{message[0:32]}" trunc_msg = message if len(message) < 35 else f"{message[0:32]}"
_LOGGER.error("Error generating audio for %s: %s", trunc_msg, err) _LOGGER.error("Error getting audio for %s: %s", trunc_msg, err)
self.mem_cache.pop(cache_key, None) self.mem_cache.pop(cache.cache_key, None)
return
task.add_done_callback(handle_error) if not store_to_disk:
return
self.mem_cache[cache_key] = { filename = f"{cache.cache_key}.{cache.extension}".lower()
"extension": "",
"voice": b"",
"pending": task,
"last_used": monotonic(),
}
return cache_key
async def async_get_tts_audio(self, cache_key: str) -> tuple[str, bytes]: # Validate filename
"""Fetch TTS audio.""" if not _RE_VOICE_FILE.match(filename) and not _RE_LEGACY_VOICE_FILE.match(
cached = self.mem_cache.get(cache_key) filename
if cached is None: ):
raise HomeAssistantError("Audio not cached") raise HomeAssistantError(
if pending := cached.get("pending"): f"TTS filename '{filename}' from {engine_instance.name} is invalid!"
await pending )
cached = self.mem_cache[cache_key]
cached["last_used"] = monotonic() if cache.extension == "mp3":
return cached["extension"], cached["voice"] name = (
engine_instance.name if isinstance(engine_instance.name, str) else "-"
)
data = self.write_tags(filename, data, name, message, language, options)
voice_file = os.path.join(self.cache_dir, filename)
def save_speech() -> None:
"""Store speech to filesystem."""
with open(voice_file, "wb") as speech:
speech.write(data)
try:
await self.hass.async_add_executor_job(save_speech)
except OSError as err:
_LOGGER.error("Can't write %s: %s", filename, err)
else:
self.file_cache[cache.cache_key] = filename
async def _async_generate_tts_audio( async def _async_generate_tts_audio(
self, self,
engine_instance: TextToSpeechEntity | Provider, engine_instance: TextToSpeechEntity | Provider,
cache_key: str,
message: str, message: str,
cache_to_disk: bool,
language: str, language: str,
options: dict[str, Any], options: dict[str, Any],
) -> None: ) -> AsyncGenerator[bytes]:
"""Start loading of the TTS audio. """Generate TTS audio from an engine."""
This method is a coroutine.
"""
options = dict(options or {}) options = dict(options or {})
supported_options = engine_instance.supported_options or [] supported_options = engine_instance.supported_options or []
@ -800,6 +927,17 @@ class SpeechManager:
extension, data = await engine_instance.async_get_tts_audio( extension, data = await engine_instance.async_get_tts_audio(
message, language, options message, language, options
) )
if data is None or extension is None:
raise HomeAssistantError(
f"No TTS from {engine_instance.name} for '{message}'"
)
async def make_data_generator(data: bytes) -> AsyncGenerator[bytes]:
yield data
data_gen = make_data_generator(data)
else: else:
async def message_gen() -> AsyncGenerator[str]: async def message_gen() -> AsyncGenerator[str]:
@ -809,12 +947,7 @@ class SpeechManager:
TTSAudioRequest(language, options, message_gen()) TTSAudioRequest(language, options, message_gen())
) )
extension = tts_result.extension extension = tts_result.extension
data = b"".join([chunk async for chunk in tts_result.data_gen]) data_gen = tts_result.data_gen
if data is None or extension is None:
raise HomeAssistantError(
f"No TTS from {engine_instance.name} for '{message}'"
)
# Only convert if we have a preferred format different than the # Only convert if we have a preferred format different than the
# expected format from the TTS system, or if a specific sample # expected format from the TTS system, or if a specific sample
@ -827,62 +960,21 @@ class SpeechManager:
) )
if needs_conversion: if needs_conversion:
data = await async_convert_audio( data_gen = _async_convert_audio(
self.hass, self.hass,
extension, extension,
data, data_gen,
to_extension=final_extension, to_extension=final_extension,
to_sample_rate=sample_rate, to_sample_rate=sample_rate,
to_sample_channels=sample_channels, to_sample_channels=sample_channels,
to_sample_bytes=sample_bytes, to_sample_bytes=sample_bytes,
) )
# Create file infos async for chunk in data_gen:
filename = f"{cache_key}.{final_extension}".lower() yield chunk
# Validate filename async def _async_load_file(self, cache_key: str) -> AsyncGenerator[bytes]:
if not _RE_VOICE_FILE.match(filename) and not _RE_LEGACY_VOICE_FILE.match( """Load TTS audio from disk."""
filename
):
raise HomeAssistantError(
f"TTS filename '{filename}' from {engine_instance.name} is invalid!"
)
# Save to memory
if final_extension == "mp3":
data = self.write_tags(
filename, data, engine_instance.name, message, language, options
)
self._async_store_to_memcache(cache_key, final_extension, data)
if not cache_to_disk:
return
voice_file = os.path.join(self.cache_dir, filename)
def save_speech() -> None:
"""Store speech to filesystem."""
with open(voice_file, "wb") as speech:
speech.write(data)
# Don't await, we're going to do this in the background
task = self.hass.async_add_executor_job(save_speech)
def write_done(future: asyncio.Future) -> None:
"""Write is done task."""
if err := future.exception():
_LOGGER.error("Can't write %s: %s", filename, err)
else:
self.file_cache[cache_key] = filename
task.add_done_callback(write_done)
async def _async_load_file_to_mem(self, cache_key: str) -> None:
"""Load voice from file cache into memory.
This method is a coroutine.
"""
if not (filename := self.file_cache.get(cache_key)): if not (filename := self.file_cache.get(cache_key)):
raise HomeAssistantError(f"Key {cache_key} not in file cache!") raise HomeAssistantError(f"Key {cache_key} not in file cache!")
@ -899,22 +991,7 @@ class SpeechManager:
del self.file_cache[cache_key] del self.file_cache[cache_key]
raise HomeAssistantError(f"Can't read {voice_file}") from err raise HomeAssistantError(f"Can't read {voice_file}") from err
extension = os.path.splitext(filename)[1][1:] yield data
self._async_store_to_memcache(cache_key, extension, data)
@callback
def _async_store_to_memcache(
self, cache_key: str, extension: str, data: bytes
) -> None:
"""Store data to memcache and set timer to remove it."""
self.mem_cache[cache_key] = {
"extension": extension,
"voice": data,
"pending": None,
"last_used": monotonic(),
}
self.memcache_cleanup.schedule()
@staticmethod @staticmethod
def write_tags( def write_tags(

View File

@ -168,7 +168,7 @@ async def test_service(
assert await get_media_source_url( assert await get_media_source_url(
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID] hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
) == ("/api/tts_proxy/test_token.mp3") ) == ("/api/tts_proxy/test_token.mp3")
await hass.async_block_till_done() await hass.async_block_till_done(wait_background_tasks=True)
assert ( assert (
mock_tts_cache_dir mock_tts_cache_dir
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_{expected_url_suffix}.mp3" / f"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_{expected_url_suffix}.mp3"
@ -230,7 +230,7 @@ async def test_service_default_language(
assert await get_media_source_url( assert await get_media_source_url(
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID] hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
) == ("/api/tts_proxy/test_token.mp3") ) == ("/api/tts_proxy/test_token.mp3")
await hass.async_block_till_done() await hass.async_block_till_done(wait_background_tasks=True)
assert ( assert (
mock_tts_cache_dir mock_tts_cache_dir
/ ( / (
@ -294,7 +294,7 @@ async def test_service_default_special_language(
assert await get_media_source_url( assert await get_media_source_url(
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID] hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
) == ("/api/tts_proxy/test_token.mp3") ) == ("/api/tts_proxy/test_token.mp3")
await hass.async_block_till_done() await hass.async_block_till_done(wait_background_tasks=True)
assert ( assert (
mock_tts_cache_dir mock_tts_cache_dir
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_{expected_url_suffix}.mp3" / f"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_{expected_url_suffix}.mp3"
@ -354,7 +354,7 @@ async def test_service_language(
assert await get_media_source_url( assert await get_media_source_url(
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID] hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
) == ("/api/tts_proxy/test_token.mp3") ) == ("/api/tts_proxy/test_token.mp3")
await hass.async_block_till_done() await hass.async_block_till_done(wait_background_tasks=True)
assert ( assert (
mock_tts_cache_dir mock_tts_cache_dir
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_-_{expected_url_suffix}.mp3" / f"42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_-_{expected_url_suffix}.mp3"
@ -470,7 +470,7 @@ async def test_service_options(
assert await get_media_source_url( assert await get_media_source_url(
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID] hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
) == ("/api/tts_proxy/test_token.mp3") ) == ("/api/tts_proxy/test_token.mp3")
await hass.async_block_till_done() await hass.async_block_till_done(wait_background_tasks=True)
assert ( assert (
mock_tts_cache_dir mock_tts_cache_dir
/ ( / (
@ -554,7 +554,7 @@ async def test_service_default_options(
assert await get_media_source_url( assert await get_media_source_url(
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID] hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
) == ("/api/tts_proxy/test_token.mp3") ) == ("/api/tts_proxy/test_token.mp3")
await hass.async_block_till_done() await hass.async_block_till_done(wait_background_tasks=True)
assert ( assert (
mock_tts_cache_dir mock_tts_cache_dir
/ ( / (
@ -628,7 +628,7 @@ async def test_merge_default_service_options(
assert await get_media_source_url( assert await get_media_source_url(
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID] hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
) == ("/api/tts_proxy/test_token.mp3") ) == ("/api/tts_proxy/test_token.mp3")
await hass.async_block_till_done() await hass.async_block_till_done(wait_background_tasks=True)
assert ( assert (
mock_tts_cache_dir mock_tts_cache_dir
/ ( / (
@ -743,7 +743,7 @@ async def test_service_clear_cache(
# To make sure the file is persisted # To make sure the file is persisted
assert len(calls) == 1 assert len(calls) == 1
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
await hass.async_block_till_done() await hass.async_block_till_done(wait_background_tasks=True)
assert ( assert (
mock_tts_cache_dir mock_tts_cache_dir
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_{expected_url_suffix}.mp3" / f"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_{expected_url_suffix}.mp3"
@ -1769,9 +1769,15 @@ async def test_async_convert_audio_error(hass: HomeAssistant) -> None:
"""Test that ffmpeg failing during audio conversion will raise an error.""" """Test that ffmpeg failing during audio conversion will raise an error."""
assert await async_setup_component(hass, ffmpeg.DOMAIN, {}) assert await async_setup_component(hass, ffmpeg.DOMAIN, {})
with pytest.raises(RuntimeError): async def bad_data_gen():
yield bytes(0)
with pytest.raises(RuntimeError): # noqa: PT012
# Simulate a bad WAV file # Simulate a bad WAV file
await tts.async_convert_audio(hass, "wav", bytes(0), "mp3") async for _chunk in tts._async_convert_audio(
hass, "wav", bad_data_gen(), "mp3"
):
pass
async def test_default_engine_prefer_entity( async def test_default_engine_prefer_entity(
@ -1846,3 +1852,86 @@ async def test_stream(hass: HomeAssistant, mock_tts_entity: MockTTSEntity) -> No
assert stream2.extension == "wav" assert stream2.extension == "wav"
result_data = b"".join([chunk async for chunk in stream2.async_stream_result()]) result_data = b"".join([chunk async for chunk in stream2.async_stream_result()])
assert result_data == data assert result_data == data
async def test_tts_cache() -> None:
"""Test TTSCache."""
async def data_gen(queue: asyncio.Queue[bytes | None | Exception]):
while chunk := await queue.get():
if isinstance(chunk, Exception):
raise chunk
yield chunk
queue = asyncio.Queue()
cache = tts.TTSCache("test-key", "mp3", data_gen(queue))
assert cache.cache_key == "test-key"
assert cache.extension == "mp3"
for i in range(10):
queue.put_nowait(f"{i}".encode())
queue.put_nowait(None)
assert await cache.async_load_data() == b"0123456789"
with pytest.raises(RuntimeError):
await cache.async_load_data()
# When data is loaded, we get it all in 1 chunk
cur = 0
async for chunk in cache.async_stream_data():
assert chunk == b"0123456789"
cur += 1
assert cur == 1
# Show we can stream the data while it's still being generated
async def consume_cache(cache: tts.TTSCache):
return b"".join([chunk async for chunk in cache.async_stream_data()])
queue = asyncio.Queue()
cache = tts.TTSCache("test-key", "mp3", data_gen(queue))
load_data_task = asyncio.create_task(cache.async_load_data())
consume_pre_data_loaded_task = asyncio.create_task(consume_cache(cache))
queue.put_nowait(b"0")
await asyncio.sleep(0)
queue.put_nowait(b"1")
await asyncio.sleep(0)
consume_mid_data_task = asyncio.create_task(consume_cache(cache))
queue.put_nowait(b"2")
await asyncio.sleep(0)
queue.put_nowait(None)
consume_post_data_loaded_task = asyncio.create_task(consume_cache(cache))
await asyncio.sleep(0)
assert await load_data_task == b"012"
assert await consume_post_data_loaded_task == b"012"
assert await consume_mid_data_task == b"012"
assert await consume_pre_data_loaded_task == b"012"
# Now with errors
async def consume_cache(cache: tts.TTSCache):
return b"".join([chunk async for chunk in cache.async_stream_data()])
queue = asyncio.Queue()
cache = tts.TTSCache("test-key", "mp3", data_gen(queue))
load_data_task = asyncio.create_task(cache.async_load_data())
consume_pre_data_loaded_task = asyncio.create_task(consume_cache(cache))
queue.put_nowait(b"0")
await asyncio.sleep(0)
queue.put_nowait(b"1")
await asyncio.sleep(0)
consume_mid_data_task = asyncio.create_task(consume_cache(cache))
queue.put_nowait(ValueError("Boom!"))
await asyncio.sleep(0)
queue.put_nowait(None)
consume_post_data_loaded_task = asyncio.create_task(consume_cache(cache))
await asyncio.sleep(0)
with pytest.raises(ValueError):
assert await load_data_task == b"012"
with pytest.raises(ValueError):
assert await consume_post_data_loaded_task == b"012"
with pytest.raises(ValueError):
assert await consume_mid_data_task == b"012"
with pytest.raises(ValueError):
assert await consume_pre_data_loaded_task == b"012"

View File

@ -150,17 +150,15 @@ async def test_get_tts_audio_connection_lost(
hass: HomeAssistant, init_wyoming_tts hass: HomeAssistant, init_wyoming_tts
) -> None: ) -> None:
"""Test streaming audio and losing connection.""" """Test streaming audio and losing connection."""
with ( stream = tts.async_create_stream(hass, "tts.test_tts", "en-US")
patch( with patch(
"homeassistant.components.wyoming.tts.AsyncTcpClient", "homeassistant.components.wyoming.tts.AsyncTcpClient",
MockAsyncTcpClient([None]), MockAsyncTcpClient([None]),
),
pytest.raises(HomeAssistantError),
): ):
await tts.async_get_media_source_audio( stream.async_set_message("Hello world")
hass, with pytest.raises(HomeAssistantError): # noqa: PT012
tts.generate_media_source_id(hass, "Hello world", "tts.test_tts", "en-US"), async for _chunk in stream.async_stream_result():
) pass
async def test_get_tts_audio_audio_oserror( async def test_get_tts_audio_audio_oserror(