mirror of
https://github.com/home-assistant/core.git
synced 2026-04-06 23:47:33 +00:00
Add external tools support for chat log (#150461)
This commit is contained in:
@@ -40,6 +40,7 @@ from .chat_log import (
|
||||
ConverseError,
|
||||
SystemContent,
|
||||
ToolResultContent,
|
||||
ToolResultContentDeltaDict,
|
||||
UserContent,
|
||||
async_get_chat_log,
|
||||
)
|
||||
@@ -79,6 +80,7 @@ __all__ = [
|
||||
"ConverseError",
|
||||
"SystemContent",
|
||||
"ToolResultContent",
|
||||
"ToolResultContentDeltaDict",
|
||||
"UserContent",
|
||||
"async_conversation_trace_append",
|
||||
"async_converse",
|
||||
|
||||
@@ -9,7 +9,7 @@ from contextvars import ContextVar
|
||||
from dataclasses import asdict, dataclass, field, replace
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, TypedDict
|
||||
from typing import Any, Literal, TypedDict, cast
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
@@ -190,6 +190,15 @@ class AssistantContentDeltaDict(TypedDict, total=False):
|
||||
native: Any
|
||||
|
||||
|
||||
class ToolResultContentDeltaDict(TypedDict, total=False):
|
||||
"""Tool result content."""
|
||||
|
||||
role: Literal["tool_result"]
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
tool_result: JsonObjectType
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatLog:
|
||||
"""Class holding the chat history of a specific conversation."""
|
||||
@@ -235,17 +244,25 @@ class ChatLog:
|
||||
|
||||
@callback
|
||||
def async_add_assistant_content_without_tools(
|
||||
self, content: AssistantContent
|
||||
self, content: AssistantContent | ToolResultContent
|
||||
) -> None:
|
||||
"""Add assistant content to the log."""
|
||||
"""Add assistant content to the log.
|
||||
|
||||
Allows assistant content without tool calls or with external tool calls only,
|
||||
as well as tool results for the external tools.
|
||||
"""
|
||||
LOGGER.debug("Adding assistant content: %s", content)
|
||||
if content.tool_calls is not None:
|
||||
raise ValueError("Tool calls not allowed")
|
||||
if (
|
||||
isinstance(content, AssistantContent)
|
||||
and content.tool_calls is not None
|
||||
and any(not tool_call.external for tool_call in content.tool_calls)
|
||||
):
|
||||
raise ValueError("Non-external tool calls not allowed")
|
||||
self.content.append(content)
|
||||
|
||||
async def async_add_assistant_content(
|
||||
self,
|
||||
content: AssistantContent,
|
||||
content: AssistantContent | ToolResultContent,
|
||||
/,
|
||||
tool_call_tasks: dict[str, asyncio.Task] | None = None,
|
||||
) -> AsyncGenerator[ToolResultContent]:
|
||||
@@ -258,7 +275,11 @@ class ChatLog:
|
||||
LOGGER.debug("Adding assistant content: %s", content)
|
||||
self.content.append(content)
|
||||
|
||||
if content.tool_calls is None:
|
||||
if (
|
||||
not isinstance(content, AssistantContent)
|
||||
or content.tool_calls is None
|
||||
or all(tool_call.external for tool_call in content.tool_calls)
|
||||
):
|
||||
return
|
||||
|
||||
if self.llm_api is None:
|
||||
@@ -267,13 +288,16 @@ class ChatLog:
|
||||
if tool_call_tasks is None:
|
||||
tool_call_tasks = {}
|
||||
for tool_input in content.tool_calls:
|
||||
if tool_input.id not in tool_call_tasks:
|
||||
if tool_input.id not in tool_call_tasks and not tool_input.external:
|
||||
tool_call_tasks[tool_input.id] = self.hass.async_create_task(
|
||||
self.llm_api.async_call_tool(tool_input),
|
||||
name=f"llm_tool_{tool_input.id}",
|
||||
)
|
||||
|
||||
for tool_input in content.tool_calls:
|
||||
if tool_input.external:
|
||||
continue
|
||||
|
||||
LOGGER.debug(
|
||||
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
|
||||
)
|
||||
@@ -296,7 +320,9 @@ class ChatLog:
|
||||
yield response_content
|
||||
|
||||
async def async_add_delta_content_stream(
|
||||
self, agent_id: str, stream: AsyncIterable[AssistantContentDeltaDict]
|
||||
self,
|
||||
agent_id: str,
|
||||
stream: AsyncIterable[AssistantContentDeltaDict | ToolResultContentDeltaDict],
|
||||
) -> AsyncGenerator[AssistantContent | ToolResultContent]:
|
||||
"""Stream content into the chat log.
|
||||
|
||||
@@ -320,30 +346,34 @@ class ChatLog:
|
||||
|
||||
# Indicates update to current message
|
||||
if "role" not in delta:
|
||||
if delta_content := delta.get("content"):
|
||||
# ToolResultContentDeltaDict will always have a role
|
||||
assistant_delta = cast(AssistantContentDeltaDict, delta)
|
||||
if delta_content := assistant_delta.get("content"):
|
||||
current_content += delta_content
|
||||
if delta_thinking_content := delta.get("thinking_content"):
|
||||
if delta_thinking_content := assistant_delta.get("thinking_content"):
|
||||
current_thinking_content += delta_thinking_content
|
||||
if delta_native := delta.get("native"):
|
||||
if delta_native := assistant_delta.get("native"):
|
||||
if current_native is not None:
|
||||
raise RuntimeError(
|
||||
"Native content already set, cannot overwrite"
|
||||
)
|
||||
current_native = delta_native
|
||||
if delta_tool_calls := delta.get("tool_calls"):
|
||||
if self.llm_api is None:
|
||||
raise ValueError("No LLM API configured")
|
||||
if delta_tool_calls := assistant_delta.get("tool_calls"):
|
||||
current_tool_calls += delta_tool_calls
|
||||
|
||||
# Start processing the tool calls as soon as we know about them
|
||||
for tool_call in delta_tool_calls:
|
||||
tool_call_tasks[tool_call.id] = self.hass.async_create_task(
|
||||
self.llm_api.async_call_tool(tool_call),
|
||||
name=f"llm_tool_{tool_call.id}",
|
||||
)
|
||||
if not tool_call.external:
|
||||
if self.llm_api is None:
|
||||
raise ValueError("No LLM API configured")
|
||||
|
||||
tool_call_tasks[tool_call.id] = self.hass.async_create_task(
|
||||
self.llm_api.async_call_tool(tool_call),
|
||||
name=f"llm_tool_{tool_call.id}",
|
||||
)
|
||||
if self.delta_listener:
|
||||
if filtered_delta := {
|
||||
k: v for k, v in delta.items() if k != "native"
|
||||
k: v for k, v in assistant_delta.items() if k != "native"
|
||||
}:
|
||||
# We do not want to send the native content to the listener
|
||||
# as it is not JSON serializable
|
||||
@@ -351,10 +381,6 @@ class ChatLog:
|
||||
continue
|
||||
|
||||
# Starting a new message
|
||||
|
||||
if delta["role"] != "assistant":
|
||||
raise ValueError(f"Only assistant role expected. Got {delta['role']}")
|
||||
|
||||
# Yield the previous message if it has content
|
||||
if (
|
||||
current_content
|
||||
@@ -362,7 +388,7 @@ class ChatLog:
|
||||
or current_tool_calls
|
||||
or current_native
|
||||
):
|
||||
content = AssistantContent(
|
||||
content: AssistantContent | ToolResultContent = AssistantContent(
|
||||
agent_id=agent_id,
|
||||
content=current_content or None,
|
||||
thinking_content=current_thinking_content or None,
|
||||
@@ -376,14 +402,38 @@ class ChatLog:
|
||||
yield tool_result
|
||||
if self.delta_listener:
|
||||
self.delta_listener(self, asdict(tool_result))
|
||||
current_content = ""
|
||||
current_thinking_content = ""
|
||||
current_native = None
|
||||
current_tool_calls = []
|
||||
|
||||
current_content = delta.get("content") or ""
|
||||
current_thinking_content = delta.get("thinking_content") or ""
|
||||
current_tool_calls = delta.get("tool_calls") or []
|
||||
current_native = delta.get("native")
|
||||
if delta["role"] == "assistant":
|
||||
current_content = delta.get("content") or ""
|
||||
current_thinking_content = delta.get("thinking_content") or ""
|
||||
current_tool_calls = delta.get("tool_calls") or []
|
||||
current_native = delta.get("native")
|
||||
|
||||
if self.delta_listener:
|
||||
self.delta_listener(self, delta) # type: ignore[arg-type]
|
||||
if self.delta_listener:
|
||||
if filtered_delta := {
|
||||
k: v for k, v in delta.items() if k != "native"
|
||||
}:
|
||||
self.delta_listener(self, filtered_delta)
|
||||
elif delta["role"] == "tool_result":
|
||||
content = ToolResultContent(
|
||||
agent_id=agent_id,
|
||||
tool_call_id=delta["tool_call_id"],
|
||||
tool_name=delta["tool_name"],
|
||||
tool_result=delta["tool_result"],
|
||||
)
|
||||
yield content
|
||||
if self.delta_listener:
|
||||
self.delta_listener(self, asdict(content))
|
||||
self.async_add_assistant_content_without_tools(content)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Only assistant and tool_result roles expected."
|
||||
f" Got {delta['role']}"
|
||||
)
|
||||
|
||||
if (
|
||||
current_content
|
||||
|
||||
@@ -183,6 +183,7 @@ class ToolInput:
|
||||
tool_args: dict[str, Any]
|
||||
# Using lambda for default to allow patching in tests
|
||||
id: str = dc_field(default_factory=lambda: ulid_now()) # pylint: disable=unnecessary-lambda
|
||||
external: bool = False
|
||||
|
||||
|
||||
class Tool:
|
||||
|
||||
@@ -40,6 +40,7 @@
|
||||
'thinking_content': "Okay, let's give it a shot. Will I pass the test?",
|
||||
'tool_calls': list([
|
||||
dict({
|
||||
'external': False,
|
||||
'id': 'toolu_0123456789AbCdEfGhIjKlM',
|
||||
'tool_args': dict({
|
||||
'param1': 'test_value',
|
||||
|
||||
@@ -467,6 +467,7 @@
|
||||
'chat_log_delta': dict({
|
||||
'tool_calls': list([
|
||||
dict({
|
||||
'external': False,
|
||||
'id': 'test_tool_id',
|
||||
'tool_args': dict({
|
||||
}),
|
||||
|
||||
@@ -16,6 +16,34 @@
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_add_delta_content_stream[deltas11]
|
||||
list([
|
||||
dict({
|
||||
'agent_id': 'mock-agent-id',
|
||||
'content': 'Test',
|
||||
'native': None,
|
||||
'role': 'assistant',
|
||||
'thinking_content': None,
|
||||
'tool_calls': list([
|
||||
dict({
|
||||
'external': True,
|
||||
'id': 'mock-tool-call-id',
|
||||
'tool_args': dict({
|
||||
'param1': 'Test Param 1',
|
||||
}),
|
||||
'tool_name': 'test_tool',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
dict({
|
||||
'agent_id': 'mock-agent-id',
|
||||
'role': 'tool_result',
|
||||
'tool_call_id': 'mock-tool-call-id',
|
||||
'tool_name': 'test_tool',
|
||||
'tool_result': 'Test Result',
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_add_delta_content_stream[deltas1]
|
||||
list([
|
||||
dict({
|
||||
@@ -58,6 +86,7 @@
|
||||
'thinking_content': None,
|
||||
'tool_calls': list([
|
||||
dict({
|
||||
'external': False,
|
||||
'id': 'mock-tool-call-id',
|
||||
'tool_args': dict({
|
||||
'param1': 'Test Param 1',
|
||||
@@ -85,6 +114,7 @@
|
||||
'thinking_content': None,
|
||||
'tool_calls': list([
|
||||
dict({
|
||||
'external': False,
|
||||
'id': 'mock-tool-call-id',
|
||||
'tool_args': dict({
|
||||
'param1': 'Test Param 1',
|
||||
@@ -112,6 +142,7 @@
|
||||
'thinking_content': None,
|
||||
'tool_calls': list([
|
||||
dict({
|
||||
'external': False,
|
||||
'id': 'mock-tool-call-id',
|
||||
'tool_args': dict({
|
||||
'param1': 'Test Param 1',
|
||||
@@ -147,6 +178,7 @@
|
||||
'thinking_content': None,
|
||||
'tool_calls': list([
|
||||
dict({
|
||||
'external': False,
|
||||
'id': 'mock-tool-call-id',
|
||||
'tool_args': dict({
|
||||
'param1': 'Test Param 1',
|
||||
@@ -154,6 +186,7 @@
|
||||
'tool_name': 'test_tool',
|
||||
}),
|
||||
dict({
|
||||
'external': False,
|
||||
'id': 'mock-tool-call-id-2',
|
||||
'tool_args': dict({
|
||||
'param1': 'Test Param 2',
|
||||
|
||||
@@ -538,6 +538,27 @@ async def test_tool_call_exception(
|
||||
{"role": "assistant"},
|
||||
{"native": object()},
|
||||
],
|
||||
# With external tool calls
|
||||
[
|
||||
{"role": "assistant"},
|
||||
{"content": "Test"},
|
||||
{
|
||||
"tool_calls": [
|
||||
llm.ToolInput(
|
||||
id="mock-tool-call-id",
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "Test Param 1"},
|
||||
external=True,
|
||||
)
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool_result",
|
||||
"tool_call_id": "mock-tool-call-id",
|
||||
"tool_name": "test_tool",
|
||||
"tool_result": "Test Result",
|
||||
},
|
||||
],
|
||||
],
|
||||
)
|
||||
async def test_add_delta_content_stream(
|
||||
@@ -569,7 +590,8 @@ async def test_add_delta_content_stream(
|
||||
for d in deltas:
|
||||
yield d
|
||||
if filtered_delta := {k: v for k, v in d.items() if k != "native"}:
|
||||
expected_delta.append(filtered_delta)
|
||||
if filtered_delta.get("role") != "tool_result":
|
||||
expected_delta.append(filtered_delta)
|
||||
|
||||
captured_deltas = []
|
||||
|
||||
|
||||
@@ -135,6 +135,7 @@
|
||||
'thinking_content': None,
|
||||
'tool_calls': list([
|
||||
dict({
|
||||
'external': False,
|
||||
'id': 'call_call_1',
|
||||
'tool_args': dict({
|
||||
'param1': 'call1',
|
||||
|
||||
@@ -30,6 +30,7 @@
|
||||
'thinking_content': None,
|
||||
'tool_calls': list([
|
||||
dict({
|
||||
'external': False,
|
||||
'id': 'call_call_1',
|
||||
'tool_args': dict({
|
||||
'param1': 'call1',
|
||||
@@ -53,6 +54,7 @@
|
||||
'thinking_content': None,
|
||||
'tool_calls': list([
|
||||
dict({
|
||||
'external': False,
|
||||
'id': 'call_call_2',
|
||||
'tool_args': dict({
|
||||
'param1': 'call2',
|
||||
@@ -144,6 +146,7 @@
|
||||
'thinking_content': None,
|
||||
'tool_calls': list([
|
||||
dict({
|
||||
'external': False,
|
||||
'id': 'call_call_1',
|
||||
'tool_args': dict({
|
||||
'param1': 'call1',
|
||||
|
||||
Reference in New Issue
Block a user