mirror of
https://github.com/home-assistant/core.git
synced 2025-07-15 17:27:10 +00:00
Add WS API for listing languages supported by a full assist pipeline (#91669)
* Add WS API for listing languages supported by a full assist pipeline * Address review comments, change logic
This commit is contained in:
parent
03dcb915e3
commit
6d619579b4
@ -8,9 +8,11 @@ from typing import Any
|
||||
import async_timeout
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import stt, websocket_api
|
||||
from homeassistant.components import conversation, stt, tts, websocket_api
|
||||
from homeassistant.const import MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.util import language as language_util
|
||||
|
||||
from .const import DOMAIN
|
||||
from .pipeline import (
|
||||
@ -34,6 +36,7 @@ _LOGGER = logging.getLogger(__name__)
|
||||
def async_register_websocket_api(hass: HomeAssistant) -> None:
|
||||
"""Register the websocket API."""
|
||||
websocket_api.async_register_command(hass, websocket_run)
|
||||
websocket_api.async_register_command(hass, websocket_list_languages)
|
||||
websocket_api.async_register_command(hass, websocket_list_runs)
|
||||
websocket_api.async_register_command(hass, websocket_get_run)
|
||||
|
||||
@ -271,3 +274,58 @@ def websocket_get_run(
|
||||
msg["id"],
|
||||
{"events": pipeline_runs[pipeline_run_id].events},
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "assist_pipeline/language/list",
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
async def websocket_list_languages(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.connection.ActiveConnection,
|
||||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""List languages which are supported by a complete pipeline.
|
||||
|
||||
This will return a list of languages which are supported by at least one stt, tts
|
||||
and conversation engine respectively.
|
||||
"""
|
||||
conv_language_tags = await conversation.async_get_conversation_languages(hass)
|
||||
stt_language_tags = stt.async_get_speech_to_text_languages(hass)
|
||||
tts_language_tags = tts.async_get_text_to_speech_languages(hass)
|
||||
pipeline_languages: set[str] | None = None
|
||||
|
||||
if conv_language_tags and conv_language_tags != MATCH_ALL:
|
||||
languages = set()
|
||||
for language_tag in conv_language_tags:
|
||||
dialect = language_util.Dialect.parse(language_tag)
|
||||
languages.add(dialect.language)
|
||||
pipeline_languages = languages
|
||||
|
||||
if stt_language_tags:
|
||||
languages = set()
|
||||
for language_tag in stt_language_tags:
|
||||
dialect = language_util.Dialect.parse(language_tag)
|
||||
languages.add(dialect.language)
|
||||
if pipeline_languages is not None:
|
||||
pipeline_languages &= languages
|
||||
else:
|
||||
pipeline_languages = languages
|
||||
|
||||
if tts_language_tags:
|
||||
languages = set()
|
||||
for language_tag in tts_language_tags:
|
||||
dialect = language_util.Dialect.parse(language_tag)
|
||||
languages.add(dialect.language)
|
||||
if pipeline_languages is not None:
|
||||
pipeline_languages &= languages
|
||||
else:
|
||||
pipeline_languages = languages
|
||||
|
||||
connection.send_result(
|
||||
msg["id"],
|
||||
{"languages": pipeline_languages},
|
||||
)
|
||||
|
@ -5,7 +5,7 @@ import asyncio
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
@ -13,6 +13,7 @@ from homeassistant import core
|
||||
from homeassistant.components import http, websocket_api
|
||||
from homeassistant.components.http.data_validator import RequestDataValidator
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import config_validation as cv, intent, singleton
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
@ -115,6 +116,23 @@ def async_unset_agent(
|
||||
_get_agent_manager(hass).async_unset_agent(config_entry.entry_id)
|
||||
|
||||
|
||||
async def async_get_conversation_languages(
|
||||
hass: HomeAssistant,
|
||||
) -> set[str] | Literal["*"]:
|
||||
"""Return a set with the union of languages supported by conversation agents."""
|
||||
agent_manager = _get_agent_manager(hass)
|
||||
languages = set()
|
||||
|
||||
for agent_info in agent_manager.async_get_agent_info():
|
||||
agent = await agent_manager.async_get_agent(agent_info.id)
|
||||
if agent.supported_languages == MATCH_ALL:
|
||||
return MATCH_ALL
|
||||
for language_tag in agent.supported_languages:
|
||||
languages.add(language_tag)
|
||||
|
||||
return languages
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Register the process service."""
|
||||
agent_manager = _get_agent_manager(hass)
|
||||
|
@ -74,6 +74,24 @@ def async_get_speech_to_text_entity(
|
||||
return component.get_entity(entity_id)
|
||||
|
||||
|
||||
@callback
|
||||
def async_get_speech_to_text_languages(hass: HomeAssistant) -> set[str]:
|
||||
"""Return a set with the union of languages supported by stt engines."""
|
||||
languages = set()
|
||||
|
||||
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
|
||||
legacy_providers: dict[str, Provider] = hass.data[DATA_PROVIDERS]
|
||||
for entity in component.entities:
|
||||
for language_tag in entity.supported_languages:
|
||||
languages.add(language_tag)
|
||||
|
||||
for engine in legacy_providers.values():
|
||||
for language_tag in engine.supported_languages:
|
||||
languages.add(language_tag)
|
||||
|
||||
return languages
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up STT."""
|
||||
websocket_api.async_register_command(hass, websocket_list_engines)
|
||||
|
@ -134,6 +134,19 @@ async def async_get_media_source_audio(
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
def async_get_text_to_speech_languages(hass: HomeAssistant) -> set[str]:
|
||||
"""Return a set with the union of languages supported by tts engines."""
|
||||
languages = set()
|
||||
|
||||
manager: SpeechManager = hass.data[DOMAIN]
|
||||
for tts_engine in manager.providers.values():
|
||||
for language_tag in tts_engine.supported_languages:
|
||||
languages.add(language_tag)
|
||||
|
||||
return languages
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up TTS."""
|
||||
websocket_api.async_register_command(hass, websocket_list_engines)
|
||||
|
@ -1130,3 +1130,19 @@ async def test_pipeline_debug_get_run_wrong_pipeline_run(
|
||||
"code": "not_found",
|
||||
"message": "pipeline_run_id blah not found",
|
||||
}
|
||||
|
||||
|
||||
async def test_list_pipeline_languages(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
) -> None:
|
||||
"""Test listing pipeline languages."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
await client.send_json_auto_id({"type": "assist_pipeline/language/list"})
|
||||
|
||||
# result
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {"languages": ["en"]}
|
||||
|
Loading…
x
Reference in New Issue
Block a user