mirror of
https://github.com/home-assistant/core.git
synced 2025-07-27 07:07:28 +00:00
Improve conversation agent tracing to help with eval and data collection (#122542)
This commit is contained in:
parent
4f5eab4646
commit
8d0e998e54
@ -47,6 +47,7 @@ from homeassistant.util.json import JsonObjectType, json_loads_object
|
|||||||
from .const import DEFAULT_EXPOSED_ATTRIBUTES, DOMAIN, ConversationEntityFeature
|
from .const import DEFAULT_EXPOSED_ATTRIBUTES, DOMAIN, ConversationEntityFeature
|
||||||
from .entity import ConversationEntity
|
from .entity import ConversationEntity
|
||||||
from .models import ConversationInput, ConversationResult
|
from .models import ConversationInput, ConversationResult
|
||||||
|
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
_DEFAULT_ERROR_TEXT = "Sorry, I couldn't understand that"
|
_DEFAULT_ERROR_TEXT = "Sorry, I couldn't understand that"
|
||||||
@ -348,6 +349,16 @@ class DefaultAgent(ConversationEntity):
|
|||||||
}
|
}
|
||||||
for entity in result.entities_list
|
for entity in result.entities_list
|
||||||
}
|
}
|
||||||
|
async_conversation_trace_append(
|
||||||
|
ConversationTraceEventType.TOOL_CALL,
|
||||||
|
{
|
||||||
|
"intent_name": result.intent.name,
|
||||||
|
"slots": {
|
||||||
|
entity.name: entity.value or entity.text
|
||||||
|
for entity in result.entities_list
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
intent_response = await intent.async_handle(
|
intent_response = await intent.async_handle(
|
||||||
|
@ -22,8 +22,8 @@ class ConversationTraceEventType(enum.StrEnum):
|
|||||||
AGENT_DETAIL = "agent_detail"
|
AGENT_DETAIL = "agent_detail"
|
||||||
"""Event detail added by a conversation agent."""
|
"""Event detail added by a conversation agent."""
|
||||||
|
|
||||||
LLM_TOOL_CALL = "llm_tool_call"
|
TOOL_CALL = "tool_call"
|
||||||
"""An LLM Tool call"""
|
"""A conversation agent Tool call or default agent intent call."""
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
@ -286,6 +286,7 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
if supports_system_instruction
|
if supports_system_instruction
|
||||||
else messages[2:],
|
else messages[2:],
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
|
"tools": [*llm_api.tools] if llm_api else None,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -225,7 +225,8 @@ class OpenAIConversationEntity(
|
|||||||
LOGGER.debug("Prompt: %s", messages)
|
LOGGER.debug("Prompt: %s", messages)
|
||||||
LOGGER.debug("Tools: %s", tools)
|
LOGGER.debug("Tools: %s", tools)
|
||||||
trace.async_conversation_trace_append(
|
trace.async_conversation_trace_append(
|
||||||
trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages}
|
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||||
|
{"messages": messages, "tools": llm_api.tools if llm_api else None},
|
||||||
)
|
)
|
||||||
|
|
||||||
client = self.entry.runtime_data
|
client = self.entry.runtime_data
|
||||||
|
@ -167,7 +167,7 @@ class APIInstance:
|
|||||||
async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
|
async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
|
||||||
"""Call a LLM tool, validate args and return the response."""
|
"""Call a LLM tool, validate args and return the response."""
|
||||||
async_conversation_trace_append(
|
async_conversation_trace_append(
|
||||||
ConversationTraceEventType.LLM_TOOL_CALL,
|
ConversationTraceEventType.TOOL_CALL,
|
||||||
{"tool_name": tool_input.tool_name, "tool_args": tool_input.tool_args},
|
{"tool_name": tool_input.tool_name, "tool_args": tool_input.tool_args},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ async def test_converation_trace(
|
|||||||
assert traces
|
assert traces
|
||||||
last_trace = traces[-1].as_dict()
|
last_trace = traces[-1].as_dict()
|
||||||
assert last_trace.get("events")
|
assert last_trace.get("events")
|
||||||
assert len(last_trace.get("events")) == 1
|
assert len(last_trace.get("events")) == 2
|
||||||
trace_event = last_trace["events"][0]
|
trace_event = last_trace["events"][0]
|
||||||
assert (
|
assert (
|
||||||
trace_event.get("event_type") == trace.ConversationTraceEventType.ASYNC_PROCESS
|
trace_event.get("event_type") == trace.ConversationTraceEventType.ASYNC_PROCESS
|
||||||
@ -50,6 +50,16 @@ async def test_converation_trace(
|
|||||||
== "Added apples"
|
== "Added apples"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
trace_event = last_trace["events"][1]
|
||||||
|
assert trace_event.get("event_type") == trace.ConversationTraceEventType.TOOL_CALL
|
||||||
|
assert trace_event.get("data") == {
|
||||||
|
"intent_name": "HassListAddItem",
|
||||||
|
"slots": {
|
||||||
|
"name": "Shopping List",
|
||||||
|
"item": "apples ",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
async def test_converation_trace_error(
|
async def test_converation_trace_error(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
|
@ -269,11 +269,12 @@ async def test_function_call(
|
|||||||
assert [event["event_type"] for event in trace_events] == [
|
assert [event["event_type"] for event in trace_events] == [
|
||||||
trace.ConversationTraceEventType.ASYNC_PROCESS,
|
trace.ConversationTraceEventType.ASYNC_PROCESS,
|
||||||
trace.ConversationTraceEventType.AGENT_DETAIL,
|
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||||
trace.ConversationTraceEventType.LLM_TOOL_CALL,
|
trace.ConversationTraceEventType.TOOL_CALL,
|
||||||
]
|
]
|
||||||
# AGENT_DETAIL event contains the raw prompt passed to the model
|
# AGENT_DETAIL event contains the raw prompt passed to the model
|
||||||
detail_event = trace_events[1]
|
detail_event = trace_events[1]
|
||||||
assert "Answer in plain text" in detail_event["data"]["prompt"]
|
assert "Answer in plain text" in detail_event["data"]["prompt"]
|
||||||
|
assert [t.name for t in detail_event["data"]["tools"]] == ["test_tool"]
|
||||||
|
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
|
@ -294,7 +294,7 @@ async def test_function_call(
|
|||||||
assert [event["event_type"] for event in trace_events] == [
|
assert [event["event_type"] for event in trace_events] == [
|
||||||
trace.ConversationTraceEventType.ASYNC_PROCESS,
|
trace.ConversationTraceEventType.ASYNC_PROCESS,
|
||||||
trace.ConversationTraceEventType.AGENT_DETAIL,
|
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||||
trace.ConversationTraceEventType.LLM_TOOL_CALL,
|
trace.ConversationTraceEventType.TOOL_CALL,
|
||||||
]
|
]
|
||||||
# AGENT_DETAIL event contains the raw prompt passed to the model
|
# AGENT_DETAIL event contains the raw prompt passed to the model
|
||||||
detail_event = trace_events[1]
|
detail_event = trace_events[1]
|
||||||
@ -303,6 +303,7 @@ async def test_function_call(
|
|||||||
"Today's date is 2024-06-03."
|
"Today's date is 2024-06-03."
|
||||||
in trace_events[1]["data"]["messages"][0]["content"]
|
in trace_events[1]["data"]["messages"][0]["content"]
|
||||||
)
|
)
|
||||||
|
assert [t.name for t in detail_event["data"]["tools"]] == ["test_tool"]
|
||||||
|
|
||||||
# Call it again, make sure we have updated prompt
|
# Call it again, make sure we have updated prompt
|
||||||
with (
|
with (
|
||||||
|
Loading…
x
Reference in New Issue
Block a user