From f3e6d6dfc0fe520528c204ae7f4840c9e0898880 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Wed, 19 Apr 2023 13:47:49 +0200 Subject: [PATCH] Add async_get_supported_voices to tts.Provider (#91649) * Add async_get_supported_voices to tts.Provider * Update WS API --- homeassistant/components/cloud/tts.py | 6 ++++ homeassistant/components/tts/__init__.py | 28 +++++++++--------- homeassistant/components/tts/legacy.py | 7 ++++- tests/components/cloud/test_tts.py | 1 + tests/components/tts/common.py | 9 +++++- tests/components/tts/test_init.py | 36 +++++++++++++++++++----- 6 files changed, 65 insertions(+), 22 deletions(-) diff --git a/homeassistant/components/cloud/tts.py b/homeassistant/components/cloud/tts.py index 438bfa580d7..a9c607d27cd 100644 --- a/homeassistant/components/cloud/tts.py +++ b/homeassistant/components/cloud/tts.py @@ -10,6 +10,7 @@ from homeassistant.components.tts import ( PLATFORM_SCHEMA, Provider, ) +from homeassistant.core import callback from .const import DOMAIN @@ -95,6 +96,11 @@ class CloudProvider(Provider): """Return list of supported options like voice, emotion.""" return [ATTR_GENDER, ATTR_VOICE, ATTR_AUDIO_OUTPUT] + @callback + def async_get_supported_voices(self, language: str) -> list[str] | None: + """Return a list of supported voices for a language.""" + return TTS_VOICES.get(language) + @property def default_options(self): """Return a dict include default options.""" diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index 27b9a754000..b237fd6a069 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -739,18 +739,20 @@ def websocket_list_engine_voices( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict ) -> None: """List voices for a given language.""" - voices = { - "voices": [ - # placeholder until TTS refactoring - { - "voice_id": "voice_1", - "name": "James Earl Jones", - }, - { - "voice_id": "voice_2", - "name": "Fran Drescher", - }, - ] - } + engine_id = msg["engine_id"] + language = msg["language"] + + manager: SpeechManager = hass.data[DOMAIN] + engine = manager.providers.get(engine_id) + + if not engine: + connection.send_error( + msg["id"], + websocket_api.const.ERR_NOT_FOUND, + f"tts engine {engine_id} not found", + ) + return + + voices = {"voices": engine.async_get_supported_voices(language)} connection.send_message(websocket_api.result_message(msg["id"], voices)) diff --git a/homeassistant/components/tts/legacy.py b/homeassistant/components/tts/legacy.py index 1f21d249504..aa98ed49a6e 100644 --- a/homeassistant/components/tts/legacy.py +++ b/homeassistant/components/tts/legacy.py @@ -25,7 +25,7 @@ from homeassistant.const import ( CONF_NAME, CONF_PLATFORM, ) -from homeassistant.core import HomeAssistant, ServiceCall +from homeassistant.core import HomeAssistant, ServiceCall, callback from homeassistant.helpers import config_per_platform, discovery import homeassistant.helpers.config_validation as cv from homeassistant.helpers.service import async_set_service_schema @@ -227,6 +227,11 @@ class Provider: """Return a list of supported options like voice, emotions.""" return None + @callback + def async_get_supported_voices(self, language: str) -> list[str] | None: + """Return a list of supported voices for a language.""" + return None + @property def default_options(self) -> Mapping[str, Any] | None: """Return a mapping with the default options.""" diff --git a/tests/components/cloud/test_tts.py b/tests/components/cloud/test_tts.py index 39de6e1536d..be023cd0d57 100644 --- a/tests/components/cloud/test_tts.py +++ b/tests/components/cloud/test_tts.py @@ -78,6 +78,7 @@ async def test_provider_properties(cloud_with_prefs) -> None: ) assert provider.supported_options == ["gender", "voice", "audio_output"] assert "nl-NL" in provider.supported_languages + assert "ColetteNeural" in provider.async_get_supported_voices("nl-NL") async def test_get_tts_audio(cloud_with_prefs) -> None: diff --git a/tests/components/tts/common.py b/tests/components/tts/common.py index 9e175089fc8..6e77c89177e 100644 --- a/tests/components/tts/common.py +++ b/tests/components/tts/common.py @@ -11,7 +11,7 @@ from homeassistant.components.tts import ( Provider, TtsAudioType, ) -from homeassistant.core import HomeAssistant +from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from tests.common import MockPlatform @@ -40,6 +40,13 @@ class MockProvider(Provider): """Return list of supported languages.""" return SUPPORT_LANGUAGES + @callback + def async_get_supported_voices(self, language: str) -> list[str] | None: + """Return list of supported languages.""" + if language == "en-US": + return ["James Earl Jones", "Fran Drescher"] + return None + @property def supported_options(self) -> list[str]: """Return list of supported options like voice, emotions.""" diff --git a/tests/components/tts/test_init.py b/tests/components/tts/test_init.py index bc176f08fda..9ab71f2512f 100644 --- a/tests/components/tts/test_init.py +++ b/tests/components/tts/test_init.py @@ -1077,16 +1077,38 @@ async def test_ws_list_voices( await client.send_json_auto_id( { "type": "tts/engine/voices", - "engine_id": "smurf", + "engine_id": "smurf_tts", + "language": "smurfish", + } + ) + + msg = await client.receive_json() + assert not msg["success"] + assert msg["error"] == { + "code": "not_found", + "message": "tts engine smurf_tts not found", + } + + await client.send_json_auto_id( + { + "type": "tts/engine/voices", + "engine_id": "test", "language": "smurfish", } ) msg = await client.receive_json() assert msg["success"] - assert msg["result"] == { - "voices": [ - {"voice_id": "voice_1", "name": "James Earl Jones"}, - {"voice_id": "voice_2", "name": "Fran Drescher"}, - ] - } + assert msg["result"] == {"voices": None} + + await client.send_json_auto_id( + { + "type": "tts/engine/voices", + "engine_id": "test", + "language": "en-US", + } + ) + + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == {"voices": ["James Earl Jones", "Fran Drescher"]}