Assist Satellite to use ChatSession for conversation ID (#137142)

* Assist Satellite to use ChatSession for conversation ID

* Adjust for changes main branch

* Ensure the initial message is in the chat log
This commit is contained in:
Paulus Schoutsen 2025-02-03 10:13:09 -05:00 committed by GitHub
parent 4531a46557
commit 8acab6c646
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 75 additions and 52 deletions

View File

@ -8,7 +8,7 @@ from dataclasses import dataclass
from enum import StrEnum from enum import StrEnum
import logging import logging
import time import time
from typing import Any, Final, Literal, final from typing import Any, Literal, final
from homeassistant.components import conversation, media_source, stt, tts from homeassistant.components import conversation, media_source, stt, tts
from homeassistant.components.assist_pipeline import ( from homeassistant.components.assist_pipeline import (
@ -28,14 +28,12 @@ from homeassistant.components.tts import (
) )
from homeassistant.core import Context, callback from homeassistant.core import Context, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity from homeassistant.helpers import chat_session, entity
from homeassistant.helpers.entity import EntityDescription from homeassistant.helpers.entity import EntityDescription
from .const import AssistSatelliteEntityFeature from .const import AssistSatelliteEntityFeature
from .errors import AssistSatelliteError, SatelliteBusyError from .errors import AssistSatelliteError, SatelliteBusyError
_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -114,7 +112,6 @@ class AssistSatelliteEntity(entity.Entity):
_attr_vad_sensitivity_entity_id: str | None = None _attr_vad_sensitivity_entity_id: str | None = None
_conversation_id: str | None = None _conversation_id: str | None = None
_conversation_id_time: float | None = None
_run_has_tts: bool = False _run_has_tts: bool = False
_is_announcing = False _is_announcing = False
@ -260,6 +257,21 @@ class AssistSatelliteEntity(entity.Entity):
else: else:
self._extra_system_prompt = start_message or None self._extra_system_prompt = start_message or None
with (
# Not passing in a conversation ID will force a new one to be created
chat_session.async_get_chat_session(self.hass) as session,
conversation.async_get_chat_log(self.hass, session) as chat_log,
):
self._conversation_id = session.conversation_id
if start_message:
async for _tool_response in chat_log.async_add_assistant_content(
conversation.AssistantContent(
agent_id=self.entity_id, content=start_message
)
):
pass # no tool responses.
try: try:
await self.async_start_conversation(announcement) await self.async_start_conversation(announcement)
finally: finally:
@ -325,51 +337,52 @@ class AssistSatelliteEntity(entity.Entity):
assert self._context is not None assert self._context is not None
# Reset conversation id if necessary
if self._conversation_id_time and (
(time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC
):
self._conversation_id = None
self._conversation_id_time = None
# Set entity state based on pipeline events # Set entity state based on pipeline events
self._run_has_tts = False self._run_has_tts = False
assert self.platform.config_entry is not None assert self.platform.config_entry is not None
self._pipeline_task = self.platform.config_entry.async_create_background_task(
self.hass,
async_pipeline_from_audio_stream(
self.hass,
context=self._context,
event_callback=self._internal_on_pipeline_event,
stt_metadata=stt.SpeechMetadata(
language="", # set in async_pipeline_from_audio_stream
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_stream,
pipeline_id=self._resolve_pipeline(),
conversation_id=self._conversation_id,
device_id=device_id,
tts_audio_output=self.tts_options,
wake_word_phrase=wake_word_phrase,
audio_settings=AudioSettings(
silence_seconds=self._resolve_vad_sensitivity()
),
start_stage=start_stage,
end_stage=end_stage,
conversation_extra_system_prompt=extra_system_prompt,
),
f"{self.entity_id}_pipeline",
)
try: with chat_session.async_get_chat_session(
await self._pipeline_task self.hass, self._conversation_id
finally: ) as session:
self._pipeline_task = None # Store the conversation ID. If it is no longer valid, get_chat_session will reset it
self._conversation_id = session.conversation_id
self._pipeline_task = (
self.platform.config_entry.async_create_background_task(
self.hass,
async_pipeline_from_audio_stream(
self.hass,
context=self._context,
event_callback=self._internal_on_pipeline_event,
stt_metadata=stt.SpeechMetadata(
language="", # set in async_pipeline_from_audio_stream
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_stream,
pipeline_id=self._resolve_pipeline(),
conversation_id=session.conversation_id,
device_id=device_id,
tts_audio_output=self.tts_options,
wake_word_phrase=wake_word_phrase,
audio_settings=AudioSettings(
silence_seconds=self._resolve_vad_sensitivity()
),
start_stage=start_stage,
end_stage=end_stage,
conversation_extra_system_prompt=extra_system_prompt,
),
f"{self.entity_id}_pipeline",
)
)
try:
await self._pipeline_task
finally:
self._pipeline_task = None
async def _cancel_running_pipeline(self) -> None: async def _cancel_running_pipeline(self) -> None:
"""Cancel the current pipeline if it's running.""" """Cancel the current pipeline if it's running."""
@ -393,11 +406,6 @@ class AssistSatelliteEntity(entity.Entity):
self._set_state(AssistSatelliteState.LISTENING) self._set_state(AssistSatelliteState.LISTENING)
elif event.type is PipelineEventType.INTENT_START: elif event.type is PipelineEventType.INTENT_START:
self._set_state(AssistSatelliteState.PROCESSING) self._set_state(AssistSatelliteState.PROCESSING)
elif event.type is PipelineEventType.INTENT_END:
assert event.data is not None
# Update timeout
self._conversation_id_time = time.monotonic()
self._conversation_id = event.data["intent_output"]["conversation_id"]
elif event.type is PipelineEventType.TTS_START: elif event.type is PipelineEventType.TTS_START:
# Wait until tts_response_finished is called to return to waiting state # Wait until tts_response_finished is called to return to waiting state
self._run_has_tts = True self._run_has_tts = True

View File

@ -94,7 +94,9 @@ class MockAssistSatellite(AssistSatelliteEntity):
self, start_announcement: AssistSatelliteConfiguration self, start_announcement: AssistSatelliteConfiguration
) -> None: ) -> None:
"""Start a conversation from the satellite.""" """Start a conversation from the satellite."""
self.start_conversations.append((self._extra_system_prompt, start_announcement)) self.start_conversations.append(
(self._conversation_id, self._extra_system_prompt, start_announcement)
)
@pytest.fixture @pytest.fixture

View File

@ -1,7 +1,8 @@
"""Test the Assist Satellite entity.""" """Test the Assist Satellite entity."""
import asyncio import asyncio
from unittest.mock import patch from collections.abc import Generator
from unittest.mock import Mock, patch
import pytest import pytest
@ -31,6 +32,14 @@ from . import ENTITY_ID
from .conftest import MockAssistSatellite from .conftest import MockAssistSatellite
@pytest.fixture
def mock_chat_session_conversation_id() -> Generator[Mock]:
"""Mock the ulid library."""
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
mock_ulid_now.return_value = "mock-conversation-id"
yield mock_ulid_now
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
async def set_pipeline_tts(hass: HomeAssistant, init_components: ConfigEntry) -> None: async def set_pipeline_tts(hass: HomeAssistant, init_components: ConfigEntry) -> None:
"""Set up a pipeline with a TTS engine.""" """Set up a pipeline with a TTS engine."""
@ -487,6 +496,7 @@ async def test_vad_sensitivity_entity_not_found(
"extra_system_prompt": "Better system prompt", "extra_system_prompt": "Better system prompt",
}, },
( (
"mock-conversation-id",
"Better system prompt", "Better system prompt",
AssistSatelliteAnnouncement( AssistSatelliteAnnouncement(
message="Hello", message="Hello",
@ -502,6 +512,7 @@ async def test_vad_sensitivity_entity_not_found(
"start_media_id": "media-source://given", "start_media_id": "media-source://given",
}, },
( (
"mock-conversation-id",
"Hello", "Hello",
AssistSatelliteAnnouncement( AssistSatelliteAnnouncement(
message="Hello", message="Hello",
@ -514,6 +525,7 @@ async def test_vad_sensitivity_entity_not_found(
( (
{"start_media_id": "http://example.com/given.mp3"}, {"start_media_id": "http://example.com/given.mp3"},
( (
"mock-conversation-id",
None, None,
AssistSatelliteAnnouncement( AssistSatelliteAnnouncement(
message="", message="",
@ -525,6 +537,7 @@ async def test_vad_sensitivity_entity_not_found(
), ),
], ],
) )
@pytest.mark.usefixtures("mock_chat_session_conversation_id")
async def test_start_conversation( async def test_start_conversation(
hass: HomeAssistant, hass: HomeAssistant,
init_components: ConfigEntry, init_components: ConfigEntry,