Anthropic conversation extended thinking support (#139662)

* Anthropic conversation extended thinking support

* update conversation snapshots

* Add conversation test

* Update openai_conversation snapshots

* Removed metadata

* Removed metadata

* Removed thinking

* cosmetic fix

* combine user messages

* Apply suggestions from code review

* Add tests for chat_log messages conversion

* s/THINKING_BUDGET_TOKENS/THINKING_BUDGET/

* Apply suggestions from code review

* Update tests

* Update homeassistant/components/anthropic/strings.json

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>

* apply suggestions from code review

---------

Co-authored-by: Robert Resch <robert@resch.dev>
Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
Denis Shulyaka 2025-03-15 05:07:59 +03:00 committed by GitHub
parent baafcf48dc
commit 07e7672b78
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 940 additions and 104 deletions

View File

@ -34,10 +34,12 @@ from .const import (
CONF_PROMPT,
CONF_RECOMMENDED,
CONF_TEMPERATURE,
CONF_THINKING_BUDGET,
DOMAIN,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TEMPERATURE,
RECOMMENDED_THINKING_BUDGET,
)
_LOGGER = logging.getLogger(__name__)
@ -128,21 +130,29 @@ class AnthropicOptionsFlow(OptionsFlow):
) -> ConfigFlowResult:
"""Manage the options."""
options: dict[str, Any] | MappingProxyType[str, Any] = self.config_entry.options
errors: dict[str, str] = {}
if user_input is not None:
if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended:
if user_input[CONF_LLM_HASS_API] == "none":
user_input.pop(CONF_LLM_HASS_API)
return self.async_create_entry(title="", data=user_input)
# Re-render the options again, now with the recommended options shown/hidden
self.last_rendered_recommended = user_input[CONF_RECOMMENDED]
if user_input.get(
CONF_THINKING_BUDGET, RECOMMENDED_THINKING_BUDGET
) >= user_input.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS):
errors[CONF_THINKING_BUDGET] = "thinking_budget_too_large"
options = {
CONF_RECOMMENDED: user_input[CONF_RECOMMENDED],
CONF_PROMPT: user_input[CONF_PROMPT],
CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API],
}
if not errors:
return self.async_create_entry(title="", data=user_input)
else:
# Re-render the options again, now with the recommended options shown/hidden
self.last_rendered_recommended = user_input[CONF_RECOMMENDED]
options = {
CONF_RECOMMENDED: user_input[CONF_RECOMMENDED],
CONF_PROMPT: user_input[CONF_PROMPT],
CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API],
}
suggested_values = options.copy()
if not suggested_values.get(CONF_PROMPT):
@ -156,6 +166,7 @@ class AnthropicOptionsFlow(OptionsFlow):
return self.async_show_form(
step_id="init",
data_schema=schema,
errors=errors or None,
)
@ -205,6 +216,10 @@ def anthropic_config_option_schema(
CONF_TEMPERATURE,
default=RECOMMENDED_TEMPERATURE,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Optional(
CONF_THINKING_BUDGET,
default=RECOMMENDED_THINKING_BUDGET,
): int,
}
)
return schema

View File

@ -13,3 +13,8 @@ CONF_MAX_TOKENS = "max_tokens"
RECOMMENDED_MAX_TOKENS = 1024
CONF_TEMPERATURE = "temperature"
RECOMMENDED_TEMPERATURE = 1.0
CONF_THINKING_BUDGET = "thinking_budget"
RECOMMENDED_THINKING_BUDGET = 0
MIN_THINKING_BUDGET = 1024
THINKING_MODELS = ["claude-3-7-sonnet-20250219", "claude-3-7-sonnet-latest"]

View File

@ -1,23 +1,32 @@
"""Conversation support for Anthropic."""
from collections.abc import AsyncGenerator, Callable
from collections.abc import AsyncGenerator, Callable, Iterable
import json
from typing import Any, Literal
from typing import Any, Literal, cast
import anthropic
from anthropic import AsyncStream
from anthropic._types import NOT_GIVEN
from anthropic.types import (
InputJSONDelta,
Message,
MessageParam,
MessageStreamEvent,
RawContentBlockDeltaEvent,
RawContentBlockStartEvent,
RawContentBlockStopEvent,
RawMessageStartEvent,
RawMessageStopEvent,
RedactedThinkingBlock,
RedactedThinkingBlockParam,
SignatureDelta,
TextBlock,
TextBlockParam,
TextDelta,
ThinkingBlock,
ThinkingBlockParam,
ThinkingConfigDisabledParam,
ThinkingConfigEnabledParam,
ThinkingDelta,
ToolParam,
ToolResultBlockParam,
ToolUseBlock,
@ -39,11 +48,15 @@ from .const import (
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_TEMPERATURE,
CONF_THINKING_BUDGET,
DOMAIN,
LOGGER,
MIN_THINKING_BUDGET,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TEMPERATURE,
RECOMMENDED_THINKING_BUDGET,
THINKING_MODELS,
)
# Max number of back and forth with the LLM to generate a response
@ -71,73 +84,101 @@ def _format_tool(
)
def _message_convert(
message: Message,
) -> MessageParam:
"""Convert from class to TypedDict."""
param_content: list[TextBlockParam | ToolUseBlockParam] = []
def _convert_content(
chat_content: Iterable[conversation.Content],
) -> list[MessageParam]:
"""Transform HA chat_log content into Anthropic API format."""
messages: list[MessageParam] = []
for message_content in message.content:
if isinstance(message_content, TextBlock):
param_content.append(TextBlockParam(type="text", text=message_content.text))
elif isinstance(message_content, ToolUseBlock):
param_content.append(
ToolUseBlockParam(
type="tool_use",
id=message_content.id,
name=message_content.name,
input=message_content.input,
)
for content in chat_content:
if isinstance(content, conversation.ToolResultContent):
tool_result_block = ToolResultBlockParam(
type="tool_result",
tool_use_id=content.tool_call_id,
content=json.dumps(content.tool_result),
)
return MessageParam(role=message.role, content=param_content)
def _convert_content(chat_content: conversation.Content) -> MessageParam:
"""Create tool response content."""
if isinstance(chat_content, conversation.ToolResultContent):
return MessageParam(
role="user",
content=[
ToolResultBlockParam(
type="tool_result",
tool_use_id=chat_content.tool_call_id,
content=json.dumps(chat_content.tool_result),
)
],
)
if isinstance(chat_content, conversation.AssistantContent):
return MessageParam(
role="assistant",
content=[
TextBlockParam(type="text", text=chat_content.content or ""),
*[
ToolUseBlockParam(
type="tool_use",
id=tool_call.id,
name=tool_call.tool_name,
input=tool_call.tool_args,
if not messages or messages[-1]["role"] != "user":
messages.append(
MessageParam(
role="user",
content=[tool_result_block],
)
for tool_call in chat_content.tool_calls or ()
],
],
)
if isinstance(chat_content, conversation.UserContent):
return MessageParam(
role="user",
content=chat_content.content,
)
# Note: We don't pass SystemContent here as its passed to the API as the prompt
raise ValueError(f"Unexpected content type: {type(chat_content)}")
)
elif isinstance(messages[-1]["content"], str):
messages[-1]["content"] = [
TextBlockParam(type="text", text=messages[-1]["content"]),
tool_result_block,
]
else:
messages[-1]["content"].append(tool_result_block) # type: ignore[attr-defined]
elif isinstance(content, conversation.UserContent):
# Combine consequent user messages
if not messages or messages[-1]["role"] != "user":
messages.append(
MessageParam(
role="user",
content=content.content,
)
)
elif isinstance(messages[-1]["content"], str):
messages[-1]["content"] = [
TextBlockParam(type="text", text=messages[-1]["content"]),
TextBlockParam(type="text", text=content.content),
]
else:
messages[-1]["content"].append( # type: ignore[attr-defined]
TextBlockParam(type="text", text=content.content)
)
elif isinstance(content, conversation.AssistantContent):
# Combine consequent assistant messages
if not messages or messages[-1]["role"] != "assistant":
messages.append(
MessageParam(
role="assistant",
content=[],
)
)
if content.content:
messages[-1]["content"].append( # type: ignore[union-attr]
TextBlockParam(type="text", text=content.content)
)
if content.tool_calls:
messages[-1]["content"].extend( # type: ignore[union-attr]
[
ToolUseBlockParam(
type="tool_use",
id=tool_call.id,
name=tool_call.tool_name,
input=tool_call.tool_args,
)
for tool_call in content.tool_calls
]
)
else:
# Note: We don't pass SystemContent here as its passed to the API as the prompt
raise TypeError(f"Unexpected content type: {type(content)}")
return messages
async def _transform_stream(
result: AsyncStream[MessageStreamEvent],
messages: list[MessageParam],
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
"""Transform the response stream into HA format.
A typical stream of responses might look something like the following:
- RawMessageStartEvent with no content
- RawContentBlockStartEvent with an empty ThinkingBlock (if extended thinking is enabled)
- RawContentBlockDeltaEvent with a ThinkingDelta
- RawContentBlockDeltaEvent with a ThinkingDelta
- RawContentBlockDeltaEvent with a ThinkingDelta
- ...
- RawContentBlockDeltaEvent with a SignatureDelta
- RawContentBlockStopEvent
- RawContentBlockStartEvent with a RedactedThinkingBlock (occasionally)
- RawContentBlockStopEvent (RedactedThinkingBlock does not have a delta)
- RawContentBlockStartEvent with an empty TextBlock
- RawContentBlockDeltaEvent with a TextDelta
- RawContentBlockDeltaEvent with a TextDelta
@ -151,44 +192,103 @@ async def _transform_stream(
- RawContentBlockStopEvent
- RawMessageDeltaEvent with a stop_reason='tool_use'
- RawMessageStopEvent(type='message_stop')
Each message could contain multiple blocks of the same type.
"""
if result is None:
raise TypeError("Expected a stream of messages")
current_tool_call: dict | None = None
current_message: MessageParam | None = None
current_block: (
TextBlockParam
| ToolUseBlockParam
| ThinkingBlockParam
| RedactedThinkingBlockParam
| None
) = None
current_tool_args: str
async for response in result:
LOGGER.debug("Received response: %s", response)
if isinstance(response, RawContentBlockStartEvent):
if isinstance(response, RawMessageStartEvent):
if response.message.role != "assistant":
raise ValueError("Unexpected message role")
current_message = MessageParam(role=response.message.role, content=[])
elif isinstance(response, RawContentBlockStartEvent):
if isinstance(response.content_block, ToolUseBlock):
current_tool_call = {
"id": response.content_block.id,
"name": response.content_block.name,
"input": "",
}
current_block = ToolUseBlockParam(
type="tool_use",
id=response.content_block.id,
name=response.content_block.name,
input="",
)
current_tool_args = ""
elif isinstance(response.content_block, TextBlock):
current_block = TextBlockParam(
type="text", text=response.content_block.text
)
yield {"role": "assistant"}
if response.content_block.text:
yield {"content": response.content_block.text}
elif isinstance(response.content_block, ThinkingBlock):
current_block = ThinkingBlockParam(
type="thinking",
thinking=response.content_block.thinking,
signature=response.content_block.signature,
)
elif isinstance(response.content_block, RedactedThinkingBlock):
current_block = RedactedThinkingBlockParam(
type="redacted_thinking", data=response.content_block.data
)
LOGGER.debug(
"Some of Claudes internal reasoning has been automatically "
"encrypted for safety reasons. This doesnt affect the quality of "
"responses"
)
elif isinstance(response, RawContentBlockDeltaEvent):
if current_block is None:
raise ValueError("Unexpected delta without a block")
if isinstance(response.delta, InputJSONDelta):
if current_tool_call is None:
raise ValueError("Unexpected delta without a tool call")
current_tool_call["input"] += response.delta.partial_json
current_tool_args += response.delta.partial_json
elif isinstance(response.delta, TextDelta):
LOGGER.debug("yielding delta: %s", response.delta.text)
text_block = cast(TextBlockParam, current_block)
text_block["text"] += response.delta.text
yield {"content": response.delta.text}
elif isinstance(response.delta, ThinkingDelta):
thinking_block = cast(ThinkingBlockParam, current_block)
thinking_block["thinking"] += response.delta.thinking
elif isinstance(response.delta, SignatureDelta):
thinking_block = cast(ThinkingBlockParam, current_block)
thinking_block["signature"] += response.delta.signature
elif isinstance(response, RawContentBlockStopEvent):
if current_tool_call:
if current_block is None:
raise ValueError("Unexpected stop event without a current block")
if current_block["type"] == "tool_use":
tool_block = cast(ToolUseBlockParam, current_block)
tool_args = json.loads(current_tool_args)
tool_block["input"] = tool_args
yield {
"tool_calls": [
llm.ToolInput(
id=current_tool_call["id"],
tool_name=current_tool_call["name"],
tool_args=json.loads(current_tool_call["input"]),
id=tool_block["id"],
tool_name=tool_block["name"],
tool_args=tool_args,
)
]
}
current_tool_call = None
elif current_block["type"] == "thinking":
thinking_block = cast(ThinkingBlockParam, current_block)
LOGGER.debug("Thinking: %s", thinking_block["thinking"])
if current_message is None:
raise ValueError("Unexpected stop event without a current message")
current_message["content"].append(current_block) # type: ignore[union-attr]
current_block = None
elif isinstance(response, RawMessageStopEvent):
if current_message is not None:
messages.append(current_message)
current_message = None
class AnthropicConversationEntity(
@ -254,34 +354,50 @@ class AnthropicConversationEntity(
system = chat_log.content[0]
if not isinstance(system, conversation.SystemContent):
raise TypeError("First message must be a system message")
messages = [_convert_content(content) for content in chat_log.content[1:]]
messages = _convert_content(chat_log.content[1:])
client = self.entry.runtime_data
thinking_budget = options.get(CONF_THINKING_BUDGET, RECOMMENDED_THINKING_BUDGET)
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
# To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS):
try:
stream = await client.messages.create(
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
messages=messages,
tools=tools or NOT_GIVEN,
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
system=system.content,
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
stream=True,
model_args = {
"model": model,
"messages": messages,
"tools": tools or NOT_GIVEN,
"max_tokens": options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
"system": system.content,
"stream": True,
}
if model in THINKING_MODELS and thinking_budget >= MIN_THINKING_BUDGET:
model_args["thinking"] = ThinkingConfigEnabledParam(
type="enabled", budget_tokens=thinking_budget
)
else:
model_args["thinking"] = ThinkingConfigDisabledParam(type="disabled")
model_args["temperature"] = options.get(
CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE
)
try:
stream = await client.messages.create(**model_args)
except anthropic.AnthropicError as err:
raise HomeAssistantError(
f"Sorry, I had a problem talking to Anthropic: {err}"
) from err
messages.extend(
[
_convert_content(content)
async for content in chat_log.async_add_delta_content_stream(
user_input.agent_id, _transform_stream(stream)
)
]
_convert_content(
[
content
async for content in chat_log.async_add_delta_content_stream(
user_input.agent_id, _transform_stream(stream, messages)
)
if not isinstance(content, conversation.AssistantContent)
]
)
)
if not chat_log.unresponded_tool_results:

View File

@ -23,12 +23,17 @@
"max_tokens": "Maximum tokens to return in response",
"temperature": "Temperature",
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
"recommended": "Recommended model settings"
"recommended": "Recommended model settings",
"thinking_budget_tokens": "Thinking budget"
},
"data_description": {
"prompt": "Instruct how the LLM should respond. This can be a template."
"prompt": "Instruct how the LLM should respond. This can be a template.",
"thinking_budget_tokens": "The number of tokens the model can use to think about the response out of the total maximum number of tokens. Set to 1024 or greater to enable extended thinking."
}
}
},
"error": {
"thinking_budget_too_large": "Maximum tokens must be greater than the thinking budget."
}
}
}

View File

@ -5,6 +5,7 @@ from unittest.mock import patch
import pytest
from homeassistant.components.anthropic import CONF_CHAT_MODEL
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant
from homeassistant.helpers import llm
@ -38,6 +39,21 @@ def mock_config_entry_with_assist(
return mock_config_entry
@pytest.fixture
def mock_config_entry_with_extended_thinking(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> MockConfigEntry:
"""Mock a config entry with assist."""
hass.config_entries.async_update_entry(
mock_config_entry,
options={
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
CONF_CHAT_MODEL: "claude-3-7-sonnet-latest",
},
)
return mock_config_entry
@pytest.fixture
async def mock_init_component(
hass: HomeAssistant, mock_config_entry: MockConfigEntry

View File

@ -1,4 +1,321 @@
# serializer version: 1
# name: test_extended_thinking_tool_call
list([
dict({
'content': '''
Current time is 16:00:00. Today's date is 2024-06-03.
You are a voice assistant for Home Assistant.
Answer questions about the world truthfully.
Answer in plain text. Keep it simple and to the point.
Only if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.
''',
'role': 'system',
}),
dict({
'content': 'Please call the test function',
'role': 'user',
}),
dict({
'agent_id': 'conversation.claude',
'content': 'Certainly, calling it now!',
'role': 'assistant',
'tool_calls': list([
dict({
'id': 'toolu_0123456789AbCdEfGhIjKlM',
'tool_args': dict({
'param1': 'test_value',
}),
'tool_name': 'test_tool',
}),
]),
}),
dict({
'agent_id': 'conversation.claude',
'role': 'tool_result',
'tool_call_id': 'toolu_0123456789AbCdEfGhIjKlM',
'tool_name': 'test_tool',
'tool_result': 'Test response',
}),
dict({
'agent_id': 'conversation.claude',
'content': 'I have successfully called the function',
'role': 'assistant',
'tool_calls': None,
}),
])
# ---
# name: test_extended_thinking_tool_call.1
list([
dict({
'content': 'Please call the test function',
'role': 'user',
}),
dict({
'content': list([
dict({
'signature': 'ErUBCkYIARgCIkCYXaVNJShe3A86Hp7XUzh9YsCYBbJTbQsrklTAPtJ2sP/NoB6tSzpK/nTL6CjSo2R6n0KNBIg5MH6asM2R/kmaEgyB/X1FtZq5OQAC7jUaDEPWCdcwGQ4RaBy5wiIwmRxExIlDhoY6tILoVPnOExkC/0igZxHEwxK8RU/fmw0b+o+TwAarzUitwzbo21E5Kh3pa3I6yqVROf1t2F8rFocNUeCegsWV/ytwYV+ayA==',
'thinking': 'The user asked me to call a test function.Is it a test? What would the function do? Would it violate any privacy or security policies?',
'type': 'thinking',
}),
dict({
'data': 'EroBCkYIARgCKkBJDytPJhw//4vy3t7aE+LfIkxvkAh51cBPrAvBCo6AjgI57Zt9KWPnUVV50OQJ0KZzUFoGZG5sxg95zx4qMwkoEgz43Su3myJKckvj03waDBZLIBSeoAeRUeVsJCIwQ5edQN0sa+HNeB/KUBkoMUwV+IT0eIhcpFxnILdvxUAKM4R1o4KG3x+yO0eo/kyOKiKfrCPFQhvBVmTZPFhgA2Ow8L9gGDVipcz6x3Uu9YETGEny',
'type': 'redacted_thinking',
}),
dict({
'signature': 'ErUBCkYIARgCIkCYXaVNJShe3A86Hp7XUzh9YsCYBbJTbQsrklTAPtJ2sP/NoB6tSzpK/nTL6CjSo2R6n0KNBIg5MH6asM2R/kmaEgyB/X1FtZq5OQAC7jUaDEPWCdcwGQ4RaBy5wiIwmRxExIlDhoY6tILoVPnOExkC/0igZxHEwxK8RU/fmw0b+o+TwAarzUitwzbo21E5Kh3pa3I6yqVROf1t2F8rFocNUeCegsWV/ytwYV+ayA==',
'thinking': "Okay, let's give it a shot. Will I pass the test?",
'type': 'thinking',
}),
dict({
'text': 'Certainly, calling it now!',
'type': 'text',
}),
dict({
'id': 'toolu_0123456789AbCdEfGhIjKlM',
'input': dict({
'param1': 'test_value',
}),
'name': 'test_tool',
'type': 'tool_use',
}),
]),
'role': 'assistant',
}),
dict({
'content': list([
dict({
'content': '"Test response"',
'tool_use_id': 'toolu_0123456789AbCdEfGhIjKlM',
'type': 'tool_result',
}),
]),
'role': 'user',
}),
dict({
'content': list([
dict({
'text': 'I have successfully called the function',
'type': 'text',
}),
]),
'role': 'assistant',
}),
])
# ---
# name: test_history_conversion[content0]
list([
dict({
'content': 'Are you sure?',
'role': 'user',
}),
dict({
'content': list([
dict({
'text': 'Yes, I am sure!',
'type': 'text',
}),
]),
'role': 'assistant',
}),
])
# ---
# name: test_history_conversion[content1]
list([
dict({
'content': 'What shape is a donut?',
'role': 'user',
}),
dict({
'content': list([
dict({
'text': 'A donut is a torus.',
'type': 'text',
}),
]),
'role': 'assistant',
}),
dict({
'content': 'Are you sure?',
'role': 'user',
}),
dict({
'content': list([
dict({
'text': 'Yes, I am sure!',
'type': 'text',
}),
]),
'role': 'assistant',
}),
])
# ---
# name: test_history_conversion[content2]
list([
dict({
'content': list([
dict({
'text': 'What shape is a donut?',
'type': 'text',
}),
dict({
'text': 'Can you tell me?',
'type': 'text',
}),
]),
'role': 'user',
}),
dict({
'content': list([
dict({
'text': 'A donut is a torus.',
'type': 'text',
}),
dict({
'text': 'Hope this helps.',
'type': 'text',
}),
]),
'role': 'assistant',
}),
dict({
'content': 'Are you sure?',
'role': 'user',
}),
dict({
'content': list([
dict({
'text': 'Yes, I am sure!',
'type': 'text',
}),
]),
'role': 'assistant',
}),
])
# ---
# name: test_history_conversion[content3]
list([
dict({
'content': list([
dict({
'text': 'What shape is a donut?',
'type': 'text',
}),
dict({
'text': 'Can you tell me?',
'type': 'text',
}),
dict({
'text': 'Please?',
'type': 'text',
}),
]),
'role': 'user',
}),
dict({
'content': list([
dict({
'text': 'A donut is a torus.',
'type': 'text',
}),
dict({
'text': 'Hope this helps.',
'type': 'text',
}),
dict({
'text': 'You are welcome.',
'type': 'text',
}),
]),
'role': 'assistant',
}),
dict({
'content': 'Are you sure?',
'role': 'user',
}),
dict({
'content': list([
dict({
'text': 'Yes, I am sure!',
'type': 'text',
}),
]),
'role': 'assistant',
}),
])
# ---
# name: test_history_conversion[content4]
list([
dict({
'content': 'Turn off the lights and make me coffee',
'role': 'user',
}),
dict({
'content': list([
dict({
'text': 'Sure.',
'type': 'text',
}),
dict({
'id': 'mock-tool-call-id',
'input': dict({
'domain': 'light',
}),
'name': 'HassTurnOff',
'type': 'tool_use',
}),
dict({
'id': 'mock-tool-call-id-2',
'input': dict({
}),
'name': 'MakeCoffee',
'type': 'tool_use',
}),
]),
'role': 'assistant',
}),
dict({
'content': list([
dict({
'text': 'Thank you',
'type': 'text',
}),
dict({
'content': '{"success": true, "response": "Lights are off."}',
'tool_use_id': 'mock-tool-call-id',
'type': 'tool_result',
}),
dict({
'content': '{"success": false, "response": "Not enough milk."}',
'tool_use_id': 'mock-tool-call-id-2',
'type': 'tool_result',
}),
]),
'role': 'user',
}),
dict({
'content': list([
dict({
'text': 'Should I add milk to the shopping list?',
'type': 'text',
}),
]),
'role': 'assistant',
}),
dict({
'content': 'Are you sure?',
'role': 'user',
}),
dict({
'content': list([
dict({
'text': 'Yes, I am sure!',
'type': 'text',
}),
]),
'role': 'assistant',
}),
])
# ---
# name: test_unknown_hass_api
dict({
'continue_conversation': False,

View File

@ -21,9 +21,11 @@ from homeassistant.components.anthropic.const import (
CONF_PROMPT,
CONF_RECOMMENDED,
CONF_TEMPERATURE,
CONF_THINKING_BUDGET,
DOMAIN,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_THINKING_BUDGET,
)
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant
@ -94,6 +96,28 @@ async def test_options(
assert options["data"][CONF_CHAT_MODEL] == RECOMMENDED_CHAT_MODEL
async def test_options_thinking_budget_more_than_max(
hass: HomeAssistant, mock_config_entry, mock_init_component
) -> None:
"""Test error about thinking budget being more than max tokens."""
options_flow = await hass.config_entries.options.async_init(
mock_config_entry.entry_id
)
options = await hass.config_entries.options.async_configure(
options_flow["flow_id"],
{
"prompt": "Speak like a pirate",
"max_tokens": 8192,
"chat_model": "claude-3-7-sonnet-latest",
"temperature": 1,
"thinking_budget": 16384,
},
)
await hass.async_block_till_done()
assert options["type"] is FlowResultType.FORM
assert options["errors"] == {"thinking_budget": "thinking_budget_too_large"}
@pytest.mark.parametrize(
("side_effect", "error"),
[
@ -186,6 +210,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
CONF_TEMPERATURE: 0.3,
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
CONF_THINKING_BUDGET: RECOMMENDED_THINKING_BUDGET,
},
),
(
@ -195,6 +220,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
CONF_TEMPERATURE: 0.3,
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
CONF_THINKING_BUDGET: RECOMMENDED_THINKING_BUDGET,
},
{
CONF_RECOMMENDED: True,

View File

@ -14,13 +14,18 @@ from anthropic.types import (
RawMessageStartEvent,
RawMessageStopEvent,
RawMessageStreamEvent,
RedactedThinkingBlock,
SignatureDelta,
TextBlock,
TextDelta,
ThinkingBlock,
ThinkingDelta,
ToolUseBlock,
Usage,
)
from freezegun import freeze_time
from httpx import URL, Request, Response
import pytest
from syrupy.assertion import SnapshotAssertion
import voluptuous as vol
@ -28,7 +33,7 @@ from homeassistant.components import conversation
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import intent, llm
from homeassistant.helpers import chat_session, intent, llm
from homeassistant.setup import async_setup_component
from homeassistant.util import ulid as ulid_util
@ -86,6 +91,57 @@ def create_content_block(
]
def create_thinking_block(
index: int, thinking_parts: list[str]
) -> list[RawMessageStreamEvent]:
"""Create a thinking block with the specified deltas."""
return [
RawContentBlockStartEvent(
type="content_block_start",
content_block=ThinkingBlock(signature="", thinking="", type="thinking"),
index=index,
),
*[
RawContentBlockDeltaEvent(
delta=ThinkingDelta(thinking=thinking_part, type="thinking_delta"),
index=index,
type="content_block_delta",
)
for thinking_part in thinking_parts
],
RawContentBlockDeltaEvent(
delta=SignatureDelta(
signature="ErUBCkYIARgCIkCYXaVNJShe3A86Hp7XUzh9YsCYBbJTbQsrklTAPtJ2sP/N"
"oB6tSzpK/nTL6CjSo2R6n0KNBIg5MH6asM2R/kmaEgyB/X1FtZq5OQAC7jUaDEPWCdcwGQ"
"4RaBy5wiIwmRxExIlDhoY6tILoVPnOExkC/0igZxHEwxK8RU/fmw0b+o+TwAarzUitwzbo"
"21E5Kh3pa3I6yqVROf1t2F8rFocNUeCegsWV/ytwYV+ayA==",
type="signature_delta",
),
index=index,
type="content_block_delta",
),
RawContentBlockStopEvent(index=index, type="content_block_stop"),
]
def create_redacted_thinking_block(index: int) -> list[RawMessageStreamEvent]:
"""Create a redacted thinking block."""
return [
RawContentBlockStartEvent(
type="content_block_start",
content_block=RedactedThinkingBlock(
data="EroBCkYIARgCKkBJDytPJhw//4vy3t7aE+LfIkxvkAh51cBPrAvBCo6AjgI57Zt9K"
"WPnUVV50OQJ0KZzUFoGZG5sxg95zx4qMwkoEgz43Su3myJKckvj03waDBZLIBSeoAeRUeV"
"sJCIwQ5edQN0sa+HNeB/KUBkoMUwV+IT0eIhcpFxnILdvxUAKM4R1o4KG3x+yO0eo/kyOK"
"iKfrCPFQhvBVmTZPFhgA2Ow8L9gGDVipcz6x3Uu9YETGEny",
type="redacted_thinking",
),
index=index,
),
RawContentBlockStopEvent(index=index, type="content_block_stop"),
]
def create_tool_use_block(
index: int, tool_id: str, tool_name: str, json_parts: list[str]
) -> list[RawMessageStreamEvent]:
@ -381,7 +437,7 @@ async def test_function_exception(
return stream_generator(
create_messages(
[
*create_content_block(0, "Certainly, calling it now!"),
*create_content_block(0, ["Certainly, calling it now!"]),
*create_tool_use_block(
1,
"toolu_0123456789AbCdEfGhIjKlM",
@ -464,7 +520,7 @@ async def test_assist_api_tools_conversion(
new_callable=AsyncMock,
return_value=stream_generator(
create_messages(
create_content_block(0, "Hello, how can I help you?"),
create_content_block(0, ["Hello, how can I help you?"]),
),
),
) as mock_create:
@ -509,7 +565,7 @@ async def test_conversation_id(
def create_stream_generator(*args, **kwargs) -> Any:
return stream_generator(
create_messages(
create_content_block(0, "Hello, how can I help you?"),
create_content_block(0, ["Hello, how can I help you?"]),
),
)
@ -547,3 +603,283 @@ async def test_conversation_id(
)
assert result.conversation_id == "koala"
async def test_extended_thinking(
hass: HomeAssistant,
mock_config_entry_with_extended_thinking: MockConfigEntry,
mock_init_component,
) -> None:
"""Test extended thinking support."""
with patch(
"anthropic.resources.messages.AsyncMessages.create",
new_callable=AsyncMock,
return_value=stream_generator(
create_messages(
[
*create_thinking_block(
0,
[
"The user has just",
' greeted me with "Hi".',
" This is a simple greeting an",
"d doesn't require any Home Assistant function",
" calls. I should respond with",
" a friendly greeting and let them know I'm available",
" to help with their smart home.",
],
),
*create_content_block(1, ["Hello, how can I help you today?"]),
]
),
),
):
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id="conversation.claude"
)
chat_log = hass.data.get(conversation.chat_log.DATA_CHAT_LOGS).get(
result.conversation_id
)
assert len(chat_log.content) == 3
assert chat_log.content[1].content == "hello"
assert chat_log.content[2].content == "Hello, how can I help you today?"
async def test_redacted_thinking(
hass: HomeAssistant,
mock_config_entry_with_extended_thinking: MockConfigEntry,
mock_init_component,
) -> None:
"""Test extended thinking with redacted thinking blocks."""
with patch(
"anthropic.resources.messages.AsyncMessages.create",
new_callable=AsyncMock,
return_value=stream_generator(
create_messages(
[
*create_redacted_thinking_block(0),
*create_redacted_thinking_block(1),
*create_redacted_thinking_block(2),
*create_content_block(3, ["How can I help you today?"]),
]
),
),
):
result = await conversation.async_converse(
hass,
"ANTHROPIC_MAGIC_STRING_TRIGGER_REDACTED_THINKING_46C9A13E193C177646C7398A9"
"8432ECCCE4C1253D5E2D82641AC0E52CC2876CB",
None,
Context(),
agent_id="conversation.claude",
)
chat_log = hass.data.get(conversation.chat_log.DATA_CHAT_LOGS).get(
result.conversation_id
)
assert len(chat_log.content) == 3
assert chat_log.content[2].content == "How can I help you today?"
@patch("homeassistant.components.anthropic.conversation.llm.AssistAPI._async_get_tools")
async def test_extended_thinking_tool_call(
mock_get_tools,
hass: HomeAssistant,
mock_config_entry_with_extended_thinking: MockConfigEntry,
mock_init_component,
snapshot: SnapshotAssertion,
) -> None:
"""Test that thinking blocks and their order are preserved in with tool calls."""
agent_id = "conversation.claude"
context = Context()
mock_tool = AsyncMock()
mock_tool.name = "test_tool"
mock_tool.description = "Test function"
mock_tool.parameters = vol.Schema(
{vol.Optional("param1", description="Test parameters"): str}
)
mock_tool.async_call.return_value = "Test response"
mock_get_tools.return_value = [mock_tool]
def completion_result(*args, messages, **kwargs):
for message in messages:
for content in message["content"]:
if not isinstance(content, str) and content["type"] == "tool_use":
return stream_generator(
create_messages(
create_content_block(
0, ["I have ", "successfully called ", "the function"]
),
)
)
return stream_generator(
create_messages(
[
*create_thinking_block(
0,
[
"The user asked me to",
" call a test function.",
"Is it a test? What",
" would the function",
" do? Would it violate",
" any privacy or security",
" policies?",
],
),
*create_redacted_thinking_block(1),
*create_thinking_block(
2, ["Okay, let's give it a shot.", " Will I pass the test?"]
),
*create_content_block(3, ["Certainly, calling it now!"]),
*create_tool_use_block(
1,
"toolu_0123456789AbCdEfGhIjKlM",
"test_tool",
['{"para', 'm1": "test_valu', 'e"}'],
),
]
)
)
with (
patch(
"anthropic.resources.messages.AsyncMessages.create",
new_callable=AsyncMock,
side_effect=completion_result,
) as mock_create,
freeze_time("2024-06-03 23:00:00"),
):
result = await conversation.async_converse(
hass,
"Please call the test function",
None,
context,
agent_id=agent_id,
)
chat_log = hass.data.get(conversation.chat_log.DATA_CHAT_LOGS).get(
result.conversation_id
)
assert chat_log.content == snapshot
assert mock_create.mock_calls[1][2]["messages"] == snapshot
@pytest.mark.parametrize(
"content",
[
[
conversation.chat_log.SystemContent("You are a helpful assistant."),
],
[
conversation.chat_log.SystemContent("You are a helpful assistant."),
conversation.chat_log.UserContent("What shape is a donut?"),
conversation.chat_log.AssistantContent(
agent_id="conversation.claude", content="A donut is a torus."
),
],
[
conversation.chat_log.SystemContent("You are a helpful assistant."),
conversation.chat_log.UserContent("What shape is a donut?"),
conversation.chat_log.UserContent("Can you tell me?"),
conversation.chat_log.AssistantContent(
agent_id="conversation.claude", content="A donut is a torus."
),
conversation.chat_log.AssistantContent(
agent_id="conversation.claude", content="Hope this helps."
),
],
[
conversation.chat_log.SystemContent("You are a helpful assistant."),
conversation.chat_log.UserContent("What shape is a donut?"),
conversation.chat_log.UserContent("Can you tell me?"),
conversation.chat_log.UserContent("Please?"),
conversation.chat_log.AssistantContent(
agent_id="conversation.claude", content="A donut is a torus."
),
conversation.chat_log.AssistantContent(
agent_id="conversation.claude", content="Hope this helps."
),
conversation.chat_log.AssistantContent(
agent_id="conversation.claude", content="You are welcome."
),
],
[
conversation.chat_log.SystemContent("You are a helpful assistant."),
conversation.chat_log.UserContent("Turn off the lights and make me coffee"),
conversation.chat_log.AssistantContent(
agent_id="conversation.claude",
content="Sure.",
tool_calls=[
llm.ToolInput(
id="mock-tool-call-id",
tool_name="HassTurnOff",
tool_args={"domain": "light"},
),
llm.ToolInput(
id="mock-tool-call-id-2",
tool_name="MakeCoffee",
tool_args={},
),
],
),
conversation.chat_log.UserContent("Thank you"),
conversation.chat_log.ToolResultContent(
agent_id="conversation.claude",
tool_call_id="mock-tool-call-id",
tool_name="HassTurnOff",
tool_result={"success": True, "response": "Lights are off."},
),
conversation.chat_log.ToolResultContent(
agent_id="conversation.claude",
tool_call_id="mock-tool-call-id-2",
tool_name="MakeCoffee",
tool_result={"success": False, "response": "Not enough milk."},
),
conversation.chat_log.AssistantContent(
agent_id="conversation.claude",
content="Should I add milk to the shopping list?",
),
],
],
)
async def test_history_conversion(
hass: HomeAssistant,
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
snapshot: SnapshotAssertion,
content: list[conversation.chat_log.Content],
) -> None:
"""Test conversion of chat_log entries into API parameters."""
conversation_id = "conversation_id"
with (
chat_session.async_get_chat_session(hass, conversation_id) as session,
conversation.async_get_chat_log(hass, session) as chat_log,
patch(
"anthropic.resources.messages.AsyncMessages.create",
new_callable=AsyncMock,
return_value=stream_generator(
create_messages(
[
*create_content_block(0, ["Yes, I am sure!"]),
]
),
),
) as mock_create,
):
chat_log.content = content
await conversation.async_converse(
hass,
"Are you sure?",
conversation_id,
Context(),
agent_id="conversation.claude",
)
assert mock_create.mock_calls[0][2]["messages"] == snapshot