mirror of
https://github.com/home-assistant/core.git
synced 2025-07-15 17:27:10 +00:00
Switch more TTS core to async generators (#140432)
* Switch more TTS core to async generators * Document a design choice * robust * Add more tests * Update comment * Clarify and document TTSCache variables
This commit is contained in:
parent
b07ac301b9
commit
55895df54d
@ -17,7 +17,7 @@ import secrets
|
|||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
from time import monotonic
|
from time import monotonic
|
||||||
from typing import Any, Final, TypedDict
|
from typing import Any, Final
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
import mutagen
|
import mutagen
|
||||||
@ -123,13 +123,94 @@ KEY_PATTERN = "{0}_{1}_{2}_{3}"
|
|||||||
SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({})
|
SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({})
|
||||||
|
|
||||||
|
|
||||||
class TTSCache(TypedDict):
|
class TTSCache:
|
||||||
"""Cached TTS file."""
|
"""Cached bytes of a TTS result."""
|
||||||
|
|
||||||
extension: str
|
_result_data: bytes | None = None
|
||||||
voice: bytes
|
"""When fully loaded, contains the result data."""
|
||||||
pending: asyncio.Task | None
|
|
||||||
last_used: float
|
_partial_data: list[bytes] | None = None
|
||||||
|
"""While loading, contains the data already received from the generator."""
|
||||||
|
|
||||||
|
_loading_error: Exception | None = None
|
||||||
|
"""If an error occurred while loading, contains the error."""
|
||||||
|
|
||||||
|
_consumers: list[asyncio.Queue[bytes | None]] | None = None
|
||||||
|
"""A queue for each current consumer to notify of new data while the generator is loading."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cache_key: str,
|
||||||
|
extension: str,
|
||||||
|
data_gen: AsyncGenerator[bytes],
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the TTS cache."""
|
||||||
|
self.cache_key = cache_key
|
||||||
|
self.extension = extension
|
||||||
|
self.last_used = monotonic()
|
||||||
|
self._data_gen = data_gen
|
||||||
|
|
||||||
|
async def async_load_data(self) -> bytes:
|
||||||
|
"""Load the data from the generator."""
|
||||||
|
if self._result_data is not None or self._partial_data is not None:
|
||||||
|
raise RuntimeError("Data already being loaded")
|
||||||
|
|
||||||
|
self._partial_data = []
|
||||||
|
self._consumers = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for chunk in self._data_gen:
|
||||||
|
self._partial_data.append(chunk)
|
||||||
|
for queue in self._consumers:
|
||||||
|
queue.put_nowait(chunk)
|
||||||
|
except Exception as err: # pylint: disable=broad-except
|
||||||
|
self._loading_error = err
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
for queue in self._consumers:
|
||||||
|
queue.put_nowait(None)
|
||||||
|
self._consumers = None
|
||||||
|
|
||||||
|
self._result_data = b"".join(self._partial_data)
|
||||||
|
self._partial_data = None
|
||||||
|
return self._result_data
|
||||||
|
|
||||||
|
async def async_stream_data(self) -> AsyncGenerator[bytes]:
|
||||||
|
"""Stream the data.
|
||||||
|
|
||||||
|
Will return all data already returned from the generator.
|
||||||
|
Will listen for future data returned from the generator.
|
||||||
|
Raises error if one occurred.
|
||||||
|
"""
|
||||||
|
if self._result_data is not None:
|
||||||
|
yield self._result_data
|
||||||
|
return
|
||||||
|
if self._loading_error:
|
||||||
|
raise self._loading_error
|
||||||
|
|
||||||
|
if self._partial_data is None:
|
||||||
|
raise RuntimeError("Data not being loaded")
|
||||||
|
|
||||||
|
queue: asyncio.Queue[bytes | None] | None = None
|
||||||
|
# Check if generator is still feeding data
|
||||||
|
if self._consumers is not None:
|
||||||
|
queue = asyncio.Queue()
|
||||||
|
self._consumers.append(queue)
|
||||||
|
|
||||||
|
for chunk in list(self._partial_data):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
if self._loading_error:
|
||||||
|
raise self._loading_error
|
||||||
|
|
||||||
|
if queue is not None:
|
||||||
|
while (chunk2 := await queue.get()) is not None:
|
||||||
|
yield chunk2
|
||||||
|
|
||||||
|
if self._loading_error:
|
||||||
|
raise self._loading_error
|
||||||
|
|
||||||
|
self.last_used = monotonic()
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@ -194,10 +275,11 @@ async def async_get_media_source_audio(
|
|||||||
) -> tuple[str, bytes]:
|
) -> tuple[str, bytes]:
|
||||||
"""Get TTS audio as extension, data."""
|
"""Get TTS audio as extension, data."""
|
||||||
manager = hass.data[DATA_TTS_MANAGER]
|
manager = hass.data[DATA_TTS_MANAGER]
|
||||||
cache_key = manager.async_cache_message_in_memory(
|
cache = manager.async_cache_message_in_memory(
|
||||||
**media_source_id_to_kwargs(media_source_id)
|
**media_source_id_to_kwargs(media_source_id)
|
||||||
)
|
)
|
||||||
return await manager.async_get_tts_audio(cache_key)
|
data = b"".join([chunk async for chunk in cache.async_stream_data()])
|
||||||
|
return cache.extension, data
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@ -216,18 +298,19 @@ def async_get_text_to_speech_languages(hass: HomeAssistant) -> set[str]:
|
|||||||
return languages
|
return languages
|
||||||
|
|
||||||
|
|
||||||
async def async_convert_audio(
|
async def _async_convert_audio(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
from_extension: str,
|
from_extension: str,
|
||||||
audio_bytes: bytes,
|
audio_bytes_gen: AsyncGenerator[bytes],
|
||||||
to_extension: str,
|
to_extension: str,
|
||||||
to_sample_rate: int | None = None,
|
to_sample_rate: int | None = None,
|
||||||
to_sample_channels: int | None = None,
|
to_sample_channels: int | None = None,
|
||||||
to_sample_bytes: int | None = None,
|
to_sample_bytes: int | None = None,
|
||||||
) -> bytes:
|
) -> AsyncGenerator[bytes]:
|
||||||
"""Convert audio to a preferred format using ffmpeg."""
|
"""Convert audio to a preferred format using ffmpeg."""
|
||||||
ffmpeg_manager = ffmpeg.get_ffmpeg_manager(hass)
|
ffmpeg_manager = ffmpeg.get_ffmpeg_manager(hass)
|
||||||
return await hass.async_add_executor_job(
|
audio_bytes = b"".join([chunk async for chunk in audio_bytes_gen])
|
||||||
|
data = await hass.async_add_executor_job(
|
||||||
lambda: _convert_audio(
|
lambda: _convert_audio(
|
||||||
ffmpeg_manager.binary,
|
ffmpeg_manager.binary,
|
||||||
from_extension,
|
from_extension,
|
||||||
@ -238,6 +321,7 @@ async def async_convert_audio(
|
|||||||
to_sample_bytes=to_sample_bytes,
|
to_sample_bytes=to_sample_bytes,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
yield data
|
||||||
|
|
||||||
|
|
||||||
def _convert_audio(
|
def _convert_audio(
|
||||||
@ -401,32 +485,33 @@ class ResultStream:
|
|||||||
return f"/api/tts_proxy/{self.token}"
|
return f"/api/tts_proxy/{self.token}"
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def _result_cache_key(self) -> asyncio.Future[str]:
|
def _result_cache(self) -> asyncio.Future[TTSCache]:
|
||||||
"""Get the future that returns the cache key."""
|
"""Get the future that returns the cache."""
|
||||||
return asyncio.Future()
|
return asyncio.Future()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_set_message_cache_key(self, cache_key: str) -> None:
|
def async_set_message_cache(self, cache: TTSCache) -> None:
|
||||||
"""Set cache key for message to be streamed."""
|
"""Set cache containing message audio to be streamed."""
|
||||||
self._result_cache_key.set_result(cache_key)
|
self._result_cache.set_result(cache)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_set_message(self, message: str) -> None:
|
def async_set_message(self, message: str) -> None:
|
||||||
"""Set message to be generated."""
|
"""Set message to be generated."""
|
||||||
cache_key = self._manager.async_cache_message_in_memory(
|
self._result_cache.set_result(
|
||||||
engine=self.engine,
|
self._manager.async_cache_message_in_memory(
|
||||||
message=message,
|
engine=self.engine,
|
||||||
use_file_cache=self.use_file_cache,
|
message=message,
|
||||||
language=self.language,
|
use_file_cache=self.use_file_cache,
|
||||||
options=self.options,
|
language=self.language,
|
||||||
|
options=self.options,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self._result_cache_key.set_result(cache_key)
|
|
||||||
|
|
||||||
async def async_stream_result(self) -> AsyncGenerator[bytes]:
|
async def async_stream_result(self) -> AsyncGenerator[bytes]:
|
||||||
"""Get the stream of this result."""
|
"""Get the stream of this result."""
|
||||||
cache_key = await self._result_cache_key
|
cache = await self._result_cache
|
||||||
_extension, data = await self._manager.async_get_tts_audio(cache_key)
|
async for chunk in cache.async_stream_data():
|
||||||
yield data
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
def _hash_options(options: dict) -> str:
|
def _hash_options(options: dict) -> str:
|
||||||
@ -483,7 +568,7 @@ class MemcacheCleanup:
|
|||||||
now = monotonic()
|
now = monotonic()
|
||||||
|
|
||||||
for cache_key, info in list(memcache.items()):
|
for cache_key, info in list(memcache.items()):
|
||||||
if info["last_used"] + maxage < now:
|
if info.last_used + maxage < now:
|
||||||
_LOGGER.debug("Cleaning up %s", cache_key)
|
_LOGGER.debug("Cleaning up %s", cache_key)
|
||||||
del memcache[cache_key]
|
del memcache[cache_key]
|
||||||
|
|
||||||
@ -638,15 +723,18 @@ class SpeechManager:
|
|||||||
if message is None:
|
if message is None:
|
||||||
return result_stream
|
return result_stream
|
||||||
|
|
||||||
cache_key = self._async_ensure_cached_in_memory(
|
# We added this method as an alternative to stream.async_set_message
|
||||||
engine=engine,
|
# to avoid the options being processed twice
|
||||||
engine_instance=engine_instance,
|
result_stream.async_set_message_cache(
|
||||||
message=message,
|
self._async_ensure_cached_in_memory(
|
||||||
use_file_cache=use_file_cache,
|
engine=engine,
|
||||||
language=language,
|
engine_instance=engine_instance,
|
||||||
options=options,
|
message=message,
|
||||||
|
use_file_cache=use_file_cache,
|
||||||
|
language=language,
|
||||||
|
options=options,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
result_stream.async_set_message_cache_key(cache_key)
|
|
||||||
|
|
||||||
return result_stream
|
return result_stream
|
||||||
|
|
||||||
@ -658,7 +746,7 @@ class SpeechManager:
|
|||||||
use_file_cache: bool | None = None,
|
use_file_cache: bool | None = None,
|
||||||
language: str | None = None,
|
language: str | None = None,
|
||||||
options: dict | None = None,
|
options: dict | None = None,
|
||||||
) -> str:
|
) -> TTSCache:
|
||||||
"""Make sure a message is cached in memory and returns cache key."""
|
"""Make sure a message is cached in memory and returns cache key."""
|
||||||
if (engine_instance := get_engine_instance(self.hass, engine)) is None:
|
if (engine_instance := get_engine_instance(self.hass, engine)) is None:
|
||||||
raise HomeAssistantError(f"Provider {engine} not found")
|
raise HomeAssistantError(f"Provider {engine} not found")
|
||||||
@ -685,7 +773,7 @@ class SpeechManager:
|
|||||||
use_file_cache: bool,
|
use_file_cache: bool,
|
||||||
language: str,
|
language: str,
|
||||||
options: dict,
|
options: dict,
|
||||||
) -> str:
|
) -> TTSCache:
|
||||||
"""Ensure a message is cached.
|
"""Ensure a message is cached.
|
||||||
|
|
||||||
Requires options, language to be processed.
|
Requires options, language to be processed.
|
||||||
@ -697,62 +785,101 @@ class SpeechManager:
|
|||||||
).lower()
|
).lower()
|
||||||
|
|
||||||
# Is speech already in memory
|
# Is speech already in memory
|
||||||
if cache_key in self.mem_cache:
|
if cache := self.mem_cache.get(cache_key):
|
||||||
return cache_key
|
_LOGGER.debug("Found audio in cache for %s", message[0:32])
|
||||||
|
return cache
|
||||||
|
|
||||||
if use_file_cache and cache_key in self.file_cache:
|
store_to_disk = use_file_cache
|
||||||
coro = self._async_load_file_to_mem(cache_key)
|
|
||||||
|
if use_file_cache and (filename := self.file_cache.get(cache_key)):
|
||||||
|
_LOGGER.debug("Loading audio from disk for %s", message[0:32])
|
||||||
|
extension = os.path.splitext(filename)[1][1:]
|
||||||
|
data_gen = self._async_load_file(cache_key)
|
||||||
|
store_to_disk = False
|
||||||
else:
|
else:
|
||||||
coro = self._async_generate_tts_audio(
|
_LOGGER.debug("Generating audio for %s", message[0:32])
|
||||||
engine_instance, cache_key, message, use_file_cache, language, options
|
extension = options.get(ATTR_PREFERRED_FORMAT, _DEFAULT_FORMAT)
|
||||||
|
data_gen = self._async_generate_tts_audio(
|
||||||
|
engine_instance, message, language, options
|
||||||
)
|
)
|
||||||
|
|
||||||
task = self.hass.async_create_task(coro, eager_start=False)
|
cache = TTSCache(
|
||||||
|
cache_key=cache_key,
|
||||||
|
extension=extension,
|
||||||
|
data_gen=data_gen,
|
||||||
|
)
|
||||||
|
|
||||||
def handle_error(future: asyncio.Future) -> None:
|
self.mem_cache[cache_key] = cache
|
||||||
"""Handle error."""
|
self.hass.async_create_background_task(
|
||||||
if not (err := future.exception()):
|
self._load_data_into_cache(
|
||||||
return
|
cache, engine_instance, message, store_to_disk, language, options
|
||||||
|
),
|
||||||
|
f"tts_load_data_into_cache_{engine_instance.name}",
|
||||||
|
)
|
||||||
|
self.memcache_cleanup.schedule()
|
||||||
|
return cache
|
||||||
|
|
||||||
|
async def _load_data_into_cache(
|
||||||
|
self,
|
||||||
|
cache: TTSCache,
|
||||||
|
engine_instance: TextToSpeechEntity | Provider,
|
||||||
|
message: str,
|
||||||
|
store_to_disk: bool,
|
||||||
|
language: str,
|
||||||
|
options: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Load and process a finished loading TTS Cache."""
|
||||||
|
try:
|
||||||
|
data = await cache.async_load_data()
|
||||||
|
except Exception as err: # pylint: disable=broad-except # noqa: BLE001
|
||||||
# Truncate message so we don't flood the logs. Cutting off at 32 chars
|
# Truncate message so we don't flood the logs. Cutting off at 32 chars
|
||||||
# but since we add 3 dots to truncated message, we cut off at 35.
|
# but since we add 3 dots to truncated message, we cut off at 35.
|
||||||
trunc_msg = message if len(message) < 35 else f"{message[0:32]}…"
|
trunc_msg = message if len(message) < 35 else f"{message[0:32]}…"
|
||||||
_LOGGER.error("Error generating audio for %s: %s", trunc_msg, err)
|
_LOGGER.error("Error getting audio for %s: %s", trunc_msg, err)
|
||||||
self.mem_cache.pop(cache_key, None)
|
self.mem_cache.pop(cache.cache_key, None)
|
||||||
|
return
|
||||||
|
|
||||||
task.add_done_callback(handle_error)
|
if not store_to_disk:
|
||||||
|
return
|
||||||
|
|
||||||
self.mem_cache[cache_key] = {
|
filename = f"{cache.cache_key}.{cache.extension}".lower()
|
||||||
"extension": "",
|
|
||||||
"voice": b"",
|
|
||||||
"pending": task,
|
|
||||||
"last_used": monotonic(),
|
|
||||||
}
|
|
||||||
return cache_key
|
|
||||||
|
|
||||||
async def async_get_tts_audio(self, cache_key: str) -> tuple[str, bytes]:
|
# Validate filename
|
||||||
"""Fetch TTS audio."""
|
if not _RE_VOICE_FILE.match(filename) and not _RE_LEGACY_VOICE_FILE.match(
|
||||||
cached = self.mem_cache.get(cache_key)
|
filename
|
||||||
if cached is None:
|
):
|
||||||
raise HomeAssistantError("Audio not cached")
|
raise HomeAssistantError(
|
||||||
if pending := cached.get("pending"):
|
f"TTS filename '{filename}' from {engine_instance.name} is invalid!"
|
||||||
await pending
|
)
|
||||||
cached = self.mem_cache[cache_key]
|
|
||||||
cached["last_used"] = monotonic()
|
if cache.extension == "mp3":
|
||||||
return cached["extension"], cached["voice"]
|
name = (
|
||||||
|
engine_instance.name if isinstance(engine_instance.name, str) else "-"
|
||||||
|
)
|
||||||
|
data = self.write_tags(filename, data, name, message, language, options)
|
||||||
|
|
||||||
|
voice_file = os.path.join(self.cache_dir, filename)
|
||||||
|
|
||||||
|
def save_speech() -> None:
|
||||||
|
"""Store speech to filesystem."""
|
||||||
|
with open(voice_file, "wb") as speech:
|
||||||
|
speech.write(data)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.hass.async_add_executor_job(save_speech)
|
||||||
|
except OSError as err:
|
||||||
|
_LOGGER.error("Can't write %s: %s", filename, err)
|
||||||
|
else:
|
||||||
|
self.file_cache[cache.cache_key] = filename
|
||||||
|
|
||||||
async def _async_generate_tts_audio(
|
async def _async_generate_tts_audio(
|
||||||
self,
|
self,
|
||||||
engine_instance: TextToSpeechEntity | Provider,
|
engine_instance: TextToSpeechEntity | Provider,
|
||||||
cache_key: str,
|
|
||||||
message: str,
|
message: str,
|
||||||
cache_to_disk: bool,
|
|
||||||
language: str,
|
language: str,
|
||||||
options: dict[str, Any],
|
options: dict[str, Any],
|
||||||
) -> None:
|
) -> AsyncGenerator[bytes]:
|
||||||
"""Start loading of the TTS audio.
|
"""Generate TTS audio from an engine."""
|
||||||
|
|
||||||
This method is a coroutine.
|
|
||||||
"""
|
|
||||||
options = dict(options or {})
|
options = dict(options or {})
|
||||||
supported_options = engine_instance.supported_options or []
|
supported_options = engine_instance.supported_options or []
|
||||||
|
|
||||||
@ -800,6 +927,17 @@ class SpeechManager:
|
|||||||
extension, data = await engine_instance.async_get_tts_audio(
|
extension, data = await engine_instance.async_get_tts_audio(
|
||||||
message, language, options
|
message, language, options
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if data is None or extension is None:
|
||||||
|
raise HomeAssistantError(
|
||||||
|
f"No TTS from {engine_instance.name} for '{message}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def make_data_generator(data: bytes) -> AsyncGenerator[bytes]:
|
||||||
|
yield data
|
||||||
|
|
||||||
|
data_gen = make_data_generator(data)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
async def message_gen() -> AsyncGenerator[str]:
|
async def message_gen() -> AsyncGenerator[str]:
|
||||||
@ -809,12 +947,7 @@ class SpeechManager:
|
|||||||
TTSAudioRequest(language, options, message_gen())
|
TTSAudioRequest(language, options, message_gen())
|
||||||
)
|
)
|
||||||
extension = tts_result.extension
|
extension = tts_result.extension
|
||||||
data = b"".join([chunk async for chunk in tts_result.data_gen])
|
data_gen = tts_result.data_gen
|
||||||
|
|
||||||
if data is None or extension is None:
|
|
||||||
raise HomeAssistantError(
|
|
||||||
f"No TTS from {engine_instance.name} for '{message}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Only convert if we have a preferred format different than the
|
# Only convert if we have a preferred format different than the
|
||||||
# expected format from the TTS system, or if a specific sample
|
# expected format from the TTS system, or if a specific sample
|
||||||
@ -827,62 +960,21 @@ class SpeechManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if needs_conversion:
|
if needs_conversion:
|
||||||
data = await async_convert_audio(
|
data_gen = _async_convert_audio(
|
||||||
self.hass,
|
self.hass,
|
||||||
extension,
|
extension,
|
||||||
data,
|
data_gen,
|
||||||
to_extension=final_extension,
|
to_extension=final_extension,
|
||||||
to_sample_rate=sample_rate,
|
to_sample_rate=sample_rate,
|
||||||
to_sample_channels=sample_channels,
|
to_sample_channels=sample_channels,
|
||||||
to_sample_bytes=sample_bytes,
|
to_sample_bytes=sample_bytes,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create file infos
|
async for chunk in data_gen:
|
||||||
filename = f"{cache_key}.{final_extension}".lower()
|
yield chunk
|
||||||
|
|
||||||
# Validate filename
|
async def _async_load_file(self, cache_key: str) -> AsyncGenerator[bytes]:
|
||||||
if not _RE_VOICE_FILE.match(filename) and not _RE_LEGACY_VOICE_FILE.match(
|
"""Load TTS audio from disk."""
|
||||||
filename
|
|
||||||
):
|
|
||||||
raise HomeAssistantError(
|
|
||||||
f"TTS filename '{filename}' from {engine_instance.name} is invalid!"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save to memory
|
|
||||||
if final_extension == "mp3":
|
|
||||||
data = self.write_tags(
|
|
||||||
filename, data, engine_instance.name, message, language, options
|
|
||||||
)
|
|
||||||
|
|
||||||
self._async_store_to_memcache(cache_key, final_extension, data)
|
|
||||||
|
|
||||||
if not cache_to_disk:
|
|
||||||
return
|
|
||||||
|
|
||||||
voice_file = os.path.join(self.cache_dir, filename)
|
|
||||||
|
|
||||||
def save_speech() -> None:
|
|
||||||
"""Store speech to filesystem."""
|
|
||||||
with open(voice_file, "wb") as speech:
|
|
||||||
speech.write(data)
|
|
||||||
|
|
||||||
# Don't await, we're going to do this in the background
|
|
||||||
task = self.hass.async_add_executor_job(save_speech)
|
|
||||||
|
|
||||||
def write_done(future: asyncio.Future) -> None:
|
|
||||||
"""Write is done task."""
|
|
||||||
if err := future.exception():
|
|
||||||
_LOGGER.error("Can't write %s: %s", filename, err)
|
|
||||||
else:
|
|
||||||
self.file_cache[cache_key] = filename
|
|
||||||
|
|
||||||
task.add_done_callback(write_done)
|
|
||||||
|
|
||||||
async def _async_load_file_to_mem(self, cache_key: str) -> None:
|
|
||||||
"""Load voice from file cache into memory.
|
|
||||||
|
|
||||||
This method is a coroutine.
|
|
||||||
"""
|
|
||||||
if not (filename := self.file_cache.get(cache_key)):
|
if not (filename := self.file_cache.get(cache_key)):
|
||||||
raise HomeAssistantError(f"Key {cache_key} not in file cache!")
|
raise HomeAssistantError(f"Key {cache_key} not in file cache!")
|
||||||
|
|
||||||
@ -899,22 +991,7 @@ class SpeechManager:
|
|||||||
del self.file_cache[cache_key]
|
del self.file_cache[cache_key]
|
||||||
raise HomeAssistantError(f"Can't read {voice_file}") from err
|
raise HomeAssistantError(f"Can't read {voice_file}") from err
|
||||||
|
|
||||||
extension = os.path.splitext(filename)[1][1:]
|
yield data
|
||||||
|
|
||||||
self._async_store_to_memcache(cache_key, extension, data)
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def _async_store_to_memcache(
|
|
||||||
self, cache_key: str, extension: str, data: bytes
|
|
||||||
) -> None:
|
|
||||||
"""Store data to memcache and set timer to remove it."""
|
|
||||||
self.mem_cache[cache_key] = {
|
|
||||||
"extension": extension,
|
|
||||||
"voice": data,
|
|
||||||
"pending": None,
|
|
||||||
"last_used": monotonic(),
|
|
||||||
}
|
|
||||||
self.memcache_cleanup.schedule()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def write_tags(
|
def write_tags(
|
||||||
|
@ -168,7 +168,7 @@ async def test_service(
|
|||||||
assert await get_media_source_url(
|
assert await get_media_source_url(
|
||||||
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
|
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
|
||||||
) == ("/api/tts_proxy/test_token.mp3")
|
) == ("/api/tts_proxy/test_token.mp3")
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done(wait_background_tasks=True)
|
||||||
assert (
|
assert (
|
||||||
mock_tts_cache_dir
|
mock_tts_cache_dir
|
||||||
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_{expected_url_suffix}.mp3"
|
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_{expected_url_suffix}.mp3"
|
||||||
@ -230,7 +230,7 @@ async def test_service_default_language(
|
|||||||
assert await get_media_source_url(
|
assert await get_media_source_url(
|
||||||
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
|
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
|
||||||
) == ("/api/tts_proxy/test_token.mp3")
|
) == ("/api/tts_proxy/test_token.mp3")
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done(wait_background_tasks=True)
|
||||||
assert (
|
assert (
|
||||||
mock_tts_cache_dir
|
mock_tts_cache_dir
|
||||||
/ (
|
/ (
|
||||||
@ -294,7 +294,7 @@ async def test_service_default_special_language(
|
|||||||
assert await get_media_source_url(
|
assert await get_media_source_url(
|
||||||
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
|
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
|
||||||
) == ("/api/tts_proxy/test_token.mp3")
|
) == ("/api/tts_proxy/test_token.mp3")
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done(wait_background_tasks=True)
|
||||||
assert (
|
assert (
|
||||||
mock_tts_cache_dir
|
mock_tts_cache_dir
|
||||||
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_{expected_url_suffix}.mp3"
|
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_{expected_url_suffix}.mp3"
|
||||||
@ -354,7 +354,7 @@ async def test_service_language(
|
|||||||
assert await get_media_source_url(
|
assert await get_media_source_url(
|
||||||
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
|
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
|
||||||
) == ("/api/tts_proxy/test_token.mp3")
|
) == ("/api/tts_proxy/test_token.mp3")
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done(wait_background_tasks=True)
|
||||||
assert (
|
assert (
|
||||||
mock_tts_cache_dir
|
mock_tts_cache_dir
|
||||||
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_-_{expected_url_suffix}.mp3"
|
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_-_{expected_url_suffix}.mp3"
|
||||||
@ -470,7 +470,7 @@ async def test_service_options(
|
|||||||
assert await get_media_source_url(
|
assert await get_media_source_url(
|
||||||
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
|
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
|
||||||
) == ("/api/tts_proxy/test_token.mp3")
|
) == ("/api/tts_proxy/test_token.mp3")
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done(wait_background_tasks=True)
|
||||||
assert (
|
assert (
|
||||||
mock_tts_cache_dir
|
mock_tts_cache_dir
|
||||||
/ (
|
/ (
|
||||||
@ -554,7 +554,7 @@ async def test_service_default_options(
|
|||||||
assert await get_media_source_url(
|
assert await get_media_source_url(
|
||||||
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
|
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
|
||||||
) == ("/api/tts_proxy/test_token.mp3")
|
) == ("/api/tts_proxy/test_token.mp3")
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done(wait_background_tasks=True)
|
||||||
assert (
|
assert (
|
||||||
mock_tts_cache_dir
|
mock_tts_cache_dir
|
||||||
/ (
|
/ (
|
||||||
@ -628,7 +628,7 @@ async def test_merge_default_service_options(
|
|||||||
assert await get_media_source_url(
|
assert await get_media_source_url(
|
||||||
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
|
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
|
||||||
) == ("/api/tts_proxy/test_token.mp3")
|
) == ("/api/tts_proxy/test_token.mp3")
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done(wait_background_tasks=True)
|
||||||
assert (
|
assert (
|
||||||
mock_tts_cache_dir
|
mock_tts_cache_dir
|
||||||
/ (
|
/ (
|
||||||
@ -743,7 +743,7 @@ async def test_service_clear_cache(
|
|||||||
# To make sure the file is persisted
|
# To make sure the file is persisted
|
||||||
assert len(calls) == 1
|
assert len(calls) == 1
|
||||||
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done(wait_background_tasks=True)
|
||||||
assert (
|
assert (
|
||||||
mock_tts_cache_dir
|
mock_tts_cache_dir
|
||||||
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_{expected_url_suffix}.mp3"
|
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_{expected_url_suffix}.mp3"
|
||||||
@ -1769,9 +1769,15 @@ async def test_async_convert_audio_error(hass: HomeAssistant) -> None:
|
|||||||
"""Test that ffmpeg failing during audio conversion will raise an error."""
|
"""Test that ffmpeg failing during audio conversion will raise an error."""
|
||||||
assert await async_setup_component(hass, ffmpeg.DOMAIN, {})
|
assert await async_setup_component(hass, ffmpeg.DOMAIN, {})
|
||||||
|
|
||||||
with pytest.raises(RuntimeError):
|
async def bad_data_gen():
|
||||||
|
yield bytes(0)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError): # noqa: PT012
|
||||||
# Simulate a bad WAV file
|
# Simulate a bad WAV file
|
||||||
await tts.async_convert_audio(hass, "wav", bytes(0), "mp3")
|
async for _chunk in tts._async_convert_audio(
|
||||||
|
hass, "wav", bad_data_gen(), "mp3"
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
async def test_default_engine_prefer_entity(
|
async def test_default_engine_prefer_entity(
|
||||||
@ -1846,3 +1852,86 @@ async def test_stream(hass: HomeAssistant, mock_tts_entity: MockTTSEntity) -> No
|
|||||||
assert stream2.extension == "wav"
|
assert stream2.extension == "wav"
|
||||||
result_data = b"".join([chunk async for chunk in stream2.async_stream_result()])
|
result_data = b"".join([chunk async for chunk in stream2.async_stream_result()])
|
||||||
assert result_data == data
|
assert result_data == data
|
||||||
|
|
||||||
|
|
||||||
|
async def test_tts_cache() -> None:
|
||||||
|
"""Test TTSCache."""
|
||||||
|
|
||||||
|
async def data_gen(queue: asyncio.Queue[bytes | None | Exception]):
|
||||||
|
while chunk := await queue.get():
|
||||||
|
if isinstance(chunk, Exception):
|
||||||
|
raise chunk
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
queue = asyncio.Queue()
|
||||||
|
cache = tts.TTSCache("test-key", "mp3", data_gen(queue))
|
||||||
|
assert cache.cache_key == "test-key"
|
||||||
|
assert cache.extension == "mp3"
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
queue.put_nowait(f"{i}".encode())
|
||||||
|
queue.put_nowait(None)
|
||||||
|
|
||||||
|
assert await cache.async_load_data() == b"0123456789"
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
await cache.async_load_data()
|
||||||
|
|
||||||
|
# When data is loaded, we get it all in 1 chunk
|
||||||
|
cur = 0
|
||||||
|
async for chunk in cache.async_stream_data():
|
||||||
|
assert chunk == b"0123456789"
|
||||||
|
cur += 1
|
||||||
|
assert cur == 1
|
||||||
|
|
||||||
|
# Show we can stream the data while it's still being generated
|
||||||
|
async def consume_cache(cache: tts.TTSCache):
|
||||||
|
return b"".join([chunk async for chunk in cache.async_stream_data()])
|
||||||
|
|
||||||
|
queue = asyncio.Queue()
|
||||||
|
cache = tts.TTSCache("test-key", "mp3", data_gen(queue))
|
||||||
|
|
||||||
|
load_data_task = asyncio.create_task(cache.async_load_data())
|
||||||
|
consume_pre_data_loaded_task = asyncio.create_task(consume_cache(cache))
|
||||||
|
queue.put_nowait(b"0")
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
queue.put_nowait(b"1")
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
consume_mid_data_task = asyncio.create_task(consume_cache(cache))
|
||||||
|
queue.put_nowait(b"2")
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
queue.put_nowait(None)
|
||||||
|
consume_post_data_loaded_task = asyncio.create_task(consume_cache(cache))
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert await load_data_task == b"012"
|
||||||
|
assert await consume_post_data_loaded_task == b"012"
|
||||||
|
assert await consume_mid_data_task == b"012"
|
||||||
|
assert await consume_pre_data_loaded_task == b"012"
|
||||||
|
|
||||||
|
# Now with errors
|
||||||
|
async def consume_cache(cache: tts.TTSCache):
|
||||||
|
return b"".join([chunk async for chunk in cache.async_stream_data()])
|
||||||
|
|
||||||
|
queue = asyncio.Queue()
|
||||||
|
cache = tts.TTSCache("test-key", "mp3", data_gen(queue))
|
||||||
|
|
||||||
|
load_data_task = asyncio.create_task(cache.async_load_data())
|
||||||
|
consume_pre_data_loaded_task = asyncio.create_task(consume_cache(cache))
|
||||||
|
queue.put_nowait(b"0")
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
queue.put_nowait(b"1")
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
consume_mid_data_task = asyncio.create_task(consume_cache(cache))
|
||||||
|
queue.put_nowait(ValueError("Boom!"))
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
queue.put_nowait(None)
|
||||||
|
consume_post_data_loaded_task = asyncio.create_task(consume_cache(cache))
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
assert await load_data_task == b"012"
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
assert await consume_post_data_loaded_task == b"012"
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
assert await consume_mid_data_task == b"012"
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
assert await consume_pre_data_loaded_task == b"012"
|
||||||
|
@ -150,17 +150,15 @@ async def test_get_tts_audio_connection_lost(
|
|||||||
hass: HomeAssistant, init_wyoming_tts
|
hass: HomeAssistant, init_wyoming_tts
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test streaming audio and losing connection."""
|
"""Test streaming audio and losing connection."""
|
||||||
with (
|
stream = tts.async_create_stream(hass, "tts.test_tts", "en-US")
|
||||||
patch(
|
with patch(
|
||||||
"homeassistant.components.wyoming.tts.AsyncTcpClient",
|
"homeassistant.components.wyoming.tts.AsyncTcpClient",
|
||||||
MockAsyncTcpClient([None]),
|
MockAsyncTcpClient([None]),
|
||||||
),
|
|
||||||
pytest.raises(HomeAssistantError),
|
|
||||||
):
|
):
|
||||||
await tts.async_get_media_source_audio(
|
stream.async_set_message("Hello world")
|
||||||
hass,
|
with pytest.raises(HomeAssistantError): # noqa: PT012
|
||||||
tts.generate_media_source_id(hass, "Hello world", "tts.test_tts", "en-US"),
|
async for _chunk in stream.async_stream_result():
|
||||||
)
|
pass
|
||||||
|
|
||||||
|
|
||||||
async def test_get_tts_audio_audio_oserror(
|
async def test_get_tts_audio_audio_oserror(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user