From 768c499b6fa35efb9b1c8186b7d7d0cc59ad84e6 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Thu, 20 Apr 2023 14:57:48 +0200 Subject: [PATCH] Include matching languages in WS stt/engine/list (#91731) * Include matching languages in WS stt/engine/list * Allow specifying country --- homeassistant/components/stt/__init__.py | 20 +++++++++----- tests/components/stt/test_init.py | 34 +++++++++++++++++++----- 2 files changed, 42 insertions(+), 12 deletions(-) diff --git a/homeassistant/components/stt/__init__.py b/homeassistant/components/stt/__init__.py index 5c9c2136aa8..ef217aed2a2 100644 --- a/homeassistant/components/stt/__init__.py +++ b/homeassistant/components/stt/__init__.py @@ -413,6 +413,7 @@ def _metadata_from_header(request: web.Request) -> SpeechMetadata: { "type": "stt/engine/list", vol.Optional("language"): str, + vol.Optional("country"): str, } ) @callback @@ -423,23 +424,30 @@ def websocket_list_engines( component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN] legacy_providers: dict[str, Provider] = hass.data[DATA_PROVIDERS] + country = msg.get("country") language = msg.get("language") providers = [] provider_info: dict[str, Any] for entity in component.entities: - provider_info = {"engine_id": entity.entity_id} + provider_info = { + "engine_id": entity.entity_id, + "supported_languages": entity.supported_languages, + } if language: - provider_info["language_supported"] = bool( - language_util.matches(language, entity.supported_languages) + provider_info["supported_languages"] = language_util.matches( + language, entity.supported_languages, country ) providers.append(provider_info) for engine_id, provider in legacy_providers.items(): - provider_info = {"engine_id": engine_id} + provider_info = { + "engine_id": engine_id, + "supported_languages": provider.supported_languages, + } if language: - provider_info["language_supported"] = bool( - language_util.matches(language, provider.supported_languages) + provider_info["supported_languages"] = language_util.matches( + language, provider.supported_languages, country ) providers.append(provider_info) diff --git a/tests/components/stt/test_init.py b/tests/components/stt/test_init.py index 5f3f491904a..cd2b7274387 100644 --- a/tests/components/stt/test_init.py +++ b/tests/components/stt/test_init.py @@ -52,7 +52,7 @@ class BaseProvider: @property def supported_languages(self) -> list[str]: """Return a list of supported languages.""" - return ["de-DE", "en-US"] + return ["de", "de-CH", "en-US"] @property def supported_formats(self) -> list[AudioFormats]: @@ -213,7 +213,7 @@ async def test_get_provider_info( response = await client.get(f"/api/stt/{TEST_DOMAIN}") assert response.status == HTTPStatus.OK assert await response.json() == { - "languages": ["de-DE", "en-US"], + "languages": ["de", "de-CH", "en-US"], "formats": ["wav", "ogg"], "codecs": ["pcm", "opus"], "sample_rates": [16000], @@ -398,14 +398,18 @@ async def test_ws_list_engines( msg = await client.receive_json() assert msg["success"] - assert msg["result"] == {"providers": [{"engine_id": engine_id}]} + assert msg["result"] == { + "providers": [ + {"engine_id": engine_id, "supported_languages": ["de", "de-CH", "en-US"]} + ] + } await client.send_json_auto_id({"type": "stt/engine/list", "language": "smurfish"}) msg = await client.receive_json() assert msg["success"] assert msg["result"] == { - "providers": [{"engine_id": engine_id, "language_supported": False}] + "providers": [{"engine_id": engine_id, "supported_languages": []}] } await client.send_json_auto_id({"type": "stt/engine/list", "language": "en"}) @@ -413,7 +417,7 @@ async def test_ws_list_engines( msg = await client.receive_json() assert msg["success"] assert msg["result"] == { - "providers": [{"engine_id": engine_id, "language_supported": True}] + "providers": [{"engine_id": engine_id, "supported_languages": ["en-US"]}] } await client.send_json_auto_id({"type": "stt/engine/list", "language": "en-UK"}) @@ -421,5 +425,23 @@ async def test_ws_list_engines( msg = await client.receive_json() assert msg["success"] assert msg["result"] == { - "providers": [{"engine_id": engine_id, "language_supported": True}] + "providers": [{"engine_id": engine_id, "supported_languages": ["en-US"]}] + } + + await client.send_json_auto_id({"type": "stt/engine/list", "language": "de"}) + msg = await client.receive_json() + assert msg["type"] == "result" + assert msg["success"] + assert msg["result"] == { + "providers": [{"engine_id": engine_id, "supported_languages": ["de", "de-CH"]}] + } + + await client.send_json_auto_id( + {"type": "stt/engine/list", "language": "de", "country": "ch"} + ) + msg = await client.receive_json() + assert msg["type"] == "result" + assert msg["success"] + assert msg["result"] == { + "providers": [{"engine_id": engine_id, "supported_languages": ["de", "de-CH"]}] }