diff --git a/homeassistant/components/conversation/agent_manager.py b/homeassistant/components/conversation/agent_manager.py index 97dc5e1292e..ce3a0cf028d 100644 --- a/homeassistant/components/conversation/agent_manager.py +++ b/homeassistant/components/conversation/agent_manager.py @@ -9,7 +9,8 @@ from typing import Any import voluptuous as vol from homeassistant.core import Context, HomeAssistant, async_get_hass, callback -from homeassistant.helpers import config_validation as cv, singleton +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import config_validation as cv, intent, singleton from .const import ( DATA_COMPONENT, @@ -109,7 +110,19 @@ async def async_converse( dataclasses.asdict(conversation_input), ) ) - result = await method(conversation_input) + try: + result = await method(conversation_input) + except HomeAssistantError as err: + intent_response = intent.IntentResponse(language=language) + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + str(err), + ) + result = ConversationResult( + response=intent_response, + conversation_id=conversation_id, + ) + trace.set_result(**result.as_dict()) return result diff --git a/homeassistant/components/conversation/session.py b/homeassistant/components/conversation/session.py index 48040e8ac9c..2235459954f 100644 --- a/homeassistant/components/conversation/session.py +++ b/homeassistant/components/conversation/session.py @@ -9,6 +9,8 @@ from datetime import datetime, timedelta import logging from typing import Literal +import voluptuous as vol + from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.core import ( CALLBACK_TYPE, @@ -23,7 +25,9 @@ from homeassistant.helpers import intent, llm, template from homeassistant.helpers.event import async_call_later from homeassistant.util import dt as dt_util, ulid as ulid_util from homeassistant.util.hass_dict import HassKey +from homeassistant.util.json import JsonObjectType +from . import trace from .const import DOMAIN from .models import ConversationInput, ConversationResult @@ -120,7 +124,7 @@ async def async_get_chat_session( if history: history = replace(history, messages=history.messages.copy()) else: - history = ChatSession(hass, conversation_id) + history = ChatSession(hass, conversation_id, user_input.agent_id) message: ChatMessage = ChatMessage( role="user", @@ -190,6 +194,7 @@ class ChatSession[_NativeT]: hass: HomeAssistant conversation_id: str + agent_id: str | None user_name: str | None = None messages: list[ChatMessage[_NativeT]] = field( default_factory=lambda: [ChatMessage(role="system", agent_id=None, content="")] @@ -209,7 +214,9 @@ class ChatSession[_NativeT]: self.messages.append(message) @callback - def async_get_messages(self, agent_id: str | None) -> list[ChatMessage[_NativeT]]: + def async_get_messages( + self, agent_id: str | None = None + ) -> list[ChatMessage[_NativeT]]: """Get messages for a specific agent ID. This will filter out any native message tied to other agent IDs. @@ -326,3 +333,29 @@ class ChatSession[_NativeT]: agent_id=user_input.agent_id, content=prompt, ) + + LOGGER.debug("Prompt: %s", self.messages) + LOGGER.debug("Tools: %s", self.llm_api.tools if self.llm_api else None) + + trace.async_conversation_trace_append( + trace.ConversationTraceEventType.AGENT_DETAIL, + { + "messages": self.messages, + "tools": self.llm_api.tools if self.llm_api else None, + }, + ) + + async def async_call_tool(self, tool_input: llm.ToolInput) -> JsonObjectType: + """Invoke LLM tool for the configured LLM API.""" + if not self.llm_api: + raise ValueError("No LLM API configured") + LOGGER.debug("Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args) + + try: + tool_response = await self.llm_api.async_call_tool(tool_input) + except (HomeAssistantError, vol.Invalid) as e: + tool_response = {"error": type(e).__name__} + if str(e): + tool_response["error_text"] = str(e) + LOGGER.debug("Tool response: %s", tool_response) + return tool_response diff --git a/homeassistant/components/openai_conversation/conversation.py b/homeassistant/components/openai_conversation/conversation.py index c89574bf3bd..1464f4224d7 100644 --- a/homeassistant/components/openai_conversation/conversation.py +++ b/homeassistant/components/openai_conversation/conversation.py @@ -16,11 +16,9 @@ from openai.types.chat import ( ) from openai.types.chat.chat_completion_message_tool_call_param import Function from openai.types.shared_params import FunctionDefinition -import voluptuous as vol from voluptuous_openapi import convert from homeassistant.components import assist_pipeline, conversation -from homeassistant.components.conversation import trace from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.core import HomeAssistant @@ -94,6 +92,19 @@ def _message_convert(message: ChatCompletionMessage) -> ChatCompletionMessagePar return param +def _chat_message_convert( + message: conversation.ChatMessage[ChatCompletionMessageParam], + agent_id: str | None, +) -> ChatCompletionMessageParam: + """Convert any native chat message for this agent to the native format.""" + if message.native is not None and message.agent_id == agent_id: + return message.native + return cast( + ChatCompletionMessageParam, + {"role": message.role, "content": message.content}, + ) + + class OpenAIConversationEntity( conversation.ConversationEntity, conversation.AbstractConversationAgent ): @@ -173,27 +184,10 @@ class OpenAIConversationEntity( for tool in session.llm_api.tools ] - messages: list[ChatCompletionMessageParam] = [] - for message in session.async_get_messages(user_input.agent_id): - if message.native is not None and message.agent_id == user_input.agent_id: - messages.append(message.native) - else: - messages.append( - cast( - ChatCompletionMessageParam, - {"role": message.role, "content": message.content}, - ) - ) - - LOGGER.debug("Prompt: %s", messages) - LOGGER.debug("Tools: %s", tools) - trace.async_conversation_trace_append( - trace.ConversationTraceEventType.AGENT_DETAIL, - { - "messages": session.messages, - "tools": session.llm_api.tools if session.llm_api else None, - }, - ) + messages = [ + _chat_message_convert(message, user_input.agent_id) + for message in session.async_get_messages() + ] client = self.entry.runtime_data @@ -211,14 +205,7 @@ class OpenAIConversationEntity( ) except openai.OpenAIError as err: LOGGER.error("Error talking to OpenAI: %s", err) - intent_response = intent.IntentResponse(language=user_input.language) - intent_response.async_set_error( - intent.IntentResponseErrorCode.UNKNOWN, - "Sorry, I had a problem talking to OpenAI", - ) - return conversation.ConversationResult( - response=intent_response, conversation_id=session.conversation_id - ) + raise HomeAssistantError("Error talking to OpenAI") from err LOGGER.debug("Response %s", result) response = result.choices[0].message @@ -241,18 +228,7 @@ class OpenAIConversationEntity( tool_name=tool_call.function.name, tool_args=json.loads(tool_call.function.arguments), ) - LOGGER.debug( - "Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args - ) - - try: - tool_response = await session.llm_api.async_call_tool(tool_input) - except (HomeAssistantError, vol.Invalid) as e: - tool_response = {"error": type(e).__name__} - if str(e): - tool_response["error_text"] = str(e) - - LOGGER.debug("Tool response: %s", tool_response) + tool_response = await session.async_call_tool(tool_input) messages.append( ChatCompletionToolMessageParam( role="tool", diff --git a/tests/components/conversation/test_session.py b/tests/components/conversation/test_session.py index feb6ca2a9e8..bca19b3b06a 100644 --- a/tests/components/conversation/test_session.py +++ b/tests/components/conversation/test_session.py @@ -2,13 +2,15 @@ from collections.abc import Generator from datetime import timedelta -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch import pytest from syrupy.assertion import SnapshotAssertion +import voluptuous as vol from homeassistant.components.conversation import ConversationInput, session from homeassistant.core import Context, HomeAssistant +from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import llm from homeassistant.util import dt as dt_util @@ -182,7 +184,7 @@ async def test_message_filtering( ) assert messages[1] == session.ChatMessage( role="user", - agent_id=mock_conversation_input.agent_id, + agent_id="mock-agent-id", content=mock_conversation_input.text, ) # Cannot add a second user message in a row @@ -203,7 +205,7 @@ async def test_message_filtering( native="assistant-reply-native", ) ) - # Different agent, will be filtered out. + # Different agent, native messages will be filtered out. chat_session.async_add_message( session.ChatMessage( role="native", agent_id="another-mock-agent-id", content="", native=1 @@ -214,11 +216,20 @@ async def test_message_filtering( role="native", agent_id="mock-agent-id", content="", native=1 ) ) + # A non-native message from another agent is not filtered out. + chat_session.async_add_message( + session.ChatMessage( + role="assistant", + agent_id="another-mock-agent-id", + content="Hi!", + native=1, + ) + ) - assert len(chat_session.messages) == 5 + assert len(chat_session.messages) == 6 messages = chat_session.async_get_messages(agent_id="mock-agent-id") - assert len(messages) == 4 + assert len(messages) == 5 assert messages[2] == session.ChatMessage( role="assistant", @@ -229,6 +240,9 @@ async def test_message_filtering( assert messages[3] == session.ChatMessage( role="native", agent_id="mock-agent-id", content="", native=1 ) + assert messages[4] == session.ChatMessage( + role="assistant", agent_id="another-mock-agent-id", content="Hi!", native=1 + ) async def test_llm_api( @@ -413,3 +427,81 @@ async def test_extra_systen_prompt( assert chat_session.extra_system_prompt == extra_system_prompt2 assert chat_session.messages[0].content.endswith(extra_system_prompt2) + + +async def test_tool_call( + hass: HomeAssistant, + mock_conversation_input: ConversationInput, +) -> None: + """Test using the session tool calling API.""" + + mock_tool = AsyncMock() + mock_tool.name = "test_tool" + mock_tool.description = "Test function" + mock_tool.parameters = vol.Schema( + {vol.Optional("param1", description="Test parameters"): str} + ) + mock_tool.async_call.return_value = "Test response" + + with patch( + "homeassistant.components.conversation.session.llm.AssistAPI._async_get_tools", + return_value=[], + ) as mock_get_tools: + mock_get_tools.return_value = [mock_tool] + + async with session.async_get_chat_session( + hass, mock_conversation_input + ) as chat_session: + await chat_session.async_update_llm_data( + conversing_domain="test", + user_input=mock_conversation_input, + user_llm_hass_api="assist", + user_llm_prompt=None, + ) + result = await chat_session.async_call_tool( + llm.ToolInput( + tool_name="test_tool", + tool_args={"param1": "Test Param"}, + ) + ) + + assert result == "Test response" + + +async def test_tool_call_exception( + hass: HomeAssistant, + mock_conversation_input: ConversationInput, +) -> None: + """Test using the session tool calling API.""" + + mock_tool = AsyncMock() + mock_tool.name = "test_tool" + mock_tool.description = "Test function" + mock_tool.parameters = vol.Schema( + {vol.Optional("param1", description="Test parameters"): str} + ) + mock_tool.async_call.side_effect = HomeAssistantError("Test error") + + with patch( + "homeassistant.components.conversation.session.llm.AssistAPI._async_get_tools", + return_value=[], + ) as mock_get_tools: + mock_get_tools.return_value = [mock_tool] + + async with session.async_get_chat_session( + hass, mock_conversation_input + ) as chat_session: + await chat_session.async_update_llm_data( + conversing_domain="test", + user_input=mock_conversation_input, + user_llm_hass_api="assist", + user_llm_prompt=None, + ) + result = await chat_session.async_call_tool( + llm.ToolInput( + tool_name="test_tool", + tool_args={"param1": "Test Param"}, + ) + ) + + assert result == {"error": "HomeAssistantError", "error_text": "Test error"} diff --git a/tests/components/conversation/test_trace.py b/tests/components/conversation/test_trace.py index 7c00b9a80b2..a975c9b7983 100644 --- a/tests/components/conversation/test_trace.py +++ b/tests/components/conversation/test_trace.py @@ -61,18 +61,18 @@ async def test_converation_trace( } -async def test_converation_trace_error( +async def test_converation_trace_uncaught_error( hass: HomeAssistant, init_components: None, sl_setup: None, ) -> None: - """Test tracing a conversation.""" + """Test tracing a conversation that raises an uncaught error.""" with ( patch( "homeassistant.components.conversation.default_agent.DefaultAgent.async_process", - side_effect=HomeAssistantError("Failed to talk to agent"), + side_effect=ValueError("Unexpected error"), ), - pytest.raises(HomeAssistantError), + pytest.raises(ValueError), ): await conversation.async_converse( hass, "add apples to my shopping list", None, Context() @@ -87,4 +87,35 @@ async def test_converation_trace_error( assert ( trace_event.get("event_type") == trace.ConversationTraceEventType.ASYNC_PROCESS ) - assert last_trace.get("error") == "Failed to talk to agent" + assert last_trace.get("error") == "Unexpected error" + assert not last_trace.get("result") + + +async def test_converation_trace_homeassistant_error( + hass: HomeAssistant, + init_components: None, + sl_setup: None, +) -> None: + """Test tracing a conversation with a HomeAssistant error.""" + with ( + patch( + "homeassistant.components.conversation.default_agent.DefaultAgent.async_process", + side_effect=HomeAssistantError("Failed to talk to agent"), + ), + ): + await conversation.async_converse( + hass, "add apples to my shopping list", None, Context() + ) + + traces = trace.async_get_traces() + assert traces + last_trace = traces[-1].as_dict() + assert last_trace.get("events") + assert len(last_trace.get("events")) == 1 + trace_event = last_trace["events"][0] + assert ( + trace_event.get("event_type") == trace.ConversationTraceEventType.ASYNC_PROCESS + ) + result = last_trace.get("result") + assert result + assert result["response"]["speech"]["plain"]["speech"] == "Failed to talk to agent"