From 28edbdc107818e5872fa060c8c815520902eaa0d Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Mon, 3 Feb 2025 11:07:45 -0600 Subject: [PATCH] Clear extra system prompt on start_conversation error (#137254) * Clear extra system prompt on start_conversation error * Update homeassistant/components/assist_satellite/entity.py Co-authored-by: Paulus Schoutsen --------- Co-authored-by: Paulus Schoutsen --- .../components/assist_satellite/entity.py | 5 ++ tests/components/voip/test_voip.py | 87 +++++++++++++++++++ 2 files changed, 92 insertions(+) diff --git a/homeassistant/components/assist_satellite/entity.py b/homeassistant/components/assist_satellite/entity.py index c901bc7d928..902cf731a5d 100644 --- a/homeassistant/components/assist_satellite/entity.py +++ b/homeassistant/components/assist_satellite/entity.py @@ -274,6 +274,11 @@ class AssistSatelliteEntity(entity.Entity): try: await self.async_start_conversation(announcement) + except Exception: + # Clear prompt on error + self._conversation_id = None + self._extra_system_prompt = None + raise finally: self._is_announcing = False diff --git a/tests/components/voip/test_voip.py b/tests/components/voip/test_voip.py index 442f4a62392..3e3e5337417 100644 --- a/tests/components/voip/test_voip.py +++ b/tests/components/voip/test_voip.py @@ -1084,3 +1084,90 @@ async def test_start_conversation( # Wait for TTS await tts_sent.wait() await conversation_task + + +@pytest.mark.usefixtures("socket_enabled") +async def test_start_conversation_user_doesnt_pick_up( + hass: HomeAssistant, + voip_devices: VoIPDevices, + voip_device: VoIPDevice, +) -> None: + """Test start conversation when the user doesn't pick up.""" + assert await async_setup_component(hass, "voip", {}) + + pipeline = assist_pipeline.Pipeline( + conversation_engine="test engine", + conversation_language="en", + language="en", + name="test pipeline", + stt_engine="test stt", + stt_language="en", + tts_engine="test tts", + tts_language="en", + tts_voice=None, + wake_word_entity=None, + wake_word_id=None, + ) + + satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id) + assert isinstance(satellite, VoipAssistSatellite) + assert ( + satellite.supported_features + & assist_satellite.AssistSatelliteEntityFeature.START_CONVERSATION + ) + + # Protocol has already been mocked, but "outgoing_call" is not async + mock_protocol: AsyncMock = hass.data[DOMAIN].protocol + mock_protocol.outgoing_call = Mock() + + pipeline_started = asyncio.Event() + + async def async_pipeline_from_audio_stream( + hass: HomeAssistant, + context: Context, + *args, + conversation_extra_system_prompt: str | None = None, + **kwargs, + ): + # System prompt should be not be set due to timeout (user not picking up) + assert conversation_extra_system_prompt is None + + pipeline_started.set() + + with ( + patch( + "homeassistant.components.assist_satellite.entity.async_get_pipeline", + return_value=pipeline, + ), + patch( + "homeassistant.components.voip.assist_satellite.VoipAssistSatellite.async_start_conversation", + side_effect=TimeoutError, + ), + patch( + "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", + new=async_pipeline_from_audio_stream, + ), + patch( + "homeassistant.components.assist_satellite.entity.tts_generate_media_source_id", + return_value="test media id", + ), + ): + satellite.transport = Mock() + + # Error should clear system prompt + with pytest.raises(TimeoutError): + await hass.services.async_call( + assist_satellite.DOMAIN, + "start_conversation", + { + "entity_id": satellite.entity_id, + "start_message": "test announcement", + "extra_system_prompt": "test prompt", + }, + blocking=True, + ) + + # Trigger a pipeline so we can check if the system prompt was cleared + satellite.on_chunk(bytes(_ONE_SECOND)) + async with asyncio.timeout(1): + await pipeline_started.wait()