From 80e720f663e8f5484e752c49b41936853a022003 Mon Sep 17 00:00:00 2001 From: Denis Shulyaka Date: Sat, 16 Aug 2025 13:20:20 +0300 Subject: [PATCH] Add external tools support for chat log (#150461) --- .../components/conversation/__init__.py | 2 + .../components/conversation/chat_log.py | 112 +++++++++++++----- homeassistant/helpers/llm.py | 1 + .../snapshots/test_conversation.ambr | 1 + .../snapshots/test_pipeline.ambr | 1 + .../conversation/snapshots/test_chat_log.ambr | 33 ++++++ .../components/conversation/test_chat_log.py | 24 +++- .../snapshots/test_conversation.ambr | 1 + .../snapshots/test_conversation.ambr | 3 + 9 files changed, 146 insertions(+), 32 deletions(-) diff --git a/homeassistant/components/conversation/__init__.py b/homeassistant/components/conversation/__init__.py index 4fd3a57034f..dec26dd3215 100644 --- a/homeassistant/components/conversation/__init__.py +++ b/homeassistant/components/conversation/__init__.py @@ -40,6 +40,7 @@ from .chat_log import ( ConverseError, SystemContent, ToolResultContent, + ToolResultContentDeltaDict, UserContent, async_get_chat_log, ) @@ -79,6 +80,7 @@ __all__ = [ "ConverseError", "SystemContent", "ToolResultContent", + "ToolResultContentDeltaDict", "UserContent", "async_conversation_trace_append", "async_converse", diff --git a/homeassistant/components/conversation/chat_log.py b/homeassistant/components/conversation/chat_log.py index 7d842b3c562..2f5e3b0cf82 100644 --- a/homeassistant/components/conversation/chat_log.py +++ b/homeassistant/components/conversation/chat_log.py @@ -9,7 +9,7 @@ from contextvars import ContextVar from dataclasses import asdict, dataclass, field, replace import logging from pathlib import Path -from typing import Any, Literal, TypedDict +from typing import Any, Literal, TypedDict, cast import voluptuous as vol @@ -190,6 +190,15 @@ class AssistantContentDeltaDict(TypedDict, total=False): native: Any +class ToolResultContentDeltaDict(TypedDict, total=False): + """Tool result content.""" + + role: Literal["tool_result"] + tool_call_id: str + tool_name: str + tool_result: JsonObjectType + + @dataclass class ChatLog: """Class holding the chat history of a specific conversation.""" @@ -235,17 +244,25 @@ class ChatLog: @callback def async_add_assistant_content_without_tools( - self, content: AssistantContent + self, content: AssistantContent | ToolResultContent ) -> None: - """Add assistant content to the log.""" + """Add assistant content to the log. + + Allows assistant content without tool calls or with external tool calls only, + as well as tool results for the external tools. + """ LOGGER.debug("Adding assistant content: %s", content) - if content.tool_calls is not None: - raise ValueError("Tool calls not allowed") + if ( + isinstance(content, AssistantContent) + and content.tool_calls is not None + and any(not tool_call.external for tool_call in content.tool_calls) + ): + raise ValueError("Non-external tool calls not allowed") self.content.append(content) async def async_add_assistant_content( self, - content: AssistantContent, + content: AssistantContent | ToolResultContent, /, tool_call_tasks: dict[str, asyncio.Task] | None = None, ) -> AsyncGenerator[ToolResultContent]: @@ -258,7 +275,11 @@ class ChatLog: LOGGER.debug("Adding assistant content: %s", content) self.content.append(content) - if content.tool_calls is None: + if ( + not isinstance(content, AssistantContent) + or content.tool_calls is None + or all(tool_call.external for tool_call in content.tool_calls) + ): return if self.llm_api is None: @@ -267,13 +288,16 @@ class ChatLog: if tool_call_tasks is None: tool_call_tasks = {} for tool_input in content.tool_calls: - if tool_input.id not in tool_call_tasks: + if tool_input.id not in tool_call_tasks and not tool_input.external: tool_call_tasks[tool_input.id] = self.hass.async_create_task( self.llm_api.async_call_tool(tool_input), name=f"llm_tool_{tool_input.id}", ) for tool_input in content.tool_calls: + if tool_input.external: + continue + LOGGER.debug( "Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args ) @@ -296,7 +320,9 @@ class ChatLog: yield response_content async def async_add_delta_content_stream( - self, agent_id: str, stream: AsyncIterable[AssistantContentDeltaDict] + self, + agent_id: str, + stream: AsyncIterable[AssistantContentDeltaDict | ToolResultContentDeltaDict], ) -> AsyncGenerator[AssistantContent | ToolResultContent]: """Stream content into the chat log. @@ -320,30 +346,34 @@ class ChatLog: # Indicates update to current message if "role" not in delta: - if delta_content := delta.get("content"): + # ToolResultContentDeltaDict will always have a role + assistant_delta = cast(AssistantContentDeltaDict, delta) + if delta_content := assistant_delta.get("content"): current_content += delta_content - if delta_thinking_content := delta.get("thinking_content"): + if delta_thinking_content := assistant_delta.get("thinking_content"): current_thinking_content += delta_thinking_content - if delta_native := delta.get("native"): + if delta_native := assistant_delta.get("native"): if current_native is not None: raise RuntimeError( "Native content already set, cannot overwrite" ) current_native = delta_native - if delta_tool_calls := delta.get("tool_calls"): - if self.llm_api is None: - raise ValueError("No LLM API configured") + if delta_tool_calls := assistant_delta.get("tool_calls"): current_tool_calls += delta_tool_calls # Start processing the tool calls as soon as we know about them for tool_call in delta_tool_calls: - tool_call_tasks[tool_call.id] = self.hass.async_create_task( - self.llm_api.async_call_tool(tool_call), - name=f"llm_tool_{tool_call.id}", - ) + if not tool_call.external: + if self.llm_api is None: + raise ValueError("No LLM API configured") + + tool_call_tasks[tool_call.id] = self.hass.async_create_task( + self.llm_api.async_call_tool(tool_call), + name=f"llm_tool_{tool_call.id}", + ) if self.delta_listener: if filtered_delta := { - k: v for k, v in delta.items() if k != "native" + k: v for k, v in assistant_delta.items() if k != "native" }: # We do not want to send the native content to the listener # as it is not JSON serializable @@ -351,10 +381,6 @@ class ChatLog: continue # Starting a new message - - if delta["role"] != "assistant": - raise ValueError(f"Only assistant role expected. Got {delta['role']}") - # Yield the previous message if it has content if ( current_content @@ -362,7 +388,7 @@ class ChatLog: or current_tool_calls or current_native ): - content = AssistantContent( + content: AssistantContent | ToolResultContent = AssistantContent( agent_id=agent_id, content=current_content or None, thinking_content=current_thinking_content or None, @@ -376,14 +402,38 @@ class ChatLog: yield tool_result if self.delta_listener: self.delta_listener(self, asdict(tool_result)) + current_content = "" + current_thinking_content = "" + current_native = None + current_tool_calls = [] - current_content = delta.get("content") or "" - current_thinking_content = delta.get("thinking_content") or "" - current_tool_calls = delta.get("tool_calls") or [] - current_native = delta.get("native") + if delta["role"] == "assistant": + current_content = delta.get("content") or "" + current_thinking_content = delta.get("thinking_content") or "" + current_tool_calls = delta.get("tool_calls") or [] + current_native = delta.get("native") - if self.delta_listener: - self.delta_listener(self, delta) # type: ignore[arg-type] + if self.delta_listener: + if filtered_delta := { + k: v for k, v in delta.items() if k != "native" + }: + self.delta_listener(self, filtered_delta) + elif delta["role"] == "tool_result": + content = ToolResultContent( + agent_id=agent_id, + tool_call_id=delta["tool_call_id"], + tool_name=delta["tool_name"], + tool_result=delta["tool_result"], + ) + yield content + if self.delta_listener: + self.delta_listener(self, asdict(content)) + self.async_add_assistant_content_without_tools(content) + else: + raise ValueError( + "Only assistant and tool_result roles expected." + f" Got {delta['role']}" + ) if ( current_content diff --git a/homeassistant/helpers/llm.py b/homeassistant/helpers/llm.py index 1ff6b188214..dc69916a728 100644 --- a/homeassistant/helpers/llm.py +++ b/homeassistant/helpers/llm.py @@ -183,6 +183,7 @@ class ToolInput: tool_args: dict[str, Any] # Using lambda for default to allow patching in tests id: str = dc_field(default_factory=lambda: ulid_now()) # pylint: disable=unnecessary-lambda + external: bool = False class Tool: diff --git a/tests/components/anthropic/snapshots/test_conversation.ambr b/tests/components/anthropic/snapshots/test_conversation.ambr index 95cc02f4576..8f7a3c43f5e 100644 --- a/tests/components/anthropic/snapshots/test_conversation.ambr +++ b/tests/components/anthropic/snapshots/test_conversation.ambr @@ -40,6 +40,7 @@ 'thinking_content': "Okay, let's give it a shot. Will I pass the test?", 'tool_calls': list([ dict({ + 'external': False, 'id': 'toolu_0123456789AbCdEfGhIjKlM', 'tool_args': dict({ 'param1': 'test_value', diff --git a/tests/components/assist_pipeline/snapshots/test_pipeline.ambr b/tests/components/assist_pipeline/snapshots/test_pipeline.ambr index 95415ddb902..b6354b2342b 100644 --- a/tests/components/assist_pipeline/snapshots/test_pipeline.ambr +++ b/tests/components/assist_pipeline/snapshots/test_pipeline.ambr @@ -467,6 +467,7 @@ 'chat_log_delta': dict({ 'tool_calls': list([ dict({ + 'external': False, 'id': 'test_tool_id', 'tool_args': dict({ }), diff --git a/tests/components/conversation/snapshots/test_chat_log.ambr b/tests/components/conversation/snapshots/test_chat_log.ambr index a1c53986053..787009ba614 100644 --- a/tests/components/conversation/snapshots/test_chat_log.ambr +++ b/tests/components/conversation/snapshots/test_chat_log.ambr @@ -16,6 +16,34 @@ }), ]) # --- +# name: test_add_delta_content_stream[deltas11] + list([ + dict({ + 'agent_id': 'mock-agent-id', + 'content': 'Test', + 'native': None, + 'role': 'assistant', + 'thinking_content': None, + 'tool_calls': list([ + dict({ + 'external': True, + 'id': 'mock-tool-call-id', + 'tool_args': dict({ + 'param1': 'Test Param 1', + }), + 'tool_name': 'test_tool', + }), + ]), + }), + dict({ + 'agent_id': 'mock-agent-id', + 'role': 'tool_result', + 'tool_call_id': 'mock-tool-call-id', + 'tool_name': 'test_tool', + 'tool_result': 'Test Result', + }), + ]) +# --- # name: test_add_delta_content_stream[deltas1] list([ dict({ @@ -58,6 +86,7 @@ 'thinking_content': None, 'tool_calls': list([ dict({ + 'external': False, 'id': 'mock-tool-call-id', 'tool_args': dict({ 'param1': 'Test Param 1', @@ -85,6 +114,7 @@ 'thinking_content': None, 'tool_calls': list([ dict({ + 'external': False, 'id': 'mock-tool-call-id', 'tool_args': dict({ 'param1': 'Test Param 1', @@ -112,6 +142,7 @@ 'thinking_content': None, 'tool_calls': list([ dict({ + 'external': False, 'id': 'mock-tool-call-id', 'tool_args': dict({ 'param1': 'Test Param 1', @@ -147,6 +178,7 @@ 'thinking_content': None, 'tool_calls': list([ dict({ + 'external': False, 'id': 'mock-tool-call-id', 'tool_args': dict({ 'param1': 'Test Param 1', @@ -154,6 +186,7 @@ 'tool_name': 'test_tool', }), dict({ + 'external': False, 'id': 'mock-tool-call-id-2', 'tool_args': dict({ 'param1': 'Test Param 2', diff --git a/tests/components/conversation/test_chat_log.py b/tests/components/conversation/test_chat_log.py index a5ed3146ddc..e851512b36e 100644 --- a/tests/components/conversation/test_chat_log.py +++ b/tests/components/conversation/test_chat_log.py @@ -538,6 +538,27 @@ async def test_tool_call_exception( {"role": "assistant"}, {"native": object()}, ], + # With external tool calls + [ + {"role": "assistant"}, + {"content": "Test"}, + { + "tool_calls": [ + llm.ToolInput( + id="mock-tool-call-id", + tool_name="test_tool", + tool_args={"param1": "Test Param 1"}, + external=True, + ) + ] + }, + { + "role": "tool_result", + "tool_call_id": "mock-tool-call-id", + "tool_name": "test_tool", + "tool_result": "Test Result", + }, + ], ], ) async def test_add_delta_content_stream( @@ -569,7 +590,8 @@ async def test_add_delta_content_stream( for d in deltas: yield d if filtered_delta := {k: v for k, v in d.items() if k != "native"}: - expected_delta.append(filtered_delta) + if filtered_delta.get("role") != "tool_result": + expected_delta.append(filtered_delta) captured_deltas = [] diff --git a/tests/components/open_router/snapshots/test_conversation.ambr b/tests/components/open_router/snapshots/test_conversation.ambr index b60bab02ae7..19b5785a9eb 100644 --- a/tests/components/open_router/snapshots/test_conversation.ambr +++ b/tests/components/open_router/snapshots/test_conversation.ambr @@ -135,6 +135,7 @@ 'thinking_content': None, 'tool_calls': list([ dict({ + 'external': False, 'id': 'call_call_1', 'tool_args': dict({ 'param1': 'call1', diff --git a/tests/components/openai_conversation/snapshots/test_conversation.ambr b/tests/components/openai_conversation/snapshots/test_conversation.ambr index 7a03c484182..d33d62214ef 100644 --- a/tests/components/openai_conversation/snapshots/test_conversation.ambr +++ b/tests/components/openai_conversation/snapshots/test_conversation.ambr @@ -30,6 +30,7 @@ 'thinking_content': None, 'tool_calls': list([ dict({ + 'external': False, 'id': 'call_call_1', 'tool_args': dict({ 'param1': 'call1', @@ -53,6 +54,7 @@ 'thinking_content': None, 'tool_calls': list([ dict({ + 'external': False, 'id': 'call_call_2', 'tool_args': dict({ 'param1': 'call2', @@ -144,6 +146,7 @@ 'thinking_content': None, 'tool_calls': list([ dict({ + 'external': False, 'id': 'call_call_1', 'tool_args': dict({ 'param1': 'call1',