mirror of
https://github.com/home-assistant/core.git
synced 2025-07-25 14:17:45 +00:00
Clean up Text-to-Speech (#143744)
This commit is contained in:
parent
97084e9382
commit
f980434046
@ -3,8 +3,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator, MutableMapping
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import hashlib
|
import hashlib
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
@ -15,7 +15,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import secrets
|
import secrets
|
||||||
from time import monotonic
|
from time import monotonic
|
||||||
from typing import Any, Final
|
from typing import Any, Final, Generic, Protocol, TypeVar
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
import mutagen
|
import mutagen
|
||||||
@ -60,10 +60,10 @@ from .const import (
|
|||||||
DOMAIN,
|
DOMAIN,
|
||||||
TtsAudioType,
|
TtsAudioType,
|
||||||
)
|
)
|
||||||
from .entity import TextToSpeechEntity, TTSAudioRequest
|
from .entity import TextToSpeechEntity, TTSAudioRequest, TTSAudioResponse
|
||||||
from .helper import get_engine_instance
|
from .helper import get_engine_instance
|
||||||
from .legacy import PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE, Provider, async_setup_legacy
|
from .legacy import PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE, Provider, async_setup_legacy
|
||||||
from .media_source import generate_media_source_id, media_source_id_to_kwargs
|
from .media_source import generate_media_source_id, parse_media_source_id
|
||||||
from .models import Voice
|
from .models import Voice
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -79,6 +79,7 @@ __all__ = [
|
|||||||
"Provider",
|
"Provider",
|
||||||
"ResultStream",
|
"ResultStream",
|
||||||
"SampleFormat",
|
"SampleFormat",
|
||||||
|
"TTSAudioResponse",
|
||||||
"TextToSpeechEntity",
|
"TextToSpeechEntity",
|
||||||
"TtsAudioType",
|
"TtsAudioType",
|
||||||
"Voice",
|
"Voice",
|
||||||
@ -264,7 +265,7 @@ def async_create_stream(
|
|||||||
@callback
|
@callback
|
||||||
def async_get_stream(hass: HomeAssistant, token: str) -> ResultStream | None:
|
def async_get_stream(hass: HomeAssistant, token: str) -> ResultStream | None:
|
||||||
"""Return a result stream given a token."""
|
"""Return a result stream given a token."""
|
||||||
return hass.data[DATA_TTS_MANAGER].token_to_stream.get(token)
|
return hass.data[DATA_TTS_MANAGER].async_get_result_stream(token)
|
||||||
|
|
||||||
|
|
||||||
async def async_get_media_source_audio(
|
async def async_get_media_source_audio(
|
||||||
@ -272,12 +273,11 @@ async def async_get_media_source_audio(
|
|||||||
media_source_id: str,
|
media_source_id: str,
|
||||||
) -> tuple[str, bytes]:
|
) -> tuple[str, bytes]:
|
||||||
"""Get TTS audio as extension, data."""
|
"""Get TTS audio as extension, data."""
|
||||||
manager = hass.data[DATA_TTS_MANAGER]
|
parsed = parse_media_source_id(media_source_id)
|
||||||
cache = manager.async_cache_message_in_memory(
|
stream = hass.data[DATA_TTS_MANAGER].async_create_result_stream(**parsed["options"])
|
||||||
**media_source_id_to_kwargs(media_source_id)
|
stream.async_set_message(parsed["message"])
|
||||||
)
|
data = b"".join([chunk async for chunk in stream.async_stream_result()])
|
||||||
data = b"".join([chunk async for chunk in cache.async_stream_data()])
|
return stream.extension, data
|
||||||
return cache.extension, data
|
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@ -457,6 +457,8 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||||||
class ResultStream:
|
class ResultStream:
|
||||||
"""Class that will stream the result when available."""
|
"""Class that will stream the result when available."""
|
||||||
|
|
||||||
|
last_used: float = field(default_factory=monotonic, init=False)
|
||||||
|
|
||||||
# Streaming/conversion properties
|
# Streaming/conversion properties
|
||||||
token: str
|
token: str
|
||||||
extension: str
|
extension: str
|
||||||
@ -480,11 +482,6 @@ class ResultStream:
|
|||||||
"""Get the future that returns the cache."""
|
"""Get the future that returns the cache."""
|
||||||
return asyncio.Future()
|
return asyncio.Future()
|
||||||
|
|
||||||
@callback
|
|
||||||
def async_set_message_cache(self, cache: TTSCache) -> None:
|
|
||||||
"""Set cache containing message audio to be streamed."""
|
|
||||||
self._result_cache.set_result(cache)
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_set_message(self, message: str) -> None:
|
def async_set_message(self, message: str) -> None:
|
||||||
"""Set message to be generated."""
|
"""Set message to be generated."""
|
||||||
@ -504,6 +501,8 @@ class ResultStream:
|
|||||||
async for chunk in cache.async_stream_data():
|
async for chunk in cache.async_stream_data():
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
self.last_used = monotonic()
|
||||||
|
|
||||||
|
|
||||||
def _hash_options(options: dict) -> str:
|
def _hash_options(options: dict) -> str:
|
||||||
"""Hashes an options dictionary."""
|
"""Hashes an options dictionary."""
|
||||||
@ -515,13 +514,25 @@ def _hash_options(options: dict) -> str:
|
|||||||
return opts_hash.hexdigest()
|
return opts_hash.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
class MemcacheCleanup:
|
class HasLastUsed(Protocol):
|
||||||
|
"""Protocol for objects that have a last_used attribute."""
|
||||||
|
|
||||||
|
last_used: float
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=HasLastUsed)
|
||||||
|
|
||||||
|
|
||||||
|
class DictCleaning(Generic[T]):
|
||||||
"""Helper to clean up the stale sessions."""
|
"""Helper to clean up the stale sessions."""
|
||||||
|
|
||||||
unsub: CALLBACK_TYPE | None = None
|
unsub: CALLBACK_TYPE | None = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, hass: HomeAssistant, maxage: float, memcache: dict[str, TTSCache]
|
self,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
maxage: float,
|
||||||
|
memcache: MutableMapping[str, T],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the cleanup."""
|
"""Initialize the cleanup."""
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
@ -588,8 +599,9 @@ class SpeechManager:
|
|||||||
self.file_cache: dict[str, str] = {}
|
self.file_cache: dict[str, str] = {}
|
||||||
self.mem_cache: dict[str, TTSCache] = {}
|
self.mem_cache: dict[str, TTSCache] = {}
|
||||||
self.token_to_stream: dict[str, ResultStream] = {}
|
self.token_to_stream: dict[str, ResultStream] = {}
|
||||||
self.memcache_cleanup = MemcacheCleanup(
|
self.memcache_cleanup = DictCleaning(hass, memory_cache_maxage, self.mem_cache)
|
||||||
hass, memory_cache_maxage, self.mem_cache
|
self.token_to_stream_cleanup = DictCleaning(
|
||||||
|
hass, memory_cache_maxage, self.token_to_stream
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init_cache(self) -> dict[str, str]:
|
def _init_cache(self) -> dict[str, str]:
|
||||||
@ -679,11 +691,21 @@ class SpeechManager:
|
|||||||
|
|
||||||
return language, merged_options
|
return language, merged_options
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_get_result_stream(
|
||||||
|
self,
|
||||||
|
token: str,
|
||||||
|
) -> ResultStream | None:
|
||||||
|
"""Return a result stream given a token."""
|
||||||
|
stream = self.token_to_stream.get(token, None)
|
||||||
|
if stream:
|
||||||
|
stream.last_used = monotonic()
|
||||||
|
return stream
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_create_result_stream(
|
def async_create_result_stream(
|
||||||
self,
|
self,
|
||||||
engine: str,
|
engine: str,
|
||||||
message: str | None = None,
|
|
||||||
use_file_cache: bool | None = None,
|
use_file_cache: bool | None = None,
|
||||||
language: str | None = None,
|
language: str | None = None,
|
||||||
options: dict | None = None,
|
options: dict | None = None,
|
||||||
@ -710,23 +732,7 @@ class SpeechManager:
|
|||||||
_manager=self,
|
_manager=self,
|
||||||
)
|
)
|
||||||
self.token_to_stream[token] = result_stream
|
self.token_to_stream[token] = result_stream
|
||||||
|
self.token_to_stream_cleanup.schedule()
|
||||||
if message is None:
|
|
||||||
return result_stream
|
|
||||||
|
|
||||||
# We added this method as an alternative to stream.async_set_message
|
|
||||||
# to avoid the options being processed twice
|
|
||||||
result_stream.async_set_message_cache(
|
|
||||||
self._async_ensure_cached_in_memory(
|
|
||||||
engine=engine,
|
|
||||||
engine_instance=engine_instance,
|
|
||||||
message=message,
|
|
||||||
use_file_cache=use_file_cache,
|
|
||||||
language=language,
|
|
||||||
options=options,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return result_stream
|
return result_stream
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@ -734,41 +740,17 @@ class SpeechManager:
|
|||||||
self,
|
self,
|
||||||
engine: str,
|
engine: str,
|
||||||
message: str,
|
message: str,
|
||||||
use_file_cache: bool | None = None,
|
|
||||||
language: str | None = None,
|
|
||||||
options: dict | None = None,
|
|
||||||
) -> TTSCache:
|
|
||||||
"""Make sure a message is cached in memory and returns cache key."""
|
|
||||||
if (engine_instance := get_engine_instance(self.hass, engine)) is None:
|
|
||||||
raise HomeAssistantError(f"Provider {engine} not found")
|
|
||||||
|
|
||||||
language, options = self.process_options(engine_instance, language, options)
|
|
||||||
if use_file_cache is None:
|
|
||||||
use_file_cache = self.use_file_cache
|
|
||||||
|
|
||||||
return self._async_ensure_cached_in_memory(
|
|
||||||
engine=engine,
|
|
||||||
engine_instance=engine_instance,
|
|
||||||
message=message,
|
|
||||||
use_file_cache=use_file_cache,
|
|
||||||
language=language,
|
|
||||||
options=options,
|
|
||||||
)
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def _async_ensure_cached_in_memory(
|
|
||||||
self,
|
|
||||||
engine: str,
|
|
||||||
engine_instance: TextToSpeechEntity | Provider,
|
|
||||||
message: str,
|
|
||||||
use_file_cache: bool,
|
use_file_cache: bool,
|
||||||
language: str,
|
language: str,
|
||||||
options: dict,
|
options: dict,
|
||||||
) -> TTSCache:
|
) -> TTSCache:
|
||||||
"""Ensure a message is cached.
|
"""Make sure a message will be cached in memory and returns cache object.
|
||||||
|
|
||||||
Requires options, language to be processed.
|
Requires options, language to be processed.
|
||||||
"""
|
"""
|
||||||
|
if (engine_instance := get_engine_instance(self.hass, engine)) is None:
|
||||||
|
raise HomeAssistantError(f"Provider {engine} not found")
|
||||||
|
|
||||||
options_key = _hash_options(options) if options else "-"
|
options_key = _hash_options(options) if options else "-"
|
||||||
msg_hash = hashlib.sha1(bytes(message, "utf-8")).hexdigest()
|
msg_hash = hashlib.sha1(bytes(message, "utf-8")).hexdigest()
|
||||||
cache_key = KEY_PATTERN.format(
|
cache_key = KEY_PATTERN.format(
|
||||||
@ -789,9 +771,13 @@ class SpeechManager:
|
|||||||
store_to_disk = False
|
store_to_disk = False
|
||||||
else:
|
else:
|
||||||
_LOGGER.debug("Generating audio for %s", message[0:32])
|
_LOGGER.debug("Generating audio for %s", message[0:32])
|
||||||
|
|
||||||
|
async def message_stream() -> AsyncGenerator[str]:
|
||||||
|
yield message
|
||||||
|
|
||||||
extension = options.get(ATTR_PREFERRED_FORMAT, _DEFAULT_FORMAT)
|
extension = options.get(ATTR_PREFERRED_FORMAT, _DEFAULT_FORMAT)
|
||||||
data_gen = self._async_generate_tts_audio(
|
data_gen = self._async_generate_tts_audio(
|
||||||
engine_instance, message, language, options
|
engine_instance, message_stream(), language, options
|
||||||
)
|
)
|
||||||
|
|
||||||
cache = TTSCache(
|
cache = TTSCache(
|
||||||
@ -799,7 +785,6 @@ class SpeechManager:
|
|||||||
extension=extension,
|
extension=extension,
|
||||||
data_gen=data_gen,
|
data_gen=data_gen,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mem_cache[cache_key] = cache
|
self.mem_cache[cache_key] = cache
|
||||||
self.hass.async_create_background_task(
|
self.hass.async_create_background_task(
|
||||||
self._load_data_into_cache(
|
self._load_data_into_cache(
|
||||||
@ -866,7 +851,7 @@ class SpeechManager:
|
|||||||
async def _async_generate_tts_audio(
|
async def _async_generate_tts_audio(
|
||||||
self,
|
self,
|
||||||
engine_instance: TextToSpeechEntity | Provider,
|
engine_instance: TextToSpeechEntity | Provider,
|
||||||
message: str,
|
message_stream: AsyncGenerator[str],
|
||||||
language: str,
|
language: str,
|
||||||
options: dict[str, Any],
|
options: dict[str, Any],
|
||||||
) -> AsyncGenerator[bytes]:
|
) -> AsyncGenerator[bytes]:
|
||||||
@ -915,6 +900,7 @@ class SpeechManager:
|
|||||||
raise HomeAssistantError("TTS engine name is not set.")
|
raise HomeAssistantError("TTS engine name is not set.")
|
||||||
|
|
||||||
if isinstance(engine_instance, Provider):
|
if isinstance(engine_instance, Provider):
|
||||||
|
message = "".join([chunk async for chunk in message_stream])
|
||||||
extension, data = await engine_instance.async_get_tts_audio(
|
extension, data = await engine_instance.async_get_tts_audio(
|
||||||
message, language, options
|
message, language, options
|
||||||
)
|
)
|
||||||
@ -930,12 +916,8 @@ class SpeechManager:
|
|||||||
data_gen = make_data_generator(data)
|
data_gen = make_data_generator(data)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
async def message_gen() -> AsyncGenerator[str]:
|
|
||||||
yield message
|
|
||||||
|
|
||||||
tts_result = await engine_instance.internal_async_stream_tts_audio(
|
tts_result = await engine_instance.internal_async_stream_tts_audio(
|
||||||
TTSAudioRequest(language, options, message_gen())
|
TTSAudioRequest(language, options, message_stream)
|
||||||
)
|
)
|
||||||
extension = tts_result.extension
|
extension = tts_result.extension
|
||||||
data_gen = tts_result.data_gen
|
data_gen = tts_result.data_gen
|
||||||
@ -1096,7 +1078,6 @@ class TextToSpeechUrlView(HomeAssistantView):
|
|||||||
try:
|
try:
|
||||||
stream = self.manager.async_create_result_stream(
|
stream = self.manager.async_create_result_stream(
|
||||||
engine,
|
engine,
|
||||||
message,
|
|
||||||
use_file_cache=use_file_cache,
|
use_file_cache=use_file_cache,
|
||||||
language=language,
|
language=language,
|
||||||
options=options,
|
options=options,
|
||||||
@ -1105,6 +1086,8 @@ class TextToSpeechUrlView(HomeAssistantView):
|
|||||||
_LOGGER.error("Error on init tts: %s", err)
|
_LOGGER.error("Error on init tts: %s", err)
|
||||||
return self.json({"error": err}, HTTPStatus.BAD_REQUEST)
|
return self.json({"error": err}, HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
stream.async_set_message(message)
|
||||||
|
|
||||||
base = get_url(self.manager.hass)
|
base = get_url(self.manager.hass)
|
||||||
url = base + stream.url
|
url = base + stream.url
|
||||||
|
|
||||||
|
@ -69,14 +69,20 @@ class MediaSourceOptions(TypedDict):
|
|||||||
"""Media source options."""
|
"""Media source options."""
|
||||||
|
|
||||||
engine: str
|
engine: str
|
||||||
message: str
|
|
||||||
language: str | None
|
language: str | None
|
||||||
options: dict | None
|
options: dict | None
|
||||||
use_file_cache: bool | None
|
use_file_cache: bool | None
|
||||||
|
|
||||||
|
|
||||||
|
class ParsedMediaSourceId(TypedDict):
|
||||||
|
"""Parsed media source ID."""
|
||||||
|
|
||||||
|
options: MediaSourceOptions
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def media_source_id_to_kwargs(media_source_id: str) -> MediaSourceOptions:
|
def parse_media_source_id(media_source_id: str) -> ParsedMediaSourceId:
|
||||||
"""Turn a media source ID into options."""
|
"""Turn a media source ID into options."""
|
||||||
parsed = URL(media_source_id)
|
parsed = URL(media_source_id)
|
||||||
if URL_QUERY_TTS_OPTIONS in parsed.query:
|
if URL_QUERY_TTS_OPTIONS in parsed.query:
|
||||||
@ -94,7 +100,6 @@ def media_source_id_to_kwargs(media_source_id: str) -> MediaSourceOptions:
|
|||||||
raise Unresolvable("No message specified.")
|
raise Unresolvable("No message specified.")
|
||||||
kwargs: MediaSourceOptions = {
|
kwargs: MediaSourceOptions = {
|
||||||
"engine": parsed.name,
|
"engine": parsed.name,
|
||||||
"message": parsed.query["message"],
|
|
||||||
"language": parsed.query.get("language"),
|
"language": parsed.query.get("language"),
|
||||||
"options": options,
|
"options": options,
|
||||||
"use_file_cache": None,
|
"use_file_cache": None,
|
||||||
@ -102,7 +107,7 @@ def media_source_id_to_kwargs(media_source_id: str) -> MediaSourceOptions:
|
|||||||
if "cache" in parsed.query:
|
if "cache" in parsed.query:
|
||||||
kwargs["use_file_cache"] = parsed.query["cache"] == "true"
|
kwargs["use_file_cache"] = parsed.query["cache"] == "true"
|
||||||
|
|
||||||
return kwargs
|
return {"message": parsed.query["message"], "options": kwargs}
|
||||||
|
|
||||||
|
|
||||||
class TTSMediaSource(MediaSource):
|
class TTSMediaSource(MediaSource):
|
||||||
@ -118,9 +123,11 @@ class TTSMediaSource(MediaSource):
|
|||||||
async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia:
|
async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia:
|
||||||
"""Resolve media to a url."""
|
"""Resolve media to a url."""
|
||||||
try:
|
try:
|
||||||
|
parsed = parse_media_source_id(item.identifier)
|
||||||
stream = self.hass.data[DATA_TTS_MANAGER].async_create_result_stream(
|
stream = self.hass.data[DATA_TTS_MANAGER].async_create_result_stream(
|
||||||
**media_source_id_to_kwargs(item.identifier)
|
**parsed["options"]
|
||||||
)
|
)
|
||||||
|
stream.async_set_message(parsed["message"])
|
||||||
except Unresolvable:
|
except Unresolvable:
|
||||||
raise
|
raise
|
||||||
except HomeAssistantError as err:
|
except HomeAssistantError as err:
|
||||||
|
@ -42,6 +42,7 @@ from tests.typing import ClientSessionGenerator
|
|||||||
DEFAULT_LANG = "en_US"
|
DEFAULT_LANG = "en_US"
|
||||||
SUPPORT_LANGUAGES = ["de_CH", "de_DE", "en_GB", "en_US"]
|
SUPPORT_LANGUAGES = ["de_CH", "de_DE", "en_GB", "en_US"]
|
||||||
TEST_DOMAIN = "test"
|
TEST_DOMAIN = "test"
|
||||||
|
MOCK_DATA = b"123"
|
||||||
|
|
||||||
|
|
||||||
def mock_tts_get_cache_files_fixture_helper() -> Generator[MagicMock]:
|
def mock_tts_get_cache_files_fixture_helper() -> Generator[MagicMock]:
|
||||||
@ -164,7 +165,7 @@ class BaseProvider:
|
|||||||
self, message: str, language: str, options: dict[str, Any]
|
self, message: str, language: str, options: dict[str, Any]
|
||||||
) -> TtsAudioType:
|
) -> TtsAudioType:
|
||||||
"""Load TTS dat."""
|
"""Load TTS dat."""
|
||||||
return ("mp3", b"")
|
return ("mp3", MOCK_DATA)
|
||||||
|
|
||||||
|
|
||||||
class MockTTSProvider(BaseProvider, Provider):
|
class MockTTSProvider(BaseProvider, Provider):
|
||||||
|
@ -27,6 +27,7 @@ from homeassistant.util import dt as dt_util
|
|||||||
|
|
||||||
from .common import (
|
from .common import (
|
||||||
DEFAULT_LANG,
|
DEFAULT_LANG,
|
||||||
|
MOCK_DATA,
|
||||||
TEST_DOMAIN,
|
TEST_DOMAIN,
|
||||||
MockResultStream,
|
MockResultStream,
|
||||||
MockTTS,
|
MockTTS,
|
||||||
@ -808,7 +809,7 @@ async def test_service_receive_voice(
|
|||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
req = await client.get(url)
|
req = await client.get(url)
|
||||||
tts_data = b""
|
tts_data = MOCK_DATA
|
||||||
tts_data = tts.SpeechManager.write_tags(
|
tts_data = tts.SpeechManager.write_tags(
|
||||||
f"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_{expected_url_suffix}.mp3",
|
f"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_{expected_url_suffix}.mp3",
|
||||||
tts_data,
|
tts_data,
|
||||||
@ -879,7 +880,7 @@ async def test_service_receive_voice_german(
|
|||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
req = await client.get(url)
|
req = await client.get(url)
|
||||||
tts_data = b""
|
tts_data = MOCK_DATA
|
||||||
tts_data = tts.SpeechManager.write_tags(
|
tts_data = tts.SpeechManager.write_tags(
|
||||||
"42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_-_{expected_url_suffix}.mp3",
|
"42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_-_{expected_url_suffix}.mp3",
|
||||||
tts_data,
|
tts_data,
|
||||||
@ -1021,7 +1022,7 @@ async def test_setup_legacy_cache_dir(
|
|||||||
"""Set up a TTS platform with cache and call service without cache."""
|
"""Set up a TTS platform with cache and call service without cache."""
|
||||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||||
|
|
||||||
tts_data = b""
|
tts_data = MOCK_DATA
|
||||||
cache_file = (
|
cache_file = (
|
||||||
mock_tts_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3"
|
mock_tts_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3"
|
||||||
)
|
)
|
||||||
@ -1059,7 +1060,7 @@ async def test_setup_cache_dir(
|
|||||||
"""Set up a TTS platform with cache and call service without cache."""
|
"""Set up a TTS platform with cache and call service without cache."""
|
||||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||||
|
|
||||||
tts_data = b""
|
tts_data = MOCK_DATA
|
||||||
cache_file = mock_tts_cache_dir / (
|
cache_file = mock_tts_cache_dir / (
|
||||||
"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_tts.test.mp3"
|
"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_tts.test.mp3"
|
||||||
)
|
)
|
||||||
@ -1165,7 +1166,7 @@ async def test_legacy_cannot_retrieve_without_token(
|
|||||||
hass_client: ClientSessionGenerator,
|
hass_client: ClientSessionGenerator,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Verify that a TTS cannot be retrieved by filename directly."""
|
"""Verify that a TTS cannot be retrieved by filename directly."""
|
||||||
tts_data = b""
|
tts_data = MOCK_DATA
|
||||||
cache_file = (
|
cache_file = (
|
||||||
mock_tts_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3"
|
mock_tts_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3"
|
||||||
)
|
)
|
||||||
@ -1188,7 +1189,7 @@ async def test_cannot_retrieve_without_token(
|
|||||||
hass_client: ClientSessionGenerator,
|
hass_client: ClientSessionGenerator,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Verify that a TTS cannot be retrieved by filename directly."""
|
"""Verify that a TTS cannot be retrieved by filename directly."""
|
||||||
tts_data = b""
|
tts_data = MOCK_DATA
|
||||||
cache_file = mock_tts_cache_dir / (
|
cache_file = mock_tts_cache_dir / (
|
||||||
"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_tts.test.mp3"
|
"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_tts.test.mp3"
|
||||||
)
|
)
|
||||||
@ -1845,6 +1846,9 @@ async def test_stream(hass: HomeAssistant, mock_tts_entity: MockTTSEntity) -> No
|
|||||||
assert stream.language == mock_tts_entity.default_language
|
assert stream.language == mock_tts_entity.default_language
|
||||||
assert stream.options == (mock_tts_entity.default_options or {})
|
assert stream.options == (mock_tts_entity.default_options or {})
|
||||||
assert tts.async_get_stream(hass, stream.token) is stream
|
assert tts.async_get_stream(hass, stream.token) is stream
|
||||||
|
stream.async_set_message("beer")
|
||||||
|
result_data = b"".join([chunk async for chunk in stream.async_stream_result()])
|
||||||
|
assert result_data == MOCK_DATA
|
||||||
|
|
||||||
data = b"beer"
|
data = b"beer"
|
||||||
stream2 = MockResultStream(hass, "wav", data)
|
stream2 = MockResultStream(hass, "wav", data)
|
||||||
|
@ -9,9 +9,8 @@ import pytest
|
|||||||
from homeassistant.components import media_source
|
from homeassistant.components import media_source
|
||||||
from homeassistant.components.media_player import BrowseError
|
from homeassistant.components.media_player import BrowseError
|
||||||
from homeassistant.components.tts.media_source import (
|
from homeassistant.components.tts.media_source import (
|
||||||
MediaSourceOptions,
|
|
||||||
generate_media_source_id,
|
generate_media_source_id,
|
||||||
media_source_id_to_kwargs,
|
parse_media_source_id,
|
||||||
)
|
)
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
@ -249,13 +248,13 @@ async def test_resolving_errors(hass: HomeAssistant, setup: str, engine: str) ->
|
|||||||
],
|
],
|
||||||
indirect=["setup"],
|
indirect=["setup"],
|
||||||
)
|
)
|
||||||
async def test_generate_media_source_id_and_media_source_id_to_kwargs(
|
async def test_generate_media_source_id_and_parse_media_source_id(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
setup: str,
|
setup: str,
|
||||||
result_engine: str,
|
result_engine: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test media_source_id and media_source_id_to_kwargs."""
|
"""Test media_source_id and parse_media_source_id."""
|
||||||
kwargs: MediaSourceOptions = {
|
kwargs = {
|
||||||
"engine": None,
|
"engine": None,
|
||||||
"message": "hello",
|
"message": "hello",
|
||||||
"language": "en_US",
|
"language": "en_US",
|
||||||
@ -263,12 +262,14 @@ async def test_generate_media_source_id_and_media_source_id_to_kwargs(
|
|||||||
"cache": True,
|
"cache": True,
|
||||||
}
|
}
|
||||||
media_source_id = generate_media_source_id(hass, **kwargs)
|
media_source_id = generate_media_source_id(hass, **kwargs)
|
||||||
assert media_source_id_to_kwargs(media_source_id) == {
|
assert parse_media_source_id(media_source_id) == {
|
||||||
"engine": result_engine,
|
|
||||||
"message": "hello",
|
"message": "hello",
|
||||||
"language": "en_US",
|
"options": {
|
||||||
"options": {"age": 5},
|
"engine": result_engine,
|
||||||
"use_file_cache": True,
|
"language": "en_US",
|
||||||
|
"options": {"age": 5},
|
||||||
|
"use_file_cache": True,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
@ -279,12 +280,14 @@ async def test_generate_media_source_id_and_media_source_id_to_kwargs(
|
|||||||
"cache": True,
|
"cache": True,
|
||||||
}
|
}
|
||||||
media_source_id = generate_media_source_id(hass, **kwargs)
|
media_source_id = generate_media_source_id(hass, **kwargs)
|
||||||
assert media_source_id_to_kwargs(media_source_id) == {
|
assert parse_media_source_id(media_source_id) == {
|
||||||
"engine": result_engine,
|
|
||||||
"message": "hello",
|
"message": "hello",
|
||||||
"language": "en_US",
|
"options": {
|
||||||
"options": {"age": [5, 6]},
|
"engine": result_engine,
|
||||||
"use_file_cache": True,
|
"language": "en_US",
|
||||||
|
"options": {"age": [5, 6]},
|
||||||
|
"use_file_cache": True,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
@ -295,10 +298,12 @@ async def test_generate_media_source_id_and_media_source_id_to_kwargs(
|
|||||||
"cache": True,
|
"cache": True,
|
||||||
}
|
}
|
||||||
media_source_id = generate_media_source_id(hass, **kwargs)
|
media_source_id = generate_media_source_id(hass, **kwargs)
|
||||||
assert media_source_id_to_kwargs(media_source_id) == {
|
assert parse_media_source_id(media_source_id) == {
|
||||||
"engine": result_engine,
|
|
||||||
"message": "hello",
|
"message": "hello",
|
||||||
"language": "en_US",
|
"options": {
|
||||||
"options": {"age": {"k1": [5, 6], "k2": "v2"}},
|
"engine": result_engine,
|
||||||
"use_file_cache": True,
|
"language": "en_US",
|
||||||
|
"options": {"age": {"k1": [5, 6], "k2": "v2"}},
|
||||||
|
"use_file_cache": True,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user