No cooldown when wake words have the same id (#101846)

* No cooldown when wake words have the same id

* Use wake word entity id in cooldown decision
This commit is contained in:
Michael Hansen 2023-10-23 12:12:34 -05:00 committed by GitHub
parent 54bcd70878
commit a52761171f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 452 additions and 10 deletions

View File

@ -9,7 +9,7 @@ from homeassistant.components import stt
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from .const import CONF_DEBUG_RECORDING_DIR, DATA_CONFIG, DOMAIN from .const import CONF_DEBUG_RECORDING_DIR, DATA_CONFIG, DATA_LAST_WAKE_UP, DOMAIN
from .error import PipelineNotFound from .error import PipelineNotFound
from .pipeline import ( from .pipeline import (
AudioSettings, AudioSettings,
@ -58,6 +58,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the Assist pipeline integration.""" """Set up the Assist pipeline integration."""
hass.data[DATA_CONFIG] = config.get(DOMAIN, {}) hass.data[DATA_CONFIG] = config.get(DOMAIN, {})
# wake_word_id -> timestamp of last detection (monotonic_ns)
hass.data[DATA_LAST_WAKE_UP] = {}
await async_setup_pipeline_store(hass) await async_setup_pipeline_store(hass)
async_register_websocket_api(hass) async_register_websocket_api(hass)

View File

@ -681,7 +681,8 @@ class PipelineRun:
wake_word_output: dict[str, Any] = {} wake_word_output: dict[str, Any] = {}
else: else:
# Avoid duplicate detections by checking cooldown # Avoid duplicate detections by checking cooldown
last_wake_up = self.hass.data.get(DATA_LAST_WAKE_UP) wake_up_key = f"{self.wake_word_entity_id}.{result.wake_word_id}"
last_wake_up = self.hass.data[DATA_LAST_WAKE_UP].get(wake_up_key)
if last_wake_up is not None: if last_wake_up is not None:
sec_since_last_wake_up = time.monotonic() - last_wake_up sec_since_last_wake_up = time.monotonic() - last_wake_up
if sec_since_last_wake_up < wake_word_settings.cooldown_seconds: if sec_since_last_wake_up < wake_word_settings.cooldown_seconds:
@ -689,7 +690,7 @@ class PipelineRun:
raise WakeWordDetectionAborted raise WakeWordDetectionAborted
# Record last wake up time to block duplicate detections # Record last wake up time to block duplicate detections
self.hass.data[DATA_LAST_WAKE_UP] = time.monotonic() self.hass.data[DATA_LAST_WAKE_UP][wake_up_key] = time.monotonic()
if result.queued_audio: if result.queued_audio:
# Add audio that was pending at detection. # Add audio that was pending at detection.

View File

@ -181,6 +181,49 @@ class MockWakeWordEntity(wake_word.WakeWordDetectionEntity):
url_path = "wake_word.test" url_path = "wake_word.test"
_attr_name = "test" _attr_name = "test"
alternate_detections = False
detected_wake_word_index = 0
async def get_supported_wake_words(self) -> list[wake_word.WakeWord]:
"""Return a list of supported wake words."""
return [
wake_word.WakeWord(id="test_ww", name="Test Wake Word"),
wake_word.WakeWord(id="test_ww_2", name="Test Wake Word 2"),
]
async def _async_process_audio_stream(
self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None
) -> wake_word.DetectionResult | None:
"""Try to detect wake word(s) in an audio stream with timestamps."""
wake_words = await self.get_supported_wake_words()
if self.alternate_detections:
detected_id = wake_words[self.detected_wake_word_index].id
self.detected_wake_word_index = (self.detected_wake_word_index + 1) % len(
wake_words
)
else:
detected_id = wake_words[0].id
async for chunk, timestamp in stream:
if chunk.startswith(b"wake word"):
return wake_word.DetectionResult(
wake_word_id=detected_id,
timestamp=timestamp,
queued_audio=[(b"queued audio", 0)],
)
# Not detected
return None
class MockWakeWordEntity2(wake_word.WakeWordDetectionEntity):
"""Second mock wake word entity to test cooldown."""
fail_process_audio = False
url_path = "wake_word.test2"
_attr_name = "test2"
async def get_supported_wake_words(self) -> list[wake_word.WakeWord]: async def get_supported_wake_words(self) -> list[wake_word.WakeWord]:
"""Return a list of supported wake words.""" """Return a list of supported wake words."""
return [wake_word.WakeWord(id="test_ww", name="Test Wake Word")] return [wake_word.WakeWord(id="test_ww", name="Test Wake Word")]
@ -189,12 +232,12 @@ class MockWakeWordEntity(wake_word.WakeWordDetectionEntity):
self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None
) -> wake_word.DetectionResult | None: ) -> wake_word.DetectionResult | None:
"""Try to detect wake word(s) in an audio stream with timestamps.""" """Try to detect wake word(s) in an audio stream with timestamps."""
if wake_word_id is None: wake_words = await self.get_supported_wake_words()
wake_word_id = (await self.get_supported_wake_words())[0].id
async for chunk, timestamp in stream: async for chunk, timestamp in stream:
if chunk.startswith(b"wake word"): if chunk.startswith(b"wake word"):
return wake_word.DetectionResult( return wake_word.DetectionResult(
wake_word_id=wake_word_id, wake_word_id=wake_words[0].id,
timestamp=timestamp, timestamp=timestamp,
queued_audio=[(b"queued audio", 0)], queued_audio=[(b"queued audio", 0)],
) )
@ -209,6 +252,12 @@ async def mock_wake_word_provider_entity(hass) -> MockWakeWordEntity:
return MockWakeWordEntity() return MockWakeWordEntity()
@pytest.fixture
async def mock_wake_word_provider_entity2(hass) -> MockWakeWordEntity2:
"""Mock wake word provider."""
return MockWakeWordEntity2()
class MockFlow(ConfigFlow): class MockFlow(ConfigFlow):
"""Test flow.""" """Test flow."""
@ -229,6 +278,7 @@ async def init_supporting_components(
mock_stt_provider_entity: MockSttProviderEntity, mock_stt_provider_entity: MockSttProviderEntity,
mock_tts_provider: MockTTSProvider, mock_tts_provider: MockTTSProvider,
mock_wake_word_provider_entity: MockWakeWordEntity, mock_wake_word_provider_entity: MockWakeWordEntity,
mock_wake_word_provider_entity2: MockWakeWordEntity2,
config_flow_fixture, config_flow_fixture,
): ):
"""Initialize relevant components with empty configs.""" """Initialize relevant components with empty configs."""
@ -265,7 +315,9 @@ async def init_supporting_components(
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up test wake word platform via config entry.""" """Set up test wake word platform via config entry."""
async_add_entities([mock_wake_word_provider_entity]) async_add_entities(
[mock_wake_word_provider_entity, mock_wake_word_provider_entity2]
)
mock_integration( mock_integration(
hass, hass,

View File

@ -717,3 +717,173 @@
'message': '', 'message': '',
}) })
# --- # ---
# name: test_wake_word_cooldown_different_entities
dict({
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 300,
}),
})
# ---
# name: test_wake_word_cooldown_different_entities.1
dict({
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 300,
}),
})
# ---
# name: test_wake_word_cooldown_different_entities.2
dict({
'entity_id': 'wake_word.test',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
'codec': 'pcm',
'format': 'wav',
'sample_rate': 16000,
}),
'timeout': 3,
})
# ---
# name: test_wake_word_cooldown_different_entities.3
dict({
'entity_id': 'wake_word.test2',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
'codec': 'pcm',
'format': 'wav',
'sample_rate': 16000,
}),
'timeout': 3,
})
# ---
# name: test_wake_word_cooldown_different_entities.4
dict({
'wake_word_output': dict({
'timestamp': 0,
'wake_word_id': 'test_ww',
}),
})
# ---
# name: test_wake_word_cooldown_different_entities.5
dict({
'wake_word_output': dict({
'timestamp': 0,
'wake_word_id': 'test_ww',
}),
})
# ---
# name: test_wake_word_cooldown_different_ids
dict({
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 300,
}),
})
# ---
# name: test_wake_word_cooldown_different_ids.1
dict({
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 300,
}),
})
# ---
# name: test_wake_word_cooldown_different_ids.2
dict({
'entity_id': 'wake_word.test',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
'codec': 'pcm',
'format': 'wav',
'sample_rate': 16000,
}),
'timeout': 3,
})
# ---
# name: test_wake_word_cooldown_different_ids.3
dict({
'entity_id': 'wake_word.test',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
'codec': 'pcm',
'format': 'wav',
'sample_rate': 16000,
}),
'timeout': 3,
})
# ---
# name: test_wake_word_cooldown_different_ids.4
dict({
'wake_word_output': dict({
'timestamp': 0,
'wake_word_id': 'test_ww',
}),
})
# ---
# name: test_wake_word_cooldown_different_ids.5
dict({
'wake_word_output': dict({
'timestamp': 0,
'wake_word_id': 'test_ww_2',
}),
})
# ---
# name: test_wake_word_cooldown_same_id
dict({
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 300,
}),
})
# ---
# name: test_wake_word_cooldown_same_id.1
dict({
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 300,
}),
})
# ---
# name: test_wake_word_cooldown_same_id.2
dict({
'entity_id': 'wake_word.test',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
'codec': 'pcm',
'format': 'wav',
'sample_rate': 16000,
}),
'timeout': 3,
})
# ---
# name: test_wake_word_cooldown_same_id.3
dict({
'entity_id': 'wake_word.test',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
'codec': 'pcm',
'format': 'wav',
'sample_rate': 16000,
}),
'timeout': 3,
})
# ---

View File

@ -9,7 +9,7 @@ from homeassistant.components.assist_pipeline.pipeline import Pipeline, Pipeline
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from .conftest import MockWakeWordEntity from .conftest import MockWakeWordEntity, MockWakeWordEntity2
from tests.typing import WebSocketGenerator from tests.typing import WebSocketGenerator
@ -1809,14 +1809,14 @@ async def test_audio_pipeline_with_enhancements(
assert msg["result"] == {"events": events} assert msg["result"] == {"events": events}
async def test_wake_word_cooldown( async def test_wake_word_cooldown_same_id(
hass: HomeAssistant, hass: HomeAssistant,
init_components, init_components,
mock_wake_word_provider_entity: MockWakeWordEntity, mock_wake_word_provider_entity: MockWakeWordEntity,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Test that duplicate wake word detections are blocked during the cooldown period.""" """Test that duplicate wake word detections with the same id are blocked during the cooldown period."""
client_1 = await hass_ws_client(hass) client_1 = await hass_ws_client(hass)
client_2 = await hass_ws_client(hass) client_2 = await hass_ws_client(hass)
@ -1888,3 +1888,219 @@ async def test_wake_word_cooldown(
# One should be a wake up, one should be an error # One should be a wake up, one should be an error
assert {event_type_1, event_type_2} == {"wake_word-end", "error"} assert {event_type_1, event_type_2} == {"wake_word-end", "error"}
async def test_wake_word_cooldown_different_ids(
hass: HomeAssistant,
init_components,
mock_wake_word_provider_entity: MockWakeWordEntity,
hass_ws_client: WebSocketGenerator,
snapshot: SnapshotAssertion,
) -> None:
"""Test that duplicate wake word detections are allowed with different ids."""
with patch.object(mock_wake_word_provider_entity, "alternate_detections", True):
client_1 = await hass_ws_client(hass)
client_2 = await hass_ws_client(hass)
await client_1.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "wake_word",
"end_stage": "tts",
"input": {
"sample_rate": 16000,
"no_vad": True,
"no_chunking": True,
},
}
)
await client_2.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "wake_word",
"end_stage": "tts",
"input": {
"sample_rate": 16000,
"no_vad": True,
"no_chunking": True,
},
}
)
# result
msg = await client_1.receive_json()
assert msg["success"], msg
msg = await client_2.receive_json()
assert msg["success"], msg
# run start
msg = await client_1.receive_json()
assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
handler_id_1 = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
assert msg["event"]["data"] == snapshot
msg = await client_2.receive_json()
assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
handler_id_2 = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
assert msg["event"]["data"] == snapshot
# wake_word
msg = await client_1.receive_json()
assert msg["event"]["type"] == "wake_word-start"
assert msg["event"]["data"] == snapshot
msg = await client_2.receive_json()
assert msg["event"]["type"] == "wake_word-start"
assert msg["event"]["data"] == snapshot
# Wake both up at the same time, but they will have different wake word ids
await client_1.send_bytes(bytes([handler_id_1]) + b"wake word")
await client_2.send_bytes(bytes([handler_id_2]) + b"wake word")
# Get response events
msg = await client_1.receive_json()
event_type_1 = msg["event"]["type"]
assert msg["event"]["data"] == snapshot
msg = await client_2.receive_json()
event_type_2 = msg["event"]["type"]
assert msg["event"]["data"] == snapshot
# Both should wake up now
assert {event_type_1, event_type_2} == {"wake_word-end"}
async def test_wake_word_cooldown_different_entities(
hass: HomeAssistant,
init_components,
mock_wake_word_provider_entity: MockWakeWordEntity,
mock_wake_word_provider_entity2: MockWakeWordEntity2,
hass_ws_client: WebSocketGenerator,
snapshot: SnapshotAssertion,
) -> None:
"""Test that duplicate wake word detections are allowed with different entities."""
client_pipeline = await hass_ws_client(hass)
await client_pipeline.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/create",
"conversation_engine": "homeassistant",
"conversation_language": "en-US",
"language": "en",
"name": "pipeline_with_wake_word_1",
"stt_engine": "test",
"stt_language": "en-US",
"tts_engine": "test",
"tts_language": "en-US",
"tts_voice": "Arnold Schwarzenegger",
"wake_word_entity": mock_wake_word_provider_entity.entity_id,
"wake_word_id": "test_ww",
}
)
msg = await client_pipeline.receive_json()
assert msg["success"]
pipeline_id_1 = msg["result"]["id"]
await client_pipeline.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/create",
"conversation_engine": "homeassistant",
"conversation_language": "en-US",
"language": "en",
"name": "pipeline_with_wake_word_2",
"stt_engine": "test",
"stt_language": "en-US",
"tts_engine": "test",
"tts_language": "en-US",
"tts_voice": "Arnold Schwarzenegger",
"wake_word_entity": mock_wake_word_provider_entity2.entity_id,
"wake_word_id": "test_ww",
}
)
msg = await client_pipeline.receive_json()
assert msg["success"]
pipeline_id_2 = msg["result"]["id"]
# Wake word clients
client_1 = await hass_ws_client(hass)
client_2 = await hass_ws_client(hass)
await client_1.send_json_auto_id(
{
"type": "assist_pipeline/run",
"pipeline": pipeline_id_1,
"start_stage": "wake_word",
"end_stage": "tts",
"input": {
"sample_rate": 16000,
"no_vad": True,
"no_chunking": True,
},
}
)
# Use different wake word entity
await client_2.send_json_auto_id(
{
"type": "assist_pipeline/run",
"pipeline": pipeline_id_2,
"start_stage": "wake_word",
"end_stage": "tts",
"input": {
"sample_rate": 16000,
"no_vad": True,
"no_chunking": True,
},
}
)
# result
msg = await client_1.receive_json()
assert msg["success"], msg
msg = await client_2.receive_json()
assert msg["success"], msg
# run start
msg = await client_1.receive_json()
assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
handler_id_1 = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
assert msg["event"]["data"] == snapshot
msg = await client_2.receive_json()
assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
handler_id_2 = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
assert msg["event"]["data"] == snapshot
# wake_word
msg = await client_1.receive_json()
assert msg["event"]["type"] == "wake_word-start"
assert msg["event"]["data"] == snapshot
msg = await client_2.receive_json()
assert msg["event"]["type"] == "wake_word-start"
assert msg["event"]["data"] == snapshot
# Wake both up at the same time.
# They will have the same wake word id, but different entities.
await client_1.send_bytes(bytes([handler_id_1]) + b"wake word")
await client_2.send_bytes(bytes([handler_id_2]) + b"wake word")
# Get response events
msg = await client_1.receive_json()
assert msg["event"]["type"] == "wake_word-end", msg
ww_id_1 = msg["event"]["data"]["wake_word_output"]["wake_word_id"]
assert msg["event"]["data"] == snapshot
msg = await client_2.receive_json()
assert msg["event"]["type"] == "wake_word-end", msg
ww_id_2 = msg["event"]["data"]["wake_word_output"]["wake_word_id"]
assert msg["event"]["data"] == snapshot
# Wake words should be the same
assert ww_id_1 == ww_id_2