mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Add tts entity (#91692)
* Add tts entity * Allow passing engine id to url view * Update async_resolve_engine * Add and update more tests * Fix assist pipeline tests temporarily * Move fixtures * Update notify platform * Complete legacy tests * Update media source tests * Update async_get_text_to_speech_languages * Address comment --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
458276a6a6
commit
1a18dc7425
@ -1,8 +1,11 @@
|
||||
"""Provide functionality for TTS."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
import asyncio
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
import hashlib
|
||||
from http import HTTPStatus
|
||||
import io
|
||||
@ -10,7 +13,7 @@ import logging
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, TypedDict, final
|
||||
|
||||
from aiohttp import web
|
||||
import mutagen
|
||||
@ -19,13 +22,30 @@ import voluptuous as vol
|
||||
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.const import PLATFORM_FORMAT
|
||||
from homeassistant.components.media_player import (
|
||||
ATTR_MEDIA_ANNOUNCE,
|
||||
ATTR_MEDIA_CONTENT_ID,
|
||||
ATTR_MEDIA_CONTENT_TYPE,
|
||||
DOMAIN as DOMAIN_MP,
|
||||
SERVICE_PLAY_MEDIA,
|
||||
MediaType,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import (
|
||||
ATTR_ENTITY_ID,
|
||||
PLATFORM_FORMAT,
|
||||
STATE_UNAVAILABLE,
|
||||
STATE_UNKNOWN,
|
||||
)
|
||||
from homeassistant.core import HassJob, HomeAssistant, ServiceCall, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.event import async_call_later
|
||||
from homeassistant.helpers.network import get_url
|
||||
from homeassistant.helpers.restore_state import RestoreEntity
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.util import language as language_util
|
||||
from homeassistant.util import dt as dt_util, language as language_util
|
||||
|
||||
from .const import (
|
||||
ATTR_CACHE,
|
||||
@ -36,12 +56,14 @@ from .const import (
|
||||
CONF_CACHE,
|
||||
CONF_CACHE_DIR,
|
||||
CONF_TIME_MEMORY,
|
||||
DATA_TTS_MANAGER,
|
||||
DEFAULT_CACHE,
|
||||
DEFAULT_CACHE_DIR,
|
||||
DEFAULT_TIME_MEMORY,
|
||||
DOMAIN,
|
||||
TtsAudioType,
|
||||
)
|
||||
from .helper import get_engine_instance
|
||||
from .legacy import PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE, Provider, async_setup_legacy
|
||||
from .media_source import generate_media_source_id, media_source_id_to_kwargs
|
||||
|
||||
@ -64,6 +86,7 @@ _LOGGER = logging.getLogger(__name__)
|
||||
|
||||
ATTR_PLATFORM = "platform"
|
||||
ATTR_AUDIO_OUTPUT = "audio_output"
|
||||
ATTR_MEDIA_PLAYER_ENTITY_ID = "media_player_entity_id"
|
||||
|
||||
CONF_LANG = "language"
|
||||
|
||||
@ -71,7 +94,12 @@ BASE_URL_KEY = "tts_base_url"
|
||||
|
||||
SERVICE_CLEAR_CACHE = "clear_cache"
|
||||
|
||||
_RE_VOICE_FILE = re.compile(r"([a-f0-9]{40})_([^_]+)_([^_]+)_([a-z_]+)\.[a-z0-9]{3,4}")
|
||||
_RE_LEGACY_VOICE_FILE = re.compile(
|
||||
r"([a-f0-9]{40})_([^_]+)_([^_]+)_([a-z_]+)\.[a-z0-9]{3,4}"
|
||||
)
|
||||
_RE_VOICE_FILE = re.compile(
|
||||
r"([a-f0-9]{40})_([^_]+)_([^_]+)_(tts\.[a-z_]+)\.[a-z0-9]{3,4}"
|
||||
)
|
||||
KEY_PATTERN = "{0}_{1}_{2}_{3}"
|
||||
|
||||
SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({})
|
||||
@ -91,20 +119,23 @@ def async_resolve_engine(hass: HomeAssistant, engine: str | None) -> str | None:
|
||||
|
||||
Returns None if no engines found or invalid engine passed in.
|
||||
"""
|
||||
manager: SpeechManager = hass.data[DOMAIN]
|
||||
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
|
||||
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
||||
|
||||
if engine is not None:
|
||||
if engine not in manager.providers:
|
||||
if not component.get_entity(engine) and engine not in manager.providers:
|
||||
return None
|
||||
return engine
|
||||
|
||||
if not manager.providers:
|
||||
return None
|
||||
|
||||
if "cloud" in manager.providers:
|
||||
return "cloud"
|
||||
|
||||
return next(iter(manager.providers))
|
||||
entity = next(iter(component.entities), None)
|
||||
|
||||
if entity is not None:
|
||||
return entity.entity_id
|
||||
|
||||
return next(iter(manager.providers), None)
|
||||
|
||||
|
||||
async def async_support_options(
|
||||
@ -114,9 +145,13 @@ async def async_support_options(
|
||||
options: dict | None = None,
|
||||
) -> bool:
|
||||
"""Return if an engine supports options."""
|
||||
manager: SpeechManager = hass.data[DOMAIN]
|
||||
if (engine_instance := get_engine_instance(hass, engine)) is None:
|
||||
raise HomeAssistantError(f"Provider {engine} not found")
|
||||
|
||||
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
||||
|
||||
try:
|
||||
manager.process_options(engine, language, options)
|
||||
manager.process_options(engine_instance, language, options)
|
||||
except HomeAssistantError:
|
||||
return False
|
||||
|
||||
@ -128,7 +163,7 @@ async def async_get_media_source_audio(
|
||||
media_source_id: str,
|
||||
) -> tuple[str, bytes]:
|
||||
"""Get TTS audio as extension, data."""
|
||||
manager: SpeechManager = hass.data[DOMAIN]
|
||||
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
||||
return await manager.async_get_tts_audio(
|
||||
**media_source_id_to_kwargs(media_source_id),
|
||||
)
|
||||
@ -139,7 +174,13 @@ def async_get_text_to_speech_languages(hass: HomeAssistant) -> set[str]:
|
||||
"""Return a set with the union of languages supported by tts engines."""
|
||||
languages = set()
|
||||
|
||||
manager: SpeechManager = hass.data[DOMAIN]
|
||||
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
|
||||
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
||||
|
||||
for entity in component.entities:
|
||||
for language_tag in entity.supported_languages:
|
||||
languages.add(language_tag)
|
||||
|
||||
for tts_engine in manager.providers.values():
|
||||
for language_tag in tts_engine.supported_languages:
|
||||
languages.add(language_tag)
|
||||
@ -173,7 +214,13 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
_LOGGER.exception("Error on cache init")
|
||||
return False
|
||||
|
||||
hass.data[DOMAIN] = tts
|
||||
hass.data[DATA_TTS_MANAGER] = tts
|
||||
component = hass.data[DOMAIN] = EntityComponent[TextToSpeechEntity](
|
||||
_LOGGER, DOMAIN, hass
|
||||
)
|
||||
|
||||
component.register_shutdown()
|
||||
|
||||
hass.http.register_view(TextToSpeechView(tts))
|
||||
hass.http.register_view(TextToSpeechUrlView(tts))
|
||||
|
||||
@ -182,6 +229,18 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
if platform_setups:
|
||||
await asyncio.wait([asyncio.create_task(setup) for setup in platform_setups])
|
||||
|
||||
component.async_register_entity_service(
|
||||
"speak",
|
||||
{
|
||||
vol.Required(ATTR_MEDIA_PLAYER_ENTITY_ID): cv.comp_entity_ids,
|
||||
vol.Required(ATTR_MESSAGE): cv.string,
|
||||
vol.Optional(ATTR_CACHE, default=DEFAULT_CACHE): cv.boolean,
|
||||
vol.Optional(ATTR_LANGUAGE): cv.string,
|
||||
vol.Optional(ATTR_OPTIONS): dict,
|
||||
},
|
||||
"async_speak",
|
||||
)
|
||||
|
||||
async def async_clear_cache_handle(service: ServiceCall) -> None:
|
||||
"""Handle clear cache service call."""
|
||||
await tts.async_clear_cache()
|
||||
@ -196,6 +255,129 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Set up a config entry."""
|
||||
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
|
||||
return await component.async_setup_entry(entry)
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload a config entry."""
|
||||
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
|
||||
return await component.async_unload_entry(entry)
|
||||
|
||||
|
||||
class TextToSpeechEntity(RestoreEntity):
|
||||
"""Represent a single TTS engine."""
|
||||
|
||||
_attr_should_poll = False
|
||||
__last_tts_loaded: str | None = None
|
||||
|
||||
@property
|
||||
@final
|
||||
def state(self) -> str | None:
|
||||
"""Return the state of the entity."""
|
||||
if self.__last_tts_loaded is None:
|
||||
return None
|
||||
return self.__last_tts_loaded
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_languages(self) -> list[str]:
|
||||
"""Return a list of supported languages."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def default_language(self) -> str:
|
||||
"""Return the default language."""
|
||||
|
||||
@property
|
||||
def supported_options(self) -> list[str] | None:
|
||||
"""Return a list of supported options like voice, emotions."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def default_options(self) -> Mapping[str, Any] | None:
|
||||
"""Return a mapping with the default options."""
|
||||
return None
|
||||
|
||||
@callback
|
||||
def async_get_supported_voices(self, language: str) -> list[str] | None:
|
||||
"""Return a list of supported voices for a language."""
|
||||
return None
|
||||
|
||||
async def async_internal_added_to_hass(self) -> None:
|
||||
"""Call when the entity is added to hass."""
|
||||
await super().async_internal_added_to_hass()
|
||||
state = await self.async_get_last_state()
|
||||
if (
|
||||
state is not None
|
||||
and state.state is not None
|
||||
and state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
|
||||
):
|
||||
self.__last_tts_loaded = state.state
|
||||
|
||||
async def async_speak(
|
||||
self,
|
||||
media_player_entity_id: list[str],
|
||||
message: str,
|
||||
cache: bool,
|
||||
language: str | None = None,
|
||||
options: dict | None = None,
|
||||
) -> None:
|
||||
"""Speak via a Media Player."""
|
||||
await self.hass.services.async_call(
|
||||
DOMAIN_MP,
|
||||
SERVICE_PLAY_MEDIA,
|
||||
{
|
||||
ATTR_ENTITY_ID: media_player_entity_id,
|
||||
ATTR_MEDIA_CONTENT_ID: generate_media_source_id(
|
||||
self.hass,
|
||||
message=message,
|
||||
engine=self.entity_id,
|
||||
language=language,
|
||||
options=options,
|
||||
cache=cache,
|
||||
),
|
||||
ATTR_MEDIA_CONTENT_TYPE: MediaType.MUSIC,
|
||||
ATTR_MEDIA_ANNOUNCE: True,
|
||||
},
|
||||
blocking=True,
|
||||
context=self._context,
|
||||
)
|
||||
|
||||
@final
|
||||
async def internal_async_get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||
) -> TtsAudioType:
|
||||
"""Process an audio stream to TTS service.
|
||||
|
||||
Only streaming content is allowed!
|
||||
"""
|
||||
self.__last_tts_loaded = dt_util.utcnow().isoformat()
|
||||
self.async_write_ha_state()
|
||||
return await self.async_get_tts_audio(
|
||||
message=message, language=language, options=options
|
||||
)
|
||||
|
||||
def get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||
) -> TtsAudioType:
|
||||
"""Load tts audio file from the engine."""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def async_get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||
) -> TtsAudioType:
|
||||
"""Load tts audio file from the engine.
|
||||
|
||||
Return a tuple of file extension and data as bytes.
|
||||
"""
|
||||
return await self.hass.async_add_executor_job(
|
||||
partial(self.get_tts_audio, message, language, options=options)
|
||||
)
|
||||
|
||||
|
||||
def _hash_options(options: dict) -> str:
|
||||
"""Hashes an options dictionary."""
|
||||
opts_hash = hashlib.blake2s(digest_size=5)
|
||||
@ -266,7 +448,7 @@ class SpeechManager:
|
||||
def async_register_legacy_engine(
|
||||
self, engine: str, provider: Provider, config: ConfigType
|
||||
) -> None:
|
||||
"""Register a TTS provider."""
|
||||
"""Register a legacy TTS engine."""
|
||||
provider.hass = self.hass
|
||||
if provider.name is None:
|
||||
provider.name = engine
|
||||
@ -279,23 +461,20 @@ class SpeechManager:
|
||||
@callback
|
||||
def process_options(
|
||||
self,
|
||||
engine: str,
|
||||
engine_instance: TextToSpeechEntity | Provider,
|
||||
language: str | None = None,
|
||||
options: dict | None = None,
|
||||
) -> tuple[str, dict | None]:
|
||||
"""Validate and process options."""
|
||||
if (provider := self.providers.get(engine)) is None:
|
||||
raise HomeAssistantError(f"Provider {engine} not found")
|
||||
|
||||
# Languages
|
||||
language = language or provider.default_language
|
||||
language = language or engine_instance.default_language
|
||||
|
||||
if language is None or provider.supported_languages is None:
|
||||
if language is None or engine_instance.supported_languages is None:
|
||||
raise HomeAssistantError(f"Not supported language {language}")
|
||||
|
||||
if language not in provider.supported_languages:
|
||||
if language not in engine_instance.supported_languages:
|
||||
language_matches = language_util.matches(
|
||||
language, provider.supported_languages
|
||||
language, engine_instance.supported_languages
|
||||
)
|
||||
if language_matches:
|
||||
# Choose best match
|
||||
@ -304,7 +483,7 @@ class SpeechManager:
|
||||
raise HomeAssistantError(f"Not supported language {language}")
|
||||
|
||||
# Options
|
||||
if (default_options := provider.default_options) and options:
|
||||
if (default_options := engine_instance.default_options) and options:
|
||||
merged_options = dict(default_options)
|
||||
merged_options.update(options)
|
||||
options = merged_options
|
||||
@ -312,7 +491,7 @@ class SpeechManager:
|
||||
options = None if default_options is None else dict(default_options)
|
||||
|
||||
if options is not None:
|
||||
supported_options = provider.supported_options or []
|
||||
supported_options = engine_instance.supported_options or []
|
||||
invalid_opts = [
|
||||
opt_name for opt_name in options if opt_name not in supported_options
|
||||
]
|
||||
@ -333,7 +512,10 @@ class SpeechManager:
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
language, options = self.process_options(engine, language, options)
|
||||
if (engine_instance := get_engine_instance(self.hass, engine)) is None:
|
||||
raise HomeAssistantError(f"Provider {engine} not found")
|
||||
|
||||
language, options = self.process_options(engine_instance, language, options)
|
||||
cache_key = self._generate_cache_key(message, language, options, engine)
|
||||
use_cache = cache if cache is not None else self.use_cache
|
||||
|
||||
@ -344,10 +526,10 @@ class SpeechManager:
|
||||
elif use_cache and cache_key in self.file_cache:
|
||||
filename = self.file_cache[cache_key]
|
||||
self.hass.async_create_task(self._async_file_to_mem(cache_key))
|
||||
# Load speech from provider into memory
|
||||
# Load speech from engine into memory
|
||||
else:
|
||||
filename = await self._async_get_tts_audio(
|
||||
engine,
|
||||
engine_instance,
|
||||
cache_key,
|
||||
message,
|
||||
use_cache,
|
||||
@ -366,7 +548,10 @@ class SpeechManager:
|
||||
options: dict | None = None,
|
||||
) -> tuple[str, bytes]:
|
||||
"""Fetch TTS audio."""
|
||||
language, options = self.process_options(engine, language, options)
|
||||
if (engine_instance := get_engine_instance(self.hass, engine)) is None:
|
||||
raise HomeAssistantError(f"Provider {engine} not found")
|
||||
|
||||
language, options = self.process_options(engine_instance, language, options)
|
||||
cache_key = self._generate_cache_key(message, language, options, engine)
|
||||
use_cache = cache if cache is not None else self.use_cache
|
||||
|
||||
@ -376,7 +561,7 @@ class SpeechManager:
|
||||
await self._async_file_to_mem(cache_key)
|
||||
else:
|
||||
await self._async_get_tts_audio(
|
||||
engine, cache_key, message, use_cache, language, options
|
||||
engine_instance, cache_key, message, use_cache, language, options
|
||||
)
|
||||
|
||||
extension = os.path.splitext(self.mem_cache[cache_key]["filename"])[1][1:]
|
||||
@ -403,7 +588,7 @@ class SpeechManager:
|
||||
|
||||
async def _async_get_tts_audio(
|
||||
self,
|
||||
engine: str,
|
||||
engine_instance: TextToSpeechEntity | Provider,
|
||||
cache_key: str,
|
||||
message: str,
|
||||
cache: bool,
|
||||
@ -414,8 +599,6 @@ class SpeechManager:
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
provider = self.providers[engine]
|
||||
|
||||
if options is not None and ATTR_AUDIO_OUTPUT in options:
|
||||
expected_extension = options[ATTR_AUDIO_OUTPUT]
|
||||
else:
|
||||
@ -423,26 +606,38 @@ class SpeechManager:
|
||||
|
||||
async def get_tts_data() -> str:
|
||||
"""Handle data available."""
|
||||
extension, data = await provider.async_get_tts_audio(
|
||||
message, language, options
|
||||
)
|
||||
if engine_instance.name is None:
|
||||
raise HomeAssistantError("TTS engine name is not set.")
|
||||
|
||||
if isinstance(engine_instance, Provider):
|
||||
extension, data = await engine_instance.async_get_tts_audio(
|
||||
message, language, options
|
||||
)
|
||||
else:
|
||||
extension, data = await engine_instance.internal_async_get_tts_audio(
|
||||
message, language, options
|
||||
)
|
||||
|
||||
if data is None or extension is None:
|
||||
raise HomeAssistantError(f"No TTS from {engine} for '{message}'")
|
||||
raise HomeAssistantError(
|
||||
f"No TTS from {engine_instance.name} for '{message}'"
|
||||
)
|
||||
|
||||
# Create file infos
|
||||
filename = f"{cache_key}.{extension}".lower()
|
||||
|
||||
# Validate filename
|
||||
if not _RE_VOICE_FILE.match(filename):
|
||||
if not _RE_VOICE_FILE.match(filename) and not _RE_LEGACY_VOICE_FILE.match(
|
||||
filename
|
||||
):
|
||||
raise HomeAssistantError(
|
||||
f"TTS filename '{filename}' from {engine} is invalid!"
|
||||
f"TTS filename '{filename}' from {engine_instance.name} is invalid!"
|
||||
)
|
||||
|
||||
# Save to memory
|
||||
if extension == "mp3":
|
||||
data = self.write_tags(
|
||||
filename, data, provider, message, language, options
|
||||
filename, data, engine_instance.name, message, language, options
|
||||
)
|
||||
self._async_store_to_memcache(cache_key, filename, data)
|
||||
|
||||
@ -547,7 +742,9 @@ class SpeechManager:
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
if not (record := _RE_VOICE_FILE.match(filename.lower())):
|
||||
if not (record := _RE_VOICE_FILE.match(filename.lower())) and not (
|
||||
record := _RE_LEGACY_VOICE_FILE.match(filename.lower())
|
||||
):
|
||||
raise HomeAssistantError("Wrong tts file format!")
|
||||
|
||||
cache_key = KEY_PATTERN.format(
|
||||
@ -570,7 +767,7 @@ class SpeechManager:
|
||||
def write_tags(
|
||||
filename: str,
|
||||
data: bytes,
|
||||
provider: Provider,
|
||||
engine_name: str,
|
||||
message: str,
|
||||
language: str,
|
||||
options: dict | None,
|
||||
@ -584,7 +781,7 @@ class SpeechManager:
|
||||
data_bytes.name = filename
|
||||
data_bytes.seek(0)
|
||||
|
||||
album = provider.name
|
||||
album = engine_name
|
||||
artist = language
|
||||
|
||||
if options is not None and (voice := options.get("voice")) is not None:
|
||||
@ -635,7 +832,9 @@ def _get_cache_files(cache_dir: str) -> dict[str, str]:
|
||||
|
||||
folder_data = os.listdir(cache_dir)
|
||||
for file_data in folder_data:
|
||||
if record := _RE_VOICE_FILE.match(file_data):
|
||||
if (record := _RE_VOICE_FILE.match(file_data)) or (
|
||||
record := _RE_LEGACY_VOICE_FILE.match(file_data)
|
||||
):
|
||||
key = KEY_PATTERN.format(
|
||||
record.group(1), record.group(2), record.group(3), record.group(4)
|
||||
)
|
||||
@ -660,12 +859,16 @@ class TextToSpeechUrlView(HomeAssistantView):
|
||||
data = await request.json()
|
||||
except ValueError:
|
||||
return self.json_message("Invalid JSON specified", HTTPStatus.BAD_REQUEST)
|
||||
if not data.get(ATTR_PLATFORM) and data.get(ATTR_MESSAGE):
|
||||
if (
|
||||
not data.get("engine_id")
|
||||
and not data.get(ATTR_PLATFORM)
|
||||
or not data.get(ATTR_MESSAGE)
|
||||
):
|
||||
return self.json_message(
|
||||
"Must specify platform and message", HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
|
||||
p_type = data[ATTR_PLATFORM]
|
||||
engine = data.get("engine_id") or data[ATTR_PLATFORM]
|
||||
message = data[ATTR_MESSAGE]
|
||||
cache = data.get(ATTR_CACHE)
|
||||
language = data.get(ATTR_LANGUAGE)
|
||||
@ -673,7 +876,7 @@ class TextToSpeechUrlView(HomeAssistantView):
|
||||
|
||||
try:
|
||||
path = await self.tts.async_get_url_path(
|
||||
p_type, message, cache=cache, language=language, options=options
|
||||
engine, message, cache=cache, language=language, options=options
|
||||
)
|
||||
except HomeAssistantError as err:
|
||||
_LOGGER.error("Error on init tts: %s", err)
|
||||
@ -724,13 +927,26 @@ def websocket_list_engines(
|
||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||
) -> None:
|
||||
"""List text to speech engines and, optionally, if they support a given language."""
|
||||
manager: SpeechManager = hass.data[DOMAIN]
|
||||
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
|
||||
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
||||
|
||||
country = msg.get("country")
|
||||
language = msg.get("language")
|
||||
providers = []
|
||||
provider_info: dict[str, Any]
|
||||
|
||||
for entity in component.entities:
|
||||
provider_info = {
|
||||
"engine_id": entity.entity_id,
|
||||
"supported_languages": entity.supported_languages,
|
||||
}
|
||||
if language:
|
||||
provider_info["supported_languages"] = language_util.matches(
|
||||
language, entity.supported_languages, country
|
||||
)
|
||||
providers.append(provider_info)
|
||||
for engine_id, provider in manager.providers.items():
|
||||
provider_info: dict[str, Any] = {
|
||||
provider_info = {
|
||||
"engine_id": engine_id,
|
||||
"supported_languages": provider.supported_languages,
|
||||
}
|
||||
@ -760,10 +976,9 @@ def websocket_list_engine_voices(
|
||||
engine_id = msg["engine_id"]
|
||||
language = msg["language"]
|
||||
|
||||
manager: SpeechManager = hass.data[DOMAIN]
|
||||
engine = manager.providers.get(engine_id)
|
||||
engine_instance = get_engine_instance(hass, engine_id)
|
||||
|
||||
if not engine:
|
||||
if not engine_instance:
|
||||
connection.send_error(
|
||||
msg["id"],
|
||||
websocket_api.const.ERR_NOT_FOUND,
|
||||
@ -771,6 +986,6 @@ def websocket_list_engine_voices(
|
||||
)
|
||||
return
|
||||
|
||||
voices = {"voices": engine.async_get_supported_voices(language)}
|
||||
voices = {"voices": engine_instance.async_get_supported_voices(language)}
|
||||
|
||||
connection.send_message(websocket_api.result_message(msg["id"], voices))
|
||||
|
@ -16,4 +16,6 @@ DEFAULT_TIME_MEMORY = 300
|
||||
|
||||
DOMAIN = "tts"
|
||||
|
||||
DATA_TTS_MANAGER = "tts_manager"
|
||||
|
||||
TtsAudioType = tuple[str | None, bytes | None]
|
||||
|
26
homeassistant/components/tts/helper.py
Normal file
26
homeassistant/components/tts/helper.py
Normal file
@ -0,0 +1,26 @@
|
||||
"""Provide helper functions for the TTS."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
|
||||
from .const import DATA_TTS_MANAGER, DOMAIN
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import SpeechManager, TextToSpeechEntity
|
||||
from .legacy import Provider
|
||||
|
||||
|
||||
def get_engine_instance(
|
||||
hass: HomeAssistant, engine: str
|
||||
) -> TextToSpeechEntity | Provider | None:
|
||||
"""Get engine instance."""
|
||||
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
|
||||
|
||||
if entity := component.get_entity(engine):
|
||||
return entity
|
||||
|
||||
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
||||
return manager.providers.get(engine)
|
@ -44,6 +44,7 @@ from .const import (
|
||||
CONF_CACHE_DIR,
|
||||
CONF_FIELDS,
|
||||
CONF_TIME_MEMORY,
|
||||
DATA_TTS_MANAGER,
|
||||
DEFAULT_CACHE,
|
||||
DEFAULT_CACHE_DIR,
|
||||
DEFAULT_TIME_MEMORY,
|
||||
@ -111,7 +112,7 @@ async def async_setup_legacy(
|
||||
hass: HomeAssistant, config: ConfigType
|
||||
) -> list[Coroutine[Any, Any, None]]:
|
||||
"""Set up legacy text to speech providers."""
|
||||
tts: SpeechManager = hass.data[DOMAIN]
|
||||
tts: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
||||
|
||||
# Load service descriptions from tts/services.yaml
|
||||
services_yaml = Path(__file__).parent / "services.yaml"
|
||||
|
@ -17,12 +17,14 @@ from homeassistant.components.media_source import (
|
||||
)
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.network import get_url
|
||||
|
||||
from .const import DOMAIN
|
||||
from .const import DATA_TTS_MANAGER, DOMAIN
|
||||
from .helper import get_engine_instance
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import SpeechManager
|
||||
from . import SpeechManager, TextToSpeechEntity
|
||||
|
||||
|
||||
async def async_get_media_source(hass: HomeAssistant) -> TTSMediaSource:
|
||||
@ -42,12 +44,16 @@ def generate_media_source_id(
|
||||
"""Generate a media source ID for text-to-speech."""
|
||||
from . import async_resolve_engine # pylint: disable=import-outside-toplevel
|
||||
|
||||
manager: SpeechManager = hass.data[DOMAIN]
|
||||
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
||||
|
||||
if (engine := async_resolve_engine(hass, engine)) is None:
|
||||
raise HomeAssistantError("Invalid TTS provider selected")
|
||||
|
||||
manager.process_options(engine, language, options)
|
||||
engine_instance = get_engine_instance(hass, engine)
|
||||
# We raise above if the engine is not resolved, so engine_instance can't be None
|
||||
assert engine_instance is not None
|
||||
|
||||
manager.process_options(engine_instance, language, options)
|
||||
params = {
|
||||
"message": message,
|
||||
}
|
||||
@ -107,7 +113,7 @@ class TTSMediaSource(MediaSource):
|
||||
|
||||
async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia:
|
||||
"""Resolve media to a url."""
|
||||
manager: SpeechManager = self.hass.data[DOMAIN]
|
||||
manager: SpeechManager = self.hass.data[DATA_TTS_MANAGER]
|
||||
|
||||
try:
|
||||
url = await manager.async_get_url_path(
|
||||
@ -129,12 +135,15 @@ class TTSMediaSource(MediaSource):
|
||||
) -> BrowseMediaSource:
|
||||
"""Return media."""
|
||||
if item.identifier:
|
||||
provider, _, params = item.identifier.partition("?")
|
||||
return self._provider_item(provider, params)
|
||||
engine, _, params = item.identifier.partition("?")
|
||||
return self._engine_item(engine, params)
|
||||
|
||||
# Root. List providers.
|
||||
manager: SpeechManager = self.hass.data[DOMAIN]
|
||||
children = [self._provider_item(provider) for provider in manager.providers]
|
||||
manager: SpeechManager = self.hass.data[DATA_TTS_MANAGER]
|
||||
component: EntityComponent[TextToSpeechEntity] = self.hass.data[DOMAIN]
|
||||
children = [self._engine_item(engine) for engine in manager.providers] + [
|
||||
self._engine_item(entity.entity_id) for entity in component.entities
|
||||
]
|
||||
return BrowseMediaSource(
|
||||
domain=DOMAIN,
|
||||
identifier=None,
|
||||
@ -148,14 +157,19 @@ class TTSMediaSource(MediaSource):
|
||||
)
|
||||
|
||||
@callback
|
||||
def _provider_item(
|
||||
self, provider_domain: str, params: str | None = None
|
||||
) -> BrowseMediaSource:
|
||||
def _engine_item(self, engine: str, params: str | None = None) -> BrowseMediaSource:
|
||||
"""Return provider item."""
|
||||
manager: SpeechManager = self.hass.data[DOMAIN]
|
||||
if (provider := manager.providers.get(provider_domain)) is None:
|
||||
from . import TextToSpeechEntity # pylint: disable=import-outside-toplevel
|
||||
|
||||
if (engine_instance := get_engine_instance(self.hass, engine)) is None:
|
||||
raise BrowseError("Unknown provider")
|
||||
|
||||
if isinstance(engine_instance, TextToSpeechEntity):
|
||||
assert engine_instance.platform is not None
|
||||
engine_domain = engine_instance.platform.domain
|
||||
else:
|
||||
engine_domain = engine
|
||||
|
||||
if params:
|
||||
params = f"?{params}"
|
||||
else:
|
||||
@ -163,11 +177,11 @@ class TTSMediaSource(MediaSource):
|
||||
|
||||
return BrowseMediaSource(
|
||||
domain=DOMAIN,
|
||||
identifier=f"{provider_domain}{params}",
|
||||
identifier=f"{engine}{params}",
|
||||
media_class=MediaClass.APP,
|
||||
media_content_type="provider",
|
||||
title=provider.name,
|
||||
thumbnail=f"https://brands.home-assistant.io/_/{provider_domain}/logo.png",
|
||||
title=engine_instance.name,
|
||||
thumbnail=f"https://brands.home-assistant.io/_/{engine_domain}/logo.png",
|
||||
can_play=False,
|
||||
can_expand=True,
|
||||
)
|
||||
|
@ -7,22 +7,26 @@ from typing import Any
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.notify import PLATFORM_SCHEMA, BaseNotificationService
|
||||
from homeassistant.const import ATTR_ENTITY_ID, CONF_NAME
|
||||
from homeassistant.const import ATTR_ENTITY_ID, CONF_ENTITY_ID, CONF_NAME
|
||||
from homeassistant.core import HomeAssistant, split_entity_id
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
|
||||
from . import ATTR_LANGUAGE, ATTR_MESSAGE, DOMAIN
|
||||
from . import ATTR_LANGUAGE, ATTR_MEDIA_PLAYER_ENTITY_ID, ATTR_MESSAGE, DOMAIN
|
||||
|
||||
CONF_MEDIA_PLAYER = "media_player"
|
||||
CONF_TTS_SERVICE = "tts_service"
|
||||
ENTITY_LEGACY_PROVIDER_GROUP = "entity_or_legacy_provider"
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
|
||||
{
|
||||
vol.Required(CONF_NAME): cv.string,
|
||||
vol.Required(CONF_TTS_SERVICE): cv.entity_id,
|
||||
vol.Exclusive(CONF_TTS_SERVICE, ENTITY_LEGACY_PROVIDER_GROUP): cv.entity_id,
|
||||
vol.Exclusive(CONF_ENTITY_ID, ENTITY_LEGACY_PROVIDER_GROUP): cv.entities_domain(
|
||||
DOMAIN
|
||||
),
|
||||
vol.Required(CONF_MEDIA_PLAYER): cv.entity_id,
|
||||
vol.Optional(ATTR_LANGUAGE): cv.string,
|
||||
}
|
||||
@ -44,7 +48,12 @@ class TTSNotificationService(BaseNotificationService):
|
||||
|
||||
def __init__(self, config: ConfigType) -> None:
|
||||
"""Initialize the service."""
|
||||
_, self._tts_service = split_entity_id(config[CONF_TTS_SERVICE])
|
||||
self._target: str | None = None
|
||||
self._tts_service: str | None = None
|
||||
if entity_id := config.get(CONF_ENTITY_ID):
|
||||
self._target = entity_id
|
||||
else:
|
||||
_, self._tts_service = split_entity_id(config[CONF_TTS_SERVICE])
|
||||
self._media_player = config[CONF_MEDIA_PLAYER]
|
||||
self._language = config.get(ATTR_LANGUAGE)
|
||||
|
||||
@ -54,13 +63,21 @@ class TTSNotificationService(BaseNotificationService):
|
||||
|
||||
data = {
|
||||
ATTR_MESSAGE: message,
|
||||
ATTR_ENTITY_ID: self._media_player,
|
||||
}
|
||||
service_name = ""
|
||||
|
||||
if self._tts_service:
|
||||
data[ATTR_ENTITY_ID] = self._media_player
|
||||
service_name = self._tts_service
|
||||
elif self._target:
|
||||
data[ATTR_ENTITY_ID] = self._target
|
||||
data[ATTR_MEDIA_PLAYER_ENTITY_ID] = self._media_player
|
||||
service_name = "speak"
|
||||
if self._language:
|
||||
data[ATTR_LANGUAGE] = self._language
|
||||
|
||||
await self.hass.services.async_call(
|
||||
DOMAIN,
|
||||
self._tts_service,
|
||||
service_name,
|
||||
data,
|
||||
)
|
||||
|
@ -40,6 +40,49 @@ say:
|
||||
selector:
|
||||
object:
|
||||
|
||||
speak:
|
||||
name: Speak
|
||||
description: Speak something using text-to-speech on a media player.
|
||||
target:
|
||||
entity:
|
||||
domain: tts
|
||||
fields:
|
||||
media_player_entity_id:
|
||||
name: Media Player Entity
|
||||
description: Name(s) of media player entities.
|
||||
required: true
|
||||
selector:
|
||||
entity:
|
||||
domain: media_player
|
||||
message:
|
||||
name: Message
|
||||
description: Text to speak on devices.
|
||||
example: "My name is hanna"
|
||||
required: true
|
||||
selector:
|
||||
text:
|
||||
cache:
|
||||
name: Cache
|
||||
description: Control file cache of this message.
|
||||
default: true
|
||||
selector:
|
||||
boolean:
|
||||
language:
|
||||
name: Language
|
||||
description: Language to use for speech generation.
|
||||
example: "ru"
|
||||
selector:
|
||||
text:
|
||||
options:
|
||||
name: Options
|
||||
description:
|
||||
A dictionary containing platform-specific options. Optional depending on
|
||||
the platform.
|
||||
advanced: true
|
||||
example: platform specific
|
||||
selector:
|
||||
object:
|
||||
|
||||
clear_cache:
|
||||
name: Clear TTS cache
|
||||
description: Remove all text-to-speech cache files and RAM cache.
|
||||
|
@ -1,4 +1,6 @@
|
||||
"""Test fixtures for voice assistant."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterable, Generator
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock
|
||||
@ -23,6 +25,7 @@ from tests.common import (
|
||||
mock_platform,
|
||||
)
|
||||
from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import
|
||||
init_cache_dir_side_effect,
|
||||
mock_get_cache_files,
|
||||
mock_init_cache_dir,
|
||||
)
|
||||
@ -33,11 +36,10 @@ _TRANSCRIPT = "test transcript"
|
||||
class BaseProvider:
|
||||
"""Mock STT provider."""
|
||||
|
||||
def __init__(self, hass: HomeAssistant, text: str) -> None:
|
||||
def __init__(self, text: str) -> None:
|
||||
"""Init test provider."""
|
||||
self.hass = hass
|
||||
self.text = text
|
||||
self.received = []
|
||||
self.received: list[bytes] = []
|
||||
|
||||
@property
|
||||
def supported_languages(self) -> list[str]:
|
||||
@ -115,7 +117,7 @@ class MockTTSProvider(tts.Provider):
|
||||
return ("mp3", b"")
|
||||
|
||||
|
||||
class MockTTS:
|
||||
class MockTTS(MockPlatform):
|
||||
"""A mock TTS platform."""
|
||||
|
||||
PLATFORM_SCHEMA = tts.PLATFORM_SCHEMA
|
||||
@ -131,15 +133,15 @@ class MockTTS:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_stt_provider(hass) -> MockSttProvider:
|
||||
async def mock_stt_provider() -> MockSttProvider:
|
||||
"""Mock STT provider."""
|
||||
return MockSttProvider(hass, _TRANSCRIPT)
|
||||
return MockSttProvider(_TRANSCRIPT)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stt_provider_entity(hass) -> MockSttProviderEntity:
|
||||
def mock_stt_provider_entity() -> MockSttProviderEntity:
|
||||
"""Test provider entity fixture."""
|
||||
return MockSttProviderEntity(hass, _TRANSCRIPT)
|
||||
return MockSttProviderEntity(_TRANSCRIPT)
|
||||
|
||||
|
||||
class MockSttPlatform(MockPlatform):
|
||||
@ -170,8 +172,9 @@ async def init_components(
|
||||
mock_stt_provider: MockSttProvider,
|
||||
mock_stt_provider_entity: MockSttProviderEntity,
|
||||
config_flow_fixture,
|
||||
init_cache_dir_side_effect, # noqa: F811
|
||||
mock_get_cache_files, # noqa: F811
|
||||
mock_init_cache_dir, # noqa: F811,
|
||||
mock_init_cache_dir, # noqa: F811
|
||||
):
|
||||
"""Initialize relevant components with empty configs."""
|
||||
|
||||
|
@ -5,30 +5,50 @@ from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import media_source
|
||||
from homeassistant.components.tts import (
|
||||
CONF_LANG,
|
||||
DOMAIN as TTS_DOMAIN,
|
||||
PLATFORM_SCHEMA,
|
||||
Provider,
|
||||
TextToSpeechEntity,
|
||||
TtsAudioType,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockPlatform
|
||||
|
||||
SUPPORT_LANGUAGES = ["de_CH", "de_DE", "en_GB", "en_US"]
|
||||
TEST_LANGUAGES = ["de", "en"]
|
||||
from tests.common import (
|
||||
MockConfigEntry,
|
||||
MockModule,
|
||||
MockPlatform,
|
||||
mock_integration,
|
||||
mock_platform,
|
||||
)
|
||||
|
||||
DEFAULT_LANG = "en_US"
|
||||
SUPPORT_LANGUAGES = ["de_CH", "de_DE", "en_GB", "en_US"]
|
||||
TEST_DOMAIN = "test"
|
||||
TEST_LANGUAGES = ["de", "en"]
|
||||
|
||||
|
||||
class MockProvider(Provider):
|
||||
async def get_media_source_url(hass: HomeAssistant, media_content_id: str) -> str:
|
||||
"""Get the media source url."""
|
||||
if media_source.DOMAIN not in hass.config.components:
|
||||
assert await async_setup_component(hass, media_source.DOMAIN, {})
|
||||
|
||||
resolved = await media_source.async_resolve_media(hass, media_content_id, None)
|
||||
return resolved.url
|
||||
|
||||
|
||||
class BaseProvider:
|
||||
"""Test speech API provider."""
|
||||
|
||||
def __init__(self, lang: str) -> None:
|
||||
"""Initialize test provider."""
|
||||
self._lang = lang
|
||||
self.name = "Test"
|
||||
|
||||
@property
|
||||
def default_language(self) -> str:
|
||||
@ -59,6 +79,24 @@ class MockProvider(Provider):
|
||||
return ("mp3", b"")
|
||||
|
||||
|
||||
class MockProvider(BaseProvider, Provider):
|
||||
"""Test speech API provider."""
|
||||
|
||||
def __init__(self, lang: str) -> None:
|
||||
"""Initialize test provider."""
|
||||
super().__init__(lang)
|
||||
self.name = "Test"
|
||||
|
||||
|
||||
class MockTTSEntity(BaseProvider, TextToSpeechEntity):
|
||||
"""Test speech API provider."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return the name of the entity."""
|
||||
return "Test"
|
||||
|
||||
|
||||
class MockTTS(MockPlatform):
|
||||
"""A mock TTS platform."""
|
||||
|
||||
@ -70,13 +108,9 @@ class MockTTS(MockPlatform):
|
||||
}
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, provider: type[MockProvider] | None = None, **kwargs: Any
|
||||
) -> None:
|
||||
def __init__(self, provider: MockProvider, **kwargs: Any) -> None:
|
||||
"""Initialize."""
|
||||
super().__init__(**kwargs)
|
||||
if provider is None:
|
||||
provider = MockProvider
|
||||
self._provider = provider
|
||||
|
||||
async def async_get_engine(
|
||||
@ -86,4 +120,65 @@ class MockTTS(MockPlatform):
|
||||
discovery_info: DiscoveryInfoType | None = None,
|
||||
) -> Provider | None:
|
||||
"""Set up a mock speech component."""
|
||||
return self._provider(config.get(CONF_LANG, DEFAULT_LANG))
|
||||
return self._provider
|
||||
|
||||
|
||||
async def mock_setup(
|
||||
hass: HomeAssistant,
|
||||
mock_provider: MockProvider,
|
||||
) -> None:
|
||||
"""Set up a test provider."""
|
||||
mock_integration(hass, MockModule(domain=TEST_DOMAIN))
|
||||
mock_platform(hass, f"{TEST_DOMAIN}.{TTS_DOMAIN}", MockTTS(mock_provider))
|
||||
|
||||
await async_setup_component(
|
||||
hass, TTS_DOMAIN, {TTS_DOMAIN: {"platform": TEST_DOMAIN}}
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
async def mock_config_entry_setup(
|
||||
hass: HomeAssistant, tts_entity: MockTTSEntity
|
||||
) -> MockConfigEntry:
|
||||
"""Set up a test tts platform via config entry."""
|
||||
|
||||
async def async_setup_entry_init(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry
|
||||
) -> bool:
|
||||
"""Set up test config entry."""
|
||||
await hass.config_entries.async_forward_entry_setup(config_entry, TTS_DOMAIN)
|
||||
return True
|
||||
|
||||
async def async_unload_entry_init(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry
|
||||
) -> bool:
|
||||
"""Unload up test config entry."""
|
||||
await hass.config_entries.async_forward_entry_unload(config_entry, TTS_DOMAIN)
|
||||
return True
|
||||
|
||||
mock_integration(
|
||||
hass,
|
||||
MockModule(
|
||||
TEST_DOMAIN,
|
||||
async_setup_entry=async_setup_entry_init,
|
||||
async_unload_entry=async_unload_entry_init,
|
||||
),
|
||||
)
|
||||
|
||||
async def async_setup_entry_platform(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up test tts platform via config entry."""
|
||||
async_add_entities([tts_entity])
|
||||
|
||||
loaded_platform = MockPlatform(async_setup_entry=async_setup_entry_platform)
|
||||
mock_platform(hass, f"{TEST_DOMAIN}.{TTS_DOMAIN}", loaded_platform)
|
||||
|
||||
config_entry = MockConfigEntry(domain=TEST_DOMAIN)
|
||||
config_entry.add_to_hass(hass)
|
||||
assert await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
return config_entry
|
||||
|
@ -2,17 +2,28 @@
|
||||
|
||||
From http://doc.pytest.org/en/latest/example/simple.html#making-test-result-information-available-in-fixtures
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.tts import _get_cache_files
|
||||
from homeassistant.config import async_process_ha_core_config
|
||||
from homeassistant.config_entries import ConfigFlow
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .common import MockTTS
|
||||
from .common import (
|
||||
DEFAULT_LANG,
|
||||
TEST_DOMAIN,
|
||||
MockProvider,
|
||||
MockTTS,
|
||||
MockTTSEntity,
|
||||
mock_config_entry_setup,
|
||||
mock_setup,
|
||||
)
|
||||
|
||||
from tests.common import MockModule, mock_integration, mock_platform
|
||||
from tests.common import MockModule, mock_config_flow, mock_integration, mock_platform
|
||||
|
||||
|
||||
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
|
||||
@ -37,15 +48,30 @@ def mock_get_cache_files():
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_init_cache_dir():
|
||||
def mock_init_cache_dir(
|
||||
init_cache_dir_side_effect: Any,
|
||||
) -> Generator[MagicMock, None, None]:
|
||||
"""Mock the TTS cache dir in memory."""
|
||||
with patch(
|
||||
"homeassistant.components.tts._init_tts_cache_dir",
|
||||
side_effect=lambda hass, cache_dir: hass.config.path(cache_dir),
|
||||
side_effect=init_cache_dir_side_effect,
|
||||
) as mock_cache_dir:
|
||||
yield mock_cache_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def init_cache_dir_side_effect(
|
||||
hass: HomeAssistant,
|
||||
) -> Callable[[HomeAssistant, str], str]:
|
||||
"""Return the cache dir."""
|
||||
|
||||
def side_effect(hass: HomeAssistant, cache_dir: str) -> str:
|
||||
"""Return the cache dir."""
|
||||
return hass.config.path(cache_dir)
|
||||
|
||||
return side_effect
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def empty_cache_dir(tmp_path, mock_init_cache_dir, mock_get_cache_files, request):
|
||||
"""Mock the TTS cache dir with empty dir."""
|
||||
@ -89,7 +115,48 @@ async def internal_url_mock(hass: HomeAssistant) -> None:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_tts(hass: HomeAssistant) -> None:
|
||||
async def mock_tts(hass: HomeAssistant, mock_provider) -> None:
|
||||
"""Mock TTS."""
|
||||
mock_integration(hass, MockModule(domain="test"))
|
||||
mock_platform(hass, "test.tts", MockTTS())
|
||||
mock_platform(hass, "test.tts", MockTTS(mock_provider))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider() -> MockProvider:
|
||||
"""Test TTS provider."""
|
||||
return MockProvider(DEFAULT_LANG)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tts_entity() -> MockTTSEntity:
|
||||
"""Test TTS entity."""
|
||||
return MockTTSEntity(DEFAULT_LANG)
|
||||
|
||||
|
||||
class TTSFlow(ConfigFlow):
|
||||
"""Test flow."""
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def config_flow_fixture(hass: HomeAssistant) -> Generator[None, None, None]:
|
||||
"""Mock config flow."""
|
||||
mock_platform(hass, f"{TEST_DOMAIN}.config_flow")
|
||||
|
||||
with mock_config_flow(TEST_DOMAIN, TTSFlow):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(name="setup")
|
||||
async def setup_fixture(
|
||||
hass: HomeAssistant,
|
||||
request: pytest.FixtureRequest,
|
||||
mock_provider: MockProvider,
|
||||
mock_tts_entity: MockTTSEntity,
|
||||
) -> None:
|
||||
"""Set up the test environment."""
|
||||
if request.param == "mock_setup":
|
||||
await mock_setup(hass, mock_provider)
|
||||
elif request.param == "mock_config_entry_setup":
|
||||
await mock_config_entry_setup(hass, mock_tts_entity)
|
||||
else:
|
||||
raise RuntimeError("Invalid setup fixture")
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -3,32 +3,51 @@ from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.tts import DOMAIN, Provider
|
||||
from homeassistant.components.media_player import (
|
||||
ATTR_MEDIA_CONTENT_ID,
|
||||
ATTR_MEDIA_CONTENT_TYPE,
|
||||
DOMAIN as DOMAIN_MP,
|
||||
SERVICE_PLAY_MEDIA,
|
||||
MediaType,
|
||||
)
|
||||
from homeassistant.components.tts import ATTR_MESSAGE, DOMAIN, Provider
|
||||
from homeassistant.const import ATTR_ENTITY_ID
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.discovery import async_load_platform
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from .common import MockTTS
|
||||
from .common import SUPPORT_LANGUAGES, MockProvider, MockTTS, get_media_source_url
|
||||
|
||||
from tests.common import (
|
||||
MockModule,
|
||||
assert_setup_component,
|
||||
async_mock_service,
|
||||
mock_integration,
|
||||
mock_platform,
|
||||
)
|
||||
|
||||
|
||||
class DefaultProvider(Provider):
|
||||
"""Test provider."""
|
||||
|
||||
@property
|
||||
def supported_languages(self) -> list[str]:
|
||||
"""Return a list of supported languages."""
|
||||
return SUPPORT_LANGUAGES
|
||||
|
||||
|
||||
async def test_default_provider_attributes() -> None:
|
||||
"""Test default provider properties."""
|
||||
provider = Provider()
|
||||
"""Test default provider attributes."""
|
||||
provider = DefaultProvider()
|
||||
|
||||
assert provider.hass is None
|
||||
assert provider.name is None
|
||||
assert provider.default_language is None
|
||||
assert provider.supported_languages is None
|
||||
assert provider.supported_languages == SUPPORT_LANGUAGES
|
||||
assert provider.supported_options is None
|
||||
assert provider.default_options is None
|
||||
assert provider.async_get_supported_voices("test") is None
|
||||
|
||||
|
||||
async def test_deprecated_platform(hass: HomeAssistant) -> None:
|
||||
@ -56,8 +75,7 @@ async def test_invalid_platform(
|
||||
|
||||
|
||||
async def test_platform_setup_without_provider(
|
||||
hass: HomeAssistant,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
hass: HomeAssistant, caplog: pytest.LogCaptureFixture, mock_provider: MockProvider
|
||||
) -> None:
|
||||
"""Test platform setup without provider returned."""
|
||||
|
||||
@ -74,7 +92,7 @@ async def test_platform_setup_without_provider(
|
||||
return None
|
||||
|
||||
mock_integration(hass, MockModule(domain="bad_tts"))
|
||||
mock_platform(hass, "bad_tts.tts", BadPlatform())
|
||||
mock_platform(hass, "bad_tts.tts", BadPlatform(mock_provider))
|
||||
|
||||
await async_load_platform(
|
||||
hass,
|
||||
@ -91,6 +109,7 @@ async def test_platform_setup_without_provider(
|
||||
async def test_platform_setup_with_error(
|
||||
hass: HomeAssistant,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
mock_provider: MockProvider,
|
||||
) -> None:
|
||||
"""Test platform setup with an error during setup."""
|
||||
|
||||
@ -107,7 +126,7 @@ async def test_platform_setup_with_error(
|
||||
raise Exception("Setup error") # pylint: disable=broad-exception-raised
|
||||
|
||||
mock_integration(hass, MockModule(domain="bad_tts"))
|
||||
mock_platform(hass, "bad_tts.tts", BadPlatform())
|
||||
mock_platform(hass, "bad_tts.tts", BadPlatform(mock_provider))
|
||||
|
||||
await async_load_platform(
|
||||
hass,
|
||||
@ -119,3 +138,58 @@ async def test_platform_setup_with_error(
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert "Error setting up platform: bad_tts" in caplog.text
|
||||
|
||||
|
||||
async def test_service_base_url_set(hass: HomeAssistant, mock_tts) -> None:
|
||||
"""Set up a TTS platform with ``base_url`` set and call service."""
|
||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
||||
config = {DOMAIN: {"platform": "test", "base_url": "http://fnord"}}
|
||||
|
||||
with assert_setup_component(1, DOMAIN):
|
||||
assert await async_setup_component(hass, DOMAIN, config)
|
||||
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
"test_say",
|
||||
{
|
||||
ATTR_ENTITY_ID: "media_player.something",
|
||||
ATTR_MESSAGE: "There is someone at the door.",
|
||||
},
|
||||
blocking=True,
|
||||
)
|
||||
assert len(calls) == 1
|
||||
assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MediaType.MUSIC
|
||||
assert (
|
||||
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== "http://fnord"
|
||||
"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491"
|
||||
"_en-us_-_test.mp3"
|
||||
)
|
||||
|
||||
|
||||
async def test_service_without_cache_config(
|
||||
hass: HomeAssistant, empty_cache_dir, mock_tts
|
||||
) -> None:
|
||||
"""Set up a TTS platform without cache."""
|
||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
||||
config = {DOMAIN: {"platform": "test", "cache": False}}
|
||||
|
||||
with assert_setup_component(1, DOMAIN):
|
||||
assert await async_setup_component(hass, DOMAIN, config)
|
||||
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
"test_say",
|
||||
{
|
||||
ATTR_ENTITY_ID: "media_player.something",
|
||||
ATTR_MESSAGE: "There is someone at the door.",
|
||||
},
|
||||
blocking=True,
|
||||
)
|
||||
assert len(calls) == 1
|
||||
await hass.async_block_till_done()
|
||||
assert not (
|
||||
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3"
|
||||
).is_file()
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""Tests for TTS media source."""
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
@ -8,33 +8,52 @@ from homeassistant.components.media_player.errors import BrowseError
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from .common import (
|
||||
DEFAULT_LANG,
|
||||
MockProvider,
|
||||
MockTTSEntity,
|
||||
mock_config_entry_setup,
|
||||
mock_setup,
|
||||
)
|
||||
|
||||
|
||||
class MSEntity(MockTTSEntity):
|
||||
"""Test speech API entity."""
|
||||
|
||||
get_tts_audio = MagicMock(return_value=("mp3", b""))
|
||||
|
||||
|
||||
class MSProvider(MockProvider):
|
||||
"""Test speech API provider."""
|
||||
|
||||
get_tts_audio = MagicMock(return_value=("mp3", b""))
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def mock_get_tts_audio(hass):
|
||||
async def setup_media_source(hass: HomeAssistant) -> None:
|
||||
"""Set up media source."""
|
||||
assert await async_setup_component(hass, "media_source", {})
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"tts",
|
||||
{
|
||||
"tts": {
|
||||
"platform": "demo",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.demo.tts.DemoProvider.get_tts_audio",
|
||||
return_value=("mp3", b""),
|
||||
) as mock_get_tts:
|
||||
yield mock_get_tts
|
||||
|
||||
|
||||
async def test_browsing(hass: HomeAssistant) -> None:
|
||||
@pytest.mark.parametrize(
|
||||
("mock_provider", "mock_tts_entity"),
|
||||
[(MSProvider(DEFAULT_LANG), MSEntity(DEFAULT_LANG))],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"setup",
|
||||
[
|
||||
"mock_setup",
|
||||
"mock_config_entry_setup",
|
||||
],
|
||||
indirect=["setup"],
|
||||
)
|
||||
async def test_browsing(hass: HomeAssistant, setup: str) -> None:
|
||||
"""Test browsing TTS media source."""
|
||||
item = await media_source.async_browse_media(hass, "media-source://tts")
|
||||
|
||||
assert item is not None
|
||||
assert item.title == "Text to Speech"
|
||||
assert item.children is not None
|
||||
assert len(item.children) == 1
|
||||
assert item.can_play is False
|
||||
assert item.can_expand is True
|
||||
@ -42,9 +61,10 @@ async def test_browsing(hass: HomeAssistant) -> None:
|
||||
item_child = await media_source.async_browse_media(
|
||||
hass, item.children[0].media_content_id
|
||||
)
|
||||
|
||||
assert item_child is not None
|
||||
assert item_child.media_content_id == item.children[0].media_content_id
|
||||
assert item_child.title == "Demo"
|
||||
assert item_child.title == "Test"
|
||||
assert item_child.children is None
|
||||
assert item_child.can_play is False
|
||||
assert item_child.can_expand is True
|
||||
@ -52,12 +72,13 @@ async def test_browsing(hass: HomeAssistant) -> None:
|
||||
item_child = await media_source.async_browse_media(
|
||||
hass, item.children[0].media_content_id + "?message=bla"
|
||||
)
|
||||
|
||||
assert item_child is not None
|
||||
assert (
|
||||
item_child.media_content_id
|
||||
== item.children[0].media_content_id + "?message=bla"
|
||||
)
|
||||
assert item_child.title == "Demo"
|
||||
assert item_child.title == "Test"
|
||||
assert item_child.children is None
|
||||
assert item_child.can_play is False
|
||||
assert item_child.can_expand is True
|
||||
@ -66,10 +87,14 @@ async def test_browsing(hass: HomeAssistant) -> None:
|
||||
await media_source.async_browse_media(hass, "media-source://tts/non-existing")
|
||||
|
||||
|
||||
async def test_resolving(hass: HomeAssistant, mock_get_tts_audio) -> None:
|
||||
"""Test resolving."""
|
||||
@pytest.mark.parametrize("mock_provider", [MSProvider(DEFAULT_LANG)])
|
||||
async def test_legacy_resolving(hass: HomeAssistant, mock_provider: MSProvider) -> None:
|
||||
"""Test resolving legacy provider."""
|
||||
await mock_setup(hass, mock_provider)
|
||||
mock_get_tts_audio = mock_provider.get_tts_audio
|
||||
|
||||
media = await media_source.async_resolve_media(
|
||||
hass, "media-source://tts/demo?message=Hello%20World", None
|
||||
hass, "media-source://tts/test?message=Hello%20World", None
|
||||
)
|
||||
assert media.url.startswith("/api/tts_proxy/")
|
||||
assert media.mime_type == "audio/mpeg"
|
||||
@ -77,14 +102,14 @@ async def test_resolving(hass: HomeAssistant, mock_get_tts_audio) -> None:
|
||||
assert len(mock_get_tts_audio.mock_calls) == 1
|
||||
message, language = mock_get_tts_audio.mock_calls[0][1]
|
||||
assert message == "Hello World"
|
||||
assert language == "en"
|
||||
assert language == "en_US"
|
||||
assert mock_get_tts_audio.mock_calls[0][2]["options"] is None
|
||||
|
||||
# Pass language and options
|
||||
mock_get_tts_audio.reset_mock()
|
||||
media = await media_source.async_resolve_media(
|
||||
hass,
|
||||
"media-source://tts/demo?message=Bye%20World&language=de&voice=Paulus",
|
||||
"media-source://tts/test?message=Bye%20World&language=de&voice=Paulus",
|
||||
None,
|
||||
)
|
||||
assert media.url.startswith("/api/tts_proxy/")
|
||||
@ -93,15 +118,62 @@ async def test_resolving(hass: HomeAssistant, mock_get_tts_audio) -> None:
|
||||
assert len(mock_get_tts_audio.mock_calls) == 1
|
||||
message, language = mock_get_tts_audio.mock_calls[0][1]
|
||||
assert message == "Bye World"
|
||||
assert language == "de"
|
||||
assert language == "de_DE"
|
||||
assert mock_get_tts_audio.mock_calls[0][2]["options"] == {"voice": "Paulus"}
|
||||
|
||||
|
||||
async def test_resolving_errors(hass: HomeAssistant) -> None:
|
||||
@pytest.mark.parametrize("mock_tts_entity", [MSEntity(DEFAULT_LANG)])
|
||||
async def test_resolving(hass: HomeAssistant, mock_tts_entity: MSEntity) -> None:
|
||||
"""Test resolving entity."""
|
||||
await mock_config_entry_setup(hass, mock_tts_entity)
|
||||
mock_get_tts_audio = mock_tts_entity.get_tts_audio
|
||||
|
||||
media = await media_source.async_resolve_media(
|
||||
hass, "media-source://tts/tts.test?message=Hello%20World", None
|
||||
)
|
||||
assert media.url.startswith("/api/tts_proxy/")
|
||||
assert media.mime_type == "audio/mpeg"
|
||||
|
||||
assert len(mock_get_tts_audio.mock_calls) == 1
|
||||
message, language = mock_get_tts_audio.mock_calls[0][1]
|
||||
assert message == "Hello World"
|
||||
assert language == "en_US"
|
||||
assert mock_get_tts_audio.mock_calls[0][2]["options"] is None
|
||||
|
||||
# Pass language and options
|
||||
mock_get_tts_audio.reset_mock()
|
||||
media = await media_source.async_resolve_media(
|
||||
hass,
|
||||
"media-source://tts/tts.test?message=Bye%20World&language=de&voice=Paulus",
|
||||
None,
|
||||
)
|
||||
assert media.url.startswith("/api/tts_proxy/")
|
||||
assert media.mime_type == "audio/mpeg"
|
||||
|
||||
assert len(mock_get_tts_audio.mock_calls) == 1
|
||||
message, language = mock_get_tts_audio.mock_calls[0][1]
|
||||
assert message == "Bye World"
|
||||
assert language == "de_DE"
|
||||
assert mock_get_tts_audio.mock_calls[0][2]["options"] == {"voice": "Paulus"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("mock_provider", "mock_tts_entity"),
|
||||
[(MSProvider(DEFAULT_LANG), MSEntity(DEFAULT_LANG))],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"setup",
|
||||
[
|
||||
"mock_setup",
|
||||
"mock_config_entry_setup",
|
||||
],
|
||||
indirect=["setup"],
|
||||
)
|
||||
async def test_resolving_errors(hass: HomeAssistant, setup: str) -> None:
|
||||
"""Test resolving."""
|
||||
# No message added
|
||||
with pytest.raises(media_source.Unresolvable):
|
||||
await media_source.async_resolve_media(hass, "media-source://tts/demo", None)
|
||||
await media_source.async_resolve_media(hass, "media-source://tts/test", None)
|
||||
|
||||
# Non-existing provider
|
||||
with pytest.raises(media_source.Unresolvable):
|
||||
|
@ -1,28 +1,22 @@
|
||||
"""The tests for the TTS component."""
|
||||
import pytest
|
||||
import yarl
|
||||
|
||||
import homeassistant.components.media_player as media_player
|
||||
from homeassistant.components import media_player, notify, tts
|
||||
from homeassistant.components.media_player import (
|
||||
DOMAIN as DOMAIN_MP,
|
||||
SERVICE_PLAY_MEDIA,
|
||||
)
|
||||
import homeassistant.components.notify as notify
|
||||
import homeassistant.components.tts as tts
|
||||
from homeassistant.config import async_process_ha_core_config
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from .common import MockTTSEntity, mock_config_entry_setup
|
||||
|
||||
from tests.common import assert_setup_component, async_mock_service
|
||||
|
||||
|
||||
def relative_url(url):
|
||||
"""Convert an absolute url to a relative one."""
|
||||
return str(yarl.URL(url).relative())
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def internal_url_mock(hass):
|
||||
async def internal_url_mock(hass: HomeAssistant) -> None:
|
||||
"""Mock internal URL of the instance."""
|
||||
await async_process_ha_core_config(
|
||||
hass,
|
||||
@ -30,8 +24,8 @@ async def internal_url_mock(hass):
|
||||
)
|
||||
|
||||
|
||||
async def test_setup_platform(hass: HomeAssistant) -> None:
|
||||
"""Set up the tts platform ."""
|
||||
async def test_setup_legacy_platform(hass: HomeAssistant) -> None:
|
||||
"""Set up the tts notify platform ."""
|
||||
config = {
|
||||
notify.DOMAIN: {
|
||||
"platform": "tts",
|
||||
@ -46,7 +40,23 @@ async def test_setup_platform(hass: HomeAssistant) -> None:
|
||||
assert hass.services.has_service(notify.DOMAIN, "tts_test")
|
||||
|
||||
|
||||
async def test_setup_component_and_test_service(hass: HomeAssistant) -> None:
|
||||
async def test_setup_platform(hass: HomeAssistant) -> None:
|
||||
"""Set up the tts notify platform ."""
|
||||
config = {
|
||||
notify.DOMAIN: {
|
||||
"platform": "tts",
|
||||
"name": "tts_test",
|
||||
"entity_id": "tts.test",
|
||||
"media_player": "media_player.demo",
|
||||
}
|
||||
}
|
||||
with assert_setup_component(1, notify.DOMAIN):
|
||||
assert await async_setup_component(hass, notify.DOMAIN, config)
|
||||
|
||||
assert hass.services.has_service(notify.DOMAIN, "tts_test")
|
||||
|
||||
|
||||
async def test_setup_legacy_service(hass: HomeAssistant) -> None:
|
||||
"""Set up the demo platform and call service."""
|
||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
||||
@ -80,3 +90,38 @@ async def test_setup_component_and_test_service(hass: HomeAssistant) -> None:
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(calls) == 1
|
||||
|
||||
|
||||
async def test_setup_service(
|
||||
hass: HomeAssistant, mock_tts_entity: MockTTSEntity
|
||||
) -> None:
|
||||
"""Set up platform and call service."""
|
||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
||||
config = {
|
||||
notify.DOMAIN: {
|
||||
"platform": "tts",
|
||||
"name": "tts_test",
|
||||
"entity_id": "tts.test",
|
||||
"media_player": "media_player.demo",
|
||||
"language": "en",
|
||||
},
|
||||
}
|
||||
|
||||
await mock_config_entry_setup(hass, mock_tts_entity)
|
||||
|
||||
with assert_setup_component(1, notify.DOMAIN):
|
||||
assert await async_setup_component(hass, notify.DOMAIN, config)
|
||||
|
||||
await hass.services.async_call(
|
||||
notify.DOMAIN,
|
||||
"tts_test",
|
||||
{
|
||||
tts.ATTR_MESSAGE: "There is someone at the door.",
|
||||
},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(calls) == 1
|
||||
|
Loading…
x
Reference in New Issue
Block a user