Move legacy tts (#91538)

* Move legacy tts

* Add error log on unknown platform

* Add legacy tests and delint all tests

* Consolidate log format

* Add more legacy tests

* Test default legacy provider attributes

* Remove test generated files

* Clean up after merge conflict
This commit is contained in:
Martin Hjelmare 2023-04-17 19:01:50 +02:00 committed by GitHub
parent 0ecd23baee
commit 9bd12f6503
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 605 additions and 358 deletions

View File

@ -1,135 +1,77 @@
"""Provide functionality for TTS."""
from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
from collections.abc import Mapping
from datetime import datetime
import functools as ft
import hashlib
from http import HTTPStatus
import io
import logging
import mimetypes
import os
from pathlib import Path
import re
from typing import TYPE_CHECKING, Any, TypedDict, cast
from typing import TypedDict
from aiohttp import web
import mutagen
from mutagen.id3 import ID3, TextFrame as ID3Text
import voluptuous as vol
import yarl
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.media_player import (
ATTR_MEDIA_ANNOUNCE,
ATTR_MEDIA_CONTENT_ID,
ATTR_MEDIA_CONTENT_TYPE,
DOMAIN as DOMAIN_MP,
SERVICE_PLAY_MEDIA,
MediaType,
)
from homeassistant.const import (
ATTR_ENTITY_ID,
CONF_DESCRIPTION,
CONF_NAME,
CONF_PLATFORM,
PLATFORM_FORMAT,
)
from homeassistant.const import PLATFORM_FORMAT
from homeassistant.core import HassJob, HomeAssistant, ServiceCall, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_per_platform, discovery
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.event import async_call_later
from homeassistant.helpers.network import get_url
from homeassistant.helpers.service import async_set_service_schema
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.setup import async_prepare_setup_platform
from homeassistant.util.network import normalize_url
from homeassistant.util.yaml import load_yaml
from homeassistant.helpers.typing import ConfigType
from .const import DOMAIN
from .const import (
ATTR_CACHE,
ATTR_LANGUAGE,
ATTR_MESSAGE,
ATTR_OPTIONS,
CONF_BASE_URL,
CONF_CACHE,
CONF_CACHE_DIR,
CONF_TIME_MEMORY,
DEFAULT_CACHE,
DEFAULT_CACHE_DIR,
DEFAULT_TIME_MEMORY,
DOMAIN,
TtsAudioType,
)
from .legacy import PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE, Provider, async_setup_legacy
from .media_source import generate_media_source_id, media_source_id_to_kwargs
__all__ = [
"async_get_media_source_audio",
"async_resolve_engine",
"async_support_options",
"ATTR_AUDIO_OUTPUT",
"CONF_LANG",
"DEFAULT_CACHE_DIR",
"generate_media_source_id",
"get_base_url",
"PLATFORM_SCHEMA_BASE",
"PLATFORM_SCHEMA",
"Provider",
"TtsAudioType",
]
_LOGGER = logging.getLogger(__name__)
TtsAudioType = tuple[str | None, bytes | None]
ATTR_CACHE = "cache"
ATTR_LANGUAGE = "language"
ATTR_MESSAGE = "message"
ATTR_OPTIONS = "options"
ATTR_PLATFORM = "platform"
ATTR_AUDIO_OUTPUT = "audio_output"
CONF_LANG = "language"
BASE_URL_KEY = "tts_base_url"
CONF_BASE_URL = "base_url"
CONF_CACHE = "cache"
CONF_CACHE_DIR = "cache_dir"
CONF_LANG = "language"
CONF_SERVICE_NAME = "service_name"
CONF_TIME_MEMORY = "time_memory"
CONF_FIELDS = "fields"
DEFAULT_CACHE = True
DEFAULT_CACHE_DIR = "tts"
DEFAULT_TIME_MEMORY = 300
SERVICE_CLEAR_CACHE = "clear_cache"
SERVICE_SAY = "say"
_RE_VOICE_FILE = re.compile(r"([a-f0-9]{40})_([^_]+)_([^_]+)_([a-z_]+)\.[a-z0-9]{3,4}")
KEY_PATTERN = "{0}_{1}_{2}_{3}"
def _deprecated_platform(value: str) -> str:
"""Validate if platform is deprecated."""
if value == "google":
raise vol.Invalid(
"google tts service has been renamed to google_translate,"
" please update your configuration."
)
return value
def valid_base_url(value: str) -> str:
"""Validate base url, return value."""
url = yarl.URL(cv.url(value))
if url.path != "/":
raise vol.Invalid("Path should be empty")
return normalize_url(value)
PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend(
{
vol.Required(CONF_PLATFORM): vol.All(cv.string, _deprecated_platform),
vol.Optional(CONF_CACHE, default=DEFAULT_CACHE): cv.boolean,
vol.Optional(CONF_CACHE_DIR, default=DEFAULT_CACHE_DIR): cv.string,
vol.Optional(CONF_TIME_MEMORY, default=DEFAULT_TIME_MEMORY): vol.All(
vol.Coerce(int), vol.Range(min=60, max=57600)
),
vol.Optional(CONF_BASE_URL): valid_base_url,
vol.Optional(CONF_SERVICE_NAME): cv.string,
}
)
PLATFORM_SCHEMA_BASE = cv.PLATFORM_SCHEMA_BASE.extend(PLATFORM_SCHEMA.schema)
SCHEMA_SERVICE_SAY = vol.Schema(
{
vol.Required(ATTR_MESSAGE): cv.string,
vol.Optional(ATTR_CACHE): cv.boolean,
vol.Required(ATTR_ENTITY_ID): cv.comp_entity_ids,
vol.Optional(ATTR_LANGUAGE): cv.string,
vol.Optional(ATTR_OPTIONS): dict,
}
)
SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({})
@ -192,22 +134,23 @@ async def async_get_media_source_audio(
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up TTS."""
tts = SpeechManager(hass)
# Legacy config options
conf = config[DOMAIN][0] if config.get(DOMAIN) else {}
use_cache: bool = conf.get(CONF_CACHE, DEFAULT_CACHE)
cache_dir: str = conf.get(CONF_CACHE_DIR, DEFAULT_CACHE_DIR)
time_memory: int = conf.get(CONF_TIME_MEMORY, DEFAULT_TIME_MEMORY)
base_url: str | None = conf.get(CONF_BASE_URL)
if base_url is not None:
_LOGGER.warning(
"TTS base_url option is deprecated. Configure internal/external URL"
" instead"
)
hass.data[BASE_URL_KEY] = base_url
tts = SpeechManager(hass, use_cache, cache_dir, time_memory, base_url)
try:
conf = config[DOMAIN][0] if config.get(DOMAIN, []) else {}
use_cache = conf.get(CONF_CACHE, DEFAULT_CACHE)
cache_dir = conf.get(CONF_CACHE_DIR, DEFAULT_CACHE_DIR)
time_memory = conf.get(CONF_TIME_MEMORY, DEFAULT_TIME_MEMORY)
base_url = conf.get(CONF_BASE_URL)
if base_url is not None:
_LOGGER.warning(
"TTS base_url option is deprecated. Configure internal/external URL"
" instead"
)
hass.data[BASE_URL_KEY] = base_url
await tts.async_init_cache(use_cache, cache_dir, time_memory, base_url)
await tts.async_init_cache()
except (HomeAssistantError, KeyError):
_LOGGER.exception("Error on cache init")
return False
@ -216,99 +159,10 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
hass.http.register_view(TextToSpeechView(tts))
hass.http.register_view(TextToSpeechUrlView(tts))
# Load service descriptions from tts/services.yaml
services_yaml = Path(__file__).parent / "services.yaml"
services_dict = cast(
dict, await hass.async_add_executor_job(load_yaml, str(services_yaml))
)
platform_setups = await async_setup_legacy(hass, config)
async def async_setup_platform(
p_type: str,
p_config: ConfigType | None = None,
discovery_info: DiscoveryInfoType | None = None,
) -> None:
"""Set up a TTS platform."""
if p_config is None:
p_config = {}
platform = await async_prepare_setup_platform(hass, config, DOMAIN, p_type)
if platform is None:
return
try:
if hasattr(platform, "async_get_engine"):
provider = await platform.async_get_engine(
hass, p_config, discovery_info
)
else:
provider = await hass.async_add_executor_job(
platform.get_engine, hass, p_config, discovery_info
)
if provider is None:
_LOGGER.error("Error setting up platform %s", p_type)
return
tts.async_register_engine(p_type, provider, p_config)
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Error setting up platform: %s", p_type)
return
async def async_say_handle(service: ServiceCall) -> None:
"""Service handle for say."""
entity_ids = service.data[ATTR_ENTITY_ID]
await hass.services.async_call(
DOMAIN_MP,
SERVICE_PLAY_MEDIA,
{
ATTR_ENTITY_ID: entity_ids,
ATTR_MEDIA_CONTENT_ID: generate_media_source_id(
hass,
engine=p_type,
message=service.data[ATTR_MESSAGE],
language=service.data.get(ATTR_LANGUAGE),
options=service.data.get(ATTR_OPTIONS),
cache=service.data.get(ATTR_CACHE),
),
ATTR_MEDIA_CONTENT_TYPE: MediaType.MUSIC,
ATTR_MEDIA_ANNOUNCE: True,
},
blocking=True,
context=service.context,
)
service_name = p_config.get(CONF_SERVICE_NAME, f"{p_type}_{SERVICE_SAY}")
hass.services.async_register(
DOMAIN, service_name, async_say_handle, schema=SCHEMA_SERVICE_SAY
)
# Register the service description
service_desc = {
CONF_NAME: f"Say a TTS message with {p_type}",
CONF_DESCRIPTION: (
f"Say something using text-to-speech on a media player with {p_type}."
),
CONF_FIELDS: services_dict[SERVICE_SAY][CONF_FIELDS],
}
async_set_service_schema(hass, DOMAIN, service_name, service_desc)
setup_tasks = [
asyncio.create_task(async_setup_platform(p_type, p_config))
for p_type, p_config in config_per_platform(config, DOMAIN)
if p_type is not None
]
if setup_tasks:
await asyncio.wait(setup_tasks)
async def async_platform_discovered(
platform: str, info: dict[str, Any] | None
) -> None:
"""Handle for discovered platform."""
await async_setup_platform(platform, discovery_info=info)
discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered)
if platform_setups:
await asyncio.wait([asyncio.create_task(setup) for setup in platform_setups])
async def async_clear_cache_handle(service: ServiceCall) -> None:
"""Handle clear cache service call."""
@ -337,29 +191,30 @@ def _hash_options(options: dict) -> str:
class SpeechManager:
"""Representation of a speech store."""
def __init__(self, hass: HomeAssistant) -> None:
def __init__(
self,
hass: HomeAssistant,
use_cache: bool,
cache_dir: str,
time_memory: int,
base_url: str | None,
) -> None:
"""Initialize a speech store."""
self.hass = hass
self.providers: dict[str, Provider] = {}
self.use_cache = DEFAULT_CACHE
self.cache_dir = DEFAULT_CACHE_DIR
self.time_memory = DEFAULT_TIME_MEMORY
self.base_url: str | None = None
self.use_cache = use_cache
self.cache_dir = cache_dir
self.time_memory = time_memory
self.base_url = base_url
self.file_cache: dict[str, str] = {}
self.mem_cache: dict[str, TTSCache] = {}
async def async_init_cache(
self, use_cache: bool, cache_dir: str, time_memory: int, base_url: str | None
) -> None:
async def async_init_cache(self) -> None:
"""Init config folder and load file cache."""
self.use_cache = use_cache
self.time_memory = time_memory
self.base_url = base_url
try:
self.cache_dir = await self.hass.async_add_executor_job(
_init_tts_cache_dir, self.hass, cache_dir
_init_tts_cache_dir, self.hass, self.cache_dir
)
except OSError as err:
raise HomeAssistantError(f"Can't init cache dir {err}") from err
@ -390,7 +245,7 @@ class SpeechManager:
self.file_cache = {}
@callback
def async_register_engine(
def async_register_legacy_engine(
self, engine: str, provider: Provider, config: ConfigType
) -> None:
"""Register a TTS provider."""
@ -739,52 +594,6 @@ class SpeechManager:
return data_bytes.getvalue()
class Provider(ABC):
"""Represent a single TTS provider."""
hass: HomeAssistant | None = None
name: str | None = None
@property
def default_language(self) -> str | None:
"""Return the default language."""
return None
@property
@abstractmethod
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
@property
def supported_options(self) -> list[str] | None:
"""Return a list of supported options like voice, emotions."""
return None
@property
def default_options(self) -> Mapping[str, Any] | None:
"""Return a mapping with the default options."""
return None
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
) -> TtsAudioType:
"""Load tts audio file from provider."""
raise NotImplementedError()
async def async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
) -> TtsAudioType:
"""Load tts audio file from provider.
Return a tuple of file extension and data as bytes.
"""
if TYPE_CHECKING:
assert self.hass
return await self.hass.async_add_executor_job(
ft.partial(self.get_tts_audio, message, language, options=options)
)
def _init_tts_cache_dir(hass: HomeAssistant, cache_dir: str) -> str:
"""Init cache folder."""
if not os.path.isabs(cache_dir):

View File

@ -1,3 +1,19 @@
"""Text-to-speech constants."""
ATTR_CACHE = "cache"
ATTR_LANGUAGE = "language"
ATTR_MESSAGE = "message"
ATTR_OPTIONS = "options"
CONF_BASE_URL = "base_url"
CONF_CACHE = "cache"
CONF_CACHE_DIR = "cache_dir"
CONF_FIELDS = "fields"
CONF_TIME_MEMORY = "time_memory"
DEFAULT_CACHE = True
DEFAULT_CACHE_DIR = "tts"
DEFAULT_TIME_MEMORY = 300
DOMAIN = "tts"
TtsAudioType = tuple[str | None, bytes | None]

View File

@ -0,0 +1,252 @@
"""Provide the legacy TTS service provider interface."""
from __future__ import annotations
from abc import abstractmethod
from collections.abc import Coroutine, Mapping
from functools import partial
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any, cast
import voluptuous as vol
import yarl
from homeassistant.components.media_player import (
ATTR_MEDIA_ANNOUNCE,
ATTR_MEDIA_CONTENT_ID,
ATTR_MEDIA_CONTENT_TYPE,
DOMAIN as DOMAIN_MP,
SERVICE_PLAY_MEDIA,
MediaType,
)
from homeassistant.const import (
ATTR_ENTITY_ID,
CONF_DESCRIPTION,
CONF_NAME,
CONF_PLATFORM,
)
from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.helpers import config_per_platform, discovery
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.service import async_set_service_schema
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.setup import async_prepare_setup_platform
from homeassistant.util.network import normalize_url
from homeassistant.util.yaml import load_yaml
from .const import (
ATTR_CACHE,
ATTR_LANGUAGE,
ATTR_MESSAGE,
ATTR_OPTIONS,
CONF_BASE_URL,
CONF_CACHE,
CONF_CACHE_DIR,
CONF_FIELDS,
CONF_TIME_MEMORY,
DEFAULT_CACHE,
DEFAULT_CACHE_DIR,
DEFAULT_TIME_MEMORY,
DOMAIN,
TtsAudioType,
)
from .media_source import generate_media_source_id
if TYPE_CHECKING:
from . import SpeechManager
_LOGGER = logging.getLogger(__name__)
CONF_SERVICE_NAME = "service_name"
def _deprecated_platform(value: str) -> str:
"""Validate if platform is deprecated."""
if value == "google":
raise vol.Invalid(
"google tts service has been renamed to google_translate,"
" please update your configuration."
)
return value
def _valid_base_url(value: str) -> str:
"""Validate base url, return value."""
url = yarl.URL(cv.url(value))
if url.path != "/":
raise vol.Invalid("Path should be empty")
return normalize_url(value)
PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend(
{
vol.Required(CONF_PLATFORM): vol.All(cv.string, _deprecated_platform),
vol.Optional(CONF_CACHE, default=DEFAULT_CACHE): cv.boolean,
vol.Optional(CONF_CACHE_DIR, default=DEFAULT_CACHE_DIR): cv.string,
vol.Optional(CONF_TIME_MEMORY, default=DEFAULT_TIME_MEMORY): vol.All(
vol.Coerce(int), vol.Range(min=60, max=57600)
),
vol.Optional(CONF_BASE_URL): _valid_base_url,
vol.Optional(CONF_SERVICE_NAME): cv.string,
}
)
PLATFORM_SCHEMA_BASE = cv.PLATFORM_SCHEMA_BASE.extend(PLATFORM_SCHEMA.schema)
SERVICE_SAY = "say"
SCHEMA_SERVICE_SAY = vol.Schema(
{
vol.Required(ATTR_MESSAGE): cv.string,
vol.Optional(ATTR_CACHE): cv.boolean,
vol.Required(ATTR_ENTITY_ID): cv.comp_entity_ids,
vol.Optional(ATTR_LANGUAGE): cv.string,
vol.Optional(ATTR_OPTIONS): dict,
}
)
async def async_setup_legacy(
hass: HomeAssistant, config: ConfigType
) -> list[Coroutine[Any, Any, None]]:
"""Set up legacy text to speech providers."""
tts: SpeechManager = hass.data[DOMAIN]
# Load service descriptions from tts/services.yaml
services_yaml = Path(__file__).parent / "services.yaml"
services_dict = cast(
dict, await hass.async_add_executor_job(load_yaml, str(services_yaml))
)
async def async_setup_platform(
p_type: str,
p_config: ConfigType | None = None,
discovery_info: DiscoveryInfoType | None = None,
) -> None:
"""Set up a TTS platform."""
if p_config is None:
p_config = {}
platform = await async_prepare_setup_platform(hass, config, DOMAIN, p_type)
if platform is None:
_LOGGER.error("Unknown text to speech platform specified")
return
try:
if hasattr(platform, "async_get_engine"):
provider = await platform.async_get_engine(
hass, p_config, discovery_info
)
else:
provider = await hass.async_add_executor_job(
platform.get_engine, hass, p_config, discovery_info
)
if provider is None:
_LOGGER.error("Error setting up platform: %s", p_type)
return
tts.async_register_legacy_engine(p_type, provider, p_config)
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Error setting up platform: %s", p_type)
return
async def async_say_handle(service: ServiceCall) -> None:
"""Service handle for say."""
entity_ids = service.data[ATTR_ENTITY_ID]
await hass.services.async_call(
DOMAIN_MP,
SERVICE_PLAY_MEDIA,
{
ATTR_ENTITY_ID: entity_ids,
ATTR_MEDIA_CONTENT_ID: generate_media_source_id(
hass,
engine=p_type,
message=service.data[ATTR_MESSAGE],
language=service.data.get(ATTR_LANGUAGE),
options=service.data.get(ATTR_OPTIONS),
cache=service.data.get(ATTR_CACHE),
),
ATTR_MEDIA_CONTENT_TYPE: MediaType.MUSIC,
ATTR_MEDIA_ANNOUNCE: True,
},
blocking=True,
context=service.context,
)
service_name = p_config.get(CONF_SERVICE_NAME, f"{p_type}_{SERVICE_SAY}")
hass.services.async_register(
DOMAIN, service_name, async_say_handle, schema=SCHEMA_SERVICE_SAY
)
# Register the service description
service_desc = {
CONF_NAME: f"Say a TTS message with {p_type}",
CONF_DESCRIPTION: (
f"Say something using text-to-speech on a media player with {p_type}."
),
CONF_FIELDS: services_dict[SERVICE_SAY][CONF_FIELDS],
}
async_set_service_schema(hass, DOMAIN, service_name, service_desc)
async def async_platform_discovered(
platform: str, info: dict[str, Any] | None
) -> None:
"""Handle for discovered platform."""
await async_setup_platform(platform, discovery_info=info)
discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered)
return [
async_setup_platform(p_type, p_config)
for p_type, p_config in config_per_platform(config, DOMAIN)
if p_type is not None
]
class Provider:
"""Represent a single TTS provider."""
hass: HomeAssistant | None = None
name: str | None = None
@property
def default_language(self) -> str | None:
"""Return the default language."""
return None
@property
@abstractmethod
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
@property
def supported_options(self) -> list[str] | None:
"""Return a list of supported options like voice, emotions."""
return None
@property
def default_options(self) -> Mapping[str, Any] | None:
"""Return a mapping with the default options."""
return None
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
) -> TtsAudioType:
"""Load tts audio file from provider."""
raise NotImplementedError()
async def async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
) -> TtsAudioType:
"""Load tts audio file from provider.
Return a tuple of file extension and data as bytes.
"""
if TYPE_CHECKING:
assert self.hass
return await self.hass.async_add_executor_job(
partial(self.get_tts_audio, message, language, options=options)
)

View File

@ -0,0 +1,77 @@
"""Provide common tests tools for tts."""
from __future__ import annotations
from typing import Any
import voluptuous as vol
from homeassistant.components.tts import (
CONF_LANG,
PLATFORM_SCHEMA,
Provider,
TtsAudioType,
)
from homeassistant.core import HomeAssistant
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from tests.common import MockPlatform
SUPPORT_LANGUAGES = ["de", "en", "en_US"]
DEFAULT_LANG = "en"
class MockProvider(Provider):
"""Test speech API provider."""
def __init__(self, lang: str) -> None:
"""Initialize test provider."""
self._lang = lang
self.name = "Test"
@property
def default_language(self) -> str:
"""Return the default language."""
return self._lang
@property
def supported_languages(self) -> list[str]:
"""Return list of supported languages."""
return SUPPORT_LANGUAGES
@property
def supported_options(self) -> list[str]:
"""Return list of supported options like voice, emotions."""
return ["voice", "age"]
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
) -> TtsAudioType:
"""Load TTS dat."""
return ("mp3", b"")
class MockTTS(MockPlatform):
"""A mock TTS platform."""
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
{vol.Optional(CONF_LANG, default=DEFAULT_LANG): vol.In(SUPPORT_LANGUAGES)}
)
def __init__(
self, provider: type[MockProvider] | None = None, **kwargs: Any
) -> None:
"""Initialize."""
super().__init__(**kwargs)
if provider is None:
provider = MockProvider
self._provider = provider
async def async_get_engine(
self,
hass: HomeAssistant,
config: ConfigType,
discovery_info: DiscoveryInfoType | None = None,
) -> Provider | None:
"""Set up a mock speech component."""
return self._provider(config.get(CONF_LANG, DEFAULT_LANG))

View File

@ -7,6 +7,12 @@ from unittest.mock import patch
import pytest
from homeassistant.components.tts import _get_cache_files
from homeassistant.config import async_process_ha_core_config
from homeassistant.core import HomeAssistant
from .common import MockTTS
from tests.common import MockModule, mock_integration, mock_platform
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
@ -71,3 +77,19 @@ def mutagen_mock():
side_effect=lambda *args: args[1],
) as mock_write_tags:
yield mock_write_tags
@pytest.fixture(autouse=True)
async def internal_url_mock(hass: HomeAssistant) -> None:
"""Mock internal URL of the instance."""
await async_process_ha_core_config(
hass,
{"internal_url": "http://example.local:8123"},
)
@pytest.fixture
async def mock_tts(hass: HomeAssistant) -> None:
"""Mock TTS."""
mock_integration(hass, MockModule(domain="test"))
mock_platform(hass, "test.tts", MockTTS())

View File

@ -17,13 +17,14 @@ from homeassistant.components.media_player import (
MediaType,
)
from homeassistant.components.media_source import Unresolvable
from homeassistant.config import async_process_ha_core_config
from homeassistant.components.tts.legacy import _valid_base_url
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.setup import async_setup_component
from homeassistant.util.network import normalize_url
from .common import MockProvider, MockTTS
from tests.common import (
MockModule,
assert_setup_component,
@ -36,7 +37,7 @@ from tests.typing import ClientSessionGenerator
ORIG_WRITE_TAGS = tts.SpeechManager.write_tags
async def get_media_source_url(hass, media_content_id):
async def get_media_source_url(hass: HomeAssistant, media_content_id: str) -> str:
"""Get the media source url."""
if media_source.DOMAIN not in hass.config.components:
assert await async_setup_component(hass, media_source.DOMAIN, {})
@ -45,88 +46,14 @@ async def get_media_source_url(hass, media_content_id):
return resolved.url
SUPPORT_LANGUAGES = ["de", "en", "en_US"]
DEFAULT_LANG = "en"
class MockProvider(tts.Provider):
"""Test speech API provider."""
def __init__(self, lang: str) -> None:
"""Initialize test provider."""
self._lang = lang
self.name = "Test"
@property
def default_language(self) -> str:
"""Return the default language."""
return self._lang
@property
def supported_languages(self) -> list[str]:
"""Return list of supported languages."""
return SUPPORT_LANGUAGES
@property
def supported_options(self) -> list[str]:
"""Return list of supported options like voice, emotions."""
return ["voice", "age"]
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
) -> tts.TtsAudioType:
"""Load TTS dat."""
return ("mp3", b"")
class MockTTS:
"""A mock TTS platform."""
PLATFORM_SCHEMA = tts.PLATFORM_SCHEMA.extend(
{vol.Optional(tts.CONF_LANG, default=DEFAULT_LANG): vol.In(SUPPORT_LANGUAGES)}
)
def __init__(self, provider=None) -> None:
"""Initialize."""
if provider is None:
provider = MockProvider
self._provider = provider
async def async_get_engine(
self,
hass: HomeAssistant,
config: ConfigType,
discovery_info: DiscoveryInfoType | None = None,
) -> tts.Provider:
"""Set up a mock speech component."""
return self._provider(config.get(tts.CONF_LANG, DEFAULT_LANG))
@pytest.fixture
def test_provider():
def mock_provider() -> MockProvider:
"""Test TTS provider."""
return MockProvider("en")
@pytest.fixture(autouse=True)
async def internal_url_mock(hass):
"""Mock internal URL of the instance."""
await async_process_ha_core_config(
hass,
{"internal_url": "http://example.local:8123"},
)
@pytest.fixture
async def mock_tts(hass):
"""Mock TTS."""
mock_integration(hass, MockModule(domain="test"))
mock_platform(hass, "test.tts", MockTTS())
@pytest.fixture
async def setup_tts(hass, mock_tts):
async def setup_tts(hass: HomeAssistant, mock_tts: None) -> None:
"""Mock TTS."""
assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}})
@ -366,6 +293,8 @@ async def test_setup_component_and_test_with_service_options_def(
config = {tts.DOMAIN: {"platform": "test"}}
class MockProviderWithDefaults(MockProvider):
"""Mock provider with default options."""
@property
def default_options(self):
return {"voice": "alex"}
@ -413,6 +342,8 @@ async def test_setup_component_and_test_with_service_options_def_2(
config = {tts.DOMAIN: {"platform": "test"}}
class MockProviderWithDefaults(MockProvider):
"""Mock provider with default options."""
@property
def default_options(self):
return {"voice": "alex"}
@ -550,7 +481,10 @@ async def test_setup_component_and_test_service_clear_cache(
async def test_setup_component_and_test_service_with_receive_voice(
hass: HomeAssistant, test_provider, hass_client: ClientSessionGenerator, mock_tts
hass: HomeAssistant,
mock_provider: MockProvider,
hass_client: ClientSessionGenerator,
mock_tts,
) -> None:
"""Set up a TTS platform and call service and receive voice."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -576,11 +510,12 @@ async def test_setup_component_and_test_service_with_receive_voice(
url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
client = await hass_client()
req = await client.get(url)
_, tts_data = test_provider.get_tts_audio("bla", "en")
_, tts_data = mock_provider.get_tts_audio("bla", "en")
assert tts_data is not None
tts_data = tts.SpeechManager.write_tags(
"42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3",
tts_data,
test_provider,
mock_provider,
message,
"en",
None,
@ -596,7 +531,10 @@ async def test_setup_component_and_test_service_with_receive_voice(
async def test_setup_component_and_test_service_with_receive_voice_german(
hass: HomeAssistant, test_provider, hass_client: ClientSessionGenerator, mock_tts
hass: HomeAssistant,
mock_provider: MockProvider,
hass_client: ClientSessionGenerator,
mock_tts,
) -> None:
"""Set up a TTS platform and call service and receive voice."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -619,11 +557,12 @@ async def test_setup_component_and_test_service_with_receive_voice_german(
url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
client = await hass_client()
req = await client.get(url)
_, tts_data = test_provider.get_tts_audio("bla", "de")
_, tts_data = mock_provider.get_tts_audio("bla", "de")
assert tts_data is not None
tts_data = tts.SpeechManager.write_tags(
"42f18378fd4393d18c8dd11d03fa9563c1e54491_de_-_test.mp3",
tts_data,
test_provider,
mock_provider,
"There is someone at the door.",
"de",
None,
@ -722,12 +661,13 @@ async def test_setup_component_test_with_cache_call_service_without_cache(
async def test_setup_component_test_with_cache_dir(
hass: HomeAssistant, empty_cache_dir, test_provider
hass: HomeAssistant, empty_cache_dir, mock_provider: MockProvider
) -> None:
"""Set up a TTS platform with cache and call service without cache."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
_, tts_data = test_provider.get_tts_audio("bla", "en")
_, tts_data = mock_provider.get_tts_audio("bla", "en")
assert tts_data is not None
cache_file = (
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3"
)
@ -738,12 +678,14 @@ async def test_setup_component_test_with_cache_dir(
config = {tts.DOMAIN: {"platform": "test", "cache": True}}
class MockProviderBoom(MockProvider):
"""Mock provider that blows up."""
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
) -> tts.TtsAudioType:
"""Load TTS dat."""
# This should not be called, data should be fetched from cache
raise Exception("Boom!")
raise Exception("Boom!") # pylint: disable=broad-exception-raised
mock_integration(hass, MockModule(domain="test"))
mock_platform(hass, "test.tts", MockTTS(MockProviderBoom))
@ -775,6 +717,8 @@ async def test_setup_component_test_with_error_on_get_tts(hass: HomeAssistant) -
config = {tts.DOMAIN: {"platform": "test"}}
class MockProviderEmpty(MockProvider):
"""Mock provider with empty get_tts_audio."""
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
) -> tts.TtsAudioType:
@ -803,13 +747,14 @@ async def test_setup_component_test_with_error_on_get_tts(hass: HomeAssistant) -
async def test_setup_component_load_cache_retrieve_without_mem_cache(
hass: HomeAssistant,
test_provider,
mock_provider: MockProvider,
empty_cache_dir,
hass_client: ClientSessionGenerator,
mock_tts,
) -> None:
"""Set up component and load cache and get without mem cache."""
_, tts_data = test_provider.get_tts_audio("bla", "en")
_, tts_data = mock_provider.get_tts_audio("bla", "en")
assert tts_data is not None
cache_file = (
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3"
)
@ -870,7 +815,7 @@ async def test_setup_component_and_web_get_url_bad_config(
assert req.status == HTTPStatus.BAD_REQUEST
async def test_tags_with_wave(hass: HomeAssistant, test_provider) -> None:
async def test_tags_with_wave(hass: HomeAssistant, mock_provider: MockProvider) -> None:
"""Set up a TTS platform and call service and receive voice."""
# below data represents an empty wav file
@ -882,7 +827,7 @@ async def test_tags_with_wave(hass: HomeAssistant, test_provider) -> None:
tagged_data = ORIG_WRITE_TAGS(
"42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.wav",
tts_data,
test_provider,
mock_provider,
"AI person is in front of your door.",
"en",
None,
@ -904,9 +849,9 @@ async def test_tags_with_wave(hass: HomeAssistant, test_provider) -> None:
)
def test_valid_base_url(value) -> None:
"""Test we validate base urls."""
assert tts.valid_base_url(value) == normalize_url(value)
assert _valid_base_url(value) == normalize_url(value)
# Test we strip trailing `/`
assert tts.valid_base_url(value + "/") == normalize_url(value)
assert _valid_base_url(value + "/") == normalize_url(value)
@pytest.mark.parametrize(
@ -926,7 +871,7 @@ def test_valid_base_url(value) -> None:
def test_invalid_base_url(value) -> None:
"""Test we catch bad base urls."""
with pytest.raises(vol.Invalid):
tts.valid_base_url(value)
_valid_base_url(value)
@pytest.mark.parametrize(
@ -1000,9 +945,11 @@ async def test_support_options(hass: HomeAssistant, setup_tts) -> None:
)
async def test_fetching_in_async(hass: HomeAssistant, hass_client) -> None:
async def test_fetching_in_async(
hass: HomeAssistant, hass_client: ClientSessionGenerator
) -> None:
"""Test async fetching of data."""
tts_audio = asyncio.Future()
tts_audio: asyncio.Future[bytes] = asyncio.Future()
class ProviderWithAsyncFetching(MockProvider):
"""Provider that supports audio output option."""
@ -1067,4 +1014,7 @@ async def test_fetching_in_async(hass: HomeAssistant, hass_client) -> None:
tts_audio = asyncio.Future()
tts_audio.set_result(b"test 2")
await tts.async_get_media_source_audio(hass, media_source_id) == ("mp3", b"test 2")
assert await tts.async_get_media_source_audio(hass, media_source_id) == (
"mp3",
b"test 2",
)

View File

@ -0,0 +1,121 @@
"""Test the legacy tts setup."""
from __future__ import annotations
import pytest
from homeassistant.components.tts import DOMAIN, Provider
from homeassistant.core import HomeAssistant
from homeassistant.helpers.discovery import async_load_platform
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.setup import async_setup_component
from .common import MockTTS
from tests.common import (
MockModule,
assert_setup_component,
mock_integration,
mock_platform,
)
async def test_default_provider_attributes() -> None:
"""Test default provider properties."""
provider = Provider()
assert provider.hass is None
assert provider.name is None
assert provider.default_language is None
assert provider.supported_languages is None
assert provider.supported_options is None
assert provider.default_options is None
async def test_deprecated_platform(hass: HomeAssistant) -> None:
"""Test deprecated google platform."""
with assert_setup_component(0, DOMAIN):
assert await async_setup_component(
hass, DOMAIN, {DOMAIN: {"platform": "google"}}
)
async def test_invalid_platform(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test platform setup with an invalid platform."""
await async_load_platform(
hass,
"tts",
"bad_tts",
{"tts": [{"platform": "bad_tts"}]},
hass_config={"tts": [{"platform": "bad_tts"}]},
)
await hass.async_block_till_done()
assert "Unknown text to speech platform specified" in caplog.text
async def test_platform_setup_without_provider(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test platform setup without provider returned."""
class BadPlatform(MockTTS):
"""A mock TTS platform without provider."""
async def async_get_engine(
self,
hass: HomeAssistant,
config: ConfigType,
discovery_info: DiscoveryInfoType | None = None,
) -> Provider | None:
"""Raise exception during platform setup."""
return None
mock_integration(hass, MockModule(domain="bad_tts"))
mock_platform(hass, "bad_tts.tts", BadPlatform())
await async_load_platform(
hass,
"tts",
"bad_tts",
{},
hass_config={"tts": [{"platform": "bad_tts"}]},
)
await hass.async_block_till_done()
assert "Error setting up platform: bad_tts" in caplog.text
async def test_platform_setup_with_error(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test platform setup with an error during setup."""
class BadPlatform(MockTTS):
"""A mock TTS platform with a setup error."""
async def async_get_engine(
self,
hass: HomeAssistant,
config: ConfigType,
discovery_info: DiscoveryInfoType | None = None,
) -> Provider:
"""Raise exception during platform setup."""
raise Exception("Setup error") # pylint: disable=broad-exception-raised
mock_integration(hass, MockModule(domain="bad_tts"))
mock_platform(hass, "bad_tts.tts", BadPlatform())
await async_load_platform(
hass,
"tts",
"bad_tts",
{},
hass_config={"tts": [{"platform": "bad_tts"}]},
)
await hass.async_block_till_done()
assert "Error setting up platform: bad_tts" in caplog.text