Add WS API for listing languages supported by a full assist pipeline (#91669)

* Add WS API for listing languages supported by a full assist pipeline

* Address review comments, change logic
This commit is contained in:
Erik Montnemery 2023-04-20 14:55:17 +02:00 committed by GitHub
parent 03dcb915e3
commit 6d619579b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 125 additions and 2 deletions

View File

@ -8,9 +8,11 @@ from typing import Any
import async_timeout import async_timeout
import voluptuous as vol import voluptuous as vol
from homeassistant.components import stt, websocket_api from homeassistant.components import conversation, stt, tts, websocket_api
from homeassistant.const import MATCH_ALL
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.util import language as language_util
from .const import DOMAIN from .const import DOMAIN
from .pipeline import ( from .pipeline import (
@ -34,6 +36,7 @@ _LOGGER = logging.getLogger(__name__)
def async_register_websocket_api(hass: HomeAssistant) -> None: def async_register_websocket_api(hass: HomeAssistant) -> None:
"""Register the websocket API.""" """Register the websocket API."""
websocket_api.async_register_command(hass, websocket_run) websocket_api.async_register_command(hass, websocket_run)
websocket_api.async_register_command(hass, websocket_list_languages)
websocket_api.async_register_command(hass, websocket_list_runs) websocket_api.async_register_command(hass, websocket_list_runs)
websocket_api.async_register_command(hass, websocket_get_run) websocket_api.async_register_command(hass, websocket_get_run)
@ -271,3 +274,58 @@ def websocket_get_run(
msg["id"], msg["id"],
{"events": pipeline_runs[pipeline_run_id].events}, {"events": pipeline_runs[pipeline_run_id].events},
) )
@callback
@websocket_api.websocket_command(
{
vol.Required("type"): "assist_pipeline/language/list",
}
)
@websocket_api.async_response
async def websocket_list_languages(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""List languages which are supported by a complete pipeline.
This will return a list of languages which are supported by at least one stt, tts
and conversation engine respectively.
"""
conv_language_tags = await conversation.async_get_conversation_languages(hass)
stt_language_tags = stt.async_get_speech_to_text_languages(hass)
tts_language_tags = tts.async_get_text_to_speech_languages(hass)
pipeline_languages: set[str] | None = None
if conv_language_tags and conv_language_tags != MATCH_ALL:
languages = set()
for language_tag in conv_language_tags:
dialect = language_util.Dialect.parse(language_tag)
languages.add(dialect.language)
pipeline_languages = languages
if stt_language_tags:
languages = set()
for language_tag in stt_language_tags:
dialect = language_util.Dialect.parse(language_tag)
languages.add(dialect.language)
if pipeline_languages is not None:
pipeline_languages &= languages
else:
pipeline_languages = languages
if tts_language_tags:
languages = set()
for language_tag in tts_language_tags:
dialect = language_util.Dialect.parse(language_tag)
languages.add(dialect.language)
if pipeline_languages is not None:
pipeline_languages &= languages
else:
pipeline_languages = languages
connection.send_result(
msg["id"],
{"languages": pipeline_languages},
)

View File

@ -5,7 +5,7 @@ import asyncio
from dataclasses import dataclass from dataclasses import dataclass
import logging import logging
import re import re
from typing import Any from typing import Any, Literal
import voluptuous as vol import voluptuous as vol
@ -13,6 +13,7 @@ from homeassistant import core
from homeassistant.components import http, websocket_api from homeassistant.components import http, websocket_api
from homeassistant.components.http.data_validator import RequestDataValidator from homeassistant.components.http.data_validator import RequestDataValidator
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import MATCH_ALL
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv, intent, singleton from homeassistant.helpers import config_validation as cv, intent, singleton
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
@ -115,6 +116,23 @@ def async_unset_agent(
_get_agent_manager(hass).async_unset_agent(config_entry.entry_id) _get_agent_manager(hass).async_unset_agent(config_entry.entry_id)
async def async_get_conversation_languages(
hass: HomeAssistant,
) -> set[str] | Literal["*"]:
"""Return a set with the union of languages supported by conversation agents."""
agent_manager = _get_agent_manager(hass)
languages = set()
for agent_info in agent_manager.async_get_agent_info():
agent = await agent_manager.async_get_agent(agent_info.id)
if agent.supported_languages == MATCH_ALL:
return MATCH_ALL
for language_tag in agent.supported_languages:
languages.add(language_tag)
return languages
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Register the process service.""" """Register the process service."""
agent_manager = _get_agent_manager(hass) agent_manager = _get_agent_manager(hass)

View File

@ -74,6 +74,24 @@ def async_get_speech_to_text_entity(
return component.get_entity(entity_id) return component.get_entity(entity_id)
@callback
def async_get_speech_to_text_languages(hass: HomeAssistant) -> set[str]:
"""Return a set with the union of languages supported by stt engines."""
languages = set()
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
legacy_providers: dict[str, Provider] = hass.data[DATA_PROVIDERS]
for entity in component.entities:
for language_tag in entity.supported_languages:
languages.add(language_tag)
for engine in legacy_providers.values():
for language_tag in engine.supported_languages:
languages.add(language_tag)
return languages
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up STT.""" """Set up STT."""
websocket_api.async_register_command(hass, websocket_list_engines) websocket_api.async_register_command(hass, websocket_list_engines)

View File

@ -134,6 +134,19 @@ async def async_get_media_source_audio(
) )
@callback
def async_get_text_to_speech_languages(hass: HomeAssistant) -> set[str]:
"""Return a set with the union of languages supported by tts engines."""
languages = set()
manager: SpeechManager = hass.data[DOMAIN]
for tts_engine in manager.providers.values():
for language_tag in tts_engine.supported_languages:
languages.add(language_tag)
return languages
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up TTS.""" """Set up TTS."""
websocket_api.async_register_command(hass, websocket_list_engines) websocket_api.async_register_command(hass, websocket_list_engines)

View File

@ -1130,3 +1130,19 @@ async def test_pipeline_debug_get_run_wrong_pipeline_run(
"code": "not_found", "code": "not_found",
"message": "pipeline_run_id blah not found", "message": "pipeline_run_id blah not found",
} }
async def test_list_pipeline_languages(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
) -> None:
"""Test listing pipeline languages."""
client = await hass_ws_client(hass)
await client.send_json_auto_id({"type": "assist_pipeline/language/list"})
# result
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"languages": ["en"]}