Text-to-Speech refactor (#139482)

* Refactor TTS

* More cleanup

* Cleanup

* Consolidate more

* Inline another function

* Inline another function

* Improve cleanup
This commit is contained in:
Paulus Schoutsen 2025-02-28 18:36:12 +00:00 committed by GitHub
parent 49c27ae7bc
commit 70bb56e0fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 357 additions and 268 deletions

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import Mapping from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
import hashlib import hashlib
@ -16,6 +17,7 @@ import re
import secrets import secrets
import subprocess import subprocess
import tempfile import tempfile
from time import monotonic
from typing import Any, Final, TypedDict, final from typing import Any, Final, TypedDict, final
from aiohttp import web from aiohttp import web
@ -37,11 +39,20 @@ from homeassistant.components.media_player import (
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ( from homeassistant.const import (
ATTR_ENTITY_ID, ATTR_ENTITY_ID,
EVENT_HOMEASSISTANT_STOP,
PLATFORM_FORMAT, PLATFORM_FORMAT,
STATE_UNAVAILABLE, STATE_UNAVAILABLE,
STATE_UNKNOWN, STATE_UNKNOWN,
) )
from homeassistant.core import HassJob, HomeAssistant, ServiceCall, callback from homeassistant.core import (
CALLBACK_TYPE,
Event,
HassJob,
HassJobType,
HomeAssistant,
ServiceCall,
callback,
)
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.entity_component import EntityComponent
@ -129,9 +140,10 @@ SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({})
class TTSCache(TypedDict): class TTSCache(TypedDict):
"""Cached TTS file.""" """Cached TTS file."""
filename: str extension: str
voice: bytes voice: bytes
pending: asyncio.Task | None pending: asyncio.Task | None
last_used: float
@callback @callback
@ -192,9 +204,11 @@ async def async_get_media_source_audio(
media_source_id: str, media_source_id: str,
) -> tuple[str, bytes]: ) -> tuple[str, bytes]:
"""Get TTS audio as extension, data.""" """Get TTS audio as extension, data."""
return await hass.data[DATA_TTS_MANAGER].async_get_tts_audio( manager = hass.data[DATA_TTS_MANAGER]
**media_source_id_to_kwargs(media_source_id), cache_key = manager.async_cache_message_in_memory(
**media_source_id_to_kwargs(media_source_id)
) )
return await manager.async_get_tts_audio(cache_key)
@callback @callback
@ -306,11 +320,11 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
# Legacy config options # Legacy config options
conf = config[DOMAIN][0] if config.get(DOMAIN) else {} conf = config[DOMAIN][0] if config.get(DOMAIN) else {}
use_cache: bool = conf.get(CONF_CACHE, DEFAULT_CACHE) use_file_cache: bool = conf.get(CONF_CACHE, DEFAULT_CACHE)
cache_dir: str = conf.get(CONF_CACHE_DIR, DEFAULT_CACHE_DIR) cache_dir: str = conf.get(CONF_CACHE_DIR, DEFAULT_CACHE_DIR)
time_memory: int = conf.get(CONF_TIME_MEMORY, DEFAULT_TIME_MEMORY) memory_cache_maxage: int = conf.get(CONF_TIME_MEMORY, DEFAULT_TIME_MEMORY)
tts = SpeechManager(hass, use_cache, cache_dir, time_memory) tts = SpeechManager(hass, use_file_cache, cache_dir, memory_cache_maxage)
try: try:
await tts.async_init_cache() await tts.async_init_cache()
@ -383,6 +397,40 @@ CACHED_PROPERTIES_WITH_ATTR_ = {
} }
@dataclass
class ResultStream:
"""Class that will stream the result when available."""
# Streaming/conversion properties
url: str
extension: str
content_type: str
# TTS properties
engine: str
use_file_cache: bool
language: str
options: dict
_manager: SpeechManager
@cached_property
def _result_cache_key(self) -> asyncio.Future[str]:
"""Get the future that returns the cache key."""
return asyncio.Future()
@callback
def async_set_message_cache_key(self, cache_key: str) -> None:
"""Set cache key for message to be streamed."""
self._result_cache_key.set_result(cache_key)
async def async_get_result(self) -> bytes:
"""Get the stream of this result."""
cache_key = await self._result_cache_key
_extension, data = await self._manager.async_get_tts_audio(cache_key)
return data
class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH_ATTR_): class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH_ATTR_):
"""Represent a single TTS engine.""" """Represent a single TTS engine."""
@ -521,29 +569,82 @@ def _hash_options(options: dict) -> str:
return opts_hash.hexdigest() return opts_hash.hexdigest()
class MemcacheCleanup:
"""Helper to clean up the stale sessions."""
unsub: CALLBACK_TYPE | None = None
def __init__(
self, hass: HomeAssistant, maxage: float, memcache: dict[str, TTSCache]
) -> None:
"""Initialize the cleanup."""
self.hass = hass
self.maxage = maxage
self.memcache = memcache
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self._on_hass_stop)
self.cleanup_job = HassJob(
self._cleanup, "chat_session_cleanup", job_type=HassJobType.Callback
)
@callback
def schedule(self) -> None:
"""Schedule the cleanup."""
if self.unsub:
return
self.unsub = async_call_later(
self.hass,
self.maxage + 1,
self.cleanup_job,
)
@callback
def _on_hass_stop(self, event: Event) -> None:
"""Cancel the cleanup on shutdown."""
if self.unsub:
self.unsub()
self.unsub = None
@callback
def _cleanup(self, _now: datetime) -> None:
"""Clean up and schedule follow-up if necessary."""
self.unsub = None
memcache = self.memcache
maxage = self.maxage
now = monotonic()
for cache_key, info in list(memcache.items()):
if info["last_used"] + maxage < now:
_LOGGER.debug("Cleaning up %s", cache_key)
del memcache[cache_key]
# Still items left, check again in timeout time.
if memcache:
self.schedule()
class SpeechManager: class SpeechManager:
"""Representation of a speech store.""" """Representation of a speech store."""
def __init__( def __init__(
self, self,
hass: HomeAssistant, hass: HomeAssistant,
use_cache: bool, use_file_cache: bool,
cache_dir: str, cache_dir: str,
time_memory: int, memory_cache_maxage: int,
) -> None: ) -> None:
"""Initialize a speech store.""" """Initialize a speech store."""
self.hass = hass self.hass = hass
self.providers: dict[str, Provider] = {} self.providers: dict[str, Provider] = {}
self.use_cache = use_cache self.use_file_cache = use_file_cache
self.cache_dir = cache_dir self.cache_dir = cache_dir
self.time_memory = time_memory self.memory_cache_maxage = memory_cache_maxage
self.file_cache: dict[str, str] = {} self.file_cache: dict[str, str] = {}
self.mem_cache: dict[str, TTSCache] = {} self.mem_cache: dict[str, TTSCache] = {}
self.token_to_stream: dict[str, ResultStream] = {}
# filename <-> token self.memcache_cleanup = MemcacheCleanup(
self.filename_to_token: dict[str, str] = {} hass, memory_cache_maxage, self.mem_cache
self.token_to_filename: dict[str, str] = {} )
def _init_cache(self) -> dict[str, str]: def _init_cache(self) -> dict[str, str]:
"""Init cache folder and fetch files.""" """Init cache folder and fetch files."""
@ -563,18 +664,21 @@ class SpeechManager:
async def async_clear_cache(self) -> None: async def async_clear_cache(self) -> None:
"""Read file cache and delete files.""" """Read file cache and delete files."""
self.mem_cache = {} self.mem_cache.clear()
def remove_files() -> None: def remove_files(files: list[str]) -> None:
"""Remove files from filesystem.""" """Remove files from filesystem."""
for filename in self.file_cache.values(): for filename in files:
try: try:
os.remove(os.path.join(self.cache_dir, filename)) os.remove(os.path.join(self.cache_dir, filename))
except OSError as err: except OSError as err:
_LOGGER.warning("Can't remove cache file '%s': %s", filename, err) _LOGGER.warning("Can't remove cache file '%s': %s", filename, err)
await self.hass.async_add_executor_job(remove_files) task = self.hass.async_add_executor_job(
self.file_cache = {} remove_files, list(self.file_cache.values())
)
self.file_cache.clear()
await task
@callback @callback
def async_register_legacy_engine( def async_register_legacy_engine(
@ -629,107 +733,153 @@ class SpeechManager:
return language, merged_options return language, merged_options
async def async_get_url_path( @callback
def async_create_result_stream(
self, self,
engine: str, engine: str,
message: str, message: str | None = None,
cache: bool | None = None, use_file_cache: bool | None = None,
language: str | None = None, language: str | None = None,
options: dict | None = None, options: dict | None = None,
) -> str: ) -> ResultStream:
"""Get URL for play message. """Create a streaming URL where the rendered TTS can be retrieved."""
This method is a coroutine.
"""
if (engine_instance := get_engine_instance(self.hass, engine)) is None: if (engine_instance := get_engine_instance(self.hass, engine)) is None:
raise HomeAssistantError(f"Provider {engine} not found") raise HomeAssistantError(f"Provider {engine} not found")
language, options = self.process_options(engine_instance, language, options) language, options = self.process_options(engine_instance, language, options)
cache_key = self._generate_cache_key(message, language, options, engine) if use_file_cache is None:
use_cache = cache if cache is not None else self.use_cache use_file_cache = self.use_file_cache
# Is speech already in memory extension = options.get(ATTR_PREFERRED_FORMAT, _DEFAULT_FORMAT)
if cache_key in self.mem_cache: token = f"{secrets.token_urlsafe(16)}.{extension}"
filename = self.mem_cache[cache_key]["filename"] content, _ = mimetypes.guess_type(token)
# Is file store in file cache result_stream = ResultStream(
elif use_cache and cache_key in self.file_cache: url=f"/api/tts_proxy/{token}",
filename = self.file_cache[cache_key] extension=extension,
self.hass.async_create_task(self._async_file_to_mem(cache_key)) content_type=content or "audio/mpeg",
# Load speech from engine into memory use_file_cache=use_file_cache,
else: engine=engine,
filename = await self._async_get_tts_audio( language=language,
engine_instance, cache_key, message, use_cache, language, options options=options,
_manager=self,
) )
self.token_to_stream[token] = result_stream
# Use a randomly generated token instead of exposing the filename if message is None:
token = self.filename_to_token.get(filename) return result_stream
if not token:
# Keep extension (.mp3, etc.)
token = secrets.token_urlsafe(16) + os.path.splitext(filename)[1]
# Map token <-> filename cache_key = self._async_ensure_cached_in_memory(
self.filename_to_token[filename] = token engine=engine,
self.token_to_filename[token] = filename engine_instance=engine_instance,
message=message,
return f"/api/tts_proxy/{token}" use_file_cache=use_file_cache,
language=language,
async def async_get_tts_audio( options=options,
self,
engine: str,
message: str,
cache: bool | None = None,
language: str | None = None,
options: dict | None = None,
) -> tuple[str, bytes]:
"""Fetch TTS audio."""
if (engine_instance := get_engine_instance(self.hass, engine)) is None:
raise HomeAssistantError(f"Provider {engine} not found")
language, options = self.process_options(engine_instance, language, options)
cache_key = self._generate_cache_key(message, language, options, engine)
use_cache = cache if cache is not None else self.use_cache
# If we have the file, load it into memory if necessary
if cache_key not in self.mem_cache:
if use_cache and cache_key in self.file_cache:
await self._async_file_to_mem(cache_key)
else:
await self._async_get_tts_audio(
engine_instance, cache_key, message, use_cache, language, options
) )
result_stream.async_set_message_cache_key(cache_key)
extension = os.path.splitext(self.mem_cache[cache_key]["filename"])[1][1:] return result_stream
cached = self.mem_cache[cache_key]
if pending := cached.get("pending"):
await pending
cached = self.mem_cache[cache_key]
return extension, cached["voice"]
@callback @callback
def _generate_cache_key( def async_cache_message_in_memory(
self, self,
message: str,
language: str,
options: dict | None,
engine: str, engine: str,
message: str,
use_file_cache: bool | None = None,
language: str | None = None,
options: dict | None = None,
) -> str: ) -> str:
"""Generate a cache key for a message.""" """Make sure a message is cached in memory and returns cache key."""
if (engine_instance := get_engine_instance(self.hass, engine)) is None:
raise HomeAssistantError(f"Provider {engine} not found")
language, options = self.process_options(engine_instance, language, options)
if use_file_cache is None:
use_file_cache = self.use_file_cache
return self._async_ensure_cached_in_memory(
engine=engine,
engine_instance=engine_instance,
message=message,
use_file_cache=use_file_cache,
language=language,
options=options,
)
@callback
def _async_ensure_cached_in_memory(
self,
engine: str,
engine_instance: TextToSpeechEntity | Provider,
message: str,
use_file_cache: bool,
language: str,
options: dict,
) -> str:
"""Ensure a message is cached.
Requires options, language to be processed.
"""
options_key = _hash_options(options) if options else "-" options_key = _hash_options(options) if options else "-"
msg_hash = hashlib.sha1(bytes(message, "utf-8")).hexdigest() msg_hash = hashlib.sha1(bytes(message, "utf-8")).hexdigest()
return KEY_PATTERN.format( cache_key = KEY_PATTERN.format(
msg_hash, language.replace("_", "-"), options_key, engine msg_hash, language.replace("_", "-"), options_key, engine
).lower() ).lower()
async def _async_get_tts_audio( # Is speech already in memory
if cache_key in self.mem_cache:
return cache_key
if use_file_cache and cache_key in self.file_cache:
coro = self._async_load_file_to_mem(cache_key)
else:
coro = self._async_generate_tts_audio(
engine_instance, cache_key, message, use_file_cache, language, options
)
task = self.hass.async_create_task(coro, eager_start=False)
def handle_error(future: asyncio.Future) -> None:
"""Handle error."""
if not (err := future.exception()):
return
# Truncate message so we don't flood the logs. Cutting off at 32 chars
# but since we add 3 dots to truncated message, we cut off at 35.
trunc_msg = message if len(message) < 35 else f"{message[0:32]}"
_LOGGER.error("Error generating audio for %s: %s", trunc_msg, err)
self.mem_cache.pop(cache_key, None)
task.add_done_callback(handle_error)
self.mem_cache[cache_key] = {
"extension": "",
"voice": b"",
"pending": task,
"last_used": monotonic(),
}
return cache_key
async def async_get_tts_audio(self, cache_key: str) -> tuple[str, bytes]:
"""Fetch TTS audio."""
cached = self.mem_cache.get(cache_key)
if cached is None:
raise HomeAssistantError("Audio not cached")
if pending := cached.get("pending"):
await pending
cached = self.mem_cache[cache_key]
cached["last_used"] = monotonic()
return cached["extension"], cached["voice"]
async def _async_generate_tts_audio(
self, self,
engine_instance: TextToSpeechEntity | Provider, engine_instance: TextToSpeechEntity | Provider,
cache_key: str, cache_key: str,
message: str, message: str,
cache: bool, cache_to_disk: bool,
language: str, language: str,
options: dict[str, Any], options: dict[str, Any],
) -> str: ) -> None:
"""Receive TTS, store for view in cache and return filename. """Start loading of the TTS audio.
This method is a coroutine. This method is a coroutine.
""" """
@ -773,8 +923,6 @@ class SpeechManager:
if sample_bytes is not None: if sample_bytes is not None:
sample_bytes = int(sample_bytes) sample_bytes = int(sample_bytes)
async def get_tts_data() -> str:
"""Handle data available."""
if engine_instance.name is None or engine_instance.name is UNDEFINED: if engine_instance.name is None or engine_instance.name is UNDEFINED:
raise HomeAssistantError("TTS engine name is not set.") raise HomeAssistantError("TTS engine name is not set.")
@ -830,39 +978,11 @@ class SpeechManager:
filename, data, engine_instance.name, message, language, options filename, data, engine_instance.name, message, language, options
) )
self._async_store_to_memcache(cache_key, filename, data) self._async_store_to_memcache(cache_key, final_extension, data)
if cache: if not cache_to_disk:
self.hass.async_create_task( return
self._async_save_tts_audio(cache_key, filename, data)
)
return filename
audio_task = self.hass.async_create_task(get_tts_data(), eager_start=False)
def handle_error(_future: asyncio.Future) -> None:
"""Handle error."""
if audio_task.exception():
self.mem_cache.pop(cache_key, None)
audio_task.add_done_callback(handle_error)
filename = f"{cache_key}.{final_extension}".lower()
self.mem_cache[cache_key] = {
"filename": filename,
"voice": b"",
"pending": audio_task,
}
return filename
async def _async_save_tts_audio(
self, cache_key: str, filename: str, data: bytes
) -> None:
"""Store voice data to file and file_cache.
This method is a coroutine.
"""
voice_file = os.path.join(self.cache_dir, filename) voice_file = os.path.join(self.cache_dir, filename)
def save_speech() -> None: def save_speech() -> None:
@ -870,13 +990,19 @@ class SpeechManager:
with open(voice_file, "wb") as speech: with open(voice_file, "wb") as speech:
speech.write(data) speech.write(data)
try: # Don't await, we're going to do this in the background
await self.hass.async_add_executor_job(save_speech) task = self.hass.async_add_executor_job(save_speech)
self.file_cache[cache_key] = filename
except OSError as err:
_LOGGER.error("Can't write %s: %s", filename, err)
async def _async_file_to_mem(self, cache_key: str) -> None: def write_done(future: asyncio.Future) -> None:
"""Write is done task."""
if err := future.exception():
_LOGGER.error("Can't write %s: %s", filename, err)
else:
self.file_cache[cache_key] = filename
task.add_done_callback(write_done)
async def _async_load_file_to_mem(self, cache_key: str) -> None:
"""Load voice from file cache into memory. """Load voice from file cache into memory.
This method is a coroutine. This method is a coroutine.
@ -897,64 +1023,22 @@ class SpeechManager:
del self.file_cache[cache_key] del self.file_cache[cache_key]
raise HomeAssistantError(f"Can't read {voice_file}") from err raise HomeAssistantError(f"Can't read {voice_file}") from err
self._async_store_to_memcache(cache_key, filename, data) extension = os.path.splitext(filename)[1][1:]
self._async_store_to_memcache(cache_key, extension, data)
@callback @callback
def _async_store_to_memcache( def _async_store_to_memcache(
self, cache_key: str, filename: str, data: bytes self, cache_key: str, extension: str, data: bytes
) -> None: ) -> None:
"""Store data to memcache and set timer to remove it.""" """Store data to memcache and set timer to remove it."""
self.mem_cache[cache_key] = { self.mem_cache[cache_key] = {
"filename": filename, "extension": extension,
"voice": data, "voice": data,
"pending": None, "pending": None,
"last_used": monotonic(),
} }
self.memcache_cleanup.schedule()
@callback
def async_remove_from_mem(_: datetime) -> None:
"""Cleanup memcache."""
self.mem_cache.pop(cache_key, None)
async_call_later(
self.hass,
self.time_memory,
HassJob(
async_remove_from_mem,
name="tts remove_from_mem",
cancel_on_shutdown=True,
),
)
async def async_read_tts(self, token: str) -> tuple[str | None, bytes]:
"""Read a voice file and return binary.
This method is a coroutine.
"""
filename = self.token_to_filename.get(token)
if not filename:
raise HomeAssistantError(f"{token} was not recognized!")
if not (record := _RE_VOICE_FILE.match(filename.lower())) and not (
record := _RE_LEGACY_VOICE_FILE.match(filename.lower())
):
raise HomeAssistantError("Wrong tts file format!")
cache_key = KEY_PATTERN.format(
record.group(1), record.group(2), record.group(3), record.group(4)
)
if cache_key not in self.mem_cache:
if cache_key not in self.file_cache:
raise HomeAssistantError(f"{cache_key} not in cache!")
await self._async_file_to_mem(cache_key)
cached = self.mem_cache[cache_key]
if pending := cached.get("pending"):
await pending
cached = self.mem_cache[cache_key]
content, _ = mimetypes.guess_type(filename)
return content, cached["voice"]
@staticmethod @staticmethod
def write_tags( def write_tags(
@ -1042,9 +1126,9 @@ class TextToSpeechUrlView(HomeAssistantView):
url = "/api/tts_get_url" url = "/api/tts_get_url"
name = "api:tts:geturl" name = "api:tts:geturl"
def __init__(self, tts: SpeechManager) -> None: def __init__(self, manager: SpeechManager) -> None:
"""Initialize a tts view.""" """Initialize a tts view."""
self.tts = tts self.manager = manager
async def post(self, request: web.Request) -> web.Response: async def post(self, request: web.Request) -> web.Response:
"""Generate speech and provide url.""" """Generate speech and provide url."""
@ -1061,45 +1145,53 @@ class TextToSpeechUrlView(HomeAssistantView):
engine = data.get("engine_id") or data[ATTR_PLATFORM] engine = data.get("engine_id") or data[ATTR_PLATFORM]
message = data[ATTR_MESSAGE] message = data[ATTR_MESSAGE]
cache = data.get(ATTR_CACHE) use_file_cache = data.get(ATTR_CACHE)
language = data.get(ATTR_LANGUAGE) language = data.get(ATTR_LANGUAGE)
options = data.get(ATTR_OPTIONS) options = data.get(ATTR_OPTIONS)
try: try:
path = await self.tts.async_get_url_path( stream = self.manager.async_create_result_stream(
engine, message, cache=cache, language=language, options=options engine,
message,
use_file_cache=use_file_cache,
language=language,
options=options,
) )
except HomeAssistantError as err: except HomeAssistantError as err:
_LOGGER.error("Error on init tts: %s", err) _LOGGER.error("Error on init tts: %s", err)
return self.json({"error": err}, HTTPStatus.BAD_REQUEST) return self.json({"error": err}, HTTPStatus.BAD_REQUEST)
base = get_url(self.tts.hass) base = get_url(self.manager.hass)
url = base + path url = base + stream.url
return self.json({"url": url, "path": path}) return self.json({"url": url, "path": stream.url})
class TextToSpeechView(HomeAssistantView): class TextToSpeechView(HomeAssistantView):
"""TTS view to serve a speech audio.""" """TTS view to serve a speech audio."""
requires_auth = False requires_auth = False
url = "/api/tts_proxy/{filename}" url = "/api/tts_proxy/{token}"
name = "api:tts_speech" name = "api:tts_speech"
def __init__(self, tts: SpeechManager) -> None: def __init__(self, manager: SpeechManager) -> None:
"""Initialize a tts view.""" """Initialize a tts view."""
self.tts = tts self.manager = manager
async def get(self, request: web.Request, filename: str) -> web.Response: async def get(self, request: web.Request, token: str) -> web.Response:
"""Start a get request.""" """Start a get request."""
try: stream = self.manager.token_to_stream.get(token)
# filename is actually token, but we keep its name for compatibility
content, data = await self.tts.async_read_tts(filename) if stream is None:
except HomeAssistantError as err:
_LOGGER.error("Error on load tts: %s", err)
return web.Response(status=HTTPStatus.NOT_FOUND) return web.Response(status=HTTPStatus.NOT_FOUND)
return web.Response(body=data, content_type=content) try:
data = await stream.async_get_result()
except HomeAssistantError as err:
_LOGGER.error("Error on get tts: %s", err)
return web.Response(status=HTTPStatus.INTERNAL_SERVER_ERROR)
return web.Response(body=data, content_type=stream.content_type)
@websocket_api.websocket_command( @websocket_api.websocket_command(

View File

@ -3,7 +3,6 @@
from __future__ import annotations from __future__ import annotations
import json import json
import mimetypes
from typing import TypedDict from typing import TypedDict
from yarl import URL from yarl import URL
@ -73,7 +72,7 @@ class MediaSourceOptions(TypedDict):
message: str message: str
language: str | None language: str | None
options: dict | None options: dict | None
cache: bool | None use_file_cache: bool | None
@callback @callback
@ -98,10 +97,10 @@ def media_source_id_to_kwargs(media_source_id: str) -> MediaSourceOptions:
"message": parsed.query["message"], "message": parsed.query["message"],
"language": parsed.query.get("language"), "language": parsed.query.get("language"),
"options": options, "options": options,
"cache": None, "use_file_cache": None,
} }
if "cache" in parsed.query: if "cache" in parsed.query:
kwargs["cache"] = parsed.query["cache"] == "true" kwargs["use_file_cache"] = parsed.query["cache"] == "true"
return kwargs return kwargs
@ -119,7 +118,7 @@ class TTSMediaSource(MediaSource):
async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia: async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia:
"""Resolve media to a url.""" """Resolve media to a url."""
try: try:
url = await self.hass.data[DATA_TTS_MANAGER].async_get_url_path( stream = self.hass.data[DATA_TTS_MANAGER].async_create_result_stream(
**media_source_id_to_kwargs(item.identifier) **media_source_id_to_kwargs(item.identifier)
) )
except Unresolvable: except Unresolvable:
@ -127,9 +126,7 @@ class TTSMediaSource(MediaSource):
except HomeAssistantError as err: except HomeAssistantError as err:
raise Unresolvable(str(err)) from err raise Unresolvable(str(err)) from err
mime_type = mimetypes.guess_type(url)[0] or "audio/mpeg" return PlayMedia(stream.url, stream.content_type)
return PlayMedia(url, mime_type)
async def async_browse_media( async def async_browse_media(
self, self,

View File

@ -350,7 +350,7 @@ async def test_tts_service_speak_error(
assert len(calls) == 1 assert len(calls) == 1
assert ( assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.NOT_FOUND == HTTPStatus.INTERNAL_SERVER_ERROR
) )
tts_entity._client.generate.assert_called_once_with( tts_entity._client.generate.assert_called_once_with(

View File

@ -475,6 +475,6 @@ async def test_service_say_error(
await retrieve_media( await retrieve_media(
hass, hass_client, service_calls[1].data[ATTR_MEDIA_CONTENT_ID] hass, hass_client, service_calls[1].data[ATTR_MEDIA_CONTENT_ID]
) )
== HTTPStatus.NOT_FOUND == HTTPStatus.INTERNAL_SERVER_ERROR
) )
assert len(mock_gtts.mock_calls) == 2 assert len(mock_gtts.mock_calls) == 2

View File

@ -155,7 +155,7 @@ async def test_service_say_http_error(
await retrieve_media( await retrieve_media(
hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID] hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]
) )
== HTTPStatus.NOT_FOUND == HTTPStatus.INTERNAL_SERVER_ERROR
) )
mock_speak.assert_called_once() mock_speak.assert_called_once()

View File

@ -366,7 +366,7 @@ async def test_service_say_error(
await retrieve_media( await retrieve_media(
hass, hass_client, service_calls[1].data[ATTR_MEDIA_CONTENT_ID] hass, hass_client, service_calls[1].data[ATTR_MEDIA_CONTENT_ID]
) )
== HTTPStatus.NOT_FOUND == HTTPStatus.INTERNAL_SERVER_ERROR
) )
assert len(mock_tts.mock_calls) == 2 assert len(mock_tts.mock_calls) == 2

View File

@ -1197,7 +1197,7 @@ async def test_service_get_tts_error(
assert len(calls) == 1 assert len(calls) == 1
assert ( assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.NOT_FOUND == HTTPStatus.INTERNAL_SERVER_ERROR
) )

View File

@ -268,7 +268,7 @@ async def test_generate_media_source_id_and_media_source_id_to_kwargs(
"message": "hello", "message": "hello",
"language": "en_US", "language": "en_US",
"options": {"age": 5}, "options": {"age": 5},
"cache": True, "use_file_cache": True,
} }
kwargs = { kwargs = {
@ -284,7 +284,7 @@ async def test_generate_media_source_id_and_media_source_id_to_kwargs(
"message": "hello", "message": "hello",
"language": "en_US", "language": "en_US",
"options": {"age": [5, 6]}, "options": {"age": [5, 6]},
"cache": True, "use_file_cache": True,
} }
kwargs = { kwargs = {
@ -300,5 +300,5 @@ async def test_generate_media_source_id_and_media_source_id_to_kwargs(
"message": "hello", "message": "hello",
"language": "en_US", "language": "en_US",
"options": {"age": {"k1": [5, 6], "k2": "v2"}}, "options": {"age": {"k1": [5, 6], "k2": "v2"}},
"cache": True, "use_file_cache": True,
} }

View File

@ -200,7 +200,7 @@ async def test_service_say_error(
assert ( assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.NOT_FOUND == HTTPStatus.INTERNAL_SERVER_ERROR
) )
assert len(aioclient_mock.mock_calls) == 1 assert len(aioclient_mock.mock_calls) == 1
assert aioclient_mock.mock_calls[0][2] == FORM_DATA assert aioclient_mock.mock_calls[0][2] == FORM_DATA
@ -234,7 +234,7 @@ async def test_service_say_timeout(
assert ( assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.NOT_FOUND == HTTPStatus.INTERNAL_SERVER_ERROR
) )
assert len(aioclient_mock.mock_calls) == 1 assert len(aioclient_mock.mock_calls) == 1
assert aioclient_mock.mock_calls[0][2] == FORM_DATA assert aioclient_mock.mock_calls[0][2] == FORM_DATA
@ -273,7 +273,7 @@ async def test_service_say_error_msg(
assert ( assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.NOT_FOUND == HTTPStatus.INTERNAL_SERVER_ERROR
) )
assert len(aioclient_mock.mock_calls) == 1 assert len(aioclient_mock.mock_calls) == 1
assert aioclient_mock.mock_calls[0][2] == FORM_DATA assert aioclient_mock.mock_calls[0][2] == FORM_DATA

View File

@ -223,7 +223,7 @@ async def test_service_say_timeout(
assert len(calls) == 1 assert len(calls) == 1
assert ( assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.NOT_FOUND == HTTPStatus.INTERNAL_SERVER_ERROR
) )
assert len(aioclient_mock.mock_calls) == 1 assert len(aioclient_mock.mock_calls) == 1
@ -269,7 +269,7 @@ async def test_service_say_http_error(
assert len(calls) == 1 assert len(calls) == 1
assert ( assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.NOT_FOUND == HTTPStatus.INTERNAL_SERVER_ERROR
) )