Clear extra system prompt on start_conversation error (#137254)

* Clear extra system prompt on start_conversation error

* Update homeassistant/components/assist_satellite/entity.py

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Michael Hansen 2025-02-03 11:07:45 -06:00 committed by GitHub
parent 58b7be7c2f
commit 28edbdc107
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 92 additions and 0 deletions

View File

@ -274,6 +274,11 @@ class AssistSatelliteEntity(entity.Entity):
try:
await self.async_start_conversation(announcement)
except Exception:
# Clear prompt on error
self._conversation_id = None
self._extra_system_prompt = None
raise
finally:
self._is_announcing = False

View File

@ -1084,3 +1084,90 @@ async def test_start_conversation(
# Wait for TTS
await tts_sent.wait()
await conversation_task
@pytest.mark.usefixtures("socket_enabled")
async def test_start_conversation_user_doesnt_pick_up(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
) -> None:
"""Test start conversation when the user doesn't pick up."""
assert await async_setup_component(hass, "voip", {})
pipeline = assist_pipeline.Pipeline(
conversation_engine="test engine",
conversation_language="en",
language="en",
name="test pipeline",
stt_engine="test stt",
stt_language="en",
tts_engine="test tts",
tts_language="en",
tts_voice=None,
wake_word_entity=None,
wake_word_id=None,
)
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
assert (
satellite.supported_features
& assist_satellite.AssistSatelliteEntityFeature.START_CONVERSATION
)
# Protocol has already been mocked, but "outgoing_call" is not async
mock_protocol: AsyncMock = hass.data[DOMAIN].protocol
mock_protocol.outgoing_call = Mock()
pipeline_started = asyncio.Event()
async def async_pipeline_from_audio_stream(
hass: HomeAssistant,
context: Context,
*args,
conversation_extra_system_prompt: str | None = None,
**kwargs,
):
# System prompt should be not be set due to timeout (user not picking up)
assert conversation_extra_system_prompt is None
pipeline_started.set()
with (
patch(
"homeassistant.components.assist_satellite.entity.async_get_pipeline",
return_value=pipeline,
),
patch(
"homeassistant.components.voip.assist_satellite.VoipAssistSatellite.async_start_conversation",
side_effect=TimeoutError,
),
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.assist_satellite.entity.tts_generate_media_source_id",
return_value="test media id",
),
):
satellite.transport = Mock()
# Error should clear system prompt
with pytest.raises(TimeoutError):
await hass.services.async_call(
assist_satellite.DOMAIN,
"start_conversation",
{
"entity_id": satellite.entity_id,
"start_message": "test announcement",
"extra_system_prompt": "test prompt",
},
blocking=True,
)
# Trigger a pipeline so we can check if the system prompt was cleared
satellite.on_chunk(bytes(_ONE_SECOND))
async with asyncio.timeout(1):
await pipeline_started.wait()