Add tts entity (#91692)

* Add tts entity

* Allow passing engine id to url view

* Update async_resolve_engine

* Add and update more tests

* Fix assist pipeline tests temporarily

* Move fixtures

* Update notify platform

* Complete legacy tests

* Update media source tests

* Update async_get_text_to_speech_languages

* Address comment

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Martin Hjelmare 2023-04-21 04:55:46 +02:00 committed by GitHub
parent 458276a6a6
commit 1a18dc7425
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 1925 additions and 644 deletions

View File

@ -1,8 +1,11 @@
"""Provide functionality for TTS.""" """Provide functionality for TTS."""
from __future__ import annotations from __future__ import annotations
from abc import abstractmethod
import asyncio import asyncio
from collections.abc import Mapping
from datetime import datetime from datetime import datetime
from functools import partial
import hashlib import hashlib
from http import HTTPStatus from http import HTTPStatus
import io import io
@ -10,7 +13,7 @@ import logging
import mimetypes import mimetypes
import os import os
import re import re
from typing import Any, TypedDict from typing import Any, TypedDict, final
from aiohttp import web from aiohttp import web
import mutagen import mutagen
@ -19,13 +22,30 @@ import voluptuous as vol
from homeassistant.components import websocket_api from homeassistant.components import websocket_api
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http import HomeAssistantView
from homeassistant.const import PLATFORM_FORMAT from homeassistant.components.media_player import (
ATTR_MEDIA_ANNOUNCE,
ATTR_MEDIA_CONTENT_ID,
ATTR_MEDIA_CONTENT_TYPE,
DOMAIN as DOMAIN_MP,
SERVICE_PLAY_MEDIA,
MediaType,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import (
ATTR_ENTITY_ID,
PLATFORM_FORMAT,
STATE_UNAVAILABLE,
STATE_UNKNOWN,
)
from homeassistant.core import HassJob, HomeAssistant, ServiceCall, callback from homeassistant.core import HassJob, HomeAssistant, ServiceCall, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.event import async_call_later from homeassistant.helpers.event import async_call_later
from homeassistant.helpers.network import get_url from homeassistant.helpers.network import get_url
from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.util import language as language_util from homeassistant.util import dt as dt_util, language as language_util
from .const import ( from .const import (
ATTR_CACHE, ATTR_CACHE,
@ -36,12 +56,14 @@ from .const import (
CONF_CACHE, CONF_CACHE,
CONF_CACHE_DIR, CONF_CACHE_DIR,
CONF_TIME_MEMORY, CONF_TIME_MEMORY,
DATA_TTS_MANAGER,
DEFAULT_CACHE, DEFAULT_CACHE,
DEFAULT_CACHE_DIR, DEFAULT_CACHE_DIR,
DEFAULT_TIME_MEMORY, DEFAULT_TIME_MEMORY,
DOMAIN, DOMAIN,
TtsAudioType, TtsAudioType,
) )
from .helper import get_engine_instance
from .legacy import PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE, Provider, async_setup_legacy from .legacy import PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE, Provider, async_setup_legacy
from .media_source import generate_media_source_id, media_source_id_to_kwargs from .media_source import generate_media_source_id, media_source_id_to_kwargs
@ -64,6 +86,7 @@ _LOGGER = logging.getLogger(__name__)
ATTR_PLATFORM = "platform" ATTR_PLATFORM = "platform"
ATTR_AUDIO_OUTPUT = "audio_output" ATTR_AUDIO_OUTPUT = "audio_output"
ATTR_MEDIA_PLAYER_ENTITY_ID = "media_player_entity_id"
CONF_LANG = "language" CONF_LANG = "language"
@ -71,7 +94,12 @@ BASE_URL_KEY = "tts_base_url"
SERVICE_CLEAR_CACHE = "clear_cache" SERVICE_CLEAR_CACHE = "clear_cache"
_RE_VOICE_FILE = re.compile(r"([a-f0-9]{40})_([^_]+)_([^_]+)_([a-z_]+)\.[a-z0-9]{3,4}") _RE_LEGACY_VOICE_FILE = re.compile(
r"([a-f0-9]{40})_([^_]+)_([^_]+)_([a-z_]+)\.[a-z0-9]{3,4}"
)
_RE_VOICE_FILE = re.compile(
r"([a-f0-9]{40})_([^_]+)_([^_]+)_(tts\.[a-z_]+)\.[a-z0-9]{3,4}"
)
KEY_PATTERN = "{0}_{1}_{2}_{3}" KEY_PATTERN = "{0}_{1}_{2}_{3}"
SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({}) SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({})
@ -91,20 +119,23 @@ def async_resolve_engine(hass: HomeAssistant, engine: str | None) -> str | None:
Returns None if no engines found or invalid engine passed in. Returns None if no engines found or invalid engine passed in.
""" """
manager: SpeechManager = hass.data[DOMAIN] component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
if engine is not None: if engine is not None:
if engine not in manager.providers: if not component.get_entity(engine) and engine not in manager.providers:
return None return None
return engine return engine
if not manager.providers:
return None
if "cloud" in manager.providers: if "cloud" in manager.providers:
return "cloud" return "cloud"
return next(iter(manager.providers)) entity = next(iter(component.entities), None)
if entity is not None:
return entity.entity_id
return next(iter(manager.providers), None)
async def async_support_options( async def async_support_options(
@ -114,9 +145,13 @@ async def async_support_options(
options: dict | None = None, options: dict | None = None,
) -> bool: ) -> bool:
"""Return if an engine supports options.""" """Return if an engine supports options."""
manager: SpeechManager = hass.data[DOMAIN] if (engine_instance := get_engine_instance(hass, engine)) is None:
raise HomeAssistantError(f"Provider {engine} not found")
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
try: try:
manager.process_options(engine, language, options) manager.process_options(engine_instance, language, options)
except HomeAssistantError: except HomeAssistantError:
return False return False
@ -128,7 +163,7 @@ async def async_get_media_source_audio(
media_source_id: str, media_source_id: str,
) -> tuple[str, bytes]: ) -> tuple[str, bytes]:
"""Get TTS audio as extension, data.""" """Get TTS audio as extension, data."""
manager: SpeechManager = hass.data[DOMAIN] manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
return await manager.async_get_tts_audio( return await manager.async_get_tts_audio(
**media_source_id_to_kwargs(media_source_id), **media_source_id_to_kwargs(media_source_id),
) )
@ -139,7 +174,13 @@ def async_get_text_to_speech_languages(hass: HomeAssistant) -> set[str]:
"""Return a set with the union of languages supported by tts engines.""" """Return a set with the union of languages supported by tts engines."""
languages = set() languages = set()
manager: SpeechManager = hass.data[DOMAIN] component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
for entity in component.entities:
for language_tag in entity.supported_languages:
languages.add(language_tag)
for tts_engine in manager.providers.values(): for tts_engine in manager.providers.values():
for language_tag in tts_engine.supported_languages: for language_tag in tts_engine.supported_languages:
languages.add(language_tag) languages.add(language_tag)
@ -173,7 +214,13 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
_LOGGER.exception("Error on cache init") _LOGGER.exception("Error on cache init")
return False return False
hass.data[DOMAIN] = tts hass.data[DATA_TTS_MANAGER] = tts
component = hass.data[DOMAIN] = EntityComponent[TextToSpeechEntity](
_LOGGER, DOMAIN, hass
)
component.register_shutdown()
hass.http.register_view(TextToSpeechView(tts)) hass.http.register_view(TextToSpeechView(tts))
hass.http.register_view(TextToSpeechUrlView(tts)) hass.http.register_view(TextToSpeechUrlView(tts))
@ -182,6 +229,18 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
if platform_setups: if platform_setups:
await asyncio.wait([asyncio.create_task(setup) for setup in platform_setups]) await asyncio.wait([asyncio.create_task(setup) for setup in platform_setups])
component.async_register_entity_service(
"speak",
{
vol.Required(ATTR_MEDIA_PLAYER_ENTITY_ID): cv.comp_entity_ids,
vol.Required(ATTR_MESSAGE): cv.string,
vol.Optional(ATTR_CACHE, default=DEFAULT_CACHE): cv.boolean,
vol.Optional(ATTR_LANGUAGE): cv.string,
vol.Optional(ATTR_OPTIONS): dict,
},
"async_speak",
)
async def async_clear_cache_handle(service: ServiceCall) -> None: async def async_clear_cache_handle(service: ServiceCall) -> None:
"""Handle clear cache service call.""" """Handle clear cache service call."""
await tts.async_clear_cache() await tts.async_clear_cache()
@ -196,6 +255,129 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up a config entry."""
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
return await component.async_setup_entry(entry)
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
return await component.async_unload_entry(entry)
class TextToSpeechEntity(RestoreEntity):
"""Represent a single TTS engine."""
_attr_should_poll = False
__last_tts_loaded: str | None = None
@property
@final
def state(self) -> str | None:
"""Return the state of the entity."""
if self.__last_tts_loaded is None:
return None
return self.__last_tts_loaded
@property
@abstractmethod
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
@property
@abstractmethod
def default_language(self) -> str:
"""Return the default language."""
@property
def supported_options(self) -> list[str] | None:
"""Return a list of supported options like voice, emotions."""
return None
@property
def default_options(self) -> Mapping[str, Any] | None:
"""Return a mapping with the default options."""
return None
@callback
def async_get_supported_voices(self, language: str) -> list[str] | None:
"""Return a list of supported voices for a language."""
return None
async def async_internal_added_to_hass(self) -> None:
"""Call when the entity is added to hass."""
await super().async_internal_added_to_hass()
state = await self.async_get_last_state()
if (
state is not None
and state.state is not None
and state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
):
self.__last_tts_loaded = state.state
async def async_speak(
self,
media_player_entity_id: list[str],
message: str,
cache: bool,
language: str | None = None,
options: dict | None = None,
) -> None:
"""Speak via a Media Player."""
await self.hass.services.async_call(
DOMAIN_MP,
SERVICE_PLAY_MEDIA,
{
ATTR_ENTITY_ID: media_player_entity_id,
ATTR_MEDIA_CONTENT_ID: generate_media_source_id(
self.hass,
message=message,
engine=self.entity_id,
language=language,
options=options,
cache=cache,
),
ATTR_MEDIA_CONTENT_TYPE: MediaType.MUSIC,
ATTR_MEDIA_ANNOUNCE: True,
},
blocking=True,
context=self._context,
)
@final
async def internal_async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
) -> TtsAudioType:
"""Process an audio stream to TTS service.
Only streaming content is allowed!
"""
self.__last_tts_loaded = dt_util.utcnow().isoformat()
self.async_write_ha_state()
return await self.async_get_tts_audio(
message=message, language=language, options=options
)
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
) -> TtsAudioType:
"""Load tts audio file from the engine."""
raise NotImplementedError()
async def async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
) -> TtsAudioType:
"""Load tts audio file from the engine.
Return a tuple of file extension and data as bytes.
"""
return await self.hass.async_add_executor_job(
partial(self.get_tts_audio, message, language, options=options)
)
def _hash_options(options: dict) -> str: def _hash_options(options: dict) -> str:
"""Hashes an options dictionary.""" """Hashes an options dictionary."""
opts_hash = hashlib.blake2s(digest_size=5) opts_hash = hashlib.blake2s(digest_size=5)
@ -266,7 +448,7 @@ class SpeechManager:
def async_register_legacy_engine( def async_register_legacy_engine(
self, engine: str, provider: Provider, config: ConfigType self, engine: str, provider: Provider, config: ConfigType
) -> None: ) -> None:
"""Register a TTS provider.""" """Register a legacy TTS engine."""
provider.hass = self.hass provider.hass = self.hass
if provider.name is None: if provider.name is None:
provider.name = engine provider.name = engine
@ -279,23 +461,20 @@ class SpeechManager:
@callback @callback
def process_options( def process_options(
self, self,
engine: str, engine_instance: TextToSpeechEntity | Provider,
language: str | None = None, language: str | None = None,
options: dict | None = None, options: dict | None = None,
) -> tuple[str, dict | None]: ) -> tuple[str, dict | None]:
"""Validate and process options.""" """Validate and process options."""
if (provider := self.providers.get(engine)) is None:
raise HomeAssistantError(f"Provider {engine} not found")
# Languages # Languages
language = language or provider.default_language language = language or engine_instance.default_language
if language is None or provider.supported_languages is None: if language is None or engine_instance.supported_languages is None:
raise HomeAssistantError(f"Not supported language {language}") raise HomeAssistantError(f"Not supported language {language}")
if language not in provider.supported_languages: if language not in engine_instance.supported_languages:
language_matches = language_util.matches( language_matches = language_util.matches(
language, provider.supported_languages language, engine_instance.supported_languages
) )
if language_matches: if language_matches:
# Choose best match # Choose best match
@ -304,7 +483,7 @@ class SpeechManager:
raise HomeAssistantError(f"Not supported language {language}") raise HomeAssistantError(f"Not supported language {language}")
# Options # Options
if (default_options := provider.default_options) and options: if (default_options := engine_instance.default_options) and options:
merged_options = dict(default_options) merged_options = dict(default_options)
merged_options.update(options) merged_options.update(options)
options = merged_options options = merged_options
@ -312,7 +491,7 @@ class SpeechManager:
options = None if default_options is None else dict(default_options) options = None if default_options is None else dict(default_options)
if options is not None: if options is not None:
supported_options = provider.supported_options or [] supported_options = engine_instance.supported_options or []
invalid_opts = [ invalid_opts = [
opt_name for opt_name in options if opt_name not in supported_options opt_name for opt_name in options if opt_name not in supported_options
] ]
@ -333,7 +512,10 @@ class SpeechManager:
This method is a coroutine. This method is a coroutine.
""" """
language, options = self.process_options(engine, language, options) if (engine_instance := get_engine_instance(self.hass, engine)) is None:
raise HomeAssistantError(f"Provider {engine} not found")
language, options = self.process_options(engine_instance, language, options)
cache_key = self._generate_cache_key(message, language, options, engine) cache_key = self._generate_cache_key(message, language, options, engine)
use_cache = cache if cache is not None else self.use_cache use_cache = cache if cache is not None else self.use_cache
@ -344,10 +526,10 @@ class SpeechManager:
elif use_cache and cache_key in self.file_cache: elif use_cache and cache_key in self.file_cache:
filename = self.file_cache[cache_key] filename = self.file_cache[cache_key]
self.hass.async_create_task(self._async_file_to_mem(cache_key)) self.hass.async_create_task(self._async_file_to_mem(cache_key))
# Load speech from provider into memory # Load speech from engine into memory
else: else:
filename = await self._async_get_tts_audio( filename = await self._async_get_tts_audio(
engine, engine_instance,
cache_key, cache_key,
message, message,
use_cache, use_cache,
@ -366,7 +548,10 @@ class SpeechManager:
options: dict | None = None, options: dict | None = None,
) -> tuple[str, bytes]: ) -> tuple[str, bytes]:
"""Fetch TTS audio.""" """Fetch TTS audio."""
language, options = self.process_options(engine, language, options) if (engine_instance := get_engine_instance(self.hass, engine)) is None:
raise HomeAssistantError(f"Provider {engine} not found")
language, options = self.process_options(engine_instance, language, options)
cache_key = self._generate_cache_key(message, language, options, engine) cache_key = self._generate_cache_key(message, language, options, engine)
use_cache = cache if cache is not None else self.use_cache use_cache = cache if cache is not None else self.use_cache
@ -376,7 +561,7 @@ class SpeechManager:
await self._async_file_to_mem(cache_key) await self._async_file_to_mem(cache_key)
else: else:
await self._async_get_tts_audio( await self._async_get_tts_audio(
engine, cache_key, message, use_cache, language, options engine_instance, cache_key, message, use_cache, language, options
) )
extension = os.path.splitext(self.mem_cache[cache_key]["filename"])[1][1:] extension = os.path.splitext(self.mem_cache[cache_key]["filename"])[1][1:]
@ -403,7 +588,7 @@ class SpeechManager:
async def _async_get_tts_audio( async def _async_get_tts_audio(
self, self,
engine: str, engine_instance: TextToSpeechEntity | Provider,
cache_key: str, cache_key: str,
message: str, message: str,
cache: bool, cache: bool,
@ -414,8 +599,6 @@ class SpeechManager:
This method is a coroutine. This method is a coroutine.
""" """
provider = self.providers[engine]
if options is not None and ATTR_AUDIO_OUTPUT in options: if options is not None and ATTR_AUDIO_OUTPUT in options:
expected_extension = options[ATTR_AUDIO_OUTPUT] expected_extension = options[ATTR_AUDIO_OUTPUT]
else: else:
@ -423,26 +606,38 @@ class SpeechManager:
async def get_tts_data() -> str: async def get_tts_data() -> str:
"""Handle data available.""" """Handle data available."""
extension, data = await provider.async_get_tts_audio( if engine_instance.name is None:
raise HomeAssistantError("TTS engine name is not set.")
if isinstance(engine_instance, Provider):
extension, data = await engine_instance.async_get_tts_audio(
message, language, options
)
else:
extension, data = await engine_instance.internal_async_get_tts_audio(
message, language, options message, language, options
) )
if data is None or extension is None: if data is None or extension is None:
raise HomeAssistantError(f"No TTS from {engine} for '{message}'") raise HomeAssistantError(
f"No TTS from {engine_instance.name} for '{message}'"
)
# Create file infos # Create file infos
filename = f"{cache_key}.{extension}".lower() filename = f"{cache_key}.{extension}".lower()
# Validate filename # Validate filename
if not _RE_VOICE_FILE.match(filename): if not _RE_VOICE_FILE.match(filename) and not _RE_LEGACY_VOICE_FILE.match(
filename
):
raise HomeAssistantError( raise HomeAssistantError(
f"TTS filename '{filename}' from {engine} is invalid!" f"TTS filename '{filename}' from {engine_instance.name} is invalid!"
) )
# Save to memory # Save to memory
if extension == "mp3": if extension == "mp3":
data = self.write_tags( data = self.write_tags(
filename, data, provider, message, language, options filename, data, engine_instance.name, message, language, options
) )
self._async_store_to_memcache(cache_key, filename, data) self._async_store_to_memcache(cache_key, filename, data)
@ -547,7 +742,9 @@ class SpeechManager:
This method is a coroutine. This method is a coroutine.
""" """
if not (record := _RE_VOICE_FILE.match(filename.lower())): if not (record := _RE_VOICE_FILE.match(filename.lower())) and not (
record := _RE_LEGACY_VOICE_FILE.match(filename.lower())
):
raise HomeAssistantError("Wrong tts file format!") raise HomeAssistantError("Wrong tts file format!")
cache_key = KEY_PATTERN.format( cache_key = KEY_PATTERN.format(
@ -570,7 +767,7 @@ class SpeechManager:
def write_tags( def write_tags(
filename: str, filename: str,
data: bytes, data: bytes,
provider: Provider, engine_name: str,
message: str, message: str,
language: str, language: str,
options: dict | None, options: dict | None,
@ -584,7 +781,7 @@ class SpeechManager:
data_bytes.name = filename data_bytes.name = filename
data_bytes.seek(0) data_bytes.seek(0)
album = provider.name album = engine_name
artist = language artist = language
if options is not None and (voice := options.get("voice")) is not None: if options is not None and (voice := options.get("voice")) is not None:
@ -635,7 +832,9 @@ def _get_cache_files(cache_dir: str) -> dict[str, str]:
folder_data = os.listdir(cache_dir) folder_data = os.listdir(cache_dir)
for file_data in folder_data: for file_data in folder_data:
if record := _RE_VOICE_FILE.match(file_data): if (record := _RE_VOICE_FILE.match(file_data)) or (
record := _RE_LEGACY_VOICE_FILE.match(file_data)
):
key = KEY_PATTERN.format( key = KEY_PATTERN.format(
record.group(1), record.group(2), record.group(3), record.group(4) record.group(1), record.group(2), record.group(3), record.group(4)
) )
@ -660,12 +859,16 @@ class TextToSpeechUrlView(HomeAssistantView):
data = await request.json() data = await request.json()
except ValueError: except ValueError:
return self.json_message("Invalid JSON specified", HTTPStatus.BAD_REQUEST) return self.json_message("Invalid JSON specified", HTTPStatus.BAD_REQUEST)
if not data.get(ATTR_PLATFORM) and data.get(ATTR_MESSAGE): if (
not data.get("engine_id")
and not data.get(ATTR_PLATFORM)
or not data.get(ATTR_MESSAGE)
):
return self.json_message( return self.json_message(
"Must specify platform and message", HTTPStatus.BAD_REQUEST "Must specify platform and message", HTTPStatus.BAD_REQUEST
) )
p_type = data[ATTR_PLATFORM] engine = data.get("engine_id") or data[ATTR_PLATFORM]
message = data[ATTR_MESSAGE] message = data[ATTR_MESSAGE]
cache = data.get(ATTR_CACHE) cache = data.get(ATTR_CACHE)
language = data.get(ATTR_LANGUAGE) language = data.get(ATTR_LANGUAGE)
@ -673,7 +876,7 @@ class TextToSpeechUrlView(HomeAssistantView):
try: try:
path = await self.tts.async_get_url_path( path = await self.tts.async_get_url_path(
p_type, message, cache=cache, language=language, options=options engine, message, cache=cache, language=language, options=options
) )
except HomeAssistantError as err: except HomeAssistantError as err:
_LOGGER.error("Error on init tts: %s", err) _LOGGER.error("Error on init tts: %s", err)
@ -724,13 +927,26 @@ def websocket_list_engines(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None: ) -> None:
"""List text to speech engines and, optionally, if they support a given language.""" """List text to speech engines and, optionally, if they support a given language."""
manager: SpeechManager = hass.data[DOMAIN] component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
country = msg.get("country") country = msg.get("country")
language = msg.get("language") language = msg.get("language")
providers = [] providers = []
provider_info: dict[str, Any]
for entity in component.entities:
provider_info = {
"engine_id": entity.entity_id,
"supported_languages": entity.supported_languages,
}
if language:
provider_info["supported_languages"] = language_util.matches(
language, entity.supported_languages, country
)
providers.append(provider_info)
for engine_id, provider in manager.providers.items(): for engine_id, provider in manager.providers.items():
provider_info: dict[str, Any] = { provider_info = {
"engine_id": engine_id, "engine_id": engine_id,
"supported_languages": provider.supported_languages, "supported_languages": provider.supported_languages,
} }
@ -760,10 +976,9 @@ def websocket_list_engine_voices(
engine_id = msg["engine_id"] engine_id = msg["engine_id"]
language = msg["language"] language = msg["language"]
manager: SpeechManager = hass.data[DOMAIN] engine_instance = get_engine_instance(hass, engine_id)
engine = manager.providers.get(engine_id)
if not engine: if not engine_instance:
connection.send_error( connection.send_error(
msg["id"], msg["id"],
websocket_api.const.ERR_NOT_FOUND, websocket_api.const.ERR_NOT_FOUND,
@ -771,6 +986,6 @@ def websocket_list_engine_voices(
) )
return return
voices = {"voices": engine.async_get_supported_voices(language)} voices = {"voices": engine_instance.async_get_supported_voices(language)}
connection.send_message(websocket_api.result_message(msg["id"], voices)) connection.send_message(websocket_api.result_message(msg["id"], voices))

View File

@ -16,4 +16,6 @@ DEFAULT_TIME_MEMORY = 300
DOMAIN = "tts" DOMAIN = "tts"
DATA_TTS_MANAGER = "tts_manager"
TtsAudioType = tuple[str | None, bytes | None] TtsAudioType = tuple[str | None, bytes | None]

View File

@ -0,0 +1,26 @@
"""Provide helper functions for the TTS."""
from __future__ import annotations
from typing import TYPE_CHECKING
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_component import EntityComponent
from .const import DATA_TTS_MANAGER, DOMAIN
if TYPE_CHECKING:
from . import SpeechManager, TextToSpeechEntity
from .legacy import Provider
def get_engine_instance(
hass: HomeAssistant, engine: str
) -> TextToSpeechEntity | Provider | None:
"""Get engine instance."""
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
if entity := component.get_entity(engine):
return entity
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
return manager.providers.get(engine)

View File

@ -44,6 +44,7 @@ from .const import (
CONF_CACHE_DIR, CONF_CACHE_DIR,
CONF_FIELDS, CONF_FIELDS,
CONF_TIME_MEMORY, CONF_TIME_MEMORY,
DATA_TTS_MANAGER,
DEFAULT_CACHE, DEFAULT_CACHE,
DEFAULT_CACHE_DIR, DEFAULT_CACHE_DIR,
DEFAULT_TIME_MEMORY, DEFAULT_TIME_MEMORY,
@ -111,7 +112,7 @@ async def async_setup_legacy(
hass: HomeAssistant, config: ConfigType hass: HomeAssistant, config: ConfigType
) -> list[Coroutine[Any, Any, None]]: ) -> list[Coroutine[Any, Any, None]]:
"""Set up legacy text to speech providers.""" """Set up legacy text to speech providers."""
tts: SpeechManager = hass.data[DOMAIN] tts: SpeechManager = hass.data[DATA_TTS_MANAGER]
# Load service descriptions from tts/services.yaml # Load service descriptions from tts/services.yaml
services_yaml = Path(__file__).parent / "services.yaml" services_yaml = Path(__file__).parent / "services.yaml"

View File

@ -17,12 +17,14 @@ from homeassistant.components.media_source import (
) )
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.network import get_url from homeassistant.helpers.network import get_url
from .const import DOMAIN from .const import DATA_TTS_MANAGER, DOMAIN
from .helper import get_engine_instance
if TYPE_CHECKING: if TYPE_CHECKING:
from . import SpeechManager from . import SpeechManager, TextToSpeechEntity
async def async_get_media_source(hass: HomeAssistant) -> TTSMediaSource: async def async_get_media_source(hass: HomeAssistant) -> TTSMediaSource:
@ -42,12 +44,16 @@ def generate_media_source_id(
"""Generate a media source ID for text-to-speech.""" """Generate a media source ID for text-to-speech."""
from . import async_resolve_engine # pylint: disable=import-outside-toplevel from . import async_resolve_engine # pylint: disable=import-outside-toplevel
manager: SpeechManager = hass.data[DOMAIN] manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
if (engine := async_resolve_engine(hass, engine)) is None: if (engine := async_resolve_engine(hass, engine)) is None:
raise HomeAssistantError("Invalid TTS provider selected") raise HomeAssistantError("Invalid TTS provider selected")
manager.process_options(engine, language, options) engine_instance = get_engine_instance(hass, engine)
# We raise above if the engine is not resolved, so engine_instance can't be None
assert engine_instance is not None
manager.process_options(engine_instance, language, options)
params = { params = {
"message": message, "message": message,
} }
@ -107,7 +113,7 @@ class TTSMediaSource(MediaSource):
async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia: async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia:
"""Resolve media to a url.""" """Resolve media to a url."""
manager: SpeechManager = self.hass.data[DOMAIN] manager: SpeechManager = self.hass.data[DATA_TTS_MANAGER]
try: try:
url = await manager.async_get_url_path( url = await manager.async_get_url_path(
@ -129,12 +135,15 @@ class TTSMediaSource(MediaSource):
) -> BrowseMediaSource: ) -> BrowseMediaSource:
"""Return media.""" """Return media."""
if item.identifier: if item.identifier:
provider, _, params = item.identifier.partition("?") engine, _, params = item.identifier.partition("?")
return self._provider_item(provider, params) return self._engine_item(engine, params)
# Root. List providers. # Root. List providers.
manager: SpeechManager = self.hass.data[DOMAIN] manager: SpeechManager = self.hass.data[DATA_TTS_MANAGER]
children = [self._provider_item(provider) for provider in manager.providers] component: EntityComponent[TextToSpeechEntity] = self.hass.data[DOMAIN]
children = [self._engine_item(engine) for engine in manager.providers] + [
self._engine_item(entity.entity_id) for entity in component.entities
]
return BrowseMediaSource( return BrowseMediaSource(
domain=DOMAIN, domain=DOMAIN,
identifier=None, identifier=None,
@ -148,14 +157,19 @@ class TTSMediaSource(MediaSource):
) )
@callback @callback
def _provider_item( def _engine_item(self, engine: str, params: str | None = None) -> BrowseMediaSource:
self, provider_domain: str, params: str | None = None
) -> BrowseMediaSource:
"""Return provider item.""" """Return provider item."""
manager: SpeechManager = self.hass.data[DOMAIN] from . import TextToSpeechEntity # pylint: disable=import-outside-toplevel
if (provider := manager.providers.get(provider_domain)) is None:
if (engine_instance := get_engine_instance(self.hass, engine)) is None:
raise BrowseError("Unknown provider") raise BrowseError("Unknown provider")
if isinstance(engine_instance, TextToSpeechEntity):
assert engine_instance.platform is not None
engine_domain = engine_instance.platform.domain
else:
engine_domain = engine
if params: if params:
params = f"?{params}" params = f"?{params}"
else: else:
@ -163,11 +177,11 @@ class TTSMediaSource(MediaSource):
return BrowseMediaSource( return BrowseMediaSource(
domain=DOMAIN, domain=DOMAIN,
identifier=f"{provider_domain}{params}", identifier=f"{engine}{params}",
media_class=MediaClass.APP, media_class=MediaClass.APP,
media_content_type="provider", media_content_type="provider",
title=provider.name, title=engine_instance.name,
thumbnail=f"https://brands.home-assistant.io/_/{provider_domain}/logo.png", thumbnail=f"https://brands.home-assistant.io/_/{engine_domain}/logo.png",
can_play=False, can_play=False,
can_expand=True, can_expand=True,
) )

View File

@ -7,22 +7,26 @@ from typing import Any
import voluptuous as vol import voluptuous as vol
from homeassistant.components.notify import PLATFORM_SCHEMA, BaseNotificationService from homeassistant.components.notify import PLATFORM_SCHEMA, BaseNotificationService
from homeassistant.const import ATTR_ENTITY_ID, CONF_NAME from homeassistant.const import ATTR_ENTITY_ID, CONF_ENTITY_ID, CONF_NAME
from homeassistant.core import HomeAssistant, split_entity_id from homeassistant.core import HomeAssistant, split_entity_id
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import ATTR_LANGUAGE, ATTR_MESSAGE, DOMAIN from . import ATTR_LANGUAGE, ATTR_MEDIA_PLAYER_ENTITY_ID, ATTR_MESSAGE, DOMAIN
CONF_MEDIA_PLAYER = "media_player" CONF_MEDIA_PLAYER = "media_player"
CONF_TTS_SERVICE = "tts_service" CONF_TTS_SERVICE = "tts_service"
ENTITY_LEGACY_PROVIDER_GROUP = "entity_or_legacy_provider"
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend( PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
{ {
vol.Required(CONF_NAME): cv.string, vol.Required(CONF_NAME): cv.string,
vol.Required(CONF_TTS_SERVICE): cv.entity_id, vol.Exclusive(CONF_TTS_SERVICE, ENTITY_LEGACY_PROVIDER_GROUP): cv.entity_id,
vol.Exclusive(CONF_ENTITY_ID, ENTITY_LEGACY_PROVIDER_GROUP): cv.entities_domain(
DOMAIN
),
vol.Required(CONF_MEDIA_PLAYER): cv.entity_id, vol.Required(CONF_MEDIA_PLAYER): cv.entity_id,
vol.Optional(ATTR_LANGUAGE): cv.string, vol.Optional(ATTR_LANGUAGE): cv.string,
} }
@ -44,6 +48,11 @@ class TTSNotificationService(BaseNotificationService):
def __init__(self, config: ConfigType) -> None: def __init__(self, config: ConfigType) -> None:
"""Initialize the service.""" """Initialize the service."""
self._target: str | None = None
self._tts_service: str | None = None
if entity_id := config.get(CONF_ENTITY_ID):
self._target = entity_id
else:
_, self._tts_service = split_entity_id(config[CONF_TTS_SERVICE]) _, self._tts_service = split_entity_id(config[CONF_TTS_SERVICE])
self._media_player = config[CONF_MEDIA_PLAYER] self._media_player = config[CONF_MEDIA_PLAYER]
self._language = config.get(ATTR_LANGUAGE) self._language = config.get(ATTR_LANGUAGE)
@ -54,13 +63,21 @@ class TTSNotificationService(BaseNotificationService):
data = { data = {
ATTR_MESSAGE: message, ATTR_MESSAGE: message,
ATTR_ENTITY_ID: self._media_player,
} }
service_name = ""
if self._tts_service:
data[ATTR_ENTITY_ID] = self._media_player
service_name = self._tts_service
elif self._target:
data[ATTR_ENTITY_ID] = self._target
data[ATTR_MEDIA_PLAYER_ENTITY_ID] = self._media_player
service_name = "speak"
if self._language: if self._language:
data[ATTR_LANGUAGE] = self._language data[ATTR_LANGUAGE] = self._language
await self.hass.services.async_call( await self.hass.services.async_call(
DOMAIN, DOMAIN,
self._tts_service, service_name,
data, data,
) )

View File

@ -40,6 +40,49 @@ say:
selector: selector:
object: object:
speak:
name: Speak
description: Speak something using text-to-speech on a media player.
target:
entity:
domain: tts
fields:
media_player_entity_id:
name: Media Player Entity
description: Name(s) of media player entities.
required: true
selector:
entity:
domain: media_player
message:
name: Message
description: Text to speak on devices.
example: "My name is hanna"
required: true
selector:
text:
cache:
name: Cache
description: Control file cache of this message.
default: true
selector:
boolean:
language:
name: Language
description: Language to use for speech generation.
example: "ru"
selector:
text:
options:
name: Options
description:
A dictionary containing platform-specific options. Optional depending on
the platform.
advanced: true
example: platform specific
selector:
object:
clear_cache: clear_cache:
name: Clear TTS cache name: Clear TTS cache
description: Remove all text-to-speech cache files and RAM cache. description: Remove all text-to-speech cache files and RAM cache.

View File

@ -1,4 +1,6 @@
"""Test fixtures for voice assistant.""" """Test fixtures for voice assistant."""
from __future__ import annotations
from collections.abc import AsyncIterable, Generator from collections.abc import AsyncIterable, Generator
from typing import Any from typing import Any
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
@ -23,6 +25,7 @@ from tests.common import (
mock_platform, mock_platform,
) )
from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import
init_cache_dir_side_effect,
mock_get_cache_files, mock_get_cache_files,
mock_init_cache_dir, mock_init_cache_dir,
) )
@ -33,11 +36,10 @@ _TRANSCRIPT = "test transcript"
class BaseProvider: class BaseProvider:
"""Mock STT provider.""" """Mock STT provider."""
def __init__(self, hass: HomeAssistant, text: str) -> None: def __init__(self, text: str) -> None:
"""Init test provider.""" """Init test provider."""
self.hass = hass
self.text = text self.text = text
self.received = [] self.received: list[bytes] = []
@property @property
def supported_languages(self) -> list[str]: def supported_languages(self) -> list[str]:
@ -115,7 +117,7 @@ class MockTTSProvider(tts.Provider):
return ("mp3", b"") return ("mp3", b"")
class MockTTS: class MockTTS(MockPlatform):
"""A mock TTS platform.""" """A mock TTS platform."""
PLATFORM_SCHEMA = tts.PLATFORM_SCHEMA PLATFORM_SCHEMA = tts.PLATFORM_SCHEMA
@ -131,15 +133,15 @@ class MockTTS:
@pytest.fixture @pytest.fixture
async def mock_stt_provider(hass) -> MockSttProvider: async def mock_stt_provider() -> MockSttProvider:
"""Mock STT provider.""" """Mock STT provider."""
return MockSttProvider(hass, _TRANSCRIPT) return MockSttProvider(_TRANSCRIPT)
@pytest.fixture @pytest.fixture
def mock_stt_provider_entity(hass) -> MockSttProviderEntity: def mock_stt_provider_entity() -> MockSttProviderEntity:
"""Test provider entity fixture.""" """Test provider entity fixture."""
return MockSttProviderEntity(hass, _TRANSCRIPT) return MockSttProviderEntity(_TRANSCRIPT)
class MockSttPlatform(MockPlatform): class MockSttPlatform(MockPlatform):
@ -170,8 +172,9 @@ async def init_components(
mock_stt_provider: MockSttProvider, mock_stt_provider: MockSttProvider,
mock_stt_provider_entity: MockSttProviderEntity, mock_stt_provider_entity: MockSttProviderEntity,
config_flow_fixture, config_flow_fixture,
init_cache_dir_side_effect, # noqa: F811
mock_get_cache_files, # noqa: F811 mock_get_cache_files, # noqa: F811
mock_init_cache_dir, # noqa: F811, mock_init_cache_dir, # noqa: F811
): ):
"""Initialize relevant components with empty configs.""" """Initialize relevant components with empty configs."""

View File

@ -5,30 +5,50 @@ from typing import Any
import voluptuous as vol import voluptuous as vol
from homeassistant.components import media_source
from homeassistant.components.tts import ( from homeassistant.components.tts import (
CONF_LANG, CONF_LANG,
DOMAIN as TTS_DOMAIN,
PLATFORM_SCHEMA, PLATFORM_SCHEMA,
Provider, Provider,
TextToSpeechEntity,
TtsAudioType, TtsAudioType,
) )
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.setup import async_setup_component
from tests.common import MockPlatform from tests.common import (
MockConfigEntry,
SUPPORT_LANGUAGES = ["de_CH", "de_DE", "en_GB", "en_US"] MockModule,
TEST_LANGUAGES = ["de", "en"] MockPlatform,
mock_integration,
mock_platform,
)
DEFAULT_LANG = "en_US" DEFAULT_LANG = "en_US"
SUPPORT_LANGUAGES = ["de_CH", "de_DE", "en_GB", "en_US"]
TEST_DOMAIN = "test"
TEST_LANGUAGES = ["de", "en"]
class MockProvider(Provider): async def get_media_source_url(hass: HomeAssistant, media_content_id: str) -> str:
"""Get the media source url."""
if media_source.DOMAIN not in hass.config.components:
assert await async_setup_component(hass, media_source.DOMAIN, {})
resolved = await media_source.async_resolve_media(hass, media_content_id, None)
return resolved.url
class BaseProvider:
"""Test speech API provider.""" """Test speech API provider."""
def __init__(self, lang: str) -> None: def __init__(self, lang: str) -> None:
"""Initialize test provider.""" """Initialize test provider."""
self._lang = lang self._lang = lang
self.name = "Test"
@property @property
def default_language(self) -> str: def default_language(self) -> str:
@ -59,6 +79,24 @@ class MockProvider(Provider):
return ("mp3", b"") return ("mp3", b"")
class MockProvider(BaseProvider, Provider):
"""Test speech API provider."""
def __init__(self, lang: str) -> None:
"""Initialize test provider."""
super().__init__(lang)
self.name = "Test"
class MockTTSEntity(BaseProvider, TextToSpeechEntity):
"""Test speech API provider."""
@property
def name(self) -> str:
"""Return the name of the entity."""
return "Test"
class MockTTS(MockPlatform): class MockTTS(MockPlatform):
"""A mock TTS platform.""" """A mock TTS platform."""
@ -70,13 +108,9 @@ class MockTTS(MockPlatform):
} }
) )
def __init__( def __init__(self, provider: MockProvider, **kwargs: Any) -> None:
self, provider: type[MockProvider] | None = None, **kwargs: Any
) -> None:
"""Initialize.""" """Initialize."""
super().__init__(**kwargs) super().__init__(**kwargs)
if provider is None:
provider = MockProvider
self._provider = provider self._provider = provider
async def async_get_engine( async def async_get_engine(
@ -86,4 +120,65 @@ class MockTTS(MockPlatform):
discovery_info: DiscoveryInfoType | None = None, discovery_info: DiscoveryInfoType | None = None,
) -> Provider | None: ) -> Provider | None:
"""Set up a mock speech component.""" """Set up a mock speech component."""
return self._provider(config.get(CONF_LANG, DEFAULT_LANG)) return self._provider
async def mock_setup(
hass: HomeAssistant,
mock_provider: MockProvider,
) -> None:
"""Set up a test provider."""
mock_integration(hass, MockModule(domain=TEST_DOMAIN))
mock_platform(hass, f"{TEST_DOMAIN}.{TTS_DOMAIN}", MockTTS(mock_provider))
await async_setup_component(
hass, TTS_DOMAIN, {TTS_DOMAIN: {"platform": TEST_DOMAIN}}
)
await hass.async_block_till_done()
async def mock_config_entry_setup(
hass: HomeAssistant, tts_entity: MockTTSEntity
) -> MockConfigEntry:
"""Set up a test tts platform via config entry."""
async def async_setup_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Set up test config entry."""
await hass.config_entries.async_forward_entry_setup(config_entry, TTS_DOMAIN)
return True
async def async_unload_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Unload up test config entry."""
await hass.config_entries.async_forward_entry_unload(config_entry, TTS_DOMAIN)
return True
mock_integration(
hass,
MockModule(
TEST_DOMAIN,
async_setup_entry=async_setup_entry_init,
async_unload_entry=async_unload_entry_init,
),
)
async def async_setup_entry_platform(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up test tts platform via config entry."""
async_add_entities([tts_entity])
loaded_platform = MockPlatform(async_setup_entry=async_setup_entry_platform)
mock_platform(hass, f"{TEST_DOMAIN}.{TTS_DOMAIN}", loaded_platform)
config_entry = MockConfigEntry(domain=TEST_DOMAIN)
config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
return config_entry

View File

@ -2,17 +2,28 @@
From http://doc.pytest.org/en/latest/example/simple.html#making-test-result-information-available-in-fixtures From http://doc.pytest.org/en/latest/example/simple.html#making-test-result-information-available-in-fixtures
""" """
from unittest.mock import patch from collections.abc import Callable, Generator
from typing import Any
from unittest.mock import MagicMock, patch
import pytest import pytest
from homeassistant.components.tts import _get_cache_files from homeassistant.components.tts import _get_cache_files
from homeassistant.config import async_process_ha_core_config from homeassistant.config import async_process_ha_core_config
from homeassistant.config_entries import ConfigFlow
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from .common import MockTTS from .common import (
DEFAULT_LANG,
TEST_DOMAIN,
MockProvider,
MockTTS,
MockTTSEntity,
mock_config_entry_setup,
mock_setup,
)
from tests.common import MockModule, mock_integration, mock_platform from tests.common import MockModule, mock_config_flow, mock_integration, mock_platform
@pytest.hookimpl(tryfirst=True, hookwrapper=True) @pytest.hookimpl(tryfirst=True, hookwrapper=True)
@ -37,15 +48,30 @@ def mock_get_cache_files():
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_init_cache_dir(): def mock_init_cache_dir(
init_cache_dir_side_effect: Any,
) -> Generator[MagicMock, None, None]:
"""Mock the TTS cache dir in memory.""" """Mock the TTS cache dir in memory."""
with patch( with patch(
"homeassistant.components.tts._init_tts_cache_dir", "homeassistant.components.tts._init_tts_cache_dir",
side_effect=lambda hass, cache_dir: hass.config.path(cache_dir), side_effect=init_cache_dir_side_effect,
) as mock_cache_dir: ) as mock_cache_dir:
yield mock_cache_dir yield mock_cache_dir
@pytest.fixture
def init_cache_dir_side_effect(
hass: HomeAssistant,
) -> Callable[[HomeAssistant, str], str]:
"""Return the cache dir."""
def side_effect(hass: HomeAssistant, cache_dir: str) -> str:
"""Return the cache dir."""
return hass.config.path(cache_dir)
return side_effect
@pytest.fixture @pytest.fixture
def empty_cache_dir(tmp_path, mock_init_cache_dir, mock_get_cache_files, request): def empty_cache_dir(tmp_path, mock_init_cache_dir, mock_get_cache_files, request):
"""Mock the TTS cache dir with empty dir.""" """Mock the TTS cache dir with empty dir."""
@ -89,7 +115,48 @@ async def internal_url_mock(hass: HomeAssistant) -> None:
@pytest.fixture @pytest.fixture
async def mock_tts(hass: HomeAssistant) -> None: async def mock_tts(hass: HomeAssistant, mock_provider) -> None:
"""Mock TTS.""" """Mock TTS."""
mock_integration(hass, MockModule(domain="test")) mock_integration(hass, MockModule(domain="test"))
mock_platform(hass, "test.tts", MockTTS()) mock_platform(hass, "test.tts", MockTTS(mock_provider))
@pytest.fixture
def mock_provider() -> MockProvider:
"""Test TTS provider."""
return MockProvider(DEFAULT_LANG)
@pytest.fixture
def mock_tts_entity() -> MockTTSEntity:
"""Test TTS entity."""
return MockTTSEntity(DEFAULT_LANG)
class TTSFlow(ConfigFlow):
"""Test flow."""
@pytest.fixture(autouse=True)
def config_flow_fixture(hass: HomeAssistant) -> Generator[None, None, None]:
"""Mock config flow."""
mock_platform(hass, f"{TEST_DOMAIN}.config_flow")
with mock_config_flow(TEST_DOMAIN, TTSFlow):
yield
@pytest.fixture(name="setup")
async def setup_fixture(
hass: HomeAssistant,
request: pytest.FixtureRequest,
mock_provider: MockProvider,
mock_tts_entity: MockTTSEntity,
) -> None:
"""Set up the test environment."""
if request.param == "mock_setup":
await mock_setup(hass, mock_provider)
elif request.param == "mock_config_entry_setup":
await mock_config_entry_setup(hass, mock_tts_entity)
else:
raise RuntimeError("Invalid setup fixture")

File diff suppressed because it is too large Load Diff

View File

@ -3,32 +3,51 @@ from __future__ import annotations
import pytest import pytest
from homeassistant.components.tts import DOMAIN, Provider from homeassistant.components.media_player import (
ATTR_MEDIA_CONTENT_ID,
ATTR_MEDIA_CONTENT_TYPE,
DOMAIN as DOMAIN_MP,
SERVICE_PLAY_MEDIA,
MediaType,
)
from homeassistant.components.tts import ATTR_MESSAGE, DOMAIN, Provider
from homeassistant.const import ATTR_ENTITY_ID
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.discovery import async_load_platform from homeassistant.helpers.discovery import async_load_platform
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from .common import MockTTS from .common import SUPPORT_LANGUAGES, MockProvider, MockTTS, get_media_source_url
from tests.common import ( from tests.common import (
MockModule, MockModule,
assert_setup_component, assert_setup_component,
async_mock_service,
mock_integration, mock_integration,
mock_platform, mock_platform,
) )
class DefaultProvider(Provider):
"""Test provider."""
@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
return SUPPORT_LANGUAGES
async def test_default_provider_attributes() -> None: async def test_default_provider_attributes() -> None:
"""Test default provider properties.""" """Test default provider attributes."""
provider = Provider() provider = DefaultProvider()
assert provider.hass is None assert provider.hass is None
assert provider.name is None assert provider.name is None
assert provider.default_language is None assert provider.default_language is None
assert provider.supported_languages is None assert provider.supported_languages == SUPPORT_LANGUAGES
assert provider.supported_options is None assert provider.supported_options is None
assert provider.default_options is None assert provider.default_options is None
assert provider.async_get_supported_voices("test") is None
async def test_deprecated_platform(hass: HomeAssistant) -> None: async def test_deprecated_platform(hass: HomeAssistant) -> None:
@ -56,8 +75,7 @@ async def test_invalid_platform(
async def test_platform_setup_without_provider( async def test_platform_setup_without_provider(
hass: HomeAssistant, hass: HomeAssistant, caplog: pytest.LogCaptureFixture, mock_provider: MockProvider
caplog: pytest.LogCaptureFixture,
) -> None: ) -> None:
"""Test platform setup without provider returned.""" """Test platform setup without provider returned."""
@ -74,7 +92,7 @@ async def test_platform_setup_without_provider(
return None return None
mock_integration(hass, MockModule(domain="bad_tts")) mock_integration(hass, MockModule(domain="bad_tts"))
mock_platform(hass, "bad_tts.tts", BadPlatform()) mock_platform(hass, "bad_tts.tts", BadPlatform(mock_provider))
await async_load_platform( await async_load_platform(
hass, hass,
@ -91,6 +109,7 @@ async def test_platform_setup_without_provider(
async def test_platform_setup_with_error( async def test_platform_setup_with_error(
hass: HomeAssistant, hass: HomeAssistant,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
mock_provider: MockProvider,
) -> None: ) -> None:
"""Test platform setup with an error during setup.""" """Test platform setup with an error during setup."""
@ -107,7 +126,7 @@ async def test_platform_setup_with_error(
raise Exception("Setup error") # pylint: disable=broad-exception-raised raise Exception("Setup error") # pylint: disable=broad-exception-raised
mock_integration(hass, MockModule(domain="bad_tts")) mock_integration(hass, MockModule(domain="bad_tts"))
mock_platform(hass, "bad_tts.tts", BadPlatform()) mock_platform(hass, "bad_tts.tts", BadPlatform(mock_provider))
await async_load_platform( await async_load_platform(
hass, hass,
@ -119,3 +138,58 @@ async def test_platform_setup_with_error(
await hass.async_block_till_done() await hass.async_block_till_done()
assert "Error setting up platform: bad_tts" in caplog.text assert "Error setting up platform: bad_tts" in caplog.text
async def test_service_base_url_set(hass: HomeAssistant, mock_tts) -> None:
"""Set up a TTS platform with ``base_url`` set and call service."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
config = {DOMAIN: {"platform": "test", "base_url": "http://fnord"}}
with assert_setup_component(1, DOMAIN):
assert await async_setup_component(hass, DOMAIN, config)
await hass.services.async_call(
DOMAIN,
"test_say",
{
ATTR_ENTITY_ID: "media_player.something",
ATTR_MESSAGE: "There is someone at the door.",
},
blocking=True,
)
assert len(calls) == 1
assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MediaType.MUSIC
assert (
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== "http://fnord"
"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491"
"_en-us_-_test.mp3"
)
async def test_service_without_cache_config(
hass: HomeAssistant, empty_cache_dir, mock_tts
) -> None:
"""Set up a TTS platform without cache."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
config = {DOMAIN: {"platform": "test", "cache": False}}
with assert_setup_component(1, DOMAIN):
assert await async_setup_component(hass, DOMAIN, config)
await hass.services.async_call(
DOMAIN,
"test_say",
{
ATTR_ENTITY_ID: "media_player.something",
ATTR_MESSAGE: "There is someone at the door.",
},
blocking=True,
)
assert len(calls) == 1
await hass.async_block_till_done()
assert not (
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3"
).is_file()

View File

@ -1,5 +1,5 @@
"""Tests for TTS media source.""" """Tests for TTS media source."""
from unittest.mock import patch from unittest.mock import MagicMock
import pytest import pytest
@ -8,33 +8,52 @@ from homeassistant.components.media_player.errors import BrowseError
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from .common import (
@pytest.fixture(autouse=True) DEFAULT_LANG,
async def mock_get_tts_audio(hass): MockProvider,
"""Set up media source.""" MockTTSEntity,
assert await async_setup_component(hass, "media_source", {}) mock_config_entry_setup,
assert await async_setup_component( mock_setup,
hass,
"tts",
{
"tts": {
"platform": "demo",
}
},
) )
with patch(
"homeassistant.components.demo.tts.DemoProvider.get_tts_audio", class MSEntity(MockTTSEntity):
return_value=("mp3", b""), """Test speech API entity."""
) as mock_get_tts:
yield mock_get_tts get_tts_audio = MagicMock(return_value=("mp3", b""))
async def test_browsing(hass: HomeAssistant) -> None: class MSProvider(MockProvider):
"""Test speech API provider."""
get_tts_audio = MagicMock(return_value=("mp3", b""))
@pytest.fixture(autouse=True)
async def setup_media_source(hass: HomeAssistant) -> None:
"""Set up media source."""
assert await async_setup_component(hass, "media_source", {})
@pytest.mark.parametrize(
("mock_provider", "mock_tts_entity"),
[(MSProvider(DEFAULT_LANG), MSEntity(DEFAULT_LANG))],
)
@pytest.mark.parametrize(
"setup",
[
"mock_setup",
"mock_config_entry_setup",
],
indirect=["setup"],
)
async def test_browsing(hass: HomeAssistant, setup: str) -> None:
"""Test browsing TTS media source.""" """Test browsing TTS media source."""
item = await media_source.async_browse_media(hass, "media-source://tts") item = await media_source.async_browse_media(hass, "media-source://tts")
assert item is not None assert item is not None
assert item.title == "Text to Speech" assert item.title == "Text to Speech"
assert item.children is not None
assert len(item.children) == 1 assert len(item.children) == 1
assert item.can_play is False assert item.can_play is False
assert item.can_expand is True assert item.can_expand is True
@ -42,9 +61,10 @@ async def test_browsing(hass: HomeAssistant) -> None:
item_child = await media_source.async_browse_media( item_child = await media_source.async_browse_media(
hass, item.children[0].media_content_id hass, item.children[0].media_content_id
) )
assert item_child is not None assert item_child is not None
assert item_child.media_content_id == item.children[0].media_content_id assert item_child.media_content_id == item.children[0].media_content_id
assert item_child.title == "Demo" assert item_child.title == "Test"
assert item_child.children is None assert item_child.children is None
assert item_child.can_play is False assert item_child.can_play is False
assert item_child.can_expand is True assert item_child.can_expand is True
@ -52,12 +72,13 @@ async def test_browsing(hass: HomeAssistant) -> None:
item_child = await media_source.async_browse_media( item_child = await media_source.async_browse_media(
hass, item.children[0].media_content_id + "?message=bla" hass, item.children[0].media_content_id + "?message=bla"
) )
assert item_child is not None assert item_child is not None
assert ( assert (
item_child.media_content_id item_child.media_content_id
== item.children[0].media_content_id + "?message=bla" == item.children[0].media_content_id + "?message=bla"
) )
assert item_child.title == "Demo" assert item_child.title == "Test"
assert item_child.children is None assert item_child.children is None
assert item_child.can_play is False assert item_child.can_play is False
assert item_child.can_expand is True assert item_child.can_expand is True
@ -66,10 +87,14 @@ async def test_browsing(hass: HomeAssistant) -> None:
await media_source.async_browse_media(hass, "media-source://tts/non-existing") await media_source.async_browse_media(hass, "media-source://tts/non-existing")
async def test_resolving(hass: HomeAssistant, mock_get_tts_audio) -> None: @pytest.mark.parametrize("mock_provider", [MSProvider(DEFAULT_LANG)])
"""Test resolving.""" async def test_legacy_resolving(hass: HomeAssistant, mock_provider: MSProvider) -> None:
"""Test resolving legacy provider."""
await mock_setup(hass, mock_provider)
mock_get_tts_audio = mock_provider.get_tts_audio
media = await media_source.async_resolve_media( media = await media_source.async_resolve_media(
hass, "media-source://tts/demo?message=Hello%20World", None hass, "media-source://tts/test?message=Hello%20World", None
) )
assert media.url.startswith("/api/tts_proxy/") assert media.url.startswith("/api/tts_proxy/")
assert media.mime_type == "audio/mpeg" assert media.mime_type == "audio/mpeg"
@ -77,14 +102,14 @@ async def test_resolving(hass: HomeAssistant, mock_get_tts_audio) -> None:
assert len(mock_get_tts_audio.mock_calls) == 1 assert len(mock_get_tts_audio.mock_calls) == 1
message, language = mock_get_tts_audio.mock_calls[0][1] message, language = mock_get_tts_audio.mock_calls[0][1]
assert message == "Hello World" assert message == "Hello World"
assert language == "en" assert language == "en_US"
assert mock_get_tts_audio.mock_calls[0][2]["options"] is None assert mock_get_tts_audio.mock_calls[0][2]["options"] is None
# Pass language and options # Pass language and options
mock_get_tts_audio.reset_mock() mock_get_tts_audio.reset_mock()
media = await media_source.async_resolve_media( media = await media_source.async_resolve_media(
hass, hass,
"media-source://tts/demo?message=Bye%20World&language=de&voice=Paulus", "media-source://tts/test?message=Bye%20World&language=de&voice=Paulus",
None, None,
) )
assert media.url.startswith("/api/tts_proxy/") assert media.url.startswith("/api/tts_proxy/")
@ -93,15 +118,62 @@ async def test_resolving(hass: HomeAssistant, mock_get_tts_audio) -> None:
assert len(mock_get_tts_audio.mock_calls) == 1 assert len(mock_get_tts_audio.mock_calls) == 1
message, language = mock_get_tts_audio.mock_calls[0][1] message, language = mock_get_tts_audio.mock_calls[0][1]
assert message == "Bye World" assert message == "Bye World"
assert language == "de" assert language == "de_DE"
assert mock_get_tts_audio.mock_calls[0][2]["options"] == {"voice": "Paulus"} assert mock_get_tts_audio.mock_calls[0][2]["options"] == {"voice": "Paulus"}
async def test_resolving_errors(hass: HomeAssistant) -> None: @pytest.mark.parametrize("mock_tts_entity", [MSEntity(DEFAULT_LANG)])
async def test_resolving(hass: HomeAssistant, mock_tts_entity: MSEntity) -> None:
"""Test resolving entity."""
await mock_config_entry_setup(hass, mock_tts_entity)
mock_get_tts_audio = mock_tts_entity.get_tts_audio
media = await media_source.async_resolve_media(
hass, "media-source://tts/tts.test?message=Hello%20World", None
)
assert media.url.startswith("/api/tts_proxy/")
assert media.mime_type == "audio/mpeg"
assert len(mock_get_tts_audio.mock_calls) == 1
message, language = mock_get_tts_audio.mock_calls[0][1]
assert message == "Hello World"
assert language == "en_US"
assert mock_get_tts_audio.mock_calls[0][2]["options"] is None
# Pass language and options
mock_get_tts_audio.reset_mock()
media = await media_source.async_resolve_media(
hass,
"media-source://tts/tts.test?message=Bye%20World&language=de&voice=Paulus",
None,
)
assert media.url.startswith("/api/tts_proxy/")
assert media.mime_type == "audio/mpeg"
assert len(mock_get_tts_audio.mock_calls) == 1
message, language = mock_get_tts_audio.mock_calls[0][1]
assert message == "Bye World"
assert language == "de_DE"
assert mock_get_tts_audio.mock_calls[0][2]["options"] == {"voice": "Paulus"}
@pytest.mark.parametrize(
("mock_provider", "mock_tts_entity"),
[(MSProvider(DEFAULT_LANG), MSEntity(DEFAULT_LANG))],
)
@pytest.mark.parametrize(
"setup",
[
"mock_setup",
"mock_config_entry_setup",
],
indirect=["setup"],
)
async def test_resolving_errors(hass: HomeAssistant, setup: str) -> None:
"""Test resolving.""" """Test resolving."""
# No message added # No message added
with pytest.raises(media_source.Unresolvable): with pytest.raises(media_source.Unresolvable):
await media_source.async_resolve_media(hass, "media-source://tts/demo", None) await media_source.async_resolve_media(hass, "media-source://tts/test", None)
# Non-existing provider # Non-existing provider
with pytest.raises(media_source.Unresolvable): with pytest.raises(media_source.Unresolvable):

View File

@ -1,28 +1,22 @@
"""The tests for the TTS component.""" """The tests for the TTS component."""
import pytest import pytest
import yarl
import homeassistant.components.media_player as media_player from homeassistant.components import media_player, notify, tts
from homeassistant.components.media_player import ( from homeassistant.components.media_player import (
DOMAIN as DOMAIN_MP, DOMAIN as DOMAIN_MP,
SERVICE_PLAY_MEDIA, SERVICE_PLAY_MEDIA,
) )
import homeassistant.components.notify as notify
import homeassistant.components.tts as tts
from homeassistant.config import async_process_ha_core_config from homeassistant.config import async_process_ha_core_config
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from .common import MockTTSEntity, mock_config_entry_setup
from tests.common import assert_setup_component, async_mock_service from tests.common import assert_setup_component, async_mock_service
def relative_url(url):
"""Convert an absolute url to a relative one."""
return str(yarl.URL(url).relative())
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
async def internal_url_mock(hass): async def internal_url_mock(hass: HomeAssistant) -> None:
"""Mock internal URL of the instance.""" """Mock internal URL of the instance."""
await async_process_ha_core_config( await async_process_ha_core_config(
hass, hass,
@ -30,8 +24,8 @@ async def internal_url_mock(hass):
) )
async def test_setup_platform(hass: HomeAssistant) -> None: async def test_setup_legacy_platform(hass: HomeAssistant) -> None:
"""Set up the tts platform .""" """Set up the tts notify platform ."""
config = { config = {
notify.DOMAIN: { notify.DOMAIN: {
"platform": "tts", "platform": "tts",
@ -46,7 +40,23 @@ async def test_setup_platform(hass: HomeAssistant) -> None:
assert hass.services.has_service(notify.DOMAIN, "tts_test") assert hass.services.has_service(notify.DOMAIN, "tts_test")
async def test_setup_component_and_test_service(hass: HomeAssistant) -> None: async def test_setup_platform(hass: HomeAssistant) -> None:
"""Set up the tts notify platform ."""
config = {
notify.DOMAIN: {
"platform": "tts",
"name": "tts_test",
"entity_id": "tts.test",
"media_player": "media_player.demo",
}
}
with assert_setup_component(1, notify.DOMAIN):
assert await async_setup_component(hass, notify.DOMAIN, config)
assert hass.services.has_service(notify.DOMAIN, "tts_test")
async def test_setup_legacy_service(hass: HomeAssistant) -> None:
"""Set up the demo platform and call service.""" """Set up the demo platform and call service."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -80,3 +90,38 @@ async def test_setup_component_and_test_service(hass: HomeAssistant) -> None:
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(calls) == 1 assert len(calls) == 1
async def test_setup_service(
hass: HomeAssistant, mock_tts_entity: MockTTSEntity
) -> None:
"""Set up platform and call service."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
config = {
notify.DOMAIN: {
"platform": "tts",
"name": "tts_test",
"entity_id": "tts.test",
"media_player": "media_player.demo",
"language": "en",
},
}
await mock_config_entry_setup(hass, mock_tts_entity)
with assert_setup_component(1, notify.DOMAIN):
assert await async_setup_component(hass, notify.DOMAIN, config)
await hass.services.async_call(
notify.DOMAIN,
"tts_test",
{
tts.ATTR_MESSAGE: "There is someone at the door.",
},
blocking=True,
)
await hass.async_block_till_done()
assert len(calls) == 1