diff --git a/homeassistant/components/conversation/default_agent.py b/homeassistant/components/conversation/default_agent.py index 45393289ac8..1661d2ad30d 100644 --- a/homeassistant/components/conversation/default_agent.py +++ b/homeassistant/components/conversation/default_agent.py @@ -47,6 +47,7 @@ from homeassistant.util.json import JsonObjectType, json_loads_object from .const import DEFAULT_EXPOSED_ATTRIBUTES, DOMAIN, ConversationEntityFeature from .entity import ConversationEntity from .models import ConversationInput, ConversationResult +from .trace import ConversationTraceEventType, async_conversation_trace_append _LOGGER = logging.getLogger(__name__) _DEFAULT_ERROR_TEXT = "Sorry, I couldn't understand that" @@ -348,6 +349,16 @@ class DefaultAgent(ConversationEntity): } for entity in result.entities_list } + async_conversation_trace_append( + ConversationTraceEventType.TOOL_CALL, + { + "intent_name": result.intent.name, + "slots": { + entity.name: entity.value or entity.text + for entity in result.entities_list + }, + }, + ) try: intent_response = await intent.async_handle( diff --git a/homeassistant/components/conversation/trace.py b/homeassistant/components/conversation/trace.py index 08b271d9058..6f993aa326a 100644 --- a/homeassistant/components/conversation/trace.py +++ b/homeassistant/components/conversation/trace.py @@ -22,8 +22,8 @@ class ConversationTraceEventType(enum.StrEnum): AGENT_DETAIL = "agent_detail" """Event detail added by a conversation agent.""" - LLM_TOOL_CALL = "llm_tool_call" - """An LLM Tool call""" + TOOL_CALL = "tool_call" + """A conversation agent Tool call or default agent intent call.""" @dataclass(frozen=True) diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py index 8dec62ad26b..a5c911bb757 100644 --- a/homeassistant/components/google_generative_ai_conversation/conversation.py +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -286,6 +286,7 @@ class GoogleGenerativeAIConversationEntity( if supports_system_instruction else messages[2:], "prompt": prompt, + "tools": [*llm_api.tools] if llm_api else None, }, ) diff --git a/homeassistant/components/openai_conversation/conversation.py b/homeassistant/components/openai_conversation/conversation.py index dd42049e3d0..483b37945d6 100644 --- a/homeassistant/components/openai_conversation/conversation.py +++ b/homeassistant/components/openai_conversation/conversation.py @@ -225,7 +225,8 @@ class OpenAIConversationEntity( LOGGER.debug("Prompt: %s", messages) LOGGER.debug("Tools: %s", tools) trace.async_conversation_trace_append( - trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages} + trace.ConversationTraceEventType.AGENT_DETAIL, + {"messages": messages, "tools": llm_api.tools if llm_api else None}, ) client = self.entry.runtime_data diff --git a/homeassistant/helpers/llm.py b/homeassistant/helpers/llm.py index 8ad576b7ea5..4ddb00166b6 100644 --- a/homeassistant/helpers/llm.py +++ b/homeassistant/helpers/llm.py @@ -167,7 +167,7 @@ class APIInstance: async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType: """Call a LLM tool, validate args and return the response.""" async_conversation_trace_append( - ConversationTraceEventType.LLM_TOOL_CALL, + ConversationTraceEventType.TOOL_CALL, {"tool_name": tool_input.tool_name, "tool_args": tool_input.tool_args}, ) diff --git a/tests/components/conversation/test_trace.py b/tests/components/conversation/test_trace.py index c586eb8865d..59cd10d2510 100644 --- a/tests/components/conversation/test_trace.py +++ b/tests/components/conversation/test_trace.py @@ -33,7 +33,7 @@ async def test_converation_trace( assert traces last_trace = traces[-1].as_dict() assert last_trace.get("events") - assert len(last_trace.get("events")) == 1 + assert len(last_trace.get("events")) == 2 trace_event = last_trace["events"][0] assert ( trace_event.get("event_type") == trace.ConversationTraceEventType.ASYNC_PROCESS @@ -50,6 +50,16 @@ async def test_converation_trace( == "Added apples" ) + trace_event = last_trace["events"][1] + assert trace_event.get("event_type") == trace.ConversationTraceEventType.TOOL_CALL + assert trace_event.get("data") == { + "intent_name": "HassListAddItem", + "slots": { + "name": "Shopping List", + "item": "apples ", + }, + } + async def test_converation_trace_error( hass: HomeAssistant, diff --git a/tests/components/google_generative_ai_conversation/test_conversation.py b/tests/components/google_generative_ai_conversation/test_conversation.py index afeb6d01faa..41f96c7b0ac 100644 --- a/tests/components/google_generative_ai_conversation/test_conversation.py +++ b/tests/components/google_generative_ai_conversation/test_conversation.py @@ -269,11 +269,12 @@ async def test_function_call( assert [event["event_type"] for event in trace_events] == [ trace.ConversationTraceEventType.ASYNC_PROCESS, trace.ConversationTraceEventType.AGENT_DETAIL, - trace.ConversationTraceEventType.LLM_TOOL_CALL, + trace.ConversationTraceEventType.TOOL_CALL, ] # AGENT_DETAIL event contains the raw prompt passed to the model detail_event = trace_events[1] assert "Answer in plain text" in detail_event["data"]["prompt"] + assert [t.name for t in detail_event["data"]["tools"]] == ["test_tool"] @patch( diff --git a/tests/components/openai_conversation/test_conversation.py b/tests/components/openai_conversation/test_conversation.py index fee1543a0d7..3364d822245 100644 --- a/tests/components/openai_conversation/test_conversation.py +++ b/tests/components/openai_conversation/test_conversation.py @@ -294,7 +294,7 @@ async def test_function_call( assert [event["event_type"] for event in trace_events] == [ trace.ConversationTraceEventType.ASYNC_PROCESS, trace.ConversationTraceEventType.AGENT_DETAIL, - trace.ConversationTraceEventType.LLM_TOOL_CALL, + trace.ConversationTraceEventType.TOOL_CALL, ] # AGENT_DETAIL event contains the raw prompt passed to the model detail_event = trace_events[1] @@ -303,6 +303,7 @@ async def test_function_call( "Today's date is 2024-06-03." in trace_events[1]["data"]["messages"][0]["content"] ) + assert [t.name for t in detail_event["data"]["tools"]] == ["test_tool"] # Call it again, make sure we have updated prompt with (