mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 11:17:21 +00:00
Add WS API to stt
(#91329)
This commit is contained in:
parent
e3ff7d048a
commit
b5817e40f7
@ -15,7 +15,9 @@ from aiohttp.web_exceptions import (
|
|||||||
HTTPNotFound,
|
HTTPNotFound,
|
||||||
HTTPUnsupportedMediaType,
|
HTTPUnsupportedMediaType,
|
||||||
)
|
)
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant.components import websocket_api
|
||||||
from homeassistant.components.http import HomeAssistantView
|
from homeassistant.components.http import HomeAssistantView
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
|
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
|
||||||
@ -23,7 +25,7 @@ from homeassistant.core import HomeAssistant, callback
|
|||||||
from homeassistant.helpers.entity_component import EntityComponent
|
from homeassistant.helpers.entity_component import EntityComponent
|
||||||
from homeassistant.helpers.restore_state import RestoreEntity
|
from homeassistant.helpers.restore_state import RestoreEntity
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util, language as language_util
|
||||||
|
|
||||||
from .const import (
|
from .const import (
|
||||||
DATA_PROVIDERS,
|
DATA_PROVIDERS,
|
||||||
@ -74,6 +76,8 @@ def async_get_speech_to_text_entity(
|
|||||||
|
|
||||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
"""Set up STT."""
|
"""Set up STT."""
|
||||||
|
websocket_api.async_register_command(hass, websocket_list_engines)
|
||||||
|
|
||||||
component = hass.data[DOMAIN] = EntityComponent[SpeechToTextEntity](
|
component = hass.data[DOMAIN] = EntityComponent[SpeechToTextEntity](
|
||||||
_LOGGER, DOMAIN, hass
|
_LOGGER, DOMAIN, hass
|
||||||
)
|
)
|
||||||
@ -376,3 +380,42 @@ def _metadata_from_header(request: web.Request) -> SpeechMetadata:
|
|||||||
)
|
)
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
raise ValueError(f"Wrong format of X-Speech-Content: {err}") from err
|
raise ValueError(f"Wrong format of X-Speech-Content: {err}") from err
|
||||||
|
|
||||||
|
|
||||||
|
@websocket_api.websocket_command(
|
||||||
|
{
|
||||||
|
"type": "stt/engine/list",
|
||||||
|
vol.Optional("language"): str,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@callback
|
||||||
|
def websocket_list_engines(
|
||||||
|
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||||
|
) -> None:
|
||||||
|
"""List speech to text engines and, optionally, if they support a given language."""
|
||||||
|
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
|
||||||
|
legacy_providers: dict[str, Provider] = hass.data[DATA_PROVIDERS]
|
||||||
|
|
||||||
|
language = msg.get("language")
|
||||||
|
providers = []
|
||||||
|
provider_info: dict[str, Any]
|
||||||
|
|
||||||
|
for entity in component.entities:
|
||||||
|
provider_info = {"engine_id": entity.entity_id}
|
||||||
|
if language:
|
||||||
|
provider_info["language_supported"] = bool(
|
||||||
|
language_util.matches(language, entity.supported_languages)
|
||||||
|
)
|
||||||
|
providers.append(provider_info)
|
||||||
|
|
||||||
|
for engine_id, provider in legacy_providers.items():
|
||||||
|
provider_info = {"engine_id": engine_id}
|
||||||
|
if language:
|
||||||
|
provider_info["language_supported"] = bool(
|
||||||
|
language_util.matches(language, provider.supported_languages)
|
||||||
|
)
|
||||||
|
providers.append(provider_info)
|
||||||
|
|
||||||
|
connection.send_message(
|
||||||
|
websocket_api.result_message(msg["id"], {"providers": providers})
|
||||||
|
)
|
||||||
|
@ -35,7 +35,7 @@ from tests.common import (
|
|||||||
mock_platform,
|
mock_platform,
|
||||||
mock_restore_cache,
|
mock_restore_cache,
|
||||||
)
|
)
|
||||||
from tests.typing import ClientSessionGenerator
|
from tests.typing import ClientSessionGenerator, WebSocketGenerator
|
||||||
|
|
||||||
TEST_DOMAIN = "test"
|
TEST_DOMAIN = "test"
|
||||||
|
|
||||||
@ -375,3 +375,48 @@ async def test_restore_state(
|
|||||||
state = hass.states.get(entity_id)
|
state = hass.states.get(entity_id)
|
||||||
assert state
|
assert state
|
||||||
assert state.state == timestamp
|
assert state.state == timestamp
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("setup", "engine_id"),
|
||||||
|
[("mock_setup", "test"), ("mock_config_entry_setup", "stt.test")],
|
||||||
|
indirect=["setup"],
|
||||||
|
)
|
||||||
|
async def test_ws_list_engines(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
setup: str,
|
||||||
|
engine_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test listing speech to text engines."""
|
||||||
|
client = await hass_ws_client()
|
||||||
|
|
||||||
|
await client.send_json_auto_id({"type": "stt/engine/list"})
|
||||||
|
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] == {"providers": [{"engine_id": engine_id}]}
|
||||||
|
|
||||||
|
await client.send_json_auto_id({"type": "stt/engine/list", "language": "smurfish"})
|
||||||
|
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] == {
|
||||||
|
"providers": [{"engine_id": engine_id, "language_supported": False}]
|
||||||
|
}
|
||||||
|
|
||||||
|
await client.send_json_auto_id({"type": "stt/engine/list", "language": "en"})
|
||||||
|
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] == {
|
||||||
|
"providers": [{"engine_id": engine_id, "language_supported": True}]
|
||||||
|
}
|
||||||
|
|
||||||
|
await client.send_json_auto_id({"type": "stt/engine/list", "language": "en-UK"})
|
||||||
|
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] == {
|
||||||
|
"providers": [{"engine_id": engine_id, "language_supported": True}]
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user