Add Gemini/OpenAI token stats to the conversation trace (#141118)

* Add gemini token status to the conversation trace

* Add OpenAI Token Stats

* Revert input_tokens_details since its not in the openai version yet

* Fix ruff lint errors
This commit is contained in:
Allen Porter 2025-03-23 09:03:06 -07:00 committed by GitHub
parent 663a204c04
commit f14b76c54b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 47 additions and 6 deletions

View File

@ -8,7 +8,7 @@ from contextlib import contextmanager
from contextvars import ContextVar from contextvars import ContextVar
from dataclasses import asdict, dataclass, field, replace from dataclasses import asdict, dataclass, field, replace
import logging import logging
from typing import Literal, TypedDict from typing import Any, Literal, TypedDict
import voluptuous as vol import voluptuous as vol
@ -456,10 +456,16 @@ class ChatLog:
LOGGER.debug("Prompt: %s", self.content) LOGGER.debug("Prompt: %s", self.content)
LOGGER.debug("Tools: %s", self.llm_api.tools if self.llm_api else None) LOGGER.debug("Tools: %s", self.llm_api.tools if self.llm_api else None)
trace.async_conversation_trace_append( self.async_trace(
trace.ConversationTraceEventType.AGENT_DETAIL,
{ {
"messages": self.content, "messages": self.content,
"tools": self.llm_api.tools if self.llm_api else None, "tools": self.llm_api.tools if self.llm_api else None,
}, }
)
def async_trace(self, agent_details: dict[str, Any]) -> None:
"""Append agent specific details to the conversation trace."""
trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL,
agent_details,
) )

View File

@ -403,6 +403,18 @@ class GoogleGenerativeAIConversationEntity(
error = f"Sorry, I had a problem talking to Google Generative AI: {err}" error = f"Sorry, I had a problem talking to Google Generative AI: {err}"
raise HomeAssistantError(error) from err raise HomeAssistantError(error) from err
if (usage_metadata := chat_response.usage_metadata) is not None:
chat_log.async_trace(
{
"stats": {
"input_tokens": usage_metadata.prompt_token_count,
"cached_input_tokens": usage_metadata.cached_content_token_count
or 0,
"output_tokens": usage_metadata.candidates_token_count,
}
}
)
response_parts = chat_response.candidates[0].content.parts response_parts = chat_response.candidates[0].content.parts
if not response_parts: if not response_parts:
raise HomeAssistantError( raise HomeAssistantError(

View File

@ -9,6 +9,7 @@ from openai._streaming import AsyncStream
from openai.types.responses import ( from openai.types.responses import (
EasyInputMessageParam, EasyInputMessageParam,
FunctionToolParam, FunctionToolParam,
ResponseCompletedEvent,
ResponseFunctionCallArgumentsDeltaEvent, ResponseFunctionCallArgumentsDeltaEvent,
ResponseFunctionCallArgumentsDoneEvent, ResponseFunctionCallArgumentsDoneEvent,
ResponseFunctionToolCall, ResponseFunctionToolCall,
@ -111,6 +112,7 @@ def _convert_content_to_param(
async def _transform_stream( async def _transform_stream(
chat_log: conversation.ChatLog,
result: AsyncStream[ResponseStreamEvent], result: AsyncStream[ResponseStreamEvent],
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]: ) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
"""Transform an OpenAI delta stream into HA format.""" """Transform an OpenAI delta stream into HA format."""
@ -137,6 +139,18 @@ async def _transform_stream(
) )
] ]
} }
elif (
isinstance(event, ResponseCompletedEvent)
and (usage := event.response.usage) is not None
):
chat_log.async_trace(
{
"stats": {
"input_tokens": usage.input_tokens,
"output_tokens": usage.output_tokens,
}
}
)
class OpenAIConversationEntity( class OpenAIConversationEntity(
@ -252,7 +266,7 @@ class OpenAIConversationEntity(
raise HomeAssistantError("Error talking to OpenAI") from err raise HomeAssistantError("Error talking to OpenAI") from err
async for content in chat_log.async_add_delta_content_stream( async for content in chat_log.async_add_delta_content_stream(
user_input.agent_id, _transform_stream(result) user_input.agent_id, _transform_stream(chat_log, result)
): ):
messages.extend(_convert_content_to_param(content)) messages.extend(_convert_content_to_param(content))

View File

@ -156,8 +156,10 @@ async def test_function_call(
trace_events = last_trace.get("events", []) trace_events = last_trace.get("events", [])
assert [event["event_type"] for event in trace_events] == [ assert [event["event_type"] for event in trace_events] == [
trace.ConversationTraceEventType.ASYNC_PROCESS, trace.ConversationTraceEventType.ASYNC_PROCESS,
trace.ConversationTraceEventType.AGENT_DETAIL, trace.ConversationTraceEventType.AGENT_DETAIL, # prompt and tools
trace.ConversationTraceEventType.AGENT_DETAIL, # stats for response
trace.ConversationTraceEventType.TOOL_CALL, trace.ConversationTraceEventType.TOOL_CALL,
trace.ConversationTraceEventType.AGENT_DETAIL, # stats for response
] ]
# AGENT_DETAIL event contains the raw prompt passed to the model # AGENT_DETAIL event contains the raw prompt passed to the model
detail_event = trace_events[1] detail_event = trace_events[1]
@ -166,6 +168,13 @@ async def test_function_call(
p["tool_name"] for p in detail_event["data"]["messages"][2]["tool_calls"] p["tool_name"] for p in detail_event["data"]["messages"][2]["tool_calls"]
] == ["test_tool"] ] == ["test_tool"]
detail_event = trace_events[2]
assert set(detail_event["data"]["stats"].keys()) == {
"input_tokens",
"cached_input_tokens",
"output_tokens",
}
@patch( @patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools" "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"