Update anthropic to use the streaming API (#138256)

This commit is contained in:
Allen Porter 2025-02-11 16:05:23 -08:00 committed by GitHub
parent 117a71cb67
commit da1e3c29ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 262 additions and 127 deletions

View File

@ -1,16 +1,23 @@
"""Conversation support for Anthropic.""" """Conversation support for Anthropic."""
from collections.abc import Callable from collections.abc import AsyncGenerator, Callable
import json import json
from typing import Any, Literal, cast from typing import Any, Literal
import anthropic import anthropic
from anthropic import AsyncStream
from anthropic._types import NOT_GIVEN from anthropic._types import NOT_GIVEN
from anthropic.types import ( from anthropic.types import (
InputJSONDelta,
Message, Message,
MessageParam, MessageParam,
MessageStreamEvent,
RawContentBlockDeltaEvent,
RawContentBlockStartEvent,
RawContentBlockStopEvent,
TextBlock, TextBlock,
TextBlockParam, TextBlockParam,
TextDelta,
ToolParam, ToolParam,
ToolResultBlockParam, ToolResultBlockParam,
ToolUseBlock, ToolUseBlock,
@ -109,7 +116,7 @@ def _convert_content(chat_content: conversation.Content) -> MessageParam:
type="tool_use", type="tool_use",
id=tool_call.id, id=tool_call.id,
name=tool_call.tool_name, name=tool_call.tool_name,
input=json.dumps(tool_call.tool_args), input=tool_call.tool_args,
) )
for tool_call in chat_content.tool_calls or () for tool_call in chat_content.tool_calls or ()
], ],
@ -124,6 +131,66 @@ def _convert_content(chat_content: conversation.Content) -> MessageParam:
raise ValueError(f"Unexpected content type: {type(chat_content)}") raise ValueError(f"Unexpected content type: {type(chat_content)}")
async def _transform_stream(
result: AsyncStream[MessageStreamEvent],
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
"""Transform the response stream into HA format.
A typical stream of responses might look something like the following:
- RawMessageStartEvent with no content
- RawContentBlockStartEvent with an empty TextBlock
- RawContentBlockDeltaEvent with a TextDelta
- RawContentBlockDeltaEvent with a TextDelta
- RawContentBlockDeltaEvent with a TextDelta
- ...
- RawContentBlockStopEvent
- RawContentBlockStartEvent with ToolUseBlock specifying the function name
- RawContentBlockDeltaEvent with a InputJSONDelta
- RawContentBlockDeltaEvent with a InputJSONDelta
- ...
- RawContentBlockStopEvent
- RawMessageDeltaEvent with a stop_reason='tool_use'
- RawMessageStopEvent(type='message_stop')
"""
if result is None:
raise TypeError("Expected a stream of messages")
current_tool_call: dict | None = None
async for response in result:
LOGGER.debug("Received response: %s", response)
if isinstance(response, RawContentBlockStartEvent):
if isinstance(response.content_block, ToolUseBlock):
current_tool_call = {
"id": response.content_block.id,
"name": response.content_block.name,
"input": "",
}
elif isinstance(response.content_block, TextBlock):
yield {"role": "assistant"}
elif isinstance(response, RawContentBlockDeltaEvent):
if isinstance(response.delta, InputJSONDelta):
if current_tool_call is None:
raise ValueError("Unexpected delta without a tool call")
current_tool_call["input"] += response.delta.partial_json
elif isinstance(response.delta, TextDelta):
LOGGER.debug("yielding delta: %s", response.delta.text)
yield {"content": response.delta.text}
elif isinstance(response, RawContentBlockStopEvent):
if current_tool_call:
yield {
"tool_calls": [
llm.ToolInput(
id=current_tool_call["id"],
tool_name=current_tool_call["name"],
tool_args=json.loads(current_tool_call["input"]),
)
]
}
current_tool_call = None
class AnthropicConversationEntity( class AnthropicConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent conversation.ConversationEntity, conversation.AbstractConversationAgent
): ):
@ -206,58 +273,30 @@ class AnthropicConversationEntity(
# To prevent infinite loops, we limit the number of iterations # To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS): for _iteration in range(MAX_TOOL_ITERATIONS):
try: try:
response = await client.messages.create( stream = await client.messages.create(
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
messages=messages, messages=messages,
tools=tools or NOT_GIVEN, tools=tools or NOT_GIVEN,
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS), max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
system=system.content, system=system.content,
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE), temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
stream=True,
) )
except anthropic.AnthropicError as err: except anthropic.AnthropicError as err:
raise HomeAssistantError( raise HomeAssistantError(
f"Sorry, I had a problem talking to Anthropic: {err}" f"Sorry, I had a problem talking to Anthropic: {err}"
) from err ) from err
LOGGER.debug("Response %s", response) messages.extend(
messages.append(_message_convert(response))
text = "".join(
[ [
content.text _convert_content(content)
for content in response.content async for content in chat_log.async_add_delta_content_stream(
if isinstance(content, TextBlock) user_input.agent_id, _transform_stream(stream)
)
] ]
) )
tool_inputs = [
llm.ToolInput(
id=tool_call.id,
tool_name=tool_call.name,
tool_args=cast(dict[str, Any], tool_call.input),
)
for tool_call in response.content
if isinstance(tool_call, ToolUseBlock)
]
tool_results = [ if not chat_log.unresponded_tool_results:
ToolResultBlockParam(
type="tool_result",
tool_use_id=tool_response.tool_call_id,
content=json.dumps(tool_response.tool_result),
)
async for tool_response in chat_log.async_add_assistant_content(
conversation.AssistantContent(
agent_id=user_input.agent_id,
content=text,
tool_calls=tool_inputs or None,
)
)
]
if tool_results:
messages.append(MessageParam(role="user", content=tool_results))
if not tool_inputs:
break break
response_content = chat_log.content[-1] response_content = chat_log.content[-1]

View File

@ -1,9 +1,24 @@
"""Tests for the Anthropic integration.""" """Tests for the Anthropic integration."""
from collections.abc import AsyncGenerator
from typing import Any
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
from anthropic import RateLimitError from anthropic import RateLimitError
from anthropic.types import Message, TextBlock, ToolUseBlock, Usage from anthropic.types import (
InputJSONDelta,
Message,
RawContentBlockDeltaEvent,
RawContentBlockStartEvent,
RawContentBlockStopEvent,
RawMessageStartEvent,
RawMessageStopEvent,
RawMessageStreamEvent,
TextBlock,
TextDelta,
ToolUseBlock,
Usage,
)
from freezegun import freeze_time from freezegun import freeze_time
from httpx import URL, Request, Response from httpx import URL, Request, Response
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
@ -20,6 +35,81 @@ from homeassistant.util import ulid as ulid_util
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
async def stream_generator(
responses: list[RawMessageStreamEvent],
) -> AsyncGenerator[RawMessageStreamEvent]:
"""Generate a response from the assistant."""
for msg in responses:
yield msg
def create_messages(
content_blocks: list[RawMessageStreamEvent],
) -> list[RawMessageStreamEvent]:
"""Create a stream of messages with the specified content blocks."""
return [
RawMessageStartEvent(
message=Message(
type="message",
id="msg_1234567890ABCDEFGHIJKLMN",
content=[],
role="assistant",
model="claude-3-5-sonnet-20240620",
usage=Usage(input_tokens=0, output_tokens=0),
),
type="message_start",
),
*content_blocks,
RawMessageStopEvent(type="message_stop"),
]
def create_content_block(
index: int, text_parts: list[str]
) -> list[RawMessageStreamEvent]:
"""Create a text content block with the specified deltas."""
return [
RawContentBlockStartEvent(
type="content_block_start",
content_block=TextBlock(text="", type="text"),
index=index,
),
*[
RawContentBlockDeltaEvent(
delta=TextDelta(text=text_part, type="text_delta"),
index=index,
type="content_block_delta",
)
for text_part in text_parts
],
RawContentBlockStopEvent(index=index, type="content_block_stop"),
]
def create_tool_use_block(
index: int, tool_id: str, tool_name: str, json_parts: list[str]
) -> list[RawMessageStreamEvent]:
"""Create a tool use content block with the specified deltas."""
return [
RawContentBlockStartEvent(
type="content_block_start",
content_block=ToolUseBlock(
id=tool_id, name=tool_name, input={}, type="tool_use"
),
index=index,
),
*[
RawContentBlockDeltaEvent(
delta=InputJSONDelta(partial_json=json_part, type="input_json_delta"),
index=index,
type="content_block_delta",
)
for json_part in json_parts
],
RawContentBlockStopEvent(index=index, type="content_block_stop"),
]
async def test_entity( async def test_entity(
hass: HomeAssistant, hass: HomeAssistant,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
@ -120,6 +210,13 @@ async def test_template_variables(
) as mock_create, ) as mock_create,
patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user), patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user),
): ):
mock_create.return_value = stream_generator(
create_messages(
create_content_block(
0, ["Okay, let", " me take care of that for you", "."]
)
)
)
await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
result = await conversation.async_converse( result = await conversation.async_converse(
@ -129,6 +226,10 @@ async def test_template_variables(
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE, ( assert result.response.response_type == intent.IntentResponseType.ACTION_DONE, (
result result
) )
assert (
result.response.speech["plain"]["speech"]
== "Okay, let me take care of that for you."
)
assert "The user name is Test User." in mock_create.mock_calls[1][2]["system"] assert "The user name is Test User." in mock_create.mock_calls[1][2]["system"]
assert "The user id is 12345." in mock_create.mock_calls[1][2]["system"] assert "The user id is 12345." in mock_create.mock_calls[1][2]["system"]
@ -168,39 +269,26 @@ async def test_function_call(
for message in messages: for message in messages:
for content in message["content"]: for content in message["content"]:
if not isinstance(content, str) and content["type"] == "tool_use": if not isinstance(content, str) and content["type"] == "tool_use":
return Message( return stream_generator(
type="message", create_messages(
id="msg_1234567890ABCDEFGHIJKLMN", create_content_block(
content=[ 0, ["I have ", "successfully called ", "the function"]
TextBlock( ),
type="text",
text="I have successfully called the function",
) )
],
model="claude-3-5-sonnet-20240620",
role="assistant",
stop_reason="end_turn",
stop_sequence=None,
usage=Usage(input_tokens=8, output_tokens=12),
) )
return Message( return stream_generator(
type="message", create_messages(
id="msg_1234567890ABCDEFGHIJKLMN", [
content=[ *create_content_block(0, ["Certainly, calling it now!"]),
TextBlock(type="text", text="Certainly, calling it now!"), *create_tool_use_block(
ToolUseBlock( 1,
type="tool_use", "toolu_0123456789AbCdEfGhIjKlM",
id="toolu_0123456789AbCdEfGhIjKlM", "test_tool",
name="test_tool", ['{"para', 'm1": "test_valu', 'e"}'],
input={"param1": "test_value"},
), ),
], ]
model="claude-3-5-sonnet-20240620", )
role="assistant",
stop_reason="tool_use",
stop_sequence=None,
usage=Usage(input_tokens=8, output_tokens=12),
) )
with ( with (
@ -222,6 +310,10 @@ async def test_function_call(
assert "Today's date is 2024-06-03." in mock_create.mock_calls[1][2]["system"] assert "Today's date is 2024-06-03." in mock_create.mock_calls[1][2]["system"]
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert (
result.response.speech["plain"]["speech"]
== "I have successfully called the function"
)
assert mock_create.mock_calls[1][2]["messages"][2] == { assert mock_create.mock_calls[1][2]["messages"][2] == {
"role": "user", "role": "user",
"content": [ "content": [
@ -275,39 +367,27 @@ async def test_function_exception(
for message in messages: for message in messages:
for content in message["content"]: for content in message["content"]:
if not isinstance(content, str) and content["type"] == "tool_use": if not isinstance(content, str) and content["type"] == "tool_use":
return Message( return stream_generator(
type="message", create_messages(
id="msg_1234567890ABCDEFGHIJKLMN", create_content_block(
content=[ 0,
TextBlock( ["There was an error calling the function"],
type="text", )
text="There was an error calling the function",
) )
],
model="claude-3-5-sonnet-20240620",
role="assistant",
stop_reason="end_turn",
stop_sequence=None,
usage=Usage(input_tokens=8, output_tokens=12),
) )
return Message( return stream_generator(
type="message", create_messages(
id="msg_1234567890ABCDEFGHIJKLMN", [
content=[ *create_content_block(0, "Certainly, calling it now!"),
TextBlock(type="text", text="Certainly, calling it now!"), *create_tool_use_block(
ToolUseBlock( 1,
type="tool_use", "toolu_0123456789AbCdEfGhIjKlM",
id="toolu_0123456789AbCdEfGhIjKlM", "test_tool",
name="test_tool", ['{"param1": "test_value"}'],
input={"param1": "test_value"},
), ),
], ]
model="claude-3-5-sonnet-20240620", )
role="assistant",
stop_reason="tool_use",
stop_sequence=None,
usage=Usage(input_tokens=8, output_tokens=12),
) )
with patch( with patch(
@ -324,6 +404,10 @@ async def test_function_exception(
) )
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert (
result.response.speech["plain"]["speech"]
== "There was an error calling the function"
)
assert mock_create.mock_calls[1][2]["messages"][2] == { assert mock_create.mock_calls[1][2]["messages"][2] == {
"role": "user", "role": "user",
"content": [ "content": [
@ -376,15 +460,10 @@ async def test_assist_api_tools_conversion(
with patch( with patch(
"anthropic.resources.messages.AsyncMessages.create", "anthropic.resources.messages.AsyncMessages.create",
new_callable=AsyncMock, new_callable=AsyncMock,
return_value=Message( return_value=stream_generator(
type="message", create_messages(
id="msg_1234567890ABCDEFGHIJKLMN", create_content_block(0, "Hello, how can I help you?"),
content=[TextBlock(type="text", text="Hello, how can I help you?")], ),
model="claude-3-5-sonnet-20240620",
role="assistant",
stop_reason="end_turn",
stop_sequence=None,
usage=Usage(input_tokens=8, output_tokens=12),
), ),
) as mock_create: ) as mock_create:
await conversation.async_converse( await conversation.async_converse(
@ -425,6 +504,23 @@ async def test_conversation_id(
mock_init_component, mock_init_component,
) -> None: ) -> None:
"""Test conversation ID is honored.""" """Test conversation ID is honored."""
def create_stream_generator(*args, **kwargs) -> Any:
return stream_generator(
create_messages(
create_content_block(0, "Hello, how can I help you?"),
),
)
with patch(
"anthropic.resources.messages.AsyncMessages.create",
new_callable=AsyncMock,
side_effect=create_stream_generator,
):
result = await conversation.async_converse(
hass, "hello", "1234", Context(), agent_id="conversation.claude"
)
result = await conversation.async_converse( result = await conversation.async_converse(
hass, "hello", None, None, agent_id="conversation.claude" hass, "hello", None, None, agent_id="conversation.claude"
) )