From a52761171f2c4cc12a18fba0cd13fec4deb7bcf1 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Mon, 23 Oct 2023 12:12:34 -0500 Subject: [PATCH] No cooldown when wake words have the same id (#101846) * No cooldown when wake words have the same id * Use wake word entity id in cooldown decision --- .../components/assist_pipeline/__init__.py | 5 +- .../components/assist_pipeline/pipeline.py | 5 +- tests/components/assist_pipeline/conftest.py | 60 ++++- .../snapshots/test_websocket.ambr | 170 ++++++++++++++ .../assist_pipeline/test_websocket.py | 222 +++++++++++++++++- 5 files changed, 452 insertions(+), 10 deletions(-) diff --git a/homeassistant/components/assist_pipeline/__init__.py b/homeassistant/components/assist_pipeline/__init__.py index fab4c3178bc..64fe9e1f5f4 100644 --- a/homeassistant/components/assist_pipeline/__init__.py +++ b/homeassistant/components/assist_pipeline/__init__.py @@ -9,7 +9,7 @@ from homeassistant.components import stt from homeassistant.core import Context, HomeAssistant from homeassistant.helpers.typing import ConfigType -from .const import CONF_DEBUG_RECORDING_DIR, DATA_CONFIG, DOMAIN +from .const import CONF_DEBUG_RECORDING_DIR, DATA_CONFIG, DATA_LAST_WAKE_UP, DOMAIN from .error import PipelineNotFound from .pipeline import ( AudioSettings, @@ -58,6 +58,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up the Assist pipeline integration.""" hass.data[DATA_CONFIG] = config.get(DOMAIN, {}) + # wake_word_id -> timestamp of last detection (monotonic_ns) + hass.data[DATA_LAST_WAKE_UP] = {} + await async_setup_pipeline_store(hass) async_register_websocket_api(hass) diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index 6ec031baf3b..bb34a223af6 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -681,7 +681,8 @@ class PipelineRun: wake_word_output: dict[str, Any] = {} else: # Avoid duplicate detections by checking cooldown - last_wake_up = self.hass.data.get(DATA_LAST_WAKE_UP) + wake_up_key = f"{self.wake_word_entity_id}.{result.wake_word_id}" + last_wake_up = self.hass.data[DATA_LAST_WAKE_UP].get(wake_up_key) if last_wake_up is not None: sec_since_last_wake_up = time.monotonic() - last_wake_up if sec_since_last_wake_up < wake_word_settings.cooldown_seconds: @@ -689,7 +690,7 @@ class PipelineRun: raise WakeWordDetectionAborted # Record last wake up time to block duplicate detections - self.hass.data[DATA_LAST_WAKE_UP] = time.monotonic() + self.hass.data[DATA_LAST_WAKE_UP][wake_up_key] = time.monotonic() if result.queued_audio: # Add audio that was pending at detection. diff --git a/tests/components/assist_pipeline/conftest.py b/tests/components/assist_pipeline/conftest.py index 1a3144ee069..97f80a33d1d 100644 --- a/tests/components/assist_pipeline/conftest.py +++ b/tests/components/assist_pipeline/conftest.py @@ -181,6 +181,49 @@ class MockWakeWordEntity(wake_word.WakeWordDetectionEntity): url_path = "wake_word.test" _attr_name = "test" + alternate_detections = False + detected_wake_word_index = 0 + + async def get_supported_wake_words(self) -> list[wake_word.WakeWord]: + """Return a list of supported wake words.""" + return [ + wake_word.WakeWord(id="test_ww", name="Test Wake Word"), + wake_word.WakeWord(id="test_ww_2", name="Test Wake Word 2"), + ] + + async def _async_process_audio_stream( + self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None + ) -> wake_word.DetectionResult | None: + """Try to detect wake word(s) in an audio stream with timestamps.""" + wake_words = await self.get_supported_wake_words() + + if self.alternate_detections: + detected_id = wake_words[self.detected_wake_word_index].id + self.detected_wake_word_index = (self.detected_wake_word_index + 1) % len( + wake_words + ) + else: + detected_id = wake_words[0].id + + async for chunk, timestamp in stream: + if chunk.startswith(b"wake word"): + return wake_word.DetectionResult( + wake_word_id=detected_id, + timestamp=timestamp, + queued_audio=[(b"queued audio", 0)], + ) + + # Not detected + return None + + +class MockWakeWordEntity2(wake_word.WakeWordDetectionEntity): + """Second mock wake word entity to test cooldown.""" + + fail_process_audio = False + url_path = "wake_word.test2" + _attr_name = "test2" + async def get_supported_wake_words(self) -> list[wake_word.WakeWord]: """Return a list of supported wake words.""" return [wake_word.WakeWord(id="test_ww", name="Test Wake Word")] @@ -189,12 +232,12 @@ class MockWakeWordEntity(wake_word.WakeWordDetectionEntity): self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None ) -> wake_word.DetectionResult | None: """Try to detect wake word(s) in an audio stream with timestamps.""" - if wake_word_id is None: - wake_word_id = (await self.get_supported_wake_words())[0].id + wake_words = await self.get_supported_wake_words() + async for chunk, timestamp in stream: if chunk.startswith(b"wake word"): return wake_word.DetectionResult( - wake_word_id=wake_word_id, + wake_word_id=wake_words[0].id, timestamp=timestamp, queued_audio=[(b"queued audio", 0)], ) @@ -209,6 +252,12 @@ async def mock_wake_word_provider_entity(hass) -> MockWakeWordEntity: return MockWakeWordEntity() +@pytest.fixture +async def mock_wake_word_provider_entity2(hass) -> MockWakeWordEntity2: + """Mock wake word provider.""" + return MockWakeWordEntity2() + + class MockFlow(ConfigFlow): """Test flow.""" @@ -229,6 +278,7 @@ async def init_supporting_components( mock_stt_provider_entity: MockSttProviderEntity, mock_tts_provider: MockTTSProvider, mock_wake_word_provider_entity: MockWakeWordEntity, + mock_wake_word_provider_entity2: MockWakeWordEntity2, config_flow_fixture, ): """Initialize relevant components with empty configs.""" @@ -265,7 +315,9 @@ async def init_supporting_components( async_add_entities: AddEntitiesCallback, ) -> None: """Set up test wake word platform via config entry.""" - async_add_entities([mock_wake_word_provider_entity]) + async_add_entities( + [mock_wake_word_provider_entity, mock_wake_word_provider_entity2] + ) mock_integration( hass, diff --git a/tests/components/assist_pipeline/snapshots/test_websocket.ambr b/tests/components/assist_pipeline/snapshots/test_websocket.ambr index b8c668f3fd0..9eb7e1e5a05 100644 --- a/tests/components/assist_pipeline/snapshots/test_websocket.ambr +++ b/tests/components/assist_pipeline/snapshots/test_websocket.ambr @@ -717,3 +717,173 @@ 'message': '', }) # --- +# name: test_wake_word_cooldown_different_entities + dict({ + 'language': 'en', + 'pipeline': , + 'runner_data': dict({ + 'stt_binary_handler_id': 1, + 'timeout': 300, + }), + }) +# --- +# name: test_wake_word_cooldown_different_entities.1 + dict({ + 'language': 'en', + 'pipeline': , + 'runner_data': dict({ + 'stt_binary_handler_id': 1, + 'timeout': 300, + }), + }) +# --- +# name: test_wake_word_cooldown_different_entities.2 + dict({ + 'entity_id': 'wake_word.test', + 'metadata': dict({ + 'bit_rate': 16, + 'channel': 1, + 'codec': 'pcm', + 'format': 'wav', + 'sample_rate': 16000, + }), + 'timeout': 3, + }) +# --- +# name: test_wake_word_cooldown_different_entities.3 + dict({ + 'entity_id': 'wake_word.test2', + 'metadata': dict({ + 'bit_rate': 16, + 'channel': 1, + 'codec': 'pcm', + 'format': 'wav', + 'sample_rate': 16000, + }), + 'timeout': 3, + }) +# --- +# name: test_wake_word_cooldown_different_entities.4 + dict({ + 'wake_word_output': dict({ + 'timestamp': 0, + 'wake_word_id': 'test_ww', + }), + }) +# --- +# name: test_wake_word_cooldown_different_entities.5 + dict({ + 'wake_word_output': dict({ + 'timestamp': 0, + 'wake_word_id': 'test_ww', + }), + }) +# --- +# name: test_wake_word_cooldown_different_ids + dict({ + 'language': 'en', + 'pipeline': , + 'runner_data': dict({ + 'stt_binary_handler_id': 1, + 'timeout': 300, + }), + }) +# --- +# name: test_wake_word_cooldown_different_ids.1 + dict({ + 'language': 'en', + 'pipeline': , + 'runner_data': dict({ + 'stt_binary_handler_id': 1, + 'timeout': 300, + }), + }) +# --- +# name: test_wake_word_cooldown_different_ids.2 + dict({ + 'entity_id': 'wake_word.test', + 'metadata': dict({ + 'bit_rate': 16, + 'channel': 1, + 'codec': 'pcm', + 'format': 'wav', + 'sample_rate': 16000, + }), + 'timeout': 3, + }) +# --- +# name: test_wake_word_cooldown_different_ids.3 + dict({ + 'entity_id': 'wake_word.test', + 'metadata': dict({ + 'bit_rate': 16, + 'channel': 1, + 'codec': 'pcm', + 'format': 'wav', + 'sample_rate': 16000, + }), + 'timeout': 3, + }) +# --- +# name: test_wake_word_cooldown_different_ids.4 + dict({ + 'wake_word_output': dict({ + 'timestamp': 0, + 'wake_word_id': 'test_ww', + }), + }) +# --- +# name: test_wake_word_cooldown_different_ids.5 + dict({ + 'wake_word_output': dict({ + 'timestamp': 0, + 'wake_word_id': 'test_ww_2', + }), + }) +# --- +# name: test_wake_word_cooldown_same_id + dict({ + 'language': 'en', + 'pipeline': , + 'runner_data': dict({ + 'stt_binary_handler_id': 1, + 'timeout': 300, + }), + }) +# --- +# name: test_wake_word_cooldown_same_id.1 + dict({ + 'language': 'en', + 'pipeline': , + 'runner_data': dict({ + 'stt_binary_handler_id': 1, + 'timeout': 300, + }), + }) +# --- +# name: test_wake_word_cooldown_same_id.2 + dict({ + 'entity_id': 'wake_word.test', + 'metadata': dict({ + 'bit_rate': 16, + 'channel': 1, + 'codec': 'pcm', + 'format': 'wav', + 'sample_rate': 16000, + }), + 'timeout': 3, + }) +# --- +# name: test_wake_word_cooldown_same_id.3 + dict({ + 'entity_id': 'wake_word.test', + 'metadata': dict({ + 'bit_rate': 16, + 'channel': 1, + 'codec': 'pcm', + 'format': 'wav', + 'sample_rate': 16000, + }), + 'timeout': 3, + }) +# --- diff --git a/tests/components/assist_pipeline/test_websocket.py b/tests/components/assist_pipeline/test_websocket.py index 28b31e5b19c..9a4e78a29af 100644 --- a/tests/components/assist_pipeline/test_websocket.py +++ b/tests/components/assist_pipeline/test_websocket.py @@ -9,7 +9,7 @@ from homeassistant.components.assist_pipeline.pipeline import Pipeline, Pipeline from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError -from .conftest import MockWakeWordEntity +from .conftest import MockWakeWordEntity, MockWakeWordEntity2 from tests.typing import WebSocketGenerator @@ -1809,14 +1809,14 @@ async def test_audio_pipeline_with_enhancements( assert msg["result"] == {"events": events} -async def test_wake_word_cooldown( +async def test_wake_word_cooldown_same_id( hass: HomeAssistant, init_components, mock_wake_word_provider_entity: MockWakeWordEntity, hass_ws_client: WebSocketGenerator, snapshot: SnapshotAssertion, ) -> None: - """Test that duplicate wake word detections are blocked during the cooldown period.""" + """Test that duplicate wake word detections with the same id are blocked during the cooldown period.""" client_1 = await hass_ws_client(hass) client_2 = await hass_ws_client(hass) @@ -1888,3 +1888,219 @@ async def test_wake_word_cooldown( # One should be a wake up, one should be an error assert {event_type_1, event_type_2} == {"wake_word-end", "error"} + + +async def test_wake_word_cooldown_different_ids( + hass: HomeAssistant, + init_components, + mock_wake_word_provider_entity: MockWakeWordEntity, + hass_ws_client: WebSocketGenerator, + snapshot: SnapshotAssertion, +) -> None: + """Test that duplicate wake word detections are allowed with different ids.""" + with patch.object(mock_wake_word_provider_entity, "alternate_detections", True): + client_1 = await hass_ws_client(hass) + client_2 = await hass_ws_client(hass) + + await client_1.send_json_auto_id( + { + "type": "assist_pipeline/run", + "start_stage": "wake_word", + "end_stage": "tts", + "input": { + "sample_rate": 16000, + "no_vad": True, + "no_chunking": True, + }, + } + ) + + await client_2.send_json_auto_id( + { + "type": "assist_pipeline/run", + "start_stage": "wake_word", + "end_stage": "tts", + "input": { + "sample_rate": 16000, + "no_vad": True, + "no_chunking": True, + }, + } + ) + + # result + msg = await client_1.receive_json() + assert msg["success"], msg + + msg = await client_2.receive_json() + assert msg["success"], msg + + # run start + msg = await client_1.receive_json() + assert msg["event"]["type"] == "run-start" + msg["event"]["data"]["pipeline"] = ANY + handler_id_1 = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"] + assert msg["event"]["data"] == snapshot + + msg = await client_2.receive_json() + assert msg["event"]["type"] == "run-start" + msg["event"]["data"]["pipeline"] = ANY + handler_id_2 = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"] + assert msg["event"]["data"] == snapshot + + # wake_word + msg = await client_1.receive_json() + assert msg["event"]["type"] == "wake_word-start" + assert msg["event"]["data"] == snapshot + + msg = await client_2.receive_json() + assert msg["event"]["type"] == "wake_word-start" + assert msg["event"]["data"] == snapshot + + # Wake both up at the same time, but they will have different wake word ids + await client_1.send_bytes(bytes([handler_id_1]) + b"wake word") + await client_2.send_bytes(bytes([handler_id_2]) + b"wake word") + + # Get response events + msg = await client_1.receive_json() + event_type_1 = msg["event"]["type"] + assert msg["event"]["data"] == snapshot + + msg = await client_2.receive_json() + event_type_2 = msg["event"]["type"] + assert msg["event"]["data"] == snapshot + + # Both should wake up now + assert {event_type_1, event_type_2} == {"wake_word-end"} + + +async def test_wake_word_cooldown_different_entities( + hass: HomeAssistant, + init_components, + mock_wake_word_provider_entity: MockWakeWordEntity, + mock_wake_word_provider_entity2: MockWakeWordEntity2, + hass_ws_client: WebSocketGenerator, + snapshot: SnapshotAssertion, +) -> None: + """Test that duplicate wake word detections are allowed with different entities.""" + client_pipeline = await hass_ws_client(hass) + await client_pipeline.send_json_auto_id( + { + "type": "assist_pipeline/pipeline/create", + "conversation_engine": "homeassistant", + "conversation_language": "en-US", + "language": "en", + "name": "pipeline_with_wake_word_1", + "stt_engine": "test", + "stt_language": "en-US", + "tts_engine": "test", + "tts_language": "en-US", + "tts_voice": "Arnold Schwarzenegger", + "wake_word_entity": mock_wake_word_provider_entity.entity_id, + "wake_word_id": "test_ww", + } + ) + msg = await client_pipeline.receive_json() + assert msg["success"] + pipeline_id_1 = msg["result"]["id"] + + await client_pipeline.send_json_auto_id( + { + "type": "assist_pipeline/pipeline/create", + "conversation_engine": "homeassistant", + "conversation_language": "en-US", + "language": "en", + "name": "pipeline_with_wake_word_2", + "stt_engine": "test", + "stt_language": "en-US", + "tts_engine": "test", + "tts_language": "en-US", + "tts_voice": "Arnold Schwarzenegger", + "wake_word_entity": mock_wake_word_provider_entity2.entity_id, + "wake_word_id": "test_ww", + } + ) + msg = await client_pipeline.receive_json() + assert msg["success"] + pipeline_id_2 = msg["result"]["id"] + + # Wake word clients + client_1 = await hass_ws_client(hass) + client_2 = await hass_ws_client(hass) + + await client_1.send_json_auto_id( + { + "type": "assist_pipeline/run", + "pipeline": pipeline_id_1, + "start_stage": "wake_word", + "end_stage": "tts", + "input": { + "sample_rate": 16000, + "no_vad": True, + "no_chunking": True, + }, + } + ) + + # Use different wake word entity + await client_2.send_json_auto_id( + { + "type": "assist_pipeline/run", + "pipeline": pipeline_id_2, + "start_stage": "wake_word", + "end_stage": "tts", + "input": { + "sample_rate": 16000, + "no_vad": True, + "no_chunking": True, + }, + } + ) + + # result + msg = await client_1.receive_json() + assert msg["success"], msg + + msg = await client_2.receive_json() + assert msg["success"], msg + + # run start + msg = await client_1.receive_json() + assert msg["event"]["type"] == "run-start" + msg["event"]["data"]["pipeline"] = ANY + handler_id_1 = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"] + assert msg["event"]["data"] == snapshot + + msg = await client_2.receive_json() + assert msg["event"]["type"] == "run-start" + msg["event"]["data"]["pipeline"] = ANY + handler_id_2 = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"] + assert msg["event"]["data"] == snapshot + + # wake_word + msg = await client_1.receive_json() + assert msg["event"]["type"] == "wake_word-start" + assert msg["event"]["data"] == snapshot + + msg = await client_2.receive_json() + assert msg["event"]["type"] == "wake_word-start" + assert msg["event"]["data"] == snapshot + + # Wake both up at the same time. + # They will have the same wake word id, but different entities. + await client_1.send_bytes(bytes([handler_id_1]) + b"wake word") + await client_2.send_bytes(bytes([handler_id_2]) + b"wake word") + + # Get response events + msg = await client_1.receive_json() + assert msg["event"]["type"] == "wake_word-end", msg + ww_id_1 = msg["event"]["data"]["wake_word_output"]["wake_word_id"] + assert msg["event"]["data"] == snapshot + + msg = await client_2.receive_json() + assert msg["event"]["type"] == "wake_word-end", msg + ww_id_2 = msg["event"]["data"]["wake_word_output"]["wake_word_id"] + assert msg["event"]["data"] == snapshot + + # Wake words should be the same + assert ww_id_1 == ww_id_2