mirror of
https://github.com/home-assistant/core.git
synced 2025-04-24 17:27:52 +00:00
Add ElevenLabs text-to-speech integration (#115645)
* Add ElevenLabs text-to-speech integration * Remove commented out code * Use model_id instead of model_name for elevenlabs api * Apply suggestions from code review Co-authored-by: Sid <27780930+autinerd@users.noreply.github.com> * Use async client instead of sync * Add ElevenLabs code owner * Apply suggestions from code review Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io> * Set entity title to voice * Rename to elevenlabs * Apply suggestions from code review Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io> * Allow multiple voices and options flow * Sort default voice at beginning * Rework config flow to include default model and reloading on options flow * Add error to strings * Add ElevenLabsData and suggestions from code review * Shorten options and config flow * Fix comments * Fix comments * Add wip * Fix * Cleanup * Bump elevenlabs version * Add data description * Fix --------- Co-authored-by: Sid <27780930+autinerd@users.noreply.github.com> Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io> Co-authored-by: Michael Hansen <mike@rhasspy.org> Co-authored-by: Joostlek <joostlek@outlook.com> Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
This commit is contained in:
parent
7bc2381a45
commit
5fefa606b6
@ -168,6 +168,7 @@ homeassistant.components.ecowitt.*
|
||||
homeassistant.components.efergy.*
|
||||
homeassistant.components.electrasmart.*
|
||||
homeassistant.components.electric_kiwi.*
|
||||
homeassistant.components.elevenlabs.*
|
||||
homeassistant.components.elgato.*
|
||||
homeassistant.components.elkm1.*
|
||||
homeassistant.components.emulated_hue.*
|
||||
|
@ -376,6 +376,8 @@ build.json @home-assistant/supervisor
|
||||
/tests/components/electrasmart/ @jafar-atili
|
||||
/homeassistant/components/electric_kiwi/ @mikey0000
|
||||
/tests/components/electric_kiwi/ @mikey0000
|
||||
/homeassistant/components/elevenlabs/ @sorgfresser
|
||||
/tests/components/elevenlabs/ @sorgfresser
|
||||
/homeassistant/components/elgato/ @frenck
|
||||
/tests/components/elgato/ @frenck
|
||||
/homeassistant/components/elkm1/ @gwww @bdraco
|
||||
|
71
homeassistant/components/elevenlabs/__init__.py
Normal file
71
homeassistant/components/elevenlabs/__init__.py
Normal file
@ -0,0 +1,71 @@
|
||||
"""The ElevenLabs text-to-speech integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from elevenlabs import Model
|
||||
from elevenlabs.client import AsyncElevenLabs
|
||||
from elevenlabs.core import ApiError
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_API_KEY, Platform
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import ConfigEntryError
|
||||
|
||||
from .const import CONF_MODEL
|
||||
|
||||
PLATFORMS: list[Platform] = [Platform.TTS]
|
||||
|
||||
|
||||
async def get_model_by_id(client: AsyncElevenLabs, model_id: str) -> Model | None:
|
||||
"""Get ElevenLabs model from their API by the model_id."""
|
||||
models = await client.models.get_all()
|
||||
for maybe_model in models:
|
||||
if maybe_model.model_id == model_id:
|
||||
return maybe_model
|
||||
return None
|
||||
|
||||
|
||||
@dataclass(kw_only=True, slots=True)
|
||||
class ElevenLabsData:
|
||||
"""ElevenLabs data type."""
|
||||
|
||||
client: AsyncElevenLabs
|
||||
model: Model
|
||||
|
||||
|
||||
type EleventLabsConfigEntry = ConfigEntry[ElevenLabsData]
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: EleventLabsConfigEntry) -> bool:
|
||||
"""Set up ElevenLabs text-to-speech from a config entry."""
|
||||
entry.add_update_listener(update_listener)
|
||||
client = AsyncElevenLabs(api_key=entry.data[CONF_API_KEY])
|
||||
model_id = entry.options[CONF_MODEL]
|
||||
try:
|
||||
model = await get_model_by_id(client, model_id)
|
||||
except ApiError as err:
|
||||
raise ConfigEntryError("Auth failed") from err
|
||||
|
||||
if model is None or (not model.languages):
|
||||
raise ConfigEntryError("Model could not be resolved")
|
||||
|
||||
entry.runtime_data = ElevenLabsData(client=client, model=model)
|
||||
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def async_unload_entry(
|
||||
hass: HomeAssistant, entry: EleventLabsConfigEntry
|
||||
) -> bool:
|
||||
"""Unload a config entry."""
|
||||
return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
|
||||
|
||||
|
||||
async def update_listener(
|
||||
hass: HomeAssistant, config_entry: EleventLabsConfigEntry
|
||||
) -> None:
|
||||
"""Handle options update."""
|
||||
await hass.config_entries.async_reload(config_entry.entry_id)
|
145
homeassistant/components/elevenlabs/config_flow.py
Normal file
145
homeassistant/components/elevenlabs/config_flow.py
Normal file
@ -0,0 +1,145 @@
|
||||
"""Config flow for ElevenLabs text-to-speech integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from elevenlabs.client import AsyncElevenLabs
|
||||
from elevenlabs.core import ApiError
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.config_entries import (
|
||||
ConfigEntry,
|
||||
ConfigFlow,
|
||||
ConfigFlowResult,
|
||||
OptionsFlow,
|
||||
OptionsFlowWithConfigEntry,
|
||||
)
|
||||
from homeassistant.const import CONF_API_KEY
|
||||
from homeassistant.helpers.selector import (
|
||||
SelectOptionDict,
|
||||
SelectSelector,
|
||||
SelectSelectorConfig,
|
||||
)
|
||||
|
||||
from .const import CONF_MODEL, CONF_VOICE, DEFAULT_MODEL, DOMAIN
|
||||
|
||||
USER_STEP_SCHEMA = vol.Schema({vol.Required(CONF_API_KEY): str})
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_voices_models(api_key: str) -> tuple[dict[str, str], dict[str, str]]:
|
||||
"""Get available voices and models as dicts."""
|
||||
client = AsyncElevenLabs(api_key=api_key)
|
||||
voices = (await client.voices.get_all()).voices
|
||||
models = await client.models.get_all()
|
||||
voices_dict = {
|
||||
voice.voice_id: voice.name
|
||||
for voice in sorted(voices, key=lambda v: v.name or "")
|
||||
if voice.name
|
||||
}
|
||||
models_dict = {
|
||||
model.model_id: model.name
|
||||
for model in sorted(models, key=lambda m: m.name or "")
|
||||
if model.name and model.can_do_text_to_speech
|
||||
}
|
||||
return voices_dict, models_dict
|
||||
|
||||
|
||||
class ElevenLabsConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
"""Handle a config flow for ElevenLabs text-to-speech."""
|
||||
|
||||
VERSION = 1
|
||||
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> ConfigFlowResult:
|
||||
"""Handle the initial step."""
|
||||
errors: dict[str, str] = {}
|
||||
if user_input is not None:
|
||||
try:
|
||||
voices, _ = await get_voices_models(user_input[CONF_API_KEY])
|
||||
except ApiError:
|
||||
errors["base"] = "invalid_api_key"
|
||||
else:
|
||||
return self.async_create_entry(
|
||||
title="ElevenLabs",
|
||||
data=user_input,
|
||||
options={CONF_MODEL: DEFAULT_MODEL, CONF_VOICE: list(voices)[0]},
|
||||
)
|
||||
return self.async_show_form(
|
||||
step_id="user", data_schema=USER_STEP_SCHEMA, errors=errors
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def async_get_options_flow(
|
||||
config_entry: ConfigEntry,
|
||||
) -> OptionsFlow:
|
||||
"""Create the options flow."""
|
||||
return ElevenLabsOptionsFlow(config_entry)
|
||||
|
||||
|
||||
class ElevenLabsOptionsFlow(OptionsFlowWithConfigEntry):
|
||||
"""ElevenLabs options flow."""
|
||||
|
||||
def __init__(self, config_entry: ConfigEntry) -> None:
|
||||
"""Initialize options flow."""
|
||||
super().__init__(config_entry)
|
||||
self.api_key: str = self.config_entry.data[CONF_API_KEY]
|
||||
# id -> name
|
||||
self.voices: dict[str, str] = {}
|
||||
self.models: dict[str, str] = {}
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> ConfigFlowResult:
|
||||
"""Manage the options."""
|
||||
if not self.voices or not self.models:
|
||||
self.voices, self.models = await get_voices_models(self.api_key)
|
||||
|
||||
assert self.models and self.voices
|
||||
|
||||
if user_input is not None:
|
||||
return self.async_create_entry(
|
||||
title="ElevenLabs",
|
||||
data=user_input,
|
||||
)
|
||||
|
||||
schema = self.elevenlabs_config_option_schema()
|
||||
return self.async_show_form(
|
||||
step_id="init",
|
||||
data_schema=schema,
|
||||
)
|
||||
|
||||
def elevenlabs_config_option_schema(self) -> vol.Schema:
|
||||
"""Elevenlabs options schema."""
|
||||
return self.add_suggested_values_to_schema(
|
||||
vol.Schema(
|
||||
{
|
||||
vol.Required(
|
||||
CONF_MODEL,
|
||||
): SelectSelector(
|
||||
SelectSelectorConfig(
|
||||
options=[
|
||||
SelectOptionDict(label=model_name, value=model_id)
|
||||
for model_id, model_name in self.models.items()
|
||||
]
|
||||
)
|
||||
),
|
||||
vol.Required(
|
||||
CONF_VOICE,
|
||||
): SelectSelector(
|
||||
SelectSelectorConfig(
|
||||
options=[
|
||||
SelectOptionDict(label=voice_name, value=voice_id)
|
||||
for voice_id, voice_name in self.voices.items()
|
||||
]
|
||||
)
|
||||
),
|
||||
}
|
||||
),
|
||||
self.options,
|
||||
)
|
7
homeassistant/components/elevenlabs/const.py
Normal file
7
homeassistant/components/elevenlabs/const.py
Normal file
@ -0,0 +1,7 @@
|
||||
"""Constants for the ElevenLabs text-to-speech integration."""
|
||||
|
||||
CONF_VOICE = "voice"
|
||||
CONF_MODEL = "model"
|
||||
DOMAIN = "elevenlabs"
|
||||
|
||||
DEFAULT_MODEL = "eleven_multilingual_v2"
|
11
homeassistant/components/elevenlabs/manifest.json
Normal file
11
homeassistant/components/elevenlabs/manifest.json
Normal file
@ -0,0 +1,11 @@
|
||||
{
|
||||
"domain": "elevenlabs",
|
||||
"name": "ElevenLabs",
|
||||
"codeowners": ["@sorgfresser"],
|
||||
"config_flow": true,
|
||||
"documentation": "https://www.home-assistant.io/integrations/elevenlabs",
|
||||
"integration_type": "service",
|
||||
"iot_class": "cloud_polling",
|
||||
"loggers": ["elevenlabs"],
|
||||
"requirements": ["elevenlabs==1.6.1"]
|
||||
}
|
31
homeassistant/components/elevenlabs/strings.json
Normal file
31
homeassistant/components/elevenlabs/strings.json
Normal file
@ -0,0 +1,31 @@
|
||||
{
|
||||
"config": {
|
||||
"step": {
|
||||
"user": {
|
||||
"data": {
|
||||
"api_key": "[%key:common::config_flow::data::api_key%]"
|
||||
},
|
||||
"data_description": {
|
||||
"api_key": "Your Elevenlabs API key."
|
||||
}
|
||||
}
|
||||
},
|
||||
"error": {
|
||||
"invalid_api_key": "[%key:common::config_flow::error::invalid_api_key%]"
|
||||
}
|
||||
},
|
||||
"options": {
|
||||
"step": {
|
||||
"init": {
|
||||
"data": {
|
||||
"voice": "Voice",
|
||||
"model": "Model"
|
||||
},
|
||||
"data_description": {
|
||||
"voice": "Voice to use for the TTS.",
|
||||
"model": "ElevenLabs model to use. Please note that not all models support all languages equally well."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
116
homeassistant/components/elevenlabs/tts.py
Normal file
116
homeassistant/components/elevenlabs/tts.py
Normal file
@ -0,0 +1,116 @@
|
||||
"""Support for the ElevenLabs text-to-speech service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from elevenlabs.client import AsyncElevenLabs
|
||||
from elevenlabs.core import ApiError
|
||||
from elevenlabs.types import Model, Voice as ElevenLabsVoice
|
||||
|
||||
from homeassistant.components.tts import (
|
||||
ATTR_VOICE,
|
||||
TextToSpeechEntity,
|
||||
TtsAudioType,
|
||||
Voice,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
||||
from . import EleventLabsConfigEntry
|
||||
from .const import CONF_VOICE, DOMAIN
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: EleventLabsConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up ElevenLabs tts platform via config entry."""
|
||||
client = config_entry.runtime_data.client
|
||||
voices = (await client.voices.get_all()).voices
|
||||
default_voice_id = config_entry.options[CONF_VOICE]
|
||||
async_add_entities(
|
||||
[
|
||||
ElevenLabsTTSEntity(
|
||||
client,
|
||||
config_entry.runtime_data.model,
|
||||
voices,
|
||||
default_voice_id,
|
||||
config_entry.entry_id,
|
||||
config_entry.title,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class ElevenLabsTTSEntity(TextToSpeechEntity):
|
||||
"""The ElevenLabs API entity."""
|
||||
|
||||
_attr_supported_options = [ATTR_VOICE]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: AsyncElevenLabs,
|
||||
model: Model,
|
||||
voices: list[ElevenLabsVoice],
|
||||
default_voice_id: str,
|
||||
entry_id: str,
|
||||
title: str,
|
||||
) -> None:
|
||||
"""Init ElevenLabs TTS service."""
|
||||
self._client = client
|
||||
self._model = model
|
||||
self._default_voice_id = default_voice_id
|
||||
self._voices = sorted(
|
||||
(Voice(v.voice_id, v.name) for v in voices if v.name),
|
||||
key=lambda v: v.name,
|
||||
)
|
||||
# Default voice first
|
||||
voice_indices = [
|
||||
idx for idx, v in enumerate(self._voices) if v.voice_id == default_voice_id
|
||||
]
|
||||
if voice_indices:
|
||||
self._voices.insert(0, self._voices.pop(voice_indices[0]))
|
||||
self._attr_unique_id = entry_id
|
||||
self._attr_name = title
|
||||
self._attr_device_info = DeviceInfo(
|
||||
identifiers={(DOMAIN, entry_id)},
|
||||
manufacturer="ElevenLabs",
|
||||
model=model.name,
|
||||
entry_type=DeviceEntryType.SERVICE,
|
||||
)
|
||||
self._attr_supported_languages = [
|
||||
lang.language_id for lang in self._model.languages or []
|
||||
]
|
||||
self._attr_default_language = self.supported_languages[0]
|
||||
|
||||
def async_get_supported_voices(self, language: str) -> list[Voice]:
|
||||
"""Return a list of supported voices for a language."""
|
||||
return self._voices
|
||||
|
||||
async def async_get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> TtsAudioType:
|
||||
"""Load tts audio file from the engine."""
|
||||
_LOGGER.debug("Getting TTS audio for %s", message)
|
||||
_LOGGER.debug("Options: %s", options)
|
||||
voice_id = options[ATTR_VOICE]
|
||||
try:
|
||||
audio = await self._client.generate(
|
||||
text=message,
|
||||
voice=voice_id,
|
||||
model=self._model.model_id,
|
||||
)
|
||||
bytes_combined = b"".join([byte_seg async for byte_seg in audio])
|
||||
except ApiError as exc:
|
||||
_LOGGER.warning(
|
||||
"Error during processing of TTS request %s", exc, exc_info=True
|
||||
)
|
||||
raise HomeAssistantError(exc) from exc
|
||||
return "mp3", bytes_combined
|
@ -150,6 +150,7 @@ FLOWS = {
|
||||
"efergy",
|
||||
"electrasmart",
|
||||
"electric_kiwi",
|
||||
"elevenlabs",
|
||||
"elgato",
|
||||
"elkm1",
|
||||
"elmax",
|
||||
|
@ -1516,6 +1516,12 @@
|
||||
"config_flow": true,
|
||||
"iot_class": "cloud_polling"
|
||||
},
|
||||
"elevenlabs": {
|
||||
"name": "ElevenLabs",
|
||||
"integration_type": "service",
|
||||
"config_flow": true,
|
||||
"iot_class": "cloud_polling"
|
||||
},
|
||||
"elgato": {
|
||||
"name": "Elgato",
|
||||
"integrations": {
|
||||
|
10
mypy.ini
10
mypy.ini
@ -1436,6 +1436,16 @@ disallow_untyped_defs = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.elevenlabs.*]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_untyped_defs = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.elgato.*]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
|
@ -779,6 +779,9 @@ ecoaliface==0.4.0
|
||||
# homeassistant.components.electric_kiwi
|
||||
electrickiwi-api==0.8.5
|
||||
|
||||
# homeassistant.components.elevenlabs
|
||||
elevenlabs==1.6.1
|
||||
|
||||
# homeassistant.components.elgato
|
||||
elgato==5.1.2
|
||||
|
||||
|
@ -660,6 +660,9 @@ easyenergy==2.1.2
|
||||
# homeassistant.components.electric_kiwi
|
||||
electrickiwi-api==0.8.5
|
||||
|
||||
# homeassistant.components.elevenlabs
|
||||
elevenlabs==1.6.1
|
||||
|
||||
# homeassistant.components.elgato
|
||||
elgato==5.1.2
|
||||
|
||||
|
1
tests/components/elevenlabs/__init__.py
Normal file
1
tests/components/elevenlabs/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Tests for the ElevenLabs integration."""
|
65
tests/components/elevenlabs/conftest.py
Normal file
65
tests/components/elevenlabs/conftest.py
Normal file
@ -0,0 +1,65 @@
|
||||
"""Common fixtures for the ElevenLabs text-to-speech tests."""
|
||||
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from elevenlabs.core import ApiError
|
||||
from elevenlabs.types import GetVoicesResponse
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.elevenlabs.const import CONF_MODEL, CONF_VOICE
|
||||
from homeassistant.const import CONF_API_KEY
|
||||
|
||||
from .const import MOCK_MODELS, MOCK_VOICES
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_setup_entry() -> Generator[AsyncMock, None, None]:
|
||||
"""Override async_setup_entry."""
|
||||
with patch(
|
||||
"homeassistant.components.elevenlabs.async_setup_entry", return_value=True
|
||||
) as mock_setup_entry:
|
||||
yield mock_setup_entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_async_client() -> Generator[AsyncMock, None, None]:
|
||||
"""Override async ElevenLabs client."""
|
||||
client_mock = AsyncMock()
|
||||
client_mock.voices.get_all.return_value = GetVoicesResponse(voices=MOCK_VOICES)
|
||||
client_mock.models.get_all.return_value = MOCK_MODELS
|
||||
with patch(
|
||||
"elevenlabs.client.AsyncElevenLabs", return_value=client_mock
|
||||
) as mock_async_client:
|
||||
yield mock_async_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_async_client_fail() -> Generator[AsyncMock, None, None]:
|
||||
"""Override async ElevenLabs client."""
|
||||
with patch(
|
||||
"homeassistant.components.elevenlabs.config_flow.AsyncElevenLabs",
|
||||
return_value=AsyncMock(),
|
||||
) as mock_async_client:
|
||||
mock_async_client.side_effect = ApiError
|
||||
yield mock_async_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_entry() -> MockConfigEntry:
|
||||
"""Mock a config entry."""
|
||||
entry = MockConfigEntry(
|
||||
domain="elevenlabs",
|
||||
data={
|
||||
CONF_API_KEY: "api_key",
|
||||
},
|
||||
options={CONF_MODEL: "model1", CONF_VOICE: "voice1"},
|
||||
)
|
||||
entry.models = {
|
||||
"model1": "model1",
|
||||
}
|
||||
|
||||
entry.voices = {"voice1": "voice1"}
|
||||
return entry
|
52
tests/components/elevenlabs/const.py
Normal file
52
tests/components/elevenlabs/const.py
Normal file
@ -0,0 +1,52 @@
|
||||
"""Constants for the Testing of the ElevenLabs text-to-speech integration."""
|
||||
|
||||
from elevenlabs.types import LanguageResponse, Model, Voice
|
||||
|
||||
from homeassistant.components.elevenlabs.const import DEFAULT_MODEL
|
||||
|
||||
MOCK_VOICES = [
|
||||
Voice(
|
||||
voice_id="voice1",
|
||||
name="Voice 1",
|
||||
),
|
||||
Voice(
|
||||
voice_id="voice2",
|
||||
name="Voice 2",
|
||||
),
|
||||
]
|
||||
|
||||
MOCK_MODELS = [
|
||||
Model(
|
||||
model_id="model1",
|
||||
name="Model 1",
|
||||
can_do_text_to_speech=True,
|
||||
languages=[
|
||||
LanguageResponse(language_id="en", name="English"),
|
||||
LanguageResponse(language_id="de", name="German"),
|
||||
LanguageResponse(language_id="es", name="Spanish"),
|
||||
LanguageResponse(language_id="ja", name="Japanese"),
|
||||
],
|
||||
),
|
||||
Model(
|
||||
model_id="model2",
|
||||
name="Model 2",
|
||||
can_do_text_to_speech=True,
|
||||
languages=[
|
||||
LanguageResponse(language_id="en", name="English"),
|
||||
LanguageResponse(language_id="de", name="German"),
|
||||
LanguageResponse(language_id="es", name="Spanish"),
|
||||
LanguageResponse(language_id="ja", name="Japanese"),
|
||||
],
|
||||
),
|
||||
Model(
|
||||
model_id=DEFAULT_MODEL,
|
||||
name=DEFAULT_MODEL,
|
||||
can_do_text_to_speech=True,
|
||||
languages=[
|
||||
LanguageResponse(language_id="en", name="English"),
|
||||
LanguageResponse(language_id="de", name="German"),
|
||||
LanguageResponse(language_id="es", name="Spanish"),
|
||||
LanguageResponse(language_id="ja", name="Japanese"),
|
||||
],
|
||||
),
|
||||
]
|
94
tests/components/elevenlabs/test_config_flow.py
Normal file
94
tests/components/elevenlabs/test_config_flow.py
Normal file
@ -0,0 +1,94 @@
|
||||
"""Test the ElevenLabs text-to-speech config flow."""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from homeassistant.components.elevenlabs.const import (
|
||||
CONF_MODEL,
|
||||
CONF_VOICE,
|
||||
DEFAULT_MODEL,
|
||||
DOMAIN,
|
||||
)
|
||||
from homeassistant.config_entries import SOURCE_USER
|
||||
from homeassistant.const import CONF_API_KEY
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.data_entry_flow import FlowResultType
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
async def test_user_step(
|
||||
hass: HomeAssistant,
|
||||
mock_setup_entry: AsyncMock,
|
||||
mock_async_client: AsyncMock,
|
||||
) -> None:
|
||||
"""Test user step create entry result."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": SOURCE_USER}
|
||||
)
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert not result["errors"]
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
CONF_API_KEY: "api_key",
|
||||
},
|
||||
)
|
||||
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result["title"] == "ElevenLabs"
|
||||
assert result["data"] == {
|
||||
"api_key": "api_key",
|
||||
}
|
||||
assert result["options"] == {CONF_MODEL: DEFAULT_MODEL, CONF_VOICE: "voice1"}
|
||||
|
||||
mock_setup_entry.assert_called_once()
|
||||
|
||||
|
||||
async def test_invalid_api_key(
|
||||
hass: HomeAssistant, mock_setup_entry: AsyncMock, mock_async_client_fail: AsyncMock
|
||||
) -> None:
|
||||
"""Test user step with invalid api key."""
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": SOURCE_USER}
|
||||
)
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert not result["errors"]
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
CONF_API_KEY: "api_key",
|
||||
},
|
||||
)
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["errors"]
|
||||
|
||||
mock_setup_entry.assert_not_called()
|
||||
|
||||
|
||||
async def test_options_flow_init(
|
||||
hass: HomeAssistant,
|
||||
mock_setup_entry: AsyncMock,
|
||||
mock_async_client: AsyncMock,
|
||||
mock_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test options flow init."""
|
||||
mock_entry.add_to_hass(hass)
|
||||
assert await hass.config_entries.async_setup(mock_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
result = await hass.config_entries.options.async_init(mock_entry.entry_id)
|
||||
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["step_id"] == "init"
|
||||
|
||||
result = await hass.config_entries.options.async_configure(
|
||||
result["flow_id"],
|
||||
user_input={CONF_MODEL: "model1", CONF_VOICE: "voice1"},
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert mock_entry.options == {CONF_MODEL: "model1", CONF_VOICE: "voice1"}
|
||||
|
||||
mock_setup_entry.assert_called_once()
|
270
tests/components/elevenlabs/test_tts.py
Normal file
270
tests/components/elevenlabs/test_tts.py
Normal file
@ -0,0 +1,270 @@
|
||||
"""Tests for the ElevenLabs TTS entity."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from elevenlabs.core import ApiError
|
||||
from elevenlabs.types import GetVoicesResponse
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import tts
|
||||
from homeassistant.components.elevenlabs.const import CONF_MODEL, CONF_VOICE, DOMAIN
|
||||
from homeassistant.components.media_player import (
|
||||
ATTR_MEDIA_CONTENT_ID,
|
||||
DOMAIN as DOMAIN_MP,
|
||||
SERVICE_PLAY_MEDIA,
|
||||
)
|
||||
from homeassistant.config import async_process_ha_core_config
|
||||
from homeassistant.const import ATTR_ENTITY_ID, CONF_API_KEY
|
||||
from homeassistant.core import HomeAssistant, ServiceCall
|
||||
|
||||
from .const import MOCK_MODELS, MOCK_VOICES
|
||||
|
||||
from tests.common import MockConfigEntry, async_mock_service
|
||||
from tests.components.tts.common import retrieve_media
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def tts_mutagen_mock_fixture_autouse(tts_mutagen_mock):
|
||||
"""Mock writing tags."""
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_tts_cache_dir_autouse(mock_tts_cache_dir):
|
||||
"""Mock the TTS cache dir with empty dir."""
|
||||
return mock_tts_cache_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def calls(hass: HomeAssistant) -> list[ServiceCall]:
|
||||
"""Mock media player calls."""
|
||||
return async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def setup_internal_url(hass: HomeAssistant) -> None:
|
||||
"""Set up internal url."""
|
||||
await async_process_ha_core_config(
|
||||
hass, {"internal_url": "http://example.local:8123"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="setup")
|
||||
async def setup_fixture(
|
||||
hass: HomeAssistant,
|
||||
config_data: dict[str, Any],
|
||||
config_options: dict[str, Any],
|
||||
request: pytest.FixtureRequest,
|
||||
mock_async_client: AsyncMock,
|
||||
) -> AsyncMock:
|
||||
"""Set up the test environment."""
|
||||
if request.param == "mock_config_entry_setup":
|
||||
await mock_config_entry_setup(hass, config_data, config_options)
|
||||
else:
|
||||
raise RuntimeError("Invalid setup fixture")
|
||||
|
||||
await hass.async_block_till_done()
|
||||
return mock_async_client
|
||||
|
||||
|
||||
@pytest.fixture(name="config_data")
|
||||
def config_data_fixture() -> dict[str, Any]:
|
||||
"""Return config data."""
|
||||
return {}
|
||||
|
||||
|
||||
@pytest.fixture(name="config_options")
|
||||
def config_options_fixture() -> dict[str, Any]:
|
||||
"""Return config options."""
|
||||
return {}
|
||||
|
||||
|
||||
async def mock_config_entry_setup(
|
||||
hass: HomeAssistant, config_data: dict[str, Any], config_options: dict[str, Any]
|
||||
) -> None:
|
||||
"""Mock config entry setup."""
|
||||
default_config_data = {
|
||||
CONF_API_KEY: "api_key",
|
||||
}
|
||||
default_config_options = {
|
||||
CONF_VOICE: "voice1",
|
||||
CONF_MODEL: "model1",
|
||||
}
|
||||
config_entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data=default_config_data | config_data,
|
||||
options=default_config_options | config_options,
|
||||
)
|
||||
config_entry.add_to_hass(hass)
|
||||
client_mock = AsyncMock()
|
||||
client_mock.voices.get_all.return_value = GetVoicesResponse(voices=MOCK_VOICES)
|
||||
client_mock.models.get_all.return_value = MOCK_MODELS
|
||||
with patch(
|
||||
"homeassistant.components.elevenlabs.AsyncElevenLabs", return_value=client_mock
|
||||
):
|
||||
assert await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config_data",
|
||||
[
|
||||
{},
|
||||
{tts.CONF_LANG: "de"},
|
||||
{tts.CONF_LANG: "en"},
|
||||
{tts.CONF_LANG: "ja"},
|
||||
{tts.CONF_LANG: "es"},
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
("setup", "tts_service", "service_data"),
|
||||
[
|
||||
(
|
||||
"mock_config_entry_setup",
|
||||
"speak",
|
||||
{
|
||||
ATTR_ENTITY_ID: "tts.mock_title",
|
||||
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
|
||||
tts.ATTR_MESSAGE: "There is a person at the front door.",
|
||||
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice2"},
|
||||
},
|
||||
),
|
||||
],
|
||||
indirect=["setup"],
|
||||
)
|
||||
async def test_tts_service_speak(
|
||||
setup: AsyncMock,
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
calls: list[ServiceCall],
|
||||
tts_service: str,
|
||||
service_data: dict[str, Any],
|
||||
) -> None:
|
||||
"""Test tts service."""
|
||||
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
|
||||
tts_entity._client.generate.reset_mock()
|
||||
|
||||
await hass.services.async_call(
|
||||
tts.DOMAIN,
|
||||
tts_service,
|
||||
service_data,
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
|
||||
tts_entity._client.generate.assert_called_once_with(
|
||||
text="There is a person at the front door.", voice="voice2", model="model1"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("setup", "tts_service", "service_data"),
|
||||
[
|
||||
(
|
||||
"mock_config_entry_setup",
|
||||
"speak",
|
||||
{
|
||||
ATTR_ENTITY_ID: "tts.mock_title",
|
||||
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
|
||||
tts.ATTR_MESSAGE: "There is a person at the front door.",
|
||||
tts.ATTR_LANGUAGE: "de",
|
||||
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice1"},
|
||||
},
|
||||
),
|
||||
(
|
||||
"mock_config_entry_setup",
|
||||
"speak",
|
||||
{
|
||||
ATTR_ENTITY_ID: "tts.mock_title",
|
||||
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
|
||||
tts.ATTR_MESSAGE: "There is a person at the front door.",
|
||||
tts.ATTR_LANGUAGE: "es",
|
||||
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice1"},
|
||||
},
|
||||
),
|
||||
],
|
||||
indirect=["setup"],
|
||||
)
|
||||
async def test_tts_service_speak_lang_config(
|
||||
setup: AsyncMock,
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
calls: list[ServiceCall],
|
||||
tts_service: str,
|
||||
service_data: dict[str, Any],
|
||||
) -> None:
|
||||
"""Test service call say with other langcodes in the config."""
|
||||
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
|
||||
tts_entity._client.generate.reset_mock()
|
||||
|
||||
await hass.services.async_call(
|
||||
tts.DOMAIN,
|
||||
tts_service,
|
||||
service_data,
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
|
||||
tts_entity._client.generate.assert_called_once_with(
|
||||
text="There is a person at the front door.", voice="voice1", model="model1"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("setup", "tts_service", "service_data"),
|
||||
[
|
||||
(
|
||||
"mock_config_entry_setup",
|
||||
"speak",
|
||||
{
|
||||
ATTR_ENTITY_ID: "tts.mock_title",
|
||||
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
|
||||
tts.ATTR_MESSAGE: "There is a person at the front door.",
|
||||
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice1"},
|
||||
},
|
||||
),
|
||||
],
|
||||
indirect=["setup"],
|
||||
)
|
||||
async def test_tts_service_speak_error(
|
||||
setup: AsyncMock,
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
calls: list[ServiceCall],
|
||||
tts_service: str,
|
||||
service_data: dict[str, Any],
|
||||
) -> None:
|
||||
"""Test service call say with http response 400."""
|
||||
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
|
||||
tts_entity._client.generate.reset_mock()
|
||||
tts_entity._client.generate.side_effect = ApiError
|
||||
|
||||
await hass.services.async_call(
|
||||
tts.DOMAIN,
|
||||
tts_service,
|
||||
service_data,
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.NOT_FOUND
|
||||
)
|
||||
|
||||
tts_entity._client.generate.assert_called_once_with(
|
||||
text="There is a person at the front door.", voice="voice1", model="model1"
|
||||
)
|
Loading…
x
Reference in New Issue
Block a user