From 257686fcfeff4f382fbe932b66b4c9699654dc0c Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Wed, 11 Oct 2023 12:21:32 -0500 Subject: [PATCH] Dynamic wake word loading for Wyoming (#101827) * Change supported_wake_words property to async method * Add test * Add timeout + test --------- Co-authored-by: Paulus Schoutsen --- .../components/wake_word/__init__.py | 20 +++++-- homeassistant/components/wyoming/wake_word.py | 20 +++++-- tests/components/assist_pipeline/conftest.py | 5 +- tests/components/wake_word/test_init.py | 35 ++++++++++-- tests/components/wyoming/test_wake_word.py | 57 ++++++++++++++++++- 5 files changed, 120 insertions(+), 17 deletions(-) diff --git a/homeassistant/components/wake_word/__init__.py b/homeassistant/components/wake_word/__init__.py index 6c55bd8e7e7..8c8fb85b8b3 100644 --- a/homeassistant/components/wake_word/__init__.py +++ b/homeassistant/components/wake_word/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations from abc import abstractmethod +import asyncio from collections.abc import AsyncIterable import logging from typing import final @@ -34,6 +35,8 @@ _LOGGER = logging.getLogger(__name__) CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN) +TIMEOUT_FETCH_WAKE_WORDS = 10 + @callback def async_default_entity(hass: HomeAssistant) -> str | None: @@ -86,9 +89,8 @@ class WakeWordDetectionEntity(RestoreEntity): """Return the state of the entity.""" return self.__last_detected - @property @abstractmethod - def supported_wake_words(self) -> list[WakeWord]: + async def get_supported_wake_words(self) -> list[WakeWord]: """Return a list of supported wake words.""" @abstractmethod @@ -133,8 +135,9 @@ class WakeWordDetectionEntity(RestoreEntity): vol.Required("entity_id"): cv.entity_domain(DOMAIN), } ) +@websocket_api.async_response @callback -def websocket_entity_info( +async def websocket_entity_info( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict ) -> None: """Get info about wake word entity.""" @@ -147,7 +150,16 @@ def websocket_entity_info( ) return + try: + async with asyncio.timeout(TIMEOUT_FETCH_WAKE_WORDS): + wake_words = await entity.get_supported_wake_words() + except asyncio.TimeoutError: + connection.send_error( + msg["id"], websocket_api.const.ERR_TIMEOUT, "Timeout fetching wake words" + ) + return + connection.send_result( msg["id"], - {"wake_words": entity.supported_wake_words}, + {"wake_words": wake_words}, ) diff --git a/homeassistant/components/wyoming/wake_word.py b/homeassistant/components/wyoming/wake_word.py index d4cbd9b9263..fce8bbf6327 100644 --- a/homeassistant/components/wyoming/wake_word.py +++ b/homeassistant/components/wyoming/wake_word.py @@ -13,7 +13,7 @@ from homeassistant.core import HomeAssistant from homeassistant.helpers.entity_platform import AddEntitiesCallback from .const import DOMAIN -from .data import WyomingService +from .data import WyomingService, load_wyoming_info from .error import WyomingError _LOGGER = logging.getLogger(__name__) @@ -28,7 +28,7 @@ async def async_setup_entry( service: WyomingService = hass.data[DOMAIN][config_entry.entry_id] async_add_entities( [ - WyomingWakeWordProvider(config_entry, service), + WyomingWakeWordProvider(hass, config_entry, service), ] ) @@ -38,10 +38,12 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity): def __init__( self, + hass: HomeAssistant, config_entry: ConfigEntry, service: WyomingService, ) -> None: """Set up provider.""" + self.hass = hass self.service = service wake_service = service.info.wake[0] @@ -52,9 +54,19 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity): self._attr_name = wake_service.name self._attr_unique_id = f"{config_entry.entry_id}-wake_word" - @property - def supported_wake_words(self) -> list[wake_word.WakeWord]: + async def get_supported_wake_words(self) -> list[wake_word.WakeWord]: """Return a list of supported wake words.""" + info = await load_wyoming_info( + self.service.host, self.service.port, retries=0, timeout=1 + ) + + if info is not None: + wake_service = info.wake[0] + self._supported_wake_words = [ + wake_word.WakeWord(id=ww.name, name=ww.description or ww.name) + for ww in wake_service.models + ] + return self._supported_wake_words async def _async_process_audio_stream( diff --git a/tests/components/assist_pipeline/conftest.py b/tests/components/assist_pipeline/conftest.py index cde2666c1ea..1a3144ee069 100644 --- a/tests/components/assist_pipeline/conftest.py +++ b/tests/components/assist_pipeline/conftest.py @@ -181,8 +181,7 @@ class MockWakeWordEntity(wake_word.WakeWordDetectionEntity): url_path = "wake_word.test" _attr_name = "test" - @property - def supported_wake_words(self) -> list[wake_word.WakeWord]: + async def get_supported_wake_words(self) -> list[wake_word.WakeWord]: """Return a list of supported wake words.""" return [wake_word.WakeWord(id="test_ww", name="Test Wake Word")] @@ -191,7 +190,7 @@ class MockWakeWordEntity(wake_word.WakeWordDetectionEntity): ) -> wake_word.DetectionResult | None: """Try to detect wake word(s) in an audio stream with timestamps.""" if wake_word_id is None: - wake_word_id = self.supported_wake_words[0].id + wake_word_id = (await self.get_supported_wake_words())[0].id async for chunk, timestamp in stream: if chunk.startswith(b"wake word"): return wake_word.DetectionResult( diff --git a/tests/components/wake_word/test_init.py b/tests/components/wake_word/test_init.py index 5d1cc5a4b3f..6b147229d47 100644 --- a/tests/components/wake_word/test_init.py +++ b/tests/components/wake_word/test_init.py @@ -1,6 +1,9 @@ """Test wake_word component setup.""" +import asyncio from collections.abc import AsyncIterable, Generator +from functools import partial from pathlib import Path +from unittest.mock import patch from freezegun import freeze_time import pytest @@ -37,8 +40,7 @@ class MockProviderEntity(wake_word.WakeWordDetectionEntity): url_path = "wake_word.test" _attr_name = "test" - @property - def supported_wake_words(self) -> list[wake_word.WakeWord]: + async def get_supported_wake_words(self) -> list[wake_word.WakeWord]: """Return a list of supported wake words.""" return [ wake_word.WakeWord(id="test_ww", name="Test Wake Word"), @@ -50,7 +52,7 @@ class MockProviderEntity(wake_word.WakeWordDetectionEntity): ) -> wake_word.DetectionResult | None: """Try to detect wake word(s) in an audio stream with timestamps.""" if wake_word_id is None: - wake_word_id = self.supported_wake_words[0].id + wake_word_id = (await self.get_supported_wake_words())[0].id async for _chunk, timestamp in stream: if timestamp >= 2000: @@ -294,7 +296,7 @@ async def test_list_wake_words_unknown_entity( setup: MockProviderEntity, hass_ws_client: WebSocketGenerator, ) -> None: - """Test that the list_wake_words websocket command works.""" + """Test that the list_wake_words websocket command handles unknown entity.""" client = await hass_ws_client(hass) await client.send_json( { @@ -308,3 +310,28 @@ async def test_list_wake_words_unknown_entity( assert not msg["success"] assert msg["error"] == {"code": "not_found", "message": "Entity not found"} + + +async def test_list_wake_words_timeout( + hass: HomeAssistant, + setup: MockProviderEntity, + hass_ws_client: WebSocketGenerator, +) -> None: + """Test that the list_wake_words websocket command handles unknown entity.""" + client = await hass_ws_client(hass) + + with patch.object( + setup, "get_supported_wake_words", partial(asyncio.sleep, 1) + ), patch("homeassistant.components.wake_word.TIMEOUT_FETCH_WAKE_WORDS", 0): + await client.send_json( + { + "id": 5, + "type": "wake_word/info", + "entity_id": setup.entity_id, + } + ) + + msg = await client.receive_json() + + assert not msg["success"] + assert msg["error"] == {"code": "timeout", "message": "Timeout fetching wake words"} diff --git a/tests/components/wyoming/test_wake_word.py b/tests/components/wyoming/test_wake_word.py index b3c09d4e816..36a6daf0452 100644 --- a/tests/components/wyoming/test_wake_word.py +++ b/tests/components/wyoming/test_wake_word.py @@ -6,12 +6,13 @@ from unittest.mock import patch from syrupy.assertion import SnapshotAssertion from wyoming.asr import Transcript +from wyoming.info import Info, WakeModel, WakeProgram from wyoming.wake import Detection from homeassistant.components import wake_word from homeassistant.core import HomeAssistant -from . import MockAsyncTcpClient +from . import TEST_ATTR, MockAsyncTcpClient async def test_support(hass: HomeAssistant, init_wyoming_wake_word) -> None: @@ -24,7 +25,7 @@ async def test_support(hass: HomeAssistant, init_wyoming_wake_word) -> None: ) assert entity is not None - assert entity.supported_wake_words == [ + assert (await entity.get_supported_wake_words()) == [ wake_word.WakeWord(id="Test Model", name="Test Model") ] @@ -157,3 +158,55 @@ async def test_detect_message_with_wrong_wake_word( result = await entity.async_process_audio_stream(audio_stream(), "my-wake-word") assert result is None + + +async def test_dynamic_wake_word_info( + hass: HomeAssistant, init_wyoming_wake_word +) -> None: + """Test that supported wake words are loaded dynamically.""" + entity = wake_word.async_get_wake_word_detection_entity( + hass, "wake_word.test_wake_word" + ) + assert entity is not None + + # Original info + assert (await entity.get_supported_wake_words()) == [ + wake_word.WakeWord("Test Model", "Test Model") + ] + + new_info = Info( + wake=[ + WakeProgram( + name="dynamic", + description="Dynamic Wake Word", + installed=True, + attribution=TEST_ATTR, + models=[ + WakeModel( + name="ww1", + description="Wake Word 1", + installed=True, + attribution=TEST_ATTR, + languages=[], + ), + WakeModel( + name="ww2", + description="Wake Word 2", + installed=True, + attribution=TEST_ATTR, + languages=[], + ), + ], + ) + ] + ) + + # Different Wyoming info will be fetched + with patch( + "homeassistant.components.wyoming.wake_word.load_wyoming_info", + return_value=new_info, + ): + assert (await entity.get_supported_wake_words()) == [ + wake_word.WakeWord("ww1", "Wake Word 1"), + wake_word.WakeWord("ww2", "Wake Word 2"), + ]