mirror of
https://github.com/home-assistant/core.git
synced 2025-07-17 02:07:09 +00:00
No cooldown when wake words have the same id (#101846)
* No cooldown when wake words have the same id * Use wake word entity id in cooldown decision
This commit is contained in:
parent
54bcd70878
commit
a52761171f
@ -9,7 +9,7 @@ from homeassistant.components import stt
|
|||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
|
||||||
from .const import CONF_DEBUG_RECORDING_DIR, DATA_CONFIG, DOMAIN
|
from .const import CONF_DEBUG_RECORDING_DIR, DATA_CONFIG, DATA_LAST_WAKE_UP, DOMAIN
|
||||||
from .error import PipelineNotFound
|
from .error import PipelineNotFound
|
||||||
from .pipeline import (
|
from .pipeline import (
|
||||||
AudioSettings,
|
AudioSettings,
|
||||||
@ -58,6 +58,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
"""Set up the Assist pipeline integration."""
|
"""Set up the Assist pipeline integration."""
|
||||||
hass.data[DATA_CONFIG] = config.get(DOMAIN, {})
|
hass.data[DATA_CONFIG] = config.get(DOMAIN, {})
|
||||||
|
|
||||||
|
# wake_word_id -> timestamp of last detection (monotonic_ns)
|
||||||
|
hass.data[DATA_LAST_WAKE_UP] = {}
|
||||||
|
|
||||||
await async_setup_pipeline_store(hass)
|
await async_setup_pipeline_store(hass)
|
||||||
async_register_websocket_api(hass)
|
async_register_websocket_api(hass)
|
||||||
|
|
||||||
|
@ -681,7 +681,8 @@ class PipelineRun:
|
|||||||
wake_word_output: dict[str, Any] = {}
|
wake_word_output: dict[str, Any] = {}
|
||||||
else:
|
else:
|
||||||
# Avoid duplicate detections by checking cooldown
|
# Avoid duplicate detections by checking cooldown
|
||||||
last_wake_up = self.hass.data.get(DATA_LAST_WAKE_UP)
|
wake_up_key = f"{self.wake_word_entity_id}.{result.wake_word_id}"
|
||||||
|
last_wake_up = self.hass.data[DATA_LAST_WAKE_UP].get(wake_up_key)
|
||||||
if last_wake_up is not None:
|
if last_wake_up is not None:
|
||||||
sec_since_last_wake_up = time.monotonic() - last_wake_up
|
sec_since_last_wake_up = time.monotonic() - last_wake_up
|
||||||
if sec_since_last_wake_up < wake_word_settings.cooldown_seconds:
|
if sec_since_last_wake_up < wake_word_settings.cooldown_seconds:
|
||||||
@ -689,7 +690,7 @@ class PipelineRun:
|
|||||||
raise WakeWordDetectionAborted
|
raise WakeWordDetectionAborted
|
||||||
|
|
||||||
# Record last wake up time to block duplicate detections
|
# Record last wake up time to block duplicate detections
|
||||||
self.hass.data[DATA_LAST_WAKE_UP] = time.monotonic()
|
self.hass.data[DATA_LAST_WAKE_UP][wake_up_key] = time.monotonic()
|
||||||
|
|
||||||
if result.queued_audio:
|
if result.queued_audio:
|
||||||
# Add audio that was pending at detection.
|
# Add audio that was pending at detection.
|
||||||
|
@ -181,6 +181,49 @@ class MockWakeWordEntity(wake_word.WakeWordDetectionEntity):
|
|||||||
url_path = "wake_word.test"
|
url_path = "wake_word.test"
|
||||||
_attr_name = "test"
|
_attr_name = "test"
|
||||||
|
|
||||||
|
alternate_detections = False
|
||||||
|
detected_wake_word_index = 0
|
||||||
|
|
||||||
|
async def get_supported_wake_words(self) -> list[wake_word.WakeWord]:
|
||||||
|
"""Return a list of supported wake words."""
|
||||||
|
return [
|
||||||
|
wake_word.WakeWord(id="test_ww", name="Test Wake Word"),
|
||||||
|
wake_word.WakeWord(id="test_ww_2", name="Test Wake Word 2"),
|
||||||
|
]
|
||||||
|
|
||||||
|
async def _async_process_audio_stream(
|
||||||
|
self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None
|
||||||
|
) -> wake_word.DetectionResult | None:
|
||||||
|
"""Try to detect wake word(s) in an audio stream with timestamps."""
|
||||||
|
wake_words = await self.get_supported_wake_words()
|
||||||
|
|
||||||
|
if self.alternate_detections:
|
||||||
|
detected_id = wake_words[self.detected_wake_word_index].id
|
||||||
|
self.detected_wake_word_index = (self.detected_wake_word_index + 1) % len(
|
||||||
|
wake_words
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
detected_id = wake_words[0].id
|
||||||
|
|
||||||
|
async for chunk, timestamp in stream:
|
||||||
|
if chunk.startswith(b"wake word"):
|
||||||
|
return wake_word.DetectionResult(
|
||||||
|
wake_word_id=detected_id,
|
||||||
|
timestamp=timestamp,
|
||||||
|
queued_audio=[(b"queued audio", 0)],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Not detected
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class MockWakeWordEntity2(wake_word.WakeWordDetectionEntity):
|
||||||
|
"""Second mock wake word entity to test cooldown."""
|
||||||
|
|
||||||
|
fail_process_audio = False
|
||||||
|
url_path = "wake_word.test2"
|
||||||
|
_attr_name = "test2"
|
||||||
|
|
||||||
async def get_supported_wake_words(self) -> list[wake_word.WakeWord]:
|
async def get_supported_wake_words(self) -> list[wake_word.WakeWord]:
|
||||||
"""Return a list of supported wake words."""
|
"""Return a list of supported wake words."""
|
||||||
return [wake_word.WakeWord(id="test_ww", name="Test Wake Word")]
|
return [wake_word.WakeWord(id="test_ww", name="Test Wake Word")]
|
||||||
@ -189,12 +232,12 @@ class MockWakeWordEntity(wake_word.WakeWordDetectionEntity):
|
|||||||
self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None
|
self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None
|
||||||
) -> wake_word.DetectionResult | None:
|
) -> wake_word.DetectionResult | None:
|
||||||
"""Try to detect wake word(s) in an audio stream with timestamps."""
|
"""Try to detect wake word(s) in an audio stream with timestamps."""
|
||||||
if wake_word_id is None:
|
wake_words = await self.get_supported_wake_words()
|
||||||
wake_word_id = (await self.get_supported_wake_words())[0].id
|
|
||||||
async for chunk, timestamp in stream:
|
async for chunk, timestamp in stream:
|
||||||
if chunk.startswith(b"wake word"):
|
if chunk.startswith(b"wake word"):
|
||||||
return wake_word.DetectionResult(
|
return wake_word.DetectionResult(
|
||||||
wake_word_id=wake_word_id,
|
wake_word_id=wake_words[0].id,
|
||||||
timestamp=timestamp,
|
timestamp=timestamp,
|
||||||
queued_audio=[(b"queued audio", 0)],
|
queued_audio=[(b"queued audio", 0)],
|
||||||
)
|
)
|
||||||
@ -209,6 +252,12 @@ async def mock_wake_word_provider_entity(hass) -> MockWakeWordEntity:
|
|||||||
return MockWakeWordEntity()
|
return MockWakeWordEntity()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mock_wake_word_provider_entity2(hass) -> MockWakeWordEntity2:
|
||||||
|
"""Mock wake word provider."""
|
||||||
|
return MockWakeWordEntity2()
|
||||||
|
|
||||||
|
|
||||||
class MockFlow(ConfigFlow):
|
class MockFlow(ConfigFlow):
|
||||||
"""Test flow."""
|
"""Test flow."""
|
||||||
|
|
||||||
@ -229,6 +278,7 @@ async def init_supporting_components(
|
|||||||
mock_stt_provider_entity: MockSttProviderEntity,
|
mock_stt_provider_entity: MockSttProviderEntity,
|
||||||
mock_tts_provider: MockTTSProvider,
|
mock_tts_provider: MockTTSProvider,
|
||||||
mock_wake_word_provider_entity: MockWakeWordEntity,
|
mock_wake_word_provider_entity: MockWakeWordEntity,
|
||||||
|
mock_wake_word_provider_entity2: MockWakeWordEntity2,
|
||||||
config_flow_fixture,
|
config_flow_fixture,
|
||||||
):
|
):
|
||||||
"""Initialize relevant components with empty configs."""
|
"""Initialize relevant components with empty configs."""
|
||||||
@ -265,7 +315,9 @@ async def init_supporting_components(
|
|||||||
async_add_entities: AddEntitiesCallback,
|
async_add_entities: AddEntitiesCallback,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set up test wake word platform via config entry."""
|
"""Set up test wake word platform via config entry."""
|
||||||
async_add_entities([mock_wake_word_provider_entity])
|
async_add_entities(
|
||||||
|
[mock_wake_word_provider_entity, mock_wake_word_provider_entity2]
|
||||||
|
)
|
||||||
|
|
||||||
mock_integration(
|
mock_integration(
|
||||||
hass,
|
hass,
|
||||||
|
@ -717,3 +717,173 @@
|
|||||||
'message': '',
|
'message': '',
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
|
# name: test_wake_word_cooldown_different_entities
|
||||||
|
dict({
|
||||||
|
'language': 'en',
|
||||||
|
'pipeline': <ANY>,
|
||||||
|
'runner_data': dict({
|
||||||
|
'stt_binary_handler_id': 1,
|
||||||
|
'timeout': 300,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_wake_word_cooldown_different_entities.1
|
||||||
|
dict({
|
||||||
|
'language': 'en',
|
||||||
|
'pipeline': <ANY>,
|
||||||
|
'runner_data': dict({
|
||||||
|
'stt_binary_handler_id': 1,
|
||||||
|
'timeout': 300,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_wake_word_cooldown_different_entities.2
|
||||||
|
dict({
|
||||||
|
'entity_id': 'wake_word.test',
|
||||||
|
'metadata': dict({
|
||||||
|
'bit_rate': 16,
|
||||||
|
'channel': 1,
|
||||||
|
'codec': 'pcm',
|
||||||
|
'format': 'wav',
|
||||||
|
'sample_rate': 16000,
|
||||||
|
}),
|
||||||
|
'timeout': 3,
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_wake_word_cooldown_different_entities.3
|
||||||
|
dict({
|
||||||
|
'entity_id': 'wake_word.test2',
|
||||||
|
'metadata': dict({
|
||||||
|
'bit_rate': 16,
|
||||||
|
'channel': 1,
|
||||||
|
'codec': 'pcm',
|
||||||
|
'format': 'wav',
|
||||||
|
'sample_rate': 16000,
|
||||||
|
}),
|
||||||
|
'timeout': 3,
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_wake_word_cooldown_different_entities.4
|
||||||
|
dict({
|
||||||
|
'wake_word_output': dict({
|
||||||
|
'timestamp': 0,
|
||||||
|
'wake_word_id': 'test_ww',
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_wake_word_cooldown_different_entities.5
|
||||||
|
dict({
|
||||||
|
'wake_word_output': dict({
|
||||||
|
'timestamp': 0,
|
||||||
|
'wake_word_id': 'test_ww',
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_wake_word_cooldown_different_ids
|
||||||
|
dict({
|
||||||
|
'language': 'en',
|
||||||
|
'pipeline': <ANY>,
|
||||||
|
'runner_data': dict({
|
||||||
|
'stt_binary_handler_id': 1,
|
||||||
|
'timeout': 300,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_wake_word_cooldown_different_ids.1
|
||||||
|
dict({
|
||||||
|
'language': 'en',
|
||||||
|
'pipeline': <ANY>,
|
||||||
|
'runner_data': dict({
|
||||||
|
'stt_binary_handler_id': 1,
|
||||||
|
'timeout': 300,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_wake_word_cooldown_different_ids.2
|
||||||
|
dict({
|
||||||
|
'entity_id': 'wake_word.test',
|
||||||
|
'metadata': dict({
|
||||||
|
'bit_rate': 16,
|
||||||
|
'channel': 1,
|
||||||
|
'codec': 'pcm',
|
||||||
|
'format': 'wav',
|
||||||
|
'sample_rate': 16000,
|
||||||
|
}),
|
||||||
|
'timeout': 3,
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_wake_word_cooldown_different_ids.3
|
||||||
|
dict({
|
||||||
|
'entity_id': 'wake_word.test',
|
||||||
|
'metadata': dict({
|
||||||
|
'bit_rate': 16,
|
||||||
|
'channel': 1,
|
||||||
|
'codec': 'pcm',
|
||||||
|
'format': 'wav',
|
||||||
|
'sample_rate': 16000,
|
||||||
|
}),
|
||||||
|
'timeout': 3,
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_wake_word_cooldown_different_ids.4
|
||||||
|
dict({
|
||||||
|
'wake_word_output': dict({
|
||||||
|
'timestamp': 0,
|
||||||
|
'wake_word_id': 'test_ww',
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_wake_word_cooldown_different_ids.5
|
||||||
|
dict({
|
||||||
|
'wake_word_output': dict({
|
||||||
|
'timestamp': 0,
|
||||||
|
'wake_word_id': 'test_ww_2',
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_wake_word_cooldown_same_id
|
||||||
|
dict({
|
||||||
|
'language': 'en',
|
||||||
|
'pipeline': <ANY>,
|
||||||
|
'runner_data': dict({
|
||||||
|
'stt_binary_handler_id': 1,
|
||||||
|
'timeout': 300,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_wake_word_cooldown_same_id.1
|
||||||
|
dict({
|
||||||
|
'language': 'en',
|
||||||
|
'pipeline': <ANY>,
|
||||||
|
'runner_data': dict({
|
||||||
|
'stt_binary_handler_id': 1,
|
||||||
|
'timeout': 300,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_wake_word_cooldown_same_id.2
|
||||||
|
dict({
|
||||||
|
'entity_id': 'wake_word.test',
|
||||||
|
'metadata': dict({
|
||||||
|
'bit_rate': 16,
|
||||||
|
'channel': 1,
|
||||||
|
'codec': 'pcm',
|
||||||
|
'format': 'wav',
|
||||||
|
'sample_rate': 16000,
|
||||||
|
}),
|
||||||
|
'timeout': 3,
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_wake_word_cooldown_same_id.3
|
||||||
|
dict({
|
||||||
|
'entity_id': 'wake_word.test',
|
||||||
|
'metadata': dict({
|
||||||
|
'bit_rate': 16,
|
||||||
|
'channel': 1,
|
||||||
|
'codec': 'pcm',
|
||||||
|
'format': 'wav',
|
||||||
|
'sample_rate': 16000,
|
||||||
|
}),
|
||||||
|
'timeout': 3,
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
@ -9,7 +9,7 @@ from homeassistant.components.assist_pipeline.pipeline import Pipeline, Pipeline
|
|||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
|
||||||
from .conftest import MockWakeWordEntity
|
from .conftest import MockWakeWordEntity, MockWakeWordEntity2
|
||||||
|
|
||||||
from tests.typing import WebSocketGenerator
|
from tests.typing import WebSocketGenerator
|
||||||
|
|
||||||
@ -1809,14 +1809,14 @@ async def test_audio_pipeline_with_enhancements(
|
|||||||
assert msg["result"] == {"events": events}
|
assert msg["result"] == {"events": events}
|
||||||
|
|
||||||
|
|
||||||
async def test_wake_word_cooldown(
|
async def test_wake_word_cooldown_same_id(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
init_components,
|
init_components,
|
||||||
mock_wake_word_provider_entity: MockWakeWordEntity,
|
mock_wake_word_provider_entity: MockWakeWordEntity,
|
||||||
hass_ws_client: WebSocketGenerator,
|
hass_ws_client: WebSocketGenerator,
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that duplicate wake word detections are blocked during the cooldown period."""
|
"""Test that duplicate wake word detections with the same id are blocked during the cooldown period."""
|
||||||
client_1 = await hass_ws_client(hass)
|
client_1 = await hass_ws_client(hass)
|
||||||
client_2 = await hass_ws_client(hass)
|
client_2 = await hass_ws_client(hass)
|
||||||
|
|
||||||
@ -1888,3 +1888,219 @@ async def test_wake_word_cooldown(
|
|||||||
|
|
||||||
# One should be a wake up, one should be an error
|
# One should be a wake up, one should be an error
|
||||||
assert {event_type_1, event_type_2} == {"wake_word-end", "error"}
|
assert {event_type_1, event_type_2} == {"wake_word-end", "error"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_wake_word_cooldown_different_ids(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components,
|
||||||
|
mock_wake_word_provider_entity: MockWakeWordEntity,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test that duplicate wake word detections are allowed with different ids."""
|
||||||
|
with patch.object(mock_wake_word_provider_entity, "alternate_detections", True):
|
||||||
|
client_1 = await hass_ws_client(hass)
|
||||||
|
client_2 = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
await client_1.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/run",
|
||||||
|
"start_stage": "wake_word",
|
||||||
|
"end_stage": "tts",
|
||||||
|
"input": {
|
||||||
|
"sample_rate": 16000,
|
||||||
|
"no_vad": True,
|
||||||
|
"no_chunking": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
await client_2.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/run",
|
||||||
|
"start_stage": "wake_word",
|
||||||
|
"end_stage": "tts",
|
||||||
|
"input": {
|
||||||
|
"sample_rate": 16000,
|
||||||
|
"no_vad": True,
|
||||||
|
"no_chunking": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# result
|
||||||
|
msg = await client_1.receive_json()
|
||||||
|
assert msg["success"], msg
|
||||||
|
|
||||||
|
msg = await client_2.receive_json()
|
||||||
|
assert msg["success"], msg
|
||||||
|
|
||||||
|
# run start
|
||||||
|
msg = await client_1.receive_json()
|
||||||
|
assert msg["event"]["type"] == "run-start"
|
||||||
|
msg["event"]["data"]["pipeline"] = ANY
|
||||||
|
handler_id_1 = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
|
||||||
|
msg = await client_2.receive_json()
|
||||||
|
assert msg["event"]["type"] == "run-start"
|
||||||
|
msg["event"]["data"]["pipeline"] = ANY
|
||||||
|
handler_id_2 = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
|
||||||
|
# wake_word
|
||||||
|
msg = await client_1.receive_json()
|
||||||
|
assert msg["event"]["type"] == "wake_word-start"
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
|
||||||
|
msg = await client_2.receive_json()
|
||||||
|
assert msg["event"]["type"] == "wake_word-start"
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
|
||||||
|
# Wake both up at the same time, but they will have different wake word ids
|
||||||
|
await client_1.send_bytes(bytes([handler_id_1]) + b"wake word")
|
||||||
|
await client_2.send_bytes(bytes([handler_id_2]) + b"wake word")
|
||||||
|
|
||||||
|
# Get response events
|
||||||
|
msg = await client_1.receive_json()
|
||||||
|
event_type_1 = msg["event"]["type"]
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
|
||||||
|
msg = await client_2.receive_json()
|
||||||
|
event_type_2 = msg["event"]["type"]
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
|
||||||
|
# Both should wake up now
|
||||||
|
assert {event_type_1, event_type_2} == {"wake_word-end"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_wake_word_cooldown_different_entities(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components,
|
||||||
|
mock_wake_word_provider_entity: MockWakeWordEntity,
|
||||||
|
mock_wake_word_provider_entity2: MockWakeWordEntity2,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test that duplicate wake word detections are allowed with different entities."""
|
||||||
|
client_pipeline = await hass_ws_client(hass)
|
||||||
|
await client_pipeline.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/pipeline/create",
|
||||||
|
"conversation_engine": "homeassistant",
|
||||||
|
"conversation_language": "en-US",
|
||||||
|
"language": "en",
|
||||||
|
"name": "pipeline_with_wake_word_1",
|
||||||
|
"stt_engine": "test",
|
||||||
|
"stt_language": "en-US",
|
||||||
|
"tts_engine": "test",
|
||||||
|
"tts_language": "en-US",
|
||||||
|
"tts_voice": "Arnold Schwarzenegger",
|
||||||
|
"wake_word_entity": mock_wake_word_provider_entity.entity_id,
|
||||||
|
"wake_word_id": "test_ww",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client_pipeline.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
pipeline_id_1 = msg["result"]["id"]
|
||||||
|
|
||||||
|
await client_pipeline.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/pipeline/create",
|
||||||
|
"conversation_engine": "homeassistant",
|
||||||
|
"conversation_language": "en-US",
|
||||||
|
"language": "en",
|
||||||
|
"name": "pipeline_with_wake_word_2",
|
||||||
|
"stt_engine": "test",
|
||||||
|
"stt_language": "en-US",
|
||||||
|
"tts_engine": "test",
|
||||||
|
"tts_language": "en-US",
|
||||||
|
"tts_voice": "Arnold Schwarzenegger",
|
||||||
|
"wake_word_entity": mock_wake_word_provider_entity2.entity_id,
|
||||||
|
"wake_word_id": "test_ww",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client_pipeline.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
pipeline_id_2 = msg["result"]["id"]
|
||||||
|
|
||||||
|
# Wake word clients
|
||||||
|
client_1 = await hass_ws_client(hass)
|
||||||
|
client_2 = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
await client_1.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/run",
|
||||||
|
"pipeline": pipeline_id_1,
|
||||||
|
"start_stage": "wake_word",
|
||||||
|
"end_stage": "tts",
|
||||||
|
"input": {
|
||||||
|
"sample_rate": 16000,
|
||||||
|
"no_vad": True,
|
||||||
|
"no_chunking": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use different wake word entity
|
||||||
|
await client_2.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/run",
|
||||||
|
"pipeline": pipeline_id_2,
|
||||||
|
"start_stage": "wake_word",
|
||||||
|
"end_stage": "tts",
|
||||||
|
"input": {
|
||||||
|
"sample_rate": 16000,
|
||||||
|
"no_vad": True,
|
||||||
|
"no_chunking": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# result
|
||||||
|
msg = await client_1.receive_json()
|
||||||
|
assert msg["success"], msg
|
||||||
|
|
||||||
|
msg = await client_2.receive_json()
|
||||||
|
assert msg["success"], msg
|
||||||
|
|
||||||
|
# run start
|
||||||
|
msg = await client_1.receive_json()
|
||||||
|
assert msg["event"]["type"] == "run-start"
|
||||||
|
msg["event"]["data"]["pipeline"] = ANY
|
||||||
|
handler_id_1 = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
|
||||||
|
msg = await client_2.receive_json()
|
||||||
|
assert msg["event"]["type"] == "run-start"
|
||||||
|
msg["event"]["data"]["pipeline"] = ANY
|
||||||
|
handler_id_2 = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
|
||||||
|
# wake_word
|
||||||
|
msg = await client_1.receive_json()
|
||||||
|
assert msg["event"]["type"] == "wake_word-start"
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
|
||||||
|
msg = await client_2.receive_json()
|
||||||
|
assert msg["event"]["type"] == "wake_word-start"
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
|
||||||
|
# Wake both up at the same time.
|
||||||
|
# They will have the same wake word id, but different entities.
|
||||||
|
await client_1.send_bytes(bytes([handler_id_1]) + b"wake word")
|
||||||
|
await client_2.send_bytes(bytes([handler_id_2]) + b"wake word")
|
||||||
|
|
||||||
|
# Get response events
|
||||||
|
msg = await client_1.receive_json()
|
||||||
|
assert msg["event"]["type"] == "wake_word-end", msg
|
||||||
|
ww_id_1 = msg["event"]["data"]["wake_word_output"]["wake_word_id"]
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
|
||||||
|
msg = await client_2.receive_json()
|
||||||
|
assert msg["event"]["type"] == "wake_word-end", msg
|
||||||
|
ww_id_2 = msg["event"]["data"]["wake_word_output"]["wake_word_id"]
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
|
||||||
|
# Wake words should be the same
|
||||||
|
assert ww_id_1 == ww_id_2
|
||||||
|
Loading…
x
Reference in New Issue
Block a user