mirror of
https://github.com/home-assistant/core.git
synced 2025-07-09 14:27:07 +00:00
Add Gemini/OpenAI token stats to the conversation trace (#141118)
* Add gemini token status to the conversation trace * Add OpenAI Token Stats * Revert input_tokens_details since its not in the openai version yet * Fix ruff lint errors
This commit is contained in:
parent
663a204c04
commit
f14b76c54b
@ -8,7 +8,7 @@ from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import asdict, dataclass, field, replace
|
||||
import logging
|
||||
from typing import Literal, TypedDict
|
||||
from typing import Any, Literal, TypedDict
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
@ -456,10 +456,16 @@ class ChatLog:
|
||||
LOGGER.debug("Prompt: %s", self.content)
|
||||
LOGGER.debug("Tools: %s", self.llm_api.tools if self.llm_api else None)
|
||||
|
||||
trace.async_conversation_trace_append(
|
||||
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||
self.async_trace(
|
||||
{
|
||||
"messages": self.content,
|
||||
"tools": self.llm_api.tools if self.llm_api else None,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
def async_trace(self, agent_details: dict[str, Any]) -> None:
|
||||
"""Append agent specific details to the conversation trace."""
|
||||
trace.async_conversation_trace_append(
|
||||
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||
agent_details,
|
||||
)
|
||||
|
@ -403,6 +403,18 @@ class GoogleGenerativeAIConversationEntity(
|
||||
error = f"Sorry, I had a problem talking to Google Generative AI: {err}"
|
||||
raise HomeAssistantError(error) from err
|
||||
|
||||
if (usage_metadata := chat_response.usage_metadata) is not None:
|
||||
chat_log.async_trace(
|
||||
{
|
||||
"stats": {
|
||||
"input_tokens": usage_metadata.prompt_token_count,
|
||||
"cached_input_tokens": usage_metadata.cached_content_token_count
|
||||
or 0,
|
||||
"output_tokens": usage_metadata.candidates_token_count,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
response_parts = chat_response.candidates[0].content.parts
|
||||
if not response_parts:
|
||||
raise HomeAssistantError(
|
||||
|
@ -9,6 +9,7 @@ from openai._streaming import AsyncStream
|
||||
from openai.types.responses import (
|
||||
EasyInputMessageParam,
|
||||
FunctionToolParam,
|
||||
ResponseCompletedEvent,
|
||||
ResponseFunctionCallArgumentsDeltaEvent,
|
||||
ResponseFunctionCallArgumentsDoneEvent,
|
||||
ResponseFunctionToolCall,
|
||||
@ -111,6 +112,7 @@ def _convert_content_to_param(
|
||||
|
||||
|
||||
async def _transform_stream(
|
||||
chat_log: conversation.ChatLog,
|
||||
result: AsyncStream[ResponseStreamEvent],
|
||||
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
||||
"""Transform an OpenAI delta stream into HA format."""
|
||||
@ -137,6 +139,18 @@ async def _transform_stream(
|
||||
)
|
||||
]
|
||||
}
|
||||
elif (
|
||||
isinstance(event, ResponseCompletedEvent)
|
||||
and (usage := event.response.usage) is not None
|
||||
):
|
||||
chat_log.async_trace(
|
||||
{
|
||||
"stats": {
|
||||
"input_tokens": usage.input_tokens,
|
||||
"output_tokens": usage.output_tokens,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class OpenAIConversationEntity(
|
||||
@ -252,7 +266,7 @@ class OpenAIConversationEntity(
|
||||
raise HomeAssistantError("Error talking to OpenAI") from err
|
||||
|
||||
async for content in chat_log.async_add_delta_content_stream(
|
||||
user_input.agent_id, _transform_stream(result)
|
||||
user_input.agent_id, _transform_stream(chat_log, result)
|
||||
):
|
||||
messages.extend(_convert_content_to_param(content))
|
||||
|
||||
|
@ -156,8 +156,10 @@ async def test_function_call(
|
||||
trace_events = last_trace.get("events", [])
|
||||
assert [event["event_type"] for event in trace_events] == [
|
||||
trace.ConversationTraceEventType.ASYNC_PROCESS,
|
||||
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||
trace.ConversationTraceEventType.AGENT_DETAIL, # prompt and tools
|
||||
trace.ConversationTraceEventType.AGENT_DETAIL, # stats for response
|
||||
trace.ConversationTraceEventType.TOOL_CALL,
|
||||
trace.ConversationTraceEventType.AGENT_DETAIL, # stats for response
|
||||
]
|
||||
# AGENT_DETAIL event contains the raw prompt passed to the model
|
||||
detail_event = trace_events[1]
|
||||
@ -166,6 +168,13 @@ async def test_function_call(
|
||||
p["tool_name"] for p in detail_event["data"]["messages"][2]["tool_calls"]
|
||||
] == ["test_tool"]
|
||||
|
||||
detail_event = trace_events[2]
|
||||
assert set(detail_event["data"]["stats"].keys()) == {
|
||||
"input_tokens",
|
||||
"cached_input_tokens",
|
||||
"output_tokens",
|
||||
}
|
||||
|
||||
|
||||
@patch(
|
||||
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
|
||||
|
Loading…
x
Reference in New Issue
Block a user