diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index 9c5b995b5de..3eb321eaeb6 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -1,135 +1,77 @@ """Provide functionality for TTS.""" from __future__ import annotations -from abc import ABC, abstractmethod import asyncio -from collections.abc import Mapping from datetime import datetime -import functools as ft import hashlib from http import HTTPStatus import io import logging import mimetypes import os -from pathlib import Path import re -from typing import TYPE_CHECKING, Any, TypedDict, cast +from typing import TypedDict from aiohttp import web import mutagen from mutagen.id3 import ID3, TextFrame as ID3Text import voluptuous as vol -import yarl from homeassistant.components.http import HomeAssistantView -from homeassistant.components.media_player import ( - ATTR_MEDIA_ANNOUNCE, - ATTR_MEDIA_CONTENT_ID, - ATTR_MEDIA_CONTENT_TYPE, - DOMAIN as DOMAIN_MP, - SERVICE_PLAY_MEDIA, - MediaType, -) -from homeassistant.const import ( - ATTR_ENTITY_ID, - CONF_DESCRIPTION, - CONF_NAME, - CONF_PLATFORM, - PLATFORM_FORMAT, -) +from homeassistant.const import PLATFORM_FORMAT from homeassistant.core import HassJob, HomeAssistant, ServiceCall, callback from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import config_per_platform, discovery -import homeassistant.helpers.config_validation as cv from homeassistant.helpers.event import async_call_later from homeassistant.helpers.network import get_url -from homeassistant.helpers.service import async_set_service_schema -from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType -from homeassistant.setup import async_prepare_setup_platform -from homeassistant.util.network import normalize_url -from homeassistant.util.yaml import load_yaml +from homeassistant.helpers.typing import ConfigType -from .const import DOMAIN +from .const import ( + ATTR_CACHE, + ATTR_LANGUAGE, + ATTR_MESSAGE, + ATTR_OPTIONS, + CONF_BASE_URL, + CONF_CACHE, + CONF_CACHE_DIR, + CONF_TIME_MEMORY, + DEFAULT_CACHE, + DEFAULT_CACHE_DIR, + DEFAULT_TIME_MEMORY, + DOMAIN, + TtsAudioType, +) +from .legacy import PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE, Provider, async_setup_legacy from .media_source import generate_media_source_id, media_source_id_to_kwargs +__all__ = [ + "async_get_media_source_audio", + "async_resolve_engine", + "async_support_options", + "ATTR_AUDIO_OUTPUT", + "CONF_LANG", + "DEFAULT_CACHE_DIR", + "generate_media_source_id", + "get_base_url", + "PLATFORM_SCHEMA_BASE", + "PLATFORM_SCHEMA", + "Provider", + "TtsAudioType", +] + _LOGGER = logging.getLogger(__name__) -TtsAudioType = tuple[str | None, bytes | None] - -ATTR_CACHE = "cache" -ATTR_LANGUAGE = "language" -ATTR_MESSAGE = "message" -ATTR_OPTIONS = "options" ATTR_PLATFORM = "platform" ATTR_AUDIO_OUTPUT = "audio_output" +CONF_LANG = "language" + BASE_URL_KEY = "tts_base_url" -CONF_BASE_URL = "base_url" -CONF_CACHE = "cache" -CONF_CACHE_DIR = "cache_dir" -CONF_LANG = "language" -CONF_SERVICE_NAME = "service_name" -CONF_TIME_MEMORY = "time_memory" - -CONF_FIELDS = "fields" - -DEFAULT_CACHE = True -DEFAULT_CACHE_DIR = "tts" -DEFAULT_TIME_MEMORY = 300 - SERVICE_CLEAR_CACHE = "clear_cache" -SERVICE_SAY = "say" _RE_VOICE_FILE = re.compile(r"([a-f0-9]{40})_([^_]+)_([^_]+)_([a-z_]+)\.[a-z0-9]{3,4}") KEY_PATTERN = "{0}_{1}_{2}_{3}" - -def _deprecated_platform(value: str) -> str: - """Validate if platform is deprecated.""" - if value == "google": - raise vol.Invalid( - "google tts service has been renamed to google_translate," - " please update your configuration." - ) - return value - - -def valid_base_url(value: str) -> str: - """Validate base url, return value.""" - url = yarl.URL(cv.url(value)) - - if url.path != "/": - raise vol.Invalid("Path should be empty") - - return normalize_url(value) - - -PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend( - { - vol.Required(CONF_PLATFORM): vol.All(cv.string, _deprecated_platform), - vol.Optional(CONF_CACHE, default=DEFAULT_CACHE): cv.boolean, - vol.Optional(CONF_CACHE_DIR, default=DEFAULT_CACHE_DIR): cv.string, - vol.Optional(CONF_TIME_MEMORY, default=DEFAULT_TIME_MEMORY): vol.All( - vol.Coerce(int), vol.Range(min=60, max=57600) - ), - vol.Optional(CONF_BASE_URL): valid_base_url, - vol.Optional(CONF_SERVICE_NAME): cv.string, - } -) -PLATFORM_SCHEMA_BASE = cv.PLATFORM_SCHEMA_BASE.extend(PLATFORM_SCHEMA.schema) - -SCHEMA_SERVICE_SAY = vol.Schema( - { - vol.Required(ATTR_MESSAGE): cv.string, - vol.Optional(ATTR_CACHE): cv.boolean, - vol.Required(ATTR_ENTITY_ID): cv.comp_entity_ids, - vol.Optional(ATTR_LANGUAGE): cv.string, - vol.Optional(ATTR_OPTIONS): dict, - } -) - SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({}) @@ -192,22 +134,23 @@ async def async_get_media_source_audio( async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up TTS.""" - tts = SpeechManager(hass) + # Legacy config options + conf = config[DOMAIN][0] if config.get(DOMAIN) else {} + use_cache: bool = conf.get(CONF_CACHE, DEFAULT_CACHE) + cache_dir: str = conf.get(CONF_CACHE_DIR, DEFAULT_CACHE_DIR) + time_memory: int = conf.get(CONF_TIME_MEMORY, DEFAULT_TIME_MEMORY) + base_url: str | None = conf.get(CONF_BASE_URL) + if base_url is not None: + _LOGGER.warning( + "TTS base_url option is deprecated. Configure internal/external URL" + " instead" + ) + hass.data[BASE_URL_KEY] = base_url + + tts = SpeechManager(hass, use_cache, cache_dir, time_memory, base_url) try: - conf = config[DOMAIN][0] if config.get(DOMAIN, []) else {} - use_cache = conf.get(CONF_CACHE, DEFAULT_CACHE) - cache_dir = conf.get(CONF_CACHE_DIR, DEFAULT_CACHE_DIR) - time_memory = conf.get(CONF_TIME_MEMORY, DEFAULT_TIME_MEMORY) - base_url = conf.get(CONF_BASE_URL) - if base_url is not None: - _LOGGER.warning( - "TTS base_url option is deprecated. Configure internal/external URL" - " instead" - ) - hass.data[BASE_URL_KEY] = base_url - - await tts.async_init_cache(use_cache, cache_dir, time_memory, base_url) + await tts.async_init_cache() except (HomeAssistantError, KeyError): _LOGGER.exception("Error on cache init") return False @@ -216,99 +159,10 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: hass.http.register_view(TextToSpeechView(tts)) hass.http.register_view(TextToSpeechUrlView(tts)) - # Load service descriptions from tts/services.yaml - services_yaml = Path(__file__).parent / "services.yaml" - services_dict = cast( - dict, await hass.async_add_executor_job(load_yaml, str(services_yaml)) - ) + platform_setups = await async_setup_legacy(hass, config) - async def async_setup_platform( - p_type: str, - p_config: ConfigType | None = None, - discovery_info: DiscoveryInfoType | None = None, - ) -> None: - """Set up a TTS platform.""" - if p_config is None: - p_config = {} - - platform = await async_prepare_setup_platform(hass, config, DOMAIN, p_type) - if platform is None: - return - - try: - if hasattr(platform, "async_get_engine"): - provider = await platform.async_get_engine( - hass, p_config, discovery_info - ) - else: - provider = await hass.async_add_executor_job( - platform.get_engine, hass, p_config, discovery_info - ) - - if provider is None: - _LOGGER.error("Error setting up platform %s", p_type) - return - - tts.async_register_engine(p_type, provider, p_config) - except Exception: # pylint: disable=broad-except - _LOGGER.exception("Error setting up platform: %s", p_type) - return - - async def async_say_handle(service: ServiceCall) -> None: - """Service handle for say.""" - entity_ids = service.data[ATTR_ENTITY_ID] - - await hass.services.async_call( - DOMAIN_MP, - SERVICE_PLAY_MEDIA, - { - ATTR_ENTITY_ID: entity_ids, - ATTR_MEDIA_CONTENT_ID: generate_media_source_id( - hass, - engine=p_type, - message=service.data[ATTR_MESSAGE], - language=service.data.get(ATTR_LANGUAGE), - options=service.data.get(ATTR_OPTIONS), - cache=service.data.get(ATTR_CACHE), - ), - ATTR_MEDIA_CONTENT_TYPE: MediaType.MUSIC, - ATTR_MEDIA_ANNOUNCE: True, - }, - blocking=True, - context=service.context, - ) - - service_name = p_config.get(CONF_SERVICE_NAME, f"{p_type}_{SERVICE_SAY}") - hass.services.async_register( - DOMAIN, service_name, async_say_handle, schema=SCHEMA_SERVICE_SAY - ) - - # Register the service description - service_desc = { - CONF_NAME: f"Say a TTS message with {p_type}", - CONF_DESCRIPTION: ( - f"Say something using text-to-speech on a media player with {p_type}." - ), - CONF_FIELDS: services_dict[SERVICE_SAY][CONF_FIELDS], - } - async_set_service_schema(hass, DOMAIN, service_name, service_desc) - - setup_tasks = [ - asyncio.create_task(async_setup_platform(p_type, p_config)) - for p_type, p_config in config_per_platform(config, DOMAIN) - if p_type is not None - ] - - if setup_tasks: - await asyncio.wait(setup_tasks) - - async def async_platform_discovered( - platform: str, info: dict[str, Any] | None - ) -> None: - """Handle for discovered platform.""" - await async_setup_platform(platform, discovery_info=info) - - discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered) + if platform_setups: + await asyncio.wait([asyncio.create_task(setup) for setup in platform_setups]) async def async_clear_cache_handle(service: ServiceCall) -> None: """Handle clear cache service call.""" @@ -337,29 +191,30 @@ def _hash_options(options: dict) -> str: class SpeechManager: """Representation of a speech store.""" - def __init__(self, hass: HomeAssistant) -> None: + def __init__( + self, + hass: HomeAssistant, + use_cache: bool, + cache_dir: str, + time_memory: int, + base_url: str | None, + ) -> None: """Initialize a speech store.""" self.hass = hass self.providers: dict[str, Provider] = {} - self.use_cache = DEFAULT_CACHE - self.cache_dir = DEFAULT_CACHE_DIR - self.time_memory = DEFAULT_TIME_MEMORY - self.base_url: str | None = None + self.use_cache = use_cache + self.cache_dir = cache_dir + self.time_memory = time_memory + self.base_url = base_url self.file_cache: dict[str, str] = {} self.mem_cache: dict[str, TTSCache] = {} - async def async_init_cache( - self, use_cache: bool, cache_dir: str, time_memory: int, base_url: str | None - ) -> None: + async def async_init_cache(self) -> None: """Init config folder and load file cache.""" - self.use_cache = use_cache - self.time_memory = time_memory - self.base_url = base_url - try: self.cache_dir = await self.hass.async_add_executor_job( - _init_tts_cache_dir, self.hass, cache_dir + _init_tts_cache_dir, self.hass, self.cache_dir ) except OSError as err: raise HomeAssistantError(f"Can't init cache dir {err}") from err @@ -390,7 +245,7 @@ class SpeechManager: self.file_cache = {} @callback - def async_register_engine( + def async_register_legacy_engine( self, engine: str, provider: Provider, config: ConfigType ) -> None: """Register a TTS provider.""" @@ -739,52 +594,6 @@ class SpeechManager: return data_bytes.getvalue() -class Provider(ABC): - """Represent a single TTS provider.""" - - hass: HomeAssistant | None = None - name: str | None = None - - @property - def default_language(self) -> str | None: - """Return the default language.""" - return None - - @property - @abstractmethod - def supported_languages(self) -> list[str]: - """Return a list of supported languages.""" - - @property - def supported_options(self) -> list[str] | None: - """Return a list of supported options like voice, emotions.""" - return None - - @property - def default_options(self) -> Mapping[str, Any] | None: - """Return a mapping with the default options.""" - return None - - def get_tts_audio( - self, message: str, language: str, options: dict[str, Any] | None = None - ) -> TtsAudioType: - """Load tts audio file from provider.""" - raise NotImplementedError() - - async def async_get_tts_audio( - self, message: str, language: str, options: dict[str, Any] | None = None - ) -> TtsAudioType: - """Load tts audio file from provider. - - Return a tuple of file extension and data as bytes. - """ - if TYPE_CHECKING: - assert self.hass - return await self.hass.async_add_executor_job( - ft.partial(self.get_tts_audio, message, language, options=options) - ) - - def _init_tts_cache_dir(hass: HomeAssistant, cache_dir: str) -> str: """Init cache folder.""" if not os.path.isabs(cache_dir): diff --git a/homeassistant/components/tts/const.py b/homeassistant/components/tts/const.py index 492e995b87f..ac066de48c7 100644 --- a/homeassistant/components/tts/const.py +++ b/homeassistant/components/tts/const.py @@ -1,3 +1,19 @@ """Text-to-speech constants.""" +ATTR_CACHE = "cache" +ATTR_LANGUAGE = "language" +ATTR_MESSAGE = "message" +ATTR_OPTIONS = "options" + +CONF_BASE_URL = "base_url" +CONF_CACHE = "cache" +CONF_CACHE_DIR = "cache_dir" +CONF_FIELDS = "fields" +CONF_TIME_MEMORY = "time_memory" + +DEFAULT_CACHE = True +DEFAULT_CACHE_DIR = "tts" +DEFAULT_TIME_MEMORY = 300 DOMAIN = "tts" + +TtsAudioType = tuple[str | None, bytes | None] diff --git a/homeassistant/components/tts/legacy.py b/homeassistant/components/tts/legacy.py new file mode 100644 index 00000000000..1f21d249504 --- /dev/null +++ b/homeassistant/components/tts/legacy.py @@ -0,0 +1,252 @@ +"""Provide the legacy TTS service provider interface.""" +from __future__ import annotations + +from abc import abstractmethod +from collections.abc import Coroutine, Mapping +from functools import partial +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Any, cast + +import voluptuous as vol +import yarl + +from homeassistant.components.media_player import ( + ATTR_MEDIA_ANNOUNCE, + ATTR_MEDIA_CONTENT_ID, + ATTR_MEDIA_CONTENT_TYPE, + DOMAIN as DOMAIN_MP, + SERVICE_PLAY_MEDIA, + MediaType, +) +from homeassistant.const import ( + ATTR_ENTITY_ID, + CONF_DESCRIPTION, + CONF_NAME, + CONF_PLATFORM, +) +from homeassistant.core import HomeAssistant, ServiceCall +from homeassistant.helpers import config_per_platform, discovery +import homeassistant.helpers.config_validation as cv +from homeassistant.helpers.service import async_set_service_schema +from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType +from homeassistant.setup import async_prepare_setup_platform +from homeassistant.util.network import normalize_url +from homeassistant.util.yaml import load_yaml + +from .const import ( + ATTR_CACHE, + ATTR_LANGUAGE, + ATTR_MESSAGE, + ATTR_OPTIONS, + CONF_BASE_URL, + CONF_CACHE, + CONF_CACHE_DIR, + CONF_FIELDS, + CONF_TIME_MEMORY, + DEFAULT_CACHE, + DEFAULT_CACHE_DIR, + DEFAULT_TIME_MEMORY, + DOMAIN, + TtsAudioType, +) +from .media_source import generate_media_source_id + +if TYPE_CHECKING: + from . import SpeechManager + +_LOGGER = logging.getLogger(__name__) + +CONF_SERVICE_NAME = "service_name" + + +def _deprecated_platform(value: str) -> str: + """Validate if platform is deprecated.""" + if value == "google": + raise vol.Invalid( + "google tts service has been renamed to google_translate," + " please update your configuration." + ) + return value + + +def _valid_base_url(value: str) -> str: + """Validate base url, return value.""" + url = yarl.URL(cv.url(value)) + + if url.path != "/": + raise vol.Invalid("Path should be empty") + + return normalize_url(value) + + +PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend( + { + vol.Required(CONF_PLATFORM): vol.All(cv.string, _deprecated_platform), + vol.Optional(CONF_CACHE, default=DEFAULT_CACHE): cv.boolean, + vol.Optional(CONF_CACHE_DIR, default=DEFAULT_CACHE_DIR): cv.string, + vol.Optional(CONF_TIME_MEMORY, default=DEFAULT_TIME_MEMORY): vol.All( + vol.Coerce(int), vol.Range(min=60, max=57600) + ), + vol.Optional(CONF_BASE_URL): _valid_base_url, + vol.Optional(CONF_SERVICE_NAME): cv.string, + } +) +PLATFORM_SCHEMA_BASE = cv.PLATFORM_SCHEMA_BASE.extend(PLATFORM_SCHEMA.schema) + +SERVICE_SAY = "say" + +SCHEMA_SERVICE_SAY = vol.Schema( + { + vol.Required(ATTR_MESSAGE): cv.string, + vol.Optional(ATTR_CACHE): cv.boolean, + vol.Required(ATTR_ENTITY_ID): cv.comp_entity_ids, + vol.Optional(ATTR_LANGUAGE): cv.string, + vol.Optional(ATTR_OPTIONS): dict, + } +) + + +async def async_setup_legacy( + hass: HomeAssistant, config: ConfigType +) -> list[Coroutine[Any, Any, None]]: + """Set up legacy text to speech providers.""" + tts: SpeechManager = hass.data[DOMAIN] + + # Load service descriptions from tts/services.yaml + services_yaml = Path(__file__).parent / "services.yaml" + services_dict = cast( + dict, await hass.async_add_executor_job(load_yaml, str(services_yaml)) + ) + + async def async_setup_platform( + p_type: str, + p_config: ConfigType | None = None, + discovery_info: DiscoveryInfoType | None = None, + ) -> None: + """Set up a TTS platform.""" + if p_config is None: + p_config = {} + + platform = await async_prepare_setup_platform(hass, config, DOMAIN, p_type) + if platform is None: + _LOGGER.error("Unknown text to speech platform specified") + return + + try: + if hasattr(platform, "async_get_engine"): + provider = await platform.async_get_engine( + hass, p_config, discovery_info + ) + else: + provider = await hass.async_add_executor_job( + platform.get_engine, hass, p_config, discovery_info + ) + + if provider is None: + _LOGGER.error("Error setting up platform: %s", p_type) + return + + tts.async_register_legacy_engine(p_type, provider, p_config) + except Exception: # pylint: disable=broad-except + _LOGGER.exception("Error setting up platform: %s", p_type) + return + + async def async_say_handle(service: ServiceCall) -> None: + """Service handle for say.""" + entity_ids = service.data[ATTR_ENTITY_ID] + + await hass.services.async_call( + DOMAIN_MP, + SERVICE_PLAY_MEDIA, + { + ATTR_ENTITY_ID: entity_ids, + ATTR_MEDIA_CONTENT_ID: generate_media_source_id( + hass, + engine=p_type, + message=service.data[ATTR_MESSAGE], + language=service.data.get(ATTR_LANGUAGE), + options=service.data.get(ATTR_OPTIONS), + cache=service.data.get(ATTR_CACHE), + ), + ATTR_MEDIA_CONTENT_TYPE: MediaType.MUSIC, + ATTR_MEDIA_ANNOUNCE: True, + }, + blocking=True, + context=service.context, + ) + + service_name = p_config.get(CONF_SERVICE_NAME, f"{p_type}_{SERVICE_SAY}") + hass.services.async_register( + DOMAIN, service_name, async_say_handle, schema=SCHEMA_SERVICE_SAY + ) + + # Register the service description + service_desc = { + CONF_NAME: f"Say a TTS message with {p_type}", + CONF_DESCRIPTION: ( + f"Say something using text-to-speech on a media player with {p_type}." + ), + CONF_FIELDS: services_dict[SERVICE_SAY][CONF_FIELDS], + } + async_set_service_schema(hass, DOMAIN, service_name, service_desc) + + async def async_platform_discovered( + platform: str, info: dict[str, Any] | None + ) -> None: + """Handle for discovered platform.""" + await async_setup_platform(platform, discovery_info=info) + + discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered) + + return [ + async_setup_platform(p_type, p_config) + for p_type, p_config in config_per_platform(config, DOMAIN) + if p_type is not None + ] + + +class Provider: + """Represent a single TTS provider.""" + + hass: HomeAssistant | None = None + name: str | None = None + + @property + def default_language(self) -> str | None: + """Return the default language.""" + return None + + @property + @abstractmethod + def supported_languages(self) -> list[str]: + """Return a list of supported languages.""" + + @property + def supported_options(self) -> list[str] | None: + """Return a list of supported options like voice, emotions.""" + return None + + @property + def default_options(self) -> Mapping[str, Any] | None: + """Return a mapping with the default options.""" + return None + + def get_tts_audio( + self, message: str, language: str, options: dict[str, Any] | None = None + ) -> TtsAudioType: + """Load tts audio file from provider.""" + raise NotImplementedError() + + async def async_get_tts_audio( + self, message: str, language: str, options: dict[str, Any] | None = None + ) -> TtsAudioType: + """Load tts audio file from provider. + + Return a tuple of file extension and data as bytes. + """ + if TYPE_CHECKING: + assert self.hass + return await self.hass.async_add_executor_job( + partial(self.get_tts_audio, message, language, options=options) + ) diff --git a/tests/components/tts/common.py b/tests/components/tts/common.py new file mode 100644 index 00000000000..fbdfad54e18 --- /dev/null +++ b/tests/components/tts/common.py @@ -0,0 +1,77 @@ +"""Provide common tests tools for tts.""" +from __future__ import annotations + +from typing import Any + +import voluptuous as vol + +from homeassistant.components.tts import ( + CONF_LANG, + PLATFORM_SCHEMA, + Provider, + TtsAudioType, +) +from homeassistant.core import HomeAssistant +from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType + +from tests.common import MockPlatform + +SUPPORT_LANGUAGES = ["de", "en", "en_US"] + +DEFAULT_LANG = "en" + + +class MockProvider(Provider): + """Test speech API provider.""" + + def __init__(self, lang: str) -> None: + """Initialize test provider.""" + self._lang = lang + self.name = "Test" + + @property + def default_language(self) -> str: + """Return the default language.""" + return self._lang + + @property + def supported_languages(self) -> list[str]: + """Return list of supported languages.""" + return SUPPORT_LANGUAGES + + @property + def supported_options(self) -> list[str]: + """Return list of supported options like voice, emotions.""" + return ["voice", "age"] + + def get_tts_audio( + self, message: str, language: str, options: dict[str, Any] | None = None + ) -> TtsAudioType: + """Load TTS dat.""" + return ("mp3", b"") + + +class MockTTS(MockPlatform): + """A mock TTS platform.""" + + PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend( + {vol.Optional(CONF_LANG, default=DEFAULT_LANG): vol.In(SUPPORT_LANGUAGES)} + ) + + def __init__( + self, provider: type[MockProvider] | None = None, **kwargs: Any + ) -> None: + """Initialize.""" + super().__init__(**kwargs) + if provider is None: + provider = MockProvider + self._provider = provider + + async def async_get_engine( + self, + hass: HomeAssistant, + config: ConfigType, + discovery_info: DiscoveryInfoType | None = None, + ) -> Provider | None: + """Set up a mock speech component.""" + return self._provider(config.get(CONF_LANG, DEFAULT_LANG)) diff --git a/tests/components/tts/conftest.py b/tests/components/tts/conftest.py index c251bdcb8bf..ca1416cd0a4 100644 --- a/tests/components/tts/conftest.py +++ b/tests/components/tts/conftest.py @@ -7,6 +7,12 @@ from unittest.mock import patch import pytest from homeassistant.components.tts import _get_cache_files +from homeassistant.config import async_process_ha_core_config +from homeassistant.core import HomeAssistant + +from .common import MockTTS + +from tests.common import MockModule, mock_integration, mock_platform @pytest.hookimpl(tryfirst=True, hookwrapper=True) @@ -71,3 +77,19 @@ def mutagen_mock(): side_effect=lambda *args: args[1], ) as mock_write_tags: yield mock_write_tags + + +@pytest.fixture(autouse=True) +async def internal_url_mock(hass: HomeAssistant) -> None: + """Mock internal URL of the instance.""" + await async_process_ha_core_config( + hass, + {"internal_url": "http://example.local:8123"}, + ) + + +@pytest.fixture +async def mock_tts(hass: HomeAssistant) -> None: + """Mock TTS.""" + mock_integration(hass, MockModule(domain="test")) + mock_platform(hass, "test.tts", MockTTS()) diff --git a/tests/components/tts/test_init.py b/tests/components/tts/test_init.py index 0a4f5a273be..cfb3b2d7efc 100644 --- a/tests/components/tts/test_init.py +++ b/tests/components/tts/test_init.py @@ -17,13 +17,14 @@ from homeassistant.components.media_player import ( MediaType, ) from homeassistant.components.media_source import Unresolvable -from homeassistant.config import async_process_ha_core_config +from homeassistant.components.tts.legacy import _valid_base_url from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.setup import async_setup_component from homeassistant.util.network import normalize_url +from .common import MockProvider, MockTTS + from tests.common import ( MockModule, assert_setup_component, @@ -36,7 +37,7 @@ from tests.typing import ClientSessionGenerator ORIG_WRITE_TAGS = tts.SpeechManager.write_tags -async def get_media_source_url(hass, media_content_id): +async def get_media_source_url(hass: HomeAssistant, media_content_id: str) -> str: """Get the media source url.""" if media_source.DOMAIN not in hass.config.components: assert await async_setup_component(hass, media_source.DOMAIN, {}) @@ -45,88 +46,14 @@ async def get_media_source_url(hass, media_content_id): return resolved.url -SUPPORT_LANGUAGES = ["de", "en", "en_US"] - -DEFAULT_LANG = "en" - - -class MockProvider(tts.Provider): - """Test speech API provider.""" - - def __init__(self, lang: str) -> None: - """Initialize test provider.""" - self._lang = lang - self.name = "Test" - - @property - def default_language(self) -> str: - """Return the default language.""" - return self._lang - - @property - def supported_languages(self) -> list[str]: - """Return list of supported languages.""" - return SUPPORT_LANGUAGES - - @property - def supported_options(self) -> list[str]: - """Return list of supported options like voice, emotions.""" - return ["voice", "age"] - - def get_tts_audio( - self, message: str, language: str, options: dict[str, Any] | None = None - ) -> tts.TtsAudioType: - """Load TTS dat.""" - return ("mp3", b"") - - -class MockTTS: - """A mock TTS platform.""" - - PLATFORM_SCHEMA = tts.PLATFORM_SCHEMA.extend( - {vol.Optional(tts.CONF_LANG, default=DEFAULT_LANG): vol.In(SUPPORT_LANGUAGES)} - ) - - def __init__(self, provider=None) -> None: - """Initialize.""" - if provider is None: - provider = MockProvider - self._provider = provider - - async def async_get_engine( - self, - hass: HomeAssistant, - config: ConfigType, - discovery_info: DiscoveryInfoType | None = None, - ) -> tts.Provider: - """Set up a mock speech component.""" - return self._provider(config.get(tts.CONF_LANG, DEFAULT_LANG)) - - @pytest.fixture -def test_provider(): +def mock_provider() -> MockProvider: """Test TTS provider.""" return MockProvider("en") -@pytest.fixture(autouse=True) -async def internal_url_mock(hass): - """Mock internal URL of the instance.""" - await async_process_ha_core_config( - hass, - {"internal_url": "http://example.local:8123"}, - ) - - @pytest.fixture -async def mock_tts(hass): - """Mock TTS.""" - mock_integration(hass, MockModule(domain="test")) - mock_platform(hass, "test.tts", MockTTS()) - - -@pytest.fixture -async def setup_tts(hass, mock_tts): +async def setup_tts(hass: HomeAssistant, mock_tts: None) -> None: """Mock TTS.""" assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}}) @@ -366,6 +293,8 @@ async def test_setup_component_and_test_with_service_options_def( config = {tts.DOMAIN: {"platform": "test"}} class MockProviderWithDefaults(MockProvider): + """Mock provider with default options.""" + @property def default_options(self): return {"voice": "alex"} @@ -413,6 +342,8 @@ async def test_setup_component_and_test_with_service_options_def_2( config = {tts.DOMAIN: {"platform": "test"}} class MockProviderWithDefaults(MockProvider): + """Mock provider with default options.""" + @property def default_options(self): return {"voice": "alex"} @@ -550,7 +481,10 @@ async def test_setup_component_and_test_service_clear_cache( async def test_setup_component_and_test_service_with_receive_voice( - hass: HomeAssistant, test_provider, hass_client: ClientSessionGenerator, mock_tts + hass: HomeAssistant, + mock_provider: MockProvider, + hass_client: ClientSessionGenerator, + mock_tts, ) -> None: """Set up a TTS platform and call service and receive voice.""" calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) @@ -576,11 +510,12 @@ async def test_setup_component_and_test_service_with_receive_voice( url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) client = await hass_client() req = await client.get(url) - _, tts_data = test_provider.get_tts_audio("bla", "en") + _, tts_data = mock_provider.get_tts_audio("bla", "en") + assert tts_data is not None tts_data = tts.SpeechManager.write_tags( "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3", tts_data, - test_provider, + mock_provider, message, "en", None, @@ -596,7 +531,10 @@ async def test_setup_component_and_test_service_with_receive_voice( async def test_setup_component_and_test_service_with_receive_voice_german( - hass: HomeAssistant, test_provider, hass_client: ClientSessionGenerator, mock_tts + hass: HomeAssistant, + mock_provider: MockProvider, + hass_client: ClientSessionGenerator, + mock_tts, ) -> None: """Set up a TTS platform and call service and receive voice.""" calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) @@ -619,11 +557,12 @@ async def test_setup_component_and_test_service_with_receive_voice_german( url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) client = await hass_client() req = await client.get(url) - _, tts_data = test_provider.get_tts_audio("bla", "de") + _, tts_data = mock_provider.get_tts_audio("bla", "de") + assert tts_data is not None tts_data = tts.SpeechManager.write_tags( "42f18378fd4393d18c8dd11d03fa9563c1e54491_de_-_test.mp3", tts_data, - test_provider, + mock_provider, "There is someone at the door.", "de", None, @@ -722,12 +661,13 @@ async def test_setup_component_test_with_cache_call_service_without_cache( async def test_setup_component_test_with_cache_dir( - hass: HomeAssistant, empty_cache_dir, test_provider + hass: HomeAssistant, empty_cache_dir, mock_provider: MockProvider ) -> None: """Set up a TTS platform with cache and call service without cache.""" calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) - _, tts_data = test_provider.get_tts_audio("bla", "en") + _, tts_data = mock_provider.get_tts_audio("bla", "en") + assert tts_data is not None cache_file = ( empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3" ) @@ -738,12 +678,14 @@ async def test_setup_component_test_with_cache_dir( config = {tts.DOMAIN: {"platform": "test", "cache": True}} class MockProviderBoom(MockProvider): + """Mock provider that blows up.""" + def get_tts_audio( self, message: str, language: str, options: dict[str, Any] | None = None ) -> tts.TtsAudioType: """Load TTS dat.""" # This should not be called, data should be fetched from cache - raise Exception("Boom!") + raise Exception("Boom!") # pylint: disable=broad-exception-raised mock_integration(hass, MockModule(domain="test")) mock_platform(hass, "test.tts", MockTTS(MockProviderBoom)) @@ -775,6 +717,8 @@ async def test_setup_component_test_with_error_on_get_tts(hass: HomeAssistant) - config = {tts.DOMAIN: {"platform": "test"}} class MockProviderEmpty(MockProvider): + """Mock provider with empty get_tts_audio.""" + def get_tts_audio( self, message: str, language: str, options: dict[str, Any] | None = None ) -> tts.TtsAudioType: @@ -803,13 +747,14 @@ async def test_setup_component_test_with_error_on_get_tts(hass: HomeAssistant) - async def test_setup_component_load_cache_retrieve_without_mem_cache( hass: HomeAssistant, - test_provider, + mock_provider: MockProvider, empty_cache_dir, hass_client: ClientSessionGenerator, mock_tts, ) -> None: """Set up component and load cache and get without mem cache.""" - _, tts_data = test_provider.get_tts_audio("bla", "en") + _, tts_data = mock_provider.get_tts_audio("bla", "en") + assert tts_data is not None cache_file = ( empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3" ) @@ -870,7 +815,7 @@ async def test_setup_component_and_web_get_url_bad_config( assert req.status == HTTPStatus.BAD_REQUEST -async def test_tags_with_wave(hass: HomeAssistant, test_provider) -> None: +async def test_tags_with_wave(hass: HomeAssistant, mock_provider: MockProvider) -> None: """Set up a TTS platform and call service and receive voice.""" # below data represents an empty wav file @@ -882,7 +827,7 @@ async def test_tags_with_wave(hass: HomeAssistant, test_provider) -> None: tagged_data = ORIG_WRITE_TAGS( "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.wav", tts_data, - test_provider, + mock_provider, "AI person is in front of your door.", "en", None, @@ -904,9 +849,9 @@ async def test_tags_with_wave(hass: HomeAssistant, test_provider) -> None: ) def test_valid_base_url(value) -> None: """Test we validate base urls.""" - assert tts.valid_base_url(value) == normalize_url(value) + assert _valid_base_url(value) == normalize_url(value) # Test we strip trailing `/` - assert tts.valid_base_url(value + "/") == normalize_url(value) + assert _valid_base_url(value + "/") == normalize_url(value) @pytest.mark.parametrize( @@ -926,7 +871,7 @@ def test_valid_base_url(value) -> None: def test_invalid_base_url(value) -> None: """Test we catch bad base urls.""" with pytest.raises(vol.Invalid): - tts.valid_base_url(value) + _valid_base_url(value) @pytest.mark.parametrize( @@ -1000,9 +945,11 @@ async def test_support_options(hass: HomeAssistant, setup_tts) -> None: ) -async def test_fetching_in_async(hass: HomeAssistant, hass_client) -> None: +async def test_fetching_in_async( + hass: HomeAssistant, hass_client: ClientSessionGenerator +) -> None: """Test async fetching of data.""" - tts_audio = asyncio.Future() + tts_audio: asyncio.Future[bytes] = asyncio.Future() class ProviderWithAsyncFetching(MockProvider): """Provider that supports audio output option.""" @@ -1067,4 +1014,7 @@ async def test_fetching_in_async(hass: HomeAssistant, hass_client) -> None: tts_audio = asyncio.Future() tts_audio.set_result(b"test 2") - await tts.async_get_media_source_audio(hass, media_source_id) == ("mp3", b"test 2") + assert await tts.async_get_media_source_audio(hass, media_source_id) == ( + "mp3", + b"test 2", + ) diff --git a/tests/components/tts/test_legacy.py b/tests/components/tts/test_legacy.py new file mode 100644 index 00000000000..42b7159df6f --- /dev/null +++ b/tests/components/tts/test_legacy.py @@ -0,0 +1,121 @@ +"""Test the legacy tts setup.""" +from __future__ import annotations + +import pytest + +from homeassistant.components.tts import DOMAIN, Provider +from homeassistant.core import HomeAssistant +from homeassistant.helpers.discovery import async_load_platform +from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType +from homeassistant.setup import async_setup_component + +from .common import MockTTS + +from tests.common import ( + MockModule, + assert_setup_component, + mock_integration, + mock_platform, +) + + +async def test_default_provider_attributes() -> None: + """Test default provider properties.""" + provider = Provider() + + assert provider.hass is None + assert provider.name is None + assert provider.default_language is None + assert provider.supported_languages is None + assert provider.supported_options is None + assert provider.default_options is None + + +async def test_deprecated_platform(hass: HomeAssistant) -> None: + """Test deprecated google platform.""" + with assert_setup_component(0, DOMAIN): + assert await async_setup_component( + hass, DOMAIN, {DOMAIN: {"platform": "google"}} + ) + + +async def test_invalid_platform( + hass: HomeAssistant, caplog: pytest.LogCaptureFixture +) -> None: + """Test platform setup with an invalid platform.""" + await async_load_platform( + hass, + "tts", + "bad_tts", + {"tts": [{"platform": "bad_tts"}]}, + hass_config={"tts": [{"platform": "bad_tts"}]}, + ) + await hass.async_block_till_done() + + assert "Unknown text to speech platform specified" in caplog.text + + +async def test_platform_setup_without_provider( + hass: HomeAssistant, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test platform setup without provider returned.""" + + class BadPlatform(MockTTS): + """A mock TTS platform without provider.""" + + async def async_get_engine( + self, + hass: HomeAssistant, + config: ConfigType, + discovery_info: DiscoveryInfoType | None = None, + ) -> Provider | None: + """Raise exception during platform setup.""" + return None + + mock_integration(hass, MockModule(domain="bad_tts")) + mock_platform(hass, "bad_tts.tts", BadPlatform()) + + await async_load_platform( + hass, + "tts", + "bad_tts", + {}, + hass_config={"tts": [{"platform": "bad_tts"}]}, + ) + await hass.async_block_till_done() + + assert "Error setting up platform: bad_tts" in caplog.text + + +async def test_platform_setup_with_error( + hass: HomeAssistant, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test platform setup with an error during setup.""" + + class BadPlatform(MockTTS): + """A mock TTS platform with a setup error.""" + + async def async_get_engine( + self, + hass: HomeAssistant, + config: ConfigType, + discovery_info: DiscoveryInfoType | None = None, + ) -> Provider: + """Raise exception during platform setup.""" + raise Exception("Setup error") # pylint: disable=broad-exception-raised + + mock_integration(hass, MockModule(domain="bad_tts")) + mock_platform(hass, "bad_tts.tts", BadPlatform()) + + await async_load_platform( + hass, + "tts", + "bad_tts", + {}, + hass_config={"tts": [{"platform": "bad_tts"}]}, + ) + await hass.async_block_till_done() + + assert "Error setting up platform: bad_tts" in caplog.text