mirror of
https://github.com/home-assistant/core.git
synced 2025-04-24 17:27:52 +00:00
Move legacy tts (#91538)
* Move legacy tts * Add error log on unknown platform * Add legacy tests and delint all tests * Consolidate log format * Add more legacy tests * Test default legacy provider attributes * Remove test generated files * Clean up after merge conflict
This commit is contained in:
parent
0ecd23baee
commit
9bd12f6503
@ -1,135 +1,77 @@
|
||||
"""Provide functionality for TTS."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
import functools as ft
|
||||
import hashlib
|
||||
from http import HTTPStatus
|
||||
import io
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, TypedDict, cast
|
||||
from typing import TypedDict
|
||||
|
||||
from aiohttp import web
|
||||
import mutagen
|
||||
from mutagen.id3 import ID3, TextFrame as ID3Text
|
||||
import voluptuous as vol
|
||||
import yarl
|
||||
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.components.media_player import (
|
||||
ATTR_MEDIA_ANNOUNCE,
|
||||
ATTR_MEDIA_CONTENT_ID,
|
||||
ATTR_MEDIA_CONTENT_TYPE,
|
||||
DOMAIN as DOMAIN_MP,
|
||||
SERVICE_PLAY_MEDIA,
|
||||
MediaType,
|
||||
)
|
||||
from homeassistant.const import (
|
||||
ATTR_ENTITY_ID,
|
||||
CONF_DESCRIPTION,
|
||||
CONF_NAME,
|
||||
CONF_PLATFORM,
|
||||
PLATFORM_FORMAT,
|
||||
)
|
||||
from homeassistant.const import PLATFORM_FORMAT
|
||||
from homeassistant.core import HassJob, HomeAssistant, ServiceCall, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import config_per_platform, discovery
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.event import async_call_later
|
||||
from homeassistant.helpers.network import get_url
|
||||
from homeassistant.helpers.service import async_set_service_schema
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
from homeassistant.setup import async_prepare_setup_platform
|
||||
from homeassistant.util.network import normalize_url
|
||||
from homeassistant.util.yaml import load_yaml
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from .const import DOMAIN
|
||||
from .const import (
|
||||
ATTR_CACHE,
|
||||
ATTR_LANGUAGE,
|
||||
ATTR_MESSAGE,
|
||||
ATTR_OPTIONS,
|
||||
CONF_BASE_URL,
|
||||
CONF_CACHE,
|
||||
CONF_CACHE_DIR,
|
||||
CONF_TIME_MEMORY,
|
||||
DEFAULT_CACHE,
|
||||
DEFAULT_CACHE_DIR,
|
||||
DEFAULT_TIME_MEMORY,
|
||||
DOMAIN,
|
||||
TtsAudioType,
|
||||
)
|
||||
from .legacy import PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE, Provider, async_setup_legacy
|
||||
from .media_source import generate_media_source_id, media_source_id_to_kwargs
|
||||
|
||||
__all__ = [
|
||||
"async_get_media_source_audio",
|
||||
"async_resolve_engine",
|
||||
"async_support_options",
|
||||
"ATTR_AUDIO_OUTPUT",
|
||||
"CONF_LANG",
|
||||
"DEFAULT_CACHE_DIR",
|
||||
"generate_media_source_id",
|
||||
"get_base_url",
|
||||
"PLATFORM_SCHEMA_BASE",
|
||||
"PLATFORM_SCHEMA",
|
||||
"Provider",
|
||||
"TtsAudioType",
|
||||
]
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
TtsAudioType = tuple[str | None, bytes | None]
|
||||
|
||||
ATTR_CACHE = "cache"
|
||||
ATTR_LANGUAGE = "language"
|
||||
ATTR_MESSAGE = "message"
|
||||
ATTR_OPTIONS = "options"
|
||||
ATTR_PLATFORM = "platform"
|
||||
ATTR_AUDIO_OUTPUT = "audio_output"
|
||||
|
||||
CONF_LANG = "language"
|
||||
|
||||
BASE_URL_KEY = "tts_base_url"
|
||||
|
||||
CONF_BASE_URL = "base_url"
|
||||
CONF_CACHE = "cache"
|
||||
CONF_CACHE_DIR = "cache_dir"
|
||||
CONF_LANG = "language"
|
||||
CONF_SERVICE_NAME = "service_name"
|
||||
CONF_TIME_MEMORY = "time_memory"
|
||||
|
||||
CONF_FIELDS = "fields"
|
||||
|
||||
DEFAULT_CACHE = True
|
||||
DEFAULT_CACHE_DIR = "tts"
|
||||
DEFAULT_TIME_MEMORY = 300
|
||||
|
||||
SERVICE_CLEAR_CACHE = "clear_cache"
|
||||
SERVICE_SAY = "say"
|
||||
|
||||
_RE_VOICE_FILE = re.compile(r"([a-f0-9]{40})_([^_]+)_([^_]+)_([a-z_]+)\.[a-z0-9]{3,4}")
|
||||
KEY_PATTERN = "{0}_{1}_{2}_{3}"
|
||||
|
||||
|
||||
def _deprecated_platform(value: str) -> str:
|
||||
"""Validate if platform is deprecated."""
|
||||
if value == "google":
|
||||
raise vol.Invalid(
|
||||
"google tts service has been renamed to google_translate,"
|
||||
" please update your configuration."
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
def valid_base_url(value: str) -> str:
|
||||
"""Validate base url, return value."""
|
||||
url = yarl.URL(cv.url(value))
|
||||
|
||||
if url.path != "/":
|
||||
raise vol.Invalid("Path should be empty")
|
||||
|
||||
return normalize_url(value)
|
||||
|
||||
|
||||
PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend(
|
||||
{
|
||||
vol.Required(CONF_PLATFORM): vol.All(cv.string, _deprecated_platform),
|
||||
vol.Optional(CONF_CACHE, default=DEFAULT_CACHE): cv.boolean,
|
||||
vol.Optional(CONF_CACHE_DIR, default=DEFAULT_CACHE_DIR): cv.string,
|
||||
vol.Optional(CONF_TIME_MEMORY, default=DEFAULT_TIME_MEMORY): vol.All(
|
||||
vol.Coerce(int), vol.Range(min=60, max=57600)
|
||||
),
|
||||
vol.Optional(CONF_BASE_URL): valid_base_url,
|
||||
vol.Optional(CONF_SERVICE_NAME): cv.string,
|
||||
}
|
||||
)
|
||||
PLATFORM_SCHEMA_BASE = cv.PLATFORM_SCHEMA_BASE.extend(PLATFORM_SCHEMA.schema)
|
||||
|
||||
SCHEMA_SERVICE_SAY = vol.Schema(
|
||||
{
|
||||
vol.Required(ATTR_MESSAGE): cv.string,
|
||||
vol.Optional(ATTR_CACHE): cv.boolean,
|
||||
vol.Required(ATTR_ENTITY_ID): cv.comp_entity_ids,
|
||||
vol.Optional(ATTR_LANGUAGE): cv.string,
|
||||
vol.Optional(ATTR_OPTIONS): dict,
|
||||
}
|
||||
)
|
||||
|
||||
SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({})
|
||||
|
||||
|
||||
@ -192,22 +134,23 @@ async def async_get_media_source_audio(
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up TTS."""
|
||||
tts = SpeechManager(hass)
|
||||
# Legacy config options
|
||||
conf = config[DOMAIN][0] if config.get(DOMAIN) else {}
|
||||
use_cache: bool = conf.get(CONF_CACHE, DEFAULT_CACHE)
|
||||
cache_dir: str = conf.get(CONF_CACHE_DIR, DEFAULT_CACHE_DIR)
|
||||
time_memory: int = conf.get(CONF_TIME_MEMORY, DEFAULT_TIME_MEMORY)
|
||||
base_url: str | None = conf.get(CONF_BASE_URL)
|
||||
if base_url is not None:
|
||||
_LOGGER.warning(
|
||||
"TTS base_url option is deprecated. Configure internal/external URL"
|
||||
" instead"
|
||||
)
|
||||
hass.data[BASE_URL_KEY] = base_url
|
||||
|
||||
tts = SpeechManager(hass, use_cache, cache_dir, time_memory, base_url)
|
||||
|
||||
try:
|
||||
conf = config[DOMAIN][0] if config.get(DOMAIN, []) else {}
|
||||
use_cache = conf.get(CONF_CACHE, DEFAULT_CACHE)
|
||||
cache_dir = conf.get(CONF_CACHE_DIR, DEFAULT_CACHE_DIR)
|
||||
time_memory = conf.get(CONF_TIME_MEMORY, DEFAULT_TIME_MEMORY)
|
||||
base_url = conf.get(CONF_BASE_URL)
|
||||
if base_url is not None:
|
||||
_LOGGER.warning(
|
||||
"TTS base_url option is deprecated. Configure internal/external URL"
|
||||
" instead"
|
||||
)
|
||||
hass.data[BASE_URL_KEY] = base_url
|
||||
|
||||
await tts.async_init_cache(use_cache, cache_dir, time_memory, base_url)
|
||||
await tts.async_init_cache()
|
||||
except (HomeAssistantError, KeyError):
|
||||
_LOGGER.exception("Error on cache init")
|
||||
return False
|
||||
@ -216,99 +159,10 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
hass.http.register_view(TextToSpeechView(tts))
|
||||
hass.http.register_view(TextToSpeechUrlView(tts))
|
||||
|
||||
# Load service descriptions from tts/services.yaml
|
||||
services_yaml = Path(__file__).parent / "services.yaml"
|
||||
services_dict = cast(
|
||||
dict, await hass.async_add_executor_job(load_yaml, str(services_yaml))
|
||||
)
|
||||
platform_setups = await async_setup_legacy(hass, config)
|
||||
|
||||
async def async_setup_platform(
|
||||
p_type: str,
|
||||
p_config: ConfigType | None = None,
|
||||
discovery_info: DiscoveryInfoType | None = None,
|
||||
) -> None:
|
||||
"""Set up a TTS platform."""
|
||||
if p_config is None:
|
||||
p_config = {}
|
||||
|
||||
platform = await async_prepare_setup_platform(hass, config, DOMAIN, p_type)
|
||||
if platform is None:
|
||||
return
|
||||
|
||||
try:
|
||||
if hasattr(platform, "async_get_engine"):
|
||||
provider = await platform.async_get_engine(
|
||||
hass, p_config, discovery_info
|
||||
)
|
||||
else:
|
||||
provider = await hass.async_add_executor_job(
|
||||
platform.get_engine, hass, p_config, discovery_info
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
_LOGGER.error("Error setting up platform %s", p_type)
|
||||
return
|
||||
|
||||
tts.async_register_engine(p_type, provider, p_config)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Error setting up platform: %s", p_type)
|
||||
return
|
||||
|
||||
async def async_say_handle(service: ServiceCall) -> None:
|
||||
"""Service handle for say."""
|
||||
entity_ids = service.data[ATTR_ENTITY_ID]
|
||||
|
||||
await hass.services.async_call(
|
||||
DOMAIN_MP,
|
||||
SERVICE_PLAY_MEDIA,
|
||||
{
|
||||
ATTR_ENTITY_ID: entity_ids,
|
||||
ATTR_MEDIA_CONTENT_ID: generate_media_source_id(
|
||||
hass,
|
||||
engine=p_type,
|
||||
message=service.data[ATTR_MESSAGE],
|
||||
language=service.data.get(ATTR_LANGUAGE),
|
||||
options=service.data.get(ATTR_OPTIONS),
|
||||
cache=service.data.get(ATTR_CACHE),
|
||||
),
|
||||
ATTR_MEDIA_CONTENT_TYPE: MediaType.MUSIC,
|
||||
ATTR_MEDIA_ANNOUNCE: True,
|
||||
},
|
||||
blocking=True,
|
||||
context=service.context,
|
||||
)
|
||||
|
||||
service_name = p_config.get(CONF_SERVICE_NAME, f"{p_type}_{SERVICE_SAY}")
|
||||
hass.services.async_register(
|
||||
DOMAIN, service_name, async_say_handle, schema=SCHEMA_SERVICE_SAY
|
||||
)
|
||||
|
||||
# Register the service description
|
||||
service_desc = {
|
||||
CONF_NAME: f"Say a TTS message with {p_type}",
|
||||
CONF_DESCRIPTION: (
|
||||
f"Say something using text-to-speech on a media player with {p_type}."
|
||||
),
|
||||
CONF_FIELDS: services_dict[SERVICE_SAY][CONF_FIELDS],
|
||||
}
|
||||
async_set_service_schema(hass, DOMAIN, service_name, service_desc)
|
||||
|
||||
setup_tasks = [
|
||||
asyncio.create_task(async_setup_platform(p_type, p_config))
|
||||
for p_type, p_config in config_per_platform(config, DOMAIN)
|
||||
if p_type is not None
|
||||
]
|
||||
|
||||
if setup_tasks:
|
||||
await asyncio.wait(setup_tasks)
|
||||
|
||||
async def async_platform_discovered(
|
||||
platform: str, info: dict[str, Any] | None
|
||||
) -> None:
|
||||
"""Handle for discovered platform."""
|
||||
await async_setup_platform(platform, discovery_info=info)
|
||||
|
||||
discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered)
|
||||
if platform_setups:
|
||||
await asyncio.wait([asyncio.create_task(setup) for setup in platform_setups])
|
||||
|
||||
async def async_clear_cache_handle(service: ServiceCall) -> None:
|
||||
"""Handle clear cache service call."""
|
||||
@ -337,29 +191,30 @@ def _hash_options(options: dict) -> str:
|
||||
class SpeechManager:
|
||||
"""Representation of a speech store."""
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
use_cache: bool,
|
||||
cache_dir: str,
|
||||
time_memory: int,
|
||||
base_url: str | None,
|
||||
) -> None:
|
||||
"""Initialize a speech store."""
|
||||
self.hass = hass
|
||||
self.providers: dict[str, Provider] = {}
|
||||
|
||||
self.use_cache = DEFAULT_CACHE
|
||||
self.cache_dir = DEFAULT_CACHE_DIR
|
||||
self.time_memory = DEFAULT_TIME_MEMORY
|
||||
self.base_url: str | None = None
|
||||
self.use_cache = use_cache
|
||||
self.cache_dir = cache_dir
|
||||
self.time_memory = time_memory
|
||||
self.base_url = base_url
|
||||
self.file_cache: dict[str, str] = {}
|
||||
self.mem_cache: dict[str, TTSCache] = {}
|
||||
|
||||
async def async_init_cache(
|
||||
self, use_cache: bool, cache_dir: str, time_memory: int, base_url: str | None
|
||||
) -> None:
|
||||
async def async_init_cache(self) -> None:
|
||||
"""Init config folder and load file cache."""
|
||||
self.use_cache = use_cache
|
||||
self.time_memory = time_memory
|
||||
self.base_url = base_url
|
||||
|
||||
try:
|
||||
self.cache_dir = await self.hass.async_add_executor_job(
|
||||
_init_tts_cache_dir, self.hass, cache_dir
|
||||
_init_tts_cache_dir, self.hass, self.cache_dir
|
||||
)
|
||||
except OSError as err:
|
||||
raise HomeAssistantError(f"Can't init cache dir {err}") from err
|
||||
@ -390,7 +245,7 @@ class SpeechManager:
|
||||
self.file_cache = {}
|
||||
|
||||
@callback
|
||||
def async_register_engine(
|
||||
def async_register_legacy_engine(
|
||||
self, engine: str, provider: Provider, config: ConfigType
|
||||
) -> None:
|
||||
"""Register a TTS provider."""
|
||||
@ -739,52 +594,6 @@ class SpeechManager:
|
||||
return data_bytes.getvalue()
|
||||
|
||||
|
||||
class Provider(ABC):
|
||||
"""Represent a single TTS provider."""
|
||||
|
||||
hass: HomeAssistant | None = None
|
||||
name: str | None = None
|
||||
|
||||
@property
|
||||
def default_language(self) -> str | None:
|
||||
"""Return the default language."""
|
||||
return None
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_languages(self) -> list[str]:
|
||||
"""Return a list of supported languages."""
|
||||
|
||||
@property
|
||||
def supported_options(self) -> list[str] | None:
|
||||
"""Return a list of supported options like voice, emotions."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def default_options(self) -> Mapping[str, Any] | None:
|
||||
"""Return a mapping with the default options."""
|
||||
return None
|
||||
|
||||
def get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||
) -> TtsAudioType:
|
||||
"""Load tts audio file from provider."""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def async_get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||
) -> TtsAudioType:
|
||||
"""Load tts audio file from provider.
|
||||
|
||||
Return a tuple of file extension and data as bytes.
|
||||
"""
|
||||
if TYPE_CHECKING:
|
||||
assert self.hass
|
||||
return await self.hass.async_add_executor_job(
|
||||
ft.partial(self.get_tts_audio, message, language, options=options)
|
||||
)
|
||||
|
||||
|
||||
def _init_tts_cache_dir(hass: HomeAssistant, cache_dir: str) -> str:
|
||||
"""Init cache folder."""
|
||||
if not os.path.isabs(cache_dir):
|
||||
|
@ -1,3 +1,19 @@
|
||||
"""Text-to-speech constants."""
|
||||
ATTR_CACHE = "cache"
|
||||
ATTR_LANGUAGE = "language"
|
||||
ATTR_MESSAGE = "message"
|
||||
ATTR_OPTIONS = "options"
|
||||
|
||||
CONF_BASE_URL = "base_url"
|
||||
CONF_CACHE = "cache"
|
||||
CONF_CACHE_DIR = "cache_dir"
|
||||
CONF_FIELDS = "fields"
|
||||
CONF_TIME_MEMORY = "time_memory"
|
||||
|
||||
DEFAULT_CACHE = True
|
||||
DEFAULT_CACHE_DIR = "tts"
|
||||
DEFAULT_TIME_MEMORY = 300
|
||||
|
||||
DOMAIN = "tts"
|
||||
|
||||
TtsAudioType = tuple[str | None, bytes | None]
|
||||
|
252
homeassistant/components/tts/legacy.py
Normal file
252
homeassistant/components/tts/legacy.py
Normal file
@ -0,0 +1,252 @@
|
||||
"""Provide the legacy TTS service provider interface."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Coroutine, Mapping
|
||||
from functools import partial
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import voluptuous as vol
|
||||
import yarl
|
||||
|
||||
from homeassistant.components.media_player import (
|
||||
ATTR_MEDIA_ANNOUNCE,
|
||||
ATTR_MEDIA_CONTENT_ID,
|
||||
ATTR_MEDIA_CONTENT_TYPE,
|
||||
DOMAIN as DOMAIN_MP,
|
||||
SERVICE_PLAY_MEDIA,
|
||||
MediaType,
|
||||
)
|
||||
from homeassistant.const import (
|
||||
ATTR_ENTITY_ID,
|
||||
CONF_DESCRIPTION,
|
||||
CONF_NAME,
|
||||
CONF_PLATFORM,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant, ServiceCall
|
||||
from homeassistant.helpers import config_per_platform, discovery
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.service import async_set_service_schema
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
from homeassistant.setup import async_prepare_setup_platform
|
||||
from homeassistant.util.network import normalize_url
|
||||
from homeassistant.util.yaml import load_yaml
|
||||
|
||||
from .const import (
|
||||
ATTR_CACHE,
|
||||
ATTR_LANGUAGE,
|
||||
ATTR_MESSAGE,
|
||||
ATTR_OPTIONS,
|
||||
CONF_BASE_URL,
|
||||
CONF_CACHE,
|
||||
CONF_CACHE_DIR,
|
||||
CONF_FIELDS,
|
||||
CONF_TIME_MEMORY,
|
||||
DEFAULT_CACHE,
|
||||
DEFAULT_CACHE_DIR,
|
||||
DEFAULT_TIME_MEMORY,
|
||||
DOMAIN,
|
||||
TtsAudioType,
|
||||
)
|
||||
from .media_source import generate_media_source_id
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import SpeechManager
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
CONF_SERVICE_NAME = "service_name"
|
||||
|
||||
|
||||
def _deprecated_platform(value: str) -> str:
|
||||
"""Validate if platform is deprecated."""
|
||||
if value == "google":
|
||||
raise vol.Invalid(
|
||||
"google tts service has been renamed to google_translate,"
|
||||
" please update your configuration."
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
def _valid_base_url(value: str) -> str:
|
||||
"""Validate base url, return value."""
|
||||
url = yarl.URL(cv.url(value))
|
||||
|
||||
if url.path != "/":
|
||||
raise vol.Invalid("Path should be empty")
|
||||
|
||||
return normalize_url(value)
|
||||
|
||||
|
||||
PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend(
|
||||
{
|
||||
vol.Required(CONF_PLATFORM): vol.All(cv.string, _deprecated_platform),
|
||||
vol.Optional(CONF_CACHE, default=DEFAULT_CACHE): cv.boolean,
|
||||
vol.Optional(CONF_CACHE_DIR, default=DEFAULT_CACHE_DIR): cv.string,
|
||||
vol.Optional(CONF_TIME_MEMORY, default=DEFAULT_TIME_MEMORY): vol.All(
|
||||
vol.Coerce(int), vol.Range(min=60, max=57600)
|
||||
),
|
||||
vol.Optional(CONF_BASE_URL): _valid_base_url,
|
||||
vol.Optional(CONF_SERVICE_NAME): cv.string,
|
||||
}
|
||||
)
|
||||
PLATFORM_SCHEMA_BASE = cv.PLATFORM_SCHEMA_BASE.extend(PLATFORM_SCHEMA.schema)
|
||||
|
||||
SERVICE_SAY = "say"
|
||||
|
||||
SCHEMA_SERVICE_SAY = vol.Schema(
|
||||
{
|
||||
vol.Required(ATTR_MESSAGE): cv.string,
|
||||
vol.Optional(ATTR_CACHE): cv.boolean,
|
||||
vol.Required(ATTR_ENTITY_ID): cv.comp_entity_ids,
|
||||
vol.Optional(ATTR_LANGUAGE): cv.string,
|
||||
vol.Optional(ATTR_OPTIONS): dict,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def async_setup_legacy(
|
||||
hass: HomeAssistant, config: ConfigType
|
||||
) -> list[Coroutine[Any, Any, None]]:
|
||||
"""Set up legacy text to speech providers."""
|
||||
tts: SpeechManager = hass.data[DOMAIN]
|
||||
|
||||
# Load service descriptions from tts/services.yaml
|
||||
services_yaml = Path(__file__).parent / "services.yaml"
|
||||
services_dict = cast(
|
||||
dict, await hass.async_add_executor_job(load_yaml, str(services_yaml))
|
||||
)
|
||||
|
||||
async def async_setup_platform(
|
||||
p_type: str,
|
||||
p_config: ConfigType | None = None,
|
||||
discovery_info: DiscoveryInfoType | None = None,
|
||||
) -> None:
|
||||
"""Set up a TTS platform."""
|
||||
if p_config is None:
|
||||
p_config = {}
|
||||
|
||||
platform = await async_prepare_setup_platform(hass, config, DOMAIN, p_type)
|
||||
if platform is None:
|
||||
_LOGGER.error("Unknown text to speech platform specified")
|
||||
return
|
||||
|
||||
try:
|
||||
if hasattr(platform, "async_get_engine"):
|
||||
provider = await platform.async_get_engine(
|
||||
hass, p_config, discovery_info
|
||||
)
|
||||
else:
|
||||
provider = await hass.async_add_executor_job(
|
||||
platform.get_engine, hass, p_config, discovery_info
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
_LOGGER.error("Error setting up platform: %s", p_type)
|
||||
return
|
||||
|
||||
tts.async_register_legacy_engine(p_type, provider, p_config)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Error setting up platform: %s", p_type)
|
||||
return
|
||||
|
||||
async def async_say_handle(service: ServiceCall) -> None:
|
||||
"""Service handle for say."""
|
||||
entity_ids = service.data[ATTR_ENTITY_ID]
|
||||
|
||||
await hass.services.async_call(
|
||||
DOMAIN_MP,
|
||||
SERVICE_PLAY_MEDIA,
|
||||
{
|
||||
ATTR_ENTITY_ID: entity_ids,
|
||||
ATTR_MEDIA_CONTENT_ID: generate_media_source_id(
|
||||
hass,
|
||||
engine=p_type,
|
||||
message=service.data[ATTR_MESSAGE],
|
||||
language=service.data.get(ATTR_LANGUAGE),
|
||||
options=service.data.get(ATTR_OPTIONS),
|
||||
cache=service.data.get(ATTR_CACHE),
|
||||
),
|
||||
ATTR_MEDIA_CONTENT_TYPE: MediaType.MUSIC,
|
||||
ATTR_MEDIA_ANNOUNCE: True,
|
||||
},
|
||||
blocking=True,
|
||||
context=service.context,
|
||||
)
|
||||
|
||||
service_name = p_config.get(CONF_SERVICE_NAME, f"{p_type}_{SERVICE_SAY}")
|
||||
hass.services.async_register(
|
||||
DOMAIN, service_name, async_say_handle, schema=SCHEMA_SERVICE_SAY
|
||||
)
|
||||
|
||||
# Register the service description
|
||||
service_desc = {
|
||||
CONF_NAME: f"Say a TTS message with {p_type}",
|
||||
CONF_DESCRIPTION: (
|
||||
f"Say something using text-to-speech on a media player with {p_type}."
|
||||
),
|
||||
CONF_FIELDS: services_dict[SERVICE_SAY][CONF_FIELDS],
|
||||
}
|
||||
async_set_service_schema(hass, DOMAIN, service_name, service_desc)
|
||||
|
||||
async def async_platform_discovered(
|
||||
platform: str, info: dict[str, Any] | None
|
||||
) -> None:
|
||||
"""Handle for discovered platform."""
|
||||
await async_setup_platform(platform, discovery_info=info)
|
||||
|
||||
discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered)
|
||||
|
||||
return [
|
||||
async_setup_platform(p_type, p_config)
|
||||
for p_type, p_config in config_per_platform(config, DOMAIN)
|
||||
if p_type is not None
|
||||
]
|
||||
|
||||
|
||||
class Provider:
|
||||
"""Represent a single TTS provider."""
|
||||
|
||||
hass: HomeAssistant | None = None
|
||||
name: str | None = None
|
||||
|
||||
@property
|
||||
def default_language(self) -> str | None:
|
||||
"""Return the default language."""
|
||||
return None
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_languages(self) -> list[str]:
|
||||
"""Return a list of supported languages."""
|
||||
|
||||
@property
|
||||
def supported_options(self) -> list[str] | None:
|
||||
"""Return a list of supported options like voice, emotions."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def default_options(self) -> Mapping[str, Any] | None:
|
||||
"""Return a mapping with the default options."""
|
||||
return None
|
||||
|
||||
def get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||
) -> TtsAudioType:
|
||||
"""Load tts audio file from provider."""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def async_get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||
) -> TtsAudioType:
|
||||
"""Load tts audio file from provider.
|
||||
|
||||
Return a tuple of file extension and data as bytes.
|
||||
"""
|
||||
if TYPE_CHECKING:
|
||||
assert self.hass
|
||||
return await self.hass.async_add_executor_job(
|
||||
partial(self.get_tts_audio, message, language, options=options)
|
||||
)
|
77
tests/components/tts/common.py
Normal file
77
tests/components/tts/common.py
Normal file
@ -0,0 +1,77 @@
|
||||
"""Provide common tests tools for tts."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.tts import (
|
||||
CONF_LANG,
|
||||
PLATFORM_SCHEMA,
|
||||
Provider,
|
||||
TtsAudioType,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
|
||||
from tests.common import MockPlatform
|
||||
|
||||
SUPPORT_LANGUAGES = ["de", "en", "en_US"]
|
||||
|
||||
DEFAULT_LANG = "en"
|
||||
|
||||
|
||||
class MockProvider(Provider):
|
||||
"""Test speech API provider."""
|
||||
|
||||
def __init__(self, lang: str) -> None:
|
||||
"""Initialize test provider."""
|
||||
self._lang = lang
|
||||
self.name = "Test"
|
||||
|
||||
@property
|
||||
def default_language(self) -> str:
|
||||
"""Return the default language."""
|
||||
return self._lang
|
||||
|
||||
@property
|
||||
def supported_languages(self) -> list[str]:
|
||||
"""Return list of supported languages."""
|
||||
return SUPPORT_LANGUAGES
|
||||
|
||||
@property
|
||||
def supported_options(self) -> list[str]:
|
||||
"""Return list of supported options like voice, emotions."""
|
||||
return ["voice", "age"]
|
||||
|
||||
def get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||
) -> TtsAudioType:
|
||||
"""Load TTS dat."""
|
||||
return ("mp3", b"")
|
||||
|
||||
|
||||
class MockTTS(MockPlatform):
|
||||
"""A mock TTS platform."""
|
||||
|
||||
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
|
||||
{vol.Optional(CONF_LANG, default=DEFAULT_LANG): vol.In(SUPPORT_LANGUAGES)}
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, provider: type[MockProvider] | None = None, **kwargs: Any
|
||||
) -> None:
|
||||
"""Initialize."""
|
||||
super().__init__(**kwargs)
|
||||
if provider is None:
|
||||
provider = MockProvider
|
||||
self._provider = provider
|
||||
|
||||
async def async_get_engine(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
config: ConfigType,
|
||||
discovery_info: DiscoveryInfoType | None = None,
|
||||
) -> Provider | None:
|
||||
"""Set up a mock speech component."""
|
||||
return self._provider(config.get(CONF_LANG, DEFAULT_LANG))
|
@ -7,6 +7,12 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.tts import _get_cache_files
|
||||
from homeassistant.config import async_process_ha_core_config
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .common import MockTTS
|
||||
|
||||
from tests.common import MockModule, mock_integration, mock_platform
|
||||
|
||||
|
||||
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
|
||||
@ -71,3 +77,19 @@ def mutagen_mock():
|
||||
side_effect=lambda *args: args[1],
|
||||
) as mock_write_tags:
|
||||
yield mock_write_tags
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def internal_url_mock(hass: HomeAssistant) -> None:
|
||||
"""Mock internal URL of the instance."""
|
||||
await async_process_ha_core_config(
|
||||
hass,
|
||||
{"internal_url": "http://example.local:8123"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_tts(hass: HomeAssistant) -> None:
|
||||
"""Mock TTS."""
|
||||
mock_integration(hass, MockModule(domain="test"))
|
||||
mock_platform(hass, "test.tts", MockTTS())
|
||||
|
@ -17,13 +17,14 @@ from homeassistant.components.media_player import (
|
||||
MediaType,
|
||||
)
|
||||
from homeassistant.components.media_source import Unresolvable
|
||||
from homeassistant.config import async_process_ha_core_config
|
||||
from homeassistant.components.tts.legacy import _valid_base_url
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util.network import normalize_url
|
||||
|
||||
from .common import MockProvider, MockTTS
|
||||
|
||||
from tests.common import (
|
||||
MockModule,
|
||||
assert_setup_component,
|
||||
@ -36,7 +37,7 @@ from tests.typing import ClientSessionGenerator
|
||||
ORIG_WRITE_TAGS = tts.SpeechManager.write_tags
|
||||
|
||||
|
||||
async def get_media_source_url(hass, media_content_id):
|
||||
async def get_media_source_url(hass: HomeAssistant, media_content_id: str) -> str:
|
||||
"""Get the media source url."""
|
||||
if media_source.DOMAIN not in hass.config.components:
|
||||
assert await async_setup_component(hass, media_source.DOMAIN, {})
|
||||
@ -45,88 +46,14 @@ async def get_media_source_url(hass, media_content_id):
|
||||
return resolved.url
|
||||
|
||||
|
||||
SUPPORT_LANGUAGES = ["de", "en", "en_US"]
|
||||
|
||||
DEFAULT_LANG = "en"
|
||||
|
||||
|
||||
class MockProvider(tts.Provider):
|
||||
"""Test speech API provider."""
|
||||
|
||||
def __init__(self, lang: str) -> None:
|
||||
"""Initialize test provider."""
|
||||
self._lang = lang
|
||||
self.name = "Test"
|
||||
|
||||
@property
|
||||
def default_language(self) -> str:
|
||||
"""Return the default language."""
|
||||
return self._lang
|
||||
|
||||
@property
|
||||
def supported_languages(self) -> list[str]:
|
||||
"""Return list of supported languages."""
|
||||
return SUPPORT_LANGUAGES
|
||||
|
||||
@property
|
||||
def supported_options(self) -> list[str]:
|
||||
"""Return list of supported options like voice, emotions."""
|
||||
return ["voice", "age"]
|
||||
|
||||
def get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||
) -> tts.TtsAudioType:
|
||||
"""Load TTS dat."""
|
||||
return ("mp3", b"")
|
||||
|
||||
|
||||
class MockTTS:
|
||||
"""A mock TTS platform."""
|
||||
|
||||
PLATFORM_SCHEMA = tts.PLATFORM_SCHEMA.extend(
|
||||
{vol.Optional(tts.CONF_LANG, default=DEFAULT_LANG): vol.In(SUPPORT_LANGUAGES)}
|
||||
)
|
||||
|
||||
def __init__(self, provider=None) -> None:
|
||||
"""Initialize."""
|
||||
if provider is None:
|
||||
provider = MockProvider
|
||||
self._provider = provider
|
||||
|
||||
async def async_get_engine(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
config: ConfigType,
|
||||
discovery_info: DiscoveryInfoType | None = None,
|
||||
) -> tts.Provider:
|
||||
"""Set up a mock speech component."""
|
||||
return self._provider(config.get(tts.CONF_LANG, DEFAULT_LANG))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_provider():
|
||||
def mock_provider() -> MockProvider:
|
||||
"""Test TTS provider."""
|
||||
return MockProvider("en")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def internal_url_mock(hass):
|
||||
"""Mock internal URL of the instance."""
|
||||
await async_process_ha_core_config(
|
||||
hass,
|
||||
{"internal_url": "http://example.local:8123"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_tts(hass):
|
||||
"""Mock TTS."""
|
||||
mock_integration(hass, MockModule(domain="test"))
|
||||
mock_platform(hass, "test.tts", MockTTS())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_tts(hass, mock_tts):
|
||||
async def setup_tts(hass: HomeAssistant, mock_tts: None) -> None:
|
||||
"""Mock TTS."""
|
||||
assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}})
|
||||
|
||||
@ -366,6 +293,8 @@ async def test_setup_component_and_test_with_service_options_def(
|
||||
config = {tts.DOMAIN: {"platform": "test"}}
|
||||
|
||||
class MockProviderWithDefaults(MockProvider):
|
||||
"""Mock provider with default options."""
|
||||
|
||||
@property
|
||||
def default_options(self):
|
||||
return {"voice": "alex"}
|
||||
@ -413,6 +342,8 @@ async def test_setup_component_and_test_with_service_options_def_2(
|
||||
config = {tts.DOMAIN: {"platform": "test"}}
|
||||
|
||||
class MockProviderWithDefaults(MockProvider):
|
||||
"""Mock provider with default options."""
|
||||
|
||||
@property
|
||||
def default_options(self):
|
||||
return {"voice": "alex"}
|
||||
@ -550,7 +481,10 @@ async def test_setup_component_and_test_service_clear_cache(
|
||||
|
||||
|
||||
async def test_setup_component_and_test_service_with_receive_voice(
|
||||
hass: HomeAssistant, test_provider, hass_client: ClientSessionGenerator, mock_tts
|
||||
hass: HomeAssistant,
|
||||
mock_provider: MockProvider,
|
||||
hass_client: ClientSessionGenerator,
|
||||
mock_tts,
|
||||
) -> None:
|
||||
"""Set up a TTS platform and call service and receive voice."""
|
||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
@ -576,11 +510,12 @@ async def test_setup_component_and_test_service_with_receive_voice(
|
||||
url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
client = await hass_client()
|
||||
req = await client.get(url)
|
||||
_, tts_data = test_provider.get_tts_audio("bla", "en")
|
||||
_, tts_data = mock_provider.get_tts_audio("bla", "en")
|
||||
assert tts_data is not None
|
||||
tts_data = tts.SpeechManager.write_tags(
|
||||
"42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3",
|
||||
tts_data,
|
||||
test_provider,
|
||||
mock_provider,
|
||||
message,
|
||||
"en",
|
||||
None,
|
||||
@ -596,7 +531,10 @@ async def test_setup_component_and_test_service_with_receive_voice(
|
||||
|
||||
|
||||
async def test_setup_component_and_test_service_with_receive_voice_german(
|
||||
hass: HomeAssistant, test_provider, hass_client: ClientSessionGenerator, mock_tts
|
||||
hass: HomeAssistant,
|
||||
mock_provider: MockProvider,
|
||||
hass_client: ClientSessionGenerator,
|
||||
mock_tts,
|
||||
) -> None:
|
||||
"""Set up a TTS platform and call service and receive voice."""
|
||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
@ -619,11 +557,12 @@ async def test_setup_component_and_test_service_with_receive_voice_german(
|
||||
url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
client = await hass_client()
|
||||
req = await client.get(url)
|
||||
_, tts_data = test_provider.get_tts_audio("bla", "de")
|
||||
_, tts_data = mock_provider.get_tts_audio("bla", "de")
|
||||
assert tts_data is not None
|
||||
tts_data = tts.SpeechManager.write_tags(
|
||||
"42f18378fd4393d18c8dd11d03fa9563c1e54491_de_-_test.mp3",
|
||||
tts_data,
|
||||
test_provider,
|
||||
mock_provider,
|
||||
"There is someone at the door.",
|
||||
"de",
|
||||
None,
|
||||
@ -722,12 +661,13 @@ async def test_setup_component_test_with_cache_call_service_without_cache(
|
||||
|
||||
|
||||
async def test_setup_component_test_with_cache_dir(
|
||||
hass: HomeAssistant, empty_cache_dir, test_provider
|
||||
hass: HomeAssistant, empty_cache_dir, mock_provider: MockProvider
|
||||
) -> None:
|
||||
"""Set up a TTS platform with cache and call service without cache."""
|
||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
||||
_, tts_data = test_provider.get_tts_audio("bla", "en")
|
||||
_, tts_data = mock_provider.get_tts_audio("bla", "en")
|
||||
assert tts_data is not None
|
||||
cache_file = (
|
||||
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3"
|
||||
)
|
||||
@ -738,12 +678,14 @@ async def test_setup_component_test_with_cache_dir(
|
||||
config = {tts.DOMAIN: {"platform": "test", "cache": True}}
|
||||
|
||||
class MockProviderBoom(MockProvider):
|
||||
"""Mock provider that blows up."""
|
||||
|
||||
def get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||
) -> tts.TtsAudioType:
|
||||
"""Load TTS dat."""
|
||||
# This should not be called, data should be fetched from cache
|
||||
raise Exception("Boom!")
|
||||
raise Exception("Boom!") # pylint: disable=broad-exception-raised
|
||||
|
||||
mock_integration(hass, MockModule(domain="test"))
|
||||
mock_platform(hass, "test.tts", MockTTS(MockProviderBoom))
|
||||
@ -775,6 +717,8 @@ async def test_setup_component_test_with_error_on_get_tts(hass: HomeAssistant) -
|
||||
config = {tts.DOMAIN: {"platform": "test"}}
|
||||
|
||||
class MockProviderEmpty(MockProvider):
|
||||
"""Mock provider with empty get_tts_audio."""
|
||||
|
||||
def get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||
) -> tts.TtsAudioType:
|
||||
@ -803,13 +747,14 @@ async def test_setup_component_test_with_error_on_get_tts(hass: HomeAssistant) -
|
||||
|
||||
async def test_setup_component_load_cache_retrieve_without_mem_cache(
|
||||
hass: HomeAssistant,
|
||||
test_provider,
|
||||
mock_provider: MockProvider,
|
||||
empty_cache_dir,
|
||||
hass_client: ClientSessionGenerator,
|
||||
mock_tts,
|
||||
) -> None:
|
||||
"""Set up component and load cache and get without mem cache."""
|
||||
_, tts_data = test_provider.get_tts_audio("bla", "en")
|
||||
_, tts_data = mock_provider.get_tts_audio("bla", "en")
|
||||
assert tts_data is not None
|
||||
cache_file = (
|
||||
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3"
|
||||
)
|
||||
@ -870,7 +815,7 @@ async def test_setup_component_and_web_get_url_bad_config(
|
||||
assert req.status == HTTPStatus.BAD_REQUEST
|
||||
|
||||
|
||||
async def test_tags_with_wave(hass: HomeAssistant, test_provider) -> None:
|
||||
async def test_tags_with_wave(hass: HomeAssistant, mock_provider: MockProvider) -> None:
|
||||
"""Set up a TTS platform and call service and receive voice."""
|
||||
|
||||
# below data represents an empty wav file
|
||||
@ -882,7 +827,7 @@ async def test_tags_with_wave(hass: HomeAssistant, test_provider) -> None:
|
||||
tagged_data = ORIG_WRITE_TAGS(
|
||||
"42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.wav",
|
||||
tts_data,
|
||||
test_provider,
|
||||
mock_provider,
|
||||
"AI person is in front of your door.",
|
||||
"en",
|
||||
None,
|
||||
@ -904,9 +849,9 @@ async def test_tags_with_wave(hass: HomeAssistant, test_provider) -> None:
|
||||
)
|
||||
def test_valid_base_url(value) -> None:
|
||||
"""Test we validate base urls."""
|
||||
assert tts.valid_base_url(value) == normalize_url(value)
|
||||
assert _valid_base_url(value) == normalize_url(value)
|
||||
# Test we strip trailing `/`
|
||||
assert tts.valid_base_url(value + "/") == normalize_url(value)
|
||||
assert _valid_base_url(value + "/") == normalize_url(value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -926,7 +871,7 @@ def test_valid_base_url(value) -> None:
|
||||
def test_invalid_base_url(value) -> None:
|
||||
"""Test we catch bad base urls."""
|
||||
with pytest.raises(vol.Invalid):
|
||||
tts.valid_base_url(value)
|
||||
_valid_base_url(value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -1000,9 +945,11 @@ async def test_support_options(hass: HomeAssistant, setup_tts) -> None:
|
||||
)
|
||||
|
||||
|
||||
async def test_fetching_in_async(hass: HomeAssistant, hass_client) -> None:
|
||||
async def test_fetching_in_async(
|
||||
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||
) -> None:
|
||||
"""Test async fetching of data."""
|
||||
tts_audio = asyncio.Future()
|
||||
tts_audio: asyncio.Future[bytes] = asyncio.Future()
|
||||
|
||||
class ProviderWithAsyncFetching(MockProvider):
|
||||
"""Provider that supports audio output option."""
|
||||
@ -1067,4 +1014,7 @@ async def test_fetching_in_async(hass: HomeAssistant, hass_client) -> None:
|
||||
|
||||
tts_audio = asyncio.Future()
|
||||
tts_audio.set_result(b"test 2")
|
||||
await tts.async_get_media_source_audio(hass, media_source_id) == ("mp3", b"test 2")
|
||||
assert await tts.async_get_media_source_audio(hass, media_source_id) == (
|
||||
"mp3",
|
||||
b"test 2",
|
||||
)
|
||||
|
121
tests/components/tts/test_legacy.py
Normal file
121
tests/components/tts/test_legacy.py
Normal file
@ -0,0 +1,121 @@
|
||||
"""Test the legacy tts setup."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.tts import DOMAIN, Provider
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.discovery import async_load_platform
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from .common import MockTTS
|
||||
|
||||
from tests.common import (
|
||||
MockModule,
|
||||
assert_setup_component,
|
||||
mock_integration,
|
||||
mock_platform,
|
||||
)
|
||||
|
||||
|
||||
async def test_default_provider_attributes() -> None:
|
||||
"""Test default provider properties."""
|
||||
provider = Provider()
|
||||
|
||||
assert provider.hass is None
|
||||
assert provider.name is None
|
||||
assert provider.default_language is None
|
||||
assert provider.supported_languages is None
|
||||
assert provider.supported_options is None
|
||||
assert provider.default_options is None
|
||||
|
||||
|
||||
async def test_deprecated_platform(hass: HomeAssistant) -> None:
|
||||
"""Test deprecated google platform."""
|
||||
with assert_setup_component(0, DOMAIN):
|
||||
assert await async_setup_component(
|
||||
hass, DOMAIN, {DOMAIN: {"platform": "google"}}
|
||||
)
|
||||
|
||||
|
||||
async def test_invalid_platform(
|
||||
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Test platform setup with an invalid platform."""
|
||||
await async_load_platform(
|
||||
hass,
|
||||
"tts",
|
||||
"bad_tts",
|
||||
{"tts": [{"platform": "bad_tts"}]},
|
||||
hass_config={"tts": [{"platform": "bad_tts"}]},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert "Unknown text to speech platform specified" in caplog.text
|
||||
|
||||
|
||||
async def test_platform_setup_without_provider(
|
||||
hass: HomeAssistant,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test platform setup without provider returned."""
|
||||
|
||||
class BadPlatform(MockTTS):
|
||||
"""A mock TTS platform without provider."""
|
||||
|
||||
async def async_get_engine(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
config: ConfigType,
|
||||
discovery_info: DiscoveryInfoType | None = None,
|
||||
) -> Provider | None:
|
||||
"""Raise exception during platform setup."""
|
||||
return None
|
||||
|
||||
mock_integration(hass, MockModule(domain="bad_tts"))
|
||||
mock_platform(hass, "bad_tts.tts", BadPlatform())
|
||||
|
||||
await async_load_platform(
|
||||
hass,
|
||||
"tts",
|
||||
"bad_tts",
|
||||
{},
|
||||
hass_config={"tts": [{"platform": "bad_tts"}]},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert "Error setting up platform: bad_tts" in caplog.text
|
||||
|
||||
|
||||
async def test_platform_setup_with_error(
|
||||
hass: HomeAssistant,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test platform setup with an error during setup."""
|
||||
|
||||
class BadPlatform(MockTTS):
|
||||
"""A mock TTS platform with a setup error."""
|
||||
|
||||
async def async_get_engine(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
config: ConfigType,
|
||||
discovery_info: DiscoveryInfoType | None = None,
|
||||
) -> Provider:
|
||||
"""Raise exception during platform setup."""
|
||||
raise Exception("Setup error") # pylint: disable=broad-exception-raised
|
||||
|
||||
mock_integration(hass, MockModule(domain="bad_tts"))
|
||||
mock_platform(hass, "bad_tts.tts", BadPlatform())
|
||||
|
||||
await async_load_platform(
|
||||
hass,
|
||||
"tts",
|
||||
"bad_tts",
|
||||
{},
|
||||
hass_config={"tts": [{"platform": "bad_tts"}]},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert "Error setting up platform: bad_tts" in caplog.text
|
Loading…
x
Reference in New Issue
Block a user