diff --git a/homeassistant/components/assist_satellite/entity.py b/homeassistant/components/assist_satellite/entity.py index 6f0e588052a..897f9ed244b 100644 --- a/homeassistant/components/assist_satellite/entity.py +++ b/homeassistant/components/assist_satellite/entity.py @@ -3,6 +3,7 @@ from abc import abstractmethod import asyncio from collections.abc import AsyncIterable +import contextlib from enum import StrEnum import logging import time @@ -73,6 +74,7 @@ class AssistSatelliteEntity(entity.Entity): _is_announcing = False _wake_word_intercept_future: asyncio.Future[str | None] | None = None _attr_tts_options: dict[str, Any] | None = None + _pipeline_task: asyncio.Task | None = None __assist_satellite_state = AssistSatelliteState.LISTENING_WAKE_WORD @@ -131,6 +133,8 @@ class AssistSatelliteEntity(entity.Entity): Calls async_announce with message and media id. """ + await self._cancel_running_pipeline() + if message is None: message = "" @@ -176,7 +180,7 @@ class AssistSatelliteEntity(entity.Entity): await self.async_announce(message, media_id) finally: self._is_announcing = False - self.tts_response_finished() + self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD) async def async_announce(self, message: str, media_id: str) -> None: """Announce media on the satellite. @@ -193,6 +197,8 @@ class AssistSatelliteEntity(entity.Entity): wake_word_phrase: str | None = None, ) -> None: """Triggers an Assist pipeline in Home Assistant from a satellite.""" + await self._cancel_running_pipeline() + if self._wake_word_intercept_future and start_stage in ( PipelineStage.WAKE_WORD, PipelineStage.STT, @@ -248,31 +254,50 @@ class AssistSatelliteEntity(entity.Entity): # Set entity state based on pipeline events self._run_has_tts = False - await async_pipeline_from_audio_stream( + assert self.platform.config_entry is not None + self._pipeline_task = self.platform.config_entry.async_create_background_task( self.hass, - context=self._context, - event_callback=self._internal_on_pipeline_event, - stt_metadata=stt.SpeechMetadata( - language="", # set in async_pipeline_from_audio_stream - format=stt.AudioFormats.WAV, - codec=stt.AudioCodecs.PCM, - bit_rate=stt.AudioBitRates.BITRATE_16, - sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, - channel=stt.AudioChannels.CHANNEL_MONO, + async_pipeline_from_audio_stream( + self.hass, + context=self._context, + event_callback=self._internal_on_pipeline_event, + stt_metadata=stt.SpeechMetadata( + language="", # set in async_pipeline_from_audio_stream + format=stt.AudioFormats.WAV, + codec=stt.AudioCodecs.PCM, + bit_rate=stt.AudioBitRates.BITRATE_16, + sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, + channel=stt.AudioChannels.CHANNEL_MONO, + ), + stt_stream=audio_stream, + pipeline_id=self._resolve_pipeline(), + conversation_id=self._conversation_id, + device_id=device_id, + tts_audio_output=self.tts_options, + wake_word_phrase=wake_word_phrase, + audio_settings=AudioSettings( + silence_seconds=self._resolve_vad_sensitivity() + ), + start_stage=start_stage, + end_stage=end_stage, ), - stt_stream=audio_stream, - pipeline_id=self._resolve_pipeline(), - conversation_id=self._conversation_id, - device_id=device_id, - tts_audio_output=self.tts_options, - wake_word_phrase=wake_word_phrase, - audio_settings=AudioSettings( - silence_seconds=self._resolve_vad_sensitivity() - ), - start_stage=start_stage, - end_stage=end_stage, + f"{self.entity_id}_pipeline", ) + try: + await self._pipeline_task + finally: + self._pipeline_task = None + + async def _cancel_running_pipeline(self) -> None: + """Cancel the current pipeline if it's running.""" + if self._pipeline_task is not None: + self._pipeline_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._pipeline_task + + self._pipeline_task = None + @abstractmethod def on_pipeline_event(self, event: PipelineEvent) -> None: """Handle pipeline events.""" diff --git a/tests/components/assist_satellite/test_entity.py b/tests/components/assist_satellite/test_entity.py index a46f754dd4e..3e58239f921 100644 --- a/tests/components/assist_satellite/test_entity.py +++ b/tests/components/assist_satellite/test_entity.py @@ -93,6 +93,55 @@ async def test_entity_state( assert state.state == AssistSatelliteState.LISTENING_WAKE_WORD +async def test_new_pipeline_cancels_pipeline( + hass: HomeAssistant, + init_components: ConfigEntry, + entity: MockAssistSatellite, +) -> None: + """Test that a new pipeline run cancels any running pipeline.""" + pipeline1_started = asyncio.Event() + pipeline1_finished = asyncio.Event() + pipeline1_cancelled = asyncio.Event() + pipeline2_finished = asyncio.Event() + + async def async_pipeline_from_audio_stream(*args, **kwargs): + if not pipeline1_started.is_set(): + # First pipeline run + pipeline1_started.set() + + # Wait for pipeline to be cancelled + try: + await pipeline1_finished.wait() + except asyncio.CancelledError: + pipeline1_cancelled.set() + raise + else: + # Second pipeline run + pipeline2_finished.set() + + with ( + patch( + "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", + new=async_pipeline_from_audio_stream, + ), + ): + hass.async_create_task( + entity.async_accept_pipeline_from_satellite( + object(), # type: ignore[arg-type] + ) + ) + + async with asyncio.timeout(1): + await pipeline1_started.wait() + + # Start a second pipeline + await entity.async_accept_pipeline_from_satellite( + object(), # type: ignore[arg-type] + ) + await pipeline1_cancelled.wait() + await pipeline2_finished.wait() + + @pytest.mark.parametrize( ("service_data", "expected_params"), [ @@ -210,6 +259,48 @@ async def test_announce_busy( await announce_task +async def test_announce_cancels_pipeline( + hass: HomeAssistant, + init_components: ConfigEntry, + entity: MockAssistSatellite, +) -> None: + """Test that announcements cancel any running pipeline.""" + media_id = "https://www.home-assistant.io/resolved.mp3" + pipeline_started = asyncio.Event() + pipeline_finished = asyncio.Event() + pipeline_cancelled = asyncio.Event() + + async def async_pipeline_from_audio_stream(*args, **kwargs): + pipeline_started.set() + + # Wait for pipeline to be cancelled + try: + await pipeline_finished.wait() + except asyncio.CancelledError: + pipeline_cancelled.set() + raise + + with ( + patch( + "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", + new=async_pipeline_from_audio_stream, + ), + patch.object(entity, "async_announce") as mock_async_announce, + ): + hass.async_create_task( + entity.async_accept_pipeline_from_satellite( + object(), # type: ignore[arg-type] + ) + ) + + async with asyncio.timeout(1): + await pipeline_started.wait() + await entity.async_internal_announce(None, media_id) + await pipeline_cancelled.wait() + + mock_async_announce.assert_called_once() + + async def test_context_refresh( hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite ) -> None: