mirror of
https://github.com/home-assistant/core.git
synced 2025-04-25 01:38:02 +00:00
Switch more TTS core to async generators (#140432)
* Switch more TTS core to async generators * Document a design choice * robust * Add more tests * Update comment * Clarify and document TTSCache variables
This commit is contained in:
parent
b07ac301b9
commit
55895df54d
@ -17,7 +17,7 @@ import secrets
|
||||
import subprocess
|
||||
import tempfile
|
||||
from time import monotonic
|
||||
from typing import Any, Final, TypedDict
|
||||
from typing import Any, Final
|
||||
|
||||
from aiohttp import web
|
||||
import mutagen
|
||||
@ -123,13 +123,94 @@ KEY_PATTERN = "{0}_{1}_{2}_{3}"
|
||||
SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({})
|
||||
|
||||
|
||||
class TTSCache(TypedDict):
|
||||
"""Cached TTS file."""
|
||||
class TTSCache:
|
||||
"""Cached bytes of a TTS result."""
|
||||
|
||||
extension: str
|
||||
voice: bytes
|
||||
pending: asyncio.Task | None
|
||||
last_used: float
|
||||
_result_data: bytes | None = None
|
||||
"""When fully loaded, contains the result data."""
|
||||
|
||||
_partial_data: list[bytes] | None = None
|
||||
"""While loading, contains the data already received from the generator."""
|
||||
|
||||
_loading_error: Exception | None = None
|
||||
"""If an error occurred while loading, contains the error."""
|
||||
|
||||
_consumers: list[asyncio.Queue[bytes | None]] | None = None
|
||||
"""A queue for each current consumer to notify of new data while the generator is loading."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_key: str,
|
||||
extension: str,
|
||||
data_gen: AsyncGenerator[bytes],
|
||||
) -> None:
|
||||
"""Initialize the TTS cache."""
|
||||
self.cache_key = cache_key
|
||||
self.extension = extension
|
||||
self.last_used = monotonic()
|
||||
self._data_gen = data_gen
|
||||
|
||||
async def async_load_data(self) -> bytes:
|
||||
"""Load the data from the generator."""
|
||||
if self._result_data is not None or self._partial_data is not None:
|
||||
raise RuntimeError("Data already being loaded")
|
||||
|
||||
self._partial_data = []
|
||||
self._consumers = []
|
||||
|
||||
try:
|
||||
async for chunk in self._data_gen:
|
||||
self._partial_data.append(chunk)
|
||||
for queue in self._consumers:
|
||||
queue.put_nowait(chunk)
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
self._loading_error = err
|
||||
raise
|
||||
finally:
|
||||
for queue in self._consumers:
|
||||
queue.put_nowait(None)
|
||||
self._consumers = None
|
||||
|
||||
self._result_data = b"".join(self._partial_data)
|
||||
self._partial_data = None
|
||||
return self._result_data
|
||||
|
||||
async def async_stream_data(self) -> AsyncGenerator[bytes]:
|
||||
"""Stream the data.
|
||||
|
||||
Will return all data already returned from the generator.
|
||||
Will listen for future data returned from the generator.
|
||||
Raises error if one occurred.
|
||||
"""
|
||||
if self._result_data is not None:
|
||||
yield self._result_data
|
||||
return
|
||||
if self._loading_error:
|
||||
raise self._loading_error
|
||||
|
||||
if self._partial_data is None:
|
||||
raise RuntimeError("Data not being loaded")
|
||||
|
||||
queue: asyncio.Queue[bytes | None] | None = None
|
||||
# Check if generator is still feeding data
|
||||
if self._consumers is not None:
|
||||
queue = asyncio.Queue()
|
||||
self._consumers.append(queue)
|
||||
|
||||
for chunk in list(self._partial_data):
|
||||
yield chunk
|
||||
|
||||
if self._loading_error:
|
||||
raise self._loading_error
|
||||
|
||||
if queue is not None:
|
||||
while (chunk2 := await queue.get()) is not None:
|
||||
yield chunk2
|
||||
|
||||
if self._loading_error:
|
||||
raise self._loading_error
|
||||
|
||||
self.last_used = monotonic()
|
||||
|
||||
|
||||
@callback
|
||||
@ -194,10 +275,11 @@ async def async_get_media_source_audio(
|
||||
) -> tuple[str, bytes]:
|
||||
"""Get TTS audio as extension, data."""
|
||||
manager = hass.data[DATA_TTS_MANAGER]
|
||||
cache_key = manager.async_cache_message_in_memory(
|
||||
cache = manager.async_cache_message_in_memory(
|
||||
**media_source_id_to_kwargs(media_source_id)
|
||||
)
|
||||
return await manager.async_get_tts_audio(cache_key)
|
||||
data = b"".join([chunk async for chunk in cache.async_stream_data()])
|
||||
return cache.extension, data
|
||||
|
||||
|
||||
@callback
|
||||
@ -216,18 +298,19 @@ def async_get_text_to_speech_languages(hass: HomeAssistant) -> set[str]:
|
||||
return languages
|
||||
|
||||
|
||||
async def async_convert_audio(
|
||||
async def _async_convert_audio(
|
||||
hass: HomeAssistant,
|
||||
from_extension: str,
|
||||
audio_bytes: bytes,
|
||||
audio_bytes_gen: AsyncGenerator[bytes],
|
||||
to_extension: str,
|
||||
to_sample_rate: int | None = None,
|
||||
to_sample_channels: int | None = None,
|
||||
to_sample_bytes: int | None = None,
|
||||
) -> bytes:
|
||||
) -> AsyncGenerator[bytes]:
|
||||
"""Convert audio to a preferred format using ffmpeg."""
|
||||
ffmpeg_manager = ffmpeg.get_ffmpeg_manager(hass)
|
||||
return await hass.async_add_executor_job(
|
||||
audio_bytes = b"".join([chunk async for chunk in audio_bytes_gen])
|
||||
data = await hass.async_add_executor_job(
|
||||
lambda: _convert_audio(
|
||||
ffmpeg_manager.binary,
|
||||
from_extension,
|
||||
@ -238,6 +321,7 @@ async def async_convert_audio(
|
||||
to_sample_bytes=to_sample_bytes,
|
||||
)
|
||||
)
|
||||
yield data
|
||||
|
||||
|
||||
def _convert_audio(
|
||||
@ -401,32 +485,33 @@ class ResultStream:
|
||||
return f"/api/tts_proxy/{self.token}"
|
||||
|
||||
@cached_property
|
||||
def _result_cache_key(self) -> asyncio.Future[str]:
|
||||
"""Get the future that returns the cache key."""
|
||||
def _result_cache(self) -> asyncio.Future[TTSCache]:
|
||||
"""Get the future that returns the cache."""
|
||||
return asyncio.Future()
|
||||
|
||||
@callback
|
||||
def async_set_message_cache_key(self, cache_key: str) -> None:
|
||||
"""Set cache key for message to be streamed."""
|
||||
self._result_cache_key.set_result(cache_key)
|
||||
def async_set_message_cache(self, cache: TTSCache) -> None:
|
||||
"""Set cache containing message audio to be streamed."""
|
||||
self._result_cache.set_result(cache)
|
||||
|
||||
@callback
|
||||
def async_set_message(self, message: str) -> None:
|
||||
"""Set message to be generated."""
|
||||
cache_key = self._manager.async_cache_message_in_memory(
|
||||
engine=self.engine,
|
||||
message=message,
|
||||
use_file_cache=self.use_file_cache,
|
||||
language=self.language,
|
||||
options=self.options,
|
||||
self._result_cache.set_result(
|
||||
self._manager.async_cache_message_in_memory(
|
||||
engine=self.engine,
|
||||
message=message,
|
||||
use_file_cache=self.use_file_cache,
|
||||
language=self.language,
|
||||
options=self.options,
|
||||
)
|
||||
)
|
||||
self._result_cache_key.set_result(cache_key)
|
||||
|
||||
async def async_stream_result(self) -> AsyncGenerator[bytes]:
|
||||
"""Get the stream of this result."""
|
||||
cache_key = await self._result_cache_key
|
||||
_extension, data = await self._manager.async_get_tts_audio(cache_key)
|
||||
yield data
|
||||
cache = await self._result_cache
|
||||
async for chunk in cache.async_stream_data():
|
||||
yield chunk
|
||||
|
||||
|
||||
def _hash_options(options: dict) -> str:
|
||||
@ -483,7 +568,7 @@ class MemcacheCleanup:
|
||||
now = monotonic()
|
||||
|
||||
for cache_key, info in list(memcache.items()):
|
||||
if info["last_used"] + maxage < now:
|
||||
if info.last_used + maxage < now:
|
||||
_LOGGER.debug("Cleaning up %s", cache_key)
|
||||
del memcache[cache_key]
|
||||
|
||||
@ -638,15 +723,18 @@ class SpeechManager:
|
||||
if message is None:
|
||||
return result_stream
|
||||
|
||||
cache_key = self._async_ensure_cached_in_memory(
|
||||
engine=engine,
|
||||
engine_instance=engine_instance,
|
||||
message=message,
|
||||
use_file_cache=use_file_cache,
|
||||
language=language,
|
||||
options=options,
|
||||
# We added this method as an alternative to stream.async_set_message
|
||||
# to avoid the options being processed twice
|
||||
result_stream.async_set_message_cache(
|
||||
self._async_ensure_cached_in_memory(
|
||||
engine=engine,
|
||||
engine_instance=engine_instance,
|
||||
message=message,
|
||||
use_file_cache=use_file_cache,
|
||||
language=language,
|
||||
options=options,
|
||||
)
|
||||
)
|
||||
result_stream.async_set_message_cache_key(cache_key)
|
||||
|
||||
return result_stream
|
||||
|
||||
@ -658,7 +746,7 @@ class SpeechManager:
|
||||
use_file_cache: bool | None = None,
|
||||
language: str | None = None,
|
||||
options: dict | None = None,
|
||||
) -> str:
|
||||
) -> TTSCache:
|
||||
"""Make sure a message is cached in memory and returns cache key."""
|
||||
if (engine_instance := get_engine_instance(self.hass, engine)) is None:
|
||||
raise HomeAssistantError(f"Provider {engine} not found")
|
||||
@ -685,7 +773,7 @@ class SpeechManager:
|
||||
use_file_cache: bool,
|
||||
language: str,
|
||||
options: dict,
|
||||
) -> str:
|
||||
) -> TTSCache:
|
||||
"""Ensure a message is cached.
|
||||
|
||||
Requires options, language to be processed.
|
||||
@ -697,62 +785,101 @@ class SpeechManager:
|
||||
).lower()
|
||||
|
||||
# Is speech already in memory
|
||||
if cache_key in self.mem_cache:
|
||||
return cache_key
|
||||
if cache := self.mem_cache.get(cache_key):
|
||||
_LOGGER.debug("Found audio in cache for %s", message[0:32])
|
||||
return cache
|
||||
|
||||
if use_file_cache and cache_key in self.file_cache:
|
||||
coro = self._async_load_file_to_mem(cache_key)
|
||||
store_to_disk = use_file_cache
|
||||
|
||||
if use_file_cache and (filename := self.file_cache.get(cache_key)):
|
||||
_LOGGER.debug("Loading audio from disk for %s", message[0:32])
|
||||
extension = os.path.splitext(filename)[1][1:]
|
||||
data_gen = self._async_load_file(cache_key)
|
||||
store_to_disk = False
|
||||
else:
|
||||
coro = self._async_generate_tts_audio(
|
||||
engine_instance, cache_key, message, use_file_cache, language, options
|
||||
_LOGGER.debug("Generating audio for %s", message[0:32])
|
||||
extension = options.get(ATTR_PREFERRED_FORMAT, _DEFAULT_FORMAT)
|
||||
data_gen = self._async_generate_tts_audio(
|
||||
engine_instance, message, language, options
|
||||
)
|
||||
|
||||
task = self.hass.async_create_task(coro, eager_start=False)
|
||||
cache = TTSCache(
|
||||
cache_key=cache_key,
|
||||
extension=extension,
|
||||
data_gen=data_gen,
|
||||
)
|
||||
|
||||
def handle_error(future: asyncio.Future) -> None:
|
||||
"""Handle error."""
|
||||
if not (err := future.exception()):
|
||||
return
|
||||
self.mem_cache[cache_key] = cache
|
||||
self.hass.async_create_background_task(
|
||||
self._load_data_into_cache(
|
||||
cache, engine_instance, message, store_to_disk, language, options
|
||||
),
|
||||
f"tts_load_data_into_cache_{engine_instance.name}",
|
||||
)
|
||||
self.memcache_cleanup.schedule()
|
||||
return cache
|
||||
|
||||
async def _load_data_into_cache(
|
||||
self,
|
||||
cache: TTSCache,
|
||||
engine_instance: TextToSpeechEntity | Provider,
|
||||
message: str,
|
||||
store_to_disk: bool,
|
||||
language: str,
|
||||
options: dict,
|
||||
) -> None:
|
||||
"""Load and process a finished loading TTS Cache."""
|
||||
try:
|
||||
data = await cache.async_load_data()
|
||||
except Exception as err: # pylint: disable=broad-except # noqa: BLE001
|
||||
# Truncate message so we don't flood the logs. Cutting off at 32 chars
|
||||
# but since we add 3 dots to truncated message, we cut off at 35.
|
||||
trunc_msg = message if len(message) < 35 else f"{message[0:32]}…"
|
||||
_LOGGER.error("Error generating audio for %s: %s", trunc_msg, err)
|
||||
self.mem_cache.pop(cache_key, None)
|
||||
_LOGGER.error("Error getting audio for %s: %s", trunc_msg, err)
|
||||
self.mem_cache.pop(cache.cache_key, None)
|
||||
return
|
||||
|
||||
task.add_done_callback(handle_error)
|
||||
if not store_to_disk:
|
||||
return
|
||||
|
||||
self.mem_cache[cache_key] = {
|
||||
"extension": "",
|
||||
"voice": b"",
|
||||
"pending": task,
|
||||
"last_used": monotonic(),
|
||||
}
|
||||
return cache_key
|
||||
filename = f"{cache.cache_key}.{cache.extension}".lower()
|
||||
|
||||
async def async_get_tts_audio(self, cache_key: str) -> tuple[str, bytes]:
|
||||
"""Fetch TTS audio."""
|
||||
cached = self.mem_cache.get(cache_key)
|
||||
if cached is None:
|
||||
raise HomeAssistantError("Audio not cached")
|
||||
if pending := cached.get("pending"):
|
||||
await pending
|
||||
cached = self.mem_cache[cache_key]
|
||||
cached["last_used"] = monotonic()
|
||||
return cached["extension"], cached["voice"]
|
||||
# Validate filename
|
||||
if not _RE_VOICE_FILE.match(filename) and not _RE_LEGACY_VOICE_FILE.match(
|
||||
filename
|
||||
):
|
||||
raise HomeAssistantError(
|
||||
f"TTS filename '{filename}' from {engine_instance.name} is invalid!"
|
||||
)
|
||||
|
||||
if cache.extension == "mp3":
|
||||
name = (
|
||||
engine_instance.name if isinstance(engine_instance.name, str) else "-"
|
||||
)
|
||||
data = self.write_tags(filename, data, name, message, language, options)
|
||||
|
||||
voice_file = os.path.join(self.cache_dir, filename)
|
||||
|
||||
def save_speech() -> None:
|
||||
"""Store speech to filesystem."""
|
||||
with open(voice_file, "wb") as speech:
|
||||
speech.write(data)
|
||||
|
||||
try:
|
||||
await self.hass.async_add_executor_job(save_speech)
|
||||
except OSError as err:
|
||||
_LOGGER.error("Can't write %s: %s", filename, err)
|
||||
else:
|
||||
self.file_cache[cache.cache_key] = filename
|
||||
|
||||
async def _async_generate_tts_audio(
|
||||
self,
|
||||
engine_instance: TextToSpeechEntity | Provider,
|
||||
cache_key: str,
|
||||
message: str,
|
||||
cache_to_disk: bool,
|
||||
language: str,
|
||||
options: dict[str, Any],
|
||||
) -> None:
|
||||
"""Start loading of the TTS audio.
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
) -> AsyncGenerator[bytes]:
|
||||
"""Generate TTS audio from an engine."""
|
||||
options = dict(options or {})
|
||||
supported_options = engine_instance.supported_options or []
|
||||
|
||||
@ -800,6 +927,17 @@ class SpeechManager:
|
||||
extension, data = await engine_instance.async_get_tts_audio(
|
||||
message, language, options
|
||||
)
|
||||
|
||||
if data is None or extension is None:
|
||||
raise HomeAssistantError(
|
||||
f"No TTS from {engine_instance.name} for '{message}'"
|
||||
)
|
||||
|
||||
async def make_data_generator(data: bytes) -> AsyncGenerator[bytes]:
|
||||
yield data
|
||||
|
||||
data_gen = make_data_generator(data)
|
||||
|
||||
else:
|
||||
|
||||
async def message_gen() -> AsyncGenerator[str]:
|
||||
@ -809,12 +947,7 @@ class SpeechManager:
|
||||
TTSAudioRequest(language, options, message_gen())
|
||||
)
|
||||
extension = tts_result.extension
|
||||
data = b"".join([chunk async for chunk in tts_result.data_gen])
|
||||
|
||||
if data is None or extension is None:
|
||||
raise HomeAssistantError(
|
||||
f"No TTS from {engine_instance.name} for '{message}'"
|
||||
)
|
||||
data_gen = tts_result.data_gen
|
||||
|
||||
# Only convert if we have a preferred format different than the
|
||||
# expected format from the TTS system, or if a specific sample
|
||||
@ -827,62 +960,21 @@ class SpeechManager:
|
||||
)
|
||||
|
||||
if needs_conversion:
|
||||
data = await async_convert_audio(
|
||||
data_gen = _async_convert_audio(
|
||||
self.hass,
|
||||
extension,
|
||||
data,
|
||||
data_gen,
|
||||
to_extension=final_extension,
|
||||
to_sample_rate=sample_rate,
|
||||
to_sample_channels=sample_channels,
|
||||
to_sample_bytes=sample_bytes,
|
||||
)
|
||||
|
||||
# Create file infos
|
||||
filename = f"{cache_key}.{final_extension}".lower()
|
||||
async for chunk in data_gen:
|
||||
yield chunk
|
||||
|
||||
# Validate filename
|
||||
if not _RE_VOICE_FILE.match(filename) and not _RE_LEGACY_VOICE_FILE.match(
|
||||
filename
|
||||
):
|
||||
raise HomeAssistantError(
|
||||
f"TTS filename '{filename}' from {engine_instance.name} is invalid!"
|
||||
)
|
||||
|
||||
# Save to memory
|
||||
if final_extension == "mp3":
|
||||
data = self.write_tags(
|
||||
filename, data, engine_instance.name, message, language, options
|
||||
)
|
||||
|
||||
self._async_store_to_memcache(cache_key, final_extension, data)
|
||||
|
||||
if not cache_to_disk:
|
||||
return
|
||||
|
||||
voice_file = os.path.join(self.cache_dir, filename)
|
||||
|
||||
def save_speech() -> None:
|
||||
"""Store speech to filesystem."""
|
||||
with open(voice_file, "wb") as speech:
|
||||
speech.write(data)
|
||||
|
||||
# Don't await, we're going to do this in the background
|
||||
task = self.hass.async_add_executor_job(save_speech)
|
||||
|
||||
def write_done(future: asyncio.Future) -> None:
|
||||
"""Write is done task."""
|
||||
if err := future.exception():
|
||||
_LOGGER.error("Can't write %s: %s", filename, err)
|
||||
else:
|
||||
self.file_cache[cache_key] = filename
|
||||
|
||||
task.add_done_callback(write_done)
|
||||
|
||||
async def _async_load_file_to_mem(self, cache_key: str) -> None:
|
||||
"""Load voice from file cache into memory.
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
async def _async_load_file(self, cache_key: str) -> AsyncGenerator[bytes]:
|
||||
"""Load TTS audio from disk."""
|
||||
if not (filename := self.file_cache.get(cache_key)):
|
||||
raise HomeAssistantError(f"Key {cache_key} not in file cache!")
|
||||
|
||||
@ -899,22 +991,7 @@ class SpeechManager:
|
||||
del self.file_cache[cache_key]
|
||||
raise HomeAssistantError(f"Can't read {voice_file}") from err
|
||||
|
||||
extension = os.path.splitext(filename)[1][1:]
|
||||
|
||||
self._async_store_to_memcache(cache_key, extension, data)
|
||||
|
||||
@callback
|
||||
def _async_store_to_memcache(
|
||||
self, cache_key: str, extension: str, data: bytes
|
||||
) -> None:
|
||||
"""Store data to memcache and set timer to remove it."""
|
||||
self.mem_cache[cache_key] = {
|
||||
"extension": extension,
|
||||
"voice": data,
|
||||
"pending": None,
|
||||
"last_used": monotonic(),
|
||||
}
|
||||
self.memcache_cleanup.schedule()
|
||||
yield data
|
||||
|
||||
@staticmethod
|
||||
def write_tags(
|
||||
|
@ -168,7 +168,7 @@ async def test_service(
|
||||
assert await get_media_source_url(
|
||||
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
|
||||
) == ("/api/tts_proxy/test_token.mp3")
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done(wait_background_tasks=True)
|
||||
assert (
|
||||
mock_tts_cache_dir
|
||||
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_{expected_url_suffix}.mp3"
|
||||
@ -230,7 +230,7 @@ async def test_service_default_language(
|
||||
assert await get_media_source_url(
|
||||
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
|
||||
) == ("/api/tts_proxy/test_token.mp3")
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done(wait_background_tasks=True)
|
||||
assert (
|
||||
mock_tts_cache_dir
|
||||
/ (
|
||||
@ -294,7 +294,7 @@ async def test_service_default_special_language(
|
||||
assert await get_media_source_url(
|
||||
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
|
||||
) == ("/api/tts_proxy/test_token.mp3")
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done(wait_background_tasks=True)
|
||||
assert (
|
||||
mock_tts_cache_dir
|
||||
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_{expected_url_suffix}.mp3"
|
||||
@ -354,7 +354,7 @@ async def test_service_language(
|
||||
assert await get_media_source_url(
|
||||
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
|
||||
) == ("/api/tts_proxy/test_token.mp3")
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done(wait_background_tasks=True)
|
||||
assert (
|
||||
mock_tts_cache_dir
|
||||
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_-_{expected_url_suffix}.mp3"
|
||||
@ -470,7 +470,7 @@ async def test_service_options(
|
||||
assert await get_media_source_url(
|
||||
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
|
||||
) == ("/api/tts_proxy/test_token.mp3")
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done(wait_background_tasks=True)
|
||||
assert (
|
||||
mock_tts_cache_dir
|
||||
/ (
|
||||
@ -554,7 +554,7 @@ async def test_service_default_options(
|
||||
assert await get_media_source_url(
|
||||
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
|
||||
) == ("/api/tts_proxy/test_token.mp3")
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done(wait_background_tasks=True)
|
||||
assert (
|
||||
mock_tts_cache_dir
|
||||
/ (
|
||||
@ -628,7 +628,7 @@ async def test_merge_default_service_options(
|
||||
assert await get_media_source_url(
|
||||
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
|
||||
) == ("/api/tts_proxy/test_token.mp3")
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done(wait_background_tasks=True)
|
||||
assert (
|
||||
mock_tts_cache_dir
|
||||
/ (
|
||||
@ -743,7 +743,7 @@ async def test_service_clear_cache(
|
||||
# To make sure the file is persisted
|
||||
assert len(calls) == 1
|
||||
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done(wait_background_tasks=True)
|
||||
assert (
|
||||
mock_tts_cache_dir
|
||||
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_{expected_url_suffix}.mp3"
|
||||
@ -1769,9 +1769,15 @@ async def test_async_convert_audio_error(hass: HomeAssistant) -> None:
|
||||
"""Test that ffmpeg failing during audio conversion will raise an error."""
|
||||
assert await async_setup_component(hass, ffmpeg.DOMAIN, {})
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
async def bad_data_gen():
|
||||
yield bytes(0)
|
||||
|
||||
with pytest.raises(RuntimeError): # noqa: PT012
|
||||
# Simulate a bad WAV file
|
||||
await tts.async_convert_audio(hass, "wav", bytes(0), "mp3")
|
||||
async for _chunk in tts._async_convert_audio(
|
||||
hass, "wav", bad_data_gen(), "mp3"
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
async def test_default_engine_prefer_entity(
|
||||
@ -1846,3 +1852,86 @@ async def test_stream(hass: HomeAssistant, mock_tts_entity: MockTTSEntity) -> No
|
||||
assert stream2.extension == "wav"
|
||||
result_data = b"".join([chunk async for chunk in stream2.async_stream_result()])
|
||||
assert result_data == data
|
||||
|
||||
|
||||
async def test_tts_cache() -> None:
|
||||
"""Test TTSCache."""
|
||||
|
||||
async def data_gen(queue: asyncio.Queue[bytes | None | Exception]):
|
||||
while chunk := await queue.get():
|
||||
if isinstance(chunk, Exception):
|
||||
raise chunk
|
||||
yield chunk
|
||||
|
||||
queue = asyncio.Queue()
|
||||
cache = tts.TTSCache("test-key", "mp3", data_gen(queue))
|
||||
assert cache.cache_key == "test-key"
|
||||
assert cache.extension == "mp3"
|
||||
|
||||
for i in range(10):
|
||||
queue.put_nowait(f"{i}".encode())
|
||||
queue.put_nowait(None)
|
||||
|
||||
assert await cache.async_load_data() == b"0123456789"
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await cache.async_load_data()
|
||||
|
||||
# When data is loaded, we get it all in 1 chunk
|
||||
cur = 0
|
||||
async for chunk in cache.async_stream_data():
|
||||
assert chunk == b"0123456789"
|
||||
cur += 1
|
||||
assert cur == 1
|
||||
|
||||
# Show we can stream the data while it's still being generated
|
||||
async def consume_cache(cache: tts.TTSCache):
|
||||
return b"".join([chunk async for chunk in cache.async_stream_data()])
|
||||
|
||||
queue = asyncio.Queue()
|
||||
cache = tts.TTSCache("test-key", "mp3", data_gen(queue))
|
||||
|
||||
load_data_task = asyncio.create_task(cache.async_load_data())
|
||||
consume_pre_data_loaded_task = asyncio.create_task(consume_cache(cache))
|
||||
queue.put_nowait(b"0")
|
||||
await asyncio.sleep(0)
|
||||
queue.put_nowait(b"1")
|
||||
await asyncio.sleep(0)
|
||||
consume_mid_data_task = asyncio.create_task(consume_cache(cache))
|
||||
queue.put_nowait(b"2")
|
||||
await asyncio.sleep(0)
|
||||
queue.put_nowait(None)
|
||||
consume_post_data_loaded_task = asyncio.create_task(consume_cache(cache))
|
||||
await asyncio.sleep(0)
|
||||
assert await load_data_task == b"012"
|
||||
assert await consume_post_data_loaded_task == b"012"
|
||||
assert await consume_mid_data_task == b"012"
|
||||
assert await consume_pre_data_loaded_task == b"012"
|
||||
|
||||
# Now with errors
|
||||
async def consume_cache(cache: tts.TTSCache):
|
||||
return b"".join([chunk async for chunk in cache.async_stream_data()])
|
||||
|
||||
queue = asyncio.Queue()
|
||||
cache = tts.TTSCache("test-key", "mp3", data_gen(queue))
|
||||
|
||||
load_data_task = asyncio.create_task(cache.async_load_data())
|
||||
consume_pre_data_loaded_task = asyncio.create_task(consume_cache(cache))
|
||||
queue.put_nowait(b"0")
|
||||
await asyncio.sleep(0)
|
||||
queue.put_nowait(b"1")
|
||||
await asyncio.sleep(0)
|
||||
consume_mid_data_task = asyncio.create_task(consume_cache(cache))
|
||||
queue.put_nowait(ValueError("Boom!"))
|
||||
await asyncio.sleep(0)
|
||||
queue.put_nowait(None)
|
||||
consume_post_data_loaded_task = asyncio.create_task(consume_cache(cache))
|
||||
await asyncio.sleep(0)
|
||||
with pytest.raises(ValueError):
|
||||
assert await load_data_task == b"012"
|
||||
with pytest.raises(ValueError):
|
||||
assert await consume_post_data_loaded_task == b"012"
|
||||
with pytest.raises(ValueError):
|
||||
assert await consume_mid_data_task == b"012"
|
||||
with pytest.raises(ValueError):
|
||||
assert await consume_pre_data_loaded_task == b"012"
|
||||
|
@ -150,17 +150,15 @@ async def test_get_tts_audio_connection_lost(
|
||||
hass: HomeAssistant, init_wyoming_tts
|
||||
) -> None:
|
||||
"""Test streaming audio and losing connection."""
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.wyoming.tts.AsyncTcpClient",
|
||||
MockAsyncTcpClient([None]),
|
||||
),
|
||||
pytest.raises(HomeAssistantError),
|
||||
stream = tts.async_create_stream(hass, "tts.test_tts", "en-US")
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.tts.AsyncTcpClient",
|
||||
MockAsyncTcpClient([None]),
|
||||
):
|
||||
await tts.async_get_media_source_audio(
|
||||
hass,
|
||||
tts.generate_media_source_id(hass, "Hello world", "tts.test_tts", "en-US"),
|
||||
)
|
||||
stream.async_set_message("Hello world")
|
||||
with pytest.raises(HomeAssistantError): # noqa: PT012
|
||||
async for _chunk in stream.async_stream_result():
|
||||
pass
|
||||
|
||||
|
||||
async def test_get_tts_audio_audio_oserror(
|
||||
|
Loading…
x
Reference in New Issue
Block a user