mirror of
https://github.com/home-assistant/core.git
synced 2025-04-24 01:08:12 +00:00
Assist Satellite to use ChatSession for conversation ID (#137142)
* Assist Satellite to use ChatSession for conversation ID * Adjust for changes main branch * Ensure the initial message is in the chat log
This commit is contained in:
parent
4531a46557
commit
8acab6c646
@ -8,7 +8,7 @@ from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Final, Literal, final
|
||||
from typing import Any, Literal, final
|
||||
|
||||
from homeassistant.components import conversation, media_source, stt, tts
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
@ -28,14 +28,12 @@ from homeassistant.components.tts import (
|
||||
)
|
||||
from homeassistant.core import Context, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import entity
|
||||
from homeassistant.helpers import chat_session, entity
|
||||
from homeassistant.helpers.entity import EntityDescription
|
||||
|
||||
from .const import AssistSatelliteEntityFeature
|
||||
from .errors import AssistSatelliteError, SatelliteBusyError
|
||||
|
||||
_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -114,7 +112,6 @@ class AssistSatelliteEntity(entity.Entity):
|
||||
_attr_vad_sensitivity_entity_id: str | None = None
|
||||
|
||||
_conversation_id: str | None = None
|
||||
_conversation_id_time: float | None = None
|
||||
|
||||
_run_has_tts: bool = False
|
||||
_is_announcing = False
|
||||
@ -260,6 +257,21 @@ class AssistSatelliteEntity(entity.Entity):
|
||||
else:
|
||||
self._extra_system_prompt = start_message or None
|
||||
|
||||
with (
|
||||
# Not passing in a conversation ID will force a new one to be created
|
||||
chat_session.async_get_chat_session(self.hass) as session,
|
||||
conversation.async_get_chat_log(self.hass, session) as chat_log,
|
||||
):
|
||||
self._conversation_id = session.conversation_id
|
||||
|
||||
if start_message:
|
||||
async for _tool_response in chat_log.async_add_assistant_content(
|
||||
conversation.AssistantContent(
|
||||
agent_id=self.entity_id, content=start_message
|
||||
)
|
||||
):
|
||||
pass # no tool responses.
|
||||
|
||||
try:
|
||||
await self.async_start_conversation(announcement)
|
||||
finally:
|
||||
@ -325,51 +337,52 @@ class AssistSatelliteEntity(entity.Entity):
|
||||
|
||||
assert self._context is not None
|
||||
|
||||
# Reset conversation id if necessary
|
||||
if self._conversation_id_time and (
|
||||
(time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC
|
||||
):
|
||||
self._conversation_id = None
|
||||
self._conversation_id_time = None
|
||||
|
||||
# Set entity state based on pipeline events
|
||||
self._run_has_tts = False
|
||||
|
||||
assert self.platform.config_entry is not None
|
||||
self._pipeline_task = self.platform.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
async_pipeline_from_audio_stream(
|
||||
self.hass,
|
||||
context=self._context,
|
||||
event_callback=self._internal_on_pipeline_event,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="", # set in async_pipeline_from_audio_stream
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=audio_stream,
|
||||
pipeline_id=self._resolve_pipeline(),
|
||||
conversation_id=self._conversation_id,
|
||||
device_id=device_id,
|
||||
tts_audio_output=self.tts_options,
|
||||
wake_word_phrase=wake_word_phrase,
|
||||
audio_settings=AudioSettings(
|
||||
silence_seconds=self._resolve_vad_sensitivity()
|
||||
),
|
||||
start_stage=start_stage,
|
||||
end_stage=end_stage,
|
||||
conversation_extra_system_prompt=extra_system_prompt,
|
||||
),
|
||||
f"{self.entity_id}_pipeline",
|
||||
)
|
||||
|
||||
try:
|
||||
await self._pipeline_task
|
||||
finally:
|
||||
self._pipeline_task = None
|
||||
with chat_session.async_get_chat_session(
|
||||
self.hass, self._conversation_id
|
||||
) as session:
|
||||
# Store the conversation ID. If it is no longer valid, get_chat_session will reset it
|
||||
self._conversation_id = session.conversation_id
|
||||
self._pipeline_task = (
|
||||
self.platform.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
async_pipeline_from_audio_stream(
|
||||
self.hass,
|
||||
context=self._context,
|
||||
event_callback=self._internal_on_pipeline_event,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="", # set in async_pipeline_from_audio_stream
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=audio_stream,
|
||||
pipeline_id=self._resolve_pipeline(),
|
||||
conversation_id=session.conversation_id,
|
||||
device_id=device_id,
|
||||
tts_audio_output=self.tts_options,
|
||||
wake_word_phrase=wake_word_phrase,
|
||||
audio_settings=AudioSettings(
|
||||
silence_seconds=self._resolve_vad_sensitivity()
|
||||
),
|
||||
start_stage=start_stage,
|
||||
end_stage=end_stage,
|
||||
conversation_extra_system_prompt=extra_system_prompt,
|
||||
),
|
||||
f"{self.entity_id}_pipeline",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
await self._pipeline_task
|
||||
finally:
|
||||
self._pipeline_task = None
|
||||
|
||||
async def _cancel_running_pipeline(self) -> None:
|
||||
"""Cancel the current pipeline if it's running."""
|
||||
@ -393,11 +406,6 @@ class AssistSatelliteEntity(entity.Entity):
|
||||
self._set_state(AssistSatelliteState.LISTENING)
|
||||
elif event.type is PipelineEventType.INTENT_START:
|
||||
self._set_state(AssistSatelliteState.PROCESSING)
|
||||
elif event.type is PipelineEventType.INTENT_END:
|
||||
assert event.data is not None
|
||||
# Update timeout
|
||||
self._conversation_id_time = time.monotonic()
|
||||
self._conversation_id = event.data["intent_output"]["conversation_id"]
|
||||
elif event.type is PipelineEventType.TTS_START:
|
||||
# Wait until tts_response_finished is called to return to waiting state
|
||||
self._run_has_tts = True
|
||||
|
@ -94,7 +94,9 @@ class MockAssistSatellite(AssistSatelliteEntity):
|
||||
self, start_announcement: AssistSatelliteConfiguration
|
||||
) -> None:
|
||||
"""Start a conversation from the satellite."""
|
||||
self.start_conversations.append((self._extra_system_prompt, start_announcement))
|
||||
self.start_conversations.append(
|
||||
(self._conversation_id, self._extra_system_prompt, start_announcement)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -1,7 +1,8 @@
|
||||
"""Test the Assist Satellite entity."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -31,6 +32,14 @@ from . import ENTITY_ID
|
||||
from .conftest import MockAssistSatellite
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_chat_session_conversation_id() -> Generator[Mock]:
|
||||
"""Mock the ulid library."""
|
||||
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
|
||||
mock_ulid_now.return_value = "mock-conversation-id"
|
||||
yield mock_ulid_now
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def set_pipeline_tts(hass: HomeAssistant, init_components: ConfigEntry) -> None:
|
||||
"""Set up a pipeline with a TTS engine."""
|
||||
@ -487,6 +496,7 @@ async def test_vad_sensitivity_entity_not_found(
|
||||
"extra_system_prompt": "Better system prompt",
|
||||
},
|
||||
(
|
||||
"mock-conversation-id",
|
||||
"Better system prompt",
|
||||
AssistSatelliteAnnouncement(
|
||||
message="Hello",
|
||||
@ -502,6 +512,7 @@ async def test_vad_sensitivity_entity_not_found(
|
||||
"start_media_id": "media-source://given",
|
||||
},
|
||||
(
|
||||
"mock-conversation-id",
|
||||
"Hello",
|
||||
AssistSatelliteAnnouncement(
|
||||
message="Hello",
|
||||
@ -514,6 +525,7 @@ async def test_vad_sensitivity_entity_not_found(
|
||||
(
|
||||
{"start_media_id": "http://example.com/given.mp3"},
|
||||
(
|
||||
"mock-conversation-id",
|
||||
None,
|
||||
AssistSatelliteAnnouncement(
|
||||
message="",
|
||||
@ -525,6 +537,7 @@ async def test_vad_sensitivity_entity_not_found(
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.usefixtures("mock_chat_session_conversation_id")
|
||||
async def test_start_conversation(
|
||||
hass: HomeAssistant,
|
||||
init_components: ConfigEntry,
|
||||
|
Loading…
x
Reference in New Issue
Block a user