mirror of
https://github.com/home-assistant/core.git
synced 2025-07-17 18:27:09 +00:00
Extract conversation ID generation to helper (#137062)
* Extract conversation ID generation to helper * Allow nested get_chat_log calls
This commit is contained in:
parent
30314ca32b
commit
2f6640707b
@ -33,7 +33,7 @@ from homeassistant.components.tts import (
|
|||||||
from homeassistant.const import MATCH_ALL
|
from homeassistant.const import MATCH_ALL
|
||||||
from homeassistant.core import Context, HomeAssistant, callback
|
from homeassistant.core import Context, HomeAssistant, callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import intent
|
from homeassistant.helpers import chat_session, intent
|
||||||
from homeassistant.helpers.collection import (
|
from homeassistant.helpers.collection import (
|
||||||
CHANGE_UPDATED,
|
CHANGE_UPDATED,
|
||||||
CollectionError,
|
CollectionError,
|
||||||
@ -1094,13 +1094,18 @@ class PipelineRun:
|
|||||||
|
|
||||||
# It was already handled, create response and add to chat history
|
# It was already handled, create response and add to chat history
|
||||||
if intent_response is not None:
|
if intent_response is not None:
|
||||||
async with conversation.async_get_chat_session(
|
async with (
|
||||||
self.hass, user_input
|
chat_session.async_get_chat_session(
|
||||||
) as chat_session:
|
self.hass, user_input.conversation_id
|
||||||
|
) as session,
|
||||||
|
conversation.async_get_chat_log(
|
||||||
|
self.hass, session, user_input
|
||||||
|
) as chat_log,
|
||||||
|
):
|
||||||
speech: str = intent_response.speech.get("plain", {}).get(
|
speech: str = intent_response.speech.get("plain", {}).get(
|
||||||
"speech", ""
|
"speech", ""
|
||||||
)
|
)
|
||||||
chat_session.async_add_message(
|
chat_log.async_add_message(
|
||||||
conversation.Content(
|
conversation.Content(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
@ -1109,7 +1114,7 @@ class PipelineRun:
|
|||||||
)
|
)
|
||||||
conversation_result = conversation.ConversationResult(
|
conversation_result = conversation.ConversationResult(
|
||||||
response=intent_response,
|
response=intent_response,
|
||||||
conversation_id=chat_session.conversation_id,
|
conversation_id=session.conversation_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -48,20 +48,14 @@ from .default_agent import DefaultAgent, async_setup_default_agent
|
|||||||
from .entity import ConversationEntity
|
from .entity import ConversationEntity
|
||||||
from .http import async_setup as async_setup_conversation_http
|
from .http import async_setup as async_setup_conversation_http
|
||||||
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
|
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
|
||||||
from .session import (
|
from .session import ChatLog, Content, ConverseError, NativeContent, async_get_chat_log
|
||||||
ChatSession,
|
|
||||||
Content,
|
|
||||||
ConverseError,
|
|
||||||
NativeContent,
|
|
||||||
async_get_chat_session,
|
|
||||||
)
|
|
||||||
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DOMAIN",
|
"DOMAIN",
|
||||||
"HOME_ASSISTANT_AGENT",
|
"HOME_ASSISTANT_AGENT",
|
||||||
"OLD_HOME_ASSISTANT_AGENT",
|
"OLD_HOME_ASSISTANT_AGENT",
|
||||||
"ChatSession",
|
"ChatLog",
|
||||||
"Content",
|
"Content",
|
||||||
"ConversationEntity",
|
"ConversationEntity",
|
||||||
"ConversationEntityFeature",
|
"ConversationEntityFeature",
|
||||||
@ -73,7 +67,7 @@ __all__ = [
|
|||||||
"async_conversation_trace_append",
|
"async_conversation_trace_append",
|
||||||
"async_converse",
|
"async_converse",
|
||||||
"async_get_agent_info",
|
"async_get_agent_info",
|
||||||
"async_get_chat_session",
|
"async_get_chat_log",
|
||||||
"async_set_agent",
|
"async_set_agent",
|
||||||
"async_setup",
|
"async_setup",
|
||||||
"async_unset_agent",
|
"async_unset_agent",
|
||||||
|
@ -42,6 +42,7 @@ from homeassistant.components.homeassistant.exposed_entities import (
|
|||||||
from homeassistant.const import EVENT_STATE_CHANGED, MATCH_ALL
|
from homeassistant.const import EVENT_STATE_CHANGED, MATCH_ALL
|
||||||
from homeassistant.helpers import (
|
from homeassistant.helpers import (
|
||||||
area_registry as ar,
|
area_registry as ar,
|
||||||
|
chat_session,
|
||||||
device_registry as dr,
|
device_registry as dr,
|
||||||
entity_registry as er,
|
entity_registry as er,
|
||||||
floor_registry as fr,
|
floor_registry as fr,
|
||||||
@ -62,7 +63,7 @@ from .const import (
|
|||||||
)
|
)
|
||||||
from .entity import ConversationEntity
|
from .entity import ConversationEntity
|
||||||
from .models import ConversationInput, ConversationResult
|
from .models import ConversationInput, ConversationResult
|
||||||
from .session import Content, async_get_chat_session
|
from .session import Content, async_get_chat_log
|
||||||
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
@ -348,7 +349,12 @@ class DefaultAgent(ConversationEntity):
|
|||||||
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
|
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
|
||||||
"""Process a sentence."""
|
"""Process a sentence."""
|
||||||
response: intent.IntentResponse | None = None
|
response: intent.IntentResponse | None = None
|
||||||
async with async_get_chat_session(self.hass, user_input) as chat_session:
|
async with (
|
||||||
|
chat_session.async_get_chat_session(
|
||||||
|
self.hass, user_input.conversation_id
|
||||||
|
) as session,
|
||||||
|
async_get_chat_log(self.hass, session, user_input) as chat_log,
|
||||||
|
):
|
||||||
# Check if a trigger matched
|
# Check if a trigger matched
|
||||||
if trigger_result := await self.async_recognize_sentence_trigger(
|
if trigger_result := await self.async_recognize_sentence_trigger(
|
||||||
user_input
|
user_input
|
||||||
@ -373,7 +379,7 @@ class DefaultAgent(ConversationEntity):
|
|||||||
)
|
)
|
||||||
|
|
||||||
speech: str = response.speech.get("plain", {}).get("speech", "")
|
speech: str = response.speech.get("plain", {}).get("speech", "")
|
||||||
chat_session.async_add_message(
|
chat_log.async_add_message(
|
||||||
Content(
|
Content(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
agent_id=user_input.agent_id,
|
agent_id=user_input.agent_id,
|
||||||
@ -382,7 +388,7 @@ class DefaultAgent(ConversationEntity):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return ConversationResult(
|
return ConversationResult(
|
||||||
response=response, conversation_id=chat_session.conversation_id
|
response=response, conversation_id=session.conversation_id
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _async_process_intent_result(
|
async def _async_process_intent_result(
|
||||||
|
@ -5,25 +5,16 @@ from __future__ import annotations
|
|||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from dataclasses import dataclass, field, replace
|
from dataclasses import dataclass, field, replace
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime
|
||||||
import logging
|
import logging
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.core import (
|
|
||||||
CALLBACK_TYPE,
|
|
||||||
Event,
|
|
||||||
HassJob,
|
|
||||||
HassJobType,
|
|
||||||
HomeAssistant,
|
|
||||||
callback,
|
|
||||||
)
|
|
||||||
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
||||||
from homeassistant.helpers import intent, llm, template
|
from homeassistant.helpers import chat_session, intent, llm, template
|
||||||
from homeassistant.helpers.event import async_call_later
|
from homeassistant.util import dt as dt_util
|
||||||
from homeassistant.util import dt as dt_util, ulid as ulid_util
|
|
||||||
from homeassistant.util.hass_dict import HassKey
|
from homeassistant.util.hass_dict import HassKey
|
||||||
from homeassistant.util.json import JsonObjectType
|
from homeassistant.util.json import JsonObjectType
|
||||||
|
|
||||||
@ -31,100 +22,36 @@ from . import trace
|
|||||||
from .const import DOMAIN
|
from .const import DOMAIN
|
||||||
from .models import ConversationInput, ConversationResult
|
from .models import ConversationInput, ConversationResult
|
||||||
|
|
||||||
DATA_CHAT_HISTORY: HassKey[dict[str, ChatSession]] = HassKey(
|
DATA_CHAT_HISTORY: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_log")
|
||||||
"conversation_chat_session"
|
|
||||||
)
|
|
||||||
DATA_CHAT_HISTORY_CLEANUP: HassKey[SessionCleanup] = HassKey(
|
|
||||||
"conversation_chat_session_cleanup"
|
|
||||||
)
|
|
||||||
|
|
||||||
LOGGER = logging.getLogger(__name__)
|
LOGGER = logging.getLogger(__name__)
|
||||||
CONVERSATION_TIMEOUT = timedelta(minutes=5)
|
|
||||||
|
|
||||||
|
|
||||||
class SessionCleanup:
|
|
||||||
"""Helper to clean up the history."""
|
|
||||||
|
|
||||||
unsub: CALLBACK_TYPE | None = None
|
|
||||||
|
|
||||||
def __init__(self, hass: HomeAssistant) -> None:
|
|
||||||
"""Initialize the history cleanup."""
|
|
||||||
self.hass = hass
|
|
||||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self._on_hass_stop)
|
|
||||||
self.cleanup_job = HassJob(
|
|
||||||
self._cleanup, "conversation_history_cleanup", job_type=HassJobType.Callback
|
|
||||||
)
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def schedule(self) -> None:
|
|
||||||
"""Schedule the cleanup."""
|
|
||||||
if self.unsub:
|
|
||||||
return
|
|
||||||
self.unsub = async_call_later(
|
|
||||||
self.hass,
|
|
||||||
CONVERSATION_TIMEOUT.total_seconds() + 1,
|
|
||||||
self.cleanup_job,
|
|
||||||
)
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def _on_hass_stop(self, event: Event) -> None:
|
|
||||||
"""Cancel the cleanup on shutdown."""
|
|
||||||
if self.unsub:
|
|
||||||
self.unsub()
|
|
||||||
self.unsub = None
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def _cleanup(self, now: datetime) -> None:
|
|
||||||
"""Clean up the history and schedule follow-up if necessary."""
|
|
||||||
self.unsub = None
|
|
||||||
all_history = self.hass.data[DATA_CHAT_HISTORY]
|
|
||||||
|
|
||||||
# We mutate original object because current commands could be
|
|
||||||
# yielding history based on it.
|
|
||||||
for conversation_id, history in list(all_history.items()):
|
|
||||||
if history.last_updated + CONVERSATION_TIMEOUT < now:
|
|
||||||
del all_history[conversation_id]
|
|
||||||
|
|
||||||
# Still conversations left, check again in timeout time.
|
|
||||||
if all_history:
|
|
||||||
self.schedule()
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def async_get_chat_session(
|
async def async_get_chat_log(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
|
session: chat_session.ChatSession,
|
||||||
user_input: ConversationInput,
|
user_input: ConversationInput,
|
||||||
) -> AsyncGenerator[ChatSession]:
|
) -> AsyncGenerator[ChatLog]:
|
||||||
"""Return chat session."""
|
"""Return chat log for a specific chat session."""
|
||||||
all_history = hass.data.get(DATA_CHAT_HISTORY)
|
all_history = hass.data.get(DATA_CHAT_HISTORY)
|
||||||
if all_history is None:
|
if all_history is None:
|
||||||
all_history = {}
|
all_history = {}
|
||||||
hass.data[DATA_CHAT_HISTORY] = all_history
|
hass.data[DATA_CHAT_HISTORY] = all_history
|
||||||
hass.data[DATA_CHAT_HISTORY_CLEANUP] = SessionCleanup(hass)
|
|
||||||
|
|
||||||
history: ChatSession | None = None
|
history = all_history.get(session.conversation_id)
|
||||||
|
|
||||||
if user_input.conversation_id is None:
|
|
||||||
conversation_id = ulid_util.ulid_now()
|
|
||||||
|
|
||||||
elif history := all_history.get(user_input.conversation_id):
|
|
||||||
conversation_id = user_input.conversation_id
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Conversation IDs are ULIDs. We generate a new one if not provided.
|
|
||||||
# If an old OLID is passed in, we will generate a new one to indicate
|
|
||||||
# a new conversation was started. If the user picks their own, they
|
|
||||||
# want to track a conversation and we respect it.
|
|
||||||
try:
|
|
||||||
ulid_util.ulid_to_bytes(user_input.conversation_id)
|
|
||||||
conversation_id = ulid_util.ulid_now()
|
|
||||||
except ValueError:
|
|
||||||
conversation_id = user_input.conversation_id
|
|
||||||
|
|
||||||
if history:
|
if history:
|
||||||
history = replace(history, messages=history.messages.copy())
|
history = replace(history, messages=history.messages.copy())
|
||||||
else:
|
else:
|
||||||
history = ChatSession(hass, conversation_id, user_input.agent_id)
|
history = ChatLog(hass, session.conversation_id, user_input.agent_id)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def do_cleanup() -> None:
|
||||||
|
"""Handle cleanup."""
|
||||||
|
all_history.pop(session.conversation_id)
|
||||||
|
|
||||||
|
session.async_on_cleanup(do_cleanup)
|
||||||
|
|
||||||
message: Content = Content(
|
message: Content = Content(
|
||||||
role="user",
|
role="user",
|
||||||
@ -142,8 +69,7 @@ async def async_get_chat_session(
|
|||||||
return
|
return
|
||||||
|
|
||||||
history.last_updated = dt_util.utcnow()
|
history.last_updated = dt_util.utcnow()
|
||||||
all_history[conversation_id] = history
|
all_history[session.conversation_id] = history
|
||||||
hass.data[DATA_CHAT_HISTORY_CLEANUP].schedule()
|
|
||||||
|
|
||||||
|
|
||||||
class ConverseError(HomeAssistantError):
|
class ConverseError(HomeAssistantError):
|
||||||
@ -187,8 +113,8 @@ class NativeContent[_NativeT]:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ChatSession[_NativeT]:
|
class ChatLog[_NativeT]:
|
||||||
"""Class holding all information for a specific conversation."""
|
"""Class holding the chat history of a specific conversation."""
|
||||||
|
|
||||||
hass: HomeAssistant
|
hass: HomeAssistant
|
||||||
conversation_id: str
|
conversation_id: str
|
||||||
|
@ -18,7 +18,7 @@ from homeassistant.config_entries import ConfigEntry
|
|||||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import device_registry as dr, intent, llm
|
from homeassistant.helpers import chat_session, device_registry as dr, intent, llm
|
||||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
|
|
||||||
from .const import (
|
from .const import (
|
||||||
@ -209,15 +209,18 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
self, user_input: conversation.ConversationInput
|
self, user_input: conversation.ConversationInput
|
||||||
) -> conversation.ConversationResult:
|
) -> conversation.ConversationResult:
|
||||||
"""Process a sentence."""
|
"""Process a sentence."""
|
||||||
async with conversation.async_get_chat_session(
|
async with (
|
||||||
self.hass, user_input
|
chat_session.async_get_chat_session(
|
||||||
) as session:
|
self.hass, user_input.conversation_id
|
||||||
return await self._async_handle_message(user_input, session)
|
) as session,
|
||||||
|
conversation.async_get_chat_log(self.hass, session, user_input) as chat_log,
|
||||||
|
):
|
||||||
|
return await self._async_handle_message(user_input, chat_log)
|
||||||
|
|
||||||
async def _async_handle_message(
|
async def _async_handle_message(
|
||||||
self,
|
self,
|
||||||
user_input: conversation.ConversationInput,
|
user_input: conversation.ConversationInput,
|
||||||
session: conversation.ChatSession[genai_types.ContentDict],
|
session: conversation.ChatLog[genai_types.ContentDict],
|
||||||
) -> conversation.ConversationResult:
|
) -> conversation.ConversationResult:
|
||||||
"""Call the API."""
|
"""Call the API."""
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ from homeassistant.config_entries import ConfigEntry
|
|||||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import device_registry as dr, intent, llm
|
from homeassistant.helpers import chat_session, device_registry as dr, intent, llm
|
||||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
|
|
||||||
from . import OpenAIConfigEntry
|
from . import OpenAIConfigEntry
|
||||||
@ -155,15 +155,18 @@ class OpenAIConversationEntity(
|
|||||||
self, user_input: conversation.ConversationInput
|
self, user_input: conversation.ConversationInput
|
||||||
) -> conversation.ConversationResult:
|
) -> conversation.ConversationResult:
|
||||||
"""Process a sentence."""
|
"""Process a sentence."""
|
||||||
async with conversation.async_get_chat_session(
|
async with (
|
||||||
self.hass, user_input
|
chat_session.async_get_chat_session(
|
||||||
) as session:
|
self.hass, user_input.conversation_id
|
||||||
return await self._async_handle_message(user_input, session)
|
) as session,
|
||||||
|
conversation.async_get_chat_log(self.hass, session, user_input) as chat_log,
|
||||||
|
):
|
||||||
|
return await self._async_handle_message(user_input, chat_log)
|
||||||
|
|
||||||
async def _async_handle_message(
|
async def _async_handle_message(
|
||||||
self,
|
self,
|
||||||
user_input: conversation.ConversationInput,
|
user_input: conversation.ConversationInput,
|
||||||
session: conversation.ChatSession[ChatCompletionMessageParam],
|
session: conversation.ChatLog[ChatCompletionMessageParam],
|
||||||
) -> conversation.ConversationResult:
|
) -> conversation.ConversationResult:
|
||||||
"""Call the API."""
|
"""Call the API."""
|
||||||
assert user_input.agent_id
|
assert user_input.agent_id
|
||||||
|
160
homeassistant/helpers/chat_session.py
Normal file
160
homeassistant/helpers/chat_session.py
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
"""Helper to organize chat sessions between integrations."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
||||||
|
from homeassistant.core import (
|
||||||
|
CALLBACK_TYPE,
|
||||||
|
Event,
|
||||||
|
HassJob,
|
||||||
|
HassJobType,
|
||||||
|
HomeAssistant,
|
||||||
|
callback,
|
||||||
|
)
|
||||||
|
from homeassistant.util import dt as dt_util, ulid as ulid_util
|
||||||
|
from homeassistant.util.hass_dict import HassKey
|
||||||
|
|
||||||
|
from .event import async_call_later
|
||||||
|
|
||||||
|
DATA_CHAT_SESSION: HassKey[dict[str, ChatSession]] = HassKey("chat_session")
|
||||||
|
DATA_CHAT_SESSION_CLEANUP: HassKey[SessionCleanup] = HassKey("chat_session_cleanup")
|
||||||
|
|
||||||
|
CONVERSATION_TIMEOUT = timedelta(minutes=5)
|
||||||
|
|
||||||
|
current_session: ContextVar[ChatSession | None] = ContextVar(
|
||||||
|
"current_session", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatSession:
|
||||||
|
"""Represent a chat session."""
|
||||||
|
|
||||||
|
conversation_id: str
|
||||||
|
last_updated: datetime = field(default_factory=dt_util.utcnow)
|
||||||
|
_cleanup_callbacks: list[CALLBACK_TYPE] = field(default_factory=list)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_updated(self) -> None:
|
||||||
|
"""Update the last updated time."""
|
||||||
|
self.last_updated = dt_util.utcnow()
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_on_cleanup(self, cb: CALLBACK_TYPE) -> None:
|
||||||
|
"""Register a callback to clean up the session."""
|
||||||
|
self._cleanup_callbacks.append(cb)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_cleanup(self) -> None:
|
||||||
|
"""Call all clean up callbacks."""
|
||||||
|
for cb in self._cleanup_callbacks:
|
||||||
|
cb()
|
||||||
|
self._cleanup_callbacks.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class SessionCleanup:
|
||||||
|
"""Helper to clean up the stale sessions."""
|
||||||
|
|
||||||
|
unsub: CALLBACK_TYPE | None = None
|
||||||
|
|
||||||
|
def __init__(self, hass: HomeAssistant) -> None:
|
||||||
|
"""Initialize the session cleanup."""
|
||||||
|
self.hass = hass
|
||||||
|
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self._on_hass_stop)
|
||||||
|
self.cleanup_job = HassJob(
|
||||||
|
self._cleanup, "chat_session_cleanup", job_type=HassJobType.Callback
|
||||||
|
)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def schedule(self) -> None:
|
||||||
|
"""Schedule the cleanup."""
|
||||||
|
if self.unsub:
|
||||||
|
return
|
||||||
|
self.unsub = async_call_later(
|
||||||
|
self.hass,
|
||||||
|
CONVERSATION_TIMEOUT.total_seconds() + 1,
|
||||||
|
self.cleanup_job,
|
||||||
|
)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _on_hass_stop(self, event: Event) -> None:
|
||||||
|
"""Cancel the cleanup on shutdown."""
|
||||||
|
if self.unsub:
|
||||||
|
self.unsub()
|
||||||
|
self.unsub = None
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _cleanup(self, now: datetime) -> None:
|
||||||
|
"""Clean up the history and schedule follow-up if necessary."""
|
||||||
|
self.unsub = None
|
||||||
|
all_sessions = self.hass.data[DATA_CHAT_SESSION]
|
||||||
|
|
||||||
|
# We mutate original object because current commands could be
|
||||||
|
# yielding session based on it.
|
||||||
|
for conversation_id, session in list(all_sessions.items()):
|
||||||
|
if session.last_updated + CONVERSATION_TIMEOUT < now:
|
||||||
|
del all_sessions[conversation_id]
|
||||||
|
session.async_cleanup()
|
||||||
|
|
||||||
|
# Still conversations left, check again in timeout time.
|
||||||
|
if all_sessions:
|
||||||
|
self.schedule()
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def async_get_chat_session(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
conversation_id: str | None = None,
|
||||||
|
) -> AsyncGenerator[ChatSession]:
|
||||||
|
"""Return a chat session."""
|
||||||
|
if session := current_session.get():
|
||||||
|
# If a session is already active and it's the requested conversation ID,
|
||||||
|
# return that. We won't update the last updated time in this case.
|
||||||
|
if session.conversation_id == conversation_id:
|
||||||
|
yield session
|
||||||
|
return
|
||||||
|
|
||||||
|
# If it's not the same conversation ID, we will create a new session
|
||||||
|
# because it might be a conversation agent calling a tool that is talking
|
||||||
|
# to another LLM.
|
||||||
|
session = None
|
||||||
|
|
||||||
|
all_sessions = hass.data.get(DATA_CHAT_SESSION)
|
||||||
|
if all_sessions is None:
|
||||||
|
all_sessions = {}
|
||||||
|
hass.data[DATA_CHAT_SESSION] = all_sessions
|
||||||
|
hass.data[DATA_CHAT_SESSION_CLEANUP] = SessionCleanup(hass)
|
||||||
|
|
||||||
|
if conversation_id is None:
|
||||||
|
conversation_id = ulid_util.ulid_now()
|
||||||
|
|
||||||
|
elif conversation_id in all_sessions:
|
||||||
|
session = all_sessions[conversation_id]
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Conversation IDs are ULIDs. We generate a new one if not provided.
|
||||||
|
# If an old ULID is passed in, we will generate a new one to indicate
|
||||||
|
# a new conversation was started. If the user picks their own, they
|
||||||
|
# want to track a conversation and we respect it.
|
||||||
|
try:
|
||||||
|
ulid_util.ulid_to_bytes(conversation_id)
|
||||||
|
conversation_id = ulid_util.ulid_now()
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if session is None:
|
||||||
|
session = ChatSession(conversation_id)
|
||||||
|
|
||||||
|
current_session.set(session)
|
||||||
|
yield session
|
||||||
|
current_session.set(None)
|
||||||
|
|
||||||
|
session.last_updated = dt_util.utcnow()
|
||||||
|
all_sessions[conversation_id] = session
|
||||||
|
hass.data[DATA_CHAT_SESSION_CLEANUP].schedule()
|
@ -8,10 +8,17 @@ import pytest
|
|||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components.conversation import ConversationInput, session
|
from homeassistant.components.conversation import (
|
||||||
|
Content,
|
||||||
|
ConversationInput,
|
||||||
|
ConverseError,
|
||||||
|
NativeContent,
|
||||||
|
async_get_chat_log,
|
||||||
|
)
|
||||||
|
from homeassistant.components.conversation.session import DATA_CHAT_HISTORY
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import llm
|
from homeassistant.helpers import chat_session, llm
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
from tests.common import async_fire_time_changed
|
from tests.common import async_fire_time_changed
|
||||||
@ -38,127 +45,69 @@ def mock_ulid() -> Generator[Mock]:
|
|||||||
yield mock_ulid_now
|
yield mock_ulid_now
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
("start_id", "given_id"),
|
|
||||||
[
|
|
||||||
(None, "mock-ulid"),
|
|
||||||
# This ULID is not known as a session
|
|
||||||
("01JHXE0952TSJCFJZ869AW6HMD", "mock-ulid"),
|
|
||||||
("not-a-ulid", "not-a-ulid"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
async def test_conversation_id(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
mock_conversation_input: ConversationInput,
|
|
||||||
mock_ulid: Mock,
|
|
||||||
start_id: str | None,
|
|
||||||
given_id: str,
|
|
||||||
) -> None:
|
|
||||||
"""Test conversation ID generation."""
|
|
||||||
mock_conversation_input.conversation_id = start_id
|
|
||||||
async with session.async_get_chat_session(
|
|
||||||
hass, mock_conversation_input
|
|
||||||
) as chat_session:
|
|
||||||
assert chat_session.conversation_id == given_id
|
|
||||||
|
|
||||||
|
|
||||||
async def test_cleanup(
|
async def test_cleanup(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_conversation_input: ConversationInput,
|
mock_conversation_input: ConversationInput,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Mock cleanup of the conversation session."""
|
"""Test cleanup of the chat log."""
|
||||||
async with session.async_get_chat_session(
|
async with (
|
||||||
hass, mock_conversation_input
|
chat_session.async_get_chat_session(hass) as session,
|
||||||
) as chat_session:
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
assert len(chat_session.messages) == 2
|
):
|
||||||
conversation_id = chat_session.conversation_id
|
conversation_id = session.conversation_id
|
||||||
|
# Add message so it persists
|
||||||
# Generate session entry.
|
chat_log.async_add_message(
|
||||||
async with session.async_get_chat_session(
|
Content(
|
||||||
hass, mock_conversation_input
|
|
||||||
) as chat_session:
|
|
||||||
# Because we didn't add a message to the session in the last block,
|
|
||||||
# the conversation was not be persisted and we get a new ID
|
|
||||||
assert chat_session.conversation_id != conversation_id
|
|
||||||
conversation_id = chat_session.conversation_id
|
|
||||||
chat_session.async_add_message(
|
|
||||||
session.Content(
|
|
||||||
role="assistant",
|
role="assistant",
|
||||||
agent_id="mock-agent-id",
|
agent_id=mock_conversation_input.agent_id,
|
||||||
content="Hey!",
|
content="",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
assert len(chat_session.messages) == 3
|
|
||||||
|
|
||||||
# Reuse conversation ID to ensure we can chat with same session
|
assert conversation_id in hass.data[DATA_CHAT_HISTORY]
|
||||||
mock_conversation_input.conversation_id = conversation_id
|
|
||||||
async with session.async_get_chat_session(
|
|
||||||
hass, mock_conversation_input
|
|
||||||
) as chat_session:
|
|
||||||
assert len(chat_session.messages) == 4
|
|
||||||
assert chat_session.conversation_id == conversation_id
|
|
||||||
|
|
||||||
# Set the last updated to be older than the timeout
|
# Set the last updated to be older than the timeout
|
||||||
hass.data[session.DATA_CHAT_HISTORY][conversation_id].last_updated = (
|
hass.data[chat_session.DATA_CHAT_SESSION][conversation_id].last_updated = (
|
||||||
dt_util.utcnow() + session.CONVERSATION_TIMEOUT
|
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT
|
||||||
)
|
)
|
||||||
|
|
||||||
async_fire_time_changed(
|
async_fire_time_changed(
|
||||||
hass, dt_util.utcnow() + session.CONVERSATION_TIMEOUT + timedelta(seconds=1)
|
hass,
|
||||||
|
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT * 2 + timedelta(seconds=1),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should not be cleaned up, but it should have scheduled another cleanup
|
assert conversation_id not in hass.data[DATA_CHAT_HISTORY]
|
||||||
mock_conversation_input.conversation_id = conversation_id
|
|
||||||
async with session.async_get_chat_session(
|
|
||||||
hass, mock_conversation_input
|
|
||||||
) as chat_session:
|
|
||||||
assert len(chat_session.messages) == 4
|
|
||||||
assert chat_session.conversation_id == conversation_id
|
|
||||||
|
|
||||||
async_fire_time_changed(
|
|
||||||
hass, dt_util.utcnow() + session.CONVERSATION_TIMEOUT * 2 + timedelta(seconds=1)
|
|
||||||
)
|
|
||||||
|
|
||||||
# It should be cleaned up now and we start a new conversation
|
|
||||||
async with session.async_get_chat_session(
|
|
||||||
hass, mock_conversation_input
|
|
||||||
) as chat_session:
|
|
||||||
assert chat_session.conversation_id != conversation_id
|
|
||||||
assert len(chat_session.messages) == 2
|
|
||||||
|
|
||||||
|
|
||||||
async def test_add_message(
|
async def test_add_message(
|
||||||
hass: HomeAssistant, mock_conversation_input: ConversationInput
|
hass: HomeAssistant, mock_conversation_input: ConversationInput
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test filtering of messages."""
|
"""Test filtering of messages."""
|
||||||
async with session.async_get_chat_session(
|
async with (
|
||||||
hass, mock_conversation_input
|
chat_session.async_get_chat_session(hass) as session,
|
||||||
) as chat_session:
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
assert len(chat_session.messages) == 2
|
):
|
||||||
|
assert len(chat_log.messages) == 2
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
chat_session.async_add_message(
|
chat_log.async_add_message(
|
||||||
session.Content(role="system", agent_id=None, content="")
|
Content(role="system", agent_id=None, content="")
|
||||||
)
|
)
|
||||||
|
|
||||||
# No 2 user messages in a row
|
# No 2 user messages in a row
|
||||||
assert chat_session.messages[1].role == "user"
|
assert chat_log.messages[1].role == "user"
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
chat_session.async_add_message(
|
chat_log.async_add_message(Content(role="user", agent_id=None, content=""))
|
||||||
session.Content(role="user", agent_id=None, content="")
|
|
||||||
)
|
|
||||||
|
|
||||||
# No 2 assistant messages in a row
|
# No 2 assistant messages in a row
|
||||||
chat_session.async_add_message(
|
chat_log.async_add_message(Content(role="assistant", agent_id=None, content=""))
|
||||||
session.Content(role="assistant", agent_id=None, content="")
|
assert len(chat_log.messages) == 3
|
||||||
)
|
assert chat_log.messages[-1].role == "assistant"
|
||||||
assert len(chat_session.messages) == 3
|
|
||||||
assert chat_session.messages[-1].role == "assistant"
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
chat_session.async_add_message(
|
chat_log.async_add_message(
|
||||||
session.Content(role="assistant", agent_id=None, content="")
|
Content(role="assistant", agent_id=None, content="")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -166,66 +115,65 @@ async def test_message_filtering(
|
|||||||
hass: HomeAssistant, mock_conversation_input: ConversationInput
|
hass: HomeAssistant, mock_conversation_input: ConversationInput
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test filtering of messages."""
|
"""Test filtering of messages."""
|
||||||
async with session.async_get_chat_session(
|
async with (
|
||||||
hass, mock_conversation_input
|
chat_session.async_get_chat_session(hass) as session,
|
||||||
) as chat_session:
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
messages = chat_session.async_get_messages(agent_id=None)
|
):
|
||||||
|
messages = chat_log.async_get_messages(agent_id=None)
|
||||||
assert len(messages) == 2
|
assert len(messages) == 2
|
||||||
assert messages[0] == session.Content(
|
assert messages[0] == Content(
|
||||||
role="system",
|
role="system",
|
||||||
agent_id=None,
|
agent_id=None,
|
||||||
content="",
|
content="",
|
||||||
)
|
)
|
||||||
assert messages[1] == session.Content(
|
assert messages[1] == Content(
|
||||||
role="user",
|
role="user",
|
||||||
agent_id="mock-agent-id",
|
agent_id="mock-agent-id",
|
||||||
content=mock_conversation_input.text,
|
content=mock_conversation_input.text,
|
||||||
)
|
)
|
||||||
# Cannot add a second user message in a row
|
# Cannot add a second user message in a row
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
chat_session.async_add_message(
|
chat_log.async_add_message(
|
||||||
session.Content(
|
Content(
|
||||||
role="user",
|
role="user",
|
||||||
agent_id="mock-agent-id",
|
agent_id="mock-agent-id",
|
||||||
content="Hey!",
|
content="Hey!",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_session.async_add_message(
|
chat_log.async_add_message(
|
||||||
session.Content(
|
Content(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
agent_id="mock-agent-id",
|
agent_id="mock-agent-id",
|
||||||
content="Hey!",
|
content="Hey!",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Different agent, native messages will be filtered out.
|
# Different agent, native messages will be filtered out.
|
||||||
chat_session.async_add_message(
|
chat_log.async_add_message(
|
||||||
session.NativeContent(agent_id="another-mock-agent-id", content=1)
|
NativeContent(agent_id="another-mock-agent-id", content=1)
|
||||||
)
|
|
||||||
chat_session.async_add_message(
|
|
||||||
session.NativeContent(agent_id="mock-agent-id", content=1)
|
|
||||||
)
|
)
|
||||||
|
chat_log.async_add_message(NativeContent(agent_id="mock-agent-id", content=1))
|
||||||
# A non-native message from another agent is not filtered out.
|
# A non-native message from another agent is not filtered out.
|
||||||
chat_session.async_add_message(
|
chat_log.async_add_message(
|
||||||
session.Content(
|
Content(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
agent_id="another-mock-agent-id",
|
agent_id="another-mock-agent-id",
|
||||||
content="Hi!",
|
content="Hi!",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(chat_session.messages) == 6
|
assert len(chat_log.messages) == 6
|
||||||
|
|
||||||
messages = chat_session.async_get_messages(agent_id="mock-agent-id")
|
messages = chat_log.async_get_messages(agent_id="mock-agent-id")
|
||||||
assert len(messages) == 5
|
assert len(messages) == 5
|
||||||
|
|
||||||
assert messages[2] == session.Content(
|
assert messages[2] == Content(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
agent_id="mock-agent-id",
|
agent_id="mock-agent-id",
|
||||||
content="Hey!",
|
content="Hey!",
|
||||||
)
|
)
|
||||||
assert messages[3] == session.NativeContent(agent_id="mock-agent-id", content=1)
|
assert messages[3] == NativeContent(agent_id="mock-agent-id", content=1)
|
||||||
assert messages[4] == session.Content(
|
assert messages[4] == Content(
|
||||||
role="assistant", agent_id="another-mock-agent-id", content="Hi!"
|
role="assistant", agent_id="another-mock-agent-id", content="Hi!"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -235,18 +183,19 @@ async def test_llm_api(
|
|||||||
mock_conversation_input: ConversationInput,
|
mock_conversation_input: ConversationInput,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test when we reference an LLM API."""
|
"""Test when we reference an LLM API."""
|
||||||
async with session.async_get_chat_session(
|
async with (
|
||||||
hass, mock_conversation_input
|
chat_session.async_get_chat_session(hass) as session,
|
||||||
) as chat_session:
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
await chat_session.async_update_llm_data(
|
):
|
||||||
|
await chat_log.async_update_llm_data(
|
||||||
conversing_domain="test",
|
conversing_domain="test",
|
||||||
user_input=mock_conversation_input,
|
user_input=mock_conversation_input,
|
||||||
user_llm_hass_api="assist",
|
user_llm_hass_api="assist",
|
||||||
user_llm_prompt=None,
|
user_llm_prompt=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(chat_session.llm_api, llm.APIInstance)
|
assert isinstance(chat_log.llm_api, llm.APIInstance)
|
||||||
assert chat_session.llm_api.api.id == "assist"
|
assert chat_log.llm_api.api.id == "assist"
|
||||||
|
|
||||||
|
|
||||||
async def test_unknown_llm_api(
|
async def test_unknown_llm_api(
|
||||||
@ -255,11 +204,12 @@ async def test_unknown_llm_api(
|
|||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test when we reference an LLM API that does not exists."""
|
"""Test when we reference an LLM API that does not exists."""
|
||||||
async with session.async_get_chat_session(
|
async with (
|
||||||
hass, mock_conversation_input
|
chat_session.async_get_chat_session(hass) as session,
|
||||||
) as chat_session:
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
with pytest.raises(session.ConverseError) as exc_info:
|
):
|
||||||
await chat_session.async_update_llm_data(
|
with pytest.raises(ConverseError) as exc_info:
|
||||||
|
await chat_log.async_update_llm_data(
|
||||||
conversing_domain="test",
|
conversing_domain="test",
|
||||||
user_input=mock_conversation_input,
|
user_input=mock_conversation_input,
|
||||||
user_llm_hass_api="unknown-api",
|
user_llm_hass_api="unknown-api",
|
||||||
@ -276,11 +226,12 @@ async def test_template_error(
|
|||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that template error handling works."""
|
"""Test that template error handling works."""
|
||||||
async with session.async_get_chat_session(
|
async with (
|
||||||
hass, mock_conversation_input
|
chat_session.async_get_chat_session(hass) as session,
|
||||||
) as chat_session:
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
with pytest.raises(session.ConverseError) as exc_info:
|
):
|
||||||
await chat_session.async_update_llm_data(
|
with pytest.raises(ConverseError) as exc_info:
|
||||||
|
await chat_log.async_update_llm_data(
|
||||||
conversing_domain="test",
|
conversing_domain="test",
|
||||||
user_input=mock_conversation_input,
|
user_input=mock_conversation_input,
|
||||||
user_llm_hass_api=None,
|
user_llm_hass_api=None,
|
||||||
@ -300,13 +251,14 @@ async def test_template_variables(
|
|||||||
mock_user.name = "Test User"
|
mock_user.name = "Test User"
|
||||||
mock_conversation_input.context = Context(user_id=mock_user.id)
|
mock_conversation_input.context = Context(user_id=mock_user.id)
|
||||||
|
|
||||||
async with session.async_get_chat_session(
|
async with (
|
||||||
hass, mock_conversation_input
|
chat_session.async_get_chat_session(hass) as session,
|
||||||
) as chat_session:
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
|
):
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.auth.AuthManager.async_get_user", return_value=mock_user
|
"homeassistant.auth.AuthManager.async_get_user", return_value=mock_user
|
||||||
):
|
):
|
||||||
await chat_session.async_update_llm_data(
|
await chat_log.async_update_llm_data(
|
||||||
conversing_domain="test",
|
conversing_domain="test",
|
||||||
user_input=mock_conversation_input,
|
user_input=mock_conversation_input,
|
||||||
user_llm_hass_api=None,
|
user_llm_hass_api=None,
|
||||||
@ -318,12 +270,12 @@ async def test_template_variables(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert chat_session.user_name == "Test User"
|
assert chat_log.user_name == "Test User"
|
||||||
|
|
||||||
assert "The instance name is test home." in chat_session.messages[0].content
|
assert "The instance name is test home." in chat_log.messages[0].content
|
||||||
assert "The user name is Test User." in chat_session.messages[0].content
|
assert "The user name is Test User." in chat_log.messages[0].content
|
||||||
assert "The user id is 12345." in chat_session.messages[0].content
|
assert "The user id is 12345." in chat_log.messages[0].content
|
||||||
assert "The calling platform is test." in chat_session.messages[0].content
|
assert "The calling platform is test." in chat_log.messages[0].content
|
||||||
|
|
||||||
|
|
||||||
async def test_extra_systen_prompt(
|
async def test_extra_systen_prompt(
|
||||||
@ -336,82 +288,86 @@ async def test_extra_systen_prompt(
|
|||||||
)
|
)
|
||||||
mock_conversation_input.extra_system_prompt = extra_system_prompt
|
mock_conversation_input.extra_system_prompt = extra_system_prompt
|
||||||
|
|
||||||
async with session.async_get_chat_session(
|
async with (
|
||||||
hass, mock_conversation_input
|
chat_session.async_get_chat_session(hass) as session,
|
||||||
) as chat_session:
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
await chat_session.async_update_llm_data(
|
):
|
||||||
|
await chat_log.async_update_llm_data(
|
||||||
conversing_domain="test",
|
conversing_domain="test",
|
||||||
user_input=mock_conversation_input,
|
user_input=mock_conversation_input,
|
||||||
user_llm_hass_api=None,
|
user_llm_hass_api=None,
|
||||||
user_llm_prompt=None,
|
user_llm_prompt=None,
|
||||||
)
|
)
|
||||||
chat_session.async_add_message(
|
chat_log.async_add_message(
|
||||||
session.Content(
|
Content(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
agent_id="mock-agent-id",
|
agent_id="mock-agent-id",
|
||||||
content="Hey!",
|
content="Hey!",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
assert chat_session.extra_system_prompt == extra_system_prompt
|
assert chat_log.extra_system_prompt == extra_system_prompt
|
||||||
assert chat_session.messages[0].content.endswith(extra_system_prompt)
|
assert chat_log.messages[0].content.endswith(extra_system_prompt)
|
||||||
|
|
||||||
# Verify that follow-up conversations with no system prompt take previous one
|
# Verify that follow-up conversations with no system prompt take previous one
|
||||||
mock_conversation_input.conversation_id = chat_session.conversation_id
|
conversation_id = chat_log.conversation_id
|
||||||
mock_conversation_input.extra_system_prompt = None
|
mock_conversation_input.extra_system_prompt = None
|
||||||
|
|
||||||
async with session.async_get_chat_session(
|
async with (
|
||||||
hass, mock_conversation_input
|
chat_session.async_get_chat_session(hass, conversation_id) as session,
|
||||||
) as chat_session:
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
await chat_session.async_update_llm_data(
|
):
|
||||||
|
await chat_log.async_update_llm_data(
|
||||||
conversing_domain="test",
|
conversing_domain="test",
|
||||||
user_input=mock_conversation_input,
|
user_input=mock_conversation_input,
|
||||||
user_llm_hass_api=None,
|
user_llm_hass_api=None,
|
||||||
user_llm_prompt=None,
|
user_llm_prompt=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert chat_session.extra_system_prompt == extra_system_prompt
|
assert chat_log.extra_system_prompt == extra_system_prompt
|
||||||
assert chat_session.messages[0].content.endswith(extra_system_prompt)
|
assert chat_log.messages[0].content.endswith(extra_system_prompt)
|
||||||
|
|
||||||
# Verify that we take new system prompts
|
# Verify that we take new system prompts
|
||||||
mock_conversation_input.extra_system_prompt = extra_system_prompt2
|
mock_conversation_input.extra_system_prompt = extra_system_prompt2
|
||||||
|
|
||||||
async with session.async_get_chat_session(
|
async with (
|
||||||
hass, mock_conversation_input
|
chat_session.async_get_chat_session(hass, conversation_id) as session,
|
||||||
) as chat_session:
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
await chat_session.async_update_llm_data(
|
):
|
||||||
|
await chat_log.async_update_llm_data(
|
||||||
conversing_domain="test",
|
conversing_domain="test",
|
||||||
user_input=mock_conversation_input,
|
user_input=mock_conversation_input,
|
||||||
user_llm_hass_api=None,
|
user_llm_hass_api=None,
|
||||||
user_llm_prompt=None,
|
user_llm_prompt=None,
|
||||||
)
|
)
|
||||||
chat_session.async_add_message(
|
chat_log.async_add_message(
|
||||||
session.Content(
|
Content(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
agent_id="mock-agent-id",
|
agent_id="mock-agent-id",
|
||||||
content="Hey!",
|
content="Hey!",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
assert chat_session.extra_system_prompt == extra_system_prompt2
|
assert chat_log.extra_system_prompt == extra_system_prompt2
|
||||||
assert chat_session.messages[0].content.endswith(extra_system_prompt2)
|
assert chat_log.messages[0].content.endswith(extra_system_prompt2)
|
||||||
assert extra_system_prompt not in chat_session.messages[0].content
|
assert extra_system_prompt not in chat_log.messages[0].content
|
||||||
|
|
||||||
# Verify that follow-up conversations with no system prompt take previous one
|
# Verify that follow-up conversations with no system prompt take previous one
|
||||||
mock_conversation_input.extra_system_prompt = None
|
mock_conversation_input.extra_system_prompt = None
|
||||||
|
|
||||||
async with session.async_get_chat_session(
|
async with (
|
||||||
hass, mock_conversation_input
|
chat_session.async_get_chat_session(hass, conversation_id) as session,
|
||||||
) as chat_session:
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
await chat_session.async_update_llm_data(
|
):
|
||||||
|
await chat_log.async_update_llm_data(
|
||||||
conversing_domain="test",
|
conversing_domain="test",
|
||||||
user_input=mock_conversation_input,
|
user_input=mock_conversation_input,
|
||||||
user_llm_hass_api=None,
|
user_llm_hass_api=None,
|
||||||
user_llm_prompt=None,
|
user_llm_prompt=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert chat_session.extra_system_prompt == extra_system_prompt2
|
assert chat_log.extra_system_prompt == extra_system_prompt2
|
||||||
assert chat_session.messages[0].content.endswith(extra_system_prompt2)
|
assert chat_log.messages[0].content.endswith(extra_system_prompt2)
|
||||||
|
|
||||||
|
|
||||||
async def test_tool_call(
|
async def test_tool_call(
|
||||||
@ -434,16 +390,17 @@ async def test_tool_call(
|
|||||||
) as mock_get_tools:
|
) as mock_get_tools:
|
||||||
mock_get_tools.return_value = [mock_tool]
|
mock_get_tools.return_value = [mock_tool]
|
||||||
|
|
||||||
async with session.async_get_chat_session(
|
async with (
|
||||||
hass, mock_conversation_input
|
chat_session.async_get_chat_session(hass) as session,
|
||||||
) as chat_session:
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
await chat_session.async_update_llm_data(
|
):
|
||||||
|
await chat_log.async_update_llm_data(
|
||||||
conversing_domain="test",
|
conversing_domain="test",
|
||||||
user_input=mock_conversation_input,
|
user_input=mock_conversation_input,
|
||||||
user_llm_hass_api="assist",
|
user_llm_hass_api="assist",
|
||||||
user_llm_prompt=None,
|
user_llm_prompt=None,
|
||||||
)
|
)
|
||||||
result = await chat_session.async_call_tool(
|
result = await chat_log.async_call_tool(
|
||||||
llm.ToolInput(
|
llm.ToolInput(
|
||||||
tool_name="test_tool",
|
tool_name="test_tool",
|
||||||
tool_args={"param1": "Test Param"},
|
tool_args={"param1": "Test Param"},
|
||||||
@ -473,16 +430,17 @@ async def test_tool_call_exception(
|
|||||||
) as mock_get_tools:
|
) as mock_get_tools:
|
||||||
mock_get_tools.return_value = [mock_tool]
|
mock_get_tools.return_value = [mock_tool]
|
||||||
|
|
||||||
async with session.async_get_chat_session(
|
async with (
|
||||||
hass, mock_conversation_input
|
chat_session.async_get_chat_session(hass) as session,
|
||||||
) as chat_session:
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
await chat_session.async_update_llm_data(
|
):
|
||||||
|
await chat_log.async_update_llm_data(
|
||||||
conversing_domain="test",
|
conversing_domain="test",
|
||||||
user_input=mock_conversation_input,
|
user_input=mock_conversation_input,
|
||||||
user_llm_hass_api="assist",
|
user_llm_hass_api="assist",
|
||||||
user_llm_prompt=None,
|
user_llm_prompt=None,
|
||||||
)
|
)
|
||||||
result = await chat_session.async_call_tool(
|
result = await chat_log.async_call_tool(
|
||||||
llm.ToolInput(
|
llm.ToolInput(
|
||||||
tool_name="test_tool",
|
tool_name="test_tool",
|
||||||
tool_args={"param1": "Test Param"},
|
tool_args={"param1": "Test Param"},
|
||||||
|
98
tests/helpers/test_chat_session.py
Normal file
98
tests/helpers/test_chat_session.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
"""Test the chat session helper."""
|
||||||
|
|
||||||
|
from collections.abc import Generator
|
||||||
|
from datetime import timedelta
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers import chat_session
|
||||||
|
from homeassistant.util import dt as dt_util, ulid as ulid_util
|
||||||
|
|
||||||
|
from tests.common import async_fire_time_changed
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_ulid() -> Generator[Mock]:
|
||||||
|
"""Mock the ulid library."""
|
||||||
|
with patch("homeassistant.util.ulid.ulid_now") as mock_ulid_now:
|
||||||
|
mock_ulid_now.return_value = "mock-ulid"
|
||||||
|
yield mock_ulid_now
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("start_id", "given_id"),
|
||||||
|
[
|
||||||
|
(None, "mock-ulid"),
|
||||||
|
# This ULID is not known as a session
|
||||||
|
("01JHXE0952TSJCFJZ869AW6HMD", "mock-ulid"),
|
||||||
|
("not-a-ulid", "not-a-ulid"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_conversation_id(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
start_id: str | None,
|
||||||
|
given_id: str,
|
||||||
|
mock_ulid: Mock,
|
||||||
|
) -> None:
|
||||||
|
"""Test conversation ID generation."""
|
||||||
|
async with chat_session.async_get_chat_session(hass, start_id) as session:
|
||||||
|
assert session.conversation_id == given_id
|
||||||
|
|
||||||
|
|
||||||
|
async def test_context_var(hass: HomeAssistant) -> None:
|
||||||
|
"""Test context var."""
|
||||||
|
async with chat_session.async_get_chat_session(hass) as session:
|
||||||
|
async with chat_session.async_get_chat_session(
|
||||||
|
hass, session.conversation_id
|
||||||
|
) as session2:
|
||||||
|
assert session is session2
|
||||||
|
|
||||||
|
async with chat_session.async_get_chat_session(hass, None) as session2:
|
||||||
|
assert session.conversation_id != session2.conversation_id
|
||||||
|
|
||||||
|
async with chat_session.async_get_chat_session(
|
||||||
|
hass, "something else"
|
||||||
|
) as session2:
|
||||||
|
assert session.conversation_id != session2.conversation_id
|
||||||
|
|
||||||
|
async with chat_session.async_get_chat_session(
|
||||||
|
hass, ulid_util.ulid_now()
|
||||||
|
) as session2:
|
||||||
|
assert session.conversation_id != session2.conversation_id
|
||||||
|
|
||||||
|
|
||||||
|
async def test_cleanup(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
) -> None:
|
||||||
|
"""Test cleanup of the chat session."""
|
||||||
|
async with chat_session.async_get_chat_session(hass) as session:
|
||||||
|
conversation_id = session.conversation_id
|
||||||
|
|
||||||
|
# Reuse conversation ID to ensure we can chat with same session
|
||||||
|
async with chat_session.async_get_chat_session(hass, conversation_id) as session:
|
||||||
|
assert session.conversation_id == conversation_id
|
||||||
|
|
||||||
|
# Set the last updated to be older than the timeout
|
||||||
|
hass.data[chat_session.DATA_CHAT_SESSION][conversation_id].last_updated = (
|
||||||
|
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT
|
||||||
|
)
|
||||||
|
|
||||||
|
async_fire_time_changed(
|
||||||
|
hass,
|
||||||
|
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT + timedelta(seconds=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not be cleaned up, but it should have scheduled another cleanup
|
||||||
|
async with chat_session.async_get_chat_session(hass, conversation_id) as session:
|
||||||
|
assert session.conversation_id == conversation_id
|
||||||
|
|
||||||
|
async_fire_time_changed(
|
||||||
|
hass,
|
||||||
|
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT * 2 + timedelta(seconds=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
# It should be cleaned up now and we start a new conversation
|
||||||
|
async with chat_session.async_get_chat_session(hass, conversation_id) as session:
|
||||||
|
assert session.conversation_id != conversation_id
|
Loading…
x
Reference in New Issue
Block a user