From 8acab6c646b65a1aa79613d3d8a1170456449ca2 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 3 Feb 2025 10:13:09 -0500 Subject: [PATCH] Assist Satellite to use ChatSession for conversation ID (#137142) * Assist Satellite to use ChatSession for conversation ID * Adjust for changes main branch * Ensure the initial message is in the chat log --- .../components/assist_satellite/entity.py | 108 ++++++++++-------- tests/components/assist_satellite/conftest.py | 4 +- .../assist_satellite/test_entity.py | 15 ++- 3 files changed, 75 insertions(+), 52 deletions(-) diff --git a/homeassistant/components/assist_satellite/entity.py b/homeassistant/components/assist_satellite/entity.py index 0229e0358b1..c901bc7d928 100644 --- a/homeassistant/components/assist_satellite/entity.py +++ b/homeassistant/components/assist_satellite/entity.py @@ -8,7 +8,7 @@ from dataclasses import dataclass from enum import StrEnum import logging import time -from typing import Any, Final, Literal, final +from typing import Any, Literal, final from homeassistant.components import conversation, media_source, stt, tts from homeassistant.components.assist_pipeline import ( @@ -28,14 +28,12 @@ from homeassistant.components.tts import ( ) from homeassistant.core import Context, callback from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import entity +from homeassistant.helpers import chat_session, entity from homeassistant.helpers.entity import EntityDescription from .const import AssistSatelliteEntityFeature from .errors import AssistSatelliteError, SatelliteBusyError -_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes - _LOGGER = logging.getLogger(__name__) @@ -114,7 +112,6 @@ class AssistSatelliteEntity(entity.Entity): _attr_vad_sensitivity_entity_id: str | None = None _conversation_id: str | None = None - _conversation_id_time: float | None = None _run_has_tts: bool = False _is_announcing = False @@ -260,6 +257,21 @@ class AssistSatelliteEntity(entity.Entity): else: self._extra_system_prompt = start_message or None + with ( + # Not passing in a conversation ID will force a new one to be created + chat_session.async_get_chat_session(self.hass) as session, + conversation.async_get_chat_log(self.hass, session) as chat_log, + ): + self._conversation_id = session.conversation_id + + if start_message: + async for _tool_response in chat_log.async_add_assistant_content( + conversation.AssistantContent( + agent_id=self.entity_id, content=start_message + ) + ): + pass # no tool responses. + try: await self.async_start_conversation(announcement) finally: @@ -325,51 +337,52 @@ class AssistSatelliteEntity(entity.Entity): assert self._context is not None - # Reset conversation id if necessary - if self._conversation_id_time and ( - (time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC - ): - self._conversation_id = None - self._conversation_id_time = None - # Set entity state based on pipeline events self._run_has_tts = False assert self.platform.config_entry is not None - self._pipeline_task = self.platform.config_entry.async_create_background_task( - self.hass, - async_pipeline_from_audio_stream( - self.hass, - context=self._context, - event_callback=self._internal_on_pipeline_event, - stt_metadata=stt.SpeechMetadata( - language="", # set in async_pipeline_from_audio_stream - format=stt.AudioFormats.WAV, - codec=stt.AudioCodecs.PCM, - bit_rate=stt.AudioBitRates.BITRATE_16, - sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, - channel=stt.AudioChannels.CHANNEL_MONO, - ), - stt_stream=audio_stream, - pipeline_id=self._resolve_pipeline(), - conversation_id=self._conversation_id, - device_id=device_id, - tts_audio_output=self.tts_options, - wake_word_phrase=wake_word_phrase, - audio_settings=AudioSettings( - silence_seconds=self._resolve_vad_sensitivity() - ), - start_stage=start_stage, - end_stage=end_stage, - conversation_extra_system_prompt=extra_system_prompt, - ), - f"{self.entity_id}_pipeline", - ) - try: - await self._pipeline_task - finally: - self._pipeline_task = None + with chat_session.async_get_chat_session( + self.hass, self._conversation_id + ) as session: + # Store the conversation ID. If it is no longer valid, get_chat_session will reset it + self._conversation_id = session.conversation_id + self._pipeline_task = ( + self.platform.config_entry.async_create_background_task( + self.hass, + async_pipeline_from_audio_stream( + self.hass, + context=self._context, + event_callback=self._internal_on_pipeline_event, + stt_metadata=stt.SpeechMetadata( + language="", # set in async_pipeline_from_audio_stream + format=stt.AudioFormats.WAV, + codec=stt.AudioCodecs.PCM, + bit_rate=stt.AudioBitRates.BITRATE_16, + sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, + channel=stt.AudioChannels.CHANNEL_MONO, + ), + stt_stream=audio_stream, + pipeline_id=self._resolve_pipeline(), + conversation_id=session.conversation_id, + device_id=device_id, + tts_audio_output=self.tts_options, + wake_word_phrase=wake_word_phrase, + audio_settings=AudioSettings( + silence_seconds=self._resolve_vad_sensitivity() + ), + start_stage=start_stage, + end_stage=end_stage, + conversation_extra_system_prompt=extra_system_prompt, + ), + f"{self.entity_id}_pipeline", + ) + ) + + try: + await self._pipeline_task + finally: + self._pipeline_task = None async def _cancel_running_pipeline(self) -> None: """Cancel the current pipeline if it's running.""" @@ -393,11 +406,6 @@ class AssistSatelliteEntity(entity.Entity): self._set_state(AssistSatelliteState.LISTENING) elif event.type is PipelineEventType.INTENT_START: self._set_state(AssistSatelliteState.PROCESSING) - elif event.type is PipelineEventType.INTENT_END: - assert event.data is not None - # Update timeout - self._conversation_id_time = time.monotonic() - self._conversation_id = event.data["intent_output"]["conversation_id"] elif event.type is PipelineEventType.TTS_START: # Wait until tts_response_finished is called to return to waiting state self._run_has_tts = True diff --git a/tests/components/assist_satellite/conftest.py b/tests/components/assist_satellite/conftest.py index 0cc0e94e149..79e4061bacc 100644 --- a/tests/components/assist_satellite/conftest.py +++ b/tests/components/assist_satellite/conftest.py @@ -94,7 +94,9 @@ class MockAssistSatellite(AssistSatelliteEntity): self, start_announcement: AssistSatelliteConfiguration ) -> None: """Start a conversation from the satellite.""" - self.start_conversations.append((self._extra_system_prompt, start_announcement)) + self.start_conversations.append( + (self._conversation_id, self._extra_system_prompt, start_announcement) + ) @pytest.fixture diff --git a/tests/components/assist_satellite/test_entity.py b/tests/components/assist_satellite/test_entity.py index 46facb80844..b3437bf5c5d 100644 --- a/tests/components/assist_satellite/test_entity.py +++ b/tests/components/assist_satellite/test_entity.py @@ -1,7 +1,8 @@ """Test the Assist Satellite entity.""" import asyncio -from unittest.mock import patch +from collections.abc import Generator +from unittest.mock import Mock, patch import pytest @@ -31,6 +32,14 @@ from . import ENTITY_ID from .conftest import MockAssistSatellite +@pytest.fixture +def mock_chat_session_conversation_id() -> Generator[Mock]: + """Mock the ulid library.""" + with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now: + mock_ulid_now.return_value = "mock-conversation-id" + yield mock_ulid_now + + @pytest.fixture(autouse=True) async def set_pipeline_tts(hass: HomeAssistant, init_components: ConfigEntry) -> None: """Set up a pipeline with a TTS engine.""" @@ -487,6 +496,7 @@ async def test_vad_sensitivity_entity_not_found( "extra_system_prompt": "Better system prompt", }, ( + "mock-conversation-id", "Better system prompt", AssistSatelliteAnnouncement( message="Hello", @@ -502,6 +512,7 @@ async def test_vad_sensitivity_entity_not_found( "start_media_id": "media-source://given", }, ( + "mock-conversation-id", "Hello", AssistSatelliteAnnouncement( message="Hello", @@ -514,6 +525,7 @@ async def test_vad_sensitivity_entity_not_found( ( {"start_media_id": "http://example.com/given.mp3"}, ( + "mock-conversation-id", None, AssistSatelliteAnnouncement( message="", @@ -525,6 +537,7 @@ async def test_vad_sensitivity_entity_not_found( ), ], ) +@pytest.mark.usefixtures("mock_chat_session_conversation_id") async def test_start_conversation( hass: HomeAssistant, init_components: ConfigEntry,