Extract conversation ID generation to helper (#137062)

* Extract conversation ID generation to helper

* Allow nested get_chat_log calls
This commit is contained in:
Paulus Schoutsen 2025-02-01 20:54:00 -05:00 committed by GitHub
parent 30314ca32b
commit 2f6640707b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 457 additions and 304 deletions

View File

@ -33,7 +33,7 @@ from homeassistant.components.tts import (
from homeassistant.const import MATCH_ALL
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import intent
from homeassistant.helpers import chat_session, intent
from homeassistant.helpers.collection import (
CHANGE_UPDATED,
CollectionError,
@ -1094,13 +1094,18 @@ class PipelineRun:
# It was already handled, create response and add to chat history
if intent_response is not None:
async with conversation.async_get_chat_session(
self.hass, user_input
) as chat_session:
async with (
chat_session.async_get_chat_session(
self.hass, user_input.conversation_id
) as session,
conversation.async_get_chat_log(
self.hass, session, user_input
) as chat_log,
):
speech: str = intent_response.speech.get("plain", {}).get(
"speech", ""
)
chat_session.async_add_message(
chat_log.async_add_message(
conversation.Content(
role="assistant",
agent_id=agent_id,
@ -1109,7 +1114,7 @@ class PipelineRun:
)
conversation_result = conversation.ConversationResult(
response=intent_response,
conversation_id=chat_session.conversation_id,
conversation_id=session.conversation_id,
)
else:

View File

@ -48,20 +48,14 @@ from .default_agent import DefaultAgent, async_setup_default_agent
from .entity import ConversationEntity
from .http import async_setup as async_setup_conversation_http
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
from .session import (
ChatSession,
Content,
ConverseError,
NativeContent,
async_get_chat_session,
)
from .session import ChatLog, Content, ConverseError, NativeContent, async_get_chat_log
from .trace import ConversationTraceEventType, async_conversation_trace_append
__all__ = [
"DOMAIN",
"HOME_ASSISTANT_AGENT",
"OLD_HOME_ASSISTANT_AGENT",
"ChatSession",
"ChatLog",
"Content",
"ConversationEntity",
"ConversationEntityFeature",
@ -73,7 +67,7 @@ __all__ = [
"async_conversation_trace_append",
"async_converse",
"async_get_agent_info",
"async_get_chat_session",
"async_get_chat_log",
"async_set_agent",
"async_setup",
"async_unset_agent",

View File

@ -42,6 +42,7 @@ from homeassistant.components.homeassistant.exposed_entities import (
from homeassistant.const import EVENT_STATE_CHANGED, MATCH_ALL
from homeassistant.helpers import (
area_registry as ar,
chat_session,
device_registry as dr,
entity_registry as er,
floor_registry as fr,
@ -62,7 +63,7 @@ from .const import (
)
from .entity import ConversationEntity
from .models import ConversationInput, ConversationResult
from .session import Content, async_get_chat_session
from .session import Content, async_get_chat_log
from .trace import ConversationTraceEventType, async_conversation_trace_append
_LOGGER = logging.getLogger(__name__)
@ -348,7 +349,12 @@ class DefaultAgent(ConversationEntity):
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
"""Process a sentence."""
response: intent.IntentResponse | None = None
async with async_get_chat_session(self.hass, user_input) as chat_session:
async with (
chat_session.async_get_chat_session(
self.hass, user_input.conversation_id
) as session,
async_get_chat_log(self.hass, session, user_input) as chat_log,
):
# Check if a trigger matched
if trigger_result := await self.async_recognize_sentence_trigger(
user_input
@ -373,7 +379,7 @@ class DefaultAgent(ConversationEntity):
)
speech: str = response.speech.get("plain", {}).get("speech", "")
chat_session.async_add_message(
chat_log.async_add_message(
Content(
role="assistant",
agent_id=user_input.agent_id,
@ -382,7 +388,7 @@ class DefaultAgent(ConversationEntity):
)
return ConversationResult(
response=response, conversation_id=chat_session.conversation_id
response=response, conversation_id=session.conversation_id
)
async def _async_process_intent_result(

View File

@ -5,25 +5,16 @@ from __future__ import annotations
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from dataclasses import dataclass, field, replace
from datetime import datetime, timedelta
from datetime import datetime
import logging
from typing import Literal
import voluptuous as vol
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import (
CALLBACK_TYPE,
Event,
HassJob,
HassJobType,
HomeAssistant,
callback,
)
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError, TemplateError
from homeassistant.helpers import intent, llm, template
from homeassistant.helpers.event import async_call_later
from homeassistant.util import dt as dt_util, ulid as ulid_util
from homeassistant.helpers import chat_session, intent, llm, template
from homeassistant.util import dt as dt_util
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.json import JsonObjectType
@ -31,100 +22,36 @@ from . import trace
from .const import DOMAIN
from .models import ConversationInput, ConversationResult
DATA_CHAT_HISTORY: HassKey[dict[str, ChatSession]] = HassKey(
"conversation_chat_session"
)
DATA_CHAT_HISTORY_CLEANUP: HassKey[SessionCleanup] = HassKey(
"conversation_chat_session_cleanup"
)
DATA_CHAT_HISTORY: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_log")
LOGGER = logging.getLogger(__name__)
CONVERSATION_TIMEOUT = timedelta(minutes=5)
class SessionCleanup:
"""Helper to clean up the history."""
unsub: CALLBACK_TYPE | None = None
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the history cleanup."""
self.hass = hass
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self._on_hass_stop)
self.cleanup_job = HassJob(
self._cleanup, "conversation_history_cleanup", job_type=HassJobType.Callback
)
@callback
def schedule(self) -> None:
"""Schedule the cleanup."""
if self.unsub:
return
self.unsub = async_call_later(
self.hass,
CONVERSATION_TIMEOUT.total_seconds() + 1,
self.cleanup_job,
)
@callback
def _on_hass_stop(self, event: Event) -> None:
"""Cancel the cleanup on shutdown."""
if self.unsub:
self.unsub()
self.unsub = None
@callback
def _cleanup(self, now: datetime) -> None:
"""Clean up the history and schedule follow-up if necessary."""
self.unsub = None
all_history = self.hass.data[DATA_CHAT_HISTORY]
# We mutate original object because current commands could be
# yielding history based on it.
for conversation_id, history in list(all_history.items()):
if history.last_updated + CONVERSATION_TIMEOUT < now:
del all_history[conversation_id]
# Still conversations left, check again in timeout time.
if all_history:
self.schedule()
@asynccontextmanager
async def async_get_chat_session(
async def async_get_chat_log(
hass: HomeAssistant,
session: chat_session.ChatSession,
user_input: ConversationInput,
) -> AsyncGenerator[ChatSession]:
"""Return chat session."""
) -> AsyncGenerator[ChatLog]:
"""Return chat log for a specific chat session."""
all_history = hass.data.get(DATA_CHAT_HISTORY)
if all_history is None:
all_history = {}
hass.data[DATA_CHAT_HISTORY] = all_history
hass.data[DATA_CHAT_HISTORY_CLEANUP] = SessionCleanup(hass)
history: ChatSession | None = None
if user_input.conversation_id is None:
conversation_id = ulid_util.ulid_now()
elif history := all_history.get(user_input.conversation_id):
conversation_id = user_input.conversation_id
else:
# Conversation IDs are ULIDs. We generate a new one if not provided.
# If an old OLID is passed in, we will generate a new one to indicate
# a new conversation was started. If the user picks their own, they
# want to track a conversation and we respect it.
try:
ulid_util.ulid_to_bytes(user_input.conversation_id)
conversation_id = ulid_util.ulid_now()
except ValueError:
conversation_id = user_input.conversation_id
history = all_history.get(session.conversation_id)
if history:
history = replace(history, messages=history.messages.copy())
else:
history = ChatSession(hass, conversation_id, user_input.agent_id)
history = ChatLog(hass, session.conversation_id, user_input.agent_id)
@callback
def do_cleanup() -> None:
"""Handle cleanup."""
all_history.pop(session.conversation_id)
session.async_on_cleanup(do_cleanup)
message: Content = Content(
role="user",
@ -142,8 +69,7 @@ async def async_get_chat_session(
return
history.last_updated = dt_util.utcnow()
all_history[conversation_id] = history
hass.data[DATA_CHAT_HISTORY_CLEANUP].schedule()
all_history[session.conversation_id] = history
class ConverseError(HomeAssistantError):
@ -187,8 +113,8 @@ class NativeContent[_NativeT]:
@dataclass
class ChatSession[_NativeT]:
"""Class holding all information for a specific conversation."""
class ChatLog[_NativeT]:
"""Class holding the chat history of a specific conversation."""
hass: HomeAssistant
conversation_id: str

View File

@ -18,7 +18,7 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr, intent, llm
from homeassistant.helpers import chat_session, device_registry as dr, intent, llm
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import (
@ -209,15 +209,18 @@ class GoogleGenerativeAIConversationEntity(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process a sentence."""
async with conversation.async_get_chat_session(
self.hass, user_input
) as session:
return await self._async_handle_message(user_input, session)
async with (
chat_session.async_get_chat_session(
self.hass, user_input.conversation_id
) as session,
conversation.async_get_chat_log(self.hass, session, user_input) as chat_log,
):
return await self._async_handle_message(user_input, chat_log)
async def _async_handle_message(
self,
user_input: conversation.ConversationInput,
session: conversation.ChatSession[genai_types.ContentDict],
session: conversation.ChatLog[genai_types.ContentDict],
) -> conversation.ConversationResult:
"""Call the API."""

View File

@ -23,7 +23,7 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr, intent, llm
from homeassistant.helpers import chat_session, device_registry as dr, intent, llm
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from . import OpenAIConfigEntry
@ -155,15 +155,18 @@ class OpenAIConversationEntity(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process a sentence."""
async with conversation.async_get_chat_session(
self.hass, user_input
) as session:
return await self._async_handle_message(user_input, session)
async with (
chat_session.async_get_chat_session(
self.hass, user_input.conversation_id
) as session,
conversation.async_get_chat_log(self.hass, session, user_input) as chat_log,
):
return await self._async_handle_message(user_input, chat_log)
async def _async_handle_message(
self,
user_input: conversation.ConversationInput,
session: conversation.ChatSession[ChatCompletionMessageParam],
session: conversation.ChatLog[ChatCompletionMessageParam],
) -> conversation.ConversationResult:
"""Call the API."""
assert user_input.agent_id

View File

@ -0,0 +1,160 @@
"""Helper to organize chat sessions between integrations."""
from __future__ import annotations
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from contextvars import ContextVar
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import (
CALLBACK_TYPE,
Event,
HassJob,
HassJobType,
HomeAssistant,
callback,
)
from homeassistant.util import dt as dt_util, ulid as ulid_util
from homeassistant.util.hass_dict import HassKey
from .event import async_call_later
DATA_CHAT_SESSION: HassKey[dict[str, ChatSession]] = HassKey("chat_session")
DATA_CHAT_SESSION_CLEANUP: HassKey[SessionCleanup] = HassKey("chat_session_cleanup")
CONVERSATION_TIMEOUT = timedelta(minutes=5)
current_session: ContextVar[ChatSession | None] = ContextVar(
"current_session", default=None
)
@dataclass
class ChatSession:
"""Represent a chat session."""
conversation_id: str
last_updated: datetime = field(default_factory=dt_util.utcnow)
_cleanup_callbacks: list[CALLBACK_TYPE] = field(default_factory=list)
@callback
def async_updated(self) -> None:
"""Update the last updated time."""
self.last_updated = dt_util.utcnow()
@callback
def async_on_cleanup(self, cb: CALLBACK_TYPE) -> None:
"""Register a callback to clean up the session."""
self._cleanup_callbacks.append(cb)
@callback
def async_cleanup(self) -> None:
"""Call all clean up callbacks."""
for cb in self._cleanup_callbacks:
cb()
self._cleanup_callbacks.clear()
class SessionCleanup:
"""Helper to clean up the stale sessions."""
unsub: CALLBACK_TYPE | None = None
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the session cleanup."""
self.hass = hass
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self._on_hass_stop)
self.cleanup_job = HassJob(
self._cleanup, "chat_session_cleanup", job_type=HassJobType.Callback
)
@callback
def schedule(self) -> None:
"""Schedule the cleanup."""
if self.unsub:
return
self.unsub = async_call_later(
self.hass,
CONVERSATION_TIMEOUT.total_seconds() + 1,
self.cleanup_job,
)
@callback
def _on_hass_stop(self, event: Event) -> None:
"""Cancel the cleanup on shutdown."""
if self.unsub:
self.unsub()
self.unsub = None
@callback
def _cleanup(self, now: datetime) -> None:
"""Clean up the history and schedule follow-up if necessary."""
self.unsub = None
all_sessions = self.hass.data[DATA_CHAT_SESSION]
# We mutate original object because current commands could be
# yielding session based on it.
for conversation_id, session in list(all_sessions.items()):
if session.last_updated + CONVERSATION_TIMEOUT < now:
del all_sessions[conversation_id]
session.async_cleanup()
# Still conversations left, check again in timeout time.
if all_sessions:
self.schedule()
@asynccontextmanager
async def async_get_chat_session(
hass: HomeAssistant,
conversation_id: str | None = None,
) -> AsyncGenerator[ChatSession]:
"""Return a chat session."""
if session := current_session.get():
# If a session is already active and it's the requested conversation ID,
# return that. We won't update the last updated time in this case.
if session.conversation_id == conversation_id:
yield session
return
# If it's not the same conversation ID, we will create a new session
# because it might be a conversation agent calling a tool that is talking
# to another LLM.
session = None
all_sessions = hass.data.get(DATA_CHAT_SESSION)
if all_sessions is None:
all_sessions = {}
hass.data[DATA_CHAT_SESSION] = all_sessions
hass.data[DATA_CHAT_SESSION_CLEANUP] = SessionCleanup(hass)
if conversation_id is None:
conversation_id = ulid_util.ulid_now()
elif conversation_id in all_sessions:
session = all_sessions[conversation_id]
else:
# Conversation IDs are ULIDs. We generate a new one if not provided.
# If an old ULID is passed in, we will generate a new one to indicate
# a new conversation was started. If the user picks their own, they
# want to track a conversation and we respect it.
try:
ulid_util.ulid_to_bytes(conversation_id)
conversation_id = ulid_util.ulid_now()
except ValueError:
pass
if session is None:
session = ChatSession(conversation_id)
current_session.set(session)
yield session
current_session.set(None)
session.last_updated = dt_util.utcnow()
all_sessions[conversation_id] = session
hass.data[DATA_CHAT_SESSION_CLEANUP].schedule()

View File

@ -8,10 +8,17 @@ import pytest
from syrupy.assertion import SnapshotAssertion
import voluptuous as vol
from homeassistant.components.conversation import ConversationInput, session
from homeassistant.components.conversation import (
Content,
ConversationInput,
ConverseError,
NativeContent,
async_get_chat_log,
)
from homeassistant.components.conversation.session import DATA_CHAT_HISTORY
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import llm
from homeassistant.helpers import chat_session, llm
from homeassistant.util import dt as dt_util
from tests.common import async_fire_time_changed
@ -38,127 +45,69 @@ def mock_ulid() -> Generator[Mock]:
yield mock_ulid_now
@pytest.mark.parametrize(
("start_id", "given_id"),
[
(None, "mock-ulid"),
# This ULID is not known as a session
("01JHXE0952TSJCFJZ869AW6HMD", "mock-ulid"),
("not-a-ulid", "not-a-ulid"),
],
)
async def test_conversation_id(
hass: HomeAssistant,
mock_conversation_input: ConversationInput,
mock_ulid: Mock,
start_id: str | None,
given_id: str,
) -> None:
"""Test conversation ID generation."""
mock_conversation_input.conversation_id = start_id
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
assert chat_session.conversation_id == given_id
async def test_cleanup(
hass: HomeAssistant,
mock_conversation_input: ConversationInput,
) -> None:
"""Mock cleanup of the conversation session."""
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
assert len(chat_session.messages) == 2
conversation_id = chat_session.conversation_id
# Generate session entry.
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
# Because we didn't add a message to the session in the last block,
# the conversation was not be persisted and we get a new ID
assert chat_session.conversation_id != conversation_id
conversation_id = chat_session.conversation_id
chat_session.async_add_message(
session.Content(
"""Test cleanup of the chat log."""
async with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
conversation_id = session.conversation_id
# Add message so it persists
chat_log.async_add_message(
Content(
role="assistant",
agent_id="mock-agent-id",
content="Hey!",
agent_id=mock_conversation_input.agent_id,
content="",
)
)
assert len(chat_session.messages) == 3
# Reuse conversation ID to ensure we can chat with same session
mock_conversation_input.conversation_id = conversation_id
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
assert len(chat_session.messages) == 4
assert chat_session.conversation_id == conversation_id
assert conversation_id in hass.data[DATA_CHAT_HISTORY]
# Set the last updated to be older than the timeout
hass.data[session.DATA_CHAT_HISTORY][conversation_id].last_updated = (
dt_util.utcnow() + session.CONVERSATION_TIMEOUT
hass.data[chat_session.DATA_CHAT_SESSION][conversation_id].last_updated = (
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT
)
async_fire_time_changed(
hass, dt_util.utcnow() + session.CONVERSATION_TIMEOUT + timedelta(seconds=1)
hass,
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT * 2 + timedelta(seconds=1),
)
# Should not be cleaned up, but it should have scheduled another cleanup
mock_conversation_input.conversation_id = conversation_id
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
assert len(chat_session.messages) == 4
assert chat_session.conversation_id == conversation_id
async_fire_time_changed(
hass, dt_util.utcnow() + session.CONVERSATION_TIMEOUT * 2 + timedelta(seconds=1)
)
# It should be cleaned up now and we start a new conversation
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
assert chat_session.conversation_id != conversation_id
assert len(chat_session.messages) == 2
assert conversation_id not in hass.data[DATA_CHAT_HISTORY]
async def test_add_message(
hass: HomeAssistant, mock_conversation_input: ConversationInput
) -> None:
"""Test filtering of messages."""
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
assert len(chat_session.messages) == 2
async with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
assert len(chat_log.messages) == 2
with pytest.raises(ValueError):
chat_session.async_add_message(
session.Content(role="system", agent_id=None, content="")
chat_log.async_add_message(
Content(role="system", agent_id=None, content="")
)
# No 2 user messages in a row
assert chat_session.messages[1].role == "user"
assert chat_log.messages[1].role == "user"
with pytest.raises(ValueError):
chat_session.async_add_message(
session.Content(role="user", agent_id=None, content="")
)
chat_log.async_add_message(Content(role="user", agent_id=None, content=""))
# No 2 assistant messages in a row
chat_session.async_add_message(
session.Content(role="assistant", agent_id=None, content="")
)
assert len(chat_session.messages) == 3
assert chat_session.messages[-1].role == "assistant"
chat_log.async_add_message(Content(role="assistant", agent_id=None, content=""))
assert len(chat_log.messages) == 3
assert chat_log.messages[-1].role == "assistant"
with pytest.raises(ValueError):
chat_session.async_add_message(
session.Content(role="assistant", agent_id=None, content="")
chat_log.async_add_message(
Content(role="assistant", agent_id=None, content="")
)
@ -166,66 +115,65 @@ async def test_message_filtering(
hass: HomeAssistant, mock_conversation_input: ConversationInput
) -> None:
"""Test filtering of messages."""
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
messages = chat_session.async_get_messages(agent_id=None)
async with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
messages = chat_log.async_get_messages(agent_id=None)
assert len(messages) == 2
assert messages[0] == session.Content(
assert messages[0] == Content(
role="system",
agent_id=None,
content="",
)
assert messages[1] == session.Content(
assert messages[1] == Content(
role="user",
agent_id="mock-agent-id",
content=mock_conversation_input.text,
)
# Cannot add a second user message in a row
with pytest.raises(ValueError):
chat_session.async_add_message(
session.Content(
chat_log.async_add_message(
Content(
role="user",
agent_id="mock-agent-id",
content="Hey!",
)
)
chat_session.async_add_message(
session.Content(
chat_log.async_add_message(
Content(
role="assistant",
agent_id="mock-agent-id",
content="Hey!",
)
)
# Different agent, native messages will be filtered out.
chat_session.async_add_message(
session.NativeContent(agent_id="another-mock-agent-id", content=1)
)
chat_session.async_add_message(
session.NativeContent(agent_id="mock-agent-id", content=1)
chat_log.async_add_message(
NativeContent(agent_id="another-mock-agent-id", content=1)
)
chat_log.async_add_message(NativeContent(agent_id="mock-agent-id", content=1))
# A non-native message from another agent is not filtered out.
chat_session.async_add_message(
session.Content(
chat_log.async_add_message(
Content(
role="assistant",
agent_id="another-mock-agent-id",
content="Hi!",
)
)
assert len(chat_session.messages) == 6
assert len(chat_log.messages) == 6
messages = chat_session.async_get_messages(agent_id="mock-agent-id")
messages = chat_log.async_get_messages(agent_id="mock-agent-id")
assert len(messages) == 5
assert messages[2] == session.Content(
assert messages[2] == Content(
role="assistant",
agent_id="mock-agent-id",
content="Hey!",
)
assert messages[3] == session.NativeContent(agent_id="mock-agent-id", content=1)
assert messages[4] == session.Content(
assert messages[3] == NativeContent(agent_id="mock-agent-id", content=1)
assert messages[4] == Content(
role="assistant", agent_id="another-mock-agent-id", content="Hi!"
)
@ -235,18 +183,19 @@ async def test_llm_api(
mock_conversation_input: ConversationInput,
) -> None:
"""Test when we reference an LLM API."""
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
await chat_session.async_update_llm_data(
async with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
user_llm_hass_api="assist",
user_llm_prompt=None,
)
assert isinstance(chat_session.llm_api, llm.APIInstance)
assert chat_session.llm_api.api.id == "assist"
assert isinstance(chat_log.llm_api, llm.APIInstance)
assert chat_log.llm_api.api.id == "assist"
async def test_unknown_llm_api(
@ -255,11 +204,12 @@ async def test_unknown_llm_api(
snapshot: SnapshotAssertion,
) -> None:
"""Test when we reference an LLM API that does not exists."""
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
with pytest.raises(session.ConverseError) as exc_info:
await chat_session.async_update_llm_data(
async with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
with pytest.raises(ConverseError) as exc_info:
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
user_llm_hass_api="unknown-api",
@ -276,11 +226,12 @@ async def test_template_error(
snapshot: SnapshotAssertion,
) -> None:
"""Test that template error handling works."""
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
with pytest.raises(session.ConverseError) as exc_info:
await chat_session.async_update_llm_data(
async with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
with pytest.raises(ConverseError) as exc_info:
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
user_llm_hass_api=None,
@ -300,13 +251,14 @@ async def test_template_variables(
mock_user.name = "Test User"
mock_conversation_input.context = Context(user_id=mock_user.id)
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
async with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
with patch(
"homeassistant.auth.AuthManager.async_get_user", return_value=mock_user
):
await chat_session.async_update_llm_data(
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
user_llm_hass_api=None,
@ -318,12 +270,12 @@ async def test_template_variables(
),
)
assert chat_session.user_name == "Test User"
assert chat_log.user_name == "Test User"
assert "The instance name is test home." in chat_session.messages[0].content
assert "The user name is Test User." in chat_session.messages[0].content
assert "The user id is 12345." in chat_session.messages[0].content
assert "The calling platform is test." in chat_session.messages[0].content
assert "The instance name is test home." in chat_log.messages[0].content
assert "The user name is Test User." in chat_log.messages[0].content
assert "The user id is 12345." in chat_log.messages[0].content
assert "The calling platform is test." in chat_log.messages[0].content
async def test_extra_systen_prompt(
@ -336,82 +288,86 @@ async def test_extra_systen_prompt(
)
mock_conversation_input.extra_system_prompt = extra_system_prompt
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
await chat_session.async_update_llm_data(
async with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
user_llm_hass_api=None,
user_llm_prompt=None,
)
chat_session.async_add_message(
session.Content(
chat_log.async_add_message(
Content(
role="assistant",
agent_id="mock-agent-id",
content="Hey!",
)
)
assert chat_session.extra_system_prompt == extra_system_prompt
assert chat_session.messages[0].content.endswith(extra_system_prompt)
assert chat_log.extra_system_prompt == extra_system_prompt
assert chat_log.messages[0].content.endswith(extra_system_prompt)
# Verify that follow-up conversations with no system prompt take previous one
mock_conversation_input.conversation_id = chat_session.conversation_id
conversation_id = chat_log.conversation_id
mock_conversation_input.extra_system_prompt = None
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
await chat_session.async_update_llm_data(
async with (
chat_session.async_get_chat_session(hass, conversation_id) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
user_llm_hass_api=None,
user_llm_prompt=None,
)
assert chat_session.extra_system_prompt == extra_system_prompt
assert chat_session.messages[0].content.endswith(extra_system_prompt)
assert chat_log.extra_system_prompt == extra_system_prompt
assert chat_log.messages[0].content.endswith(extra_system_prompt)
# Verify that we take new system prompts
mock_conversation_input.extra_system_prompt = extra_system_prompt2
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
await chat_session.async_update_llm_data(
async with (
chat_session.async_get_chat_session(hass, conversation_id) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
user_llm_hass_api=None,
user_llm_prompt=None,
)
chat_session.async_add_message(
session.Content(
chat_log.async_add_message(
Content(
role="assistant",
agent_id="mock-agent-id",
content="Hey!",
)
)
assert chat_session.extra_system_prompt == extra_system_prompt2
assert chat_session.messages[0].content.endswith(extra_system_prompt2)
assert extra_system_prompt not in chat_session.messages[0].content
assert chat_log.extra_system_prompt == extra_system_prompt2
assert chat_log.messages[0].content.endswith(extra_system_prompt2)
assert extra_system_prompt not in chat_log.messages[0].content
# Verify that follow-up conversations with no system prompt take previous one
mock_conversation_input.extra_system_prompt = None
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
await chat_session.async_update_llm_data(
async with (
chat_session.async_get_chat_session(hass, conversation_id) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
user_llm_hass_api=None,
user_llm_prompt=None,
)
assert chat_session.extra_system_prompt == extra_system_prompt2
assert chat_session.messages[0].content.endswith(extra_system_prompt2)
assert chat_log.extra_system_prompt == extra_system_prompt2
assert chat_log.messages[0].content.endswith(extra_system_prompt2)
async def test_tool_call(
@ -434,16 +390,17 @@ async def test_tool_call(
) as mock_get_tools:
mock_get_tools.return_value = [mock_tool]
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
await chat_session.async_update_llm_data(
async with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
user_llm_hass_api="assist",
user_llm_prompt=None,
)
result = await chat_session.async_call_tool(
result = await chat_log.async_call_tool(
llm.ToolInput(
tool_name="test_tool",
tool_args={"param1": "Test Param"},
@ -473,16 +430,17 @@ async def test_tool_call_exception(
) as mock_get_tools:
mock_get_tools.return_value = [mock_tool]
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
await chat_session.async_update_llm_data(
async with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
user_llm_hass_api="assist",
user_llm_prompt=None,
)
result = await chat_session.async_call_tool(
result = await chat_log.async_call_tool(
llm.ToolInput(
tool_name="test_tool",
tool_args={"param1": "Test Param"},

View File

@ -0,0 +1,98 @@
"""Test the chat session helper."""
from collections.abc import Generator
from datetime import timedelta
from unittest.mock import Mock, patch
import pytest
from homeassistant.core import HomeAssistant
from homeassistant.helpers import chat_session
from homeassistant.util import dt as dt_util, ulid as ulid_util
from tests.common import async_fire_time_changed
@pytest.fixture
def mock_ulid() -> Generator[Mock]:
"""Mock the ulid library."""
with patch("homeassistant.util.ulid.ulid_now") as mock_ulid_now:
mock_ulid_now.return_value = "mock-ulid"
yield mock_ulid_now
@pytest.mark.parametrize(
("start_id", "given_id"),
[
(None, "mock-ulid"),
# This ULID is not known as a session
("01JHXE0952TSJCFJZ869AW6HMD", "mock-ulid"),
("not-a-ulid", "not-a-ulid"),
],
)
async def test_conversation_id(
hass: HomeAssistant,
start_id: str | None,
given_id: str,
mock_ulid: Mock,
) -> None:
"""Test conversation ID generation."""
async with chat_session.async_get_chat_session(hass, start_id) as session:
assert session.conversation_id == given_id
async def test_context_var(hass: HomeAssistant) -> None:
"""Test context var."""
async with chat_session.async_get_chat_session(hass) as session:
async with chat_session.async_get_chat_session(
hass, session.conversation_id
) as session2:
assert session is session2
async with chat_session.async_get_chat_session(hass, None) as session2:
assert session.conversation_id != session2.conversation_id
async with chat_session.async_get_chat_session(
hass, "something else"
) as session2:
assert session.conversation_id != session2.conversation_id
async with chat_session.async_get_chat_session(
hass, ulid_util.ulid_now()
) as session2:
assert session.conversation_id != session2.conversation_id
async def test_cleanup(
hass: HomeAssistant,
) -> None:
"""Test cleanup of the chat session."""
async with chat_session.async_get_chat_session(hass) as session:
conversation_id = session.conversation_id
# Reuse conversation ID to ensure we can chat with same session
async with chat_session.async_get_chat_session(hass, conversation_id) as session:
assert session.conversation_id == conversation_id
# Set the last updated to be older than the timeout
hass.data[chat_session.DATA_CHAT_SESSION][conversation_id].last_updated = (
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT
)
async_fire_time_changed(
hass,
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT + timedelta(seconds=1),
)
# Should not be cleaned up, but it should have scheduled another cleanup
async with chat_session.async_get_chat_session(hass, conversation_id) as session:
assert session.conversation_id == conversation_id
async_fire_time_changed(
hass,
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT * 2 + timedelta(seconds=1),
)
# It should be cleaned up now and we start a new conversation
async with chat_session.async_get_chat_session(hass, conversation_id) as session:
assert session.conversation_id != conversation_id