diff --git a/homeassistant/components/conversation/chat_log.py b/homeassistant/components/conversation/chat_log.py index 2de785dae7d..cb7b8dd22f7 100644 --- a/homeassistant/components/conversation/chat_log.py +++ b/homeassistant/components/conversation/chat_log.py @@ -8,7 +8,7 @@ from contextlib import contextmanager from contextvars import ContextVar from dataclasses import asdict, dataclass, field, replace import logging -from typing import Literal, TypedDict +from typing import Any, Literal, TypedDict import voluptuous as vol @@ -456,10 +456,16 @@ class ChatLog: LOGGER.debug("Prompt: %s", self.content) LOGGER.debug("Tools: %s", self.llm_api.tools if self.llm_api else None) - trace.async_conversation_trace_append( - trace.ConversationTraceEventType.AGENT_DETAIL, + self.async_trace( { "messages": self.content, "tools": self.llm_api.tools if self.llm_api else None, - }, + } + ) + + def async_trace(self, agent_details: dict[str, Any]) -> None: + """Append agent specific details to the conversation trace.""" + trace.async_conversation_trace_append( + trace.ConversationTraceEventType.AGENT_DETAIL, + agent_details, ) diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py index 4648f1afb4c..e35346cc745 100644 --- a/homeassistant/components/google_generative_ai_conversation/conversation.py +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -403,6 +403,18 @@ class GoogleGenerativeAIConversationEntity( error = f"Sorry, I had a problem talking to Google Generative AI: {err}" raise HomeAssistantError(error) from err + if (usage_metadata := chat_response.usage_metadata) is not None: + chat_log.async_trace( + { + "stats": { + "input_tokens": usage_metadata.prompt_token_count, + "cached_input_tokens": usage_metadata.cached_content_token_count + or 0, + "output_tokens": usage_metadata.candidates_token_count, + } + } + ) + response_parts = chat_response.candidates[0].content.parts if not response_parts: raise HomeAssistantError( diff --git a/homeassistant/components/openai_conversation/conversation.py b/homeassistant/components/openai_conversation/conversation.py index 6767734bb00..32ac20b2680 100644 --- a/homeassistant/components/openai_conversation/conversation.py +++ b/homeassistant/components/openai_conversation/conversation.py @@ -9,6 +9,7 @@ from openai._streaming import AsyncStream from openai.types.responses import ( EasyInputMessageParam, FunctionToolParam, + ResponseCompletedEvent, ResponseFunctionCallArgumentsDeltaEvent, ResponseFunctionCallArgumentsDoneEvent, ResponseFunctionToolCall, @@ -111,6 +112,7 @@ def _convert_content_to_param( async def _transform_stream( + chat_log: conversation.ChatLog, result: AsyncStream[ResponseStreamEvent], ) -> AsyncGenerator[conversation.AssistantContentDeltaDict]: """Transform an OpenAI delta stream into HA format.""" @@ -137,6 +139,18 @@ async def _transform_stream( ) ] } + elif ( + isinstance(event, ResponseCompletedEvent) + and (usage := event.response.usage) is not None + ): + chat_log.async_trace( + { + "stats": { + "input_tokens": usage.input_tokens, + "output_tokens": usage.output_tokens, + } + } + ) class OpenAIConversationEntity( @@ -252,7 +266,7 @@ class OpenAIConversationEntity( raise HomeAssistantError("Error talking to OpenAI") from err async for content in chat_log.async_add_delta_content_stream( - user_input.agent_id, _transform_stream(result) + user_input.agent_id, _transform_stream(chat_log, result) ): messages.extend(_convert_content_to_param(content)) diff --git a/tests/components/google_generative_ai_conversation/test_conversation.py b/tests/components/google_generative_ai_conversation/test_conversation.py index 64f71c18bf2..22bc079a21f 100644 --- a/tests/components/google_generative_ai_conversation/test_conversation.py +++ b/tests/components/google_generative_ai_conversation/test_conversation.py @@ -156,8 +156,10 @@ async def test_function_call( trace_events = last_trace.get("events", []) assert [event["event_type"] for event in trace_events] == [ trace.ConversationTraceEventType.ASYNC_PROCESS, - trace.ConversationTraceEventType.AGENT_DETAIL, + trace.ConversationTraceEventType.AGENT_DETAIL, # prompt and tools + trace.ConversationTraceEventType.AGENT_DETAIL, # stats for response trace.ConversationTraceEventType.TOOL_CALL, + trace.ConversationTraceEventType.AGENT_DETAIL, # stats for response ] # AGENT_DETAIL event contains the raw prompt passed to the model detail_event = trace_events[1] @@ -166,6 +168,13 @@ async def test_function_call( p["tool_name"] for p in detail_event["data"]["messages"][2]["tool_calls"] ] == ["test_tool"] + detail_event = trace_events[2] + assert set(detail_event["data"]["stats"].keys()) == { + "input_tokens", + "cached_input_tokens", + "output_tokens", + } + @patch( "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"