From 9679fc787851dab28742c9476d7917ceacbdb463 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 3 Feb 2025 00:05:20 -0500 Subject: [PATCH] Chat session rev2 (#137209) * Chat Session rev 2 * Rename session to chat_log * Simplify typing * Typing * Address comments * Fix anthropic and ollama --- .../components/anthropic/conversation.py | 1 + .../components/assist_pipeline/pipeline.py | 14 +- .../components/conversation/__init__.py | 16 +- .../conversation/{session.py => chat_log.py} | 151 +++++++------ .../components/conversation/default_agent.py | 12 +- .../conversation.py | 157 ++++++++----- .../openai_conversation/conversation.py | 121 +++++----- homeassistant/helpers/llm.py | 5 +- .../components/anthropic/test_conversation.py | 2 + .../{test_session.ambr => test_chat_log.ambr} | 0 .../{test_session.py => test_chat_log.py} | 212 +++++++----------- .../test_conversation.py | 16 +- tests/components/ollama/test_conversation.py | 9 + .../openai_conversation/test_conversation.py | 2 + 14 files changed, 388 insertions(+), 330 deletions(-) rename homeassistant/components/conversation/{session.py => chat_log.py} (68%) rename tests/components/conversation/snapshots/{test_session.ambr => test_chat_log.ambr} (100%) rename tests/components/conversation/{test_session.py => test_chat_log.py} (67%) diff --git a/homeassistant/components/anthropic/conversation.py b/homeassistant/components/anthropic/conversation.py index e45e849adf6..259d1295809 100644 --- a/homeassistant/components/anthropic/conversation.py +++ b/homeassistant/components/anthropic/conversation.py @@ -272,6 +272,7 @@ class AnthropicConversationEntity( continue tool_input = llm.ToolInput( + id=tool_call.id, tool_name=tool_call.name, tool_args=cast(dict[str, Any], tool_call.input), ) diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index cfc7261410a..c5f9098623a 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -1063,11 +1063,11 @@ class PipelineRun: agent_id=self.intent_agent, extra_system_prompt=conversation_extra_system_prompt, ) - processed_locally = self.intent_agent == conversation.HOME_ASSISTANT_AGENT - agent_id = user_input.agent_id + agent_id = self.intent_agent + processed_locally = agent_id == conversation.HOME_ASSISTANT_AGENT intent_response: intent.IntentResponse | None = None - if user_input.agent_id != conversation.HOME_ASSISTANT_AGENT: + if not processed_locally: # Sentence triggers override conversation agent if ( trigger_response_text @@ -1105,13 +1105,13 @@ class PipelineRun: speech: str = intent_response.speech.get("plain", {}).get( "speech", "" ) - chat_log.async_add_message( - conversation.Content( - role="assistant", + async for _ in chat_log.async_add_assistant_content( + conversation.AssistantContent( agent_id=agent_id, content=speech, ) - ) + ): + pass conversation_result = conversation.ConversationResult( response=intent_response, conversation_id=session.conversation_id, diff --git a/homeassistant/components/conversation/__init__.py b/homeassistant/components/conversation/__init__.py index 13152beff51..69e738205c5 100644 --- a/homeassistant/components/conversation/__init__.py +++ b/homeassistant/components/conversation/__init__.py @@ -30,6 +30,16 @@ from .agent_manager import ( async_get_agent, get_agent_manager, ) +from .chat_log import ( + AssistantContent, + ChatLog, + Content, + ConverseError, + SystemContent, + ToolResultContent, + UserContent, + async_get_chat_log, +) from .const import ( ATTR_AGENT_ID, ATTR_CONVERSATION_ID, @@ -48,13 +58,13 @@ from .default_agent import DefaultAgent, async_setup_default_agent from .entity import ConversationEntity from .http import async_setup as async_setup_conversation_http from .models import AbstractConversationAgent, ConversationInput, ConversationResult -from .session import ChatLog, Content, ConverseError, NativeContent, async_get_chat_log from .trace import ConversationTraceEventType, async_conversation_trace_append __all__ = [ "DOMAIN", "HOME_ASSISTANT_AGENT", "OLD_HOME_ASSISTANT_AGENT", + "AssistantContent", "ChatLog", "Content", "ConversationEntity", @@ -63,7 +73,9 @@ __all__ = [ "ConversationResult", "ConversationTraceEventType", "ConverseError", - "NativeContent", + "SystemContent", + "ToolResultContent", + "UserContent", "async_conversation_trace_append", "async_converse", "async_get_agent_info", diff --git a/homeassistant/components/conversation/session.py b/homeassistant/components/conversation/chat_log.py similarity index 68% rename from homeassistant/components/conversation/session.py rename to homeassistant/components/conversation/chat_log.py index c32d61333a0..d053d114a11 100644 --- a/homeassistant/components/conversation/session.py +++ b/homeassistant/components/conversation/chat_log.py @@ -2,19 +2,16 @@ from __future__ import annotations -from collections.abc import Generator +from collections.abc import AsyncGenerator, Generator from contextlib import contextmanager from dataclasses import dataclass, field, replace -from datetime import datetime import logging -from typing import Literal import voluptuous as vol from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError, TemplateError from homeassistant.helpers import chat_session, intent, llm, template -from homeassistant.util import dt as dt_util from homeassistant.util.hass_dict import HassKey from homeassistant.util.json import JsonObjectType @@ -31,7 +28,7 @@ LOGGER = logging.getLogger(__name__) def async_get_chat_log( hass: HomeAssistant, session: chat_session.ChatSession, - user_input: ConversationInput, + user_input: ConversationInput | None = None, ) -> Generator[ChatLog]: """Return chat log for a specific chat session.""" all_history = hass.data.get(DATA_CHAT_HISTORY) @@ -42,9 +39,9 @@ def async_get_chat_log( history = all_history.get(session.conversation_id) if history: - history = replace(history, messages=history.messages.copy()) + history = replace(history, content=history.content.copy()) else: - history = ChatLog(hass, session.conversation_id, user_input.agent_id) + history = ChatLog(hass, session.conversation_id) @callback def do_cleanup() -> None: @@ -53,22 +50,19 @@ def async_get_chat_log( session.async_on_cleanup(do_cleanup) - message: Content = Content( - role="user", - agent_id=user_input.agent_id, - content=user_input.text, - ) - history.async_add_message(message) + if user_input is not None: + history.async_add_user_content(UserContent(content=user_input.text)) + + last_message = history.content[-1] yield history - if history.messages[-1] is message: + if history.content[-1] is last_message: LOGGER.debug( "History opened but no assistant message was added, ignoring update" ) return - history.last_updated = dt_util.utcnow() all_history[session.conversation_id] = history @@ -94,63 +88,94 @@ class ConverseError(HomeAssistantError): ) -@dataclass -class Content: +@dataclass(frozen=True) +class SystemContent: """Base class for chat messages.""" - role: Literal["system", "assistant", "user"] - agent_id: str | None + role: str = field(init=False, default="system") content: str @dataclass(frozen=True) -class NativeContent[_NativeT]: - """Native content.""" +class UserContent: + """Assistant content.""" - role: str = field(init=False, default="native") + role: str = field(init=False, default="user") + content: str + + +@dataclass(frozen=True) +class AssistantContent: + """Assistant content.""" + + role: str = field(init=False, default="assistant") agent_id: str - content: _NativeT + content: str + tool_calls: list[llm.ToolInput] | None = None + + +@dataclass(frozen=True) +class ToolResultContent: + """Tool result content.""" + + role: str = field(init=False, default="tool_result") + agent_id: str + tool_call_id: str + tool_name: str + tool_result: JsonObjectType + + +Content = SystemContent | UserContent | AssistantContent | ToolResultContent @dataclass -class ChatLog[_NativeT]: +class ChatLog: """Class holding the chat history of a specific conversation.""" hass: HomeAssistant conversation_id: str - agent_id: str | None - user_name: str | None = None - messages: list[Content | NativeContent[_NativeT]] = field( - default_factory=lambda: [Content(role="system", agent_id=None, content="")] - ) + content: list[Content] = field(default_factory=lambda: [SystemContent(content="")]) extra_system_prompt: str | None = None llm_api: llm.APIInstance | None = None - last_updated: datetime = field(default_factory=dt_util.utcnow) @callback - def async_add_message(self, message: Content | NativeContent[_NativeT]) -> None: - """Process intent.""" - if message.role == "system": - raise ValueError("Cannot add system messages to history") - if message.role != "native" and self.messages[-1].role == message.role: - raise ValueError("Cannot add two assistant or user messages in a row") + def async_add_user_content(self, content: UserContent) -> None: + """Add user content to the log.""" + self.content.append(content) - self.messages.append(message) + async def async_add_assistant_content( + self, content: AssistantContent + ) -> AsyncGenerator[ToolResultContent]: + """Add assistant content.""" + self.content.append(content) - @callback - def async_get_messages( - self, agent_id: str | None = None - ) -> list[Content | NativeContent[_NativeT]]: - """Get messages for a specific agent ID. + if content.tool_calls is None: + return - This will filter out any native message tied to other agent IDs. - It can still include assistant/user messages generated by other agents. - """ - return [ - message - for message in self.messages - if message.role != "native" or message.agent_id == agent_id - ] + if self.llm_api is None: + raise ValueError("No LLM API configured") + + for tool_input in content.tool_calls: + LOGGER.debug( + "Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args + ) + + try: + tool_result = await self.llm_api.async_call_tool(tool_input) + except (HomeAssistantError, vol.Invalid) as e: + tool_result = {"error": type(e).__name__} + if str(e): + tool_result["error_text"] = str(e) + LOGGER.debug("Tool response: %s", tool_result) + + response_content = ToolResultContent( + agent_id=content.agent_id, + tool_call_id=tool_input.id, + tool_name=tool_input.tool_name, + tool_result=tool_result, + ) + self.content.append(response_content) + yield response_content async def async_update_llm_data( self, @@ -250,36 +275,16 @@ class ChatLog[_NativeT]: prompt = "\n".join(prompt_parts) self.llm_api = llm_api - self.user_name = user_name self.extra_system_prompt = extra_system_prompt - self.messages[0] = Content( - role="system", - agent_id=user_input.agent_id, - content=prompt, - ) + self.content[0] = SystemContent(content=prompt) - LOGGER.debug("Prompt: %s", self.messages) + LOGGER.debug("Prompt: %s", self.content) LOGGER.debug("Tools: %s", self.llm_api.tools if self.llm_api else None) trace.async_conversation_trace_append( trace.ConversationTraceEventType.AGENT_DETAIL, { - "messages": self.messages, + "messages": self.content, "tools": self.llm_api.tools if self.llm_api else None, }, ) - - async def async_call_tool(self, tool_input: llm.ToolInput) -> JsonObjectType: - """Invoke LLM tool for the configured LLM API.""" - if not self.llm_api: - raise ValueError("No LLM API configured") - LOGGER.debug("Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args) - - try: - tool_response = await self.llm_api.async_call_tool(tool_input) - except (HomeAssistantError, vol.Invalid) as e: - tool_response = {"error": type(e).__name__} - if str(e): - tool_response["error_text"] = str(e) - LOGGER.debug("Tool response: %s", tool_response) - return tool_response diff --git a/homeassistant/components/conversation/default_agent.py b/homeassistant/components/conversation/default_agent.py index c4a8f7ea7eb..5e1709c0404 100644 --- a/homeassistant/components/conversation/default_agent.py +++ b/homeassistant/components/conversation/default_agent.py @@ -55,6 +55,7 @@ from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.event import async_track_state_added_domain from homeassistant.util.json import JsonObjectType, json_loads_object +from .chat_log import AssistantContent, async_get_chat_log from .const import ( DATA_DEFAULT_ENTITY, DEFAULT_EXPOSED_ATTRIBUTES, @@ -63,7 +64,6 @@ from .const import ( ) from .entity import ConversationEntity from .models import ConversationInput, ConversationResult -from .session import Content, async_get_chat_log from .trace import ConversationTraceEventType, async_conversation_trace_append _LOGGER = logging.getLogger(__name__) @@ -379,13 +379,13 @@ class DefaultAgent(ConversationEntity): ) speech: str = response.speech.get("plain", {}).get("speech", "") - chat_log.async_add_message( - Content( - role="assistant", - agent_id=user_input.agent_id, + async for _tool_result in chat_log.async_add_assistant_content( + AssistantContent( + agent_id=user_input.agent_id, # type: ignore[arg-type] content=speech, ) - ) + ): + pass return ConversationResult( response=response, conversation_id=session.conversation_id diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py index 53ee4e1f880..8a6c5563601 100644 --- a/homeassistant/components/google_generative_ai_conversation/conversation.py +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -4,7 +4,7 @@ from __future__ import annotations import codecs from collections.abc import Callable -from typing import Any, Literal +from typing import Any, Literal, cast from google.api_core.exceptions import GoogleAPIError import google.generativeai as genai @@ -149,15 +149,53 @@ def _escape_decode(value: Any) -> Any: return value -def _chat_message_convert( - message: conversation.Content | conversation.NativeContent[genai_types.ContentDict], -) -> genai_types.ContentDict: - """Convert any native chat message for this agent to the native format.""" - if message.role == "native": - return message.content +def _create_google_tool_response_content( + content: list[conversation.ToolResultContent], +) -> protos.Content: + """Create a Google tool response content.""" + return protos.Content( + parts=[ + protos.Part( + function_response=protos.FunctionResponse( + name=tool_result.tool_name, response=tool_result.tool_result + ) + ) + for tool_result in content + ] + ) - role = "model" if message.role == "assistant" else message.role - return {"role": role, "parts": message.content} + +def _convert_content( + content: conversation.UserContent + | conversation.AssistantContent + | conversation.SystemContent, +) -> genai_types.ContentDict: + """Convert HA content to Google content.""" + if content.role != "assistant" or not content.tool_calls: # type: ignore[union-attr] + role = "model" if content.role == "assistant" else content.role + return {"role": role, "parts": content.content} + + # Handle the Assistant content with tool calls. + assert type(content) is conversation.AssistantContent + parts = [] + + if content.content: + parts.append(protos.Part(text=content.content)) + + if content.tool_calls: + parts.extend( + [ + protos.Part( + function_call=protos.FunctionCall( + name=tool_call.tool_name, + args=_escape_decode(tool_call.tool_args), + ) + ) + for tool_call in content.tool_calls + ] + ) + + return protos.Content({"role": "model", "parts": parts}) class GoogleGenerativeAIConversationEntity( @@ -220,7 +258,7 @@ class GoogleGenerativeAIConversationEntity( async def _async_handle_message( self, user_input: conversation.ConversationInput, - session: conversation.ChatLog[genai_types.ContentDict], + chat_log: conversation.ChatLog, ) -> conversation.ConversationResult: """Call the API.""" @@ -228,7 +266,7 @@ class GoogleGenerativeAIConversationEntity( options = self.entry.options try: - await session.async_update_llm_data( + await chat_log.async_update_llm_data( DOMAIN, user_input, options.get(CONF_LLM_HASS_API), @@ -238,10 +276,10 @@ class GoogleGenerativeAIConversationEntity( return err.as_conversation_result() tools: list[dict[str, Any]] | None = None - if session.llm_api: + if chat_log.llm_api: tools = [ - _format_tool(tool, session.llm_api.custom_serializer) - for tool in session.llm_api.tools + _format_tool(tool, chat_log.llm_api.custom_serializer) + for tool in chat_log.llm_api.tools ] model_name = self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL) @@ -252,9 +290,36 @@ class GoogleGenerativeAIConversationEntity( "gemini-1.0" not in model_name and "gemini-pro" not in model_name ) - prompt, *messages = [ - _chat_message_convert(message) for message in session.async_get_messages() - ] + prompt = chat_log.content[0].content # type: ignore[union-attr] + messages: list[genai_types.ContentDict] = [] + + # Google groups tool results, we do not. Group them before sending. + tool_results: list[conversation.ToolResultContent] = [] + + for chat_content in chat_log.content[1:]: + if chat_content.role == "tool_result": + # mypy doesn't like picking a type based on checking shared property 'role' + tool_results.append(cast(conversation.ToolResultContent, chat_content)) + continue + + if tool_results: + messages.append(_create_google_tool_response_content(tool_results)) + tool_results.clear() + + messages.append( + _convert_content( + cast( + conversation.UserContent + | conversation.SystemContent + | conversation.AssistantContent, + chat_content, + ) + ) + ) + + if tool_results: + messages.append(_create_google_tool_response_content(tool_results)) + model = genai.GenerativeModel( model_name=model_name, generation_config={ @@ -282,12 +347,12 @@ class GoogleGenerativeAIConversationEntity( ), }, tools=tools or None, - system_instruction=prompt["parts"] if supports_system_instruction else None, + system_instruction=prompt if supports_system_instruction else None, ) if not supports_system_instruction: messages = [ - {"role": "user", "parts": prompt["parts"]}, + {"role": "user", "parts": prompt}, {"role": "model", "parts": "Ok"}, *messages, ] @@ -325,50 +390,40 @@ class GoogleGenerativeAIConversationEntity( content = " ".join( [part.text.strip() for part in chat_response.parts if part.text] ) - if content: - session.async_add_message( - conversation.Content( - role="assistant", - agent_id=user_input.agent_id, - content=content, - ) - ) - function_calls = [ - part.function_call for part in chat_response.parts if part.function_call - ] - - if not function_calls or not session.llm_api: - break - - tool_responses = [] - for function_call in function_calls: - tool_call = MessageToDict(function_call._pb) # noqa: SLF001 + tool_calls = [] + for part in chat_response.parts: + if not part.function_call: + continue + tool_call = MessageToDict(part.function_call._pb) # noqa: SLF001 tool_name = tool_call["name"] tool_args = _escape_decode(tool_call["args"]) - tool_input = llm.ToolInput(tool_name=tool_name, tool_args=tool_args) - function_response = await session.async_call_tool(tool_input) - tool_responses.append( - protos.Part( - function_response=protos.FunctionResponse( - name=tool_name, response=function_response + tool_calls.append( + llm.ToolInput(tool_name=tool_name, tool_args=tool_args) + ) + + chat_request = _create_google_tool_response_content( + [ + tool_response + async for tool_response in chat_log.async_add_assistant_content( + conversation.AssistantContent( + agent_id=user_input.agent_id, + content=content, + tool_calls=tool_calls or None, ) ) - ) - chat_request = protos.Content(parts=tool_responses) - session.async_add_message( - conversation.NativeContent( - agent_id=user_input.agent_id, - content=chat_request, - ) + ] ) + if not tool_calls: + break + response = intent.IntentResponse(language=user_input.language) response.async_set_speech( " ".join([part.text.strip() for part in chat_response.parts if part.text]) ) return conversation.ConversationResult( - response=response, conversation_id=session.conversation_id + response=response, conversation_id=chat_log.conversation_id ) async def _async_entry_update_listener( diff --git a/homeassistant/components/openai_conversation/conversation.py b/homeassistant/components/openai_conversation/conversation.py index aced98eaa97..73dafa1c48d 100644 --- a/homeassistant/components/openai_conversation/conversation.py +++ b/homeassistant/components/openai_conversation/conversation.py @@ -70,7 +70,9 @@ def _format_tool( return ChatCompletionToolParam(type="function", function=tool_spec) -def _message_convert(message: ChatCompletionMessage) -> ChatCompletionMessageParam: +def _convert_message_to_param( + message: ChatCompletionMessage, +) -> ChatCompletionMessageParam: """Convert from class to TypedDict.""" tool_calls: list[ChatCompletionMessageToolCallParam] = [] if message.tool_calls: @@ -94,20 +96,42 @@ def _message_convert(message: ChatCompletionMessage) -> ChatCompletionMessagePar return param -def _chat_message_convert( - message: conversation.Content - | conversation.NativeContent[ChatCompletionMessageParam], +def _convert_content_to_param( + content: conversation.Content, ) -> ChatCompletionMessageParam: """Convert any native chat message for this agent to the native format.""" - role = message.role - if role == "native": - # mypy doesn't understand that checking role ensures content type - return message.content # type: ignore[return-value] - if role == "system": - role = "developer" - return cast( - ChatCompletionMessageParam, - {"role": role, "content": message.content}, + if content.role == "tool_result": + assert type(content) is conversation.ToolResultContent + return ChatCompletionToolMessageParam( + role="tool", + tool_call_id=content.tool_call_id, + content=json.dumps(content.tool_result), + ) + if content.role != "assistant" or not content.tool_calls: # type: ignore[union-attr] + role = content.role + if role == "system": + role = "developer" + return cast( + ChatCompletionMessageParam, + {"role": content.role, "content": content.content}, # type: ignore[union-attr] + ) + + # Handle the Assistant content including tool calls. + assert type(content) is conversation.AssistantContent + return ChatCompletionAssistantMessageParam( + role="assistant", + content=content.content, + tool_calls=[ + ChatCompletionMessageToolCallParam( + id=tool_call.id, + function=Function( + arguments=json.dumps(tool_call.tool_args), + name=tool_call.tool_name, + ), + type="function", + ) + for tool_call in content.tool_calls + ], ) @@ -171,14 +195,14 @@ class OpenAIConversationEntity( async def _async_handle_message( self, user_input: conversation.ConversationInput, - session: conversation.ChatLog[ChatCompletionMessageParam], + chat_log: conversation.ChatLog, ) -> conversation.ConversationResult: """Call the API.""" assert user_input.agent_id options = self.entry.options try: - await session.async_update_llm_data( + await chat_log.async_update_llm_data( DOMAIN, user_input, options.get(CONF_LLM_HASS_API), @@ -188,17 +212,14 @@ class OpenAIConversationEntity( return err.as_conversation_result() tools: list[ChatCompletionToolParam] | None = None - if session.llm_api: + if chat_log.llm_api: tools = [ - _format_tool(tool, session.llm_api.custom_serializer) - for tool in session.llm_api.tools + _format_tool(tool, chat_log.llm_api.custom_serializer) + for tool in chat_log.llm_api.tools ] model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL) - - messages = [ - _chat_message_convert(message) for message in session.async_get_messages() - ] + messages = [_convert_content_to_param(content) for content in chat_log.content] client = self.entry.runtime_data @@ -213,7 +234,7 @@ class OpenAIConversationEntity( ), "top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P), "temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE), - "user": session.conversation_id, + "user": chat_log.conversation_id, } if model.startswith("o"): @@ -229,43 +250,39 @@ class OpenAIConversationEntity( LOGGER.debug("Response %s", result) response = result.choices[0].message - messages.append(_message_convert(response)) + messages.append(_convert_message_to_param(response)) - session.async_add_message( - conversation.Content( - role=response.role, - agent_id=user_input.agent_id, - content=response.content or "", - ), + tool_calls: list[llm.ToolInput] | None = None + if response.tool_calls: + tool_calls = [ + llm.ToolInput( + id=tool_call.id, + tool_name=tool_call.function.name, + tool_args=json.loads(tool_call.function.arguments), + ) + for tool_call in response.tool_calls + ] + + messages.extend( + [ + _convert_content_to_param(tool_response) + async for tool_response in chat_log.async_add_assistant_content( + conversation.AssistantContent( + agent_id=user_input.agent_id, + content=response.content or "", + tool_calls=tool_calls, + ) + ) + ] ) - if not response.tool_calls or not session.llm_api: + if not tool_calls: break - for tool_call in response.tool_calls: - tool_input = llm.ToolInput( - tool_name=tool_call.function.name, - tool_args=json.loads(tool_call.function.arguments), - ) - tool_response = await session.async_call_tool(tool_input) - messages.append( - ChatCompletionToolMessageParam( - role="tool", - tool_call_id=tool_call.id, - content=json.dumps(tool_response), - ) - ) - session.async_add_message( - conversation.NativeContent( - agent_id=user_input.agent_id, - content=messages[-1], - ) - ) - intent_response = intent.IntentResponse(language=user_input.language) intent_response.async_set_speech(response.content or "") return conversation.ConversationResult( - response=intent_response, conversation_id=session.conversation_id + response=intent_response, conversation_id=chat_log.conversation_id ) async def _async_entry_update_listener( diff --git a/homeassistant/helpers/llm.py b/homeassistant/helpers/llm.py index 2bca4c8528b..b7c4951d8de 100644 --- a/homeassistant/helpers/llm.py +++ b/homeassistant/helpers/llm.py @@ -4,7 +4,7 @@ from __future__ import annotations from abc import ABC, abstractmethod from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field as dc_field from datetime import timedelta from decimal import Decimal from enum import Enum @@ -36,6 +36,7 @@ from homeassistant.exceptions import HomeAssistantError from homeassistant.util import dt as dt_util, yaml as yaml_util from homeassistant.util.hass_dict import HassKey from homeassistant.util.json import JsonObjectType +from homeassistant.util.ulid import ulid_now from . import ( area_registry as ar, @@ -139,6 +140,8 @@ class ToolInput: tool_name: str tool_args: dict[str, Any] + # Using lambda for default to allow patching in tests + id: str = dc_field(default_factory=lambda: ulid_now()) # pylint: disable=unnecessary-lambda class Tool: diff --git a/tests/components/anthropic/test_conversation.py b/tests/components/anthropic/test_conversation.py index fa5bcb8137a..bb77e2ff926 100644 --- a/tests/components/anthropic/test_conversation.py +++ b/tests/components/anthropic/test_conversation.py @@ -236,6 +236,7 @@ async def test_function_call( mock_tool.async_call.assert_awaited_once_with( hass, llm.ToolInput( + id="toolu_0123456789AbCdEfGhIjKlM", tool_name="test_tool", tool_args={"param1": "test_value"}, ), @@ -373,6 +374,7 @@ async def test_function_exception( mock_tool.async_call.assert_awaited_once_with( hass, llm.ToolInput( + id="toolu_0123456789AbCdEfGhIjKlM", tool_name="test_tool", tool_args={"param1": "test_value"}, ), diff --git a/tests/components/conversation/snapshots/test_session.ambr b/tests/components/conversation/snapshots/test_chat_log.ambr similarity index 100% rename from tests/components/conversation/snapshots/test_session.ambr rename to tests/components/conversation/snapshots/test_chat_log.ambr diff --git a/tests/components/conversation/test_session.py b/tests/components/conversation/test_chat_log.py similarity index 67% rename from tests/components/conversation/test_session.py rename to tests/components/conversation/test_chat_log.py index 3943f41a62b..a37d4408756 100644 --- a/tests/components/conversation/test_session.py +++ b/tests/components/conversation/test_chat_log.py @@ -9,13 +9,13 @@ from syrupy.assertion import SnapshotAssertion import voluptuous as vol from homeassistant.components.conversation import ( - Content, + AssistantContent, ConversationInput, ConverseError, - NativeContent, + ToolResultContent, async_get_chat_log, ) -from homeassistant.components.conversation.session import DATA_CHAT_HISTORY +from homeassistant.components.conversation.chat_log import DATA_CHAT_HISTORY from homeassistant.core import Context, HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import chat_session, llm @@ -40,7 +40,7 @@ def mock_conversation_input(hass: HomeAssistant) -> ConversationInput: @pytest.fixture def mock_ulid() -> Generator[Mock]: """Mock the ulid library.""" - with patch("homeassistant.util.ulid.ulid_now") as mock_ulid_now: + with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now: mock_ulid_now.return_value = "mock-ulid" yield mock_ulid_now @@ -56,13 +56,13 @@ async def test_cleanup( ): conversation_id = session.conversation_id # Add message so it persists - chat_log.async_add_message( - Content( - role="assistant", - agent_id=mock_conversation_input.agent_id, - content="", + async for _tool_result in chat_log.async_add_assistant_content( + AssistantContent( + agent_id="mock-agent-id", + content="Hey!", ) - ) + ): + pytest.fail("should not reach here") assert conversation_id in hass.data[DATA_CHAT_HISTORY] @@ -79,7 +79,7 @@ async def test_cleanup( assert conversation_id not in hass.data[DATA_CHAT_HISTORY] -async def test_add_message( +async def test_default_content( hass: HomeAssistant, mock_conversation_input: ConversationInput ) -> None: """Test filtering of messages.""" @@ -87,95 +87,11 @@ async def test_add_message( chat_session.async_get_chat_session(hass) as session, async_get_chat_log(hass, session, mock_conversation_input) as chat_log, ): - assert len(chat_log.messages) == 2 - - with pytest.raises(ValueError): - chat_log.async_add_message( - Content(role="system", agent_id=None, content="") - ) - - # No 2 user messages in a row - assert chat_log.messages[1].role == "user" - - with pytest.raises(ValueError): - chat_log.async_add_message(Content(role="user", agent_id=None, content="")) - - # No 2 assistant messages in a row - chat_log.async_add_message(Content(role="assistant", agent_id=None, content="")) - assert len(chat_log.messages) == 3 - assert chat_log.messages[-1].role == "assistant" - - with pytest.raises(ValueError): - chat_log.async_add_message( - Content(role="assistant", agent_id=None, content="") - ) - - -async def test_message_filtering( - hass: HomeAssistant, mock_conversation_input: ConversationInput -) -> None: - """Test filtering of messages.""" - with ( - chat_session.async_get_chat_session(hass) as session, - async_get_chat_log(hass, session, mock_conversation_input) as chat_log, - ): - messages = chat_log.async_get_messages(agent_id=None) - assert len(messages) == 2 - assert messages[0] == Content( - role="system", - agent_id=None, - content="", - ) - assert messages[1] == Content( - role="user", - agent_id="mock-agent-id", - content=mock_conversation_input.text, - ) - # Cannot add a second user message in a row - with pytest.raises(ValueError): - chat_log.async_add_message( - Content( - role="user", - agent_id="mock-agent-id", - content="Hey!", - ) - ) - - chat_log.async_add_message( - Content( - role="assistant", - agent_id="mock-agent-id", - content="Hey!", - ) - ) - # Different agent, native messages will be filtered out. - chat_log.async_add_message( - NativeContent(agent_id="another-mock-agent-id", content=1) - ) - chat_log.async_add_message(NativeContent(agent_id="mock-agent-id", content=1)) - # A non-native message from another agent is not filtered out. - chat_log.async_add_message( - Content( - role="assistant", - agent_id="another-mock-agent-id", - content="Hi!", - ) - ) - - assert len(chat_log.messages) == 6 - - messages = chat_log.async_get_messages(agent_id="mock-agent-id") - assert len(messages) == 5 - - assert messages[2] == Content( - role="assistant", - agent_id="mock-agent-id", - content="Hey!", - ) - assert messages[3] == NativeContent(agent_id="mock-agent-id", content=1) - assert messages[4] == Content( - role="assistant", agent_id="another-mock-agent-id", content="Hi!" - ) + assert len(chat_log.content) == 2 + assert chat_log.content[0].role == "system" + assert chat_log.content[0].content == "" + assert chat_log.content[1].role == "user" + assert chat_log.content[1].content == mock_conversation_input.text async def test_llm_api( @@ -268,12 +184,10 @@ async def test_template_variables( ), ) - assert chat_log.user_name == "Test User" - - assert "The instance name is test home." in chat_log.messages[0].content - assert "The user name is Test User." in chat_log.messages[0].content - assert "The user id is 12345." in chat_log.messages[0].content - assert "The calling platform is test." in chat_log.messages[0].content + assert "The instance name is test home." in chat_log.content[0].content + assert "The user name is Test User." in chat_log.content[0].content + assert "The user id is 12345." in chat_log.content[0].content + assert "The calling platform is test." in chat_log.content[0].content async def test_extra_systen_prompt( @@ -296,16 +210,16 @@ async def test_extra_systen_prompt( user_llm_hass_api=None, user_llm_prompt=None, ) - chat_log.async_add_message( - Content( - role="assistant", + async for _tool_result in chat_log.async_add_assistant_content( + AssistantContent( agent_id="mock-agent-id", content="Hey!", ) - ) + ): + pytest.fail("should not reach here") assert chat_log.extra_system_prompt == extra_system_prompt - assert chat_log.messages[0].content.endswith(extra_system_prompt) + assert chat_log.content[0].content.endswith(extra_system_prompt) # Verify that follow-up conversations with no system prompt take previous one conversation_id = chat_log.conversation_id @@ -323,7 +237,7 @@ async def test_extra_systen_prompt( ) assert chat_log.extra_system_prompt == extra_system_prompt - assert chat_log.messages[0].content.endswith(extra_system_prompt) + assert chat_log.content[0].content.endswith(extra_system_prompt) # Verify that we take new system prompts mock_conversation_input.extra_system_prompt = extra_system_prompt2 @@ -338,17 +252,17 @@ async def test_extra_systen_prompt( user_llm_hass_api=None, user_llm_prompt=None, ) - chat_log.async_add_message( - Content( - role="assistant", + async for _tool_result in chat_log.async_add_assistant_content( + AssistantContent( agent_id="mock-agent-id", content="Hey!", ) - ) + ): + pytest.fail("should not reach here") assert chat_log.extra_system_prompt == extra_system_prompt2 - assert chat_log.messages[0].content.endswith(extra_system_prompt2) - assert extra_system_prompt not in chat_log.messages[0].content + assert chat_log.content[0].content.endswith(extra_system_prompt2) + assert extra_system_prompt not in chat_log.content[0].content # Verify that follow-up conversations with no system prompt take previous one mock_conversation_input.extra_system_prompt = None @@ -365,7 +279,7 @@ async def test_extra_systen_prompt( ) assert chat_log.extra_system_prompt == extra_system_prompt2 - assert chat_log.messages[0].content.endswith(extra_system_prompt2) + assert chat_log.content[0].content.endswith(extra_system_prompt2) async def test_tool_call( @@ -383,8 +297,7 @@ async def test_tool_call( mock_tool.async_call.return_value = "Test response" with patch( - "homeassistant.components.conversation.session.llm.AssistAPI._async_get_tools", - return_value=[], + "homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[] ) as mock_get_tools: mock_get_tools.return_value = [mock_tool] @@ -398,14 +311,29 @@ async def test_tool_call( user_llm_hass_api="assist", user_llm_prompt=None, ) - result = await chat_log.async_call_tool( - llm.ToolInput( - tool_name="test_tool", - tool_args={"param1": "Test Param"}, + result = None + async for tool_result_content in chat_log.async_add_assistant_content( + AssistantContent( + agent_id=mock_conversation_input.agent_id, + content="", + tool_calls=[ + llm.ToolInput( + id="mock-tool-call-id", + tool_name="test_tool", + tool_args={"param1": "Test Param"}, + ) + ], ) - ) + ): + assert result is None + result = tool_result_content - assert result == "Test response" + assert result == ToolResultContent( + agent_id=mock_conversation_input.agent_id, + tool_call_id="mock-tool-call-id", + tool_result="Test response", + tool_name="test_tool", + ) async def test_tool_call_exception( @@ -423,8 +351,7 @@ async def test_tool_call_exception( mock_tool.async_call.side_effect = HomeAssistantError("Test error") with patch( - "homeassistant.components.conversation.session.llm.AssistAPI._async_get_tools", - return_value=[], + "homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[] ) as mock_get_tools: mock_get_tools.return_value = [mock_tool] @@ -438,11 +365,26 @@ async def test_tool_call_exception( user_llm_hass_api="assist", user_llm_prompt=None, ) - result = await chat_log.async_call_tool( - llm.ToolInput( - tool_name="test_tool", - tool_args={"param1": "Test Param"}, + result = None + async for tool_result_content in chat_log.async_add_assistant_content( + AssistantContent( + agent_id=mock_conversation_input.agent_id, + content="", + tool_calls=[ + llm.ToolInput( + id="mock-tool-call-id", + tool_name="test_tool", + tool_args={"param1": "Test Param"}, + ) + ], ) - ) + ): + assert result is None + result = tool_result_content - assert result == {"error": "HomeAssistantError", "error_text": "Test error"} + assert result == ToolResultContent( + agent_id=mock_conversation_input.agent_id, + tool_call_id="mock-tool-call-id", + tool_result={"error": "HomeAssistantError", "error_text": "Test error"}, + tool_name="test_tool", + ) diff --git a/tests/components/google_generative_ai_conversation/test_conversation.py b/tests/components/google_generative_ai_conversation/test_conversation.py index a87056275dc..72a5390f4b1 100644 --- a/tests/components/google_generative_ai_conversation/test_conversation.py +++ b/tests/components/google_generative_ai_conversation/test_conversation.py @@ -36,6 +36,13 @@ def freeze_the_time(): yield +@pytest.fixture(autouse=True) +def mock_ulid_tools(): + """Mock generated ULIDs for tool calls.""" + with patch("homeassistant.helpers.llm.ulid_now", return_value="mock-tool-call"): + yield + + @pytest.mark.parametrize( "agent_id", [None, "conversation.google_generative_ai_conversation"] ) @@ -177,6 +184,7 @@ async def test_chat_history( "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools" ) @pytest.mark.usefixtures("mock_init_component") +@pytest.mark.usefixtures("mock_ulid_tools") async def test_function_call( mock_get_tools, hass: HomeAssistant, @@ -256,6 +264,7 @@ async def test_function_call( mock_tool.async_call.assert_awaited_once_with( hass, llm.ToolInput( + id="mock-tool-call", tool_name="test_tool", tool_args={ "param1": ["test_value", "param1's value"], @@ -287,9 +296,7 @@ async def test_function_call( detail_event = trace_events[1] assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"] assert [ - p.function_response.name - for p in detail_event["data"]["messages"][2]["content"].parts - if p.function_response + p["tool_name"] for p in detail_event["data"]["messages"][2]["tool_calls"] ] == ["test_tool"] @@ -362,6 +369,7 @@ async def test_function_call_without_parameters( mock_tool.async_call.assert_awaited_once_with( hass, llm.ToolInput( + id="mock-tool-call", tool_name="test_tool", tool_args={}, ), @@ -451,6 +459,7 @@ async def test_function_exception( mock_tool.async_call.assert_awaited_once_with( hass, llm.ToolInput( + id="mock-tool-call", tool_name="test_tool", tool_args={"param1": 1}, ), @@ -605,6 +614,7 @@ async def test_template_variables( mock_chat.send_message_async.return_value = chat_response mock_part = MagicMock() mock_part.text = "Model response" + mock_part.function_call = None chat_response.parts = [mock_part] result = await conversation.async_converse( hass, "hello", None, context, agent_id=mock_config_entry.entry_id diff --git a/tests/components/ollama/test_conversation.py b/tests/components/ollama/test_conversation.py index 202f7385697..b8e299f5e77 100644 --- a/tests/components/ollama/test_conversation.py +++ b/tests/components/ollama/test_conversation.py @@ -18,6 +18,13 @@ from homeassistant.helpers import intent, llm from tests.common import MockConfigEntry +@pytest.fixture(autouse=True) +def mock_ulid_tools(): + """Mock generated ULIDs for tool calls.""" + with patch("homeassistant.helpers.llm.ulid_now", return_value="mock-tool-call"): + yield + + @pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"]) async def test_chat( hass: HomeAssistant, @@ -205,6 +212,7 @@ async def test_function_call( mock_tool.async_call.assert_awaited_once_with( hass, llm.ToolInput( + id="mock-tool-call", tool_name="test_tool", tool_args=expected_tool_args, ), @@ -285,6 +293,7 @@ async def test_function_exception( mock_tool.async_call.assert_awaited_once_with( hass, llm.ToolInput( + id="mock-tool-call", tool_name="test_tool", tool_args={"param1": "test_value"}, ), diff --git a/tests/components/openai_conversation/test_conversation.py b/tests/components/openai_conversation/test_conversation.py index 9ee19cd330c..39ca1b53e28 100644 --- a/tests/components/openai_conversation/test_conversation.py +++ b/tests/components/openai_conversation/test_conversation.py @@ -195,6 +195,7 @@ async def test_function_call( mock_tool.async_call.assert_awaited_once_with( hass, llm.ToolInput( + id="call_AbCdEfGhIjKlMnOpQrStUvWx", tool_name="test_tool", tool_args={"param1": "test_value"}, ), @@ -359,6 +360,7 @@ async def test_function_exception( mock_tool.async_call.assert_awaited_once_with( hass, llm.ToolInput( + id="call_AbCdEfGhIjKlMnOpQrStUvWx", tool_name="test_tool", tool_args={"param1": "test_value"}, ),