Allow passing a wake word ID to detect wake word (#100832)

* Allow passing a wake word ID to detect wake word

* Do not inject default wake words in wake_word integration
This commit is contained in:
Paulus Schoutsen 2023-09-25 09:33:54 -04:00 committed by GitHub
parent 8a44adb447
commit 23b239ba77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 32 additions and 12 deletions

View File

@ -88,7 +88,7 @@ class WakeWordDetectionEntity(RestoreEntity):
@abstractmethod @abstractmethod
async def _async_process_audio_stream( async def _async_process_audio_stream(
self, stream: AsyncIterable[tuple[bytes, int]] self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None
) -> DetectionResult | None: ) -> DetectionResult | None:
"""Try to detect wake word(s) in an audio stream with timestamps. """Try to detect wake word(s) in an audio stream with timestamps.
@ -96,13 +96,13 @@ class WakeWordDetectionEntity(RestoreEntity):
""" """
async def async_process_audio_stream( async def async_process_audio_stream(
self, stream: AsyncIterable[tuple[bytes, int]] self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None = None
) -> DetectionResult | None: ) -> DetectionResult | None:
"""Try to detect wake word(s) in an audio stream with timestamps. """Try to detect wake word(s) in an audio stream with timestamps.
Audio must be 16Khz sample rate with 16-bit mono PCM samples. Audio must be 16Khz sample rate with 16-bit mono PCM samples.
""" """
result = await self._async_process_audio_stream(stream) result = await self._async_process_audio_stream(stream, wake_word_id)
if result is not None: if result is not None:
# Update last detected only when there is a detection # Update last detected only when there is a detection
self.__last_detected = dt_util.utcnow().isoformat() self.__last_detected = dt_util.utcnow().isoformat()

View File

@ -58,7 +58,7 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity):
return self._supported_wake_words return self._supported_wake_words
async def _async_process_audio_stream( async def _async_process_audio_stream(
self, stream: AsyncIterable[tuple[bytes, int]] self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None
) -> wake_word.DetectionResult | None: ) -> wake_word.DetectionResult | None:
"""Try to detect one or more wake words in an audio stream. """Try to detect one or more wake words in an audio stream.

View File

@ -187,13 +187,15 @@ class MockWakeWordEntity(wake_word.WakeWordDetectionEntity):
return [wake_word.WakeWord(ww_id="test_ww", name="Test Wake Word")] return [wake_word.WakeWord(ww_id="test_ww", name="Test Wake Word")]
async def _async_process_audio_stream( async def _async_process_audio_stream(
self, stream: AsyncIterable[tuple[bytes, int]] self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None
) -> wake_word.DetectionResult | None: ) -> wake_word.DetectionResult | None:
"""Try to detect wake word(s) in an audio stream with timestamps.""" """Try to detect wake word(s) in an audio stream with timestamps."""
if wake_word_id is None:
wake_word_id = self.supported_wake_words[0].ww_id
async for chunk, timestamp in stream: async for chunk, timestamp in stream:
if chunk.startswith(b"wake word"): if chunk.startswith(b"wake word"):
return wake_word.DetectionResult( return wake_word.DetectionResult(
ww_id=self.supported_wake_words[0].ww_id, ww_id=wake_word_id,
timestamp=timestamp, timestamp=timestamp,
queued_audio=[(b"queued audio", 0)], queued_audio=[(b"queued audio", 0)],
) )

View File

@ -1,5 +1,8 @@
# serializer version: 1 # serializer version: 1
# name: test_detected_entity # name: test_detected_entity[None-test_ww]
None
# ---
# name: test_detected_entity[test_ww_2-test_ww_2]
None None
# --- # ---
# name: test_ws_detect # name: test_ws_detect

View File

@ -39,16 +39,22 @@ class MockProviderEntity(wake_word.WakeWordDetectionEntity):
@property @property
def supported_wake_words(self) -> list[wake_word.WakeWord]: def supported_wake_words(self) -> list[wake_word.WakeWord]:
"""Return a list of supported wake words.""" """Return a list of supported wake words."""
return [wake_word.WakeWord(ww_id="test_ww", name="Test Wake Word")] return [
wake_word.WakeWord(ww_id="test_ww", name="Test Wake Word"),
wake_word.WakeWord(ww_id="test_ww_2", name="Test Wake Word 2"),
]
async def _async_process_audio_stream( async def _async_process_audio_stream(
self, stream: AsyncIterable[tuple[bytes, int]] self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None
) -> wake_word.DetectionResult | None: ) -> wake_word.DetectionResult | None:
"""Try to detect wake word(s) in an audio stream with timestamps.""" """Try to detect wake word(s) in an audio stream with timestamps."""
if wake_word_id is None:
wake_word_id = self.supported_wake_words[0].ww_id
async for _chunk, timestamp in stream: async for _chunk, timestamp in stream:
if timestamp >= 2000: if timestamp >= 2000:
return wake_word.DetectionResult( return wake_word.DetectionResult(
ww_id=self.supported_wake_words[0].ww_id, timestamp=timestamp ww_id=wake_word_id, timestamp=timestamp
) )
# Not detected # Not detected
@ -148,11 +154,20 @@ async def test_config_entry_unload(
assert config_entry.state == ConfigEntryState.NOT_LOADED assert config_entry.state == ConfigEntryState.NOT_LOADED
@pytest.mark.parametrize(
("ww_id", "expected_ww"),
[
(None, "test_ww"),
("test_ww_2", "test_ww_2"),
],
)
async def test_detected_entity( async def test_detected_entity(
hass: HomeAssistant, hass: HomeAssistant,
tmp_path: Path, tmp_path: Path,
setup: MockProviderEntity, setup: MockProviderEntity,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
ww_id: str | None,
expected_ww: str,
) -> None: ) -> None:
"""Test successful detection through entity.""" """Test successful detection through entity."""
@ -164,8 +179,8 @@ async def test_detected_entity(
# Need 2 seconds to trigger # Need 2 seconds to trigger
state = setup.state state = setup.state
result = await setup.async_process_audio_stream(three_second_stream()) result = await setup.async_process_audio_stream(three_second_stream(), ww_id)
assert result == wake_word.DetectionResult("test_ww", 2048) assert result == wake_word.DetectionResult(expected_ww, 2048)
assert state != setup.state assert state != setup.state
assert state == snapshot assert state == snapshot