mirror of
https://github.com/home-assistant/core.git
synced 2025-07-13 16:27:08 +00:00
Assist Satellite to use ChatSession for conversation ID (#137142)
* Assist Satellite to use ChatSession for conversation ID * Adjust for changes main branch * Ensure the initial message is in the chat log
This commit is contained in:
parent
4531a46557
commit
8acab6c646
@ -8,7 +8,7 @@ from dataclasses import dataclass
|
|||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Any, Final, Literal, final
|
from typing import Any, Literal, final
|
||||||
|
|
||||||
from homeassistant.components import conversation, media_source, stt, tts
|
from homeassistant.components import conversation, media_source, stt, tts
|
||||||
from homeassistant.components.assist_pipeline import (
|
from homeassistant.components.assist_pipeline import (
|
||||||
@ -28,14 +28,12 @@ from homeassistant.components.tts import (
|
|||||||
)
|
)
|
||||||
from homeassistant.core import Context, callback
|
from homeassistant.core import Context, callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import entity
|
from homeassistant.helpers import chat_session, entity
|
||||||
from homeassistant.helpers.entity import EntityDescription
|
from homeassistant.helpers.entity import EntityDescription
|
||||||
|
|
||||||
from .const import AssistSatelliteEntityFeature
|
from .const import AssistSatelliteEntityFeature
|
||||||
from .errors import AssistSatelliteError, SatelliteBusyError
|
from .errors import AssistSatelliteError, SatelliteBusyError
|
||||||
|
|
||||||
_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -114,7 +112,6 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
_attr_vad_sensitivity_entity_id: str | None = None
|
_attr_vad_sensitivity_entity_id: str | None = None
|
||||||
|
|
||||||
_conversation_id: str | None = None
|
_conversation_id: str | None = None
|
||||||
_conversation_id_time: float | None = None
|
|
||||||
|
|
||||||
_run_has_tts: bool = False
|
_run_has_tts: bool = False
|
||||||
_is_announcing = False
|
_is_announcing = False
|
||||||
@ -260,6 +257,21 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
else:
|
else:
|
||||||
self._extra_system_prompt = start_message or None
|
self._extra_system_prompt = start_message or None
|
||||||
|
|
||||||
|
with (
|
||||||
|
# Not passing in a conversation ID will force a new one to be created
|
||||||
|
chat_session.async_get_chat_session(self.hass) as session,
|
||||||
|
conversation.async_get_chat_log(self.hass, session) as chat_log,
|
||||||
|
):
|
||||||
|
self._conversation_id = session.conversation_id
|
||||||
|
|
||||||
|
if start_message:
|
||||||
|
async for _tool_response in chat_log.async_add_assistant_content(
|
||||||
|
conversation.AssistantContent(
|
||||||
|
agent_id=self.entity_id, content=start_message
|
||||||
|
)
|
||||||
|
):
|
||||||
|
pass # no tool responses.
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.async_start_conversation(announcement)
|
await self.async_start_conversation(announcement)
|
||||||
finally:
|
finally:
|
||||||
@ -325,51 +337,52 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
|
|
||||||
assert self._context is not None
|
assert self._context is not None
|
||||||
|
|
||||||
# Reset conversation id if necessary
|
|
||||||
if self._conversation_id_time and (
|
|
||||||
(time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC
|
|
||||||
):
|
|
||||||
self._conversation_id = None
|
|
||||||
self._conversation_id_time = None
|
|
||||||
|
|
||||||
# Set entity state based on pipeline events
|
# Set entity state based on pipeline events
|
||||||
self._run_has_tts = False
|
self._run_has_tts = False
|
||||||
|
|
||||||
assert self.platform.config_entry is not None
|
assert self.platform.config_entry is not None
|
||||||
self._pipeline_task = self.platform.config_entry.async_create_background_task(
|
|
||||||
self.hass,
|
|
||||||
async_pipeline_from_audio_stream(
|
|
||||||
self.hass,
|
|
||||||
context=self._context,
|
|
||||||
event_callback=self._internal_on_pipeline_event,
|
|
||||||
stt_metadata=stt.SpeechMetadata(
|
|
||||||
language="", # set in async_pipeline_from_audio_stream
|
|
||||||
format=stt.AudioFormats.WAV,
|
|
||||||
codec=stt.AudioCodecs.PCM,
|
|
||||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
|
||||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
|
||||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
|
||||||
),
|
|
||||||
stt_stream=audio_stream,
|
|
||||||
pipeline_id=self._resolve_pipeline(),
|
|
||||||
conversation_id=self._conversation_id,
|
|
||||||
device_id=device_id,
|
|
||||||
tts_audio_output=self.tts_options,
|
|
||||||
wake_word_phrase=wake_word_phrase,
|
|
||||||
audio_settings=AudioSettings(
|
|
||||||
silence_seconds=self._resolve_vad_sensitivity()
|
|
||||||
),
|
|
||||||
start_stage=start_stage,
|
|
||||||
end_stage=end_stage,
|
|
||||||
conversation_extra_system_prompt=extra_system_prompt,
|
|
||||||
),
|
|
||||||
f"{self.entity_id}_pipeline",
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
with chat_session.async_get_chat_session(
|
||||||
await self._pipeline_task
|
self.hass, self._conversation_id
|
||||||
finally:
|
) as session:
|
||||||
self._pipeline_task = None
|
# Store the conversation ID. If it is no longer valid, get_chat_session will reset it
|
||||||
|
self._conversation_id = session.conversation_id
|
||||||
|
self._pipeline_task = (
|
||||||
|
self.platform.config_entry.async_create_background_task(
|
||||||
|
self.hass,
|
||||||
|
async_pipeline_from_audio_stream(
|
||||||
|
self.hass,
|
||||||
|
context=self._context,
|
||||||
|
event_callback=self._internal_on_pipeline_event,
|
||||||
|
stt_metadata=stt.SpeechMetadata(
|
||||||
|
language="", # set in async_pipeline_from_audio_stream
|
||||||
|
format=stt.AudioFormats.WAV,
|
||||||
|
codec=stt.AudioCodecs.PCM,
|
||||||
|
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||||
|
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||||
|
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||||
|
),
|
||||||
|
stt_stream=audio_stream,
|
||||||
|
pipeline_id=self._resolve_pipeline(),
|
||||||
|
conversation_id=session.conversation_id,
|
||||||
|
device_id=device_id,
|
||||||
|
tts_audio_output=self.tts_options,
|
||||||
|
wake_word_phrase=wake_word_phrase,
|
||||||
|
audio_settings=AudioSettings(
|
||||||
|
silence_seconds=self._resolve_vad_sensitivity()
|
||||||
|
),
|
||||||
|
start_stage=start_stage,
|
||||||
|
end_stage=end_stage,
|
||||||
|
conversation_extra_system_prompt=extra_system_prompt,
|
||||||
|
),
|
||||||
|
f"{self.entity_id}_pipeline",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._pipeline_task
|
||||||
|
finally:
|
||||||
|
self._pipeline_task = None
|
||||||
|
|
||||||
async def _cancel_running_pipeline(self) -> None:
|
async def _cancel_running_pipeline(self) -> None:
|
||||||
"""Cancel the current pipeline if it's running."""
|
"""Cancel the current pipeline if it's running."""
|
||||||
@ -393,11 +406,6 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
self._set_state(AssistSatelliteState.LISTENING)
|
self._set_state(AssistSatelliteState.LISTENING)
|
||||||
elif event.type is PipelineEventType.INTENT_START:
|
elif event.type is PipelineEventType.INTENT_START:
|
||||||
self._set_state(AssistSatelliteState.PROCESSING)
|
self._set_state(AssistSatelliteState.PROCESSING)
|
||||||
elif event.type is PipelineEventType.INTENT_END:
|
|
||||||
assert event.data is not None
|
|
||||||
# Update timeout
|
|
||||||
self._conversation_id_time = time.monotonic()
|
|
||||||
self._conversation_id = event.data["intent_output"]["conversation_id"]
|
|
||||||
elif event.type is PipelineEventType.TTS_START:
|
elif event.type is PipelineEventType.TTS_START:
|
||||||
# Wait until tts_response_finished is called to return to waiting state
|
# Wait until tts_response_finished is called to return to waiting state
|
||||||
self._run_has_tts = True
|
self._run_has_tts = True
|
||||||
|
@ -94,7 +94,9 @@ class MockAssistSatellite(AssistSatelliteEntity):
|
|||||||
self, start_announcement: AssistSatelliteConfiguration
|
self, start_announcement: AssistSatelliteConfiguration
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start a conversation from the satellite."""
|
"""Start a conversation from the satellite."""
|
||||||
self.start_conversations.append((self._extra_system_prompt, start_announcement))
|
self.start_conversations.append(
|
||||||
|
(self._conversation_id, self._extra_system_prompt, start_announcement)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
"""Test the Assist Satellite entity."""
|
"""Test the Assist Satellite entity."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from unittest.mock import patch
|
from collections.abc import Generator
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -31,6 +32,14 @@ from . import ENTITY_ID
|
|||||||
from .conftest import MockAssistSatellite
|
from .conftest import MockAssistSatellite
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_chat_session_conversation_id() -> Generator[Mock]:
|
||||||
|
"""Mock the ulid library."""
|
||||||
|
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
|
||||||
|
mock_ulid_now.return_value = "mock-conversation-id"
|
||||||
|
yield mock_ulid_now
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
async def set_pipeline_tts(hass: HomeAssistant, init_components: ConfigEntry) -> None:
|
async def set_pipeline_tts(hass: HomeAssistant, init_components: ConfigEntry) -> None:
|
||||||
"""Set up a pipeline with a TTS engine."""
|
"""Set up a pipeline with a TTS engine."""
|
||||||
@ -487,6 +496,7 @@ async def test_vad_sensitivity_entity_not_found(
|
|||||||
"extra_system_prompt": "Better system prompt",
|
"extra_system_prompt": "Better system prompt",
|
||||||
},
|
},
|
||||||
(
|
(
|
||||||
|
"mock-conversation-id",
|
||||||
"Better system prompt",
|
"Better system prompt",
|
||||||
AssistSatelliteAnnouncement(
|
AssistSatelliteAnnouncement(
|
||||||
message="Hello",
|
message="Hello",
|
||||||
@ -502,6 +512,7 @@ async def test_vad_sensitivity_entity_not_found(
|
|||||||
"start_media_id": "media-source://given",
|
"start_media_id": "media-source://given",
|
||||||
},
|
},
|
||||||
(
|
(
|
||||||
|
"mock-conversation-id",
|
||||||
"Hello",
|
"Hello",
|
||||||
AssistSatelliteAnnouncement(
|
AssistSatelliteAnnouncement(
|
||||||
message="Hello",
|
message="Hello",
|
||||||
@ -514,6 +525,7 @@ async def test_vad_sensitivity_entity_not_found(
|
|||||||
(
|
(
|
||||||
{"start_media_id": "http://example.com/given.mp3"},
|
{"start_media_id": "http://example.com/given.mp3"},
|
||||||
(
|
(
|
||||||
|
"mock-conversation-id",
|
||||||
None,
|
None,
|
||||||
AssistSatelliteAnnouncement(
|
AssistSatelliteAnnouncement(
|
||||||
message="",
|
message="",
|
||||||
@ -525,6 +537,7 @@ async def test_vad_sensitivity_entity_not_found(
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@pytest.mark.usefixtures("mock_chat_session_conversation_id")
|
||||||
async def test_start_conversation(
|
async def test_start_conversation(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
init_components: ConfigEntry,
|
init_components: ConfigEntry,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user