diff --git a/homeassistant/components/stt/__init__.py b/homeassistant/components/stt/__init__.py index 83492f5f6bc..e0ab446153c 100644 --- a/homeassistant/components/stt/__init__.py +++ b/homeassistant/components/stt/__init__.py @@ -15,7 +15,9 @@ from aiohttp.web_exceptions import ( HTTPNotFound, HTTPUnsupportedMediaType, ) +import voluptuous as vol +from homeassistant.components import websocket_api from homeassistant.components.http import HomeAssistantView from homeassistant.config_entries import ConfigEntry from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN @@ -23,7 +25,7 @@ from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.typing import ConfigType -from homeassistant.util import dt as dt_util +from homeassistant.util import dt as dt_util, language as language_util from .const import ( DATA_PROVIDERS, @@ -74,6 +76,8 @@ def async_get_speech_to_text_entity( async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up STT.""" + websocket_api.async_register_command(hass, websocket_list_engines) + component = hass.data[DOMAIN] = EntityComponent[SpeechToTextEntity]( _LOGGER, DOMAIN, hass ) @@ -376,3 +380,42 @@ def _metadata_from_header(request: web.Request) -> SpeechMetadata: ) except ValueError as err: raise ValueError(f"Wrong format of X-Speech-Content: {err}") from err + + +@websocket_api.websocket_command( + { + "type": "stt/engine/list", + vol.Optional("language"): str, + } +) +@callback +def websocket_list_engines( + hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict +) -> None: + """List speech to text engines and, optionally, if they support a given language.""" + component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN] + legacy_providers: dict[str, Provider] = hass.data[DATA_PROVIDERS] + + language = msg.get("language") + providers = [] + provider_info: dict[str, Any] + + for entity in component.entities: + provider_info = {"engine_id": entity.entity_id} + if language: + provider_info["language_supported"] = bool( + language_util.matches(language, entity.supported_languages) + ) + providers.append(provider_info) + + for engine_id, provider in legacy_providers.items(): + provider_info = {"engine_id": engine_id} + if language: + provider_info["language_supported"] = bool( + language_util.matches(language, provider.supported_languages) + ) + providers.append(provider_info) + + connection.send_message( + websocket_api.result_message(msg["id"], {"providers": providers}) + ) diff --git a/tests/components/stt/test_init.py b/tests/components/stt/test_init.py index 483037e7ee1..fd3dbba0e1f 100644 --- a/tests/components/stt/test_init.py +++ b/tests/components/stt/test_init.py @@ -35,7 +35,7 @@ from tests.common import ( mock_platform, mock_restore_cache, ) -from tests.typing import ClientSessionGenerator +from tests.typing import ClientSessionGenerator, WebSocketGenerator TEST_DOMAIN = "test" @@ -375,3 +375,48 @@ async def test_restore_state( state = hass.states.get(entity_id) assert state assert state.state == timestamp + + +@pytest.mark.parametrize( + ("setup", "engine_id"), + [("mock_setup", "test"), ("mock_config_entry_setup", "stt.test")], + indirect=["setup"], +) +async def test_ws_list_engines( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + setup: str, + engine_id: str, +) -> None: + """Test listing speech to text engines.""" + client = await hass_ws_client() + + await client.send_json_auto_id({"type": "stt/engine/list"}) + + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == {"providers": [{"engine_id": engine_id}]} + + await client.send_json_auto_id({"type": "stt/engine/list", "language": "smurfish"}) + + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == { + "providers": [{"engine_id": engine_id, "language_supported": False}] + } + + await client.send_json_auto_id({"type": "stt/engine/list", "language": "en"}) + + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == { + "providers": [{"engine_id": engine_id, "language_supported": True}] + } + + await client.send_json_auto_id({"type": "stt/engine/list", "language": "en-UK"}) + + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == { + "providers": [{"engine_id": engine_id, "language_supported": True}] + }