Use wake word settings in assist pipeline runs (#100864)

This commit is contained in:
Erik Montnemery 2023-09-25 18:58:10 +02:00 committed by GitHub
parent 11e8bf0b9c
commit d76c5ed351
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 10 additions and 7 deletions

View File

@ -476,7 +476,9 @@ class PipelineRun:
async def prepare_wake_word_detection(self) -> None:
"""Prepare wake-word-detection."""
entity_id = wake_word.async_default_entity(self.hass)
entity_id = self.pipeline.wake_word_entity or wake_word.async_default_entity(
self.hass
)
if entity_id is None:
raise WakeWordDetectionError(
code="wake-engine-missing",
@ -553,7 +555,8 @@ class PipelineRun:
audio_stream=stream,
stt_audio_buffer=stt_audio_buffer,
wake_word_vad=wake_word_vad,
)
),
self.pipeline.wake_word_id,
)
if stt_audio_buffer is not None:

View File

@ -96,7 +96,7 @@ class WakeWordDetectionEntity(RestoreEntity):
"""
async def async_process_audio_stream(
self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None = None
self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None
) -> DetectionResult | None:
"""Try to detect wake word(s) in an audio stream with timestamps.

View File

@ -199,7 +199,7 @@ async def test_not_detected_entity(
# Need 2 seconds to trigger
state = setup.state
result = await setup.async_process_audio_stream(one_second_stream())
result = await setup.async_process_audio_stream(one_second_stream(), None)
assert result is None
# State should only change when there's a detection

View File

@ -54,7 +54,7 @@ async def test_streaming_audio(
"homeassistant.components.wyoming.wake_word.AsyncTcpClient",
MockAsyncTcpClient(client_events),
):
result = await entity.async_process_audio_stream(audio_stream())
result = await entity.async_process_audio_stream(audio_stream(), None)
assert result is not None
assert result == snapshot
@ -78,7 +78,7 @@ async def test_streaming_audio_connection_lost(
"homeassistant.components.wyoming.wake_word.AsyncTcpClient",
MockAsyncTcpClient([None]),
):
result = await entity.async_process_audio_stream(audio_stream())
result = await entity.async_process_audio_stream(audio_stream(), None)
assert result is None
@ -103,6 +103,6 @@ async def test_streaming_audio_oserror(
"homeassistant.components.wyoming.wake_word.AsyncTcpClient",
mock_client,
), patch.object(mock_client, "read_event", side_effect=OSError("Boom!")):
result = await entity.async_process_audio_stream(audio_stream())
result = await entity.async_process_audio_stream(audio_stream(), None)
assert result is None