From a9bcfe5700026a875175788e196e79b02406a73e Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Tue, 26 Sep 2023 20:24:55 +0200 Subject: [PATCH] Abort wake word detection when assist pipeline is modified (#100918) --- .../components/assist_pipeline/error.py | 8 +++ .../components/assist_pipeline/pipeline.py | 68 ++++++++++++++++--- .../assist_pipeline/snapshots/test_init.ambr | 35 ++++++++++ tests/components/assist_pipeline/test_init.py | 64 +++++++++++++++++ 4 files changed, 165 insertions(+), 10 deletions(-) diff --git a/homeassistant/components/assist_pipeline/error.py b/homeassistant/components/assist_pipeline/error.py index 094913424b6..209e2611ec0 100644 --- a/homeassistant/components/assist_pipeline/error.py +++ b/homeassistant/components/assist_pipeline/error.py @@ -22,6 +22,14 @@ class WakeWordDetectionError(PipelineError): """Error in wake-word-detection portion of pipeline.""" +class WakeWordDetectionAborted(WakeWordDetectionError): + """Wake-word-detection was aborted.""" + + def __init__(self) -> None: + """Set error message.""" + super().__init__("wake_word_detection_aborted", "") + + class WakeWordTimeoutError(WakeWordDetectionError): """Timeout when wake word was not detected.""" diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index a66408a01de..7e4c71671ad 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -32,6 +32,7 @@ from homeassistant.components.tts.media_source import ( from homeassistant.core import Context, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.collection import ( + CHANGE_UPDATED, CollectionError, ItemNotFound, SerializedStorageCollection, @@ -54,6 +55,7 @@ from .error import ( PipelineNotFound, SpeechToTextError, TextToSpeechError, + WakeWordDetectionAborted, WakeWordDetectionError, WakeWordTimeoutError, ) @@ -470,11 +472,13 @@ class PipelineRun: audio_settings: AudioSettings = field(default_factory=AudioSettings) id: str = field(default_factory=ulid_util.ulid) - stt_provider: stt.SpeechToTextEntity | stt.Provider = field(init=False) - tts_engine: str = field(init=False) + stt_provider: stt.SpeechToTextEntity | stt.Provider = field(init=False, repr=False) + tts_engine: str = field(init=False, repr=False) tts_options: dict | None = field(init=False, default=None) - wake_word_entity_id: str = field(init=False) - wake_word_entity: wake_word.WakeWordDetectionEntity = field(init=False) + wake_word_entity_id: str = field(init=False, repr=False) + wake_word_entity: wake_word.WakeWordDetectionEntity = field(init=False, repr=False) + + abort_wake_word_detection: bool = field(init=False, default=False) debug_recording_thread: Thread | None = None """Thread that records audio to debug_recording_dir""" @@ -485,7 +489,7 @@ class PipelineRun: audio_processor: AudioProcessor | None = None """VAD/noise suppression/auto gain""" - audio_processor_buffer: AudioBuffer = field(init=False) + audio_processor_buffer: AudioBuffer = field(init=False, repr=False) """Buffer used when splitting audio into chunks for audio processing""" def __post_init__(self) -> None: @@ -504,6 +508,7 @@ class PipelineRun: size_limit=STORED_PIPELINE_RUNS ) pipeline_data.pipeline_debug[self.pipeline.id][self.id] = PipelineRunDebug() + pipeline_data.pipeline_runs.add_run(self) # Initialize with audio settings self.audio_processor_buffer = AudioBuffer(AUDIO_PROCESSOR_BYTES) @@ -548,6 +553,9 @@ class PipelineRun: ) ) + pipeline_data: PipelineData = self.hass.data[DOMAIN] + pipeline_data.pipeline_runs.remove_run(self) + async def prepare_wake_word_detection(self) -> None: """Prepare wake-word-detection.""" entity_id = self.pipeline.wake_word_entity or wake_word.async_default_entity( @@ -638,6 +646,8 @@ class PipelineRun: # All audio kept from right before the wake word was detected as # a single chunk. audio_chunks_for_stt.extend(stt_audio_buffer) + except WakeWordDetectionAborted: + raise except WakeWordTimeoutError: _LOGGER.debug("Timeout during wake word detection") raise @@ -696,6 +706,9 @@ class PipelineRun: """ chunk_seconds = AUDIO_PROCESSOR_SAMPLES / sample_rate async for chunk in audio_stream: + if self.abort_wake_word_detection: + raise WakeWordDetectionAborted + if self.debug_recording_queue is not None: self.debug_recording_queue.put_nowait(chunk.audio) @@ -1547,13 +1560,48 @@ class PipelineStorageCollectionWebsocket( connection.send_result(msg["id"]) -@dataclass +class PipelineRuns: + """Class managing pipelineruns.""" + + def __init__(self, pipeline_store: PipelineStorageCollection) -> None: + """Initialize.""" + self._pipeline_runs: dict[str, list[PipelineRun]] = {} + self._pipeline_store = pipeline_store + pipeline_store.async_add_listener(self._change_listener) + + def add_run(self, pipeline_run: PipelineRun) -> None: + """Add pipeline run.""" + pipeline_id = pipeline_run.pipeline.id + if pipeline_id not in self._pipeline_runs: + self._pipeline_runs[pipeline_id] = [] + self._pipeline_runs[pipeline_id].append(pipeline_run) + + def remove_run(self, pipeline_run: PipelineRun) -> None: + """Remove pipeline run.""" + pipeline_id = pipeline_run.pipeline.id + self._pipeline_runs[pipeline_id].remove(pipeline_run) + + async def _change_listener( + self, change_type: str, item_id: str, change: dict + ) -> None: + """Handle pipeline store changes.""" + if change_type != CHANGE_UPDATED: + return + if pipeline_runs := self._pipeline_runs.get(item_id): + # Create a temporary list in case the list is modified while we iterate + for pipeline_run in list(pipeline_runs): + pipeline_run.abort_wake_word_detection = True + + class PipelineData: """Store and debug data stored in hass.data.""" - pipeline_debug: dict[str, LimitedSizeDict[str, PipelineRunDebug]] - pipeline_store: PipelineStorageCollection - pipeline_devices: set[str] = field(default_factory=set, init=False) + def __init__(self, pipeline_store: PipelineStorageCollection) -> None: + """Initialize.""" + self.pipeline_store = pipeline_store + self.pipeline_debug: dict[str, LimitedSizeDict[str, PipelineRunDebug]] = {} + self.pipeline_devices: set[str] = set() + self.pipeline_runs = PipelineRuns(pipeline_store) @dataclass @@ -1605,4 +1653,4 @@ async def async_setup_pipeline_store(hass: HomeAssistant) -> PipelineData: PIPELINE_FIELDS, PIPELINE_FIELDS, ).async_setup(hass) - return PipelineData({}, pipeline_store) + return PipelineData(pipeline_store) diff --git a/tests/components/assist_pipeline/snapshots/test_init.ambr b/tests/components/assist_pipeline/snapshots/test_init.ambr index 2108b84460e..3f0582f2bfb 100644 --- a/tests/components/assist_pipeline/snapshots/test_init.ambr +++ b/tests/components/assist_pipeline/snapshots/test_init.ambr @@ -377,3 +377,38 @@ }), ]) # --- +# name: test_wake_word_detection_aborted + list([ + dict({ + 'data': dict({ + 'language': 'en', + 'pipeline': , + }), + 'type': , + }), + dict({ + 'data': dict({ + 'entity_id': 'wake_word.test', + 'metadata': dict({ + 'bit_rate': , + 'channel': , + 'codec': , + 'format': , + 'sample_rate': , + }), + }), + 'type': , + }), + dict({ + 'data': dict({ + 'code': 'wake_word_detection_aborted', + 'message': '', + }), + 'type': , + }), + dict({ + 'data': None, + 'type': , + }), + ]) +# --- diff --git a/tests/components/assist_pipeline/test_init.py b/tests/components/assist_pipeline/test_init.py index b41e23d7a0d..5258736c89f 100644 --- a/tests/components/assist_pipeline/test_init.py +++ b/tests/components/assist_pipeline/test_init.py @@ -563,3 +563,67 @@ async def test_pipeline_saved_audio_write_error( start_stage=assist_pipeline.PipelineStage.WAKE_WORD, end_stage=assist_pipeline.PipelineStage.STT, ) + + +async def test_wake_word_detection_aborted( + hass: HomeAssistant, + mock_stt_provider: MockSttProvider, + mock_wake_word_provider_entity: MockWakeWordEntity, + init_components, + pipeline_data: assist_pipeline.pipeline.PipelineData, + snapshot: SnapshotAssertion, +) -> None: + """Test creating a pipeline from an audio stream with wake word.""" + + events: list[assist_pipeline.PipelineEvent] = [] + + async def audio_data(): + yield b"silence!" + yield b"wake word!" + yield b"part1" + yield b"part2" + yield b"" + + pipeline_store = pipeline_data.pipeline_store + pipeline_id = pipeline_store.async_get_preferred_item() + pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) + + pipeline_input = assist_pipeline.pipeline.PipelineInput( + conversation_id=None, + device_id=None, + stt_metadata=stt.SpeechMetadata( + language="", + format=stt.AudioFormats.WAV, + codec=stt.AudioCodecs.PCM, + bit_rate=stt.AudioBitRates.BITRATE_16, + sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, + channel=stt.AudioChannels.CHANNEL_MONO, + ), + stt_stream=audio_data(), + run=assist_pipeline.pipeline.PipelineRun( + hass, + context=Context(), + pipeline=pipeline, + start_stage=assist_pipeline.PipelineStage.WAKE_WORD, + end_stage=assist_pipeline.PipelineStage.TTS, + event_callback=events.append, + tts_audio_output=None, + wake_word_settings=assist_pipeline.WakeWordSettings( + audio_seconds_to_buffer=1.5 + ), + audio_settings=assist_pipeline.AudioSettings( + is_vad_enabled=False, is_chunking_enabled=False + ), + ), + ) + await pipeline_input.validate() + + updates = pipeline.to_json() + updates.pop("id") + await pipeline_store.async_update_item( + pipeline_id, + updates, + ) + await pipeline_input.execute() + + assert process_events(events) == snapshot