mirror of
https://github.com/home-assistant/core.git
synced 2025-07-14 00:37:13 +00:00
Allow passing a wake word ID to detect wake word (#100832)
* Allow passing a wake word ID to detect wake word * Do not inject default wake words in wake_word integration
This commit is contained in:
parent
8a44adb447
commit
23b239ba77
@ -88,7 +88,7 @@ class WakeWordDetectionEntity(RestoreEntity):
|
||||
|
||||
@abstractmethod
|
||||
async def _async_process_audio_stream(
|
||||
self, stream: AsyncIterable[tuple[bytes, int]]
|
||||
self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None
|
||||
) -> DetectionResult | None:
|
||||
"""Try to detect wake word(s) in an audio stream with timestamps.
|
||||
|
||||
@ -96,13 +96,13 @@ class WakeWordDetectionEntity(RestoreEntity):
|
||||
"""
|
||||
|
||||
async def async_process_audio_stream(
|
||||
self, stream: AsyncIterable[tuple[bytes, int]]
|
||||
self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None = None
|
||||
) -> DetectionResult | None:
|
||||
"""Try to detect wake word(s) in an audio stream with timestamps.
|
||||
|
||||
Audio must be 16Khz sample rate with 16-bit mono PCM samples.
|
||||
"""
|
||||
result = await self._async_process_audio_stream(stream)
|
||||
result = await self._async_process_audio_stream(stream, wake_word_id)
|
||||
if result is not None:
|
||||
# Update last detected only when there is a detection
|
||||
self.__last_detected = dt_util.utcnow().isoformat()
|
||||
|
@ -58,7 +58,7 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity):
|
||||
return self._supported_wake_words
|
||||
|
||||
async def _async_process_audio_stream(
|
||||
self, stream: AsyncIterable[tuple[bytes, int]]
|
||||
self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None
|
||||
) -> wake_word.DetectionResult | None:
|
||||
"""Try to detect one or more wake words in an audio stream.
|
||||
|
||||
|
@ -187,13 +187,15 @@ class MockWakeWordEntity(wake_word.WakeWordDetectionEntity):
|
||||
return [wake_word.WakeWord(ww_id="test_ww", name="Test Wake Word")]
|
||||
|
||||
async def _async_process_audio_stream(
|
||||
self, stream: AsyncIterable[tuple[bytes, int]]
|
||||
self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None
|
||||
) -> wake_word.DetectionResult | None:
|
||||
"""Try to detect wake word(s) in an audio stream with timestamps."""
|
||||
if wake_word_id is None:
|
||||
wake_word_id = self.supported_wake_words[0].ww_id
|
||||
async for chunk, timestamp in stream:
|
||||
if chunk.startswith(b"wake word"):
|
||||
return wake_word.DetectionResult(
|
||||
ww_id=self.supported_wake_words[0].ww_id,
|
||||
ww_id=wake_word_id,
|
||||
timestamp=timestamp,
|
||||
queued_audio=[(b"queued audio", 0)],
|
||||
)
|
||||
|
@ -1,5 +1,8 @@
|
||||
# serializer version: 1
|
||||
# name: test_detected_entity
|
||||
# name: test_detected_entity[None-test_ww]
|
||||
None
|
||||
# ---
|
||||
# name: test_detected_entity[test_ww_2-test_ww_2]
|
||||
None
|
||||
# ---
|
||||
# name: test_ws_detect
|
||||
|
@ -39,16 +39,22 @@ class MockProviderEntity(wake_word.WakeWordDetectionEntity):
|
||||
@property
|
||||
def supported_wake_words(self) -> list[wake_word.WakeWord]:
|
||||
"""Return a list of supported wake words."""
|
||||
return [wake_word.WakeWord(ww_id="test_ww", name="Test Wake Word")]
|
||||
return [
|
||||
wake_word.WakeWord(ww_id="test_ww", name="Test Wake Word"),
|
||||
wake_word.WakeWord(ww_id="test_ww_2", name="Test Wake Word 2"),
|
||||
]
|
||||
|
||||
async def _async_process_audio_stream(
|
||||
self, stream: AsyncIterable[tuple[bytes, int]]
|
||||
self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None
|
||||
) -> wake_word.DetectionResult | None:
|
||||
"""Try to detect wake word(s) in an audio stream with timestamps."""
|
||||
if wake_word_id is None:
|
||||
wake_word_id = self.supported_wake_words[0].ww_id
|
||||
|
||||
async for _chunk, timestamp in stream:
|
||||
if timestamp >= 2000:
|
||||
return wake_word.DetectionResult(
|
||||
ww_id=self.supported_wake_words[0].ww_id, timestamp=timestamp
|
||||
ww_id=wake_word_id, timestamp=timestamp
|
||||
)
|
||||
|
||||
# Not detected
|
||||
@ -148,11 +154,20 @@ async def test_config_entry_unload(
|
||||
assert config_entry.state == ConfigEntryState.NOT_LOADED
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("ww_id", "expected_ww"),
|
||||
[
|
||||
(None, "test_ww"),
|
||||
("test_ww_2", "test_ww_2"),
|
||||
],
|
||||
)
|
||||
async def test_detected_entity(
|
||||
hass: HomeAssistant,
|
||||
tmp_path: Path,
|
||||
setup: MockProviderEntity,
|
||||
snapshot: SnapshotAssertion,
|
||||
ww_id: str | None,
|
||||
expected_ww: str,
|
||||
) -> None:
|
||||
"""Test successful detection through entity."""
|
||||
|
||||
@ -164,8 +179,8 @@ async def test_detected_entity(
|
||||
|
||||
# Need 2 seconds to trigger
|
||||
state = setup.state
|
||||
result = await setup.async_process_audio_stream(three_second_stream())
|
||||
assert result == wake_word.DetectionResult("test_ww", 2048)
|
||||
result = await setup.async_process_audio_stream(three_second_stream(), ww_id)
|
||||
assert result == wake_word.DetectionResult(expected_ww, 2048)
|
||||
|
||||
assert state != setup.state
|
||||
assert state == snapshot
|
||||
|
Loading…
x
Reference in New Issue
Block a user