Add WS API to tts (#91330)

* Add WS API to tts

* Use language util, change from entity_id to engine_id

* Fix rebase mistake
This commit is contained in:
Erik Montnemery 2023-04-17 22:52:19 +02:00 committed by GitHub
parent 2819ad9a16
commit e32dacc62d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 125 additions and 2 deletions

View File

@ -10,13 +10,14 @@ import logging
import mimetypes
import os
import re
from typing import TypedDict
from typing import Any, TypedDict
from aiohttp import web
import mutagen
from mutagen.id3 import ID3, TextFrame as ID3Text
import voluptuous as vol
from homeassistant.components import websocket_api
from homeassistant.components.http import HomeAssistantView
from homeassistant.const import PLATFORM_FORMAT
from homeassistant.core import HassJob, HomeAssistant, ServiceCall, callback
@ -24,6 +25,7 @@ from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.event import async_call_later
from homeassistant.helpers.network import get_url
from homeassistant.helpers.typing import ConfigType
from homeassistant.util import language as language_util
from .const import (
ATTR_CACHE,
@ -134,6 +136,9 @@ async def async_get_media_source_audio(
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up TTS."""
websocket_api.async_register_command(hass, websocket_list_engines)
websocket_api.async_register_command(hass, websocket_list_engine_voices)
# Legacy config options
conf = config[DOMAIN][0] if config.get(DOMAIN) else {}
use_cache: bool = conf.get(CONF_CACHE, DEFAULT_CACHE)
@ -685,3 +690,60 @@ class TextToSpeechView(HomeAssistantView):
def get_base_url(hass: HomeAssistant) -> str:
"""Get base URL."""
return hass.data[BASE_URL_KEY] or get_url(hass)
@websocket_api.websocket_command(
{
"type": "tts/engine/list",
vol.Optional("language"): str,
}
)
@callback
def websocket_list_engines(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""List text to speech engines and, optionally, if they support a given language."""
manager: SpeechManager = hass.data[DOMAIN]
language = msg.get("language")
providers = []
for engine_id, provider in manager.providers.items():
provider_info: dict[str, Any] = {"engine_id": engine_id}
if language:
provider_info["language_supported"] = bool(
language_util.matches(language, provider.supported_languages)
)
providers.append(provider_info)
connection.send_message(
websocket_api.result_message(msg["id"], {"providers": providers})
)
@websocket_api.websocket_command(
{
"type": "tts/engine/voices",
vol.Required("engine_id"): str,
vol.Required("language"): str,
}
)
@callback
def websocket_list_engine_voices(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""List voices for a given language."""
voices = {
"voices": [
# placeholder until TTS refactoring
{
"voice_id": "voice_1",
"name": "James Earl Jones",
},
{
"voice_id": "voice_2",
"name": "Fran Drescher",
},
]
}
connection.send_message(websocket_api.result_message(msg["id"], voices))

View File

@ -32,7 +32,7 @@ from tests.common import (
mock_integration,
mock_platform,
)
from tests.typing import ClientSessionGenerator
from tests.typing import ClientSessionGenerator, WebSocketGenerator
ORIG_WRITE_TAGS = tts.SpeechManager.write_tags
@ -1018,3 +1018,64 @@ async def test_fetching_in_async(
"mp3",
b"test 2",
)
async def test_ws_list_engines(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, setup_tts
) -> None:
"""Test streaming audio and getting response."""
client = await hass_ws_client()
await client.send_json_auto_id({"type": "tts/engine/list"})
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"providers": [{"engine_id": "test"}]}
await client.send_json_auto_id({"type": "tts/engine/list", "language": "smurfish"})
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"providers": [{"engine_id": "test", "language_supported": False}]
}
await client.send_json_auto_id({"type": "tts/engine/list", "language": "en"})
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"providers": [{"engine_id": "test", "language_supported": True}]
}
await client.send_json_auto_id({"type": "tts/engine/list", "language": "en-UK"})
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"providers": [{"engine_id": "test", "language_supported": True}]
}
async def test_ws_list_voices(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, setup_tts
) -> None:
"""Test streaming audio and getting response."""
client = await hass_ws_client()
await client.send_json_auto_id(
{
"type": "tts/engine/voices",
"engine_id": "smurf",
"language": "smurfish",
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"voices": [
{"voice_id": "voice_1", "name": "James Earl Jones"},
{"voice_id": "voice_2", "name": "Fran Drescher"},
]
}