mirror of
https://github.com/home-assistant/core.git
synced 2025-07-25 14:17:45 +00:00
Clear extra system prompt on start_conversation error (#137254)
* Clear extra system prompt on start_conversation error * Update homeassistant/components/assist_satellite/entity.py Co-authored-by: Paulus Schoutsen <balloob@gmail.com> --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
58b7be7c2f
commit
28edbdc107
@ -274,6 +274,11 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
await self.async_start_conversation(announcement)
|
await self.async_start_conversation(announcement)
|
||||||
|
except Exception:
|
||||||
|
# Clear prompt on error
|
||||||
|
self._conversation_id = None
|
||||||
|
self._extra_system_prompt = None
|
||||||
|
raise
|
||||||
finally:
|
finally:
|
||||||
self._is_announcing = False
|
self._is_announcing = False
|
||||||
|
|
||||||
|
@ -1084,3 +1084,90 @@ async def test_start_conversation(
|
|||||||
# Wait for TTS
|
# Wait for TTS
|
||||||
await tts_sent.wait()
|
await tts_sent.wait()
|
||||||
await conversation_task
|
await conversation_task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("socket_enabled")
|
||||||
|
async def test_start_conversation_user_doesnt_pick_up(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
voip_devices: VoIPDevices,
|
||||||
|
voip_device: VoIPDevice,
|
||||||
|
) -> None:
|
||||||
|
"""Test start conversation when the user doesn't pick up."""
|
||||||
|
assert await async_setup_component(hass, "voip", {})
|
||||||
|
|
||||||
|
pipeline = assist_pipeline.Pipeline(
|
||||||
|
conversation_engine="test engine",
|
||||||
|
conversation_language="en",
|
||||||
|
language="en",
|
||||||
|
name="test pipeline",
|
||||||
|
stt_engine="test stt",
|
||||||
|
stt_language="en",
|
||||||
|
tts_engine="test tts",
|
||||||
|
tts_language="en",
|
||||||
|
tts_voice=None,
|
||||||
|
wake_word_entity=None,
|
||||||
|
wake_word_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||||
|
assert isinstance(satellite, VoipAssistSatellite)
|
||||||
|
assert (
|
||||||
|
satellite.supported_features
|
||||||
|
& assist_satellite.AssistSatelliteEntityFeature.START_CONVERSATION
|
||||||
|
)
|
||||||
|
|
||||||
|
# Protocol has already been mocked, but "outgoing_call" is not async
|
||||||
|
mock_protocol: AsyncMock = hass.data[DOMAIN].protocol
|
||||||
|
mock_protocol.outgoing_call = Mock()
|
||||||
|
|
||||||
|
pipeline_started = asyncio.Event()
|
||||||
|
|
||||||
|
async def async_pipeline_from_audio_stream(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
context: Context,
|
||||||
|
*args,
|
||||||
|
conversation_extra_system_prompt: str | None = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# System prompt should be not be set due to timeout (user not picking up)
|
||||||
|
assert conversation_extra_system_prompt is None
|
||||||
|
|
||||||
|
pipeline_started.set()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.assist_satellite.entity.async_get_pipeline",
|
||||||
|
return_value=pipeline,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.voip.assist_satellite.VoipAssistSatellite.async_start_conversation",
|
||||||
|
side_effect=TimeoutError,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
|
new=async_pipeline_from_audio_stream,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.assist_satellite.entity.tts_generate_media_source_id",
|
||||||
|
return_value="test media id",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
satellite.transport = Mock()
|
||||||
|
|
||||||
|
# Error should clear system prompt
|
||||||
|
with pytest.raises(TimeoutError):
|
||||||
|
await hass.services.async_call(
|
||||||
|
assist_satellite.DOMAIN,
|
||||||
|
"start_conversation",
|
||||||
|
{
|
||||||
|
"entity_id": satellite.entity_id,
|
||||||
|
"start_message": "test announcement",
|
||||||
|
"extra_system_prompt": "test prompt",
|
||||||
|
},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Trigger a pipeline so we can check if the system prompt was cleared
|
||||||
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await pipeline_started.wait()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user