Add tts entity (#91692)

* Add tts entity

* Allow passing engine id to url view

* Update async_resolve_engine

* Add and update more tests

* Fix assist pipeline tests temporarily

* Move fixtures

* Update notify platform

* Complete legacy tests

* Update media source tests

* Update async_get_text_to_speech_languages

* Address comment

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Martin Hjelmare 2023-04-21 04:55:46 +02:00 committed by GitHub
parent 458276a6a6
commit 1a18dc7425
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 1925 additions and 644 deletions

View File

@ -1,8 +1,11 @@
"""Provide functionality for TTS."""
from __future__ import annotations
from abc import abstractmethod
import asyncio
from collections.abc import Mapping
from datetime import datetime
from functools import partial
import hashlib
from http import HTTPStatus
import io
@ -10,7 +13,7 @@ import logging
import mimetypes
import os
import re
from typing import Any, TypedDict
from typing import Any, TypedDict, final
from aiohttp import web
import mutagen
@ -19,13 +22,30 @@ import voluptuous as vol
from homeassistant.components import websocket_api
from homeassistant.components.http import HomeAssistantView
from homeassistant.const import PLATFORM_FORMAT
from homeassistant.components.media_player import (
ATTR_MEDIA_ANNOUNCE,
ATTR_MEDIA_CONTENT_ID,
ATTR_MEDIA_CONTENT_TYPE,
DOMAIN as DOMAIN_MP,
SERVICE_PLAY_MEDIA,
MediaType,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import (
ATTR_ENTITY_ID,
PLATFORM_FORMAT,
STATE_UNAVAILABLE,
STATE_UNKNOWN,
)
from homeassistant.core import HassJob, HomeAssistant, ServiceCall, callback
from homeassistant.exceptions import HomeAssistantError
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.event import async_call_later
from homeassistant.helpers.network import get_url
from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.helpers.typing import ConfigType
from homeassistant.util import language as language_util
from homeassistant.util import dt as dt_util, language as language_util
from .const import (
ATTR_CACHE,
@ -36,12 +56,14 @@ from .const import (
CONF_CACHE,
CONF_CACHE_DIR,
CONF_TIME_MEMORY,
DATA_TTS_MANAGER,
DEFAULT_CACHE,
DEFAULT_CACHE_DIR,
DEFAULT_TIME_MEMORY,
DOMAIN,
TtsAudioType,
)
from .helper import get_engine_instance
from .legacy import PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE, Provider, async_setup_legacy
from .media_source import generate_media_source_id, media_source_id_to_kwargs
@ -64,6 +86,7 @@ _LOGGER = logging.getLogger(__name__)
ATTR_PLATFORM = "platform"
ATTR_AUDIO_OUTPUT = "audio_output"
ATTR_MEDIA_PLAYER_ENTITY_ID = "media_player_entity_id"
CONF_LANG = "language"
@ -71,7 +94,12 @@ BASE_URL_KEY = "tts_base_url"
SERVICE_CLEAR_CACHE = "clear_cache"
_RE_VOICE_FILE = re.compile(r"([a-f0-9]{40})_([^_]+)_([^_]+)_([a-z_]+)\.[a-z0-9]{3,4}")
_RE_LEGACY_VOICE_FILE = re.compile(
r"([a-f0-9]{40})_([^_]+)_([^_]+)_([a-z_]+)\.[a-z0-9]{3,4}"
)
_RE_VOICE_FILE = re.compile(
r"([a-f0-9]{40})_([^_]+)_([^_]+)_(tts\.[a-z_]+)\.[a-z0-9]{3,4}"
)
KEY_PATTERN = "{0}_{1}_{2}_{3}"
SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({})
@ -91,20 +119,23 @@ def async_resolve_engine(hass: HomeAssistant, engine: str | None) -> str | None:
Returns None if no engines found or invalid engine passed in.
"""
manager: SpeechManager = hass.data[DOMAIN]
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
if engine is not None:
if engine not in manager.providers:
if not component.get_entity(engine) and engine not in manager.providers:
return None
return engine
if not manager.providers:
return None
if "cloud" in manager.providers:
return "cloud"
return next(iter(manager.providers))
entity = next(iter(component.entities), None)
if entity is not None:
return entity.entity_id
return next(iter(manager.providers), None)
async def async_support_options(
@ -114,9 +145,13 @@ async def async_support_options(
options: dict | None = None,
) -> bool:
"""Return if an engine supports options."""
manager: SpeechManager = hass.data[DOMAIN]
if (engine_instance := get_engine_instance(hass, engine)) is None:
raise HomeAssistantError(f"Provider {engine} not found")
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
try:
manager.process_options(engine, language, options)
manager.process_options(engine_instance, language, options)
except HomeAssistantError:
return False
@ -128,7 +163,7 @@ async def async_get_media_source_audio(
media_source_id: str,
) -> tuple[str, bytes]:
"""Get TTS audio as extension, data."""
manager: SpeechManager = hass.data[DOMAIN]
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
return await manager.async_get_tts_audio(
**media_source_id_to_kwargs(media_source_id),
)
@ -139,7 +174,13 @@ def async_get_text_to_speech_languages(hass: HomeAssistant) -> set[str]:
"""Return a set with the union of languages supported by tts engines."""
languages = set()
manager: SpeechManager = hass.data[DOMAIN]
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
for entity in component.entities:
for language_tag in entity.supported_languages:
languages.add(language_tag)
for tts_engine in manager.providers.values():
for language_tag in tts_engine.supported_languages:
languages.add(language_tag)
@ -173,7 +214,13 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
_LOGGER.exception("Error on cache init")
return False
hass.data[DOMAIN] = tts
hass.data[DATA_TTS_MANAGER] = tts
component = hass.data[DOMAIN] = EntityComponent[TextToSpeechEntity](
_LOGGER, DOMAIN, hass
)
component.register_shutdown()
hass.http.register_view(TextToSpeechView(tts))
hass.http.register_view(TextToSpeechUrlView(tts))
@ -182,6 +229,18 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
if platform_setups:
await asyncio.wait([asyncio.create_task(setup) for setup in platform_setups])
component.async_register_entity_service(
"speak",
{
vol.Required(ATTR_MEDIA_PLAYER_ENTITY_ID): cv.comp_entity_ids,
vol.Required(ATTR_MESSAGE): cv.string,
vol.Optional(ATTR_CACHE, default=DEFAULT_CACHE): cv.boolean,
vol.Optional(ATTR_LANGUAGE): cv.string,
vol.Optional(ATTR_OPTIONS): dict,
},
"async_speak",
)
async def async_clear_cache_handle(service: ServiceCall) -> None:
"""Handle clear cache service call."""
await tts.async_clear_cache()
@ -196,6 +255,129 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up a config entry."""
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
return await component.async_setup_entry(entry)
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
return await component.async_unload_entry(entry)
class TextToSpeechEntity(RestoreEntity):
"""Represent a single TTS engine."""
_attr_should_poll = False
__last_tts_loaded: str | None = None
@property
@final
def state(self) -> str | None:
"""Return the state of the entity."""
if self.__last_tts_loaded is None:
return None
return self.__last_tts_loaded
@property
@abstractmethod
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
@property
@abstractmethod
def default_language(self) -> str:
"""Return the default language."""
@property
def supported_options(self) -> list[str] | None:
"""Return a list of supported options like voice, emotions."""
return None
@property
def default_options(self) -> Mapping[str, Any] | None:
"""Return a mapping with the default options."""
return None
@callback
def async_get_supported_voices(self, language: str) -> list[str] | None:
"""Return a list of supported voices for a language."""
return None
async def async_internal_added_to_hass(self) -> None:
"""Call when the entity is added to hass."""
await super().async_internal_added_to_hass()
state = await self.async_get_last_state()
if (
state is not None
and state.state is not None
and state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
):
self.__last_tts_loaded = state.state
async def async_speak(
self,
media_player_entity_id: list[str],
message: str,
cache: bool,
language: str | None = None,
options: dict | None = None,
) -> None:
"""Speak via a Media Player."""
await self.hass.services.async_call(
DOMAIN_MP,
SERVICE_PLAY_MEDIA,
{
ATTR_ENTITY_ID: media_player_entity_id,
ATTR_MEDIA_CONTENT_ID: generate_media_source_id(
self.hass,
message=message,
engine=self.entity_id,
language=language,
options=options,
cache=cache,
),
ATTR_MEDIA_CONTENT_TYPE: MediaType.MUSIC,
ATTR_MEDIA_ANNOUNCE: True,
},
blocking=True,
context=self._context,
)
@final
async def internal_async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
) -> TtsAudioType:
"""Process an audio stream to TTS service.
Only streaming content is allowed!
"""
self.__last_tts_loaded = dt_util.utcnow().isoformat()
self.async_write_ha_state()
return await self.async_get_tts_audio(
message=message, language=language, options=options
)
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
) -> TtsAudioType:
"""Load tts audio file from the engine."""
raise NotImplementedError()
async def async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
) -> TtsAudioType:
"""Load tts audio file from the engine.
Return a tuple of file extension and data as bytes.
"""
return await self.hass.async_add_executor_job(
partial(self.get_tts_audio, message, language, options=options)
)
def _hash_options(options: dict) -> str:
"""Hashes an options dictionary."""
opts_hash = hashlib.blake2s(digest_size=5)
@ -266,7 +448,7 @@ class SpeechManager:
def async_register_legacy_engine(
self, engine: str, provider: Provider, config: ConfigType
) -> None:
"""Register a TTS provider."""
"""Register a legacy TTS engine."""
provider.hass = self.hass
if provider.name is None:
provider.name = engine
@ -279,23 +461,20 @@ class SpeechManager:
@callback
def process_options(
self,
engine: str,
engine_instance: TextToSpeechEntity | Provider,
language: str | None = None,
options: dict | None = None,
) -> tuple[str, dict | None]:
"""Validate and process options."""
if (provider := self.providers.get(engine)) is None:
raise HomeAssistantError(f"Provider {engine} not found")
# Languages
language = language or provider.default_language
language = language or engine_instance.default_language
if language is None or provider.supported_languages is None:
if language is None or engine_instance.supported_languages is None:
raise HomeAssistantError(f"Not supported language {language}")
if language not in provider.supported_languages:
if language not in engine_instance.supported_languages:
language_matches = language_util.matches(
language, provider.supported_languages
language, engine_instance.supported_languages
)
if language_matches:
# Choose best match
@ -304,7 +483,7 @@ class SpeechManager:
raise HomeAssistantError(f"Not supported language {language}")
# Options
if (default_options := provider.default_options) and options:
if (default_options := engine_instance.default_options) and options:
merged_options = dict(default_options)
merged_options.update(options)
options = merged_options
@ -312,7 +491,7 @@ class SpeechManager:
options = None if default_options is None else dict(default_options)
if options is not None:
supported_options = provider.supported_options or []
supported_options = engine_instance.supported_options or []
invalid_opts = [
opt_name for opt_name in options if opt_name not in supported_options
]
@ -333,7 +512,10 @@ class SpeechManager:
This method is a coroutine.
"""
language, options = self.process_options(engine, language, options)
if (engine_instance := get_engine_instance(self.hass, engine)) is None:
raise HomeAssistantError(f"Provider {engine} not found")
language, options = self.process_options(engine_instance, language, options)
cache_key = self._generate_cache_key(message, language, options, engine)
use_cache = cache if cache is not None else self.use_cache
@ -344,10 +526,10 @@ class SpeechManager:
elif use_cache and cache_key in self.file_cache:
filename = self.file_cache[cache_key]
self.hass.async_create_task(self._async_file_to_mem(cache_key))
# Load speech from provider into memory
# Load speech from engine into memory
else:
filename = await self._async_get_tts_audio(
engine,
engine_instance,
cache_key,
message,
use_cache,
@ -366,7 +548,10 @@ class SpeechManager:
options: dict | None = None,
) -> tuple[str, bytes]:
"""Fetch TTS audio."""
language, options = self.process_options(engine, language, options)
if (engine_instance := get_engine_instance(self.hass, engine)) is None:
raise HomeAssistantError(f"Provider {engine} not found")
language, options = self.process_options(engine_instance, language, options)
cache_key = self._generate_cache_key(message, language, options, engine)
use_cache = cache if cache is not None else self.use_cache
@ -376,7 +561,7 @@ class SpeechManager:
await self._async_file_to_mem(cache_key)
else:
await self._async_get_tts_audio(
engine, cache_key, message, use_cache, language, options
engine_instance, cache_key, message, use_cache, language, options
)
extension = os.path.splitext(self.mem_cache[cache_key]["filename"])[1][1:]
@ -403,7 +588,7 @@ class SpeechManager:
async def _async_get_tts_audio(
self,
engine: str,
engine_instance: TextToSpeechEntity | Provider,
cache_key: str,
message: str,
cache: bool,
@ -414,8 +599,6 @@ class SpeechManager:
This method is a coroutine.
"""
provider = self.providers[engine]
if options is not None and ATTR_AUDIO_OUTPUT in options:
expected_extension = options[ATTR_AUDIO_OUTPUT]
else:
@ -423,26 +606,38 @@ class SpeechManager:
async def get_tts_data() -> str:
"""Handle data available."""
extension, data = await provider.async_get_tts_audio(
message, language, options
)
if engine_instance.name is None:
raise HomeAssistantError("TTS engine name is not set.")
if isinstance(engine_instance, Provider):
extension, data = await engine_instance.async_get_tts_audio(
message, language, options
)
else:
extension, data = await engine_instance.internal_async_get_tts_audio(
message, language, options
)
if data is None or extension is None:
raise HomeAssistantError(f"No TTS from {engine} for '{message}'")
raise HomeAssistantError(
f"No TTS from {engine_instance.name} for '{message}'"
)
# Create file infos
filename = f"{cache_key}.{extension}".lower()
# Validate filename
if not _RE_VOICE_FILE.match(filename):
if not _RE_VOICE_FILE.match(filename) and not _RE_LEGACY_VOICE_FILE.match(
filename
):
raise HomeAssistantError(
f"TTS filename '{filename}' from {engine} is invalid!"
f"TTS filename '{filename}' from {engine_instance.name} is invalid!"
)
# Save to memory
if extension == "mp3":
data = self.write_tags(
filename, data, provider, message, language, options
filename, data, engine_instance.name, message, language, options
)
self._async_store_to_memcache(cache_key, filename, data)
@ -547,7 +742,9 @@ class SpeechManager:
This method is a coroutine.
"""
if not (record := _RE_VOICE_FILE.match(filename.lower())):
if not (record := _RE_VOICE_FILE.match(filename.lower())) and not (
record := _RE_LEGACY_VOICE_FILE.match(filename.lower())
):
raise HomeAssistantError("Wrong tts file format!")
cache_key = KEY_PATTERN.format(
@ -570,7 +767,7 @@ class SpeechManager:
def write_tags(
filename: str,
data: bytes,
provider: Provider,
engine_name: str,
message: str,
language: str,
options: dict | None,
@ -584,7 +781,7 @@ class SpeechManager:
data_bytes.name = filename
data_bytes.seek(0)
album = provider.name
album = engine_name
artist = language
if options is not None and (voice := options.get("voice")) is not None:
@ -635,7 +832,9 @@ def _get_cache_files(cache_dir: str) -> dict[str, str]:
folder_data = os.listdir(cache_dir)
for file_data in folder_data:
if record := _RE_VOICE_FILE.match(file_data):
if (record := _RE_VOICE_FILE.match(file_data)) or (
record := _RE_LEGACY_VOICE_FILE.match(file_data)
):
key = KEY_PATTERN.format(
record.group(1), record.group(2), record.group(3), record.group(4)
)
@ -660,12 +859,16 @@ class TextToSpeechUrlView(HomeAssistantView):
data = await request.json()
except ValueError:
return self.json_message("Invalid JSON specified", HTTPStatus.BAD_REQUEST)
if not data.get(ATTR_PLATFORM) and data.get(ATTR_MESSAGE):
if (
not data.get("engine_id")
and not data.get(ATTR_PLATFORM)
or not data.get(ATTR_MESSAGE)
):
return self.json_message(
"Must specify platform and message", HTTPStatus.BAD_REQUEST
)
p_type = data[ATTR_PLATFORM]
engine = data.get("engine_id") or data[ATTR_PLATFORM]
message = data[ATTR_MESSAGE]
cache = data.get(ATTR_CACHE)
language = data.get(ATTR_LANGUAGE)
@ -673,7 +876,7 @@ class TextToSpeechUrlView(HomeAssistantView):
try:
path = await self.tts.async_get_url_path(
p_type, message, cache=cache, language=language, options=options
engine, message, cache=cache, language=language, options=options
)
except HomeAssistantError as err:
_LOGGER.error("Error on init tts: %s", err)
@ -724,13 +927,26 @@ def websocket_list_engines(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""List text to speech engines and, optionally, if they support a given language."""
manager: SpeechManager = hass.data[DOMAIN]
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
country = msg.get("country")
language = msg.get("language")
providers = []
provider_info: dict[str, Any]
for entity in component.entities:
provider_info = {
"engine_id": entity.entity_id,
"supported_languages": entity.supported_languages,
}
if language:
provider_info["supported_languages"] = language_util.matches(
language, entity.supported_languages, country
)
providers.append(provider_info)
for engine_id, provider in manager.providers.items():
provider_info: dict[str, Any] = {
provider_info = {
"engine_id": engine_id,
"supported_languages": provider.supported_languages,
}
@ -760,10 +976,9 @@ def websocket_list_engine_voices(
engine_id = msg["engine_id"]
language = msg["language"]
manager: SpeechManager = hass.data[DOMAIN]
engine = manager.providers.get(engine_id)
engine_instance = get_engine_instance(hass, engine_id)
if not engine:
if not engine_instance:
connection.send_error(
msg["id"],
websocket_api.const.ERR_NOT_FOUND,
@ -771,6 +986,6 @@ def websocket_list_engine_voices(
)
return
voices = {"voices": engine.async_get_supported_voices(language)}
voices = {"voices": engine_instance.async_get_supported_voices(language)}
connection.send_message(websocket_api.result_message(msg["id"], voices))

View File

@ -16,4 +16,6 @@ DEFAULT_TIME_MEMORY = 300
DOMAIN = "tts"
DATA_TTS_MANAGER = "tts_manager"
TtsAudioType = tuple[str | None, bytes | None]

View File

@ -0,0 +1,26 @@
"""Provide helper functions for the TTS."""
from __future__ import annotations
from typing import TYPE_CHECKING
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_component import EntityComponent
from .const import DATA_TTS_MANAGER, DOMAIN
if TYPE_CHECKING:
from . import SpeechManager, TextToSpeechEntity
from .legacy import Provider
def get_engine_instance(
hass: HomeAssistant, engine: str
) -> TextToSpeechEntity | Provider | None:
"""Get engine instance."""
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
if entity := component.get_entity(engine):
return entity
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
return manager.providers.get(engine)

View File

@ -44,6 +44,7 @@ from .const import (
CONF_CACHE_DIR,
CONF_FIELDS,
CONF_TIME_MEMORY,
DATA_TTS_MANAGER,
DEFAULT_CACHE,
DEFAULT_CACHE_DIR,
DEFAULT_TIME_MEMORY,
@ -111,7 +112,7 @@ async def async_setup_legacy(
hass: HomeAssistant, config: ConfigType
) -> list[Coroutine[Any, Any, None]]:
"""Set up legacy text to speech providers."""
tts: SpeechManager = hass.data[DOMAIN]
tts: SpeechManager = hass.data[DATA_TTS_MANAGER]
# Load service descriptions from tts/services.yaml
services_yaml = Path(__file__).parent / "services.yaml"

View File

@ -17,12 +17,14 @@ from homeassistant.components.media_source import (
)
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.network import get_url
from .const import DOMAIN
from .const import DATA_TTS_MANAGER, DOMAIN
from .helper import get_engine_instance
if TYPE_CHECKING:
from . import SpeechManager
from . import SpeechManager, TextToSpeechEntity
async def async_get_media_source(hass: HomeAssistant) -> TTSMediaSource:
@ -42,12 +44,16 @@ def generate_media_source_id(
"""Generate a media source ID for text-to-speech."""
from . import async_resolve_engine # pylint: disable=import-outside-toplevel
manager: SpeechManager = hass.data[DOMAIN]
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
if (engine := async_resolve_engine(hass, engine)) is None:
raise HomeAssistantError("Invalid TTS provider selected")
manager.process_options(engine, language, options)
engine_instance = get_engine_instance(hass, engine)
# We raise above if the engine is not resolved, so engine_instance can't be None
assert engine_instance is not None
manager.process_options(engine_instance, language, options)
params = {
"message": message,
}
@ -107,7 +113,7 @@ class TTSMediaSource(MediaSource):
async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia:
"""Resolve media to a url."""
manager: SpeechManager = self.hass.data[DOMAIN]
manager: SpeechManager = self.hass.data[DATA_TTS_MANAGER]
try:
url = await manager.async_get_url_path(
@ -129,12 +135,15 @@ class TTSMediaSource(MediaSource):
) -> BrowseMediaSource:
"""Return media."""
if item.identifier:
provider, _, params = item.identifier.partition("?")
return self._provider_item(provider, params)
engine, _, params = item.identifier.partition("?")
return self._engine_item(engine, params)
# Root. List providers.
manager: SpeechManager = self.hass.data[DOMAIN]
children = [self._provider_item(provider) for provider in manager.providers]
manager: SpeechManager = self.hass.data[DATA_TTS_MANAGER]
component: EntityComponent[TextToSpeechEntity] = self.hass.data[DOMAIN]
children = [self._engine_item(engine) for engine in manager.providers] + [
self._engine_item(entity.entity_id) for entity in component.entities
]
return BrowseMediaSource(
domain=DOMAIN,
identifier=None,
@ -148,14 +157,19 @@ class TTSMediaSource(MediaSource):
)
@callback
def _provider_item(
self, provider_domain: str, params: str | None = None
) -> BrowseMediaSource:
def _engine_item(self, engine: str, params: str | None = None) -> BrowseMediaSource:
"""Return provider item."""
manager: SpeechManager = self.hass.data[DOMAIN]
if (provider := manager.providers.get(provider_domain)) is None:
from . import TextToSpeechEntity # pylint: disable=import-outside-toplevel
if (engine_instance := get_engine_instance(self.hass, engine)) is None:
raise BrowseError("Unknown provider")
if isinstance(engine_instance, TextToSpeechEntity):
assert engine_instance.platform is not None
engine_domain = engine_instance.platform.domain
else:
engine_domain = engine
if params:
params = f"?{params}"
else:
@ -163,11 +177,11 @@ class TTSMediaSource(MediaSource):
return BrowseMediaSource(
domain=DOMAIN,
identifier=f"{provider_domain}{params}",
identifier=f"{engine}{params}",
media_class=MediaClass.APP,
media_content_type="provider",
title=provider.name,
thumbnail=f"https://brands.home-assistant.io/_/{provider_domain}/logo.png",
title=engine_instance.name,
thumbnail=f"https://brands.home-assistant.io/_/{engine_domain}/logo.png",
can_play=False,
can_expand=True,
)

View File

@ -7,22 +7,26 @@ from typing import Any
import voluptuous as vol
from homeassistant.components.notify import PLATFORM_SCHEMA, BaseNotificationService
from homeassistant.const import ATTR_ENTITY_ID, CONF_NAME
from homeassistant.const import ATTR_ENTITY_ID, CONF_ENTITY_ID, CONF_NAME
from homeassistant.core import HomeAssistant, split_entity_id
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import ATTR_LANGUAGE, ATTR_MESSAGE, DOMAIN
from . import ATTR_LANGUAGE, ATTR_MEDIA_PLAYER_ENTITY_ID, ATTR_MESSAGE, DOMAIN
CONF_MEDIA_PLAYER = "media_player"
CONF_TTS_SERVICE = "tts_service"
ENTITY_LEGACY_PROVIDER_GROUP = "entity_or_legacy_provider"
_LOGGER = logging.getLogger(__name__)
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
{
vol.Required(CONF_NAME): cv.string,
vol.Required(CONF_TTS_SERVICE): cv.entity_id,
vol.Exclusive(CONF_TTS_SERVICE, ENTITY_LEGACY_PROVIDER_GROUP): cv.entity_id,
vol.Exclusive(CONF_ENTITY_ID, ENTITY_LEGACY_PROVIDER_GROUP): cv.entities_domain(
DOMAIN
),
vol.Required(CONF_MEDIA_PLAYER): cv.entity_id,
vol.Optional(ATTR_LANGUAGE): cv.string,
}
@ -44,7 +48,12 @@ class TTSNotificationService(BaseNotificationService):
def __init__(self, config: ConfigType) -> None:
"""Initialize the service."""
_, self._tts_service = split_entity_id(config[CONF_TTS_SERVICE])
self._target: str | None = None
self._tts_service: str | None = None
if entity_id := config.get(CONF_ENTITY_ID):
self._target = entity_id
else:
_, self._tts_service = split_entity_id(config[CONF_TTS_SERVICE])
self._media_player = config[CONF_MEDIA_PLAYER]
self._language = config.get(ATTR_LANGUAGE)
@ -54,13 +63,21 @@ class TTSNotificationService(BaseNotificationService):
data = {
ATTR_MESSAGE: message,
ATTR_ENTITY_ID: self._media_player,
}
service_name = ""
if self._tts_service:
data[ATTR_ENTITY_ID] = self._media_player
service_name = self._tts_service
elif self._target:
data[ATTR_ENTITY_ID] = self._target
data[ATTR_MEDIA_PLAYER_ENTITY_ID] = self._media_player
service_name = "speak"
if self._language:
data[ATTR_LANGUAGE] = self._language
await self.hass.services.async_call(
DOMAIN,
self._tts_service,
service_name,
data,
)

View File

@ -40,6 +40,49 @@ say:
selector:
object:
speak:
name: Speak
description: Speak something using text-to-speech on a media player.
target:
entity:
domain: tts
fields:
media_player_entity_id:
name: Media Player Entity
description: Name(s) of media player entities.
required: true
selector:
entity:
domain: media_player
message:
name: Message
description: Text to speak on devices.
example: "My name is hanna"
required: true
selector:
text:
cache:
name: Cache
description: Control file cache of this message.
default: true
selector:
boolean:
language:
name: Language
description: Language to use for speech generation.
example: "ru"
selector:
text:
options:
name: Options
description:
A dictionary containing platform-specific options. Optional depending on
the platform.
advanced: true
example: platform specific
selector:
object:
clear_cache:
name: Clear TTS cache
description: Remove all text-to-speech cache files and RAM cache.

View File

@ -1,4 +1,6 @@
"""Test fixtures for voice assistant."""
from __future__ import annotations
from collections.abc import AsyncIterable, Generator
from typing import Any
from unittest.mock import AsyncMock
@ -23,6 +25,7 @@ from tests.common import (
mock_platform,
)
from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import
init_cache_dir_side_effect,
mock_get_cache_files,
mock_init_cache_dir,
)
@ -33,11 +36,10 @@ _TRANSCRIPT = "test transcript"
class BaseProvider:
"""Mock STT provider."""
def __init__(self, hass: HomeAssistant, text: str) -> None:
def __init__(self, text: str) -> None:
"""Init test provider."""
self.hass = hass
self.text = text
self.received = []
self.received: list[bytes] = []
@property
def supported_languages(self) -> list[str]:
@ -115,7 +117,7 @@ class MockTTSProvider(tts.Provider):
return ("mp3", b"")
class MockTTS:
class MockTTS(MockPlatform):
"""A mock TTS platform."""
PLATFORM_SCHEMA = tts.PLATFORM_SCHEMA
@ -131,15 +133,15 @@ class MockTTS:
@pytest.fixture
async def mock_stt_provider(hass) -> MockSttProvider:
async def mock_stt_provider() -> MockSttProvider:
"""Mock STT provider."""
return MockSttProvider(hass, _TRANSCRIPT)
return MockSttProvider(_TRANSCRIPT)
@pytest.fixture
def mock_stt_provider_entity(hass) -> MockSttProviderEntity:
def mock_stt_provider_entity() -> MockSttProviderEntity:
"""Test provider entity fixture."""
return MockSttProviderEntity(hass, _TRANSCRIPT)
return MockSttProviderEntity(_TRANSCRIPT)
class MockSttPlatform(MockPlatform):
@ -170,8 +172,9 @@ async def init_components(
mock_stt_provider: MockSttProvider,
mock_stt_provider_entity: MockSttProviderEntity,
config_flow_fixture,
init_cache_dir_side_effect, # noqa: F811
mock_get_cache_files, # noqa: F811
mock_init_cache_dir, # noqa: F811,
mock_init_cache_dir, # noqa: F811
):
"""Initialize relevant components with empty configs."""

View File

@ -5,30 +5,50 @@ from typing import Any
import voluptuous as vol
from homeassistant.components import media_source
from homeassistant.components.tts import (
CONF_LANG,
DOMAIN as TTS_DOMAIN,
PLATFORM_SCHEMA,
Provider,
TextToSpeechEntity,
TtsAudioType,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.setup import async_setup_component
from tests.common import MockPlatform
SUPPORT_LANGUAGES = ["de_CH", "de_DE", "en_GB", "en_US"]
TEST_LANGUAGES = ["de", "en"]
from tests.common import (
MockConfigEntry,
MockModule,
MockPlatform,
mock_integration,
mock_platform,
)
DEFAULT_LANG = "en_US"
SUPPORT_LANGUAGES = ["de_CH", "de_DE", "en_GB", "en_US"]
TEST_DOMAIN = "test"
TEST_LANGUAGES = ["de", "en"]
class MockProvider(Provider):
async def get_media_source_url(hass: HomeAssistant, media_content_id: str) -> str:
"""Get the media source url."""
if media_source.DOMAIN not in hass.config.components:
assert await async_setup_component(hass, media_source.DOMAIN, {})
resolved = await media_source.async_resolve_media(hass, media_content_id, None)
return resolved.url
class BaseProvider:
"""Test speech API provider."""
def __init__(self, lang: str) -> None:
"""Initialize test provider."""
self._lang = lang
self.name = "Test"
@property
def default_language(self) -> str:
@ -59,6 +79,24 @@ class MockProvider(Provider):
return ("mp3", b"")
class MockProvider(BaseProvider, Provider):
"""Test speech API provider."""
def __init__(self, lang: str) -> None:
"""Initialize test provider."""
super().__init__(lang)
self.name = "Test"
class MockTTSEntity(BaseProvider, TextToSpeechEntity):
"""Test speech API provider."""
@property
def name(self) -> str:
"""Return the name of the entity."""
return "Test"
class MockTTS(MockPlatform):
"""A mock TTS platform."""
@ -70,13 +108,9 @@ class MockTTS(MockPlatform):
}
)
def __init__(
self, provider: type[MockProvider] | None = None, **kwargs: Any
) -> None:
def __init__(self, provider: MockProvider, **kwargs: Any) -> None:
"""Initialize."""
super().__init__(**kwargs)
if provider is None:
provider = MockProvider
self._provider = provider
async def async_get_engine(
@ -86,4 +120,65 @@ class MockTTS(MockPlatform):
discovery_info: DiscoveryInfoType | None = None,
) -> Provider | None:
"""Set up a mock speech component."""
return self._provider(config.get(CONF_LANG, DEFAULT_LANG))
return self._provider
async def mock_setup(
hass: HomeAssistant,
mock_provider: MockProvider,
) -> None:
"""Set up a test provider."""
mock_integration(hass, MockModule(domain=TEST_DOMAIN))
mock_platform(hass, f"{TEST_DOMAIN}.{TTS_DOMAIN}", MockTTS(mock_provider))
await async_setup_component(
hass, TTS_DOMAIN, {TTS_DOMAIN: {"platform": TEST_DOMAIN}}
)
await hass.async_block_till_done()
async def mock_config_entry_setup(
hass: HomeAssistant, tts_entity: MockTTSEntity
) -> MockConfigEntry:
"""Set up a test tts platform via config entry."""
async def async_setup_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Set up test config entry."""
await hass.config_entries.async_forward_entry_setup(config_entry, TTS_DOMAIN)
return True
async def async_unload_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Unload up test config entry."""
await hass.config_entries.async_forward_entry_unload(config_entry, TTS_DOMAIN)
return True
mock_integration(
hass,
MockModule(
TEST_DOMAIN,
async_setup_entry=async_setup_entry_init,
async_unload_entry=async_unload_entry_init,
),
)
async def async_setup_entry_platform(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up test tts platform via config entry."""
async_add_entities([tts_entity])
loaded_platform = MockPlatform(async_setup_entry=async_setup_entry_platform)
mock_platform(hass, f"{TEST_DOMAIN}.{TTS_DOMAIN}", loaded_platform)
config_entry = MockConfigEntry(domain=TEST_DOMAIN)
config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
return config_entry

View File

@ -2,17 +2,28 @@
From http://doc.pytest.org/en/latest/example/simple.html#making-test-result-information-available-in-fixtures
"""
from unittest.mock import patch
from collections.abc import Callable, Generator
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from homeassistant.components.tts import _get_cache_files
from homeassistant.config import async_process_ha_core_config
from homeassistant.config_entries import ConfigFlow
from homeassistant.core import HomeAssistant
from .common import MockTTS
from .common import (
DEFAULT_LANG,
TEST_DOMAIN,
MockProvider,
MockTTS,
MockTTSEntity,
mock_config_entry_setup,
mock_setup,
)
from tests.common import MockModule, mock_integration, mock_platform
from tests.common import MockModule, mock_config_flow, mock_integration, mock_platform
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
@ -37,15 +48,30 @@ def mock_get_cache_files():
@pytest.fixture(autouse=True)
def mock_init_cache_dir():
def mock_init_cache_dir(
init_cache_dir_side_effect: Any,
) -> Generator[MagicMock, None, None]:
"""Mock the TTS cache dir in memory."""
with patch(
"homeassistant.components.tts._init_tts_cache_dir",
side_effect=lambda hass, cache_dir: hass.config.path(cache_dir),
side_effect=init_cache_dir_side_effect,
) as mock_cache_dir:
yield mock_cache_dir
@pytest.fixture
def init_cache_dir_side_effect(
hass: HomeAssistant,
) -> Callable[[HomeAssistant, str], str]:
"""Return the cache dir."""
def side_effect(hass: HomeAssistant, cache_dir: str) -> str:
"""Return the cache dir."""
return hass.config.path(cache_dir)
return side_effect
@pytest.fixture
def empty_cache_dir(tmp_path, mock_init_cache_dir, mock_get_cache_files, request):
"""Mock the TTS cache dir with empty dir."""
@ -89,7 +115,48 @@ async def internal_url_mock(hass: HomeAssistant) -> None:
@pytest.fixture
async def mock_tts(hass: HomeAssistant) -> None:
async def mock_tts(hass: HomeAssistant, mock_provider) -> None:
"""Mock TTS."""
mock_integration(hass, MockModule(domain="test"))
mock_platform(hass, "test.tts", MockTTS())
mock_platform(hass, "test.tts", MockTTS(mock_provider))
@pytest.fixture
def mock_provider() -> MockProvider:
"""Test TTS provider."""
return MockProvider(DEFAULT_LANG)
@pytest.fixture
def mock_tts_entity() -> MockTTSEntity:
"""Test TTS entity."""
return MockTTSEntity(DEFAULT_LANG)
class TTSFlow(ConfigFlow):
"""Test flow."""
@pytest.fixture(autouse=True)
def config_flow_fixture(hass: HomeAssistant) -> Generator[None, None, None]:
"""Mock config flow."""
mock_platform(hass, f"{TEST_DOMAIN}.config_flow")
with mock_config_flow(TEST_DOMAIN, TTSFlow):
yield
@pytest.fixture(name="setup")
async def setup_fixture(
hass: HomeAssistant,
request: pytest.FixtureRequest,
mock_provider: MockProvider,
mock_tts_entity: MockTTSEntity,
) -> None:
"""Set up the test environment."""
if request.param == "mock_setup":
await mock_setup(hass, mock_provider)
elif request.param == "mock_config_entry_setup":
await mock_config_entry_setup(hass, mock_tts_entity)
else:
raise RuntimeError("Invalid setup fixture")

File diff suppressed because it is too large Load Diff

View File

@ -3,32 +3,51 @@ from __future__ import annotations
import pytest
from homeassistant.components.tts import DOMAIN, Provider
from homeassistant.components.media_player import (
ATTR_MEDIA_CONTENT_ID,
ATTR_MEDIA_CONTENT_TYPE,
DOMAIN as DOMAIN_MP,
SERVICE_PLAY_MEDIA,
MediaType,
)
from homeassistant.components.tts import ATTR_MESSAGE, DOMAIN, Provider
from homeassistant.const import ATTR_ENTITY_ID
from homeassistant.core import HomeAssistant
from homeassistant.helpers.discovery import async_load_platform
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.setup import async_setup_component
from .common import MockTTS
from .common import SUPPORT_LANGUAGES, MockProvider, MockTTS, get_media_source_url
from tests.common import (
MockModule,
assert_setup_component,
async_mock_service,
mock_integration,
mock_platform,
)
class DefaultProvider(Provider):
"""Test provider."""
@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
return SUPPORT_LANGUAGES
async def test_default_provider_attributes() -> None:
"""Test default provider properties."""
provider = Provider()
"""Test default provider attributes."""
provider = DefaultProvider()
assert provider.hass is None
assert provider.name is None
assert provider.default_language is None
assert provider.supported_languages is None
assert provider.supported_languages == SUPPORT_LANGUAGES
assert provider.supported_options is None
assert provider.default_options is None
assert provider.async_get_supported_voices("test") is None
async def test_deprecated_platform(hass: HomeAssistant) -> None:
@ -56,8 +75,7 @@ async def test_invalid_platform(
async def test_platform_setup_without_provider(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
hass: HomeAssistant, caplog: pytest.LogCaptureFixture, mock_provider: MockProvider
) -> None:
"""Test platform setup without provider returned."""
@ -74,7 +92,7 @@ async def test_platform_setup_without_provider(
return None
mock_integration(hass, MockModule(domain="bad_tts"))
mock_platform(hass, "bad_tts.tts", BadPlatform())
mock_platform(hass, "bad_tts.tts", BadPlatform(mock_provider))
await async_load_platform(
hass,
@ -91,6 +109,7 @@ async def test_platform_setup_without_provider(
async def test_platform_setup_with_error(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
mock_provider: MockProvider,
) -> None:
"""Test platform setup with an error during setup."""
@ -107,7 +126,7 @@ async def test_platform_setup_with_error(
raise Exception("Setup error") # pylint: disable=broad-exception-raised
mock_integration(hass, MockModule(domain="bad_tts"))
mock_platform(hass, "bad_tts.tts", BadPlatform())
mock_platform(hass, "bad_tts.tts", BadPlatform(mock_provider))
await async_load_platform(
hass,
@ -119,3 +138,58 @@ async def test_platform_setup_with_error(
await hass.async_block_till_done()
assert "Error setting up platform: bad_tts" in caplog.text
async def test_service_base_url_set(hass: HomeAssistant, mock_tts) -> None:
"""Set up a TTS platform with ``base_url`` set and call service."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
config = {DOMAIN: {"platform": "test", "base_url": "http://fnord"}}
with assert_setup_component(1, DOMAIN):
assert await async_setup_component(hass, DOMAIN, config)
await hass.services.async_call(
DOMAIN,
"test_say",
{
ATTR_ENTITY_ID: "media_player.something",
ATTR_MESSAGE: "There is someone at the door.",
},
blocking=True,
)
assert len(calls) == 1
assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MediaType.MUSIC
assert (
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== "http://fnord"
"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491"
"_en-us_-_test.mp3"
)
async def test_service_without_cache_config(
hass: HomeAssistant, empty_cache_dir, mock_tts
) -> None:
"""Set up a TTS platform without cache."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
config = {DOMAIN: {"platform": "test", "cache": False}}
with assert_setup_component(1, DOMAIN):
assert await async_setup_component(hass, DOMAIN, config)
await hass.services.async_call(
DOMAIN,
"test_say",
{
ATTR_ENTITY_ID: "media_player.something",
ATTR_MESSAGE: "There is someone at the door.",
},
blocking=True,
)
assert len(calls) == 1
await hass.async_block_till_done()
assert not (
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3"
).is_file()

View File

@ -1,5 +1,5 @@
"""Tests for TTS media source."""
from unittest.mock import patch
from unittest.mock import MagicMock
import pytest
@ -8,33 +8,52 @@ from homeassistant.components.media_player.errors import BrowseError
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from .common import (
DEFAULT_LANG,
MockProvider,
MockTTSEntity,
mock_config_entry_setup,
mock_setup,
)
class MSEntity(MockTTSEntity):
"""Test speech API entity."""
get_tts_audio = MagicMock(return_value=("mp3", b""))
class MSProvider(MockProvider):
"""Test speech API provider."""
get_tts_audio = MagicMock(return_value=("mp3", b""))
@pytest.fixture(autouse=True)
async def mock_get_tts_audio(hass):
async def setup_media_source(hass: HomeAssistant) -> None:
"""Set up media source."""
assert await async_setup_component(hass, "media_source", {})
assert await async_setup_component(
hass,
"tts",
{
"tts": {
"platform": "demo",
}
},
)
with patch(
"homeassistant.components.demo.tts.DemoProvider.get_tts_audio",
return_value=("mp3", b""),
) as mock_get_tts:
yield mock_get_tts
async def test_browsing(hass: HomeAssistant) -> None:
@pytest.mark.parametrize(
("mock_provider", "mock_tts_entity"),
[(MSProvider(DEFAULT_LANG), MSEntity(DEFAULT_LANG))],
)
@pytest.mark.parametrize(
"setup",
[
"mock_setup",
"mock_config_entry_setup",
],
indirect=["setup"],
)
async def test_browsing(hass: HomeAssistant, setup: str) -> None:
"""Test browsing TTS media source."""
item = await media_source.async_browse_media(hass, "media-source://tts")
assert item is not None
assert item.title == "Text to Speech"
assert item.children is not None
assert len(item.children) == 1
assert item.can_play is False
assert item.can_expand is True
@ -42,9 +61,10 @@ async def test_browsing(hass: HomeAssistant) -> None:
item_child = await media_source.async_browse_media(
hass, item.children[0].media_content_id
)
assert item_child is not None
assert item_child.media_content_id == item.children[0].media_content_id
assert item_child.title == "Demo"
assert item_child.title == "Test"
assert item_child.children is None
assert item_child.can_play is False
assert item_child.can_expand is True
@ -52,12 +72,13 @@ async def test_browsing(hass: HomeAssistant) -> None:
item_child = await media_source.async_browse_media(
hass, item.children[0].media_content_id + "?message=bla"
)
assert item_child is not None
assert (
item_child.media_content_id
== item.children[0].media_content_id + "?message=bla"
)
assert item_child.title == "Demo"
assert item_child.title == "Test"
assert item_child.children is None
assert item_child.can_play is False
assert item_child.can_expand is True
@ -66,10 +87,14 @@ async def test_browsing(hass: HomeAssistant) -> None:
await media_source.async_browse_media(hass, "media-source://tts/non-existing")
async def test_resolving(hass: HomeAssistant, mock_get_tts_audio) -> None:
"""Test resolving."""
@pytest.mark.parametrize("mock_provider", [MSProvider(DEFAULT_LANG)])
async def test_legacy_resolving(hass: HomeAssistant, mock_provider: MSProvider) -> None:
"""Test resolving legacy provider."""
await mock_setup(hass, mock_provider)
mock_get_tts_audio = mock_provider.get_tts_audio
media = await media_source.async_resolve_media(
hass, "media-source://tts/demo?message=Hello%20World", None
hass, "media-source://tts/test?message=Hello%20World", None
)
assert media.url.startswith("/api/tts_proxy/")
assert media.mime_type == "audio/mpeg"
@ -77,14 +102,14 @@ async def test_resolving(hass: HomeAssistant, mock_get_tts_audio) -> None:
assert len(mock_get_tts_audio.mock_calls) == 1
message, language = mock_get_tts_audio.mock_calls[0][1]
assert message == "Hello World"
assert language == "en"
assert language == "en_US"
assert mock_get_tts_audio.mock_calls[0][2]["options"] is None
# Pass language and options
mock_get_tts_audio.reset_mock()
media = await media_source.async_resolve_media(
hass,
"media-source://tts/demo?message=Bye%20World&language=de&voice=Paulus",
"media-source://tts/test?message=Bye%20World&language=de&voice=Paulus",
None,
)
assert media.url.startswith("/api/tts_proxy/")
@ -93,15 +118,62 @@ async def test_resolving(hass: HomeAssistant, mock_get_tts_audio) -> None:
assert len(mock_get_tts_audio.mock_calls) == 1
message, language = mock_get_tts_audio.mock_calls[0][1]
assert message == "Bye World"
assert language == "de"
assert language == "de_DE"
assert mock_get_tts_audio.mock_calls[0][2]["options"] == {"voice": "Paulus"}
async def test_resolving_errors(hass: HomeAssistant) -> None:
@pytest.mark.parametrize("mock_tts_entity", [MSEntity(DEFAULT_LANG)])
async def test_resolving(hass: HomeAssistant, mock_tts_entity: MSEntity) -> None:
"""Test resolving entity."""
await mock_config_entry_setup(hass, mock_tts_entity)
mock_get_tts_audio = mock_tts_entity.get_tts_audio
media = await media_source.async_resolve_media(
hass, "media-source://tts/tts.test?message=Hello%20World", None
)
assert media.url.startswith("/api/tts_proxy/")
assert media.mime_type == "audio/mpeg"
assert len(mock_get_tts_audio.mock_calls) == 1
message, language = mock_get_tts_audio.mock_calls[0][1]
assert message == "Hello World"
assert language == "en_US"
assert mock_get_tts_audio.mock_calls[0][2]["options"] is None
# Pass language and options
mock_get_tts_audio.reset_mock()
media = await media_source.async_resolve_media(
hass,
"media-source://tts/tts.test?message=Bye%20World&language=de&voice=Paulus",
None,
)
assert media.url.startswith("/api/tts_proxy/")
assert media.mime_type == "audio/mpeg"
assert len(mock_get_tts_audio.mock_calls) == 1
message, language = mock_get_tts_audio.mock_calls[0][1]
assert message == "Bye World"
assert language == "de_DE"
assert mock_get_tts_audio.mock_calls[0][2]["options"] == {"voice": "Paulus"}
@pytest.mark.parametrize(
("mock_provider", "mock_tts_entity"),
[(MSProvider(DEFAULT_LANG), MSEntity(DEFAULT_LANG))],
)
@pytest.mark.parametrize(
"setup",
[
"mock_setup",
"mock_config_entry_setup",
],
indirect=["setup"],
)
async def test_resolving_errors(hass: HomeAssistant, setup: str) -> None:
"""Test resolving."""
# No message added
with pytest.raises(media_source.Unresolvable):
await media_source.async_resolve_media(hass, "media-source://tts/demo", None)
await media_source.async_resolve_media(hass, "media-source://tts/test", None)
# Non-existing provider
with pytest.raises(media_source.Unresolvable):

View File

@ -1,28 +1,22 @@
"""The tests for the TTS component."""
import pytest
import yarl
import homeassistant.components.media_player as media_player
from homeassistant.components import media_player, notify, tts
from homeassistant.components.media_player import (
DOMAIN as DOMAIN_MP,
SERVICE_PLAY_MEDIA,
)
import homeassistant.components.notify as notify
import homeassistant.components.tts as tts
from homeassistant.config import async_process_ha_core_config
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from .common import MockTTSEntity, mock_config_entry_setup
from tests.common import assert_setup_component, async_mock_service
def relative_url(url):
"""Convert an absolute url to a relative one."""
return str(yarl.URL(url).relative())
@pytest.fixture(autouse=True)
async def internal_url_mock(hass):
async def internal_url_mock(hass: HomeAssistant) -> None:
"""Mock internal URL of the instance."""
await async_process_ha_core_config(
hass,
@ -30,8 +24,8 @@ async def internal_url_mock(hass):
)
async def test_setup_platform(hass: HomeAssistant) -> None:
"""Set up the tts platform ."""
async def test_setup_legacy_platform(hass: HomeAssistant) -> None:
"""Set up the tts notify platform ."""
config = {
notify.DOMAIN: {
"platform": "tts",
@ -46,7 +40,23 @@ async def test_setup_platform(hass: HomeAssistant) -> None:
assert hass.services.has_service(notify.DOMAIN, "tts_test")
async def test_setup_component_and_test_service(hass: HomeAssistant) -> None:
async def test_setup_platform(hass: HomeAssistant) -> None:
"""Set up the tts notify platform ."""
config = {
notify.DOMAIN: {
"platform": "tts",
"name": "tts_test",
"entity_id": "tts.test",
"media_player": "media_player.demo",
}
}
with assert_setup_component(1, notify.DOMAIN):
assert await async_setup_component(hass, notify.DOMAIN, config)
assert hass.services.has_service(notify.DOMAIN, "tts_test")
async def test_setup_legacy_service(hass: HomeAssistant) -> None:
"""Set up the demo platform and call service."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -80,3 +90,38 @@ async def test_setup_component_and_test_service(hass: HomeAssistant) -> None:
await hass.async_block_till_done()
assert len(calls) == 1
async def test_setup_service(
hass: HomeAssistant, mock_tts_entity: MockTTSEntity
) -> None:
"""Set up platform and call service."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
config = {
notify.DOMAIN: {
"platform": "tts",
"name": "tts_test",
"entity_id": "tts.test",
"media_player": "media_player.demo",
"language": "en",
},
}
await mock_config_entry_setup(hass, mock_tts_entity)
with assert_setup_component(1, notify.DOMAIN):
assert await async_setup_component(hass, notify.DOMAIN, config)
await hass.services.async_call(
notify.DOMAIN,
"tts_test",
{
tts.ATTR_MESSAGE: "There is someone at the door.",
},
blocking=True,
)
await hass.async_block_till_done()
assert len(calls) == 1