mirror of
https://github.com/home-assistant/core.git
synced 2025-07-16 01:37:08 +00:00
Send Wyoming Detect message during wake word detection (#100968)
* Send Detect message with desired wake word * Add tests * Fix test
This commit is contained in:
parent
9b574fd2c9
commit
af8367a8c6
@ -5,5 +5,5 @@
|
||||
"config_flow": true,
|
||||
"documentation": "https://www.home-assistant.io/integrations/wyoming",
|
||||
"iot_class": "local_push",
|
||||
"requirements": ["wyoming==1.1.0"]
|
||||
"requirements": ["wyoming==1.2.0"]
|
||||
}
|
||||
|
@ -5,7 +5,7 @@ import logging
|
||||
|
||||
from wyoming.audio import AudioChunk, AudioStart
|
||||
from wyoming.client import AsyncTcpClient
|
||||
from wyoming.wake import Detection
|
||||
from wyoming.wake import Detect, Detection
|
||||
|
||||
from homeassistant.components import wake_word
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
@ -71,6 +71,11 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity):
|
||||
|
||||
try:
|
||||
async with AsyncTcpClient(self.service.host, self.service.port) as client:
|
||||
# Inform client which wake word we want to detect (None = default)
|
||||
await client.write_event(
|
||||
Detect(names=[wake_word_id] if wake_word_id else None).event()
|
||||
)
|
||||
|
||||
await client.write_event(
|
||||
AudioStart(
|
||||
rate=16000,
|
||||
@ -97,10 +102,20 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity):
|
||||
break
|
||||
|
||||
if Detection.is_type(event.type):
|
||||
# Successful detection
|
||||
# Possible detection
|
||||
detection = Detection.from_event(event)
|
||||
_LOGGER.info(detection)
|
||||
|
||||
if wake_word_id and (detection.name != wake_word_id):
|
||||
_LOGGER.warning(
|
||||
"Expected wake word %s but got %s, skipping",
|
||||
wake_word_id,
|
||||
detection.name,
|
||||
)
|
||||
wake_task = asyncio.create_task(client.read_event())
|
||||
pending.add(wake_task)
|
||||
continue
|
||||
|
||||
# Retrieve queued audio
|
||||
queued_audio: list[tuple[bytes, int]] | None = None
|
||||
if audio_task in pending:
|
||||
|
@ -2718,7 +2718,7 @@ wled==0.16.0
|
||||
wolf-smartset==0.1.11
|
||||
|
||||
# homeassistant.components.wyoming
|
||||
wyoming==1.1.0
|
||||
wyoming==1.2.0
|
||||
|
||||
# homeassistant.components.xbox
|
||||
xbox-webapi==2.0.11
|
||||
|
@ -2021,7 +2021,7 @@ wled==0.16.0
|
||||
wolf-smartset==0.1.11
|
||||
|
||||
# homeassistant.components.wyoming
|
||||
wyoming==1.1.0
|
||||
wyoming==1.2.0
|
||||
|
||||
# homeassistant.components.xbox
|
||||
xbox-webapi==2.0.11
|
||||
|
@ -92,7 +92,11 @@ class MockAsyncTcpClient:
|
||||
async def read_event(self):
|
||||
"""Receive."""
|
||||
await asyncio.sleep(0) # force context switch
|
||||
return self.responses.pop(0)
|
||||
|
||||
if self.responses:
|
||||
return self.responses.pop(0)
|
||||
|
||||
return None
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Enter."""
|
||||
|
@ -106,3 +106,54 @@ async def test_streaming_audio_oserror(
|
||||
result = await entity.async_process_audio_stream(audio_stream(), None)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
async def test_detect_message_with_wake_word(
|
||||
hass: HomeAssistant, init_wyoming_wake_word
|
||||
) -> None:
|
||||
"""Test that specifying a wake word id produces a Detect message with that id."""
|
||||
entity = wake_word.async_get_wake_word_detection_entity(
|
||||
hass, "wake_word.test_wake_word"
|
||||
)
|
||||
assert entity is not None
|
||||
|
||||
async def audio_stream():
|
||||
yield b"chunk1", 1000
|
||||
|
||||
mock_client = MockAsyncTcpClient(
|
||||
[Detection(name="my-wake-word", timestamp=1000).event()]
|
||||
)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.wake_word.AsyncTcpClient",
|
||||
mock_client,
|
||||
):
|
||||
result = await entity.async_process_audio_stream(audio_stream(), "my-wake-word")
|
||||
|
||||
assert isinstance(result, wake_word.DetectionResult)
|
||||
assert result.wake_word_id == "my-wake-word"
|
||||
|
||||
|
||||
async def test_detect_message_with_wrong_wake_word(
|
||||
hass: HomeAssistant, init_wyoming_wake_word
|
||||
) -> None:
|
||||
"""Test that specifying a wake word id filters invalid detections."""
|
||||
entity = wake_word.async_get_wake_word_detection_entity(
|
||||
hass, "wake_word.test_wake_word"
|
||||
)
|
||||
assert entity is not None
|
||||
|
||||
async def audio_stream():
|
||||
yield b"chunk1", 1000
|
||||
|
||||
mock_client = MockAsyncTcpClient(
|
||||
[Detection(name="not-my-wake-word", timestamp=1000).event()],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.wake_word.AsyncTcpClient",
|
||||
mock_client,
|
||||
):
|
||||
result = await entity.async_process_audio_stream(audio_stream(), "my-wake-word")
|
||||
|
||||
assert result is None
|
||||
|
Loading…
x
Reference in New Issue
Block a user