Add external tools support for chat log (#150461)

This commit is contained in:
Denis Shulyaka
2025-08-16 13:20:20 +03:00
committed by GitHub
parent 616b031df8
commit 80e720f663
9 changed files with 146 additions and 32 deletions

View File

@@ -40,6 +40,7 @@ from .chat_log import (
ConverseError,
SystemContent,
ToolResultContent,
ToolResultContentDeltaDict,
UserContent,
async_get_chat_log,
)
@@ -79,6 +80,7 @@ __all__ = [
"ConverseError",
"SystemContent",
"ToolResultContent",
"ToolResultContentDeltaDict",
"UserContent",
"async_conversation_trace_append",
"async_converse",

View File

@@ -9,7 +9,7 @@ from contextvars import ContextVar
from dataclasses import asdict, dataclass, field, replace
import logging
from pathlib import Path
from typing import Any, Literal, TypedDict
from typing import Any, Literal, TypedDict, cast
import voluptuous as vol
@@ -190,6 +190,15 @@ class AssistantContentDeltaDict(TypedDict, total=False):
native: Any
class ToolResultContentDeltaDict(TypedDict, total=False):
"""Tool result content."""
role: Literal["tool_result"]
tool_call_id: str
tool_name: str
tool_result: JsonObjectType
@dataclass
class ChatLog:
"""Class holding the chat history of a specific conversation."""
@@ -235,17 +244,25 @@ class ChatLog:
@callback
def async_add_assistant_content_without_tools(
self, content: AssistantContent
self, content: AssistantContent | ToolResultContent
) -> None:
"""Add assistant content to the log."""
"""Add assistant content to the log.
Allows assistant content without tool calls or with external tool calls only,
as well as tool results for the external tools.
"""
LOGGER.debug("Adding assistant content: %s", content)
if content.tool_calls is not None:
raise ValueError("Tool calls not allowed")
if (
isinstance(content, AssistantContent)
and content.tool_calls is not None
and any(not tool_call.external for tool_call in content.tool_calls)
):
raise ValueError("Non-external tool calls not allowed")
self.content.append(content)
async def async_add_assistant_content(
self,
content: AssistantContent,
content: AssistantContent | ToolResultContent,
/,
tool_call_tasks: dict[str, asyncio.Task] | None = None,
) -> AsyncGenerator[ToolResultContent]:
@@ -258,7 +275,11 @@ class ChatLog:
LOGGER.debug("Adding assistant content: %s", content)
self.content.append(content)
if content.tool_calls is None:
if (
not isinstance(content, AssistantContent)
or content.tool_calls is None
or all(tool_call.external for tool_call in content.tool_calls)
):
return
if self.llm_api is None:
@@ -267,13 +288,16 @@ class ChatLog:
if tool_call_tasks is None:
tool_call_tasks = {}
for tool_input in content.tool_calls:
if tool_input.id not in tool_call_tasks:
if tool_input.id not in tool_call_tasks and not tool_input.external:
tool_call_tasks[tool_input.id] = self.hass.async_create_task(
self.llm_api.async_call_tool(tool_input),
name=f"llm_tool_{tool_input.id}",
)
for tool_input in content.tool_calls:
if tool_input.external:
continue
LOGGER.debug(
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
)
@@ -296,7 +320,9 @@ class ChatLog:
yield response_content
async def async_add_delta_content_stream(
self, agent_id: str, stream: AsyncIterable[AssistantContentDeltaDict]
self,
agent_id: str,
stream: AsyncIterable[AssistantContentDeltaDict | ToolResultContentDeltaDict],
) -> AsyncGenerator[AssistantContent | ToolResultContent]:
"""Stream content into the chat log.
@@ -320,30 +346,34 @@ class ChatLog:
# Indicates update to current message
if "role" not in delta:
if delta_content := delta.get("content"):
# ToolResultContentDeltaDict will always have a role
assistant_delta = cast(AssistantContentDeltaDict, delta)
if delta_content := assistant_delta.get("content"):
current_content += delta_content
if delta_thinking_content := delta.get("thinking_content"):
if delta_thinking_content := assistant_delta.get("thinking_content"):
current_thinking_content += delta_thinking_content
if delta_native := delta.get("native"):
if delta_native := assistant_delta.get("native"):
if current_native is not None:
raise RuntimeError(
"Native content already set, cannot overwrite"
)
current_native = delta_native
if delta_tool_calls := delta.get("tool_calls"):
if self.llm_api is None:
raise ValueError("No LLM API configured")
if delta_tool_calls := assistant_delta.get("tool_calls"):
current_tool_calls += delta_tool_calls
# Start processing the tool calls as soon as we know about them
for tool_call in delta_tool_calls:
tool_call_tasks[tool_call.id] = self.hass.async_create_task(
self.llm_api.async_call_tool(tool_call),
name=f"llm_tool_{tool_call.id}",
)
if not tool_call.external:
if self.llm_api is None:
raise ValueError("No LLM API configured")
tool_call_tasks[tool_call.id] = self.hass.async_create_task(
self.llm_api.async_call_tool(tool_call),
name=f"llm_tool_{tool_call.id}",
)
if self.delta_listener:
if filtered_delta := {
k: v for k, v in delta.items() if k != "native"
k: v for k, v in assistant_delta.items() if k != "native"
}:
# We do not want to send the native content to the listener
# as it is not JSON serializable
@@ -351,10 +381,6 @@ class ChatLog:
continue
# Starting a new message
if delta["role"] != "assistant":
raise ValueError(f"Only assistant role expected. Got {delta['role']}")
# Yield the previous message if it has content
if (
current_content
@@ -362,7 +388,7 @@ class ChatLog:
or current_tool_calls
or current_native
):
content = AssistantContent(
content: AssistantContent | ToolResultContent = AssistantContent(
agent_id=agent_id,
content=current_content or None,
thinking_content=current_thinking_content or None,
@@ -376,14 +402,38 @@ class ChatLog:
yield tool_result
if self.delta_listener:
self.delta_listener(self, asdict(tool_result))
current_content = ""
current_thinking_content = ""
current_native = None
current_tool_calls = []
current_content = delta.get("content") or ""
current_thinking_content = delta.get("thinking_content") or ""
current_tool_calls = delta.get("tool_calls") or []
current_native = delta.get("native")
if delta["role"] == "assistant":
current_content = delta.get("content") or ""
current_thinking_content = delta.get("thinking_content") or ""
current_tool_calls = delta.get("tool_calls") or []
current_native = delta.get("native")
if self.delta_listener:
self.delta_listener(self, delta) # type: ignore[arg-type]
if self.delta_listener:
if filtered_delta := {
k: v for k, v in delta.items() if k != "native"
}:
self.delta_listener(self, filtered_delta)
elif delta["role"] == "tool_result":
content = ToolResultContent(
agent_id=agent_id,
tool_call_id=delta["tool_call_id"],
tool_name=delta["tool_name"],
tool_result=delta["tool_result"],
)
yield content
if self.delta_listener:
self.delta_listener(self, asdict(content))
self.async_add_assistant_content_without_tools(content)
else:
raise ValueError(
"Only assistant and tool_result roles expected."
f" Got {delta['role']}"
)
if (
current_content

View File

@@ -183,6 +183,7 @@ class ToolInput:
tool_args: dict[str, Any]
# Using lambda for default to allow patching in tests
id: str = dc_field(default_factory=lambda: ulid_now()) # pylint: disable=unnecessary-lambda
external: bool = False
class Tool:

View File

@@ -40,6 +40,7 @@
'thinking_content': "Okay, let's give it a shot. Will I pass the test?",
'tool_calls': list([
dict({
'external': False,
'id': 'toolu_0123456789AbCdEfGhIjKlM',
'tool_args': dict({
'param1': 'test_value',

View File

@@ -467,6 +467,7 @@
'chat_log_delta': dict({
'tool_calls': list([
dict({
'external': False,
'id': 'test_tool_id',
'tool_args': dict({
}),

View File

@@ -16,6 +16,34 @@
}),
])
# ---
# name: test_add_delta_content_stream[deltas11]
list([
dict({
'agent_id': 'mock-agent-id',
'content': 'Test',
'native': None,
'role': 'assistant',
'thinking_content': None,
'tool_calls': list([
dict({
'external': True,
'id': 'mock-tool-call-id',
'tool_args': dict({
'param1': 'Test Param 1',
}),
'tool_name': 'test_tool',
}),
]),
}),
dict({
'agent_id': 'mock-agent-id',
'role': 'tool_result',
'tool_call_id': 'mock-tool-call-id',
'tool_name': 'test_tool',
'tool_result': 'Test Result',
}),
])
# ---
# name: test_add_delta_content_stream[deltas1]
list([
dict({
@@ -58,6 +86,7 @@
'thinking_content': None,
'tool_calls': list([
dict({
'external': False,
'id': 'mock-tool-call-id',
'tool_args': dict({
'param1': 'Test Param 1',
@@ -85,6 +114,7 @@
'thinking_content': None,
'tool_calls': list([
dict({
'external': False,
'id': 'mock-tool-call-id',
'tool_args': dict({
'param1': 'Test Param 1',
@@ -112,6 +142,7 @@
'thinking_content': None,
'tool_calls': list([
dict({
'external': False,
'id': 'mock-tool-call-id',
'tool_args': dict({
'param1': 'Test Param 1',
@@ -147,6 +178,7 @@
'thinking_content': None,
'tool_calls': list([
dict({
'external': False,
'id': 'mock-tool-call-id',
'tool_args': dict({
'param1': 'Test Param 1',
@@ -154,6 +186,7 @@
'tool_name': 'test_tool',
}),
dict({
'external': False,
'id': 'mock-tool-call-id-2',
'tool_args': dict({
'param1': 'Test Param 2',

View File

@@ -538,6 +538,27 @@ async def test_tool_call_exception(
{"role": "assistant"},
{"native": object()},
],
# With external tool calls
[
{"role": "assistant"},
{"content": "Test"},
{
"tool_calls": [
llm.ToolInput(
id="mock-tool-call-id",
tool_name="test_tool",
tool_args={"param1": "Test Param 1"},
external=True,
)
]
},
{
"role": "tool_result",
"tool_call_id": "mock-tool-call-id",
"tool_name": "test_tool",
"tool_result": "Test Result",
},
],
],
)
async def test_add_delta_content_stream(
@@ -569,7 +590,8 @@ async def test_add_delta_content_stream(
for d in deltas:
yield d
if filtered_delta := {k: v for k, v in d.items() if k != "native"}:
expected_delta.append(filtered_delta)
if filtered_delta.get("role") != "tool_result":
expected_delta.append(filtered_delta)
captured_deltas = []

View File

@@ -135,6 +135,7 @@
'thinking_content': None,
'tool_calls': list([
dict({
'external': False,
'id': 'call_call_1',
'tool_args': dict({
'param1': 'call1',

View File

@@ -30,6 +30,7 @@
'thinking_content': None,
'tool_calls': list([
dict({
'external': False,
'id': 'call_call_1',
'tool_args': dict({
'param1': 'call1',
@@ -53,6 +54,7 @@
'thinking_content': None,
'tool_calls': list([
dict({
'external': False,
'id': 'call_call_2',
'tool_args': dict({
'param1': 'call2',
@@ -144,6 +146,7 @@
'thinking_content': None,
'tool_calls': list([
dict({
'external': False,
'id': 'call_call_1',
'tool_args': dict({
'param1': 'call1',