Send Wyoming Detect message during wake word detection (#100968)

* Send Detect message with desired wake word

* Add tests

* Fix test
This commit is contained in:
Michael Hansen 2023-09-26 19:24:02 -05:00 committed by GitHub
parent 9b574fd2c9
commit af8367a8c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 76 additions and 6 deletions

View File

@ -5,5 +5,5 @@
"config_flow": true, "config_flow": true,
"documentation": "https://www.home-assistant.io/integrations/wyoming", "documentation": "https://www.home-assistant.io/integrations/wyoming",
"iot_class": "local_push", "iot_class": "local_push",
"requirements": ["wyoming==1.1.0"] "requirements": ["wyoming==1.2.0"]
} }

View File

@ -5,7 +5,7 @@ import logging
from wyoming.audio import AudioChunk, AudioStart from wyoming.audio import AudioChunk, AudioStart
from wyoming.client import AsyncTcpClient from wyoming.client import AsyncTcpClient
from wyoming.wake import Detection from wyoming.wake import Detect, Detection
from homeassistant.components import wake_word from homeassistant.components import wake_word
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
@ -71,6 +71,11 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity):
try: try:
async with AsyncTcpClient(self.service.host, self.service.port) as client: async with AsyncTcpClient(self.service.host, self.service.port) as client:
# Inform client which wake word we want to detect (None = default)
await client.write_event(
Detect(names=[wake_word_id] if wake_word_id else None).event()
)
await client.write_event( await client.write_event(
AudioStart( AudioStart(
rate=16000, rate=16000,
@ -97,10 +102,20 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity):
break break
if Detection.is_type(event.type): if Detection.is_type(event.type):
# Successful detection # Possible detection
detection = Detection.from_event(event) detection = Detection.from_event(event)
_LOGGER.info(detection) _LOGGER.info(detection)
if wake_word_id and (detection.name != wake_word_id):
_LOGGER.warning(
"Expected wake word %s but got %s, skipping",
wake_word_id,
detection.name,
)
wake_task = asyncio.create_task(client.read_event())
pending.add(wake_task)
continue
# Retrieve queued audio # Retrieve queued audio
queued_audio: list[tuple[bytes, int]] | None = None queued_audio: list[tuple[bytes, int]] | None = None
if audio_task in pending: if audio_task in pending:

View File

@ -2718,7 +2718,7 @@ wled==0.16.0
wolf-smartset==0.1.11 wolf-smartset==0.1.11
# homeassistant.components.wyoming # homeassistant.components.wyoming
wyoming==1.1.0 wyoming==1.2.0
# homeassistant.components.xbox # homeassistant.components.xbox
xbox-webapi==2.0.11 xbox-webapi==2.0.11

View File

@ -2021,7 +2021,7 @@ wled==0.16.0
wolf-smartset==0.1.11 wolf-smartset==0.1.11
# homeassistant.components.wyoming # homeassistant.components.wyoming
wyoming==1.1.0 wyoming==1.2.0
# homeassistant.components.xbox # homeassistant.components.xbox
xbox-webapi==2.0.11 xbox-webapi==2.0.11

View File

@ -92,7 +92,11 @@ class MockAsyncTcpClient:
async def read_event(self): async def read_event(self):
"""Receive.""" """Receive."""
await asyncio.sleep(0) # force context switch await asyncio.sleep(0) # force context switch
return self.responses.pop(0)
if self.responses:
return self.responses.pop(0)
return None
async def __aenter__(self): async def __aenter__(self):
"""Enter.""" """Enter."""

View File

@ -106,3 +106,54 @@ async def test_streaming_audio_oserror(
result = await entity.async_process_audio_stream(audio_stream(), None) result = await entity.async_process_audio_stream(audio_stream(), None)
assert result is None assert result is None
async def test_detect_message_with_wake_word(
hass: HomeAssistant, init_wyoming_wake_word
) -> None:
"""Test that specifying a wake word id produces a Detect message with that id."""
entity = wake_word.async_get_wake_word_detection_entity(
hass, "wake_word.test_wake_word"
)
assert entity is not None
async def audio_stream():
yield b"chunk1", 1000
mock_client = MockAsyncTcpClient(
[Detection(name="my-wake-word", timestamp=1000).event()]
)
with patch(
"homeassistant.components.wyoming.wake_word.AsyncTcpClient",
mock_client,
):
result = await entity.async_process_audio_stream(audio_stream(), "my-wake-word")
assert isinstance(result, wake_word.DetectionResult)
assert result.wake_word_id == "my-wake-word"
async def test_detect_message_with_wrong_wake_word(
hass: HomeAssistant, init_wyoming_wake_word
) -> None:
"""Test that specifying a wake word id filters invalid detections."""
entity = wake_word.async_get_wake_word_detection_entity(
hass, "wake_word.test_wake_word"
)
assert entity is not None
async def audio_stream():
yield b"chunk1", 1000
mock_client = MockAsyncTcpClient(
[Detection(name="not-my-wake-word", timestamp=1000).event()],
)
with patch(
"homeassistant.components.wyoming.wake_word.AsyncTcpClient",
mock_client,
):
result = await entity.async_process_audio_stream(audio_stream(), "my-wake-word")
assert result is None