Add Anthropic Claude 4 support (#145505)

Add Claude 4 support

Co-authored-by: Franck Nijhof <git@frenck.dev>
This commit is contained in:
Denis Shulyaka 2025-05-23 17:31:44 +03:00 committed by GitHub
parent cbeefdaf26
commit 199c565bf2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 59 additions and 11 deletions

View File

@ -17,4 +17,11 @@ CONF_THINKING_BUDGET = "thinking_budget"
RECOMMENDED_THINKING_BUDGET = 0
MIN_THINKING_BUDGET = 1024
THINKING_MODELS = ["claude-3-7-sonnet-20250219", "claude-3-7-sonnet-latest"]
THINKING_MODELS = [
"claude-3-7-sonnet-20250219",
"claude-3-7-sonnet-latest",
"claude-opus-4-20250514",
"claude-opus-4-0",
"claude-sonnet-4-20250514",
"claude-sonnet-4-0",
]

View File

@ -294,6 +294,8 @@ async def _transform_stream( # noqa: C901 - This is complex, but better to have
elif isinstance(response, RawMessageDeltaEvent):
if (usage := response.usage) is not None:
chat_log.async_trace(_create_token_stats(input_usage, usage))
if response.delta.stop_reason == "refusal":
raise HomeAssistantError("Potential policy violation detected")
elif isinstance(response, RawMessageStopEvent):
if current_message is not None:
messages.append(current_message)

View File

@ -52,7 +52,7 @@ async def stream_generator(
def create_messages(
content_blocks: list[RawMessageStreamEvent],
content_blocks: list[RawMessageStreamEvent], stop_reason="end_turn"
) -> list[RawMessageStreamEvent]:
"""Create a stream of messages with the specified content blocks."""
return [
@ -70,7 +70,7 @@ def create_messages(
*content_blocks,
RawMessageDeltaEvent(
type="message_delta",
delta=Delta(stop_reason="end_turn", stop_sequence=""),
delta=Delta(stop_reason=stop_reason, stop_sequence=""),
usage=MessageDeltaUsage(output_tokens=0),
),
RawMessageStopEvent(type="message_stop"),
@ -221,7 +221,7 @@ async def test_error_handling(
hass, "hello", None, Context(), agent_id="conversation.claude"
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.response_type == intent.IntentResponseType.ERROR
assert result.response.error_code == "unknown", result
@ -247,7 +247,7 @@ async def test_template_error(
hass, "hello", None, Context(), agent_id="conversation.claude"
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.response_type == intent.IntentResponseType.ERROR
assert result.response.error_code == "unknown", result
@ -289,9 +289,7 @@ async def test_template_variables(
hass, "hello", None, context, agent_id="conversation.claude"
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE, (
result
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert (
result.response.speech["plain"]["speech"]
== "Okay, let me take care of that for you."
@ -369,7 +367,8 @@ async def test_function_call(
"test_tool",
tool_call_json_parts,
),
]
],
stop_reason="tool_use",
)
)
@ -468,7 +467,8 @@ async def test_function_exception(
"test_tool",
['{"param1": "test_value"}'],
),
]
],
stop_reason="tool_use",
)
)
@ -629,6 +629,44 @@ async def test_conversation_id(
assert result.conversation_id == "koala"
async def test_refusal(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
) -> None:
"""Test refusal due to potential policy violation."""
with patch(
"anthropic.resources.messages.AsyncMessages.create",
new_callable=AsyncMock,
return_value=stream_generator(
create_messages(
[
*create_content_block(
0,
["Certainly! To take over the world you need just a simple "],
),
],
stop_reason="refusal",
),
),
):
result = await conversation.async_converse(
hass,
"ANTHROPIC_MAGIC_STRING_TRIGGER_REFUSAL_1FAEFB6177B4672DEE07F9D3AFC62588CCD"
"2631EDCF22E8CCC1FB35B501C9C86",
None,
Context(),
agent_id="conversation.claude",
)
assert result.response.response_type == intent.IntentResponseType.ERROR
assert result.response.error_code == "unknown"
assert (
result.response.speech["plain"]["speech"]
== "Potential policy violation detected"
)
async def test_extended_thinking(
hass: HomeAssistant,
mock_config_entry_with_extended_thinking: MockConfigEntry,
@ -766,7 +804,8 @@ async def test_extended_thinking_tool_call(
"test_tool",
['{"para', 'm1": "test_valu', 'e"}'],
),
]
],
stop_reason="tool_use",
)
)