diff --git a/homeassistant/components/assist_pipeline/__init__.py b/homeassistant/components/assist_pipeline/__init__.py index 7f6bef6e3c0..a009cfb1095 100644 --- a/homeassistant/components/assist_pipeline/__init__.py +++ b/homeassistant/components/assist_pipeline/__init__.py @@ -83,6 +83,7 @@ async def async_pipeline_from_audio_stream( event_callback: PipelineEventCallback, stt_metadata: stt.SpeechMetadata, stt_stream: AsyncIterable[bytes], + wake_word_phrase: str | None = None, pipeline_id: str | None = None, conversation_id: str | None = None, tts_audio_output: str | None = None, @@ -101,6 +102,7 @@ async def async_pipeline_from_audio_stream( device_id=device_id, stt_metadata=stt_metadata, stt_stream=stt_stream, + wake_word_phrase=wake_word_phrase, run=PipelineRun( hass, context=context, diff --git a/homeassistant/components/assist_pipeline/const.py b/homeassistant/components/assist_pipeline/const.py index 091b19db69e..ef1ed1177a6 100644 --- a/homeassistant/components/assist_pipeline/const.py +++ b/homeassistant/components/assist_pipeline/const.py @@ -10,6 +10,6 @@ DEFAULT_WAKE_WORD_TIMEOUT = 3 # seconds CONF_DEBUG_RECORDING_DIR = "debug_recording_dir" DATA_LAST_WAKE_UP = f"{DOMAIN}.last_wake_up" -DEFAULT_WAKE_WORD_COOLDOWN = 2 # seconds +WAKE_WORD_COOLDOWN = 2 # seconds EVENT_RECORDING = f"{DOMAIN}_recording" diff --git a/homeassistant/components/assist_pipeline/error.py b/homeassistant/components/assist_pipeline/error.py index 209e2611ec0..8b72331817c 100644 --- a/homeassistant/components/assist_pipeline/error.py +++ b/homeassistant/components/assist_pipeline/error.py @@ -38,6 +38,17 @@ class SpeechToTextError(PipelineError): """Error in speech-to-text portion of pipeline.""" +class DuplicateWakeUpDetectedError(WakeWordDetectionError): + """Error when multiple voice assistants wake up at the same time (same wake word).""" + + def __init__(self, wake_up_phrase: str) -> None: + """Set error message.""" + super().__init__( + "duplicate_wake_up_detected", + f"Duplicate wake-up detected for {wake_up_phrase}", + ) + + class IntentRecognitionError(PipelineError): """Error in intent recognition portion of pipeline.""" diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index a98f184094f..bf511f6cff5 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -55,10 +55,11 @@ from .const import ( CONF_DEBUG_RECORDING_DIR, DATA_CONFIG, DATA_LAST_WAKE_UP, - DEFAULT_WAKE_WORD_COOLDOWN, DOMAIN, + WAKE_WORD_COOLDOWN, ) from .error import ( + DuplicateWakeUpDetectedError, IntentRecognitionError, PipelineError, PipelineNotFound, @@ -453,9 +454,6 @@ class WakeWordSettings: audio_seconds_to_buffer: float = 0 """Seconds of audio to buffer before detection and forward to STT.""" - cooldown_seconds: float = DEFAULT_WAKE_WORD_COOLDOWN - """Seconds after a wake word detection where other detections are ignored.""" - @dataclass(frozen=True) class AudioSettings: @@ -742,16 +740,22 @@ class PipelineRun: wake_word_output: dict[str, Any] = {} else: # Avoid duplicate detections by checking cooldown - wake_up_key = f"{self.wake_word_entity_id}.{result.wake_word_id}" - last_wake_up = self.hass.data[DATA_LAST_WAKE_UP].get(wake_up_key) + last_wake_up = self.hass.data[DATA_LAST_WAKE_UP].get( + result.wake_word_phrase + ) if last_wake_up is not None: sec_since_last_wake_up = time.monotonic() - last_wake_up - if sec_since_last_wake_up < wake_word_settings.cooldown_seconds: - _LOGGER.debug("Duplicate wake word detection occurred") - raise WakeWordDetectionAborted + if sec_since_last_wake_up < WAKE_WORD_COOLDOWN: + _LOGGER.debug( + "Duplicate wake word detection occurred for %s", + result.wake_word_phrase, + ) + raise DuplicateWakeUpDetectedError(result.wake_word_phrase) # Record last wake up time to block duplicate detections - self.hass.data[DATA_LAST_WAKE_UP][wake_up_key] = time.monotonic() + self.hass.data[DATA_LAST_WAKE_UP][ + result.wake_word_phrase + ] = time.monotonic() if result.queued_audio: # Add audio that was pending at detection. @@ -1308,6 +1312,9 @@ class PipelineInput: stt_stream: AsyncIterable[bytes] | None = None """Input audio for stt. Required when start_stage = stt.""" + wake_word_phrase: str | None = None + """Optional key used to de-duplicate wake-ups for local wake word detection.""" + intent_input: str | None = None """Input for conversation agent. Required when start_stage = intent.""" @@ -1352,6 +1359,25 @@ class PipelineInput: assert self.stt_metadata is not None assert stt_processed_stream is not None + if self.wake_word_phrase is not None: + # Avoid duplicate wake-ups by checking cooldown + last_wake_up = self.run.hass.data[DATA_LAST_WAKE_UP].get( + self.wake_word_phrase + ) + if last_wake_up is not None: + sec_since_last_wake_up = time.monotonic() - last_wake_up + if sec_since_last_wake_up < WAKE_WORD_COOLDOWN: + _LOGGER.debug( + "Speech-to-text cancelled to avoid duplicate wake-up for %s", + self.wake_word_phrase, + ) + raise DuplicateWakeUpDetectedError(self.wake_word_phrase) + + # Record last wake up time to block duplicate detections + self.run.hass.data[DATA_LAST_WAKE_UP][ + self.wake_word_phrase + ] = time.monotonic() + stt_input_stream = stt_processed_stream if stt_audio_buffer: diff --git a/homeassistant/components/assist_pipeline/websocket_api.py b/homeassistant/components/assist_pipeline/websocket_api.py index 6d60426e730..f7a6d3c43fa 100644 --- a/homeassistant/components/assist_pipeline/websocket_api.py +++ b/homeassistant/components/assist_pipeline/websocket_api.py @@ -97,7 +97,12 @@ def async_register_websocket_api(hass: HomeAssistant) -> None: extra=vol.ALLOW_EXTRA, ), PipelineStage.STT: vol.Schema( - {vol.Required("input"): {vol.Required("sample_rate"): int}}, + { + vol.Required("input"): { + vol.Required("sample_rate"): int, + vol.Optional("wake_word_phrase"): str, + } + }, extra=vol.ALLOW_EXTRA, ), PipelineStage.INTENT: vol.Schema( @@ -149,12 +154,15 @@ async def websocket_run( msg_input = msg["input"] audio_queue: asyncio.Queue[bytes] = asyncio.Queue() incoming_sample_rate = msg_input["sample_rate"] + wake_word_phrase: str | None = None if start_stage == PipelineStage.WAKE_WORD: wake_word_settings = WakeWordSettings( timeout=msg["input"].get("timeout", DEFAULT_WAKE_WORD_TIMEOUT), audio_seconds_to_buffer=msg_input.get("audio_seconds_to_buffer", 0), ) + elif start_stage == PipelineStage.STT: + wake_word_phrase = msg["input"].get("wake_word_phrase") async def stt_stream() -> AsyncGenerator[bytes, None]: state = None @@ -189,6 +197,7 @@ async def websocket_run( channel=stt.AudioChannels.CHANNEL_MONO, ) input_args["stt_stream"] = stt_stream() + input_args["wake_word_phrase"] = wake_word_phrase # Audio settings audio_settings = AudioSettings( diff --git a/homeassistant/components/wake_word/models.py b/homeassistant/components/wake_word/models.py index 8e0699d97d0..c341df188ce 100644 --- a/homeassistant/components/wake_word/models.py +++ b/homeassistant/components/wake_word/models.py @@ -7,7 +7,13 @@ class WakeWord: """Wake word model.""" id: str + """Id of wake word model""" + name: str + """Name of wake word model""" + + phrase: str | None = None + """Wake word phrase used to trigger model""" @dataclass @@ -17,6 +23,9 @@ class DetectionResult: wake_word_id: str """Id of detected wake word""" + wake_word_phrase: str + """Normalized phrase for the detected wake word""" + timestamp: int | None """Timestamp of audio chunk with detected wake word""" diff --git a/homeassistant/components/wyoming/manifest.json b/homeassistant/components/wyoming/manifest.json index 14cf9f77683..830ba5a3435 100644 --- a/homeassistant/components/wyoming/manifest.json +++ b/homeassistant/components/wyoming/manifest.json @@ -6,6 +6,6 @@ "dependencies": ["assist_pipeline"], "documentation": "https://www.home-assistant.io/integrations/wyoming", "iot_class": "local_push", - "requirements": ["wyoming==1.5.2"], + "requirements": ["wyoming==1.5.3"], "zeroconf": ["_wyoming._tcp.local."] } diff --git a/homeassistant/components/wyoming/satellite.py b/homeassistant/components/wyoming/satellite.py index ea7a7d5df0c..9569c420a1e 100644 --- a/homeassistant/components/wyoming/satellite.py +++ b/homeassistant/components/wyoming/satellite.py @@ -1,4 +1,5 @@ """Support for Wyoming satellite services.""" + import asyncio from collections.abc import AsyncGenerator import io @@ -10,6 +11,7 @@ from wyoming.asr import Transcribe, Transcript from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStart, AudioStop from wyoming.client import AsyncTcpClient from wyoming.error import Error +from wyoming.info import Describe, Info from wyoming.ping import Ping, Pong from wyoming.pipeline import PipelineStage, RunPipeline from wyoming.satellite import PauseSatellite, RunSatellite @@ -86,7 +88,9 @@ class WyomingSatellite: await self._connect_and_loop() except asyncio.CancelledError: raise # don't restart - except Exception: # pylint: disable=broad-exception-caught + except Exception as err: # pylint: disable=broad-exception-caught + _LOGGER.debug("%s: %s", err.__class__.__name__, str(err)) + # Ensure sensor is off (before restart) self.device.set_is_active(False) @@ -197,6 +201,8 @@ class WyomingSatellite: async def _run_pipeline_loop(self) -> None: """Run a pipeline one or more times.""" assert self._client is not None + client_info: Info | None = None + wake_word_phrase: str | None = None run_pipeline: RunPipeline | None = None send_ping = True @@ -209,6 +215,9 @@ class WyomingSatellite: ) pending = {pipeline_ended_task, client_event_task} + # Update info from satellite + await self._client.write_event(Describe().event()) + while self.is_running and (not self.device.is_muted): if send_ping: # Ensure satellite is still connected @@ -230,6 +239,9 @@ class WyomingSatellite: ) pending.add(pipeline_ended_task) + # Clear last wake word detection + wake_word_phrase = None + if (run_pipeline is not None) and run_pipeline.restart_on_end: # Automatically restart pipeline. # Used with "always on" streaming satellites. @@ -253,7 +265,7 @@ class WyomingSatellite: elif RunPipeline.is_type(client_event.type): # Satellite requested pipeline run run_pipeline = RunPipeline.from_event(client_event) - self._run_pipeline_once(run_pipeline) + self._run_pipeline_once(run_pipeline, wake_word_phrase) elif ( AudioChunk.is_type(client_event.type) and self._is_pipeline_running ): @@ -265,6 +277,32 @@ class WyomingSatellite: # Stop pipeline _LOGGER.debug("Client requested pipeline to stop") self._audio_queue.put_nowait(b"") + elif Info.is_type(client_event.type): + client_info = Info.from_event(client_event) + _LOGGER.debug("Updated client info: %s", client_info) + elif Detection.is_type(client_event.type): + detection = Detection.from_event(client_event) + wake_word_phrase = detection.name + + # Resolve wake word name/id to phrase if info is available. + # + # This allows us to deconflict multiple satellite wake-ups + # with the same wake word. + if (client_info is not None) and (client_info.wake is not None): + found_phrase = False + for wake_service in client_info.wake: + for wake_model in wake_service.models: + if wake_model.name == detection.name: + wake_word_phrase = ( + wake_model.phrase or wake_model.name + ) + found_phrase = True + break + + if found_phrase: + break + + _LOGGER.debug("Client detected wake word: %s", wake_word_phrase) else: _LOGGER.debug("Unexpected event from satellite: %s", client_event) @@ -274,7 +312,9 @@ class WyomingSatellite: ) pending.add(client_event_task) - def _run_pipeline_once(self, run_pipeline: RunPipeline) -> None: + def _run_pipeline_once( + self, run_pipeline: RunPipeline, wake_word_phrase: str | None = None + ) -> None: """Run a pipeline once.""" _LOGGER.debug("Received run information: %s", run_pipeline) @@ -332,6 +372,7 @@ class WyomingSatellite: volume_multiplier=self.device.volume_multiplier, ), device_id=self.device.device_id, + wake_word_phrase=wake_word_phrase, ), name="wyoming satellite pipeline", ) diff --git a/homeassistant/components/wyoming/wake_word.py b/homeassistant/components/wyoming/wake_word.py index da05e8c9fe1..303a87e99bd 100644 --- a/homeassistant/components/wyoming/wake_word.py +++ b/homeassistant/components/wyoming/wake_word.py @@ -1,4 +1,5 @@ """Support for Wyoming wake-word-detection services.""" + import asyncio from collections.abc import AsyncIterable import logging @@ -49,7 +50,9 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity): wake_service = service.info.wake[0] self._supported_wake_words = [ - wake_word.WakeWord(id=ww.name, name=ww.description or ww.name) + wake_word.WakeWord( + id=ww.name, name=ww.description or ww.name, phrase=ww.phrase + ) for ww in wake_service.models ] self._attr_name = wake_service.name @@ -64,7 +67,11 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity): if info is not None: wake_service = info.wake[0] self._supported_wake_words = [ - wake_word.WakeWord(id=ww.name, name=ww.description or ww.name) + wake_word.WakeWord( + id=ww.name, + name=ww.description or ww.name, + phrase=ww.phrase, + ) for ww in wake_service.models ] @@ -140,6 +147,7 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity): return wake_word.DetectionResult( wake_word_id=detection.name, + wake_word_phrase=self._get_phrase(detection.name), timestamp=detection.timestamp, queued_audio=queued_audio, ) @@ -183,3 +191,14 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity): _LOGGER.exception("Error processing audio stream: %s", err) return None + + def _get_phrase(self, model_id: str) -> str: + """Get wake word phrase for model id.""" + for ww_model in self._supported_wake_words: + if not ww_model.phrase: + continue + + if ww_model.id == model_id: + return ww_model.phrase + + return model_id diff --git a/requirements_all.txt b/requirements_all.txt index 1b90cb7c11a..88c3950190a 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -2863,7 +2863,7 @@ wled==0.17.0 wolf-comm==0.0.4 # homeassistant.components.wyoming -wyoming==1.5.2 +wyoming==1.5.3 # homeassistant.components.xbox xbox-webapi==2.0.11 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index a9e15197752..ae01d969dd3 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -2195,7 +2195,7 @@ wled==0.17.0 wolf-comm==0.0.4 # homeassistant.components.wyoming -wyoming==1.5.2 +wyoming==1.5.3 # homeassistant.components.xbox xbox-webapi==2.0.11 diff --git a/tests/components/assist_pipeline/conftest.py b/tests/components/assist_pipeline/conftest.py index 38c96871ed3..0c9d83200b4 100644 --- a/tests/components/assist_pipeline/conftest.py +++ b/tests/components/assist_pipeline/conftest.py @@ -201,16 +201,19 @@ class MockWakeWordEntity(wake_word.WakeWordDetectionEntity): if self.alternate_detections: detected_id = wake_words[self.detected_wake_word_index].id + detected_name = wake_words[self.detected_wake_word_index].name self.detected_wake_word_index = (self.detected_wake_word_index + 1) % len( wake_words ) else: detected_id = wake_words[0].id + detected_name = wake_words[0].name async for chunk, timestamp in stream: if chunk.startswith(b"wake word"): return wake_word.DetectionResult( wake_word_id=detected_id, + wake_word_phrase=detected_name, timestamp=timestamp, queued_audio=[(b"queued audio", 0)], ) @@ -240,6 +243,7 @@ class MockWakeWordEntity2(wake_word.WakeWordDetectionEntity): if chunk.startswith(b"wake word"): return wake_word.DetectionResult( wake_word_id=wake_words[0].id, + wake_word_phrase=wake_words[0].name, timestamp=timestamp, queued_audio=[(b"queued audio", 0)], ) diff --git a/tests/components/assist_pipeline/snapshots/test_init.ambr b/tests/components/assist_pipeline/snapshots/test_init.ambr index e822759d208..bbd0c9d333a 100644 --- a/tests/components/assist_pipeline/snapshots/test_init.ambr +++ b/tests/components/assist_pipeline/snapshots/test_init.ambr @@ -294,6 +294,7 @@ 'wake_word_output': dict({ 'timestamp': 2000, 'wake_word_id': 'test_ww', + 'wake_word_phrase': 'Test Wake Word', }), }), 'type': , diff --git a/tests/components/assist_pipeline/snapshots/test_websocket.ambr b/tests/components/assist_pipeline/snapshots/test_websocket.ambr index a050b009a8d..10a76bc9344 100644 --- a/tests/components/assist_pipeline/snapshots/test_websocket.ambr +++ b/tests/components/assist_pipeline/snapshots/test_websocket.ambr @@ -381,6 +381,7 @@ 'wake_word_output': dict({ 'timestamp': 0, 'wake_word_id': 'test_ww', + 'wake_word_phrase': 'Test Wake Word', }), }) # --- @@ -695,6 +696,46 @@ # name: test_pipeline_empty_tts_output.3 None # --- +# name: test_stt_cooldown_different_ids + dict({ + 'language': 'en', + 'pipeline': , + 'runner_data': dict({ + 'stt_binary_handler_id': 1, + 'timeout': 300, + }), + }) +# --- +# name: test_stt_cooldown_different_ids.1 + dict({ + 'language': 'en', + 'pipeline': , + 'runner_data': dict({ + 'stt_binary_handler_id': 1, + 'timeout': 300, + }), + }) +# --- +# name: test_stt_cooldown_same_id + dict({ + 'language': 'en', + 'pipeline': , + 'runner_data': dict({ + 'stt_binary_handler_id': 1, + 'timeout': 300, + }), + }) +# --- +# name: test_stt_cooldown_same_id.1 + dict({ + 'language': 'en', + 'pipeline': , + 'runner_data': dict({ + 'stt_binary_handler_id': 1, + 'timeout': 300, + }), + }) +# --- # name: test_stt_provider_missing dict({ 'language': 'en', @@ -926,15 +967,14 @@ 'wake_word_output': dict({ 'timestamp': 0, 'wake_word_id': 'test_ww', + 'wake_word_phrase': 'Test Wake Word', }), }) # --- # name: test_wake_word_cooldown_different_entities.5 dict({ - 'wake_word_output': dict({ - 'timestamp': 0, - 'wake_word_id': 'test_ww', - }), + 'code': 'duplicate_wake_up_detected', + 'message': 'Duplicate wake-up detected for Test Wake Word', }) # --- # name: test_wake_word_cooldown_different_ids @@ -988,6 +1028,7 @@ 'wake_word_output': dict({ 'timestamp': 0, 'wake_word_id': 'test_ww', + 'wake_word_phrase': 'Test Wake Word', }), }) # --- @@ -996,6 +1037,7 @@ 'wake_word_output': dict({ 'timestamp': 0, 'wake_word_id': 'test_ww_2', + 'wake_word_phrase': 'Test Wake Word 2', }), }) # --- @@ -1045,3 +1087,18 @@ 'timeout': 3, }) # --- +# name: test_wake_word_cooldown_same_id.4 + dict({ + 'wake_word_output': dict({ + 'timestamp': 0, + 'wake_word_id': 'test_ww', + 'wake_word_phrase': 'Test Wake Word', + }), + }) +# --- +# name: test_wake_word_cooldown_same_id.5 + dict({ + 'code': 'duplicate_wake_up_detected', + 'message': 'Duplicate wake-up detected for Test Wake Word', + }) +# --- diff --git a/tests/components/assist_pipeline/test_websocket.py b/tests/components/assist_pipeline/test_websocket.py index 3ea6be028c1..9138819de12 100644 --- a/tests/components/assist_pipeline/test_websocket.py +++ b/tests/components/assist_pipeline/test_websocket.py @@ -1,6 +1,7 @@ """Websocket tests for Voice Assistant integration.""" import asyncio import base64 +from typing import Any from unittest.mock import ANY, patch from syrupy.assertion import SnapshotAssertion @@ -1887,14 +1888,23 @@ async def test_wake_word_cooldown_same_id( await client_2.send_bytes(bytes([handler_id_2]) + b"wake word") # Get response events + error_data: dict[str, Any] | None = None msg = await client_1.receive_json() event_type_1 = msg["event"]["type"] + assert msg["event"]["data"] == snapshot + if event_type_1 == "error": + error_data = msg["event"]["data"] msg = await client_2.receive_json() event_type_2 = msg["event"]["type"] + assert msg["event"]["data"] == snapshot + if event_type_2 == "error": + error_data = msg["event"]["data"] # One should be a wake up, one should be an error assert {event_type_1, event_type_2} == {"wake_word-end", "error"} + assert error_data is not None + assert error_data["code"] == "duplicate_wake_up_detected" async def test_wake_word_cooldown_different_ids( @@ -1989,7 +1999,7 @@ async def test_wake_word_cooldown_different_entities( hass_ws_client: WebSocketGenerator, snapshot: SnapshotAssertion, ) -> None: - """Test that duplicate wake word detections are allowed with different entities.""" + """Test that duplicate wake word detections are blocked even with different wake word entities.""" client_pipeline = await hass_ws_client(hass) await client_pipeline.send_json_auto_id( { @@ -2049,7 +2059,7 @@ async def test_wake_word_cooldown_different_entities( } ) - # Use different wake word entity + # Use different wake word entity (but same wake word) await client_2.send_json_auto_id( { "type": "assist_pipeline/run", @@ -2099,18 +2109,23 @@ async def test_wake_word_cooldown_different_entities( await client_2.send_bytes(bytes([handler_id_2]) + b"wake word") # Get response events + error_data: dict[str, Any] | None = None msg = await client_1.receive_json() - assert msg["event"]["type"] == "wake_word-end", msg - ww_id_1 = msg["event"]["data"]["wake_word_output"]["wake_word_id"] + event_type_1 = msg["event"]["type"] assert msg["event"]["data"] == snapshot + if event_type_1 == "error": + error_data = msg["event"]["data"] msg = await client_2.receive_json() - assert msg["event"]["type"] == "wake_word-end", msg - ww_id_2 = msg["event"]["data"]["wake_word_output"]["wake_word_id"] + event_type_2 = msg["event"]["type"] assert msg["event"]["data"] == snapshot + if event_type_2 == "error": + error_data = msg["event"]["data"] - # Wake words should be the same - assert ww_id_1 == ww_id_2 + # One should be a wake up, one should be an error + assert {event_type_1, event_type_2} == {"wake_word-end", "error"} + assert error_data is not None + assert error_data["code"] == "duplicate_wake_up_detected" async def test_device_capture( @@ -2521,3 +2536,138 @@ async def test_pipeline_list_devices( "pipeline_entity": "select.test_assist_device_test_prefix_pipeline", } ] + + +async def test_stt_cooldown_same_id( + hass: HomeAssistant, + init_components, + mock_stt_provider, + hass_ws_client: WebSocketGenerator, + snapshot: SnapshotAssertion, +) -> None: + """Test that two speech-to-text pipelines cannot run within the cooldown period if they have the same wake word.""" + client_1 = await hass_ws_client(hass) + client_2 = await hass_ws_client(hass) + + await client_1.send_json_auto_id( + { + "type": "assist_pipeline/run", + "start_stage": "stt", + "end_stage": "tts", + "input": { + "sample_rate": 16000, + "wake_word_phrase": "ok_nabu", + }, + } + ) + + await client_2.send_json_auto_id( + { + "type": "assist_pipeline/run", + "start_stage": "stt", + "end_stage": "tts", + "input": { + "sample_rate": 16000, + "wake_word_phrase": "ok_nabu", + }, + } + ) + + # result + msg = await client_1.receive_json() + assert msg["success"], msg + + msg = await client_2.receive_json() + assert msg["success"], msg + + # run start + msg = await client_1.receive_json() + assert msg["event"]["type"] == "run-start" + msg["event"]["data"]["pipeline"] = ANY + assert msg["event"]["data"] == snapshot + + msg = await client_2.receive_json() + assert msg["event"]["type"] == "run-start" + msg["event"]["data"]["pipeline"] = ANY + assert msg["event"]["data"] == snapshot + + # Get response events + error_data: dict[str, Any] | None = None + msg = await client_1.receive_json() + event_type_1 = msg["event"]["type"] + if event_type_1 == "error": + error_data = msg["event"]["data"] + + msg = await client_2.receive_json() + event_type_2 = msg["event"]["type"] + if event_type_2 == "error": + error_data = msg["event"]["data"] + + # One should be a stt start, one should be an error + assert {event_type_1, event_type_2} == {"stt-start", "error"} + assert error_data is not None + assert error_data["code"] == "duplicate_wake_up_detected" + + +async def test_stt_cooldown_different_ids( + hass: HomeAssistant, + init_components, + mock_stt_provider, + hass_ws_client: WebSocketGenerator, + snapshot: SnapshotAssertion, +) -> None: + """Test that two speech-to-text pipelines can run within the cooldown period if they have the different wake words.""" + client_1 = await hass_ws_client(hass) + client_2 = await hass_ws_client(hass) + + await client_1.send_json_auto_id( + { + "type": "assist_pipeline/run", + "start_stage": "stt", + "end_stage": "tts", + "input": { + "sample_rate": 16000, + "wake_word_phrase": "ok_nabu", + }, + } + ) + + await client_2.send_json_auto_id( + { + "type": "assist_pipeline/run", + "start_stage": "stt", + "end_stage": "tts", + "input": { + "sample_rate": 16000, + "wake_word_phrase": "hey_jarvis", + }, + } + ) + + # result + msg = await client_1.receive_json() + assert msg["success"], msg + + msg = await client_2.receive_json() + assert msg["success"], msg + + # run start + msg = await client_1.receive_json() + assert msg["event"]["type"] == "run-start" + msg["event"]["data"]["pipeline"] = ANY + assert msg["event"]["data"] == snapshot + + msg = await client_2.receive_json() + assert msg["event"]["type"] == "run-start" + msg["event"]["data"]["pipeline"] = ANY + assert msg["event"]["data"] == snapshot + + # Get response events + msg = await client_1.receive_json() + event_type_1 = msg["event"]["type"] + + msg = await client_2.receive_json() + event_type_2 = msg["event"]["type"] + + # Both should start stt + assert {event_type_1, event_type_2} == {"stt-start"} diff --git a/tests/components/wake_word/test_init.py b/tests/components/wake_word/test_init.py index 6b147229d47..0aac011d02a 100644 --- a/tests/components/wake_word/test_init.py +++ b/tests/components/wake_word/test_init.py @@ -1,4 +1,5 @@ """Test wake_word component setup.""" + import asyncio from collections.abc import AsyncIterable, Generator from functools import partial @@ -43,8 +44,12 @@ class MockProviderEntity(wake_word.WakeWordDetectionEntity): async def get_supported_wake_words(self) -> list[wake_word.WakeWord]: """Return a list of supported wake words.""" return [ - wake_word.WakeWord(id="test_ww", name="Test Wake Word"), - wake_word.WakeWord(id="test_ww_2", name="Test Wake Word 2"), + wake_word.WakeWord( + id="test_ww", name="Test Wake Word", phrase="Test Phrase" + ), + wake_word.WakeWord( + id="test_ww_2", name="Test Wake Word 2", phrase="Test Phrase 2" + ), ] async def _async_process_audio_stream( @@ -54,10 +59,18 @@ class MockProviderEntity(wake_word.WakeWordDetectionEntity): if wake_word_id is None: wake_word_id = (await self.get_supported_wake_words())[0].id + wake_word_phrase = wake_word_id + for ww in await self.get_supported_wake_words(): + if ww.id == wake_word_id: + wake_word_phrase = ww.phrase or ww.name + break + async for _chunk, timestamp in stream: if timestamp >= 2000: return wake_word.DetectionResult( - wake_word_id=wake_word_id, timestamp=timestamp + wake_word_id=wake_word_id, + wake_word_phrase=wake_word_phrase, + timestamp=timestamp, ) # Not detected @@ -159,10 +172,10 @@ async def test_config_entry_unload( @freeze_time("2023-06-22 10:30:00+00:00") @pytest.mark.parametrize( - ("wake_word_id", "expected_ww"), + ("wake_word_id", "expected_ww", "expected_phrase"), [ - (None, "test_ww"), - ("test_ww_2", "test_ww_2"), + (None, "test_ww", "Test Phrase"), + ("test_ww_2", "test_ww_2", "Test Phrase 2"), ], ) async def test_detected_entity( @@ -171,6 +184,7 @@ async def test_detected_entity( setup: MockProviderEntity, wake_word_id: str | None, expected_ww: str, + expected_phrase: str, ) -> None: """Test successful detection through entity.""" @@ -184,7 +198,9 @@ async def test_detected_entity( state = setup.state assert state is None result = await setup.async_process_audio_stream(three_second_stream(), wake_word_id) - assert result == wake_word.DetectionResult(expected_ww, 2048) + assert result == wake_word.DetectionResult( + wake_word_id=expected_ww, wake_word_phrase=expected_phrase, timestamp=2048 + ) assert state != setup.state assert setup.state == "2023-06-22T10:30:00+00:00" @@ -285,8 +301,8 @@ async def test_list_wake_words( assert msg["success"] assert msg["result"] == { "wake_words": [ - {"id": "test_ww", "name": "Test Wake Word"}, - {"id": "test_ww_2", "name": "Test Wake Word 2"}, + {"id": "test_ww", "name": "Test Wake Word", "phrase": "Test Phrase"}, + {"id": "test_ww_2", "name": "Test Wake Word 2", "phrase": "Test Phrase 2"}, ] } @@ -320,9 +336,10 @@ async def test_list_wake_words_timeout( """Test that the list_wake_words websocket command handles unknown entity.""" client = await hass_ws_client(hass) - with patch.object( - setup, "get_supported_wake_words", partial(asyncio.sleep, 1) - ), patch("homeassistant.components.wake_word.TIMEOUT_FETCH_WAKE_WORDS", 0): + with ( + patch.object(setup, "get_supported_wake_words", partial(asyncio.sleep, 1)), + patch("homeassistant.components.wake_word.TIMEOUT_FETCH_WAKE_WORDS", 0), + ): await client.send_json( { "id": 5, diff --git a/tests/components/wyoming/__init__.py b/tests/components/wyoming/__init__.py index 2adc9a21b6f..6b049b04c42 100644 --- a/tests/components/wyoming/__init__.py +++ b/tests/components/wyoming/__init__.py @@ -75,6 +75,7 @@ WAKE_WORD_INFO = Info( WakeModel( name="Test Model", description="Test Model", + phrase="Test Phrase", installed=True, attribution=TEST_ATTR, languages=["en-US"], diff --git a/tests/components/wyoming/snapshots/test_wake_word.ambr b/tests/components/wyoming/snapshots/test_wake_word.ambr index 41518634a51..27e40854ead 100644 --- a/tests/components/wyoming/snapshots/test_wake_word.ambr +++ b/tests/components/wyoming/snapshots/test_wake_word.ambr @@ -9,5 +9,6 @@ ]), 'timestamp': 0, 'wake_word_id': 'Test Model', + 'wake_word_phrase': 'Test Phrase', }) # --- diff --git a/tests/components/wyoming/test_satellite.py b/tests/components/wyoming/test_satellite.py index f568f7b6975..5cbbfd0a8c3 100644 --- a/tests/components/wyoming/test_satellite.py +++ b/tests/components/wyoming/test_satellite.py @@ -1,4 +1,5 @@ """Test Wyoming satellite.""" + from __future__ import annotations import asyncio @@ -12,6 +13,7 @@ from wyoming.asr import Transcribe, Transcript from wyoming.audio import AudioChunk, AudioStart, AudioStop from wyoming.error import Error from wyoming.event import Event +from wyoming.info import Info from wyoming.ping import Ping, Pong from wyoming.pipeline import PipelineStage, RunPipeline from wyoming.satellite import RunSatellite @@ -26,7 +28,7 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component -from . import SATELLITE_INFO, MockAsyncTcpClient +from . import SATELLITE_INFO, WAKE_WORD_INFO, MockAsyncTcpClient from tests.common import MockConfigEntry @@ -207,19 +209,25 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None: audio_chunk_received.set() break - with patch( - "homeassistant.components.wyoming.data.load_wyoming_info", - return_value=SATELLITE_INFO, - ), patch( - "homeassistant.components.wyoming.satellite.AsyncTcpClient", - SatelliteAsyncTcpClient(events), - ) as mock_client, patch( - "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", - async_pipeline_from_audio_stream, - ), patch( - "homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio", - return_value=("wav", get_test_wav()), - ), patch("homeassistant.components.wyoming.satellite._PING_SEND_DELAY", 0): + with ( + patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=SATELLITE_INFO, + ), + patch( + "homeassistant.components.wyoming.satellite.AsyncTcpClient", + SatelliteAsyncTcpClient(events), + ) as mock_client, + patch( + "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", + async_pipeline_from_audio_stream, + ), + patch( + "homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio", + return_value=("wav", get_test_wav()), + ), + patch("homeassistant.components.wyoming.satellite._PING_SEND_DELAY", 0), + ): entry = await setup_config_entry(hass) device: SatelliteDevice = hass.data[wyoming.DOMAIN][ entry.entry_id @@ -433,14 +441,16 @@ async def test_satellite_muted(hass: HomeAssistant) -> None: self.device.set_is_muted(False) on_muted_event.set() - with patch( - "homeassistant.components.wyoming.data.load_wyoming_info", - return_value=SATELLITE_INFO, - ), patch( - "homeassistant.components.wyoming._make_satellite", make_muted_satellite - ), patch( - "homeassistant.components.wyoming.satellite.WyomingSatellite.on_muted", - on_muted, + with ( + patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=SATELLITE_INFO, + ), + patch("homeassistant.components.wyoming._make_satellite", make_muted_satellite), + patch( + "homeassistant.components.wyoming.satellite.WyomingSatellite.on_muted", + on_muted, + ), ): entry = await setup_config_entry(hass) async with asyncio.timeout(1): @@ -462,16 +472,21 @@ async def test_satellite_restart(hass: HomeAssistant) -> None: self.stop() on_restart_event.set() - with patch( - "homeassistant.components.wyoming.data.load_wyoming_info", - return_value=SATELLITE_INFO, - ), patch( - "homeassistant.components.wyoming.satellite.WyomingSatellite._connect_and_loop", - side_effect=RuntimeError(), - ), patch( - "homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart", - on_restart, - ), patch("homeassistant.components.wyoming.satellite._RESTART_SECONDS", 0): + with ( + patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=SATELLITE_INFO, + ), + patch( + "homeassistant.components.wyoming.satellite.WyomingSatellite._connect_and_loop", + side_effect=RuntimeError(), + ), + patch( + "homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart", + on_restart, + ), + patch("homeassistant.components.wyoming.satellite._RESTART_SECONDS", 0), + ): await setup_config_entry(hass) async with asyncio.timeout(1): await on_restart_event.wait() @@ -497,19 +512,25 @@ async def test_satellite_reconnect(hass: HomeAssistant) -> None: async def on_stopped(self): stopped_event.set() - with patch( - "homeassistant.components.wyoming.data.load_wyoming_info", - return_value=SATELLITE_INFO, - ), patch( - "homeassistant.components.wyoming.satellite.AsyncTcpClient.connect", - side_effect=ConnectionRefusedError(), - ), patch( - "homeassistant.components.wyoming.satellite.WyomingSatellite.on_reconnect", - on_reconnect, - ), patch( - "homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped", - on_stopped, - ), patch("homeassistant.components.wyoming.satellite._RECONNECT_SECONDS", 0): + with ( + patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=SATELLITE_INFO, + ), + patch( + "homeassistant.components.wyoming.satellite.AsyncTcpClient.connect", + side_effect=ConnectionRefusedError(), + ), + patch( + "homeassistant.components.wyoming.satellite.WyomingSatellite.on_reconnect", + on_reconnect, + ), + patch( + "homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped", + on_stopped, + ), + patch("homeassistant.components.wyoming.satellite._RECONNECT_SECONDS", 0), + ): await setup_config_entry(hass) async with asyncio.timeout(1): await reconnect_event.wait() @@ -524,17 +545,22 @@ async def test_satellite_disconnect_before_pipeline(hass: HomeAssistant) -> None self.stop() on_restart_event.set() - with patch( - "homeassistant.components.wyoming.data.load_wyoming_info", - return_value=SATELLITE_INFO, - ), patch( - "homeassistant.components.wyoming.satellite.AsyncTcpClient", - MockAsyncTcpClient([]), # no RunPipeline event - ), patch( - "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", - ) as mock_run_pipeline, patch( - "homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart", - on_restart, + with ( + patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=SATELLITE_INFO, + ), + patch( + "homeassistant.components.wyoming.satellite.AsyncTcpClient", + MockAsyncTcpClient([]), # no RunPipeline event + ), + patch( + "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", + ) as mock_run_pipeline, + patch( + "homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart", + on_restart, + ), ): await setup_config_entry(hass) async with asyncio.timeout(1): @@ -564,20 +590,26 @@ async def test_satellite_disconnect_during_pipeline(hass: HomeAssistant) -> None async def on_stopped(self): on_stopped_event.set() - with patch( - "homeassistant.components.wyoming.data.load_wyoming_info", - return_value=SATELLITE_INFO, - ), patch( - "homeassistant.components.wyoming.satellite.AsyncTcpClient", - MockAsyncTcpClient(events), - ), patch( - "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", - ) as mock_run_pipeline, patch( - "homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart", - on_restart, - ), patch( - "homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped", - on_stopped, + with ( + patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=SATELLITE_INFO, + ), + patch( + "homeassistant.components.wyoming.satellite.AsyncTcpClient", + MockAsyncTcpClient(events), + ), + patch( + "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", + ) as mock_run_pipeline, + patch( + "homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart", + on_restart, + ), + patch( + "homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped", + on_stopped, + ), ): entry = await setup_config_entry(hass) device: SatelliteDevice = hass.data[wyoming.DOMAIN][ @@ -608,16 +640,20 @@ async def test_satellite_error_during_pipeline(hass: HomeAssistant) -> None: def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None: pipeline_event.set() - with patch( - "homeassistant.components.wyoming.data.load_wyoming_info", - return_value=SATELLITE_INFO, - ), patch( - "homeassistant.components.wyoming.satellite.AsyncTcpClient", - SatelliteAsyncTcpClient(events), - ) as mock_client, patch( - "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", - wraps=_async_pipeline_from_audio_stream, - ) as mock_run_pipeline: + with ( + patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=SATELLITE_INFO, + ), + patch( + "homeassistant.components.wyoming.satellite.AsyncTcpClient", + SatelliteAsyncTcpClient(events), + ) as mock_client, + patch( + "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", + wraps=_async_pipeline_from_audio_stream, + ) as mock_run_pipeline, + ): await setup_config_entry(hass) async with asyncio.timeout(1): @@ -663,21 +699,27 @@ async def test_tts_not_wav(hass: HomeAssistant) -> None: def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None: pipeline_event.set() - with patch( - "homeassistant.components.wyoming.data.load_wyoming_info", - return_value=SATELLITE_INFO, - ), patch( - "homeassistant.components.wyoming.satellite.AsyncTcpClient", - SatelliteAsyncTcpClient(events), - ) as mock_client, patch( - "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", - wraps=_async_pipeline_from_audio_stream, - ) as mock_run_pipeline, patch( - "homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio", - return_value=("mp3", bytes(1)), - ), patch( - "homeassistant.components.wyoming.satellite.WyomingSatellite._stream_tts", - _stream_tts, + with ( + patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=SATELLITE_INFO, + ), + patch( + "homeassistant.components.wyoming.satellite.AsyncTcpClient", + SatelliteAsyncTcpClient(events), + ) as mock_client, + patch( + "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", + wraps=_async_pipeline_from_audio_stream, + ) as mock_run_pipeline, + patch( + "homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio", + return_value=("mp3", bytes(1)), + ), + patch( + "homeassistant.components.wyoming.satellite.WyomingSatellite._stream_tts", + _stream_tts, + ), ): entry = await setup_config_entry(hass) @@ -752,15 +794,19 @@ async def test_pipeline_changed(hass: HomeAssistant) -> None: pipeline_stopped.set() - with patch( - "homeassistant.components.wyoming.data.load_wyoming_info", - return_value=SATELLITE_INFO, - ), patch( - "homeassistant.components.wyoming.satellite.AsyncTcpClient", - SatelliteAsyncTcpClient(events), - ) as mock_client, patch( - "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", - async_pipeline_from_audio_stream, + with ( + patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=SATELLITE_INFO, + ), + patch( + "homeassistant.components.wyoming.satellite.AsyncTcpClient", + SatelliteAsyncTcpClient(events), + ) as mock_client, + patch( + "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", + async_pipeline_from_audio_stream, + ), ): entry = await setup_config_entry(hass) device: SatelliteDevice = hass.data[wyoming.DOMAIN][ @@ -822,15 +868,19 @@ async def test_audio_settings_changed(hass: HomeAssistant) -> None: pipeline_stopped.set() - with patch( - "homeassistant.components.wyoming.data.load_wyoming_info", - return_value=SATELLITE_INFO, - ), patch( - "homeassistant.components.wyoming.satellite.AsyncTcpClient", - SatelliteAsyncTcpClient(events), - ) as mock_client, patch( - "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", - async_pipeline_from_audio_stream, + with ( + patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=SATELLITE_INFO, + ), + patch( + "homeassistant.components.wyoming.satellite.AsyncTcpClient", + SatelliteAsyncTcpClient(events), + ) as mock_client, + patch( + "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", + async_pipeline_from_audio_stream, + ), ): entry = await setup_config_entry(hass) device: SatelliteDevice = hass.data[wyoming.DOMAIN][ @@ -873,7 +923,7 @@ async def test_invalid_stages(hass: HomeAssistant) -> None: start_stage_event = asyncio.Event() end_stage_event = asyncio.Event() - def _run_pipeline_once(self, run_pipeline): + def _run_pipeline_once(self, run_pipeline, wake_word_phrase): # Set bad start stage run_pipeline.start_stage = PipelineStage.INTENT run_pipeline.end_stage = PipelineStage.TTS @@ -892,15 +942,19 @@ async def test_invalid_stages(hass: HomeAssistant) -> None: except ValueError: end_stage_event.set() - with patch( - "homeassistant.components.wyoming.data.load_wyoming_info", - return_value=SATELLITE_INFO, - ), patch( - "homeassistant.components.wyoming.satellite.AsyncTcpClient", - SatelliteAsyncTcpClient(events), - ) as mock_client, patch( - "homeassistant.components.wyoming.satellite.WyomingSatellite._run_pipeline_once", - _run_pipeline_once, + with ( + patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=SATELLITE_INFO, + ), + patch( + "homeassistant.components.wyoming.satellite.AsyncTcpClient", + SatelliteAsyncTcpClient(events), + ) as mock_client, + patch( + "homeassistant.components.wyoming.satellite.WyomingSatellite._run_pipeline_once", + _run_pipeline_once, + ), ): entry = await setup_config_entry(hass) @@ -950,15 +1004,19 @@ async def test_client_stops_pipeline(hass: HomeAssistant) -> None: pipeline_stopped.set() - with patch( - "homeassistant.components.wyoming.data.load_wyoming_info", - return_value=SATELLITE_INFO, - ), patch( - "homeassistant.components.wyoming.satellite.AsyncTcpClient", - SatelliteAsyncTcpClient(events), - ) as mock_client, patch( - "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", - async_pipeline_from_audio_stream, + with ( + patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=SATELLITE_INFO, + ), + patch( + "homeassistant.components.wyoming.satellite.AsyncTcpClient", + SatelliteAsyncTcpClient(events), + ) as mock_client, + patch( + "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", + async_pipeline_from_audio_stream, + ), ): entry = await setup_config_entry(hass) @@ -982,3 +1040,46 @@ async def test_client_stops_pipeline(hass: HomeAssistant) -> None: # Stop the satellite await hass.config_entries.async_unload(entry.entry_id) await hass.async_block_till_done() + + +async def test_wake_word_phrase(hass: HomeAssistant) -> None: + """Test that wake word phrase from info is given to pipeline.""" + events = [ + # Fake local wake word detection + Info(satellite=SATELLITE_INFO.satellite, wake=WAKE_WORD_INFO.wake).event(), + Detection(name="Test Model").event(), + RunPipeline( + start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS + ).event(), + ] + + pipeline_event = asyncio.Event() + + def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None: + pipeline_event.set() + + with ( + patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=SATELLITE_INFO, + ), + patch( + "homeassistant.components.wyoming.satellite.AsyncTcpClient", + SatelliteAsyncTcpClient(events), + ), + patch( + "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", + wraps=_async_pipeline_from_audio_stream, + ) as mock_run_pipeline, + ): + await setup_config_entry(hass) + + async with asyncio.timeout(1): + await pipeline_event.wait() + + # async_pipeline_from_audio_stream will receive the wake word phrase for + # deconfliction. + mock_run_pipeline.assert_called_once() + assert ( + mock_run_pipeline.call_args.kwargs.get("wake_word_phrase") == "Test Phrase" + ) diff --git a/tests/components/wyoming/test_wake_word.py b/tests/components/wyoming/test_wake_word.py index 1ab869b1b0a..74b8483f7fc 100644 --- a/tests/components/wyoming/test_wake_word.py +++ b/tests/components/wyoming/test_wake_word.py @@ -1,4 +1,5 @@ """Test stt.""" + from __future__ import annotations import asyncio @@ -26,7 +27,7 @@ async def test_support(hass: HomeAssistant, init_wyoming_wake_word) -> None: assert entity is not None assert (await entity.get_supported_wake_words()) == [ - wake_word.WakeWord(id="Test Model", name="Test Model") + wake_word.WakeWord(id="Test Model", name="Test Model", phrase="Test Phrase") ] @@ -59,6 +60,8 @@ async def test_streaming_audio( assert result is not None assert result == snapshot + assert result.wake_word_id == "Test Model" + assert result.wake_word_phrase == "Test Phrase" async def test_streaming_audio_connection_lost( @@ -100,10 +103,13 @@ async def test_streaming_audio_oserror( [Detection(name="Test Model", timestamp=1000).event()] ) - with patch( - "homeassistant.components.wyoming.wake_word.AsyncTcpClient", - mock_client, - ), patch.object(mock_client, "read_event", side_effect=OSError("Boom!")): + with ( + patch( + "homeassistant.components.wyoming.wake_word.AsyncTcpClient", + mock_client, + ), + patch.object(mock_client, "read_event", side_effect=OSError("Boom!")), + ): result = await entity.async_process_audio_stream(audio_stream(), None) assert result is None @@ -171,7 +177,7 @@ async def test_dynamic_wake_word_info( # Original info assert (await entity.get_supported_wake_words()) == [ - wake_word.WakeWord("Test Model", "Test Model") + wake_word.WakeWord("Test Model", "Test Model", "Test Phrase") ] new_info = Info( @@ -185,6 +191,7 @@ async def test_dynamic_wake_word_info( WakeModel( name="ww1", description="Wake Word 1", + phrase="Wake Word Phrase 1", installed=True, attribution=TEST_ATTR, languages=[], @@ -193,6 +200,7 @@ async def test_dynamic_wake_word_info( WakeModel( name="ww2", description="Wake Word 2", + phrase="Wake Word Phrase 2", installed=True, attribution=TEST_ATTR, languages=[], @@ -210,6 +218,6 @@ async def test_dynamic_wake_word_info( return_value=new_info, ): assert (await entity.get_supported_wake_words()) == [ - wake_word.WakeWord("ww1", "Wake Word 1"), - wake_word.WakeWord("ww2", "Wake Word 2"), + wake_word.WakeWord("ww1", "Wake Word 1", "Wake Word Phrase 1"), + wake_word.WakeWord("ww2", "Wake Word 2", "Wake Word Phrase 2"), ]