mirror of
https://github.com/home-assistant/core.git
synced 2025-07-14 00:37:13 +00:00
Allow passing a wake word ID to detect wake word (#100832)
* Allow passing a wake word ID to detect wake word * Do not inject default wake words in wake_word integration
This commit is contained in:
parent
8a44adb447
commit
23b239ba77
@ -88,7 +88,7 @@ class WakeWordDetectionEntity(RestoreEntity):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def _async_process_audio_stream(
|
async def _async_process_audio_stream(
|
||||||
self, stream: AsyncIterable[tuple[bytes, int]]
|
self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None
|
||||||
) -> DetectionResult | None:
|
) -> DetectionResult | None:
|
||||||
"""Try to detect wake word(s) in an audio stream with timestamps.
|
"""Try to detect wake word(s) in an audio stream with timestamps.
|
||||||
|
|
||||||
@ -96,13 +96,13 @@ class WakeWordDetectionEntity(RestoreEntity):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
async def async_process_audio_stream(
|
async def async_process_audio_stream(
|
||||||
self, stream: AsyncIterable[tuple[bytes, int]]
|
self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None = None
|
||||||
) -> DetectionResult | None:
|
) -> DetectionResult | None:
|
||||||
"""Try to detect wake word(s) in an audio stream with timestamps.
|
"""Try to detect wake word(s) in an audio stream with timestamps.
|
||||||
|
|
||||||
Audio must be 16Khz sample rate with 16-bit mono PCM samples.
|
Audio must be 16Khz sample rate with 16-bit mono PCM samples.
|
||||||
"""
|
"""
|
||||||
result = await self._async_process_audio_stream(stream)
|
result = await self._async_process_audio_stream(stream, wake_word_id)
|
||||||
if result is not None:
|
if result is not None:
|
||||||
# Update last detected only when there is a detection
|
# Update last detected only when there is a detection
|
||||||
self.__last_detected = dt_util.utcnow().isoformat()
|
self.__last_detected = dt_util.utcnow().isoformat()
|
||||||
|
@ -58,7 +58,7 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity):
|
|||||||
return self._supported_wake_words
|
return self._supported_wake_words
|
||||||
|
|
||||||
async def _async_process_audio_stream(
|
async def _async_process_audio_stream(
|
||||||
self, stream: AsyncIterable[tuple[bytes, int]]
|
self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None
|
||||||
) -> wake_word.DetectionResult | None:
|
) -> wake_word.DetectionResult | None:
|
||||||
"""Try to detect one or more wake words in an audio stream.
|
"""Try to detect one or more wake words in an audio stream.
|
||||||
|
|
||||||
|
@ -187,13 +187,15 @@ class MockWakeWordEntity(wake_word.WakeWordDetectionEntity):
|
|||||||
return [wake_word.WakeWord(ww_id="test_ww", name="Test Wake Word")]
|
return [wake_word.WakeWord(ww_id="test_ww", name="Test Wake Word")]
|
||||||
|
|
||||||
async def _async_process_audio_stream(
|
async def _async_process_audio_stream(
|
||||||
self, stream: AsyncIterable[tuple[bytes, int]]
|
self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None
|
||||||
) -> wake_word.DetectionResult | None:
|
) -> wake_word.DetectionResult | None:
|
||||||
"""Try to detect wake word(s) in an audio stream with timestamps."""
|
"""Try to detect wake word(s) in an audio stream with timestamps."""
|
||||||
|
if wake_word_id is None:
|
||||||
|
wake_word_id = self.supported_wake_words[0].ww_id
|
||||||
async for chunk, timestamp in stream:
|
async for chunk, timestamp in stream:
|
||||||
if chunk.startswith(b"wake word"):
|
if chunk.startswith(b"wake word"):
|
||||||
return wake_word.DetectionResult(
|
return wake_word.DetectionResult(
|
||||||
ww_id=self.supported_wake_words[0].ww_id,
|
ww_id=wake_word_id,
|
||||||
timestamp=timestamp,
|
timestamp=timestamp,
|
||||||
queued_audio=[(b"queued audio", 0)],
|
queued_audio=[(b"queued audio", 0)],
|
||||||
)
|
)
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
# serializer version: 1
|
# serializer version: 1
|
||||||
# name: test_detected_entity
|
# name: test_detected_entity[None-test_ww]
|
||||||
|
None
|
||||||
|
# ---
|
||||||
|
# name: test_detected_entity[test_ww_2-test_ww_2]
|
||||||
None
|
None
|
||||||
# ---
|
# ---
|
||||||
# name: test_ws_detect
|
# name: test_ws_detect
|
||||||
|
@ -39,16 +39,22 @@ class MockProviderEntity(wake_word.WakeWordDetectionEntity):
|
|||||||
@property
|
@property
|
||||||
def supported_wake_words(self) -> list[wake_word.WakeWord]:
|
def supported_wake_words(self) -> list[wake_word.WakeWord]:
|
||||||
"""Return a list of supported wake words."""
|
"""Return a list of supported wake words."""
|
||||||
return [wake_word.WakeWord(ww_id="test_ww", name="Test Wake Word")]
|
return [
|
||||||
|
wake_word.WakeWord(ww_id="test_ww", name="Test Wake Word"),
|
||||||
|
wake_word.WakeWord(ww_id="test_ww_2", name="Test Wake Word 2"),
|
||||||
|
]
|
||||||
|
|
||||||
async def _async_process_audio_stream(
|
async def _async_process_audio_stream(
|
||||||
self, stream: AsyncIterable[tuple[bytes, int]]
|
self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None
|
||||||
) -> wake_word.DetectionResult | None:
|
) -> wake_word.DetectionResult | None:
|
||||||
"""Try to detect wake word(s) in an audio stream with timestamps."""
|
"""Try to detect wake word(s) in an audio stream with timestamps."""
|
||||||
|
if wake_word_id is None:
|
||||||
|
wake_word_id = self.supported_wake_words[0].ww_id
|
||||||
|
|
||||||
async for _chunk, timestamp in stream:
|
async for _chunk, timestamp in stream:
|
||||||
if timestamp >= 2000:
|
if timestamp >= 2000:
|
||||||
return wake_word.DetectionResult(
|
return wake_word.DetectionResult(
|
||||||
ww_id=self.supported_wake_words[0].ww_id, timestamp=timestamp
|
ww_id=wake_word_id, timestamp=timestamp
|
||||||
)
|
)
|
||||||
|
|
||||||
# Not detected
|
# Not detected
|
||||||
@ -148,11 +154,20 @@ async def test_config_entry_unload(
|
|||||||
assert config_entry.state == ConfigEntryState.NOT_LOADED
|
assert config_entry.state == ConfigEntryState.NOT_LOADED
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("ww_id", "expected_ww"),
|
||||||
|
[
|
||||||
|
(None, "test_ww"),
|
||||||
|
("test_ww_2", "test_ww_2"),
|
||||||
|
],
|
||||||
|
)
|
||||||
async def test_detected_entity(
|
async def test_detected_entity(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
setup: MockProviderEntity,
|
setup: MockProviderEntity,
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
|
ww_id: str | None,
|
||||||
|
expected_ww: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test successful detection through entity."""
|
"""Test successful detection through entity."""
|
||||||
|
|
||||||
@ -164,8 +179,8 @@ async def test_detected_entity(
|
|||||||
|
|
||||||
# Need 2 seconds to trigger
|
# Need 2 seconds to trigger
|
||||||
state = setup.state
|
state = setup.state
|
||||||
result = await setup.async_process_audio_stream(three_second_stream())
|
result = await setup.async_process_audio_stream(three_second_stream(), ww_id)
|
||||||
assert result == wake_word.DetectionResult("test_ww", 2048)
|
assert result == wake_word.DetectionResult(expected_ww, 2048)
|
||||||
|
|
||||||
assert state != setup.state
|
assert state != setup.state
|
||||||
assert state == snapshot
|
assert state == snapshot
|
||||||
|
Loading…
x
Reference in New Issue
Block a user