mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 03:07:37 +00:00
Move legacy tts (#91538)
* Move legacy tts * Add error log on unknown platform * Add legacy tests and delint all tests * Consolidate log format * Add more legacy tests * Test default legacy provider attributes * Remove test generated files * Clean up after merge conflict
This commit is contained in:
parent
0ecd23baee
commit
9bd12f6503
@ -1,135 +1,77 @@
|
|||||||
"""Provide functionality for TTS."""
|
"""Provide functionality for TTS."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Mapping
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import functools as ft
|
|
||||||
import hashlib
|
import hashlib
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Any, TypedDict, cast
|
from typing import TypedDict
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
import mutagen
|
import mutagen
|
||||||
from mutagen.id3 import ID3, TextFrame as ID3Text
|
from mutagen.id3 import ID3, TextFrame as ID3Text
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
import yarl
|
|
||||||
|
|
||||||
from homeassistant.components.http import HomeAssistantView
|
from homeassistant.components.http import HomeAssistantView
|
||||||
from homeassistant.components.media_player import (
|
from homeassistant.const import PLATFORM_FORMAT
|
||||||
ATTR_MEDIA_ANNOUNCE,
|
|
||||||
ATTR_MEDIA_CONTENT_ID,
|
|
||||||
ATTR_MEDIA_CONTENT_TYPE,
|
|
||||||
DOMAIN as DOMAIN_MP,
|
|
||||||
SERVICE_PLAY_MEDIA,
|
|
||||||
MediaType,
|
|
||||||
)
|
|
||||||
from homeassistant.const import (
|
|
||||||
ATTR_ENTITY_ID,
|
|
||||||
CONF_DESCRIPTION,
|
|
||||||
CONF_NAME,
|
|
||||||
CONF_PLATFORM,
|
|
||||||
PLATFORM_FORMAT,
|
|
||||||
)
|
|
||||||
from homeassistant.core import HassJob, HomeAssistant, ServiceCall, callback
|
from homeassistant.core import HassJob, HomeAssistant, ServiceCall, callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import config_per_platform, discovery
|
|
||||||
import homeassistant.helpers.config_validation as cv
|
|
||||||
from homeassistant.helpers.event import async_call_later
|
from homeassistant.helpers.event import async_call_later
|
||||||
from homeassistant.helpers.network import get_url
|
from homeassistant.helpers.network import get_url
|
||||||
from homeassistant.helpers.service import async_set_service_schema
|
from homeassistant.helpers.typing import ConfigType
|
||||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
|
||||||
from homeassistant.setup import async_prepare_setup_platform
|
|
||||||
from homeassistant.util.network import normalize_url
|
|
||||||
from homeassistant.util.yaml import load_yaml
|
|
||||||
|
|
||||||
from .const import DOMAIN
|
from .const import (
|
||||||
|
ATTR_CACHE,
|
||||||
|
ATTR_LANGUAGE,
|
||||||
|
ATTR_MESSAGE,
|
||||||
|
ATTR_OPTIONS,
|
||||||
|
CONF_BASE_URL,
|
||||||
|
CONF_CACHE,
|
||||||
|
CONF_CACHE_DIR,
|
||||||
|
CONF_TIME_MEMORY,
|
||||||
|
DEFAULT_CACHE,
|
||||||
|
DEFAULT_CACHE_DIR,
|
||||||
|
DEFAULT_TIME_MEMORY,
|
||||||
|
DOMAIN,
|
||||||
|
TtsAudioType,
|
||||||
|
)
|
||||||
|
from .legacy import PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE, Provider, async_setup_legacy
|
||||||
from .media_source import generate_media_source_id, media_source_id_to_kwargs
|
from .media_source import generate_media_source_id, media_source_id_to_kwargs
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"async_get_media_source_audio",
|
||||||
|
"async_resolve_engine",
|
||||||
|
"async_support_options",
|
||||||
|
"ATTR_AUDIO_OUTPUT",
|
||||||
|
"CONF_LANG",
|
||||||
|
"DEFAULT_CACHE_DIR",
|
||||||
|
"generate_media_source_id",
|
||||||
|
"get_base_url",
|
||||||
|
"PLATFORM_SCHEMA_BASE",
|
||||||
|
"PLATFORM_SCHEMA",
|
||||||
|
"Provider",
|
||||||
|
"TtsAudioType",
|
||||||
|
]
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
TtsAudioType = tuple[str | None, bytes | None]
|
|
||||||
|
|
||||||
ATTR_CACHE = "cache"
|
|
||||||
ATTR_LANGUAGE = "language"
|
|
||||||
ATTR_MESSAGE = "message"
|
|
||||||
ATTR_OPTIONS = "options"
|
|
||||||
ATTR_PLATFORM = "platform"
|
ATTR_PLATFORM = "platform"
|
||||||
ATTR_AUDIO_OUTPUT = "audio_output"
|
ATTR_AUDIO_OUTPUT = "audio_output"
|
||||||
|
|
||||||
|
CONF_LANG = "language"
|
||||||
|
|
||||||
BASE_URL_KEY = "tts_base_url"
|
BASE_URL_KEY = "tts_base_url"
|
||||||
|
|
||||||
CONF_BASE_URL = "base_url"
|
|
||||||
CONF_CACHE = "cache"
|
|
||||||
CONF_CACHE_DIR = "cache_dir"
|
|
||||||
CONF_LANG = "language"
|
|
||||||
CONF_SERVICE_NAME = "service_name"
|
|
||||||
CONF_TIME_MEMORY = "time_memory"
|
|
||||||
|
|
||||||
CONF_FIELDS = "fields"
|
|
||||||
|
|
||||||
DEFAULT_CACHE = True
|
|
||||||
DEFAULT_CACHE_DIR = "tts"
|
|
||||||
DEFAULT_TIME_MEMORY = 300
|
|
||||||
|
|
||||||
SERVICE_CLEAR_CACHE = "clear_cache"
|
SERVICE_CLEAR_CACHE = "clear_cache"
|
||||||
SERVICE_SAY = "say"
|
|
||||||
|
|
||||||
_RE_VOICE_FILE = re.compile(r"([a-f0-9]{40})_([^_]+)_([^_]+)_([a-z_]+)\.[a-z0-9]{3,4}")
|
_RE_VOICE_FILE = re.compile(r"([a-f0-9]{40})_([^_]+)_([^_]+)_([a-z_]+)\.[a-z0-9]{3,4}")
|
||||||
KEY_PATTERN = "{0}_{1}_{2}_{3}"
|
KEY_PATTERN = "{0}_{1}_{2}_{3}"
|
||||||
|
|
||||||
|
|
||||||
def _deprecated_platform(value: str) -> str:
|
|
||||||
"""Validate if platform is deprecated."""
|
|
||||||
if value == "google":
|
|
||||||
raise vol.Invalid(
|
|
||||||
"google tts service has been renamed to google_translate,"
|
|
||||||
" please update your configuration."
|
|
||||||
)
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def valid_base_url(value: str) -> str:
|
|
||||||
"""Validate base url, return value."""
|
|
||||||
url = yarl.URL(cv.url(value))
|
|
||||||
|
|
||||||
if url.path != "/":
|
|
||||||
raise vol.Invalid("Path should be empty")
|
|
||||||
|
|
||||||
return normalize_url(value)
|
|
||||||
|
|
||||||
|
|
||||||
PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend(
|
|
||||||
{
|
|
||||||
vol.Required(CONF_PLATFORM): vol.All(cv.string, _deprecated_platform),
|
|
||||||
vol.Optional(CONF_CACHE, default=DEFAULT_CACHE): cv.boolean,
|
|
||||||
vol.Optional(CONF_CACHE_DIR, default=DEFAULT_CACHE_DIR): cv.string,
|
|
||||||
vol.Optional(CONF_TIME_MEMORY, default=DEFAULT_TIME_MEMORY): vol.All(
|
|
||||||
vol.Coerce(int), vol.Range(min=60, max=57600)
|
|
||||||
),
|
|
||||||
vol.Optional(CONF_BASE_URL): valid_base_url,
|
|
||||||
vol.Optional(CONF_SERVICE_NAME): cv.string,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
PLATFORM_SCHEMA_BASE = cv.PLATFORM_SCHEMA_BASE.extend(PLATFORM_SCHEMA.schema)
|
|
||||||
|
|
||||||
SCHEMA_SERVICE_SAY = vol.Schema(
|
|
||||||
{
|
|
||||||
vol.Required(ATTR_MESSAGE): cv.string,
|
|
||||||
vol.Optional(ATTR_CACHE): cv.boolean,
|
|
||||||
vol.Required(ATTR_ENTITY_ID): cv.comp_entity_ids,
|
|
||||||
vol.Optional(ATTR_LANGUAGE): cv.string,
|
|
||||||
vol.Optional(ATTR_OPTIONS): dict,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({})
|
SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({})
|
||||||
|
|
||||||
|
|
||||||
@ -192,14 +134,12 @@ async def async_get_media_source_audio(
|
|||||||
|
|
||||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
"""Set up TTS."""
|
"""Set up TTS."""
|
||||||
tts = SpeechManager(hass)
|
# Legacy config options
|
||||||
|
conf = config[DOMAIN][0] if config.get(DOMAIN) else {}
|
||||||
try:
|
use_cache: bool = conf.get(CONF_CACHE, DEFAULT_CACHE)
|
||||||
conf = config[DOMAIN][0] if config.get(DOMAIN, []) else {}
|
cache_dir: str = conf.get(CONF_CACHE_DIR, DEFAULT_CACHE_DIR)
|
||||||
use_cache = conf.get(CONF_CACHE, DEFAULT_CACHE)
|
time_memory: int = conf.get(CONF_TIME_MEMORY, DEFAULT_TIME_MEMORY)
|
||||||
cache_dir = conf.get(CONF_CACHE_DIR, DEFAULT_CACHE_DIR)
|
base_url: str | None = conf.get(CONF_BASE_URL)
|
||||||
time_memory = conf.get(CONF_TIME_MEMORY, DEFAULT_TIME_MEMORY)
|
|
||||||
base_url = conf.get(CONF_BASE_URL)
|
|
||||||
if base_url is not None:
|
if base_url is not None:
|
||||||
_LOGGER.warning(
|
_LOGGER.warning(
|
||||||
"TTS base_url option is deprecated. Configure internal/external URL"
|
"TTS base_url option is deprecated. Configure internal/external URL"
|
||||||
@ -207,7 +147,10 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
)
|
)
|
||||||
hass.data[BASE_URL_KEY] = base_url
|
hass.data[BASE_URL_KEY] = base_url
|
||||||
|
|
||||||
await tts.async_init_cache(use_cache, cache_dir, time_memory, base_url)
|
tts = SpeechManager(hass, use_cache, cache_dir, time_memory, base_url)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await tts.async_init_cache()
|
||||||
except (HomeAssistantError, KeyError):
|
except (HomeAssistantError, KeyError):
|
||||||
_LOGGER.exception("Error on cache init")
|
_LOGGER.exception("Error on cache init")
|
||||||
return False
|
return False
|
||||||
@ -216,99 +159,10 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
hass.http.register_view(TextToSpeechView(tts))
|
hass.http.register_view(TextToSpeechView(tts))
|
||||||
hass.http.register_view(TextToSpeechUrlView(tts))
|
hass.http.register_view(TextToSpeechUrlView(tts))
|
||||||
|
|
||||||
# Load service descriptions from tts/services.yaml
|
platform_setups = await async_setup_legacy(hass, config)
|
||||||
services_yaml = Path(__file__).parent / "services.yaml"
|
|
||||||
services_dict = cast(
|
|
||||||
dict, await hass.async_add_executor_job(load_yaml, str(services_yaml))
|
|
||||||
)
|
|
||||||
|
|
||||||
async def async_setup_platform(
|
if platform_setups:
|
||||||
p_type: str,
|
await asyncio.wait([asyncio.create_task(setup) for setup in platform_setups])
|
||||||
p_config: ConfigType | None = None,
|
|
||||||
discovery_info: DiscoveryInfoType | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Set up a TTS platform."""
|
|
||||||
if p_config is None:
|
|
||||||
p_config = {}
|
|
||||||
|
|
||||||
platform = await async_prepare_setup_platform(hass, config, DOMAIN, p_type)
|
|
||||||
if platform is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
if hasattr(platform, "async_get_engine"):
|
|
||||||
provider = await platform.async_get_engine(
|
|
||||||
hass, p_config, discovery_info
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
provider = await hass.async_add_executor_job(
|
|
||||||
platform.get_engine, hass, p_config, discovery_info
|
|
||||||
)
|
|
||||||
|
|
||||||
if provider is None:
|
|
||||||
_LOGGER.error("Error setting up platform %s", p_type)
|
|
||||||
return
|
|
||||||
|
|
||||||
tts.async_register_engine(p_type, provider, p_config)
|
|
||||||
except Exception: # pylint: disable=broad-except
|
|
||||||
_LOGGER.exception("Error setting up platform: %s", p_type)
|
|
||||||
return
|
|
||||||
|
|
||||||
async def async_say_handle(service: ServiceCall) -> None:
|
|
||||||
"""Service handle for say."""
|
|
||||||
entity_ids = service.data[ATTR_ENTITY_ID]
|
|
||||||
|
|
||||||
await hass.services.async_call(
|
|
||||||
DOMAIN_MP,
|
|
||||||
SERVICE_PLAY_MEDIA,
|
|
||||||
{
|
|
||||||
ATTR_ENTITY_ID: entity_ids,
|
|
||||||
ATTR_MEDIA_CONTENT_ID: generate_media_source_id(
|
|
||||||
hass,
|
|
||||||
engine=p_type,
|
|
||||||
message=service.data[ATTR_MESSAGE],
|
|
||||||
language=service.data.get(ATTR_LANGUAGE),
|
|
||||||
options=service.data.get(ATTR_OPTIONS),
|
|
||||||
cache=service.data.get(ATTR_CACHE),
|
|
||||||
),
|
|
||||||
ATTR_MEDIA_CONTENT_TYPE: MediaType.MUSIC,
|
|
||||||
ATTR_MEDIA_ANNOUNCE: True,
|
|
||||||
},
|
|
||||||
blocking=True,
|
|
||||||
context=service.context,
|
|
||||||
)
|
|
||||||
|
|
||||||
service_name = p_config.get(CONF_SERVICE_NAME, f"{p_type}_{SERVICE_SAY}")
|
|
||||||
hass.services.async_register(
|
|
||||||
DOMAIN, service_name, async_say_handle, schema=SCHEMA_SERVICE_SAY
|
|
||||||
)
|
|
||||||
|
|
||||||
# Register the service description
|
|
||||||
service_desc = {
|
|
||||||
CONF_NAME: f"Say a TTS message with {p_type}",
|
|
||||||
CONF_DESCRIPTION: (
|
|
||||||
f"Say something using text-to-speech on a media player with {p_type}."
|
|
||||||
),
|
|
||||||
CONF_FIELDS: services_dict[SERVICE_SAY][CONF_FIELDS],
|
|
||||||
}
|
|
||||||
async_set_service_schema(hass, DOMAIN, service_name, service_desc)
|
|
||||||
|
|
||||||
setup_tasks = [
|
|
||||||
asyncio.create_task(async_setup_platform(p_type, p_config))
|
|
||||||
for p_type, p_config in config_per_platform(config, DOMAIN)
|
|
||||||
if p_type is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
if setup_tasks:
|
|
||||||
await asyncio.wait(setup_tasks)
|
|
||||||
|
|
||||||
async def async_platform_discovered(
|
|
||||||
platform: str, info: dict[str, Any] | None
|
|
||||||
) -> None:
|
|
||||||
"""Handle for discovered platform."""
|
|
||||||
await async_setup_platform(platform, discovery_info=info)
|
|
||||||
|
|
||||||
discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered)
|
|
||||||
|
|
||||||
async def async_clear_cache_handle(service: ServiceCall) -> None:
|
async def async_clear_cache_handle(service: ServiceCall) -> None:
|
||||||
"""Handle clear cache service call."""
|
"""Handle clear cache service call."""
|
||||||
@ -337,29 +191,30 @@ def _hash_options(options: dict) -> str:
|
|||||||
class SpeechManager:
|
class SpeechManager:
|
||||||
"""Representation of a speech store."""
|
"""Representation of a speech store."""
|
||||||
|
|
||||||
def __init__(self, hass: HomeAssistant) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
use_cache: bool,
|
||||||
|
cache_dir: str,
|
||||||
|
time_memory: int,
|
||||||
|
base_url: str | None,
|
||||||
|
) -> None:
|
||||||
"""Initialize a speech store."""
|
"""Initialize a speech store."""
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self.providers: dict[str, Provider] = {}
|
self.providers: dict[str, Provider] = {}
|
||||||
|
|
||||||
self.use_cache = DEFAULT_CACHE
|
self.use_cache = use_cache
|
||||||
self.cache_dir = DEFAULT_CACHE_DIR
|
self.cache_dir = cache_dir
|
||||||
self.time_memory = DEFAULT_TIME_MEMORY
|
self.time_memory = time_memory
|
||||||
self.base_url: str | None = None
|
self.base_url = base_url
|
||||||
self.file_cache: dict[str, str] = {}
|
self.file_cache: dict[str, str] = {}
|
||||||
self.mem_cache: dict[str, TTSCache] = {}
|
self.mem_cache: dict[str, TTSCache] = {}
|
||||||
|
|
||||||
async def async_init_cache(
|
async def async_init_cache(self) -> None:
|
||||||
self, use_cache: bool, cache_dir: str, time_memory: int, base_url: str | None
|
|
||||||
) -> None:
|
|
||||||
"""Init config folder and load file cache."""
|
"""Init config folder and load file cache."""
|
||||||
self.use_cache = use_cache
|
|
||||||
self.time_memory = time_memory
|
|
||||||
self.base_url = base_url
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.cache_dir = await self.hass.async_add_executor_job(
|
self.cache_dir = await self.hass.async_add_executor_job(
|
||||||
_init_tts_cache_dir, self.hass, cache_dir
|
_init_tts_cache_dir, self.hass, self.cache_dir
|
||||||
)
|
)
|
||||||
except OSError as err:
|
except OSError as err:
|
||||||
raise HomeAssistantError(f"Can't init cache dir {err}") from err
|
raise HomeAssistantError(f"Can't init cache dir {err}") from err
|
||||||
@ -390,7 +245,7 @@ class SpeechManager:
|
|||||||
self.file_cache = {}
|
self.file_cache = {}
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_register_engine(
|
def async_register_legacy_engine(
|
||||||
self, engine: str, provider: Provider, config: ConfigType
|
self, engine: str, provider: Provider, config: ConfigType
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Register a TTS provider."""
|
"""Register a TTS provider."""
|
||||||
@ -739,52 +594,6 @@ class SpeechManager:
|
|||||||
return data_bytes.getvalue()
|
return data_bytes.getvalue()
|
||||||
|
|
||||||
|
|
||||||
class Provider(ABC):
|
|
||||||
"""Represent a single TTS provider."""
|
|
||||||
|
|
||||||
hass: HomeAssistant | None = None
|
|
||||||
name: str | None = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def default_language(self) -> str | None:
|
|
||||||
"""Return the default language."""
|
|
||||||
return None
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def supported_languages(self) -> list[str]:
|
|
||||||
"""Return a list of supported languages."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supported_options(self) -> list[str] | None:
|
|
||||||
"""Return a list of supported options like voice, emotions."""
|
|
||||||
return None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def default_options(self) -> Mapping[str, Any] | None:
|
|
||||||
"""Return a mapping with the default options."""
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_tts_audio(
|
|
||||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
|
||||||
) -> TtsAudioType:
|
|
||||||
"""Load tts audio file from provider."""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
async def async_get_tts_audio(
|
|
||||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
|
||||||
) -> TtsAudioType:
|
|
||||||
"""Load tts audio file from provider.
|
|
||||||
|
|
||||||
Return a tuple of file extension and data as bytes.
|
|
||||||
"""
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
assert self.hass
|
|
||||||
return await self.hass.async_add_executor_job(
|
|
||||||
ft.partial(self.get_tts_audio, message, language, options=options)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _init_tts_cache_dir(hass: HomeAssistant, cache_dir: str) -> str:
|
def _init_tts_cache_dir(hass: HomeAssistant, cache_dir: str) -> str:
|
||||||
"""Init cache folder."""
|
"""Init cache folder."""
|
||||||
if not os.path.isabs(cache_dir):
|
if not os.path.isabs(cache_dir):
|
||||||
|
@ -1,3 +1,19 @@
|
|||||||
"""Text-to-speech constants."""
|
"""Text-to-speech constants."""
|
||||||
|
ATTR_CACHE = "cache"
|
||||||
|
ATTR_LANGUAGE = "language"
|
||||||
|
ATTR_MESSAGE = "message"
|
||||||
|
ATTR_OPTIONS = "options"
|
||||||
|
|
||||||
|
CONF_BASE_URL = "base_url"
|
||||||
|
CONF_CACHE = "cache"
|
||||||
|
CONF_CACHE_DIR = "cache_dir"
|
||||||
|
CONF_FIELDS = "fields"
|
||||||
|
CONF_TIME_MEMORY = "time_memory"
|
||||||
|
|
||||||
|
DEFAULT_CACHE = True
|
||||||
|
DEFAULT_CACHE_DIR = "tts"
|
||||||
|
DEFAULT_TIME_MEMORY = 300
|
||||||
|
|
||||||
DOMAIN = "tts"
|
DOMAIN = "tts"
|
||||||
|
|
||||||
|
TtsAudioType = tuple[str | None, bytes | None]
|
||||||
|
252
homeassistant/components/tts/legacy.py
Normal file
252
homeassistant/components/tts/legacy.py
Normal file
@ -0,0 +1,252 @@
|
|||||||
|
"""Provide the legacy TTS service provider interface."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import abstractmethod
|
||||||
|
from collections.abc import Coroutine, Mapping
|
||||||
|
from functools import partial
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
|
import voluptuous as vol
|
||||||
|
import yarl
|
||||||
|
|
||||||
|
from homeassistant.components.media_player import (
|
||||||
|
ATTR_MEDIA_ANNOUNCE,
|
||||||
|
ATTR_MEDIA_CONTENT_ID,
|
||||||
|
ATTR_MEDIA_CONTENT_TYPE,
|
||||||
|
DOMAIN as DOMAIN_MP,
|
||||||
|
SERVICE_PLAY_MEDIA,
|
||||||
|
MediaType,
|
||||||
|
)
|
||||||
|
from homeassistant.const import (
|
||||||
|
ATTR_ENTITY_ID,
|
||||||
|
CONF_DESCRIPTION,
|
||||||
|
CONF_NAME,
|
||||||
|
CONF_PLATFORM,
|
||||||
|
)
|
||||||
|
from homeassistant.core import HomeAssistant, ServiceCall
|
||||||
|
from homeassistant.helpers import config_per_platform, discovery
|
||||||
|
import homeassistant.helpers.config_validation as cv
|
||||||
|
from homeassistant.helpers.service import async_set_service_schema
|
||||||
|
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||||
|
from homeassistant.setup import async_prepare_setup_platform
|
||||||
|
from homeassistant.util.network import normalize_url
|
||||||
|
from homeassistant.util.yaml import load_yaml
|
||||||
|
|
||||||
|
from .const import (
|
||||||
|
ATTR_CACHE,
|
||||||
|
ATTR_LANGUAGE,
|
||||||
|
ATTR_MESSAGE,
|
||||||
|
ATTR_OPTIONS,
|
||||||
|
CONF_BASE_URL,
|
||||||
|
CONF_CACHE,
|
||||||
|
CONF_CACHE_DIR,
|
||||||
|
CONF_FIELDS,
|
||||||
|
CONF_TIME_MEMORY,
|
||||||
|
DEFAULT_CACHE,
|
||||||
|
DEFAULT_CACHE_DIR,
|
||||||
|
DEFAULT_TIME_MEMORY,
|
||||||
|
DOMAIN,
|
||||||
|
TtsAudioType,
|
||||||
|
)
|
||||||
|
from .media_source import generate_media_source_id
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from . import SpeechManager
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
CONF_SERVICE_NAME = "service_name"
|
||||||
|
|
||||||
|
|
||||||
|
def _deprecated_platform(value: str) -> str:
|
||||||
|
"""Validate if platform is deprecated."""
|
||||||
|
if value == "google":
|
||||||
|
raise vol.Invalid(
|
||||||
|
"google tts service has been renamed to google_translate,"
|
||||||
|
" please update your configuration."
|
||||||
|
)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _valid_base_url(value: str) -> str:
|
||||||
|
"""Validate base url, return value."""
|
||||||
|
url = yarl.URL(cv.url(value))
|
||||||
|
|
||||||
|
if url.path != "/":
|
||||||
|
raise vol.Invalid("Path should be empty")
|
||||||
|
|
||||||
|
return normalize_url(value)
|
||||||
|
|
||||||
|
|
||||||
|
PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend(
|
||||||
|
{
|
||||||
|
vol.Required(CONF_PLATFORM): vol.All(cv.string, _deprecated_platform),
|
||||||
|
vol.Optional(CONF_CACHE, default=DEFAULT_CACHE): cv.boolean,
|
||||||
|
vol.Optional(CONF_CACHE_DIR, default=DEFAULT_CACHE_DIR): cv.string,
|
||||||
|
vol.Optional(CONF_TIME_MEMORY, default=DEFAULT_TIME_MEMORY): vol.All(
|
||||||
|
vol.Coerce(int), vol.Range(min=60, max=57600)
|
||||||
|
),
|
||||||
|
vol.Optional(CONF_BASE_URL): _valid_base_url,
|
||||||
|
vol.Optional(CONF_SERVICE_NAME): cv.string,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
PLATFORM_SCHEMA_BASE = cv.PLATFORM_SCHEMA_BASE.extend(PLATFORM_SCHEMA.schema)
|
||||||
|
|
||||||
|
SERVICE_SAY = "say"
|
||||||
|
|
||||||
|
SCHEMA_SERVICE_SAY = vol.Schema(
|
||||||
|
{
|
||||||
|
vol.Required(ATTR_MESSAGE): cv.string,
|
||||||
|
vol.Optional(ATTR_CACHE): cv.boolean,
|
||||||
|
vol.Required(ATTR_ENTITY_ID): cv.comp_entity_ids,
|
||||||
|
vol.Optional(ATTR_LANGUAGE): cv.string,
|
||||||
|
vol.Optional(ATTR_OPTIONS): dict,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup_legacy(
|
||||||
|
hass: HomeAssistant, config: ConfigType
|
||||||
|
) -> list[Coroutine[Any, Any, None]]:
|
||||||
|
"""Set up legacy text to speech providers."""
|
||||||
|
tts: SpeechManager = hass.data[DOMAIN]
|
||||||
|
|
||||||
|
# Load service descriptions from tts/services.yaml
|
||||||
|
services_yaml = Path(__file__).parent / "services.yaml"
|
||||||
|
services_dict = cast(
|
||||||
|
dict, await hass.async_add_executor_job(load_yaml, str(services_yaml))
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_setup_platform(
|
||||||
|
p_type: str,
|
||||||
|
p_config: ConfigType | None = None,
|
||||||
|
discovery_info: DiscoveryInfoType | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Set up a TTS platform."""
|
||||||
|
if p_config is None:
|
||||||
|
p_config = {}
|
||||||
|
|
||||||
|
platform = await async_prepare_setup_platform(hass, config, DOMAIN, p_type)
|
||||||
|
if platform is None:
|
||||||
|
_LOGGER.error("Unknown text to speech platform specified")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
if hasattr(platform, "async_get_engine"):
|
||||||
|
provider = await platform.async_get_engine(
|
||||||
|
hass, p_config, discovery_info
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
provider = await hass.async_add_executor_job(
|
||||||
|
platform.get_engine, hass, p_config, discovery_info
|
||||||
|
)
|
||||||
|
|
||||||
|
if provider is None:
|
||||||
|
_LOGGER.error("Error setting up platform: %s", p_type)
|
||||||
|
return
|
||||||
|
|
||||||
|
tts.async_register_legacy_engine(p_type, provider, p_config)
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
_LOGGER.exception("Error setting up platform: %s", p_type)
|
||||||
|
return
|
||||||
|
|
||||||
|
async def async_say_handle(service: ServiceCall) -> None:
|
||||||
|
"""Service handle for say."""
|
||||||
|
entity_ids = service.data[ATTR_ENTITY_ID]
|
||||||
|
|
||||||
|
await hass.services.async_call(
|
||||||
|
DOMAIN_MP,
|
||||||
|
SERVICE_PLAY_MEDIA,
|
||||||
|
{
|
||||||
|
ATTR_ENTITY_ID: entity_ids,
|
||||||
|
ATTR_MEDIA_CONTENT_ID: generate_media_source_id(
|
||||||
|
hass,
|
||||||
|
engine=p_type,
|
||||||
|
message=service.data[ATTR_MESSAGE],
|
||||||
|
language=service.data.get(ATTR_LANGUAGE),
|
||||||
|
options=service.data.get(ATTR_OPTIONS),
|
||||||
|
cache=service.data.get(ATTR_CACHE),
|
||||||
|
),
|
||||||
|
ATTR_MEDIA_CONTENT_TYPE: MediaType.MUSIC,
|
||||||
|
ATTR_MEDIA_ANNOUNCE: True,
|
||||||
|
},
|
||||||
|
blocking=True,
|
||||||
|
context=service.context,
|
||||||
|
)
|
||||||
|
|
||||||
|
service_name = p_config.get(CONF_SERVICE_NAME, f"{p_type}_{SERVICE_SAY}")
|
||||||
|
hass.services.async_register(
|
||||||
|
DOMAIN, service_name, async_say_handle, schema=SCHEMA_SERVICE_SAY
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register the service description
|
||||||
|
service_desc = {
|
||||||
|
CONF_NAME: f"Say a TTS message with {p_type}",
|
||||||
|
CONF_DESCRIPTION: (
|
||||||
|
f"Say something using text-to-speech on a media player with {p_type}."
|
||||||
|
),
|
||||||
|
CONF_FIELDS: services_dict[SERVICE_SAY][CONF_FIELDS],
|
||||||
|
}
|
||||||
|
async_set_service_schema(hass, DOMAIN, service_name, service_desc)
|
||||||
|
|
||||||
|
async def async_platform_discovered(
|
||||||
|
platform: str, info: dict[str, Any] | None
|
||||||
|
) -> None:
|
||||||
|
"""Handle for discovered platform."""
|
||||||
|
await async_setup_platform(platform, discovery_info=info)
|
||||||
|
|
||||||
|
discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered)
|
||||||
|
|
||||||
|
return [
|
||||||
|
async_setup_platform(p_type, p_config)
|
||||||
|
for p_type, p_config in config_per_platform(config, DOMAIN)
|
||||||
|
if p_type is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Provider:
|
||||||
|
"""Represent a single TTS provider."""
|
||||||
|
|
||||||
|
hass: HomeAssistant | None = None
|
||||||
|
name: str | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_language(self) -> str | None:
|
||||||
|
"""Return the default language."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def supported_languages(self) -> list[str]:
|
||||||
|
"""Return a list of supported languages."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_options(self) -> list[str] | None:
|
||||||
|
"""Return a list of supported options like voice, emotions."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_options(self) -> Mapping[str, Any] | None:
|
||||||
|
"""Return a mapping with the default options."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_tts_audio(
|
||||||
|
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||||
|
) -> TtsAudioType:
|
||||||
|
"""Load tts audio file from provider."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
async def async_get_tts_audio(
|
||||||
|
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||||
|
) -> TtsAudioType:
|
||||||
|
"""Load tts audio file from provider.
|
||||||
|
|
||||||
|
Return a tuple of file extension and data as bytes.
|
||||||
|
"""
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
assert self.hass
|
||||||
|
return await self.hass.async_add_executor_job(
|
||||||
|
partial(self.get_tts_audio, message, language, options=options)
|
||||||
|
)
|
77
tests/components/tts/common.py
Normal file
77
tests/components/tts/common.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
"""Provide common tests tools for tts."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant.components.tts import (
|
||||||
|
CONF_LANG,
|
||||||
|
PLATFORM_SCHEMA,
|
||||||
|
Provider,
|
||||||
|
TtsAudioType,
|
||||||
|
)
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||||
|
|
||||||
|
from tests.common import MockPlatform
|
||||||
|
|
||||||
|
SUPPORT_LANGUAGES = ["de", "en", "en_US"]
|
||||||
|
|
||||||
|
DEFAULT_LANG = "en"
|
||||||
|
|
||||||
|
|
||||||
|
class MockProvider(Provider):
|
||||||
|
"""Test speech API provider."""
|
||||||
|
|
||||||
|
def __init__(self, lang: str) -> None:
|
||||||
|
"""Initialize test provider."""
|
||||||
|
self._lang = lang
|
||||||
|
self.name = "Test"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_language(self) -> str:
|
||||||
|
"""Return the default language."""
|
||||||
|
return self._lang
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_languages(self) -> list[str]:
|
||||||
|
"""Return list of supported languages."""
|
||||||
|
return SUPPORT_LANGUAGES
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_options(self) -> list[str]:
|
||||||
|
"""Return list of supported options like voice, emotions."""
|
||||||
|
return ["voice", "age"]
|
||||||
|
|
||||||
|
def get_tts_audio(
|
||||||
|
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||||
|
) -> TtsAudioType:
|
||||||
|
"""Load TTS dat."""
|
||||||
|
return ("mp3", b"")
|
||||||
|
|
||||||
|
|
||||||
|
class MockTTS(MockPlatform):
|
||||||
|
"""A mock TTS platform."""
|
||||||
|
|
||||||
|
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
|
||||||
|
{vol.Optional(CONF_LANG, default=DEFAULT_LANG): vol.In(SUPPORT_LANGUAGES)}
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, provider: type[MockProvider] | None = None, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Initialize."""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
if provider is None:
|
||||||
|
provider = MockProvider
|
||||||
|
self._provider = provider
|
||||||
|
|
||||||
|
async def async_get_engine(
|
||||||
|
self,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config: ConfigType,
|
||||||
|
discovery_info: DiscoveryInfoType | None = None,
|
||||||
|
) -> Provider | None:
|
||||||
|
"""Set up a mock speech component."""
|
||||||
|
return self._provider(config.get(CONF_LANG, DEFAULT_LANG))
|
@ -7,6 +7,12 @@ from unittest.mock import patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components.tts import _get_cache_files
|
from homeassistant.components.tts import _get_cache_files
|
||||||
|
from homeassistant.config import async_process_ha_core_config
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
|
from .common import MockTTS
|
||||||
|
|
||||||
|
from tests.common import MockModule, mock_integration, mock_platform
|
||||||
|
|
||||||
|
|
||||||
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
|
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
|
||||||
@ -71,3 +77,19 @@ def mutagen_mock():
|
|||||||
side_effect=lambda *args: args[1],
|
side_effect=lambda *args: args[1],
|
||||||
) as mock_write_tags:
|
) as mock_write_tags:
|
||||||
yield mock_write_tags
|
yield mock_write_tags
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
async def internal_url_mock(hass: HomeAssistant) -> None:
|
||||||
|
"""Mock internal URL of the instance."""
|
||||||
|
await async_process_ha_core_config(
|
||||||
|
hass,
|
||||||
|
{"internal_url": "http://example.local:8123"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mock_tts(hass: HomeAssistant) -> None:
|
||||||
|
"""Mock TTS."""
|
||||||
|
mock_integration(hass, MockModule(domain="test"))
|
||||||
|
mock_platform(hass, "test.tts", MockTTS())
|
||||||
|
@ -17,13 +17,14 @@ from homeassistant.components.media_player import (
|
|||||||
MediaType,
|
MediaType,
|
||||||
)
|
)
|
||||||
from homeassistant.components.media_source import Unresolvable
|
from homeassistant.components.media_source import Unresolvable
|
||||||
from homeassistant.config import async_process_ha_core_config
|
from homeassistant.components.tts.legacy import _valid_base_url
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
from homeassistant.util.network import normalize_url
|
from homeassistant.util.network import normalize_url
|
||||||
|
|
||||||
|
from .common import MockProvider, MockTTS
|
||||||
|
|
||||||
from tests.common import (
|
from tests.common import (
|
||||||
MockModule,
|
MockModule,
|
||||||
assert_setup_component,
|
assert_setup_component,
|
||||||
@ -36,7 +37,7 @@ from tests.typing import ClientSessionGenerator
|
|||||||
ORIG_WRITE_TAGS = tts.SpeechManager.write_tags
|
ORIG_WRITE_TAGS = tts.SpeechManager.write_tags
|
||||||
|
|
||||||
|
|
||||||
async def get_media_source_url(hass, media_content_id):
|
async def get_media_source_url(hass: HomeAssistant, media_content_id: str) -> str:
|
||||||
"""Get the media source url."""
|
"""Get the media source url."""
|
||||||
if media_source.DOMAIN not in hass.config.components:
|
if media_source.DOMAIN not in hass.config.components:
|
||||||
assert await async_setup_component(hass, media_source.DOMAIN, {})
|
assert await async_setup_component(hass, media_source.DOMAIN, {})
|
||||||
@ -45,88 +46,14 @@ async def get_media_source_url(hass, media_content_id):
|
|||||||
return resolved.url
|
return resolved.url
|
||||||
|
|
||||||
|
|
||||||
SUPPORT_LANGUAGES = ["de", "en", "en_US"]
|
|
||||||
|
|
||||||
DEFAULT_LANG = "en"
|
|
||||||
|
|
||||||
|
|
||||||
class MockProvider(tts.Provider):
|
|
||||||
"""Test speech API provider."""
|
|
||||||
|
|
||||||
def __init__(self, lang: str) -> None:
|
|
||||||
"""Initialize test provider."""
|
|
||||||
self._lang = lang
|
|
||||||
self.name = "Test"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def default_language(self) -> str:
|
|
||||||
"""Return the default language."""
|
|
||||||
return self._lang
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supported_languages(self) -> list[str]:
|
|
||||||
"""Return list of supported languages."""
|
|
||||||
return SUPPORT_LANGUAGES
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supported_options(self) -> list[str]:
|
|
||||||
"""Return list of supported options like voice, emotions."""
|
|
||||||
return ["voice", "age"]
|
|
||||||
|
|
||||||
def get_tts_audio(
|
|
||||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
|
||||||
) -> tts.TtsAudioType:
|
|
||||||
"""Load TTS dat."""
|
|
||||||
return ("mp3", b"")
|
|
||||||
|
|
||||||
|
|
||||||
class MockTTS:
|
|
||||||
"""A mock TTS platform."""
|
|
||||||
|
|
||||||
PLATFORM_SCHEMA = tts.PLATFORM_SCHEMA.extend(
|
|
||||||
{vol.Optional(tts.CONF_LANG, default=DEFAULT_LANG): vol.In(SUPPORT_LANGUAGES)}
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self, provider=None) -> None:
|
|
||||||
"""Initialize."""
|
|
||||||
if provider is None:
|
|
||||||
provider = MockProvider
|
|
||||||
self._provider = provider
|
|
||||||
|
|
||||||
async def async_get_engine(
|
|
||||||
self,
|
|
||||||
hass: HomeAssistant,
|
|
||||||
config: ConfigType,
|
|
||||||
discovery_info: DiscoveryInfoType | None = None,
|
|
||||||
) -> tts.Provider:
|
|
||||||
"""Set up a mock speech component."""
|
|
||||||
return self._provider(config.get(tts.CONF_LANG, DEFAULT_LANG))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def test_provider():
|
def mock_provider() -> MockProvider:
|
||||||
"""Test TTS provider."""
|
"""Test TTS provider."""
|
||||||
return MockProvider("en")
|
return MockProvider("en")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
async def internal_url_mock(hass):
|
|
||||||
"""Mock internal URL of the instance."""
|
|
||||||
await async_process_ha_core_config(
|
|
||||||
hass,
|
|
||||||
{"internal_url": "http://example.local:8123"},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def mock_tts(hass):
|
async def setup_tts(hass: HomeAssistant, mock_tts: None) -> None:
|
||||||
"""Mock TTS."""
|
|
||||||
mock_integration(hass, MockModule(domain="test"))
|
|
||||||
mock_platform(hass, "test.tts", MockTTS())
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def setup_tts(hass, mock_tts):
|
|
||||||
"""Mock TTS."""
|
"""Mock TTS."""
|
||||||
assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}})
|
assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}})
|
||||||
|
|
||||||
@ -366,6 +293,8 @@ async def test_setup_component_and_test_with_service_options_def(
|
|||||||
config = {tts.DOMAIN: {"platform": "test"}}
|
config = {tts.DOMAIN: {"platform": "test"}}
|
||||||
|
|
||||||
class MockProviderWithDefaults(MockProvider):
|
class MockProviderWithDefaults(MockProvider):
|
||||||
|
"""Mock provider with default options."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def default_options(self):
|
def default_options(self):
|
||||||
return {"voice": "alex"}
|
return {"voice": "alex"}
|
||||||
@ -413,6 +342,8 @@ async def test_setup_component_and_test_with_service_options_def_2(
|
|||||||
config = {tts.DOMAIN: {"platform": "test"}}
|
config = {tts.DOMAIN: {"platform": "test"}}
|
||||||
|
|
||||||
class MockProviderWithDefaults(MockProvider):
|
class MockProviderWithDefaults(MockProvider):
|
||||||
|
"""Mock provider with default options."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def default_options(self):
|
def default_options(self):
|
||||||
return {"voice": "alex"}
|
return {"voice": "alex"}
|
||||||
@ -550,7 +481,10 @@ async def test_setup_component_and_test_service_clear_cache(
|
|||||||
|
|
||||||
|
|
||||||
async def test_setup_component_and_test_service_with_receive_voice(
|
async def test_setup_component_and_test_service_with_receive_voice(
|
||||||
hass: HomeAssistant, test_provider, hass_client: ClientSessionGenerator, mock_tts
|
hass: HomeAssistant,
|
||||||
|
mock_provider: MockProvider,
|
||||||
|
hass_client: ClientSessionGenerator,
|
||||||
|
mock_tts,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set up a TTS platform and call service and receive voice."""
|
"""Set up a TTS platform and call service and receive voice."""
|
||||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||||
@ -576,11 +510,12 @@ async def test_setup_component_and_test_service_with_receive_voice(
|
|||||||
url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
req = await client.get(url)
|
req = await client.get(url)
|
||||||
_, tts_data = test_provider.get_tts_audio("bla", "en")
|
_, tts_data = mock_provider.get_tts_audio("bla", "en")
|
||||||
|
assert tts_data is not None
|
||||||
tts_data = tts.SpeechManager.write_tags(
|
tts_data = tts.SpeechManager.write_tags(
|
||||||
"42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3",
|
"42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3",
|
||||||
tts_data,
|
tts_data,
|
||||||
test_provider,
|
mock_provider,
|
||||||
message,
|
message,
|
||||||
"en",
|
"en",
|
||||||
None,
|
None,
|
||||||
@ -596,7 +531,10 @@ async def test_setup_component_and_test_service_with_receive_voice(
|
|||||||
|
|
||||||
|
|
||||||
async def test_setup_component_and_test_service_with_receive_voice_german(
|
async def test_setup_component_and_test_service_with_receive_voice_german(
|
||||||
hass: HomeAssistant, test_provider, hass_client: ClientSessionGenerator, mock_tts
|
hass: HomeAssistant,
|
||||||
|
mock_provider: MockProvider,
|
||||||
|
hass_client: ClientSessionGenerator,
|
||||||
|
mock_tts,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set up a TTS platform and call service and receive voice."""
|
"""Set up a TTS platform and call service and receive voice."""
|
||||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||||
@ -619,11 +557,12 @@ async def test_setup_component_and_test_service_with_receive_voice_german(
|
|||||||
url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
req = await client.get(url)
|
req = await client.get(url)
|
||||||
_, tts_data = test_provider.get_tts_audio("bla", "de")
|
_, tts_data = mock_provider.get_tts_audio("bla", "de")
|
||||||
|
assert tts_data is not None
|
||||||
tts_data = tts.SpeechManager.write_tags(
|
tts_data = tts.SpeechManager.write_tags(
|
||||||
"42f18378fd4393d18c8dd11d03fa9563c1e54491_de_-_test.mp3",
|
"42f18378fd4393d18c8dd11d03fa9563c1e54491_de_-_test.mp3",
|
||||||
tts_data,
|
tts_data,
|
||||||
test_provider,
|
mock_provider,
|
||||||
"There is someone at the door.",
|
"There is someone at the door.",
|
||||||
"de",
|
"de",
|
||||||
None,
|
None,
|
||||||
@ -722,12 +661,13 @@ async def test_setup_component_test_with_cache_call_service_without_cache(
|
|||||||
|
|
||||||
|
|
||||||
async def test_setup_component_test_with_cache_dir(
|
async def test_setup_component_test_with_cache_dir(
|
||||||
hass: HomeAssistant, empty_cache_dir, test_provider
|
hass: HomeAssistant, empty_cache_dir, mock_provider: MockProvider
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set up a TTS platform with cache and call service without cache."""
|
"""Set up a TTS platform with cache and call service without cache."""
|
||||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||||
|
|
||||||
_, tts_data = test_provider.get_tts_audio("bla", "en")
|
_, tts_data = mock_provider.get_tts_audio("bla", "en")
|
||||||
|
assert tts_data is not None
|
||||||
cache_file = (
|
cache_file = (
|
||||||
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3"
|
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3"
|
||||||
)
|
)
|
||||||
@ -738,12 +678,14 @@ async def test_setup_component_test_with_cache_dir(
|
|||||||
config = {tts.DOMAIN: {"platform": "test", "cache": True}}
|
config = {tts.DOMAIN: {"platform": "test", "cache": True}}
|
||||||
|
|
||||||
class MockProviderBoom(MockProvider):
|
class MockProviderBoom(MockProvider):
|
||||||
|
"""Mock provider that blows up."""
|
||||||
|
|
||||||
def get_tts_audio(
|
def get_tts_audio(
|
||||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||||
) -> tts.TtsAudioType:
|
) -> tts.TtsAudioType:
|
||||||
"""Load TTS dat."""
|
"""Load TTS dat."""
|
||||||
# This should not be called, data should be fetched from cache
|
# This should not be called, data should be fetched from cache
|
||||||
raise Exception("Boom!")
|
raise Exception("Boom!") # pylint: disable=broad-exception-raised
|
||||||
|
|
||||||
mock_integration(hass, MockModule(domain="test"))
|
mock_integration(hass, MockModule(domain="test"))
|
||||||
mock_platform(hass, "test.tts", MockTTS(MockProviderBoom))
|
mock_platform(hass, "test.tts", MockTTS(MockProviderBoom))
|
||||||
@ -775,6 +717,8 @@ async def test_setup_component_test_with_error_on_get_tts(hass: HomeAssistant) -
|
|||||||
config = {tts.DOMAIN: {"platform": "test"}}
|
config = {tts.DOMAIN: {"platform": "test"}}
|
||||||
|
|
||||||
class MockProviderEmpty(MockProvider):
|
class MockProviderEmpty(MockProvider):
|
||||||
|
"""Mock provider with empty get_tts_audio."""
|
||||||
|
|
||||||
def get_tts_audio(
|
def get_tts_audio(
|
||||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||||
) -> tts.TtsAudioType:
|
) -> tts.TtsAudioType:
|
||||||
@ -803,13 +747,14 @@ async def test_setup_component_test_with_error_on_get_tts(hass: HomeAssistant) -
|
|||||||
|
|
||||||
async def test_setup_component_load_cache_retrieve_without_mem_cache(
|
async def test_setup_component_load_cache_retrieve_without_mem_cache(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
test_provider,
|
mock_provider: MockProvider,
|
||||||
empty_cache_dir,
|
empty_cache_dir,
|
||||||
hass_client: ClientSessionGenerator,
|
hass_client: ClientSessionGenerator,
|
||||||
mock_tts,
|
mock_tts,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set up component and load cache and get without mem cache."""
|
"""Set up component and load cache and get without mem cache."""
|
||||||
_, tts_data = test_provider.get_tts_audio("bla", "en")
|
_, tts_data = mock_provider.get_tts_audio("bla", "en")
|
||||||
|
assert tts_data is not None
|
||||||
cache_file = (
|
cache_file = (
|
||||||
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3"
|
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3"
|
||||||
)
|
)
|
||||||
@ -870,7 +815,7 @@ async def test_setup_component_and_web_get_url_bad_config(
|
|||||||
assert req.status == HTTPStatus.BAD_REQUEST
|
assert req.status == HTTPStatus.BAD_REQUEST
|
||||||
|
|
||||||
|
|
||||||
async def test_tags_with_wave(hass: HomeAssistant, test_provider) -> None:
|
async def test_tags_with_wave(hass: HomeAssistant, mock_provider: MockProvider) -> None:
|
||||||
"""Set up a TTS platform and call service and receive voice."""
|
"""Set up a TTS platform and call service and receive voice."""
|
||||||
|
|
||||||
# below data represents an empty wav file
|
# below data represents an empty wav file
|
||||||
@ -882,7 +827,7 @@ async def test_tags_with_wave(hass: HomeAssistant, test_provider) -> None:
|
|||||||
tagged_data = ORIG_WRITE_TAGS(
|
tagged_data = ORIG_WRITE_TAGS(
|
||||||
"42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.wav",
|
"42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.wav",
|
||||||
tts_data,
|
tts_data,
|
||||||
test_provider,
|
mock_provider,
|
||||||
"AI person is in front of your door.",
|
"AI person is in front of your door.",
|
||||||
"en",
|
"en",
|
||||||
None,
|
None,
|
||||||
@ -904,9 +849,9 @@ async def test_tags_with_wave(hass: HomeAssistant, test_provider) -> None:
|
|||||||
)
|
)
|
||||||
def test_valid_base_url(value) -> None:
|
def test_valid_base_url(value) -> None:
|
||||||
"""Test we validate base urls."""
|
"""Test we validate base urls."""
|
||||||
assert tts.valid_base_url(value) == normalize_url(value)
|
assert _valid_base_url(value) == normalize_url(value)
|
||||||
# Test we strip trailing `/`
|
# Test we strip trailing `/`
|
||||||
assert tts.valid_base_url(value + "/") == normalize_url(value)
|
assert _valid_base_url(value + "/") == normalize_url(value)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -926,7 +871,7 @@ def test_valid_base_url(value) -> None:
|
|||||||
def test_invalid_base_url(value) -> None:
|
def test_invalid_base_url(value) -> None:
|
||||||
"""Test we catch bad base urls."""
|
"""Test we catch bad base urls."""
|
||||||
with pytest.raises(vol.Invalid):
|
with pytest.raises(vol.Invalid):
|
||||||
tts.valid_base_url(value)
|
_valid_base_url(value)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -1000,9 +945,11 @@ async def test_support_options(hass: HomeAssistant, setup_tts) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_fetching_in_async(hass: HomeAssistant, hass_client) -> None:
|
async def test_fetching_in_async(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
) -> None:
|
||||||
"""Test async fetching of data."""
|
"""Test async fetching of data."""
|
||||||
tts_audio = asyncio.Future()
|
tts_audio: asyncio.Future[bytes] = asyncio.Future()
|
||||||
|
|
||||||
class ProviderWithAsyncFetching(MockProvider):
|
class ProviderWithAsyncFetching(MockProvider):
|
||||||
"""Provider that supports audio output option."""
|
"""Provider that supports audio output option."""
|
||||||
@ -1067,4 +1014,7 @@ async def test_fetching_in_async(hass: HomeAssistant, hass_client) -> None:
|
|||||||
|
|
||||||
tts_audio = asyncio.Future()
|
tts_audio = asyncio.Future()
|
||||||
tts_audio.set_result(b"test 2")
|
tts_audio.set_result(b"test 2")
|
||||||
await tts.async_get_media_source_audio(hass, media_source_id) == ("mp3", b"test 2")
|
assert await tts.async_get_media_source_audio(hass, media_source_id) == (
|
||||||
|
"mp3",
|
||||||
|
b"test 2",
|
||||||
|
)
|
||||||
|
121
tests/components/tts/test_legacy.py
Normal file
121
tests/components/tts/test_legacy.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
"""Test the legacy tts setup."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.components.tts import DOMAIN, Provider
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers.discovery import async_load_platform
|
||||||
|
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||||
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
from .common import MockTTS
|
||||||
|
|
||||||
|
from tests.common import (
|
||||||
|
MockModule,
|
||||||
|
assert_setup_component,
|
||||||
|
mock_integration,
|
||||||
|
mock_platform,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_default_provider_attributes() -> None:
|
||||||
|
"""Test default provider properties."""
|
||||||
|
provider = Provider()
|
||||||
|
|
||||||
|
assert provider.hass is None
|
||||||
|
assert provider.name is None
|
||||||
|
assert provider.default_language is None
|
||||||
|
assert provider.supported_languages is None
|
||||||
|
assert provider.supported_options is None
|
||||||
|
assert provider.default_options is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_deprecated_platform(hass: HomeAssistant) -> None:
|
||||||
|
"""Test deprecated google platform."""
|
||||||
|
with assert_setup_component(0, DOMAIN):
|
||||||
|
assert await async_setup_component(
|
||||||
|
hass, DOMAIN, {DOMAIN: {"platform": "google"}}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_invalid_platform(
|
||||||
|
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
||||||
|
) -> None:
|
||||||
|
"""Test platform setup with an invalid platform."""
|
||||||
|
await async_load_platform(
|
||||||
|
hass,
|
||||||
|
"tts",
|
||||||
|
"bad_tts",
|
||||||
|
{"tts": [{"platform": "bad_tts"}]},
|
||||||
|
hass_config={"tts": [{"platform": "bad_tts"}]},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert "Unknown text to speech platform specified" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
async def test_platform_setup_without_provider(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
caplog: pytest.LogCaptureFixture,
|
||||||
|
) -> None:
|
||||||
|
"""Test platform setup without provider returned."""
|
||||||
|
|
||||||
|
class BadPlatform(MockTTS):
|
||||||
|
"""A mock TTS platform without provider."""
|
||||||
|
|
||||||
|
async def async_get_engine(
|
||||||
|
self,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config: ConfigType,
|
||||||
|
discovery_info: DiscoveryInfoType | None = None,
|
||||||
|
) -> Provider | None:
|
||||||
|
"""Raise exception during platform setup."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
mock_integration(hass, MockModule(domain="bad_tts"))
|
||||||
|
mock_platform(hass, "bad_tts.tts", BadPlatform())
|
||||||
|
|
||||||
|
await async_load_platform(
|
||||||
|
hass,
|
||||||
|
"tts",
|
||||||
|
"bad_tts",
|
||||||
|
{},
|
||||||
|
hass_config={"tts": [{"platform": "bad_tts"}]},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert "Error setting up platform: bad_tts" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
async def test_platform_setup_with_error(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
caplog: pytest.LogCaptureFixture,
|
||||||
|
) -> None:
|
||||||
|
"""Test platform setup with an error during setup."""
|
||||||
|
|
||||||
|
class BadPlatform(MockTTS):
|
||||||
|
"""A mock TTS platform with a setup error."""
|
||||||
|
|
||||||
|
async def async_get_engine(
|
||||||
|
self,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config: ConfigType,
|
||||||
|
discovery_info: DiscoveryInfoType | None = None,
|
||||||
|
) -> Provider:
|
||||||
|
"""Raise exception during platform setup."""
|
||||||
|
raise Exception("Setup error") # pylint: disable=broad-exception-raised
|
||||||
|
|
||||||
|
mock_integration(hass, MockModule(domain="bad_tts"))
|
||||||
|
mock_platform(hass, "bad_tts.tts", BadPlatform())
|
||||||
|
|
||||||
|
await async_load_platform(
|
||||||
|
hass,
|
||||||
|
"tts",
|
||||||
|
"bad_tts",
|
||||||
|
{},
|
||||||
|
hass_config={"tts": [{"platform": "bad_tts"}]},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert "Error setting up platform: bad_tts" in caplog.text
|
Loading…
x
Reference in New Issue
Block a user