Add wake word cooldown to avoid duplicate wake-ups (#101417)

This commit is contained in:
Michael Hansen 2023-10-06 02:18:35 -05:00 committed by GitHub
parent 48a23798d0
commit 244f6d8002
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 198 additions and 18 deletions

View File

@ -9,7 +9,7 @@ from homeassistant.components import stt
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from .const import DATA_CONFIG, DOMAIN from .const import CONF_DEBUG_RECORDING_DIR, DATA_CONFIG, DOMAIN
from .error import PipelineNotFound from .error import PipelineNotFound
from .pipeline import ( from .pipeline import (
AudioSettings, AudioSettings,
@ -45,7 +45,9 @@ __all__ = (
CONFIG_SCHEMA = vol.Schema( CONFIG_SCHEMA = vol.Schema(
{ {
DOMAIN: vol.Schema( DOMAIN: vol.Schema(
{vol.Optional("debug_recording_dir"): str}, {
vol.Optional(CONF_DEBUG_RECORDING_DIR): str,
},
) )
}, },
extra=vol.ALLOW_EXTRA, extra=vol.ALLOW_EXTRA,

View File

@ -2,3 +2,12 @@
DOMAIN = "assist_pipeline" DOMAIN = "assist_pipeline"
DATA_CONFIG = f"{DOMAIN}.config" DATA_CONFIG = f"{DOMAIN}.config"
DEFAULT_PIPELINE_TIMEOUT = 60 * 5 # seconds
DEFAULT_WAKE_WORD_TIMEOUT = 3 # seconds
CONF_DEBUG_RECORDING_DIR = "debug_recording_dir"
DATA_LAST_WAKE_UP = f"{DOMAIN}.last_wake_up"
DEFAULT_WAKE_WORD_COOLDOWN = 2 # seconds

View File

@ -48,7 +48,13 @@ from homeassistant.util import (
) )
from homeassistant.util.limited_size_dict import LimitedSizeDict from homeassistant.util.limited_size_dict import LimitedSizeDict
from .const import DATA_CONFIG, DOMAIN from .const import (
CONF_DEBUG_RECORDING_DIR,
DATA_CONFIG,
DATA_LAST_WAKE_UP,
DEFAULT_WAKE_WORD_COOLDOWN,
DOMAIN,
)
from .error import ( from .error import (
IntentRecognitionError, IntentRecognitionError,
PipelineError, PipelineError,
@ -399,6 +405,9 @@ class WakeWordSettings:
audio_seconds_to_buffer: float = 0 audio_seconds_to_buffer: float = 0
"""Seconds of audio to buffer before detection and forward to STT.""" """Seconds of audio to buffer before detection and forward to STT."""
cooldown_seconds: float = DEFAULT_WAKE_WORD_COOLDOWN
"""Seconds after a wake word detection where other detections are ignored."""
@dataclass(frozen=True) @dataclass(frozen=True)
class AudioSettings: class AudioSettings:
@ -603,6 +612,8 @@ class PipelineRun:
) )
) )
wake_word_settings = self.wake_word_settings or WakeWordSettings()
# Remove language since it doesn't apply to wake words yet # Remove language since it doesn't apply to wake words yet
metadata_dict.pop("language", None) metadata_dict.pop("language", None)
@ -612,6 +623,7 @@ class PipelineRun:
{ {
"entity_id": self.wake_word_entity_id, "entity_id": self.wake_word_entity_id,
"metadata": metadata_dict, "metadata": metadata_dict,
"timeout": wake_word_settings.timeout or 0,
}, },
) )
) )
@ -619,8 +631,6 @@ class PipelineRun:
if self.debug_recording_queue is not None: if self.debug_recording_queue is not None:
self.debug_recording_queue.put_nowait(f"00_wake-{self.wake_word_entity_id}") self.debug_recording_queue.put_nowait(f"00_wake-{self.wake_word_entity_id}")
wake_word_settings = self.wake_word_settings or WakeWordSettings()
wake_word_vad: VoiceActivityTimeout | None = None wake_word_vad: VoiceActivityTimeout | None = None
if (wake_word_settings.timeout is not None) and ( if (wake_word_settings.timeout is not None) and (
wake_word_settings.timeout > 0 wake_word_settings.timeout > 0
@ -670,6 +680,17 @@ class PipelineRun:
if result is None: if result is None:
wake_word_output: dict[str, Any] = {} wake_word_output: dict[str, Any] = {}
else: else:
# Avoid duplicate detections by checking cooldown
last_wake_up = self.hass.data.get(DATA_LAST_WAKE_UP)
if last_wake_up is not None:
sec_since_last_wake_up = time.monotonic() - last_wake_up
if sec_since_last_wake_up < wake_word_settings.cooldown_seconds:
_LOGGER.debug("Duplicate wake word detection occurred")
raise WakeWordDetectionAborted
# Record last wake up time to block duplicate detections
self.hass.data[DATA_LAST_WAKE_UP] = time.monotonic()
if result.queued_audio: if result.queued_audio:
# Add audio that was pending at detection. # Add audio that was pending at detection.
# #
@ -1032,7 +1053,7 @@ class PipelineRun:
# Directory to save audio for each pipeline run. # Directory to save audio for each pipeline run.
# Configured in YAML for assist_pipeline. # Configured in YAML for assist_pipeline.
if debug_recording_dir := self.hass.data[DATA_CONFIG].get( if debug_recording_dir := self.hass.data[DATA_CONFIG].get(
"debug_recording_dir" CONF_DEBUG_RECORDING_DIR
): ):
if device_id is None: if device_id is None:
# <debug_recording_dir>/<pipeline.name>/<run.id> # <debug_recording_dir>/<pipeline.name>/<run.id>

View File

@ -15,7 +15,7 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.util import language as language_util from homeassistant.util import language as language_util
from .const import DOMAIN from .const import DEFAULT_PIPELINE_TIMEOUT, DEFAULT_WAKE_WORD_TIMEOUT, DOMAIN
from .error import PipelineNotFound from .error import PipelineNotFound
from .pipeline import ( from .pipeline import (
AudioSettings, AudioSettings,
@ -30,9 +30,6 @@ from .pipeline import (
async_get_pipeline, async_get_pipeline,
) )
DEFAULT_TIMEOUT = 60 * 5 # seconds
DEFAULT_WAKE_WORD_TIMEOUT = 3 # seconds
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -117,7 +114,7 @@ async def websocket_run(
) )
return return
timeout = msg.get("timeout", DEFAULT_TIMEOUT) timeout = msg.get("timeout", DEFAULT_PIPELINE_TIMEOUT)
start_stage = PipelineStage(msg["start_stage"]) start_stage = PipelineStage(msg["start_stage"])
end_stage = PipelineStage(msg["end_stage"]) end_stage = PipelineStage(msg["end_stage"])
handler_id: int | None = None handler_id: int | None = None

View File

@ -285,6 +285,7 @@
'format': <AudioFormats.WAV: 'wav'>, 'format': <AudioFormats.WAV: 'wav'>,
'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>, 'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>,
}), }),
'timeout': 0,
}), }),
'type': <PipelineEventType.WAKE_WORD_START: 'wake_word-start'>, 'type': <PipelineEventType.WAKE_WORD_START: 'wake_word-start'>,
}), }),
@ -396,6 +397,7 @@
'format': <AudioFormats.WAV: 'wav'>, 'format': <AudioFormats.WAV: 'wav'>,
'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>, 'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>,
}), }),
'timeout': 0,
}), }),
'type': <PipelineEventType.WAKE_WORD_START: 'wake_word-start'>, 'type': <PipelineEventType.WAKE_WORD_START: 'wake_word-start'>,
}), }),

View File

@ -373,6 +373,7 @@
'format': 'wav', 'format': 'wav',
'sample_rate': 16000, 'sample_rate': 16000,
}), }),
'timeout': 0,
}) })
# --- # ---
# name: test_audio_pipeline_with_wake_word_no_timeout.2 # name: test_audio_pipeline_with_wake_word_no_timeout.2
@ -474,6 +475,7 @@
'format': 'wav', 'format': 'wav',
'sample_rate': 16000, 'sample_rate': 16000,
}), }),
'timeout': 1,
}) })
# --- # ---
# name: test_audio_pipeline_with_wake_word_timeout.2 # name: test_audio_pipeline_with_wake_word_timeout.2
@ -655,3 +657,63 @@
# name: test_tts_failed.2 # name: test_tts_failed.2
None None
# --- # ---
# name: test_wake_word_cooldown
dict({
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 300,
}),
})
# ---
# name: test_wake_word_cooldown.1
dict({
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 300,
}),
})
# ---
# name: test_wake_word_cooldown.2
dict({
'entity_id': 'wake_word.test',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
'codec': 'pcm',
'format': 'wav',
'sample_rate': 16000,
}),
'timeout': 3,
})
# ---
# name: test_wake_word_cooldown.3
dict({
'entity_id': 'wake_word.test',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
'codec': 'pcm',
'format': 'wav',
'sample_rate': 16000,
}),
'timeout': 3,
})
# ---
# name: test_wake_word_cooldown.4
dict({
'wake_word_output': dict({
'timestamp': 0,
'wake_word_id': 'test_ww',
}),
})
# ---
# name: test_wake_word_cooldown.5
dict({
'code': 'wake_word_detection_aborted',
'message': '',
})
# ---

View File

@ -10,6 +10,10 @@ import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from homeassistant.components import assist_pipeline, stt from homeassistant.components import assist_pipeline, stt
from homeassistant.components.assist_pipeline.const import (
CONF_DEBUG_RECORDING_DIR,
DOMAIN,
)
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
@ -395,8 +399,8 @@ async def test_pipeline_save_audio(
temp_dir = Path(temp_dir_str) temp_dir = Path(temp_dir_str)
assert await async_setup_component( assert await async_setup_component(
hass, hass,
"assist_pipeline", DOMAIN,
{"assist_pipeline": {"debug_recording_dir": temp_dir_str}}, {DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}},
) )
pipeline = assist_pipeline.async_get_pipeline(hass) pipeline = assist_pipeline.async_get_pipeline(hass)
@ -476,8 +480,8 @@ async def test_pipeline_saved_audio_with_device_id(
temp_dir = Path(temp_dir_str) temp_dir = Path(temp_dir_str)
assert await async_setup_component( assert await async_setup_component(
hass, hass,
"assist_pipeline", DOMAIN,
{"assist_pipeline": {"debug_recording_dir": temp_dir_str}}, {DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}},
) )
def event_callback(event: assist_pipeline.PipelineEvent): def event_callback(event: assist_pipeline.PipelineEvent):
@ -529,8 +533,8 @@ async def test_pipeline_saved_audio_write_error(
temp_dir = Path(temp_dir_str) temp_dir = Path(temp_dir_str)
assert await async_setup_component( assert await async_setup_component(
hass, hass,
"assist_pipeline", DOMAIN,
{"assist_pipeline": {"debug_recording_dir": temp_dir_str}}, {DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}},
) )
def event_callback(event: assist_pipeline.PipelineEvent): def event_callback(event: assist_pipeline.PipelineEvent):

View File

@ -9,6 +9,8 @@ from homeassistant.components.assist_pipeline.pipeline import Pipeline, Pipeline
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from .conftest import MockWakeWordEntity
from tests.typing import WebSocketGenerator from tests.typing import WebSocketGenerator
@ -266,7 +268,7 @@ async def test_audio_pipeline_with_wake_word_no_timeout(
events.append(msg["event"]) events.append(msg["event"])
# "audio" # "audio"
await client.send_bytes(bytes([1]) + b"wake word") await client.send_bytes(bytes([handler_id]) + b"wake word")
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "wake_word-end" assert msg["event"]["type"] == "wake_word-end"
@ -1805,3 +1807,84 @@ async def test_audio_pipeline_with_enhancements(
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] assert msg["success"]
assert msg["result"] == {"events": events} assert msg["result"] == {"events": events}
async def test_wake_word_cooldown(
hass: HomeAssistant,
init_components,
mock_wake_word_provider_entity: MockWakeWordEntity,
hass_ws_client: WebSocketGenerator,
snapshot: SnapshotAssertion,
) -> None:
"""Test that duplicate wake word detections are blocked during the cooldown period."""
client_1 = await hass_ws_client(hass)
client_2 = await hass_ws_client(hass)
await client_1.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "wake_word",
"end_stage": "tts",
"input": {
"sample_rate": 16000,
"no_vad": True,
"no_chunking": True,
},
}
)
await client_2.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "wake_word",
"end_stage": "tts",
"input": {
"sample_rate": 16000,
"no_vad": True,
"no_chunking": True,
},
}
)
# result
msg = await client_1.receive_json()
assert msg["success"], msg
msg = await client_2.receive_json()
assert msg["success"], msg
# run start
msg = await client_1.receive_json()
assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
handler_id_1 = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
assert msg["event"]["data"] == snapshot
msg = await client_2.receive_json()
assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
handler_id_2 = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
assert msg["event"]["data"] == snapshot
# wake_word
msg = await client_1.receive_json()
assert msg["event"]["type"] == "wake_word-start"
assert msg["event"]["data"] == snapshot
msg = await client_2.receive_json()
assert msg["event"]["type"] == "wake_word-start"
assert msg["event"]["data"] == snapshot
# Wake both up at the same time
await client_1.send_bytes(bytes([handler_id_1]) + b"wake word")
await client_2.send_bytes(bytes([handler_id_2]) + b"wake word")
# Get response events
msg = await client_1.receive_json()
event_type_1 = msg["event"]["type"]
msg = await client_2.receive_json()
event_type_2 = msg["event"]["type"]
# One should be a wake up, one should be an error
assert {event_type_1, event_type_2} == {"wake_word-end", "error"}