Assist Satellite to use ChatSession for conversation ID (#137142)

* Assist Satellite to use ChatSession for conversation ID

* Adjust for changes main branch

* Ensure the initial message is in the chat log
This commit is contained in:
Paulus Schoutsen 2025-02-03 10:13:09 -05:00 committed by GitHub
parent 4531a46557
commit 8acab6c646
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 75 additions and 52 deletions

View File

@ -8,7 +8,7 @@ from dataclasses import dataclass
from enum import StrEnum
import logging
import time
from typing import Any, Final, Literal, final
from typing import Any, Literal, final
from homeassistant.components import conversation, media_source, stt, tts
from homeassistant.components.assist_pipeline import (
@ -28,14 +28,12 @@ from homeassistant.components.tts import (
)
from homeassistant.core import Context, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity
from homeassistant.helpers import chat_session, entity
from homeassistant.helpers.entity import EntityDescription
from .const import AssistSatelliteEntityFeature
from .errors import AssistSatelliteError, SatelliteBusyError
_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes
_LOGGER = logging.getLogger(__name__)
@ -114,7 +112,6 @@ class AssistSatelliteEntity(entity.Entity):
_attr_vad_sensitivity_entity_id: str | None = None
_conversation_id: str | None = None
_conversation_id_time: float | None = None
_run_has_tts: bool = False
_is_announcing = False
@ -260,6 +257,21 @@ class AssistSatelliteEntity(entity.Entity):
else:
self._extra_system_prompt = start_message or None
with (
# Not passing in a conversation ID will force a new one to be created
chat_session.async_get_chat_session(self.hass) as session,
conversation.async_get_chat_log(self.hass, session) as chat_log,
):
self._conversation_id = session.conversation_id
if start_message:
async for _tool_response in chat_log.async_add_assistant_content(
conversation.AssistantContent(
agent_id=self.entity_id, content=start_message
)
):
pass # no tool responses.
try:
await self.async_start_conversation(announcement)
finally:
@ -325,51 +337,52 @@ class AssistSatelliteEntity(entity.Entity):
assert self._context is not None
# Reset conversation id if necessary
if self._conversation_id_time and (
(time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC
):
self._conversation_id = None
self._conversation_id_time = None
# Set entity state based on pipeline events
self._run_has_tts = False
assert self.platform.config_entry is not None
self._pipeline_task = self.platform.config_entry.async_create_background_task(
self.hass,
async_pipeline_from_audio_stream(
self.hass,
context=self._context,
event_callback=self._internal_on_pipeline_event,
stt_metadata=stt.SpeechMetadata(
language="", # set in async_pipeline_from_audio_stream
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_stream,
pipeline_id=self._resolve_pipeline(),
conversation_id=self._conversation_id,
device_id=device_id,
tts_audio_output=self.tts_options,
wake_word_phrase=wake_word_phrase,
audio_settings=AudioSettings(
silence_seconds=self._resolve_vad_sensitivity()
),
start_stage=start_stage,
end_stage=end_stage,
conversation_extra_system_prompt=extra_system_prompt,
),
f"{self.entity_id}_pipeline",
)
try:
await self._pipeline_task
finally:
self._pipeline_task = None
with chat_session.async_get_chat_session(
self.hass, self._conversation_id
) as session:
# Store the conversation ID. If it is no longer valid, get_chat_session will reset it
self._conversation_id = session.conversation_id
self._pipeline_task = (
self.platform.config_entry.async_create_background_task(
self.hass,
async_pipeline_from_audio_stream(
self.hass,
context=self._context,
event_callback=self._internal_on_pipeline_event,
stt_metadata=stt.SpeechMetadata(
language="", # set in async_pipeline_from_audio_stream
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_stream,
pipeline_id=self._resolve_pipeline(),
conversation_id=session.conversation_id,
device_id=device_id,
tts_audio_output=self.tts_options,
wake_word_phrase=wake_word_phrase,
audio_settings=AudioSettings(
silence_seconds=self._resolve_vad_sensitivity()
),
start_stage=start_stage,
end_stage=end_stage,
conversation_extra_system_prompt=extra_system_prompt,
),
f"{self.entity_id}_pipeline",
)
)
try:
await self._pipeline_task
finally:
self._pipeline_task = None
async def _cancel_running_pipeline(self) -> None:
"""Cancel the current pipeline if it's running."""
@ -393,11 +406,6 @@ class AssistSatelliteEntity(entity.Entity):
self._set_state(AssistSatelliteState.LISTENING)
elif event.type is PipelineEventType.INTENT_START:
self._set_state(AssistSatelliteState.PROCESSING)
elif event.type is PipelineEventType.INTENT_END:
assert event.data is not None
# Update timeout
self._conversation_id_time = time.monotonic()
self._conversation_id = event.data["intent_output"]["conversation_id"]
elif event.type is PipelineEventType.TTS_START:
# Wait until tts_response_finished is called to return to waiting state
self._run_has_tts = True

View File

@ -94,7 +94,9 @@ class MockAssistSatellite(AssistSatelliteEntity):
self, start_announcement: AssistSatelliteConfiguration
) -> None:
"""Start a conversation from the satellite."""
self.start_conversations.append((self._extra_system_prompt, start_announcement))
self.start_conversations.append(
(self._conversation_id, self._extra_system_prompt, start_announcement)
)
@pytest.fixture

View File

@ -1,7 +1,8 @@
"""Test the Assist Satellite entity."""
import asyncio
from unittest.mock import patch
from collections.abc import Generator
from unittest.mock import Mock, patch
import pytest
@ -31,6 +32,14 @@ from . import ENTITY_ID
from .conftest import MockAssistSatellite
@pytest.fixture
def mock_chat_session_conversation_id() -> Generator[Mock]:
"""Mock the ulid library."""
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
mock_ulid_now.return_value = "mock-conversation-id"
yield mock_ulid_now
@pytest.fixture(autouse=True)
async def set_pipeline_tts(hass: HomeAssistant, init_components: ConfigEntry) -> None:
"""Set up a pipeline with a TTS engine."""
@ -487,6 +496,7 @@ async def test_vad_sensitivity_entity_not_found(
"extra_system_prompt": "Better system prompt",
},
(
"mock-conversation-id",
"Better system prompt",
AssistSatelliteAnnouncement(
message="Hello",
@ -502,6 +512,7 @@ async def test_vad_sensitivity_entity_not_found(
"start_media_id": "media-source://given",
},
(
"mock-conversation-id",
"Hello",
AssistSatelliteAnnouncement(
message="Hello",
@ -514,6 +525,7 @@ async def test_vad_sensitivity_entity_not_found(
(
{"start_media_id": "http://example.com/given.mp3"},
(
"mock-conversation-id",
None,
AssistSatelliteAnnouncement(
message="",
@ -525,6 +537,7 @@ async def test_vad_sensitivity_entity_not_found(
),
],
)
@pytest.mark.usefixtures("mock_chat_session_conversation_id")
async def test_start_conversation(
hass: HomeAssistant,
init_components: ConfigEntry,