From af8367a8c6afa7a7f7660c0357a082392113f22b Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Tue, 26 Sep 2023 19:24:02 -0500 Subject: [PATCH] Send Wyoming Detect message during wake word detection (#100968) * Send Detect message with desired wake word * Add tests * Fix test --- .../components/wyoming/manifest.json | 2 +- homeassistant/components/wyoming/wake_word.py | 19 ++++++- requirements_all.txt | 2 +- requirements_test_all.txt | 2 +- tests/components/wyoming/__init__.py | 6 ++- tests/components/wyoming/test_wake_word.py | 51 +++++++++++++++++++ 6 files changed, 76 insertions(+), 6 deletions(-) diff --git a/homeassistant/components/wyoming/manifest.json b/homeassistant/components/wyoming/manifest.json index 810092094d1..ddb5407e1ce 100644 --- a/homeassistant/components/wyoming/manifest.json +++ b/homeassistant/components/wyoming/manifest.json @@ -5,5 +5,5 @@ "config_flow": true, "documentation": "https://www.home-assistant.io/integrations/wyoming", "iot_class": "local_push", - "requirements": ["wyoming==1.1.0"] + "requirements": ["wyoming==1.2.0"] } diff --git a/homeassistant/components/wyoming/wake_word.py b/homeassistant/components/wyoming/wake_word.py index 45d33b2a28c..c9010425c52 100644 --- a/homeassistant/components/wyoming/wake_word.py +++ b/homeassistant/components/wyoming/wake_word.py @@ -5,7 +5,7 @@ import logging from wyoming.audio import AudioChunk, AudioStart from wyoming.client import AsyncTcpClient -from wyoming.wake import Detection +from wyoming.wake import Detect, Detection from homeassistant.components import wake_word from homeassistant.config_entries import ConfigEntry @@ -71,6 +71,11 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity): try: async with AsyncTcpClient(self.service.host, self.service.port) as client: + # Inform client which wake word we want to detect (None = default) + await client.write_event( + Detect(names=[wake_word_id] if wake_word_id else None).event() + ) + await client.write_event( AudioStart( rate=16000, @@ -97,10 +102,20 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity): break if Detection.is_type(event.type): - # Successful detection + # Possible detection detection = Detection.from_event(event) _LOGGER.info(detection) + if wake_word_id and (detection.name != wake_word_id): + _LOGGER.warning( + "Expected wake word %s but got %s, skipping", + wake_word_id, + detection.name, + ) + wake_task = asyncio.create_task(client.read_event()) + pending.add(wake_task) + continue + # Retrieve queued audio queued_audio: list[tuple[bytes, int]] | None = None if audio_task in pending: diff --git a/requirements_all.txt b/requirements_all.txt index 03c969afecd..6e9ac7f86f6 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -2718,7 +2718,7 @@ wled==0.16.0 wolf-smartset==0.1.11 # homeassistant.components.wyoming -wyoming==1.1.0 +wyoming==1.2.0 # homeassistant.components.xbox xbox-webapi==2.0.11 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index ba09e8975b9..5a9e1e838f7 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -2021,7 +2021,7 @@ wled==0.16.0 wolf-smartset==0.1.11 # homeassistant.components.wyoming -wyoming==1.1.0 +wyoming==1.2.0 # homeassistant.components.xbox xbox-webapi==2.0.11 diff --git a/tests/components/wyoming/__init__.py b/tests/components/wyoming/__init__.py index c326228ec8b..e04ff4eda03 100644 --- a/tests/components/wyoming/__init__.py +++ b/tests/components/wyoming/__init__.py @@ -92,7 +92,11 @@ class MockAsyncTcpClient: async def read_event(self): """Receive.""" await asyncio.sleep(0) # force context switch - return self.responses.pop(0) + + if self.responses: + return self.responses.pop(0) + + return None async def __aenter__(self): """Enter.""" diff --git a/tests/components/wyoming/test_wake_word.py b/tests/components/wyoming/test_wake_word.py index eec5a16ff25..b3c09d4e816 100644 --- a/tests/components/wyoming/test_wake_word.py +++ b/tests/components/wyoming/test_wake_word.py @@ -106,3 +106,54 @@ async def test_streaming_audio_oserror( result = await entity.async_process_audio_stream(audio_stream(), None) assert result is None + + +async def test_detect_message_with_wake_word( + hass: HomeAssistant, init_wyoming_wake_word +) -> None: + """Test that specifying a wake word id produces a Detect message with that id.""" + entity = wake_word.async_get_wake_word_detection_entity( + hass, "wake_word.test_wake_word" + ) + assert entity is not None + + async def audio_stream(): + yield b"chunk1", 1000 + + mock_client = MockAsyncTcpClient( + [Detection(name="my-wake-word", timestamp=1000).event()] + ) + + with patch( + "homeassistant.components.wyoming.wake_word.AsyncTcpClient", + mock_client, + ): + result = await entity.async_process_audio_stream(audio_stream(), "my-wake-word") + + assert isinstance(result, wake_word.DetectionResult) + assert result.wake_word_id == "my-wake-word" + + +async def test_detect_message_with_wrong_wake_word( + hass: HomeAssistant, init_wyoming_wake_word +) -> None: + """Test that specifying a wake word id filters invalid detections.""" + entity = wake_word.async_get_wake_word_detection_entity( + hass, "wake_word.test_wake_word" + ) + assert entity is not None + + async def audio_stream(): + yield b"chunk1", 1000 + + mock_client = MockAsyncTcpClient( + [Detection(name="not-my-wake-word", timestamp=1000).event()], + ) + + with patch( + "homeassistant.components.wyoming.wake_word.AsyncTcpClient", + mock_client, + ): + result = await entity.async_process_audio_stream(audio_stream(), "my-wake-word") + + assert result is None