Fix Assist Satellite making up conversation IDs (#125933)

This commit is contained in:
Paulus Schoutsen 2024-09-13 23:21:31 -04:00 committed by GitHub
parent 6d212ea24e
commit 1b913b8088
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 34 additions and 23 deletions

View File

@ -28,7 +28,6 @@ from homeassistant.components.tts import (
from homeassistant.core import Context, callback from homeassistant.core import Context, callback
from homeassistant.helpers import entity from homeassistant.helpers import entity
from homeassistant.helpers.entity import EntityDescription from homeassistant.helpers.entity import EntityDescription
from homeassistant.util import ulid
from .const import AssistSatelliteEntityFeature from .const import AssistSatelliteEntityFeature
from .errors import AssistSatelliteError, SatelliteBusyError from .errors import AssistSatelliteError, SatelliteBusyError
@ -240,16 +239,11 @@ class AssistSatelliteEntity(entity.Entity):
assert self._context is not None assert self._context is not None
# Reset conversation id if necessary # Reset conversation id if necessary
if (self._conversation_id_time is None) or ( if self._conversation_id_time and (
(time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC (time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC
): ):
self._conversation_id = None self._conversation_id = None
self._conversation_id_time = None
if self._conversation_id is None:
self._conversation_id = ulid.ulid()
# Update timeout
self._conversation_id_time = time.monotonic()
# Set entity state based on pipeline events # Set entity state based on pipeline events
self._run_has_tts = False self._run_has_tts = False
@ -311,6 +305,11 @@ class AssistSatelliteEntity(entity.Entity):
self._set_state(AssistSatelliteState.LISTENING_COMMAND) self._set_state(AssistSatelliteState.LISTENING_COMMAND)
elif event.type is PipelineEventType.INTENT_START: elif event.type is PipelineEventType.INTENT_START:
self._set_state(AssistSatelliteState.PROCESSING) self._set_state(AssistSatelliteState.PROCESSING)
elif event.type is PipelineEventType.INTENT_END:
assert event.data is not None
# Update timeout
self._conversation_id_time = time.monotonic()
self._conversation_id = event.data["intent_output"]["conversation_id"]
elif event.type is PipelineEventType.TTS_START: elif event.type is PipelineEventType.TTS_START:
# Wait until tts_response_finished is called to return to waiting state # Wait until tts_response_finished is called to return to waiting state
self._run_has_tts = True self._run_has_tts = True

View File

@ -69,22 +69,34 @@ async def test_entity_state(
assert kwargs["start_stage"] == PipelineStage.STT assert kwargs["start_stage"] == PipelineStage.STT
assert kwargs["end_stage"] == PipelineStage.TTS assert kwargs["end_stage"] == PipelineStage.TTS
for event_type, expected_state in ( for event_type, event_data, expected_state in (
(PipelineEventType.RUN_START, AssistSatelliteState.LISTENING_WAKE_WORD), (PipelineEventType.RUN_START, {}, AssistSatelliteState.LISTENING_WAKE_WORD),
(PipelineEventType.RUN_END, AssistSatelliteState.LISTENING_WAKE_WORD), (PipelineEventType.RUN_END, {}, AssistSatelliteState.LISTENING_WAKE_WORD),
(PipelineEventType.WAKE_WORD_START, AssistSatelliteState.LISTENING_WAKE_WORD), (
(PipelineEventType.WAKE_WORD_END, AssistSatelliteState.LISTENING_WAKE_WORD), PipelineEventType.WAKE_WORD_START,
(PipelineEventType.STT_START, AssistSatelliteState.LISTENING_COMMAND), {},
(PipelineEventType.STT_VAD_START, AssistSatelliteState.LISTENING_COMMAND), AssistSatelliteState.LISTENING_WAKE_WORD,
(PipelineEventType.STT_VAD_END, AssistSatelliteState.LISTENING_COMMAND), ),
(PipelineEventType.STT_END, AssistSatelliteState.LISTENING_COMMAND), (PipelineEventType.WAKE_WORD_END, {}, AssistSatelliteState.LISTENING_WAKE_WORD),
(PipelineEventType.INTENT_START, AssistSatelliteState.PROCESSING), (PipelineEventType.STT_START, {}, AssistSatelliteState.LISTENING_COMMAND),
(PipelineEventType.INTENT_END, AssistSatelliteState.PROCESSING), (PipelineEventType.STT_VAD_START, {}, AssistSatelliteState.LISTENING_COMMAND),
(PipelineEventType.TTS_START, AssistSatelliteState.RESPONDING), (PipelineEventType.STT_VAD_END, {}, AssistSatelliteState.LISTENING_COMMAND),
(PipelineEventType.TTS_END, AssistSatelliteState.RESPONDING), (PipelineEventType.STT_END, {}, AssistSatelliteState.LISTENING_COMMAND),
(PipelineEventType.ERROR, AssistSatelliteState.RESPONDING), (PipelineEventType.INTENT_START, {}, AssistSatelliteState.PROCESSING),
(
PipelineEventType.INTENT_END,
{
"intent_output": {
"conversation_id": "mock-conversation-id",
}
},
AssistSatelliteState.PROCESSING,
),
(PipelineEventType.TTS_START, {}, AssistSatelliteState.RESPONDING),
(PipelineEventType.TTS_END, {}, AssistSatelliteState.RESPONDING),
(PipelineEventType.ERROR, {}, AssistSatelliteState.RESPONDING),
): ):
kwargs["event_callback"](PipelineEvent(event_type, {})) kwargs["event_callback"](PipelineEvent(event_type, event_data))
state = hass.states.get(ENTITY_ID) state = hass.states.get(ENTITY_ID)
assert state.state == expected_state, event_type assert state.state == expected_state, event_type