From 23b239ba77424e3eee1b4d6e8c917328453bc02c Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 25 Sep 2023 09:33:54 -0400 Subject: [PATCH] Allow passing a wake word ID to detect wake word (#100832) * Allow passing a wake word ID to detect wake word * Do not inject default wake words in wake_word integration --- .../components/wake_word/__init__.py | 6 ++--- homeassistant/components/wyoming/wake_word.py | 2 +- tests/components/assist_pipeline/conftest.py | 6 +++-- .../wake_word/snapshots/test_init.ambr | 5 +++- tests/components/wake_word/test_init.py | 25 +++++++++++++++---- 5 files changed, 32 insertions(+), 12 deletions(-) diff --git a/homeassistant/components/wake_word/__init__.py b/homeassistant/components/wake_word/__init__.py index b308cf98912..c29789a5fc8 100644 --- a/homeassistant/components/wake_word/__init__.py +++ b/homeassistant/components/wake_word/__init__.py @@ -88,7 +88,7 @@ class WakeWordDetectionEntity(RestoreEntity): @abstractmethod async def _async_process_audio_stream( - self, stream: AsyncIterable[tuple[bytes, int]] + self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None ) -> DetectionResult | None: """Try to detect wake word(s) in an audio stream with timestamps. @@ -96,13 +96,13 @@ class WakeWordDetectionEntity(RestoreEntity): """ async def async_process_audio_stream( - self, stream: AsyncIterable[tuple[bytes, int]] + self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None = None ) -> DetectionResult | None: """Try to detect wake word(s) in an audio stream with timestamps. Audio must be 16Khz sample rate with 16-bit mono PCM samples. """ - result = await self._async_process_audio_stream(stream) + result = await self._async_process_audio_stream(stream, wake_word_id) if result is not None: # Update last detected only when there is a detection self.__last_detected = dt_util.utcnow().isoformat() diff --git a/homeassistant/components/wyoming/wake_word.py b/homeassistant/components/wyoming/wake_word.py index 0e7fb3c4429..710e3408c5a 100644 --- a/homeassistant/components/wyoming/wake_word.py +++ b/homeassistant/components/wyoming/wake_word.py @@ -58,7 +58,7 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity): return self._supported_wake_words async def _async_process_audio_stream( - self, stream: AsyncIterable[tuple[bytes, int]] + self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None ) -> wake_word.DetectionResult | None: """Try to detect one or more wake words in an audio stream. diff --git a/tests/components/assist_pipeline/conftest.py b/tests/components/assist_pipeline/conftest.py index d2ec3553cf0..0ea92dd42fd 100644 --- a/tests/components/assist_pipeline/conftest.py +++ b/tests/components/assist_pipeline/conftest.py @@ -187,13 +187,15 @@ class MockWakeWordEntity(wake_word.WakeWordDetectionEntity): return [wake_word.WakeWord(ww_id="test_ww", name="Test Wake Word")] async def _async_process_audio_stream( - self, stream: AsyncIterable[tuple[bytes, int]] + self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None ) -> wake_word.DetectionResult | None: """Try to detect wake word(s) in an audio stream with timestamps.""" + if wake_word_id is None: + wake_word_id = self.supported_wake_words[0].ww_id async for chunk, timestamp in stream: if chunk.startswith(b"wake word"): return wake_word.DetectionResult( - ww_id=self.supported_wake_words[0].ww_id, + ww_id=wake_word_id, timestamp=timestamp, queued_audio=[(b"queued audio", 0)], ) diff --git a/tests/components/wake_word/snapshots/test_init.ambr b/tests/components/wake_word/snapshots/test_init.ambr index cf7c09cd730..60439d1109b 100644 --- a/tests/components/wake_word/snapshots/test_init.ambr +++ b/tests/components/wake_word/snapshots/test_init.ambr @@ -1,5 +1,8 @@ # serializer version: 1 -# name: test_detected_entity +# name: test_detected_entity[None-test_ww] + None +# --- +# name: test_detected_entity[test_ww_2-test_ww_2] None # --- # name: test_ws_detect diff --git a/tests/components/wake_word/test_init.py b/tests/components/wake_word/test_init.py index d37cb3aa540..e54bfc97214 100644 --- a/tests/components/wake_word/test_init.py +++ b/tests/components/wake_word/test_init.py @@ -39,16 +39,22 @@ class MockProviderEntity(wake_word.WakeWordDetectionEntity): @property def supported_wake_words(self) -> list[wake_word.WakeWord]: """Return a list of supported wake words.""" - return [wake_word.WakeWord(ww_id="test_ww", name="Test Wake Word")] + return [ + wake_word.WakeWord(ww_id="test_ww", name="Test Wake Word"), + wake_word.WakeWord(ww_id="test_ww_2", name="Test Wake Word 2"), + ] async def _async_process_audio_stream( - self, stream: AsyncIterable[tuple[bytes, int]] + self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None ) -> wake_word.DetectionResult | None: """Try to detect wake word(s) in an audio stream with timestamps.""" + if wake_word_id is None: + wake_word_id = self.supported_wake_words[0].ww_id + async for _chunk, timestamp in stream: if timestamp >= 2000: return wake_word.DetectionResult( - ww_id=self.supported_wake_words[0].ww_id, timestamp=timestamp + ww_id=wake_word_id, timestamp=timestamp ) # Not detected @@ -148,11 +154,20 @@ async def test_config_entry_unload( assert config_entry.state == ConfigEntryState.NOT_LOADED +@pytest.mark.parametrize( + ("ww_id", "expected_ww"), + [ + (None, "test_ww"), + ("test_ww_2", "test_ww_2"), + ], +) async def test_detected_entity( hass: HomeAssistant, tmp_path: Path, setup: MockProviderEntity, snapshot: SnapshotAssertion, + ww_id: str | None, + expected_ww: str, ) -> None: """Test successful detection through entity.""" @@ -164,8 +179,8 @@ async def test_detected_entity( # Need 2 seconds to trigger state = setup.state - result = await setup.async_process_audio_stream(three_second_stream()) - assert result == wake_word.DetectionResult("test_ww", 2048) + result = await setup.async_process_audio_stream(three_second_stream(), ww_id) + assert result == wake_word.DetectionResult(expected_ww, 2048) assert state != setup.state assert state == snapshot