Clean up Text-to-Speech (#143744)

This commit is contained in:
Paulus Schoutsen 2025-04-29 22:29:35 -04:00 committed by GitHub
parent 97084e9382
commit f980434046
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 108 additions and 108 deletions

View File

@ -3,8 +3,8 @@
from __future__ import annotations
import asyncio
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from collections.abc import AsyncGenerator, MutableMapping
from dataclasses import dataclass, field
from datetime import datetime
import hashlib
from http import HTTPStatus
@ -15,7 +15,7 @@ import os
import re
import secrets
from time import monotonic
from typing import Any, Final
from typing import Any, Final, Generic, Protocol, TypeVar
from aiohttp import web
import mutagen
@ -60,10 +60,10 @@ from .const import (
DOMAIN,
TtsAudioType,
)
from .entity import TextToSpeechEntity, TTSAudioRequest
from .entity import TextToSpeechEntity, TTSAudioRequest, TTSAudioResponse
from .helper import get_engine_instance
from .legacy import PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE, Provider, async_setup_legacy
from .media_source import generate_media_source_id, media_source_id_to_kwargs
from .media_source import generate_media_source_id, parse_media_source_id
from .models import Voice
__all__ = [
@ -79,6 +79,7 @@ __all__ = [
"Provider",
"ResultStream",
"SampleFormat",
"TTSAudioResponse",
"TextToSpeechEntity",
"TtsAudioType",
"Voice",
@ -264,7 +265,7 @@ def async_create_stream(
@callback
def async_get_stream(hass: HomeAssistant, token: str) -> ResultStream | None:
"""Return a result stream given a token."""
return hass.data[DATA_TTS_MANAGER].token_to_stream.get(token)
return hass.data[DATA_TTS_MANAGER].async_get_result_stream(token)
async def async_get_media_source_audio(
@ -272,12 +273,11 @@ async def async_get_media_source_audio(
media_source_id: str,
) -> tuple[str, bytes]:
"""Get TTS audio as extension, data."""
manager = hass.data[DATA_TTS_MANAGER]
cache = manager.async_cache_message_in_memory(
**media_source_id_to_kwargs(media_source_id)
)
data = b"".join([chunk async for chunk in cache.async_stream_data()])
return cache.extension, data
parsed = parse_media_source_id(media_source_id)
stream = hass.data[DATA_TTS_MANAGER].async_create_result_stream(**parsed["options"])
stream.async_set_message(parsed["message"])
data = b"".join([chunk async for chunk in stream.async_stream_result()])
return stream.extension, data
@callback
@ -457,6 +457,8 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
class ResultStream:
"""Class that will stream the result when available."""
last_used: float = field(default_factory=monotonic, init=False)
# Streaming/conversion properties
token: str
extension: str
@ -480,11 +482,6 @@ class ResultStream:
"""Get the future that returns the cache."""
return asyncio.Future()
@callback
def async_set_message_cache(self, cache: TTSCache) -> None:
"""Set cache containing message audio to be streamed."""
self._result_cache.set_result(cache)
@callback
def async_set_message(self, message: str) -> None:
"""Set message to be generated."""
@ -504,6 +501,8 @@ class ResultStream:
async for chunk in cache.async_stream_data():
yield chunk
self.last_used = monotonic()
def _hash_options(options: dict) -> str:
"""Hashes an options dictionary."""
@ -515,13 +514,25 @@ def _hash_options(options: dict) -> str:
return opts_hash.hexdigest()
class MemcacheCleanup:
class HasLastUsed(Protocol):
"""Protocol for objects that have a last_used attribute."""
last_used: float
T = TypeVar("T", bound=HasLastUsed)
class DictCleaning(Generic[T]):
"""Helper to clean up the stale sessions."""
unsub: CALLBACK_TYPE | None = None
def __init__(
self, hass: HomeAssistant, maxage: float, memcache: dict[str, TTSCache]
self,
hass: HomeAssistant,
maxage: float,
memcache: MutableMapping[str, T],
) -> None:
"""Initialize the cleanup."""
self.hass = hass
@ -588,8 +599,9 @@ class SpeechManager:
self.file_cache: dict[str, str] = {}
self.mem_cache: dict[str, TTSCache] = {}
self.token_to_stream: dict[str, ResultStream] = {}
self.memcache_cleanup = MemcacheCleanup(
hass, memory_cache_maxage, self.mem_cache
self.memcache_cleanup = DictCleaning(hass, memory_cache_maxage, self.mem_cache)
self.token_to_stream_cleanup = DictCleaning(
hass, memory_cache_maxage, self.token_to_stream
)
def _init_cache(self) -> dict[str, str]:
@ -679,11 +691,21 @@ class SpeechManager:
return language, merged_options
@callback
def async_get_result_stream(
self,
token: str,
) -> ResultStream | None:
"""Return a result stream given a token."""
stream = self.token_to_stream.get(token, None)
if stream:
stream.last_used = monotonic()
return stream
@callback
def async_create_result_stream(
self,
engine: str,
message: str | None = None,
use_file_cache: bool | None = None,
language: str | None = None,
options: dict | None = None,
@ -710,23 +732,7 @@ class SpeechManager:
_manager=self,
)
self.token_to_stream[token] = result_stream
if message is None:
return result_stream
# We added this method as an alternative to stream.async_set_message
# to avoid the options being processed twice
result_stream.async_set_message_cache(
self._async_ensure_cached_in_memory(
engine=engine,
engine_instance=engine_instance,
message=message,
use_file_cache=use_file_cache,
language=language,
options=options,
)
)
self.token_to_stream_cleanup.schedule()
return result_stream
@callback
@ -734,41 +740,17 @@ class SpeechManager:
self,
engine: str,
message: str,
use_file_cache: bool | None = None,
language: str | None = None,
options: dict | None = None,
) -> TTSCache:
"""Make sure a message is cached in memory and returns cache key."""
if (engine_instance := get_engine_instance(self.hass, engine)) is None:
raise HomeAssistantError(f"Provider {engine} not found")
language, options = self.process_options(engine_instance, language, options)
if use_file_cache is None:
use_file_cache = self.use_file_cache
return self._async_ensure_cached_in_memory(
engine=engine,
engine_instance=engine_instance,
message=message,
use_file_cache=use_file_cache,
language=language,
options=options,
)
@callback
def _async_ensure_cached_in_memory(
self,
engine: str,
engine_instance: TextToSpeechEntity | Provider,
message: str,
use_file_cache: bool,
language: str,
options: dict,
) -> TTSCache:
"""Ensure a message is cached.
"""Make sure a message will be cached in memory and returns cache object.
Requires options, language to be processed.
"""
if (engine_instance := get_engine_instance(self.hass, engine)) is None:
raise HomeAssistantError(f"Provider {engine} not found")
options_key = _hash_options(options) if options else "-"
msg_hash = hashlib.sha1(bytes(message, "utf-8")).hexdigest()
cache_key = KEY_PATTERN.format(
@ -789,9 +771,13 @@ class SpeechManager:
store_to_disk = False
else:
_LOGGER.debug("Generating audio for %s", message[0:32])
async def message_stream() -> AsyncGenerator[str]:
yield message
extension = options.get(ATTR_PREFERRED_FORMAT, _DEFAULT_FORMAT)
data_gen = self._async_generate_tts_audio(
engine_instance, message, language, options
engine_instance, message_stream(), language, options
)
cache = TTSCache(
@ -799,7 +785,6 @@ class SpeechManager:
extension=extension,
data_gen=data_gen,
)
self.mem_cache[cache_key] = cache
self.hass.async_create_background_task(
self._load_data_into_cache(
@ -866,7 +851,7 @@ class SpeechManager:
async def _async_generate_tts_audio(
self,
engine_instance: TextToSpeechEntity | Provider,
message: str,
message_stream: AsyncGenerator[str],
language: str,
options: dict[str, Any],
) -> AsyncGenerator[bytes]:
@ -915,6 +900,7 @@ class SpeechManager:
raise HomeAssistantError("TTS engine name is not set.")
if isinstance(engine_instance, Provider):
message = "".join([chunk async for chunk in message_stream])
extension, data = await engine_instance.async_get_tts_audio(
message, language, options
)
@ -930,12 +916,8 @@ class SpeechManager:
data_gen = make_data_generator(data)
else:
async def message_gen() -> AsyncGenerator[str]:
yield message
tts_result = await engine_instance.internal_async_stream_tts_audio(
TTSAudioRequest(language, options, message_gen())
TTSAudioRequest(language, options, message_stream)
)
extension = tts_result.extension
data_gen = tts_result.data_gen
@ -1096,7 +1078,6 @@ class TextToSpeechUrlView(HomeAssistantView):
try:
stream = self.manager.async_create_result_stream(
engine,
message,
use_file_cache=use_file_cache,
language=language,
options=options,
@ -1105,6 +1086,8 @@ class TextToSpeechUrlView(HomeAssistantView):
_LOGGER.error("Error on init tts: %s", err)
return self.json({"error": err}, HTTPStatus.BAD_REQUEST)
stream.async_set_message(message)
base = get_url(self.manager.hass)
url = base + stream.url

View File

@ -69,14 +69,20 @@ class MediaSourceOptions(TypedDict):
"""Media source options."""
engine: str
message: str
language: str | None
options: dict | None
use_file_cache: bool | None
class ParsedMediaSourceId(TypedDict):
"""Parsed media source ID."""
options: MediaSourceOptions
message: str
@callback
def media_source_id_to_kwargs(media_source_id: str) -> MediaSourceOptions:
def parse_media_source_id(media_source_id: str) -> ParsedMediaSourceId:
"""Turn a media source ID into options."""
parsed = URL(media_source_id)
if URL_QUERY_TTS_OPTIONS in parsed.query:
@ -94,7 +100,6 @@ def media_source_id_to_kwargs(media_source_id: str) -> MediaSourceOptions:
raise Unresolvable("No message specified.")
kwargs: MediaSourceOptions = {
"engine": parsed.name,
"message": parsed.query["message"],
"language": parsed.query.get("language"),
"options": options,
"use_file_cache": None,
@ -102,7 +107,7 @@ def media_source_id_to_kwargs(media_source_id: str) -> MediaSourceOptions:
if "cache" in parsed.query:
kwargs["use_file_cache"] = parsed.query["cache"] == "true"
return kwargs
return {"message": parsed.query["message"], "options": kwargs}
class TTSMediaSource(MediaSource):
@ -118,9 +123,11 @@ class TTSMediaSource(MediaSource):
async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia:
"""Resolve media to a url."""
try:
parsed = parse_media_source_id(item.identifier)
stream = self.hass.data[DATA_TTS_MANAGER].async_create_result_stream(
**media_source_id_to_kwargs(item.identifier)
**parsed["options"]
)
stream.async_set_message(parsed["message"])
except Unresolvable:
raise
except HomeAssistantError as err:

View File

@ -42,6 +42,7 @@ from tests.typing import ClientSessionGenerator
DEFAULT_LANG = "en_US"
SUPPORT_LANGUAGES = ["de_CH", "de_DE", "en_GB", "en_US"]
TEST_DOMAIN = "test"
MOCK_DATA = b"123"
def mock_tts_get_cache_files_fixture_helper() -> Generator[MagicMock]:
@ -164,7 +165,7 @@ class BaseProvider:
self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType:
"""Load TTS dat."""
return ("mp3", b"")
return ("mp3", MOCK_DATA)
class MockTTSProvider(BaseProvider, Provider):

View File

@ -27,6 +27,7 @@ from homeassistant.util import dt as dt_util
from .common import (
DEFAULT_LANG,
MOCK_DATA,
TEST_DOMAIN,
MockResultStream,
MockTTS,
@ -808,7 +809,7 @@ async def test_service_receive_voice(
await hass.async_block_till_done()
client = await hass_client()
req = await client.get(url)
tts_data = b""
tts_data = MOCK_DATA
tts_data = tts.SpeechManager.write_tags(
f"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_{expected_url_suffix}.mp3",
tts_data,
@ -879,7 +880,7 @@ async def test_service_receive_voice_german(
await hass.async_block_till_done()
client = await hass_client()
req = await client.get(url)
tts_data = b""
tts_data = MOCK_DATA
tts_data = tts.SpeechManager.write_tags(
"42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_-_{expected_url_suffix}.mp3",
tts_data,
@ -1021,7 +1022,7 @@ async def test_setup_legacy_cache_dir(
"""Set up a TTS platform with cache and call service without cache."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
tts_data = b""
tts_data = MOCK_DATA
cache_file = (
mock_tts_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3"
)
@ -1059,7 +1060,7 @@ async def test_setup_cache_dir(
"""Set up a TTS platform with cache and call service without cache."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
tts_data = b""
tts_data = MOCK_DATA
cache_file = mock_tts_cache_dir / (
"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_tts.test.mp3"
)
@ -1165,7 +1166,7 @@ async def test_legacy_cannot_retrieve_without_token(
hass_client: ClientSessionGenerator,
) -> None:
"""Verify that a TTS cannot be retrieved by filename directly."""
tts_data = b""
tts_data = MOCK_DATA
cache_file = (
mock_tts_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3"
)
@ -1188,7 +1189,7 @@ async def test_cannot_retrieve_without_token(
hass_client: ClientSessionGenerator,
) -> None:
"""Verify that a TTS cannot be retrieved by filename directly."""
tts_data = b""
tts_data = MOCK_DATA
cache_file = mock_tts_cache_dir / (
"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_tts.test.mp3"
)
@ -1845,6 +1846,9 @@ async def test_stream(hass: HomeAssistant, mock_tts_entity: MockTTSEntity) -> No
assert stream.language == mock_tts_entity.default_language
assert stream.options == (mock_tts_entity.default_options or {})
assert tts.async_get_stream(hass, stream.token) is stream
stream.async_set_message("beer")
result_data = b"".join([chunk async for chunk in stream.async_stream_result()])
assert result_data == MOCK_DATA
data = b"beer"
stream2 = MockResultStream(hass, "wav", data)

View File

@ -9,9 +9,8 @@ import pytest
from homeassistant.components import media_source
from homeassistant.components.media_player import BrowseError
from homeassistant.components.tts.media_source import (
MediaSourceOptions,
generate_media_source_id,
media_source_id_to_kwargs,
parse_media_source_id,
)
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
@ -249,13 +248,13 @@ async def test_resolving_errors(hass: HomeAssistant, setup: str, engine: str) ->
],
indirect=["setup"],
)
async def test_generate_media_source_id_and_media_source_id_to_kwargs(
async def test_generate_media_source_id_and_parse_media_source_id(
hass: HomeAssistant,
setup: str,
result_engine: str,
) -> None:
"""Test media_source_id and media_source_id_to_kwargs."""
kwargs: MediaSourceOptions = {
"""Test media_source_id and parse_media_source_id."""
kwargs = {
"engine": None,
"message": "hello",
"language": "en_US",
@ -263,12 +262,14 @@ async def test_generate_media_source_id_and_media_source_id_to_kwargs(
"cache": True,
}
media_source_id = generate_media_source_id(hass, **kwargs)
assert media_source_id_to_kwargs(media_source_id) == {
"engine": result_engine,
assert parse_media_source_id(media_source_id) == {
"message": "hello",
"language": "en_US",
"options": {"age": 5},
"use_file_cache": True,
"options": {
"engine": result_engine,
"language": "en_US",
"options": {"age": 5},
"use_file_cache": True,
},
}
kwargs = {
@ -279,12 +280,14 @@ async def test_generate_media_source_id_and_media_source_id_to_kwargs(
"cache": True,
}
media_source_id = generate_media_source_id(hass, **kwargs)
assert media_source_id_to_kwargs(media_source_id) == {
"engine": result_engine,
assert parse_media_source_id(media_source_id) == {
"message": "hello",
"language": "en_US",
"options": {"age": [5, 6]},
"use_file_cache": True,
"options": {
"engine": result_engine,
"language": "en_US",
"options": {"age": [5, 6]},
"use_file_cache": True,
},
}
kwargs = {
@ -295,10 +298,12 @@ async def test_generate_media_source_id_and_media_source_id_to_kwargs(
"cache": True,
}
media_source_id = generate_media_source_id(hass, **kwargs)
assert media_source_id_to_kwargs(media_source_id) == {
"engine": result_engine,
assert parse_media_source_id(media_source_id) == {
"message": "hello",
"language": "en_US",
"options": {"age": {"k1": [5, 6], "k2": "v2"}},
"use_file_cache": True,
"options": {
"engine": result_engine,
"language": "en_US",
"options": {"age": {"k1": [5, 6], "k2": "v2"}},
"use_file_cache": True,
},
}