mirror of
https://github.com/home-assistant/core.git
synced 2025-07-09 14:27:07 +00:00
Add Gemini/OpenAI token stats to the conversation trace (#141118)
* Add gemini token status to the conversation trace * Add OpenAI Token Stats * Revert input_tokens_details since its not in the openai version yet * Fix ruff lint errors
This commit is contained in:
parent
663a204c04
commit
f14b76c54b
@ -8,7 +8,7 @@ from contextlib import contextmanager
|
|||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from dataclasses import asdict, dataclass, field, replace
|
from dataclasses import asdict, dataclass, field, replace
|
||||||
import logging
|
import logging
|
||||||
from typing import Literal, TypedDict
|
from typing import Any, Literal, TypedDict
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
@ -456,10 +456,16 @@ class ChatLog:
|
|||||||
LOGGER.debug("Prompt: %s", self.content)
|
LOGGER.debug("Prompt: %s", self.content)
|
||||||
LOGGER.debug("Tools: %s", self.llm_api.tools if self.llm_api else None)
|
LOGGER.debug("Tools: %s", self.llm_api.tools if self.llm_api else None)
|
||||||
|
|
||||||
trace.async_conversation_trace_append(
|
self.async_trace(
|
||||||
trace.ConversationTraceEventType.AGENT_DETAIL,
|
|
||||||
{
|
{
|
||||||
"messages": self.content,
|
"messages": self.content,
|
||||||
"tools": self.llm_api.tools if self.llm_api else None,
|
"tools": self.llm_api.tools if self.llm_api else None,
|
||||||
},
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def async_trace(self, agent_details: dict[str, Any]) -> None:
|
||||||
|
"""Append agent specific details to the conversation trace."""
|
||||||
|
trace.async_conversation_trace_append(
|
||||||
|
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||||
|
agent_details,
|
||||||
)
|
)
|
||||||
|
@ -403,6 +403,18 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
error = f"Sorry, I had a problem talking to Google Generative AI: {err}"
|
error = f"Sorry, I had a problem talking to Google Generative AI: {err}"
|
||||||
raise HomeAssistantError(error) from err
|
raise HomeAssistantError(error) from err
|
||||||
|
|
||||||
|
if (usage_metadata := chat_response.usage_metadata) is not None:
|
||||||
|
chat_log.async_trace(
|
||||||
|
{
|
||||||
|
"stats": {
|
||||||
|
"input_tokens": usage_metadata.prompt_token_count,
|
||||||
|
"cached_input_tokens": usage_metadata.cached_content_token_count
|
||||||
|
or 0,
|
||||||
|
"output_tokens": usage_metadata.candidates_token_count,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
response_parts = chat_response.candidates[0].content.parts
|
response_parts = chat_response.candidates[0].content.parts
|
||||||
if not response_parts:
|
if not response_parts:
|
||||||
raise HomeAssistantError(
|
raise HomeAssistantError(
|
||||||
|
@ -9,6 +9,7 @@ from openai._streaming import AsyncStream
|
|||||||
from openai.types.responses import (
|
from openai.types.responses import (
|
||||||
EasyInputMessageParam,
|
EasyInputMessageParam,
|
||||||
FunctionToolParam,
|
FunctionToolParam,
|
||||||
|
ResponseCompletedEvent,
|
||||||
ResponseFunctionCallArgumentsDeltaEvent,
|
ResponseFunctionCallArgumentsDeltaEvent,
|
||||||
ResponseFunctionCallArgumentsDoneEvent,
|
ResponseFunctionCallArgumentsDoneEvent,
|
||||||
ResponseFunctionToolCall,
|
ResponseFunctionToolCall,
|
||||||
@ -111,6 +112,7 @@ def _convert_content_to_param(
|
|||||||
|
|
||||||
|
|
||||||
async def _transform_stream(
|
async def _transform_stream(
|
||||||
|
chat_log: conversation.ChatLog,
|
||||||
result: AsyncStream[ResponseStreamEvent],
|
result: AsyncStream[ResponseStreamEvent],
|
||||||
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
||||||
"""Transform an OpenAI delta stream into HA format."""
|
"""Transform an OpenAI delta stream into HA format."""
|
||||||
@ -137,6 +139,18 @@ async def _transform_stream(
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
elif (
|
||||||
|
isinstance(event, ResponseCompletedEvent)
|
||||||
|
and (usage := event.response.usage) is not None
|
||||||
|
):
|
||||||
|
chat_log.async_trace(
|
||||||
|
{
|
||||||
|
"stats": {
|
||||||
|
"input_tokens": usage.input_tokens,
|
||||||
|
"output_tokens": usage.output_tokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIConversationEntity(
|
class OpenAIConversationEntity(
|
||||||
@ -252,7 +266,7 @@ class OpenAIConversationEntity(
|
|||||||
raise HomeAssistantError("Error talking to OpenAI") from err
|
raise HomeAssistantError("Error talking to OpenAI") from err
|
||||||
|
|
||||||
async for content in chat_log.async_add_delta_content_stream(
|
async for content in chat_log.async_add_delta_content_stream(
|
||||||
user_input.agent_id, _transform_stream(result)
|
user_input.agent_id, _transform_stream(chat_log, result)
|
||||||
):
|
):
|
||||||
messages.extend(_convert_content_to_param(content))
|
messages.extend(_convert_content_to_param(content))
|
||||||
|
|
||||||
|
@ -156,8 +156,10 @@ async def test_function_call(
|
|||||||
trace_events = last_trace.get("events", [])
|
trace_events = last_trace.get("events", [])
|
||||||
assert [event["event_type"] for event in trace_events] == [
|
assert [event["event_type"] for event in trace_events] == [
|
||||||
trace.ConversationTraceEventType.ASYNC_PROCESS,
|
trace.ConversationTraceEventType.ASYNC_PROCESS,
|
||||||
trace.ConversationTraceEventType.AGENT_DETAIL,
|
trace.ConversationTraceEventType.AGENT_DETAIL, # prompt and tools
|
||||||
|
trace.ConversationTraceEventType.AGENT_DETAIL, # stats for response
|
||||||
trace.ConversationTraceEventType.TOOL_CALL,
|
trace.ConversationTraceEventType.TOOL_CALL,
|
||||||
|
trace.ConversationTraceEventType.AGENT_DETAIL, # stats for response
|
||||||
]
|
]
|
||||||
# AGENT_DETAIL event contains the raw prompt passed to the model
|
# AGENT_DETAIL event contains the raw prompt passed to the model
|
||||||
detail_event = trace_events[1]
|
detail_event = trace_events[1]
|
||||||
@ -166,6 +168,13 @@ async def test_function_call(
|
|||||||
p["tool_name"] for p in detail_event["data"]["messages"][2]["tool_calls"]
|
p["tool_name"] for p in detail_event["data"]["messages"][2]["tool_calls"]
|
||||||
] == ["test_tool"]
|
] == ["test_tool"]
|
||||||
|
|
||||||
|
detail_event = trace_events[2]
|
||||||
|
assert set(detail_event["data"]["stats"].keys()) == {
|
||||||
|
"input_tokens",
|
||||||
|
"cached_input_tokens",
|
||||||
|
"output_tokens",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
|
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user