Send Wyoming Detect message during wake word detection (#100968)

* Send Detect message with desired wake word

* Add tests

* Fix test
This commit is contained in:
Michael Hansen 2023-09-26 19:24:02 -05:00 committed by GitHub
parent 9b574fd2c9
commit af8367a8c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 76 additions and 6 deletions

View File

@ -5,5 +5,5 @@
"config_flow": true,
"documentation": "https://www.home-assistant.io/integrations/wyoming",
"iot_class": "local_push",
"requirements": ["wyoming==1.1.0"]
"requirements": ["wyoming==1.2.0"]
}

View File

@ -5,7 +5,7 @@ import logging
from wyoming.audio import AudioChunk, AudioStart
from wyoming.client import AsyncTcpClient
from wyoming.wake import Detection
from wyoming.wake import Detect, Detection
from homeassistant.components import wake_word
from homeassistant.config_entries import ConfigEntry
@ -71,6 +71,11 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity):
try:
async with AsyncTcpClient(self.service.host, self.service.port) as client:
# Inform client which wake word we want to detect (None = default)
await client.write_event(
Detect(names=[wake_word_id] if wake_word_id else None).event()
)
await client.write_event(
AudioStart(
rate=16000,
@ -97,10 +102,20 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity):
break
if Detection.is_type(event.type):
# Successful detection
# Possible detection
detection = Detection.from_event(event)
_LOGGER.info(detection)
if wake_word_id and (detection.name != wake_word_id):
_LOGGER.warning(
"Expected wake word %s but got %s, skipping",
wake_word_id,
detection.name,
)
wake_task = asyncio.create_task(client.read_event())
pending.add(wake_task)
continue
# Retrieve queued audio
queued_audio: list[tuple[bytes, int]] | None = None
if audio_task in pending:

View File

@ -2718,7 +2718,7 @@ wled==0.16.0
wolf-smartset==0.1.11
# homeassistant.components.wyoming
wyoming==1.1.0
wyoming==1.2.0
# homeassistant.components.xbox
xbox-webapi==2.0.11

View File

@ -2021,7 +2021,7 @@ wled==0.16.0
wolf-smartset==0.1.11
# homeassistant.components.wyoming
wyoming==1.1.0
wyoming==1.2.0
# homeassistant.components.xbox
xbox-webapi==2.0.11

View File

@ -92,7 +92,11 @@ class MockAsyncTcpClient:
async def read_event(self):
"""Receive."""
await asyncio.sleep(0) # force context switch
return self.responses.pop(0)
if self.responses:
return self.responses.pop(0)
return None
async def __aenter__(self):
"""Enter."""

View File

@ -106,3 +106,54 @@ async def test_streaming_audio_oserror(
result = await entity.async_process_audio_stream(audio_stream(), None)
assert result is None
async def test_detect_message_with_wake_word(
hass: HomeAssistant, init_wyoming_wake_word
) -> None:
"""Test that specifying a wake word id produces a Detect message with that id."""
entity = wake_word.async_get_wake_word_detection_entity(
hass, "wake_word.test_wake_word"
)
assert entity is not None
async def audio_stream():
yield b"chunk1", 1000
mock_client = MockAsyncTcpClient(
[Detection(name="my-wake-word", timestamp=1000).event()]
)
with patch(
"homeassistant.components.wyoming.wake_word.AsyncTcpClient",
mock_client,
):
result = await entity.async_process_audio_stream(audio_stream(), "my-wake-word")
assert isinstance(result, wake_word.DetectionResult)
assert result.wake_word_id == "my-wake-word"
async def test_detect_message_with_wrong_wake_word(
hass: HomeAssistant, init_wyoming_wake_word
) -> None:
"""Test that specifying a wake word id filters invalid detections."""
entity = wake_word.async_get_wake_word_detection_entity(
hass, "wake_word.test_wake_word"
)
assert entity is not None
async def audio_stream():
yield b"chunk1", 1000
mock_client = MockAsyncTcpClient(
[Detection(name="not-my-wake-word", timestamp=1000).event()],
)
with patch(
"homeassistant.components.wyoming.wake_word.AsyncTcpClient",
mock_client,
):
result = await entity.async_process_audio_stream(audio_stream(), "my-wake-word")
assert result is None