mirror of
https://github.com/home-assistant/core.git
synced 2025-07-24 05:37:44 +00:00
Clean up Text-to-Speech (#143744)
This commit is contained in:
parent
97084e9382
commit
f980434046
@ -3,8 +3,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from collections.abc import AsyncGenerator, MutableMapping
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
import hashlib
|
||||
from http import HTTPStatus
|
||||
@ -15,7 +15,7 @@ import os
|
||||
import re
|
||||
import secrets
|
||||
from time import monotonic
|
||||
from typing import Any, Final
|
||||
from typing import Any, Final, Generic, Protocol, TypeVar
|
||||
|
||||
from aiohttp import web
|
||||
import mutagen
|
||||
@ -60,10 +60,10 @@ from .const import (
|
||||
DOMAIN,
|
||||
TtsAudioType,
|
||||
)
|
||||
from .entity import TextToSpeechEntity, TTSAudioRequest
|
||||
from .entity import TextToSpeechEntity, TTSAudioRequest, TTSAudioResponse
|
||||
from .helper import get_engine_instance
|
||||
from .legacy import PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE, Provider, async_setup_legacy
|
||||
from .media_source import generate_media_source_id, media_source_id_to_kwargs
|
||||
from .media_source import generate_media_source_id, parse_media_source_id
|
||||
from .models import Voice
|
||||
|
||||
__all__ = [
|
||||
@ -79,6 +79,7 @@ __all__ = [
|
||||
"Provider",
|
||||
"ResultStream",
|
||||
"SampleFormat",
|
||||
"TTSAudioResponse",
|
||||
"TextToSpeechEntity",
|
||||
"TtsAudioType",
|
||||
"Voice",
|
||||
@ -264,7 +265,7 @@ def async_create_stream(
|
||||
@callback
|
||||
def async_get_stream(hass: HomeAssistant, token: str) -> ResultStream | None:
|
||||
"""Return a result stream given a token."""
|
||||
return hass.data[DATA_TTS_MANAGER].token_to_stream.get(token)
|
||||
return hass.data[DATA_TTS_MANAGER].async_get_result_stream(token)
|
||||
|
||||
|
||||
async def async_get_media_source_audio(
|
||||
@ -272,12 +273,11 @@ async def async_get_media_source_audio(
|
||||
media_source_id: str,
|
||||
) -> tuple[str, bytes]:
|
||||
"""Get TTS audio as extension, data."""
|
||||
manager = hass.data[DATA_TTS_MANAGER]
|
||||
cache = manager.async_cache_message_in_memory(
|
||||
**media_source_id_to_kwargs(media_source_id)
|
||||
)
|
||||
data = b"".join([chunk async for chunk in cache.async_stream_data()])
|
||||
return cache.extension, data
|
||||
parsed = parse_media_source_id(media_source_id)
|
||||
stream = hass.data[DATA_TTS_MANAGER].async_create_result_stream(**parsed["options"])
|
||||
stream.async_set_message(parsed["message"])
|
||||
data = b"".join([chunk async for chunk in stream.async_stream_result()])
|
||||
return stream.extension, data
|
||||
|
||||
|
||||
@callback
|
||||
@ -457,6 +457,8 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
class ResultStream:
|
||||
"""Class that will stream the result when available."""
|
||||
|
||||
last_used: float = field(default_factory=monotonic, init=False)
|
||||
|
||||
# Streaming/conversion properties
|
||||
token: str
|
||||
extension: str
|
||||
@ -480,11 +482,6 @@ class ResultStream:
|
||||
"""Get the future that returns the cache."""
|
||||
return asyncio.Future()
|
||||
|
||||
@callback
|
||||
def async_set_message_cache(self, cache: TTSCache) -> None:
|
||||
"""Set cache containing message audio to be streamed."""
|
||||
self._result_cache.set_result(cache)
|
||||
|
||||
@callback
|
||||
def async_set_message(self, message: str) -> None:
|
||||
"""Set message to be generated."""
|
||||
@ -504,6 +501,8 @@ class ResultStream:
|
||||
async for chunk in cache.async_stream_data():
|
||||
yield chunk
|
||||
|
||||
self.last_used = monotonic()
|
||||
|
||||
|
||||
def _hash_options(options: dict) -> str:
|
||||
"""Hashes an options dictionary."""
|
||||
@ -515,13 +514,25 @@ def _hash_options(options: dict) -> str:
|
||||
return opts_hash.hexdigest()
|
||||
|
||||
|
||||
class MemcacheCleanup:
|
||||
class HasLastUsed(Protocol):
|
||||
"""Protocol for objects that have a last_used attribute."""
|
||||
|
||||
last_used: float
|
||||
|
||||
|
||||
T = TypeVar("T", bound=HasLastUsed)
|
||||
|
||||
|
||||
class DictCleaning(Generic[T]):
|
||||
"""Helper to clean up the stale sessions."""
|
||||
|
||||
unsub: CALLBACK_TYPE | None = None
|
||||
|
||||
def __init__(
|
||||
self, hass: HomeAssistant, maxage: float, memcache: dict[str, TTSCache]
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
maxage: float,
|
||||
memcache: MutableMapping[str, T],
|
||||
) -> None:
|
||||
"""Initialize the cleanup."""
|
||||
self.hass = hass
|
||||
@ -588,8 +599,9 @@ class SpeechManager:
|
||||
self.file_cache: dict[str, str] = {}
|
||||
self.mem_cache: dict[str, TTSCache] = {}
|
||||
self.token_to_stream: dict[str, ResultStream] = {}
|
||||
self.memcache_cleanup = MemcacheCleanup(
|
||||
hass, memory_cache_maxage, self.mem_cache
|
||||
self.memcache_cleanup = DictCleaning(hass, memory_cache_maxage, self.mem_cache)
|
||||
self.token_to_stream_cleanup = DictCleaning(
|
||||
hass, memory_cache_maxage, self.token_to_stream
|
||||
)
|
||||
|
||||
def _init_cache(self) -> dict[str, str]:
|
||||
@ -679,11 +691,21 @@ class SpeechManager:
|
||||
|
||||
return language, merged_options
|
||||
|
||||
@callback
|
||||
def async_get_result_stream(
|
||||
self,
|
||||
token: str,
|
||||
) -> ResultStream | None:
|
||||
"""Return a result stream given a token."""
|
||||
stream = self.token_to_stream.get(token, None)
|
||||
if stream:
|
||||
stream.last_used = monotonic()
|
||||
return stream
|
||||
|
||||
@callback
|
||||
def async_create_result_stream(
|
||||
self,
|
||||
engine: str,
|
||||
message: str | None = None,
|
||||
use_file_cache: bool | None = None,
|
||||
language: str | None = None,
|
||||
options: dict | None = None,
|
||||
@ -710,23 +732,7 @@ class SpeechManager:
|
||||
_manager=self,
|
||||
)
|
||||
self.token_to_stream[token] = result_stream
|
||||
|
||||
if message is None:
|
||||
return result_stream
|
||||
|
||||
# We added this method as an alternative to stream.async_set_message
|
||||
# to avoid the options being processed twice
|
||||
result_stream.async_set_message_cache(
|
||||
self._async_ensure_cached_in_memory(
|
||||
engine=engine,
|
||||
engine_instance=engine_instance,
|
||||
message=message,
|
||||
use_file_cache=use_file_cache,
|
||||
language=language,
|
||||
options=options,
|
||||
)
|
||||
)
|
||||
|
||||
self.token_to_stream_cleanup.schedule()
|
||||
return result_stream
|
||||
|
||||
@callback
|
||||
@ -734,41 +740,17 @@ class SpeechManager:
|
||||
self,
|
||||
engine: str,
|
||||
message: str,
|
||||
use_file_cache: bool | None = None,
|
||||
language: str | None = None,
|
||||
options: dict | None = None,
|
||||
) -> TTSCache:
|
||||
"""Make sure a message is cached in memory and returns cache key."""
|
||||
if (engine_instance := get_engine_instance(self.hass, engine)) is None:
|
||||
raise HomeAssistantError(f"Provider {engine} not found")
|
||||
|
||||
language, options = self.process_options(engine_instance, language, options)
|
||||
if use_file_cache is None:
|
||||
use_file_cache = self.use_file_cache
|
||||
|
||||
return self._async_ensure_cached_in_memory(
|
||||
engine=engine,
|
||||
engine_instance=engine_instance,
|
||||
message=message,
|
||||
use_file_cache=use_file_cache,
|
||||
language=language,
|
||||
options=options,
|
||||
)
|
||||
|
||||
@callback
|
||||
def _async_ensure_cached_in_memory(
|
||||
self,
|
||||
engine: str,
|
||||
engine_instance: TextToSpeechEntity | Provider,
|
||||
message: str,
|
||||
use_file_cache: bool,
|
||||
language: str,
|
||||
options: dict,
|
||||
) -> TTSCache:
|
||||
"""Ensure a message is cached.
|
||||
"""Make sure a message will be cached in memory and returns cache object.
|
||||
|
||||
Requires options, language to be processed.
|
||||
"""
|
||||
if (engine_instance := get_engine_instance(self.hass, engine)) is None:
|
||||
raise HomeAssistantError(f"Provider {engine} not found")
|
||||
|
||||
options_key = _hash_options(options) if options else "-"
|
||||
msg_hash = hashlib.sha1(bytes(message, "utf-8")).hexdigest()
|
||||
cache_key = KEY_PATTERN.format(
|
||||
@ -789,9 +771,13 @@ class SpeechManager:
|
||||
store_to_disk = False
|
||||
else:
|
||||
_LOGGER.debug("Generating audio for %s", message[0:32])
|
||||
|
||||
async def message_stream() -> AsyncGenerator[str]:
|
||||
yield message
|
||||
|
||||
extension = options.get(ATTR_PREFERRED_FORMAT, _DEFAULT_FORMAT)
|
||||
data_gen = self._async_generate_tts_audio(
|
||||
engine_instance, message, language, options
|
||||
engine_instance, message_stream(), language, options
|
||||
)
|
||||
|
||||
cache = TTSCache(
|
||||
@ -799,7 +785,6 @@ class SpeechManager:
|
||||
extension=extension,
|
||||
data_gen=data_gen,
|
||||
)
|
||||
|
||||
self.mem_cache[cache_key] = cache
|
||||
self.hass.async_create_background_task(
|
||||
self._load_data_into_cache(
|
||||
@ -866,7 +851,7 @@ class SpeechManager:
|
||||
async def _async_generate_tts_audio(
|
||||
self,
|
||||
engine_instance: TextToSpeechEntity | Provider,
|
||||
message: str,
|
||||
message_stream: AsyncGenerator[str],
|
||||
language: str,
|
||||
options: dict[str, Any],
|
||||
) -> AsyncGenerator[bytes]:
|
||||
@ -915,6 +900,7 @@ class SpeechManager:
|
||||
raise HomeAssistantError("TTS engine name is not set.")
|
||||
|
||||
if isinstance(engine_instance, Provider):
|
||||
message = "".join([chunk async for chunk in message_stream])
|
||||
extension, data = await engine_instance.async_get_tts_audio(
|
||||
message, language, options
|
||||
)
|
||||
@ -930,12 +916,8 @@ class SpeechManager:
|
||||
data_gen = make_data_generator(data)
|
||||
|
||||
else:
|
||||
|
||||
async def message_gen() -> AsyncGenerator[str]:
|
||||
yield message
|
||||
|
||||
tts_result = await engine_instance.internal_async_stream_tts_audio(
|
||||
TTSAudioRequest(language, options, message_gen())
|
||||
TTSAudioRequest(language, options, message_stream)
|
||||
)
|
||||
extension = tts_result.extension
|
||||
data_gen = tts_result.data_gen
|
||||
@ -1096,7 +1078,6 @@ class TextToSpeechUrlView(HomeAssistantView):
|
||||
try:
|
||||
stream = self.manager.async_create_result_stream(
|
||||
engine,
|
||||
message,
|
||||
use_file_cache=use_file_cache,
|
||||
language=language,
|
||||
options=options,
|
||||
@ -1105,6 +1086,8 @@ class TextToSpeechUrlView(HomeAssistantView):
|
||||
_LOGGER.error("Error on init tts: %s", err)
|
||||
return self.json({"error": err}, HTTPStatus.BAD_REQUEST)
|
||||
|
||||
stream.async_set_message(message)
|
||||
|
||||
base = get_url(self.manager.hass)
|
||||
url = base + stream.url
|
||||
|
||||
|
@ -69,14 +69,20 @@ class MediaSourceOptions(TypedDict):
|
||||
"""Media source options."""
|
||||
|
||||
engine: str
|
||||
message: str
|
||||
language: str | None
|
||||
options: dict | None
|
||||
use_file_cache: bool | None
|
||||
|
||||
|
||||
class ParsedMediaSourceId(TypedDict):
|
||||
"""Parsed media source ID."""
|
||||
|
||||
options: MediaSourceOptions
|
||||
message: str
|
||||
|
||||
|
||||
@callback
|
||||
def media_source_id_to_kwargs(media_source_id: str) -> MediaSourceOptions:
|
||||
def parse_media_source_id(media_source_id: str) -> ParsedMediaSourceId:
|
||||
"""Turn a media source ID into options."""
|
||||
parsed = URL(media_source_id)
|
||||
if URL_QUERY_TTS_OPTIONS in parsed.query:
|
||||
@ -94,7 +100,6 @@ def media_source_id_to_kwargs(media_source_id: str) -> MediaSourceOptions:
|
||||
raise Unresolvable("No message specified.")
|
||||
kwargs: MediaSourceOptions = {
|
||||
"engine": parsed.name,
|
||||
"message": parsed.query["message"],
|
||||
"language": parsed.query.get("language"),
|
||||
"options": options,
|
||||
"use_file_cache": None,
|
||||
@ -102,7 +107,7 @@ def media_source_id_to_kwargs(media_source_id: str) -> MediaSourceOptions:
|
||||
if "cache" in parsed.query:
|
||||
kwargs["use_file_cache"] = parsed.query["cache"] == "true"
|
||||
|
||||
return kwargs
|
||||
return {"message": parsed.query["message"], "options": kwargs}
|
||||
|
||||
|
||||
class TTSMediaSource(MediaSource):
|
||||
@ -118,9 +123,11 @@ class TTSMediaSource(MediaSource):
|
||||
async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia:
|
||||
"""Resolve media to a url."""
|
||||
try:
|
||||
parsed = parse_media_source_id(item.identifier)
|
||||
stream = self.hass.data[DATA_TTS_MANAGER].async_create_result_stream(
|
||||
**media_source_id_to_kwargs(item.identifier)
|
||||
**parsed["options"]
|
||||
)
|
||||
stream.async_set_message(parsed["message"])
|
||||
except Unresolvable:
|
||||
raise
|
||||
except HomeAssistantError as err:
|
||||
|
@ -42,6 +42,7 @@ from tests.typing import ClientSessionGenerator
|
||||
DEFAULT_LANG = "en_US"
|
||||
SUPPORT_LANGUAGES = ["de_CH", "de_DE", "en_GB", "en_US"]
|
||||
TEST_DOMAIN = "test"
|
||||
MOCK_DATA = b"123"
|
||||
|
||||
|
||||
def mock_tts_get_cache_files_fixture_helper() -> Generator[MagicMock]:
|
||||
@ -164,7 +165,7 @@ class BaseProvider:
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> TtsAudioType:
|
||||
"""Load TTS dat."""
|
||||
return ("mp3", b"")
|
||||
return ("mp3", MOCK_DATA)
|
||||
|
||||
|
||||
class MockTTSProvider(BaseProvider, Provider):
|
||||
|
@ -27,6 +27,7 @@ from homeassistant.util import dt as dt_util
|
||||
|
||||
from .common import (
|
||||
DEFAULT_LANG,
|
||||
MOCK_DATA,
|
||||
TEST_DOMAIN,
|
||||
MockResultStream,
|
||||
MockTTS,
|
||||
@ -808,7 +809,7 @@ async def test_service_receive_voice(
|
||||
await hass.async_block_till_done()
|
||||
client = await hass_client()
|
||||
req = await client.get(url)
|
||||
tts_data = b""
|
||||
tts_data = MOCK_DATA
|
||||
tts_data = tts.SpeechManager.write_tags(
|
||||
f"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_{expected_url_suffix}.mp3",
|
||||
tts_data,
|
||||
@ -879,7 +880,7 @@ async def test_service_receive_voice_german(
|
||||
await hass.async_block_till_done()
|
||||
client = await hass_client()
|
||||
req = await client.get(url)
|
||||
tts_data = b""
|
||||
tts_data = MOCK_DATA
|
||||
tts_data = tts.SpeechManager.write_tags(
|
||||
"42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_-_{expected_url_suffix}.mp3",
|
||||
tts_data,
|
||||
@ -1021,7 +1022,7 @@ async def test_setup_legacy_cache_dir(
|
||||
"""Set up a TTS platform with cache and call service without cache."""
|
||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
||||
tts_data = b""
|
||||
tts_data = MOCK_DATA
|
||||
cache_file = (
|
||||
mock_tts_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3"
|
||||
)
|
||||
@ -1059,7 +1060,7 @@ async def test_setup_cache_dir(
|
||||
"""Set up a TTS platform with cache and call service without cache."""
|
||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
||||
tts_data = b""
|
||||
tts_data = MOCK_DATA
|
||||
cache_file = mock_tts_cache_dir / (
|
||||
"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_tts.test.mp3"
|
||||
)
|
||||
@ -1165,7 +1166,7 @@ async def test_legacy_cannot_retrieve_without_token(
|
||||
hass_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Verify that a TTS cannot be retrieved by filename directly."""
|
||||
tts_data = b""
|
||||
tts_data = MOCK_DATA
|
||||
cache_file = (
|
||||
mock_tts_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3"
|
||||
)
|
||||
@ -1188,7 +1189,7 @@ async def test_cannot_retrieve_without_token(
|
||||
hass_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Verify that a TTS cannot be retrieved by filename directly."""
|
||||
tts_data = b""
|
||||
tts_data = MOCK_DATA
|
||||
cache_file = mock_tts_cache_dir / (
|
||||
"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_tts.test.mp3"
|
||||
)
|
||||
@ -1845,6 +1846,9 @@ async def test_stream(hass: HomeAssistant, mock_tts_entity: MockTTSEntity) -> No
|
||||
assert stream.language == mock_tts_entity.default_language
|
||||
assert stream.options == (mock_tts_entity.default_options or {})
|
||||
assert tts.async_get_stream(hass, stream.token) is stream
|
||||
stream.async_set_message("beer")
|
||||
result_data = b"".join([chunk async for chunk in stream.async_stream_result()])
|
||||
assert result_data == MOCK_DATA
|
||||
|
||||
data = b"beer"
|
||||
stream2 = MockResultStream(hass, "wav", data)
|
||||
|
@ -9,9 +9,8 @@ import pytest
|
||||
from homeassistant.components import media_source
|
||||
from homeassistant.components.media_player import BrowseError
|
||||
from homeassistant.components.tts.media_source import (
|
||||
MediaSourceOptions,
|
||||
generate_media_source_id,
|
||||
media_source_id_to_kwargs,
|
||||
parse_media_source_id,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
@ -249,13 +248,13 @@ async def test_resolving_errors(hass: HomeAssistant, setup: str, engine: str) ->
|
||||
],
|
||||
indirect=["setup"],
|
||||
)
|
||||
async def test_generate_media_source_id_and_media_source_id_to_kwargs(
|
||||
async def test_generate_media_source_id_and_parse_media_source_id(
|
||||
hass: HomeAssistant,
|
||||
setup: str,
|
||||
result_engine: str,
|
||||
) -> None:
|
||||
"""Test media_source_id and media_source_id_to_kwargs."""
|
||||
kwargs: MediaSourceOptions = {
|
||||
"""Test media_source_id and parse_media_source_id."""
|
||||
kwargs = {
|
||||
"engine": None,
|
||||
"message": "hello",
|
||||
"language": "en_US",
|
||||
@ -263,12 +262,14 @@ async def test_generate_media_source_id_and_media_source_id_to_kwargs(
|
||||
"cache": True,
|
||||
}
|
||||
media_source_id = generate_media_source_id(hass, **kwargs)
|
||||
assert media_source_id_to_kwargs(media_source_id) == {
|
||||
"engine": result_engine,
|
||||
assert parse_media_source_id(media_source_id) == {
|
||||
"message": "hello",
|
||||
"language": "en_US",
|
||||
"options": {"age": 5},
|
||||
"use_file_cache": True,
|
||||
"options": {
|
||||
"engine": result_engine,
|
||||
"language": "en_US",
|
||||
"options": {"age": 5},
|
||||
"use_file_cache": True,
|
||||
},
|
||||
}
|
||||
|
||||
kwargs = {
|
||||
@ -279,12 +280,14 @@ async def test_generate_media_source_id_and_media_source_id_to_kwargs(
|
||||
"cache": True,
|
||||
}
|
||||
media_source_id = generate_media_source_id(hass, **kwargs)
|
||||
assert media_source_id_to_kwargs(media_source_id) == {
|
||||
"engine": result_engine,
|
||||
assert parse_media_source_id(media_source_id) == {
|
||||
"message": "hello",
|
||||
"language": "en_US",
|
||||
"options": {"age": [5, 6]},
|
||||
"use_file_cache": True,
|
||||
"options": {
|
||||
"engine": result_engine,
|
||||
"language": "en_US",
|
||||
"options": {"age": [5, 6]},
|
||||
"use_file_cache": True,
|
||||
},
|
||||
}
|
||||
|
||||
kwargs = {
|
||||
@ -295,10 +298,12 @@ async def test_generate_media_source_id_and_media_source_id_to_kwargs(
|
||||
"cache": True,
|
||||
}
|
||||
media_source_id = generate_media_source_id(hass, **kwargs)
|
||||
assert media_source_id_to_kwargs(media_source_id) == {
|
||||
"engine": result_engine,
|
||||
assert parse_media_source_id(media_source_id) == {
|
||||
"message": "hello",
|
||||
"language": "en_US",
|
||||
"options": {"age": {"k1": [5, 6], "k2": "v2"}},
|
||||
"use_file_cache": True,
|
||||
"options": {
|
||||
"engine": result_engine,
|
||||
"language": "en_US",
|
||||
"options": {"age": {"k1": [5, 6], "k2": "v2"}},
|
||||
"use_file_cache": True,
|
||||
},
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user