Set responding state in assist satellite start_conversation (#141388)

* Set responding state in async_start_conversation

* Check idle state
This commit is contained in:
Michael Hansen 2025-03-25 13:30:44 -05:00 committed by GitHub
parent c8745cc339
commit 7319637bd5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 0 deletions

View File

@ -262,6 +262,8 @@ class AssistSatelliteEntity(entity.Entity):
raise SatelliteBusyError raise SatelliteBusyError
self._is_announcing = True self._is_announcing = True
self._set_state(AssistSatelliteState.RESPONDING)
# Provide our start info to the LLM so it understands context of incoming message # Provide our start info to the LLM so it understands context of incoming message
if extra_system_prompt is not None: if extra_system_prompt is not None:
self._extra_system_prompt = extra_system_prompt self._extra_system_prompt = extra_system_prompt
@ -291,6 +293,7 @@ class AssistSatelliteEntity(entity.Entity):
raise raise
finally: finally:
self._is_announcing = False self._is_announcing = False
self._set_state(AssistSatelliteState.IDLE)
async def async_start_conversation( async def async_start_conversation(
self, start_announcement: AssistSatelliteAnnouncement self, start_announcement: AssistSatelliteAnnouncement

View File

@ -594,6 +594,13 @@ async def test_start_conversation(
expected_params: tuple[str, str], expected_params: tuple[str, str],
) -> None: ) -> None:
"""Test starting a conversation on a device.""" """Test starting a conversation on a device."""
original_start_conversation = entity.async_start_conversation
async def async_start_conversation(start_announcement):
# Verify state change
assert entity.state == AssistSatelliteState.RESPONDING
await original_start_conversation(start_announcement)
await async_update_pipeline( await async_update_pipeline(
hass, hass,
async_get_pipeline(hass), async_get_pipeline(hass),
@ -620,6 +627,7 @@ async def test_start_conversation(
mime_type="audio/mp3", mime_type="audio/mp3",
), ),
), ),
patch.object(entity, "async_start_conversation", new=async_start_conversation),
): ):
await hass.services.async_call( await hass.services.async_call(
"assist_satellite", "assist_satellite",
@ -628,6 +636,7 @@ async def test_start_conversation(
target={"entity_id": "assist_satellite.test_entity"}, target={"entity_id": "assist_satellite.test_entity"},
blocking=True, blocking=True,
) )
assert entity.state == AssistSatelliteState.IDLE
assert entity.start_conversations[0] == expected_params assert entity.start_conversations[0] == expected_params