diff --git a/homeassistant/components/anthropic/conversation.py b/homeassistant/components/anthropic/conversation.py index 56b8031417b..7e1fda467a8 100644 --- a/homeassistant/components/anthropic/conversation.py +++ b/homeassistant/components/anthropic/conversation.py @@ -9,11 +9,13 @@ from anthropic import AsyncStream from anthropic._types import NOT_GIVEN from anthropic.types import ( InputJSONDelta, + MessageDeltaUsage, MessageParam, MessageStreamEvent, RawContentBlockDeltaEvent, RawContentBlockStartEvent, RawContentBlockStopEvent, + RawMessageDeltaEvent, RawMessageStartEvent, RawMessageStopEvent, RedactedThinkingBlock, @@ -31,6 +33,7 @@ from anthropic.types import ( ToolResultBlockParam, ToolUseBlock, ToolUseBlockParam, + Usage, ) from voluptuous_openapi import convert @@ -162,7 +165,8 @@ def _convert_content( return messages -async def _transform_stream( +async def _transform_stream( # noqa: C901 - This is complex, but better to have it in one place + chat_log: conversation.ChatLog, result: AsyncStream[MessageStreamEvent], messages: list[MessageParam], ) -> AsyncGenerator[conversation.AssistantContentDeltaDict]: @@ -207,6 +211,7 @@ async def _transform_stream( | None ) = None current_tool_args: str + input_usage: Usage | None = None async for response in result: LOGGER.debug("Received response: %s", response) @@ -215,6 +220,7 @@ async def _transform_stream( if response.message.role != "assistant": raise ValueError("Unexpected message role") current_message = MessageParam(role=response.message.role, content=[]) + input_usage = response.message.usage elif isinstance(response, RawContentBlockStartEvent): if isinstance(response.content_block, ToolUseBlock): current_block = ToolUseBlockParam( @@ -285,12 +291,34 @@ async def _transform_stream( raise ValueError("Unexpected stop event without a current message") current_message["content"].append(current_block) # type: ignore[union-attr] current_block = None + elif isinstance(response, RawMessageDeltaEvent): + if (usage := response.usage) is not None: + chat_log.async_trace(_create_token_stats(input_usage, usage)) elif isinstance(response, RawMessageStopEvent): if current_message is not None: messages.append(current_message) current_message = None +def _create_token_stats( + input_usage: Usage | None, response_usage: MessageDeltaUsage +) -> dict[str, Any]: + """Create token stats for conversation agent tracing.""" + input_tokens = 0 + cached_input_tokens = 0 + if input_usage: + input_tokens = input_usage.input_tokens + cached_input_tokens = input_usage.cache_creation_input_tokens or 0 + output_tokens = response_usage.output_tokens + return { + "stats": { + "input_tokens": input_tokens, + "cached_input_tokens": cached_input_tokens, + "output_tokens": output_tokens, + } + } + + class AnthropicConversationEntity( conversation.ConversationEntity, conversation.AbstractConversationAgent ): @@ -393,7 +421,8 @@ class AnthropicConversationEntity( [ content async for content in chat_log.async_add_delta_content_stream( - user_input.agent_id, _transform_stream(stream, messages) + user_input.agent_id, + _transform_stream(chat_log, stream, messages), ) if not isinstance(content, conversation.AssistantContent) ] diff --git a/tests/components/anthropic/test_conversation.py b/tests/components/anthropic/test_conversation.py index caaef43e931..8706abf36c0 100644 --- a/tests/components/anthropic/test_conversation.py +++ b/tests/components/anthropic/test_conversation.py @@ -8,9 +8,11 @@ from anthropic import RateLimitError from anthropic.types import ( InputJSONDelta, Message, + MessageDeltaUsage, RawContentBlockDeltaEvent, RawContentBlockStartEvent, RawContentBlockStopEvent, + RawMessageDeltaEvent, RawMessageStartEvent, RawMessageStopEvent, RawMessageStreamEvent, @@ -23,6 +25,7 @@ from anthropic.types import ( ToolUseBlock, Usage, ) +from anthropic.types.raw_message_delta_event import Delta from freezegun import freeze_time from httpx import URL, Request, Response import pytest @@ -65,6 +68,11 @@ def create_messages( type="message_start", ), *content_blocks, + RawMessageDeltaEvent( + type="message_delta", + delta=Delta(stop_reason="end_turn", stop_sequence=""), + usage=MessageDeltaUsage(output_tokens=0), + ), RawMessageStopEvent(type="message_stop"), ]