mirror of
https://github.com/home-assistant/core.git
synced 2025-07-22 20:57:21 +00:00
Text-to-Speech refactor (#139482)
* Refactor TTS * More cleanup * Cleanup * Consolidate more * Inline another function * Inline another function * Improve cleanup
This commit is contained in:
parent
49c27ae7bc
commit
70bb56e0fc
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import hashlib
|
import hashlib
|
||||||
@ -16,6 +17,7 @@ import re
|
|||||||
import secrets
|
import secrets
|
||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from time import monotonic
|
||||||
from typing import Any, Final, TypedDict, final
|
from typing import Any, Final, TypedDict, final
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
@ -37,11 +39,20 @@ from homeassistant.components.media_player import (
|
|||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
ATTR_ENTITY_ID,
|
ATTR_ENTITY_ID,
|
||||||
|
EVENT_HOMEASSISTANT_STOP,
|
||||||
PLATFORM_FORMAT,
|
PLATFORM_FORMAT,
|
||||||
STATE_UNAVAILABLE,
|
STATE_UNAVAILABLE,
|
||||||
STATE_UNKNOWN,
|
STATE_UNKNOWN,
|
||||||
)
|
)
|
||||||
from homeassistant.core import HassJob, HomeAssistant, ServiceCall, callback
|
from homeassistant.core import (
|
||||||
|
CALLBACK_TYPE,
|
||||||
|
Event,
|
||||||
|
HassJob,
|
||||||
|
HassJobType,
|
||||||
|
HomeAssistant,
|
||||||
|
ServiceCall,
|
||||||
|
callback,
|
||||||
|
)
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import config_validation as cv
|
from homeassistant.helpers import config_validation as cv
|
||||||
from homeassistant.helpers.entity_component import EntityComponent
|
from homeassistant.helpers.entity_component import EntityComponent
|
||||||
@ -129,9 +140,10 @@ SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({})
|
|||||||
class TTSCache(TypedDict):
|
class TTSCache(TypedDict):
|
||||||
"""Cached TTS file."""
|
"""Cached TTS file."""
|
||||||
|
|
||||||
filename: str
|
extension: str
|
||||||
voice: bytes
|
voice: bytes
|
||||||
pending: asyncio.Task | None
|
pending: asyncio.Task | None
|
||||||
|
last_used: float
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@ -192,9 +204,11 @@ async def async_get_media_source_audio(
|
|||||||
media_source_id: str,
|
media_source_id: str,
|
||||||
) -> tuple[str, bytes]:
|
) -> tuple[str, bytes]:
|
||||||
"""Get TTS audio as extension, data."""
|
"""Get TTS audio as extension, data."""
|
||||||
return await hass.data[DATA_TTS_MANAGER].async_get_tts_audio(
|
manager = hass.data[DATA_TTS_MANAGER]
|
||||||
**media_source_id_to_kwargs(media_source_id),
|
cache_key = manager.async_cache_message_in_memory(
|
||||||
|
**media_source_id_to_kwargs(media_source_id)
|
||||||
)
|
)
|
||||||
|
return await manager.async_get_tts_audio(cache_key)
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@ -306,11 +320,11 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
|
|
||||||
# Legacy config options
|
# Legacy config options
|
||||||
conf = config[DOMAIN][0] if config.get(DOMAIN) else {}
|
conf = config[DOMAIN][0] if config.get(DOMAIN) else {}
|
||||||
use_cache: bool = conf.get(CONF_CACHE, DEFAULT_CACHE)
|
use_file_cache: bool = conf.get(CONF_CACHE, DEFAULT_CACHE)
|
||||||
cache_dir: str = conf.get(CONF_CACHE_DIR, DEFAULT_CACHE_DIR)
|
cache_dir: str = conf.get(CONF_CACHE_DIR, DEFAULT_CACHE_DIR)
|
||||||
time_memory: int = conf.get(CONF_TIME_MEMORY, DEFAULT_TIME_MEMORY)
|
memory_cache_maxage: int = conf.get(CONF_TIME_MEMORY, DEFAULT_TIME_MEMORY)
|
||||||
|
|
||||||
tts = SpeechManager(hass, use_cache, cache_dir, time_memory)
|
tts = SpeechManager(hass, use_file_cache, cache_dir, memory_cache_maxage)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await tts.async_init_cache()
|
await tts.async_init_cache()
|
||||||
@ -383,6 +397,40 @@ CACHED_PROPERTIES_WITH_ATTR_ = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ResultStream:
|
||||||
|
"""Class that will stream the result when available."""
|
||||||
|
|
||||||
|
# Streaming/conversion properties
|
||||||
|
url: str
|
||||||
|
extension: str
|
||||||
|
content_type: str
|
||||||
|
|
||||||
|
# TTS properties
|
||||||
|
engine: str
|
||||||
|
use_file_cache: bool
|
||||||
|
language: str
|
||||||
|
options: dict
|
||||||
|
|
||||||
|
_manager: SpeechManager
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _result_cache_key(self) -> asyncio.Future[str]:
|
||||||
|
"""Get the future that returns the cache key."""
|
||||||
|
return asyncio.Future()
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_set_message_cache_key(self, cache_key: str) -> None:
|
||||||
|
"""Set cache key for message to be streamed."""
|
||||||
|
self._result_cache_key.set_result(cache_key)
|
||||||
|
|
||||||
|
async def async_get_result(self) -> bytes:
|
||||||
|
"""Get the stream of this result."""
|
||||||
|
cache_key = await self._result_cache_key
|
||||||
|
_extension, data = await self._manager.async_get_tts_audio(cache_key)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH_ATTR_):
|
class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH_ATTR_):
|
||||||
"""Represent a single TTS engine."""
|
"""Represent a single TTS engine."""
|
||||||
|
|
||||||
@ -521,29 +569,82 @@ def _hash_options(options: dict) -> str:
|
|||||||
return opts_hash.hexdigest()
|
return opts_hash.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
class MemcacheCleanup:
|
||||||
|
"""Helper to clean up the stale sessions."""
|
||||||
|
|
||||||
|
unsub: CALLBACK_TYPE | None = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, hass: HomeAssistant, maxage: float, memcache: dict[str, TTSCache]
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the cleanup."""
|
||||||
|
self.hass = hass
|
||||||
|
self.maxage = maxage
|
||||||
|
self.memcache = memcache
|
||||||
|
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self._on_hass_stop)
|
||||||
|
self.cleanup_job = HassJob(
|
||||||
|
self._cleanup, "chat_session_cleanup", job_type=HassJobType.Callback
|
||||||
|
)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def schedule(self) -> None:
|
||||||
|
"""Schedule the cleanup."""
|
||||||
|
if self.unsub:
|
||||||
|
return
|
||||||
|
self.unsub = async_call_later(
|
||||||
|
self.hass,
|
||||||
|
self.maxage + 1,
|
||||||
|
self.cleanup_job,
|
||||||
|
)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _on_hass_stop(self, event: Event) -> None:
|
||||||
|
"""Cancel the cleanup on shutdown."""
|
||||||
|
if self.unsub:
|
||||||
|
self.unsub()
|
||||||
|
self.unsub = None
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _cleanup(self, _now: datetime) -> None:
|
||||||
|
"""Clean up and schedule follow-up if necessary."""
|
||||||
|
self.unsub = None
|
||||||
|
memcache = self.memcache
|
||||||
|
maxage = self.maxage
|
||||||
|
now = monotonic()
|
||||||
|
|
||||||
|
for cache_key, info in list(memcache.items()):
|
||||||
|
if info["last_used"] + maxage < now:
|
||||||
|
_LOGGER.debug("Cleaning up %s", cache_key)
|
||||||
|
del memcache[cache_key]
|
||||||
|
|
||||||
|
# Still items left, check again in timeout time.
|
||||||
|
if memcache:
|
||||||
|
self.schedule()
|
||||||
|
|
||||||
|
|
||||||
class SpeechManager:
|
class SpeechManager:
|
||||||
"""Representation of a speech store."""
|
"""Representation of a speech store."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
use_cache: bool,
|
use_file_cache: bool,
|
||||||
cache_dir: str,
|
cache_dir: str,
|
||||||
time_memory: int,
|
memory_cache_maxage: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize a speech store."""
|
"""Initialize a speech store."""
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self.providers: dict[str, Provider] = {}
|
self.providers: dict[str, Provider] = {}
|
||||||
|
|
||||||
self.use_cache = use_cache
|
self.use_file_cache = use_file_cache
|
||||||
self.cache_dir = cache_dir
|
self.cache_dir = cache_dir
|
||||||
self.time_memory = time_memory
|
self.memory_cache_maxage = memory_cache_maxage
|
||||||
self.file_cache: dict[str, str] = {}
|
self.file_cache: dict[str, str] = {}
|
||||||
self.mem_cache: dict[str, TTSCache] = {}
|
self.mem_cache: dict[str, TTSCache] = {}
|
||||||
|
self.token_to_stream: dict[str, ResultStream] = {}
|
||||||
# filename <-> token
|
self.memcache_cleanup = MemcacheCleanup(
|
||||||
self.filename_to_token: dict[str, str] = {}
|
hass, memory_cache_maxage, self.mem_cache
|
||||||
self.token_to_filename: dict[str, str] = {}
|
)
|
||||||
|
|
||||||
def _init_cache(self) -> dict[str, str]:
|
def _init_cache(self) -> dict[str, str]:
|
||||||
"""Init cache folder and fetch files."""
|
"""Init cache folder and fetch files."""
|
||||||
@ -563,18 +664,21 @@ class SpeechManager:
|
|||||||
|
|
||||||
async def async_clear_cache(self) -> None:
|
async def async_clear_cache(self) -> None:
|
||||||
"""Read file cache and delete files."""
|
"""Read file cache and delete files."""
|
||||||
self.mem_cache = {}
|
self.mem_cache.clear()
|
||||||
|
|
||||||
def remove_files() -> None:
|
def remove_files(files: list[str]) -> None:
|
||||||
"""Remove files from filesystem."""
|
"""Remove files from filesystem."""
|
||||||
for filename in self.file_cache.values():
|
for filename in files:
|
||||||
try:
|
try:
|
||||||
os.remove(os.path.join(self.cache_dir, filename))
|
os.remove(os.path.join(self.cache_dir, filename))
|
||||||
except OSError as err:
|
except OSError as err:
|
||||||
_LOGGER.warning("Can't remove cache file '%s': %s", filename, err)
|
_LOGGER.warning("Can't remove cache file '%s': %s", filename, err)
|
||||||
|
|
||||||
await self.hass.async_add_executor_job(remove_files)
|
task = self.hass.async_add_executor_job(
|
||||||
self.file_cache = {}
|
remove_files, list(self.file_cache.values())
|
||||||
|
)
|
||||||
|
self.file_cache.clear()
|
||||||
|
await task
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_register_legacy_engine(
|
def async_register_legacy_engine(
|
||||||
@ -629,107 +733,153 @@ class SpeechManager:
|
|||||||
|
|
||||||
return language, merged_options
|
return language, merged_options
|
||||||
|
|
||||||
async def async_get_url_path(
|
@callback
|
||||||
|
def async_create_result_stream(
|
||||||
self,
|
self,
|
||||||
engine: str,
|
engine: str,
|
||||||
message: str,
|
message: str | None = None,
|
||||||
cache: bool | None = None,
|
use_file_cache: bool | None = None,
|
||||||
language: str | None = None,
|
language: str | None = None,
|
||||||
options: dict | None = None,
|
options: dict | None = None,
|
||||||
) -> str:
|
) -> ResultStream:
|
||||||
"""Get URL for play message.
|
"""Create a streaming URL where the rendered TTS can be retrieved."""
|
||||||
|
|
||||||
This method is a coroutine.
|
|
||||||
"""
|
|
||||||
if (engine_instance := get_engine_instance(self.hass, engine)) is None:
|
if (engine_instance := get_engine_instance(self.hass, engine)) is None:
|
||||||
raise HomeAssistantError(f"Provider {engine} not found")
|
raise HomeAssistantError(f"Provider {engine} not found")
|
||||||
|
|
||||||
language, options = self.process_options(engine_instance, language, options)
|
language, options = self.process_options(engine_instance, language, options)
|
||||||
cache_key = self._generate_cache_key(message, language, options, engine)
|
if use_file_cache is None:
|
||||||
use_cache = cache if cache is not None else self.use_cache
|
use_file_cache = self.use_file_cache
|
||||||
|
|
||||||
# Is speech already in memory
|
extension = options.get(ATTR_PREFERRED_FORMAT, _DEFAULT_FORMAT)
|
||||||
if cache_key in self.mem_cache:
|
token = f"{secrets.token_urlsafe(16)}.{extension}"
|
||||||
filename = self.mem_cache[cache_key]["filename"]
|
content, _ = mimetypes.guess_type(token)
|
||||||
# Is file store in file cache
|
result_stream = ResultStream(
|
||||||
elif use_cache and cache_key in self.file_cache:
|
url=f"/api/tts_proxy/{token}",
|
||||||
filename = self.file_cache[cache_key]
|
extension=extension,
|
||||||
self.hass.async_create_task(self._async_file_to_mem(cache_key))
|
content_type=content or "audio/mpeg",
|
||||||
# Load speech from engine into memory
|
use_file_cache=use_file_cache,
|
||||||
else:
|
engine=engine,
|
||||||
filename = await self._async_get_tts_audio(
|
language=language,
|
||||||
engine_instance, cache_key, message, use_cache, language, options
|
options=options,
|
||||||
)
|
_manager=self,
|
||||||
|
)
|
||||||
|
self.token_to_stream[token] = result_stream
|
||||||
|
|
||||||
# Use a randomly generated token instead of exposing the filename
|
if message is None:
|
||||||
token = self.filename_to_token.get(filename)
|
return result_stream
|
||||||
if not token:
|
|
||||||
# Keep extension (.mp3, etc.)
|
|
||||||
token = secrets.token_urlsafe(16) + os.path.splitext(filename)[1]
|
|
||||||
|
|
||||||
# Map token <-> filename
|
cache_key = self._async_ensure_cached_in_memory(
|
||||||
self.filename_to_token[filename] = token
|
engine=engine,
|
||||||
self.token_to_filename[token] = filename
|
engine_instance=engine_instance,
|
||||||
|
message=message,
|
||||||
|
use_file_cache=use_file_cache,
|
||||||
|
language=language,
|
||||||
|
options=options,
|
||||||
|
)
|
||||||
|
result_stream.async_set_message_cache_key(cache_key)
|
||||||
|
|
||||||
return f"/api/tts_proxy/{token}"
|
return result_stream
|
||||||
|
|
||||||
async def async_get_tts_audio(
|
|
||||||
self,
|
|
||||||
engine: str,
|
|
||||||
message: str,
|
|
||||||
cache: bool | None = None,
|
|
||||||
language: str | None = None,
|
|
||||||
options: dict | None = None,
|
|
||||||
) -> tuple[str, bytes]:
|
|
||||||
"""Fetch TTS audio."""
|
|
||||||
if (engine_instance := get_engine_instance(self.hass, engine)) is None:
|
|
||||||
raise HomeAssistantError(f"Provider {engine} not found")
|
|
||||||
|
|
||||||
language, options = self.process_options(engine_instance, language, options)
|
|
||||||
cache_key = self._generate_cache_key(message, language, options, engine)
|
|
||||||
use_cache = cache if cache is not None else self.use_cache
|
|
||||||
|
|
||||||
# If we have the file, load it into memory if necessary
|
|
||||||
if cache_key not in self.mem_cache:
|
|
||||||
if use_cache and cache_key in self.file_cache:
|
|
||||||
await self._async_file_to_mem(cache_key)
|
|
||||||
else:
|
|
||||||
await self._async_get_tts_audio(
|
|
||||||
engine_instance, cache_key, message, use_cache, language, options
|
|
||||||
)
|
|
||||||
|
|
||||||
extension = os.path.splitext(self.mem_cache[cache_key]["filename"])[1][1:]
|
|
||||||
cached = self.mem_cache[cache_key]
|
|
||||||
if pending := cached.get("pending"):
|
|
||||||
await pending
|
|
||||||
cached = self.mem_cache[cache_key]
|
|
||||||
return extension, cached["voice"]
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _generate_cache_key(
|
def async_cache_message_in_memory(
|
||||||
self,
|
self,
|
||||||
message: str,
|
|
||||||
language: str,
|
|
||||||
options: dict | None,
|
|
||||||
engine: str,
|
engine: str,
|
||||||
|
message: str,
|
||||||
|
use_file_cache: bool | None = None,
|
||||||
|
language: str | None = None,
|
||||||
|
options: dict | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate a cache key for a message."""
|
"""Make sure a message is cached in memory and returns cache key."""
|
||||||
|
if (engine_instance := get_engine_instance(self.hass, engine)) is None:
|
||||||
|
raise HomeAssistantError(f"Provider {engine} not found")
|
||||||
|
|
||||||
|
language, options = self.process_options(engine_instance, language, options)
|
||||||
|
if use_file_cache is None:
|
||||||
|
use_file_cache = self.use_file_cache
|
||||||
|
|
||||||
|
return self._async_ensure_cached_in_memory(
|
||||||
|
engine=engine,
|
||||||
|
engine_instance=engine_instance,
|
||||||
|
message=message,
|
||||||
|
use_file_cache=use_file_cache,
|
||||||
|
language=language,
|
||||||
|
options=options,
|
||||||
|
)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _async_ensure_cached_in_memory(
|
||||||
|
self,
|
||||||
|
engine: str,
|
||||||
|
engine_instance: TextToSpeechEntity | Provider,
|
||||||
|
message: str,
|
||||||
|
use_file_cache: bool,
|
||||||
|
language: str,
|
||||||
|
options: dict,
|
||||||
|
) -> str:
|
||||||
|
"""Ensure a message is cached.
|
||||||
|
|
||||||
|
Requires options, language to be processed.
|
||||||
|
"""
|
||||||
options_key = _hash_options(options) if options else "-"
|
options_key = _hash_options(options) if options else "-"
|
||||||
msg_hash = hashlib.sha1(bytes(message, "utf-8")).hexdigest()
|
msg_hash = hashlib.sha1(bytes(message, "utf-8")).hexdigest()
|
||||||
return KEY_PATTERN.format(
|
cache_key = KEY_PATTERN.format(
|
||||||
msg_hash, language.replace("_", "-"), options_key, engine
|
msg_hash, language.replace("_", "-"), options_key, engine
|
||||||
).lower()
|
).lower()
|
||||||
|
|
||||||
async def _async_get_tts_audio(
|
# Is speech already in memory
|
||||||
|
if cache_key in self.mem_cache:
|
||||||
|
return cache_key
|
||||||
|
|
||||||
|
if use_file_cache and cache_key in self.file_cache:
|
||||||
|
coro = self._async_load_file_to_mem(cache_key)
|
||||||
|
else:
|
||||||
|
coro = self._async_generate_tts_audio(
|
||||||
|
engine_instance, cache_key, message, use_file_cache, language, options
|
||||||
|
)
|
||||||
|
|
||||||
|
task = self.hass.async_create_task(coro, eager_start=False)
|
||||||
|
|
||||||
|
def handle_error(future: asyncio.Future) -> None:
|
||||||
|
"""Handle error."""
|
||||||
|
if not (err := future.exception()):
|
||||||
|
return
|
||||||
|
# Truncate message so we don't flood the logs. Cutting off at 32 chars
|
||||||
|
# but since we add 3 dots to truncated message, we cut off at 35.
|
||||||
|
trunc_msg = message if len(message) < 35 else f"{message[0:32]}…"
|
||||||
|
_LOGGER.error("Error generating audio for %s: %s", trunc_msg, err)
|
||||||
|
self.mem_cache.pop(cache_key, None)
|
||||||
|
|
||||||
|
task.add_done_callback(handle_error)
|
||||||
|
|
||||||
|
self.mem_cache[cache_key] = {
|
||||||
|
"extension": "",
|
||||||
|
"voice": b"",
|
||||||
|
"pending": task,
|
||||||
|
"last_used": monotonic(),
|
||||||
|
}
|
||||||
|
return cache_key
|
||||||
|
|
||||||
|
async def async_get_tts_audio(self, cache_key: str) -> tuple[str, bytes]:
|
||||||
|
"""Fetch TTS audio."""
|
||||||
|
cached = self.mem_cache.get(cache_key)
|
||||||
|
if cached is None:
|
||||||
|
raise HomeAssistantError("Audio not cached")
|
||||||
|
if pending := cached.get("pending"):
|
||||||
|
await pending
|
||||||
|
cached = self.mem_cache[cache_key]
|
||||||
|
cached["last_used"] = monotonic()
|
||||||
|
return cached["extension"], cached["voice"]
|
||||||
|
|
||||||
|
async def _async_generate_tts_audio(
|
||||||
self,
|
self,
|
||||||
engine_instance: TextToSpeechEntity | Provider,
|
engine_instance: TextToSpeechEntity | Provider,
|
||||||
cache_key: str,
|
cache_key: str,
|
||||||
message: str,
|
message: str,
|
||||||
cache: bool,
|
cache_to_disk: bool,
|
||||||
language: str,
|
language: str,
|
||||||
options: dict[str, Any],
|
options: dict[str, Any],
|
||||||
) -> str:
|
) -> None:
|
||||||
"""Receive TTS, store for view in cache and return filename.
|
"""Start loading of the TTS audio.
|
||||||
|
|
||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
"""
|
"""
|
||||||
@ -773,96 +923,66 @@ class SpeechManager:
|
|||||||
if sample_bytes is not None:
|
if sample_bytes is not None:
|
||||||
sample_bytes = int(sample_bytes)
|
sample_bytes = int(sample_bytes)
|
||||||
|
|
||||||
async def get_tts_data() -> str:
|
if engine_instance.name is None or engine_instance.name is UNDEFINED:
|
||||||
"""Handle data available."""
|
raise HomeAssistantError("TTS engine name is not set.")
|
||||||
if engine_instance.name is None or engine_instance.name is UNDEFINED:
|
|
||||||
raise HomeAssistantError("TTS engine name is not set.")
|
|
||||||
|
|
||||||
if isinstance(engine_instance, Provider):
|
if isinstance(engine_instance, Provider):
|
||||||
extension, data = await engine_instance.async_get_tts_audio(
|
extension, data = await engine_instance.async_get_tts_audio(
|
||||||
message, language, options
|
message, language, options
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
extension, data = await engine_instance.internal_async_get_tts_audio(
|
extension, data = await engine_instance.internal_async_get_tts_audio(
|
||||||
message, language, options
|
message, language, options
|
||||||
)
|
|
||||||
|
|
||||||
if data is None or extension is None:
|
|
||||||
raise HomeAssistantError(
|
|
||||||
f"No TTS from {engine_instance.name} for '{message}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Only convert if we have a preferred format different than the
|
|
||||||
# expected format from the TTS system, or if a specific sample
|
|
||||||
# rate/format/channel count is requested.
|
|
||||||
needs_conversion = (
|
|
||||||
(final_extension != extension)
|
|
||||||
or (sample_rate is not None)
|
|
||||||
or (sample_channels is not None)
|
|
||||||
or (sample_bytes is not None)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if needs_conversion:
|
if data is None or extension is None:
|
||||||
data = await async_convert_audio(
|
raise HomeAssistantError(
|
||||||
self.hass,
|
f"No TTS from {engine_instance.name} for '{message}'"
|
||||||
extension,
|
)
|
||||||
data,
|
|
||||||
to_extension=final_extension,
|
|
||||||
to_sample_rate=sample_rate,
|
|
||||||
to_sample_channels=sample_channels,
|
|
||||||
to_sample_bytes=sample_bytes,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create file infos
|
# Only convert if we have a preferred format different than the
|
||||||
filename = f"{cache_key}.{final_extension}".lower()
|
# expected format from the TTS system, or if a specific sample
|
||||||
|
# rate/format/channel count is requested.
|
||||||
|
needs_conversion = (
|
||||||
|
(final_extension != extension)
|
||||||
|
or (sample_rate is not None)
|
||||||
|
or (sample_channels is not None)
|
||||||
|
or (sample_bytes is not None)
|
||||||
|
)
|
||||||
|
|
||||||
# Validate filename
|
if needs_conversion:
|
||||||
if not _RE_VOICE_FILE.match(filename) and not _RE_LEGACY_VOICE_FILE.match(
|
data = await async_convert_audio(
|
||||||
filename
|
self.hass,
|
||||||
):
|
extension,
|
||||||
raise HomeAssistantError(
|
data,
|
||||||
f"TTS filename '{filename}' from {engine_instance.name} is invalid!"
|
to_extension=final_extension,
|
||||||
)
|
to_sample_rate=sample_rate,
|
||||||
|
to_sample_channels=sample_channels,
|
||||||
# Save to memory
|
to_sample_bytes=sample_bytes,
|
||||||
if final_extension == "mp3":
|
)
|
||||||
data = self.write_tags(
|
|
||||||
filename, data, engine_instance.name, message, language, options
|
|
||||||
)
|
|
||||||
|
|
||||||
self._async_store_to_memcache(cache_key, filename, data)
|
|
||||||
|
|
||||||
if cache:
|
|
||||||
self.hass.async_create_task(
|
|
||||||
self._async_save_tts_audio(cache_key, filename, data)
|
|
||||||
)
|
|
||||||
|
|
||||||
return filename
|
|
||||||
|
|
||||||
audio_task = self.hass.async_create_task(get_tts_data(), eager_start=False)
|
|
||||||
|
|
||||||
def handle_error(_future: asyncio.Future) -> None:
|
|
||||||
"""Handle error."""
|
|
||||||
if audio_task.exception():
|
|
||||||
self.mem_cache.pop(cache_key, None)
|
|
||||||
|
|
||||||
audio_task.add_done_callback(handle_error)
|
|
||||||
|
|
||||||
|
# Create file infos
|
||||||
filename = f"{cache_key}.{final_extension}".lower()
|
filename = f"{cache_key}.{final_extension}".lower()
|
||||||
self.mem_cache[cache_key] = {
|
|
||||||
"filename": filename,
|
|
||||||
"voice": b"",
|
|
||||||
"pending": audio_task,
|
|
||||||
}
|
|
||||||
return filename
|
|
||||||
|
|
||||||
async def _async_save_tts_audio(
|
# Validate filename
|
||||||
self, cache_key: str, filename: str, data: bytes
|
if not _RE_VOICE_FILE.match(filename) and not _RE_LEGACY_VOICE_FILE.match(
|
||||||
) -> None:
|
filename
|
||||||
"""Store voice data to file and file_cache.
|
):
|
||||||
|
raise HomeAssistantError(
|
||||||
|
f"TTS filename '{filename}' from {engine_instance.name} is invalid!"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save to memory
|
||||||
|
if final_extension == "mp3":
|
||||||
|
data = self.write_tags(
|
||||||
|
filename, data, engine_instance.name, message, language, options
|
||||||
|
)
|
||||||
|
|
||||||
|
self._async_store_to_memcache(cache_key, final_extension, data)
|
||||||
|
|
||||||
|
if not cache_to_disk:
|
||||||
|
return
|
||||||
|
|
||||||
This method is a coroutine.
|
|
||||||
"""
|
|
||||||
voice_file = os.path.join(self.cache_dir, filename)
|
voice_file = os.path.join(self.cache_dir, filename)
|
||||||
|
|
||||||
def save_speech() -> None:
|
def save_speech() -> None:
|
||||||
@ -870,13 +990,19 @@ class SpeechManager:
|
|||||||
with open(voice_file, "wb") as speech:
|
with open(voice_file, "wb") as speech:
|
||||||
speech.write(data)
|
speech.write(data)
|
||||||
|
|
||||||
try:
|
# Don't await, we're going to do this in the background
|
||||||
await self.hass.async_add_executor_job(save_speech)
|
task = self.hass.async_add_executor_job(save_speech)
|
||||||
self.file_cache[cache_key] = filename
|
|
||||||
except OSError as err:
|
|
||||||
_LOGGER.error("Can't write %s: %s", filename, err)
|
|
||||||
|
|
||||||
async def _async_file_to_mem(self, cache_key: str) -> None:
|
def write_done(future: asyncio.Future) -> None:
|
||||||
|
"""Write is done task."""
|
||||||
|
if err := future.exception():
|
||||||
|
_LOGGER.error("Can't write %s: %s", filename, err)
|
||||||
|
else:
|
||||||
|
self.file_cache[cache_key] = filename
|
||||||
|
|
||||||
|
task.add_done_callback(write_done)
|
||||||
|
|
||||||
|
async def _async_load_file_to_mem(self, cache_key: str) -> None:
|
||||||
"""Load voice from file cache into memory.
|
"""Load voice from file cache into memory.
|
||||||
|
|
||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
@ -897,64 +1023,22 @@ class SpeechManager:
|
|||||||
del self.file_cache[cache_key]
|
del self.file_cache[cache_key]
|
||||||
raise HomeAssistantError(f"Can't read {voice_file}") from err
|
raise HomeAssistantError(f"Can't read {voice_file}") from err
|
||||||
|
|
||||||
self._async_store_to_memcache(cache_key, filename, data)
|
extension = os.path.splitext(filename)[1][1:]
|
||||||
|
|
||||||
|
self._async_store_to_memcache(cache_key, extension, data)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_store_to_memcache(
|
def _async_store_to_memcache(
|
||||||
self, cache_key: str, filename: str, data: bytes
|
self, cache_key: str, extension: str, data: bytes
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Store data to memcache and set timer to remove it."""
|
"""Store data to memcache and set timer to remove it."""
|
||||||
self.mem_cache[cache_key] = {
|
self.mem_cache[cache_key] = {
|
||||||
"filename": filename,
|
"extension": extension,
|
||||||
"voice": data,
|
"voice": data,
|
||||||
"pending": None,
|
"pending": None,
|
||||||
|
"last_used": monotonic(),
|
||||||
}
|
}
|
||||||
|
self.memcache_cleanup.schedule()
|
||||||
@callback
|
|
||||||
def async_remove_from_mem(_: datetime) -> None:
|
|
||||||
"""Cleanup memcache."""
|
|
||||||
self.mem_cache.pop(cache_key, None)
|
|
||||||
|
|
||||||
async_call_later(
|
|
||||||
self.hass,
|
|
||||||
self.time_memory,
|
|
||||||
HassJob(
|
|
||||||
async_remove_from_mem,
|
|
||||||
name="tts remove_from_mem",
|
|
||||||
cancel_on_shutdown=True,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def async_read_tts(self, token: str) -> tuple[str | None, bytes]:
|
|
||||||
"""Read a voice file and return binary.
|
|
||||||
|
|
||||||
This method is a coroutine.
|
|
||||||
"""
|
|
||||||
filename = self.token_to_filename.get(token)
|
|
||||||
if not filename:
|
|
||||||
raise HomeAssistantError(f"{token} was not recognized!")
|
|
||||||
|
|
||||||
if not (record := _RE_VOICE_FILE.match(filename.lower())) and not (
|
|
||||||
record := _RE_LEGACY_VOICE_FILE.match(filename.lower())
|
|
||||||
):
|
|
||||||
raise HomeAssistantError("Wrong tts file format!")
|
|
||||||
|
|
||||||
cache_key = KEY_PATTERN.format(
|
|
||||||
record.group(1), record.group(2), record.group(3), record.group(4)
|
|
||||||
)
|
|
||||||
|
|
||||||
if cache_key not in self.mem_cache:
|
|
||||||
if cache_key not in self.file_cache:
|
|
||||||
raise HomeAssistantError(f"{cache_key} not in cache!")
|
|
||||||
await self._async_file_to_mem(cache_key)
|
|
||||||
|
|
||||||
cached = self.mem_cache[cache_key]
|
|
||||||
if pending := cached.get("pending"):
|
|
||||||
await pending
|
|
||||||
cached = self.mem_cache[cache_key]
|
|
||||||
|
|
||||||
content, _ = mimetypes.guess_type(filename)
|
|
||||||
return content, cached["voice"]
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def write_tags(
|
def write_tags(
|
||||||
@ -1042,9 +1126,9 @@ class TextToSpeechUrlView(HomeAssistantView):
|
|||||||
url = "/api/tts_get_url"
|
url = "/api/tts_get_url"
|
||||||
name = "api:tts:geturl"
|
name = "api:tts:geturl"
|
||||||
|
|
||||||
def __init__(self, tts: SpeechManager) -> None:
|
def __init__(self, manager: SpeechManager) -> None:
|
||||||
"""Initialize a tts view."""
|
"""Initialize a tts view."""
|
||||||
self.tts = tts
|
self.manager = manager
|
||||||
|
|
||||||
async def post(self, request: web.Request) -> web.Response:
|
async def post(self, request: web.Request) -> web.Response:
|
||||||
"""Generate speech and provide url."""
|
"""Generate speech and provide url."""
|
||||||
@ -1061,45 +1145,53 @@ class TextToSpeechUrlView(HomeAssistantView):
|
|||||||
|
|
||||||
engine = data.get("engine_id") or data[ATTR_PLATFORM]
|
engine = data.get("engine_id") or data[ATTR_PLATFORM]
|
||||||
message = data[ATTR_MESSAGE]
|
message = data[ATTR_MESSAGE]
|
||||||
cache = data.get(ATTR_CACHE)
|
use_file_cache = data.get(ATTR_CACHE)
|
||||||
language = data.get(ATTR_LANGUAGE)
|
language = data.get(ATTR_LANGUAGE)
|
||||||
options = data.get(ATTR_OPTIONS)
|
options = data.get(ATTR_OPTIONS)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
path = await self.tts.async_get_url_path(
|
stream = self.manager.async_create_result_stream(
|
||||||
engine, message, cache=cache, language=language, options=options
|
engine,
|
||||||
|
message,
|
||||||
|
use_file_cache=use_file_cache,
|
||||||
|
language=language,
|
||||||
|
options=options,
|
||||||
)
|
)
|
||||||
except HomeAssistantError as err:
|
except HomeAssistantError as err:
|
||||||
_LOGGER.error("Error on init tts: %s", err)
|
_LOGGER.error("Error on init tts: %s", err)
|
||||||
return self.json({"error": err}, HTTPStatus.BAD_REQUEST)
|
return self.json({"error": err}, HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
base = get_url(self.tts.hass)
|
base = get_url(self.manager.hass)
|
||||||
url = base + path
|
url = base + stream.url
|
||||||
|
|
||||||
return self.json({"url": url, "path": path})
|
return self.json({"url": url, "path": stream.url})
|
||||||
|
|
||||||
|
|
||||||
class TextToSpeechView(HomeAssistantView):
|
class TextToSpeechView(HomeAssistantView):
|
||||||
"""TTS view to serve a speech audio."""
|
"""TTS view to serve a speech audio."""
|
||||||
|
|
||||||
requires_auth = False
|
requires_auth = False
|
||||||
url = "/api/tts_proxy/{filename}"
|
url = "/api/tts_proxy/{token}"
|
||||||
name = "api:tts_speech"
|
name = "api:tts_speech"
|
||||||
|
|
||||||
def __init__(self, tts: SpeechManager) -> None:
|
def __init__(self, manager: SpeechManager) -> None:
|
||||||
"""Initialize a tts view."""
|
"""Initialize a tts view."""
|
||||||
self.tts = tts
|
self.manager = manager
|
||||||
|
|
||||||
async def get(self, request: web.Request, filename: str) -> web.Response:
|
async def get(self, request: web.Request, token: str) -> web.Response:
|
||||||
"""Start a get request."""
|
"""Start a get request."""
|
||||||
try:
|
stream = self.manager.token_to_stream.get(token)
|
||||||
# filename is actually token, but we keep its name for compatibility
|
|
||||||
content, data = await self.tts.async_read_tts(filename)
|
if stream is None:
|
||||||
except HomeAssistantError as err:
|
|
||||||
_LOGGER.error("Error on load tts: %s", err)
|
|
||||||
return web.Response(status=HTTPStatus.NOT_FOUND)
|
return web.Response(status=HTTPStatus.NOT_FOUND)
|
||||||
|
|
||||||
return web.Response(body=data, content_type=content)
|
try:
|
||||||
|
data = await stream.async_get_result()
|
||||||
|
except HomeAssistantError as err:
|
||||||
|
_LOGGER.error("Error on get tts: %s", err)
|
||||||
|
return web.Response(status=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||||
|
|
||||||
|
return web.Response(body=data, content_type=stream.content_type)
|
||||||
|
|
||||||
|
|
||||||
@websocket_api.websocket_command(
|
@websocket_api.websocket_command(
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import mimetypes
|
|
||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
|
|
||||||
from yarl import URL
|
from yarl import URL
|
||||||
@ -73,7 +72,7 @@ class MediaSourceOptions(TypedDict):
|
|||||||
message: str
|
message: str
|
||||||
language: str | None
|
language: str | None
|
||||||
options: dict | None
|
options: dict | None
|
||||||
cache: bool | None
|
use_file_cache: bool | None
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@ -98,10 +97,10 @@ def media_source_id_to_kwargs(media_source_id: str) -> MediaSourceOptions:
|
|||||||
"message": parsed.query["message"],
|
"message": parsed.query["message"],
|
||||||
"language": parsed.query.get("language"),
|
"language": parsed.query.get("language"),
|
||||||
"options": options,
|
"options": options,
|
||||||
"cache": None,
|
"use_file_cache": None,
|
||||||
}
|
}
|
||||||
if "cache" in parsed.query:
|
if "cache" in parsed.query:
|
||||||
kwargs["cache"] = parsed.query["cache"] == "true"
|
kwargs["use_file_cache"] = parsed.query["cache"] == "true"
|
||||||
|
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
@ -119,7 +118,7 @@ class TTSMediaSource(MediaSource):
|
|||||||
async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia:
|
async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia:
|
||||||
"""Resolve media to a url."""
|
"""Resolve media to a url."""
|
||||||
try:
|
try:
|
||||||
url = await self.hass.data[DATA_TTS_MANAGER].async_get_url_path(
|
stream = self.hass.data[DATA_TTS_MANAGER].async_create_result_stream(
|
||||||
**media_source_id_to_kwargs(item.identifier)
|
**media_source_id_to_kwargs(item.identifier)
|
||||||
)
|
)
|
||||||
except Unresolvable:
|
except Unresolvable:
|
||||||
@ -127,9 +126,7 @@ class TTSMediaSource(MediaSource):
|
|||||||
except HomeAssistantError as err:
|
except HomeAssistantError as err:
|
||||||
raise Unresolvable(str(err)) from err
|
raise Unresolvable(str(err)) from err
|
||||||
|
|
||||||
mime_type = mimetypes.guess_type(url)[0] or "audio/mpeg"
|
return PlayMedia(stream.url, stream.content_type)
|
||||||
|
|
||||||
return PlayMedia(url, mime_type)
|
|
||||||
|
|
||||||
async def async_browse_media(
|
async def async_browse_media(
|
||||||
self,
|
self,
|
||||||
|
@ -350,7 +350,7 @@ async def test_tts_service_speak_error(
|
|||||||
assert len(calls) == 1
|
assert len(calls) == 1
|
||||||
assert (
|
assert (
|
||||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||||
== HTTPStatus.NOT_FOUND
|
== HTTPStatus.INTERNAL_SERVER_ERROR
|
||||||
)
|
)
|
||||||
|
|
||||||
tts_entity._client.generate.assert_called_once_with(
|
tts_entity._client.generate.assert_called_once_with(
|
||||||
|
@ -475,6 +475,6 @@ async def test_service_say_error(
|
|||||||
await retrieve_media(
|
await retrieve_media(
|
||||||
hass, hass_client, service_calls[1].data[ATTR_MEDIA_CONTENT_ID]
|
hass, hass_client, service_calls[1].data[ATTR_MEDIA_CONTENT_ID]
|
||||||
)
|
)
|
||||||
== HTTPStatus.NOT_FOUND
|
== HTTPStatus.INTERNAL_SERVER_ERROR
|
||||||
)
|
)
|
||||||
assert len(mock_gtts.mock_calls) == 2
|
assert len(mock_gtts.mock_calls) == 2
|
||||||
|
@ -155,7 +155,7 @@ async def test_service_say_http_error(
|
|||||||
await retrieve_media(
|
await retrieve_media(
|
||||||
hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]
|
hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]
|
||||||
)
|
)
|
||||||
== HTTPStatus.NOT_FOUND
|
== HTTPStatus.INTERNAL_SERVER_ERROR
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_speak.assert_called_once()
|
mock_speak.assert_called_once()
|
||||||
|
@ -366,7 +366,7 @@ async def test_service_say_error(
|
|||||||
await retrieve_media(
|
await retrieve_media(
|
||||||
hass, hass_client, service_calls[1].data[ATTR_MEDIA_CONTENT_ID]
|
hass, hass_client, service_calls[1].data[ATTR_MEDIA_CONTENT_ID]
|
||||||
)
|
)
|
||||||
== HTTPStatus.NOT_FOUND
|
== HTTPStatus.INTERNAL_SERVER_ERROR
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(mock_tts.mock_calls) == 2
|
assert len(mock_tts.mock_calls) == 2
|
||||||
|
@ -1197,7 +1197,7 @@ async def test_service_get_tts_error(
|
|||||||
assert len(calls) == 1
|
assert len(calls) == 1
|
||||||
assert (
|
assert (
|
||||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||||
== HTTPStatus.NOT_FOUND
|
== HTTPStatus.INTERNAL_SERVER_ERROR
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -268,7 +268,7 @@ async def test_generate_media_source_id_and_media_source_id_to_kwargs(
|
|||||||
"message": "hello",
|
"message": "hello",
|
||||||
"language": "en_US",
|
"language": "en_US",
|
||||||
"options": {"age": 5},
|
"options": {"age": 5},
|
||||||
"cache": True,
|
"use_file_cache": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
@ -284,7 +284,7 @@ async def test_generate_media_source_id_and_media_source_id_to_kwargs(
|
|||||||
"message": "hello",
|
"message": "hello",
|
||||||
"language": "en_US",
|
"language": "en_US",
|
||||||
"options": {"age": [5, 6]},
|
"options": {"age": [5, 6]},
|
||||||
"cache": True,
|
"use_file_cache": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
@ -300,5 +300,5 @@ async def test_generate_media_source_id_and_media_source_id_to_kwargs(
|
|||||||
"message": "hello",
|
"message": "hello",
|
||||||
"language": "en_US",
|
"language": "en_US",
|
||||||
"options": {"age": {"k1": [5, 6], "k2": "v2"}},
|
"options": {"age": {"k1": [5, 6], "k2": "v2"}},
|
||||||
"cache": True,
|
"use_file_cache": True,
|
||||||
}
|
}
|
||||||
|
@ -200,7 +200,7 @@ async def test_service_say_error(
|
|||||||
|
|
||||||
assert (
|
assert (
|
||||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||||
== HTTPStatus.NOT_FOUND
|
== HTTPStatus.INTERNAL_SERVER_ERROR
|
||||||
)
|
)
|
||||||
assert len(aioclient_mock.mock_calls) == 1
|
assert len(aioclient_mock.mock_calls) == 1
|
||||||
assert aioclient_mock.mock_calls[0][2] == FORM_DATA
|
assert aioclient_mock.mock_calls[0][2] == FORM_DATA
|
||||||
@ -234,7 +234,7 @@ async def test_service_say_timeout(
|
|||||||
|
|
||||||
assert (
|
assert (
|
||||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||||
== HTTPStatus.NOT_FOUND
|
== HTTPStatus.INTERNAL_SERVER_ERROR
|
||||||
)
|
)
|
||||||
assert len(aioclient_mock.mock_calls) == 1
|
assert len(aioclient_mock.mock_calls) == 1
|
||||||
assert aioclient_mock.mock_calls[0][2] == FORM_DATA
|
assert aioclient_mock.mock_calls[0][2] == FORM_DATA
|
||||||
@ -273,7 +273,7 @@ async def test_service_say_error_msg(
|
|||||||
|
|
||||||
assert (
|
assert (
|
||||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||||
== HTTPStatus.NOT_FOUND
|
== HTTPStatus.INTERNAL_SERVER_ERROR
|
||||||
)
|
)
|
||||||
assert len(aioclient_mock.mock_calls) == 1
|
assert len(aioclient_mock.mock_calls) == 1
|
||||||
assert aioclient_mock.mock_calls[0][2] == FORM_DATA
|
assert aioclient_mock.mock_calls[0][2] == FORM_DATA
|
||||||
|
@ -223,7 +223,7 @@ async def test_service_say_timeout(
|
|||||||
assert len(calls) == 1
|
assert len(calls) == 1
|
||||||
assert (
|
assert (
|
||||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||||
== HTTPStatus.NOT_FOUND
|
== HTTPStatus.INTERNAL_SERVER_ERROR
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(aioclient_mock.mock_calls) == 1
|
assert len(aioclient_mock.mock_calls) == 1
|
||||||
@ -269,7 +269,7 @@ async def test_service_say_http_error(
|
|||||||
assert len(calls) == 1
|
assert len(calls) == 1
|
||||||
assert (
|
assert (
|
||||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||||
== HTTPStatus.NOT_FOUND
|
== HTTPStatus.INTERNAL_SERVER_ERROR
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user