diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index 1d320d79bf2..f2b2f1c1ea4 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -33,7 +33,7 @@ from homeassistant.components.tts import ( from homeassistant.const import MATCH_ALL from homeassistant.core import Context, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import intent +from homeassistant.helpers import chat_session, intent from homeassistant.helpers.collection import ( CHANGE_UPDATED, CollectionError, @@ -1094,13 +1094,18 @@ class PipelineRun: # It was already handled, create response and add to chat history if intent_response is not None: - async with conversation.async_get_chat_session( - self.hass, user_input - ) as chat_session: + async with ( + chat_session.async_get_chat_session( + self.hass, user_input.conversation_id + ) as session, + conversation.async_get_chat_log( + self.hass, session, user_input + ) as chat_log, + ): speech: str = intent_response.speech.get("plain", {}).get( "speech", "" ) - chat_session.async_add_message( + chat_log.async_add_message( conversation.Content( role="assistant", agent_id=agent_id, @@ -1109,7 +1114,7 @@ class PipelineRun: ) conversation_result = conversation.ConversationResult( response=intent_response, - conversation_id=chat_session.conversation_id, + conversation_id=session.conversation_id, ) else: diff --git a/homeassistant/components/conversation/__init__.py b/homeassistant/components/conversation/__init__.py index b110d53540c..13152beff51 100644 --- a/homeassistant/components/conversation/__init__.py +++ b/homeassistant/components/conversation/__init__.py @@ -48,20 +48,14 @@ from .default_agent import DefaultAgent, async_setup_default_agent from .entity import ConversationEntity from .http import async_setup as async_setup_conversation_http from .models import AbstractConversationAgent, ConversationInput, ConversationResult -from .session import ( - ChatSession, - Content, - ConverseError, - NativeContent, - async_get_chat_session, -) +from .session import ChatLog, Content, ConverseError, NativeContent, async_get_chat_log from .trace import ConversationTraceEventType, async_conversation_trace_append __all__ = [ "DOMAIN", "HOME_ASSISTANT_AGENT", "OLD_HOME_ASSISTANT_AGENT", - "ChatSession", + "ChatLog", "Content", "ConversationEntity", "ConversationEntityFeature", @@ -73,7 +67,7 @@ __all__ = [ "async_conversation_trace_append", "async_converse", "async_get_agent_info", - "async_get_chat_session", + "async_get_chat_log", "async_set_agent", "async_setup", "async_unset_agent", diff --git a/homeassistant/components/conversation/default_agent.py b/homeassistant/components/conversation/default_agent.py index be0387555dc..99a1a09e52b 100644 --- a/homeassistant/components/conversation/default_agent.py +++ b/homeassistant/components/conversation/default_agent.py @@ -42,6 +42,7 @@ from homeassistant.components.homeassistant.exposed_entities import ( from homeassistant.const import EVENT_STATE_CHANGED, MATCH_ALL from homeassistant.helpers import ( area_registry as ar, + chat_session, device_registry as dr, entity_registry as er, floor_registry as fr, @@ -62,7 +63,7 @@ from .const import ( ) from .entity import ConversationEntity from .models import ConversationInput, ConversationResult -from .session import Content, async_get_chat_session +from .session import Content, async_get_chat_log from .trace import ConversationTraceEventType, async_conversation_trace_append _LOGGER = logging.getLogger(__name__) @@ -348,7 +349,12 @@ class DefaultAgent(ConversationEntity): async def async_process(self, user_input: ConversationInput) -> ConversationResult: """Process a sentence.""" response: intent.IntentResponse | None = None - async with async_get_chat_session(self.hass, user_input) as chat_session: + async with ( + chat_session.async_get_chat_session( + self.hass, user_input.conversation_id + ) as session, + async_get_chat_log(self.hass, session, user_input) as chat_log, + ): # Check if a trigger matched if trigger_result := await self.async_recognize_sentence_trigger( user_input @@ -373,7 +379,7 @@ class DefaultAgent(ConversationEntity): ) speech: str = response.speech.get("plain", {}).get("speech", "") - chat_session.async_add_message( + chat_log.async_add_message( Content( role="assistant", agent_id=user_input.agent_id, @@ -382,7 +388,7 @@ class DefaultAgent(ConversationEntity): ) return ConversationResult( - response=response, conversation_id=chat_session.conversation_id + response=response, conversation_id=session.conversation_id ) async def _async_process_intent_result( diff --git a/homeassistant/components/conversation/session.py b/homeassistant/components/conversation/session.py index 43f4cbf427c..c23d6575d6c 100644 --- a/homeassistant/components/conversation/session.py +++ b/homeassistant/components/conversation/session.py @@ -5,25 +5,16 @@ from __future__ import annotations from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from dataclasses import dataclass, field, replace -from datetime import datetime, timedelta +from datetime import datetime import logging from typing import Literal import voluptuous as vol -from homeassistant.const import EVENT_HOMEASSISTANT_STOP -from homeassistant.core import ( - CALLBACK_TYPE, - Event, - HassJob, - HassJobType, - HomeAssistant, - callback, -) +from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError, TemplateError -from homeassistant.helpers import intent, llm, template -from homeassistant.helpers.event import async_call_later -from homeassistant.util import dt as dt_util, ulid as ulid_util +from homeassistant.helpers import chat_session, intent, llm, template +from homeassistant.util import dt as dt_util from homeassistant.util.hass_dict import HassKey from homeassistant.util.json import JsonObjectType @@ -31,100 +22,36 @@ from . import trace from .const import DOMAIN from .models import ConversationInput, ConversationResult -DATA_CHAT_HISTORY: HassKey[dict[str, ChatSession]] = HassKey( - "conversation_chat_session" -) -DATA_CHAT_HISTORY_CLEANUP: HassKey[SessionCleanup] = HassKey( - "conversation_chat_session_cleanup" -) +DATA_CHAT_HISTORY: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_log") LOGGER = logging.getLogger(__name__) -CONVERSATION_TIMEOUT = timedelta(minutes=5) - - -class SessionCleanup: - """Helper to clean up the history.""" - - unsub: CALLBACK_TYPE | None = None - - def __init__(self, hass: HomeAssistant) -> None: - """Initialize the history cleanup.""" - self.hass = hass - hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self._on_hass_stop) - self.cleanup_job = HassJob( - self._cleanup, "conversation_history_cleanup", job_type=HassJobType.Callback - ) - - @callback - def schedule(self) -> None: - """Schedule the cleanup.""" - if self.unsub: - return - self.unsub = async_call_later( - self.hass, - CONVERSATION_TIMEOUT.total_seconds() + 1, - self.cleanup_job, - ) - - @callback - def _on_hass_stop(self, event: Event) -> None: - """Cancel the cleanup on shutdown.""" - if self.unsub: - self.unsub() - self.unsub = None - - @callback - def _cleanup(self, now: datetime) -> None: - """Clean up the history and schedule follow-up if necessary.""" - self.unsub = None - all_history = self.hass.data[DATA_CHAT_HISTORY] - - # We mutate original object because current commands could be - # yielding history based on it. - for conversation_id, history in list(all_history.items()): - if history.last_updated + CONVERSATION_TIMEOUT < now: - del all_history[conversation_id] - - # Still conversations left, check again in timeout time. - if all_history: - self.schedule() @asynccontextmanager -async def async_get_chat_session( +async def async_get_chat_log( hass: HomeAssistant, + session: chat_session.ChatSession, user_input: ConversationInput, -) -> AsyncGenerator[ChatSession]: - """Return chat session.""" +) -> AsyncGenerator[ChatLog]: + """Return chat log for a specific chat session.""" all_history = hass.data.get(DATA_CHAT_HISTORY) if all_history is None: all_history = {} hass.data[DATA_CHAT_HISTORY] = all_history - hass.data[DATA_CHAT_HISTORY_CLEANUP] = SessionCleanup(hass) - history: ChatSession | None = None - - if user_input.conversation_id is None: - conversation_id = ulid_util.ulid_now() - - elif history := all_history.get(user_input.conversation_id): - conversation_id = user_input.conversation_id - - else: - # Conversation IDs are ULIDs. We generate a new one if not provided. - # If an old OLID is passed in, we will generate a new one to indicate - # a new conversation was started. If the user picks their own, they - # want to track a conversation and we respect it. - try: - ulid_util.ulid_to_bytes(user_input.conversation_id) - conversation_id = ulid_util.ulid_now() - except ValueError: - conversation_id = user_input.conversation_id + history = all_history.get(session.conversation_id) if history: history = replace(history, messages=history.messages.copy()) else: - history = ChatSession(hass, conversation_id, user_input.agent_id) + history = ChatLog(hass, session.conversation_id, user_input.agent_id) + + @callback + def do_cleanup() -> None: + """Handle cleanup.""" + all_history.pop(session.conversation_id) + + session.async_on_cleanup(do_cleanup) message: Content = Content( role="user", @@ -142,8 +69,7 @@ async def async_get_chat_session( return history.last_updated = dt_util.utcnow() - all_history[conversation_id] = history - hass.data[DATA_CHAT_HISTORY_CLEANUP].schedule() + all_history[session.conversation_id] = history class ConverseError(HomeAssistantError): @@ -187,8 +113,8 @@ class NativeContent[_NativeT]: @dataclass -class ChatSession[_NativeT]: - """Class holding all information for a specific conversation.""" +class ChatLog[_NativeT]: + """Class holding the chat history of a specific conversation.""" hass: HomeAssistant conversation_id: str diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py index db2df9cddd3..d4982797e22 100644 --- a/homeassistant/components/google_generative_ai_conversation/conversation.py +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -18,7 +18,7 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import device_registry as dr, intent, llm +from homeassistant.helpers import chat_session, device_registry as dr, intent, llm from homeassistant.helpers.entity_platform import AddEntitiesCallback from .const import ( @@ -209,15 +209,18 @@ class GoogleGenerativeAIConversationEntity( self, user_input: conversation.ConversationInput ) -> conversation.ConversationResult: """Process a sentence.""" - async with conversation.async_get_chat_session( - self.hass, user_input - ) as session: - return await self._async_handle_message(user_input, session) + async with ( + chat_session.async_get_chat_session( + self.hass, user_input.conversation_id + ) as session, + conversation.async_get_chat_log(self.hass, session, user_input) as chat_log, + ): + return await self._async_handle_message(user_input, chat_log) async def _async_handle_message( self, user_input: conversation.ConversationInput, - session: conversation.ChatSession[genai_types.ContentDict], + session: conversation.ChatLog[genai_types.ContentDict], ) -> conversation.ConversationResult: """Call the API.""" diff --git a/homeassistant/components/openai_conversation/conversation.py b/homeassistant/components/openai_conversation/conversation.py index 2f35bea97e2..330a6fe9a34 100644 --- a/homeassistant/components/openai_conversation/conversation.py +++ b/homeassistant/components/openai_conversation/conversation.py @@ -23,7 +23,7 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import device_registry as dr, intent, llm +from homeassistant.helpers import chat_session, device_registry as dr, intent, llm from homeassistant.helpers.entity_platform import AddEntitiesCallback from . import OpenAIConfigEntry @@ -155,15 +155,18 @@ class OpenAIConversationEntity( self, user_input: conversation.ConversationInput ) -> conversation.ConversationResult: """Process a sentence.""" - async with conversation.async_get_chat_session( - self.hass, user_input - ) as session: - return await self._async_handle_message(user_input, session) + async with ( + chat_session.async_get_chat_session( + self.hass, user_input.conversation_id + ) as session, + conversation.async_get_chat_log(self.hass, session, user_input) as chat_log, + ): + return await self._async_handle_message(user_input, chat_log) async def _async_handle_message( self, user_input: conversation.ConversationInput, - session: conversation.ChatSession[ChatCompletionMessageParam], + session: conversation.ChatLog[ChatCompletionMessageParam], ) -> conversation.ConversationResult: """Call the API.""" assert user_input.agent_id diff --git a/homeassistant/helpers/chat_session.py b/homeassistant/helpers/chat_session.py new file mode 100644 index 00000000000..4cfa91bc555 --- /dev/null +++ b/homeassistant/helpers/chat_session.py @@ -0,0 +1,160 @@ +"""Helper to organize chat sessions between integrations.""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from contextvars import ContextVar +from dataclasses import dataclass, field +from datetime import datetime, timedelta + +from homeassistant.const import EVENT_HOMEASSISTANT_STOP +from homeassistant.core import ( + CALLBACK_TYPE, + Event, + HassJob, + HassJobType, + HomeAssistant, + callback, +) +from homeassistant.util import dt as dt_util, ulid as ulid_util +from homeassistant.util.hass_dict import HassKey + +from .event import async_call_later + +DATA_CHAT_SESSION: HassKey[dict[str, ChatSession]] = HassKey("chat_session") +DATA_CHAT_SESSION_CLEANUP: HassKey[SessionCleanup] = HassKey("chat_session_cleanup") + +CONVERSATION_TIMEOUT = timedelta(minutes=5) + +current_session: ContextVar[ChatSession | None] = ContextVar( + "current_session", default=None +) + + +@dataclass +class ChatSession: + """Represent a chat session.""" + + conversation_id: str + last_updated: datetime = field(default_factory=dt_util.utcnow) + _cleanup_callbacks: list[CALLBACK_TYPE] = field(default_factory=list) + + @callback + def async_updated(self) -> None: + """Update the last updated time.""" + self.last_updated = dt_util.utcnow() + + @callback + def async_on_cleanup(self, cb: CALLBACK_TYPE) -> None: + """Register a callback to clean up the session.""" + self._cleanup_callbacks.append(cb) + + @callback + def async_cleanup(self) -> None: + """Call all clean up callbacks.""" + for cb in self._cleanup_callbacks: + cb() + self._cleanup_callbacks.clear() + + +class SessionCleanup: + """Helper to clean up the stale sessions.""" + + unsub: CALLBACK_TYPE | None = None + + def __init__(self, hass: HomeAssistant) -> None: + """Initialize the session cleanup.""" + self.hass = hass + hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self._on_hass_stop) + self.cleanup_job = HassJob( + self._cleanup, "chat_session_cleanup", job_type=HassJobType.Callback + ) + + @callback + def schedule(self) -> None: + """Schedule the cleanup.""" + if self.unsub: + return + self.unsub = async_call_later( + self.hass, + CONVERSATION_TIMEOUT.total_seconds() + 1, + self.cleanup_job, + ) + + @callback + def _on_hass_stop(self, event: Event) -> None: + """Cancel the cleanup on shutdown.""" + if self.unsub: + self.unsub() + self.unsub = None + + @callback + def _cleanup(self, now: datetime) -> None: + """Clean up the history and schedule follow-up if necessary.""" + self.unsub = None + all_sessions = self.hass.data[DATA_CHAT_SESSION] + + # We mutate original object because current commands could be + # yielding session based on it. + for conversation_id, session in list(all_sessions.items()): + if session.last_updated + CONVERSATION_TIMEOUT < now: + del all_sessions[conversation_id] + session.async_cleanup() + + # Still conversations left, check again in timeout time. + if all_sessions: + self.schedule() + + +@asynccontextmanager +async def async_get_chat_session( + hass: HomeAssistant, + conversation_id: str | None = None, +) -> AsyncGenerator[ChatSession]: + """Return a chat session.""" + if session := current_session.get(): + # If a session is already active and it's the requested conversation ID, + # return that. We won't update the last updated time in this case. + if session.conversation_id == conversation_id: + yield session + return + + # If it's not the same conversation ID, we will create a new session + # because it might be a conversation agent calling a tool that is talking + # to another LLM. + session = None + + all_sessions = hass.data.get(DATA_CHAT_SESSION) + if all_sessions is None: + all_sessions = {} + hass.data[DATA_CHAT_SESSION] = all_sessions + hass.data[DATA_CHAT_SESSION_CLEANUP] = SessionCleanup(hass) + + if conversation_id is None: + conversation_id = ulid_util.ulid_now() + + elif conversation_id in all_sessions: + session = all_sessions[conversation_id] + + else: + # Conversation IDs are ULIDs. We generate a new one if not provided. + # If an old ULID is passed in, we will generate a new one to indicate + # a new conversation was started. If the user picks their own, they + # want to track a conversation and we respect it. + try: + ulid_util.ulid_to_bytes(conversation_id) + conversation_id = ulid_util.ulid_now() + except ValueError: + pass + + if session is None: + session = ChatSession(conversation_id) + + current_session.set(session) + yield session + current_session.set(None) + + session.last_updated = dt_util.utcnow() + all_sessions[conversation_id] = session + hass.data[DATA_CHAT_SESSION_CLEANUP].schedule() diff --git a/tests/components/conversation/test_session.py b/tests/components/conversation/test_session.py index 60c7f2957b8..4a9c662fffa 100644 --- a/tests/components/conversation/test_session.py +++ b/tests/components/conversation/test_session.py @@ -8,10 +8,17 @@ import pytest from syrupy.assertion import SnapshotAssertion import voluptuous as vol -from homeassistant.components.conversation import ConversationInput, session +from homeassistant.components.conversation import ( + Content, + ConversationInput, + ConverseError, + NativeContent, + async_get_chat_log, +) +from homeassistant.components.conversation.session import DATA_CHAT_HISTORY from homeassistant.core import Context, HomeAssistant from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import llm +from homeassistant.helpers import chat_session, llm from homeassistant.util import dt as dt_util from tests.common import async_fire_time_changed @@ -38,127 +45,69 @@ def mock_ulid() -> Generator[Mock]: yield mock_ulid_now -@pytest.mark.parametrize( - ("start_id", "given_id"), - [ - (None, "mock-ulid"), - # This ULID is not known as a session - ("01JHXE0952TSJCFJZ869AW6HMD", "mock-ulid"), - ("not-a-ulid", "not-a-ulid"), - ], -) -async def test_conversation_id( - hass: HomeAssistant, - mock_conversation_input: ConversationInput, - mock_ulid: Mock, - start_id: str | None, - given_id: str, -) -> None: - """Test conversation ID generation.""" - mock_conversation_input.conversation_id = start_id - async with session.async_get_chat_session( - hass, mock_conversation_input - ) as chat_session: - assert chat_session.conversation_id == given_id - - async def test_cleanup( hass: HomeAssistant, mock_conversation_input: ConversationInput, ) -> None: - """Mock cleanup of the conversation session.""" - async with session.async_get_chat_session( - hass, mock_conversation_input - ) as chat_session: - assert len(chat_session.messages) == 2 - conversation_id = chat_session.conversation_id - - # Generate session entry. - async with session.async_get_chat_session( - hass, mock_conversation_input - ) as chat_session: - # Because we didn't add a message to the session in the last block, - # the conversation was not be persisted and we get a new ID - assert chat_session.conversation_id != conversation_id - conversation_id = chat_session.conversation_id - chat_session.async_add_message( - session.Content( + """Test cleanup of the chat log.""" + async with ( + chat_session.async_get_chat_session(hass) as session, + async_get_chat_log(hass, session, mock_conversation_input) as chat_log, + ): + conversation_id = session.conversation_id + # Add message so it persists + chat_log.async_add_message( + Content( role="assistant", - agent_id="mock-agent-id", - content="Hey!", + agent_id=mock_conversation_input.agent_id, + content="", ) ) - assert len(chat_session.messages) == 3 - # Reuse conversation ID to ensure we can chat with same session - mock_conversation_input.conversation_id = conversation_id - async with session.async_get_chat_session( - hass, mock_conversation_input - ) as chat_session: - assert len(chat_session.messages) == 4 - assert chat_session.conversation_id == conversation_id + assert conversation_id in hass.data[DATA_CHAT_HISTORY] # Set the last updated to be older than the timeout - hass.data[session.DATA_CHAT_HISTORY][conversation_id].last_updated = ( - dt_util.utcnow() + session.CONVERSATION_TIMEOUT + hass.data[chat_session.DATA_CHAT_SESSION][conversation_id].last_updated = ( + dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT ) async_fire_time_changed( - hass, dt_util.utcnow() + session.CONVERSATION_TIMEOUT + timedelta(seconds=1) + hass, + dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT * 2 + timedelta(seconds=1), ) - # Should not be cleaned up, but it should have scheduled another cleanup - mock_conversation_input.conversation_id = conversation_id - async with session.async_get_chat_session( - hass, mock_conversation_input - ) as chat_session: - assert len(chat_session.messages) == 4 - assert chat_session.conversation_id == conversation_id - - async_fire_time_changed( - hass, dt_util.utcnow() + session.CONVERSATION_TIMEOUT * 2 + timedelta(seconds=1) - ) - - # It should be cleaned up now and we start a new conversation - async with session.async_get_chat_session( - hass, mock_conversation_input - ) as chat_session: - assert chat_session.conversation_id != conversation_id - assert len(chat_session.messages) == 2 + assert conversation_id not in hass.data[DATA_CHAT_HISTORY] async def test_add_message( hass: HomeAssistant, mock_conversation_input: ConversationInput ) -> None: """Test filtering of messages.""" - async with session.async_get_chat_session( - hass, mock_conversation_input - ) as chat_session: - assert len(chat_session.messages) == 2 + async with ( + chat_session.async_get_chat_session(hass) as session, + async_get_chat_log(hass, session, mock_conversation_input) as chat_log, + ): + assert len(chat_log.messages) == 2 with pytest.raises(ValueError): - chat_session.async_add_message( - session.Content(role="system", agent_id=None, content="") + chat_log.async_add_message( + Content(role="system", agent_id=None, content="") ) # No 2 user messages in a row - assert chat_session.messages[1].role == "user" + assert chat_log.messages[1].role == "user" with pytest.raises(ValueError): - chat_session.async_add_message( - session.Content(role="user", agent_id=None, content="") - ) + chat_log.async_add_message(Content(role="user", agent_id=None, content="")) # No 2 assistant messages in a row - chat_session.async_add_message( - session.Content(role="assistant", agent_id=None, content="") - ) - assert len(chat_session.messages) == 3 - assert chat_session.messages[-1].role == "assistant" + chat_log.async_add_message(Content(role="assistant", agent_id=None, content="")) + assert len(chat_log.messages) == 3 + assert chat_log.messages[-1].role == "assistant" with pytest.raises(ValueError): - chat_session.async_add_message( - session.Content(role="assistant", agent_id=None, content="") + chat_log.async_add_message( + Content(role="assistant", agent_id=None, content="") ) @@ -166,66 +115,65 @@ async def test_message_filtering( hass: HomeAssistant, mock_conversation_input: ConversationInput ) -> None: """Test filtering of messages.""" - async with session.async_get_chat_session( - hass, mock_conversation_input - ) as chat_session: - messages = chat_session.async_get_messages(agent_id=None) + async with ( + chat_session.async_get_chat_session(hass) as session, + async_get_chat_log(hass, session, mock_conversation_input) as chat_log, + ): + messages = chat_log.async_get_messages(agent_id=None) assert len(messages) == 2 - assert messages[0] == session.Content( + assert messages[0] == Content( role="system", agent_id=None, content="", ) - assert messages[1] == session.Content( + assert messages[1] == Content( role="user", agent_id="mock-agent-id", content=mock_conversation_input.text, ) # Cannot add a second user message in a row with pytest.raises(ValueError): - chat_session.async_add_message( - session.Content( + chat_log.async_add_message( + Content( role="user", agent_id="mock-agent-id", content="Hey!", ) ) - chat_session.async_add_message( - session.Content( + chat_log.async_add_message( + Content( role="assistant", agent_id="mock-agent-id", content="Hey!", ) ) # Different agent, native messages will be filtered out. - chat_session.async_add_message( - session.NativeContent(agent_id="another-mock-agent-id", content=1) - ) - chat_session.async_add_message( - session.NativeContent(agent_id="mock-agent-id", content=1) + chat_log.async_add_message( + NativeContent(agent_id="another-mock-agent-id", content=1) ) + chat_log.async_add_message(NativeContent(agent_id="mock-agent-id", content=1)) # A non-native message from another agent is not filtered out. - chat_session.async_add_message( - session.Content( + chat_log.async_add_message( + Content( role="assistant", agent_id="another-mock-agent-id", content="Hi!", ) ) - assert len(chat_session.messages) == 6 + assert len(chat_log.messages) == 6 - messages = chat_session.async_get_messages(agent_id="mock-agent-id") + messages = chat_log.async_get_messages(agent_id="mock-agent-id") assert len(messages) == 5 - assert messages[2] == session.Content( + assert messages[2] == Content( role="assistant", agent_id="mock-agent-id", content="Hey!", ) - assert messages[3] == session.NativeContent(agent_id="mock-agent-id", content=1) - assert messages[4] == session.Content( + assert messages[3] == NativeContent(agent_id="mock-agent-id", content=1) + assert messages[4] == Content( role="assistant", agent_id="another-mock-agent-id", content="Hi!" ) @@ -235,18 +183,19 @@ async def test_llm_api( mock_conversation_input: ConversationInput, ) -> None: """Test when we reference an LLM API.""" - async with session.async_get_chat_session( - hass, mock_conversation_input - ) as chat_session: - await chat_session.async_update_llm_data( + async with ( + chat_session.async_get_chat_session(hass) as session, + async_get_chat_log(hass, session, mock_conversation_input) as chat_log, + ): + await chat_log.async_update_llm_data( conversing_domain="test", user_input=mock_conversation_input, user_llm_hass_api="assist", user_llm_prompt=None, ) - assert isinstance(chat_session.llm_api, llm.APIInstance) - assert chat_session.llm_api.api.id == "assist" + assert isinstance(chat_log.llm_api, llm.APIInstance) + assert chat_log.llm_api.api.id == "assist" async def test_unknown_llm_api( @@ -255,11 +204,12 @@ async def test_unknown_llm_api( snapshot: SnapshotAssertion, ) -> None: """Test when we reference an LLM API that does not exists.""" - async with session.async_get_chat_session( - hass, mock_conversation_input - ) as chat_session: - with pytest.raises(session.ConverseError) as exc_info: - await chat_session.async_update_llm_data( + async with ( + chat_session.async_get_chat_session(hass) as session, + async_get_chat_log(hass, session, mock_conversation_input) as chat_log, + ): + with pytest.raises(ConverseError) as exc_info: + await chat_log.async_update_llm_data( conversing_domain="test", user_input=mock_conversation_input, user_llm_hass_api="unknown-api", @@ -276,11 +226,12 @@ async def test_template_error( snapshot: SnapshotAssertion, ) -> None: """Test that template error handling works.""" - async with session.async_get_chat_session( - hass, mock_conversation_input - ) as chat_session: - with pytest.raises(session.ConverseError) as exc_info: - await chat_session.async_update_llm_data( + async with ( + chat_session.async_get_chat_session(hass) as session, + async_get_chat_log(hass, session, mock_conversation_input) as chat_log, + ): + with pytest.raises(ConverseError) as exc_info: + await chat_log.async_update_llm_data( conversing_domain="test", user_input=mock_conversation_input, user_llm_hass_api=None, @@ -300,13 +251,14 @@ async def test_template_variables( mock_user.name = "Test User" mock_conversation_input.context = Context(user_id=mock_user.id) - async with session.async_get_chat_session( - hass, mock_conversation_input - ) as chat_session: + async with ( + chat_session.async_get_chat_session(hass) as session, + async_get_chat_log(hass, session, mock_conversation_input) as chat_log, + ): with patch( "homeassistant.auth.AuthManager.async_get_user", return_value=mock_user ): - await chat_session.async_update_llm_data( + await chat_log.async_update_llm_data( conversing_domain="test", user_input=mock_conversation_input, user_llm_hass_api=None, @@ -318,12 +270,12 @@ async def test_template_variables( ), ) - assert chat_session.user_name == "Test User" + assert chat_log.user_name == "Test User" - assert "The instance name is test home." in chat_session.messages[0].content - assert "The user name is Test User." in chat_session.messages[0].content - assert "The user id is 12345." in chat_session.messages[0].content - assert "The calling platform is test." in chat_session.messages[0].content + assert "The instance name is test home." in chat_log.messages[0].content + assert "The user name is Test User." in chat_log.messages[0].content + assert "The user id is 12345." in chat_log.messages[0].content + assert "The calling platform is test." in chat_log.messages[0].content async def test_extra_systen_prompt( @@ -336,82 +288,86 @@ async def test_extra_systen_prompt( ) mock_conversation_input.extra_system_prompt = extra_system_prompt - async with session.async_get_chat_session( - hass, mock_conversation_input - ) as chat_session: - await chat_session.async_update_llm_data( + async with ( + chat_session.async_get_chat_session(hass) as session, + async_get_chat_log(hass, session, mock_conversation_input) as chat_log, + ): + await chat_log.async_update_llm_data( conversing_domain="test", user_input=mock_conversation_input, user_llm_hass_api=None, user_llm_prompt=None, ) - chat_session.async_add_message( - session.Content( + chat_log.async_add_message( + Content( role="assistant", agent_id="mock-agent-id", content="Hey!", ) ) - assert chat_session.extra_system_prompt == extra_system_prompt - assert chat_session.messages[0].content.endswith(extra_system_prompt) + assert chat_log.extra_system_prompt == extra_system_prompt + assert chat_log.messages[0].content.endswith(extra_system_prompt) # Verify that follow-up conversations with no system prompt take previous one - mock_conversation_input.conversation_id = chat_session.conversation_id + conversation_id = chat_log.conversation_id mock_conversation_input.extra_system_prompt = None - async with session.async_get_chat_session( - hass, mock_conversation_input - ) as chat_session: - await chat_session.async_update_llm_data( + async with ( + chat_session.async_get_chat_session(hass, conversation_id) as session, + async_get_chat_log(hass, session, mock_conversation_input) as chat_log, + ): + await chat_log.async_update_llm_data( conversing_domain="test", user_input=mock_conversation_input, user_llm_hass_api=None, user_llm_prompt=None, ) - assert chat_session.extra_system_prompt == extra_system_prompt - assert chat_session.messages[0].content.endswith(extra_system_prompt) + assert chat_log.extra_system_prompt == extra_system_prompt + assert chat_log.messages[0].content.endswith(extra_system_prompt) # Verify that we take new system prompts mock_conversation_input.extra_system_prompt = extra_system_prompt2 - async with session.async_get_chat_session( - hass, mock_conversation_input - ) as chat_session: - await chat_session.async_update_llm_data( + async with ( + chat_session.async_get_chat_session(hass, conversation_id) as session, + async_get_chat_log(hass, session, mock_conversation_input) as chat_log, + ): + await chat_log.async_update_llm_data( conversing_domain="test", user_input=mock_conversation_input, user_llm_hass_api=None, user_llm_prompt=None, ) - chat_session.async_add_message( - session.Content( + chat_log.async_add_message( + Content( role="assistant", agent_id="mock-agent-id", content="Hey!", ) ) - assert chat_session.extra_system_prompt == extra_system_prompt2 - assert chat_session.messages[0].content.endswith(extra_system_prompt2) - assert extra_system_prompt not in chat_session.messages[0].content + assert chat_log.extra_system_prompt == extra_system_prompt2 + assert chat_log.messages[0].content.endswith(extra_system_prompt2) + assert extra_system_prompt not in chat_log.messages[0].content # Verify that follow-up conversations with no system prompt take previous one mock_conversation_input.extra_system_prompt = None - async with session.async_get_chat_session( - hass, mock_conversation_input - ) as chat_session: - await chat_session.async_update_llm_data( + async with ( + chat_session.async_get_chat_session(hass, conversation_id) as session, + async_get_chat_log(hass, session, mock_conversation_input) as chat_log, + ): + await chat_log.async_update_llm_data( conversing_domain="test", user_input=mock_conversation_input, user_llm_hass_api=None, user_llm_prompt=None, ) - assert chat_session.extra_system_prompt == extra_system_prompt2 - assert chat_session.messages[0].content.endswith(extra_system_prompt2) + assert chat_log.extra_system_prompt == extra_system_prompt2 + assert chat_log.messages[0].content.endswith(extra_system_prompt2) async def test_tool_call( @@ -434,16 +390,17 @@ async def test_tool_call( ) as mock_get_tools: mock_get_tools.return_value = [mock_tool] - async with session.async_get_chat_session( - hass, mock_conversation_input - ) as chat_session: - await chat_session.async_update_llm_data( + async with ( + chat_session.async_get_chat_session(hass) as session, + async_get_chat_log(hass, session, mock_conversation_input) as chat_log, + ): + await chat_log.async_update_llm_data( conversing_domain="test", user_input=mock_conversation_input, user_llm_hass_api="assist", user_llm_prompt=None, ) - result = await chat_session.async_call_tool( + result = await chat_log.async_call_tool( llm.ToolInput( tool_name="test_tool", tool_args={"param1": "Test Param"}, @@ -473,16 +430,17 @@ async def test_tool_call_exception( ) as mock_get_tools: mock_get_tools.return_value = [mock_tool] - async with session.async_get_chat_session( - hass, mock_conversation_input - ) as chat_session: - await chat_session.async_update_llm_data( + async with ( + chat_session.async_get_chat_session(hass) as session, + async_get_chat_log(hass, session, mock_conversation_input) as chat_log, + ): + await chat_log.async_update_llm_data( conversing_domain="test", user_input=mock_conversation_input, user_llm_hass_api="assist", user_llm_prompt=None, ) - result = await chat_session.async_call_tool( + result = await chat_log.async_call_tool( llm.ToolInput( tool_name="test_tool", tool_args={"param1": "Test Param"}, diff --git a/tests/helpers/test_chat_session.py b/tests/helpers/test_chat_session.py new file mode 100644 index 00000000000..a11f4126886 --- /dev/null +++ b/tests/helpers/test_chat_session.py @@ -0,0 +1,98 @@ +"""Test the chat session helper.""" + +from collections.abc import Generator +from datetime import timedelta +from unittest.mock import Mock, patch + +import pytest + +from homeassistant.core import HomeAssistant +from homeassistant.helpers import chat_session +from homeassistant.util import dt as dt_util, ulid as ulid_util + +from tests.common import async_fire_time_changed + + +@pytest.fixture +def mock_ulid() -> Generator[Mock]: + """Mock the ulid library.""" + with patch("homeassistant.util.ulid.ulid_now") as mock_ulid_now: + mock_ulid_now.return_value = "mock-ulid" + yield mock_ulid_now + + +@pytest.mark.parametrize( + ("start_id", "given_id"), + [ + (None, "mock-ulid"), + # This ULID is not known as a session + ("01JHXE0952TSJCFJZ869AW6HMD", "mock-ulid"), + ("not-a-ulid", "not-a-ulid"), + ], +) +async def test_conversation_id( + hass: HomeAssistant, + start_id: str | None, + given_id: str, + mock_ulid: Mock, +) -> None: + """Test conversation ID generation.""" + async with chat_session.async_get_chat_session(hass, start_id) as session: + assert session.conversation_id == given_id + + +async def test_context_var(hass: HomeAssistant) -> None: + """Test context var.""" + async with chat_session.async_get_chat_session(hass) as session: + async with chat_session.async_get_chat_session( + hass, session.conversation_id + ) as session2: + assert session is session2 + + async with chat_session.async_get_chat_session(hass, None) as session2: + assert session.conversation_id != session2.conversation_id + + async with chat_session.async_get_chat_session( + hass, "something else" + ) as session2: + assert session.conversation_id != session2.conversation_id + + async with chat_session.async_get_chat_session( + hass, ulid_util.ulid_now() + ) as session2: + assert session.conversation_id != session2.conversation_id + + +async def test_cleanup( + hass: HomeAssistant, +) -> None: + """Test cleanup of the chat session.""" + async with chat_session.async_get_chat_session(hass) as session: + conversation_id = session.conversation_id + + # Reuse conversation ID to ensure we can chat with same session + async with chat_session.async_get_chat_session(hass, conversation_id) as session: + assert session.conversation_id == conversation_id + + # Set the last updated to be older than the timeout + hass.data[chat_session.DATA_CHAT_SESSION][conversation_id].last_updated = ( + dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT + ) + + async_fire_time_changed( + hass, + dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT + timedelta(seconds=1), + ) + + # Should not be cleaned up, but it should have scheduled another cleanup + async with chat_session.async_get_chat_session(hass, conversation_id) as session: + assert session.conversation_id == conversation_id + + async_fire_time_changed( + hass, + dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT * 2 + timedelta(seconds=1), + ) + + # It should be cleaned up now and we start a new conversation + async with chat_session.async_get_chat_session(hass, conversation_id) as session: + assert session.conversation_id != conversation_id