Add language scores websocket command

This commit is contained in:
Michael Hansen 2025-03-12 16:54:54 -05:00
parent 358f78c7cd
commit 222330e7c5
3 changed files with 339 additions and 0 deletions

View File

@ -3,11 +3,13 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable from collections.abc import Iterable
from dataclasses import asdict
from typing import Any from typing import Any
from aiohttp import web from aiohttp import web
from hassil.recognize import MISSING_ENTITY, RecognizeResult from hassil.recognize import MISSING_ENTITY, RecognizeResult
from hassil.string_matcher import UnmatchedRangeEntity, UnmatchedTextEntity from hassil.string_matcher import UnmatchedRangeEntity, UnmatchedTextEntity
from home_assistant_intents import get_language_scores
import voluptuous as vol import voluptuous as vol
from homeassistant.components import http, websocket_api from homeassistant.components import http, websocket_api
@ -38,6 +40,7 @@ def async_setup(hass: HomeAssistant) -> None:
websocket_api.async_register_command(hass, websocket_list_agents) websocket_api.async_register_command(hass, websocket_list_agents)
websocket_api.async_register_command(hass, websocket_list_sentences) websocket_api.async_register_command(hass, websocket_list_sentences)
websocket_api.async_register_command(hass, websocket_hass_agent_debug) websocket_api.async_register_command(hass, websocket_hass_agent_debug)
websocket_api.async_register_command(hass, websocket_hass_agent_language_scores)
@websocket_api.websocket_command( @websocket_api.websocket_command(
@ -336,6 +339,21 @@ def _get_unmatched_slots(
return unmatched_slots return unmatched_slots
@websocket_api.websocket_command(
{vol.Required("type"): "conversation/agent/homeassistant/language_scores"}
)
@websocket_api.async_response
async def websocket_hass_agent_language_scores(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Get support scores per language."""
scores = await hass.async_add_executor_job(get_language_scores)
result = {lang_key: asdict(lang_scores) for lang_key, lang_scores in scores.items()}
connection.send_result(msg["id"], result)
class ConversationProcessView(http.HomeAssistantView): class ConversationProcessView(http.HomeAssistantView):
"""View to process text.""" """View to process text."""

View File

@ -729,3 +729,302 @@
]), ]),
}) })
# --- # ---
# name: test_ws_hass_language_scores
dict({
'af_ZA': dict({
'cloud': 0,
'focused_local': 0,
'full_local': 0,
}),
'ar_JO': dict({
'cloud': 1,
'focused_local': 0,
'full_local': 0,
}),
'bg_BG': dict({
'cloud': 3,
'focused_local': 0,
'full_local': 0,
}),
'bn_BD': dict({
'cloud': 0,
'focused_local': 0,
'full_local': 0,
}),
'bn_IN': dict({
'cloud': 1,
'focused_local': 0,
'full_local': 0,
}),
'ca_ES': dict({
'cloud': 3,
'focused_local': 0,
'full_local': 0,
}),
'cs_CZ': dict({
'cloud': 1,
'focused_local': 0,
'full_local': 0,
}),
'da_DK': dict({
'cloud': 3,
'focused_local': 0,
'full_local': 0,
}),
'de_CH': dict({
'cloud': 3,
'focused_local': 0,
'full_local': 0,
}),
'de_DE': dict({
'cloud': 3,
'focused_local': 2,
'full_local': 3,
}),
'el_GR': dict({
'cloud': 1,
'focused_local': 0,
'full_local': 0,
}),
'en_US': dict({
'cloud': 3,
'focused_local': 2,
'full_local': 3,
}),
'es_MX': dict({
'cloud': 3,
'focused_local': 2,
'full_local': 3,
}),
'et_EE': dict({
'cloud': 1,
'focused_local': 0,
'full_local': 0,
}),
'eu_ES': dict({
'cloud': 1,
'focused_local': 0,
'full_local': 0,
}),
'fa_IR': dict({
'cloud': 1,
'focused_local': 0,
'full_local': 0,
}),
'fi_FI': dict({
'cloud': 3,
'focused_local': 0,
'full_local': 0,
}),
'fr_FR': dict({
'cloud': 3,
'focused_local': 2,
'full_local': 0,
}),
'gl_ES': dict({
'cloud': 1,
'focused_local': 0,
'full_local': 0,
}),
'gu_IN': dict({
'cloud': 0,
'focused_local': 0,
'full_local': 0,
}),
'he_IL': dict({
'cloud': 1,
'focused_local': 0,
'full_local': 0,
}),
'hi_IN': dict({
'cloud': 1,
'focused_local': 0,
'full_local': 0,
}),
'hr_HR': dict({
'cloud': 3,
'focused_local': 0,
'full_local': 0,
}),
'hu_HU': dict({
'cloud': 3,
'focused_local': 0,
'full_local': 0,
}),
'id_ID': dict({
'cloud': 1,
'focused_local': 0,
'full_local': 0,
}),
'is_IS': dict({
'cloud': 1,
'focused_local': 0,
'full_local': 0,
}),
'it_IT': dict({
'cloud': 3,
'focused_local': 2,
'full_local': 3,
}),
'ka_GE': dict({
'cloud': 1,
'focused_local': 0,
'full_local': 0,
}),
'kn_IN': dict({
'cloud': 0,
'focused_local': 0,
'full_local': 0,
}),
'ko_KR': dict({
'cloud': 1,
'focused_local': 0,
'full_local': 0,
}),
'lb_LU': dict({
'cloud': 0,
'focused_local': 0,
'full_local': 0,
}),
'lt_LT': dict({
'cloud': 1,
'focused_local': 0,
'full_local': 0,
}),
'lv_LV': dict({
'cloud': 1,
'focused_local': 0,
'full_local': 0,
}),
'ml_IN': dict({
'cloud': 1,
'focused_local': 0,
'full_local': 0,
}),
'mn_MN': dict({
'cloud': 1,
'focused_local': 0,
'full_local': 0,
}),
'mr_IN': dict({
'cloud': 0,
'focused_local': 0,
'full_local': 0,
}),
'ms_MY': dict({
'cloud': 1,
'focused_local': 0,
'full_local': 0,
}),
'nb_NO': dict({
'cloud': 3,
'focused_local': 0,
'full_local': 0,
}),
'ne_NP': dict({
'cloud': 1,
'focused_local': 0,
'full_local': 0,
}),
'nl_BE': dict({
'cloud': 3,
'focused_local': 0,
'full_local': 0,
}),
'nl_NL': dict({
'cloud': 3,
'focused_local': 2,
'full_local': 0,
}),
'pl_PL': dict({
'cloud': 3,
'focused_local': 0,
'full_local': 0,
}),
'pt_BR': dict({
'cloud': 3,
'focused_local': 0,
'full_local': 3,
}),
'pt_PT': dict({
'cloud': 3,
'focused_local': 0,
'full_local': 3,
}),
'ro_RO': dict({
'cloud': 3,
'focused_local': 0,
'full_local': 0,
}),
'ru_RU': dict({
'cloud': 1,
'focused_local': 0,
'full_local': 0,
}),
'sk_SK': dict({
'cloud': 1,
'focused_local': 0,
'full_local': 0,
}),
'sl_SI': dict({
'cloud': 3,
'focused_local': 0,
'full_local': 0,
}),
'sr_RS': dict({
'cloud': 3,
'focused_local': 0,
'full_local': 0,
}),
'sv_SE': dict({
'cloud': 3,
'focused_local': 0,
'full_local': 0,
}),
'sw_KE': dict({
'cloud': 1,
'focused_local': 0,
'full_local': 0,
}),
'te_IN': dict({
'cloud': 1,
'focused_local': 0,
'full_local': 0,
}),
'th_TH': dict({
'cloud': 1,
'focused_local': 0,
'full_local': 0,
}),
'tr_TR': dict({
'cloud': 1,
'focused_local': 0,
'full_local': 0,
}),
'uk_UA': dict({
'cloud': 1,
'focused_local': 0,
'full_local': 0,
}),
'ur_IN': dict({
'cloud': 0,
'focused_local': 0,
'full_local': 0,
}),
'vi_VN': dict({
'cloud': 1,
'focused_local': 0,
'full_local': 0,
}),
'zh_CN': dict({
'cloud': 1,
'focused_local': 0,
'full_local': 0,
}),
'zh_HK': dict({
'cloud': 3,
'focused_local': 0,
'full_local': 0,
}),
})
# ---

View File

@ -536,3 +536,25 @@ async def test_ws_hass_agent_debug_sentence_trigger(
# Trigger should not have been executed # Trigger should not have been executed
assert len(calls) == 0 assert len(calls) == 0
async def test_ws_hass_language_scores(
hass: HomeAssistant,
init_components,
hass_ws_client: WebSocketGenerator,
snapshot: SnapshotAssertion,
) -> None:
"""Test getting language support scores."""
client = await hass_ws_client(hass)
await client.send_json_auto_id(
{"type": "conversation/agent/homeassistant/language_scores"}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == snapshot
# Sanity check
assert msg["result"]["en_US"] == {"cloud": 3, "focused_local": 2, "full_local": 3}